Search code examples
matrix-multiplicationpytorchdot-product

Product of PyTorch tensors along arbitrary axes à la NumPy's `tensordot`


NumPy provides the very useful tensordot function. It allows you to compute the product of two ndarrays along any axes (whose sizes match). I'm having a hard time finding anything similar in PyTorch. mm works only with 2D arrays, and matmul has some undesirable broadcasting properties.

Am I missing something? Am I really meant to reshape the arrays to mimic the products I want using mm?


Solution

  • As mentioned by @McLawrence, this feature is being currently discussed (issue thread).

    In the meantime, you could consider torch.einsum(), e.g.:

    import torch
    import numpy as np
    
    a = np.arange(36.).reshape(3,4,3)
    b = np.arange(24.).reshape(4,3,2)
    c = np.tensordot(a, b, axes=([1,0],[0,1]))
    print(c)
    # [[ 2640.  2838.] [ 2772.  2982.] [ 2904.  3126.]]
    
    a = torch.from_numpy(a)
    b = torch.from_numpy(b)
    c = torch.einsum("ijk,jil->kl", (a, b))
    print(c)
    # tensor([[ 2640.,  2838.], [ 2772.,  2982.], [ 2904.,  3126.]], dtype=torch.float64)