Search code examples
pythonpytorch

How retain_grad() in pytorch works? I found its position changes the grad result


in a simple test in pytorch, I want to see grad in a non-leaf tensor, so I use retain_grad():

import torch
a = torch.tensor([1.], requires_grad=True)
y = torch.zeros((10))
gt = torch.zeros((10))

y[0] = a
y[1] = y[0] * 2
y.retain_grad()

loss = torch.sum((y-gt) ** 2)
loss.backward()
print(y.grad)

it gives me a normal output:

tensor([2., 4., 0., 0., 0., 0., 0., 0., 0., 0.])

but when I use retain grad() before y[1] and after y[0] is assigned:

import torch
a = torch.tensor([1.], requires_grad=True)
y = torch.zeros((10))
gt = torch.zeros((10))

y[0] = a
y.retain_grad()
y[1] = y[0] * 2

loss = torch.sum((y-gt) ** 2)
loss.backward()
print(y.grad)

now the output changes to:

tensor([10.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.])

I can't understand the result at all.


Solution

  • Okay so what's going on is really weird.

    What .retain_grad() essentially does is convert any non-leaf tensor into a leaf tensor, such that it contains a .grad attribute (since by default, pytorch computes gradients to leaf tensors only).

    Hence, in your first example, after calling y.retain_grad(), it basically converted y into a leaf tensor with an accessible .grad attribute.

    However, in your second example, you initially converted the entire y tensor into a leaf tensor; then, you created a non-leaf tensor (y[1]) within your leaf tensor (y), which is what caused the confusion.

    y = torch.zeros((10))  # y is a non-leaf tensor
    
    y[0] = a  # y[0] is a non-leaf tensor
    y.retain_grad()  # y is a leaf tensor (including y[1])
    y[1] = y[0] * 2  # y[1] is a non-leaf tensor, BUT y[0], y[2], y[3], ..., y[9] are all leaf tensors!
    

    The confusing part is:

    y[1] after calling y.retain_grad() is now a leaf tensor with a .grad attribute. However, y[1] after the computation (y[1] = y[0] * 2) is now not a leaf tensor with a .grad attribute; it is now treated as a new non-leaf variable/tensor.

    Therefore, when calling loss.backward(), the Chain rule of the loss w.r.t y, and particularly looking at the Chain rule of the loss w.r.t leaf y[1] now looks something like this:


    Chain rule