Search code examples
pytorchtensortorchconvolution

Batchwise convolutions – for each batch element a different weight


Given an input x which has a shape of (32, 4, 16). To this input, a convolution (conv1d) shall be applied. Note that, there are 4 input channels. Each of these channels has a dimensionality of 16. The batch size is 32. In the following, it is assumed that the number of output channels shall be 8. Therefore, the weight has the shape (8, 4, 1)

Overall, we got the following code:

batch_size = 32
input_channels = 4
output_channels = 8
dim = 16

x = torch.zeros(size=(batch_size, input_channels, dim))
weight = torch.rand(size=(output_channels, input_channels, 1))
y = F.conv1d(x, weight)

Now to my question: Instead of applying the same weight to each batch element, I want to apply to each batch element a different weight. In other words, the weight must have the shape (32, 8, 4, 1). How I can implement this batchwise convolution operation?

If weight is assigned to torch.rand(size=(batch_size, output_channels, input_channels, 1)) then the code does not work. Ofcourse there exists a simple solution which is based on a for-loop. I am looking for a solution without a for-loop.


Solution

  • Conv1d with a kernel of shape (1) is the same as broadcasted multiply with sum-reduction. Specifically

    # x.shape (32, 4, 16)
    # weight.shape (8, 4, 1)
    
    y = (x.unsqueeze(1) * weight.unsqueeze(0)).sum(dim=2)
    # equivalent to y = F.conv1d(x, weight)
    # equivalent to y = torch.einsum('bin,oin->bon', x, weight)
    

    So if we assume weight has shape (32, 8, 4, 1) then we just don't need to broadcast over the first dimension.

    # x.shape (32, 4, 16)
    # weight.shape (32, 8, 4, 1)
    
    y = (x.unsqueeze(1) * weight).sum(dim=2)
    # equivalent to y = torch.einsum('bin,boin->bon', x, weight)