Search code examples
c++multithreadingc++20visual-studio-2022

std::latch C++20 latch.wait() block a whole program


I am learning C++20, and try to practice std::latch

but my example does not work well, Could you please help me to debug this code?

My program is blocked when calling taskDone.wait() function

#include <chrono>
#include <iostream>
#include <thread>
#include <latch>

using namespace std::chrono_literals;
std::latch  latch{ 3 };
std::latch taskDone{ 1 };

void worker_f(int this_id)
{
    if (latch.try_wait())
    {
        std::printf("worker %d comes late, full slot ... waiting for the task done\n", this_id);
        taskDone.wait();
        std::printf("task done, worker %d exit\n", this_id);
    }
    else
    {
        std::printf("worker %d comes on time, there is a task for him\n", this_id);
        latch.count_down();
    }
    std::this_thread::sleep_for(100ms);
}
    

int main()
{
    for (int i = 0; i < 5; ++i)
    {
        std::jthread ac{ worker_f, i };
    }
    std::cout << "hello" << std::endl;
    std::printf("waiting for worker\n");
    latch.wait();
    std::printf("full slot, doing task\n");

    taskDone.count_down();
    
    std::printf("task done\n");
    std::this_thread::sleep_for(1000ms);
}

the output:

worker 0 comes on time, there is a task for him
worker 1 comes on time, there is a task for him
worker 2 comes on time, there is a task for him
worker 3 comes late, full slot ... waiting for the task done

Thank you very much for your help!


Solution

  • Your problem is that your threads aren't running in parallel:

    for (int i = 0; i < 5; ++i) {
        std::jthread ac{ worker_f, i };
    }
    

    ac is a local variable and is destroyed at the end of its scope (when the next loop starts). When a std::jthread is destroyed, it waits for its thread to join via ac.request_stop(); ac.join();.
    This means you are starting a thread and waiting for it to end before going on to the next iteration of the loop.

    You need to move them out of scope:

    std::jthread acs[5];
    for (int i = 0; i < 5; ++i) {
        acs[i] = std::jthread{ worker_f, i };
    }