Search code examples
pythonnumpynumpy-einsum

np.einsum performance of 4 matrix multiplications


Given the following 3 matrices:

M = np.arange(35 * 37 * 59).reshape([35, 37, 59])
A = np.arange(35 * 51 * 59).reshape([35, 51, 59])
B = np.arange(37 * 51 * 51 * 59).reshape([37, 51, 51, 59])
C = np.arange(59 * 27).reshape([59, 27])

I'm using einsum to compute:

D1 = np.einsum('xyf,xtf,ytpf,fr->tpr', M, A, B, C, optimize=True);

But I found it to be much less performant then:

tmp = np.einsum('xyf,xtf->tfy', A, M, optimize=True)
tmp = np.einsum('ytpf,yft->ftp', B, tmp, optimize=True)
D2 = np.einsum('fr,ftp->tpr', C, tmp, optimize=True)

And I can't understand why.
Overall I'm trying to optimize this piece of code as much as I can. I've read about the np.tensordot function but I can't seem to figure out how to utilize it for the given computation.


Solution

  • Looks like you stumbled onto a case where the greedy path gives a non-optimal scaling.

    >>> path, desc = np.einsum_path('xyf,xtf,ytpf,fr->tpr', M, A, B, C, optimize="greedy");
    >>> print(desc)
      Complete contraction:  xyf,xtf,ytpf,fr->tpr
             Naive scaling:  6
         Optimized scaling:  5
          Naive FLOP count:  3.219e+10
      Optimized FLOP count:  4.165e+08
       Theoretical speedup:  77.299
      Largest intermediate:  5.371e+06 elements
    --------------------------------------------------------------------------
    scaling                  current                                remaining
    --------------------------------------------------------------------------
       5              ytpf,xyf->xptf                         xtf,fr,xptf->tpr
       4               xptf,xtf->ptf                              fr,ptf->tpr
       4                 ptf,fr->tpr                                 tpr->tpr
    
    >>> path, desc = np.einsum_path('xyf,xtf,ytpf,fr->tpr', M, A, B, C, optimize="optimal");
    >>> print(desc)
      Complete contraction:  xyf,xtf,ytpf,fr->tpr
             Naive scaling:  6
         Optimized scaling:  4
          Naive FLOP count:  3.219e+10
      Optimized FLOP count:  2.744e+07
       Theoretical speedup:  1173.425
      Largest intermediate:  1.535e+05 elements
    --------------------------------------------------------------------------
    scaling                  current                                remaining
    --------------------------------------------------------------------------
       4                xtf,xyf->ytf                         ytpf,fr,ytf->tpr
       4               ytf,ytpf->ptf                              fr,ptf->tpr
       4                 ptf,fr->tpr                                 tpr->tpr
    

    Using np.einsum('xyf,xtf,ytpf,fr->tpr', M, A, B, C, optimize="optimal") should have you running at peak performance. I can look into this edge to see if greedy can nab it.