Search code examples
pythontensorflowkeraspytorch

How to make pytorch 11-output model by combining three sequential models of 4, 3 and 4 shape?


I'm trying to move my project from Tensorflow to Pytorch to compare the accuracy. The overall data flow is as follows (the number in brackets refers to layer output size): enter image description here

Now, in Tensorflow I can use functional API and write three tf.keras.Sequential:

def single_model(topology):
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(topology, activation = relu),
        tf.keras.layers.Dense(topology, activation = relu)])
    return model
input_ANN = tf.keras.layers.Input(shape=(21,), name="Input")
model1=single_model(4)(input_ANN)
model2=single_model(3)(input_ANN)
model3=single_model(4)(input_ANN)
concat = tf.keras.layers.Concatenate(axis=-1, name='Concatenate')([model1, model2, model3])
model = tf.keras.Model(inputs=[input_ANN], outputs=[concat])

And from this point I have a complete single model to work as with typical single-road ANN. But how to do the same with Pytorch? I was trying to do this as follow (21 is input size)

# define the model
def single_model(topology):
    model = nn.Sequential(
        nn.Linear(21, topology),
        nn.ReLU(),
        nn.Linear(topology, topology),
        nn.ReLU())
    return model
# define the model

model1=single_model(3)
model2=single_model(4)
model3=single_model(3)
model=nn.Sequential(*model1.children(),*model2.children(),*model3.children())

But it fails with the error running_mean should contain 3 elements not 21. The error is on the lane where I try to generate model prediction in fitting loop:

for e in range(epochs):
    train_loss = 0.0
    model.train()     # Optional when not using Model Specific layer
    for data, labels in train_dataloader:
        optimizer.zero_grad()
        target = model(data) #Here it crash
        loss = MARELoss(target,labels)
        loss.backward()
        optimizer. Step()
        train_loss += loss. Item()

Can you please advise me how to connect the sub-models in the correct way?


Solution

  • In the case someone will deal with the same problem, the solution is to subclass Model:

    class Model(torch.nn.Module):
      def __init__(self):
        super(Model, self).__init__()
    
        self.d1 = nn.Sequential(nn.Linear(21, 4), nn.Linear(4, 4))
        self.d2 = nn.Sequential(nn.Linear(21, 3), nn.Linear(3, 3))
        self.d3 = nn.Sequential(nn.Linear(21, 4), nn.Linear(4, 4))
    
      def forward(self, x):
        return torch.cat((self.d1(x), self.d2(x), self.d3(x)), dim=1)
    

    And now the instance of model can be created model = Model()

    Keep in mind that torch.cat will not work as expected outside the class definition as in this context it would expect tensor, not sequential as input.