Search code examples
c++multithreadingconcurrencydeadlockrace-condition

Unexpected output of multithreaded C++ program


I'm studying concurrency in C++ and I'm trying to implement a multithreaded callback registration system. I came up with the following code, which is supposed to accept registration requests until an event occurs. After that, it should execute all the registered callbacks in order with which they were registered. The registration order doesn't have to be deterministic. The code doesn't work as expected. First of all, it rarely prints the "Pushing callback with id" message. Secondly, it sometimes hangs (a deadlock caused by a race condition, I assume). I'd appreciate help in figuring out what's going on here. If you see that I overcomplicate some parts of the code or misuse some pieces, please also point it out.

#include <condition_variable>
#include <functional>
#include <iostream>
#include <mutex>
#include <queue>
#include <thread>

class CallbackRegistrar{
public:
    void registerCallbackAndExecute(std::function<void()> callback) {
        if (!eventTriggered) {
            std::unique_lock<std::mutex> lock(callbackMutex);
            auto saved_id = callback_id;
            std::cout << "Pushing callback with id " << saved_id << std::endl;
            registeredCallbacks.push(std::make_pair(callback_id, callback));
            ++callback_id;
            callbackCond.wait(lock, [this, saved_id]{return releasedCallback.first == saved_id;});
            releasedCallback.second();
            callbackExecuted = true;
            eventCond.notify_one();
        }
        else {
            callback();
        }
    }
    void registerEvent() {
        eventTriggered = true;
        while (!registeredCallbacks.empty()) {
            releasedCallback = registeredCallbacks.front();
            callbackCond.notify_all();
            std::unique_lock<std::mutex> lock(eventMutex);
            eventCond.wait(lock, [this]{return callbackExecuted;});
            callbackExecuted = false;
            registeredCallbacks.pop();
        }
    }
private:
    std::queue<std::pair<unsigned, std::function<void()>>> registeredCallbacks;
    bool eventTriggered{false};
    bool callbackExecuted{false};
    std::mutex callbackMutex;
    std::mutex eventMutex;
    std::condition_variable callbackCond;
    std::condition_variable eventCond;
    unsigned callback_id{1};
    std::pair<unsigned, std::function<void()>> releasedCallback;
};

int main()
{
    CallbackRegistrar registrar;
    std::thread t1(&CallbackRegistrar::registerCallbackAndExecute, std::ref(registrar), []{std::cout << "First!\n";});
    std::thread t2(&CallbackRegistrar::registerCallbackAndExecute, std::ref(registrar), []{std::cout << "Second!\n";});
    
    registrar.registerEvent();
    
    t1.join();
    t2.join();

    return 0;
}

