Search code examples
asp.net-coreauthentication.net-corejwt.net-8.0

Double Response Issue with JWT Authentication in ASP.NET Core


I'm facing an issue with my JWT authentication in ASP.NET Core where I'm getting duplicate responses.

Here's the relevant code for the JWT authentication:

public static class JwtAuthExtension
{
    public static void AddGatewayCustomJwtAuthentication(this IServiceCollection services)
    {
        AuthDbConfigurations _authConfiguration = AuthDbConfigurations.Instance;
        _authConfiguration.Initialize();

        services.AddAuthentication(options =>
        {
            options.DefaultAuthenticateScheme = JwtBearerDefaults.AuthenticationScheme;
            options.DefaultChallengeScheme = JwtBearerDefaults.AuthenticationScheme;
        })
        .AddJwtBearer(JwtBearerDefaults.AuthenticationScheme, options =>
        {
            options.RequireHttpsMetadata = true;
            options.SaveToken = true;
            options.TokenValidationParameters = new TokenValidationParameters
            {
                ValidateIssuerSigningKey = true,
                ValidateIssuer = true,
                ValidateAudience = true,
                ValidateLifetime = true,
                ClockSkew = TimeSpan.Zero,
                IssuerSigningKey = new SymmetricSecurityKey(Encoding.UTF8.GetBytes("F-JaNdRfUserjd89#5*6Xn2r5usErw8x/A?D(G+KbPeShV")),
                ValidIssuer = "SecureIssuer",
                ValidAudience = "SecureIssuerClient"
            };
            options.Events = new JwtBearerEvents
            {
                OnMessageReceived = ctx =>
                {
                    var actionName = ctx.Request.RouteValues["action"]?.ToString() ?? String.Empty;

                    if (ctx.Request.Cookies.ContainsKey(CommonConstants.TokenName) &&
                    !actionName.Equals("Authenticate", StringComparison.OrdinalIgnoreCase) &&
                    !actionName.Equals("RenewToken", StringComparison.OrdinalIgnoreCase))
                    {
                        ctx.Token = ctx.Request.Cookies[CommonConstants.TokenName];
                    }
                    return Task.CompletedTask;
                },
                OnAuthenticationFailed = ctx =>
                {
                    if (!ctx.Response.HasStarted)
                    {
                        if (ctx.Exception is SecurityTokenInvalidSignatureException || ctx.Exception is SecurityTokenMalformedException)
                        {
                            ctx.Response.Headers.Add("X-TOKEN-TAMPERED", "true");
                            ctx.Response.StatusCode = 498;
                            ctx.Response.ContentType = "application/json";
                            var tamperedPayload = new TokenTamperedResponse();
                            var tamperedMessage = JsonConvert.SerializeObject(tamperedPayload);
                            return ctx.Response.WriteAsync(tamperedMessage);
                        }
                        else if (ctx.Exception is SecurityTokenExpiredException)
                        {
                            ctx.Response.Headers.Add("IS-TOKEN-EXPIRED", "true");
                            ctx.Response.StatusCode = StatusCodes.Status401Unauthorized;
                            ctx.Response.ContentType = "application/json";
                            var payload = new UnAuthorizedResponse("Token Expired");
                            var errorMessage = JsonConvert.SerializeObject(payload);
                            return ctx.Response.WriteAsync(errorMessage);
                        }
                    }

                    return Task.CompletedTask;
                },
                OnChallenge = ctx =>
                {
                    if (!ctx.Response.HasStarted)
                    {
                        ctx.HandleResponse();

                        if (ctx.Response.Headers.ContainsKey("X-TOKEN-TAMPERED"))
                        {
                            ctx.Response.StatusCode = 498;
                            ctx.Response.ContentType = "application/json";
                            var tamperedPayload = new TokenTamperedResponse();
                            var tamperedMessage = JsonConvert.SerializeObject(tamperedPayload);
                            return ctx.Response.WriteAsync(tamperedMessage);
                        }
                        else
                        {
                            ctx.Response.StatusCode = StatusCodes.Status401Unauthorized;
                            ctx.Response.ContentType = "application/json";
                            var payload = new UnAuthorizedResponse();
                            var errorMessage = JsonConvert.SerializeObject(payload);
                            return ctx.Response.WriteAsync(errorMessage);
                        }
                    }

                    return Task.CompletedTask;
                },
                OnTokenValidated = async ctx =>
                {
                    await Task.CompletedTask;
                }
            };
        });
    }
}

