Search code examples
cudathrust

Thrust scan of just one class member


I have a custom class myClass which has members weight and config. I'd like to run an inclusive scan on a bunch of myClasses, but only on the weights. Basically what I want is to take:

[ {configA, weightA}, {configB, weightB}, {configC, weightC}, ...]

to:

[ {configA, weightA}, {configB, weight A + weightB}, {configC, weight A + weight B + weightC}, ...]

Is there a simple way to do this using Thrust's fancy iterators? Since the binaryOp is required to be associative, I don't see how to do this with just overloading operator+.


Solution

  • inclusive_scan requires an associative operator, but it needn't be commutative. If you create a binary function which copies the config member of its second parameter to the result, it should work out:

    #include <iostream>
    #include <thrust/device_vector.h>
    #include <thrust/scan.h>
    
    struct my_struct
    {
      __host__ __device__
      my_struct() {}
    
      __host__ __device__
      my_struct(const my_struct &other)
        : config(other.config), weight(other.weight)
      {}
    
      __host__ __device__
      my_struct(char c, double w)
        : config(c), weight(w)
      {}
    
      char config;
      double weight;
    };
    
    
    struct functor
    {
      __host__ __device__
      my_struct operator()(my_struct a, my_struct b)
      {
        my_struct result;
        result.config = b.config;
        result.weight = a.weight + b.weight;
        return result;
      }
    };
    
    int main()
    {
      thrust::device_vector<my_struct> vec(3);
    
      vec[0] = my_struct('a', 1);
      vec[1] = my_struct('b', 2);
      vec[2] = my_struct('c', 3);
    
      std::cout << "input: ";
      for(int i = 0; i < vec.size(); ++i)
      {
        my_struct x = vec[i];
        std::cout << "{" << x.config << ", " << x.weight << "} ";
      }
      std::cout << std::endl;
    
      thrust::inclusive_scan(vec.begin(), vec.end(), vec.begin(), functor());
    
      std::cout << "result: ";
      for(int i = 0; i < vec.size(); ++i)
      {
        my_struct x = vec[i];
        std::cout << "{" << x.config << ", " << x.weight << "} ";
      }
      std::cout << std::endl;
    
      return 0;
    }
    

    The output:

    $ nvcc -arch=sm_20 test.cu -run
    input: {a, 1} {b, 2} {c, 3} 
    result: {a, 1} {b, 3} {c, 6}