How to disable GLOBALLY grad,backward and any other non forward() functionality in Torch ?
I see examples of how to do it locally but not globally ?
The Docs say that what may be I'm looking is Inference only mode ! but how to set it globally.
You can use torch.set_grad_enabled(False)
to disable gradient propagation globally for the entire thread. Besides, after you called torch.set_grad_enabled(False)
, doing anything like backward()
will raise an exception.
a = torch.tensor(np.random.rand(64,5),dtype=torch.float32)
l = torch.nn.Linear(5,10)
o = torch.sum(l(a))
print(o.requires_grad) #True
print(l.weight.grad) #showed gradients
o = torch.sum(l(a))
print(o.requires_grad) #False
o.backward()# RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn