Search code examples
pythonmachine-learningpytorch

AdamW and Adam with weight decay


Is there any difference between torch.optim.Adam(weight_decay=0.01) and torch.optim.AdamW(weight_decay=0.01)?

Link to the docs: torch.optim.


Solution

  • Yes, Adam and AdamW weight decay are different.

    Hutter pointed out in their paper (Decoupled Weight Decay Regularization) that the way weight decay is implemented in Adam in every library seems to be wrong, and proposed a simple way (which they call AdamW) to fix it.

    In Adam, the weight decay is usually implemented by adding wd*w (wd is weight decay here) to the gradients (Ist case), rather than actually subtracting from weights (IInd case).

    # Ist: Adam weight decay implementation (L2 regularization)
    final_loss = loss + wd * all_weights.pow(2).sum() / 2
    # IInd: equivalent to this in SGD
    w = w - lr * w.grad - lr * wd * w
    

    These methods are same for vanilla SGD, but as soon as we add momentum, or use a more sophisticated optimizer like Adam, L2 regularization (first equation) and weight decay (second equation) become different.

    AdamW follows the second equation for weight decay.

    In Adam

    weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)

    In AdamW

    weight_decay (float, optional) – weight decay coefficient (default: 1e-2)

    Read more on the fastai blog.