Search code examples
c#.net-corejson-rpctaskcompletionsource

JSON-RPC - Handling exceptions with TaskCompletionSource (SetException)


There is a JSON-RPC API, which I'm currently implementing. It can be tested here.

The problem is that if an incorrect DTO model is passed to SendAsync<TResponse>, JsonSerializer.Deserialize is going to throw a JsonException, which is not handled by my code. I know I've got to use SetException in some way, but I don't know how to do it, so here is the question. The exception message should be printed in the console as well.

public sealed class Client : IDisposable
{
    ...

    private readonly ConcurrentDictionary<long, IResponseHandler> _handlers = new();

    ...

    public Task StartAsync(CancellationToken cancellationToken)
    {
        _ = Task.Run(async () =>
        {
            await foreach (var message in _client.Start(cancellationToken))
            {
                using var response = JsonDocument.Parse(message);

                try
                {
                    var requestId = response.RootElement.GetProperty("id").GetInt32();

                    // TODO: Handle JsonException errors via `SetException`?
                    // TODO: Show error when incorrect input parameters are filled
                    if (_handlers.ContainsKey(requestId))
                    {
                        _handlers[requestId].SetResult(message);
                        _handlers.TryRemove(requestId, out _);
                    }
                }
                catch (KeyNotFoundException)
                {
                    // My point is that a message should be processed only if it doesn't include `id`,
                    // because that means that the message is an actual web socket subscription.
                    _messageReceivedSubject.OnNext(message);
                }
            }
        }, cancellationToken);

        ...

        return Task.CompletedTask;
    }

    public Task<TResponse> SendAsync<TResponse>(string method, object @params)
    {
        var request = new JsonRpcRequest<object>
        {
            JsonRpc = "2.0",
            Id = NextId(),
            Method = method,
            Params = @params
        };

        //var tcs = new TaskCompletionSource<TResponse>();
        //_requestManager.Add(request.Id, request, tcs);

        var handler = new ResponseHandlerBase<TResponse>();
        _handlers[request.Id] = handler;

        var message = JsonSerializer.Serialize(request);

        _ = _client.SendAsync(message);

        return handler.Task;

        //return tcs.Task;
    }

    public async Task<AuthResponse?> AuthenticateAsync(string clientId, string clientSecret)
    {
        var @params = new Dictionary<string, string>
        {
            {"grant_type", "client_credentials"},
            {"client_id", clientId},
            {"client_secret", clientSecret}
        };

        var response = await SendAsync<SocketResponse<AuthResponse>>("public/auth", @params).ConfigureAwait(false);
        return response.Result;
    }

    ...

    private interface IResponseHandler
    {
        void SetResult(string payload);
    }

    private class ResponseHandlerBase<TRes> : IResponseHandler
    {
        private readonly TaskCompletionSource<TRes> _tcs = new();

        public Task<TRes> Task => _tcs.Task;

        public void SetResult(string payload)
        {
            var result = JsonSerializer.Deserialize(payload, typeof(TRes));
            _tcs.SetResult((TRes) result);
        }
    }
}

