Search code examples
pytorchpython-importpython-modulepython-class

Derived Class of Pytorch nn.Module Cannot be Loaded by Module Import in Python


Using Python 3.6 with Pytorch 1.3.1. I have noticed that some saved nn.Modules cannot be loaded when the whole module is being imported into another module. To give an example, here is the template of a minimum working example.

#!/usr/bin/env python3
#encoding:utf-8
# file 'dnn_predict.py'

from torch import nn
class NN(nn.Module):##NN network
    # Initialisation and other class methods

networks=[torch.load(f=os.path.join(resource_directory, 'nn-classify-cpu_{fold}.pkl'.format(fold=fold))) for fold in range(5)]
...
if __name__=='__main__':
    # Some testing snippets
    pass

The whole file works just fine when I run it in the shell directly. However, when I want to use the class and load the neural network in another file using this code, it fails.

#!/usr/bin/env python3
#encoding:utf-8
from dnn_predict import *

The error reads AttributeError: Can't get attribute 'NN' on <module '__main__'>

Does loading of saved variables or importing modules happen differently in Pytorch than other common Python libraries? Some help or pointer to the root cause will be really appreciated.


Solution

  • When you save a model with torch.save(model, PATH) the whole object gets serialised with pickle, which does not save the class itself, but a path to the file containing the class, hence when loading the model the exact same directory and file structure is required to find the correct class. When running a Python script, the module of that file is __main__, therefore if you want to load that module, your NN class must be defined in the script you're running.

    That is very inflexible, so the recommended approach is to not save the entire model, but instead just save the state dictionary, which only saves the parameters of the model.

    # Save the state dictionary of the model
    torch.save(model.state_dict(), PATH)
    

    Afterwards, the state dictionary can be loaded and applied to your model.

    from dnn_predict import NN
    
    # Create the model (will have randomly initialised parameters)
    model = NN()
    
    # Load the previously saved state dictionary
    state_dict = torch.load(PATH)
    
    # Apply the state dictionary to the model
    model.load_state_dict(state_dict)
    

    More details on the state dictionary and saving/loading the models: PyTorch - Saving and Loading Models