Search code examples
c++boost-asio

TCP proxy - Closing the connection with server on receiving the client disconnect


In my proxy code, when a disconnection signal is received from the client side (with the proxy acting as the server), it doesn't properly terminate the associated session on the server side (where the proxy acts as a client). When the client reconnects, it establishes a new connection to the server through the proxy acting as a client. However, if the client repeatedly restarts due to a software bug, there is a potential concern of exhausting the server's maximum file descriptors, given that the proxy keeps initiating new connections as a client. I've implemented a function called close_client_connection that's designed to terminate the connection with the server when a socket disconnect event is detected from the client. Is this the correct approach, or do you have any alternative recommendations?

Handle this scenario in the below code
#include <ctime>
#include "proxy.h"
#include "parser.h"
#include "send.h"
#include "defines.h"
#include "log.h"

std::shared_ptr<OAmipTimerLocation> location_timer = NULL;

bridge::bridge(asio::io_context &io_context,
                         AppConfig *appconfig)
    : io_context_(io_context),
      acceptor_(io_context, asio::ip::tcp::endpoint(asio::ip::address::from_string(appconfig->serverConfig.ip), appconfig->serverConfig.port)),
      resolver_(io_context_),
      thread_pool_(),
      _stopped(false),
      app_config(appconfig)
{
}

bridge::~bridge()
{
    // Stop and join the io_context to prevent memory leaks
    io_context_.stop();
    for (auto &thread : thread_pool_)
    {
        thread.join();
    }
}

void bridge::stop(){
    acceptor_.cancel();
}

void bridge::start_server()
{
    try
    {

        socket_t socket = std::make_shared<boost::asio::ip::tcp::socket>(io_context_);
        acceptor_.async_accept(*socket, [this, socket](auto ec)
                               { 
                                    if(ec)
                                        log_error("Error accepting connection from : %s",ec.message().c_str());
                                    else
                                        this->handle_accept(ec, socket); });
    }
    catch (const std::exception &e)
    {
        log_error("Start_server, Exception caught: %s", e.what());
    }
}

void bridge::handle_accept(error_code const &error, socket_t upstream_socket)
{
    if (!error)
    {
        socket_t downstream_socket = std::make_shared<boost::asio::ip::tcp::socket>(io_context_);
        auto handler = std::make_shared<msg_handler>(io_context_, upstream_socket, downstream_socket, app_config);
        handler->start_async_connect();
    }
    else
    {
        log_error("Error! Error code = %d \t  Message : %s", error.value(), error.message().c_str());
    }
    if (!_stopped)
    {
        log_error("Connection closed by the client");
        start_server();
    }
    else
    {
        acceptor_.close();
    }
}

msg_handler::msg_handler(asio::io_context &io_context,
                         socket_t upstream_socket,
                         socket_t downstream_socket,
                         AppConfig *appConfig)
    : io_context_(io_context),
      server_socket_(upstream_socket),
      client_socket_(downstream_socket),
      resolver_(io_context_),
      app_config(appConfig)
{
}

msg_handler::~msg_handler()
{
    // Explicitly clear the buffer to release the allocated memory
    in_packet_.consume(in_packet_.size());
}
std::string msg_handler::get_server()
{
    std::string server_s = app_config->clientConfig.ip + ":" + std::to_string(app_config->clientConfig.port);
    return server_s;
}

void msg_handler::start_async_connect()
{
    auto upstream_host = app_config->clientConfig.ip;
    auto upstream_port = app_config->clientConfig.port;
    asio::ip::tcp::resolver::query query(upstream_host, std::to_string(upstream_port), asio::ip::resolver_query_base::numeric_service);
    asio::ip::tcp::resolver::iterator client_endpoint_iterator_ = resolver_.resolve(query);
    bool connected = false;
    int server_connect_retry_timeout = app_config->clientConfig.timeout;

    asio::async_connect(*client_socket_, client_endpoint_iterator_,
                        [me = shared_from_this(), server_connect_retry_timeout](const boost::system::error_code &ec, const asio::ip::tcp::resolver::iterator)
                        {
                            if (!ec)
                            {
                                log_info("Connected to the server %s",me->get_server().c_str());
                                // me->connected = true;
                                me->handle_server_connection(ec);
                            }
                            else
                            {

                                log_error("Async connect failed for server %s, Error:%s", me->get_server().c_str(), ec.message().c_str());
                                // me->connected = false;
                                me->stop();
                                log_info("Retrying in %d seconds", server_connect_retry_timeout);
                                std::this_thread::sleep_for(std::chrono::seconds(server_connect_retry_timeout));
                                me->start_async_connect();
                            }
                        });
}

