Between mutliple .backward()
passes I'd like to set the gradients to zero. Right now I have to do this for every component seperately (here these are x
and t
), is there a way to do this "globally" for all affected variables? (I imagine something like z.set_all_gradients_to_zero()
.)
I know there is optimizer.zero_grad()
if you use an optimizer, but is there also a direct way without using an optimizer?
import torch
x = torch.randn(3, requires_grad = True)
t = torch.randn(3, requires_grad = True)
y = x + t
z = y + y.flip(0)
z.backward(torch.tensor([1., 0., 0.]), retain_graph = True)
print(x.grad)
print(t.grad)
x.grad.data.zero_() # both gradients need to be set to zero
t.grad.data.zero_()
z.backward(torch.tensor([0., 1., 0.]), retain_graph = True)
print(x.grad)
print(t.grad)
You can also use nn.Module.zero_grad()
. In fact, optim.zero_grad()
just calls nn.Module.zero_grad()
on all parameters which were passed to it.
There is no reasonable way to do it globally. You can collect your variables in a list
grad_vars = [x, t]
for var in grad_vars:
var.grad = None
or create some hacky function based on vars()
. Perhaps it's also possible to inspect the computation graph and zero the gradient of all leaf nodes, but I am not familiar with the graph API. Long story short, you're expected to use the object-oriented interface of torch.nn
instead of manually creating tensor variables.