Search code examples
pythonpytorch

How to get index of different top-k at each row in a 2D tensor in Pytorch?


Given:

  • a positive integer tensor A: (batch_size, N) in which zero is the smallest value. For example:
tensor([[4, 3, 1, 4, 2],
        [0, 0, 2, 3, 4],
        [4, 4, 3, 0, 3]])

I want get the index of different k of k-th largest value at each row?

  • k is a list of batch_size elements are chosen randomly in which its values only express only 2 cases: first is largest (so k = 1) second is the smallest but ignore zero, e.g. if the row is [2,3,4,0] so the smallest index is 0 (value 2). (with possibility = 0.7 for largest and 0.3 for smallest)

With the example above, if k = [1,0,0] (1 means get largest, 0 mean smallest) then the output indices will be

output = [0, 2, 2] the correspond values are [4, 2,3]

Notes: please vectorize these calculations.


Solution

  • You can do something like this

    import torch
    
    x = torch.tensor([[4, 3, 1, 4, 2],
                      [0, 0, 2, 3, 4],
                      [4, 4, 3, 0, 3]]).float() # float required for later ops
    
    k = torch.tensor([1, 0, 0]).long()
    
    # set 0 to -1, ie [1, -1, -1]
    k_sign = k + (-1 * (k==0).float())
    
    # flip sign for rows where we want the smallest nonzero index
    x_signed = x * k_sign.unsqueeze(1)
    
    # fill zeros with -inf
    x_filled = x_signed.masked_fill(x==0, float('-inf'))
    
    # grab topk index of each row
    _, output = x_filled.topk(1, dim=1)
    
    output = output.squeeze()
    
    output
    > tensor([0, 2, 2])