The model architecture I'm working with has k
parallel convolution layers. The individual outputs of these layers are combined linearly using a MLP layer predicting weights. The model takes batch images as input and outputs batch images (with a different channel count).
A simplified forward()
function is as follows:
# the batch input
# input channels = 3
# input dimensions = 32x32
# x.size() = [128, 3, 32, 32]
x
# self.mlp = Flatten -> Linear -> LeakyReLU
# conv_weights.size() = [128, 8]
# batch size = 128
# 8 scalar values for each of the 8 convolution layers
conv_weights = self.mlp(x)
# run convolutions
# self.convs = ModuleList([Conv2d -> BatchNorm2d -> LeakyReLU])
conv_outputs = []
for conv in self.convs:
# size = [128, 32, 32, 32]
# output channels = 32
# output dimensions = 32x32
conv_outputs.append(conv(x))
result = ???
I'm having trouble trying to multiply the convolution outputs with the scalar factors due to the additional dimension for the batch size. Without the batch size, I could simply do a scalar multiplication of each convolution layer.
How do I multiply the output of each convolution layer with the MLP output?
To combine linearly your n=8
outputs, you can first stack conv_outputs
on dim=1
. This leaves you with a tensor of shape (b,n,c_out,h,w)
:
>>> conv_outputs = torch.stack(conv_outputs, dim=1)
Then broadcast conv_weights
to (b,n,1,1,1)
and multiply to conv_outputs
. What matters is that dimensions of size b
and size n
remain in first positions. The last three dimensions are expanded automatically on conv_weights
to compute the resulting tensor linear_comb
:
>>> linear_comb = conv_weights[:,:,None,None,None]*conv_outputs
Now linear_comb is shaped (b,8,c_out,h,w)
, finish by reducing dim=1
:
>>> result = linear_comb.sum(1)
Final shape is (b,c_out,h,w)
.