Search code examples
pythonpython-3.xasynchronouspython-asyncioaiohttp

Python Asyncio task is running without gather()


I was trying to reproduce & better understand the TaskPool example in this blog post by Cristian Garcia, and I ran into a very interesting result.

Here are the two scripts that I used. I swapped out an actual network request with a random sleep call

#task_pool.py
import asyncio

class TaskPool(object):

    def __init__(self, workers):
        self._semaphore = asyncio.Semaphore(workers)
        self._tasks = set()

    async def put(self, coro):
        await self._semaphore.acquire()
        task = asyncio.create_task(coro)
        self._tasks.add(task)
        task.add_done_callback(self._on_task_done)

    def _on_task_done(self, task):
        self._tasks.remove(task)
        self._semaphore.release()

    async def join(self):
        await asyncio.gather(*self._tasks)

    async def __aenter__(self):
        return self

    def __aexit__(self, exc_type, exc, tb):
        print("aexit triggered")
        return self.join()

And

# main.py
import asyncio
import sys
from task_pool import TaskPool
import random
limit = 3

async def fetch(i):
    timereq = random.randrange(5)
    print("request: {} start, delay: {}".format(i, timereq))
    await asyncio.sleep(timereq)
    print("request: {} end".format(i))
    return (timereq,i)

async def _main(total_requests):
    async with TaskPool(limit) as tasks:
        for i in range(total_requests):
            await tasks.put(fetch(i))

loop = asyncio.get_event_loop()
loop.run_until_complete(_main(int(sys.argv[1])))

The command main.py 10 on python 3.7.1 yields the following result.

request: 0 start, delay: 3
request: 1 start, delay: 3
request: 2 start, delay: 3
request: 0 end
request: 1 end
request: 2 end
request: 3 start, delay: 4
request: 4 start, delay: 1
request: 5 start, delay: 0
request: 5 end
request: 6 start, delay: 1
request: 4 end
request: 6 end
request: 7 start, delay: 1
request: 8 start, delay: 4
request: 7 end
aexit triggered
request: 9 start, delay: 1
request: 9 end
request: 3 end
request: 8 end

I have a few questions based on this result.

  1. I would not have expected the tasks to run until the context manager exited and triggered __aexit__, because that is the only trigger for asyncio.gather. However the print statements strongly suggest that the fetch jobs are occuring even before the aexit. What's happening, exactly? Are the tasks running? If so, what started them?
  2. Related to (1). Why is the context manager exiting before all the jobs have returned?
  3. The fetch job is supposed to return a tuple. How can I access this value? For a web-based application, I imagine the developer may want to do operations on the data returned by the website.

Any help is greatly appreciated!


Solution

    1. A task starts as soon as create_task is called.

      Straight from the documentation, first line:

      Wrap the coro coroutine into a Task and schedule its execution.

    2. it should not, but. Look at the code in your question:

      def __aexit__(self, exc_type, exc, tb):
          print("aexit triggered")
          return self.join()
      

      There are three issues:

      • This is a regular synchronous function. Change it to async def and add the mandatory await for invoking self.join(). Here you don't call join you just create the task but never run it. Your python surely complains about you never awaiting a task. Those warnings must never be ignored because they mean something is going very wrong in your program.

        [edit:] as user4815162342 pointed out below, the construction you wrote will actually work, though probably not for the intended reasons — it works because the coroutine function returned by calling self.join() without awaiting it will be returned and used as if it was aexit's own. You don't want this, make it async and await.

      • Once this is fixed, __aexit__ will print "aexit triggered" and then calls join, which waits for tasks to complete. Therefore messages from tasks not yet completed will appear after the "aexit triggered" message.

      • The return value of __aexit__ is ignored, unless the exit happens because an exception was raised. In that case, return True will swallow the exception. Drop the return

      So that part, fixed:

      async def __aexit__(self, exc_type, exc, tb):
          print("aexit triggered")
          await self.join()
          print("aexit completed")
      
    3. Your TaskPool must make the result of tasks available. It is yours to design, python will not do any magic under the hood. From what you have, a simple way would be for join to store the result of gather as an attribute of the task pool.