Search code examples
pytorchautograd

How to avoid recalculating a function when we need to backpropagate through it twice?


In PyTorch, I want to do the following calculation:

l1 = f(x.detach(), y)
l1.backward(retain_graph=True)
l2 = -1*f(x, y.detach())
l2.backward()

where f is some function, and x and y are tensors that require gradient. Notice that x and y may both be the results of previous calculations which utilize shared parameters (for example, maybe x=g(z) and y=g(w) where g is an nn.Module).

The issue is that l1 and l2 are both numerically identical, up to the minus sign, and it seems wasteful to repeat the calculation f(x,y) twice. It would be nicer to be able to calculate it once, and apply backward twice on the result. Is there any way of doing this?

One possibility is to manually call autograd.grad and update the w.grad field of each nn.Parameter w. But I'm wondering if there is a more direct and clean way to do this, using the backward function.


Solution

  • I took this answer from here.

    We can calculate f(x,y) once, without detaching neither x or y, if we ensure that we we multiply by -1 the gradient flowing through x. This can be done using register_hook:

    x.register_hook(lambda t: -t)
    l = f(x,y)
    l.backward()
    

    Here is code demonstrating that this works:

    import torch
    
    lin = torch.nn.Linear(1, 1, bias=False)
    lin.weight.data[:] = 1.0
    a = torch.tensor([1.0])
    b = torch.tensor([2.0])
    loss_func = lambda x, y: (x - y).abs()
    
    # option 1: this is the inefficient option, presented in the original question
    lin.zero_grad()
    x = lin(a)
    y = lin(b)
    loss1 = loss_func(x.detach(), y)
    loss1.backward(retain_graph=True)
    loss2 = -1 * loss_func(x, y.detach())  # second invocation of `loss_func` - not efficient!
    loss2.backward()
    print(lin.weight.grad)
    
    # option 2: this is the efficient method, suggested in this answer. 
    lin.zero_grad()
    x = lin(a)
    y = lin(b)
    x.register_hook(lambda t: -t)
    loss = loss_func(x, y)  # only one invocation of `loss_func` - more efficient!
    loss.backward()
    print(lin.weight.grad)  # the output of this is identical to the previous print, which confirms the method
    
    # option 3 - this should not be equivalent to the previous options, used just for comparison
    lin.zero_grad()
    x = lin(a)
    y = lin(b)
    loss = loss_func(x, y)
    loss.backward()
    print(lin.weight.grad)