Search code examples
pythonpytorchmatrix-multiplicationdimensiontensordot

Dimension of tensordot between 2 3D tensors


I have a rather quick question on tensordot operation. I'm trying to figure out if there is a way to perform a tensordot product between two tensors to get the right output of shape that I want. One of the tensors is B X L X D dimensions and the other one is B X 1 X D dimensions and I'm trying to figure out if it's possible to end up with B X D matrix at the end.

Currently I'm looping through the B dimension and performing a matrix multiplication between 1 X D and D X L (transposing L X D) matrices and stacking them to end up with B X L matrix at the end. This is obviously not the fastest way possible as a loop can be expensive. Would it be possible to get the desired output of B X D shape by performing a quick tensordot? I cannot seem to figure out a way to get rid of 1 of the B's.

Any insight or direction would be very much appreciated.


Solution

  • One option

    Is to use torch.bmm() which does exactly that (docs).

    It takes tensors of shape (b, n, m) and (b, m, p) and returns the batch matrix multiplication of shape (b, n, p).

    (I assume you ment a result of B X L since the matrix multiplication of 1 X D and D X L is of shape 1 X L and not 1 X D).

    In your case:

    import torch
    B, L, D = 32, 10, 512
    
    a = torch.randn(B, 1, D)    #shape (B X 1 X D)
    b = torch.randn(B, L, D)    #shape (B X L X D)
    
    b = b.transpose(1,2)        #shape (B X D X L)
    
    result = torch.bmm(a, b)
    
    result = result.squeeze()
    print(result.shape)
    >>> torch.Size([32, 10])
    

    Alternatively

    You can use torch.einsum(), which is more compact but less readable in my opinion:

    import torch
    B, L, D = 32, 10, 512
    
    a = torch.randn(B, 1, D)
    b = torch.randn(B, L, D)
    
    result = torch.einsum('abc, adc->ad', a, b)
    
    print(result.shape)
    >>> torch.Size([32, 10])
    

    The squeeze at the end is in order to make your result of shape (32, 10) instead of shape (32, 1, 10).