I have a tensor with shape (1,3,1):
topk_indices =
tensor([[[6],
[1],
[0]]], device='cuda:0')
and tensor with shape (1,7,16):
in_tensor = tensor([[[0.8359, 0.4812, 0.0297, 0.5219, 0.1595, 0.9066, 0.1965, 0.4639,
0.3890, 0.5890, 0.9705, 0.5475, 0.7896, 0.8881, 0.9037, 0.3273],
[0.3882, 0.7410, 0.3636, 0.7341, 0.3908, 0.1609, 0.7035, 0.5767,
0.7229, 0.9967, 0.8414, 0.9740, 0.5268, 0.0699, 0.1492, 0.1894],
[0.0594, 0.2494, 0.0397, 0.0387, 0.2012, 0.0071, 0.1931, 0.6907,
0.9170, 0.3513, 0.3546, 0.7670, 0.2533, 0.2636, 0.8081, 0.0643],
[0.5611, 0.9417, 0.5857, 0.6360, 0.2088, 0.4931, 0.5275, 0.6227,
0.6943, 0.9345, 0.1184, 0.5150, 0.2502, 0.1045, 0.4600, 0.0599],
[0.8489, 0.5579, 0.2305, 0.7613, 0.0268, 0.3066, 0.4026, 0.0751,
0.1821, 0.4184, 0.8794, 0.9828, 0.8181, 0.2014, 0.1729, 0.9363],
[0.6769, 0.5133, 0.5677, 0.0982, 0.3331, 0.9813, 0.3767, 0.4749,
0.0848, 0.2203, 0.4898, 0.1894, 0.4380, 0.7035, 0.0109, 0.6485],
[0.1694, 0.2560, 0.6920, 0.8976, 0.3633, 0.2947, 0.0479, 0.2422,
0.0622, 0.3856, 0.6020, 0.0316, 0.9366, 0.8137, 0.0105, 0.2612]]],
device='cuda:0')
I would like to create new tensor with shape (1,3,16), that would take only indexes in tensor topk_indices, but leave the origin order of in_tensor along dim=1. So that the result will be:
in_tensor = tensor([[[0.8359, 0.4812, 0.0297, 0.5219, 0.1595, 0.9066, 0.1965, 0.4639,
0.3890, 0.5890, 0.9705, 0.5475, 0.7896, 0.8881, 0.9037, 0.3273],
[0.3882, 0.7410, 0.3636, 0.7341, 0.3908, 0.1609, 0.7035, 0.5767,
0.7229, 0.9967, 0.8414, 0.9740, 0.5268, 0.0699, 0.1492, 0.1894],
[0.1694, 0.2560, 0.6920, 0.8976, 0.3633, 0.2947, 0.0479, 0.2422,
0.0622, 0.3856, 0.6020, 0.0316, 0.9366, 0.8137, 0.0105, 0.2612]]],
device='cuda:0')
so that only 6, 1, 0 indexes are remain, but in original order.
I tried with torch.gather
selected_tensor = torch.gather(in_tensor, 1, topk_indices.repeat(1, 1, in_tensor.shape[-1]).unsqueeze(3).squeeze(3))
but the result, is that the output tensor is sorted according to topk_indices.
Ok, the answer could be achieved with two methods, to achieve initial ordering, you have to "unsort" indexes in topk_indices
, in code snippet topk_sort
, then first solution will need to create empty tensor:
X = torch.empty(in_tensor.shape[0], topk_indices.shape(1), in_tensor.shape[2])
for i in range(in_tensor.shape[0]):
X[i]=in_tensor[i,topk_sort[i],:].squeeze(1)
second with torch.gather:
X = torch.gather(in_tensor, 1, topk_sort.repeat(1,1,in_tensor.shape[-1]))