Search code examples

New tensor based on old tensor and 2d indices

I previously asked: PyTorch tensors: new tensor based on old tensor and indices

I have the same problem now but need to use a 2d index tensor.

I have a tensor col of size [batch_size, k] with values between 0 and k-1:

idx = tensor([[0,1,2,0],

and the following matrix:

x = tensor([[[0, 9],
 [1, 8],
 [2, 3],
 [4, 9]],
 [[0, 0],
 [1, 2],
 [3, 4],
 [5, 6]]])

I want to create a new tensor which contains the rows specified in index, in that order. So I want:

tensor([[[0, 9],
 [1, 8],
 [2, 3],
 [0, 9]],
 [[0, 0],
 [5, 6],
 [3, 4],
 [3, 4]]])

Currently I'm doing it like this:

for i, batch in enumerate(t):
    t[i] = batch[col[i]]

How can I do it more efficiently?


  • you should use torch gather to achieve this. It would actually also work for the otehr question you linked, but this is left as an exercise to the reader :p

    Let us call idx your first tensor and source the second one. Their respective dimensions are (B,N) and (B, K, p) (with p=2 in your example), and all values of idx are between 0 and K-1.

    So to use torch gather, we first need to express your operation as a nested for loop. In your case, what you actually want to achieve is :

    for b in range(B):
        for i in range(N):
            for j in range(p):
                # This kind of nested for loops is what torch.gether actually does
                target[b,i,j] = source[b, idx[b,i,j], j]

    But that does not work because idx is a 2D tensor, not a 3D one. Well, no big deal, let's make it a 3D tensor. We want it to have shape (B, N, p) and be actually constant along the last dimension. Then we can replace the for loop with a call to gather:

    reshaped_idx = idx.unsqueeze(-1).repeat(1,1,2)
    target = source.gather(1, reshaped_idx)
    # or : target = torch.gather(source, 1, reshaped_idx)