Search code examples
python-trio

How to find out how often a Python coroutine is suspended when using Trio


Is there a way to find out how often a coroutine has been suspended?

For example:

import trio

async def sleep(count: int):
   for _ in range(count):
      await trio.sleep(1)

async def test():
   with track_suspensions() as ctx:
      await sleep(1)
   assert ctx.suspension_count == 1

   with track_suspensions() as ctx:
      await sleep(3)
   assert ctx.suspension_count == 3

trio.run(test)

..except that (as far as I know), track_suspensions() does not exist. Is the a way to implement it, or some other way to get this information? (I'm looking for this information to write a unit test).


Solution

  • You can't do it with a context manager.

    However, the async protocol is just a beefed-up iterator, and you can intercept Python's iterator protocol. Thus the following Tracker object counts how often the wrapped function yields.

    import trio
    
    class Tracker:
        def __init__(self):
            self.n = 0
    
        def __call__(self, /, fn,*a,**k):
            self.fn = fn
            self.a = a
            self.k = k
            return self
    
        def __await__(self):
            it = self.fn(*self.a,**self.k).__await__()
            try:
                r = None
                while True:
                    r = it.send(r)
                    r = (yield r)
                    self.n += 1
            except StopIteration as r:
                return r.value
    
    
    async def sleep(n):
        for i in range(n):
            await trio.sleep(0.1)
        return n**2
    
    
    async def main():
        t = Tracker()
        res = await t(sleep,10)
        assert t.n == 10
        assert res == 100
    
    trio.run(main)
    
    

    This code also works with asyncio (and thus anyio) mainloops.

    NB: The single slash in the __call__ argument list prevents a name clash when your function happens to require a keyword argument named "fn".