Search code examples
deep-learningpytorch

how to combine two trained models using PyTorch?


I'm currently working on two models that use different types of data but are connected. I'd want to create a combination model that takes in one instance of each of the data types, runs them through each of the pre-trained models independently, and then processes the combined output of the two distinct models through a few feed-forward layers at the top. So far, I've learned that I can change forward to accept both inputs, so I've just cloned the structures of my individual models into the combined one, processed them each individually using forward(right )'s layers, and then merged the outputs as specified. What I'm having trouble with is figuring out how to achieve this.


Solution

  • as I understand from your question you can create two models then you need a third model that combines both the neural network with the forward and in the __main__ you can then load_state_dict for example:

    the first model

    class FirstM(nn.Module):
        def __init__(self):
            super(FirstM, self).__init__()
            self.fc1 = nn.Linear(20, 2)
            
        def forward(self, x):
            x = self.fc1(x)
            return x
    

    the second model

    class SecondM(nn.Module):
        def __init__(self):
            super(SecondM, self).__init__()
            self.fc1 = nn.Linear(20, 2)
            
        def forward(self, x):
            x = self.fc1(x)
            return x
    
    

    here you create a model that you can merge both two models in it as follows:

    class Combined_model(nn.Module):
        def __init__(self, modelA, modelB):
            super(Combined_model, self).__init__()
            self.modelA = modelA
            self.modelB = modelB
            self.classifier = nn.Linear(4, 2)
            
        def forward(self, x1, x2):
            x1 = self.modelA(x1)
            x2 = self.modelB(x2)
            x = torch.cat((x1, x2), dim=1)
            x = self.classifier(F.relu(x))
            return x
    
    

    and then outside the classed in the main you can do as following

    # Create models and load state_dicts    
    modelA = FirstM()
    modelB = SecondM()
    # Load state dicts
    modelA.load_state_dict(torch.load(PATH))
    modelB.load_state_dict(torch.load(PATH))
    
    model = Combined_model(modelA, modelB)
    x1, x2 = torch.randn(1, 10), torch.randn(1, 20)
    output = model(x1, x2)