I encountered many code fragments like the following for choosing an action, that include a mix of torch.no_grad
and detach
(where actor
is some actor, SomeDistribution
your preferred distribution), and I'm wondering whether they make sense:
def f():
with torch.no_grad():
x = actor(observation)
dist = SomeDistribution(x)
sample = dist.sample()
return sample.detach()
Is the use of detach
in the return statement not unnecessary, as x has its requires_grad
already set to False, so all computations using x
should already be detached from the graph? Or do the computations after the torch.no_grad
wrapper somehow end up on the graph again, so we need to detach them once again in the end (in which case it seems to me that no_grad
would be unnecessary)?
Also, if I'm right, I suppose instead of omitting detach
one could also omit torch.no_grad
, and end up with the same functionality, but worse performance, so torch.no_grad
is to be preferred?
While it may be redundant, it depends on the internals of actor
and SomeDistribution
. In general, there are three cases I can think of where detach
would be necessary in this code. Since you've already observed that x
has requires_grad
set to False
then cases 2 and 3 don't apply to your specific case.
SomeDistribution
has internal parameters (leaf tensors with requires_grad=True
) then dist.sample()
may result in a computation graph connecting sample
to those parameters. Without detaching, that computation graph, including those parameters, would be unnecessarily kept in memory after returning.torch.no_grad
context is to return the result of tensor operations having requires_grad
set to False
. However, if actor(observation)
for some reason explicitly sets requires_grad
of its return value to True
before returning, then a computation graph may be created that connects x
to sample
. Without detaching, that computation graph, including x
, would be unnecessarily kept in memory after returning.actor(observation)
actually just returns a reference to observation
, and observation.requires_grad
is True
, then a computation graph all the way from observation
to sample
may be constructed during dist.sample()
.As for the suggestion of removing the no_grad
context in leu of detach
, this may result in the construction of a computation graph connecting observation
(if it requires gradients) and/or the parameters of the distribution (if it has any) to x
. The graph would be discarded after detach
, but it does take time and memory to create the computation graph, so there may be a performance penalty.
In conclusion, it's safer to do both no_grad
and detach
, though the necessity of either depends on the details of the distribution and actor.