Search code examples
pythonnumpylinear-algebranumpy-ndarraynumpy-einsum

Dot-product a list of Matrices in numpy


Let's generate a 'list of three 2x2 matrices' that I call M1, M2 and M3:

import numpy as np
arr = np.arange(4*2*2).reshape((3, 2, 2))

I want to take the dot product of all these matrices:

 A = M1 @ M2 @ M3

What's the easiest and fastest way to do this? I'm basically looking for something similar to '.sum(axis=0)', but for matrix multiplication.


Solution

  • You are probably looking for np.linalg.multi_dot:

    arr = np.arange(3*2*2).reshape((-1, 2, 2))
    np.linalg.multi_dot(arr)
    

    Will give you the dot product between arr[0], arr[1] and arr[2]. As would arr[0] @ arr[1] @ arr[2].