Search code examples
pythonmultithreadingcritical-section

Guarding critical section in a multithreaded program


I have a multithreaded Python program (financial trading) in which certain threads execute critical sections (like in the middle of executing a trade). The thread executing the critical sections are daemon threads. The main thread of the program captures SIGINT and tries to exit the program gracefully by releasing all resources held by child threads. In order to prevent the main thread causing the child threads to terminate abruptly; the main the thread will loop through the list of child thread objects and call their shutdown() function. This function will block till a critical section of the thread completes before returning.

The following is a basic implementation

class ChildDaemonThread(Thread):

    def __init__(self):
        self._critical_section = False        
        # other initialisations

    def shutdown(self):
        # called by parent thread before calling sys.exit(0)

        while True:
            if not self._critical_section:
                break

            # add code to prevent entering critical section
            # do resource deallocation

     def do_critical_stuff(self):
         self._critical_section = True
         # do critical stuff
         self._critical_section = False

     def run(self):
         while True:
             self._do_critical_stuff()

I am not sure if my implementation will work because while the ChildDaemonThread is executing critical section through do_critical_stuff(), if the parent thread calls the child's shutdown(), which blocks till a critical section executes, then at this point two methods of the ChildDaemonThread run() and do_critical_stuff() are called at the same time (I am not sure if this is even legal). Is this possible? Is my implementation correct? Is there a better way to achieve this?


Solution

  • There are some race conditions in this implementation.

    You have no guarantee that the main thread will check the value of _critical_section at the right time to see a False value. The worker thread may leave and re-enter the critical section before the main thread gets around to checking the value again. This may not cause any issues of correctness but it could cause your program to take longer to shut down (since when the main thread "misses" a safe time to shut down it will have to wait for another critical section to complete).

    Additionally, the worker thread may re-enter the critical after the main thread has noticed that _critical_section is False but before the main thread manages to cause the process to exit. This could pose real correctness issues since it effectively breaks your attempt to make sure the critical section completes.

    Of course, the program may also crash due to some other issue. Therefore, it may be better if you implement the ability to recover from an interrupted critical section.

    However, if you want to improve this strategy to the greatest extent possible, I would suggest something more like this:

    class ChildDaemonThread(Thread):
    
        def __init__(self):
            self._keep_running = True
            # other initialisations
    
        def shutdown(self):
            # called by parent thread before calling sys.exit(0)
            self._keep_running = False
    
         def do_critical_stuff(self):
             # do critical stuff
    
         def run(self):
             while self._keep_running:
                 self._do_critical_stuff()
             # do resource deallocation
    
    
    workers = [ChildDaemonThread(), ...]
    
    # Install your SIGINT handler which calls shutdown on all of workers
    # ...
    
    # Start all the workers
    for w in workers:
        w.start()
    
    # Wait for the run method of all the workers to return
    for w in workers:
        w.join()
    

    The key here is that join will block until the thread has finished. This ensures you're not interrupting one mid-critical-section.