When I try to save the PyTorch model with this piece of code:
checkpoint = {'model': Net(), 'state_dict': model.state_dict(),'optimizer' :optimizer.state_dict()}
torch.save(checkpoint, 'Checkpoint.pth')
I get the following error:
E:\PROGRAM FILES\Anaconda\envs\staj_projesi\lib\site-packages\torch\serialization.py:251: UserWarning: Couldn't retrieve source code for container of type Net. It won't be checked for correctness upon loading.
...
"type " + obj.__name__ + ". It won't be checked "
Can't pickle local object 'trainModel.<locals>.Net'
When I try to save the PyTorch model with this piece of code:
checkpoint = {'state_dict': model.state_dict(),'optimizer' :optimizer.state_dict()}
torch.save(checkpoint, 'Checkpoint.pth')
I don't don't get any errors, but I want to save the ANN class. How can I solve this problem? Also, I could save the model with the first structure in the other projects before
You can't! torch.save
is saving the objects state_dict()
only.
When you use the following:
checkpoint = {'model': Net(), 'state_dict': model.state_dict(),'optimizer' :optimizer.state_dict()}
torch.save(checkpoint, 'Checkpoint.pth')
You are trying to save the model itself, but this data is saved in the model.state_dict()
and when loading a model with the state_dict
you should first initiate a model object.
This is exactly the reason why the second method works properly:
checkpoint = {'state_dict': model.state_dict(),'optimizer' :optimizer.state_dict()}
torch.save(checkpoint, 'Checkpoint.pth')
I would suggest reading the pytorch docs of how to properly save\load a model in the following link: https://pytorch.org/tutorials/beginner/saving_loading_models.html