Search code examples
pythonpytorchtensorgradient-descent

How to create torch.tensor object and to update only some of its elements?


Let's say I want to create torch.tensor object of size [2,3] filled with random elements, and I intend to use this matrix in the network and optimize it's values. However, I want to update only some of the the values in the matrix.

I know that it can be done for a tensor by setting up parameter requires_grad To True or False. However, the following code

z = torch.rand([2,3], requires_grad=True)
z[-1][-1].requires_grad=False

does not work as expected

RuntimeError: you can only change requires_grad flags of leaf variables. If you want to use a computed variable in a subgraph that doesn't require differentiation use var_no_grad = var.detach().

How to fix this RuntimeError? How to initialize torch tensor and then define which elements there would have requires_grad =True?

If I write code in a similar manner:

z = torch.rand([2,3], requires_grad=False)
z[-1][-1].requires_grad=True

There will be no error, but no change of the requires_grad as well.


Solution

  • It does not really make much sense to have a single tensor which requires_grad for only part of its entries.
    Why not have two separate tensors one that us updated (requires_grad=True) and another one fixed (requires_grad=False)? You can then merge them for computational ease:

    fixed = torch.rand([2, 3], require_grad=False)
    upd = torch.rand([2, 3], require_grad=True)
    mask = torch.tensor([[0, 1, 0], [1, 0, 1]], require_grad=False)  # how to combine the two
    # combine them using fixed "mask":
    z = mask * fixed + (1-mask) * upd
    

    You can obviously have other methods of combining fixed and upd other than using a binary mask.
    For example, if upd occupies the first two columns of z and fixed the rest, then:

    fixed = torch.rand([2, 1], require_grad=False)
    upd = torch.rand([2, 2], require_grad=True)
    # combine them using concatination
    z = torch.cat((upd, fixed),dim=1)
    

    Or, if you know the indices

    fidx = torch.tensor([0, 2], dtype=torch.long)
    uidx = torch.tensor([1, 3, 4, 5], dtype=torch.long)
    fixed = torch.rand([2,], require_grad=False)
    upd = torch.rand([4,], require_grad=True)
    z = torch.empty([2, 3])
    z[fidx] = fixed 
    z[uidx] = upd