Search code examples
c++c++17condition-variable

Multi-consumer condition variable wait in the same instance's member function


I'm having trouble thinking of a way to properly implement a signalling mechanism for multiple listeners waiting in the same function for a producer to signal some new data continuously, without getting "signalled" for the same previous data-

I want all listeners to always see the latest available data (not caring about missed signals if they are busy), without repeats.

My attempt so far:

#include <functional>
#include <shared_mutex>
#include <condition_variable>
#include <thread>

class Signaller {
public:
    // Used by producer, will hold on to the mutex uniquely as it modifies data
    void Signal(const std::function<void()>& fnIn) {
        std::unique_lock lock(m_mtx);
        fnIn();
        m_newData = true;
        m_cv.notify_all();
    }

    // Used by consumers, will only hold shared mutex to read data
    void Wait(const std::function<void()>& fnIn) {
        {
            std::shared_lock lock(m_mtx);
            m_cv.wait(lock, [this](){ return m_newData; });
            fnIn();
        }
        // Need some way to flip m_newData to false when all threads are "done" 
        // (or some other method of preventing spurious wakeups)
        // I don't think this is particularly ideal
        {
            std::unique_lock lock(m_mtx);
            m_newData = false;
        }
    }
private:
    std::condition_variable_any m_cv;
    std::shared_mutex m_mtx;
    bool m_newData{false}; // To prevent spurious wakeups
};

class Example {
public:
    // Multiple threads will call this function in the same instance of Example
    void ConsumerLoop() 
    {
        int latestData{0};
        while (true){
            m_signaller.Wait([this, &latestData](){ latestData = m_latestData; });
            
            // process latestData...

            // I want to make sure latestData here is always the latest 
            // (It's OK to miss a few signals in between if its off processing this latest data)
        }
    }

    // One thread will be using this to signal new data
    void ProducerLoop(){
        while(true){
            int newData = rand();
            m_signaller.Signal([this, newData](){ m_latestData = newData; });
            std::this_thread::sleep_for(std::chrono::milliseconds(1));
        }
    }

private:
    Signaller m_signaller;
    int m_latestData;
};

My main issue (I think) is how to prevent spurious wakeups, while preventing repeated data from waking up the same thread. I've thought about using some sort of counter within each thread to keep track of whether it's receiving the same data, but couldn't get anywhere with that idea (unless I perhaps make some sort of map using std::this_thread::get_id?). Is there a better way to do this?

EDIT: Expanding on my map of thread ID's idea, I think I've found a solution:

#include <functional>
#include <shared_mutex>
#include <condition_variable>
#include <unordered_map>
#include <thread>

class Signaller {
public:
    // Used by producer, will hold on to the mutex uniquely as it modifies data
    void Signal(const std::function<void()>& fnIn) {
        std::unique_lock lock(m_mtx);
        fnIn();
        m_ctr++;
        m_cv.notify_all();
    }

    void RegisterWaiter(){
        std::unique_lock lock(m_mtx);
        auto [itr, emplaced] = m_threadCtrMap.try_emplace(std::this_thread::get_id(), m_ctr);
        if (!emplaced) {
           itr->second = m_ctr;
        }
    }

    // Used by consumers, will only hold shared mutex to read data
    void Wait(const std::function<void()>& fnIn) {
        std::shared_lock lock(m_mtx);
        m_cv.wait(lock, [this](){ return m_threadCtrMap[std::this_thread::get_id()] != m_ctr; });
        fnIn();
        m_threadCtrMap[std::this_thread::get_id()] = m_ctr; 
    }
private:
    std::condition_variable_any m_cv;
    std::shared_mutex m_mtx;
    std::uint32_t m_ctr{0};
    std::unordered_map<std::thread::id, std::uint32_t> m_threadCtrMap; // Stores the last signalled ctr for that thread
};

