Search code examples
pythonnumpytensortensordoteinsum

Faster alternative for numpy einsum in Python


I am trying to perform the following operations on some tensors. Currently I am using einsum, and I wonder if there is a way (maybe using dot or tensordot) to make things faster, since I feel like what I am doing is more or less some outer and inner products.

Scenario 1:

A = np.arange(6).reshape((2, 3))
B = np.arange(12).reshape((2, 3, 2))
res1 = numpy.einsum('ij, kjh->ikjh', A, B)

>>> res1 = 
[[[[ 0  0]
   [ 2  3]
   [ 8 10]]

  [[ 0  0]
   [ 8  9]
   [20 22]]]


 [[[ 0  3]
   [ 8 12]
   [20 25]]

  [[18 21]
   [32 36]
   [50 55]]]].

Scenario 2:

C = np.arange(12).reshape((2, 3, 2))
D = np.arange(6).reshape((3, 2))
res2 = np.einsum('ijk, jk->ij', C, D)

>>> res2 = 
[[ 1 13 41]
 [ 7 43 95]]

I have tried using tensordot and dot, and for some reason, I cannot figure the right way to set the axes...


Solution

  • Let's explore your first calculation. I'll start with a small example, to make sure values match. Timings on this size may not reflect your real-world needs.

    In [138]: n,m,k = 2,3,4
    In [141]: A = np.arange(n*m).reshape(n,m)
    In [142]: B = np.arange(n*m*k).reshape(n,m,k)
    
    
    In [144]: res1 = np.einsum('ij, kjh->ikjh', A, B)    
    In [145]: res1.shape
    Out[145]: (2, 2, 3, 4)
    

    Since there's no sum-of-products (j is in all terms), we can do it with broadcasted multiply:

    In [146]: x=A[:,None,:,None]*B
    In [147]: x.shape
    Out[147]: (2, 2, 3, 4)
    

    And the results and shapes match:

    In [148]: np.allclose(res1,x)
    Out[148]: True
    

    Some times (with the usual scalar qualification):

    In [149]: timeit res1 = np.einsum('ij, kjh->ikjh', A, B)
    13.6 µs ± 67.1 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
    
    In [150]: timeit x=A[:,None,:,None]*B
    7.2 µs ± 74.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
    
    In [151]: timeit res1 = np.einsum('ij, kjh->ikjh', A, B, optimize=True)
    90.5 µs ± 2.09 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
    

    The broadcasted result is best - I expect that to hold for larger arrays. And the optimize does not help. We could look at the einsum_path, but with only 2 arguments, there's isn't much to improve on.

    2nd

    In [152]: C = np.arange(n*m*k).reshape(n,m,k)
    D = np.arange(m*k).reshape(m,k)
    
    In [153]: res2 = np.einsum('ijk, jk->ij', C, D)
    In [154]: res2.shape
    Out[154]: (2, 3)
    

    These shape broadcast without change:

    In [155]: (C*D).shape
    Out[155]: (2, 3, 4)    
    In [156]: y=(C*D).sum(2)  # sum-of-products on last dimension    
    In [157]: y.shape
    Out[157]: (2, 3)
    

    which matches the einsum:

    In [158]: np.allclose(res2,y)
    Out[158]: True
    

    A matmul approach:

    In [159]: ([email protected]).shape
    Out[159]: (2, 3, 3)    
    In [160]: np.allclose(([email protected])[:,np.arange(3),np.arange(3)],res2)
    Out[160]: True
    

    I don't like having to take the last diagonal; I'll have to play with it some more.

    For these small timings, einsum is still best:

    In [164]: timeit res2 = np.einsum('ijk, jk->ij', C, D)
    11.9 µs ± 31 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
    
    In [165]: timeit y=(C*D).sum(2)
    13.9 µs ± 25.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
    
    In [166]: timeit ([email protected])[:,np.arange(3),np.arange(3)]
    21.4 µs ± 78.8 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
    

    2nd matmul

    The correct, and faster matmul

    In [167]: (C[:,:,None,:]@D[:,:,None]).shape
    Out[167]: (2, 3, 1, 1)
    

    squeeze out the trailing 1s:

    In [168]: np.allclose((C[:,:,None,:]@D[:,:,None])[:,:,0,0],res2)
    Out[168]: True
    
    In [169]: timeit (C[:,:,None,:]@D[:,:,None])
    6.63 µs ± 24.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
    

    I could probably use a similar trick to perform the first example with matmul, with sum-of-products on a dummy size 1 dimension.