Search code examples
pythonpython-triopython-anyio

How to combine streams in anyio?


How to iterate over multiple steams at once in anyio, interleaving the items as they appear?

Let's say, I want a simple equivalent of annotate-output. The simplest I could make is

#!/usr/bin/env python3

import dataclasses
from collections.abc import Sequence
from typing import TypeVar

import anyio
import anyio.abc
import anyio.streams.text

SCRIPT = r"""
for idx in $(seq 1 5); do
    printf "%s  " "$idx"
    date -Ins
    sleep 0.08
done
echo "."
"""
CMD = ["bash", "-x", "-c", SCRIPT]


def print_data(data: str, is_stderr: bool) -> None:
    print(f"{int(is_stderr)}: {data!r}")


T_Item = TypeVar("T_Item")  # TODO: covariant=True?


@dataclasses.dataclass(eq=False)
class CombinedReceiveStream(anyio.abc.ObjectReceiveStream[tuple[int, T_Item]]):
    """Combines multiple streams into a single one, annotating each item with position index of the origin stream"""

    streams: Sequence[anyio.abc.ObjectReceiveStream[T_Item]]
    max_buffer_size_items: int = 32

    def __post_init__(self) -> None:
        self._queue_send, self._queue_receive = anyio.create_memory_object_stream(
            max_buffer_size=self.max_buffer_size_items,
            # Should be: `item_type=tuple[int, T_Item] | None`
        )
        self._pending = set(range(len(self.streams)))
        self._started = False
        self._task_group = anyio.create_task_group()

    async def _copier(self, idx: int) -> None:
        assert idx in self._pending
        stream = self.streams[idx]
        async for item in stream:
            await self._queue_send.send((idx, item))
        assert idx in self._pending
        self._pending.remove(idx)
        await self._queue_send.send(None)  # Wake up the `receive` waiters, if any.

    async def _start(self) -> None:
        assert not self._started
        await self._task_group.__aenter__()
        for idx in range(len(self.streams)):
            self._task_group.start_soon(self._copier, idx, name=f"_combined_receive_copier_{idx}")
        self._started = True

    async def receive(self) -> tuple[int, T_Item]:
        if not self._started:
            await self._start()

        # Non-blocking pre-check.
        # Gathers items that are in the queue when `self._pending` is empty.
        try:
            item = self._queue_receive.receive_nowait()
        except anyio.WouldBlock:
            pass
        else:
            if item is not None:
                return item

        while True:
            if not self._pending:
                raise anyio.EndOfStream

            item = await self._queue_receive.receive()
            if item is not None:
                return item

    async def aclose(self) -> None:
        if self._started:
            self._task_group.cancel_scope.cancel()
            self._started = False
            await self._task_group.__aexit__(None, None, None)


async def amain(max_buffer_size_items: int = 32) -> None:
    async with await anyio.open_process(CMD) as proc:
        assert proc.stdout is not None
        assert proc.stderr is not None
        raw_streams = [proc.stdout, proc.stderr]
        idx_to_is_stderr = {0: False, 1: True}  # just making it explicit
        streams = [anyio.streams.text.TextReceiveStream(stream) for stream in raw_streams]
        async with CombinedReceiveStream(streams) as outputs:
            async for idx, data in outputs:
                is_stderr = idx_to_is_stderr[idx]
                print_data(data, is_stderr=is_stderr)


def main():
    anyio.run(amain)


if __name__ == "__main__":
    main()

However, this CombinedReceiveStream solution is somewhat ugly, and I would some solution should already exist. What am I overlooking?


Solution

  • This should be more safe and idiomatic.

    class CtxObj:
        """
        Add an async context manager that calls `_ctx` to run the context.
    
        Usage::
            class Foo(CtxObj):
                @asynccontextmanager
                async def _ctx(self):
                    yield self # or whatever
    
            async with Foo() as self_or_whatever:
                pass
        """
    
        async def __aenter__(self):
            self.__ctx = ctx = self._ctx()  # pylint: disable=E1101,W0201
            return await ctx.__aenter__()
    
        def __aexit__(self, *tb):
            return self.__ctx.__aexit__(*tb)
    
    
    @dataclasses.dataclass(eq=False)
    class CombinedReceiveStream(CtxObj):
        """Combines multiple streams into a single one, annotating each item with position index of the origin stream"""
    
        streams: Sequence[anyio.abc.ObjectReceiveStream[T_Item]]
        max_buffer_size_items: int = 32
    
        def __post_init__(self) -> None:
            self._queue_send, self._queue_receive = anyio.create_memory_object_stream(
                max_buffer_size=self.max_buffer_size_items,
                # Should be: `item_type=tuple[int, T_Item] | None`
            )
            self._pending = set(range(len(self.streams)))
    
        @asynccontextmanager
        async def _ctx(self):
            async with anyio.create_task_group() as tg:
                for i in self._pending:
                    tg.start_soon(self._copier, i)
    
                yield self
                tg.cancel_scope.cancel()
    
    
        async def _copier(self, idx: int) -> None:
            stream = self.streams[idx]
            async for item in stream:
                await self._queue_send.send((idx, item))
            self._pending.remove(idx)
            if not self._pending:
                await self._queue_send.aclose()
    
    
        async def receive(self) -> tuple[int, T_Item]:
            return await self._queue_receive.receive()
    
        def __aiter__(self):
            return self
    
        async def __anext__(self):
            try:
                return await self.receive()
            except anyio.EndOfStream:
                raise StopAsyncIteration() from None