Search code examples
pytorchtensoreinsummatmul

Pytorch - do matrix multiplications from slices of 2 tensors


If there are 2 tensors of the following sizes.

A = [N x L x T]

B = [N x T x K]

Then I would like to do a matrix multiplication of slices from the 2 tensors. like below.

matmul_slice = A[0,:,:] @ B[0,:,:] = [L x T] @ [T x K] = [L x K]

Then I would like to do it N times along the dimension = 0. So that I end up with the final matrix with size [N,L,K]

I do not want to use loop over N since it slows down the computation. I have been playing around with torch.matmul and einsum, but I cannot get the correct answer.

How can I achieve this in a compact way?


Solution

  • torch.bmm is what your need, although torch.matmul should be equivalent in your case. I think you should recheck your computation.