Search code examples
c++boostboost-asio

C++20 coroutines read/write websocket


I want to make a websocket running on a single thread using coroutines and boost::asio. One coroutine would be responsible for writing (async_write) and the other would take care of reading (async_read).

If either coroutine gets an exception (for now I assume every exception means connection broken) I will try reconnecting.

For writing, I want a writeBuffer which serves as a queue for messages. Clients of the websocket would call ws.Send(data) and instead of sending immediately, it'll remain in the buffer until the next ws.Run() call.

To make it work as described, I need a way to suspend the write coroutine if writeBuffer is empty. If I don't it'll spin forever waiting for buffer to fill in. But I get an error when trying to suspend it with std::suspend_always{}:

error C2665: 'boost::asio::detail::awaitable_frame_base<Executor>::await_transform': no overloaded function could convert all the argument types

So I guess that's not how I suspend coroutine with asio::awaitable. I really need this proxy buffer as a queue for my messages. I could probably use something else from boost - maybe signals provide a way to co_await on them as well, but I'm afriad I'd have to make 3 more questions to understand these.

Here's my code distilled to minimum:

#include <iostream>
#include <coroutine>
#include <optional>

#include <boost/asio.hpp>
#include <boost/beast.hpp>
#include <boost/asio/awaitable.hpp>
#include <boost/asio/experimental/awaitable_operators.hpp>

namespace asio = boost::asio;
namespace beast = boost::beast;
namespace websocket = beast::websocket;
using namespace std::chrono_literals;
using namespace asio::experimental::awaitable_operators;

struct CoroWebsocket {

    CoroWebsocket(std::string host, std::string port)
    : _host(std::move(host))
    , _port(std::move(port)) {
        asio::co_spawn(_ws.get_executor(), do_run(), asio::detached);
    }

    void Run() {
        _ioc.run_for(50ms);
    }

    void Write(std::string data) {
        // TODO: mutex
        _writeBuffer.push_back(std::move(data));
    }

    std::optional<std::string> Read(){ 
        // TODO: mutex
        if (_readBuffer.empty())
            return {};
        const auto message = _readBuffer.back();
        _readBuffer.pop_back();
        return message;
    }

private:
    const std::string _host, _port;
    using tcp = asio::ip::tcp;
    std::vector<std::string>       _writeBuffer; // Will be filled externally.
    std::vector<std::string>       _readBuffer;
    boost::asio::io_context        _ioc;
    websocket::stream<tcp::socket> _ws{_ioc};

    asio::awaitable<void> do_run() {
        while(true) {
            try {
                co_await do_connect();
                co_await asio::co_spawn(_ws.get_executor(), do_write() || do_read(), asio::use_awaitable); // If either ends, it must've been an exception. Reconnect.
            } catch (const boost::system::system_error& se) {
                std::cerr << "Error: " << se.code().message() << std::endl;
            }
        }
    }

    asio::awaitable<void> do_connect() {
        try {
            while(true) {
                co_await async_connect(_ws.next_layer(), tcp::resolver(_ioc).resolve(_host, _port), asio::use_awaitable);
                _ws.set_option(websocket::stream_base::decorator([](websocket::request_type& req) {
                    req.set(beast::http::field::user_agent, BOOST_BEAST_VERSION_STRING " WsConnect");
                }));

                co_await _ws.async_handshake(_host + ':' + _port, "/", asio::use_awaitable);
                _readBuffer.clear();
            }
        } catch (boost::system::system_error const& se) {
            std::cerr << "Error: " << se.code().message() << std::endl;
        }
    }

    asio::awaitable<void> do_write() {
        try {
            while(true) {
                while (_writeBuffer.empty()) {
                    co_await std::suspend_always{}; // I want to switch context but ERROR
                }

                for (const auto& message : _writeBuffer) {
                    co_await _ws.async_write(boost::asio::buffer(message), asio::use_awaitable);
                }
                _writeBuffer.clear();
            }
        } catch (boost::system::system_error const& se) {
            std::cerr << "Error: " << se.code().message() << std::endl;
        }
    }

