Search code examples
pythonpytorchtorch

How to get topk's values with its indices (2D)?


I have two 3D tensor and I want to use one's top k indices get another top k.

For example for the following tensor

a = torch.tensor([[[1], [2], [3]],
                  [[4], [5], [6]]])

b = torch.tensor([[[7,1], [8,2], [9,3]],
                  [[10,4],[11,5],[12,6]]])

pytorch's topk function will give me the following.

top_tensor, indices = torch.topk(a, 2, dim=1)

# top_tensor: tensor([[[3], [2]],
#                    [[6],  [5]]])

# indices: tensor([[[2], [1]],
#                 [[2],  [1]]])

But I want to use the result of a, map to b.

# use indices to do something for b, get torch.tensor([[[8,2], [9,3]],
#                                                      [[11,5],[12,6]]])

In this case, I don't know the real values of b, so I can't use topk to b.

on the other word, I want to get a funtion foo_slice as following:

top_tensor, indices = torch.topk(a, 2, dim=1)
# top_tensor == foo_slice(a, indices)

Is there any approach to achieve this using pytorch?

Thanks!


Solution

  • The solution what you are looking for is here

    So the code based solution to your problem is as follows

    #inputs are changed in order from the above ques
    
    a = torch.tensor([[[1], [2], [3]],
                      [[5], [6], [4]]])
    
    b = torch.tensor([[[7,1], [8,2], [9,3]],
                      [[11,5],[12,6],[10,4]]])
    
    top_tensor, indices = torch.topk(a, 2, dim=1)
    
    v = [indices.view(-1,2)[i] for i in range(0,indices.shape[1])]
    
    
    new_tensor = []
    for i,f in enumerate(v):
          new_tensor.append(torch.index_select(b[i], 0, f))
    print(new_tensor ) #[tensor([[9, 3],
                       #         [8, 2]]),
                       #tensor([[12,  6],
                       #        [11,  5]])]