Search code examples
pythonpython-asynciopython-triostructured-concurrency

How to force an Async Context Manager to Exit


I've been getting into Structured Concurrency recently and this is a pattern that keeps cropping up:

It's nice to use async context managers to access resource - say a some websocket. That's all great if the websocket stays open, but what about if it closes? well we expect our context to be forcefully exited - normally through an exception.

How can I write and implement a context manager that exhibits this behaviour? How can I throw an exception 'into' the calling codes open context? How can I forcefully exit a context?

Here's a simple setup, just for argument's sake:

# Let's pretend I'm implementing this:
class SomeServiceContextManager:
    def __init__(self, service):
        self.service = service

    async def __aenter__(self):
        await self.service.connect(self.connection_state_callback)
        return self.service

    async def __aexit__(self, exc_type, exc, tb):
        self.service.disconnect()
        return False

    def connection_state_callback(self, state):
        if state == "connection lost":
            print("WHAT DO I DO HERE? how do I inform my consumer and force the exit of their context manager?")

class Consumer:
    async def send_stuff(self):
        try:
            async with SomeServiceContextManager(self.service) as connected_service:
                while True:
                    await asyncio.sleep(1)
                    connected_service.send("hello")
        except ConnectionLostException: #<< how do I implement this from the ContextManager?
            print("Oh no my connection was lost!!")

How is this generally handled? It seems to be something I've run up into a couple of times when writing ContextManagers!

Here's a slightly more interesting example (hopefully) to demonstrate how things get a bit messy - say you are receiving through an async loop but want to close your connection if something downstream disconnects:

# Let's pretend I'm implementing this:
class SomeServiceContextManager:
    def __init__(self, service):
        self.service = service

    async def __aenter__(self):
        await self.service.connect(self.connection_state_callback)
        return self.service

    async def __aexit__(self, exc_type, exc, tb):
        self.service.disconnect()
        return False

    def connection_state_callback(self, state):
        if state == "connection lost":
            print("WHAT DO I DO HERE? how do I inform my consumer and force the exit of their context manager?")

class Consumer:
    async def translate_stuff_stuff(self):
        async with SomeOtherServiceContextManager(self.otherservice) as connected_other_service:
            try:
                async with SomeServiceContextManager(self.service) as connected_service:
                    for message in connected_other_service.messages():
                        connected_service.send("message received: " + message.text)
            except ConnectionLostException: #<< how do I implement this from the ContextManager?
                print("Oh no my connection was lost - I'll also drop out of the other service connection!!")

Solution

  • Before we get started, let's replace manual __aenter__() and __aexit__() implementations with contextlib.asynccontextmanager. This takes care of handling exceptions properly and is especially useful when you have nested context managers, as we're going to have in this answer. Here's your snippet rewritten in this way.

    from contextlib import asynccontextmanager
    
    class SomeServiceConnection:
        def __init__(self, service):
            self.service = service
    
        async def _connect(self, connection_state_callback):
            await self.service.connect(connection_state_callback)
    
        async def _disconnect(self):
            self.service.disconnect()
    
    @asynccontextmanager
    async def get_service_connection(service):
        connection = SomeServiceConnection(service)
        await connection._connect(
            ...  # to be done
        )
        try:
            yield connection
        finally:
            await connection._disconnect()
    

    OK, with that out of the way: The core of the answer here is that, if you want to stop running tasks in response to some event, then use a cancel scope.

    @asynccontextmanager
    async def get_service_connection(service):
        connection = SomeServiceConnection(service)
        with trio.CancelScope() as cancel_scope:
            await connection._connect(cancel_scope.cancel)
            try:
                yield connection
            finally:
                await connection._disconnect()
                if cancel_scope.called:
                    raise RuntimeError("connection lost")
    

    But wait... what if some other exception (or exceptions!) were thrown at roughly the same time that the connection was closed? That would be lost when you raise your own exception. This is handily dealt with by using a nursery instead. This has its own cancel scope doing the cancellation work, but it also deals with creating ExceptionGroup objects (formerly known as MultiErrors). Now your callback just needs to raise an exception inside the nursery. As a bonus, there is a good chance you needed to run a background task to make the callback happen anyway. (If not, e.g., your callback is called from another thread via a Trio token, then use a trio.Event as another answer suggested, and await it from within the nursery.)

    async def check_connection(connection):
        await connection.wait_disconnected()
        raise RuntimeError("connection lost")
    
    @asynccontextmanager
    async def get_service_connection(service):
        connection = SomeServiceConnection(service)
        await connection._connect()
        try:
            async with trio.open_nursery() as nursery:
                nursery.start_soon(check_connection)
                yield connection
                nursery.cancel_scope.cancel()
        finally:
            await connection._disconnect()