Search code examples
c++network-programmingboost-asioboost-beast

Handling multiple WebSocket write requests in a continuous loop with Boost.Beast


template<typename Derived>
class Websocket: public std::enable_shared_from_this<Websocket<Derived>>
{
public:
    using InternalWSType = boost::beast::websocket::stream<
        boost::asio::ssl::stream<boost::asio::ip::tcp::socket>>;

    using SSLHandshakeConfigurator = std::function<void(Websocket *, InternalWSType &)>;

    explicit Websocket(boost::asio::io_context &io_context)
        : m_io_context { io_context }
        , m_resolver { m_io_context }
        , m_ws { m_io_context, m_ssl_context }
    {
        m_response.reserve(100'000'000);
    }

    Websocket(
        boost::asio::io_context &io_context,
        SSLHandshakeConfigurator ssl_handshake_configurator
    )
        : m_io_context { io_context }
        , m_resolver { m_io_context }
        , m_ws { m_io_context, m_ssl_context }
        , m_ssl_handshake_configurator(std::move(ssl_handshake_configurator))
    {
        m_ws.read_message_max(0);
        m_ws.auto_fragment(false);

        m_response.reserve(100'000'000);
    }

    virtual ~Websocket() = default;

    Websocket(const Websocket &) = delete;

    Websocket(Websocket &&) noexcept = delete;

    auto operator=(const Websocket &) -> Websocket & = delete;

    auto operator=(Websocket &&) noexcept -> Websocket & = delete;

    void async_start(std::string_view host, std::string_view port, std::string_view target)
    {
        m_host = std::string { host };
        m_target = std::string { target };

        // Look up the domain name
        m_resolver.async_resolve(
            m_host,
            std::string { port },
            [this](
                boost::system::error_code error_code,
                boost::asio::ip::tcp::resolver::results_type res
            ) {
                if (error_code)
                {
                    // handle
                    SPDLOG_ERROR("Failed to async_start");
                    SPDLOG_ERROR("    error_code.message: {}", error_code.message());
                    SPDLOG_ERROR("    error_code.what: {}", error_code.what());
                }
                else
                {
                    async_connect(std::move(res));
                }
            }
        );
    }

    void async_stop()
    {
        m_stop_requested = true;
        auto holder = this->shared_from_this();

        if (m_ws.next_layer().next_layer().is_open())
        {
            m_ws.async_close(
                boost::beast::websocket::close_code::normal,
                [holder = std::move(holder)](const boost::system::error_code &) {}
            );
        }
    }

    void send_message(std::string_view message)
    {
        m_ws.async_write(
            boost::beast::net::buffer(std::string { message }),
            [self = this->shared_from_this()](boost::system::error_code error_code, size_t) {
                if (error_code)
                {
                    SPDLOG_ERROR("Failed to send message");
                    SPDLOG_ERROR("    error_code.message: {}", error_code.message());
                    SPDLOG_ERROR("    error_code.what: {}", error_code.what());
                }
            }
        );
    }

    void poll()
    {
        m_io_context.poll();
    }

    void run()
    {
        m_io_context.run();
    }


protected:
    [[nodiscard]] auto get_io_context() const -> boost::asio::io_context &
    {
        return m_io_context;
    }

    virtual void send_subscribe_message() {};

private:
    void async_connect(boost::asio::ip::tcp::resolver::results_type result)
    {
        if (!SSL_set_tlsext_host_name(m_ws.next_layer().native_handle(), m_host.c_str()))
        {
            SPDLOG_ERROR("Boost::Beast error: async connect");
            return;
        }

        boost::asio::async_connect(
            m_ws.next_layer().next_layer(),
            result.begin(),
            result.end(),
            [this](boost::system::error_code error_code, boost::asio::ip::tcp::resolver::iterator) {
                if (error_code)
                {
                    SPDLOG_ERROR("Failed to async connect");
                    SPDLOG_ERROR("    error_code.message: {}", error_code.message());
                    SPDLOG_ERROR("    error_code.what: {}", error_code.what());

                    if (!m_stop_requested)
                    {
                        // handle the beast error
                    }
                }
                else
                {
                    on_connected();
                }
            }
        );
    }

    void on_connected()
    {
        m_ws.next_layer().async_handshake(
            boost::asio::ssl::stream_base::client,
            [this](boost::system::error_code error_code) {
                if (error_code)
                {
                    if (!m_stop_requested)
                    {
                    }
                }
                else
                {
                    on_async_ssl_handshake();
                }
            }
        );
    }

