Search code examples
pythonkeraspytorch

How to perform multiplication along axes in pytorch?


I have 2 tensors X and Y - X has shape (20,4,300) and Y has shape(20,300) . How to perform multiplication such that I have an result of shape (20,4). The corresponding techinique in keras is

doc_product = Dot(axes=(2,1))([X,Y])

I would like to know how the same can be done in pytorch?


Solution

  • Your most versatile function for matrix multiplication is torch.einsum: it allows you specify the dimensions along which to multiply and the order of the dimensions of the output tensor.
    In your case it would look like:

    dot_product = torch.einsum('bij,bj->bi')