Search code examples
c++cudathrust

How to pass additional array to Thrust's min_element predicate


I'm trying to use Thrust's min_element reduction to find next edge in Prim's algorithm. I iterate over graph edges. This is my comparison function:

struct compareEdge {
    __host__ /*__device__*/ bool operator()(Edge l, Edge r) {
        if (visited[l.u] != visited[l.v] && visited[r.u] != visited[r.v]) {
            return l.cost < r.cost;
        } else if (visited[l.u] != visited[l.v]) {
            return true;
        } else {
            return false;
        }
    }
};

Unfortunately this code cannot run on device, because I use visited array, where I mark already visited nodes. How can I pass this array to my predicate to make it usable from device-executed code?


Solution

  • There are probably a number of ways this can be handled. I will present one approach. Please note that your question is how to pass an arbitrary data set to a functor, which is what I'm trying to show. I'm not trying to address the question of whether or not your proposed functor is a useful comparison predicate for thrust::min_element (which I'm not sure of).

    One approach is simply to have a statically defined array:

    __device__ int d_visited[DSIZE];
    

    then in your host code, before using the functor, you will need to initialize the array:

    cudaMemcpyToSymbol(d_visited, visited, DSIZE*sizeof(int));
    

    Your functor code would have to be modified. Since you may want the functor to be usable either on the host or the device, we will need to control the code based on this:

    struct compareEdge {
        __host__ __device__ bool operator()(Edge l, Edge r) {
    #ifdef __CUDA_ARCH__
            if (d_visited[l.u] != d_visited[l.v] && d_visited[r.u] != d_visited[r.v]) {
                return l.cost < r.cost;
            } else if (d_visited[l.u] != d_visited[l.v]) {
                return true;
            } else {
                return false;
            }
    #else
            if (visited[l.u] != visited[l.v] && visited[r.u] != visited[r.v]) {
                return l.cost < r.cost;
            } else if (visited[l.u] != visited[l.v]) {
                return true;
            } else {
                return false;
            }
    #endif
        }
    };