Search code examples
pythonoptimizationdeep-learningpytorchgradient-descent

How does a decaying learning rate schedule with AdamW influence the weight decay parameter?


According to the Pytorch documentation

https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html

the AdamW optimiser computes at each step the product of the learning rate gamma and the weight decay coefficient lambda. The product

gamma*lambda =: p

is then used as the actual weight for the weight decay step. To see this, consider the second line within the for-loop in the AdamW algorithm:

enter image description here

But what if the learning rate lambda shrinks after each epoch because we use (say) an exponential learning rate decay schedule? Is p consistently computed using the initial learning rate lambda and thus p stays constant during the whole training process? Or does p shrink dynamically as lambda shrinks due to an implicit interaction with the the learning rate decay schedule?

Thanks!


Solution

  • The function torch.optim._functional.adamw is called each time you step the optimizer using the current parameters of the optimizer (that call occurs at torch/optim/adamw.py:145). This is the function that actually updates the model parameter values. So after a learning-rate scheduler changes the optimizer parameters, the steps afterwards will use those parameters, not the initial ones.

    To verify this, the product is recomputed at each step in the code at torch/optim/_functional.py:137.