Search code examples
pythonnumpylinear-algebratensordot-product

Understanding tensordot


After I learned how to use einsum, I am now trying to understand how np.tensordot works.

However, I am a little bit lost especially regarding the various possibilities for the parameter axes.

To understand it, as I have never practiced tensor calculus, I use the following example:

A = np.random.randint(2, size=(2, 3, 5))
B = np.random.randint(2, size=(3, 2, 4))

In this case, what are the different possible np.tensordot and how would you compute it manually?


Solution

  • The idea with tensordot is pretty simple - We input the arrays and the respective axes along which the sum-reductions are intended. The axes that take part in sum-reduction are removed in the output and all of the remaining axes from the input arrays are spread-out as different axes in the output keeping the order in which the input arrays are fed.

    Let's look at few sample cases with one and two axes of sum-reductions and also swap the input places and see how the order is kept in the output.

    I. One axis of sum-reduction

    Inputs :

     In [7]: A = np.random.randint(2, size=(2, 6, 5))
       ...:  B = np.random.randint(2, size=(3, 2, 4))
       ...: 
    

    Case #1:

    In [9]: np.tensordot(A, B, axes=((0),(1))).shape
    Out[9]: (6, 5, 3, 4)
    
    A : (2, 6, 5) -> reduction of axis=0
    B : (3, 2, 4) -> reduction of axis=1
    
    Output : `(2, 6, 5)`, `(3, 2, 4)` ===(2 gone)==> `(6,5)` + `(3,4)` => `(6,5,3,4)`
    

    Case #2 (same as case #1 but the inputs are fed swapped):

    In [8]: np.tensordot(B, A, axes=((1),(0))).shape
    Out[8]: (3, 4, 6, 5)
    
    B : (3, 2, 4) -> reduction of axis=1
    A : (2, 6, 5) -> reduction of axis=0
    
    Output : `(3, 2, 4)`, `(2, 6, 5)` ===(2 gone)==> `(3,4)` + `(6,5)` => `(3,4,6,5)`.
    

    II. Two axes of sum-reduction

    Inputs :

    In [11]: A = np.random.randint(2, size=(2, 3, 5))
        ...: B = np.random.randint(2, size=(3, 2, 4))
        ...: 
    

    Case #1:

    In [12]: np.tensordot(A, B, axes=((0,1),(1,0))).shape
    Out[12]: (5, 4)
    
    A : (2, 3, 5) -> reduction of axis=(0,1)
    B : (3, 2, 4) -> reduction of axis=(1,0)
    
    Output : `(2, 3, 5)`, `(3, 2, 4)` ===(2,3 gone)==> `(5)` + `(4)` => `(5,4)`
    

    Case #2:

    In [14]: np.tensordot(B, A, axes=((1,0),(0,1))).shape
    Out[14]: (4, 5)
    
    B : (3, 2, 4) -> reduction of axis=(1,0)
    A : (2, 3, 5) -> reduction of axis=(0,1)
    
    Output : `(3, 2, 4)`, `(2, 3, 5)` ===(2,3 gone)==> `(4)` + `(5)` => `(4,5)`
    

    We can extend this to as many axes as possible.