Search code examples
ccudathrust

Issue trying to use Templates on a Thrust device comparator


The declaration of the struct is pretty simple. For some reason whennever I try to use templating for defining a comparator I am unable to use the .get() associated with that tuple so the following code throws an error at t1.get<0>() and all the rest. I would like to understand why when you use templates the tuple no longe rhas .get() as a function on it.

template<typename FirstType, typename SecondType>
struct TupleComp{

    typedef typename thrust::device_vector<FirstType >::iterator firstIter;
    typedef typename thrust::device_vector<SecondType>::iterator secondIter;

    typedef typename thrust::tuple<firstIter,secondIter> TupleType;

     __host__ __device__
    bool operator()(const TupleType &t1, const TupleType &t2)
    {
        // thrust::tuple<thrust::device_vector<long long>::iterator > tup;
         TupleType tup;


         if(t1.get<0>() < t2.get<0>()){
             return true;
         }

         if(t1.get<0>() > t2.get<0>()){
             return false;
         }

         return (t1.get<1>() < t2.get<1>());

     }
};

Below is the similar code working

struct TupleCompUllFirstLLSecond{


    typedef typename thrust::tuple<thrust::device_vector<unsigned long long>::iterator,thrust::device_vector<long long>::iterator> TupleType;

     __host__ __device__
    bool operator()(const TupleType &t1, const TupleType &t2)
    {

         if(t1.get<0>() < t2.get<0>()){
             return true;
         }

        if(t1.get<0>() > t2.get<0>()){
             return false;
        }

         return (t1.get<1>() < t2.get<1>());

    }
};

Solution

  • Thanks to Robert Crovella (who coincidently has solved all my thrust questions to date) the solution was fixing an error I had in the kind of tuple I was comparing and using thrust::get as opposed to the tuple t1.get. The working comparison functor is:

    template<typename FirstType, typename SecondType>
    struct TupleComp{
    
    
    
        typedef typename thrust::tuple<FirstType,SecondType> TupleType;
    
         __host__ __device__
        bool operator()(const TupleType &t1, const TupleType &t2)
        {
    
             FirstType leftFirst = thrust::get<0>(t1);
             FirstType rightFirst = thrust::get<0>(t2);
    
    
             if(leftFirst < rightFirst){
                 return true;
             }
    
             if(leftFirst > rightFirst){
                 return false;
             }
    
             SecondType leftSecond = thrust::get<1>(t1);
             SecondType rightSecond = thrust::get<1>(t2);
    
    
             return leftSecond < rightSecond;
    
        }
    };