Search code examples
pythonpytorchtensorautograd

usage of retain graph in pytorch


I get error if I don't supply retain_graph=True in y1.backward()

   import torch
   x = torch.tensor([2.0], requires_grad=True)
   y = torch.tensor([3.0], requires_grad=True)
   f = x+y
   z = 2*f
   y1 = z**2
   y2 = z**3
   y1.backward()
   y2.backward()
Traceback (most recent call last):
  File "/Users/a0m08er/pytorch/pytorch_tutorial/tensor.py", line 58, in <module>
    y2.backward()
  File "/Users/a0m08er/pytorch/lib/python3.11/site-packages/torch/_tensor.py", line 521, in backward
    torch.autograd.backward(
  File "/Users/a0m08er/pytorch/lib/python3.11/site-packages/torch/autograd/__init__.py", line 289, in backward
    _engine_run_backward(
  File "/Users/a0m08er/pytorch/lib/python3.11/site-packages/torch/autograd/graph.py", line 769, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

But I don't get error when I do this:

   import torch

   x = torch.tensor([2.0], requires_grad=True)
   y = torch.tensor([3.0], requires_grad=True)
   z = x+y
   y1 = z**2
   y2 = z**3
   y1.backward()
   y2.backward()

Since z is a common node for y1 and y2 why it is not showing me error when I do y2.backward()


Solution

  • basically the error

    Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

    Error comes when the backwards pass tries to access tensors that were saved for the backwards pass (using ctx.save_for_backward), and those are not present (usually because they were freed after doing the first backward pass witoutretain_graph=True).

    So the computation graph is still there after the first backwards pass, only the tensors saved in context were freed.

    But the thing is, addition operations do not need to save tensors for backwards pass (the gradient along each of the inputs is the same as the gradient over the sum — so the gradient is just passed along the graph without doing any operation, no need to save anything for backward). Thus the error doesn't happen if the only shared node is an addition node.

    In comparison, multiplication needs to save the input values for the backward pass (since the gradient for a * b along b is a * grad(a * b)). Thus the exception gets raised when it tries to access them