Search code examples
c++recursion

In C++, how can we write a recursive function that returns a range?


In C++ we face 2 problems when writing a recursive function that returns a range

  • the base case type is different from the recursive case type
  • ranges often involve lambdas, so the return type is auto, which prevents recursion

Let me illustrate with sum decompositions. Here is a strict (non lazy) solution

std::list<std::list<unsigned int>> SumDecompStrict(unsigned int n)
{
    if (n == 0)
        return { {} }; // the sum of the empty list is 0 by convention
    std::list<std::list<unsigned int>> ret;
    for (unsigned int k = 1; k <= n; k++)
    {
        std::list<std::list<unsigned int>> ll = SumDecompStrict(n - k);
        for (auto& l : ll)
            l.push_front(k);
        ret.splice(ret.end(), std::move(ll));
    }
    return ret;
}

Now I want to convert this into a lazy generator of sum decompositions, by returning a range, which could look like this

auto SumDecompLazy(unsigned int n)
{
    if (n == 0)
        return std::ranges::single_view<std::list<unsigned int>>({}); // the sum of the empty list is 0 by convention
    return std::views::iota(1, n+1)
        | std::views::transform([n](unsigned int k) {
            return SumDecompLazy(n - k)
                | std::views::transform([k](std::list<unsigned int>& l) { l.push_front(k); return l; });
            }
        | std::views::join;
}

But SumDecompLazy fails to compile, because of the 2 mentioned problems:

  • the type of std::ranges::single_view<std::list<unsigned int>>({}) is different from the type of the recursive case
  • auto return type prevents calling SumDecompLazy from itself

This problem is solved in other programming languages, for example by the IEnumerable type in C#, or its alias seq in F#. Is there a solution in C++?


Solution

  • C++23 has std::generator, which allows you to express this as a coroutine, either explicitly looping

    std::generator<std::list<unsigned int>> SumDecompLoop(unsigned int n)
    {
        if (n == 0)
            co_yield {};
    
        for (unsigned int k = 1; k <= n; k++)
        {
            for (auto l : SumDecompLoop(n - k))
            {
                l.push_front(k);
                co_yield std::move(l);
            }
        }
    }
    

    or using range adaptors

    std::generator<std::list<unsigned int>> 
    SumDecompRange(unsigned int n)
    {
        if (n == 0)
            co_yield {}; 
        co_yield std::ranges::elements_of(std::views::iota(1u, n+1)
            | std::views::transform([n](unsigned int k) {
                return SumDecompRange(n - k)
                    | std::views::transform([k](std::list<unsigned int> l) { l.push_front(k); return l; });
                })
            | std::views::join);
    }