Search code examples
c++shared-ptrsmart-pointers

Simplified shared_ptr implementation


I am attempting to implement a simplified shared_ptr that does not include weak_ptr functionality. It does support make_shared to perform only a single application (in fact, this is the only way to create a new shared_ptr in my implementation).

The problem is, something about my implementation causes my code to crash (hard fault on an ARM MCU).

/**
 * @file shared_ptr.h
 *
 */

#ifndef TW_SHARED_PTR_H_
#define TW_SHARED_PTR_H_

//==============================================================================
// INCLUDES
//==============================================================================

#include "_common.h"

/*
 * If TW_CONFIG_USE_STD_SHARED_PTR == 1, simply define tw::shared_ptr (and other
 * helper classes/methods) as wrappers around std::shared_ptr
 */
#if TW_CONFIG_USE_STD_SHARED_PTR == 1

#include <memory>

namespace tw {

template <typename T>
using shared_ptr = std::shared_ptr<T>;

template <typename T>
using enable_shared_from_this = std::enable_shared_from_this<T>;

template <typename T, typename ... ARGS>
inline shared_ptr<T> make_shared(ARGS&&... args) {
    return std::make_shared<T>(std::forward<ARGS>(args)...);
}

} //namespace tw

#else //TW_CONFIG_USE_STD_SHARED_PTR == 0

#include <atomic>
#include <type_traits>

//==============================================================================
// DEFINES
//==============================================================================

#ifndef TW_SHARED_PTR_ASSERT
    #include <cstdlib>
    #define TW_SHARED_PTR_ASSERT(x_) if (!(x_)) abort()
#endif

namespace tw {

//==============================================================================
// CLASSES
//==============================================================================

template <typename T>
class shared_ptr;

class _enable_shared_from_this;

template <typename T>
class enable_shared_from_this;

//-----[ CLASS: _shared_ptr_base ]-------------------------------------------------
class _shared_ptr_base {
    template <typename T_, typename ... ARGS_>
    friend
    shared_ptr<T_> make_shared(ARGS_&&...);

    friend class _enable_shared_from_this;

protected:
    using ref_count_type = std::atomic<unsigned int>;

    struct control_block {
        ref_count_type refCount;

        struct {
            const void* p;
            void(*destroy)(const void*);

            inline void operator()(){ destroy(p); }
            inline operator bool() { return (p != nullptr) && (destroy != nullptr); }
        } destructor;
    };
};

//-----[ TEMPLATE CLASS: shared_ptr<T> ]-----------------------------------------
template <typename T>
class shared_ptr : public _shared_ptr_base {
    template <typename T_>
    friend
    class shared_ptr;

    template <typename T_>
    friend
    class enable_shared_from_this;

public:
    using value_type = T;

public:
    shared_ptr():
        _block(nullptr),
        _ptr(nullptr)
    {

    }

    ~shared_ptr() {
        reset();
    }

    //Converting copy constructor
    template <typename U_>
    shared_ptr(const shared_ptr<U_>& other):
        _block(other._block),
        _ptr(other._ptr)
    {
        if (_block != nullptr) {
            TW_SHARED_PTR_ASSERT(_ptr != nullptr);
            //Increment ref count
            ++(_block->refCount);
        }
    }

    //Converting move constructor
    template <typename U_>
    shared_ptr(shared_ptr<U_>&& other):
        _block(other._block),
        _ptr(other._ptr)
    {
        other._block = nullptr;
        other._ptr = nullptr;
    }


public:
    //Converting copy assignment operator
    template <typename U_>
    shared_ptr& operator=(const shared_ptr<U_>& rhs) {
        if (static_cast<void*>(this) != static_cast<void*>(&rhs)) {
            reset();

            _block = rhs._block;
            _ptr = rhs._ptr;

            if (_block != nullptr) {
                TW_SHARED_PTR_ASSERT(_ptr != nullptr);
                //Increment ref count
                ++(_block->refCount);
            }
        }

        return *this;
    }

    //Converting move assignment operator
    template <typename U_>
    shared_ptr& operator=(shared_ptr<U_>&& rhs) {
        if (static_cast<void*>(this) != static_cast<void*>(&rhs)) {
            reset();

            _block = rhs._block;
            _ptr = rhs._ptr;

            rhs._block = nullptr;
            rhs._ptr = nullptr;
        }

        return *this;
    }

