Search code examples
torch

3D tensor * 2D tensor dot in Torch


In Theano, when I have a 3D tensor x with shape [A,B,C] and a 2D tensor y with shape [C,D], then theano.tensor.dot(x, y) returns a 3D tensor with shape [A,B,D].

What would be the equivalent operation in Torch? torch.dot doesn't seem to do that, and x * y and torch.mm complain that they want a 2D tensor for both arguments, and torch.bmm wants 3D tensors for both arguments.


Solution

  • As @smhx proposed, the possible solution is to repeat the second tensor (there's a way to do it without memory allocating) and then perform a batch matrix matrix product:

    function repeatNoCopy(tensor, k)
        local tens_size = tensor:size():totable()
        return torch.expand(tensor:view(1, unpack(tens_size)), k, unpack(tens_size))
    end
    
    A = torch.rand(3, 2, 5)
    B = torch.rand(5, 4)
    B_rep = repeatNoCopy(B, 3)
    
    result = torch.bmm(A, B_rep)
    
    print(result)
    > [torch.DoubleTensor of size 3x2x4]