Search code examples
c++templatespolymorphismderived-class

How do I properly derive from a nested struct?


I have an abstract (templated) class that I want to have its own return type InferenceData.

template <typename StateType>
class Model {
public:
    struct InferenceData;
    virtual InferenceData inference () = 0;
};

Now below is an attempt to derive from it

template <typename StateType>
class MonteCarlo : public Model<StateType> {
public:
    
    // struct InferenceData {};
    
    typename MonteCarlo::InferenceData inference () {
        typename MonteCarlo::InferenceData x;
        return x;
    }
};

This works, but only because the definition of MonteCarlo::InferenceData is commented out. If it is not commented, I get invalid covariant return type error. I want each ModelDerivation<StateType>::InferenceData to be its own type and have its own implementation as a struct. How do I achieve this?


Solution

  • You cannot change the return type of a derived virtual method. This is why your compilation failed when you try to return your derived InferenceData from MonteCarlo::inference().

    In order to achieve what you need, you need to use a polymorphic return type, which requires pointer/reference semantics. For this your derived InferenceData will have to inherit the base InferenceData, and inference() should return a pointer/reference to the base InferenceData.

    One way to do it is with a smart pointer - e.g. a std::unique_ptr - see the code below:

    #include <memory>
    
    template <typename StateType>
    class Model {
    public:
        struct InferenceData {};
        virtual std::unique_ptr<InferenceData> inference() = 0;
    };
    
    
    template <typename StateType>
    class MonteCarlo : public Model<StateType> {
    public:
        struct InferenceDataSpecific : public Model<StateType>::InferenceData {};
    
        virtual std::unique_ptr<Model::InferenceData> inference() {
            return std::make_unique<InferenceDataSpecific>();
        }
    };
    
    int main()
    {
        MonteCarlo<int> m;
        auto d = m.inference();
        return 0;
    }
    

    Note: if you need to share the data, you can use a std::shared_ptr.