I've got a 4 dimensional torch tensor parameter defined like this :
nn.parameter.Parameter(data=torch.Tensor((13,13,13,13)), requires_grad=True)
and four tensors with dims (batch_size,13) (or one tensor with dims (batch_size,4,13)).
I'd like to get a tensor with dims (batch_size) equal to the formula at the end of this picture :
[EDIT: I made a mistake in the first pict, I've corrected it]
I've seen in the torch documentation the function tensordot, but I can't manage to make it work by myself.
whenever you have a funny tensor product torch.einsum
(or numpy.einsum
) is your friend:
batch_size = 5
A = torch.rand(13, 13, 13, 13)
a = torch.rand(batch_size, 13)
b = torch.rand(batch_size, 13)
c = torch.rand(batch_size, 13)
d = torch.rand(batch_size, 13)
B = torch.einsum('ijkl,bi,bj,bk,bl->b', A, a, b, c, d)