Search code examples
pythonpython-asyncio

Is there a way to access the original task passed to asyncio.as_completed?


I'm trying to pull tasks from an asyncio queue and call a given error handler if an exception occurred. Queued items are given as dictionaries (enqueued by enqueue_task) which contain the task, a possible error handler, and any args/kwargs the error handler may require. Since I'd like to handle any errors as the tasks complete, I map each task to the dequeued dictionary and attempt to access it if an exception occurs.

async def _check_tasks(self):
    try:
        while self._check_tasks_task or not self._check_task_queue.empty():
            tasks = []
            details = {}
            try:
                while len(tasks) < self._CHECK_TASKS_MAX_COUNT:
                    detail = self._check_task_queue.get_nowait()
                    task = detail['task']
                    tasks.append(task)
                    details[task] = detail
            except asyncio.QueueEmpty:
                pass

            if tasks:
                for task in asyncio.as_completed(tasks):
                    try:
                        await task
                    except Exception as e:
                        logger.exception('')
                        detail = details[task]
                        error_handler = detail.get('error_handler')
                        error_handler_args = detail.get('error_handler_args', [])
                        error_handler_kwargs = detail.get('error_handler_kwargs', {})

                        if error_handler:
                            logger.info('calling error handler')
                            if inspect.iscoroutinefunction(error_handler):
                                self.enqueue_task(
                                    task=error_handler(
                                        e,
                                        *error_handler_args,
                                        **error_handler_kwargs
                                    )
                                )
                            else:
                                error_handler(e, *error_handler_args, **error_handler_kwargs)
                        else:
                            logger.exception(f'Exception encountered while handling task: {str(e)}')
            else:
                await asyncio.sleep(self._QUEUE_EMPTY_SLEEP_TIME)
    except:
        logger.exception('')


def enqueue_task(self, task, error_handler=None, error_handler_args=[],
                 error_handler_kwargs={}):
    if not asyncio.isfuture(task):
        task = asyncio.ensure_future(task)

    self._app.gateway._check_task_queue.put_nowait({
        'task': task,
        'error_handler': error_handler,
        'error_handler_args': error_handler_args,
        'error_handler_kwargs': error_handler_kwargs,
    })

However, when an exception occurs, it appears the task being used as a key is not found in the details dictionary, and I receive the following error:

KeyError: <generator object as_completed.<locals>._wait_for_one at 0x7fc2d1cea308>
Exception encountered while handling task: <generator object as_completed.<locals>._wait_for_one at 0x7fc2d1cea308>
Traceback (most recent call last):
  File "/app/app/gateway/gateway.py", line 64, in _check_tasks
    detail = details[task]
KeyError: <generator object as_completed.<locals>._wait_for_one at 0x7fc2d1cea308>

When task is yielded by asyncio.as_completed, it seems to be a generator

<generator object as_completed.<locals>._wait_for_one at 0x7fc2d1cea308>

when I expect it to be a task

<Task pending coro=<GatewayL1Component._save_tick_to_stream() running at /app/app/gateway/l1.py:320> wait_for=<Future pending cb=[<TaskWakeupMethWrapper object at 0x7fc2d4380d98>()]>>

Why is task a generator instead of the original task after being yielded by asyncio.as_complete? Is there a way to access the original task?


Solution

  • Why is task a generator instead of the original task after being yielded by asyncio.as_complete?

    The problem is that as_completed is not an async iterator (which you'd exhaust with async for), but an ordinary iterator. Where an async iterator's __aiter__ can suspend while awaiting an async event, an ordinary iterator's __iter__ must immediately provide a result. It obviously cannot yield a completed task because no tasks have yet had time to complete, so it yields an awaitable object that actually waits for a task to complete. This is the object that looks like a generator.

    As another consequence of the implementation, awaiting this task gives you the result of the original task rather than a reference to the task object - in contrast to the original concurrent.futures.as_completed. This makes asyncio.as_completed less intuitive and harder to use, and there is a bug report that argues that as_completed should be made usable as an async iterator as well, providing the correct semantics. (This can be done in a backward-compatible way.)

    Is there a way to access the original task?

    As a workaround, you can create an async version of as_completed by wrapping the original task into a coroutine that finishes when the task does, and has the task as its result:

    async def as_completed_async(futures):
        loop = asyncio.get_event_loop()
        wrappers = []
        for fut in futures:
            assert isinstance(fut, asyncio.Future)  # we need Future or Task
            # Wrap the future in one that completes when the original does,
            # and whose result is the original future object.
            wrapper = loop.create_future()
            fut.add_done_callback(wrapper.set_result)
            wrappers.append(wrapper)
    
        for next_completed in asyncio.as_completed(wrappers):
            # awaiting next_completed will dereference the wrapper and get
            # the original future (which we know has completed), so we can
            # just yield that
            yield await next_completed
    

    That should allow you to get the original tasks - here is a simple test case:

    async def main():
        loop = asyncio.get_event_loop()
        fut1 = loop.create_task(asyncio.sleep(.2))
        fut1.t = .2
        fut2 = loop.create_task(asyncio.sleep(.3))
        fut2.t = .3
        fut3 = loop.create_task(asyncio.sleep(.1))
        fut3.t = .1
        async for fut in as_completed_async([fut1, fut2, fut3]):
            # using the `.t` attribute shows that we've got the original tasks
            print('completed', fut.t)
    
    asyncio.get_event_loop().run_until_complete(main())