Search code examples
pythonpytorchnanautograd

PyTorch backward() on a tensor element affected by nan in other elements


Consider the following two examples:

x = torch.tensor(1., requires_grad=True)
y = torch.tensor(0., requires_grad=True)
z = torch.full((2, ), float("nan"))
z0 = x * y / y
z1 = x + y
print(z0, z1) # tensor(nan, grad_fn=<DivBackward0>) tensor(1., grad_fn=<AddBackward0>)
z1.backward()
print(x.grad) # tensor(1.)


x = torch.tensor(1., requires_grad=True)
y = torch.tensor(0., requires_grad=True)
z = torch.full((2, ), float("nan"))
z[0] = x * y / y
z[1] = x + y
print(z) # tensor([nan, 1.], grad_fn=<CopySlices>)
z[1].backward()
print(x.grad) # tensor(nan)

In example 1, z0 does not affect z1, and the backward() of z1 executes as expected and x.grad is not nan. However, in example 2, the backward() of z[1] seems to be affected by z[0], and x.grad is nan.

How do I prevent this (example 1 is desired behaviour)? Specifically I need to retain the nan in z[0] so adding epsilon to division does not help.


Solution

  • When indexing the tensor in the assignment, PyTorch accesses all elements of the tensor (it uses binary multiplicative masking under the hood to maintain differentiability) and this is where it is picking up the nan of the other element (since 0*nan -> nan).

    We can see this in the computational graph:

    torchviz.make_dot(z1, params={'x':x,'y':y}) torchviz.make_dot(z[1], params={'x':x,'y':y})
    enter image description here enter image description here

    If you wish to avoid this behaviour, either mask the nan's, or do as you have done in the first example - separate these into two different objects.