I have tried various things like checking whether the response has started or not by putting the code in if condition and while debugging the code I found that it is going to my RequestResponseLoggingMiddleware which is registered in my Gateway class library project. Below is the code of RequestResponseLoggingMiddleware

public class RequestResponseLoggingMiddleware
{
    private readonly ILogger<RequestResponseLoggingMiddleware> _logger;
    private readonly RequestDelegate _next;

    public RequestResponseLoggingMiddleware(RequestDelegate next, ILogger<RequestResponseLoggingMiddleware> logger)
    {
        _next = next;
        _logger = logger;
    }

    public async Task InvokeAsync(HttpContext context)
    {
        context.Request.EnableBuffering();

        var builder = new StringBuilder();

        var request = await FormatRequest(context.Request);

        builder.Append(Environment.NewLine);
        builder.Append(Environment.NewLine);
        builder.Append("Request: ").AppendLine(request);
        builder.AppendLine("Request headers:");
        foreach (var header in context.Request.Headers)
        {
            builder.Append(header.Key).Append(':').AppendLine(header.Value);
        }

        //Copy a pointer to the original response body stream
        var originalBodyStream = context.Response.Body;

        //Create a new memory stream...
        using var responseBody = new MemoryStream();
        //...and use that for the temporary response body
        context.Response.Body = responseBody;

        //Continue down the Middleware pipeline, eventually returning to this class
        await _next(context);

        //Format the response from the server
        var response = await FormatResponse(context.Response);
        builder.Append(Environment.NewLine);
        builder.Append("Response: ").AppendLine(response);
        builder.AppendLine("Response headers: ");
        foreach (var header in context.Response.Headers)
        {
            builder.Append(header.Key).Append(':').AppendLine(header.Value);
        }

        //Save log to chosen datastore
        _logger.LogInformation(builder.ToString());

        //Copy the contents of the new memory stream (which contains the response) to the original stream, which is then returned to the client.
        await responseBody.CopyToAsync(originalBodyStream);
    }

    private static async Task<string> FormatRequest(HttpRequest request)
    {
        // Leave the body open so the next middleware can read it.
        using var reader = new StreamReader(
            request.Body,
            encoding: Encoding.UTF8,
            detectEncodingFromByteOrderMarks: false,
            leaveOpen: true);
        var body = await reader.ReadToEndAsync();
        // Do some processing with body…

        var formattedRequest = $"{request.Scheme} {request.Host}{request.Path} {request.QueryString} {body}";

        // Reset the request body stream position so the next middleware can read it
        request.Body.Position = 0;

        return formattedRequest;
    }

    private async Task<string> FormatResponse(HttpResponse response)
    {
        //We need to read the response stream from the beginning...
        response.Body.Seek(0, SeekOrigin.Begin);

        //...and copy it into a string
        string text = await new StreamReader(response.Body).ReadToEndAsync();

        //We need to reset the reader for the response so that the client can read it.
        response.Body.Seek(0, SeekOrigin.Begin);

        //Return the string for the response, including the status code (e.g. 200, 404, 401, etc.)
        return $"{response.StatusCode}: {text}";
    }
}

After tampering the token value I am expecting that after that if any HTTP request will be made it will return my TokenTampered response but instead I was getting two appended response which I had shown below API Response after token tampering

{ "StatusCode": 498, "Succeeded": false, "Message": "Token Tampered" }{ "StatusCode": 498, "Succeeded": false, "Message": "Token Tampered" }

which is expected to come with only a single response.


