Search code examples
pythonpytorchtorch

How to load a Pytorch model when the parameters are saved as numpy arrays?


On this GitHub repo, I've downloaded the pretrained model senet50_ft.

I load it like so:

import pickle
f = open('pretrained_models/senet50_ft_weight.pkl', 'rb')
state_dict = pickle.load(f, encoding='latin1')
f.close()

The state is loaded, the Github repos also provides the SENet model Class here.

So I managed to instanciate that model:

model = senet.senet50()

Then I Tried to load the state, but I got an error:

model.load_state_dict(state_dict)

Traceback (most recent call last):
  File "...\module.py", line 982, in _load_from_state_dict
    param.copy_(input_param)
TypeError: copy_(): argument 'other' (position 1) must be Tensor, not numpy.ndarray

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "...\module.py", line 1037, in load_state_dict
    load(self)
  File "...\module.py", line 1035, in load
    load(child, prefix + name + '.')
  File "...\module.py", line 1032, in load
    state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
  File "...\module.py", line 988, in _load_from_state_dict
    .format(key, param.size(), input_param.size(), ex.args))
TypeError: 'int' object is not callable

I tried to convert ndarray to Tensor by doing the following:

for key in state_dict.keys():
    state_dict[key] = torch.from_numpy(state_dict[key])

But I got an another error and I think I'm not going anywhere.

I'm new to PyTorch but I suspect that this model was serialized with an old version of PyTorch. Do you know if a solution exists?


Solution

  • They have a load_state_dict function that does what you want.