Search code examples
pythonobjectsavepytorchpickle

PyTorch model saving error: "Can't pickle local object"


When I try to save the PyTorch model with this piece of code:

checkpoint = {'model': Net(), 'state_dict': model.state_dict(),'optimizer' :optimizer.state_dict()}
torch.save(checkpoint, 'Checkpoint.pth')

I get the following error:

    E:\PROGRAM FILES\Anaconda\envs\staj_projesi\lib\site-packages\torch\serialization.py:251: UserWarning: Couldn't retrieve source code for container of type Net. It won't be checked for correctness upon loading.
...

      "type " + obj.__name__ + ". It won't be checked "
    Can't pickle local object 'trainModel.<locals>.Net'

When I try to save the PyTorch model with this piece of code:

checkpoint = {'state_dict': model.state_dict(),'optimizer' :optimizer.state_dict()}
torch.save(checkpoint, 'Checkpoint.pth')

I don't don't get any errors, but I want to save the ANN class. How can I solve this problem? Also, I could save the model with the first structure in the other projects before


Solution

  • You can't! torch.save is saving the objects state_dict() only.

    When you use the following:

    checkpoint = {'model': Net(), 'state_dict': model.state_dict(),'optimizer' :optimizer.state_dict()}
    torch.save(checkpoint, 'Checkpoint.pth')
    

    You are trying to save the model itself, but this data is saved in the model.state_dict() and when loading a model with the state_dict you should first initiate a model object.

    This is exactly the reason why the second method works properly:

    checkpoint = {'state_dict': model.state_dict(),'optimizer' :optimizer.state_dict()}
    torch.save(checkpoint, 'Checkpoint.pth')
    

    I would suggest reading the pytorch docs of how to properly save\load a model in the following link: https://pytorch.org/tutorials/beginner/saving_loading_models.html