Search code examples
gradientpytorch

How to set gradients to Zero without optimizer?


Between mutliple .backward() passes I'd like to set the gradients to zero. Right now I have to do this for every component seperately (here these are x and t), is there a way to do this "globally" for all affected variables? (I imagine something like z.set_all_gradients_to_zero().)

I know there is optimizer.zero_grad() if you use an optimizer, but is there also a direct way without using an optimizer?

import torch

x = torch.randn(3, requires_grad = True)
t = torch.randn(3, requires_grad = True)
y = x + t
z = y + y.flip(0)

z.backward(torch.tensor([1., 0., 0.]), retain_graph = True)
print(x.grad)
print(t.grad)
x.grad.data.zero_()  # both gradients need to be set to zero 
t.grad.data.zero_()
z.backward(torch.tensor([0., 1., 0.]), retain_graph = True)
print(x.grad)
print(t.grad)

Solution

  • You can also use nn.Module.zero_grad(). In fact, optim.zero_grad() just calls nn.Module.zero_grad() on all parameters which were passed to it.

    There is no reasonable way to do it globally. You can collect your variables in a list

    grad_vars = [x, t]
    for var in grad_vars:
        var.grad = None
    

    or create some hacky function based on vars(). Perhaps it's also possible to inspect the computation graph and zero the gradient of all leaf nodes, but I am not familiar with the graph API. Long story short, you're expected to use the object-oriented interface of torch.nn instead of manually creating tensor variables.