template<typename Derived>
class Websocket: public std::enable_shared_from_this<Websocket<Derived>>
{
public:
using InternalWSType = boost::beast::websocket::stream<
boost::asio::ssl::stream<boost::asio::ip::tcp::socket>>;
using SSLHandshakeConfigurator = std::function<void(Websocket *, InternalWSType &)>;
explicit Websocket(boost::asio::io_context &io_context)
: m_io_context { io_context }
, m_resolver { m_io_context }
, m_ws { m_io_context, m_ssl_context }
{
m_response.reserve(100'000'000);
}
Websocket(
boost::asio::io_context &io_context,
SSLHandshakeConfigurator ssl_handshake_configurator
)
: m_io_context { io_context }
, m_resolver { m_io_context }
, m_ws { m_io_context, m_ssl_context }
, m_ssl_handshake_configurator(std::move(ssl_handshake_configurator))
{
m_ws.read_message_max(0);
m_ws.auto_fragment(false);
m_response.reserve(100'000'000);
}
virtual ~Websocket() = default;
Websocket(const Websocket &) = delete;
Websocket(Websocket &&) noexcept = delete;
auto operator=(const Websocket &) -> Websocket & = delete;
auto operator=(Websocket &&) noexcept -> Websocket & = delete;
void async_start(std::string_view host, std::string_view port, std::string_view target)
{
m_host = std::string { host };
m_target = std::string { target };
// Look up the domain name
m_resolver.async_resolve(
m_host,
std::string { port },
[this](
boost::system::error_code error_code,
boost::asio::ip::tcp::resolver::results_type res
) {
if (error_code)
{
// handle
SPDLOG_ERROR("Failed to async_start");
SPDLOG_ERROR(" error_code.message: {}", error_code.message());
SPDLOG_ERROR(" error_code.what: {}", error_code.what());
}
else
{
async_connect(std::move(res));
}
}
);
}
void async_stop()
{
m_stop_requested = true;
auto holder = this->shared_from_this();
if (m_ws.next_layer().next_layer().is_open())
{
m_ws.async_close(
boost::beast::websocket::close_code::normal,
[holder = std::move(holder)](const boost::system::error_code &) {}
);
}
}
void send_message(std::string_view message)
{
m_ws.async_write(
boost::beast::net::buffer(std::string { message }),
[self = this->shared_from_this()](boost::system::error_code error_code, size_t) {
if (error_code)
{
SPDLOG_ERROR("Failed to send message");
SPDLOG_ERROR(" error_code.message: {}", error_code.message());
SPDLOG_ERROR(" error_code.what: {}", error_code.what());
}
}
);
}
void poll()
{
m_io_context.poll();
}
void run()
{
m_io_context.run();
}
protected:
[[nodiscard]] auto get_io_context() const -> boost::asio::io_context &
{
return m_io_context;
}
virtual void send_subscribe_message() {};
private:
void async_connect(boost::asio::ip::tcp::resolver::results_type result)
{
if (!SSL_set_tlsext_host_name(m_ws.next_layer().native_handle(), m_host.c_str()))
{
SPDLOG_ERROR("Boost::Beast error: async connect");
return;
}
boost::asio::async_connect(
m_ws.next_layer().next_layer(),
result.begin(),
result.end(),
[this](boost::system::error_code error_code, boost::asio::ip::tcp::resolver::iterator) {
if (error_code)
{
SPDLOG_ERROR("Failed to async connect");
SPDLOG_ERROR(" error_code.message: {}", error_code.message());
SPDLOG_ERROR(" error_code.what: {}", error_code.what());
if (!m_stop_requested)
{
// handle the beast error
}
}
else
{
on_connected();
}
}
);
}
void on_connected()
{
m_ws.next_layer().async_handshake(
boost::asio::ssl::stream_base::client,
[this](boost::system::error_code error_code) {
if (error_code)
{
if (!m_stop_requested)
{
}
}
else
{
on_async_ssl_handshake();
}
}
);
}
void on_async_ssl_handshake()
{
if (m_ssl_handshake_configurator)
{
m_ssl_handshake_configurator(this, m_ws);
}
m_ws.async_handshake(m_host, m_target, [this](boost::system::error_code error_code) {
if (!error_code)
send_subscribe_message();
else
{
SPDLOG_ERROR("Failed to async ssl handshake");
SPDLOG_ERROR(" error_code.message: {}", error_code.message());
SPDLOG_ERROR(" error_code.what: {}", error_code.what());
}
start_read(error_code);
});
}
void start_read(boost::system::error_code error_code)
{
if (error_code)
{
if (!m_stop_requested)
{
SPDLOG_ERROR("Failed to start_read");
SPDLOG_ERROR(" error_code.message: {}", error_code.message());
SPDLOG_ERROR(" error_code.what: {}", error_code.what());
}
async_stop();
return;
}
m_buffer.prepare(512'000'000);
m_ws.async_read(m_buffer, [this](boost::system::error_code error_code, size_t size) {
on_read(error_code, size);
});
}
void on_read(boost::system::error_code error_code, [[maybe_unused]] size_t size)
{
if (error_code)
{
if (!m_stop_requested)
{
SPDLOG_ERROR("Failed to on_read");
SPDLOG_ERROR(" error_code.message: {}", error_code.message());
SPDLOG_ERROR(" error_code.what: {}", error_code.what());
SPDLOG_ERROR(" reason: {}", m_ws.reason().reason.data());
SPDLOG_ERROR(" reason_code: {}", m_ws.reason().code);
}
return;
}
m_response.clear();
for (const auto &bytes : m_buffer.data())
{
m_response.append(static_cast<const char *>(bytes.data()), bytes.size());
}
m_buffer.consume(m_buffer.size());
static_cast<Derived *>(this)->handle_message(m_response);
// restart
start_read(boost::system::error_code {});
}
boost::asio::io_context &m_io_context;
boost::asio::ssl::context m_ssl_context { boost::asio::ssl::context::sslv23_client };
boost::asio::ip::tcp::resolver m_resolver;
InternalWSType m_ws;
boost::beast::multi_buffer m_buffer;
std::string m_response;
std::string m_host;
std::string m_target;
bool m_stop_requested {};
SSLHandshakeConfigurator m_ssl_handshake_configurator;
};
class EntryPoint final: public Websocket<EntryPoint>
{
public:
EntryPoint(boost::asio::io_context &io_context): Websocket<EntryPoint>(io_context)
{
}
void handle_message(std::string_view response)
{
// handle the response
}
private:
void subscribe()
{
async_start("url", "443", "/ws/private/orders/control");
}
};
auto main() -> int32_t {
std::cout << "entrypoint" << std::endl;
boost::asio::io_context ioctx;
auto entry_point_ws = std::make_shared<EntryPoint>(ioctx);
entry_point_ws->async_start();
while (true) {
auto pseudo_request = "pseudo_request_popped_from_queue_pushed_by_another_thread";
entry_point_ws->send_message("pseudo-request");
ioctx.poll();
}
}
My understanding is that there can be no more than one active async_write
call at the same time for a WebSocket connection. However, in my main loop, I'm continuously popping requests from a queue and calling send_message
and calling poll after that.
My question is:
How should I structure my program to handle multiple write requests when there can only be one active async_write
at a time? I want to ensure that all messages in the queue are sent without losing any due to overlapping async_write
calls.
Any suggestions on how to queue these written operations or restructure the program to handle this scenario would be greatly appreciated.
My possible solution :
run the io_context::run in seperate thread and from the main thread i push the request to lockfree queue which then gets queued to async_write and i get the response back from another queue which will be spsc lockfree queue which all the responses will be in.
code : https://gist.github.com/Naseefabu/966d4469980977450f2746db73a43065
Your problems are way bigger. Your entire write invokes UB because you, by definition, pass a stale buffer to async_write
:
asio::buffer(std::string{message})
That's a temporary string object which will by definition be destructed before the operation completes.
What you typically do is move the message into a queue local to the websocket:
void send_message(std::string_view message) {
m_outgoing_messages.emplace_back(message);
if (m_outgoing_messages.size() == 1) {
do_write_loop();
}
}
And then
private: std::dequestd::string m_outgoing_messages;
void do_write_loop() {
if (m_outgoing_messages.empty()) {
return;
}
m_ws.async_write(asio::buffer(m_outgoing_messages.front()),
[this, self = this->shared_from_this()](error_code ec, size_t) {
if (ec) {
SPDLOG_ERROR("Failed to send message(s)");
SPDLOG_ERROR(" error_code.message: {}", ec.message());
SPDLOG_ERROR(" error_code.what: {}", ec.what());
} else {
m_outgoing_messages.pop_front();
do_write_loop();
}
});
}
There's a lot of unrelated things to improve. Perhaps you'd want to look at Channels.
At least consider removing the coupling with io_context&
using executors.
Here's some simplifications/generalizations that might help you get there:
#include <boost/asio.hpp>
#include <boost/beast.hpp>
#include <boost/beast/ssl.hpp>
#include <deque>
#include <fmt/format.h>
#include <iostream>
namespace spdlog {
class logger {
public:
template <typename... Args>
static constexpr void error(std::string_view level, std::string_view format, Args const&... args) {
std::cerr << level << "\t" << fmt::format(fmt::runtime(format), args...) << std::endl;
}
};
constexpr logger instance{};
} // namespace spdlog
#define SPDLOG_ERROR(...) spdlog::instance.error("ERR", __VA_ARGS__)
#define SPDLOG_WARN(...) spdlog::instance.error("WRN", __VA_ARGS__)
#define SPDLOG_INFO(...) spdlog::instance.error("INF", __VA_ARGS__)
#define TRACE SPDLOG_INFO("{}:{}", __PRETTY_FUNCTION__, __LINE__)
namespace beast = boost::beast;
namespace asio = boost::asio;
namespace ssl = boost::asio::ssl;
namespace websocket = boost::beast::websocket;
using tcp = boost::asio::ip::tcp;
using beast::error_code;
template <typename Derived> class Websocket : public std::enable_shared_from_this<Websocket<Derived>> {
public:
using Stream = websocket::stream<ssl::stream<tcp::socket>>;
using SSLConfigurator = std::function<void(Websocket*, Stream&)>;
explicit Websocket(asio::any_io_executor ex, SSLConfigurator ssl_handshake_configurator = {})
: m_resolver{ex}
, m_ws{ex, m_ssl_context}
, m_ssl_handshake_configurator(std::move(ssl_handshake_configurator)) {
TRACE;
m_ws.read_message_max(0);
m_ws.auto_fragment(false);
m_response.reserve(100'000'000);
}
virtual ~Websocket() {
TRACE;
}
void start(std::string_view host, std::string_view port, std::string_view target) {
TRACE;
asio::post( //
m_ws.get_executor(),
[this, self = shared_from_this(), host = std::string{host}, port = std::string{port},
target = std::string{target}]() mutable { //
do_start(host, port, target);
});
}
void stop() {
TRACE;
asio::post(m_ws.get_executor(), [this, self = shared_from_this()] { do_stop(); });
}
void send_message(std::string_view message) {
return;
TRACE;
asio::post(m_ws.get_executor(),
[this, self = shared_from_this(), m = std::string{message}]() mutable {
do_send_message(std::move(m));
});
}
protected:
virtual void send_subscribe_message() {
TRACE; // send the subscribe message
};
private:
using std::enable_shared_from_this<Websocket<Derived>>::shared_from_this;
void do_start(std::string_view host, std::string_view port, std::string_view target) { // on the strand
TRACE;
m_host = std::string{host};
m_target = std::string{target};
SPDLOG_INFO("Connecting to {}:{}", m_host, port);
// Look up the domain name
m_resolver.async_resolve( //
m_host, port, [this, self = shared_from_this()](error_code ec, tcp::resolver::results_type res) {
if (!check_fail(ec, "async_start")) {
async_connect(std::move(res));
}
});
}
void do_stop() {
TRACE;
if (m_stop_requested.exchange(true))
return;
if (m_ws.next_layer().next_layer().is_open()) {
m_ws.async_close(websocket::close_code::normal, [self = shared_from_this()](error_code) {});
}
}
std::deque<std::string> m_outgoing_messages;
void do_send_message(std::string_view message) { // on the strand
TRACE;
m_outgoing_messages.emplace_back(message);
if (m_outgoing_messages.size() == 1) {
do_write_loop();
}
}
void do_write_loop() { // on the strand
TRACE;
if (m_outgoing_messages.empty()) {
return;
}
m_ws.async_write( //
asio::buffer(m_outgoing_messages.front()),
[this, self = shared_from_this()](error_code ec, size_t) {
if (!check_fail(ec, "do_write_loop")) {
m_outgoing_messages.pop_front();
do_write_loop();
}
});
}
void async_connect(tcp::resolver::results_type result) {
TRACE;
if (!SSL_set_tlsext_host_name(m_ws.next_layer().native_handle(), m_host.c_str())) {
SPDLOG_ERROR("Boost::Beast error: async connect");
return;
}
asio::async_connect( //
m_ws.next_layer().next_layer(), result,
[this, self = shared_from_this()](error_code ec, tcp::endpoint) {
if (check_fail(ec, "async_connect")) {
if (!m_stop_requested) {
// handle the beast error
}
} else {
on_connected();
}
});
}
void on_connected() {
TRACE;
m_ws.next_layer().async_handshake( //
ssl::stream_base::client, [this, self = shared_from_this()](error_code ec) {
if (!check_fail(ec, "ssl handshake")) {
on_ssl_handshake();
}
});
}
void on_ssl_handshake() {
TRACE;
if (m_ssl_handshake_configurator) {
m_ssl_handshake_configurator(this, m_ws);
}
m_ws.async_handshake(m_host, m_target, [this, self = shared_from_this()](error_code ec) {
if (!check_fail(ec, "ws handshake")) {
TRACE;
send_subscribe_message();
start_read(ec);
}
});
}
void start_read(error_code ec) {
TRACE;
if (check_fail(ec, "start_read")) {
return stop();
}
m_buffer.prepare(512'000'000);
m_ws.async_read( //
m_buffer, //
[this, self = shared_from_this()](error_code ec, size_t size) { //
on_read(ec, size);
});
}
void on_read(error_code ec, [[maybe_unused]] size_t size) {
TRACE;
if (check_fail(ec, "on_read")) {
SPDLOG_ERROR(" reason: {}", m_ws.reason().reason.data());
SPDLOG_ERROR(" reason_code: {}", m_ws.reason().code);
return stop();
}
m_response.clear();
for (auto const& bytes : m_buffer.data()) {
m_response.append(static_cast<char const*>(bytes.data()), bytes.size());
}
m_buffer.consume(m_buffer.size());
static_cast<Derived*>(this)->handle_message(m_response);
// restart
start_read(error_code{});
}
bool check_fail(error_code ec, std::string_view task) const {
if (m_stop_requested)
return true;
if (ec) {
SPDLOG_ERROR("Failed to {}", task);
SPDLOG_ERROR(" error_code.message: {}", ec.message());
SPDLOG_ERROR(" error_code.what: {}", ec.what());
}
return ec.failed();
}
ssl::context m_ssl_context{ssl::context::sslv23_client};
tcp::resolver m_resolver;
Stream m_ws;
beast::multi_buffer m_buffer;
std::string m_response;
std::string m_host, m_target;
std::atomic_bool m_stop_requested{false};
SSLConfigurator m_ssl_handshake_configurator;
};
class EntryPoint final : public Websocket<EntryPoint> {
public:
EntryPoint(asio::any_io_executor ex) : Websocket<EntryPoint>(ex) {}
void handle_message(std::string_view /*response*/) {
TRACE;
// handle the response
}
void subscribe() {
TRACE;
start("localhost", "1443", "/ws/private/orders/control");
}
};
static_assert(not std::is_copy_constructible_v<EntryPoint>);
static_assert(not std::is_move_assignable_v<EntryPoint>);
static_assert(not std::is_copy_assignable_v<EntryPoint>);
using namespace std::chrono_literals;
using std::this_thread::sleep_for;
int main() {
std::cout << "entrypoint" << std::endl;
asio::thread_pool ioc(1);
auto ws = std::make_shared<EntryPoint>(make_strand(ioc));
ws->subscribe();
for (sleep_for(1s);; sleep_for(100ms)) {
ws->send_message("pseudo_request_popped_from_queue_pushed_by_another_thread");
}
ioc.join();
}
I have not been able to comprehensively test, because I couldn't setup a wss server quickly enough.