My actual problem is in a higher dimension, but I am posting it in a smaller dimension to make it easy to visualize.
I have a tensor of shape (2,3,4):
x = torch.randn(2, 3, 4)
tensor([[[-0.9118, 1.4676, -0.4684, -0.6343],
[ 1.5649, 1.0218, -1.3703, 1.8961],
[ 0.8652, 0.2491, -0.2556, 0.1311]],
[[ 0.5289, -1.2723, 2.3865, 0.0222],
[-1.5528, -0.4638, -0.6954, 0.1661],
[-1.8151, -0.4634, 1.6490, 0.6957]]])
From this tensor, I need to select rows given by a list of indices along axis-1
.
Example,
indices = torch.tensor([0, 2])
Expected Output:
tensor([[[-0.9118, 1.4676, -0.4684, -0.6343]],
[[-1.8151, -0.4634, 1.6490, 0.6957]]])
Output Shape: (2,1,4)
Explanation: Select 0th row from x[0], select 2nd row from x[1]. (Came from indices)
I tried using index_select
like this:
torch.index_select(x, 1, indices)
But the problem is that it is selecting the 0th and 2nd row for each item in x. It looks like it needs some modification I could not figure it out at the moment.
In your case, this is quite straightforward. An easy way to navigate through two dimensions in parallel is to use a range on the first axis and your indexing tensor on the second:
>>> x[range(len(indices)), indices]
tensor([[-0.9118, 1.4676, -0.4684, -0.6343],
[-1.8151, -0.4634, 1.6490, 0.6957]])
In more general cases though, this would require the use of torch.gather
:
First expand indices such that it has enough dimensions:
index = indices[:,None,None].expand(x.size(0), -1, x.size(-1))
Then you can apply the function on x
and index
and squeeze dim=1
:
>>> x.gather(dim=-2, index=index)[:,0]
tensor([[-0.9118, 1.4676, -0.4684, -0.6343],
[-1.8151, -0.4634, 1.6490, 0.6957]])