Search code examples
c++templatescudaoperator-overloading

Class template operator overloading for fundamental and specific non-fundamental types


I am just writing on a MathVector class

template<typename T> MathVector
{
   using value_type = T;

   // further implementation
};

However, the class is thought to work with fundamental types but also with a, lets say, Complex class

template<typename T> Complex
{
   using value_type = T;

   // further implementation
};

which offers for example the member functions

template<typename T> Complex<T>& Complex<T>::operator*=(const Complex<T>& c);
template<typename T> Complex<T>& Complex<T>::operator*=(const T& c);

Now, for the MathVector class also a multiplication is defined:

template<typename T> MathVector<T>& MathVector<T>::operator*=(const MathVector<T>& c);

This is fine for T=double, but for T=Complex<double> I would like to have the possibility to multiply also with double without first converting it to Complex<double> (much more efficient).

This is aggravated by the fact that the Code should also work in CUDA device code (I omitted the specifier __host__ __device__ for brevity). This means that the standard library tools will not be helpful.

First I thought of an additional template parameter

template<typename T, typename U> MathVector<T>& MathVector<T>::operator*=(const U& c);

But this seems dangerous to me, because U can be a lot of more than T or T::value_type. (In fact I had this parameter also in the Complex class first - the compiler was not able any more to decide which template to use, the one of the Complex class or the one of the MathVector class.)

The second idea is to use template specialization

template<typename T, typename U> MathVector<T>& MathVector<T>::operator*=(const U& c)
{
   static_assert(sizeof(T) == 0, "Error...");
}
template<typename T> MathVector<T>& MathVector<T>::operator*=(const typename T::value_type& c)
{
   // implementation
}

But this will not work with fundamental types any more!

I have seen the solutions of this (or a very similar) problem in C++ Operator Overloading for a Matrix Class with Both Real and Complex Matrices and Return double or complex from template function, but they are solved using the standard library in a way which is not possible for CUDA.

So my question is: Is there a way to overload the operator that works with fundamental types and with types that serve a value_type but not for others - without using std:: stuff that the nvcc compiler will reject?


Solution

  • You could make operator*= non-member function templates, and provide all the overloads, make SFINAE to take effect.

    template<typename T>
    MathVector<T>& operator*=(MathVector<T>& m, const MathVector<T>& c);
    template<typename T>
    MathVector<T>& operator*=(MathVector<T>& m, const T& c);
    template<typename T>
    MathVector<T>& operator*=(MathVector<T>& m, const typename T::value_type& c);
    

    Then call them as:

    MathVector<Complex<double>> m1;
    m1 *= MathVector<Complex<double>>{};  // call the 1st one
    m1 *= Complex<double>{};              // call the 2nd one
    m1 *= 0.0;                            // call the 3rd one
    
    MathVector<double> m2;
    m2 *= MathVector<double>{};           // call the 1st one
    m2 *= 0.0;                            // call the 2nd one
    

    LIVE