Search code examples
vgg-netpre-trained-model

How to load pretrained weights in modified vgg19 network in pytorch?


I am trying to load the vgg19 network with a modified number of input channels. The number of input channels is 4 is my case and also I am changing the classifier to my own classifier. I have also removed the Adaptive Average pooling layer from the network. How should I be loading the pre-trained weights into the modified version of my model in PyTorch?

Say the modified version of my model is in the variable myModel. How could I load the pretrained weights of vgg19 into the same?


Solution

  • Option 1. If you are going to use the original pre-trained weights given with the original VGG19 network, you have to load the weights first before modifying the network. The pre-trained weights are defined for the original network, so it needs to match the input channels. Then you can add an extra layer at the beginning as input layer, and remove the pooling layer in your new network.

    Option 2. You can load the weights for all the layers separately except for the input layer since there will be dimension mismatch.

    In code it would look something like this -

      # corresp_name is a dict object with mapping for your given layer 
      # name and original models layer name
      p_dict = torch.load(Path.model_dir()) #p_dict is my_model
      s_dict = self.state_dict()
      for name in p_dict:
          if name not in corresp_name:
                continue
          s_dict[corresp_name[name]] = p_dict[name]
      self.load_state_dict(s_dict)