Solution

  • This answer has been edited in response to more information being provided by the OP in a comment, the edit is at the bottom of the answer.

    Along with the excellent suggestions in the comments, the main problem that I have found in your code is with the callbackCond condition variable wait condition that you have set up. What happens if releasedCallback.first does not equal savedId?

    When I have run your code (with a thread-safe queue and eventTriggered as an atomic) I found that the problem was in this wait function, if you put a print statement in that function you will find that you get something like this:

    releasedCallback.first: 0, savedId: 1
    

    This then waits forever.

    In fact, I've found that the condition variables used in your code aren't actually needed. You only need one, and it can live inside the thread-safe queue that you are going to build after some searching ;)

    After you have the thread-safe queue, the code from above can be reduced to:

    class CallbackRegistrar{
    public:
      using NumberedCallback = std::pair<unsigned int, std::function<void()>>;
    
      void postCallback(std::function<void()> callback) {
    
        if (!eventTriggered)
        {
          std::unique_lock<std::mutex> lock(mutex);
          auto saved_id = callback_id;
          std::cout << "Pushing callback with id " << saved_id << std::endl;
          registeredCallbacks.push(std::make_pair(callback_id, callback));
          ++callback_id;
        }
        else
        {
          while (!registeredCallbacks.empty())
          {
            NumberedCallback releasedCallback;
            registeredCallbacks.waitAndPop(releasedCallback);
            releasedCallback.second();
          }
          callback();
        }
      }
      void registerEvent() {
        eventTriggered = true;
      }
    private:
      ThreadSafeQueue<NumberedCallback> registeredCallbacks;
      std::atomic<bool> eventTriggered{false};
      std::mutex mutex;
      unsigned int callback_id{1};
    };
    
    int main()
    {
      CallbackRegistrar registrar;
      std::vector<std::thread> threads;
    
      for (int i = 0; i < 10; i++)
      {
        threads.push_back(std::thread(&CallbackRegistrar::postCallback, 
                                      std::ref(registrar), 
                                      [i]{std::cout << std::to_string(i) <<"\n";}
                                      ));
      }
    
      registrar.registerEvent();
    
      for (auto& thread : threads)
      {
        thread.join();
      }
    
      return 0;
    }
    

    I'm not sure if this does exactly what you want, but it doesn't deadlock. It's a good starting point in any case, but you need to bring your own implementation of ThreadSafeQueue.

    Edit

    This edit is in response to the comment by the OP stating that "once the event occurs, all the callbacks should be executed in [the] order that they've been pushed to the queue and by the same thread that registered them".

    This was not mentioned in the original question post. However, if that is the required behaviour then we need to have a condition variable wait in the postCallback method. I think this is also the reason why the OP had the condition variable in the postCallback method in the first place.

    In the code below I have made a few edits to the callbacks, they now take input parameters. I did this to print some useful information while the code is running so that it is easier to see how it works, and, importantly how the condition variable wait is working.

    The basic idea is similar to what you had done, I've just trimmed out the stuff you didn't need.

    class CallbackRegistrar{
    public:
      using NumberedCallback = std::pair<unsigned int, std::function<void(int, int)>>;
    
      void postCallback(std::function<void(int, int)> callback, int threadId) {
    
        if (!m_eventTriggered)
        {
          // Lock the m_mutex
          std::unique_lock<std::mutex> lock(m_mutex);
    
          // Save the current callback ID and push the callback to the queue
          auto savedId = m_currentCallbackId++;
          std::cout << "Pushing callback with ID " << savedId << "\n";
          m_registeredCallbacks.push(std::make_pair(savedId, callback));
    
          // Wait until our thread's callback is next in the queue,
          // this will occur when the ID of the last called callback is one less than our saved callback.
          m_conditionVariable.wait(lock, [this, savedId, threadId] () -> bool
          {
            std::cout << "Waiting on thread " << threadId << " last: " << m_lastCalledCallbackId << ", saved - 1: " << (savedId - 1) << "\n";
            return (m_lastCalledCallbackId == (savedId - 1));
          });
    
          // Once we are finished waiting, get the callback out of the queue
          NumberedCallback retrievedCallback;
          m_registeredCallbacks.waitAndPop(retrievedCallback);
    
          // Update last callback ID and call the callback
          m_lastCalledCallbackId = retrievedCallback.first;
          retrievedCallback.second(m_lastCalledCallbackId, threadId);
    
          // Notify one waiting thread
          m_conditionVariable.notify_one();
        }
        else
        {
          // If the event is already triggered, call the callback straight away
          callback(-1, threadId);
        }
      }
    
      void registerEvent() {
        // This is all we have to do here.
        m_eventTriggered = true;
      }
    
    private:
      ThreadSafeQueue<NumberedCallback> m_registeredCallbacks;
      std::atomic<bool> m_eventTriggered{ false};
      std::mutex m_mutex;
      std::condition_variable m_conditionVariable;
      unsigned int m_currentCallbackId{ 1};
      std::atomic<unsigned int> m_lastCalledCallbackId{ 0};
    };
    

    The main function is as above, except I am creating 100 threads instead of 10, and I have made the callback print out information about how it was called.

    for (int createdThreadId = 0; createdThreadId < 100; createdThreadId++)
    {
      threads.push_back(std::thread(&CallbackRegistrar::postCallback,
                                    std::ref(registrar),
                                    [createdThreadId](int registeredCallbackId, int callingThreadId)
                                    {
                                      if (registeredCallbackId < 0)
                                      {
                                        std::cout << "Callback " << createdThreadId;
                                        std::cout << " called immediately, from thread: " << callingThreadId << "\n";
                                      }
                                      else
                                      {
                                        std::cout << "Callback " << createdThreadId;
                                        std::cout << " called from thread " << callingThreadId;
                                        std::cout << " after being registered as " << registeredCallbackId << "\n";
                                      }
                                    },
                                    createdThreadId));
    }
    

    I am not entirely sure why you want to do this, as it seems to defeat the point of having multiple threads, although I may be missing something there. But, regardless, I hope this helps you to understand better the problem you are trying to solve.