    void on_async_ssl_handshake()
    {
        if (m_ssl_handshake_configurator)
        {
            m_ssl_handshake_configurator(this, m_ws);
        }

        m_ws.async_handshake(m_host, m_target, [this](boost::system::error_code error_code) {
            if (!error_code)
                send_subscribe_message();
            else
            {
                SPDLOG_ERROR("Failed to async ssl handshake");
                SPDLOG_ERROR("    error_code.message: {}", error_code.message());
                SPDLOG_ERROR("    error_code.what: {}", error_code.what());
            }


            start_read(error_code);
        });
    }

    void start_read(boost::system::error_code error_code)
    {
        if (error_code)
        {
            if (!m_stop_requested)
            {
                SPDLOG_ERROR("Failed to start_read");
                SPDLOG_ERROR("    error_code.message: {}", error_code.message());
                SPDLOG_ERROR("    error_code.what: {}", error_code.what());
            }

            async_stop();

            return;
        }

        m_buffer.prepare(512'000'000);
        m_ws.async_read(m_buffer, [this](boost::system::error_code error_code, size_t size) {
            on_read(error_code, size);
        });
    }

    void on_read(boost::system::error_code error_code, [[maybe_unused]] size_t size)
    {
        if (error_code)
        {
            if (!m_stop_requested)
            {
                SPDLOG_ERROR("Failed to on_read");
                SPDLOG_ERROR("    error_code.message: {}", error_code.message());
                SPDLOG_ERROR("    error_code.what: {}", error_code.what());
                SPDLOG_ERROR("    reason: {}", m_ws.reason().reason.data());
                SPDLOG_ERROR("    reason_code: {}", m_ws.reason().code);
            }

            return;
        }

        m_response.clear();

        for (const auto &bytes : m_buffer.data())
        {
            m_response.append(static_cast<const char *>(bytes.data()), bytes.size());
        }
        m_buffer.consume(m_buffer.size());

        static_cast<Derived *>(this)->handle_message(m_response);

        // restart
        start_read(boost::system::error_code {});
    }

    boost::asio::io_context &m_io_context;
    boost::asio::ssl::context m_ssl_context { boost::asio::ssl::context::sslv23_client };
    boost::asio::ip::tcp::resolver m_resolver;
    InternalWSType m_ws;
    boost::beast::multi_buffer m_buffer;
    std::string m_response;
    std::string m_host;
    std::string m_target;
    bool m_stop_requested {};

    SSLHandshakeConfigurator m_ssl_handshake_configurator;
};

class EntryPoint final: public Websocket<EntryPoint>
{
public:
    EntryPoint(boost::asio::io_context &io_context): Websocket<EntryPoint>(io_context)
    {
    }


    void handle_message(std::string_view response)
    {
        // handle the response
    }

private:
    void subscribe()
    {
        async_start("url", "443", "/ws/private/orders/control");
    }
};

auto main() -> int32_t {
    std::cout << "entrypoint" << std::endl;
    boost::asio::io_context ioctx;
    auto entry_point_ws = std::make_shared<EntryPoint>(ioctx);
    entry_point_ws->async_start();
    
    while (true) {
        auto pseudo_request = "pseudo_request_popped_from_queue_pushed_by_another_thread";
        entry_point_ws->send_message("pseudo-request");
        ioctx.poll();
    }
}

My understanding is that there can be no more than one active async_write call at the same time for a WebSocket connection. However, in my main loop, I'm continuously popping requests from a queue and calling send_message and calling poll after that.

My question is:

How should I structure my program to handle multiple write requests when there can only be one active async_write at a time? I want to ensure that all messages in the queue are sent without losing any due to overlapping async_write calls.

Any suggestions on how to queue these written operations or restructure the program to handle this scenario would be greatly appreciated.

My possible solution :

run the io_context::run in seperate thread and from the main thread i push the request to lockfree queue which then gets queued to async_write and i get the response back from another queue which will be spsc lockfree queue which all the responses will be in.

code : https://gist.github.com/Naseefabu/966d4469980977450f2746db73a43065


