I am trying to us torch.bmm to do the following matrix operation,
If matrix is a M * N tensor, batch is a N * B tensor, how can i achieve, In each batch, matrix @ batch_i, which gives M, and put the batch size together, the output tensor looks like M * B
There two questions here,
1.To use torch.bmm, it seems need both matrix need be batch, but my first input is not
I guess it is the same question for Numpy users
It seems that torch.einsum('ij,jbc->ibc', A, B) will solve the question