Search code examples
c++template-specialization

Automate generation of function template specialization in cpp


I have a list of templated operations with the same signature. I want to read the type of template parameters and the operation as an input from user, and then call the corresponding template specialization. Right now my solution is a long list of nested switch cases and looks as follows


#include <iostream>

template<typename T, unsigned L>
int square() {
    T num = 1;
    for(unsigned l = 0; l < L; ++l){
        num += 1;
        std::cout << num*num << "\n";
    }
}

template<typename T, unsigned L>
int cube(){
    T num = 1;
    for(unsigned l = 0; l < L; ++l){
        num += 1;
        std::cout << num*num*num << "\n";
    }
}
/*..
template<typename T, unsigned L>
int nextOp(){
    stuff...
}
..*/
int main(int argc, char* argv[]){
    int type = std::stoi(argv[1]);
    int len = std::stoi(argv[2]);
    int opType = std::stoi(argv[3]);

    switch (opType){
        case 2: // op type 2
        switch(type){
            case 0: // datatype 0
            switch(len){
                case 1: // length 1
                    square<int, 1>();
                    break;
                case 2: // length 2
                    square<int, 2>();
            }
            break;
    
            case 1: // datatype 1
            switch(len){
                case 1:
                    square<float,1>();
                    break;
                case 2:
                    square<float, 2>();
                
            }
        }
        break;

        case 3: // op type 3
        switch(type){
            case 0:
            switch(len){
                case 1:
                    cube<int, 1>();
                    break;
                case 2:
                    cube<int, 2>();
            }
            break;
    
            case 1:
            switch(len){
                case 1:
                    cube<float,1>();
                    break;
                case 2:
                    cube<float, 2>();
                
            }
        }
        /*..
        repeat same nested switch with nextOp<type, len> ??
        ..*/
    }
}

I want to know if there is a better way of doing this.

Can I for example have some function which takes template function 'op' as an input and generate all the specialization permutations for T and L. But this requires passing a template function as an input which I am not sure is possible to do. Another way could be using macros, but again I am not sure how exactly. So my question is: What is the best way to generate such template specialization?


Solution

  • std::variant might help to avoid nested switch:

    std::variant<std::integral_constant<std::size_t, 1>,
                 std::integral_constant<std::size_t, 2>
    > to_int_constant_var(int n)
    {
        switch (n) {
            case 1: return std::integral_constant<std::size_t, 1>();
            case 2: return std::integral_constant<std::size_t, 2>();
        }
        throw std::runtime_error("Invalid argument");
    }
    
    std::variant<std::type_identity<int>,
                 std::type_identity<float>
    > to_type_var(int n)
    {
        switch (n) {
            case 0: return std::type_identity<int>();
            case 1: return std::type_identity<float>();
        }
        throw std::runtime_error("Invalid argument");
    }
    
    // We don't have wrapper/container for function template.
    // You might turn the function into class functor.
    // std doesn't have wrapper around template class, but you might write one if needed
    // Here I just call the function and I don't wrap it in any variant.
    void call(int op_type, int type, int value)
    {
        auto value_c = to_int_constant_var(value);
        auto type_c = to_type_var(type);
    
        switch (op_type) {
            case 2:
                 std::visit([]<typename T, std::size_t N>(
                     std::type_identity<T>,
                     std::integral_constant<std::size_t, N>) {
                         square<T, N>();
                      }, type_c, value_c);
                 break;
            case 3:
                 std::visit([]<typename T, std::size_t N>(
                     std::type_identity<T>,
                     std::integral_constant<std::size_t, N>) {
                         cube<T, N>();
                     }, type_c, value_c);
            break;
        }
    }
    

    Demo