Search code examples
pythonpytorchpytorch-lightning

PyTorch Lightning inference after each epoch


I'm using pytorch lightning, and, after each epoch, I'm running inference on a small dataset to produce a figure that I monitor with weight & biases.

I thought the natural way to do that was to use a Callback with a on_train_epoch_end method that generates the plot. The latter method needs to run some inference, therefore I wanted to use trainer.predict. Yet, when doing this, I get the error below, so I guess it's not the intented way to do that.

Minimal reproducible example:

import lightning as L
from lightning.pytorch.callbacks import Callback

import torch
from torch.utils.data import DataLoader
from torch import nn, optim

class Model(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.f = nn.Linear(10, 1)
        
    def training_step(self, batch, *args):
        out = self(batch)
        return out.mean() ** 2
    
    def forward(self, x):
        return self.f(x)[:, 0]

    def train_dataloader(self):
        return DataLoader(torch.randn((100, 10)))
    
    def predict_dataloader(self):
        return DataLoader(torch.randn((100, 10)))
    
    def predict_step(self, batch):
        return self(batch)
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
    
class CallbackExample(Callback):
    def on_train_epoch_end(self, trainer: L.Trainer, model: Model) -> None:
        loader = model.predict_dataloader()
        trainer.predict(model, loader)
        
        ... # save figure to wandb

model = Model()
callback = CallbackExample()
trainer = L.Trainer(max_epochs=2, callbacks=callback, accelerator="mps")

trainer.fit(model)
File ~/Library/Caches/pypoetry/virtualenvs/novae-ezkWKrh6-py3.9/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:233, in _LoggerConnector.metrics(self)
    231 """This function returns either batch or epoch metrics."""
    232 on_step = self._first_loop_iter is not None
--> 233 assert self.trainer._results is not None
    234 return self.trainer._results.metrics(on_step)

AssertionError: 

What is the most natural and elegant way to do it?


Solution

  • Using .transfer_batch_to_device solved it:

    class PlotCallback(Callback):
        def on_train_epoch_end(self, trainer: L.Trainer, model: Model) -> None:
            loader = model.predict_dataloader()
            for batch in loader:
                batch = model.transfer_batch_to_device(batch, model.device, 0)
                model.predict_step(batch)
            
            ... # save figure to wandb