Search code examples
pythonperformancenumpynumpy-einsum

Python fast array multiplication for multidimensional arrays


I have two 3-dimensional arrays, A, B, where

  1. A has dimensions (500 x 500 x 80), and
  2. B has dimensions (500 x 80 x 2000).

In both arrays the dimension that has the size 80 can be called 'time' (e.g. 80 timepoints i). The dimension that has the size 2000 can be called 'scenario' (we have 2000 scenarios).

What I need to do is to take 500 x 500 matrix A[:, :, i] and multiply by it each 500-element column vector at a corresponding time B[:, i, scenario] for each scenario and time i.

I eventually ended up with the code below

from scipy.stats import norm
import numpy as np
A = norm.rvs(size = (500, 500, 80),  random_state = 0)
B = norm.rvs(size = (500, 80, 2000), random_state = 0)
result = np.einsum('ijk,jkl->ikl', A, B, optimize=True)

while a naive approach would for the same problem be to use a nested for loop

for scenario in range(2000):
    for i in range(80):
         out[:, i, scenario] = A[:, :, i] @ B[:, i, scenario]

I expected einsum to be quite fast because the problem 'only' involves simple operations on a large array but it actually runs for several minutes.

I compared the speed of the einsum above to the case where we assume that each matrix in A is the same, we can keep A as a (500 x 500) matrix (instead of a 3d array), and then the whole problem can be written as

A = norm.rvs(size = (500, 500),      random_state = 0)
B = norm.rvs(size = (500, 80, 2000), random_state = 0)
result = np.einsum('ij,jkl->ikl', A, B, optimize=True)

This is fast and only runs for a few seconds. Much faster than the 'slightly' more general case above.

My question is - do I write the general case with the slow einsum in a computationally efficient form?


Solution

  • You can do better than the existing two nested loops one with one loop instead -

    m = A.shape[0]
    n = B.shape[2]
    r = A.shape[2]
    out1 = np.empty((m,r,n), dtype=np.result_type(A.dtype, B.dtype))
    for i in range(r):
        out1[:,i,:] = A[:, :, i] @ B[:, i,:]
    

    Alternatively, with np.matmul/@ operator -

    out = (A.transpose(2,0,1) @ B.transpose(1,0,2)).swapaxes(0,1)
    

    These two seem to scale much better than einsum version.

    Timings

    Case #1 : Scaled 1/4th sizes

    In [44]: m = 500
        ...: n = 2000
        ...: r = 80
        ...: m,n,r = m//4, n//4, r//4
        ...: 
        ...: A = norm.rvs(size = (m, m, r),  random_state = 0)
        ...: B = norm.rvs(size = (m, r, n), random_state = 0)
    
    In [45]: %%timeit
        ...: out1 = np.empty((m,r,n), dtype=np.result_type(A.dtype, B.dtype))
        ...: for i in range(r):
        ...:     out1[:,i,:] = A[:, :, i] @ B[:, i,:]
    175 ms ± 6.54 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
    
    In [46]: %timeit (A.transpose(2,0,1) @ B.transpose(1,0,2)).swapaxes(0,1)
    165 ms ± 1.11 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
    
    In [47]: %timeit np.einsum('ijk,jkl->ikl', A, B, optimize=True)
    483 ms ± 13.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    

    and as we scale up, the memory congestion would start favouring the one-loop version.

    Case #2 : Scaled 1/2 sizes

    In [48]: m = 500
        ...: n = 2000
        ...: r = 80
        ...: m,n,r = m//2, n//2, r//2
        ...: 
        ...: A = norm.rvs(size = (m, m, r),  random_state = 0)
        ...: B = norm.rvs(size = (m, r, n), random_state = 0)
    
    In [49]: %%timeit
        ...: out1 = np.empty((m,r,n), dtype=np.result_type(A.dtype, B.dtype))
        ...: for i in range(r):
        ...:     out1[:,i,:] = A[:, :, i] @ B[:, i,:]
    2.9 s ± 58.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    
    In [50]: %timeit (A.transpose(2,0,1) @ B.transpose(1,0,2)).swapaxes(0,1)
    3.02 s ± 94.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    

    Case #3 : Scaled 67% sizes

    In [59]: m = 500
        ...: n = 2000
        ...: r = 80
        ...: m,n,r = int(m/1.5), int(n/1.5), int(r/1.5)
    
    In [60]: A = norm.rvs(size = (m, m, r),  random_state = 0)
        ...: B = norm.rvs(size = (m, r, n), random_state = 0)
    
    In [61]: %%timeit
        ...: out1 = np.empty((m,r,n), dtype=np.result_type(A.dtype, B.dtype))
        ...: for i in range(r):
        ...:     out1[:,i,:] = A[:, :, i] @ B[:, i,:]
    25.8 s ± 4.9 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
    
    In [62]: %timeit (A.transpose(2,0,1) @ B.transpose(1,0,2)).swapaxes(0,1)
    29.2 s ± 2.41 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
    

    Numba spin-off

    from numba import njit, prange
        
    @njit(parallel=True)
    def func1(A, B):
        m = A.shape[0]
        n = B.shape[2]
        r = A.shape[2]
        out = np.empty((m,r,n))
        for i in prange(r):
            out[:,i,:] = A[:, :, i] @ B[:, i,:]
        return out
    

    Timings with case #3 -

    In [80]: m = 500
        ...: n = 2000
        ...: r = 80
        ...: m,n,r = int(m/1.5), int(n/1.5), int(r/1.5)
    
    In [81]: A = norm.rvs(size = (m, m, r),  random_state = 0)
        ...: B = norm.rvs(size = (m, r, n), random_state = 0)
    
    In [82]: %timeit func1(A, B)
    653 ms ± 10.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)