Search code examples
c++std-variant

well-known overloads for std::visit does not work with reference_wrapper


Here's a sample code: http://coliru.stacked-crooked.com/a/5f630d2d65cd983e

#include <variant>
#include <functional>

template<class... Ts> struct overloads : Ts... { using Ts::operator()...; };
template<class... Ts> overloads(Ts &&...) -> overloads<std::remove_cvref_t<Ts>...>;

template<typename... Ts, typename... Fs>
constexpr inline auto transform(const std::variant<Ts...> &var, Fs &&... fs)
    -> decltype(auto) { return std::visit(overloads{fs...}, var); }

template<typename... Ts, typename... Fs>
constexpr inline auto transform_by_ref(const std::variant<Ts...> &var, Fs &&... fs)
    -> decltype(auto) { return std::visit(overloads{std::ref(fs)...}, var); }
    
int main()
{
    transform(
        std::variant<int, double>{1.0},
        [](int) { return 1; },
        [](double) { return 2; }); // fine
    transform_by_ref(
        std::variant<int, double>{1.0},
        [](int) { return 1; },
        [](double) { return 2; }); // compilation error
    return 0;
}

Here, I have adopted the well-known overloads helper type to invoke std::visit() with multiple lambdas.

transform() copies function objects so I write a new function transform_by_ref() which utilizes std::reference_wrapper to prevent copying function objects.

Even though original lambdas are temporary objects, the lifetime is ensured at the end of execution of transform_by_ref() and I think lifetime should not be a problem here.

transform() works as expected but transform_by_ref() causes compilation error:

main.cpp: In instantiation of 'constexpr decltype(auto) transform_by_ref(const std::variant<_Types ...>&, Fs&& ...) [with Ts = {int, double}; Fs = {main()::<lambda(int)>, main()::<lambda(double)>}]':
main.cpp:18:21:   required from here
main.cpp:13:42: error: no matching function for call to 'visit(overloads<std::reference_wrapper<main()::<lambda(int)> >, std::reference_wrapper<main()::<lambda(double)> > >, const std::variant<int, double>&)'
   13 |     -> decltype(auto) { return std::visit(overloads{std::ref(fs)...}, var); }
      |                                ~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
In file included from main.cpp:1:
/usr/local/include/c++/12.1.0/variant:1819:5: note: candidate: 'template<class _Visitor, class ... _Variants> constexpr std::__detail::__variant::__visit_result_t<_Visitor, _Variants ...> std::visit(_Visitor&&, _Variants&& ...)'
 1819 |     visit(_Visitor&& __visitor, _Variants&&... __variants)
      |     ^~~~~
/usr/local/include/c++/12.1.0/variant:1819:5: note:   template argument deduction/substitution failed:
In file included from /usr/local/include/c++/12.1.0/variant:37:
/usr/local/include/c++/12.1.0/type_traits: In substitution of 'template<class _Fn, class ... _Args> using invoke_result_t = typename std::invoke_result::type [with _Fn = overloads<std::reference_wrapper<main()::<lambda(int)> >, std::reference_wrapper<main()::<lambda(double)> > >; _Args = {const int&}]':
/usr/local/include/c++/12.1.0/variant:1093:11:   required by substitution of 'template<class _Visitor, class ... _Variants> using __visit_result_t = std::invoke_result_t<_Visitor, std::__detail::__variant::__get_t<0, _Variants, decltype (std::__detail::__variant::__as(declval<_Variants>())), typename std::variant_alternative<0, typename std::remove_reference<decltype (std::__detail::__variant::__as(declval<_Variants>()))>::type>::type>...> [with _Visitor = overloads<std::reference_wrapper<main()::<lambda(int)> >, std::reference_wrapper<main()::<lambda(double)> > >; _Variants = {const std::variant<int, double>&}]'
/usr/local/include/c++/12.1.0/variant:1819:5:   required by substitution of 'template<class _Visitor, class ... _Variants> constexpr std::__detail::__variant::__visit_result_t<_Visitor, _Variants ...> std::visit(_Visitor&&, _Variants&& ...) [with _Visitor = overloads<std::reference_wrapper<main()::<lambda(int)> >, std::reference_wrapper<main()::<lambda(double)> > >; _Variants = {const std::variant<int, double>&}]'
main.cpp:13:42:   required from 'constexpr decltype(auto) transform_by_ref(const std::variant<_Types ...>&, Fs&& ...) [with Ts = {int, double}; Fs = {main()::<lambda(int)>, main()::<lambda(double)>}]'
main.cpp:18:21:   required from here
/usr/local/include/c++/12.1.0/type_traits:3034:11: error: no type named 'type' in 'struct std::invoke_result<overloads<std::reference_wrapper<main()::<lambda(int)> >, std::reference_wrapper<main()::<lambda(double)> > >, const int&>'
 3034 |     using invoke_result_t = typename invoke_result<_Fn, _Args...>::type;
      |           ^~~~~~~~~~~~~~~
