Search code examples
c++templatesc++17metaprogrammingvariadic-templates

Traversing trees at compile time with C++17 Variadic Templates


I'm currently looking into using C++ (C++17) variadic templates for generating efficient, real-time simulations of circuits.

My goal is to leverage variadic templates to define a tree that can be traversed at compile-time. To define such a tree, I use the following three structs:

template <auto Tag> struct Leaf
{
    static constexpr auto tag = Tag;
};

template <typename ... Children> struct Branch
{
    static constexpr auto child_count = sizeof ... (Children);
    
    template <typename Lambda> constexpr void for_each_child(Lambda && lambda)
    {
        // TODO: Execute 'lambda' on each child.
    }
    
    std::tuple<Children ...> m_children {};
};

template <typename Root> struct Tree
{
    template <auto Tag> constexpr auto & get_leaf()
    {
        // TODO: Traverse the tree and find the leaf with tag 'Tag'.
        
        // If there's no leaf with tag 'Tag' the program shouldn't compile.
    }
    
    Root root {};
};

Using the above definition of a tree, we can define a set of circuit components as follows:

template <auto Tag> struct Resistor : Leaf<Tag>
{
    float resistance() { return m_resistance; }
    
    float m_resistance {};
};

template <auto Tag> struct Capacitor : Leaf<Tag>
{
    float resistance() { return 0.0f; }
    
    float m_capacitance {};
};

template <typename ... Children> struct Series : Branch<Children ...>
{
    using Branch<Children ...>::for_each_child;
    
    float resistance()
    {
        float acc = 0.0f;
        
        for_each_child([&acc](auto child) { acc += child.resistance(); });
        
        return acc;
    }
};

template <typename ... Children> struct Parallel : Branch<Children ...>
{
    using Branch<Children ...>::for_each_child;
    
    float resistance()
    {
        float acc = 0.0f;
        
        for_each_child([&acc](auto child) { acc += 1.0f / child.resistance(); });
        
        return 1.0f / acc;
    }
};

Next, using the above components, we can express a specific circuit like this:

enum { R0, R1, C0, C1 };

using Circuit =
    Tree<
        Parallel<
            Series<
                Resistor<R0>,
                Capacitor<C0>
            >, // Series
            Series<
                Resistor<R0>,
                Capacitor<C1>
            > // Series
        > // Parallel
    >; // Tree

...where R0, R1, C0, and C1 are tags that we use for accessing components at compile time. E.g. a very basic use case could be the following:

int main()
{
    Circuit circuit {};
    
    circuit.get_leaf<R0>().m_resistance  =  5.0E+3f;
    circuit.get_leaf<C0>().m_capacitance = 10.0E-3f;
    circuit.get_leaf<R1>().m_resistance  =  5.0E+6f;
    circuit.get_leaf<C1>().m_capacitance = 10.0E-6f;
    
    std::cout << circuit.root.resistance() << std::endl;
}

What I just can't wrap my head around is how to implement the functions for_each_child and get_leaf. I've tried different approaches using if-constexpr statements and template-structs without finding a good solution. Variadic templates are interesting but daunting at the same time. Any help would be greatly appreciated.


Solution

  • After studying various articles on C++ Variadic Templates, I've managed to patch up a solution to the problem.

    First, to implement for_each_child we use the following helper function that works as a for-loop that is un-rolled at compile-time:

    template <auto from, auto to, typename Lambda>
        static inline constexpr void for_constexpr(Lambda && lambda)
    {
        if constexpr (from < to)
        {
            constexpr auto i = std::integral_constant<decltype(from), from>();
            
            lambda(i);
            
            for_constexpr<from + 1, to>(lambda);
        }
    }
    

    By using this helper function we can implement for_each_child as follows:

    template <typename ... Children> struct Branch
    {
        static constexpr auto children_count = sizeof ... (Children);
    
        template <typename Lambda> constexpr void for_each_child(Lambda && lambda)
        {
            for_constexpr<0, children_count>([lambda, this](auto i)
            {
                lambda(std::get<i>(m_children));
            });
        }
        
        std::tuple<Children ...> m_children {};
    };
    

    Next, to implement get_leaf, we use a bunch of different helper functions. As Caleth suggested, we can divide the problem into two steps. First, we compute the path from the root to the desired leaf; afterwards, we can follow that path to extract the leaf from the tree.

    A path can be represented as an index sequence as follows:

    template <auto ...indices> using Path = std::index_sequence<indices...>;
    

    The first helper function we need checks whether a node has a leaf with a given tag:

    template <auto tag, class Node> struct has_path
    {
        static constexpr
            std::true_type
                match(const Leaf<tag>);
        
        template <class ...Children> static constexpr
            typename std::enable_if<
                (has_path<tag, Children>::type::value || ...),
                std::true_type
            >::type
                match(const Branch<Children...>);
        
        static constexpr
            std::false_type
                match(...);
        
        using type = decltype(match(std::declval<Node>()));
    };
    

    We simply pattern match on the node. If it is a leaf we must make sure that it has the correct tag. And, if it is a branch, we need to ensure that one of the children has a leaf with the tag.

    The next helper function is a bit more complicated:

    template <auto tag, class Node, auto ...indices> struct find_path
    {
        template <auto index, class Child, class ...Children> struct search_children
        {
            static constexpr auto fold()
            {
                if constexpr(has_path<tag, Child>::type::value)
                {
                    return typename find_path<tag, Child, indices..., index>::type();
                }
                else
                {
                    return typename search_children<index + 1, Children...>::type();
                }
            }
            
            using type = decltype(fold());
        };
        
        static constexpr
            Path<indices...>
                match(const Leaf<tag>);
        
        template <class ...Children> static constexpr
            typename search_children<0, Children...>::type
                match(const Branch<Children...>);
        
        using type = decltype(match(std::declval<Node>()));
    };
    

    We accumulate the path in the indices template parameter. If the node that we are investigating (via the template parameter Node) is a leaf, we check that it has the correct tag and, if so, return the accumulated path. If instead, the node is a branch we have to use the helper function search_children which iterates through all the children in the branch. For each child, we first check whether that child has a leaf with the given tag. If so, we append the current index (given by the template parameter index) to the accumulated path and call find_path recursively on that child. If the child does not have a leaf with the given tag, we try the next child instead, and so on.

    Finally, we define a helper function that can extract a leaf given a path:

    template <class Node>
        static inline constexpr auto &
            get(Node & leaf, Path<> path)
    {
        return leaf;
    }
    
    template <auto index, auto ...indices, class Node>
        static inline constexpr auto &
            get(Node & branch, Path<index, indices...> path)
    {
        auto & child = std::get<index>(branch.m_children);
        
        return get(child, Path<indices...>());
    }
    

    Using find_path and get we can implement get_leaf as follows:

    template <typename Root> struct Tree
    {
        template <auto tag> constexpr auto & get_leaf()
        {
            constexpr auto path = typename implementation::find_path<tag, Root>::type {};
            
            return implementation::get(root, path);
        }
        
        Root root;
    };
    

    Here's a link to godbolt.org that demonstrates that the code compiles and works as expected with Clang:

    godbolt.org/...