Search code examples
c++boostboost-asioshared-ptr

C++ boost::asio bad_weak_ptr when using shared_from_this


I'm building an asynchronous socket server using boost::asio and getting into trouble regarding pointers.

SocketServer.hpp

#pragma once

#include <memory>
#include <boost/asio.hpp>

#include "SocketConnection.hpp"

#define SOCKET_PORT 8080

#define BOOST_ASIO_ENABLE_HANDLER_TRACKING 1 // Enable tracking

class SocketServer : public std::enable_shared_from_this<SocketServer>
{
private:
    boost::asio::io_context _ioc;
    boost::asio::ip::tcp::acceptor* _acceptor;

    std::vector<std::weak_ptr<SocketConnection>> _registered;
    std::mutex _mutex;

public:
    SocketServer();
    ~SocketServer();
    void accept_loop();
    void register_connection(std::weak_ptr<SocketConnection> connection);
    void broadcast(std::string const &msg);
    void run();
};

SocketServer.cpp

#include <algorithm>
#include <cstdlib>
#include <functional>
#include <iostream>
#include <memory>
#include <string>
#include <thread>
#include <vector>

#include "SocketServer.hpp"


SocketServer::SocketServer()
{
    _acceptor = new boost::asio::ip::tcp::acceptor(_ioc, boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(), SOCKET_PORT));
    accept_loop();
}

SocketServer::~SocketServer()
{
    std::cout << "Server destructor" << std::endl;
}

void SocketServer::register_connection(std::weak_ptr<SocketConnection> connection)
{
    std::lock_guard<std::mutex> lock(_mutex); // Will be released automatically
    _registered.push_back(connection);
}

void SocketServer::accept_loop()
{
    _acceptor->async_accept(
        [this](boost::system::error_code ec, boost::asio::ip::tcp::socket socket)
        {
            std::cout << "ACCEPT LOOP " << ec << std::endl;
            if (!ec)
            {
                auto connection = std::make_shared<SocketConnection>(std::move(socket));
                std::cout << "Connection accepted from " << connection->get_remote_endpoint() << " (" << ec.message() << ")" << std::endl;
                register_connection(connection);
                connection->run();
            }

            accept_loop();
        });
}

void SocketServer::broadcast(std::string const &msg)
{
    std::lock_guard<std::mutex> lock(_mutex);

    std::vector<std::weak_ptr<SocketConnection>> active;

    for (auto &conn : _registered)
    {
        auto c = conn.lock();
        c->do_write(msg);
    }
}

void SocketServer::run()
{
    _ioc.run();
}

SocketConnection.hpp

#pragma once

#include <memory>
#include <iostream>
#include <boost/asio.hpp>

#define MAX_BUFFER_LEN 4096

class SocketConnection : public std::enable_shared_from_this<SocketConnection>
{
private:
    boost::asio::ip::tcp::socket _socket;
    char read_buffer[MAX_BUFFER_LEN];
    char write_buffer[MAX_BUFFER_LEN];


public:
    SocketConnection(boost::asio::ip::tcp::socket socket);
    ~SocketConnection();
    boost::asio::ip::tcp::endpoint get_remote_endpoint();
    void do_read();
    void on_read();
    void do_write(std::string msg);
    void on_write();
    void run();
};

SocketConnection.cpp

#include "SocketConnection.hpp"


SocketConnection::SocketConnection(boost::asio::ip::tcp::socket socket)
    : _socket(std::move(socket)) // socket cannot be copied
{
  do_read();
}

SocketConnection::~SocketConnection()
{
    std::cout << "SocketConnection destructor" << std::endl;
}

boost::asio::ip::tcp::endpoint SocketConnection::get_remote_endpoint()
{
  return _socket.remote_endpoint();
}

void SocketConnection::do_read()
{
  std::cout << "GETTING HERE BEFORE CRASH" << std::endl;
  auto self(shared_from_this());
  std::cout << "DOES NOT GET HERE" << std::endl;

  _socket.async_read_some(boost::asio::buffer(std::ref(read_buffer), MAX_BUFFER_LEN),
                          [this, self](boost::system::error_code ec, std::size_t length)
                          {
                            std::cout << "INSIDE DO_READ: " << ec << " " << ec.message() << std::endl;

                            if (!ec)
                            {
                              on_read();
                            }
                          });
}

