Search code examples
pythonluapytorchtorch

Extracting the top-k value-indices from a 1-D Tensor


Given a 1-D tensor in Torch (torch.Tensor), containing values which can be compared (say floating point), how can we extract the indices of the top-k values in that tensor?

Apart from the brute-force method, I am looking for some API call, that Torch/lua provides, which can perform this task efficiently.


Solution

  • As of pull request #496 Torch now includes a built-in API named torch.topk. Example:

    > t = torch.Tensor{9, 1, 8, 2, 7, 3, 6, 4, 5}
    
    -- obtain the 3 smallest elements
    > res = t:topk(3)
    > print(res)
     1
     2
     3
    [torch.DoubleTensor of size 3]
    
    -- you can also get the indices in addition
    > res, ind = t:topk(3)
    > print(ind)
     2
     4
     6
    [torch.LongTensor of size 3]
    
    -- alternatively you can obtain the k largest elements as follow
    -- (see the API documentation for more details)
    > res = t:topk(3, true)
    > print(res)
     9
     8
     7
    [torch.DoubleTensor of size 3]
    

    At the time of writing the CPU implementation follows a sort and narrow approach (there are plans to improve it in the future). That being said an optimized GPU implementation for cutorch is currently being reviewed.