Search code examples
pythonlinear-algebranumpy-einsumeinsum

replacing einsum with normal operations


I need to replace einsum operation with standard numpy operations in the following code:

import numpy as np
a = np.random.rand(128, 16, 8, 32)
b = np.random.rand(256, 8, 32)
output = np.einsum('aijb,rjb->ira', a, b)

How would I do that?


Solution

  • One option would be to align to a similar shape and broadcast multiply, then sum and reorder the axes:

    output2 = (b[None, None]*a[:,:,None]).sum(axis=(-1, -2)).transpose((1, 2, 0))
    
    # assert np.allclose(output, output2)
    

    But this is much less efficient as it's producing a large intermediate (shape (128, 16, 256, 8, 32)):

    # np.einsum('aijb,rjb->ira', a, b)
    68.9 ms ± 23.1 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
    
    # (b[None, None]*a[:,:,None]).sum(axis=(-1, -2)).transpose((1, 2, 0))
    4.66 s ± 1.65 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
    

    Shapes:

    # b[None, None].shape
    #a  i    r  j   b
    (1, 1, 256, 8, 32)
    
    # a[:,:,None].shape
    #  a   i  r  j   b
    (128, 16, 1, 8, 32)