Search code examples
c++cudathrust

How to use thrust::transform on larger Vector derived from smaller Vector?


Input and starting arrays:

dv_A = { 5, -3, 2, 6} //4 elements
dv_B = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }

Expected output:

dv_B = { 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1 }

For every element in dv_A{}, there are (dv_A.size - 1) elements in dv_B{}. This is because each element of dv_A should have a child element in dv_B for each of the other dv_A elements (i.e. should exclude itself). Therefore, if there are 4 elements in dv_A, there should be 3 elements in dv_B for each of the dv_A elements.

I want to transform each dv_B element to have a value of 1 if its corresponding dv_A element has a value > 0. Correspondence is determined based on the position of the element in dv_B. For example:

The first 3 dv_B values will be transformed by the value in dv_A[0], The second 3 dv_B values will be transformed by the value in dv_A[1], Etc.

Here's my attempt so far

thrust::transform(
    dv_B.begin(),
    dv_B.end(),
    thrust::make_transform_iterator(
        dv_A.begin(),
        _1 % dv_A
    ), 
    dv_B.begin(),
    _2 > 0 //When 2nd argument is greater than 0 then set the position in dv_A to 1.
);

Solution

  • The serial code could look something like this:

    for(int i = 0; i < dv_b.size(); i++){
        const int readIndex = i / (dv_a.size() - 1);
        if(dv_a[readIndex] > 0) dv_b[i] = 1;
        else dv_b[i] = 0;
    }
    

    which can easily be written using for_each. I think this makes the code more clear compared to using transform together with various fancy iterators.

    thrust::for_each(
        thrust::device,
        thrust::make_counting_iterator(0),
        thrust::make_counting_iterator(0) + dv_b.size(),
        [
         s = dv_a.size() - 1,
         dv_a = thrust::raw_pointer_cast(dv_a.data()),
         dv_b = thrust::raw_pointer_cast(dv_b.data())
        ] __device__ (int i){
            const int readIndex = i / s;
            if(dv_a[readIndex] > 0) dv_b[i] = 1;
            else dv_b[i] = 0;
        }
    );