I have a pretrained pytorch model which is saved in .pth format. How can i use it for prediction on new dataset in a separate python file.
I need a detailed guide.
To use a pretrained model you should load the state on a new instance of the architecture as explained in the docs/tutorials:
Here models
is imported beforehand:
model = models.vgg16()
model.load_state_dict(torch.load('model_weights.pth')) # This line uses .load() to read a .pth file and load the network weights on to the architecture.
model.eval() # enabling the eval mode to test with new samples.
If you are using a custom architecture you only need to change the first line.
model = MyCustomModel()
After enabling the eval
mode, you can proceed as follows:
Dataset
instance and then in a DataLoader
.More about Dataset
and DataLoader
here.