Search code examples
loggingpytorchtensorboardpytorch-lightning

How to extract loss and accuracy from logger by each epoch in pytorch lightning?


I want to extract all data to make the plot, not with tensorboard. My understanding is all log with loss and accuracy is stored in a defined directory since tensorboard draw the line graph.

%reload_ext tensorboard
%tensorboard --logdir lightning_logs/

enter image description here

However, I wonder how all log can be extracted from the logger in pytorch lightning. The next is the code example in training part.

#model
ssl_classifier = SSLImageClassifier(lr=lr)

#train
logger = pl.loggers.TensorBoardLogger(name=f'ssl-{lr}-{num_epoch}', save_dir='lightning_logs')

trainer = pl.Trainer(progress_bar_refresh_rate=20,
                            gpus=1,
                            max_epochs = max_epoch,
                            logger = logger,
                            )

trainer.fit(ssl_classifier, train_loader, val_loader)

I had confirmed that trainer.logger.log_dir returned directory which seems to save logs and trainer.logger.log_metrics returned <bound method TensorBoardLogger.log_metrics of <pytorch_lightning.loggers.tensorboard.TensorBoardLogger object at 0x7efcb89a3e50>>.

trainer.logged_metrics returned only the log in the final epoch, like

{'epoch': 19,
 'train_acc': tensor(1.),
 'train_loss': tensor(0.1038),
 'val_acc': 0.6499999761581421,
 'val_loss': 1.2171183824539185}

Do you know how to solve the situation?


Solution

  • Lightning do not store all logs by itself. All it does is streams them into the logger instance and the logger decides what to do.

    The best way to retrieve all logged metrics is by having a custom callback:

    class MetricTracker(Callback):
    
      def __init__(self):
        self.collection = []
    
      def on_validation_batch_end(trainer, module, outputs, ...):
        vacc = outputs['val_acc'] # you can access them here
        self.collection.append(vacc) # track them
    
      def on_validation_epoch_end(trainer, module):
        elogs = trainer.logged_metrics # access it here
        self.collection.append(elogs)
        # do whatever is needed
    

    You can then access all logged stuff from the callback instance

    cb = MetricTracker()
    Trainer(callbacks=[cb])
    
    cb.collection # do you plotting and stuff