Search code examples
pythonpytorchconv-neural-networklstm

PyTorch dictionary keys not matching


I am trying to implement a convolutional LSTM I found online, and it seems that the dictionary keys are not matching:

The pre-trained weights are in a pickled dictionary with the following keys:

pkl_load = torch.load(trained_model_dir)
print(pkl_load.keys())

odict_keys(['module.E.conv1.weight', 'module.E.bn1.weight', 'module.E.bn1.bias', ....

However, the keys in the state_dict for the actual NN model are:

"E.conv1.weight", "E.bn1.weight", "E.bn1.bias", ....

I am getting an error when trying to load the pre-trained weights into the state_dict because the keys don't match. What are ways to work around this? (Sorry if this is easy, I am new to PyTorch).


Solution

  • You could do something like:

    keys = ['module.E.conv1.weight', 'module.E.bn1.weight', 'module.E.bn1.bias']
    res = []
    for key in keys:
        words = key.split('.')
        tempRes = words[1:]
        newWord = '.'.join(tempRes)
        res.append(newWord)
    print(res)
    

    output:

    ['E.conv1.weight', 'E.bn1.weight', 'E.bn1.bias']