Search code examples
.net-7.0executioncontextprincipalcontextazure-functions-isolated

How to overwrite ClaimIdentity from HttpRequestData.Identities while working with isolated azure function


I'm working with .net 7 isolated azure function. before processing the business I need to add more information to the Claim Identity like the code below:

public async Task<IActionResult> PreprocessingRequestAsync(**HttpRequestData request**, Func<Task<IActionResult>> logic)
        {
            var hasId = request.Headers.TryGetValues("id", out var id);
            var hasUsername = request.Headers.TryGetValues("username", out var username);
            var hasEmail = request.Headers.TryGetValues("email", out var email);

            if (hasId &&
                hasUsername &&
                hasEmail)
            {
                var user = new ClaimsIdentity(new[]
                    {
                        new Claim("http://schemas.xmlsoap.org/ws/2005/05/identity/claims/nameidentifier",
                            id.FirstOrDefault()),
                        new Claim("http://schemas.xmlsoap.org/ws/2005/05/identity/claims/givenname",
                            username.FirstOrDefault()),
                        new Claim("http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name",
                            email.FirstOrDefault())
                    }
                    , "ByPassAuth");

                **request.Identities.Append(user);**

                var response = await ContinueToProcessBusinessLogicAsync(logic, request, false);
                return response;
            }
        }

The above code was not able to work as expected (add more information to the Claim Identity). I know that Identities is ready-only property, but is there any way that allow me to do that.

Thank you,


Solution

  • You can setup your own Middleware that would push data to the FunctionContext.

    using Microsoft.Azure.Functions.Worker;
    using Microsoft.Azure.Functions.Worker.Middleware;
    using Microsoft.IdentityModel.Tokens;
    using System;
    using System.Collections.Generic;
    using System.IdentityModel.Tokens.Jwt;
    using System.Linq;
    using System.Security.Claims;
    using System.Security.Cryptography;
    using System.Text;
    using System.Text.Json;
    using System.Threading.Tasks;
    
    namespace AzureFunctionApp.Middleware
    {
        public class TokenMiddleware : IFunctionsWorkerMiddleware
        {
            private readonly string _issuer = "https://accounts.google.com";
            private readonly string _audience = "$(GoogleClientID)";
            private readonly HttpClient _httpClient = new HttpClient();
    
            public async Task Invoke(FunctionContext context, FunctionExecutionDelegate next)
            {
                var request = await context.GetHttpRequestDataAsync();
                var token = request.Headers.GetValues("Authorization").FirstOrDefault()?.Split(" ").Last();
    
                if (token != null)
                {
                    ClaimsPrincipal claimsPrincipal = await ValidateAndExtractClaimsAsync(token);
    
                    if (claimsPrincipal != null)
                    {
                        context.Items.Add("User", claimsPrincipal); // <-- this is where you're saving data
                    }
                }
    
                await next(context);
            }
    
            public async Task<ClaimsPrincipal> ValidateAndExtractClaimsAsync(string jwtToken)
            {
                var tokenHandler = new JwtSecurityTokenHandler();
                var discoveryDocument = await _httpClient.GetStringAsync(_issuer + "/.well-known/openid-configuration");
                var jwksUri = JsonDocument.Parse(discoveryDocument).RootElement.GetProperty("jwks_uri").GetString();
                var jwksResponse = await _httpClient.GetStringAsync(jwksUri);
                var jwks = JsonDocument.Parse(jwksResponse).RootElement;
    
                var validationParameters = new TokenValidationParameters
                {
                    ValidIssuer = _issuer,
                    ValidAudience = _audience,
                    IssuerSigningKeys = ExtractSigningKeys(jwks),
                    ValidateIssuerSigningKey = true,
                    ValidateIssuer = true,
                    ValidateAudience = true,
                };
    
                try
                {
                    ClaimsPrincipal claimsPrincipal = tokenHandler.ValidateToken(jwtToken, validationParameters, out _);
                    return claimsPrincipal;
                }
                catch (Exception ex)
                {
                    return null; // Token validation failed
                }
            }
    
            private IEnumerable<SecurityKey> ExtractSigningKeys(JsonElement jwks)
            {
                var keys = jwks.GetProperty("keys").EnumerateArray();
                var signingKeys = new List<SecurityKey>();
    
                foreach (var key in keys)
                {
                    var keyType = key.GetProperty("kty").GetString();
                    var modulusBase64Url = key.GetProperty("n").GetString();
                    var exponentBase64Url = key.GetProperty("e").GetString();
    
                    var modulusBytes = Base64UrlDecode(modulusBase64Url);
                    var exponentBytes = Base64UrlDecode(exponentBase64Url);
    
                    var keyParameters = new RSAParameters
                    {
                        Modulus = modulusBytes,
                        Exponent = exponentBytes,
                    };
    
                    var rsa = RSA.Create();
                    rsa.ImportParameters(keyParameters);
                    signingKeys.Add(new RsaSecurityKey(rsa));
                }
    
                return signingKeys;
            }
    
            private byte[] Base64UrlDecode(string input)
            {
                string base64 = input.Replace('-', '+').Replace('_', '/');
                while (base64.Length % 4 != 0)
                {
                    base64 += "=";
                }
                return Convert.FromBase64String(base64);
            }
        }
    }
    
    

    Now use your middleware in the Program.cs:

    .ConfigureFunctionsWorkerDefaults(builder =>
        {
            builder.UseMiddleware<TokenMiddleware>();
        })
    

    And now access the User object in the Function:

    [Function("FunctionName")]
    public async Task<HttpResponseData> Run([HttpTrigger(AuthorizationLevel.Function, "post")] HttpRequestData req, FunctionContext functionContext)
    {
       var identity = functionContext.Items["User"] as ClaimsPrincipal;
       var name = identity.Claims.Where(c => c.Type == ClaimTypes.NameIdentifier).Select(c => c.Value).SingleOrDefault();
    
      // the rest of the function code
    }