I'm writing code that I want to be usable in interactive shell, like IPython, that means that the code needs to be able to handle a completely asynchronous, and unexpected, exception like KeyboardInterrupt
and still be functional after it. I guess I should use context managers to handle locking and unlocking.
Problem is, what if the exception is raised during the execution of __enter__
or __exit__
method and I need to acquire multiple locks in the __enter__
method and release multiple ones in __exit__
?
My understanding is that calls to C code are atomic and all the bytecode instructions are atomic. So the call to thread.Lock.acquire()
won't be interrupted, as will the assignment of the result of that call to a variable. But, we don't have a guarantee that an exception won't be raised between the return from acquire()
and assignment of the return value to a variable (as that's handled by two different operations on bytecode level).
In other words, code like this:
import threading
lock = threading.Lock()
while True:
locked = False
try:
locked = lock.acquire()
except KeyboardInterrupt:
if locked:
lock.release()
locked = False
raise
finally:
if locked:
lock.release()
won't always release the lock
on Ctrl+C.
similarly, the following script will sometimes fail the last assert
check:
import sys
import threading
class ComplexLock():
def __init__(self):
self.lock_a = threading.Lock()
self.lock_b = threading.Lock()
def __enter__(self):
self.lock_a.acquire()
self.lock_b.acquire()
def __exit__(self, exc_type, exc_value, traceback):
self.lock_b.release()
self.lock_a.release()
def state_consistent(self):
return not self.lock_a.locked() and not self.lock_b.locked()
a = ComplexLock()
assert a.state_consistent()
try:
while True:
with a:
pass
except KeyboardInterrupt:
sys.exit(0)
finally:
assert a.state_consistent()
(side note: I can't use lock.locked()
in place of locked
as I don't know if it's the current thread that is holding the lock; lock
is global and shared, locked
is thread-local)
So, the question is how to handle KeyboardInterrupt
in a way that doesn't corrupt global objects?
A practical example use case:
Consider that you want to calculate e**x
, but for that sake of argument, the computation time is long and depends on value of x
.
As a library vendor, I provide the e
as an export so the user can do
from example_library import e
e**12
Now, because the calculation take a lot of time, I want to cache internally some precomputed variables (not results, I'm aware of lru_cache and it won't work for my real use case). So even though e
doesn't change externally, it does change internally. But that's an implementation detail unknown to the user.
So when a user does a ^C because e**10000
takes too long, it is reasonable for the user to expect that doing a e**12
after that ^C will work just fine.
But it won't work fine if the lock used to synchronise the cache of precomputed variables was locked but not freed during previous calculation.
Actually, it looks like it's not possible in general sense, it's a bug in python: https://bugs.python.org/issue29988 and also it's a known deficiency, see PEP 419 (unfortunately deferred). And the more specific for this case: https://bugs.python.org/issue31388
the short of it is that even if the __enter__
and __exit__
are C functions, it's not guaranteed that the __exit__
will get called even if the __enter__
executed successfully, we can just patch it up on python level to make it less likely, not impossible
by implementing our own signal handler and deferring the signal delivery, we can make it fairly robust, but only if the critical section is expected to be short, so the deferral won't impact the user
to continue example from the question, something like this seems to work as expected, even if we introduce threads into the mix:
import sys
import signal
import threading
class ComplexLock():
def __init__(self):
self.lock_a = threading.Lock()
self.lock_b = threading.Lock()
def __enter__(self):
if threading.current_thread().__class__.__name__ == '_MainThread':
# only MainThread can handle signals
self.signal_received = False
self.old_handler = signal.signal(signal.SIGINT, self._handler)
self.lock_a.acquire()
self.lock_b.acquire()
def __exit__(self, exc_type, exc_value, traceback):
self.lock_b.release()
self.lock_a.release()
if threading.current_thread().__class__.__name__ == '_MainThread':
signal.signal(signal.SIGINT, self.old_handler)
if self.signal_received:
self.old_handler(*self.signal_received)
def _handler(self, sig, frame):
self.signal_received = (sig, frame)
def state_consistent(self):
return not self.lock_a.locked() and not self.lock_b.locked()
lock = ComplexLock()
assert lock.state_consistent()
import time
def countdown(x):
while x[0]:
with x[1]:
pass
param = [True, lock]
t = threading.Thread(target=countdown, args=(param, ))
t.start()
try:
while True:
with lock:
pass
except KeyboardInterrupt:
param[0] = False
t.join()
sys.exit(0)
finally:
assert lock.state_consistent()
(there are obviously races when multiple signals get delivered, and some can get dropped, but that's less of a problem than broken environment after Ctrl+C)