Search code examples
pythonnumpybroadcasting

Broadcasting 3d arrays for elementwise multiplication


Good evening,

I need some help understanding advanced broadcasting with complex numpy arrays.

I have:

array A: 50000x2000

array B: 2000x10x10

Implementation with for loop:

for k in range(50000):
    temp = A[k,:].reshape(2000,1,1)
    finalarray[k,:,:]=np.sum ( B*temp , axis=0)

I want an element-wise multiplication and summation of the axis with 2000 elements, with endproduct:

finalarray: 50000x10x10

Is it possible to avoid the for loop? Thank you!


Solution

  • For something like this I'd use np.einsum, which makes it pretty easy to write down what you want to happen in terms of the index actions you want:

    fast = np.einsum('ij,jkl->ikl', A, B)
    

    which gives me the same result (dropping 50000->500 so the loopy one finishes quickly):

    A = np.random.random((500, 2000))
    B = np.random.random((2000, 10, 10))
    finalarray = np.zeros((500, 10, 10))
    for k in range(500):
        temp = A[k,:].reshape(2000,1,1)
        finalarray[k,:,:]=np.sum ( B*temp , axis=0)
    
    fast = np.einsum('ij,jkl->ikl', A, B)
    

    gives me

    In [81]: (finalarray == fast).all()
    Out[81]: True
    

    and reasonable performance even in the 50000 case:

    In [88]: %time fast = np.einsum('ij,jkl->ikl', A, B)
    Wall time: 4.93 s
    
    In [89]: fast.shape
    
    Out[89]: (50000, 10, 10)
    

    Alternatively, in this case, you could use tensordot:

    faster = np.tensordot(A, B, axes=1)
    

    which will be a few times faster (at the cost of being less general):

    In [29]: A = np.random.random((50000, 2000))
    
    In [30]: B = np.random.random((2000, 10, 10))
    
    In [31]: %time fast = np.einsum('ij,jkl->ikl', A, B)
    Wall time: 5.08 s
    
    In [32]: %time faster = np.tensordot(A, B, axes=1)
    Wall time: 504 ms
    
    In [33]: np.allclose(fast, faster)
    Out[33]: True
    

    I had to use allclose here because the values wind up being very slightly different:

    In [34]: abs(fast - faster).max()
    Out[34]: 2.7853275241795927e-12