Search code examples

Dot product between two 3D tensors

I have two 3D tensors, tensor A which has shape [B,N,S] and tensor B which also has shape [B,N,S]. What I want to get is a third tensor C, which I expect to have [B,B,N] shape, where the element C[i,j,k] =[i,k,:], B[j,k,:]. I also want to achieve this is a vectorized way.

Some further info: The two tensors A and B have shape [Batch_size, Num_vectors, Vector_size]. The tensor C, is supposed to represent the dot product between each element in the batch from A and each element in the batch from B, between all of the different vectors.

Hope that it is clear enough and looking forward to you answers!


  • In [331]: A=np.random.rand(100,200,300)                                                              
    In [332]: B=A

    The suggested einsum, working directly from the

    C[i,j,k] =[i,k,:], B[j,k,:] 


    In [333]: np.einsum( 'ikm, jkm-> ijk', A, B).shape                                                   
    Out[333]: (100, 100, 200)
    In [334]: timeit np.einsum( 'ikm, jkm-> ijk', A, B).shape                                            
    800 ms ± 25.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

    matmul does a dot on the last 2 dimensions, and treats the leading one(s) as batch. In your case 'k' is the batch dimension, and 'm' is the one that should obey the last A and 2nd to the last of B rule. So rewriting the ikm,jkm... to fit, and transposing A and B accordingly:

    In [335]: np.einsum('kim,kmj->kij', A.transpose(1,0,2), B.transpose(1,2,0)).shape                     
    Out[335]: (200, 100, 100)
    In [336]: timeit np.einsum('kim,kmj->kij',A.transpose(1,0,2), B.transpose(1,2,0)).shape              
    774 ms ± 22.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

    Not much difference in performance. But now use matmul:

    In [337]: (A.transpose(1,0,2)@B.transpose(1,2,0)).transpose(1,2,0).shape                             
    Out[337]: (100, 100, 200)
    In [338]: timeit (A.transpose(1,0,2)@B.transpose(1,2,0)).transpose(1,2,0).shape                      
    64.4 ms ± 1.17 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

    and verify that values match (though more often than not, if shapes match, values do to).

    In [339]: np.allclose((A.transpose(1,0,2)@B.transpose(1,2,0)).transpose(1,2,0),np.einsum( 'ikm, jkm->
         ...:  ijk', A, B))                                                                              
    Out[339]: True

    I won't try to measure memory usage, but the time improvement suggests it too is better.

    In some cases einsum is optimized to use matmul. Here that doesn't seem to be the case, though we could play with its parameters. I'm a little surprised the matmul is doing so much better.


    I vaguely recall another SO about matmul taking a short cut when the two arrays are the same thing, A@A. I used B=A in these tests.

    In [350]: timeit (A.transpose(1,0,2)@B.transpose(1,2,0)).transpose(1,2,0).shape                      
    60.6 ms ± 1.17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    In [352]: B2=np.random.rand(100,200,300)                                                             
    In [353]: timeit (A.transpose(1,0,2)@B2.transpose(1,2,0)).transpose(1,2,0).shape                     
    97.4 ms ± 164 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

    But that only made a modest difference.

    In [356]: np.__version__                                                                             
    Out[356]: '1.16.4'

    My BLAS etc is standard Linux, nothing special.