How to iterate over multiple steams at once in anyio
, interleaving the items as they appear?
Let's say, I want a simple equivalent of annotate-output
. The simplest I could make is
#!/usr/bin/env python3
import dataclasses
from collections.abc import Sequence
from typing import TypeVar
import anyio
import anyio.abc
import anyio.streams.text
SCRIPT = r"""
for idx in $(seq 1 5); do
printf "%s " "$idx"
date -Ins
sleep 0.08
done
echo "."
"""
CMD = ["bash", "-x", "-c", SCRIPT]
def print_data(data: str, is_stderr: bool) -> None:
print(f"{int(is_stderr)}: {data!r}")
T_Item = TypeVar("T_Item") # TODO: covariant=True?
@dataclasses.dataclass(eq=False)
class CombinedReceiveStream(anyio.abc.ObjectReceiveStream[tuple[int, T_Item]]):
"""Combines multiple streams into a single one, annotating each item with position index of the origin stream"""
streams: Sequence[anyio.abc.ObjectReceiveStream[T_Item]]
max_buffer_size_items: int = 32
def __post_init__(self) -> None:
self._queue_send, self._queue_receive = anyio.create_memory_object_stream(
max_buffer_size=self.max_buffer_size_items,
# Should be: `item_type=tuple[int, T_Item] | None`
)
self._pending = set(range(len(self.streams)))
self._started = False
self._task_group = anyio.create_task_group()
async def _copier(self, idx: int) -> None:
assert idx in self._pending
stream = self.streams[idx]
async for item in stream:
await self._queue_send.send((idx, item))
assert idx in self._pending
self._pending.remove(idx)
await self._queue_send.send(None) # Wake up the `receive` waiters, if any.
async def _start(self) -> None:
assert not self._started
await self._task_group.__aenter__()
for idx in range(len(self.streams)):
self._task_group.start_soon(self._copier, idx, name=f"_combined_receive_copier_{idx}")
self._started = True
async def receive(self) -> tuple[int, T_Item]:
if not self._started:
await self._start()
# Non-blocking pre-check.
# Gathers items that are in the queue when `self._pending` is empty.
try:
item = self._queue_receive.receive_nowait()
except anyio.WouldBlock:
pass
else:
if item is not None:
return item
while True:
if not self._pending:
raise anyio.EndOfStream
item = await self._queue_receive.receive()
if item is not None:
return item
async def aclose(self) -> None:
if self._started:
self._task_group.cancel_scope.cancel()
self._started = False
await self._task_group.__aexit__(None, None, None)
async def amain(max_buffer_size_items: int = 32) -> None:
async with await anyio.open_process(CMD) as proc:
assert proc.stdout is not None
assert proc.stderr is not None
raw_streams = [proc.stdout, proc.stderr]
idx_to_is_stderr = {0: False, 1: True} # just making it explicit
streams = [anyio.streams.text.TextReceiveStream(stream) for stream in raw_streams]
async with CombinedReceiveStream(streams) as outputs:
async for idx, data in outputs:
is_stderr = idx_to_is_stderr[idx]
print_data(data, is_stderr=is_stderr)
def main():
anyio.run(amain)
if __name__ == "__main__":
main()
However, this CombinedReceiveStream
solution is somewhat ugly, and I would some solution should already exist. What am I overlooking?
This should be more safe and idiomatic.
class CtxObj:
"""
Add an async context manager that calls `_ctx` to run the context.
Usage::
class Foo(CtxObj):
@asynccontextmanager
async def _ctx(self):
yield self # or whatever
async with Foo() as self_or_whatever:
pass
"""
async def __aenter__(self):
self.__ctx = ctx = self._ctx() # pylint: disable=E1101,W0201
return await ctx.__aenter__()
def __aexit__(self, *tb):
return self.__ctx.__aexit__(*tb)
@dataclasses.dataclass(eq=False)
class CombinedReceiveStream(CtxObj):
"""Combines multiple streams into a single one, annotating each item with position index of the origin stream"""
streams: Sequence[anyio.abc.ObjectReceiveStream[T_Item]]
max_buffer_size_items: int = 32
def __post_init__(self) -> None:
self._queue_send, self._queue_receive = anyio.create_memory_object_stream(
max_buffer_size=self.max_buffer_size_items,
# Should be: `item_type=tuple[int, T_Item] | None`
)
self._pending = set(range(len(self.streams)))
@asynccontextmanager
async def _ctx(self):
async with anyio.create_task_group() as tg:
for i in self._pending:
tg.start_soon(self._copier, i)
yield self
tg.cancel_scope.cancel()
async def _copier(self, idx: int) -> None:
stream = self.streams[idx]
async for item in stream:
await self._queue_send.send((idx, item))
self._pending.remove(idx)
if not self._pending:
await self._queue_send.aclose()
async def receive(self) -> tuple[int, T_Item]:
return await self._queue_receive.receive()
def __aiter__(self):
return self
async def __anext__(self):
try:
return await self.receive()
except anyio.EndOfStream:
raise StopAsyncIteration() from None