void msg_handler::stop()
{
    if (server_socket_->is_open())
        server_socket_->cancel(); // closing would be a race condition
    if (client_socket_->is_open())
        client_socket_->cancel();
};
void msg_handler::close_client_connection(){
   if (client_socket_->is_open())
        client_socket_->cancel(); 
        client_socket_->close(); 
}
void cmd_handler::stop_timer_thread()
{
    if (location_timer)
        location_timer->setPeriod(0);
        location_timer = NULL;
}

void msg_handler::handle_server_connection(error_code const &error)
{
    if (!error)
    {
        read_cmd_from_client();
        read_cmd_from_server();
    }
    else
    {
        log_error("Connect Error! Error code = %d . Message : %s", error.value(), error.message().c_str());
        stop();
    }
}

asio::ip::tcp::endpoint msg_handler::remote_endpoint(boost::system::error_code &ec)
{
    return server_socket_->remote_endpoint(ec);
}

asio::ip::tcp::endpoint msg_handler::c_remote_endpoint(boost::system::error_code &ec)
{
    return client_socket_->remote_endpoint(ec);
}

void msg_handler::read_cmd_from_server()
{
    boost::system::error_code error_code;
    asio::ip::tcp::endpoint endpoint = c_remote_endpoint(error_code);
    auto upstream_host = app_config->clientConfig.ip;
    auto upstream_port = app_config->clientConfig.port;
    if (!error_code)
    {
        auto remote_ip = endpoint.address().to_string();
        auto port = endpoint.port();
        std::string key = remote_ip + ":" + std::to_string(port);
        log_debug("Read cmd from %s", key.c_str());
    }
    else
    {
        log_error("Error in remote endpoint server: %s", error_code.message().c_str());
        stop();
        start_async_connect();
        return;
    }
    asio::async_read_until(*client_socket_,
                           c_in_packet_,
                           '\n',
                           [me = shared_from_this()](boost::system::error_code const &ec, std::size_t bytes_xfer)
                           {
                               if (ec == asio::error::eof)
                               {
                                   log_error("Connection closed by server %s",me->get_server().c_str() );
                                   me->stop();
                   me->stop_timer_thread();
                                   me->start_async_connect();
                                   return; // No need to read further; the connection is closed.
                               }
                               else if (ec)
                               {
                                   log_error("Error in async_read_until: %s",ec.message().c_str());
                                   return;
                               }
                               else
                               {
                                   me->read_cmd_from_server_done(ec, bytes_xfer);
                               }
                           });
}

void msg_handler::read_cmd_from_server_done(boost::system::error_code const &ec, std::size_t bytes_transferred)
{
    if (ec == asio::error::eof)
    {

        log_error("Connection closed by Server");
        return; // No need to read further; the connection is closed.
    }
    else if (ec)
    {
        log_error("Error accepting packet from the server, Error: %s", ec.message().c_str());
        return;
    }

    log_debug("Reading socket fd: %d,Sending socket fd: %d",client_socket_->is_open(),server_socket_->is_open());
    std::string command(buffers_begin(c_in_packet_.data()), buffers_begin(c_in_packet_.data()) + bytes_transferred);
    c_in_packet_.consume(bytes_transferred);
    log_info("From server: <--- %s", command.c_str());
    CmdSender sender(server_socket_);
    log_info("To Client: ---> %s", recv_cmd.c_str());
    sender.send_cmd(recv_cmd);
    read_cmd_from_server();
}

