Search code examples
pythonpytorchbroadcasttorch

PyTorch broadcast multiplication of 4D and 2D matrix?


How do I broadcast to multiply these two matrices together?

x: torch.Size([10, 120, 180, 30]) # (N, H, W, C)
W: torch.Size([64, 30]) # (Y, C)

The output should be:

(10, 120, 180, 64) == (N, H, W, Y)

Solution

  • I assume x is some kind of example with batches and w matrix is the corresponding weight. In this case you could simply do:

    out = x @ w.T
    

    which is a tensor multiplication, not an element-wise one. You can't do element-wise multiplication to get such shape and this operation would not make sense. All you could do is to unsqueeze both of the matrics in some way to have their shape broadcastable and apply some operation over dimension you don't want for some reason like this:

    x : torch.Size([10, 120, 180, 30, 1])
    W: torch.Size([1, 1, 1, 30, 64]) # transposition would be needed as well
    

    After such unsqueezing you could do x*w and sum or mean along the third dim to get desired shape.

    For clarity, both ways are not equivalent.