Search code examples
pythonamazon-web-servicesaws-lambdajwtamazon-cognito

AWS Lambda Custom JWT Validation



I've built that first validates the JWT Token and then extracts the user unique ID ("sub").

In a non Lambda environment the script works fine, however in the AWS Lambda I'm having an error message.

What could be the problem?

Unexpected error during JWT validation: Unable to find an algorithm for key: {'alg': 'RS256', 'e': 'AQAB', 'kid': 'dmAQX7bVDINFkTGxZc5YCxF5ZA/pcaRsQMUoBbRt4bw=', 'kty': 'RSA', 'n': 'u9hHbyMaI-PWsTG9MtaHjxwBmMez6VeV-ScqIgllBUSQkx8Ao...vGUIG39rb3nPmNVCunBw', 'use': 'sig'}

This is my AWS Lambda code:

import json
import os
import requests
from jose import jwt, jwk

def get_efs_keys(file_name="/mnt/efs/jwks.json"):
    
    # The jkws.json is obtained from here:
    # https://cognito-idp.<Region>.amazonaws.com/<userPoolId>/.well-known/jwks.json
    
    try:
        with open(file_name, 'r') as file:
            jwks_data = json.load(file)
            return jwks_data.get('keys', [])
    except Exception as e:
        print(f"An error occurred while fetching keys: {e}")
        return []

def validate_jwt(jwt_token, keys):
    if not jwt_token:
        return False, False

    try:
        headers = jwt.get_unverified_headers(jwt_token)
        kid = headers.get('kid')
        if not kid:
            return False, False

        key = next((key for key in keys if key['kid'] == kid), None)
        if key is None:
            return False, False

        public_key = jwk.construct(key)
        decoded_token = jwt.decode(jwt_token, public_key, algorithms=['RS256'], audience=os.environ.get('APP_CLIENT_ID'))
        return True, decoded_token.get('sub', False)
    except jwt.JWTError as e:
        print(f"JWT token validation error: {e}")
        return False, False
    except Exception as e:
        print(f"Unexpected error during JWT validation: {e}")
        return False, False

def lambda_handler(event, context):
    # Get all headers from the event
    headers = event.get('headers', {})

    # Get the Authorization header
    authorization_header = headers.get('Authorization', '')

    # Parse the Bearer token to get only the access token (case-insensitive)
    if authorization_header.lower().startswith('bearer '):
        access_token = authorization_header[7:]
    else:
        access_token = None

    # Get keys from EFS
    keys = get_efs_keys()

    # Validate the JWT token
    jwt_valid, sub = validate_jwt(access_token, keys)

    # Create a response
    response_body = {
        'access_token': access_token,
        'jwt_valid': jwt_valid,
        'sub': sub
    }

    response = {
        'statusCode': 200,
        'headers': {
            'Content-Type': 'application/json'
        },
        'body': json.dumps(response_body)
    }

    return response

If the validation is successful, the "jwt_valid" must be "True" and the "sub" the respective unique value.


Solution

  • This thread suggests that you are missing the dependency cryptography. E.g. you need to install it in your Lambda layer because it provides the necessary algorithms.

    pip install cryptography