I previously asked: PyTorch tensors: new tensor based on old tensor and indices
I have the same problem now but need to use a 2d index tensor.
I have a tensor col of size [batch_size, k] with values between 0 and k-1:
idx = tensor([[0,1,2,0],
[0,3,2,2],
...])
and the following matrix:
x = tensor([[[0, 9],
[1, 8],
[2, 3],
[4, 9]],
[[0, 0],
[1, 2],
[3, 4],
[5, 6]]])
I want to create a new tensor which contains the rows specified in index, in that order. So I want:
tensor([[[0, 9],
[1, 8],
[2, 3],
[0, 9]],
[[0, 0],
[5, 6],
[3, 4],
[3, 4]]])
Currently I'm doing it like this:
for i, batch in enumerate(t):
t[i] = batch[col[i]]
How can I do it more efficiently?
you should use torch gather to achieve this. It would actually also work for the otehr question you linked, but this is left as an exercise to the reader :p
Let us call idx
your first tensor and source
the second one. Their respective dimensions are (B,N)
and (B, K, p)
(with p=2
in your example), and all values of idx
are between 0
and K-1
.
So to use torch gather, we first need to express your operation as a nested for loop. In your case, what you actually want to achieve is :
for b in range(B):
for i in range(N):
for j in range(p):
# This kind of nested for loops is what torch.gether actually does
target[b,i,j] = source[b, idx[b,i,j], j]
But that does not work because idx
is a 2D tensor, not a 3D one. Well, no big deal, let's make it a 3D tensor. We want it to have shape (B, N, p)
and be actually constant along the last dimension. Then we can replace the for loop with a call to gather
:
reshaped_idx = idx.unsqueeze(-1).repeat(1,1,2)
target = source.gather(1, reshaped_idx)
# or : target = torch.gather(source, 1, reshaped_idx)