Search code examples
c#asynchronousiasyncenumerablegrpc-c#

`IAsyncEnumerable` behaviour at end of stream


IAsyncEnumerable 'end of stream' seems to behave differently depending on where the stream comes from...

If you're consuming an IAsyncEnumerable yielded from a network resource like a GRPC stream, you get an error if you try to consume the stream twice, which is what I'd expect.

So similarly in the following code, on the second foreach I'd expect it to throw - but it doesn't it yields the complete stream twice!

var stream = GetStream();

await foreach (var item in stream)
{
    Console.WriteLine(item);
}

await foreach (var item in stream)
{
    Console.WriteLine(item);
}

async IAsyncEnumerable<string> GetStream()
{
    await Task.CompletedTask;
    yield return "yes";
    yield return "no";
    yield return "maybe";
}

What's causing this stream to reset to the beginning?

I came across this problem because I was writing an integration test for code which consumes an IAsyncEnumerable which in production is from a GRPC call but for testing I was creating it as above. The production code was errantly reading the stream twice and throwing but that didn't show up in my integration test because of this behaviour.

How could I get my in-memory stream to behave like a network stream?


Solution

  • What's causing this stream to reset to the beginning?

    The fact that you're basically calling GetAsyncEnumerator() twice (which happens implicitly for a foreach loop). Each time you call GetAsyncEnumerator(), it effectively starts a new state machine at the start of the code.

    Note that this is not specific to async enumerables at all - you'd see the same thing with an iterator method returning an IEnumerable<string>.

    How could I get my in-memory stream to behave like a network stream?

    You could implement IAsyncEnumerable directly yourself, potentially using an iterator method returning IAsyncEnumerator so that you could still use yield return. Here's an example:

    // General purpose code
    public class OneShotAsyncEnumerable<T> : IAsyncEnumerable<T>
    {
        private readonly IAsyncEnumerator<T> iterator;
        private bool consumed;
    
        public OneShotAsyncEnumerable(IAsyncEnumerator<T> iterator)
        {
            this.iterator = iterator;
        }
    
        public IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default)
        {
            if (consumed)
            {
                throw new InvalidOperationException($"Can only call {nameof(GetAsyncEnumerator)} once");
            }
            consumed = true;
            return iterator;
        }
    }
    
    // Usage of the code
    var stream = new OneShotAsyncEnumerable<string>(GetStream());
    
    await foreach (var item in stream)
    {
        Console.WriteLine(item);
    }
    
    await foreach (var item in stream)
    {
        Console.WriteLine(item);
    }
    
    async IAsyncEnumerator<string> GetStream()
    {
        await Task.CompletedTask;
        yield return "yes";
        yield return "no";
        yield return "maybe";
    }
    

    Alternatively, you could write an extension method to wrap an IAsyncEnumerable<T>, e.g.

    // General purpose code
    public static class Extensions
    {
        public static IAsyncEnumerable<T> IterateOnlyOnce<T>(this IAsyncEnumerable<T> iterable) =>
            new OneShotAsyncEnumerable<T>(iterable);
    
        private class OneShotAsyncEnumerable<T> : IAsyncEnumerable<T>
        {
            private readonly IAsyncEnumerable<T> iterable;
            private bool consumed;
    
            public OneShotAsyncEnumerable(IAsyncEnumerable<T> iterable)
            {
                this.iterable = iterable;
            }
    
            public IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default)
            {
                if (consumed)
                {
                    throw new InvalidOperationException($"Can only call {nameof(GetAsyncEnumerator)} once");
                }
                consumed = true;
                return iterable.GetAsyncEnumerator(cancellationToken);
            }
        }
    }
    
    // Usage of the code
    var stream = GetStream().IterateOnlyOnce();
    
    await foreach (var item in stream)
    {
        Console.WriteLine(item);
    }
    
    await foreach (var item in stream)
    {
        Console.WriteLine(item);
    }
    
    async IAsyncEnumerable<string> GetStream()
    {
        await Task.CompletedTask;
        yield return "yes";
        yield return "no";
        yield return "maybe";
    }