Search code examples
pythonnumpymatrix-multiplication

Numpy matrix multiplication with 2D elements


I have an a numpy ndarray 3x3 matrix that looks like this

a =  ([[ uu, uv, uw],
       [ uv, vv, vw],
       [ uw, vw, ww]])

Each component is itself a 2D array of size (N,M), so the a matrix has a (3,3,N,M) shape.

How could I perform a matrix multiplication of a*a in a pythonic way? Using a@a throws the following error (for N=1218 and M=540):

ValueError: shapes (3,3,1218,540) and (3,3,1218,540) not aligned: 540 (dim 3) != 1218 (dim 2)

I want to be able to perform this operation as if the elements of a where just scalar values where a@a does not throw an error related to its shapes since it is a simple 3x3 matrix multiplication.

Thanks.


Solution

  • Assuming that you are looking to perform matrix-multiplication for each element along the last two axes, we can use np.einsum -

    np.einsum('ijkl,jmkl->imkl',a,a)
    

    Sample run for verification -

    In [43]: np.random.seed(0)
    
    In [44]: a = np.random.rand(3,3,4,5)
    
    In [45]: a[:,:,0,0].dot(a[:,:,0,0])
    Out[45]: 
    array([[0.71750146, 1.17057872, 1.11135764],
           [0.62938365, 0.86437796, 0.74541383],
           [1.04636618, 1.62011127, 1.35483565]])
    
    In [46]: np.einsum('ijkl,jmkl->imkl',a,a)[:,:,0,0]
    Out[46]: 
    array([[0.71750146, 1.17057872, 1.11135764],
           [0.62938365, 0.86437796, 0.74541383],
           [1.04636618, 1.62011127, 1.35483565]])