Search code examples
c++polymorphismabstract-class

Generic object for derived classes for instantiation and returning purposes


I want to create a graph object which can switch between following different mathematical functions as I plot along it. I am currently keeping the current math function as an object inside the graph object, so it knows what to follow when I call its plotting function, and am trying to use polymorphism to describe different kinds of math functions. My current (non-working) code looks like this:

#pragma once
#include<array>

class MathExpression
{
public:
    virtual ~MathExpression() = default;
    virtual MathExpression integrate() const = 0;
};

class Polynomial : public MathExpression
{
public:
    Polynomial(std::array<double, 5> coefficients) : coefficients(coefficients) {};
    Polynomial() : coefficients(std::array<double, 5>{0.0, 0.0, 0.0, 0.0, 0.0}) {};
    Polynomial(const Polynomial& p) : coefficients(p.coefficients) {};
    ~Polynomial() = default;

    Polynomial integrate() const;
    std::array<double, 5> getCoefficients() const;
private:
    std::array<double, 5> coefficients;
};

class Exponential : public MathExpression
{
public:
    Exponential(std::array<double, 3> coefficients) : coefficients(coefficients) {};
    Exponential() : coefficients(std::array<double, 3>{0.0, 0.0, 0.0}) {};
    Exponential(const Exponential& e) : coefficients(e.coefficients) {};
    ~Exponential() = default;

    Exponential integrate() const;
    std::array<double, 3> getCoefficients() const;
private:
    const std::array<double, 3> coefficients;
};

class Graph
{
public:
    Graph() = default;
    MathExpression getCurrentCurve() {return current_curve;}
    void setCurrentCurve(const MathExpression& curve) {
        current_curve = curve;
        return;
    }
private:
    MathExpression current_curve;
};

This code does not work as mathExpression is an abstract class which contains pure virtual functions and therefore I can neither create instances of it, nor return it from functions. I am also getting errors from the integrate() functions of the derived classes that the return type (Polynomial/Exponential) is not identical to nor covariant with the return type MathExpression of the overridden integrate() function.

I need to find a solution to these two problems:

  • How to have a generic member in Graph which could be any derived class of mathExpression and which can be redefined at will, even to other derived classes (going from a polynomial to exponential, for example).
  • How to return that derived class type from member functions inside Graph or the derived classes (such as getCurrentCurve() or integrate()).

How can I solve this/work around these issues?


Solution

  • As @PaulMcKenzie suggested, just use polymorphism (for example, with smart pointers). That is, you manage pointers to a base class, MathExpression, and you create heap instances of the derived classes, Polynomial and Exponential (with new or make_unique or make_shared). Since you seem to be returning instances of your derived classes and sharing them, you may want to use shared_pointers.

    A possible implementation to start with:

    [Demo]

    #include <array>
    #include <memory>  // make_shared, shared_ptr
    
    class MathExpression
    {
    public:
        virtual ~MathExpression() = default;
        virtual std::shared_ptr<MathExpression> integrate() const = 0;
    };
    
    class Polynomial : public MathExpression
    {
    public:
        Polynomial(std::array<double, 5> coefficients) : coefficients(coefficients) {};
        Polynomial() : coefficients(std::array<double, 5>{0.0, 0.0, 0.0, 0.0, 0.0}) {};
        Polynomial(const Polynomial& p) : coefficients(p.coefficients) {};
        ~Polynomial() = default;
    
        std::shared_ptr<MathExpression> integrate() const { return std::make_shared<Polynomial>(); }
        std::array<double, 5> getCoefficients() const;
    private:
        std::array<double, 5> coefficients;
    };
    
    class Exponential : public MathExpression
    {
    public:
        Exponential(std::array<double, 3> coefficients) : coefficients(coefficients) {};
        Exponential() : coefficients(std::array<double, 3>{0.0, 0.0, 0.0}) {};
        Exponential(const Exponential& e) : coefficients(e.coefficients) {};
        ~Exponential() = default;
    
        std::shared_ptr<MathExpression> integrate() const { return std::make_shared<Exponential>(); }
        std::array<double, 3> getCoefficients() const;
    private:
        const std::array<double, 3> coefficients;
    };
    
    class Graph
    {
    public:
        Graph() = default;
        std::shared_ptr<MathExpression> getCurrentCurve() { return current_curve; }
        void setCurrentCurve(const std::shared_ptr<MathExpression>& curve) {
            current_curve = curve;
            return;
        }
    private:
        std::shared_ptr<MathExpression> current_curve{};
    };
    
    int main() {
        std::shared_ptr<MathExpression> p{std::make_shared<Polynomial>()};
        std::shared_ptr<MathExpression> e{std::make_shared<Exponential>()};
    }