Search code examples
pythonasynchronouspython-asyncio

How can I iterate over an AsyncIterator stream in Python with a timeout, without cancelling the stream?


I'm dealing with an object that is an AsyncIterator[str]. It gets messages from the network, and yields them as strings. I want to create a wrapper for this stream that buffers these messages, and yields them at a regular interval.

My code looks like this:

async def buffer_stream(stream: AsyncIterator[str], buffer_time: Optional[float]) -> AsyncIterator[str]:
    """
    Buffer messages from the stream, and yields them at regular intervals.
    """
    last_sent_at = time.perf_counter()
    buffer = ''

    stop = False
    while not stop:
        time_to_send = False

        timeout = (
            max(buffer_time - (time.perf_counter() - last_sent_at), 0)
            if buffer_time else None
        )
        try:
            buffer += await asyncio.wait_for(
                stream.__anext__(),
                timeout=timeout
            )
        except asyncio.TimeoutError:
            time_to_send = True
        except StopAsyncIteration:
            time_to_send = True
            stop = True
        else:
            if time.perf_counter() - last_sent_at >= buffer_time:
                time_to_send = True

        if not buffer_time or time_to_send:
            if buffer:
                yield buffer
                buffer = ''
            last_sent_at = time.perf_counter()

As far as I can tell, the logic makes sense, but as soon as it hits the first timeout, it interrupts the stream, and exits early, before the stream is done.

I think this might be because asyncio.wait_for specifically says:

When a timeout occurs, it cancels the task and raises TimeoutError. To avoid the task cancellation, warp it in shield().

I tried wrapping it in shield:

buffer += await asyncio.wait_for(
    shield(stream.__anext__()),
    timeout=timeout
)

This errors out for a different reason: RuntimeError: anext(): asynchronous generator is already running. From what I understand, that means that it's still in the process of getting the previous anext() when it tries to get the next one, which causes an error.

Is there a proper way to do this?

Demo: https://www.sololearn.com/en/compiler-playground/cBCVnVAD4H7g


Solution

  • You can turn the result of stream.__anext__() into a task (or, more generally, a future) and await it until it times out or yields a result:

    async def buffer_stream(stream: AsyncIterator[str], buffer_time: Optional[float]) -> AsyncIterator[str]:
        last_sent_at = time.perf_counter()
        buffer = ''
    
        stop = False
        await_next = None
        while not stop:
            time_to_send = False
    
            timeout = (
                max(buffer_time - (time.perf_counter() - last_sent_at), 0)
                if buffer_time else None
            )
            if await_next is None:
                await_next = asyncio.ensure_future(stream.__anext__())
            try:
                buffer += await asyncio.wait_for(
                    asyncio.shield(await_next),
                    timeout=timeout
                )
            except asyncio.TimeoutError:
                time_to_send = True
            except StopAsyncIteration:
                time_to_send = True
                stop = True
            else:
                await_next = None
                if time.perf_counter() - last_sent_at >= buffer_time:
                    time_to_send = True
    
            if not buffer_time or time_to_send:
                if buffer:
                    yield buffer
                    buffer = ''
                last_sent_at = time.perf_counter()