Right now, I have a use case like following
A
is a torch.tensor
like following
A =
[[1,x,y],
[1,2,3],
[1,z,3]]
only elements x, y, z
in A
are differentiable, other elements are just constant.
For example, if the cost = tr(A.A)
cost = 14 + 2 x + 2 y + 6 z
When I do backtracking, I only want to differentiate and update with respect to x, y, z
. Certainly this example is just a toy example, not the true complicated one.
How can I realize such use case?
I've figured out a method:
import torch
# Define variables x, y, z as differentiable tensors
x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=True)
z = torch.tensor([2.0], requires_grad=True)
print(x.grad, y.grad, z.grad) # None, None, None
# Create tensor A with only x, y, z being differentiable
A = torch.tensor([[1, x.item(), y.item()],
[1, 2, 3],
[1, z.item(), 3]], requires_grad=False)
# Put x, y, z back into A, this time they are differentiable
A[0, 1] = x
A[0, 2] = y
A[2, 1] = z
# Compute the loss function
cost = torch.trace(A @ A)
# Backpropagate to compute the gradients
cost.backward()
# Output the gradients
print(x.grad, y.grad, z.grad) # 2, 2, 6