Search code examples
pythonmachine-learningpytorch

Why is nn.Parameter not being saved in model state_dict?


I'm trying to train a custom model in Pytorch (version 1.9.1), and save the model weights at each training step. To do this, I'm using torch.save(model.state_dict(), 'filename.pt'). After training, when I attempt to load the saved weights I get an error saying there is a missing key in the model's state_dict, indicating that some part of the model isn't saving properly.

I tracked down the culprit, which is a nn.Parameter object initialized like so:

self.class_token = nn.Parameter(torch.rand(1, self.hidden_d)).to(self.device)

I have no idea why it isn't saving. Reports of similar issues seem to imply that it might be the conversion between devices, but all the other layers in the network are converted the same way as this parameter. Additionally, the model code was written by a colleague of mine, and the saving/loading worked for him without issues. It's possible that it's a version issue, and that he was running a more recent version of Pytorch, but I'm running it on a system where upgrading to Pytorch 2.0 will be difficult if not impossible.

Is there another way to convert between the devices that won't detach it from the model? Otherwise, is there something else I have to do to manually add this parameter so that it's saved in the state_dict?


Solution

  • Move the .to(self.device) operation into the Parameter call:

    self.class_token = nn.Parameter(torch.rand(1, self.hidden_d).to(self.device))
    

    You might be facing this issue because you are trying to load the tensor for GPU but instead, the tensor is getting loaded for your CPU.

    Check a similar issue here.