Search code examples
c#asynchronousinterceptorgrpc-c#

How to make C# GRPC ClientInterceptor with streaming response wait until completion before exiting


Starting with this sample:

https://github.com/grpc/grpc-dotnet/tree/master/examples/Interceptor/Interceptor.sln

I want to change this interceptor:

https://github.com/grpc/grpc-dotnet/blob/master/examples/Interceptor/Client/ClientLoggerInterceptor.cs

I want to change this method to not return (and enter the "finally" block) until the continuation has completed.

I can't seem to figure out how to make the await happen without modifying the method signature and breaking the override.

public override AsyncServerStreamingCall<TResponse> AsyncServerStreamingCall<TRequest, TResponse>(
    TRequest request,
    ClientInterceptorContext<TRequest, TResponse> context,
    AsyncServerStreamingCallContinuation<TRequest, TResponse> continuation)
{
    LogCall(context.Method);
    AddCallerMetadata(ref context);

    // added to the example:
    var timer = new Stopwatch();
    timer.Start();

    try
    {
        return continuation(request, context);
    }
    catch (Exception ex)
    {
        LogError(ex);
        throw;
    }
    // added to the example:
    finally
    {
        // I want the timer to include the complete runtime of "continuation"
        _logger.LogInformation(
            $"Total elapsed time: {timer.ElapsedMilliseconds}");
    }
}

Solution

  • Try this

    public override AsyncServerStreamingCall<TResponse> AsyncServerStreamingCall<TRequest, TResponse>(
        TRequest request,
        ClientInterceptorContext<TRequest, TResponse> context,
        AsyncServerStreamingCallContinuation<TRequest, TResponse> continuation)
    {
        LogCall(context.Method);
        AddCallerMetadata(ref context);
    
        // added to the example:
        var timer = new Stopwatch();
        timer.Start();
    
        var call = continuation(request, context);
    
        return new AsyncServerStreamingCall<TResponse>(
            new AsyncStreamReaderErrorHandler<TResponse>(
                call.ResponseStream,
                () => _logger.LogInformation($"Total elapsed time: {timer.ElapsedMilliseconds}")),
            call.ResponseHeadersAsync,
            call.GetStatus,
            call.GetTrailers,
            call.Dispose);
    }
    
    public class AsyncStreamReaderErrorHandler<TResponse> : IAsyncStreamReader<TResponse>
    {
        private readonly IAsyncStreamReader<TResponse> _streamReader;
        private readonly Action _completionCallback;
    
        public AsyncStreamReaderErrorHandler(IAsyncStreamReader<TResponse> streamReader, Action completionCallback)
        {
            _streamReader = streamReader;
            _completionCallback = completionCallback;
        }
    
        public TResponse Current => _streamReader.Current;
    
        public async Task<bool> MoveNext(CancellationToken cancellationToken)
        {
            try
            {
                var hasNext = await _streamReader.MoveNext(cancellationToken);
    
                if (!hasNext)
                {
                    _completionCallback();
                }
    
                return hasNext;
            }
            catch (RpcException ex)
            {
                LogError(ex);
                _completionCallback();
                throw;
            }
        }
    }