Search code examples
c#channelproducer-consumersystem.threading.channels

Consume all messages in a System.Threading.Channels.Channel


Suppose I have a many producers, 1 consumer unbound Channel, with a consumer:

await foreach (var message in channel.Reader.ReadAllAsync(cts.Token))
{
    await consume(message);
}

The problem is that the consume function does some IO access and potentially some network access too, thus before 1 message is consumed many more may be produced. But since the IO resources can't be accessed concurently, I can't have many consumers, nor can I throw the consume function into a Task and forget it.

The consume function is such that it can be easily modified to take multiple messages and handle them all in a batch. So my question is if there's a way to make the consumer take all messages in the channel queue whenever it tries to access it, something like this:

while (true) {
    Message[] messages = await channel.Reader.TakeAll();
    await consumeAll(messages);
}

Edit: 1 option that I can come up with, is:

List<Message> messages = new();
await foreach (var message in channel.Reader.ReadAllAsync(cts.Token))
{
    await consume(message);
    Message msg;
    while (channel.Reader.TryRead(out msg))
        messages.Add(msg);
    if (messages.Count > 0)
    {
        await consumeAll(messages);
        messages.Clear();
    }
}

But I feel like thare should be a better way to do this.


Solution

  • After reading Stephen Toub's primer on channels, I had a stab at writing an extension method that should do what you need (It's been a while since I did any C#, so this was fun).

    public static class ChannelReaderEx
    {
        public static async IAsyncEnumerable<IEnumerable<T>> ReadBatchesAsync<T>(
            this ChannelReader<T> reader, 
            [EnumeratorCancellation] CancellationToken cancellationToken = default
        )
        {
            while (await reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false))
            {
                yield return reader.Flush().ToList();
            }
        }
    
        public static IEnumerable<T> Flush<T>(this ChannelReader<T> reader)
        {
            while (reader.TryRead(out T item))
            {
                yield return item;
            }
        }
    }
    

    which can be used like this:

    await foreach (var batch in channel.Reader.ReadBatchesAsync())
    {
        await ConsumeBatch(batch);
    }