Search code examples
pytorch

How to select indices according to another tensor in pytorch?


I have two tensors a and b. And I want to retrive the values of b according to the positions of max values in a. That is,

max_values, indices = torch.max(a, dim=0, keepdim=True)

However, I do not know how to use the indices to retrive the values of b. Can anybody helps to solve it? Thanks a lot!!

Edit:

Sorry for not describing my problem concretely. To give a minimal example, the value of tensors a and b are:

    a = torch.tensor([[1,2,4],[2,1,3]])
    b = torch.tensor([[10,24,2],[23,4,5]])

If I use torch.max(a, dim=0, keepdim=True), it will return:

    max:  tensor([[2, 2, 4]])
    indices:   tensor([[1, 0, 0]])

What I want to obtain is the selected value of tensor b according to the indices of max values of a in dim=0, that is,

    tensor([[23, 24, 2]])

I have tried b[indices], whereas the result is not what I want:

    tensor([[[ 2,  3,  5],
             [10, 30, 40],
             [10, 30, 40]]])

Solution

  • You can use torch.gather:

    torch.gather(b, dim=0, index=indices)