Search code examples
pythondeep-learningpytorchtorchvision

How to load checkpoints from a model trained and saved with nn.DataParallel onto a model that doesn't use nn.DataParallel?


How can I load checkpoints from a model trained and saved with nn.DataParallel onto a model that doesn't use nn.DataParallel? I tried to remove the 'module.' from the state_dict, but I'm encountering a different error at the moment. This is the link to the ResNet-50 checkpoints.

from torchvision.models import ResNet50_Weights, resnet50

# Load the model
model = resnet50()
checkpoint_path = 'C:/res50-debiased.pth.tar'
checkpoint = torch.load(checkpoint_path)

state_dict = checkpoint['state_dict']

# creating new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v
# load params
model.load_state_dict(new_state_dict)

This gives an error

RuntimeError: Error(s) in loading state_dict for ResNet:
    Unexpected key(s) in state_dict: "bn1.aux_bn.weight", "bn1.aux_bn.bias", "bn1.aux_bn.running_mean", "bn1.aux_bn.running_var", "bn1.aux_bn.num_batches_tracked", "layer1.0.bn1.aux_bn.weight", "layer1.0.bn1.aux_bn.bias", "layer1.0.bn1.aux_bn.running_mean", "layer1.0.bn1.aux_bn.running_var", "layer1.0.bn1.aux_bn.num_batches_tracked", "layer1.0.bn2.aux_bn.weight", "layer1.0.bn2.aux_bn.bias", "layer1.0.bn2.aux_bn.running_mean", "layer1.0.bn2.aux_bn.running_var", "layer1.0.bn2.aux_bn.num_batches_tracked", "layer1.0.bn3.aux_bn.weight", "layer1.0.bn3.aux_bn.bias", "layer1.0.bn3.aux_bn.running_mean", "layer1.0.bn3.aux_bn.running_var", "layer1.0.bn3.aux_bn.num_batches_tracked", "layer1.0.downsample.1.aux_bn.weight", "layer1.0.downsample.1.aux_bn.bias", "layer1.0.downsample.1.aux_bn.running_mean", "layer1.0.downsample.1.aux_bn.running_var", "layer1.0.downsample.1.aux_bn.num_batches_tracked", "layer1.1.bn1.aux_bn.weight", "layer1.1.bn1.aux_bn.bias", "layer1.1.bn1.aux_bn.running_mean", "layer1.1.bn1.aux_bn.running_var", "layer1.1.bn1.aux_bn.num_batches_tracked", "layer1.1.bn2.aux_bn.weight", "layer1.1.bn2.aux_bn.bias", "layer1.1.bn2.aux_bn.running_mean", "layer1.1.bn2.aux_bn.running_var", "layer1.1.bn2.aux_bn.num_batches_tracked", "layer1.1.bn3.aux_bn.weight", "layer1.1.bn3.aux_bn.bias", 

Loading normally like this

# Load the model
model = resnet50()
checkpoint_path = 'C:/res50-debiased.pth.tar'
checkpoint = torch.load(checkpoint_path)

state_dict = checkpoint['state_dict']

model.load_state_dict(state_dict)

gives error Unexpected key(s) in state_dict: "module.conv1.weight",

