Search code examples
pythonthread-safetypython-multithreading

Handling KeyboardInterrupt in context manager __enter__ method


I'm writing code that I want to be usable in interactive shell, like IPython, that means that the code needs to be able to handle a completely asynchronous, and unexpected, exception like KeyboardInterrupt and still be functional after it. I guess I should use context managers to handle locking and unlocking.

Problem is, what if the exception is raised during the execution of __enter__ or __exit__ method and I need to acquire multiple locks in the __enter__ method and release multiple ones in __exit__?

My understanding is that calls to C code are atomic and all the bytecode instructions are atomic. So the call to thread.Lock.acquire() won't be interrupted, as will the assignment of the result of that call to a variable. But, we don't have a guarantee that an exception won't be raised between the return from acquire() and assignment of the return value to a variable (as that's handled by two different operations on bytecode level).

In other words, code like this:

import threading
lock = threading.Lock()
while True:
    locked = False
    try:
        locked = lock.acquire()
    except KeyboardInterrupt:
        if locked:
            lock.release()
            locked = False
        raise
    finally:
        if locked:
            lock.release()    

won't always release the lock on Ctrl+C.

similarly, the following script will sometimes fail the last assert check:

import sys
import threading

class ComplexLock():
    def __init__(self):
        self.lock_a = threading.Lock()
        self.lock_b = threading.Lock()

    def __enter__(self):
        self.lock_a.acquire()
        self.lock_b.acquire()

    def __exit__(self, exc_type, exc_value, traceback):
        self.lock_b.release()
        self.lock_a.release()

    def state_consistent(self):
        return not self.lock_a.locked() and not self.lock_b.locked()


a = ComplexLock()
assert a.state_consistent()

try:
    while True:
        with a:
            pass
except KeyboardInterrupt:
    sys.exit(0)
finally:
    assert a.state_consistent()

(side note: I can't use lock.locked() in place of locked as I don't know if it's the current thread that is holding the lock; lock is global and shared, locked is thread-local)

So, the question is how to handle KeyboardInterrupt in a way that doesn't corrupt global objects?

A practical example use case: Consider that you want to calculate e**x, but for that sake of argument, the computation time is long and depends on value of x. As a library vendor, I provide the e as an export so the user can do

from example_library import e
e**12

Now, because the calculation take a lot of time, I want to cache internally some precomputed variables (not results, I'm aware of lru_cache and it won't work for my real use case). So even though e doesn't change externally, it does change internally. But that's an implementation detail unknown to the user. So when a user does a ^C because e**10000 takes too long, it is reasonable for the user to expect that doing a e**12 after that ^C will work just fine.

But it won't work fine if the lock used to synchronise the cache of precomputed variables was locked but not freed during previous calculation.


Solution

  • Actually, it looks like it's not possible in general sense, it's a bug in python: https://bugs.python.org/issue29988 and also it's a known deficiency, see PEP 419 (unfortunately deferred). And the more specific for this case: https://bugs.python.org/issue31388

    the short of it is that even if the __enter__ and __exit__ are C functions, it's not guaranteed that the __exit__ will get called even if the __enter__ executed successfully, we can just patch it up on python level to make it less likely, not impossible

    by implementing our own signal handler and deferring the signal delivery, we can make it fairly robust, but only if the critical section is expected to be short, so the deferral won't impact the user

    to continue example from the question, something like this seems to work as expected, even if we introduce threads into the mix:

    import sys
    import signal
    import threading
    
    class ComplexLock():
        def __init__(self):
            self.lock_a = threading.Lock()
            self.lock_b = threading.Lock()
    
        def __enter__(self):
            if threading.current_thread().__class__.__name__ == '_MainThread':
                # only MainThread can handle signals
                self.signal_received = False
                self.old_handler = signal.signal(signal.SIGINT, self._handler)
    
            self.lock_a.acquire()
            self.lock_b.acquire()
    
        def __exit__(self, exc_type, exc_value, traceback):
            self.lock_b.release()
            self.lock_a.release()
    
            if threading.current_thread().__class__.__name__ == '_MainThread':
                signal.signal(signal.SIGINT, self.old_handler)
                if self.signal_received:
                    self.old_handler(*self.signal_received)
    
        def _handler(self, sig, frame):
            self.signal_received = (sig, frame)
    
        def state_consistent(self):
            return not self.lock_a.locked() and not self.lock_b.locked()
    
    lock = ComplexLock()
    assert lock.state_consistent()
    import time
    def countdown(x):
        while x[0]:
            with x[1]:
                pass
    
    param = [True, lock]
    
    t = threading.Thread(target=countdown, args=(param, ))
    t.start()
    
    try:
        while True:
            with lock:
                pass
    except KeyboardInterrupt:
        param[0] = False
        t.join()
        sys.exit(0)
    finally:
        assert lock.state_consistent()
    

    (there are obviously races when multiple signals get delivered, and some can get dropped, but that's less of a problem than broken environment after Ctrl+C)