Search code examples
numpypytorchslicetensornumpy-slicing

batched tensor slice, slice B x N x M with B x 1


I have an B x M x N tensor, X, and I have and B x 1 tensor, Y, which corresponds to the index of tensor X at dimension=1 that I want to keep. What is the shorthand for this slice so that I can avoid a loop?

Essentially I want to do this:

Z = torch.zeros(B,N)

for i in range(B):
    Z[i] = X[i][Y[i]]

Solution

  • The answer provided by @Hammad is short and perfect for the job. Here's an alternative solution if you're interested in using some less known Pytorch built-ins. We will use torch.gather (similarly you can achieve this with numpy.take).

    The idea behind torch.gather is to construct a new tensor-based on two identically shaped tensors containing the indices (here ~ Y) and the values (here ~ X).

    The operation performed is Z[i][j][k] = X[i][Y[i][j][k]][k].

    Since X's shape is (B, M, N) and Y shape is (B, 1) we are looking to fill in the blanks inside Y such that Y's shape becomes (B, 1, N).

    This can be achieved with some axis manipulation:

    >>> Y.expand(-1, N)[:, None] # expand to dim=1 to N and unsqueeze dim=1
    

    The actual call to torch.gather will be:

    >>> X.gather(dim=1, index=Y.expand(-1, N)[:, None])
    

    Which you can reshape to (B, N) by adding in [:, 0].


    This function can be very effective in tricky scenarios...