Search code examples
pytorch

Accumulating gradients for a larger batch size with PyTorch


In order to mimick a larger batch size, I want to be able to accumulate gradients every N batches for a model in PyTorch, like:

def train(model, optimizer, dataloader, num_epochs, N):
     for epoch_num in range(1, num_epochs+1):
         for batch_num, data in enumerate(dataloader):
             ims = data.to('cuda:0') 
             loss = model(ims)
             loss.backward()
             if batch_num % N == 0:
                 optimizer.step()
                 optimizer.zero_grad(set_to_none=True)

For this approach do I need to add the flag retain_graph=True, i.e.

loss.backward(retain_graph=True)

In this manner, are the gradients per each backward call simply summed per each parameter?


Solution

  • You need to set retain_graph=True if you want to make multiple backward passes over the same computational graph, making use of the intermediate results from a single forward pass. This would have been the case, for instance, if you called loss.backward() multiple times after computing loss once, or if you had multiple losses from different parts of the graph to backpropagate from (a good explanation can be found here).

    In your case, for each forward pass, you backpropagate exactly once. So you don't need to store the intermediate results from the computational graph once the gradients are computed.

    In short:

    • Intermediate outputs in the graph are cleared after a backward pass, unless explicitly preserved using retain_graph=True.
    • Gradients accumulate by default, unless explicitly cleared using zero_grad.