Consider the following two examples:
x = torch.tensor(1., requires_grad=True)
y = torch.tensor(0., requires_grad=True)
z = torch.full((2, ), float("nan"))
z0 = x * y / y
z1 = x + y
print(z0, z1) # tensor(nan, grad_fn=<DivBackward0>) tensor(1., grad_fn=<AddBackward0>)
z1.backward()
print(x.grad) # tensor(1.)
x = torch.tensor(1., requires_grad=True)
y = torch.tensor(0., requires_grad=True)
z = torch.full((2, ), float("nan"))
z[0] = x * y / y
z[1] = x + y
print(z) # tensor([nan, 1.], grad_fn=<CopySlices>)
z[1].backward()
print(x.grad) # tensor(nan)
In example 1, z0
does not affect z1
, and the backward()
of z1
executes as expected and x.grad
is not nan. However, in example 2, the backward()
of z[1]
seems to be affected by z[0]
, and x.grad
is nan
.
How do I prevent this (example 1 is desired behaviour)? Specifically I need to retain the nan
in z[0]
so adding epsilon to division does not help.
When indexing the tensor in the assignment, PyTorch accesses all elements of the tensor (it uses binary multiplicative masking under the hood to maintain differentiability) and this is where it is picking up the nan
of the other element (since 0*nan -> nan
).
We can see this in the computational graph:
If you wish to avoid this behaviour, either mask the nan
's, or do as you have done in the first example - separate these into two different objects.