Search code examples
c++lambday-combinator

Why does this Y Combinator using code fail to compile?


I've been reading about combinators for three days now and I finally started writing them in code (more like copying stuff from places and making sense of things).

Here's some code that I'm trying to run:

#include <iostream>
#include <utility>

template <typename Lambda>
class y_combinator {
  private:
    Lambda lambda;
  public:
    template <typename T>
    constexpr explicit y_combinator (T&& lambda)
        : lambda (std::forward <T> (lambda))
    { }

    template <typename...Args>
    decltype(auto) operator () (Args&&... args) {
        return lambda((decltype(*this)&)(*this), std::forward <Args> (args)...);
    }
};

template <typename Lambda>
decltype(auto) y_combine (Lambda&& lambda) {
    return y_combinator <std::decay_t <Lambda>> (std::forward <Lambda> (lambda));
}

int main () {
    auto factorial = y_combine([&] (auto self, int64_t n) {
        return n == 1 ? (int64_t)1 : n * self(n - 1);
    });
    
    int n;
    std::cin >> n;

    std::cout << factorial(n) << '\n';
}

If I explicitly state the return type of the lambda as -> int64_t, everything works well. However, when I remove it, the compiler complains. The error:

main.cpp|16|error: use of 'main()::<lambda(auto:11, int64_t)> [with auto:11 = y_combinator<main()::<lambda(auto:11, int64_t)> >; int64_t = long long int]' before deduction of 'auto'

Why can't the compiler figure out the return type and deduce auto? I first thought that maybe I needed to change ... ? 1 : n * self(n - 1) to ... ? int64_t(1) : n * self(n - 1) so that the type of both return values ends up as int64_t and no possible ambiguities remain. This doesn't seem to be the case though. What am I missing?

Also, in the y_combinator class, declaring lambda as an object of type Lambda&& seems to cause problems. Why is this the case? This only happens when I write the cast in the operator () overload as (decltype(*this)&) instead of std::ref(*this). Are they doing different things?


Solution

  • Type deduction

    The type of n == 1 ? (int64_t)1 : n * self(n - 1) depends on the the return type of self, therefore it can't be deduced. You would think that int64_t is an obvious candidate, but float and double are also just as good. You can't expect the compiler to consider every possible return type and choose the best candidate.

    To fix this instead of using a ternary expression, use an if-else block:

    int main () {
        auto factorial = y_combine([&] (auto self, int64_t n) {
            if (n == 1) {
                return (int64_t)1;
            } else {
                return n * self(n - 1);
            }
        });
        // ...
    }
    

    With this one of the return statements doesn't depend on the return type of self, so type deduction can happen.

    When deducting the return type of a function, the compiler looks at all of the return statements in the body of the function in sequence and tries to deduce their type. If it fails you get a compilation error.

    Using the ternary operator the type of the return statement return n == 1 ? (int64_t)1 : n * self(n - 1); depends on the return type of self, which is not yet known. Therefore you get a compilation error.

    When using an if statement and multiple return statements, the compiler can deduce the return type from the first one it encounters, since

    If there are multiple return statements, they must all deduce to the same type.

    and

    Once a return statement has been seen in a function, the return type deduced from that statement can be used in the rest of the function, including in other return statements.

    as seen on cppreference. This is why

            if (n == 1) {
                return (int64_t)1;
            } else {
                return n * self(n - 1);
            }
    

    can be deduced to return an int64_t.

    As a side note

            if (n == 1) {
                return 1; // no explicit cast to int64_t
            } else {
                return n * self(n - 1);
            }
    

    would fail to compile because from the first return statement the return type of the function will be deduced as int, and from the second return statement as int64_t.

            if (n != 1) {
                return n * self(n - 1);
            } else {
                return (int64_t)1;
            }
    

    would also fail, because the first return statement it incounters depends on the return type of the function, so it can't be deduced.

    Second question

    The error happens because when calling lambda you're trying to make a copy of *this because of the lambda's auto self parameter. This is due to having an rvalue reference member. (see clang's great error message on godbolt)

    To fix this either use std::ref as mention in the original question or make the lambda have an auto &&self parameter (and use std::forward<decltype(self)>(self)).

    Also note that the member Lambda &&lambda is an rvalue reference, so you can only construct an instance of y_combinator with a temporary or moved lambda. In general copying functors in not a big deal, the standard library also takes functor parameters by copy.