Search code examples
pythonpytorchpermutationinverse

How to quickly inverse a permutation by using PyTorch?


I am confused on how to quickly restore an array shuffled by a permutation.

Example #1:

  • [x, y, z] shuffled by P: [2, 0, 1], we will obtain [z, x, y]
  • the corresponding inverse should be P^-1: [1, 2, 0]

Example #2:

  • [a, b, c, d, e, f] shuffled by P: [5, 2, 0, 1, 4, 3], then we will get [f, c, a, b, e, d]
  • the corresponding inverse should be P^-1: [2, 3, 1, 5, 4, 0]

I wrote the following codes based on matrix multiplication (the transpose of permutation matrix is its inverse), but this approach is too slow when I utilize it on my model training. Does there exisits a faster implementation?

import torch

n = 10
x = torch.Tensor(list(range(n)))
print('Original array', x)

random_perm_indices = torch.randperm(n).long()
perm_matrix = torch.eye(n)[random_perm_indices].t()
x = x[random_perm_indices]
print('Shuffled', x)

restore_indices = torch.Tensor(list(range(n))).view(n, 1)
restore_indices = perm_matrix.mm(restore_indices).view(n).long()
x = x[restore_indices]
print('Restored', x)

Solution

  • I obtained the solution in PyTorch Forum.

    >>> import torch
    >>> torch.__version__
    '1.7.1'
    >>> p1 = torch.tensor ([2, 0, 1])
    >>> torch.argsort (p1)
    tensor([1, 2, 0])
    >>> p2 = torch.tensor ([5, 2, 0, 1, 4, 3])
    >>> torch.argsort (p2)
    tensor([2, 3, 1, 5, 4, 0])
    

    Update: The following solution is more efficient due to its linear time complexity:

    def inverse_permutation(perm):
        inv = torch.empty_like(perm)
        inv[perm] = torch.arange(perm.size(0), device=perm.device)
        return inv