Search code examples
pythonoptimizationdeep-learningpytorchgradient-descent

how is the optmization done when we use zero_grad() in PyTorch?


zero_grad() method is used when we want to "conserve" RAM with massive datasets. There was already an answer on that, here : Why do we need to call zero_grad() in PyTorch?.

Gradients are used for the update of the parameters during back prop. But if we delete the gradients by setting them at 0, how can the optimization be done during the backward propagation ? There are models where we use this method and there is still an optimization that is occurring, how is this possible ?


Solution

  • You don't "delete the gradients", you simply clear the cache of gradients from previous iteration. The reason of existence of this cache is ease of implementation of specific methods such as simulation of big batch without memory to actually use the whole batch.