Search code examples
pythonvectorizationpytorchbatch-processingmatrix-multiplication

Batch-Matrix multiplication in Pytorch - Confused with the handling of the output's dimension


I got two arrays :

A
B

Array A contains a batch of RGB images, with shape:

[batch, Width, Height, 3]

whereas Array B contains coefficients needed for a "transformation-like" operation on images, with shape:

[batch, 4, 4, 3]

To put it simply, the operation for a single image is a multiplication that outputs an environment map (normalMap * Coefficients).

The output I want should hold shape:

[batch, Width, Height, 3]

I tried using torch.bmm but failed. Is this possible somehow?


Solution

  • I think you need to calculate that PyTorch works with

    BxCxHxW : number of mini-batches, channels, height, width
    

    format, and also use matmul, since bmm works with tensors or ndim/dim/rank =3.

    I know you may find this online, but for any case:

    batch1 = torch.randn(10, 3, 20, 10)
    batch2 = torch.randn(10, 3, 10, 30)
    res = torch.matmul(batch1, batch2)
    res.size() # torch.Size([10, 3, 20, 30])