Search code examples
pytorch

PyTorch: how to let only chosen elements in a tensor be differentiable?


Right now, I have a use case like following

A is a torch.tensor like following

A = 
[[1,x,y], 
 [1,2,3],
 [1,z,3]]

only elements x, y, z in A are differentiable, other elements are just constant.

For example, if the cost = tr(A.A)

cost = 14 + 2 x + 2 y + 6 z

When I do backtracking, I only want to differentiate and update with respect to x, y, z. Certainly this example is just a toy example, not the true complicated one.

How can I realize such use case?


Solution

  • I've figured out a method:

    import torch
    
    # Define variables x, y, z as differentiable tensors
    x = torch.tensor([2.0], requires_grad=True)
    y = torch.tensor([2.0], requires_grad=True)
    z = torch.tensor([2.0], requires_grad=True)
    
    print(x.grad, y.grad, z.grad)  # None, None, None
    
    # Create tensor A with only x, y, z being differentiable
    A = torch.tensor([[1, x.item(), y.item()],
                      [1, 2, 3],
                      [1, z.item(), 3]], requires_grad=False)
    
    # Put x, y, z back into A, this time they are differentiable
    A[0, 1] = x
    A[0, 2] = y
    A[2, 1] = z
    
    # Compute the loss function
    cost = torch.trace(A @ A)
    
    # Backpropagate to compute the gradients
    cost.backward()
    
    # Output the gradients
    print(x.grad, y.grad, z.grad)  # 2, 2, 6