Search code examples
numpypytorchopenblas

Pytorch or Numpy Batch Matrix Operation


I am trying to us torch.bmm to do the following matrix operation,

If matrix is a M * N tensor, batch is a N * B tensor, how can i achieve, In each batch, matrix @ batch_i, which gives M, and put the batch size together, the output tensor looks like M * B

There two questions here,

1.To use torch.bmm, it seems need both matrix need be batch, but my first input is not

  1. The batch size need be the first dimension, while my batch size in the end

I guess it is the same question for Numpy users


Solution

  • It seems that torch.einsum('ij,jbc->ibc', A, B) will solve the question