Search code examples
neural-networkpytorchpytorch-lightning

pytorch lightning epoch_end/validation_epoch_end


Could anybody breakdown the code and explain it to me? The part that needs help is indicated with the "#This part". I would greatly appreciate any help thanks

def validation_epoch_end(self, outputs):
    batch_losses = [x["val_loss"]for x in outputs] #This part
    epoch_loss = torch.stack(batch_losses).mean() 
    batch_accs =  [x["val_acc"]for x in outputs]   #This part
    epoch_acc = torch.stack(batch_accs).mean()   
    return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}

def epoch_end(self, epoch, result):
    print("Epoch [{}], val_loss: {:.4f}, val_acc: {:.4f}".format( epoch,result['val_loss'], result['val_acc'])) #This part

Solution

  • In your provided snippet, outputs is a list containing dicts elements which seem to contain at least keys "val_loss", and "val_acc". It would be fair to assume they correspond to the validation loss and validation accuracy respectively.

    Those two lines (annotated with the # This path comment) correspond to list comprehensions going over the elements inside the outputs list. The first one gathers the values of the key "val_loss" for each element in outputs. The second one does the same this time gathering the values of the "val_acc" key.

    A minimal example would be:

    ## before
    outputs = [{'val_loss': tensor(a), # element 0
                'val_acc': tensor(b)},
    
               {'val_loss': tensor(c), # element 1
                'val_acc': tensor(d)}]
    
    ## after
    batch_losses = [tensor(a), tensor(c)]
    batch_acc = [tensor(b), tensor(d)]