Search code examples
pythonpytorchtensortensor-indexing

Pytorch: How to reorder the tensor by given sorted indices


Given a tensor A shape (d0, d1, ..., dn, dn+1) and a tensor of sorted indices I with shape (d0, d1, ..., dn) I want to reorder the indices of A using the sorted indices in I.

The first n dimensions of tensors A and I are equal, the (n+1)-th dimension of tensor A can be any size.

Example

Given A and I:

>>> A.shape
torch.Size([8, 8, 4])
>>> A
tensor([[[5.6065e-01, 3.1521e-01, 5.7780e-01, 6.7756e-01],
         [9.9534e-04, 7.6054e-01, 9.0428e-01, 4.1251e-01],
         [8.1525e-01, 3.0477e-01, 3.9605e-01, 2.9155e-01],
         [4.9588e-01, 7.4128e-01, 8.8521e-01, 6.1442e-01],
         [4.3290e-01, 2.4908e-01, 9.0862e-01, 2.6999e-01],
         [9.8264e-01, 4.9388e-01, 4.9769e-01, 2.7884e-02],
         [5.7816e-01, 7.5621e-01, 7.0113e-01, 4.4830e-01],
         [7.2809e-01, 8.6010e-01, 7.8921e-01, 1.1440e-01]],
        ...])
>>> I.shape
torch.Size([8, 8])
>>> I
tensor([[2, 7, 4, 6, 1, 3, 0, 5],
        ...])

The elements of the second-last dimension of A after reordering should look like this:

>>> A
tensor([[[8.1525e-01, 3.0477e-01, 3.9605e-01, 2.9155e-01],
         [7.2809e-01, 8.6010e-01, 7.8921e-01, 1.1440e-01],
         [4.3290e-01, 2.4908e-01, 9.0862e-01, 2.6999e-01],
         [5.7816e-01, 7.5621e-01, 7.0113e-01, 4.4830e-01],
         [9.9534e-04, 7.6054e-01, 9.0428e-01, 4.1251e-01],
         [4.9588e-01, 7.4128e-01, 8.8521e-01, 6.1442e-01],
         [5.6065e-01, 3.1521e-01, 5.7780e-01, 6.7756e-01],
         [9.8264e-01, 4.9388e-01, 4.9769e-01, 2.7884e-02]],
        ...])

For simplicity, I have included only the first row for the tensors A and I.

Solution

Based on the accepted answer I implemented a generalized version that can sort any tensor of any number or dimensions (d0, d1, ..., dn, dn+1, dn+2, , ..., dn+k) given a tensor of sorted indices (d0, d1, ..., dn).

Here the code snippet:

import torch
from torch import LongTensor, Tensor


def sort_by_indices(values: Tensor, indices: LongTensor) -> Tensor:
    num_dims = indices.dim()

    new_shape = tuple(indices.shape) + tuple(
        1
        for _ in range(values.dim() - num_dims)
    )
    repeats = tuple(
        1
        for _ in range(num_dims)
    ) + tuple(values.shape[num_dims:])

    repeated_indices = indices.reshape(*new_shape).repeat(*repeats)
    return torch.gather(values, num_dims - 1, repeated_indices)

Solution

  • You can use torch.gather but you need to reshape and tile indices like the below:

    (For better to show I change (8, 8, 4) -> (4, 4, 2) and (8, 8) -> (4, 4))

    import torch
    torch.manual_seed(2023)
    A = torch.rand(4, 4, 2)
    # First A
    # >>> A
    # tensor([[[0.4290, 0.7201],
    #          [0.9481, 0.4797],
    #          [0.5414, 0.9906],
    #          [0.4086, 0.2183]],
    
    #         [[0.1834, 0.2852],
    #          [0.7813, 0.1048],
    #          [0.6550, 0.8375],
    #          [0.1823, 0.5239]],
    
    #         [[0.2432, 0.9644],
    #          [0.5034, 0.0320],
    #          [0.8316, 0.3807],
    #          [0.3539, 0.2114]],
    
    #         [[0.9839, 0.6632],
    #          [0.7001, 0.0155],
    #          [0.3840, 0.7968],
    #          [0.4917, 0.4324]]])
    B = torch.tensor([
        [0, 2, 3, 1],
        [1, 3, 0, 2],
        [3, 1, 2, 0],
        [2, 0, 1, 3]
    ])
    B_changed = torch.tile(B[..., None], (1,1,A.shape[2]))
    A_new = torch.gather(a, dim = 1, index = B_changed)
    print(A_new)
    

    Output:

    tensor([[[0.4290, 0.7201],
             [0.5414, 0.9906],
             [0.4086, 0.2183],
             [0.9481, 0.4797]],
    
            [[0.7813, 0.1048],
             [0.1823, 0.5239],
             [0.1834, 0.2852],
             [0.6550, 0.8375]],
    
            [[0.3539, 0.2114],
             [0.5034, 0.0320],
             [0.8316, 0.3807],
             [0.2432, 0.9644]],
    
            [[0.3840, 0.7968],
             [0.9839, 0.6632],
             [0.7001, 0.0155],
             [0.4917, 0.4324]]])