Search code examples
c++templatesmetaprogrammingtemplate-meta-programmingcomplex-numbers

Creating a function template for double and std::complex that does not require specialization


As a learning exercise, I was trying to create a function that computes a Hermitian conjugate in-place. It should behave like a simple transpose when all entries are real, and hence should work with double. I know that I can specialize separately for double, and it's doable in this particular example. But, I imagine that specializing would become tedious for larger problems like ODE Solvers.

I tried the following

#include <complex>

const size_t ZERO = 0ul;

template <class value_type,
          class container_type = value_type*>
auto
hermitianConjugate(container_type buffer, size_t width)
{
    for (size_t row = ZERO; row < width; row++)
    {
        for (size_t col = ZERO; col < width; col++)
        {
            auto temp = std::conj(buffer[col * width + row]);
            if (std::imag(temp) == 0)
            {
                // works for both double and std::complex
                buffer[row * width + col] = buffer[col * width + row];
            } else 
            {
                // for std::complex
                buffer[row * width + col] = temp;
                // raises error when value_type is double
            }
        }
    }
}

Is there a workaround that does not involve explicit specialization? Is there any way to use conditional branching "statically", if that makes sense?


Solution

  • You can use if-constexpr if you have c++17. This essentially creates different specializations, without the need to write separate functions.

    #include <complex>
    #include <type_traits>
    
    const size_t ZERO = 0ul;
    
    template <class value_type,
              class container_type = value_type*>
    auto
    hermitianConjugate(container_type buffer, size_t width)
    {
        for (size_t row = ZERO; row < width; row++)
        {
            for (size_t col = ZERO; col < width; col++)
            {
                if constexpr (std::is_same_v<value_type, std::complex<double>>) {
                    // only for complex
                    buffer[row * width + col] = std::conj(buffer[col * width + row]);
                }
                else
                {
                    // for double
                    buffer[row * width + col] = buffer[col * width + row];
                }
            }
        }
    }
    

    If you do not have c++17, you could write an overloaded function to perform the different tasks, depending on the type:

    #include <complex>
    
    const size_t ZERO = 0ul;
    
    constexpr double myconj(double x) noexcept { return x; }
    std::complex<double> myconj(std::complex<double> x) { return std::conj(x); }
    
    template <class value_type,
              class container_type = value_type*>
    auto
    hermitianConjugate(container_type buffer, size_t width)
    {
        for (size_t row = ZERO; row < width; row++)
        {
            for (size_t col = ZERO; col < width; col++)
            {
                buffer[row * width + col] = myconj(buffer[col * width + row]);
            }
        }
    }
    

    Notice that std::conj has already a dedicated overload for double, but also in this case it returns a std::complex.