I have an B x M x N tensor, X, and I have and B x 1 tensor, Y, which corresponds to the index of tensor X at dimension=1 that I want to keep. What is the shorthand for this slice so that I can avoid a loop?
Essentially I want to do this:
Z = torch.zeros(B,N)
for i in range(B):
Z[i] = X[i][Y[i]]
The answer provided by @Hammad is short and perfect for the job. Here's an alternative solution if you're interested in using some less known Pytorch built-ins. We will use torch.gather
(similarly you can achieve this with numpy.take
).
The idea behind torch.gather
is to construct a new tensor-based on two identically shaped tensors containing the indices (here ~ Y
) and the values (here ~ X
).
The operation performed is Z[i][j][k] = X[i][Y[i][j][k]][k]
.
Since X
's shape is (B, M, N)
and Y
shape is (B, 1)
we are looking to fill in the blanks inside Y
such that Y
's shape becomes (B, 1, N)
.
This can be achieved with some axis manipulation:
>>> Y.expand(-1, N)[:, None] # expand to dim=1 to N and unsqueeze dim=1
The actual call to torch.gather
will be:
>>> X.gather(dim=1, index=Y.expand(-1, N)[:, None])
Which you can reshape to (B, N)
by adding in [:, 0]
.
This function can be very effective in tricky scenarios...