I have two 3-dimensional arrays, A, B, where
In both arrays the dimension that has the size 80 can be called 'time' (e.g. 80 timepoints i
). The dimension that has the size 2000 can be called 'scenario' (we have 2000 scenario
What I need to do is to take 500 x 500 matrix A[:, :, i]
and multiply by it each 500-element column vector at a corresponding time B[:, i, scenario]
for each scenario
and time i
I eventually ended up with the code below
from scipy.stats import norm
import numpy as np
A = norm.rvs(size = (500, 500, 80), random_state = 0)
B = norm.rvs(size = (500, 80, 2000), random_state = 0)
result = np.einsum('ijk,jkl->ikl', A, B, optimize=True)
while a naive approach would for the same problem be to use a nested for loop
for scenario in range(2000):
for i in range(80):
out[:, i, scenario] = A[:, :, i] @ B[:, i, scenario]
I expected einsum
to be quite fast because the problem 'only' involves simple operations on a large array but it actually runs for several minutes.
I compared the speed of the einsum
above to the case where we assume that each matrix in A is the same, we can keep A as a (500 x 500) matrix (instead of a 3d array), and then the whole problem can be written as
A = norm.rvs(size = (500, 500), random_state = 0)
B = norm.rvs(size = (500, 80, 2000), random_state = 0)
result = np.einsum('ij,jkl->ikl', A, B, optimize=True)
This is fast and only runs for a few seconds. Much faster than the 'slightly' more general case above.
My question is - do I write the general case with the slow einsum
in a computationally efficient form?
You can do better than the existing two nested loops one with one loop instead -
m = A.shape[0]
n = B.shape[2]
r = A.shape[2]
out1 = np.empty((m,r,n), dtype=np.result_type(A.dtype, B.dtype))
for i in range(r):
out1[:,i,:] = A[:, :, i] @ B[:, i,:]
Alternatively, with np.matmul/@ operator
out = (A.transpose(2,0,1) @ B.transpose(1,0,2)).swapaxes(0,1)
These two seem to scale much better than einsum
Case #1 : Scaled 1/4th sizes
In [44]: m = 500
...: n = 2000
...: r = 80
...: m,n,r = m//4, n//4, r//4
...: A = norm.rvs(size = (m, m, r), random_state = 0)
...: B = norm.rvs(size = (m, r, n), random_state = 0)
In [45]: %%timeit
...: out1 = np.empty((m,r,n), dtype=np.result_type(A.dtype, B.dtype))
...: for i in range(r):
...: out1[:,i,:] = A[:, :, i] @ B[:, i,:]
175 ms ± 6.54 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
In [46]: %timeit (A.transpose(2,0,1) @ B.transpose(1,0,2)).swapaxes(0,1)
165 ms ± 1.11 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
In [47]: %timeit np.einsum('ijk,jkl->ikl', A, B, optimize=True)
483 ms ± 13.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
and as we scale up, the memory congestion would start favouring the one-loop version.
Case #2 : Scaled 1/2 sizes
In [48]: m = 500
...: n = 2000
...: r = 80
...: m,n,r = m//2, n//2, r//2
...: A = norm.rvs(size = (m, m, r), random_state = 0)
...: B = norm.rvs(size = (m, r, n), random_state = 0)
In [49]: %%timeit
...: out1 = np.empty((m,r,n), dtype=np.result_type(A.dtype, B.dtype))
...: for i in range(r):
...: out1[:,i,:] = A[:, :, i] @ B[:, i,:]
2.9 s ± 58.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [50]: %timeit (A.transpose(2,0,1) @ B.transpose(1,0,2)).swapaxes(0,1)
3.02 s ± 94.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Case #3 : Scaled 67% sizes
In [59]: m = 500
...: n = 2000
...: r = 80
...: m,n,r = int(m/1.5), int(n/1.5), int(r/1.5)
In [60]: A = norm.rvs(size = (m, m, r), random_state = 0)
...: B = norm.rvs(size = (m, r, n), random_state = 0)
In [61]: %%timeit
...: out1 = np.empty((m,r,n), dtype=np.result_type(A.dtype, B.dtype))
...: for i in range(r):
...: out1[:,i,:] = A[:, :, i] @ B[:, i,:]
25.8 s ± 4.9 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [62]: %timeit (A.transpose(2,0,1) @ B.transpose(1,0,2)).swapaxes(0,1)
29.2 s ± 2.41 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
from numba import njit, prange
def func1(A, B):
m = A.shape[0]
n = B.shape[2]
r = A.shape[2]
out = np.empty((m,r,n))
for i in prange(r):
out[:,i,:] = A[:, :, i] @ B[:, i,:]
return out
Timings with case #3 -
In [80]: m = 500
...: n = 2000
...: r = 80
...: m,n,r = int(m/1.5), int(n/1.5), int(r/1.5)
In [81]: A = norm.rvs(size = (m, m, r), random_state = 0)
...: B = norm.rvs(size = (m, r, n), random_state = 0)
In [82]: %timeit func1(A, B)
653 ms ± 10.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)