I am implementing a visitor pattern for a syntax tree in C++. Every node in the tree is derived from a base class Expr
which declares a pure virtual method accept
. The accept
method takes a reference to an instance of Visitor
. Visitor is also an abstract class declaring a visit
method for every type of node in the syntax tree. So far so good, classic visitor pattern. However, I need to make accept
's return type parametric and I haven't figured out how to do it properly yet. Because accept
is a virtual method I cannot use templates with it like this:
class Expr
{
public:
template <typename T>
virtual T accept(Visitor<T> &visitor) = 0;
};
I can use templates to declare the Expr
class itself, though:
template <typename T>
class Binary;
template <typename T>
class Grouping;
template <typename T>
class Literal;
template <typename T>
class Unary;
template <typename T>
class Visitor
{
public:
~Visitor() = default;
virtual T visitBinaryExpr(Binary<T> &expr) = 0;
virtual T visitGroupingExpr(Grouping<T> &expr) = 0;
virtual T visitLiteralExpr(Literal<T> &expr) = 0;
virtual T visitUnaryExpr(Unary<T> &expr) = 0;
};
template <typename T>
class Expr
{
public:
virtual T accept(Visitor<T> &visitor) = 0;
};
template <typename T>
class Binary : public Expr<T>
{
std::shared_ptr<Expr<T>> left_;
std::shared_ptr<Token> op_;
std::shared_ptr<Expr<T>> right_;
public:
Binary(std::shared_ptr<Expr<T>> left, std::shared_ptr<Token> op, std::shared_ptr<Expr<T>> right)
{
left_ = left;
op_ = op;
right_ = right;
}
T accept(Visitor<T> &visitor) override
{
return visitor.visitBinaryExpr(*this);
}
std::shared_ptr<Expr<T>> getLeft()
{
return left_;
}
std::shared_ptr<Token> getOp()
{
return op_;
}
std::shared_ptr<Expr<T>> getRight()
{
return right_;
}
};
Because abstract classes cannot be instantiated, the above seems to work. This solution, however, does not feel right (please keep in mind I am a C++ novice). For example, if I have a node of type Binary
and I want to visit it with two functions that have different return type, I am forced to have an object instance for every return type spacialization and copy the original object back and forth among these to get the appropriate behavior for accept
.
What would be a better way of implementing the accept
's intended behavior?
For completeness, another approach I found to solve this problem is described here: https://www.codeproject.com/Tips/1018315/Visitor-with-the-Return-Value.
This follows the classic visitor pattern (i.e., have visit
and accept
methods return void
). Different return types are taken care of by a separate template class, and some template magic, as follow:
template <typename VisitorImpl, typename ResultType>
class ValueGetter
{
ResultType value_;
public:
static ResultType getValue(Expr &e)
{
VisitorImpl visitor;
e.accept(visitor);
return visitor.value_;
}
void returnValue(ResultType value)
{
value_ = value;
}
};
class AstPrint : public Visitor, public ValueGetter<AstPrint, std::string>
{
public:
void visitBinary(Binary &expr) override
{
std::string str;
str += getValue(expr.getLeft());
str += expr.getOp()->getLexeme();
str += getValue(expr.getRight());
returnValue(str);
}
/* ... */
};
The AstPrint
class derives from a base templated class ValueGetter
that contains the derived class itself as template argument, allowing to instantiate a Visitor
for every node visited by getValue
. This pattern is called Curious Recurrent Template Pattern (more information on the pattern can be found here: https://www.fluentcpp.com/2017/05/12/curiously-recurring-template-pattern).
visitBinary
prepares the return data by storing it in the attribute value_
. Finally, data is returned by getValue
when visitBinary
is done visiting the node. A use case for this follows:
Expr *expression = new Binary(
std::make_shared<Unary>(
std::make_shared<Token>(TokenType::eMINUS, "-", nullptr, 1),
std::make_shared<Literal>(
std::make_shared<NumericLiteral>("123")
)
),
std::make_shared<Token>(TokenType::eSTAR, "*", nullptr, 1),
std::make_shared<Grouping>(
std::make_shared<Literal>(
std::make_shared<NumericLiteral>("45.67")
)
)
);
/* print -123 * (45.67) */
std::cout << AstPrint::getValue(*expression) << std::endl;
I think this is what @Igor was suggesting.