Search code examples
c++templatescudametaprogrammingfunctor

How to instantiate templated functors F<D> over multiple functors F1,F2 and multiple template parameters D1,D2?


I need to instantiate a bunch of functors

template<typename DataType>
struct Functor1{
   int a;
   Functor1(int a_){ a = a_; }
//    __device__
   void operator()(DataType &elem) 
       elem.x +=1;
   }
};

template<typename DataType>
struct Functor2{
   int a;
   Functor2(int a_){ a = a_; }
//    __device__
   void operator()(DataType &elem) {
       elem.x +=10;
   }
};

for Cuda by a set of structs D1, D2...:

struct D1{
    int x;
};
struct D2{
    int x;
    int y;
};

I want automatically and explicitly instantiate all of them:

template class Functor1<D1>;
template class Functor1<D2>;
template class Functor2<D1>;
template class Functor2<D2>;

I want a macros/metaprogrammic trick the code above:

#define DATATYPE_LIST(D1)(D2)
#define FUNCTOR_LIST (Functor1)(Functor2)
EXPLICIT_FUNCTOR_INSTANTIATION(FUNCTOR_LIST, DATATYPE_LIST)

How to do that using macroses or SFINAE?


Solution

  • I want a macros/metaprogrammic trick the code above

    Here is a way to do this using templates instead of macros. The below program works for arbitrary number of Functors and Ds. See the different instantiations at the end of this answer, for different combinations of Functors and Ds. This uses features like constexpr if and fold expression.

    template<template<typename>typename Functor, template<typename>typename... Functors, typename... Args> void f(Args... args)
    {
        int i = (Functor<Args>(55)(args),...,1); //instantiate for the first parameter Functor with all of args
        if constexpr(sizeof...(Functors)>0)
        {
             f<Functors...>(args...);           //call recursively for remaining Functors with all of args
        }  
    }
    
    int main()
    {
        std::cout << "Test 1: "<< std::endl;
        f<Functor1, Functor2>(D1(), D2());
        std::cout <<"--------------------------------------" << std::endl;
        
        std::cout << "Test 2: "<< std::endl;
        f<Functor1, Functor2, Functor3>(D1(), D2());
        std::cout <<"--------------------------------------" << std::endl;
        
        std::cout << "Test 3: "<< std::endl;
        f<Functor1, Functor2>(D1(), D2(), D3());
        std::cout <<"--------------------------------------" << std::endl;
        
        std::cout << "Test 4: "<< std::endl;
        f<Functor1, Functor2, Functor3>(D1(), D2(), D3());
        std::cout <<"--------------------------------------" << std::endl;
    }
    

    Working demo c++17


    Below are given(output of the above program) the instantiations that will be generated due to different call expressions.

    Test 1: 
    template Functor1<D1>
    template Functor1<D2>
    template Functor2<D1>
    template Functor2<D2>
    --------------------------------------
    Test 2: 
    template Functor1<D1>
    template Functor1<D2>
    template Functor2<D1>
    template Functor2<D2>
    template Functor3<D1>
    template Functor3<D2>
    --------------------------------------
    Test 3: 
    template Functor1<D1>
    template Functor1<D2>
    template Functor1<D3>
    template Functor2<D1>
    template Functor2<D2>
    template Functor2<D3>
    --------------------------------------
    Test 4: 
    template Functor1<D1>
    template Functor1<D2>
    template Functor1<D3>
    template Functor2<D1>
    template Functor2<D2>
    template Functor2<D3>
    template Functor3<D1>
    template Functor3<D2>
    template Functor3<D3>
    --------------------------------------
    

    This also works for asymmetric call expressions like: f<Functor1, Functor2, Functor3>(D1(), D2(), D3(), D4()); etc.