Search code examples
deep-learningpytorchrecurrent-neural-networkgradient-descent

Gradient accumulation in an RNN


I ran into some memory issues (GPU) when running a large RNN network, but I want to keep my batch size reasonable so I wanted to try out gradient accumulation. In a network where you predict the output in one go, that seems self-evident but in an RNN you do multiple forward passes for each input step. Because of that, I fear that my implementation does not work as intended. I started from user albanD's excellent examples here , but I think they should be modified when using an RNN. The reason I think that is because you accumulate much more gradients because you do multiple forwards per sequence.

My current implementation looks like this, at the same time allowing for AMP in PyTorch 1.6 which seems important - everything needs to be called in the right place. Note that this is just an abstract version, which might seem like a lot of code but it is mostly comments.

def train(epochs):
    """Main training loop. Loops for `epoch` number of epochs. Calls `process`."""
    for epoch in range(1, epochs + 1):
        train_loss = process("train")
        valid_loss = process("valid")
        # ... check whether we improved over earlier epochs
        if lr_scheduler:
            lr_scheduler.step(valid_loss)
        
def process(do):
    """Do a single epoch run through the dataloader of the training or validation set. 
       Also takes care of optimizing the model after every `gradient_accumulation_steps` steps.
       Calls `step` for each batch where it gets the loss from."""
    if do == "train":
        model.train()
        torch.set_grad_enabled(True)
    else:
        model.eval()
        torch.set_grad_enabled(False)
    
    loss = 0.
    for batch_idx, batch in enumerate(dataloaders[do]):
        step_loss, avg_step_loss = step(batch)
        loss += avg_step_loss

        if do == "train":
            if amp:
                scaler.scale(step_loss).backward()

                if (batch_idx + 1) % gradient_accumulation_steps == 0:
                    # Unscales the gradients of optimizer's assigned params in-place
                    scaler.unscale_(optimizer)
                    # clip in-place
                    clip_grad_norm_(model.parameters(), 2.0)
                    scaler.step(optimizer)
                    scaler.update()
                    model.zero_grad()
            else:
                step_loss.backward()
                if (batch_idx + 1) % gradient_accumulation_steps == 0:
                    clip_grad_norm_(model.parameters(), 2.0)
                    optimizer.step()
                    model.zero_grad()
        
        # return average loss
        return loss / len(dataloaders[do])

    def step():
        """Processes one step (one batch) by forwarding multiple times to get a final prediction for a given sequence."""
        # do stuff... init hidden state and first input etc.
        loss = torch.tensor([0.]).to(device)
        
        for i in range(target_len):
            with torch.cuda.amp.autocast(enabled=amp):
                # overwrite previous decoder_hidden
                output, decoder_hidden = model(decoder_input, decoder_hidden)

                # compute loss between predicted classes (bs x classes) and correct classes for _this word_
                item_loss = criterion(output, target_tensor[i])

                # We calculate the gradients for the average step so that when
                # we do take an optimizer.step, it takes into account the mean step_loss
                # across batches. So basically (A+B+C)/3 = A/3 + B/3 + C/3
                loss += (item_loss / gradient_accumulation_steps)

            topv, topi = output.topk(1)
            decoder_input = topi.detach()
        
        return loss, loss.item() / target_len

The above does not seem to work as I had hoped, i.e. it still runs into out-of-memory issues very quickly. I think the reason is that step already accumulates so much information, but I am not sure.


Solution

  • For simplicity, I will only take care of amp enabled gradient accumulation, without amp the idea is the same. And your step presented runs under amp so let's stick to that.

    step

    In PyTorch documentation about amp you have an example of gradient accumulation. You should do it inside step. Each time you run loss.backward() gradient is accumulated inside tensor leafs which can be optimized by optimizer. Hence, your step should look like this (see comments):

    def step():
        """Processes one step (one batch) by forwarding multiple times to get a final prediction for a given sequence."""
        # You should not accumulate loss on `GPU`, RAM and CPU is better for that
        # Use GPU only for calculations, not for gathering metrics etc.
        loss = 0
    
        for i in range(target_len):
            with torch.cuda.amp.autocast(enabled=amp):
                # where decoder_input is from?
                # I assume there is one in real code
                output, decoder_hidden = model(decoder_input, decoder_hidden)
                # Here you divide by accumulation steps
                item_loss = criterion(output, target_tensor[i]) / (
                    gradient_accumulation_steps * target_len
                )
    
    
            scaler.scale(item_loss).backward()
            loss += item_loss.detach().item()
    
            # Not sure what was topv for here
            _, topi = output.topk(1)
            decoder_input = topi.detach()
    
        # No need to return loss now as we did backward above
        return loss / target_len
    

    As you detach decoder_input anyway (so it is like totally new hidden input without history and parameters will be optimized based on that, not based on all runs) there is no need for backward in process. Also, you probably don't need decoder_hidden, if it isn't passed to the network, torch.tensor filled with zeros is passed implicitly.

    Also we should divide by gradient_accumulation_steps * target_len as that's how many backwards we will run before single optimization step.

    As some of your variables are ill-defined I assume you just made a scheme of what's going on.

    Also, if you want the history to be kept you shouldn't detach decoder_input, in this case it would look like this:

    def step():
        """Processes one step (one batch) by forwarding multiple times to get a final prediction for a given sequence."""
        loss = 0
    
        for i in range(target_len):
            with torch.cuda.amp.autocast(enabled=amp):
                output, decoder_hidden = model(decoder_input, decoder_hidden)
                item_loss = criterion(output, target_tensor[i]) / (
                    gradient_accumulation_steps * target_len
                )
    
            _, topi = output.topk(1)
            decoder_input = topi
    
            loss += item_loss
        scaler.scale(loss).backward()
        return loss.detach().cpu() / target_len
    

    This effectively goes through RNN multiple times and will probably raise OOM, not sure what you are after here. If that's the case there's not much you can do AFAIK as the RNN computations are simply too long to fit into the GPU.

    process

    Only relevant part of this code is presented, so it would be:

    loss = 0.0
    for batch_idx, batch in enumerate(dataloaders[do]):
        # Here everything is detached from graph so we're safe
        avg_step_loss = step(batch)
        loss += avg_step_loss
    
        if do == "train":
            if (batch_idx + 1) % gradient_accumulation_steps == 0:
                # You can use unscale as in the example in PyTorch's docs
                # just like you did
                scaler.unscale_(optimizer)
                # clip in-place
                clip_grad_norm_(model.parameters(), 2.0)
                scaler.step(optimizer)
                scaler.update()
                # IMO in this case optimizer.zero_grad is more readable
                # but it's a nitpicking
                optimizer.zero_grad()
    
    # return average loss
    return loss / len(dataloaders[do])
    

    Question-like

    [...] in an RNN you do multiple forward passes for each input step. Because of that, I fear that my implementation does not work as intended.

    It does not matter. For each forward you should usually do one backward (seems to be the case here, see steps for possible options). After that we (usually) don't need loss connected to graph as we already performed backpropagation, got our gradients and are ready to optimize parameters.

    That loss needs to have history, as it goes back to the process loop where backward will be called on it.

    No need to call backward in process as presented.