Search code examples
indexingpytorchtorch

Create new tensor from of another tensor of indexes


I have a tensor with shape (1,3,1):

topk_indices =
    tensor([[[6],
         [1],
         [0]]], device='cuda:0')

and tensor with shape (1,7,16):

in_tensor = tensor([[[0.8359, 0.4812, 0.0297, 0.5219, 0.1595, 0.9066, 0.1965, 0.4639,
                      0.3890, 0.5890, 0.9705, 0.5475, 0.7896, 0.8881, 0.9037, 0.3273],
                     [0.3882, 0.7410, 0.3636, 0.7341, 0.3908, 0.1609, 0.7035, 0.5767,
                      0.7229, 0.9967, 0.8414, 0.9740, 0.5268, 0.0699, 0.1492, 0.1894],
                     [0.0594, 0.2494, 0.0397, 0.0387, 0.2012, 0.0071, 0.1931, 0.6907,
                      0.9170, 0.3513, 0.3546, 0.7670, 0.2533, 0.2636, 0.8081, 0.0643],
                     [0.5611, 0.9417, 0.5857, 0.6360, 0.2088, 0.4931, 0.5275, 0.6227,
                      0.6943, 0.9345, 0.1184, 0.5150, 0.2502, 0.1045, 0.4600, 0.0599],
                     [0.8489, 0.5579, 0.2305, 0.7613, 0.0268, 0.3066, 0.4026, 0.0751,
                      0.1821, 0.4184, 0.8794, 0.9828, 0.8181, 0.2014, 0.1729, 0.9363],
                     [0.6769, 0.5133, 0.5677, 0.0982, 0.3331, 0.9813, 0.3767, 0.4749,
                      0.0848, 0.2203, 0.4898, 0.1894, 0.4380, 0.7035, 0.0109, 0.6485],
                     [0.1694, 0.2560, 0.6920, 0.8976, 0.3633, 0.2947, 0.0479, 0.2422,
                      0.0622, 0.3856, 0.6020, 0.0316, 0.9366, 0.8137, 0.0105, 0.2612]]],
   device='cuda:0')

I would like to create new tensor with shape (1,3,16), that would take only indexes in tensor topk_indices, but leave the origin order of in_tensor along dim=1. So that the result will be:

in_tensor = tensor([[[0.8359, 0.4812, 0.0297, 0.5219, 0.1595, 0.9066, 0.1965, 0.4639,
                      0.3890, 0.5890, 0.9705, 0.5475, 0.7896, 0.8881, 0.9037, 0.3273],
                     [0.3882, 0.7410, 0.3636, 0.7341, 0.3908, 0.1609, 0.7035, 0.5767,
                      0.7229, 0.9967, 0.8414, 0.9740, 0.5268, 0.0699, 0.1492, 0.1894],
                     [0.1694, 0.2560, 0.6920, 0.8976, 0.3633, 0.2947, 0.0479, 0.2422,
                      0.0622, 0.3856, 0.6020, 0.0316, 0.9366, 0.8137, 0.0105, 0.2612]]],
                      device='cuda:0')

so that only 6, 1, 0 indexes are remain, but in original order.

I tried with torch.gather

selected_tensor = torch.gather(in_tensor, 1, topk_indices.repeat(1, 1, in_tensor.shape[-1]).unsqueeze(3).squeeze(3))

but the result, is that the output tensor is sorted according to topk_indices.


Solution

  • Ok, the answer could be achieved with two methods, to achieve initial ordering, you have to "unsort" indexes in topk_indices, in code snippet topk_sort, then first solution will need to create empty tensor:

    X = torch.empty(in_tensor.shape[0], topk_indices.shape(1), in_tensor.shape[2])
    for i in range(in_tensor.shape[0]):
        X[i]=in_tensor[i,topk_sort[i],:].squeeze(1)
    

    second with torch.gather:

    X = torch.gather(in_tensor, 1, topk_sort.repeat(1,1,in_tensor.shape[-1]))