Search code examples
c++boost-asiogoogletest

boost::asio async_write interleaving in googletest


I have been attempting to write a tcp server using boost::asio, this server will be sending data to any number of connected clients, I have been attempting to write a few tests using google test.

I am having some trouble with the test below, it is designed to test for interleaving, it does so by filling 5 buffers with a single number each. Then when I read I use a map to check the count of each value.

I understand that async_write will call async_write_some under the hood, so to avoid interleaving strands are recommended as they guarantee sequential operations on the same strand.

That doesn't seem to be happening, and I am not sure why? Any help is appreciated.

#include <boost/asio.hpp>
#include <gtest/gtest.h>
using tcp = boost::asio::ip::tcp;

// tcp_session.h
using boost_tcp = boost::asio::ip::tcp;

class tcp_session : public std::enable_shared_from_this<tcp_session> {
  public:
    explicit tcp_session(boost_tcp::socket socket);
    ~tcp_session();

    void   async_write(char const* data, size_t size);
    void   async_write(boost::asio::const_buffer buff);
    size_t write(char const* data, size_t size);
    size_t write(boost::asio::const_buffer buff);

    void close();

  private:
    boost_tcp::socket                                     m_socket;
    std::mutex                                            m_socket_mutex{};
    boost::asio::strand<boost_tcp::socket::executor_type> m_strand;

    void handle_async_write(boost::system::error_code const& err, size_t bytes_transferred);
};

// tcp_svr.h
class tcp_svr {
  public:
    explicit tcp_svr(int16_t port);
    ~tcp_svr();

    void async_write(char const* data, size_t size);
    void async_write(boost::asio::const_buffer buff);
    void run();
    void stop();

    size_t              get_session_count();
    boost_tcp::endpoint get_local_endpoint() { return m_acceptor.local_endpoint(); }

  private:
    using session_ptr = std::shared_ptr<tcp_session>;
    void                     do_accept();
    std::atomic_bool         m_stopped;
    std::vector<session_ptr> m_sessions{};
    std::mutex               m_sessions_mutex{};
    boost::asio::io_service  m_io_service{};
    boost_tcp::acceptor      m_acceptor;
};
// tcp_svr.cpp

tcp_svr::tcp_svr(int16_t port) : m_acceptor(m_io_service, tcp::endpoint(tcp::v4(), port)) { do_accept(); }

tcp_svr::~tcp_svr() { stop(); }

void tcp_svr::async_write(boost::asio::const_buffer buff) {
    std::lock_guard lock(m_sessions_mutex);
    for (auto const& client : m_sessions) {
        client->async_write(buff);
    }
}

void tcp_svr::run() { m_io_service.run(); }

void tcp_svr::stop() {
    if (m_stopped)
        return;

    m_stopped   = true;
    auto thread = std::thread{[this]() { m_acceptor.cancel(); }};
    thread.join();
    {
        std::lock_guard lock(m_sessions_mutex);
        for (auto const& client : m_sessions) {
            client->close();
        }

        m_sessions.clear();
    }

    m_io_service.stop();
}

size_t tcp_svr::get_session_count() {
    std::lock_guard lock(m_sessions_mutex);
    return m_sessions.size();
}

void tcp_svr::do_accept() {
    m_acceptor.async_accept(boost::asio::make_strand(m_io_service),
                            [this](boost::system::error_code err, boost_tcp::socket socket) {
                                if (m_stopped) {
                                    return;
                                }

                                if (!err) {
                                    std::lock_guard lock(m_sessions_mutex);
                                    m_sessions.push_back(std::make_shared<tcp_session>(std::move(socket)));
                                }

                                do_accept();
                            });
}

// tcp_session.cpp
tcp_session::tcp_session(boost_tcp::socket socket)
    : m_socket(std::move(socket))
    , m_strand(socket.get_executor()) {}

tcp_session::~tcp_session() { close(); }

void tcp_session::async_write(boost::asio::const_buffer buff) {
    boost::asio::post(m_strand, [this, buff = buff]() {
        std::lock_guard lock(m_socket_mutex);
        boost::asio::async_write(m_socket, buff,
                                 [this, self = shared_from_this()](const boost::system::error_code& ec,
                                                                   std::size_t bytes_transferred) {
                                     handle_async_write(ec, bytes_transferred);
                                 });
    });
}

void tcp_session::async_write(char const* data, size_t size) { async_write(boost::asio::buffer(data, size)); }

void tcp_session::close() {
    std::lock_guard lock(m_socket_mutex);
    if (m_socket.is_open()) {
        m_socket.close();
    }
}

void tcp_session::handle_async_write(boost::system::error_code const& err,
                                     [[maybe_unused]] size_t          bytes_transferred) {
    if (err) {
        // log
    }
}