Solution

  • Coincidentally, I did something very similar while live-coding a TCP/IP chat application last week.

    Since in this case you already have an IAsyncEnumerable<string> - and since you can get messages other than responses - I recommend also exposing that IAsyncEnumerable<string>:

    public sealed class Client : IDisposable
    {
      public async IAsyncEnumerable<string> Start(CancellationToken cancellationToken)
      {
        await foreach (var message in _client.Start(cancellationToken))
        {
          // TODO: parse and handle responses for our requests
          yield return message;
        }
      }
    }
    

    You can change this to be Rx-based if you want (_messageReceivedSubject.OnNext), but I figure if you already have IAsyncEnumerable<T>, then you may as well keep the same abstraction.

    Then, you can parse and detect responses, passing along all other messages:

    public sealed class Client : IDisposable
    {
      public async IAsyncEnumerable<string> Start(CancellationToken cancellationToken)
      {
        await foreach (var message in _client.Start(cancellationToken))
        {
          var (requestId, response) = TryParseResponse(message);
          if (requestId != null)
          {
            // TODO: handle
          }
          else
          {
            yield return message;
          }
        }
    
        (long? RequestId, JsonDocument? Response) TryParseResponse(string message)
        {
          try
          {
            var document = JsonDocument.Parse(message);
            var requestId = response.RootElement.GetProperty("id").GetInt32();
            return (document, requestId);
          }
          catch
          {
            return (null, null);
          }
        }
      }
    }
    

    Then, you can define your collection of outstanding requests and handle messages that are for those requests:

    public sealed class Client : IDisposable
    {
      private readonly ConcurrentDictionary<int, TaskCompletionSource<JsonDocument>> _requests = new();
    
      public async IAsyncEnumerable<string> Start(CancellationToken cancellationToken)
      {
        await foreach (var message in _client.Start(cancellationToken))
        {
          var (requestId, response) = TryParseResponse(message);
          if (requestId != null && _requests.TryRemove(requestId.Value, out var tcs))
          {
            tcs.TrySetResult(response);
          }
          else
          {
            yield return message;
          }
        }
    
        (long? RequestId, JsonDocument? Response) TryParseResponse(string message)
        {
          try
          {
            var document = JsonDocument.Parse(message);
            var requestId = response.RootElement.GetProperty("id").GetInt32();
            return (document, requestId);
          }
          catch
          {
            return (null, null);
          }
        }
      }
    }
    

    Note the usage of ConcurrentDictionary.TryRemove, which is safer than accessing the value and then removing it.

    Now you can write your general SendAsync. As I note in my video, I prefer to split up the code that runs synchronously in SendAsync and the code that awaits the response:

    public sealed class Client : IDisposable
    {
      ...
    
      public Task<TResponse> SendAsync<TResponse>(string method, object @params)
      {
        var request = new JsonRpcRequest<object>
        {
          JsonRpc = "2.0",
          Id = NextId(),
          Method = method,
          Params = @params,
        };
    
        var tcs = new TaskCompletionSource<JsonDocument>(TaskCreationOptions.RunContinuationsAsynchronously);
        _requests.TryAdd(request.Id, tcs);
        return SendRequestAndWaitForResponseAsync();
    
        async Task<TResponse> SendRequestAndWaitForResponseAsync()
        {
          var message = JsonSerializer.Serialize(request);
          await _client.SendAsync(message);
          var response = await tcs.Task;
          return JsonSerializer.Deserialize(response, typeof(TResponse));
        }
      }
    }
    

    I've removed the "handler" concept completely, since it was just providing the type for JsonSerializer.Deserialize. Also, by using a local async method, I can use the async state machine to propagate exceptions naturally.

    Then, your higher-level methods can be built on this:

    public sealed class Client : IDisposable
    {
      ...
    
      public async Task<AuthResponse?> AuthenticateAsync(string clientId, string clientSecret)
      {
        var @params = new Dictionary<string, string>
        {
          {"grant_type", "client_credentials"},
          {"client_id", clientId},
          {"client_secret", clientSecret}
        };
    
        var response = await SendAsync<SocketResponse<AuthResponse>>("public/auth", @params);
        return response.Result;
      }
    }
    

    So the final code ends up being:

    public sealed class Client : IDisposable
    {
      private readonly ConcurrentDictionary<int, TaskCompletionSource<JsonDocument>> _requests = new();
    
      public async IAsyncEnumerable<string> Start(CancellationToken cancellationToken)
      {
        await foreach (var message in _client.Start(cancellationToken))
        {
          var (requestId, response) = TryParseResponse(message);
          if (requestId != null && _requests.TryRemove(requestId.Value, out var tcs))
          {
            tcs.TrySetResult(response);
          }
          else
          {
            yield return message;
          }
        }
    
        (long? RequestId, JsonDocument? Response) TryParseResponse(string message)
        {
          try
          {
            var document = JsonDocument.Parse(message);
            var requestId = response.RootElement.GetProperty("id").GetInt32();
            return (document, requestId);
          }
          catch
          {
            return (null, null);
          }
        }
      }
    
      public Task<TResponse> SendAsync<TResponse>(string method, object @params)
      {
        var request = new JsonRpcRequest<object>
        {
          JsonRpc = "2.0",
          Id = NextId(),
          Method = method,
          Params = @params,
        };
    
        var tcs = new TaskCompletionSource<JsonDocument>(TaskCreationOptions.RunContinuationsAsynchronously);
        _requests.TryAdd(request.Id, tcs);
        return SendRequestAndWaitForResponseAsync();
    
        async Task<TResponse> SendRequestAndWaitForResponseAsync()
        {
          var message = JsonSerializer.Serialize(request);
          await _client.SendAsync(message);
          var response = await tcs.Task;
          return JsonSerializer.Deserialize(response, typeof(TResponse));
        }
      }
    
      public async Task<AuthResponse?> AuthenticateAsync(string clientId, string clientSecret)
      {
        var @params = new Dictionary<string, string>
        {
          {"grant_type", "client_credentials"},
          {"client_id", clientId},
          {"client_secret", clientSecret}
        };
    
        var response = await SendAsync<SocketResponse<AuthResponse>>("public/auth", @params);
        return response.Result;
      }
    }
    

    You may also want to check out David Fowler's Project Bedrock, which may simplify this code quite a bit.