Using Python 3.6 with Pytorch 1.3.1. I have noticed that some saved nn.Modules cannot be loaded when the whole module is being imported into another module. To give an example, here is the template of a minimum working example.
#!/usr/bin/env python3
#encoding:utf-8
# file 'dnn_predict.py'
from torch import nn
class NN(nn.Module):##NN network
# Initialisation and other class methods
networks=[torch.load(f=os.path.join(resource_directory, 'nn-classify-cpu_{fold}.pkl'.format(fold=fold))) for fold in range(5)]
...
if __name__=='__main__':
# Some testing snippets
pass
The whole file works just fine when I run it in the shell directly. However, when I want to use the class and load the neural network in another file using this code, it fails.
#!/usr/bin/env python3
#encoding:utf-8
from dnn_predict import *
The error reads AttributeError: Can't get attribute 'NN' on <module '__main__'>
Does loading of saved variables or importing modules happen differently in Pytorch than other common Python libraries? Some help or pointer to the root cause will be really appreciated.
When you save a model with torch.save(model, PATH)
the whole object gets serialised with pickle
, which does not save the class itself, but a path to the file containing the class, hence when loading the model the exact same directory and file structure is required to find the correct class. When running a Python script, the module of that file is __main__
, therefore if you want to load that module, your NN
class must be defined in the script you're running.
That is very inflexible, so the recommended approach is to not save the entire model, but instead just save the state dictionary, which only saves the parameters of the model.
# Save the state dictionary of the model
torch.save(model.state_dict(), PATH)
Afterwards, the state dictionary can be loaded and applied to your model.
from dnn_predict import NN
# Create the model (will have randomly initialised parameters)
model = NN()
# Load the previously saved state dictionary
state_dict = torch.load(PATH)
# Apply the state dictionary to the model
model.load_state_dict(state_dict)
More details on the state dictionary and saving/loading the models: PyTorch - Saving and Loading Models