Search code examples
pythonpytorchmatrix-multiplication

Batch multiplication of layer outputs and scalar factors in PyTorch


The model architecture I'm working with has k parallel convolution layers. The individual outputs of these layers are combined linearly using a MLP layer predicting weights. The model takes batch images as input and outputs batch images (with a different channel count).

A simplified forward() function is as follows:

# the batch input
#   input channels = 3
#   input dimensions = 32x32
# x.size() = [128, 3, 32, 32]
x

# self.mlp = Flatten -> Linear -> LeakyReLU
# conv_weights.size() = [128, 8]
#   batch size = 128
#   8 scalar values for each of the 8 convolution layers
conv_weights = self.mlp(x)

# run convolutions
# self.convs = ModuleList([Conv2d -> BatchNorm2d -> LeakyReLU])
conv_outputs = []
for conv in self.convs:
    # size = [128, 32, 32, 32]
    #   output channels = 32
    #   output dimensions = 32x32
    conv_outputs.append(conv(x))

result = ???

I'm having trouble trying to multiply the convolution outputs with the scalar factors due to the additional dimension for the batch size. Without the batch size, I could simply do a scalar multiplication of each convolution layer.

How do I multiply the output of each convolution layer with the MLP output?


Solution

  • To combine linearly your n=8 outputs, you can first stack conv_outputs on dim=1. This leaves you with a tensor of shape (b,n,c_out,h,w):

    >>> conv_outputs = torch.stack(conv_outputs, dim=1)
    

    Then broadcast conv_weights to (b,n,1,1,1) and multiply to conv_outputs. What matters is that dimensions of size b and size n remain in first positions. The last three dimensions are expanded automatically on conv_weights to compute the resulting tensor linear_comb:

    >>> linear_comb = conv_weights[:,:,None,None,None]*conv_outputs
    

    Now linear_comb is shaped (b,8,c_out,h,w), finish by reducing dim=1:

    >>> result = linear_comb.sum(1)
    

    Final shape is (b,c_out,h,w).