Search code examples
c++templatesoperator-overloadingnon-member-functions

What's the syntax to overload operator== as a free function with templated parameters?


I have a set of polymorphic classes, such as:

class Apple {};
class Red : public Apple {};
class Green : public Apple {};

And free functions which compare them:

bool operator==(const Apple&, const Apple&);
bool operator< (const Apple&, const Apple&);

I'm designing a copyable wrapper class which will allow me to use classes Red and Green as keys in STL maps while retaining their polymorphic behaviour.

template<typename Cat>
class Copy
{
public:
    Copy(const Cat& inCat) : type(inCat.clone()) {}
    ~Copy() { delete type; }
    Cat* operator->() { return type; }
    Cat& operator*() { return *type; }
private:
    Copy() : type(0) {}
    Cat* type;
};

I want the Copy<Apples> type to be as interchangeable with Apples as possible. There are a few more functions I'll have to add to the Copy class above, but for now I'm working on a free function for operator==, as follows:

template<typename Cat>
bool operator==(const Copy<Cat>& copy, const Cat& e) {
    return *copy == e;
}

Here is part of my testing code:

Red red;
Copy<Apple> redCopy = red;
Copy<Apple> redCopy2 = redCopy;
assert(redCopy == Red());

But the compiler is telling me

../src/main.cpp:91: error: no match for ‘operator==’ in ‘redCopy == Red()’

How do I get it to recognize my operator== above? I suspect the answer might be in adding some implicit conversion code somewhere but I'm not sure what to do.


Solution

  • Your template is declared as

    template <typename Cat>
    bool operator==(const Copy<Cat>& copy, const Cat& e)
    

    This doesn't match redCopy == Red() because Red() is of type Red, so the compiler deduces Red as the type of the second argument, i.e. Cat = Red, but then it expects the type of the first argument to be Copy<Red>, which it is not (redCopy's type is Copy<Apple>).

    What you really want to express is something like

    template <typename Cat>
    bool operator==(const Copy<Cat>& copy, const something-that-derives-from-Cat& e)
    

    The easiest way to do this, is to add a second template parameter:

    template <typename Cat, typename DerivedFromCat>
    bool operator==(const Copy<Cat>& copy, const DerivedFromCat& e)
    

    Of course, this doesn't get the compiler to enforce that DerivedFromCat is actually derived from Cat. If you want this, you can use boost::enable_if:

    template <typename Cat, typename DerivedFromCat>
    typename enable_if<is_base_of<Cat, DerivedFromCat>, bool>::type
    operator==(const Copy<Cat>&, const DerivedFromCat& e)
    

    But that may be a bit of overkill...