Search code examples
matrixpytorchmultiplication

Multiplication of 4d and 2d matrices using pytorch


I'm trying to do matrix multiplication through a torch. What I want to do is multiply a 4d matrix by a 2d matrix. For example, 4d has a size of (A,B,C,D), and 2d has a size of (D,C).

What I want is that the matrix of 2d is repeatedly matrix multiplied as much as the sizes of A and B of 4d, and the final size is in the form of (A,B,C,C).

I'd be grateful if you could point out how to solve this.


Solution

  • What you are looking to compute is:

    >>> for a, b in AxB:
    ...     R[a, b] = X[a, b]@Y
    

    You can start with a loop over A and B and compute each matrix multiplication (C,D)@(D,C) which yield (C,C). Overall you get a tensor of shape (A, B, C, C), ie. A*C matrices of size CxC.

    You can perform such operation using torch.einsum (read more about here):

    >>> R = torch.einsum('abcd,de->abce', X, Y)
    

    Notice how we used subscripts e different to c.
    The corresponds to the pseudo-code:

    >>> R = torch.zeros(A,B,C,C)
    >>> for a,b,c,d,e in AxBxCxDxC:
    ...     R[a,b,c,d] += X[a,b,c,d]*Y[d,e]
    

    However, in this simple use case you can get away with multiplying X and Y straightaway because PyTorch handles broadcasting. So it comes down to:

    >>> R = X@Y
    

    Simple!