main.cpp: In instantiation of 'constexpr decltype(auto) transform_by_ref(const std::variant<_Types ...>&, Fs&& ...) [with Ts = {int, double}; Fs = {main()::<lambda(int)>, main()::<lambda(double)>}]':
main.cpp:18:21:   required from here
/usr/local/include/c++/12.1.0/variant:1859:5: note: candidate: 'template<class _Res, class _Visitor, class ... _Variants> constexpr _Res std::visit(_Visitor&&, _Variants&& ...)'
 1859 |     visit(_Visitor&& __visitor, _Variants&&... __variants)
      |     ^~~~~
/usr/local/include/c++/12.1.0/variant:1859:5: note:   template argument deduction/substitution failed:
main.cpp:13:42: note:   couldn't deduce template parameter '_Res'
   13 |     -> decltype(auto) { return std::visit(overloads{std::ref(fs)...}, var); }
      |                                ~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

I think I can fix this by not using std::visit() and implement my own visit function anyway. However, I want to know why this code does not work as expected.

Why does my transform_by_ref() cause compilation error and how to fix it without custom visit function implementation?


Solution

  • Each std::reference_wrapper has an operator() overload that can be called with any argument list that the referenced lambda would accept as arguments.

    That means the reference wrappers for both [](int) { return 1; } and [](double) { return 2; } have operator() overloads that accept an int argument as well as an double argument, both without conversion of the argument.

    So when std::visit tries to do overload resolution for a specific element type of the variant, the operator() overloads made visible via using Ts::operator()...; for both reference wrappers of the lambdas will be viable, but in contrast to the non-reference-wrapper case, both overloads will be viable without conversions of the argument, meaning that they are equally good and hence overload resolution ambiguous.

    The ambiguity can be avoided by enforcing that the lambdas take only exactly the type they are supposed to match as argument (assuming C++20 here):

    transform_by_ref(
        std::variant<int, double>{1.0},
        [](std::same_as<int> auto) { return 1; },
        [](std::same_as<double> auto) { return 2; });
    

    or by using a single overload with if constexpr in its body to branch on the type of the argument.

    While it is possible to make the operator() of a wrapper class SFINAE-friendly so that it won't be considered viable if the wrapped callable isn't, it is impossible to "forward" the conversion rank of calls for calls to such a wrapper, at least in general. For non-generic lambdas specifically, it is theoretically possible to extract the parameter type in the wrapper and use it as the parameter type of the operator() overload, but that is messy and doesn't work with generic callables. Proper reflection would be required to implement such a wrapper.


    In your code for transform you are using fs directly as lvalue instead of properly forwarding its value category via std::forward<Fs>(fs). If you used that instead, then only move construction would be used, instead of copies.

    If the goal is to also avoid the move construction, the usual approach which constructs overloads in the caller already achieves that:

    template<typename... Ts, typename Fs>
    constexpr inline auto transform(const std::variant<Ts...> &var, Fs && fs)
        -> decltype(auto) { return std::visit(std::forward<Fs>(fs), var); }
        
    int main()
    {
        transform(
            std::variant<int, double>{1.0},
            overloads{
                [](int) { return 1; },
                [](double) { return 2; }});
        return 0;
    }
    

    This uses aggregate-initialization of overloads from prvalues, which means mandatory copy elision applies and no lambdas will be copied or moved.

    The std::ref approach, even if it did work, would also waste memory to store the references for non-capturing lambdas.