Search code examples
c++polymorphismfunction-calls

How do I avoid type casting and typeid in c++ with polymorphic function arguments from inherited classes


I have an abstract shape class which is extended with multiple actual shape implementations. The overlap of different shapes need to be calculated. At first I thought to do this with simple polymorphic function arguments, as shown in code snippet 1.

#include <iostream>
#include <vector>
#include <typeinfo>
#include <string>

class Shape {
    public:
        virtual void getOverlap(Shape *obj) = 0;
    };

    class Square : public Shape {
        class Circle;
    public:
        virtual void getOverlap(Circle *obj) 
             { std::cout << "Square overlap with Circle" << std::endl; }    
        virtual void getOverlap(Square *obj) 
             { std::cout << "Square overlap with Square" << std::endl; }
};

class Circle : public Shape {
    class Square;
    public:
        virtual void getOverlap(Circle *obj)
            { std::cout << "Circle overlap with Circle" << std::endl; }
        virtual void getOverlap(Square *obj) 
            { std::cout << "Circle overlap with Square" << std::endl; }
};

int main() {
    std::vector<Shape*> shapes = { new Square, new Circle };
    shapes[0]->getOverlap(shapes[1]);
    shapes[1]->getOverlap(shapes[1]);
}

This off course does not compile since the inherited classes do not implement the virtual function, but instead try to implement the virtual function with derived classes as argument. I hope however that it makes my intention clear.

The workaround that I think of implementing after not finding a suitable answer for this problem is shown in the second code snippet.

#include <iostream>
#include <vector>
#include <typeinfo>
#include <string>

class Shape {
public:
    virtual void getOverlap(Shape *obj) = 0;
};

class Square : public Shape {
    class Circle;
public:
    virtual void getOverlap(Shape *obj) {
        string className = typeid(*obj).name();
        if (className.compare("class Circle") == 0){
            getOverlap((Circle*)obj);
        }
        else if (className.compare("class Square") == 0) {
            getOverlap((Square*)obj);
        }
    }

private:
    void getOverlap(Circle *obj) 
        {   std::cout << "Square overlap with Circle" << std::endl; }
    void getOverlap(Square *obj) 
        { std::cout << "Square overlap with Square" << std::endl; }
};

class Circle : public Shape {
    class Square;
public:
    virtual void getOverlap(Shape *obj) {
        string className = typeid(*obj).name();
        if (className.compare("class Circle") == 0) {
            getOverlap((Circle*)obj);
        }
        else if (className.compare("class Square") == 0) {
            getOverlap((Square*)obj);
        }
    }

private:
    void getOverlap(Circle *obj) 
         { std::cout << "Circle overlap with Circle" << std::endl; }
    void getOverlap(Square *obj) 
         { std::cout << "Circle overlap with Square" << std::endl; }
};

int main() {
    std::vector<Shape*> shapes = { new Square, new Circle };
    shapes[0]->getOverlap(shapes[1]);
    shapes[1]->getOverlap(shapes[1]);
}

I strongly dislike the specific request for the typeid of the class and the explicit conversion to the correct pointer type, since the object is in principle already of the correct type.

What is a good way of doing this, using the polymorphism of c++.


Solution

  • What you're trying to achieve is called double dispatch, and the object-oriented way to implement it is the Visitor pattern.

    In your case, it would look somewhat like this:

    class Shape
    {
       public:
         virtual void getOverlap(Shape* s)=0;
    
         virtual void getOverlapCircle(Circle* c)=0;
         virtual void getOverlapSquare(Square* s)=0;
    };
    
    class Square : public Shape
    {
       public:
         virtual void getOverlap(Shape* s)
         {
             s->getOverlapSquare(this);
         }
    
         virtual void getOverlapSquare(Square* s)
         {
            // code for overlapping 2 squares
         }
    
         virtual void getOverlapCircle(Circle* c)
         {
            // code for overlapping a circle & a square
         }
    };
    

    The same goes for the Circle class obviously.

    This is a default textbook approach; notice how it doesn't require any casts.

    Personally however I find it to be quite ugly, as it completely lacks any notion of symmetry of overlapping shapes, and it's not easy to extend.

    Better techniques however would require some low-level tweaking, up to the point of reimplementing the whole vtable mechanics by yourself.