Search code examples
pythonalgorithmpytorchcuda

"Reverse" map of `inverse` in torch.unique?


I want to implement a function like "reverse" map of inverse value of torch.unique.

For example, torch.unique function can return the unique values uniques of long type input x and a inverse tensor that is a mapping from uniques to x.

x = torch.LongTensor([9, 10, 9, 9, 10, 9])
uniques, inverse, counts = torch.unique(x, return_inverse=True, return_counts=True)
# uniques = [9, 10]
# inverse = [0, 1, 0, 0, 1, 0] 
# counts = [4, 2]
print((uniques[inverse] == x).all())  # True

For my question, is there some efficient way to get "reverse" inverse back_map that mapping from x to uniques?

def reverse_unique(x): ...

uniques, inverse, counts, back_map = reverse_unique(x)
# uniques = [9, 10]
# inverse = [0, 1, 0, 0, 1, 0] 
# counts = [4, 2]
# back_map = [0, 2, 3, 5, 1, 4]
print((x[back_map] == uniques.repeat_interleave(counts)).all()) # True

In above code, the back_map maps the values of inverse to the position of the input x.

I know it not a difficult thing implements this function with python loop, but in my case where the input x size can reach e8, so the time overhead is intolerable.

So, Is there any high level implementation using pytorch api or the cuda kernel(I tried to use cuda extension to parallelize it but my cuad kernel is slow extremely :sob:)?

__global__ void unique_back_map_kernel(
    int32_t num_uni,
    int32_t num_x,
    int64_t* __restrict__ uniques,
    int64_t* __restrict__ cumsum_counts,
    int64_t* __restrict__ x,
    int64_t* __restrict__ out) {
  int32_t n = blockIdx.x * blockDim.x + threadIdx.x;
  if (n >= num_uni) {
    return;
  }

  size_t counts = 0;
  auto idx = __ldg(&cumsum_counts[n]);

#pragma unroll
  for (int64_t i = 0; i < num_x; ++i) {
    if (x[i] == uniques[n]) {
      out[idx + counts] = i;
      counts++;
    }
  }
}

std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> reverse_unique(at::Tensor x) {
  TORCH_CHECK(x.dtype() == at::kLong && x.device().is_cuda());
  auto [uniques, inverse, counts] = at::_unique2(x, false, true, true);
  counts.cumsum_(0);
  // python: cumsum_counts = torch.cat([torch.tensor([0]), cumsum_counts[:-1]])
  auto cumsum_counts = at::cat({at::zeros({1}, counts.options()), counts.slice(0, 0, -1)});

  auto back_map = at::empty_like(x);

  int32_t threads = (uniques.numel() > 256) ? 256 : 32;
  int32_t blocks = (uniques.numel() + threads - 1) / threads;
  unique_back_map_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
      uniques.numel(),
      x.numel(),
      uniques.data_ptr<int64_t>(),
      cumsum_counts.data_ptr<int64_t>(),
      x.data_ptr<int64_t>(),
      back_map.data_ptr<int64_t>());

  return std::make_tuple(uniques, inverse, cumsum_counts, back_map);
}

Solution

  • @alexey-birukov probably means this:

    If you read the documentation for torch.unique you will learn, that it basically just sorts the values and then does a torch.unique_consecutive:

    Currently in the CUDA implementation and the CPU implementation when dim is specified, torch.unique always sort [sic] the tensor at the beginning regardless of the sort argument. Sorting could be slow, so if your input tensor is already sorted, it is recommended to use torch.unique_consecutive() which avoids the sorting.

    Another look at the documentation of torch.sort lets you know that it will not only return the sorted values but also the permutation of indices for getting back the original order. This permutation seems to be exactly what you are after.

    x = torch.LongTensor([9, 10, 9, 9, 10, 9])
    sorted, back_map = torch.sort(x)
    uniques, inverse, counts = torch.unique_consecutive(sorted, return_inverse=True, return_counts=True)