Search code examples
c++thread-safetyshared-ptratomic

Correct way to create thread_safe shared_ptr without a lock?


I'm trying to create a class with a thread-safe shared_ptr. My use case is that the shared_ptr belongs to an object of the class, and behaves sort of like a singleton (the CreateIfNotExist function can be run by any thread at any point in time).

Essentially if the pointer is null, the first thread that sets it's value wins, and all other threads that are creating it at the same time use the winning thread's value.

Here is what I have so far (note that the only function in question is the CreateIfNotExist() function, rest is for testing purposes):

#include <memory>
#include <iostream>
#include <thread>
#include <vector>
#include <mutex>

struct A {
    A(int a) : x(a) {}
    int x;
};

struct B {
    B() : test(nullptr) {}

    void CreateIfNotExist(int val) {
        std::shared_ptr<A> newPtr = std::make_shared<A>(val);
        std::shared_ptr<A> _null = nullptr;
        std::atomic_compare_exchange_strong(&test, &_null, newPtr);
    }

    std::shared_ptr<A> test;
};

int gRet = -1;
std::mutex m;

void Func(B* b, int val) {
    b->CreateIfNotExist(val);
    int ret =  b->test->x;

    if(gRet == -1) {
        std::unique_lock<std::mutex> l(m);
        if(gRet == -1) {
            gRet = ret;
        }
    }

    if(ret != gRet) {
        std::cout << " FAILED " << std::endl;
    }
}

int main() {
    B b;

    std::vector<std::thread> threads;
    for(int i = 0; i < 10000; ++i) {
        threads.clear();
        for(int i = 0; i < 8; ++i) threads.emplace_back(&Func, &b, i);
        for(int i = 0; i < 8; ++i) threads[i].join();
    }
}

Is this the correct way to do this? Is there a better way to ensure that all threads calling CreateIfNotExist() at the same time all use the same shared_ptr?


Solution

  • Something along these lines perhaps:

    struct B {
      void CreateIfNotExist(int val) {
        std::call_once(test_init,
                       [this, val](){test = std::make_shared<A>(val);});
      }
    
      std::shared_ptr<A> test;
      std::once_flag test_init;
    };