Search code examples
pytorch-lightning

How to get total test accuracy for pytorch lightning?


How can the trainer.test method be used to get total accuracy over all batches?

I know I can implement model.test_step but that is for a single batch only. I need the accuracy over the whole data set. I can use torchmetrics.Accuracy to accumulate accuracy. But what is the proper way to combine that and get the total accuracy out? What is model.test_step supposed to return anyway since batchwise test scores are not very useful? I could hack it somehow, but I'm surprised that I couldn't find any example on the internet that demonstrates how to get accuracy with the pytorch-lightning native way.


Solution

  • You can see here (https://pytorch-lightning.readthedocs.io/en/stable/extensions/logging.html#automatic-logging) that the on_epoch argument in log automatically accumulates and logs at the end of the epoch. The right way of doing this would be:

    from torchmetrics import Accuracy
    
    def validation_step(self, batch, batch_idx): 
        x, y = batch 
        preds = self.forward(x) 
        loss = self.criterion(preds, y) 
        accuracy = Accuracy()
        acc = accuracy(preds, y)
        self.log('accuracy', acc, on_epoch=True)
        return loss 
    

    If you want a custom reduction function you can set it using the reduce_fx argument, the default is torch.mean(). log() can be called from any method in you LightningModule