Solution

  • Your problems are way bigger. Your entire write invokes UB because you, by definition, pass a stale buffer to async_write:

    asio::buffer(std::string{message})
    

    That's a temporary string object which will by definition be destructed before the operation completes.

    What you typically do is move the message into a queue local to the websocket:

    void send_message(std::string_view message) {
        m_outgoing_messages.emplace_back(message);
        if (m_outgoing_messages.size() == 1) {
            do_write_loop();
        }
    }
    

    And then

    private: std::dequestd::string m_outgoing_messages;

    void do_write_loop() {
        if (m_outgoing_messages.empty()) {
            return;
        }
    
        m_ws.async_write(asio::buffer(m_outgoing_messages.front()),
                         [this, self = this->shared_from_this()](error_code ec, size_t) {
                             if (ec) {
                             SPDLOG_ERROR("Failed to send message(s)");
                                 SPDLOG_ERROR("    error_code.message: {}", ec.message());
                                 SPDLOG_ERROR("    error_code.what: {}", ec.what());
                             } else {
                                 m_outgoing_messages.pop_front();
                                 do_write_loop();
                             }
                         });
    }
    

    There's a lot of unrelated things to improve. Perhaps you'd want to look at Channels.

    At least consider removing the coupling with io_context& using executors.

    Here's some simplifications/generalizations that might help you get there:

    Live On Coliru

    #include <boost/asio.hpp>
    #include <boost/beast.hpp>
    #include <boost/beast/ssl.hpp>
    #include <deque>
    #include <fmt/format.h>
    #include <iostream>
    namespace spdlog {
        class logger {
          public:
            template <typename... Args>
            static constexpr void error(std::string_view level, std::string_view format, Args const&... args) {
                std::cerr << level << "\t" << fmt::format(fmt::runtime(format), args...) << std::endl;
            }
        };
        constexpr logger instance{};
    } // namespace spdlog
    
    #define SPDLOG_ERROR(...) spdlog::instance.error("ERR", __VA_ARGS__)
    #define SPDLOG_WARN(...) spdlog::instance.error("WRN", __VA_ARGS__)
    #define SPDLOG_INFO(...) spdlog::instance.error("INF", __VA_ARGS__)
    #define TRACE SPDLOG_INFO("{}:{}", __PRETTY_FUNCTION__, __LINE__)
    
    namespace beast     = boost::beast;
    namespace asio      = boost::asio;
    namespace ssl       = boost::asio::ssl;
    namespace websocket = boost::beast::websocket;
    using tcp           = boost::asio::ip::tcp;
    using beast::error_code;
    
    template <typename Derived> class Websocket : public std::enable_shared_from_this<Websocket<Derived>> {
      public:
        using Stream          = websocket::stream<ssl::stream<tcp::socket>>;
        using SSLConfigurator = std::function<void(Websocket*, Stream&)>;
    
        explicit Websocket(asio::any_io_executor ex, SSLConfigurator ssl_handshake_configurator = {})
            : m_resolver{ex}
            , m_ws{ex, m_ssl_context}
            , m_ssl_handshake_configurator(std::move(ssl_handshake_configurator)) {
            TRACE;
            m_ws.read_message_max(0);
            m_ws.auto_fragment(false);
    
            m_response.reserve(100'000'000);
        }
    
        virtual ~Websocket() {
            TRACE;
        }
    
        void start(std::string_view host, std::string_view port, std::string_view target) {
            TRACE;
            asio::post( //
                m_ws.get_executor(),
                [this, self = shared_from_this(), host = std::string{host}, port = std::string{port},
                 target = std::string{target}]() mutable { //
                    do_start(host, port, target);
                });
        }
    
        void stop() {
            TRACE;
            asio::post(m_ws.get_executor(), [this, self = shared_from_this()] { do_stop(); });
        }
    
        void send_message(std::string_view message) {
            return;
            TRACE;
            asio::post(m_ws.get_executor(),
                       [this, self = shared_from_this(), m = std::string{message}]() mutable {
                           do_send_message(std::move(m));
                       });
        }
    
      protected:
        virtual void send_subscribe_message() {
            TRACE; // send the subscribe message
        };
    
      private:
        using std::enable_shared_from_this<Websocket<Derived>>::shared_from_this;
    
        void do_start(std::string_view host, std::string_view port, std::string_view target) { // on the strand
            TRACE;
            m_host   = std::string{host};
            m_target = std::string{target};
            SPDLOG_INFO("Connecting to {}:{}", m_host, port);
    
            // Look up the domain name
            m_resolver.async_resolve( //
                m_host, port, [this, self = shared_from_this()](error_code ec, tcp::resolver::results_type res) {
                    if (!check_fail(ec, "async_start")) {
                        async_connect(std::move(res));
                    }
                });
        }
    
        void do_stop() {
            TRACE;
            if (m_stop_requested.exchange(true))
                return;
    
            if (m_ws.next_layer().next_layer().is_open()) {
                m_ws.async_close(websocket::close_code::normal, [self = shared_from_this()](error_code) {});
            }
        }
    
        std::deque<std::string> m_outgoing_messages;
    
        void do_send_message(std::string_view message) { // on the strand
            TRACE;
            m_outgoing_messages.emplace_back(message);
            if (m_outgoing_messages.size() == 1) {
                do_write_loop();
            }
        }
    
        void do_write_loop() { // on the strand
            TRACE;
            if (m_outgoing_messages.empty()) {
                return;
            }
    
            m_ws.async_write( //
                asio::buffer(m_outgoing_messages.front()),
                [this, self = shared_from_this()](error_code ec, size_t) {
                    if (!check_fail(ec, "do_write_loop")) {
                        m_outgoing_messages.pop_front();
                        do_write_loop();
                    }
                });
        }
    
        void async_connect(tcp::resolver::results_type result) {
            TRACE;
            if (!SSL_set_tlsext_host_name(m_ws.next_layer().native_handle(), m_host.c_str())) {
                SPDLOG_ERROR("Boost::Beast error: async connect");
                return;
            }
    
            asio::async_connect( //
                m_ws.next_layer().next_layer(), result,
                [this, self = shared_from_this()](error_code ec, tcp::endpoint) {
                    if (check_fail(ec, "async_connect")) {
                        if (!m_stop_requested) {
                            // handle the beast error
                        }
                    } else {
                        on_connected();
                    }
                });
        }
    
        void on_connected() {
            TRACE;
            m_ws.next_layer().async_handshake( //
                ssl::stream_base::client, [this, self = shared_from_this()](error_code ec) {
                    if (!check_fail(ec, "ssl handshake")) {
                        on_ssl_handshake();
                    }
                });
        }
    
        void on_ssl_handshake() {
            TRACE;
            if (m_ssl_handshake_configurator) {
                m_ssl_handshake_configurator(this, m_ws);
            }
    
            m_ws.async_handshake(m_host, m_target, [this, self = shared_from_this()](error_code ec) {
                if (!check_fail(ec, "ws handshake")) {
                    TRACE;
                    send_subscribe_message();
                    start_read(ec);
                }
            });
        }
    
        void start_read(error_code ec) {
            TRACE;
            if (check_fail(ec, "start_read")) {
                return stop();
            }
    
            m_buffer.prepare(512'000'000);
            m_ws.async_read(                                                    //
                m_buffer,                                                       //
                [this, self = shared_from_this()](error_code ec, size_t size) { //
                    on_read(ec, size);
                });
        }
    
        void on_read(error_code ec, [[maybe_unused]] size_t size) {
            TRACE;
            if (check_fail(ec, "on_read")) {
                SPDLOG_ERROR("    reason: {}", m_ws.reason().reason.data());
                SPDLOG_ERROR("    reason_code: {}", m_ws.reason().code);
                return stop();
            }
    
            m_response.clear();
    
            for (auto const& bytes : m_buffer.data()) {
                m_response.append(static_cast<char const*>(bytes.data()), bytes.size());
            }
            m_buffer.consume(m_buffer.size());
    
            static_cast<Derived*>(this)->handle_message(m_response);
    
            // restart
            start_read(error_code{});
        }
    
        bool check_fail(error_code ec, std::string_view task) const {
            if (m_stop_requested)
                return true;
    
            if (ec) {
                SPDLOG_ERROR("Failed to {}", task);
                SPDLOG_ERROR("    error_code.message: {}", ec.message());
                SPDLOG_ERROR("    error_code.what: {}", ec.what());
            }
            return ec.failed();
        }
    
        ssl::context        m_ssl_context{ssl::context::sslv23_client};
        tcp::resolver       m_resolver;
        Stream              m_ws;
        beast::multi_buffer m_buffer;
        std::string         m_response;
        std::string         m_host, m_target;
        std::atomic_bool    m_stop_requested{false};
    
        SSLConfigurator m_ssl_handshake_configurator;
    };
    
    class EntryPoint final : public Websocket<EntryPoint> {
      public:
        EntryPoint(asio::any_io_executor ex) : Websocket<EntryPoint>(ex) {}
    
        void handle_message(std::string_view /*response*/) {
            TRACE;
            // handle the response
        }
    
        void subscribe() {
            TRACE;
            start("localhost", "1443", "/ws/private/orders/control");
        }
    };
    
    static_assert(not std::is_copy_constructible_v<EntryPoint>);
    static_assert(not std::is_move_assignable_v<EntryPoint>);
    static_assert(not std::is_copy_assignable_v<EntryPoint>);
    
    using namespace std::chrono_literals;
    using std::this_thread::sleep_for;
    
    int main() {
        std::cout << "entrypoint" << std::endl;
        asio::thread_pool ioc(1);
    
        auto ws = std::make_shared<EntryPoint>(make_strand(ioc));
        ws->subscribe();
    
        for (sleep_for(1s);; sleep_for(100ms)) {
            ws->send_message("pseudo_request_popped_from_queue_pushed_by_another_thread");
        }
    
        ioc.join();
    }
    

    I have not been able to comprehensively test, because I couldn't setup a wss server quickly enough.