Search code examples
c#asynchronoustask-parallel-library.net-8.0

How to implement a custom Task.WhenAll that preserves the order of exceptions?


The behavior of the non-generic Task.WhenAll changed in .NET 8, and now the order of the exceptions is in chronological order instead of positional order. The new behavior is not without merits, but I do have some tests that require the old Task.WhenAll behavior to pass. So instead of changing my tests, I would prefer to use a method similar to the Task.WhenAll that propagates the exceptions in the same order that the faulted tasks are positioned in the tasks array. I am asking here how to implement this method. The signature of the method:

/// <summary>
/// Creates a task that will complete when all of the supplied tasks have completed.
/// This task contains the exceptions of the faulted tasks in their positional order
/// (not chronological order).
/// </summary>
public static Task WhenAll_OrderedExceptions(params Task[] tasks);

Here is a simple program that demostrates the desirable behavior of this method:

async Task DoAsync(string title, int delay)
{
    await Task.Delay(delay);
    throw new Exception(title);
}

Task t1 = DoAsync("A", 200);
Task t2 = DoAsync("B", 100);

try
{
    WhenAll_OrderedExceptions(t1, t2).Wait();
}
catch (AggregateException aex)
{
    Console.Write($"Error message: {aex.Message}");
}

Desirable output:

Error message: One or more errors occurred. (A) (B)

Undesirable output:

Error message: One or more errors occurred. (B) (A)

The program on .NET Fiddle.

Does anyone have any idea about how to implement this method?


Solution

  • Based on the comments a less naive implementation that will handle the wrapped exception / cancellation.

    public static Task WhenAll_OrderedExceptions(params Task[] tasks) {
        // TODO: check null/ length etc
        var tcs = new TaskCompletionSource();
        var stateMachineTask = LocalStateMachine(tcs, tasks);
    
        static async Task LocalStateMachine(TaskCompletionSource tcs, params Task[] tasks) {
            var listExceptions = new List<Exception>();
            foreach (var task in tasks) {
                try {
                    await task;
                } catch (Exception ex) {
                    // TODO: filter Cancellation out?
                    listExceptions.Add(ex);
                }
            }
    
            if (tasks.Any(t => t.IsCanceled) && tasks.All(t => !t.IsFaulted)) {
                tcs.TrySetCanceled();
            } else if (listExceptions.Count != 0) {
                // this is key overload of TrySetException
                // TrySetException(Exception) is called from the state machine'builder
                // presumably which makes for manual building of AggregateException
                // problematic
                tcs.TrySetException(listExceptions);
            } else {
                tcs.TrySetResult();
            }
        }
    
        return tcs.Task;
    }