class TcpFixture : public testing::Test {
  protected:
    boost::asio::io_service  m_client_io_service{};
    std::unique_ptr<tcp_svr> m_server;
    std::thread              m_client_io_service_thread;
    std::thread              m_server_thread;

    void SetUp() override {
        int16_t port               = 1234;
        m_server                   = std::make_unique<tcp_svr>(port);
        m_client_io_service_thread = std::thread([&]() { m_client_io_service.run(); });
        m_server_thread            = std::thread([&]() { m_server->run(); });
    }

    void TearDown() override {
        m_client_io_service.stop();
        m_client_io_service_thread.join();
        m_server->stop();
        m_server_thread.join();
    }
};

TEST_F(TcpFixture, NoInterleavingAsyncWrite) {
    tcp::socket client_socket(m_client_io_service);
    client_socket.connect(m_server->get_local_endpoint());

    static constexpr int    kMessageCount = 5;
    static constexpr size_t kDataSize     = 65536;

    std::array<std::array<int, kDataSize>, kMessageCount> data{};
    for (int i = 0; i < kMessageCount; i++) {
        std::fill(data.at(i).begin(), data.at(i).end(), i + 1);
    }

    while (m_server->get_session_count() != 1) {
    }

    for (auto const& arr : data) {
        m_server->async_write(boost::asio::buffer(arr));
    }

    for (int i = 0; i < kMessageCount; i++) {
        std::array<int, kDataSize> read_buffer{};
        boost::asio::read(client_socket, boost::asio::buffer(read_buffer));

        std::map<int, int> value_count{};
        for (int val : read_buffer) {
            value_count[val]++;
        }

        EXPECT_EQ(value_count.size(), 1);
    }
}

Printing e.g.

Running main() from ./googletest/src/gtest_main.cc
[==========] Running 1 test from 1 test suite.
[----------] Global test environment set-up.
[----------] 1 test from TcpFixture
[ RUN      ] TcpFixture.NoInterleavingAsyncWrite
/home/sehe/Projects/stackoverflow/test.cpp:190: Failure
Expected equality of these values:
  value_count.size()
    Which is: 4
  1
/home/sehe/Projects/stackoverflow/test.cpp:190: Failure
Expected equality of these values:
  value_count.size()
    Which is: 4
  1
/home/sehe/Projects/stackoverflow/test.cpp:190: Failure
Expected equality of these values:
  value_count.size()
    Which is: 4
  1
/home/sehe/Projects/stackoverflow/test.cpp:190: Failure
Expected equality of these values:
  value_count.size()
    Which is: 4
  1
/home/sehe/Projects/stackoverflow/test.cpp:190: Failure
Expected equality of these values:
  value_count.size()
    Which is: 4
  1
[  FAILED  ] TcpFixture.NoInterleavingAsyncWrite (102 ms)
[----------] 1 test from TcpFixture (102 ms total)

[----------] Global test environment tear-down
[==========] 1 test from 1 test suite ran. (102 ms total)
[  PASSED  ] 0 tests.
[  FAILED  ] 1 test, listed below:
[  FAILED  ] TcpFixture.NoInterleavingAsyncWrite

 1 FAILED TEST

