Search code examples
pythonpython-3.xmultithreadingsemaphorepython-multithreading

How to create a joinable semaphore in python?


In some use-cases i need to wait for all already created threads to finish and make some decisions based on their results and see whether we need to move further or not - without ThreadPoolExecutor.shutdown().

i implemented it like this:

from threading import BoundedSemaphore, Event


class JoinSemaphore(BoundedSemaphore):
    def __init__(self, value=1):
        super().__init__(value)
        self._empty = Event()

    def join(self, timeout=None):
        if self._value < self._initial_value:
            self._empty.wait(timeout)

    def release(self):
        with self._cond:
            if self._value >= self._initial_value:
                raise ValueError("Semaphore released too many times")
            elif self._value == self._initial_value - 1:
                self._empty.set()

            self._value += 1
            self._cond.notify()

    def acquired(self):
        with self._cond:
            return self._initial_value - self._value

Here i compute self._value < self._initial_value without any guard and it has risk. when i write the join() function like bellow to prevent unwanted changes while i'm computing self._value < self._initial_value i will face a deadlock when the main-thread joins on the semaphore and at this point other thread and not acquire release() lock because main thread already accrued it, but main thread still waiting for other threads. so this implementation is not correct.

    def join(self, timeout=None):
        with self._cond:
            if self._value < self._initial_value:
                self._empty.wait(timeout)

In third implementation i cannot guarantee that when i want to wait() for _empty event, other threads wouldn't have possibly submit their results and release the lock.

    def join(self, timeout=None):
        with self._cond:
            if self._value == self._initial_value:
                return
        self._empty.wait(timeout)

the Question is:

How can i compute self._value < self._initial_value correctly using lock and wait for _empty event and release the lock before waiting on it to avoid deadlock?

Thanks a lot


Solution

  • This is a way i found for this problem, i wanted to inherit from base Semaphore class but couldn't find an easy solution to fix the problem, so i decided to edit the base BoundSemaphore class:

    from threading import Event, Condition, Lock
    from time import monotonic as _time
    
    
    class JoinSemaphore:
    
        def __init__(self, value=1):
            if value < 0:
                raise ValueError("semaphore initial value must be >= 0")
            self._cond = Condition(Lock())
            self._value = value
            self._initial_value = value
            self._empty = Event()
            self._empty.set()
    
        def acquire(self, blocking=True, timeout=None):
            if not blocking and timeout is not None:
                raise ValueError("can't specify timeout for non-blocking acquire")
            rc = False
            endtime = None
            with self._cond:
                while self._value == 0:
                    if not blocking:
                        break
                    if timeout is not None:
                        if endtime is None:
                            endtime = _time() + timeout
                        else:
                            timeout = endtime - _time()
                            if timeout <= 0:
                                break
                    self._cond.wait(timeout)
                else:
                    self._empty.clear()
                    self._value -= 1
                    rc = True
            return rc
    
        __enter__ = acquire
    
        def join(self, timeout=None):
            self._empty.wait(timeout)
    
        def release(self):
            with self._cond:
                if self._value >= self._initial_value:
                    raise ValueError("Semaphore released too many times")
                elif self._value == self._initial_value - 1:
                    self._empty.set()
    
                self._value += 1
                self._cond.notify()
    
        def acquired(self):
            with self._cond:
                return self._initial_value - self._value
    
    

    diffs are:

    class JoinSemaphore:
    
        def __init__(self, value=1):
            ...
            self._empty = Event()
            self._empty.set()
    
        def acquire(self, blocking=True, timeout=None):
            ...
            with self._cond:
                while ...:
            ...
                else:
                    self._empty.clear()
            ...
    
        def join(self, timeout=None):
            self._empty.wait(timeout)
    
        def release(self):
            ...
                elif self._value == self._initial_value - 1:
                    self._empty.set()
            ...
    
        def acquired(self):
            with self._cond:
                return self._initial_value - self._value