Search code examples
pythonpytorchtensorelementwise-operations

Pytorch batch matrix-matrix outer product


Similarly to the question in Pytorch batch matrix vector outer product I have two matrices and would like to compute their outer product, or in other words the pairwise elementwise product.

Shape example: If we have X1 and X2 of shapes of torch.Size([32, 300, 8]) The result should be of size torch.Size([32, 300, 300, 8])


Solution

  • You can add singleton dimensions:

    X1[:, None, ...] * X1[..., None, :]
    

    But Usman Ali's comment is also a good idea. Use torch.einsum:

    torch.einsum('bik,bjk->bijk', X1, X2)