Solution

  • You are confused about strands.

    You don't need a mutex and a strand. When you associate the strand executor with the accept socket, you don't need another strand. Just use m_socket.get_executor().

    Also, strands do prevent concurrent execution of (intermediate) handlers. They do NOT prevent overlapping async operations. It's just that they will not be concurrently initiated. ¯\(ツ)

    Besides, your hole operation is single-threaded, so the strand is pretty redundant.

    If you want to avoid interleaving async composed write operations, you have to serialize the operations, e.g. by queueing.

    Besides there are session lifetime issues (sessions will NEVER go away, because your svr keeps the shared ownership). More seriously, the async strand posted operations fail to capture shared_from_this.

    You also didn't handle any errors (including EOF) in the client.

    I've taken the liberty to address all of the above and some unrelated improvements, here's the test case fully working as expected:

    Live On Compiler Explorer

    #include <boost/asio.hpp>
    #include <deque>
    #include <list>
    
    namespace Tcp {
        namespace asio = boost::asio;
        using tcp      = asio::ip::tcp;
        using boost::system::error_code;
    
        struct Session : std::enable_shared_from_this<Session> {
            Session(tcp::socket socket) : m_socket(std::move(socket)) {}
    
            void start() {
                // keep a running read so we can detect errors and keep the session alive without the server
                // owning the lifetime
                post(m_socket.get_executor(), std::bind(&Session::do_read_loop, shared_from_this()));
            }
    
            void send(asio::const_buffer buff) {
                post(m_socket.get_executor(), [this, self = shared_from_this(), buff]() {
                    outbox_.push_back(std::move(buff));
                    if (outbox_.size() == 1)
                        do_write_loop();
                });
            }
    
            void close() {
                post(m_socket.get_executor(), std::bind(&Session::do_close, shared_from_this()));
            }
    
          private:
            tcp::socket                    m_socket;
            std::array<char, 16>           incoming_;
            std::deque<asio::const_buffer> outbox_;
    
            void do_read_loop() { // on strand
                m_socket.async_read_some(asio::buffer(incoming_),
                                         [this, self = shared_from_this()](error_code ec, size_t) {
                                             if (!ec)
                                                 do_read_loop();
                                         });
            }
    
            void do_write_loop() { // on strand
                if (outbox_.empty())
                    return;
    
                async_write(m_socket, outbox_.front(), //
                            [this, self = shared_from_this()](error_code ec, [[maybe_unused]] size_t n) {
                                if (!ec) {
                                    outbox_.pop_front();
                                    do_write_loop();
                                }
                            });
            }
    
            void do_close() { // on strand
                if (m_socket.is_open())
                    m_socket.close();
            }
        };
    
        class server {
          public:
            server(asio::any_io_executor ex, uint16_t port) //
                : ex_(ex)
                , acc_(make_strand(ex), {{}, port}) {
                do_accept();
            }
    
            ~server() { stop(); }
    
            void send(asio::const_buffer buff) {
                for (auto& handle : sessions_) {
                    if (auto sess = handle.lock())
                        sess->send(buff);
                }
            }
    
            void stop() {
                post(acc_.get_executor(), [this] {
                    for (auto& handle : sessions_) {
                        if (auto sess = handle.lock())
                            sess->close();
                    }
    
                    sessions_.clear();
                    acc_.cancel();
                });
            }
    
            size_t get_session_count() {
                return asio::post(              //
                           acc_.get_executor(), //
                           asio::use_future([this] { return sessions_.size(); }))
                    .get();
            }
            tcp::endpoint local_endpoint() { return acc_.local_endpoint(); }
    
          private:
            using session_ptr = std::shared_ptr<Session>;
            using handle      = std::weak_ptr<Session>;
            void do_accept() {
                acc_.async_accept(asio::make_strand(ex_), [this](error_code err, tcp::socket socket) {
                    if (!err) {
                        do_accept();
    
                        auto sess = std::make_shared<Session>(std::move(socket));
                        sessions_.push_back(sess);
                        sess->start();
                    }
                });
            }
    
            asio::any_io_executor ex_;
            tcp::acceptor         acc_;
            std::list<handle>     sessions_{};
        };
    
    } // namespace Tcp
    
    #include <gtest/gtest.h>
    
    class TcpFixture : public testing::Test {
      protected:
        boost::asio::thread_pool ioc_{1};
        uint16_t                 port_ = 1234;
        Tcp::server              svr_{ioc_.get_executor(), port_};
    
        void SetUp() override {}
    
        void TearDown() override {
            svr_.stop();
            ioc_.join(); // optionally wait
        }
    };
    
    TEST_F(TcpFixture, NoInterleavingAsyncWrite) {
        boost::asio::ip::tcp::socket client_socket(ioc_);
        client_socket.connect(svr_.local_endpoint());
    
        constexpr size_t LEN = 65536;
        using Message        = std::array<int, LEN>;
        std::vector<Message> data(5);
    
        for (int i = 0; auto& msg : data)
            std::fill(msg.begin(), msg.end(), ++i);
    
        for (auto const& msg : data)
            svr_.send(boost::asio::buffer(msg));
    
        size_t total_received = 0;
        for ([[maybe_unused]] auto&& _ : data) {
            Message msg;
    
            boost::system::error_code ec;
            auto n = read(client_socket, boost::asio::buffer(msg), ec);
    
            if (ec == boost::asio::error::eof) {
                if (n == 0) // error condition without (partial) read success
                    break;
            } else {
                EXPECT_EQ(n, sizeof(Message));
                EXPECT_FALSE(ec.failed());
            }
    
            if (n) {
                ++total_received;
                EXPECT_EQ(n, sizeof(Message));
                EXPECT_EQ(msg.size(), std::count(msg.begin(), msg.end(), msg.front()));
            }
        }
    
        EXPECT_EQ(data.size(), total_received);
    }
    

    Showing the output

    Running main() from /opt/compiler-explorer/libs/googletest/release-1.10.0/googletest/src/gtest_main.cc
    [==========] Running 1 test from 1 test suite.
    [----------] Global test environment set-up.
    [----------] 1 test from TcpFixture
    [ RUN      ] TcpFixture.NoInterleavingAsyncWrite
    [       OK ] TcpFixture.NoInterleavingAsyncWrite (2 ms)
    [----------] 1 test from TcpFixture (2 ms total)
    
    [----------] Global test environment tear-down
    [==========] 1 test from 1 test suite ran. (2 ms total)
    [  PASSED  ] 1 test.
    

    Local Demo:

    enter image description here