Search code examples
c++template-meta-programmingstd-variant

How to define class with overloaded methods for each std::variant alternative?


I have some std::variant classes, each with several alternatives, and I would like to define a visitor class template that takes a variant as its template parameter and will automatically define a pure virtual void operator()(T const&) const for each alternative T in the variant. This way, I can define subclasses that inherit from instantiations of these visitor template classes, and will be forced to override each method, defined as pure virtual in its respective base class.

e.g.

#include <variant>

using VarA = std::variant<A1, A2, /* ... more alternatives ... */>;
using VarB = std::variant<B1, B2, /* ... more alternatives ... */>;

struct VarAVisitor : Visitor<VarA>
{
    // Must override 'void operator()(T const&) const' for each alternative type 'T' in VarA
};

struct VarBVisitor : Visitor<VarB>
{
    // Must override 'void operator()(T const&) const' for each alternative type 'T' in VarB
};

Basically, I am asking how would I implement the Visitor class template in the above example?


Solution

  • After some some googling and lots of trial and error, I managed to come up with something that does what I want. I'm sharing the solution here for anyone else who comes across the same issue.

    Here is a proof of concept.

    #include <iostream>
    #include <variant>
    
    
    template <typename> class Test { };
    
    using Foo = std::variant<
        Test<struct A>,
        Test<struct B>,
        Test<struct C>,
        Test<struct D>
        >;
    
    using Bar = std::variant<
        Test<struct E>,
        Test<struct F>,
        Test<struct G>,
        Test<struct H>,
        Test<struct I>,
        Test<struct J>,
        Test<struct K>,
        Test<struct L>
        >;
    
    
    template <typename T>
    struct DefineVirtualFunctor
    {
        virtual int operator()(T const&) const = 0;
    };
    
    template <template <typename> typename Modifier, typename... Rest>
    struct ForEach { };
    template <template <typename> typename Modifier, typename T, typename... Rest>
    struct ForEach<Modifier, T, Rest...> : Modifier<T>, ForEach<Modifier, Rest...> { };
    
    template <typename Variant>
    struct Visitor;
    template <typename... Alts>
    struct Visitor<std::variant<Alts...>> : ForEach<DefineVirtualFunctor, Alts...> { };
    
    
    struct FooVisitor final : Visitor<Foo>
    {
        int operator()(Test<A> const&) const override { return  0; }
        int operator()(Test<B> const&) const override { return  1; }
        int operator()(Test<C> const&) const override { return  2; }
        int operator()(Test<D> const&) const override { return  3; }
    };
    
    struct BarVisitor final : Visitor<Bar>
    {
        int operator()(Test<E> const&) const override { return  4; }
        int operator()(Test<F> const&) const override { return  5; }
        int operator()(Test<G> const&) const override { return  6; }
        int operator()(Test<H> const&) const override { return  7; }
        int operator()(Test<I> const&) const override { return  8; }
        int operator()(Test<J> const&) const override { return  9; }
        int operator()(Test<K> const&) const override { return 10; }
        int operator()(Test<L> const&) const override { return 11; }
    };
    
    
    int main(int argc, char const* argv[])
    {
        Foo foo;
        Bar bar;
        
        switch (argc) {
        case  0: foo = Foo{ std::in_place_index<0> }; break;
        case  1: foo = Foo{ std::in_place_index<1> }; break;
        case  2: foo = Foo{ std::in_place_index<2> }; break;
        default: foo = Foo{ std::in_place_index<3> }; break;
        }
        switch (argc) {
        case  0: bar = Bar{ std::in_place_index<0> }; break;
        case  1: bar = Bar{ std::in_place_index<1> }; break;
        case  2: bar = Bar{ std::in_place_index<2> }; break;
        case  3: bar = Bar{ std::in_place_index<3> }; break;
        case  4: bar = Bar{ std::in_place_index<4> }; break;
        case  5: bar = Bar{ std::in_place_index<5> }; break;
        case  6: bar = Bar{ std::in_place_index<6> }; break;
        default: bar = Bar{ std::in_place_index<7> }; break;
        }
        
        std::cout << std::visit(FooVisitor{ }, foo) << "\n";
        std::cout << std::visit(BarVisitor{ }, bar) << "\n";
    
        return 0;
    }
    

    As you can see, the Visitor class template accepts a std::variant type as a template parameter, from which it will define an interface that must be implemented in any child classes that inherit from the template class instantiation. If, in a child class, you happen to forget to override one of the pure virtual methods, you will get an error like the following.

    $ g++ -std=c++17 -o example example.cc
    example.cc: In function ‘int main(int, const char**)’:
    example.cc:87:41: error: invalid cast to abstract class type ‘BarVisitor’
       87 |     std::cout << std::visit(BarVisitor{ }, bar) << "\n";
          |                                         ^
    example.cc:51:8: note:   because the following virtual functions are pure within ‘BarVisitor’:
       51 | struct BarVisitor final : Visitor<Bar>
          |        ^~~~~~~~~~
    example.cc:29:17: note:     ‘int DefineVirtualFunctor<T>::operator()(const T&) const [with T = Test<J>]’
       29 |     virtual int operator()(T const&) const = 0;
          |                 ^~~~~~~~
    

    This is much easier to understand than the error messages that the compiler usually generates when using std::visit().