    inline T* operator->() const noexcept {
        TW_SHARED_PTR_ASSERT(_ptr != nullptr);
        return _ptr;
    }
    inline operator T*() const noexcept { return _ptr; }
    inline T& operator*() const noexcept {
        TW_SHARED_PTR_ASSERT(_ptr != nullptr);
        return *_ptr;
    }

    inline operator bool() const noexcept { return (_ptr != nullptr); }

public:
    inline T* get() const noexcept { return _ptr; }

    void reset() {
        if (_block != nullptr) {
            TW_SHARED_PTR_ASSERT(_ptr != nullptr);
            if (--(_block->refCount) == 0) {
                TW_SHARED_PTR_ASSERT(_block->destructor);
                //Free resources by calling deleter
                _block->destructor();
            }
        }

        _block = nullptr;
        _ptr = nullptr;
    }

    inline ref_count_type::value_type use_count() const noexcept {
        return (_block == nullptr) ? 0 : _block->refCount.load();
    }

protected:
    shared_ptr(control_block* block, value_type* ptr):
        _block(block),
        _ptr(ptr)
    {
        if (_block != nullptr) {
            TW_SHARED_PTR_ASSERT(_ptr != nullptr);
            //Increment ref count
            ++(_block->refCount);
        }
    }

private:
    control_block* _block;
    value_type* _ptr;

    template <typename T_, typename ... ARGS_>
    friend
    shared_ptr<T_> make_shared(ARGS_&&...);
};

//-----[ CLASS: _enable_shared_from_this ]--------------------------------------
class _enable_shared_from_this {
protected:
    _enable_shared_from_this():
        _block(nullptr)
    {

    }

protected:
    _shared_ptr_base::control_block* _block;
};

//-----[ TEMPLATE CLASS: enable_shared_from_this ]------------------------------
template <typename T>
class enable_shared_from_this : public _enable_shared_from_this {
    template <typename T_, typename ... ARGS_>
    friend
    shared_ptr<T_> make_shared(ARGS_&&...);

public:
    shared_ptr<T> shared_from_this() noexcept {
        return shared_ptr<T>(_block, static_cast<T*>(this));
    }

    shared_ptr<const T> shared_from_this() const noexcept {
        return shared_ptr<const T>(_block, static_cast<const T*>(this));
    }
};

//==============================================================================
// FUNCTIONS
//==============================================================================

template <typename T, typename ... ARGS>
shared_ptr<T> make_shared(ARGS&&... args) {
    struct FullBlock {
        _shared_ptr_base::control_block block;
        T value;
    };

    //Allocate block on heap
    auto block = new FullBlock{
        {0},
        T{std::forward<ARGS>(args)...} //value
    };

    block->block.destructor.p = block;
    block->block.destructor.destroy = [](const void* x){
        delete static_cast<const FullBlock*>(x);
    };
    if constexpr (std::is_base_of_v<_enable_shared_from_this, T>) {
        block->value._block = &block->block;
    }

    /*
     * Up until this point, the make_shared function "owns" the pointer to
     * 'block'.  It now "transfers" ownership of this pointer to a shared_ptr
     * instance.
     */
    return shared_ptr<T>(&block->block, &block->value);
}

} //namespace tw

#endif

#endif /* TW_SHARED_PTR_H_ */

The define TW_CONFIG_USE_STD_SHARED_PTR allows me to swap between my shared_ptr implementation and the standard library's implementation. The standard library's implementation does NOT cause a hard fault in my application so I know that there must be something wrong with my implementation.

Can anyone spot any obvious problems in my implementation?


