Could anybody breakdown the code and explain it to me? The part that needs help is indicated with the "#This part". I would greatly appreciate any help thanks
def validation_epoch_end(self, outputs):
batch_losses = [x["val_loss"]for x in outputs] #This part
epoch_loss = torch.stack(batch_losses).mean()
batch_accs = [x["val_acc"]for x in outputs] #This part
epoch_acc = torch.stack(batch_accs).mean()
return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}
def epoch_end(self, epoch, result):
print("Epoch [{}], val_loss: {:.4f}, val_acc: {:.4f}".format( epoch,result['val_loss'], result['val_acc'])) #This part
In your provided snippet, outputs
is a list containing dicts elements which seem to contain at least keys "val_loss"
, and "val_acc"
. It would be fair to assume they correspond to the validation loss and validation accuracy respectively.
Those two lines (annotated with the # This path
comment) correspond to list comprehensions going over the elements inside the outputs
list. The first one gathers the values of the key "val_loss"
for each element in outputs. The second one does the same this time gathering the values of the "val_acc"
key.
A minimal example would be:
## before
outputs = [{'val_loss': tensor(a), # element 0
'val_acc': tensor(b)},
{'val_loss': tensor(c), # element 1
'val_acc': tensor(d)}]
## after
batch_losses = [tensor(a), tensor(c)]
batch_acc = [tensor(b), tensor(d)]