Search code examples
pythonpoolpython-asyncio

How to make a asyncio pool cancelable?


I have a pool_map function that can be used to limit the number of simultaneously executing functions.

The idea is to have a coroutine function accepting a single parameter that is mapped to a list of possible parameters, but to also wrap all function calls into a semaphore acquisition, whereupon only a limited number is running at once:

from typing import Callable, Awaitable, Iterable, Iterator
from asyncio import Semaphore

A = TypeVar('A')
V = TypeVar('V')

async def pool_map(
    func: Callable[[A], Awaitable[V]],
    arg_it: Iterable[A],
    size: int=10
) -> Generator[Awaitable[V], None, None]:
    """
    Maps an async function to iterables
    ensuring that only some are executed at once.
    """
    semaphore = Semaphore(size)

    async def sub(arg):
        async with semaphore:
            return await func(arg)

    return map(sub, arg_it)

I modified and didn’t test above code for the sake of an example, but my variant works well. E.g. you can use it like this:

from asyncio import get_event_loop, coroutine, as_completed
from contextlib import closing

URLS = [...]

async def run_all(awaitables):
    for a in as_completed(awaitables):
        result = await a
        print('got result', result)

async def download(url): ...


if __name__ != '__main__':
    pool = pool_map(download, URLS)

    with closing(get_event_loop()) as loop:
        loop.run_until_complete(run_all(pool))

But a problem arises if there is an exception thrown while awaiting a future. I can’t see how to cancel all scheduled or still-running tasks, neither the ones still waiting for the semaphore to be acquired.

Is there a library or an elegant building block for this that I don’t know, or do I have to build all parts myself? (i.e. a Semaphore with access to its waiters, a as_finished that provides access to its running task queue, …)


Solution

  • Use ensure_future to get a Task instead of a coroutine:

    import asyncio
    from contextlib import closing
    
    
    def pool_map(func, args, size=10):
        """
        Maps an async function to iterables
        ensuring that only some are executed at once.
        """
        semaphore = asyncio.Semaphore(size)
    
        async def sub(arg):
            async with semaphore:
                return await func(arg)
    
        tasks = [asyncio.ensure_future(sub(x)) for x in args]
    
        return tasks
    
    
    async def f(n):
        print(">>> start", n)
    
        if n == 7:
            raise Exception("boom!")
    
        await asyncio.sleep(n / 10)
    
        print("<<< end", n)
        return n
    
    
    async def run_all(tasks):
        exc = None
        for a in asyncio.as_completed(tasks):
            try:
                result = await a
                print('=== result', result)
            except asyncio.CancelledError as e:
                print("!!! cancel", e)
            except Exception as e:
                print("Exception in task, cancelling!")
                for t in tasks:
                    t.cancel()
                exc = e
        if exc:
            raise exc
    
    
    pool = pool_map(f, range(1, 20), 3)
    
    with closing(asyncio.get_event_loop()) as loop:
        loop.run_until_complete(run_all(pool))