Search code examples

Clarification about Gradient Accumulation

I'm trying to get a better understanding of how Gradient Accumulation works and why it is useful. To this end, I wanted to ask what is the difference (if any) between these two possible PyTorch-like implementations of a custom training loop with gradient accumulation:

gradient_accumulation_steps = 5
for batch_idx, batch in enumerate(dataset):
  x_batch, y_true_batch = batch
  y_pred_batch = model(x_batch)

  loss = loss_fn(y_true_batch, y_pred_batch)

  if (batch_idx + 1) % gradient_accumulation_steps == 0: # (assumption: the number of batches is a multiple of gradient_accumulation_steps)
y_true_batches, y_pred_batches = [], []
gradient_accumulation_steps = 5
for batch_idx, batch in enumerate(dataset):
  x_batch, y_true_batch = batch
  y_pred_batch = model(x_batch)


  if (batch_idx + 1) % gradient_accumulation_steps == 0: # (assumption: the number of batches is a multiple of gradient_accumulation_steps)
    y_true = stack_vertically(y_true_batches)
    y_pred = stack_vertically(y_pred_batches)

    loss = loss_fn(y_true, y_pred)


Also, kind of as an unrelated question: Since the purpose of gradient accumulation is to mimic a larger batch size in cases where you have memory constraints, does it mean that I should also increase the learning rate proportionally?


  • 1. The difference between the two programs:
    Conceptually, your two implementations are the same: you forward gradient_accumulation_steps batches for each weight update.
    As you already observed, the second method requires more memory resources than the first one.

    There is, however, a slight difference: usually, loss functions implementation use mean to reduce the loss over the batch. When you use gradient accumulation (first implementation) you reduce using mean over each mini-batch, but using sum over the accumulated gradient_accumulation_steps mini-batches. To make sure the accumulated gradient implementation is identical to large batches implementation you need to be very careful in the way the loss function is reduced. In many cases you will need to divide the accumulated loss by gradient_accumulation_steps. See this answer for a detailed imlpementation.

    2. Batch size and learning rate: Learning rate and batch size are indeed related. When increasing the batch size one usually reduces the learning rate.
    See, e.g.:
    Samuel L. Smith, Pieter-Jan Kindermans, Chris Ying, Quoc V. Le, Don't Decay the Learning Rate, Increase the Batch Size (ICLR 2018).