Solution

  • Even though this question was a bit open ended without an straight forward answer, those who commented did help me out quite a bit. @WhozCraig pointed out that there was not copy or move constructors. I did not realize this was the case as I thought that my "converting" (i.e., templated) copy and move constructors would also cover the default case (same pointer). However, I guess this is NOT how the C++ compiler handles this.

    So, for reference, here is a version of my pointer which works correctly:

    /**
     * @file shared_ptr.h
     *
     */
    
    #ifndef TW_SHARED_PTR_H_
    #define TW_SHARED_PTR_H_
    
    //==============================================================================
    // INCLUDES
    //==============================================================================
    
    /*
     * If TW_CONFIG_USE_STD_SHARED_PTR == 1, simply define tw::shared_ptr (and other
     * helper classes/methods) as wrappers around std::shared_ptr
     */
    #if TW_CONFIG_USE_STD_SHARED_PTR == 1
    
    #include <memory>
    
    namespace tw {
    
    template <typename T>
    using shared_ptr = std::shared_ptr<T>;
    
    template <typename T>
    using enable_shared_from_this = std::enable_shared_from_this<T>;
    
    template <typename T, typename ... ARGS>
    inline shared_ptr<T> make_shared(ARGS&&... args) {
        return std::make_shared<T>(std::forward<ARGS>(args)...);
    }
    
    } //namespace tw
    
    #else //TW_CONFIG_USE_STD_SHARED_PTR == 0
    
    #include <atomic>
    #include <type_traits>
    
    //==============================================================================
    // DEFINES
    //==============================================================================
    
    #ifndef TW_SHARED_PTR_ASSERT
        #include <cstdlib>
        #define TW_SHARED_PTR_ASSERT(x_) if (!(x_)) abort()
    #endif
    
    namespace tw {
    
    //==============================================================================
    // CLASSES
    //==============================================================================
    
    template <typename T>
    class shared_ptr;
    
    class _enable_shared_from_this;
    
    template <typename T>
    class enable_shared_from_this;
    
    //-----[ CLASS: _shared_ptr_base ]-------------------------------------------------
    class _shared_ptr_base {
        template <typename T_, typename ... ARGS_>
        friend
        shared_ptr<T_> make_shared(ARGS_&&...);
    
        friend class _enable_shared_from_this;
    
    protected:
        using ref_count_type = std::atomic<unsigned int>;
    
        struct control_block {
            ref_count_type refCount;
    
            struct {
                const void* p;
                void(*destroy)(const void*);
    
                inline void operator()(){ destroy(p); }
                inline operator bool() { return (p != nullptr) && (destroy != nullptr); }
            } destructor;
        };
    };
    
    //-----[ TEMPLATE CLASS: shared_ptr<T> ]-----------------------------------------
    template <typename T>
    class shared_ptr : public _shared_ptr_base {
        template <typename T_>
        friend
        class shared_ptr;
    
        template <typename T_>
        friend
        class enable_shared_from_this;
    
    public:
        using value_type = T;
    
    public:
        /**
         * @brief Default constructor
         *
         * By default, a shared_ptr is null (nullptr)
         */
        shared_ptr():
            _block(nullptr),
            _ptr(nullptr)
        {
    
        }
    
        /**
         * @brief Destructor
         */
        ~shared_ptr() {
            reset();
        }
    
        /**
         * @brief Copy constructor
         */
        shared_ptr(const shared_ptr& other):
            _block(other._block),
            _ptr(other._ptr)
        {
            _incr();
        }
    
        /**
         * @brief Converting copy constructor
         */
        template <typename U_>
        shared_ptr(const shared_ptr<U_>& other):
            _block(other._block),
            _ptr(other._ptr)
        {
            _incr();
        }
    
        /**
         * @brief Move constructor
         */
        shared_ptr(shared_ptr&& other):
            _block(other._block),
            _ptr(other._ptr)
        {
            other._block = nullptr;
            other._ptr = nullptr;
        }
    
        /**
         * @brief Converting move constructor
         */
        template <typename U_>
        shared_ptr(shared_ptr<U_>&& other):
            _block(other._block),
            _ptr(other._ptr)
        {
            other._block = nullptr;
            other._ptr = nullptr;
        }
    
    
    public:
        /**
         * @brief Copy assignment operator
         */
        shared_ptr& operator=(const shared_ptr& rhs) {
            if (static_cast<void*>(this) != static_cast<void*>(&rhs)) {
                reset();
    
                _block = rhs._block;
                _ptr = rhs._ptr;
    
                _incr();
            }
    
            return *this;
        }
    
        /**
         * @brief Converting copy assignment operator
         */
        template <typename U_>
        shared_ptr& operator=(const shared_ptr<U_>& rhs) {
            if (static_cast<void*>(this) != static_cast<void*>(&rhs)) {
                reset();
    
                _block = rhs._block;
                _ptr = rhs._ptr;
    
                _incr();
            }
    
            return *this;
        }
    
        /**
         * @brief Move assignment operator
         */
        shared_ptr& operator=(shared_ptr&& rhs) {
            if (static_cast<void*>(this) != static_cast<void*>(&rhs)) {
                reset();
    
                _block = rhs._block;
                _ptr = rhs._ptr;
    
                rhs._block = nullptr;
                rhs._ptr = nullptr;
            }
    
            return *this;
        }
    
        /**
         * @brief Converting move assignment operator
         */
        template <typename U_>
        shared_ptr& operator=(shared_ptr<U_>&& rhs) {
            if (static_cast<void*>(this) != static_cast<void*>(&rhs)) {
                reset();
    
                _block = rhs._block;
                _ptr = rhs._ptr;
    
                rhs._block = nullptr;
                rhs._ptr = nullptr;
            }
    
            return *this;
        }
    
        inline T* operator->() const noexcept {
            TW_SHARED_PTR_ASSERT(_ptr != nullptr);
            return _ptr;
        }
        inline operator T*() const noexcept { return _ptr; }
        inline T& operator*() const noexcept {
            TW_SHARED_PTR_ASSERT(_ptr != nullptr);
            return *_ptr;
        }
    
        inline operator bool() const noexcept { return (_ptr != nullptr); }
    
    public:
        inline T* get() const noexcept { return _ptr; }
    
        void reset() {
            if (_block != nullptr) {
                TW_SHARED_PTR_ASSERT(_ptr != nullptr);
                if (--(_block->refCount) == 0) {
                    TW_SHARED_PTR_ASSERT(_block->destructor);
                    //Free resources by calling deleter
                    _block->destructor();
                }
            }
    
            _block = nullptr;
            _ptr = nullptr;
        }
    
        inline ref_count_type::value_type use_count() const noexcept {
            return (_block == nullptr) ? 0 : _block->refCount.load();
        }
    
    protected:
        shared_ptr(control_block* block, value_type* ptr):
            _block(block),
            _ptr(ptr)
        {
            if (_block != nullptr) {
                TW_SHARED_PTR_ASSERT(_ptr != nullptr);
                //Increment ref count
                ++(_block->refCount);
            }
        }
    
    private:
        /**
         * @brief Increment the reference count
         */
        void _incr() {
            if (_block != nullptr) {
                TW_SHARED_PTR_ASSERT(_ptr != nullptr);
                //Increment ref count
                ++(_block->refCount);
            }
        }
    
    private:
        control_block* _block;
        value_type* _ptr;
    
        template <typename T_, typename ... ARGS_>
        friend
        shared_ptr<T_> make_shared(ARGS_&&...);
    };
    
    //-----[ CLASS: _enable_shared_from_this ]--------------------------------------
    class _enable_shared_from_this {
    protected:
        _enable_shared_from_this():
            _block(nullptr)
        {
    
        }
    
    protected:
        _shared_ptr_base::control_block* _block;
    };
    
    //-----[ TEMPLATE CLASS: enable_shared_from_this ]------------------------------
    template <typename T>
    class enable_shared_from_this : public _enable_shared_from_this {
        template <typename T_, typename ... ARGS_>
        friend
        shared_ptr<T_> make_shared(ARGS_&&...);
    
    public:
        shared_ptr<T> shared_from_this() noexcept {
            return shared_ptr<T>(_block, static_cast<T*>(this));
        }
    
        shared_ptr<const T> shared_from_this() const noexcept {
            return shared_ptr<const T>(_block, static_cast<const T*>(this));
        }
    };
    
    //==============================================================================
    // FUNCTIONS
    //==============================================================================
    
    template <typename T, typename ... ARGS>
    shared_ptr<T> make_shared(ARGS&&... args) {
        struct FullBlock {
            _shared_ptr_base::control_block block;
            T value;
        };
    
        //Allocate block on heap
        auto block = new FullBlock{
            {0},
            T{std::forward<ARGS>(args)...} //value
        };
    
        block->block.destructor.p = block;
        block->block.destructor.destroy = [](const void* x){
            delete static_cast<const FullBlock*>(x);
        };
        if constexpr (std::is_base_of_v<_enable_shared_from_this, T>) {
            block->value._block = &block->block;
        }
    
        /*
         * Up until this point, the make_shared function "owns" the pointer to
         * 'block'.  It now "transfers" ownership of this pointer to a shared_ptr
         * instance.
         */
        return shared_ptr<T>(&block->block, &block->value);
    }
    
    } //namespace tw
    
    #endif
    
    #endif /* TW_SHARED_PTR_H_ */