Search code examples
c++boostc++20coroutineasio

Is there any elegant way to combine asio::co_composed and std::variant?


Question

I want to treat connection classes that has same signature member functions. For example, both tcp and tls are kind of connection class. They have send() member function template that supports CompletionToken. I use std::variant for connection type.

At some part of my application, I define non_awaitable_func function template. It supports CompletionToken, and I want to use co_await for its implementation. In this case, we can use asio::experimental::co_composed.

So far, so good.

However, when I call connection's send() member function template with boost::asio::deferred using std::visit, assertion failed happned. It is because std::visit visitor functions' return type is different. boost::asio::deferred could create different return type that can be converted to the given signature. Typically, in this case, I use boost::asio::use_awaitable instead of boost::asio::deferred. boost::asio::use_awaitable creates the same return type. However, in the co_composed implementation, I can't use boost::asio::use_awaitable.

Is there any good way to solve this situation?

Demonstration code

#include <iostream>
#include <chrono>

#include <boost/asio.hpp>
#include <boost/asio/experimental/co_composed.hpp>

struct tcp {
    template <typename CompletionToken> // // void(boost::system::error_code)
    auto send(CompletionToken&& token) {
        // pseudo implementation
        auto tim = std::make_shared<boost::asio::steady_timer>(exe_, std::chrono::seconds(1));
        return tim->async_wait(
            boost::asio::consign(
                std::forward<CompletionToken>(token),
                tim
            )
        );
    }
    boost::asio::any_io_executor exe_;
};

struct tls {
    template <typename CompletionToken> // void(boost::system::error_code)
    auto send(CompletionToken&& token) {
        // pseudo implementation
        return boost::asio::dispatch(
            boost::asio::append(
                std::forward<CompletionToken>(token),
                boost::system::errc::make_error_code(boost::system::errc::bad_message)
            )
        );
    }
    boost::asio::any_io_executor exe_;
};

#if 0 // if set 0 then no error happens because all visit return types are the same
using connection = std::variant<tcp, tls>;
#else
using connection = std::variant<tcp>;
#endif

template <typename CompletionToken>
auto non_awaitable_func(
    connection& con,
    CompletionToken&& token
) {
    return boost::asio::async_initiate<
        CompletionToken,
        void(boost::system::error_code)
    >(
        boost::asio::experimental::co_composed<
            void(boost::system::error_code)
        >(
            [](auto /*state*/, connection& con) -> void {
                auto [ec] = co_await std::visit(
                    [](auto& c) {
                        // use_awaitable can't be used here because of co_composed
                        return c.send(boost::asio::as_tuple(boost::asio::deferred));
                    },
                    con
                );
                // user defined implementation
                co_return {ec};
            }
        ),
        token,
        con
    );
}

int main() {
    boost::asio::io_context ioc;
    connection con = tcp{ioc.get_executor()};
    non_awaitable_func(
        con,
        [&] (boost::system::error_code ec) {
            std::cout << "cb called:" << ec << std::endl;
        }
    );
    ioc.run();
}

godbolt link: https://godbolt.org/z/b7zac6Pn1

My workaround

I replaced std::visit with std::get_if.

        boost::asio::experimental::co_composed<
            void(boost::system::error_code)
        >(
            [](auto /*state*/, connection& con) -> void {
                // No std::visit approach
                // However, not elegant and appears similar code repeatedly...
                if (auto* p = std::get_if<0>(&con)) {
                    auto [ec] = co_await p->send(boost::asio::as_tuple(boost::asio::deferred));
                    // user defined implementation
                    co_return {ec};
                }
                else {
                    // p conflicts
                    auto* q  = std::get_if<1>(&con);
                    BOOST_ASSERT(q);
                    auto [ec] = co_await q->send(boost::asio::as_tuple(boost::asio::deferred));
                    // user defined implementation
                    co_return {ec};
                }
            }
        ),

godbolt link: https://godbolt.org/z/z8YMfMvMv

It works as I expected, but not elegant. I tried to use preprocessor macro to avoid code repeatation but I couldn't find a good way, so far.

Environment

  • Boost 1.84.0
  • clang++ 17.0.1
  • compiler option -std=c++20

UPDATE

