Search code examples
pytorch

Pytorch `torch.no_grad()` doesn't affect modules


I was under the (evidently wrong) impression from the documentation that torch.no_grad(), as a context manager, was supposed to make everything requires_grad=False. Indeed that's what I intended to use torch.no_grad() for, as just a convenient context manager for instantiating a bunch of things that I want to stay constant (through training). but that's only the case for torch.Tensors it seems; it doesn't seem to affect torch.nn.Modules, as the following example code shows:

with torch.no_grad():
    linear = torch.nn.Linear(2, 3)
for p in linear.parameters():
    print(p.requires_grad)

This will output:

True
True

That's a bit counterintuitive in my opinion. Is this the intended behaviour? If so, why? And is there a similarly convenient context manager in which I can be assured that anything I instantiate under it will not require gradient?


Solution

  • This is expected behavior, but I agree it is somewhat unclear from the documentation. Note that the documentation says :

    In this mode, the result of every computation will have requires_grad=False, even when the inputs have requires_grad=True.

    This context disables the gradient on the output of any computation done within the context. Technically, declaring/creating a layer is not computation, so the parameter's requires_grad is True. However, for any calculation you'd do inside this context, you won't be able to compute gradients. The requires_grad for the output of calculation would be False. This is probably best explained by extending your code snippet as below:

    with torch.no_grad():
         linear = torch.nn.Linear(2, 3)
         for p in linear.parameters():
             print(p.requires_grad)
         out  = linear(torch.rand(10,2))
         print(out.requires_grad)
    out = linear(torch.rand(10,2)) 
    print(out.requires_grad)
    
    True
    True
    False
    True
    

    Even if the requires_grad for layer parameters is True, you won't be able to compute the gradient as the output has requires_grad False.