Search code examples
c++polymorphismsubtype

How to make the return type of a virtual member function parametric


I am implementing a visitor pattern for a syntax tree in C++. Every node in the tree is derived from a base class Expr which declares a pure virtual method accept. The accept method takes a reference to an instance of Visitor. Visitor is also an abstract class declaring a visit method for every type of node in the syntax tree. So far so good, classic visitor pattern. However, I need to make accept's return type parametric and I haven't figured out how to do it properly yet. Because accept is a virtual method I cannot use templates with it like this:

class Expr
{
public:
    template <typename T>
    virtual T accept(Visitor<T> &visitor) = 0;
};   

I can use templates to declare the Expr class itself, though:

template <typename T>
class Binary;

template <typename T>
class Grouping;

template <typename T>
class Literal;

template <typename T>
class Unary;


template <typename T>
class Visitor
{
public:
    ~Visitor() = default;
    virtual T visitBinaryExpr(Binary<T> &expr) = 0;
    virtual T visitGroupingExpr(Grouping<T> &expr) = 0;
    virtual T visitLiteralExpr(Literal<T> &expr) = 0;
    virtual T visitUnaryExpr(Unary<T> &expr) = 0;
};

template <typename T>
class Expr
{
public:
    virtual T accept(Visitor<T> &visitor) = 0;
};

template <typename T>
class Binary : public Expr<T>
{
    std::shared_ptr<Expr<T>> left_;
    std::shared_ptr<Token> op_;
    std::shared_ptr<Expr<T>> right_;

public:
    Binary(std::shared_ptr<Expr<T>> left, std::shared_ptr<Token> op, std::shared_ptr<Expr<T>> right)
    {
        left_ = left;
        op_ = op;
        right_ = right;
    }

    T accept(Visitor<T> &visitor) override
    {
        return visitor.visitBinaryExpr(*this);
    }

    std::shared_ptr<Expr<T>> getLeft()
    {
        return left_;
    }

    std::shared_ptr<Token> getOp()
    {
        return op_;
    }

    std::shared_ptr<Expr<T>> getRight()
    {
        return right_;
    }
};

Because abstract classes cannot be instantiated, the above seems to work. This solution, however, does not feel right (please keep in mind I am a C++ novice). For example, if I have a node of type Binary and I want to visit it with two functions that have different return type, I am forced to have an object instance for every return type spacialization and copy the original object back and forth among these to get the appropriate behavior for accept.

What would be a better way of implementing the accept's intended behavior?


Solution

  • For completeness, another approach I found to solve this problem is described here: https://www.codeproject.com/Tips/1018315/Visitor-with-the-Return-Value. This follows the classic visitor pattern (i.e., have visit and accept methods return void). Different return types are taken care of by a separate template class, and some template magic, as follow:

    template <typename VisitorImpl, typename ResultType>
    class ValueGetter
    {
        ResultType value_;
    
    public:
        static ResultType getValue(Expr &e)
        {
            VisitorImpl visitor;
            e.accept(visitor);
            return visitor.value_;
        }
    
        void returnValue(ResultType value)
        {
            value_ = value;
        }
    };
    
    class AstPrint : public Visitor, public ValueGetter<AstPrint, std::string>
    {
    public:
        void visitBinary(Binary &expr) override
        {
            std::string str;
            str += getValue(expr.getLeft());
            str += expr.getOp()->getLexeme();
            str += getValue(expr.getRight());
            returnValue(str);
        }
     
        /* ... */
    };
    

    The AstPrint class derives from a base templated class ValueGetter that contains the derived class itself as template argument, allowing to instantiate a Visitor for every node visited by getValue. This pattern is called Curious Recurrent Template Pattern (more information on the pattern can be found here: https://www.fluentcpp.com/2017/05/12/curiously-recurring-template-pattern).

    visitBinary prepares the return data by storing it in the attribute value_. Finally, data is returned by getValue when visitBinary is done visiting the node. A use case for this follows:

    Expr *expression = new Binary(
            std::make_shared<Unary>(
                std::make_shared<Token>(TokenType::eMINUS, "-", nullptr, 1),
                std::make_shared<Literal>(
                    std::make_shared<NumericLiteral>("123")
                    )
                ),
            std::make_shared<Token>(TokenType::eSTAR, "*", nullptr, 1),
            std::make_shared<Grouping>(
                std::make_shared<Literal>(
                    std::make_shared<NumericLiteral>("45.67")
                )
            )
        );
    
    /* print -123 * (45.67) */
    std::cout << AstPrint::getValue(*expression) << std::endl;
    

    I think this is what @Igor was suggesting.