Search code examples
c#memory-leaksproducer-consumeriasyncenumerablesystem.threading.channels

How to batch a ChannelReader<T>, enforcing a maximum interval policy between consuming and processing any individual item?


I am using a Channel<T> in a producer-consumer scenario, and I have the requirement to consume the channel in batches of 10 items each, and without letting any consumed item to stay idle in a buffer for more than 5 seconds. This duration is the maximum latency allowed between reading an item from the channel, and processing a batch that contains this item. The maximum latency policy has precedence over the batch size policy, so a batch should be processed even with fewer than 10 items, in order to satisfy the max-latency requirement.

I was able to implement the first requirement, in the form of a ReadAllBatches extension method for the ChannelReader<T> class:

public static async IAsyncEnumerable<T[]> ReadAllBatches<T>(
    this ChannelReader<T> channelReader, int batchSize)
{
    List<T> buffer = new();
    while (true)
    {
        T item;
        try { item = await channelReader.ReadAsync(); }
        catch (ChannelClosedException) { break; }
        buffer.Add(item);
        if (buffer.Count == batchSize)
        {
            yield return buffer.ToArray();
            buffer.Clear();
        }
    }
    if (buffer.Count > 0) yield return buffer.ToArray();
    await channelReader.Completion; // Propagate possible failure
}

I am planning to use it like this:

await foreach (Item[] batch in myChannel.Reader.ReadAllBatches(10))
{
    Console.WriteLine($"Processing batch of {batch.Length} items");
}

My question is: how can I enhance my ReadAllBatches<T> implementation with an additional TimeSpan timeout parameter, that enforces the aforementioned maximum latency policy, and without installing third-party packages to my project?

Important: The requested implementation should not be susceptible to the memory leak issue that has been reported here. So the loop that consumes the channel should not cause the steady increment of the memory used by the application, in case the producer that writes the items in the channel has become idle for a prolonged period of time.

Note: I am aware that a similar question exists regarding batching the IAsyncEnumerable<T> interface, but I am not interested to that. I am interested for a method that targets directly the ChannelReader<T> type, for performance reasons.


Solution

  • Below is an implementation of an idea that was posted on GitHub, by tkrafael.

    I had the same "leak" issue and resolved by:

    • First read uses main token (If I have no items to handle, just wait until one arrives)
    • All the remaining items must be read in x milliseconds

    This way I will never get an empty read due to timeout cancellation token (ok, maybe one empty read when application is being shutdown) and I get correct behaviour when items arrives from channel's writer.

    The internal CancellationTokenSource is scheduled with a timer for cancellation, immediately after consuming the first element in the batch.

    /// <summary>
    /// Reads all of the data from the channel in batches, enforcing a maximum
    /// interval policy between consuming an item and emitting it in a batch.
    /// </summary>
    public static IAsyncEnumerable<T[]> ReadAllBatches<T>(
        this ChannelReader<T> source, int batchSize, TimeSpan timeSpan)
    {
        ArgumentNullException.ThrowIfNull(source);
        if (batchSize < 1) throw new ArgumentOutOfRangeException(nameof(batchSize));
        if (timeSpan < TimeSpan.Zero)
            throw new ArgumentOutOfRangeException(nameof(timeSpan));
        return Implementation();
    
        async IAsyncEnumerable<T[]> Implementation(
            [EnumeratorCancellation] CancellationToken cancellationToken = default)
        {
            CancellationTokenSource timerCts = CancellationTokenSource
                .CreateLinkedTokenSource(cancellationToken);
            try
            {
                List<T> buffer = new();
                while (true)
                {
                    CancellationToken token = buffer.Count == 0 ?
                        cancellationToken : timerCts.Token;
                    (T Value, bool HasValue) item;
                    try
                    {
                        item = (await source.ReadAsync(token).ConfigureAwait(false), true);
                    }
                    catch (ChannelClosedException) { break; }
                    catch (OperationCanceledException)
                    {
                        if (cancellationToken.IsCancellationRequested) break;
                        // Timeout occurred.
                        Debug.Assert(timerCts.IsCancellationRequested);
                        Debug.Assert(buffer.Count > 0);
                        item = default;
                    }
                    if (buffer.Count == 0) timerCts.CancelAfter(timeSpan);
                    if (item.HasValue)
                    {
                        buffer.Add(item.Value);
                        if (buffer.Count < batchSize) continue;
                    }
                    yield return buffer.ToArray();
                    buffer.Clear();
                    if (!timerCts.TryReset())
                    {
                        timerCts.Dispose();
                        timerCts = CancellationTokenSource
                            .CreateLinkedTokenSource(cancellationToken);
                    }
                }
                // Emit what's left before throwing exceptions.
                if (buffer.Count > 0) yield return buffer.ToArray();
    
                cancellationToken.ThrowIfCancellationRequested();
    
                // Propagate possible failure of the channel.
                if (source.Completion.IsCompleted)
                    await source.Completion.ConfigureAwait(false);
            }
            finally { timerCts.Dispose(); }
        }
    }
    

    Usage example:

    await foreach (Item[] batch in myChannel.Reader
        .ReadAllBatches(10, TimeSpan.FromSeconds(5)))
    {
        Console.WriteLine($"Processing batch of {batch.Length} items");
    }
    

    This implementation is non-destructive, meaning that no items that have been consumed from the channel are in danger of being lost. In case the enumeration is canceled or the channel is faulted, any consumed items will be emitted in a final batch, before the propagation of the error.

    Note: In case the source ChannelReader<T> is completed at the same time that the cancellationToken is canceled, the cancellation has precedence over completion. This is the same behavior with all native methods of the ChannelReader<T> and ChannelWriter<T> classes. It means that it's possible (although rare) for an OperationCanceledException to be thrown, even in case all the work has completed.