Search code examples
c#.netexceptionasync-awaittask-parallel-library

How to properly cancel Task.WhenAll and throw the first exception?


I have multiple tasks that accept a cancellation token and call ThrowIfCancellationRequested accordingly. These tasks will run concurrently using Task.WhenAll. I want all tasks to be cancelled when any tasks throw an exception. I achieved this using Select and ContinueWith:

var cts = new CancellationTokenSource();

try
{
    var tasks = new Task[] { DoSomethingAsync(cts.Token), ... } // multiple tasks here
        .Select(task => task.ContinueWith(task =>
        {
            if (task.IsFaulted)
            {
                cts.Cancel();
            }
        }));

    await Task.WhenAll(tasks).ConfigureAwait(false);
}
catch (SpecificException)
{
    // Why is this block never reached?
}

I'm not sure if this is the best way to do this, it seems to have some issues. It appears the exception will be caught internally, code after WhenAll is always reached. I don't want the code after WhenAll to be reached when an exception has occurred, I'd rather have the exception to be thrown so I can catch it manually on another level of the call stack. What's the best way to achieve this? If possible I'd like the call stack to remain intact. If multiple exceptions occur it would be best if only the first exception is rethrown, no AggregateException.


On a related note, I tried passing the cancellation token to ContinueWith like so: task.ContinueWith(lambda, cts.Token). However, when an exception in any task occurs, this will eventually throw a TaskCanceledException instead of the exception I'm interested in. I figured I should pass the cancellation token to ContinueWith because this would cancel ContinueWith itself, which I don't think is what I want.


Solution

  • You shouldn't use ContinueWith. The correct answer is to introduce another "higher-level" async method instead of attaching a continuation to each task:

    private async Task DoSomethingWithCancel(CancellationTokenSource cts)
    {
        try
        {
            await DoSomethingAsync(cts.Token).ConfigureAwait(false);
        }
        catch
        {
            cts.Cancel();
            throw;
        }
    }
    
    var cts = new CancellationTokenSource();
    try
    {
        var tasks = new Task[] { DoSomethingWithCancel(cts), ... };
        await Task.WhenAll(tasks).ConfigureAwait(false);
    }
    catch (SpecificException)
    {
        // ...
    }