void msg_handler::read_cmd_from_client()
{
    boost::system::error_code error_code;
    asio::ip::tcp::endpoint endpoint = remote_endpoint(error_code);
    std::string key = "";
    if (!error_code)
    {
        auto remote_ip = endpoint.address().to_string();
        auto port = endpoint.port();
        key = remote_ip + ":" + std::to_string(port);
        log_debug("Read cmd from %s", key.c_str());
    }
    else
    {
        if (server_socket_ && client_socket_)
            log_debug("Reading socket fd: %d,Sending socket fd:%d", server_socket_->is_open(), client_socket_->is_open());
        log_error("Error in remote endpoint client: %s", error_code.message().c_str());
        return;
    }
    asio::async_read_until(*server_socket_,
                           in_packet_,
                           '\n',
                           [me = shared_from_this(), key](boost::system::error_code const &ec, std::size_t bytes_xfer)
                           {
                               if (ec == asio::error::eof)
                               {
                                   log_error("Connection closed by client %s", key.c_str());
                                   me->close_client_connection();
                                   me->stop_timer_thread();
                                   return; // No need to read further; the connection is closed.
                               }
                               else if (ec)
                               {
                                   log_error("Error in async_read_until: %s", ec.message().c_str());
                                   return;
                               }
                               else
                               {
                                   me->read_cmd_from_client_done(ec, bytes_xfer);
                               }
                           });
}


void msg_handler::read_cmd_from_client_done(boost::system::error_code const &ec, std::size_t bytes_transferred)
{
    log_debug("Reading socket fd: %d,Sending socket fd:%d", server_socket_->is_open(), client_socket_->is_open());
    if (ec == asio::error::eof)
    {

        log_error("Connection closed by client");
        stop_timer_thread();
        close_client_connection();
        return; // No need to read further; the connection is closed.
    }
    else if (ec)
    {

        log_error("Error accepting packet from the client: %s", ec.message().c_str());
        return;
    }
    std::string command(buffers_begin(in_packet_.data()), buffers_begin(in_packet_.data()) + bytes_transferred);
    in_packet_.consume(bytes_transferred);
    log_info("From Client: <--- %s", command.c_str());

    Parser parser(app_config);
    bool W_cmd = parser.check_for_cmd(command, 'W');

    CmdSender sender(client_socket_);
    log_info("To Server: ---> %s", recv_cmd.c_str());
    sender.send_cmd(recv_cmd);
    if (recv_cmd == "period"){
        int period = app_config->serverConfig.default_location_period);
        if (!location_timer)
            location_timer = std::make_shared<OAmipTimerLocation>(
                            io_context_, server_socket_, app_config);
        location_timer->setPeriod(period);
    }
    read_cmd_from_client();
}
int main()
{
        int thread_count = 1;
        int port = 5002;
        int s_port = 5004;
        std::vector<std::thread> thread_pool;
        std::string ip_addr = "127.0.0.1";
        std::string s_ip_addr = "127.0.0.1";
    AppConfig appConfig;
        appConfig.read_config(config_file, &appConfig);
        ProxyServer proxy_server(io_context, &appConfig);
    
        proxy_server.start_server();
        int num_threads = 4; // Adjust the number of threads based on your requirements
        for (int i = 0; i < num_threads; ++i)
        {
            thread_pool.emplace_back([&]()
                                     { io_context.run(); });
        }
        for (auto &thread : thread_pool)
        {
            thread.join();
        }
    }
    catch (const std::exception &e)
    {
        pac_error("Exception caught: {%s}", e.what());
    }

    return 0;
}