Solution

  • I encountered the exact same problem recently. I leave a working code example below. Steps you should pay attention to while reading response body;

    1- You should assign the original body into a variable

    var originalBodyStream = context.Response.Body;
    

    2- After reading the body, you should set the position of the stream position you used to read the body to 0 and copy it to the original body stream.

    emptyResponseBody.Seek(0, SeekOrigin.Begin);
    await emptyResponseBody.CopyToAsync(originalBodyStream);
    

    3- You must assign the original body that you assigned to the variable in step 1 back to the response body.

     context.Response.Body = originalBodyStream;
    

    When you perform these 3 steps, you will perform the same operation as context.Request.EnableBuffering(); or context.Request.EnableRewind(); operations.

    You've already done steps 1 and 2. I think the problem will be solved if you add the assignment process in step 3 to the end of the middleware.

    I am sharing my own working code block to help you.

    PS: Using RecyclableMemoryStreamManager may provide performance benefits in such operations.

    Also, as @Md Farid Uddin Kiron mentioned in the comment, you should make sure that the middleware order is correct. My RequestResponseMiddleware is at the top of the request pipeline.


    Here is my RequestResponseLoggingMiddleware.cs;

    
        using Microsoft.IO;
        using System.Diagnostics;
        using System.Text;
        using System.Text.RegularExpressions;
        
        namespace MyApp.Api.Middlewares;
        
    
        public class RequestResponseLoggingMiddleware(RequestDelegate next, ILoggerFactory loggerFactory)
        {
            private readonly RequestDelegate _next = next;
            private readonly ILogger _logger = loggerFactory.CreateLogger<RequestResponseLoggingMiddleware>();
            private readonly RecyclableMemoryStreamManager _recyclableMemoryStreamManager = new();
            private readonly Stopwatch _sw = new();
            private readonly RequestInfo _requestInfo = new();
        
            public async Task Invoke(HttpContext context)
            {
                var activity = new Activity("LoggingActivity");
        
                activity.Start();
        
                _sw.Restart();
        
                await ReadRequestAsync(context);
                await LogResponseAsync(context);
        
                activity.Dispose();
            }
        
            private async Task ReadRequestAsync(HttpContext context)
            {
                context.Request.EnableBuffering();
        
                await using var requestStream = _recyclableMemoryStreamManager.GetStream();
        
                await context.Request.Body.CopyToAsync(requestStream);
        
                _requestInfo.Method = context.Request.Method;
        
                _requestInfo.Body = await ReadStreamInChunksAsync(requestStream);
        
                _requestInfo.Headers = context.Request.Headers;
        
                _requestInfo.ContentLength = context.Request.ContentLength ?? 0;
        
                _requestInfo.AbsoluteUri = string.Concat(context.Request.Scheme,
                                                         "://",
                                                         context.Request.Host.ToUriComponent(),
                                                         context.Request.PathBase.ToUriComponent(),
                                                         context.Request.Path.ToUriComponent(),
                                                         context.Request.QueryString.ToUriComponent());
        
                _requestInfo.QueryString = context.Request.QueryString.ToUriComponent();
        
                context.Request.Body.Position = 0;
            }
        
            private async Task LogResponseAsync(HttpContext context)
            {
                var originalBodyStream = context.Response.Body;
        
                try
                {
                    await using var emptyResponseBody = _recyclableMemoryStreamManager.GetStream();
        
                    context.Response.Body = emptyResponseBody;
        
                    await _next(context);
        
                    emptyResponseBody.Seek(0, SeekOrigin.Begin);
        
                    using var streamReader = new StreamReader(emptyResponseBody);
        
                    var responseBody = await streamReader.ReadToEndAsync();
        
                    if (!IsUIReqeust(context))
                        responseBody = Regex.Unescape(responseBody);
        
                    emptyResponseBody.Seek(0, SeekOrigin.Begin);
          
                    _sw.Stop();
        
                    _logger.LogInformation("your-log");
        
                    await emptyResponseBody.CopyToAsync(originalBodyStream);
                }
                finally
                {
                    context.Response.Body = originalBodyStream;
                }
            }
        
            private static async Task<string> ReadStreamInChunksAsync(RecyclableMemoryStream stream)
            {
                const int readChunkBufferLength = 4096;
        
                stream.Seek(0, SeekOrigin.Begin);
        
                using var textWriter = new StringWriter();
        
                using var reader = new StreamReader(stream);
        
                var readChunk = new char[readChunkBufferLength];
        
                int readChunkLength;
        
                do
                {
                    readChunkLength = await reader.ReadBlockAsync(readChunk, 0, readChunkBufferLength);
        
                    await textWriter.WriteAsync(readChunk, 0, readChunkLength);
        
                } while (readChunkLength > 0);
        
                return Regex.Unescape(textWriter.ToString());
            }
        
            private static bool IsUIReqeust(HttpContext httpContext)
            {
                if (httpContext.Request.Path.StartsWithSegments("/api/documentation")
                    || httpContext.Request.Path.StartsWithSegments("/api/hc")
                    || httpContext.Request.Path.StartsWithSegments("/api/health-check")
                    || httpContext.Request.Path.StartsWithSegments("/api/hc-ui"))
                    return true;
        
                return false;
            }
        }
    

    Here's the relevant code for the JWT authentication:

    
        services.AddJwtBearer(options =>
        {
            options.Events = new JwtBearerEvents
            {    
                 // This event is fired when the token is not provided or after OnForbidden and OnAuthenticationFailed events fired.
                 OnChallenge = async context =>
                 {
                      // We will add this check and response rewrite when the token is not provided.
                      // At the same time, since I set the response code in the OnForbidden and OnAuthenticationFailed events, it was added in order not to rewrite the response a second time.
                      if (!(context.Response.StatusCode == 403 || context.Response.StatusCode == 401))
                      {
                            // Since this scenario will work when a token is not sent to an endpoint that requires authorization, I set the response to 401.
                            context.Response.StatusCode = 401;
                            await WriteResponseAsync(context, HttpStatusCode.Unauthorized);
                      }
                 },
                 OnForbidden = async context =>
                 {
                      // Invalid permissions
                      context.Response.StatusCode = 403;
                      await WriteResponseAsync(context, HttpStatusCode.Forbidden);
                 },
                 OnAuthenticationFailed = async context =>
                 {
                      // Invalid token
                      context.Response.StatusCode = 401;
                      await WriteResponseAsync(context, HttpStatusCode.Unauthorized);
                 },
            };
        
            static Task WriteResponseAsync(BaseContext<JwtBearerOptions> context, HttpStatusCode statusCode)
            {
                if (!(context.Response.StatusCode is >= 200 and <= 299))
                    context.Response.OnStarting(async () =>
                    {              
                        if (!context.Response.HasStarted)
                        {                        
                            context.Response.ContentType = MimeTypeNames.ApplicationJson;
                            context.Response.StatusCode = (int)statusCode;
        
                            await context.Response.WriteAsJsonAsync(your-response-model);
                        }
                    });
        
                return Task.CompletedTask;
            }
        });
    
    

    EDIT;

    I detected the problem by reproducing your code. The problem arises from the use of JWT events.

    • OnChallenge: This event is triggered when an authentication challenge occurs, such as when a 401 Unauthorized response is about to be sent.
    • OnForbidden: This event is triggered when a 403 Forbidden response is about to be sent.
    • OnAuthenticationFailed: This event is triggered if the authentication process fails.

    So, in your code, the response is first rewritten in OnAuthenticationFailed and then in OnChallenge. Since you did not write the response body stream from the beginning, it adds it to the end and thus your response becomes a duplicate.

    There is a working example in my answer with the response without duplicate. You can use events correctly by establishing your own logic according to your own needs. I only left an example because you can do this in many different ways (You can even solve this by putting a flag in HttpContext.Items indicating whether the response has been rewritten within the event).

    PS: I am not deleting the parts about rewriting response.body from my answer because it might be useful for others.