void SocketConnection::on_read()
{
  std::string s(read_buffer);
  std::string reply = "REPLY: " + s;
  do_write(reply);
}

void SocketConnection::do_write(std::string msg)
{
  auto self(shared_from_this());

  strcpy(write_buffer, msg.c_str());

  boost::asio::async_write(_socket, boost::asio::buffer(std::ref(write_buffer), strlen(write_buffer)),
                           [this, self](boost::system::error_code ec, std::size_t /*length*/)
                           {
                             std::cout << "INSIDE DO_WRITE: " << ec << std::endl;

                             if (!ec)
                             {
                               on_write();
                             }
                           });
}

void SocketConnection::on_write()
{
  std::string s(write_buffer);
}

void SocketConnection::run()
{
    std::cout << "Running" << std::endl;
}

I'm using Ubuntu Linux and issuing the following command:

$ telnet 127.0.0.1 8080
Trying 127.0.0.1...
Connected to 127.0.0.1.
Escape character is '^]'.
Connection closed by foreign host.

And getting the following error:

ACCEPT LOOP system:0
GETTING HERE BEFORE CRASH
terminate called after throwing an instance of 'std::bad_weak_ptr'
  what():  bad_weak_ptr
Aborted (core dumped)

Why am I'm getting bad_weak_ptr? I think I've set correctly class SocketConnection : public std::enable_shared_from_this<SocketConnection> ? What am I missing?


