Search code examples
c++virtualabstract-methods

C++ override virtual method with abstract class as a parameter


I have a following abstract class

class A {
public:
    virtual void foo(A* a) = 0;
}

and several classes inheriting from this class. e.g

class B : public A {
public:
    void foo(A* a); // implementation in a separete file
}

However, I only want class B to accept itself as an argument in foo

void foo(B* b);

Is it possible to do this in C++? I've considered a template but the syntax allows too much flexibility. It is possible to write class B: public A<B>, but I want a compiler error with class B: public A<C>.

-- Edit --

It seems like my use of abstract class is not justified. Let me clarify my situation.

I am utilizing a polymorphic behavior of A in a separate function. In addition to that, I want to define a function that takes in an argument of the same type such as the one above. I am trying to write a function that defines the distance between two objects of a derived class. Distance is only defined between objects from the same class (b1 and b2, or c1 and c2, but not b1 and c2). I also would like to access this distance function in a general way as possible.

-- Edit 2--

Cássio showed why it is not possible to perform compiler based checking. zar's solution adds slightly more structure to the code with runtime error checking.


Solution

  • I understand your question is more about the syntax. What you have is right, just pass an object of type B. The definition will still say A but it will be happy to take the derived class. You don't need any special definition for this.

    class A {
    public:
        virtual void foo(A* a) = 0;
    };
    
    class B : public A {
    public:
        void foo(A* a)
        {
            if (dynamic_cast<B*> (a) == NULL)
                std::cout << "wrong type, expecting type B\r\n";
        }
    };
    
    class C : public A {
    public:
        void foo(A* a)
        {
            if (dynamic_cast<C*> (a) == NULL)
                std::cout << "wrong type, expecting type C\r\n";
        }
    };
    
    int main()
    {
        B * b1 = new B;
        B * b2 = new B;
    
        C * c1 = new C;
        C * c2 = new C;
    
        b2->foo(c1); // bad
    
        c1->foo(b1); // bad
    
        b2->foo(b1); // good
    
        delete b1;
        delete b2;
        delete c1;
        delete c2;
    }
    

    see also dynamic_cast.