I'm trying to move my project from Tensorflow to Pytorch to compare the accuracy. The overall data flow is as follows (the number in brackets refers to layer output size):
Now, in Tensorflow I can use functional API and write three tf.keras.Sequential
:
def single_model(topology):
model = tf.keras.Sequential([
tf.keras.layers.Dense(topology, activation = relu),
tf.keras.layers.Dense(topology, activation = relu)])
return model
input_ANN = tf.keras.layers.Input(shape=(21,), name="Input")
model1=single_model(4)(input_ANN)
model2=single_model(3)(input_ANN)
model3=single_model(4)(input_ANN)
concat = tf.keras.layers.Concatenate(axis=-1, name='Concatenate')([model1, model2, model3])
model = tf.keras.Model(inputs=[input_ANN], outputs=[concat])
And from this point I have a complete single model to work as with typical single-road ANN. But how to do the same with Pytorch
? I was trying to do this as follow (21 is input size)
# define the model
def single_model(topology):
model = nn.Sequential(
nn.Linear(21, topology),
nn.ReLU(),
nn.Linear(topology, topology),
nn.ReLU())
return model
# define the model
model1=single_model(3)
model2=single_model(4)
model3=single_model(3)
model=nn.Sequential(*model1.children(),*model2.children(),*model3.children())
But it fails with the error running_mean should contain 3 elements not 21
. The error is on the lane where I try to generate model prediction in fitting loop:
for e in range(epochs):
train_loss = 0.0
model.train() # Optional when not using Model Specific layer
for data, labels in train_dataloader:
optimizer.zero_grad()
target = model(data) #Here it crash
loss = MARELoss(target,labels)
loss.backward()
optimizer. Step()
train_loss += loss. Item()
Can you please advise me how to connect the sub-models in the correct way?
In the case someone will deal with the same problem, the solution is to subclass Model:
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.d1 = nn.Sequential(nn.Linear(21, 4), nn.Linear(4, 4))
self.d2 = nn.Sequential(nn.Linear(21, 3), nn.Linear(3, 3))
self.d3 = nn.Sequential(nn.Linear(21, 4), nn.Linear(4, 4))
def forward(self, x):
return torch.cat((self.d1(x), self.d2(x), self.d3(x)), dim=1)
And now the instance of model can be created model = Model()
Keep in mind that torch.cat
will not work as expected outside the class definition as in this context it would expect tensor, not sequential as input.