Search code examples
c++templatesthrust

Return thrust binary function


I'm trying to define a function that will return the desired type operator based on the content of a string. I have tried this, but it doesn't work:

impl.cpp

template <typename T> thrust::binary_function<T,T,bool>
 get_filter_operator(const std::string &op)
    if (op == "!=")
        return thrust::not_equal_to<T>();
    else if (op == ">")
        return thrust::greater<T>();
    else if (op == "<")
        return thrust::less<T>();
    else if (op == ">=")
        return thrust::greater_equal<T>();
    else if (op == "<=")
        return thrust::less_equal<T>();
    else
    {
        return thrust::equal_to<T>();
    }

template thrust::binary_function<float,float,bool> get_filter_operator<float>(const std::string &);

impl.h

template <typename T> thrust::binary_function<T, T, bool> get_filter_operator(const std::string &op);

How can I return a pointer to an arbitrary function like thrust::not_equal_to<int>(), or thrust::equal_to<int>()? I can't find the correct type to return.

EDIT

As requested, Compiler error:

In instantiation of ‘thrust::binary_function<T, T, bool> get_filter_operator(const string&) [with T = float; std::string = std::basic_string<char>]’:

error: could not convert ‘thrust::equal_to<float>()’ from ‘thrust::equal_to<float>’ to ‘thrust::binary_function<float, float, bool>’ return thrust::equal_to()

Update

Ok sorry to not have mentioned this before: The problem with this is I can't use std::function because it would only work on host code. I wanted to use thrust binary functions so that I could use them both in the GPU and the CPU.


Solution

  • How can I return a pointer to an arbitrary function like thrust::not_equal_to(), or thrust::equal_to()? I cant find the correct type to return

    Each of the things that you are trying to return is a function of two arguments, each of some type T, that returns bool. The correct return type is

    std::function<bool(T, T)>
    

    As in:

    #include <thrust/functional.h>
    #include <functional>
    #include <string>
    
    template<typename T>
    std::function<bool(T, T)>
    get_filter_operator(const std::string &op)
    {
        if (op == "!=")
            return thrust::not_equal_to<T>();
        else if (op == ">")
            return thrust::greater<T>();
        else if (op == "<")
            return thrust::less<T>();
        else if (op == ">=")
            return thrust::greater_equal<T>();
        else if (op == "<=")
            return thrust::less_equal<T>();
        else
        {
            return thrust::equal_to<T>();
        }
    }
    
    #include <iostream>
    
    using namespace std;
    
    int main()
    {
        auto relop = get_filter_operator<int>("!=");
        cout << boolalpha << relop(1,0) << endl;
        cout << boolalpha << relop(1,1) << endl;
    
        return 0;
    }
    

    Now, you may wish to re-iterate your comment to @MohamadElghawi:

    Yeah, I knew that worked, but the problem is that I'm trying to return a thrust::binary_function, not from std

    That may be what you are trying to do, but it is the wrong thing to be trying to do and an impossible thing to do. Look at the definition of template<typename A1, typename A2, typename R> struct thrust::binary_function in <thrust/functional> and at the associated documentation. Note:

    binary_function is an empty base class: it contains no member functions or member variables, but only type information

    In particular, thrust::binary_function<A1,A2,R> has no operator(). It is not callable. It cannot store any other callable object (or anything at all). See also the definitions of equal_to, not_equal_to, etc. in the same file. binary_function is not a even base of any of them. There is no conversion from any of them to binary_function.

    Note too:

    binary_function is currently redundant with the C++ STL type std::binary_function. We reserve it here for potential additional functionality at a later date.

    (std::binary_function is itself deprecated as of C++11 and will be removed in C++17).

    thrust::binary_function<T,T,bool> is not what you are looking for. std::function<bool(T, T)> is.

    std::function<bool(int, int)> f = thrust::greater<int>(); 
    

    makes f encapsulate a callable object that is a thrust::greater<int>

    Later

    The problem with this is that it can only be used in host code doesnt it? The beauty of thrust binary functions is that they can be used both in the GPU and the CPU.

    I think you may be under the impression that, e.g.

    std::function<bool(int, int)> f = thrust::greater<int>();  /*A*/
    

    takes a thrust::greater<int> and in some manner downgrades it into a std::function<bool(int, int)> that has similar but more restricted ("std") execution capabilities.

    Nothing like that is the case. An std::function<bool(int, int)> foo is simply a receptacle for anything bar that is callable with two arguments that are implicitly convertible to int and returns something implicitly convertible to bool, such that if:

    std::function<bool(int, int)> foo = bar; 
    

    then when you call foo(i,j) you are returned the result, as bool, of executing bar(i,j). Not the result of executing anything that is in any way different from bar(i,j).

    Thus in /*A*/ above, the callable thing contained by, and called by, f is a thrust binary function; it is a thrust::greater<int>(). The method that is invoked by the f's operator() is thrust::greater<int>::operator().

    Here is a program:

    #include <thrust/functional.h>
    #include <functional>
    #include <iostream>
    
    using namespace std;
    
    int main()
    {
        auto thrust_greater_than_int = thrust::greater<int>();
        std::function<bool(int, int)> f = thrust_greater_than_int;
        cout << "f " 
            << (f.target<thrust::greater<int>>() ? "calls" : "does not call") 
            << " a thrust::greater<int>" << endl;
        cout << "f " 
            << (f.target<thrust::equal_to<int>>() ? "calls" : "does not call") 
            << " a thrust::equal_to<int>" << endl;
        cout << "f " 
            << (f.target<std::greater<int>>() ? "calls" : "does not call") 
            << " an std::greater<int>" << endl;
        cout << "f " 
            << (f.target<std::function<bool(int,int)>>() ? "calls" : "does not call") 
            << " an std::function<bool(int,int)>" << endl;
        return 0;
    }
    

    that stores a thrust::greater<int> in a std::function<bool(int, int)> f and then informs you that:

    f calls a thrust::greater<int>
    f does not call a thrust::equal_to<int>
    f does not call an std::greater<int>
    f does not call an std::function<bool(int,int)>