Search code examples
pythonflaskauth0flask-restful

Working with Flask Request Context for user management


I am currently working on a project where I am building an API using Flask, Flask-RESTful, and then Auth0 for user authentication. Currently, I am wrestling with passing the Auth0 sub id from the example decorator to the Flask-RESTful method. Here is the code below.

Auth decorator provided by Auth0 Docs:

def get_token_auth_header():
    """Obtains the access token from the Authorization Header"""
    auth = request.headers.get("Authorization", None)
    if not auth:
        raise AuthError(
            {
                "code": "authorization_header_missing",
                "description": "Authorization header is expected",
            },
            401,
        )

    parts = auth.split()

    if parts[0].lower() != "bearer":
        raise AuthError(
            {
                "code": "invalid_header",
                "description": "Authorization header must start with" " Bearer",
            },
            401,
        )
    elif len(parts) == 1:
        raise AuthError(
            {"code": "invalid_header", "description": "Token not found"}, 401
        )
    elif len(parts) > 2:
        raise AuthError(
            {
                "code": "invalid_header",
                "description": "Authorization header must be" " Bearer token",
            },
            401,
        )

    token = parts[1]
    return token

def requires_auth(f):
    """Determines if the access token is valid"""

    @wraps(f)
    def decorated(*args, **kwargs):
        token = get_token_auth_header()
        jsonurl = urlopen("https://" + AUTH0_DOMAIN + "/.well-known/jwks.json")
        jwks = json.loads(jsonurl.read())
        try:
            unverified_header = jwt.get_unverified_header(token)
        except jwt.JWTError:
            raise AuthError(
                {
                    "code": "invalid_header",
                    "description": "Invalid header. "
                    "Use an RS256 signed JWT Access Token",
                },
                401,
            )
        if unverified_header["alg"] == "HS256":
            raise AuthError(
                {
                    "code": "invalid_header",
                    "description": "Invalid header. "
                    "Use an RS256 signed JWT Access Token",
                },
                401,
            )
        rsa_key = {}
        for key in jwks["keys"]:
            if key["kid"] == unverified_header["kid"]:
                rsa_key = {
                    "kty": key["kty"],
                    "kid": key["kid"],
                    "use": key["use"],
                    "n": key["n"],
                    "e": key["e"],
                }
        if rsa_key:
            try:
                payload = jwt.decode(
                    token,
                    rsa_key,
                    algorithms=ALGORITHMS,
                    audience=API_IDENTIFIER,
                    issuer="https://" + AUTH0_DOMAIN + "/",
                )
            except jwt.ExpiredSignatureError:
                raise AuthError(
                    {"code": "token_expired", "description": "token is expired"}, 401
                )
            except jwt.JWTClaimsError as jce:
                raise AuthError(
                    {
                        "code": "invalid_claims",
                        "description": "incorrect claims,"
                        " please check the audience and issuer",
                    },
                    401,
                )
            except Exception:
                raise AuthError(
                    {
                        "code": "invalid_header",
                        "description": "Unable to parse authentication" " token.",
                    },
                    401,
                )

            _request_ctx_stack.top.current_user = payload
            print(payload)
            return f(*args, **kwargs)
        raise AuthError(
            {"code": "invalid_header", "description": "Unable to find appropriate key"},
            401,
        )

    return decorated

Example Use:

class HealthCheck(Resource):

    @requires_auth
    @cross_origin(headers=["Content-Type", "Authorization", "Access-Control-Allow-Origin", "http://localhost:3000"])
    def post(self):
        print(f'Req: {_request_ctx_stack.top.current_user}')
        return jsonify(message=_request_ctx_stack.top.current_user)

Specific code:

_request_ctx_stack.top.current_user = payload

Currently, the requires_auth wrapper will store the Auth0 response payload in the _request_ctx_stack.top.current_user. Is there a best practice for working with this? Is the implementation above in my resource the best way to get the payload? Something is telling me there is a better way. I tried putting it into the kwargs to access in the header but flask-restful did not like that. I also played with flask.g but that seemed equally unsafe as it is the app context. Thanks!


Solution

  • It really does look odd with the leading underscore, but this does appear to be the correct way to handle this. If you'd like to hide the complexity a bit, you can use a LocalProxy like so:

    auth.py

    from flask import _request_ctx_stack
    from werkzeug.local import LocalProxy
    
    current_user = LocalProxy(lambda: _request_ctx_stack.top.current_user)
    
    # ...
    def requires_auth(f):
    # ...
                    _request_ctx_stack.top.current_user = payload
    

    routes.py

    from .auth import current_user
    
    class HealthCheck(Resource):
    
        @requires_auth
        @cross_origin(headers=["Content-Type", "Authorization", "Access-Control-Allow-Origin", "http://localhost:3000"])
        def post(self):
            print(f'Req: {current_user}')
            return jsonify(message=current_user)