Search code examples
pytorchpytorch-lightningpytorch-forecasting

Deativate Pytorch Lightning Module Model Logging During Prediction


I am trying to serve a Pytorch Forecasting model using FastAPI. I am loading the model from a checkpoint using the following code on startup:

model = BaseModel.load_from_checkpoint(model_path)

model.eval()

Although the predictions do come up fine, every time there's a new version generated in the lightining_logs folder with the hyperparameters stored in a new file after each prediction. I use the following code for the predictions:

raw_predictions = model.predict(df, mode="raw", return_x=True)

How can I stop logging when I serve the model for predictions?


Solution

  • Someone posted the answer on GitHub around the same time I discovered it after doing lots of reading. It's not that evident, at least for me:

    trainer_kwargs={'logger':False}
    

    In the case of the code in my question the prediction part would turn into:

    raw_predictions = model.predict(df, mode="raw", return_x=False, trainer_kwardgs=dict(accelarator="cpu|gpu", logger=False))