Search code examples
c++11overridingclonesmart-pointersvirtual-functions

Clone pattern for std::shared_ptr in C++


Why do you need (in order to make it compile) the intermediate CloneImplementation and std::static_pointer_cast (see Section 3 below) to use the Clone pattern for std::shared_ptr instead of something closer (see Section 2 below) to the use of raw pointers (see Section 1 below)? Because as far as I understand, std::shared_ptr has a generalized copy constructor and a generalized assignment operator?

1. Clone pattern with raw pointers:

#include <iostream>

struct Base {
    virtual Base *Clone() const {
        std::cout << "Base::Clone\n";
        return new Base(*this);
    }
};

struct Derived : public Base {
    virtual Derived *Clone() const override {
        std::cout << "Derived::Clone\n";
        return new Derived(*this);
    }
};

int main() {
  Base *b = new Derived;
  b->Clone();
}

2. Clone pattern with shared pointers (naive attempt):

#include <iostream>
#include <memory>

struct Base {
    virtual std::shared_ptr< Base > Clone() const {
        std::cout << "Base::Clone\n";
        return std::shared_ptr< Base >(new Base(*this));
    }
};
struct Derived : public Base {
    virtual std::shared_ptr< Derived > Clone() const override {
        std::cout << "Derived::Clone\n";
        return std::shared_ptr< Derived >(new Derived(*this));
    }
};

int main() {
  Base *b = new Derived;
  b->Clone();
}

Output:

error: invalid covariant return type for 'virtual std::shared_ptr<Derived> Derived::Clone() const'
error:   overriding 'virtual std::shared_ptr<Base> Base::Clone() const'

3. Clone pattern with shared pointers:

#include <iostream>
#include <memory>

struct Base {

    std::shared_ptr< Base > Clone() const {
        std::cout << "Base::Clone\n";
        return CloneImplementation();
    }

private:

    virtual std::shared_ptr< Base > CloneImplementation() const {
        std::cout << "Base::CloneImplementation\n";
        return std::shared_ptr< Base >(new Base(*this));
    }
};
struct Derived : public Base {

    std::shared_ptr< Derived > Clone() const {
        std::cout << "Derived::Clone\n";
        return std::static_pointer_cast< Derived >(CloneImplementation());
    }

private:

    virtual std::shared_ptr< Base > CloneImplementation() const override {
        std::cout << "Derived::CloneImplementation\n";
        return std::shared_ptr< Derived >(new Derived(*this));
    }
};

int main() {
  Base *b = new Derived;
  b->Clone();
}

Solution

  • The general rule in C++ is that the overriding function must have the same signature as the function it overrides. The only difference is that covariance is allowed on pointers and references: if the inherited function returns A* or A&, the overrider can return B* or B& respectively, as long as A is a base class of B. This rule is what allows Section 1 to work.

    On the other hand, std::shared_ptr<Derived> and std::shared_ptr<Base> are two totally distinct types with no inheritance relationship between them. It's therefore not possible to return one instead of the other from an overrider. Section 2 is conceptually the same as trying to override virtual int f() with std::string f() override.

    That's why some extra mechanism is needed to make smart pointers behave covariantly. What you've shown as Section 3 is one such possible mechanism. It's the most general one, but in some cases, alternatives also exist. For example this:

    struct Base {
        std::shared_ptr< Base > Clone() const {
            std::cout << "Base::Clone\n";
            return std::shared_ptr< Base >(CloneImplementation());
        }
    
    private:
        virtual Base* CloneImplementation() const {
            return new Base(*this);
        }
    };
    
    struct Derived : public Base {
         std::shared_ptr< Derived > Clone() const {
            std::cout << "Derived::Clone\n";
            return std::shared_ptr< Derived >(CloneImplementation());
        }
    
    private:
        virtual Derived* CloneImplementation() const override {
            std::cout << "Derived::CloneImplementation\n";
            return new Derived(*this);
        }
    };