Search code examples
pythonmultidimensional-arraypytorchtensor

Select multiple indices in an axis of pytorch tensor


My actual problem is in a higher dimension, but I am posting it in a smaller dimension to make it easy to visualize.

I have a tensor of shape (2,3,4): x = torch.randn(2, 3, 4)

tensor([[[-0.9118,  1.4676, -0.4684, -0.6343],
         [ 1.5649,  1.0218, -1.3703,  1.8961],
         [ 0.8652,  0.2491, -0.2556,  0.1311]],

        [[ 0.5289, -1.2723,  2.3865,  0.0222],
         [-1.5528, -0.4638, -0.6954,  0.1661],
         [-1.8151, -0.4634,  1.6490,  0.6957]]])

From this tensor, I need to select rows given by a list of indices along axis-1.

Example,

indices = torch.tensor([0, 2])

Expected Output:

tensor([[[-0.9118,  1.4676, -0.4684, -0.6343]],
        [[-1.8151, -0.4634,  1.6490,  0.6957]]])

Output Shape: (2,1,4)

Explanation: Select 0th row from x[0], select 2nd row from x[1]. (Came from indices)

I tried using index_select like this:

torch.index_select(x, 1, indices)

But the problem is that it is selecting the 0th and 2nd row for each item in x. It looks like it needs some modification I could not figure it out at the moment.


Solution

  • In your case, this is quite straightforward. An easy way to navigate through two dimensions in parallel is to use a range on the first axis and your indexing tensor on the second:

    >>> x[range(len(indices)), indices]
    tensor([[-0.9118,  1.4676, -0.4684, -0.6343],
            [-1.8151, -0.4634,  1.6490,  0.6957]])
    

    In more general cases though, this would require the use of torch.gather:

    • First expand indices such that it has enough dimensions:

      index = indices[:,None,None].expand(x.size(0), -1, x.size(-1))
      
    • Then you can apply the function on x and index and squeeze dim=1:

      >>> x.gather(dim=-2, index=index)[:,0]
      tensor([[-0.9118,  1.4676, -0.4684, -0.6343],
              [-1.8151, -0.4634,  1.6490,  0.6957]])