Search code examples
c++algorithmdesign-patternsglobal-variablestbb

Thread pool design


I have a design question about thread pool. Consider the following code:

int main() {
  auto po = std::execution::par_unseq;
  
  // If Parallel STL uses TBB as the backend, I think the thread pool
  //   is created during the first time for_each() being called.
  std::for_each(po, begin, end, unaryFun);  
  // Worker threads are put to sleep.
  
  // ... Other things are done by the main thread.
  
  std::for_each(po, begin, end, unaryFun); // Threads wake up and run again.
  
  // ... Other things are done by the main thread.
  
} // Right before going out of scope, how does the thread pool know to destruct
//     itself ? 

TBB once had a memory leak problem C++: memory leak in tbb. Whenever I compiled my program with sanitizers, I had to set export ASAN_OPTIONS=alloc_dealloc_mismatch=0 to avoid crashing. I always thought the leaky problem is exactly due to the thread pool not being deleted going out of scope.

However, the newer version oneTBB no longer has this problem. How did they solve it? I don't believe the answer is as dumb as that the thread pool is constructed and destructed inside every for_each() call. How does the thread pool know to destruct itself going out of scope? I want to apply such design to some other structures.


Solution

  • You don't need to destroy the threadpool at the end of scope, on any operating system you can execute code when a library is being unloaded, as for C++, the compiler should call the destructors of all static objects when a library is being unloaded, you just need to make the threadpool a singleton.

    Here is a naive implementation of such design working godbolt demo

    #include <thread>
    #include <vector>
    #include <queue>
    #include <future>
    #include <optional>
    #include <mutex>
    #include <condition_variable>
    #include <iostream>
    #include <atomic>
    
    class TaskQueue
    {
        public:
        void push(std::optional<std::packaged_task<void()>> task)
        {
            {
                std::unique_lock lk{m_mutex};
                m_queue.push(std::move(task));
            }
            m_cv.notify_one();
        }
        std::optional<std::packaged_task<void()>> pop()
        {
            std::optional<std::packaged_task<void()>> task;
            {
                std::unique_lock lk(m_mutex);
                m_cv.wait(lk, [this](){ return !this->m_queue.empty();});
                task = std::move(m_queue.front());
                m_queue.pop();
            }
            return task;
        }
        private:
        std::queue<std::optional<std::packaged_task<void()>>> m_queue;
        std::mutex m_mutex;
        std::condition_variable m_cv;
    
    };
    
    class ThreadPool
    {
        public:
        static ThreadPool& Instance()
        {
            static ThreadPool pool(std::thread::hardware_concurrency());
            return pool;
        }
        
        template<typename Func>
        std::future<void> push_task(Func&& f)
        {
            std::packaged_task<void()> task{
                [func = std::move(f)] { func(); }
            };
            auto fut = task.get_future();
            m_queue.push(std::move(task));
            return fut;
        }
    
        private:
        ThreadPool(int thread_count)
        : m_thread_count{thread_count}
        {
            Initialize();
        }
        void worker_task()
        {
            while (m_running)
            {
                auto task = m_queue.pop();
                if (task)
                {
                    (*task)();
                }
                else
                {
                    break;
                }
            }
        }
        void Initialize()
        {
            m_running = true;
            for (int i = 0; i < m_thread_count; i++)
            {
                m_workers.push_back(std::thread{[this]{this->worker_task();}});
            }
        }
    
        ~ThreadPool()
        {
            m_running = false;
            for (int i = 0; i < m_thread_count; i++)
            {
                m_queue.push(std::nullopt);
            }
            for (auto&& worker: m_workers)
            {
                if (worker.joinable())
                {
                    worker.join();
                }
            }
        // maybe set an exception on every item left in queue
        }
        TaskQueue m_queue;
        std::vector<std::thread> m_workers;
        std::atomic<bool> m_running = false;
        int m_thread_count;
    };
    
    template<typename RndIter, typename Func>
    void foreach_par(RndIter begin, RndIter end, Func&& func)
    {
        std::vector<std::future<void>> futures;
        futures.reserve(std::distance(begin,end));
        auto&& threadpool = ThreadPool::Instance();
        while (begin != end)
        {
            futures.push_back(threadpool.push_task([begin = begin, &func]{ func(*begin);}));
            begin++;
        }
        for (auto&& future: futures)
        {
            future.get();
        }
    }
    
    int main()
    {
        std::vector<int> vals{1,2,3,4,5};
        foreach_par(vals.begin(), vals.end(), [](int& value) { value *= 2; });
    
        for (auto&& value: vals)
        {
            std::cout << value << ' ';
        }
    }
    

    Typically, libraries do extra optimizations like

    • reducing heap allocations
    • reducing locks by using lock-free structures
    • chunking tasks
    • per-worker queue and work stealing

    This implementation is not optimized but it is threadsafe and doesn't leak, but you could still run into problems like one thread using the threadpool after it is destroyed, also threads that terminate without a known reason need to be revived, or add timeouts.

    If you are designing your own threadpool i would recommend you make it not a singleton and instead pass around references to it or use a resource locator to it and just create it in your main to get around the uncontrollable order of construction/destruction of static objects, the above code could break if someone tried using the threadpool during the creation or destruction of another static object, there are a ton of downsides to singletons.