Let us have an initial_tensor of size (batch_size, N, N) and a tensor of indexes (batch_size, N), specifying the new order of elements in each 2D tensor in a batch. The goal is to re-arrange the elements of tensors in the batch according to the index tensor to obtain a target tensor.
Currently I am able to do it on CPU using the following loop:
for batch in range(batch_size):
old_ids = indexes[batch]
for i in range(N):
for j in range(N):
target[batch][i][j] = initial_tensor[batch][old_ids[i]][old_ids[j]]
I am looking for an equivalent vector solution to get rid of CPU utilisation.
I tried various options of utilisation of scattering and slicing, but could not figure out the equivalent for the loop.
What you are looking for is to gather values along two axes:
out[b, i, j] = x[b, index[b,i], index[b,j]]
There are no functions for this out of the box, you need to work around it. Compare your setup with the use case of torch.gather
, here x is only indexed on a single axis: dim=1
:
out[b,i] = x[b, index[b,i]]
So what you want to do is flatten x
, and the indices accordingly. Here is a basic setup:
x = torch.rand(B,N,N)
indices = torch.randint(0,N,(B,N))
You can easily get the flattened indices with:
findex = indices.repeat_interleave(N,1)*N + indices.repeat(1,N)
Then simply flatten the (N,N)
dimensions of x
and apply the indexing on dim=1
using findex
:
x.flatten(1).gather(1,findex).view(B,N,N)