Search code examples
machine-learningdeep-learningneural-networkpytorchgradient-descent

torch.no_grad() and detach() combined


I encountered many code fragments like the following for choosing an action, that include a mix of torch.no_grad and detach (where actor is some actor, SomeDistribution your preferred distribution), and I'm wondering whether they make sense:

def f():
    with torch.no_grad():
        x = actor(observation)
    dist = SomeDistribution(x)
    sample = dist.sample()
    return sample.detach()

Is the use of detach in the return statement not unnecessary, as x has its requires_grad already set to False, so all computations using x should already be detached from the graph? Or do the computations after the torch.no_grad wrapper somehow end up on the graph again, so we need to detach them once again in the end (in which case it seems to me that no_grad would be unnecessary)? Also, if I'm right, I suppose instead of omitting detach one could also omit torch.no_grad, and end up with the same functionality, but worse performance, so torch.no_grad is to be preferred?


Solution

  • While it may be redundant, it depends on the internals of actor and SomeDistribution. In general, there are three cases I can think of where detach would be necessary in this code. Since you've already observed that x has requires_grad set to False then cases 2 and 3 don't apply to your specific case.

    1. If SomeDistribution has internal parameters (leaf tensors with requires_grad=True) then dist.sample() may result in a computation graph connecting sample to those parameters. Without detaching, that computation graph, including those parameters, would be unnecessarily kept in memory after returning.
    2. The default behavior within a torch.no_grad context is to return the result of tensor operations having requires_grad set to False. However, if actor(observation) for some reason explicitly sets requires_grad of its return value to True before returning, then a computation graph may be created that connects x to sample. Without detaching, that computation graph, including x, would be unnecessarily kept in memory after returning.
    3. This one seems even more unlikely, but if actor(observation) actually just returns a reference to observation, and observation.requires_grad is True, then a computation graph all the way from observation to sample may be constructed during dist.sample().

    As for the suggestion of removing the no_grad context in leu of detach, this may result in the construction of a computation graph connecting observation (if it requires gradients) and/or the parameters of the distribution (if it has any) to x. The graph would be discarded after detach, but it does take time and memory to create the computation graph, so there may be a performance penalty.

    In conclusion, it's safer to do both no_grad and detach, though the necessity of either depends on the details of the distribution and actor.