Search code examples
cudathrust

Using Thrust counting iterators with strides


I am looking for a way to use the thrust::counting_iterator function in order to parallelize the following for loop:

for (int stride = 0 ; stride < N * M ; stride+=M) // N iterations
{
    // Body of the loop
}

Here is how the code looks like:

struct functor ()
{
   __host__ __device__ void operator() (const int i)
   {
      // Body of the loop
   }
}

thrust::counting_iterator<int> it1(0);
thrust::counting_iterator<int> it2 = it1 + N * M;
thrust::for_each (it1 , it2 , functor());

I'm aware that counting_iterator increments iterators by 1, so is there a way to increment by M?


Solution

  • Why not just multiply the i variable by M in your functor?

    If M is known at compile time, it could be:

    struct functor 
    {
       __host__ __device__ void operator() (const int my_i)
       {
          int i = my_i *M;
          // Body of the loop
       }
    };
    
    thrust::counting_iterator<int> it1(0);
    thrust::counting_iterator<int> it2 = it1 + N;
    thrust::for_each (it1 , it2 , functor());
    

    If M is known only at runtime, we can pass it as an initializing parameter to the functor:

    struct functor 
    {
       int my_M;
       functor(int _M) : my_M(_M) ();
       __host__ __device__ void operator() (const int my_i)
       {
          int i = my_i *my_M;
          // Body of the loop
       }
    };
    
    thrust::counting_iterator<int> it1(0);
    thrust::counting_iterator<int> it2 = it1 + N;
    thrust::for_each (it1 , it2 , functor(M));
    

    You could also wrap a counting iterator in a transform iterator, which takes the counting iterator and multiplies it by M:

    struct functor 
    {
       __host__ __device__ void operator() (const int i)
       {
          // Body of the loop
       }
    };
    
    using namespace thrust::placeholders;
    thrust::counting_iterator<int> it1(0);
    thrust::counting_iterator<int> it2 = it1 + N;
    thrust::for_each (make_transform_iterator(it1, _1 * M) , thrust::make_transform_iterator(it2, _1 * M) , functor());
    

    This last example uses thrust placeholder expressions, although it could be realized equivalently with an additional trivial functor which returns its argument multiplied by its parameter.

    Here is a fully worked example showing all 3 methods:

    $ cat t492.cu
    #include <stdio.h>
    #include <thrust/transform.h>
    #include <thrust/for_each.h>
    #include <thrust/execution_policy.h>
    #include <thrust/iterator/counting_iterator.h>
    #include <thrust/iterator/transform_iterator.h>
    #include <thrust/host_vector.h>
    #include <thrust/functional.h>
    #define N 5
    #define M 4
    using namespace thrust::placeholders;
    
    struct my_functor_1
    {
      __host__ __device__  void operator() (const int i)
      {
        printf("functor 1 value: %d\n", i);
      }
    };
    
    struct my_functor_2
    {
       __host__ __device__ void operator() (const int my_i)
       {
        int i = my_i*M;
        printf("functor 2 value: %d\n", i);
       }
    };
    
    struct my_functor_3
    {
       int my_M;
       my_functor_3(int _M) : my_M(_M) {};
       __host__ __device__ void operator() (const int my_i)
       {
          int i = my_i *my_M;
          printf("functor 3 value: %d\n", i);
       }
    };
    
    
    int main(){
      thrust::counting_iterator<int> it1(0);
      thrust::counting_iterator<int> it2 = it1 + N;
      thrust::for_each(thrust::host, it1, it2, my_functor_1());
      thrust::for_each(thrust::host, it1, it2, my_functor_2());
      thrust::for_each(thrust::host, it1, it2, my_functor_3(M));
      thrust::for_each(thrust::host, thrust::make_transform_iterator(it1, _1 * M), thrust::make_transform_iterator(it2, _1 * M), my_functor_1());
      return 0;
    }
    
    
    $ nvcc -arch=sm_20 -o t492 t492.cu
    $ ./t492
    functor 1 value: 0
    functor 1 value: 1
    functor 1 value: 2
    functor 1 value: 3
    functor 1 value: 4
    functor 2 value: 0
    functor 2 value: 4
    functor 2 value: 8
    functor 2 value: 12
    functor 2 value: 16
    functor 3 value: 0
    functor 3 value: 4
    functor 3 value: 8
    functor 3 value: 12
    functor 3 value: 16
    functor 1 value: 0
    functor 1 value: 4
    functor 1 value: 8
    functor 1 value: 12
    functor 1 value: 16
    $