Search code examples
pythonjwtheaderaws-api-gatewayfastapi

Check FastApi headers within AWS Lambda before routes


I am trying to deploy an API with FastAPI, hosted on AWS Lambda and accessible through AWS API Gateway. Every request will have to include a valid JWT in the headers before being processed. I would like to add a function "check_jwt()" in the main.py file before processing each request, instead of including this check for each route of the API.

So far here is the code I have for the file main.py:

from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, HTTPException, status
from mangum import Mangum
from pymongo import MongoClient
import config
import auth
from routes import project_router, forms_models_router, inspections_router, structures_router

@asynccontextmanager
async def lifespan(app: FastAPI):
    # headers = app.head
    # print(f"Headers: {headers}")
    # if not auth.check_jwt(headers, config.jwt_secret):
    #     raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Invalid authentication.")
    # else:
        print("Starting connection to database...")
        app.mongodb_client = MongoClient(config.mongodb_url)
        app.database = app.mongodb_client[config.mongo_database]
        yield
        app.mongodb_client.close()
        print("Database connection closed.")

app = FastAPI(lifespan=lifespan)

app.include_router(project_router, tags=["projects"], prefix="/projects")
app.include_router(forms_models_router, tags=["forms_models"], prefix="/forms_models")
app.include_router(inspections_router, tags=["inspections"], prefix="/inspections")
app.include_router(structures_router, tags=["structures"], prefix="/structures")

handler = Mangum(app)

My file routes.py looks like this:

from json import JSONDecoder
from bson import ObjectId
from fastapi import APIRouter, Request, HTTPException, status
from typing import List, Union, Optional
from models import FormModel, Inspection, Structure, Project, PostResponse, GetResponse, JSONEncoder, JSONDecoder
from schemas import is_valid_object_id

project_router = APIRouter()

@project_router.get("/", response_description="Get all the projects", status_code=status.HTTP_200_OK, response_model=GetResponse)
async def get_all_projects(request: Request):
    # some code to interact with MongoDB

@project_router.get("/{project_id}", response_description="Get a specific project", status_code=status.HTTP_200_OK, response_model=GetResponse)
async def get_project(request: Request, project_id: str):
    # some other code to interact with MongoDB

@project_router.post("/", response_description="Save a project", status_code=status.HTTP_201_CREATED, response_model=PostResponse)
async def post_project(request: Request, project: Project):
    # some final code to interact with MongoDB

I have about 25 routes, so I would like to include the headers check in the main.py file if possible, instead of doing it in each route with using the request object.

EDIT: As suggested by Waket Zheng, I have been able to move forward with FastAPI Middleware. I can now access the headers directly from the main.py file, modified below:

from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, HTTPException, status
from mangum import Mangum
from pymongo import MongoClient
import config
from routes import project_router, forms_models_router, inspections_router, structures_router
import auth


@asynccontextmanager
async def lifespan(app: FastAPI):
    print("Starting connection to database...")
    app.mongodb_client = MongoClient(config.mongodb_url)
    app.database = app.mongodb_client[config.mongo_database]
    yield
    app.mongodb_client.close()
    print("Database connection closed.")

app = FastAPI(lifespan=lifespan)

@app.middleware("http")
async def check_token(request: Request, call_next):
    headers = request.headers
    try:
        encoded_jwt = headers["Authorization"]
        if not auth.check_jwt(encoded_jwt, "mekas", config.jwt_secret):
            raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"Invalid authentication.")
    except KeyError:
        raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"Invalid authentication.") 
    response = await call_next(request)
    return response

app.include_router(project_router, tags=["projects"], prefix="/projects")
app.include_router(forms_models_router, tags=["forms_models"], prefix="/forms_models")
app.include_router(inspections_router, tags=["inspections"], prefix="/inspections")
app.include_router(structures_router, tags=["structures"])

handler = Mangum(app)

However, I am unsure about how to handle the cases when the token is missing or invalid. I would like to block the request without even it accessing the routes, and to return a status_code 403. Currently, the successfull request are properly processed, but the unauthorized one triggers server errors caused by the raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"Invalid authentication."):

[ERROR] An error occurred running the application.
Traceback (most recent call last): File "/var/lang/lib/python3.9/site-packages/starlette/middleware/base.py", line 108, in __call__
response = await self.dispatch_func(request, call_next)
File "/var/task/main.py", line 65, in check_token
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"Invalid authentication.")
fastapi.exceptions.HTTPException
During handling of the above exception, another exception occurred:
+ Exception Group Traceback (most recent call last):
|   File "/var/lang/lib/python3.9/site-packages/mangum/protocols/http.py", line 58, in run
|     await app(self.scope, self.receive, self.send)
|   File "/var/lang/lib/python3.9/site-packages/fastapi/applications.py", line 290, in __call__
|     await super().__call__(scope, receive, send)
|   File "/var/lang/lib/python3.9/site-packages/starlette/applications.py", line 122, in __call__
|     await self.middleware_stack(scope, receive, send)
|   File "/var/lang/lib/python3.9/site-packages/starlette/middleware/errors.py", line 184, in __call__
|     raise exc
|   File "/var/lang/lib/python3.9/site-packages/starlette/middleware/errors.py", line 162, in __call__
|     await self.app(scope, receive, _send)
|   File "/var/lang/lib/python3.9/site-packages/starlette/middleware/base.py", line 110, in __call__
|     response_sent.set()
|   File "/var/lang/lib/python3.9/site-packages/anyio/_backends/_asyncio.py", line 664, in __aexit__
|     raise BaseExceptionGroup(
| exceptiongroup.ExceptionGroup: unhandled errors in a TaskGroup (1 sub-exception)
+-+---------------- 1 ----------------
| Traceback (most recent call last): |   File "/var/lang/lib/python3.9/site-packages/starlette/middleware/base.py", line 108, in __call__
|     response = await self.dispatch_func(request, call_next)
|   File "/var/task/main.py", line 65, in check_token
|     raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"Invalid authentication.")
| fastapi.exceptions.HTTPException

Solution

  • You can return a response instead of raise exception

    from fastapi.responses import JSONResponse
    
    @app.middleware("http")
    async def check_token(request: Request, call_next):
        headers = request.headers
        encoded_jwt = headers.get("Authorization")
        if not encoded_jwt or not auth.check_jwt(encoded_jwt, "mekas", config.jwt_secret):
            return JSONResponse(
                {"detail": "Invalid authentication."}, status.HTTP_403_FORBIDDEN
            )
        response = await call_next(request)
        return response