How can the trainer.test
method be used to get total accuracy over all batches?
I know I can implement model.test_step
but that is for a single batch only. I need the accuracy over the whole data set. I can use torchmetrics.Accuracy
to accumulate accuracy. But what is the proper way to combine that and get the total accuracy out? What is model.test_step
supposed to return anyway since batchwise test scores are not very useful? I could hack it somehow, but I'm surprised that I couldn't find any example on the internet that demonstrates how to get accuracy with the pytorch-lightning native way.
You can see here (https://pytorch-lightning.readthedocs.io/en/stable/extensions/logging.html#automatic-logging) that the on_epoch
argument in log
automatically accumulates and logs at the end of the epoch. The right way of doing this would be:
from torchmetrics import Accuracy
def validation_step(self, batch, batch_idx):
x, y = batch
preds = self.forward(x)
loss = self.criterion(preds, y)
accuracy = Accuracy()
acc = accuracy(preds, y)
self.log('accuracy', acc, on_epoch=True)
return loss
If you want a custom reduction function you can set it using the reduce_fx
argument, the default is torch.mean()
. log()
can be called from any method in you LightningModule