Solution

  • As others have pointed out shared_from_this() is illegal in the constructor.

    The common pattern is to have a two-step initiation for async operations. You already had it, but not doing anything, so move the do_read():

    void SocketConnection::run() {
        std::cout << "Running" << std::endl;
        do_read();
    }
    

    Side Note(s)

    • Don't put ODR-changing defines in some header that may not be included in all translation units:

       #define BOOST_ASIO_ENABLE_HANDLER_TRACKING 1 // Enable tracking
      

      See why "#define BOOST_ASIO_ENABLE_HANDLER_TRACKING 1" doesn't work as expected? and Valgrind errors from boost::asio

    • Don't [unnecessarily] use new/delete.

    • the std::ref() is out of place.

    • don't use unbound-checked strcpy

    • don't copy at all?

    • don't assume data doesn't contain embedded NUL characters

    • make sure writes don't overlap, especially since all writes use the same shared buffer!

    • SocketServer is enable_shared_from_this but never uses shared_from_this. Drop the unused complexity

    • the mutex implies multiple threads. Consider using a strand for the connections

    • get_remote_endpoint may throw if the connection isn't active/valid

    • minimize lock duration

    • consider garbage collecting expired connection

    • anticipate weak_ptr::lock() returning nullptr

       void SocketServer::broadcast(std::string const& msg) {
           auto snapshot = [&] {
               std::lock_guard<std::mutex> lock(_mutex);
               return _registered;
           }();
      
           for (auto& conn : snapshot)
               if (auto c = conn.lock())
                   c->send(msg);
       }
      

    Addressing all of the above:

    Live On Coliru

    // #pragma once
    
    #include <boost/asio.hpp>
    #include <deque>
    #include <iomanip>
    #include <iostream>
    #include <list>
    namespace asio = boost::asio;
    using asio::ip::tcp;
    using boost::system::error_code;
    static constexpr inline uint16_t SOCKET_PORT = 8080;
    
    class SocketConnection : public std::enable_shared_from_this<SocketConnection> {
      private:
        tcp::socket             socket_;
        std::array<char, 4086>  read_buffer_;
        std::deque<std::string> outbox_;
        tcp::endpoint           endpoint_;
    
      public:
        SocketConnection(tcp::socket socket) //
            : socket_(std::move(socket))
            , endpoint_(socket_.remote_endpoint()) {
        }
    
        tcp::endpoint get_remote_endpoint() {
            return endpoint_;
        }
    
        void send(std::string msg);
        void run();
    
      private:
        void do_read_loop();
        void do_write_loop();
    
        void on_read(std::string msg);
        void on_write(std::string msg);
    };
    // #pragma once
    
    // #include "SocketConnection.hpp"
    
    class SocketServer {
      private:
        asio::io_context _ioc;
        tcp::acceptor    _acceptor{_ioc, {{}, SOCKET_PORT}};
    
        using Handle = std::weak_ptr<SocketConnection>;
        std::list<Handle> _registered;
        std::mutex        _mutex;
    
      public:
        SocketServer();
        void accept_loop();
        void register_connection(Handle connection);
        void broadcast(std::string const& msg);
        void run();
    };
    
    // #include "SocketConnection.hpp"
    
    void SocketConnection::send(std::string msg) {
        asio::post(socket_.get_executor(), [this, self = shared_from_this(), m = std::move(msg)]() mutable {
            // on strand
            outbox_.push_back(std::move(m));
            if (outbox_.size() == 1)
                do_write_loop();
        });
    }
    
    void SocketConnection::run() {
        std::cout << "Running" << std::endl;
        asio::post(socket_.get_executor(), //
                   [self = shared_from_this()] { self->do_read_loop(); });
    }
    
    void SocketConnection::on_read(std::string msg) {
        std::cout << "on_read: " << quoted(msg) << std::endl;
        send("REPLY: " + std::move(msg));
    }
    
    void SocketConnection::do_write_loop() {
        if (outbox_.empty())
            return;
    
        asio::async_write( //
            socket_, asio::buffer(outbox_.front()),
            [this, self = shared_from_this()](error_code ec, size_t /*length*/) {
                std::cout << "INSIDE DO_WRITE: " << ec.message() << std::endl;
    
                on_write(std::move(outbox_.front()));
                outbox_.pop_front();
                if (!ec)
                    do_write_loop();
    
            });
    }
    
    void SocketConnection::do_read_loop() {
        socket_.async_read_some( //
            asio::buffer(read_buffer_), [this, self = shared_from_this()](error_code ec, size_t length) {
                std::cout << "INSIDE DO_READ_LOOP: " << ec.message() << std::endl;
    
                if (!ec) {
                    on_read({read_buffer_.data(), length});
                    do_read_loop();
                }
            });
    }
    
    void SocketConnection::on_write(std::string /*s*/) {
    }
    
    #include <algorithm>
    #include <cstdlib>
    #include <functional>
    #include <iostream>
    #include <memory>
    #include <string>
    #include <thread>
    #include <vector>
    
    // #include "SocketServer.hpp"
    
    SocketServer::SocketServer() {
        accept_loop();
    }
    
    void SocketServer::register_connection(Handle connection) {
        std::lock_guard<std::mutex> lock(_mutex);
        // garbage collect
        _registered.remove_if(std::mem_fn(&Handle::expired));
        _registered.push_back(std::move(connection));
    }
    
    void SocketServer::accept_loop() {
        _acceptor.async_accept(make_strand(_acceptor.get_executor()), [this](error_code ec, tcp::socket socket) {
            std::cout << __FUNCTION__ << " " << ec.message() << std::endl;
            if (!ec) {
                auto conn = std::make_shared<SocketConnection>(std::move(socket));
                std::cout << "Accepted from " << conn->get_remote_endpoint() << std::endl;
                register_connection(conn);
                conn->run();
            }
    
            accept_loop();
        });
    }
    
    void SocketServer::broadcast(std::string const& msg) {
        auto snapshot = [&] {
            std::lock_guard<std::mutex> lock(_mutex);
            return _registered;
        }();
    
        for (auto& conn : snapshot)
            if (auto c = conn.lock())
                c->send(msg);
    }
    
    void SocketServer::run() {
        _ioc.run();
    }
    
    int main() {
        SocketServer ss;
        std::thread  sthread([&] { ss.run(); });
    
        for (int i = 1; ; ++i) {
            std::this_thread::sleep_for(std::chrono::seconds(3));
            ss.broadcast("Ping " + std::to_string(i) + "\n");
        }
    
        sthread.join();
    }
    

    Local testing:

    enter image description here