RuntimeError: Error(s) in loading state_dict for ResNet:
    Missing key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.conv2.weight", "layer1.0.bn2.weight", "layer1.0.bn2.bias", "layer1.0.bn2.running_mean", "layer1.0.bn2.running_var", "layer1.0.conv3.weight", "layer1.0.bn3.weight", "layer1.0.bn3.bias", "layer1.0.bn3.running_mean", "layer1.0.bn3.running_var", "layer1.0.downsample.0.weight", "layer1.0.downsample.1.weight", "layer1.0.downsample.1.bias", "layer1.0.downsample.1.running_mean", "layer1.0.downsample.1.running_var", "layer1.1.conv1.weight", "layer1.1.bn1.weight", "layer1.1.bn1.bias", "layer1.1.bn1.running_mean", "layer1.1.bn1.running_var", "layer1.1.conv2.weight", "layer1.1.bn2.weight", "layer1.1.bn2.bias", "layer1.1.bn2.running_mean", "layer1.1.bn2.running_var", "layer1.1.conv3.weight", "layer1.1.bn3.weight", "layer1.1.bn3.bias", "layer1.1.bn3.running_mean", "layer1.1.bn3.running_var", "layer1.2.conv1.weight", "layer1.2.bn1.weight", "layer1.2.bn1.bias", "layer1.2.bn1.running_mean", "layer1.2.bn1.running_var", "layer1.2.conv2.weight", "layer1.2.bn2.weight", "layer1.2.bn2.bias", "layer1.2.bn2.running_mean", "layer1.2.bn2.running_var", "layer1.2.conv3.weight", "layer1.2.bn3.weight", "layer1.2.bn3.bias", "layer1.2.bn3.running_mean", "layer1.2.bn3.running_var", "layer2.0.conv1.weight", "layer2.0.bn1.weight", "layer2.0.bn1.bias", "layer2.0.bn1.running_mean", "layer2.0.bn1.running_var", "layer2.0.conv2.weight", "layer2.0.bn2.weight", ...

Unexpected key(s) in state_dict: "module.conv1.weight", "module.bn1.weight", "module.bn1.bias", "module.bn1.running_mean", "module.bn1.running_var", "module.bn1.num_batches_tracked", "module.bn1.aux_bn.weight", "module.bn1.aux_bn.bias", "module.bn1.aux_bn.running_mean", "module.bn1.aux_bn.running_var", "module.bn1.aux_bn.num_batches_tracked", "module.layer1.0.conv1.weight", "module.layer1.0.bn1.weight", "module.layer1.0.bn1.bias", "module.layer1.0.bn1.running_mean", "module.layer1.0.bn1.running_var", "module.layer1.0.bn1.num_batches_tracked", "module.layer1.0.bn1.aux_bn.weight", "module.layer1.0.bn1.aux_bn.bias", "module.layer1.0.bn1.aux_bn.running_mean", "module.layer1.0.bn1.aux_bn.running_var", "module.layer1.0.bn1.aux_bn.num_batches_tracked", "module.layer1.0.conv2.weight", "module.layer1.0.bn2.weight", "module.layer1.0.bn2.bias", "module.layer1.0.bn2.running_mean", "module.layer1.0.bn2.running_var", "module.layer1.0.bn2.num_batches_tracked", "module.layer1.0.bn2.aux_bn.weight", "module.layer1.0.bn2.aux_bn.bias", "module.layer1.0.bn2.aux_bn.running_mean", "module.layer1.0.bn2.aux_bn.running_var", "module.layer1.0.bn2.aux_bn.num_batches_tracked", "module.layer1.0.conv3.weight", "module.layer1.0.bn3.weight", "module.layer1.0.bn3.bias", "module.layer1.0.bn3.running_mean", "module.layer1.0.bn3.running_var", "module.layer1.0.bn3.num_batches_tracked", "module.layer1.0.bn3.aux_bn.weight", "module.layer1.0.bn3.aux_bn.bias", "module.layer1.0.bn3.aux_bn.running_mean", "module.layer1.0.bn3.aux_bn.running_var", "module.layer1.0.bn3.aux_bn.num_batches_tracked", "module.layer1.0.downsample.0.weight", "module.layer1.0.downsample.1.weight", "module.layer1.0.downsample.1.bias", "module.layer1.0.downsample.1.running_mean", "module.layer1.0.downsample.1.running_var", "module.

Many thanks.


Solution

  • You did the right thing removing the "module." prefix but the remaining issue comes from the fact this resnet50 model was initialized with a custom normalization layer defined in aux_bn.py as MixBatchNorm2d. You can see the ResNet being initialized here.
    This results in keys of the type "bn*.aux_bn".

    Your code should function with the correct initialization:

    checkpoint = torch.load(checkpoint_path)
    state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()}
    
    model = resnet50(num_classes=1_000, norm_layer=MixBatchNorm2d)
    model.load_state_dict(state_dict)