Search code examples
c#asynchronousasync-awaitcancellation-tokeniasyncenumerable

Nested IAsyncEnumerable function requires EnumeratorCancellation attribute?


I'd like to know if I need to use EnumeratorCancellation when passing a cancellation token to my local function. I am thinking of using this code pattern often in the future:

public static IAsyncEnumerable<string> MyOuterFunctionAsync(
this Client client, 
CancellationToken cancellationToken,
int? limit = null) 
{
    return MyLocalFunction().
        TakeWhile(
            (_, index) => 
                limit is null ||
                index < limit.Value);
 
    async IAsyncEnumerable<string> MyLocalFunction()
    {
       var request = CreateRequest();

       do 
       {
            var page = await request.GetAsync(cancellationToken);
            foreach (var item in page) 
            {
                yield return item;
            }
            request = GetNextPageRequest();
       }
       while (request is not null)
    }
}

Resharper doesn't mention the need for EnumeratorCancellation, and when I try to add it to the outer function it says it will have no effect, but if I try adding it to the local function Resharper stays happy, as without. Should I use it anywhere? I checked the IL viewer but I don't see any difference between the versions.

Will MyOuterFunctionAsync work properly? Do I need to change MyLocalFunction signature to

async IAsyncEnumerable<string> MyLocalFunction(
    [EnumeratorCancellation] CancellationToken cancellationToken)

Solution

  • To answer directly your question, the MyOuterFunctionAsync will recognize correctly a CancellationToken passed as argument, but not if the CancellationToken is passed with the WithCancellation operator. For example if:

    var sequence = client.MyOuterFunctionAsync(CancellationToken.None)
        .WithCancellation(token);
    

    ...the token will be ignored. For correct behavior you do have to add the EnumeratorCancellation attribute in the MyLocalFunction.

    You could consider adopting the pattern used in the System.Linq.Async library. For example the Where LINQ operator:

    public static IAsyncEnumerable<TSource> Where<TSource>(
        this IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate)
    {
        if (source == null)
            throw Error.ArgumentNull(nameof(source));
        if (predicate == null)
            throw Error.ArgumentNull(nameof(predicate));
    
        return Core(source, predicate);
    
        static async IAsyncEnumerable<TSource> Core(IAsyncEnumerable<TSource> source,
            Func<TSource, bool> predicate,
            [EnumeratorCancellation] CancellationToken cancellationToken = default)
        {
            await foreach (var element in source.WithCancellation(cancellationToken)
                .ConfigureAwait(false))
            {
                if (predicate(element))
                {
                    yield return element;
                }
            }
        }
    }
    

    As you see the outer Where is not async, and the inner Core is async. The public signature of the Where doesn't include a CancellationToken parameter. The caller can always use the WithCancellation operator to attach a CancellationToken to an asynchronous sequence, so including a parameter is redundant. For example:

    var query = sequence.Where(x => x.IsAvailable).WithCancellation(token);
    

    On the other hand the local Core implementation of the operator does include a CancellationToken parameter, which is also decorated with the EnumeratorCancellation attribute. When the caller uses the Where and attaches a token with the WithCancellation, this token is automatically passed to the Core implementation because of the EnumeratorCancellation attribute.

    So the general rule is: A method that returns an IAsyncEnumerable<T> should include an CancellationToken parameter only if it's implemented with async, in which case the parameter should also be decorated with the EnumeratorCancellation attribute.

    Ideally public APIs that return IAsyncEnumerable<T> should not be implemented with async. That's because giving to the caller two different options to pass a token, either directly or through the WithCancellation, creates confusion without adding any value to the API. As an example of what not to do, see the implementation of the ChannelReader<T>.ReadAllAsync API.