Search code examples
pytorchloss-function

Pytorch: correct way to sum batch loss with epoch loss


I'm calculating two losses. One per batch and one per epoch, at the end of the batches loop. When I try to sum these two losses I get the following error:

RuntimeError: one of the variables needed for gradient computation has been modified by   an inplace operation: [torch.FloatTensor [64, 49]], which is output 0 of AsStridedBackward0, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

I have my reasons for summing these two losses.

The general idea of the code is something like this:

loss_epoch = 0 # it's zero in the first epoch

for epoch in epochs:
    for batch in batches:
        optimizer.zero_grad()
    
        loss_batch = criterion_batch(output_batch, target_batch)
        loss = loss_batch + loss_epoch # adds zero in the first epoch
    
        loss.backward()
        optimizer.step()
    
    loss_epoch = criterion_epoch(output_epoch, target_epoch)
    

I get that the problem is I'm modifying the gradient when I calculate another loss at the end of the first loop (the loop that goes through the batches) but I couldn't solve this problem.

It also might have something to do with the order of the operations (loss calculation, backward, zero_grad, step).

I need to calculate the loss_epoch at the end of the batch loop because I'm using the entire dataset to calculate this loss.


Solution

  • Assuming that you do not want to backpropagate the epoch_loss through every forward pass for the entire dataset (which of course would be computationally infeasible for a dataset of any non-trivial size), you could detach the epoch_loss and essentially add it as a scalar which is updated once per epoch. Not entirely sure if this is the behavior you want though.