Search code examples
pythonpytorchneural-network

Learning stops when you multiply the weights of a layer with a scalar?


I am trying to implement sparsely connected weight matrices for my simple 3-layer feedforward model. To do this I implemented a mask for each of my layers with a certain % of zeros, with the idea being that I would like to zero out the same set of weights after every optimizer step so that my layers are not fully connected. But I am having trouble with this because when I do an element-wise multiplication of the mask with the weight matrices, the weights stop changing in subsequent backward passes. To see if my mask is causing the issue, I did just multiplied my weight matrices with the scalar 1.0 and this recreates the issue. What might be happening here? I checked and gradients still get calculated. It’s just that the loss doesn’t go down anymore and the weights don’t change. Does doing this multiplication somehow disconnect the weights from the graph?

My model:

class TSP(nn.Module):

  def __init__(self, input_size, hidden_size):
    super(TSP, self).__init__()
    self.sc1 = nn.Linear(input_size, hidden_size)
    self.sc2 = nn.Linear(hidden_size, input_size)

    torch.nn.init.normal_(self.sc1.weight, mean=0, std=0.1)
    torch.nn.init.normal_(self.sc2.weight, mean=0, std=0.1)

  def forward(self, x):
    x = torch.relu(self.sc1(x)) 
    x = (self.sc2(x))
    return x


  def predict_hidden(self, x):
    x = torch.relu(self.sc1(x))
    return x

To recreate this issue all that is needed is the following and the weights stop getting updated:

model.sc1.weight = nn.Parameter(1. * model.sc1.weight)
model.sc2.weight = nn.Parameter(1. * model.sc2.weight)

Solution

  • When you run

    model.sc1.weight = nn.Parameter(1. * model.sc1.weight)
    model.sc2.weight = nn.Parameter(1. * model.sc2.weight)
    

    You are not "multiplying by a scalar". You are creating an entirely new object (nn.Parameter(1. * model.sc1.weight)) and assigning it to the .weight attribute.

    I assume you are updating your model with a standard pytorch optimizer, something like:

    model = TSP(...)
    opt = torch.optim.SGD(model.parameters(), lr=1e-3)
    

    When you run model.sc1.weight = nn.Parameter(1. * model.sc1.weight), you create an entirely new object in model.sc1.weight, but the optimizer still references the old object.

    You can validate this as follows:

    # data pointer of weight
    model.sc1.weight.data_ptr()
    > 124805056
    
    # data pointer of weight in the optimizer 
    opt.param_groups[0]['params'][0].data_ptr()
    > 124805056
    
    # now create new weight object
    model.sc1.weight = nn.Parameter(1. * model.sc1.weight)
    
    # data pointer of model weight has changed
    model.sc1.weight.data_ptr()
    > 139582720
    
    # data pointer of optimizer has not
    opt.param_groups[0]['params'][0].data_ptr()
    > 124805056
    

    To avoid this, update the object instead of creating a new object

    # data pointer of weight
    model.sc1.weight.data_ptr()
    > 124805056
    
    # data pointer of weight in the optimizer 
    opt.param_groups[0]['params'][0].data_ptr()
    > 124805056
    
    # update data of weight tensor with in-place operation
    model.sc1.weight.data.mul_(2.)
    
    # weight and optimizer still have same data pointer
    model.sc1.weight.data_ptr()
    > 124805056
    
    opt.param_groups[0]['params'][0].data_ptr()
    > 124805056