Search code examples
boostboost-variant

Composing boost::variant visitors for recursive variants


I have an application with several boost::variants which share many of the fields. I would like to be able to compose these visitors into visitors for "larger" variants without copying and pasting a bunch of code. It seems straightforward to do this for non-recursive variants, but once you have a recursive one, the self-references within the visitor (of course) point to the wrong class. To make this concrete (and cribbing from the boost::variant docs):

#include "boost/variant.hpp"
#include <iostream>

struct add;
struct sub;
template <typename OpTag> struct binop;

typedef boost::variant<
  int
  , boost::recursive_wrapper< binop<add> >
  , boost::recursive_wrapper< binop<sub> >
  > expression;

template <typename OpTag>
struct binop
{
  expression left;
  expression right;

  binop( const expression & lhs, const expression & rhs )
    : left(lhs), right(rhs)
  {
  }

};

// Add multiplication
struct mult;
typedef boost::variant<
  int
  , boost::recursive_wrapper< binop<add> >
  , boost::recursive_wrapper< binop<sub> >
  , boost::recursive_wrapper< binop<mult> >
  > mult_expression;

class calculator : public boost::static_visitor<int>
{
public:

  int operator()(int value) const
  {
    return value;
  }

  int operator()(const binop<add> & binary) const
  {
    return boost::apply_visitor( *this, binary.left )
      + boost::apply_visitor( *this, binary.right );
  }

  int operator()(const binop<sub> & binary) const
  {
    return boost::apply_visitor( *this, binary.left )
      - boost::apply_visitor( *this, binary.right );
  }

};

class mult_calculator : public boost::static_visitor<int>
{
public:

  int operator()(int value) const
  {
    return value;
  }

  int operator()(const binop<add> & binary) const
  {
    return boost::apply_visitor( *this, binary.left )
      + boost::apply_visitor( *this, binary.right );
  }

  int operator()(const binop<sub> & binary) const
  {
    return boost::apply_visitor( *this, binary.left )
      - boost::apply_visitor( *this, binary.right );
  }

  int operator()(const binop<mult> & binary) const
  {
    return boost::apply_visitor( *this, binary.left )
      * boost::apply_visitor( *this, binary.right );
  }

};

// I'd like something like this to compile
// class better_mult_calculator : public calculator
// {
// public:

//   int operator()(const binop<mult> & binary) const
//   {
//     return boost::apply_visitor( *this, binary.left )
//       * boost::apply_visitor( *this, binary.right );
//   }

// };


int main(int argc, char **argv)
{
  // result = ((7-3)+8) = 12
  expression result(binop<add>(binop<sub>(7,3), 8));

  assert( boost::apply_visitor(calculator(),result) == 12 );

  std::cout << "Success add" << std::endl;

  // result2 = ((7-3)+8)*2 = 12
  mult_expression result2(binop<mult>(binop<add>(binop<sub>(7,3), 8),2));
  assert( boost::apply_visitor(mult_calculator(),result2) == 24 );

  std::cout << "Success mult" << std::endl;
}

I would really like something like that commented out better_mult_expression to compile (and work) but it doesn't -- because the this pointers within the base calculator visitor don't reference mult_expression, but expression.

Does anyone have suggestions for overcoming this or am I just barking down the wrong tree?


Solution

  • Firstly, I'd suggest the variant to include all possible node types, not distinguishing between mult and expression. This distinction makes no sense at the AST level, only at a parser stage (if you implement operator precedence in recursive/PEG fashion).

    Other than that, here's a few observations:

    • if you encapsulate the apply_visitor dispatch into your evaluation functor you can reduce the code duplication by a big factor

    • your real question seems not to be about composing variants, but composing visitors, more specifically, by inheritance.

      You can use using to pull inherited overloads into scope for overload resolution, so this might be the most direct answer:

      Live On Coliru

       struct better_mult_calculator : calculator {
           using calculator::operator();
      
           auto operator()(const binop<mult>& binary) const
           {
               return boost::apply_visitor(*this, binary.left) *
                   boost::apply_visitor(*this, binary.right);
           }
       };
      

    IMPROVING!

    Starting from that listing let's shave off some noise!

    1. remove unncessary AST distinction (-40 lines, down to 55 lines of code)

    2. generalize the operations; the <functional> header comes standard with these:

      namespace AST {
          template <typename> struct binop;
          using add  = binop<std::plus<>>;
          using sub  = binop<std::minus<>>;
          using mult = binop<std::multiplies<>>;
          using expr = boost::variant<int,
              recursive_wrapper<add>,
              recursive_wrapper<sub>,
              recursive_wrapper<mult>>;
      
          template <typename> struct binop { expr left, right; };
      } // namespace AST
      

      Now the entire calculator can be:

      struct calculator : boost::static_visitor<int> {
          int operator()(int value) const { return value; }
      
          template <typename Op> 
          int operator()(AST::binop<Op> const& binary) const {
              return Op{}(boost::apply_visitor(*this, binary.left),
                          boost::apply_visitor(*this, binary.right));
          }
      };
      

      Here your variant can add arbitrary operations without even needing to touch the calculator.

      Live Demo, 43 Lines Of Code

    3. Like I mentioned starting off, encapsulate visitation!

      struct Calculator {
          template <typename... T> int operator()(boost::variant<T...> const& v) const {
              return boost::apply_visitor(*this, v);
          }
      
          template <typename T> 
          int operator()(T const& lit) const { return lit; }
      
          template <typename Op> 
          int operator()(AST::binop<Op> const& bin) const {
              return Op{}(operator()(bin.left), operator()(bin.right));
          }
      };
      

      Now you can just call your calculator, like intended:

      Calculator calc;
      auto result1 = calc(e1);
      

      It will work when you extend the variant with operatios or even other literal types (like e.g. double). It will even work, regardless of whether you pass it an incompatible variant type that holds a subset of the node types.

    4. To finish that off for maintainability/readability, I'd suggest making operator() only a dispatch function:

    Full Demo

    Live On Coliru

    #include <boost/variant.hpp>
    #include <iostream>
    
    namespace AST {
        using boost::recursive_wrapper;
    
        template <typename> struct binop;
        using add  = binop<std::plus<>>;
        using sub  = binop<std::minus<>>;
        using mult = binop<std::multiplies<>>;
        using expr = boost::variant<int,
            recursive_wrapper<add>,
            recursive_wrapper<sub>,
            recursive_wrapper<mult>>;
    
        template <typename> struct binop { expr left, right; };
    } // namespace AST
    
    struct Calculator {
        auto operator()(auto const& v) const { return call(v); }
    
      private:
        template <typename... T> int call(boost::variant<T...> const& v) const {
            return boost::apply_visitor(*this, v);
        }
    
        template <typename T> 
        int call(T const& lit) const { return lit; }
    
        template <typename Op> 
        int call(AST::binop<Op> const& bin) const {
            return Op{}(call(bin.left), call(bin.right));
        }
    };
    
    int main()
    {
        using namespace AST;
        std::cout << std::boolalpha;
        auto sub_expr = add{sub{7, 3}, 8};
        expr e1       = sub_expr;
        expr e2       = mult{sub_expr, 2};
    
        Calculator calc;
    
        auto result1 = calc(e1);
        std::cout << "result1:  " << result1 << " Success? " << (12 == result1) << "\n";
    
        // result2 = ((7-3)+8)*2 = 12
        auto result2 = calc(e2);
        std::cout << "result2:  " << result2 << " Success? " << (24 == result2) << "\n";
    }
    

    Still prints

    result1:  12 Success? true
    result2:  24 Success? true