Search code examples
pythondeep-learningpytorchonnxnnet

Saving and Loading a PyTorch NN model (.nnet or .onnx format)


I am trying to train and save a PyTorch model locally in my computer (preferably in .nnet or .onnet format).

# Defining the neural network class
class Net(nn.Module):
    def __init__(self, input_size, hidden_size1, hidden_size2, output_size):
        super(Net, self).__init__()
        self.hidden1 = nn.Linear(input_size, hidden_size1)
        self.hidden2 = nn.Linear(hidden_size1, hidden_size2)
        self.output = nn.Linear(hidden_size2, output_size)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.hidden1(x))
        x = self.relu(self.hidden2(x))
        x = self.output(x)
        return x

# Defining the input size, hidden layer sizes, and output size
input_size =5
hidden_size1 = 2
hidden_size2 = 3
output_size = 5

# Creating an instance of the neural network
model = Net(input_size, hidden_size1, hidden_size2, output_size)

# Printing the model architecture
print(model)

I saved the model in .nnet format using the following code

torch.save(model,'theModel.nnet')

I want to later load the model into a PyTorch object and use the model later independently without writing the same code. How can I do this ?

I tried loading the model using

saved_model=torch.load('theModel.nnet')

It throws the error

AttributeError                            Traceback (most recent call last)
Cell In[7], line 1
----> 1 saved_model=torch.load('theModel.nnet')

File ~\anaconda3\lib\site-packages\torch\serialization.py:712, in load(f, map_location, pickle_module, **pickle_load_args)
    710             opened_file.seek(orig_position)
    711             return torch.jit.load(opened_file)
--> 712         return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
    713 return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)

File ~\anaconda3\lib\site-packages\torch\serialization.py:1049, in _load(zip_file, map_location, pickle_module, pickle_file, **pickle_load_args)
   1047 unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
   1048 unpickler.persistent_load = persistent_load
-> 1049 result = unpickler.load()
   1051 torch._utils._validate_loaded_sparse_tensors()
   1053 return result

File ~\anaconda3\lib\site-packages\torch\serialization.py:1042, in _load.<locals>.UnpicklerWrapper.find_class(self, mod_name, name)
   1040         pass
   1041 mod_name = load_module_mapping.get(mod_name, mod_name)
-> 1042 return super().find_class(mod_name, name)

AttributeError: Can't get attribute 'Net' on <module '__main__'>

Is there an alternative way to this ?


Solution

  • Try

    torch.save(model.state_dict(),'theModel.nnet')
    

    and

    state_dict = torch.load('theModel.nnet')
    model.load_state_dict(state_dict)
    

    where model is instantiated as above model = Net(...)