Search code examples
pythonnumpy

How to understand the matmul function when matrix a is two-dimensional and matrix b is three-dimensional?


a=np.arange(8).reshape(2,2,2)
b=np.arange(4).reshape(2,2)
print(np.matmul(a,b))

the Result is:

[[[ 2  3]
  [ 6 11]]

 [[10 19]
  [14 27]]]

I don't understand this result, can someone please explain it?


Solution

  • Short answer: it "broadcasts" the second 2d matrix to a 3d matrix, and then performs a "mapping" so, it maps the elementwise submatrices to new submatrices in the result.

    As the documentation on np.matmul [numpy-doc] says:

    numpy.matmul(a, b, out=None)

    Matrix product of two arrays.

    The behavior depends on the arguments in the following way.

    1. If both arguments are 2-D they are multiplied like conventional matrices.
    2. If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the last two indexes and broadcast accordingly.
    3. If the first argument is 1-D, it is promoted to a matrix by prepending a 1 to its dimensions. After matrix multiplication the prepended 1 is removed.
    4. If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its dimensions. After matrix multiplication the appended 1 is removed.

    So here the second item is applicable. So first the second matrix is "broadcasted" to the 3d variant as well, so that means that we multiple:

    array([[[0, 1],
            [2, 3]],
    
           [[4, 5],
            [6, 7]]])
    

    with:

    array([[[0, 1],
            [2, 3]],
    
           [[0, 1],
            [2, 3]]])
    

    and we see these as stacked matrices. So first we multiply:

    array([[0, 1],      array([[0, 1],
           [2, 3]])  x        [2, 3]])
    

    which gives us:

    array([[ 2,  3],
           [ 6, 11]])
    

    and then the elementwise second submatrices:

    array([[4, 5],      array([[0, 1],
           [6, 7]])  x        [2, 3]])
    

    an this gives us:

    array([[10, 19],
           [14, 27]])
    

    we thus stack these together into the result, and obtain:

    >>> np.matmul(a, b)
    array([[[ 2,  3],
            [ 6, 11]],
    
           [[10, 19],
            [14, 27]]])
    

    Although the behavior is thus perfectly defined, it might be better to use this feature carefully, since there are other "sensical" definitions of what a "matrix product" on 3d matrices with 2d matrices might look like, and these are thus not used here.