I'm dealing with an object that is an AsyncIterator[str]
.
It gets messages from the network, and yields them as strings.
I want to create a wrapper for this stream that buffers these messages, and yields them at a regular interval.
My code looks like this:
async def buffer_stream(stream: AsyncIterator[str], buffer_time: Optional[float]) -> AsyncIterator[str]:
"""
Buffer messages from the stream, and yields them at regular intervals.
"""
last_sent_at = time.perf_counter()
buffer = ''
stop = False
while not stop:
time_to_send = False
timeout = (
max(buffer_time - (time.perf_counter() - last_sent_at), 0)
if buffer_time else None
)
try:
buffer += await asyncio.wait_for(
stream.__anext__(),
timeout=timeout
)
except asyncio.TimeoutError:
time_to_send = True
except StopAsyncIteration:
time_to_send = True
stop = True
else:
if time.perf_counter() - last_sent_at >= buffer_time:
time_to_send = True
if not buffer_time or time_to_send:
if buffer:
yield buffer
buffer = ''
last_sent_at = time.perf_counter()
As far as I can tell, the logic makes sense, but as soon as it hits the first timeout, it interrupts the stream, and exits early, before the stream is done.
I think this might be because asyncio.wait_for
specifically says:
When a timeout occurs, it cancels the task and raises TimeoutError. To avoid the task cancellation, warp it in
shield()
.
I tried wrapping it in shield:
buffer += await asyncio.wait_for(
shield(stream.__anext__()),
timeout=timeout
)
This errors out for a different reason: RuntimeError: anext(): asynchronous generator is already running
. From what I understand, that means that it's still in the process of getting the previous anext()
when it tries to get the next one, which causes an error.
Is there a proper way to do this?
Demo: https://www.sololearn.com/en/compiler-playground/cBCVnVAD4H7g
You can turn the result of stream.__anext__()
into a task (or, more generally, a future) and await it until it times out or yields a result:
async def buffer_stream(stream: AsyncIterator[str], buffer_time: Optional[float]) -> AsyncIterator[str]:
last_sent_at = time.perf_counter()
buffer = ''
stop = False
await_next = None
while not stop:
time_to_send = False
timeout = (
max(buffer_time - (time.perf_counter() - last_sent_at), 0)
if buffer_time else None
)
if await_next is None:
await_next = asyncio.ensure_future(stream.__anext__())
try:
buffer += await asyncio.wait_for(
asyncio.shield(await_next),
timeout=timeout
)
except asyncio.TimeoutError:
time_to_send = True
except StopAsyncIteration:
time_to_send = True
stop = True
else:
await_next = None
if time.perf_counter() - last_sent_at >= buffer_time:
time_to_send = True
if not buffer_time or time_to_send:
if buffer:
yield buffer
buffer = ''
last_sent_at = time.perf_counter()