NumPy provides the very useful tensordot
function. It allows you to compute the product of two ndarrays
along any axes (whose sizes match). I'm having a hard time finding anything similar in PyTorch. mm
works only with 2D arrays, and matmul
has some undesirable broadcasting properties.
Am I missing something? Am I really meant to reshape the arrays to mimic the products I want using mm
?
As mentioned by @McLawrence, this feature is being currently discussed (issue thread).
In the meantime, you could consider torch.einsum()
, e.g.:
import torch
import numpy as np
a = np.arange(36.).reshape(3,4,3)
b = np.arange(24.).reshape(4,3,2)
c = np.tensordot(a, b, axes=([1,0],[0,1]))
print(c)
# [[ 2640. 2838.] [ 2772. 2982.] [ 2904. 3126.]]
a = torch.from_numpy(a)
b = torch.from_numpy(b)
c = torch.einsum("ijk,jil->kl", (a, b))
print(c)
# tensor([[ 2640., 2838.], [ 2772., 2982.], [ 2904., 3126.]], dtype=torch.float64)