I am trying to get a deeper understanding of how Pytorch's autograd works. I am unable to explain the following results:
import torch
def fn(a):
b = torch.tensor(5,dtype=torch.float32,requires_grad=True)
return a*b
a = torch.tensor(10,dtype=torch.float32,requires_grad=True)
output = fn(a)
output.backward()
print(a.grad)
The output is tensor(5.). But my question is that the variable b is created within the function and so should be removed from memory after the function returns a*b right? So when I call backward how is the value of b still present for allowing this computation? As far as I understand each operation in Pytorch has a context variable which tracks "which" tensor to use for backward computation and there are also versions present in each tensor, and if the version changes then backward should raise an error right?
Now when I try to run the following code,
import torch
def fn(a):
b = a**2
for i in range(5):
b *= b
return b
a = torch.tensor(10,dtype=torch.float32,requires_grad=True)
output = fn(a)
output.backward()
print(a.grad)
I get the following error: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor []], which is output 0 of MulBackward0, is at version 5; expected version 4 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
But if I run the following code, there is no error:
import torch
def fn(a):
b = a**2
for i in range(2):
b = b*b
return b
def fn2(a):
b = a**2
c = a**2
for i in range(2):
c *= b
return c
a = torch.tensor(5,dtype=torch.float32,requires_grad=True)
output = fn(a)
output.backward()
print(a.grad)
output2 = fn2(a)
output2.backward()
print(a.grad)
The output for this is :
tensor(625000.)
tensor(643750.)
So for a standard computation graphs with quite a few variables, in the same function, I am able to understand how the computation graph works. But when there is a variable changing before the call of backward function, I am having a lot of trouble understanding the results. Can someone explain?
Please note that b *=b
is not same as b = b*b
.
It is perhaps confusing, but the underlying operations vary.
In case of b *=b
, an in-place operation takes place which messes up with the gradients and hence the RuntimeError
.
In case of b = b*b
, two tensor objects gets multiplied and the resulting object is assigned the name b
. Thus no RuntimeError
when you run this way.
Here is a SO question on the underlying python operation: The difference between x += y and x = x + y
Now what is the difference between fn
in first case and fn2
in the second case? The operation c*=b
does not destroy the graph links to b
from c
. The operation c*=c
would make it impossible to have a graph connecting two tensors via an operation.
Well, I cannot work with tensors to showcase that because they raise RuntimeError. So I'll try with python list.
>>> x = [1,2]
>>> y = [3]
>>> id(x), id(y)
(140192646516680, 140192646927112)
>>>
>>> x += y
>>> x, y
([1, 2, 3], [3])
>>> id(x), id(y)
(140192646516680, 140192646927112)
Notice that there is no new object created. So it is not possible to trace from the output
to initial variables. We cannot distinguish the object_140192646516680
to be an output or an input. So how does one create a graph with that..
Consider the following alternate case:
>>> a = [1,2]
>>> b = [3]
>>>
>>> id(a), id(b)
(140192666168008, 140192666168264)
>>>
>>> a = a + b
>>> a, b
([1, 2, 3], [3])
>>> id(a), id(b)
(140192666168328, 140192666168264)
>>>
Notice that the new list a
is in fact a new object with id
140192666168328
. Here we can trace that the object_140192666168328
came from the addition operation
between two other objects object_140192666168008
and object_140192666168264
. Thus a graph can be dynamically created and gradients can be propagated back from output
to previous layers.