Search code examples
c++performancetemplates

nested switch statements for template arguments


We have some templated function that takes 2 compile-time arguments, e.g.

template<int a, int b> 
void SomeFunc(double *x, double *y, double *z);

For performance reasons, it is better to have specialized implementations of this function based on some predefined set of values for a and b. A reasonable number of values for each is 10, thus a total of 100 implementations of SomeFunc at most are possible. If you are wondering why, it's because for the application considered, knowing those values at compile-time allows the compiler to better optimize SomeFunc, particularly since it involves heavy dense computations.

One way around this is to have 2 nested switch statements and evaluate this at runtime, each time the function gets called. However, this is both redundant as well as needlessly re-evaluating the same switch statements each time (since the values a and b do not change).

Another approach could be to have a function pointer that preprocesses the values of a and b and decides which function to point to. In this case, SomeFunc is relatively big enough, such that I do not have to worry about inlining it and so I wouldn't have to pay for performance costs here. Still, the problem is that we would have to write those 2 tedious nested switch statements, e.g.

auto DecideFuncPtr(int A, int B)
{
  switch(A)
  {
    case(1): return fa1(B);
    case(2): return fa2(B);
    ...
  }
}

with each specialized implementation looking like:

// A = 1
auto fa1(int B)
{
  switch(B)
  {
    case(1): return SomeFunc<1,1>;
    case(2): return SomeFunc<1,2>;
    ...
  }
}

// A = 2
auto fa2(int B)
{
  switch(B)
  {
    case(1): return SomeFunc<2,1>;
    case(2): return SomeFunc<2,2>;
    ...
  }
}
...

And finally, preprocessing the relevant function once (e.g. in a constructor) as such:

auto f = DecideFuncPtr(A, B);

// Use the function later on as such:
f(matrix_A, matrix_B, matrix_C);

Thus, you can see how messy this can look like in terms of design. Specifically, if one needs to add/adjust the combination of implementable values a and b. Is there any way to make this slightly nicer, while still ensuring that the resulting function SomeFunc has values a and b known at compile-time?

For completion, I can use C++17 or even C++20.


EDIT: I forgot to add the function type (void); not that changes the pseudo-code and question posed. Also, note that the values of a and b are not necessarily sequential. Here I specify values of 1 and 2 for demonstration purposes only..


Solution

  • One way to turn runtime value to compile time value is to use std::variant:

    std::variant<
        std::integral_constant<enumA, enumA::Value0>,
        std::integral_constant<enumA, enumA::Value1>,
        std::integral_constant<enumA, enumA::Value2>
        // ...
    > to_variant(enumA a) {
        switch (a) {
            case enumA::Value0: return std::integral_constant<enumA, enumA::Value0>{};
            case enumA::Value1: return std::integral_constant<enumA, enumA::Value1>{};
            case enumA::Value2: return std::integral_constant<enumA, enumA::Value2>{};
    // ...
        }
        std::unreachable(); // Or throw
    }
    
    // Similar for enumB
    

    You do it once by type. you don't have to do the Cartesian product yourself.

    Then, you can use std::visit to do the job for you:

    auto foo(enumA a, enumB b)
    -> void (*)(double*, double*, double*)
    {
        return std::visit([](auto a, auto b){ return &SomeFunc<a(), b()>; },
                          to_variant(a), to_variant(b));
    }
    // or directly
    void foo(enumA a, enumB b, double* x, double* y, double* z)
    {
        std::visit([&](auto a, auto b){ return SomeFunc<a(), b()>(x, y, z); },
                   to_variant(a), to_variant(b));
    }
    

    Note: I used enum instead of int here, as it better expresses than possible values are "limited", but you can do it with int if you prefer.