Search code examples
pythonasynchronousasync-awaitpython-asyncio

python asyncio await for first n task from list to complete


I want to do async/await in python, and return once the first n task has been completed. However, in asyncio it only has three options: ALL_COMPLETED, FIRST_COMPLETED, FIRST_EXCEPTION

What I am trying to do, is get to return when FIRST_N_COMPLETED. So if I have 10 tasks in the task_list, and I want to wait for the first 3 to complete.

# task_list: list of asyncio.create_task()
done, pending = asyncio.wait(task_list)

done should have 3 completed tasks, and pending 7 uncomplete tasks

I tried creating another task with a while loop, that awaits the task_list and returns when FIRST_COMPLETED, and count up the completed task, then break when I reach 3 task. However I think it doesn't really work.

async def partialAwait(task_list, n=None):
    counter = 0
    total_tasks = len(task_list)
    if n is None or n > total_tasks:
        n = total_tasks

    while task_list and counter < n:
        done, task_list = await asyncio.await(task_list, return_when=asyncio.FIRST_COMPLETED)
        counter += len(done)

    return task_list

Solution

  • I think as_completed() might work for this? Executing a .cancel() on the remaining tasks might be good housekeeping?

    import time
    from pprint import pprint
    import asyncio
    import random
    import itertools
    from time import perf_counter
    
    
    async def worker(nap_time):
        print(f"Task {nap_time=} starting.")
        await asyncio.sleep(nap_time)
        return f"Task {nap_time=} done."
    
    
    async def main():
        start_time = perf_counter()
        tasks = [asyncio.create_task(worker(random.randint(1, 10))) for _ in range(10)]
        pprint([(t.get_name(), t._state) for t in tasks])
    
        for task in itertools.islice(asyncio.as_completed(tasks), 3):
            print(await task)
    
        for t in tasks:
            t.cancel()
    
        await asyncio.wait(tasks)
        pprint([(t.get_name(), t._state) for t in tasks])
        print(time.perf_counter() - start_time)
    
    asyncio.run(main())
    

    Output:

    [('Task-2', 'PENDING'),
     ('Task-3', 'PENDING'),
     ('Task-4', 'PENDING'),
     ('Task-5', 'PENDING'),
     ('Task-6', 'PENDING'),
     ('Task-7', 'PENDING'),
     ('Task-8', 'PENDING'),
     ('Task-9', 'PENDING'),
     ('Task-10', 'PENDING'),
     ('Task-11', 'PENDING')]
    
    Task nap_time=7 starting.
    Task nap_time=3 starting.
    Task nap_time=1 starting.
    Task nap_time=9 starting.
    Task nap_time=4 starting.
    Task nap_time=6 starting.
    Task nap_time=4 starting.
    Task nap_time=1 starting.
    Task nap_time=3 starting.
    Task nap_time=3 starting.
    
    Task nap_time=1 done.
    Task nap_time=1 done.
    Task nap_time=3 done.
    
    [('Task-2', 'CANCELLED'),
     ('Task-3', 'FINISHED'),
     ('Task-4', 'FINISHED'),
     ('Task-5', 'CANCELLED'),
     ('Task-6', 'CANCELLED'),
     ('Task-7', 'CANCELLED'),
     ('Task-8', 'CANCELLED'),
     ('Task-9', 'FINISHED'),
     ('Task-10', 'FINISHED'),
     ('Task-11', 'FINISHED')]
    
    3.002996250000251