I'm trying to do matrix multiplication through a torch. What I want to do is multiply a 4d matrix by a 2d matrix. For example, 4d has a size of (A,B,C,D), and 2d has a size of (D,C).
What I want is that the matrix of 2d is repeatedly matrix multiplied as much as the sizes of A and B of 4d, and the final size is in the form of (A,B,C,C).
I'd be grateful if you could point out how to solve this.
What you are looking to compute is:
>>> for a, b in AxB:
... R[a, b] = X[a, b]@Y
You can start with a loop over A
and B
and compute each matrix multiplication (C,D)@(D,C)
which yield (C,C)
. Overall you get a tensor of shape (A, B, C, C)
, ie. A*C matrices of size CxC
.
You can perform such operation using torch.einsum
(read more about here):
>>> R = torch.einsum('abcd,de->abce', X, Y)
Notice how we used subscripts e
different to c
.
The corresponds to the pseudo-code:
>>> R = torch.zeros(A,B,C,C)
>>> for a,b,c,d,e in AxBxCxDxC:
... R[a,b,c,d] += X[a,b,c,d]*Y[d,e]
However, in this simple use case you can get away with multiplying X
and Y
straightaway because PyTorch handles broadcasting. So it comes down to:
>>> R = X@Y
Simple!