Search code examples
pythondeep-learningpytorchbackpropagation

Freezing an intermediate layer in Pytorch


Suppose I have a simple Neural Network designed as:

lin0 = nn.Linear(2, 2)
lin1 = nn.Linear(2, 2)
lin2 = nn.Linear(2, 2)

My aim is to freeze the second layer, keeping the weight updating if the first and in the third one.

I have tried doing

x = lin0(x)
with torch.no_grad():    
    x = lin1(x)
x = lin2(x)

in the forward function, that should freezes all the parameters of 'lin1'. However, I am wondering if the back propagation is still reaching the first layer, and in case, how it updates the weights. Does it?


Solution

  • I actually tried 2 things:

    1. Like you said, using torch.no_grad() on the one layer you want to not change parameters for while training

    2. Set the layer.required_grad to False

    In the first case, the parameter values didn't change for layer 0 and layer 1. In the second case, however, the parameters didn't change only for layer 1 -- just like you wanted. I think when we use torch.no_grad(), the entire model stops right there when it comes to updating parameters.

    P.S: I just tracked the values of each layers parameters using the .parameter() function available in pytorch. Try it yourself and let me know!

    First instance:

    Model:

        class model(torch.nn.Module):
    
          def __init__(self):
        
              super(model,self).__init__()
              self.lin0 = nn.Linear(1,2)
              self.lin1 = nn.Linear(2,2)
              self.lin2 = nn.Linear(2,10)
        
          def forward(self,x):
       
              x = self.lin0(x)
              with torch.no_grad():
                 x = self.lin1(x)
              x = self.lin2(x)
    
      
        
              return x
    
          model = model()
    

    Second Instance:

        class model(torch.nn.Module):
    
           def __init__(self):
        
               super(model,self).__init__()
               self.lin0 = nn.Linear(1,2)
               self.lin1 = nn.Linear(2,2)
               self.lin2 = nn.Linear(2,10)
        
           def forward(self,x):
       
               x = self.lin0(x)
      
               x = self.lin1(x)
               x = self.lin2(x)
    
      
        
           return x
    
        model = model()
        model.lin1.weight.requires_grad = False
        model.lin1.bias.requires_grad = False