Search code examples
c#goasynchronousblockingcollectionsystem.threading.channels

How to implement the BlockingCollection.TakeFromAny equivalent for Channels?


I am trying to implement an asynchronous method that takes an array of ChannelReader<T>s, and takes a value from any of the channels that has an item available. It is a method with similar functionality with the BlockingCollection<T>.TakeFromAny method, that has this signature:

public static int TakeFromAny(BlockingCollection<T>[] collections, out T item,
    CancellationToken cancellationToken);

This method returns the index in the collections array from which the item was removed. An async method cannot have out parameters, so the API that I am trying to implement is this:

public static Task<(T Item, int Index)> TakeFromAnyAsync<T>(
    ChannelReader<T>[] channelReaders,
    CancellationToken cancellationToken = default);

The TakeFromAnyAsync<T> method should read asynchronously an item, and return the consumed item along with the index of the associated channel in the channelReaders array. In case all the channels are completed (either successfully or with an error), or all become complete during the await, the method should throw asynchronously a ChannelClosedException.

My question is: how can I implement the TakeFromAnyAsync<T> method? The implementation looks quite tricky. It is obvious that under no circumstances the method should consume more than one items from the channels. Also it should not leave behind fire-and-forget tasks, or let disposable resources undisposed. The method will be typically called in a loop, so it should also be reasonably efficient. It should have complexity not worse than O(n), where n in the number of the channels.

As an insight of where this method can be useful, you could take a look at the select statement of the Go language. From the tour:

The select statement lets a goroutine wait on multiple communication operations.

A select blocks until one of its cases can run, then it executes that case. It chooses one at random if multiple are ready.

select {
case msg1 := <-c1:
    fmt.Println("received", msg1)
case msg2 := <-c2:
    fmt.Println("received", msg2)
}

In the above example either a value will be taken from the channel c1 and assigned to the variable msg1, or a value will be taken from the channel c2 and assigned to the variable msg2. The Go select statement is not restricted to reading from channels. It can include multiple heterogeneous cases like writing to bounded channels, waiting for timers etc. Replicating the full functionality of the Go select statement is beyond the scope of this question.


Solution

  • I came up with something like this:

    
    public static async Task<(T Item, int Index)> TakeFromAnyAsync<T>(
        ChannelReader<T>[] channelReaders,
        CancellationToken cancellationToken = default)
    {
        if (channelReaders == null)
        {
            throw new ArgumentNullException(nameof(channelReaders));
        }
    
        if (channelReaders.Length == 0)
        {
            throw new ArgumentException("The list cannot be empty.", nameof(channelReaders));
        }
    
        if (channelReaders.Length == 1)
        {
            return (await channelReaders[0].ReadAsync(cancellationToken), 0);
        }
    
        // First attempt to read an item synchronosuly 
        for (int i = 0; i < channelReaders.Length; ++i)
        {
            if (channelReaders[i].TryRead(out var item))
            {
                return (item, i);
            }
        }
    
        using (var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken))
        {
    
            var waitToReadTasks = channelReaders
                    .Select(it => it.WaitToReadAsync(cts.Token).AsTask())
                    .ToArray();
    
            var pendingTasks = new List<Task<bool>>(waitToReadTasks);
    
            while (pendingTasks.Count > 1)
            {
                var t = await Task.WhenAny(pendingTasks);
    
                if (t.IsCompletedSuccessfully && t.Result)
                {
                    int index = Array.IndexOf(waitToReadTasks, t);
                    var reader = channelReaders[index];
    
                    // Attempt to read an item synchronosly
                    if (reader.TryRead(out var item))
                    {
                        if (pendingTasks.Count > 1)
                        {
                            // Cancel pending "wait to read" on the remaining readers
                            // then wait for the completion 
                            try
                            {
                                cts.Cancel();
                                await Task.WhenAll((IEnumerable<Task>)pendingTasks);
                            }
                            catch { }
                        }
                        return (item, index);
                    }
    
                    // Due to the race condition item is no longer available
                    if (!reader.Completion.IsCompleted)
                    {
                        // .. but the channel appears to be still open, so we retry
                        var waitToReadTask = reader.WaitToReadAsync(cts.Token).AsTask();
                        waitToReadTasks[index] = waitToReadTask;
                        pendingTasks.Add(waitToReadTask);
                    }
    
                }
    
                // Remove all completed tasks that could not yield 
                pendingTasks.RemoveAll(tt => tt == t || 
                    tt.IsCompletedSuccessfully && !tt.Result || 
                    tt.IsFaulted || tt.IsCanceled);
    
            }
    
            int lastIndex = 0;
            if (pendingTasks.Count > 0)
            {
                lastIndex = Array.IndexOf(waitToReadTasks, pendingTasks[0]);
                await pendingTasks[0];
            }
    
            var lastItem = await channelReaders[lastIndex].ReadAsync(cancellationToken);
            return (lastItem, lastIndex);
        }
    }