I am trying to implement sparsely connected weight matrices for my simple 3-layer feedforward model. To do this I implemented a mask for each of my layers with a certain % of zeros, with the idea being that I would like to zero out the same set of weights after every optimizer step so that my layers are not fully connected. But I am having trouble with this because when I do an element-wise multiplication of the mask with the weight matrices, the weights stop changing in subsequent backward passes. To see if my mask is causing the issue, I did just multiplied my weight matrices with the scalar 1.0 and this recreates the issue. What might be happening here? I checked and gradients still get calculated. It’s just that the loss doesn’t go down anymore and the weights don’t change. Does doing this multiplication somehow disconnect the weights from the graph?
My model:
class TSP(nn.Module):
def __init__(self, input_size, hidden_size):
super(TSP, self).__init__()
self.sc1 = nn.Linear(input_size, hidden_size)
self.sc2 = nn.Linear(hidden_size, input_size)
torch.nn.init.normal_(self.sc1.weight, mean=0, std=0.1)
torch.nn.init.normal_(self.sc2.weight, mean=0, std=0.1)
def forward(self, x):
x = torch.relu(self.sc1(x))
x = (self.sc2(x))
return x
def predict_hidden(self, x):
x = torch.relu(self.sc1(x))
return x
To recreate this issue all that is needed is the following and the weights stop getting updated:
model.sc1.weight = nn.Parameter(1. * model.sc1.weight)
model.sc2.weight = nn.Parameter(1. * model.sc2.weight)
When you run
model.sc1.weight = nn.Parameter(1. * model.sc1.weight)
model.sc2.weight = nn.Parameter(1. * model.sc2.weight)
You are not "multiplying by a scalar". You are creating an entirely new object (nn.Parameter(1. * model.sc1.weight)
) and assigning it to the .weight
attribute.
I assume you are updating your model with a standard pytorch optimizer, something like:
model = TSP(...)
opt = torch.optim.SGD(model.parameters(), lr=1e-3)
When you run model.sc1.weight = nn.Parameter(1. * model.sc1.weight)
, you create an entirely new object in model.sc1.weight
, but the optimizer still references the old object.
You can validate this as follows:
# data pointer of weight
model.sc1.weight.data_ptr()
> 124805056
# data pointer of weight in the optimizer
opt.param_groups[0]['params'][0].data_ptr()
> 124805056
# now create new weight object
model.sc1.weight = nn.Parameter(1. * model.sc1.weight)
# data pointer of model weight has changed
model.sc1.weight.data_ptr()
> 139582720
# data pointer of optimizer has not
opt.param_groups[0]['params'][0].data_ptr()
> 124805056
To avoid this, update the object instead of creating a new object
# data pointer of weight
model.sc1.weight.data_ptr()
> 124805056
# data pointer of weight in the optimizer
opt.param_groups[0]['params'][0].data_ptr()
> 124805056
# update data of weight tensor with in-place operation
model.sc1.weight.data.mul_(2.)
# weight and optimizer still have same data pointer
model.sc1.weight.data_ptr()
> 124805056
opt.param_groups[0]['params'][0].data_ptr()
> 124805056