Search code examples
pythonpython-asyncio

Automatic awaiting on coroutines when actual values are needed?


Assume I want to implement the following using asyncio:

def f(): 
    val1 = a()  # a() takes 1 sec
    val2 = b()  # b() takes 3 sec
    val3 = c(val1, val2)  # c() takes 1 sec, but must wait for a() and b() to finish
    val4 = d(val1)  # d() takes 1 sec, but must wait for a() to finish

all functions a, b, c, d are asynchronous and could potentially run in parallel. The optimized way to run this would be: 1) run a() and b() in parallel. 2) when a() is done, run d(). 3) when a() and b() are done, run c(). Everything together should take 4 seconds.

I find that implementing that with asyncio is not ideal:

import time
import asyncio

async def a():
    await asyncio.sleep(1)

async def b():
    await asyncio.sleep(3)

async def c(val1, val2):
    await val2
    await asyncio.sleep(1)

async def d(val1):
    await val1
    await asyncio.sleep(1)

async def f():
    val1 = a()
    val2 = b()
    val3 = c(val1, val2)
    val4 = d(val1)
    return await asyncio.gather(val3, val4)

t1 = time.time()
await f()
t2 = time.time()
print(t2 - t1)  # This will be 4 seconds indeed

The above implementation works, but the main flow is that I need to know that a() finishes before b(), in order to await val1 in d() and not await it in c(). In other words, given a (possibly complex) execution graph, I have to know which functions finish before others, in order to place the "await" statement in the right place. It I await the same coroutine in two places, I get an exception.

My question is the following: is there a mechanism in asyncio (or other python module), that awaits on coroutines automatically, just when they are needed to be resolved to actual values? I know that such mechanism is implemented in other parallel execution mechanisms.


Solution

  • There are many ways how to do it. One possibility is to use synchronization primitives, such as asyncio.Event. For example:

    import time
    import asyncio
    
    
    val1 = None
    val2 = None
    
    event_a = None
    event_b = None
    
    
    async def a():
        global val1
        await asyncio.sleep(1)  # some computation
        val1 = 1
        event_a.set()
    
    
    async def b():
        global val2
        await asyncio.sleep(3)
        val2 = 100
        event_b.set()
    
    
    async def c():
        await event_a.wait()
        await event_b.wait()
    
        await asyncio.sleep(1)
    
        return val1 + val2
    
    
    async def d():
        await event_a.wait()
    
        await asyncio.sleep(1)
    
        return val1 * 2
    
    
    async def f():
        global event_a
        global event_b
    
        event_a = asyncio.Event()
        event_b = asyncio.Event()
    
        out = await asyncio.gather(a(), b(), c(), d())
        assert out[2] == 101
        assert out[3] == 2
    
    
    async def main():
        t1 = time.time()
        await f()
        t2 = time.time()
        print(t2 - t1)
    
    
    if __name__ == "__main__":
        asyncio.run(main())
    

    Prints:

    4.0029356479644775
    

    Another option is to split computation to more coroutines, for example:

    import time
    import asyncio
    
    
    async def a():
        await asyncio.sleep(1)  # some computation
        return 1
    
    
    async def b():
        await asyncio.sleep(3)
        return 100
    
    
    async def c(val1, val2):
        await asyncio.sleep(1)
        return val1 + val2
    
    
    async def d(val1):
        await asyncio.sleep(1)
        return val1 * 2
    
    
    async def f():
        async def task1():
            params = await asyncio.gather(a(), b())  # <-- run a() and b() in parallel
            return await c(*params)
    
        async def task2():
            return await d(await a())
    
        out = await asyncio.gather(task1(), task2())
        assert out[0] == 101
        assert out[1] == 2
    
    
    async def main():
        t1 = time.time()
        await f()
        t2 = time.time()
        print(t2 - t1)
    
    
    if __name__ == "__main__":
        asyncio.run(main())
    

    Prints:

    4.00294041633606