Let's say I want to create torch.tensor object of size [2,3] filled with random elements, and I intend to use this matrix in the network and optimize it's values. However, I want to update only some of the the values in the matrix.
I know that it can be done for a tensor by setting up parameter requires_grad
To True
or False
. However, the following code
z = torch.rand([2,3], requires_grad=True)
z[-1][-1].requires_grad=False
does not work as expected
RuntimeError: you can only change requires_grad flags of leaf variables. If you want to use a computed variable in a subgraph that doesn't require differentiation use var_no_grad = var.detach().
How to fix this RuntimeError? How to initialize torch tensor and then define which elements there would have requires_grad =True
?
If I write code in a similar manner:
z = torch.rand([2,3], requires_grad=False)
z[-1][-1].requires_grad=True
There will be no error, but no change of the requires_grad as well.
It does not really make much sense to have a single tensor which requires_grad
for only part of its entries.
Why not have two separate tensors one that us updated (requires_grad=True
) and another one fixed (requires_grad=False
)? You can then merge them for computational ease:
fixed = torch.rand([2, 3], require_grad=False)
upd = torch.rand([2, 3], require_grad=True)
mask = torch.tensor([[0, 1, 0], [1, 0, 1]], require_grad=False) # how to combine the two
# combine them using fixed "mask":
z = mask * fixed + (1-mask) * upd
You can obviously have other methods of combining fixed
and upd
other than using a binary mask
.
For example, if upd
occupies the first two columns of z
and fixed
the rest, then:
fixed = torch.rand([2, 1], require_grad=False)
upd = torch.rand([2, 2], require_grad=True)
# combine them using concatination
z = torch.cat((upd, fixed),dim=1)
Or, if you know the indices
fidx = torch.tensor([0, 2], dtype=torch.long)
uidx = torch.tensor([1, 3, 4, 5], dtype=torch.long)
fixed = torch.rand([2,], require_grad=False)
upd = torch.rand([4,], require_grad=True)
z = torch.empty([2, 3])
z[fidx] = fixed
z[uidx] = upd