Search code examples
numpytensordot

NumPy Tensordot axes=2


I know there are many questions about tensordot, and I've skimmed some of the 15 page mini-book answers that people I'm sure spent hours making, but I haven't found an explanation of what axes=2 does.

This made me think that np.tensordot(b,c,axes=2) == np.sum(b * c), but as an array:

b = np.array([[1,10],[100,1000]])
c = np.array([[2,3],[5,7]])
np.tensordot(b,c,axes=2)
Out: array(7532)

But then this failed:

a = np.arange(30).reshape((2,3,5))
np.tensordot(a,a,axes=2)

If anyone can provide a short, concise explanation of np.tensordot(x,y,axes=2), and only axes=2, then I would gladly accept it.


Solution

  • In [70]: a = np.arange(24).reshape(2,3,4)
    In [71]: np.tensordot(a,a,axes=2)
    Traceback (most recent call last):
      File "<ipython-input-71-dbe04e46db70>", line 1, in <module>
        np.tensordot(a,a,axes=2)
      File "<__array_function__ internals>", line 5, in tensordot
      File "/usr/local/lib/python3.8/dist-packages/numpy/core/numeric.py", line 1116, in tensordot
        raise ValueError("shape-mismatch for sum")
    ValueError: shape-mismatch for sum
    

    In my previous post I deduced that axis=2 translates to axes=([-2,-1],[0,1])

    How does numpy.tensordot function works step-by-step?

    In [72]: np.tensordot(a,a,axes=([-2,-1],[0,1]))
    Traceback (most recent call last):
      File "<ipython-input-72-efdbfe6ff0d3>", line 1, in <module>
        np.tensordot(a,a,axes=([-2,-1],[0,1]))
      File "<__array_function__ internals>", line 5, in tensordot
      File "/usr/local/lib/python3.8/dist-packages/numpy/core/numeric.py", line 1116, in tensordot
        raise ValueError("shape-mismatch for sum")
    ValueError: shape-mismatch for sum
    

    So that's trying to do a double axis reduction on the last 2 dimensions of the first a, and the first 2 dimensions of the second a. With this a that's a dimensions mismatch. Evidently this axes was intended for 2d arrays, without much thought given to 3d ones. It is not a 3 axis reduction.

    These single digit axes values are something that some developer thought would be convenient, but that does not mean they were rigorously thought out or tested.

    The tuple axes gives you more control:

    In [74]: np.tensordot(a,a,axes=[(0,1,2),(0,1,2)])
    Out[74]: array(4324)
    In [75]: np.tensordot(a,a,axes=[(0,1),(0,1)])
    Out[75]: 
    array([[ 880,  940, 1000, 1060],
           [ 940, 1006, 1072, 1138],
           [1000, 1072, 1144, 1216],
           [1060, 1138, 1216, 1294]])