Search code examples
c++operatorspolymorphismvirtualequals-operator

Operator== in derived class never gets called


Can someone please put me out of my misery with this? I'm trying to figure out why a derived operator== never gets called in a loop. To simplify the example, here's my Base and Derived class:

class Base { // ... snipped
  bool operator==( const Base& other ) const { return name_ == other.name_; }
};

class Derived : public Base { // ... snipped
  bool operator==( const Derived& other ) const { 
    return ( static_cast<const Base&>( *this ) ==
             static_cast<const Base&>( other ) ? age_ == other.age_ :
                                                 false );
};

Now when I instantiate and compare like this ...

Derived p1("Sarah", 42);
Derived p2("Sarah", 42);
bool z = ( p1 == p2 );

... all is fine. Here the operator== from Derived gets called, but when I loop over a list, comparing items in a list of pointers to Base objects ...

list<Base*> coll;

coll.push_back( new Base("fred") );
coll.push_back( new Derived("sarah", 42) );
// ... snipped

// Get two items from the list.
Base& obj1 = **itr;
Base& obj2 = **itr2;

cout << obj1.asString() << " " << ( ( obj1 == obj2 ) ? "==" : "!=" ) << " "
     << obj2.asString() << endl;

Here asString() (which is virtual and not shown here for brevity) works fine, but obj1 == obj2 always calls the Base operator== even if the two objects are Derived.

I know I'm going to kick myself when I find out what's wrong, but if someone could let me down gently it would be much appreciated.


Solution

  • There are two ways to fix this.

    First solution. I would suggest adding some extra type logic to the loop, so you know when you have a Base and when you have a Derived. If you're really only dealing with Derived objects, use

    list<Derived*> coll;
    

    otherwise put a dynamic_cast somewhere.

    Second solution. Put the same kind of logic into your operator==. First make it virtual, so the type of the left-hand operand is determined at runtime. Then manually check the type of the right-hand operand.

    virtual bool operator==( const Base& other ) const {
      if ( ! Base::operator==( other ) ) return false;
      Derived *other_derived = dynamic_cast< Derived * >( &other );
      if ( ! other_derived ) return false;
      return age_ == other_derived->age_;
    }
    

    but considering that objects of different types probably won't be equal, probably what you want is

    virtual bool operator==( const Base& other ) const {
      Derived *other_derived = dynamic_cast< Derived * >( &other );
      return other_derived
       && Base::operator==( other )
       && age_ == other_derived->age_;
    }