Given:
tensor([[4, 3, 1, 4, 2],
[0, 0, 2, 3, 4],
[4, 4, 3, 0, 3]])
I want get the index of different k of k-th largest value at each row?
batch_size
elements are chosen randomly in which its values only express only 2 cases:
first is largest (so k = 1) second is the smallest but ignore zero, e.g. if the row is [2,3,4,0] so the smallest index is 0 (value 2). (with possibility = 0.7 for largest and 0.3 for smallest)With the example above, if k = [1,0,0]
(1 means get largest, 0 mean smallest) then the output indices will be
output = [0, 2, 2]
the correspond values are [4, 2,3]
Notes: please vectorize these calculations.
You can do something like this
import torch
x = torch.tensor([[4, 3, 1, 4, 2],
[0, 0, 2, 3, 4],
[4, 4, 3, 0, 3]]).float() # float required for later ops
k = torch.tensor([1, 0, 0]).long()
# set 0 to -1, ie [1, -1, -1]
k_sign = k + (-1 * (k==0).float())
# flip sign for rows where we want the smallest nonzero index
x_signed = x * k_sign.unsqueeze(1)
# fill zeros with -inf
x_filled = x_signed.masked_fill(x==0, float('-inf'))
# grab topk index of each row
_, output = x_filled.topk(1, dim=1)
output = output.squeeze()
output
> tensor([0, 2, 2])