Search code examples
pythondeep-learningpytorchgradient-descentlearning-rate

Why do we multiply learning rate by gradient accumulation steps in PyTorch?


Loss functions in pytorch use "mean" reduction. So it means that the model gradient will have roughly the same magnitude given any batch size. It makes sense that you want to scale the learning rate up when you increase batch size because your gradient doesn't become bigger as you increase batch size.

For gradient accumulation in PyTorch, it will "sum" the gradient N times where N is the number of times you call backward() before you call step(). My intuition is that this would increase the magnitude of the gradient and you should reduce the learning rate, or at least not increase it.

But I saw people wrote multiplication to gradient accumulation steps in this repo:

if args.scale_lr:
    args.learning_rate = (
        args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
    )

I also see similar code in this repo:

model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr

I understand why you want to increase the learning rate by batch size. But I don't understand why they try to increase the learning rate by the number of accumulation steps.

  1. Do they divide the loss by N to reduce the magnitude of the gradient? Otherwise why do they multiply learning rate by the accumulation steps?
  2. How are gradients from different GPUs accumulated? Is it using mean or sum? If it's sum, why are they multiplying the learning rate by nGPUs?

Solution

  • I found that they indeed divided the loss by N (number of gradient accumulation steps). You can see sample code from accelerate package here: https://huggingface.co/docs/accelerate/usage_guides/gradient_accumulation

    Notice the following line of code from the guide above: loss = loss / gradient_accumulation_steps

    This is why you need to multiply the learning rate by gradient accumulation steps to cancel the above division.

    I assume that the same procedure also happens in PyTorch Lightning. I asked a related Lightning question at the github discussion here: https://github.com/Lightning-AI/lightning/discussions/17035

    I hope that someone will answer later that Trainer in Lightning does the same division process. The evidence from accelerate package made me think that gradients from different GPUs are also averaged, not summed. If they are going to be summed, the loss on each GPU has to be divided by the number of GPUs.

    This leads to a simple intuition about gradients in most PyTorch training workflows: No matter how big or small the batch is, the gradient will always have roughly the same magnitude. If you check the magnitude of the gradient right before step() call, it should stay roughly the same even if you vary batch size, number of gradient accumulation steps, number of GPUs, or even number of computers.