Search code examples
pythonnumpylinear-algebratensordot

Most efficient way to perform large dot/tensor dot products while only keeping diagonal entries


I'm trying to figure out a way to use numpy to perform the following algebra in the most time-efficient way possible:

Given a 3D matrix/tensor, A, with shape (n, m, p) and a 2D matrix/tensor, B, with shape (n, p), calculate C_ij = sum_over_k (A_ijk * B_ik), where the resulting matrix C would have dimension (n, m).

I've tried two ways to do this. One is to loop through the first dimension, and calculate a regular dot product each time. The other method is to use np.tensordot(A, B.T) to calculate a result with shape (n, m, n), and then take the diagonal elements along 1st and 3rd dimension. Both methods are shown below.

First method:

C = np.zeros((n,m))

for i in range(n):

  C[i] = np.dot(A[i], B[i])

Second method:

C = np.diagonal(np.tensordot(A, B.T, axes = 1), axis1=0, axis2=2).T

However, because n is a very large number, the loop over n in the first method is costing a lot of time. The second method calculates too many unnecessary entries to obtain that huge (n, m, n)matrix, and is also costing too much time, I'm wondering if there's any efficient way to do this?


Solution

  • Define 2 arrays:

    In [168]: A = np.arange(2*3*4).reshape(2,3,4); B = np.arange(2*4).reshape(2,4)                               
    

    Your iterative approach:

    In [169]: [np.dot(a,b) for a,b in zip(A,B)]                                                                  
    Out[169]: [array([14, 38, 62]), array([302, 390, 478])]
    

    The einsum practically writes itself from your C_ij = sum_over_k (A_ijk * B_ik):

    In [170]: np.einsum('ijk,ik->ij', A, B)                                                                      
    Out[170]: 
    array([[ 14,  38,  62],
           [302, 390, 478]])
    

    @, matmul, was added to perform batch dot products; here the i dimension is the batch one. Since it uses the last of A and 2nd to the last of B for the dot summation, we have to temporarily expand B to (2,4,1):

    In [171]: A@B[...,None]                                                                                      
    Out[171]: 
    array([[[ 14],
            [ 38],
            [ 62]],
    
           [[302],
            [390],
            [478]]])
    In [172]: (A@B[...,None])[...,0]                                                                             
    Out[172]: 
    array([[ 14,  38,  62],
           [302, 390, 478]])
    

    Typically matmul is fastest, since it passes the task to BLAS like code.