I have an a
numpy ndarray
3x3 matrix that looks like this
a = ([[ uu, uv, uw],
[ uv, vv, vw],
[ uw, vw, ww]])
Each component is itself a 2D array of size (N,M)
, so the a
matrix has a (3,3,N,M)
How could I perform a matrix multiplication of a*a
in a pythonic way?
Using a@a
throws the following error (for N=1218 and M=540):
ValueError: shapes (3,3,1218,540) and (3,3,1218,540) not aligned: 540 (dim 3) != 1218 (dim 2)
I want to be able to perform this operation as if the elements of a
where just scalar values where a@a
does not throw an error related to its shapes since it is a simple 3x3 matrix multiplication.
Assuming that you are looking to perform matrix-multiplication for each element along the last two axes, we can use np.einsum
Sample run for verification -
In [43]: np.random.seed(0)
In [44]: a = np.random.rand(3,3,4,5)
In [45]: a[:,:,0,0].dot(a[:,:,0,0])
array([[0.71750146, 1.17057872, 1.11135764],
[0.62938365, 0.86437796, 0.74541383],
[1.04636618, 1.62011127, 1.35483565]])
In [46]: np.einsum('ijkl,jmkl->imkl',a,a)[:,:,0,0]
array([[0.71750146, 1.17057872, 1.11135764],
[0.62938365, 0.86437796, 0.74541383],
[1.04636618, 1.62011127, 1.35483565]])