Search code examples
pythonnumpypytorchtorchalgebra

Numpy/PyTorch funny tensor product


I've got a 4 dimensional torch tensor parameter defined like this :

nn.parameter.Parameter(data=torch.Tensor((13,13,13,13)), requires_grad=True)

and four tensors with dims (batch_size,13) (or one tensor with dims (batch_size,4,13)). I'd like to get a tensor with dims (batch_size) equal to the formula at the end of this picture : [EDIT: I made a mistake in the first pict, I've corrected it] enter image description here I've seen in the torch documentation the function tensordot, but I can't manage to make it work by myself.


Solution

  • whenever you have a funny tensor product torch.einsum (or numpy.einsum) is your friend:

    batch_size = 5
    A = torch.rand(13, 13, 13, 13)
    a = torch.rand(batch_size, 13)
    b = torch.rand(batch_size, 13)
    c = torch.rand(batch_size, 13)
    d = torch.rand(batch_size, 13)
    B = torch.einsum('ijkl,bi,bj,bk,bl->b', A, a, b, c, d)