Search code examples
c++templatesfriend

Why is this working? (template friend function with template class with different arguments)


I have a matrix class:

template <typename T, const int N, const int M>
class TMatrixNxM
{
    (...)
    friend TMatrixNxM operator*(const TMatrixNxM&, const TMatrixNxM&);
    (...)
}

Now, in maths, multiplying an NxM matrix with MxP matrix returns NxP matrix. So, I needed an operator that returns an NxP matrix, and takes an NxM and MxP matrices as arguments, like so:

template <typename T, const int N, const int M, const int P>
TMatrixNxM<T, N, P> operator*(const TMatrixNxM<T, N, M> &par_value1, const TMatrixNxM<T, M, P> &par_value2)
{
    TMatrixNxM<T, N, P> result;

    (...) //Calculate

    return result;
}

When I test it:

TMatrixNxM<float, 2, 3> m1;
(...) //Set the values

TMatrixNxM<float, 3, 4> m2;
(...) //Set the values

TMatrixNxM<float, 2, 4> m3 = m1 * m2;

m1.print(); //Matrix class has a print function for testing
printf("\n");
m2.print();
printf("\n");
m3.print();

It works just like that. How and why exactly does this work? The overloadad operator takes an extra template argument, while the class takes only 3, and in the declaration I didn't specify anything. If I declare it like this, however:

template <typename T, const int N, const int M>
class TMatrixNxM
{
    (...)
    template<typename T, const int N, const int M, const int P> friend TMatrixNxM<N, P> operator*(const TMatrixNxM<N, M>&, const TMatrixNxM<M, P>&);
    (...)
}

Then compiler complains that there are too few template arguments. I hope I'm not missing something obvious here.

Thanks!

EDIT

I see now that the "too few argument" complaint is aimed at the fact that I didn't include T too. Should have been TMatrixNxM< T, N, P >, etc.


Solution

  • Two things:

    1) The compiler understands the difference between the template arguments to a class, and the template arguments to a function. You can have a class with 3 template arguments, and a function with 4.

    So when you declare:

    template <typename T, const int N, const int M, const int P>
    TMatrixNxM<T, N, P> operator*(const TMatrixNxM<T, N, M> &par_value1, constXX TMatrixNxM<T, M, P> &par_value2)
    

    you have defined a function that takes 4 template arguments. Then, when the compiler sees:

    TMatrixNxM<float, 2, 4> m3 = m1 * m2;
    

    the compiler deduces the 4 template arguments: T, N, M, and P. It deduces T = float, N = 2, M = 4, and P = comes from the third template argument of par_value2.

    Also note that the template argument names for the function don't have to be the same as the template argument names on the class:

    template <typename FOO, const int BAR, const int BAZ, const int QUX>
    TMatrixNxM<FOO, BAR, QUX> operator*(const TMatrixNxM<FOO, BAR, BAZ> &par_value1, const TMatrixNxM<FOO, BAR, QUX> &par_value2)
    

    2) In your second example, you are indeed missing a template argument. You are trying to return a TMatrixNxM but TMatrixNxM requires 3 arguments. It looks like it would work if you changed the return type to TMatrixNxM... which is what you did in the first part.