In PyTorch, I want to do the following calculation:
l1 = f(x.detach(), y)
l1.backward(retain_graph=True)
l2 = -1*f(x, y.detach())
l2.backward()
where f
is some function, and x
and y
are tensors that require gradient. Notice that x
and y
may both be the results of previous calculations which utilize shared parameters (for example, maybe x=g(z)
and y=g(w)
where g
is an nn.Module
).
The issue is that l1
and l2
are both numerically identical, up to the minus sign, and it seems wasteful to repeat the calculation f(x,y)
twice. It would be nicer to be able to calculate it once, and apply backward
twice on the result. Is there any way of doing this?
One possibility is to manually call autograd.grad
and update the w.grad
field of each nn.Parameter
w
. But I'm wondering if there is a more direct and clean way to do this, using the backward
function.
I took this answer from here.
We can calculate f(x,y)
once, without detaching neither x
or y
, if we ensure that we we multiply by -1
the gradient flowing through x
. This can be done using register_hook
:
x.register_hook(lambda t: -t)
l = f(x,y)
l.backward()
Here is code demonstrating that this works:
import torch
lin = torch.nn.Linear(1, 1, bias=False)
lin.weight.data[:] = 1.0
a = torch.tensor([1.0])
b = torch.tensor([2.0])
loss_func = lambda x, y: (x - y).abs()
# option 1: this is the inefficient option, presented in the original question
lin.zero_grad()
x = lin(a)
y = lin(b)
loss1 = loss_func(x.detach(), y)
loss1.backward(retain_graph=True)
loss2 = -1 * loss_func(x, y.detach()) # second invocation of `loss_func` - not efficient!
loss2.backward()
print(lin.weight.grad)
# option 2: this is the efficient method, suggested in this answer.
lin.zero_grad()
x = lin(a)
y = lin(b)
x.register_hook(lambda t: -t)
loss = loss_func(x, y) # only one invocation of `loss_func` - more efficient!
loss.backward()
print(lin.weight.grad) # the output of this is identical to the previous print, which confirms the method
# option 3 - this should not be equivalent to the previous options, used just for comparison
lin.zero_grad()
x = lin(a)
y = lin(b)
loss = loss_func(x, y)
loss.backward()
print(lin.weight.grad)