Search code examples
cudathrust

how to get mapping array when compact with thrust


I need to know what position each element is mapped to when compact with thrust.

For example:

arr:       5  2  -1  3  -1  6 -1  7
compacted: 5  2   3  6  7     --------------(remove all -1 element)
map Arr:   0  1  -1  2  -1  3 -1  4

The mapping array, here I mean the array indicating which position that each element is moved to, sorry I can't come up with a better name to describe this, I hope I made things clear.

Compact with thrust is easy but I'm wondering if I can get that mapping array while thrust is compacting.


Solution

  • The following sequence of thrust calls can generate the desired mapping array.

    The output for your example is:

    compacted: 5 2 3 6 7 
    map: 0  0 -1  0 -1  0 -1  0 
    map: 0  0 -1 -1 -2 -2 -3 -3 
    map: 0  1  1  2  2  3  3  4 
    map: 0  1 -1  2 -1  3 -1  4 
    

    #include <iostream>
    #include <string>
    
    #include <thrust/scan.h>
    #include <thrust/scatter.h>
    #include <thrust/remove.h>
    #include <thrust/iterator/counting_iterator.h>
    #include <thrust/iterator/constant_iterator.h>
    
    void print(const std::string& name, int* begin, int* end)
    {
      std::cout << name << ": ";
      thrust::copy(begin, end, std::ostream_iterator<int>(std::cout, " "));
      std::cout <<std::endl;
    }
    
    struct is_marker
    {
      __host__ __device__
      bool operator()(const int x) const
      {
        return (x == -1);
      }
    };
    
    int main()
    {
      const int N = 8;
      int arr[N] = {5,2,-1,3,-1,6,-1,7};
    
      int compacted[N] = {0};
      int* compacted_end = thrust::remove_copy(arr, arr+N, compacted, -1);
      print("compacted", compacted, compacted_end);
    
      int map[N] = {0};
      thrust::scatter_if(thrust::make_constant_iterator(-1), thrust::make_constant_iterator(-1)+N, thrust::make_counting_iterator(0), arr, map, is_marker());
      print("map", map, map+N);
    
      thrust::inclusive_scan(map, map+N, map);
      print("map", map, map+N);
    
      thrust::transform(map, map+N, thrust::make_counting_iterator(0), map, thrust::plus<int>());
      print("map", map, map+N);
    
      thrust::scatter_if(thrust::make_constant_iterator(-1), thrust::make_constant_iterator(-1)+N, thrust::make_counting_iterator(0), arr, map, is_marker());
      print("map", map, map+N);
    
      return 0;
    }