Search code examples

How to find argmax/argmin in only selected indices of a Pytorch tensor

I have a distance tensor

tensor([ 5,  10,  2,  3,  4], device='cuda:0')

And a indices tensor

tensor([ 0,  2,  3], device='cuda:0')

I want to find argmax of the distance tensor but only on the subset of indices specified by the indices tensor.

In this example, I would be looking at 0th, 2nd and 3rd elements of distance tensor (values 5, 2, 3) and returning the index 0 (the biggest value - 5 is on the 0th place in the distance tensor)

tensor([ 0], device='cuda:0')

Is something like this feasible without the use of for cycles? Thanks


  • Here an example. You can check that the maximum dist value for the selected subset of items is at index zero, and the final output tensor contains value zero too. Note that as we are using 1D tensors, dim argument in torch.index_select is zero.

    import torch
    dist = torch.randn(5, 1)
    #tensor([[ 0.3392],
    #        [ 0.4472],
    #        [ 0.1398],
    #        [-1.0379],
    #        [ 0.2950]])
    idx = torch.tensor([0,2,3])
    #tensor([0, 2, 3])

    Just using max function and tensor filtering:

    max_val = torch.max(torch.index_select(dist, 0, idx)).item()
    (dist == max_val).nonzero(as_tuple=True)[0]