Search code examples
torch

Torch - Query matrix with another matrix


I have a m x n tensor (Tensor 1) and another k x 2 tensor (Tensor 2) and I wish to extract all the values of Tensor 1 using indices based on Tensor 2. For example;

Tensor1
  1   2   3   4   5
  6   7   8   9  10
 11  12  13  14  15
 16  17  18  19  20
[torch.DoubleTensor of size 4x5]

Tensor2
 2  1
 3  5
 1  1
 4  3
[torch.DoubleTensor of size 4x2]

And the function would yield;

6
15
1
18

Solution

  • The first solution that comes into mind is to simply loop through indexes and pick the correspoding values:

    function get_elems_simple(tensor, indices)
        local res = torch.Tensor(indices:size(1)):typeAs(tensor)
        local i = 0
        res:apply(
            function () 
                i = i + 1
                return tensor[indices[i]:clone():storage()] 
            end)
        return res
    end
    

    Here tensor[indices[i]:clone():storage()] is just a generic way to pick an element from a multi-dimensional tensor. In k-dimensional case this is exactly analogous to tensor[{indices[i][1], ... , indices[i][k]}].

    This method works fine if you don't have to extract lots of values (the bottleneck is :apply method which is not able to use many optimization techniques and SIMD instructions because the function it executes is a black box). The job can be done way more efficiently: the method :index does exactly what you need... with a one-dimensional tensor. Multi-dimensional target/index tensors need to be flattened:

    function flatten_indices(sp_indices, shape)
        sp_indices = sp_indices - 1
        local n_elem, n_dim = sp_indices:size(1), sp_indices:size(2)
        local flat_ind = torch.LongTensor(n_elem):fill(1)
    
        local mult = 1
        for d = n_dim, 1, -1 do
            flat_ind:add(sp_indices[{{}, d}] * mult)
            mult = mult * shape[d]
        end
        return flat_ind
    end
    
    function get_elems_efficient(tensor, sp_indices)
        local flat_indices = flatten_indices(sp_indices, tensor:size()) 
        local flat_tensor = tensor:view(-1)
        return flat_tensor:index(1, flat_indices)
    end
    

    The difference is drastic:

    n = 500000
    k = 100
    a = torch.rand(n, k)
    ind = torch.LongTensor(n, 2)
    ind[{{}, 1}]:random(1, n)
    ind[{{}, 2}]:random(1, k)
    
    elems1 = get_elems_simple(a, ind)      # 4.53 sec
    elems2 = get_elems_efficient(a, ind)   # 0.05 sec
    
    print(torch.all(elems1:eq(elems2)))    # true