Search code examples
pythonmultithreadingbarrier

"Python threading barrier" Why this code works and is there any better way?


I have searched for python barrier but there is very few related issues. I am still confused about barrier.wait(), even my code works.

I utilize python barrier to implement such a function: A main thread and n sub-threads. In each round, the main thread waits for all the sub-threads finishing their current work, and then all the threads go to the next round until some condition we meet. Thus, I found that barrier is proper to implement this function, here is my code for main thread.

 def superstep(self):
    workers = []
    barrier = threading.Barrier(self.num_workers+1)
    for vertex in self.vertices:
        worker = Worker(vertex, barrier)
        workers.append(worker)
        worker.start()

    while self.flag:
        barrier.wait()
        self.distributeMessages()
        self.count += 1
        print ("superstep: ", self.count)
        self.flag = self.isTerminated()

    for worker in workers:
        worker.flag = False

    for worker in workers:
        worker.join()
  1. the first 'for' loop creates n threads named worker and stored in a list workers.
  2. the 'while' loop is the main thread that waits for other sub-threads, and breaks when self.flag is False.
  3. the second 'for' loop used for setting flag to False in each worker(sub-threads), telling them to exit loop.

here is my Worker class.

class Worker(threading.Thread):
    def __init__(self, vertex, barrier):
        threading.Thread.__init__(self)
        self.vertex = vertex
        self.flag = True
        self.barrier = barrier

    def run(self):
        while self.flag:
            self.barrier.wait()
            do something

The code works well, all the threads can join(). But as I looked at python barrier, all the threads will release simultaneously when all the threads call wait(). What if the main threads breaks from its while loop and all the other threads are waiting for it just right, in this case, the second 'for' loop is useless and sub-threads will never join().

So how this code works, is there any other way to exit the barrier instead of raising BrokenBarrierError? In addition, if I add some code in the second 'for' loop, print some information or something else, the procedure is blocked. I guess there must be some sub-threads that are in wait() and have no chance to check flag, so they cannot exit from run() of threads.


Solution

  • If you don't want to use abort you could have two calls to Barrier.wait in each thread. This would break the operation to two parts. In first part worker threads would do their work and main thread would update the flag status. Then on second part every thread would check the flag status and exit the loop if necessary.

    On code level it would look something like this:

    # Main
    def superstep(self):
        workers = []
        barrier = threading.Barrier(self.num_workers+1)
        for vertex in self.vertices:
            worker = Worker(vertex, barrier)
            workers.append(worker)
            worker.start()
    
        while self.flag:
            barrier.wait()
            self.distributeMessages()
            self.count += 1
            print ("superstep: ", self.count)
            self.flag = self.isTerminated()
            for worker in workers:
                worker.flag = self.flag
            barrier.wait()
    
        for worker in workers:
            worker.join()
    
    # Worker
    def run(self):
        while self.flag:
            self.barrier.wait()
            # do something
            self.barrier.wait()