I noticed that the name of non_awaitable_func() is misleading. Let me clarify the meaning. I mean async function that doesn't have boost::asio::awaitable<T> return type. The async function has CompletionToken parameter. The parameter could be a callback function or any CompletionToken types that supports Boost.Asio. For example, use_future, use_awaitable, and deferred.


Solution

  • You'd need to type-erase the deferred async operations. Or write out the entire composed operation so the internal type discrepancies don't leak out to co_composed.

    Both aren't without overhead, so perhaps you might flip the visitation inside out so the composed operation is never variant to begin with:

    template <typename Connection, typename CompletionToken> auto non_awaitable_func_impl(Connection& con, CompletionToken&& token) {
        return asio::async_initiate<CompletionToken, Sig>(
            asio::experimental::co_composed<Sig>([](auto state, Connection& con) -> void {
                auto [ec] = co_await con.send(as_tuple(asio::deferred));
                co_yield state.complete(ec);
            }),
            token, con);
    }
    
    template <typename CompletionToken> auto non_awaitable_func(connection& con, CompletionToken&& token) {
        std::visit(
            [&token](auto& con) { return non_awaitable_func_impl(con, std::forward<CompletionToken>(token)); },
            con);
    }
    

    This works. You can combine the two if you don't mind readability:

    template <typename CompletionToken> auto non_awaitable_func(connection& con, CompletionToken&& token) {
        return std::visit(
            [&token](auto& con) {
                return asio::async_initiate<CompletionToken, Sig>(
                    asio::experimental::co_composed<Sig>([&con](auto state) -> void {
                        auto [ec] = co_await con.send(as_tuple(asio::deferred));
                        co_yield state.complete(ec);
                    }),
                    token);
            },
            con);
    }
    

    See it Live On Coliru (or Godbolt)

    #include <chrono>
    #include <iostream>
    #include <boost/asio.hpp>
    #include <boost/asio/experimental/co_composed.hpp>
    using namespace std::chrono_literals;
    namespace asio   = boost::asio;
    using error_code = boost::system::error_code;
    using Sig        = void(error_code);
    
    struct tcp {
        template <asio::completion_token_for<Sig> Token> auto send(Token&& token) {
            // pseudo implementation
            auto tim = std::make_unique<asio::steady_timer>(exe_, 1s);
            return tim->async_wait(consign(std::forward<Token>(token), std::move(tim)));
        }
        asio::any_io_executor exe_;
    };
    
    struct tls {
        template <asio::completion_token_for<Sig> Token> auto send(Token&& token) {
            return dispatch( // pseudo implementation
                append(std::forward<Token>(token), make_error_code(boost::system::errc::bad_message)));
        }
        asio::any_io_executor exe_;
    };
    
    using connection = std::variant<tcp, tls>;
    
    template <asio::completion_token_for<Sig> Token> auto non_awaitable_func(connection& con, Token&& token) {
        return std::visit(
            [&token](auto& con) {
                return asio::async_initiate<Token, Sig>(
                    asio::experimental::co_composed<Sig>([&con](auto state) -> void {
                        auto [ec] = co_await con.send(as_tuple(asio::deferred));
                        co_return state.complete(ec);
                    }),
                    token);
            },
            con);
    }
    
    int main() {
        asio::io_context ioc;
    
        connection 
            con1 = tls{ioc.get_executor()},
            con2 = tcp{ioc.get_executor()};
    
        non_awaitable_func(con1, [&](error_code ec) { std::cout << "cb1:" << ec.message() << std::endl; });
        non_awaitable_func(con2, [&](error_code ec) { std::cout << "cb2:" << ec.message() << std::endl; });
        ioc.run();
    }
    

    Note that it is important to let ADL find the correct make_error_code overload.

    Prints:

    cb1:Bad message
    cb2:Success
    

    UPDATE: Promise!

    I had a brainwave. Another experimental type, asio::experimental::promise<> which, like std::promise, apparently does some type erasure internally, yet, unlike std::future also can be await-transformed in Asio coroutines.

    And indeed it works:

    template <asio::completion_token_for<Sig> Token> //
    auto async_send(connection& con, Token&& token) {
        return asio::async_initiate<Token, Sig>(
            boost::asio::experimental::co_composed<Sig>([&con](auto /*state*/) -> void {
                auto [ec] = co_await std::visit(
                    [](auto& c) { return c.send(asio::as_tuple(asio::experimental::use_promise)); }, con);
                co_return {ec};
            }),
            token);
    }
    

    Here's a way more complete test program:

    Live On Coliru or Godbolt

    #include <boost/asio.hpp>
    #include <boost/asio/experimental/co_composed.hpp>
    #include <boost/asio/experimental/promise.hpp>
    #include <boost/asio/experimental/use_coro.hpp>
    #include <boost/asio/experimental/use_promise.hpp>
    #include <boost/core/demangle.hpp>
    #include <chrono>
    #include <iostream>
    #include <syncstream>
    using namespace std::chrono_literals;
    namespace asio   = boost::asio;
    using error_code = boost::system::error_code;
    using Sig        = void(error_code);
    
    static inline auto out() { return std::osyncstream(std::clog); }
    
    struct tcp {
        template <asio::completion_token_for<Sig> Token> //
        auto send(Token&& token) {
            // pseudo implementation
            auto tim = std::make_unique<asio::steady_timer>(exe_, 1s);
            return tim->async_wait(consign(std::forward<Token>(token), std::move(tim)));
        }
        asio::any_io_executor exe_;
    };
    
    struct tls {
        template <asio::completion_token_for<Sig> Token> //
        auto send(Token&& token) {
            return dispatch( // pseudo implementation
                append(std::forward<Token>(token), make_error_code(boost::system::errc::bad_message)));
        }
        asio::any_io_executor exe_;
    };
    
    using connection = std::variant<tcp, tls>;
    
    template <asio::completion_token_for<Sig> Token> //
    auto async_send(connection& con, Token&& token) {
        return asio::async_initiate<Token, Sig>(
            boost::asio::experimental::co_composed<Sig>([&con](auto /*state*/) -> void {
                auto [ec] = co_await std::visit(
                    [](auto& c) { return c.send(asio::as_tuple(asio::experimental::use_promise)); }, con);
                co_return {ec};
            }),
            token);
    }
    
    template <class V> // HT: https://stackoverflow.com/a/53697591/85371
    std::type_info const& var_type(V const& v) {
        return std::visit([](auto&& x) -> decltype(auto) { return typeid(x); }, v);
    }
    
    int main() {
        asio::thread_pool ioc(1);
    
        connection 
            con1 = tls{ioc.get_executor()},
            con2 = tcp{ioc.get_executor()};
    
        { // callback
            async_send(con1, [&](error_code ec) { out() << "cb1:" << ec.message() << std::endl; });
            async_send(con2, [&](error_code ec) { out() << "cb2:" << ec.message() << std::endl; });
        }
    
        { // use_future
            auto f1 = async_send(con1, as_tuple(asio::use_future));
            auto f2 = async_send(con2, as_tuple(asio::use_future));
            out() << "f1: " << std::get<0>(f1.get()).message() << std::endl;
            out() << "f2: " << std::get<0>(f2.get()).message() << std::endl;
    
            try {
                async_send(con1, asio::use_future).get();
            } catch (boost::system::system_error const& se) {
                out() << "alternatively: " << se.code().message() << std::endl;
            }
        }
    
        { // use_awaitable
            for (connection& con : {std::ref(con1), std::ref(con2)}) {
                auto name = "coro-" + boost::core::demangle(var_type(con).name());
                co_spawn(
                    ioc,
                    [&con, name]() -> asio::awaitable<void> {
                        auto [ec_defer] = co_await async_send(con, as_tuple(asio::deferred));
                        auto [ec_aw]    = co_await async_send(con, as_tuple(asio::use_awaitable));
                        out() << name << ": " << ec_defer.message() << "/" << ec_aw.message() << std::endl;
                        co_await async_send(con, asio::deferred); // will throw
                    },
                    [name](std::exception_ptr e) {
                        try {
                            if (e)
                                std::rethrow_exception(e);
                        } catch (boost::system::system_error const& se) {
                            out() << name << " threw " << se.code().message() << std::endl;
                        }
                    });
            }
        }
        ioc.join();
    }
    

    Printing e.g.

    cb1:Bad message
    f1: Bad message
    cb2:Success
    f2: Success
    alternatively: Bad message
    coro-tls: Bad message/Bad message
    coro-tls threw Bad message
    coro-tcp: Success/Success