Consider the following in Python: A has dimension (T,), U has dimension (L,T) and G has dimension (K,T), Y is (L,L,T). My code outputs a numer1 and numer2 with dimensions (T, LK, 1), . Consider that numer1 is (T, LK, 1), while numer3 is (LK,1)
# MWE
np.random.seed(0)
T=250
K = 15
L = 20
A = np.random.normal(size=(T))
Y = np.random.normal(size=(L,L,T))
G = np.random.normal(size=(K,T))
U = np.random.normal(size=(L,T))
# Calculations
L, T = U.shape
K = G.shape[0]
Y_transposed = Y.transpose(2, 0, 1)
sG = np.einsum('it,jt->tij', G, G)
A_transposed = A[None, :, None].transpose(1, 0, 2)
numer1 = np.einsum('ai,ak->aik', U.T, G.T).reshape(T, L * K, 1)
numer2 = numer1 * A_transposed
numer3 = numer2.sum(axis=0)
It turns out that np.reshape is very slow in this piece of code. When I say slow, in mean it compared to other solutions that are loop-based. Is there a way to avoid reshaping by using the einsum in a different way since I take sums in the last step?
A similar approach would also benefit the other piece of code, for which the problem seems a bit more biting:
denom1 = np.einsum('aij,akl->aikjl', Y_transposed, sG).reshape(T, L * K, L *K)
denom2 = denom1 * A_transposed
denom3 = denom2.sum(axis=0)
You are using einsum
to do broadcasted outer products.
With a small example:
In [63]: A = np.arange(3); U = np.arange(4*3).reshape(4,3); G=np.arange(5*3).reshape(5,3); Y=np.arange(4*4*3).reshape(4,4,3)
Your numer?
calculations time as:
In [64]: %%timeit
...: A_transposed = A[None, :, None].transpose(1, 0, 2)
...: numer1 = np.einsum('ai,ak->aik', U.T, G.T).reshape(T, L * K, 1)
...: numer2 = numer1 * A_transposed
...: numer3 = numer2.sum(axis=0)
...:
...:
37.8 µs ± 74.2 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
An equivalent with just broadcasting is modestly faster:
In [65]: timeit x=((U[:,None,:]*G[None,:,:]).reshape(-1,3)*A).sum(axis=1)
24.8 µs ± 144 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
They are the same, except for that added trailing dimension:
In [66]: numer3.shape, x.shape
Out[66]: ((20, 1), (20,))
In [67]: np.allclose(x, numer3[:,0])
Out[67]: True
or moving the reshape to the end
In [73]: timeit y=(U[:,None,:]*G[None,:,:]*A).sum(axis=-1).reshape(-1,1)
25.5 µs ± 153 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
In the previous calculation I did two things, replace the einsum
with broadcasted multiply, and moved the T
dimension to the end, removing the need to the transposing.
In [76]: denom1 = np.einsum('aij,akl->aikjl', Y_transposed, sG).reshape(T, L * K, L *K)
...: denom2 = denom1 * A_transposed
...: denom3 = denom2.sum(axis=0)
In [77]: Y_transposed.shape, sG.shape
Out[77]: ((3, 4, 4), (3, 5, 5))
In [78]: denom1.shape
Out[78]: (3, 20, 20)
In [79]: Y.shape, G.shape
Out[79]: ((4, 4, 3), (5, 3))
Focusing on moving T/t/a
to the end:
In [80]: sg1 = np.einsum('it,jt->ijt', G, G)
Doing broadcasting
in [80] will be just as before.
In [82]: denom4 = (np.einsum('ija,kla->ikjla', Y, sg1)*A).sum(axis=-1)
In [83]: denom4.shape
Out[83]: (4, 5, 4, 5)
There are some more dimensions in [82] but the same idea applies.
Now do the reshape and check for equality:
In [84]: denom3.shape
Out[84]: (20, 20)
In [85]: denom4 = denom4.reshape(L*K,L*K)
In [86]: np.allclose(denom3, denom4)
Out[86]: True