class Example {
public:
    // Multiple threads will call this function in the same instance of Example
    void ConsumerLoop() 
    {
        int latestData{0};
        m_signaller.RegisterWaiter();
        while (true){
            m_signaller.Wait([this, &latestData](){ latestData = m_latestData; });
        }
    }

    // One thread will be using this to signal new data
    void ProducerLoop(){
        while(true){
            int newData = rand();
            m_signaller.Signal([this, newData](){ m_latestData = newData; });
            std::this_thread::sleep_for(std::chrono::milliseconds(1));
        }
    }

private:
    Signaller m_signaller;
    int m_latestData;
};

Solution

  • Here's my implementation:

    #include <unordered_map>
    #include <condition_variable>
    #include <shared_mutex>
    #include <thread>
    /*
    Example usage:
    
    struct MyClass {
        MultiCVSignaller m_signaller;
        int m_latestData;
        std::atomic<bool> m_stop{false};
    
        ~MyClass(){ 
            m_stop = true;
            m_signaller.Shutdown();
        }
    
        void FuncToWaitOnData() { // e.g. Multiple threads call this fn to "subscribe" to the signal
            auto& signalCtr = m_signaller.RegisterListener();
            while(!m_stop.load(std::memory_order_relaxed)) {
                int latestDataInLocalThread;
                // WaitForSignal() calls the provided function while holding on to the shared mutex
                m_signaller.WaitForSignal(signalCtr, [this, &latestDataInLocalThread](){
                    latestDataInLocalThread = m_latestData;
                });
                // Make use of latest data...
            }
        }
    
        void ProducerLoop() {
            while(!m_stop.load(std::memory_order_relaxed)) {
                // Signal() holds on to the mutex uniquely while calling the provided function.
                m_signaller.Signal([&latestData](){
                    m_latestData = rand();
                });
            }
        }
    };
    */
    
    class MultiCVSignaller
    {
    public:
        using SignalCtr = std::uint32_t; 
    public:
        MultiCVSignaller() = default;
        ~MultiCVSignaller() { Shutdown(); }
        /*
            Call to set and signal shutdown state, cancelling waits (and skipping the functions provided if any)
            This should be added in the class' destructor before threads are joined.
        */
        void Shutdown() { 
            std::unique_lock lock(m_mtx);
            m_shutdown = true;
            m_cv.notify_all();
        }
    
        // Calls the function if specified while holding on to the mutex with a UNIQUE lock
        template<class Func = void(*)()>
        void Signal(Func fnIn = +[]{})
        {
            std::unique_lock lock(m_mtx);
            fnIn();
            m_ctr++;
            m_cv.notify_all();
        }
        
        MultiCVSignaller::SignalCtr& RegisterListener(){
            std::unique_lock lock(m_mtx);
            auto [itr, emplaced] = m_threadCtrMap.try_emplace(std::this_thread::get_id(), m_ctr);
            if (!emplaced) {
                itr->second = m_ctr;
            }
            return itr->second;
        }
    
        /* 
            Calls the optional function while holding on to the SHARED lock when signalled. The signalCtr argument should be provided by the return of RegisterListener() (see example)     
        */ 
        template<class Func = void(*)()>
        void WaitForSignal(MultiCVSignaller::SignalCtr& signalCtr, Func fnIn = +[]{})
        {
            std::shared_lock lock(m_mtx);
            m_cv.wait(lock, [this, &signalCtr](){ return ( m_shutdown ||  signalCtr != m_ctr); });
            if (!m_shutdown) 
            {
                fnIn();
                signalCtr = m_ctr;
            }
        }
    
    private:
        std::condition_variable_any m_cv;
        std::shared_mutex m_mtx;
        bool m_shutdown{false};
        SignalCtr m_ctr{0}; // Latest ctr from Signal()
        // This map stores the signal count received for registered listeners.
        // We use an unordered_map as references are never invalidated (unless erased), 
        // which is not the case for a vector
        std::unordered_map<std::thread::id, SignalCtr> m_threadCtrMap;
    };