Search code examples
pythonpython-asyncioaiohttp

How to prevent asyncio.Task from being cancelled


I am implementing graceful shutdown that needs to wait for certain tasks to finish execution before shutting down the application. I am waiting for tasks using asyncio.gather(*asyncio.Task.all_tasks()) in the shutdown handler.

The problem I have however, is that the tasks that are created and need to be waited for, get cancelled as soon as I kill the application and therefore don't appear in asyncio.Task.get_all(). How to prevent that?


Solution

  • Note: asyncio.Task.all_tasks() is depricated, will refer it as asyncio.all_tasks() instead.


    TL;DR Demo code

    Different solutions per os type.

    • *nix: terminated by sending SIGINT
    • Windows: terminated by Ctrl+C

    Task duration is set to 10 seconds, so terminate before task completes.

    Pure asyncio (*nix only)

    Complex, long, reinventing the wheels. Adds custom signal handler to prevent error propagation.

    Demonstrating spawning 3 shielded & 3 unshielded tasks - former running until completion, latter getting canceled.

    """
    Task shielding demonstration with pure asyncio, nix only
    """
    import asyncio
    import signal
    import os
    
    
    # Sets of tasks we shouldn't cancel
    REQUIRE_SHIELDING = set()
    
    
    async def work(n):
        """Some random io intensive work to test shielding"""
        print(f"[{n}] Task start!")
        try:
            await asyncio.sleep(10)
    
        except asyncio.CancelledError:
            # we shouldn't see following output
            print(f"[{n}] Canceled!")
            return
    
        print(f"[{n}] Task done!")
    
    
    def install_handler():
    
        def handler(sig_name):
            print(f"Received {sig_name}")
    
            # distinguish what to await and what to cancel. We'll have to await all,
            # but we only have to manually cancel subset of it.
            to_await = asyncio.all_tasks()
            to_cancel = to_await - REQUIRE_SHIELDING
    
            # cancel tasks that don't require shielding
            for task in to_cancel:
                task.cancel()
    
            print(f"Cancelling {len(to_cancel)} out of {len(to_await)}")
    
        loop = asyncio.get_running_loop()
    
        # install for SIGINT and SIGTERM
        for signal_name in ("SIGINT", "SIGTERM"):
            loop.add_signal_handler(getattr(signal, signal_name), handler, signal_name)
    
    
    async def main():
        print(f"PID: {os.getpid()}")
    
        # If main task is done - errored or not - all other tasks are canceled.
        # So we need to shield main task.
        REQUIRE_SHIELDING.add(asyncio.current_task())
    
        # install handler
        install_handler()
    
        # spawn tasks that will be shielded
        for n in range(3):
            REQUIRE_SHIELDING.add(asyncio.create_task(work(n)))
    
        # spawn tasks that won't be shielded, for comparison
        for n in range(3, 6):
            asyncio.create_task(work(n))
    
        # we'll need to keep main task alive just until tasks are done, excluding self.
        await asyncio.gather(*(REQUIRE_SHIELDING - {asyncio.current_task()}))
    
    asyncio.run(main())
    
    PID: 10778
    [0] Task start!
    [1] Task start!
    [2] Task start!
    [3] Task start!
    [4] Task start!
    [5] Task start!
    Received SIGINT
    Cancelling 3 out of 7
    [3] Canceled!
    [5] Canceled!
    [4] Canceled!
    [0] Task done!
    [1] Task done!
    [2] Task done!
    

    asyncio + aiorun (All OS)

    Demonstrating same thing as above.

    """
    Task shielding demonstration with asyncio + aiorun, all OS
    """
    import asyncio
    import os
    
    from aiorun import run, shutdown_waits_for
    
    
    async def work(n):
        """Some random io intensive work to test shielding"""
        print(f"[{n}] Task start!")
        try:
            await asyncio.sleep(10)
    
        except asyncio.CancelledError:
            print(f"[{n}] Canceled!")
            return
    
        print(f"[{n}] Task done!")
    
    
    async def main():
        print(f"PID: {os.getpid()}")
        child_tasks = []
    
        # spawn tasks that will be shielded
        child_tasks.extend(
            asyncio.create_task(shutdown_waits_for(work(n))) for n in range(3)
        )
    
        # spawn tasks without shielding for comparison
        child_tasks.extend(asyncio.create_task(work(n)) for n in range(3))
    
        # aiorun runs forever by default, even without any coroutines left to run.
        # We'll have to manually stop the loop, but can't use asyncio.all_tasks()
        # check as aiorun's internal tasks included in it run forever.
        # instead, keep child task spawned by main task and await those.
        await asyncio.gather(*child_tasks)
        asyncio.get_running_loop().stop()
    
    
    run(main())
    
    PID: 26548
    [0] Task start!
    [1] Task start!
    [2] Task start!
    [3] Task start!
    [4] Task start!
    [5] Task start!
    Stopping the loop
    [4] Canceled!
    [5] Canceled!
    [3] Canceled!
    [1] Task done!
    [0] Task done!
    [2] Task done!
    

    Switching to trio (All OS)

    Ground-up pure python asynchronous event loop without callback soup

    """
    Task shielding demonstration with trio, all OS
    """
    import os
    
    import trio
    
    
    async def work(n):
        """Some random io intensive work to test shielding"""
        print(f"[{n}] Task start!")
        try:
            await trio.sleep(10)
    
        except trio.Cancelled:
            print(f"[{n}] Canceled!")
            raise
    
        print(f"[{n}] Task done!")
    
    
    async def shielded():
        # opening explicit concurrency context.
        # Every concurrency in trio is explicit, via Nursery that takes care of tasks.
        async with trio.open_nursery() as nursery:
    
            # shield nursery from cancellation. Now all tasks in this scope is shielded.
            nursery.cancel_scope.shield = True
    
            # spawn tasks
            for n in range(3):
                nursery.start_soon(work, n)
    
    
    async def main():
        print(f"PID: {os.getpid()}")
    
        try:
            async with trio.open_nursery() as nursery:
                nursery.start_soon(shielded)
    
                for n in range(3, 6):
                    nursery.start_soon(work, n)
    
        except (trio.Cancelled, KeyboardInterrupt):
            # Nursery always make sure all child tasks are done - either canceled or not.
            # This try-except is just here to suppress traceback. Not quite required.
            print("Nursery Cancelled!")
    
    
    trio.run(main)
    
    PID: 23684
    [3] Task start!
    [4] Task start!
    [5] Task start!
    [0] Task start!
    [1] Task start!
    [2] Task start!
    [3] Canceled!
    [4] Canceled!
    [5] Canceled!
    [0] Task done!
    [1] Task done!
    [2] Task done!
    Nursery Cancelled!
    

    Below is a tiny bit in-depth ramble on asyncio's signal handler flow.


    Pure asyncio's signal handling

    Spent full day digging into this issue - tracing, searching, reading source codes, yet can't get a complete flow. Following flow is my guess.

    Without custom signal handlers

    1. Receives SIGINT
    2. Somehow signal._signal.default_int_handler is called, raising KeyboardInterrupt
    # signal/_signal.py - probably C code
    def default_int_handler(*args, **kwargs): # real signature unknown
        """
        The default handler for SIGINT installed by Python.
        
        It raises KeyboardInterrupt.
        """
    
    1. Exception propagates, finally block runs in asyncio.run, calling asyncio.runners._cancel_all_tasks()
    # asyncio.runners
    def run(main, *, debug=None):
        ...
        loop = events.new_event_loop()
        try:
            events.set_event_loop(loop)
            if debug is not None:
                loop.set_debug(debug)
            return loop.run_until_complete(main)
        finally:
            try:
                _cancel_all_tasks(loop)  # <---- this is called
                loop.run_until_complete(loop.shutdown_asyncgens())
                loop.run_until_complete(loop.shutdown_default_executor())
            finally:
                events.set_event_loop(None)
                loop.close()
    
    1. asyncio.runners._cancel_all_tasks() cancel all tasks returned by asyncio.all_tasks
    # asyncio/runners.py
    def _cancel_all_tasks(loop):
        to_cancel = tasks.all_tasks(loop)  # <---- gets all running tasks
        if not to_cancel:                  # internally list of weakref.WeakSet '_all_tasks'
            return
    
        for task in to_cancel:  # <---- cancels all of it
            task.cancel()
    
        loop.run_until_complete(tasks.gather(*to_cancel, return_exceptions=True))
        ...
    

    At the end of execution, successful or not, any remaining tasks will receive cancellation in step 4 eventually.

    Since that asyncio.shield also adds shielded tasks to _all_tasks it won't help either.

    However, if we add custom handlers - things get a bit different.

    With custom signal handlers

    1. We add out custom signal handler via asyncio.add_signal_handler
    # asyncio/unix_events.py
    class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
        ...
        def add_signal_handler(self, sig, callback, *args):
            """Add a handler for a signal.  UNIX only.
    
            Raise ValueError if the signal number is invalid or uncatchable.
            Raise RuntimeError if there is a problem setting up the handler.
            """
            ...
            handle = events.Handle(callback, args, self, None)
            self._signal_handlers[sig] = handle  # <---- added to sig handler dict
            ...
    
    1. Receives SIGINT
    2. Somehow our event loop's _handle_signal is called, gets matching signal handler from dictionary, and add it as a callback
    # asyncio/unix_events.py
    class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
        ...
        def _handle_signal(self, sig):
            """Internal helper that is the actual signal handler."""
            handle = self._signal_handlers.get(sig)  # <---- fetches added handler
            if handle is None:
                return  # Assume it's some race condition.
            if handle._cancelled:
                self.remove_signal_handler(sig)
            else:
                self._add_callback_signalsafe(handle)  # <---- adds as callback
        ...
    
    1. Our custom callback is called

    Now default signal handler is not called, so KeyboardInterrupt haven't been raised, hence asyncio.run's try-finally block hasn't proceeded to finally yet. Therefore no asyncio.runners._cancel_all_tasks call.

    All tasks finally survived! cancel non-essential tasks manually in handler and we're good to go.