Search code examples
c++templateslambdac++20

Pass C++20 Templated Lambda to Function Then Call it


I'm trying to pass a templated lambda to a function and then call it with a template parameter, to enable specialization of that function for custom types.

But when I try to call the lambda I get this error: error: invalid operands to binary expression

Here's a godbolt link for anyone who wants to play around with this: https://gcc.godbolt.org/z/qYPcea

#include <cstdint>
#include <string>
#include <cstring>

enum class Alignment : uint8_t {
    
    one,
    two,
    four,
    eight
};

template <Alignment alignment, typename T>
static void align(T& pointer)
{
    intptr_t& value = reinterpret_cast<intptr_t&>(pointer);
    value += (-value) & ((uint64_t)alignment - 1);
}

template<typename Lambda, typename T>
static void specialization(Lambda&& lambda, const T& t) 
{
   lambda<Alignment::eight>(t.data(), t.size());
}

int main()
{
    uint8_t buffer[1024];

    void *writeTo = buffer;

    auto lambda = [&] <Alignment alignment> (const void *input, uint32_t inputSize) -> void
    {
        align<alignment>(writeTo);
        writeTo = memcpy(writeTo, buffer, inputSize);
    };
    
    std::string input("helloworld");

    specialization(lambda, input);

    return 0;
}

Solution

  • The issue is that lambdas are not class templates, they're just regular classes where the member call operator, i.e. operator() is templated.

    When the template parameters are deduced for a generic lambda, this distinction is not noticeable (which is a very good thing).

    So in your example, lambda is not a class template, but you are using syntax that would be used for a class, not a member function.

    If you want to specify the template parameters for a lambda explicitly, you'll need to say that you're calling the member operator() of lambda, and you'll need to say that it's a template to disambiguate.

    lambda.template operator()<Alignment::eight>(t.data(), t.size());
    

    Here's the version of your code that compiles.

    #include <cstdint>
    #include <string>
    #include <cstring>
    
    enum class Alignment : uint8_t {
        
        one,
        two,
        four,
        eight
    };
    
    template <Alignment alignment, typename T>
    static void align(T& pointer)
    {
        intptr_t& value = reinterpret_cast<intptr_t&>(pointer);
        value += (-value) & ((uint64_t)alignment - 1);
    }
    
    template<typename Lambda, typename T>
    static void specialization(Lambda&& lambda, const T& t) 
    {
       lambda.template operator()<Alignment::eight>(t.data(), t.size());
    }
    
    int main()
    {
        uint8_t buffer[1024];
    
        void *writeTo = buffer;
    
        auto lambda = [&] <Alignment alignment> (const void *input, uint32_t inputSize) -> void
        {
            align<alignment>(writeTo);
            writeTo = memcpy(writeTo, buffer, inputSize);
        };
        
        std::string input("helloworld");
    
        specialization(lambda, input);
    
        return 0;
    }