Search code examples
pythonpytorchtensormatrix-indexing

Top K indices of a multi-dimensional tensor


I have a 2D tensor and I want to get the indices of the top k values. I know about pytorch's topk function. The problem with pytorch's topk function is, it computes the topk values over some dimension. I want to get topk values over both dimensions.

For example for the following tensor

a = torch.tensor([[4, 9, 7, 4, 0],
        [8, 1, 3, 1, 0],
        [9, 8, 4, 4, 8],
        [0, 9, 4, 7, 8],
        [8, 8, 0, 1, 4]])

pytorch's topk function will give me the following.

values, indices = torch.topk(a, 3)

print(indices)
# tensor([[1, 2, 0],
#        [0, 2, 1],
#        [0, 1, 4],
#        [1, 4, 3],
#        [1, 0, 4]])

But I want to get the following

tensor([[0, 1],
        [2, 0],
        [3, 1]])

This is the indices of 9 in the 2D tensor.

Is there any approach to achieve this using pytorch?


Solution

  • v, i = torch.topk(a.flatten(), 3)
    print (np.array(np.unravel_index(i.numpy(), a.shape)).T)
    

    Output:

    [[3 1]
     [2 0]
     [0 1]]
    
    1. Flatten and find top k
    2. Convert 1D indices to 2D using unravel_index