Solution

  • Sorry the code is incomplete. Right now my analysis is likely that is_open() calls are a race condition and calling cancel() is a data race. There isn't even enough code to see which of the sockets matches upstream_socket in your msg_handler. Doing a blocking DNS resolve in "start_async_connect" is ... a lie and asking for stability issues. I don't understand why the sockets are pointers if the entire msg_handler is shared_from_this ...

    These were just the things that caught my eye when looking for the place where you handle client EOF/disconnect.

    And having found that,

    • why is that logic duplicated? It's in the lambda, but repeated in read_cmd_from_client_done

    • you have the explicit

       if (ec == asio::error::eof)
       {
      
           log_error("Connection closed by client");
           close_client_connection();
           return; // No need to read further; the connection is closed.
       }
      

      It would seem that your bug is fixed when you add something like close_server_connection()? I mean, that's the missing behaviour, right?

    • Note that if the server might still send things back (unclear because it depends on protocol) you might want to make sure that is completed.

    • This completion handler as is currently discards data read with partial success (error == eof). That's ... probably also a bug.

    UPDATE

    I minimized and cleaned up the code from the previous answer that you based your code on. A number of the things I mentioned above were merely copy/pasted from that code - which ... wasn't so good¹.

    So I got rid of the unnecessary pointers and unscrupulous cancel() invocations. Below, I've even added DNS resolution for the upstream - but asynchronously, as required²:

    void start(Settings::endpoint const& upstream) {
        res.async_resolve(
            upstream.query(),
            [=, this, self = shared_from_this()](error_code ec, tcp::resolver::results_type eps) {
                if (ec)
                    std::cout << "Could not resolve upstream (" << upstream << "): " << ec.message()
                              << std::endl;
                else
                    asio::async_connect(
                        up, eps, [this, self = shared_from_this()](error_code ec, tcp::endpoint /*ep*/) {
                            if (!ec) {
                                up_down.start(self);
                                down_up.start(self);
                            } else {
                                std::cerr << "Connect: " << ec.message() << std::endl;
                            };
                        });
            });
    }
    

    Next up, I removed code duplication between the up/down half-duplex transfer loops with a Pump type:

    struct Pump {
        using Ref = std::shared_ptr<void const>;
        Pump(socket_t& source, socket_t& dest) : src(source), dst(dest) {}
    
        void pump(Ref self) {
            src.async_read_some(
                asio::buffer(buf), [=, this](error_code ec, size_t n) {
                    bool const eof = (ec == asio::error::eof);
                    // std::cout << "Read(" << n << ") " << *this << ": " << ec.message() << std::endl;
                    if (!ec || eof) {
                        async_write(dst, asio::buffer(buf, n), [=, this](error_code ec, size_t) {
                            if (eof) {
                                dst.shutdown(socket_t::shutdown_send, ec);
                                std::cout << "Shutdown " << *this << ": " << ec.message() << std::endl;
                            } else if (!ec)
                                pump(self);
                        });
                    } else {
                        dst.shutdown(socket_t::shutdown_send, ec);
                        std::cout << "Shutdown " << *this << ": " << ec.message() << std::endl;
                    }
                });
        }
    
        socket_t &             src, &dst;
        tcp::endpoint          sep, dep;
        std::array<char, 4096> buf;
    };
    

    Note the logic

    • that performs a shutdown_send on the opposite end of the channel if there is EOF or an error
    • Note especially how the shutdown is only performed after writing any data received with the EOF condition
    • Perhaps superfluously, note that the sockets are closed by shared_from_this going out of scope, destructing the Bridge. I've added minor console logging to make this visible in demo

    Now we instantiate Pumps for both directions:

    socket_t down, up;
    Pump up_down{up, down}, down_up{down, up};
    

    And we fire both up after the Bridge connects (as shown already above). The fact that both directions share the exact same code make sure we treat failure/close on either end with the same fidelity.

    Full Listing

    Live On Coliru

    #include <boost/asio.hpp>
    #include <iostream>
    
    using namespace std::placeholders;
    namespace asio = boost::asio;
    using asio::ip::tcp;
    using error_code = boost::system::error_code;
    using socket_t   = tcp::socket;
    
    struct Settings {
        struct endpoint {
            std::string host;
            uint16_t    port;
    
            friend std::ostream& operator<<(std::ostream& os, endpoint const& ep) { return os << ep.host << ":" << ep.port; }
            // conversions
            tcp::resolver::query query() const { return {host, std::to_string(port)}; }
            operator tcp::endpoint() const { return {asio::ip::address::from_string(host), port}; }
        } remote, local;
    };
    
    class Bridge : public std::enable_shared_from_this<Bridge> {
      public:
        Bridge(socket_t downstream_socket) : down(std::move(downstream_socket)) {}
        ~Bridge() { std::cout << __FUNCTION__ << " " << down_up << "/" << up_down << std::endl; }
    
        void start(Settings::endpoint const& upstream) {
            res.async_resolve(
                upstream.query(),
                [=, this, self = shared_from_this()](error_code ec, tcp::resolver::results_type eps) {
                    if (ec)
                        std::cout << "Could not resolve upstream (" << upstream << "): " << ec.message()
                                  << std::endl;
                    else
                        asio::async_connect(
                            up, eps, [this, self = shared_from_this()](error_code ec, tcp::endpoint /*ep*/) {
                                if (!ec) {
                                    up_down.start(self);
                                    down_up.start(self);
                                } else {
                                    std::cerr << "Connect: " << ec.message() << std::endl;
                                };
                            });
                });
        }
    
      private:
        socket_t down, up{down.get_executor()};
        tcp::resolver res{down.get_executor()};
    
        struct Pump {
            using Ref = std::shared_ptr<void const>;
            Pump(socket_t& source, socket_t& dest) : src(source), dst(dest) {}
    
            void start(Ref self) {
                sep = src.remote_endpoint();
                dep = dst.remote_endpoint();
                std::cout << "Starting pump " << *this << std::endl;
                pump(self);
            }
    
          private:
            void pump(Ref self) {
                src.async_read_some(
                    asio::buffer(buf), [=, this](error_code ec, size_t n) {
                        bool const eof = (ec == asio::error::eof);
                        // std::cout << "Read(" << n << ") " << *this << ": " << ec.message() << std::endl;
                        if (!ec || eof) {
                            async_write(dst, asio::buffer(buf, n), [=, this](error_code ec, size_t) {
                                if (eof) {
                                    dst.shutdown(socket_t::shutdown_send, ec);
                                    std::cout << "Shutdown " << *this << ": " << ec.message() << std::endl;
                                } else if (!ec)
                                    pump(self);
                            });
                        } else {
                            dst.shutdown(socket_t::shutdown_send, ec);
                            std::cout << "Shutdown " << *this << ": " << ec.message() << std::endl;
                        }
                    });
            }
    
            socket_t &             src, &dst;
            tcp::endpoint          sep, dep;
            std::array<char, 4096> buf;
            friend std::ostream&   operator<<(std::ostream& os, Pump const& d) { return os << "(" << d.sep << " -> " << d.dep << ")"; }
        } up_down{up, down}, down_up{down, up};
    };
    
    class Acceptor {
      private:
        Settings      cfg_;
        tcp::acceptor acc_;
    
      public:
        Acceptor(asio::any_io_executor ex, Settings cfg)
            : cfg_(std::move(cfg))
            , acc_(ex, {asio::ip::address::from_string(cfg_.local.host), cfg_.local.port}) //
        {
            acc_.listen();
            accept_loop();
        }
    
      private:
        void accept_loop() {
            acc_.async_accept(make_strand(acc_.get_executor()), [this](error_code ec, socket_t s) {
                if (!ec) {
                    std::cout << "New session from " << s.remote_endpoint() << std::endl;
                    std::make_shared<Bridge>(std::move(s))->start(cfg_.remote);
                    accept_loop();
                } else {
                    std::cerr << "Error: " << ec.message() << std::endl;
                }
            });
        }
    };
    
    int main() {
        std::cout << "Proxy started." << std::endl;
        Settings cfg {
            .remote = {"127.0.0.1", 3306},
            .local = {"127.0.0.1", 9000},
        };
    
        std::cout << "cfg.local  = " << cfg.local << " cfg.remote = " << cfg.remote << std::endl;
    
        asio::io_context ioc;
        Acceptor proxy(ioc.get_executor(), cfg);
        ioc.run();
    }
    

    A local demo proxying MySql connections, showing interruptions of either server or client:

    • first, the mysql client tries to connect before the proxy is started, demonstrating that it fails to connect
    • the proxy is started
    • now the mysql client can connect and responds interactively
    • exiting mysql client is reflected in the console output and shows ~Bridge freeing all resources
    • starting another client runs another full session
    • running a very long running client session that is interrupted with Ctrl-C is accurately handled, and the proxy server keeps accepting new connections
    • running the same very long running client session that is interrupted by force-restarting Mysql server is accurately handled
    • attempting to connect the client before the remote Mysql server is ready to accept connections fails with Connection refused logged
    • a later connection succeeds again

    (click for full size video)


    ¹ (perhaps why the original bitbucket project has since been deleted) ² it's required because the server is single-threaded and we cannot afford to block all connections on a potentially slow DNS query