Search code examples
pythondeep-learningneural-networkpytorch

How to use trained pytorch model for prediction


I have a pretrained pytorch model which is saved in .pth format. How can i use it for prediction on new dataset in a separate python file.

I need a detailed guide.


Solution

  • To use a pretrained model you should load the state on a new instance of the architecture as explained in the docs/tutorials:

    Here models is imported beforehand:

    model = models.vgg16()
    model.load_state_dict(torch.load('model_weights.pth')) # This line uses .load() to read a .pth file and load the network weights on to the architecture.
    model.eval() # enabling the eval mode to test with new samples.
    

    If you are using a custom architecture you only need to change the first line.

    model = MyCustomModel()
    

    After enabling the eval mode, you can proceed as follows:

    • Load your data into a Dataset instance and then in a DataLoader.
    • Make your predictions with the data.
    • Calculate metrics on the results.

    More about Dataset and DataLoader here.