Search code examples
pytorchfast-ai

fastaiv2 to pytorch for torchserver


I usually use fastai (v2 or v1) for fast prototyping. Now I'd like to deploy one of my models, trained with fastai, to torchserver.

Let's say that we have a simple model like this one:

    learn = cnn_learner(data, 
                    models.resnet34, 
                    metrics=[accuracy, error_rate, score])
    # after the training 
    torch.save(learn.model.state_dict(), "./test1.pth")
    state = torch.load("./test1.pth")
    model_torch_rep = models.resnet34()
    model_torch_rep.load_state_dict(state)

I've tried many different things with the same result

    RuntimeError Traceback (most recent call last)
    <ipython-input-284-e4dbdce23d43> in <module>
    ----> 1 model_torch_rep.load_state_dict(state);
    
    /opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
        837         if len(error_msgs) > 0:
        838             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
    --> 839                                self.__class__.__name__, "\n\t".join(error_msgs)))
        840         return _IncompatibleKeys(missing_keys, unexpected_keys)
        841 
    
    RuntimeError: Error(s) in loading state_dict for ResNet:
        Missing key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "layer1.0.conv1.weight", "layer1.0.bn1.weight"

This is happening with fastai 1.0.6 or fastai 2.3.1 + pytorch 1.8.1 ...


Solution

  • Just figured this out.

    For some reason the way you save the state_dict adds a string "module." to each key in the loaded state_dict. (This is because you aren't using Learner class from FastAI to save the model, I assume).

    Simply remove the "module." substring from the state_dict and you're all good.

        learn = cnn_learner(data, 
                            models.resnet34, 
                            metrics=[accuracy, error_rate, score])
    
        # after the training 
        torch.save(learn.model.state_dict(), "./test1.pth")
        state = torch.load("./test1.pth")
    
        # fix dict keys 
        new_state = OrderedDict([(k.partition('module.')[2], v) for k, v in state.items()])
    
        model_torch_rep = models.resnet34()
        model_torch_rep.load_state_dict(new_state)