Search code examples
pytorchdisable

Disable grad and backward Globally?


How to disable GLOBALLY grad,backward and any other non forward() functionality in Torch ?

I see examples of how to do it locally but not globally ?

The Docs say that what may be I'm looking is Inference only mode ! but how to set it globally.


Solution

  • You can use torch.set_grad_enabled(False) to disable gradient propagation globally for the entire thread. Besides, after you called torch.set_grad_enabled(False), doing anything like backward() will raise an exception.

    a = torch.tensor(np.random.rand(64,5),dtype=torch.float32)
    l = torch.nn.Linear(5,10)
    
    o = torch.sum(l(a))
    print(o.requires_grad) #True
    o.backward()
    print(l.weight.grad) #showed gradients
    
    torch.set_grad_enabled(False)
    
    o = torch.sum(l(a))
    print(o.requires_grad) #False
    o.backward()# RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
    print(l.weight.grad)