Search code examples
neural-networkpytorch

How to iterate through all parameters in a neural network using pytorch


I have the following simple fully-connected neural network:

class Neural_net(nn.Module):
    def __init__(self):
        super(Neural_net, self).__init__()
        self.fc1    = nn.Linear(2, 2)        
        self.fc2    = nn.Linear(2, 1)
        self.fc_out = nn.Linear(1, 1)      
        
    def forward(self, x,train = True):
        x = torch.tanh(self.fc1(x))
        x = torch.tanh(self.fc2(x))
        x = self.fc_out(x)
        return x

net = Neural_net()

How can I loop through all the parameters of the network and check if for example they are greater than a certain value? I am using pytorch and if I do:

for n,p in net.named_parameters():     
       if p > value:
       ...

I get an error since p is not a single number, but rather a tensor of either weights or biases for each layer.

My goal is to check if a criterion is satisfied for each of the parameters and flag them e.g. with 1 if it is or 0 if it is not, storing it in a dictionary with the same structure as net.parameters(). Yet, I am having trouble figuring out how to loop through them.

I thought about creating a parameter vector:

param_vec  =  torch.cat([p.view(-1) for p in net.parameters()])

and then accessing the parameter values and checking them would be easy,but then I can't think of a way to go back to the dictionary form to flag them.

Thank you for any help!


Solution

  • First I would define the criterion as an operation on a tensor. In your case, this could look like this:

    cond = lambda tensor: tensor.gt(value)
    

    Then you just need to apply it to each tensor in net.parameters(). To keep it with the same structure, you can do it with dict comprehension:

    cond_parameters = {n: cond(p) for n,p in net.named_parameters()}
    

    Let's see it in practice!

    net = Neural_net()
    print(dict(net.parameters())
    #> {'fc1.weight': Parameter containing:
    #>  tensor([[-0.4767,  0.0771],
    #>          [ 0.2874,  0.5474]], requires_grad=True),
    #>  'fc1.bias': Parameter containing:
    #>  tensor([ 0.0405, -0.1997], requires_grad=True),
    #>  'fc2.weight': Parameter containing:
    #>  tensor([[0.5400, 0.3241]], requires_grad=True),
    #>  'fc2.bias': Parameter containing:
    #>  tensor([-0.5306], requires_grad=True),
    #>  'fc_out.weight': Parameter containing:
    #>  tensor([[-0.9706]], requires_grad=True),
    #>  'fc_out.bias': Parameter containing:
    #> tensor([-0.4174], requires_grad=True)}
    

    Let's set value to zero and get the dict of parameters:

    value = 0
    cond = lambda tensor: tensor.gt(value)
    cond_parameters = {n: cond(p) for n,p in net.named_parameters()}
    #>{'fc1.weight': tensor([[False,  True],
    #>         [ True,  True]]),
    #> 'fc1.bias': tensor([ True, False]),
    #> 'fc2.weight': tensor([[True, True]]),
    #> 'fc2.bias': tensor([False]),
    #> 'fc_out.weight': tensor([[False]]),
    #> 'fc_out.bias': tensor([False])}