Search code examples
c++templatesmemoizationfixpoint-combinatorsanonymous-recursion

Quickly memoize anonymous recursive functions using lambdas and the "fix function"


Background

I recently learned that the fixed-point combinator makes it easy to define recursive functions without naming them. It is primarily used in functional programming languages (e.g. fix function), but you can mimic its appearance in C++20 as follows:

#include <iostream>

template <typename F>
struct Fix {
    F f;
    decltype(auto) operator()(auto arg) {
        return f(std::ref(*this), arg);
    }
};

int main() {
    auto fact = Fix{[](auto self, int n) -> int {
        return (n <= 1) ? 1 : n * self(n - 1);
    }};

    std::cout << fact(5) << std::endl; // 120
}

All you need to do is receive auto self as the first argument of the lambda expression to achieve anonymous recursion.

I like this sort of syntax because of its simplicity and generality.

Question

Is it possible to combine the functionalities of the "fix function" used in functional programming languages with the memoization ability into a single class in C++, such as the MemoFix class, in order to write concise memoized recursive functions?

Details

In programming competitions, I often want to write memoized functions.

So, I started creating a simple library containing the MemoFix class, which adds memoization functionality to the "fix function". It will generate recursive functions and also make it a memoized function at the same time.

I have finished writing the code. But unfortunately, the compilation fails.

#include <iostream>
#include <unordered_map>

template <typename F, typename Ret, typename Arg>
struct MemoFix {
    F f;
    std::unordered_map<Arg, Ret> cache{};

    Ret operator()(Arg x) {
        if (!cache.contains(x)) {
            cache[x] = f(std::ref(*this), x);
        }
        return cache[x];
    }
};

int main() {
    // not works, saying like "No viable constructor or deduction guide for deduction"
    auto fact = MemoFix{[](auto self, int n) -> int {
        return (n <= 1) ? 1 : n * self(n - 1);
    }};

    // works, but too redundant. I don't want to write `decltype`, types, and lambda names
    // auto body = [](auto self, int n) -> int {
    //     return (n <= 1) ? 1 : n * self(n - 1);
    // };
    // auto fact = MemoFix<decltype(body), int, int>{body};

    std::cout << fact(5) << std::endl; // 120
}

unordered_set is used to remember the return values. Because unordered_set needs explicit type parameters, additional type parameters Ret and Arg are notated.

Their types are correctly identified only by specifying template arguments like MemoFix<decltype(body), int, int>. Once you omit this, it immediately goes wrong.

What changes should I make on MemoFix to allow the compiler to infer types properly?

Other notes:

  • In Clang 18.0, C++20 (or newer)
  • I'm reluctant to use std::function due to its perceived performance disadvantages.

How can I do memoized recursion using the simplest notation possible without writing types over and over or giving extra names to lambda expressions? If you have an approach from another perspective, that's fine too. I would be happy to receive advice from someone who knows more.


Solution

  • I've crafted a simplified version of LambdaTraits designed to extract both the parameter types and the return type.

    template <typename...> struct LambdaTraits;
    
    template <typename F>
    struct LambdaTraits<F> : public LambdaTraits<decltype(&F::operator())> {};
    
    template<typename F, typename ...TArgs>
    struct LambdaTraits<F, TArgs...> : LambdaTraits<decltype(&F::template operator()<TArgs...>)> {};
    
    template <typename C, typename Ret, typename... Args>
    struct LambdaTraits<Ret(C::*)(Args...) const> {
        using args_type = std::tuple<Args...>;
        using return_type = Ret;
    };
    
    struct Any {
        // template <typename T>
        // constexpr operator T&() const;
    
        template <typename T>
        T operator()(T) const;
    };
    
    template <typename F>
    struct MemoFix {
        F f;
        using Arg = std::tuple_element_t<1, typename LambdaTraits<F, Any>::args_type>;
        using Ret = LambdaTraits<F, Any>::return_type;
        std::unordered_map<Arg, Ret> cache{};
    };
    

    Demo