    asio::awaitable<void> do_read() {
        try {
            while(true) {
                if (0 != co_await _ws.async_read_some(boost::asio::buffer(_readBuffer), asio::use_awaitable)) {
                    while (!_ws.is_message_done()) {
                        co_await _ws.async_read_some(boost::asio::buffer(_readBuffer), asio::use_awaitable);
                    }
                    // Signal new message.
                }
            }
        } catch (boost::system::system_error const& se) {
            std::cerr << "Error: " << se.code().message() << std::endl;
        }
    
    }
};

Ideally, I'd also want to replace io_context with thread_pool with 1 thread so I can really run these coroutines detached. However, if I just let them spin on a separate thread if there's no reading or writing to do, I think the thread would burn cycles on just switching context as both suspend. To solve this I thought about adding a 3rd coroutine to the mix that'd stall the thread with basic this_thread::sleep(50ms) if no reading or writing was done on the last spin.


Solution

  • I want a websocket running on a single thread serving reads and writes as they come. If there's nothing to do - check for work without burning cycles

    You're in luck, that's precisely what asynchronous completion is about.

    If there's nothing to do, nothing will happen and the thread(s) in the pool will effectively sleep, meaning other processes in the system can get a turn. The difference is that instead of "naked sleep" it is "smart sleep" as in: the sleep gets woken up as soon as any relevant IO event appears. This can be any event supported by an Asio service like files, pipes, UNIX or internet sockets, serial port, async process signals¹.

    Regarding the code, let me first point out that the await on a spawn:

        co_await asio::co_spawn(
            _ws.get_executor(), do_write() || do_read(),
            asio::use_awaitable);
    

    is a potentially more costly way to just write

        co_await (do_write() || do_read()); // idiomatic full-duplex!
    

    And the

        co_await std::suspend_always{}; // I want to switch context but ERROR
    

    Is also always suboptimal, but if you insist could be

        co_await post(asio::deferred);
    

    How To Approach

    I'd make it so that the write-loop is started only when there is something queued. This immediately makes it trivial to synchronize access to the queue using a strand.

    Alternatively look at a channel as a replacement for the queue. The write thread could asynchronously receive from the channel. It also gives you control over queue capacity: https://www.boost.org/doc/libs/1_84_0/doc/html/boost_asio/overview/channels.html

    Channels seem like the better idea here because it also allows you to write your Read interface naturally and optimally.

    Demo

    Using channels seems like a good match. It all boils down to this essence:

    const std::string _host, _port;
    Stream            _ws;
    Channel           _in{_ws.get_executor()}, _out{_ws.get_executor()};
    
    Task do_run() {
        while (true) {
            try {
                co_await do_connect();
                co_await (do_write_loop() || do_read_loop());
            } catch (boost::system::system_error const& se) {
                std::cerr << "Error: " << se.code().message() << std::endl;
            }
        }
    }
    
    Task do_write_loop() {
        for (;;)
            co_await _ws.async_write(asio::buffer(co_await _out.async_receive()));
    }
    
    Task do_read_loop() {
        for (Message msg;; msg.clear()) {
            auto buf = asio::dynamic_buffer(msg);
    
            auto [ec, bytes] = co_await _ws.async_read(buf, as_tuple(asio::deferred));
            co_await _in.async_send(ec, std::move(msg));
        }
    }
    

    Adding a graceful shutdown flag and a back-off delay in case connection fails:

    Task do_run() {
        for (; !_close_requested; co_await delay(200ms)) {
            try {
                co_await do_connect();
                co_await (do_write_loop() || do_read_loop());
            } catch (boost::system::system_error const& se) {
                std::cerr << "Error: " << se.code().message() << std::endl;
            }
        }
    }
    

    Full Listing

    Live On Coliru

    #include <iostream>
    #include <boost/asio.hpp>
    #include <boost/asio/experimental/awaitable_operators.hpp>
    #include <boost/asio/experimental/concurrent_channel.hpp>
    #include <boost/beast.hpp>
    
    namespace asio      = boost::asio;
    namespace beast     = boost::beast;
    namespace websocket = beast::websocket;
    using namespace std::chrono_literals;
    using namespace asio::experimental::awaitable_operators;
    using boost::system::error_code;
    using tcp = asio::ip::tcp;
    
    struct Server {
        using Message  = std::string;
        using Task     = asio::awaitable<void>;
        using Channel  = asio::deferred_t::as_default_on_t< //
            asio::experimental::concurrent_channel<void(error_code, Message)>>;
        using Stream   = asio::deferred_t::as_default_on_t<websocket::stream<tcp::socket>>;
        using Resolver = asio::deferred_t::as_default_on_t<tcp::resolver>;
        using Opts     = websocket::stream_base;
    
        Server(asio::any_io_executor ex, std::string host, std::string port)
            : _host(std::move(host))
            , _port(std::move(port))
            , _ws(ex) {
            co_spawn(ex, do_run(), asio::detached);
        }
    
        bool Write(std::string data) { return _out.try_send(error_code{}, std::move(data)); }
    
        std::optional<std::string> Read() {
            Message ret;
            if (_in.try_receive([&](error_code ec, Message msg_) {
                    if (ec)
                        throw boost::system::system_error(ec);
                    ret = std::move(msg_);
                }))
                return ret;
            return std::nullopt;
        }
    
        void close() {
            _close_requested = true;
            post(_ws.get_executor(), [this] {
                if (error_code ignore; _ws.is_open())
                    _ws.close({}, ignore);
            });
        }
    
      private:
        const std::string _host, _port;
        Stream            _ws;
        Channel           _in{_ws.get_executor(), 10}, _out{_ws.get_executor(), 10};
        std::atomic_bool  _close_requested{false};
    
        Task delay(auto duration_or_timepoint) {
            auto ex = co_await asio::this_coro::executor;
            co_await asio::steady_timer(ex, duration_or_timepoint).async_wait(asio::deferred);
        }
    
        Task do_run() {
            for (; !_close_requested; co_await delay(200ms)) {
                try {
                    co_await do_connect();
                    co_await (do_write_loop() || do_read_loop());
                } catch (boost::system::system_error const& se) {
                    std::cerr << "Error: " << se.code().message() << std::endl;
                }
            }
        }
    
        Task do_write_loop() {
            for (;;)
                co_await _ws.async_write(asio::buffer(co_await _out.async_receive()));
        }
    
        Task do_read_loop() {
            for (Message msg;; msg.clear()) {
                auto buf = asio::dynamic_buffer(msg);
    
                auto [ec, bytes] = co_await _ws.async_read(buf, as_tuple(asio::deferred));
                co_await _in.async_send(ec, std::move(msg));
    
                if (ec)
                    break;
            }
        }
    
        asio::awaitable<void> do_connect() {
            auto ex = co_await asio::this_coro::executor;
            if (error_code ignore; _ws.is_open()) {
                _ws.close({}, ignore);
                _in.reset();
                _out.reset();
            }
    
            auto eps = co_await Resolver(ex).async_resolve(_host, _port);
            co_await async_connect(beast::get_lowest_layer(_ws), eps);
    
            _ws.set_option(Opts::decorator([](websocket::request_type& req) {
                req.set(beast::http::field::user_agent, BOOST_BEAST_VERSION_STRING " WsConnect");
            }));
    
            co_await _ws.async_handshake(_host + ':' + _port, "/");
        }
    };
    
    int main() {
        asio::thread_pool ioc;
        Server s(make_strand(ioc), "localhost", "8989");
    
        using std::this_thread::sleep_for;
    
        for (auto msg : {"foo", "bar", "qux"}) {
            s.Write(msg);
            sleep_for(1s);
    
            while (auto response = s.Read())
                std::cout << "Received response " << quoted(*response) << std::endl;
        }
    
        s.close();
        ioc.join();
    }
    

    With a live demo against:

     websocketd --port 8989 -- \
          bash -c 'tee log | while read line; do echo "Responding to ($line)"; done'
    

    enter image description here

    ¹ and some platform specific things like completion ports on windows