Search code examples
pytorchtensor

Rearrange 2D tensors in a batch Torch


Let us have an initial_tensor of size (batch_size, N, N) and a tensor of indexes (batch_size, N), specifying the new order of elements in each 2D tensor in a batch. The goal is to re-arrange the elements of tensors in the batch according to the index tensor to obtain a target tensor.

Currently I am able to do it on CPU using the following loop:

    for batch in range(batch_size):
        old_ids = indexes[batch]

        for i in range(N):
            for j in range(N):
                target[batch][i][j] = initial_tensor[batch][old_ids[i]][old_ids[j]]

I am looking for an equivalent vector solution to get rid of CPU utilisation.

I tried various options of utilisation of scattering and slicing, but could not figure out the equivalent for the loop.


Solution

  • What you are looking for is to gather values along two axes:

    out[b, i, j] = x[b, index[b,i], index[b,j]]
    

    There are no functions for this out of the box, you need to work around it. Compare your setup with the use case of torch.gather, here x is only indexed on a single axis: dim=1:

    out[b,i] = x[b, index[b,i]]
    

    So what you want to do is flatten x, and the indices accordingly. Here is a basic setup:

    x = torch.rand(B,N,N)
    indices = torch.randint(0,N,(B,N))
    

    You can easily get the flattened indices with:

    findex = indices.repeat_interleave(N,1)*N + indices.repeat(1,N)
    

    Then simply flatten the (N,N) dimensions of x and apply the indexing on dim=1 using findex:

    x.flatten(1).gather(1,findex).view(B,N,N)