Search code examples
pythonfastapistarlette

make fastapi middleware returning custom http status instead of AuthenticationError status 400


In the following example when you pass a username in the basic auth field it raise a basic 400 error, but i want to return 401 since it's related to the authentication system.

I did tried Fastapi exceptions classes but they do not raise (i presume since we are in a starlette middleware). I also tried JSONResponse from starlette but it doesn't work either.

AuthenticationError work and raise a 400 but it's just an empty class that inherit from Exception so no status code can be given.

Fully working example:

import base64
import binascii

import uvicorn
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer, HTTPBasic
from starlette.authentication import AuthenticationBackend, AuthCredentials, AuthenticationError, BaseUser
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.responses import JSONResponse


class SimpleUserTest(BaseUser):
    """
    user object returned to route
    """
    def __init__(self, username: str, test1: str, test2: str) -> None:
        self.username = username
        self.test1 = test1
        self.test2 = test2

    @property
    def is_authenticated(self) -> bool:
        return True


async def jwt_auth(auth: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False))):
    if auth:
        return True


async def key_auth(apikey_header=Depends(HTTPBasic(auto_error=False))):
    if apikey_header:
        return True


class BasicAuthBackend(AuthenticationBackend):
    async def authenticate(self, conn):
        if "Authorization" not in conn.headers:
            return

        auth = conn.headers["Authorization"]
        try:
            scheme, credentials = auth.split()
            if scheme.lower() == 'bearer':
                # check bearer content and decode it
                user: dict = {"username": "bearer", "test1": "test1", "test2": "test2"}

            elif scheme.lower() == 'basic':

                decoded = base64.b64decode(credentials).decode("ascii")
                username, _, password = decoded.partition(":")
                if username:
                    # check redis here instead of return dict
                    print("try error raise")
                    raise AuthenticationError('Invalid basic auth credentials')  # <= raise 400, we need 401
                    # user: dict = {"username": "basic auth", "test1": "test1", "test2": "test2"}
                else:
                    print("error should raise")
                    return JSONResponse(status_code=401, content={'reason': str("You need to provide a username")})
            else:
                return JSONResponse(status_code=401, content={'reason': str("Authentication type is not supported")})

        except (ValueError, UnicodeDecodeError, binascii.Error) as exc:
            raise AuthenticationError('Invalid basic auth credentials')

        return AuthCredentials(["authenticated"]), SimpleUserTest(**user)


async def jwt_or_key_auth(jwt_result=Depends(jwt_auth), key_result=Depends(key_auth)):
    if not (key_result or jwt_result):
        raise HTTPException(status_code=401, detail="Not authenticated")


app = FastAPI(
    dependencies=[Depends(jwt_or_key_auth)],
    middleware=[Middleware(AuthenticationMiddleware, backend=BasicAuthBackend())]
)


@app.get("/")
async def read_items(request: Request) -> str:
    return request.user.__dict__


if __name__ == "__main__":
    uvicorn.run("main:app", host="127.0.0.1", port=5000, log_level="info")

if we set username in basic auth:

INFO:     127.0.0.1:22930 - "GET / HTTP/1.1" 400 Bad Request

Solution

  • so i ended up using on_error as suggested by @MatsLindh

    old app:

    app = FastAPI(
        dependencies=[Depends(jwt_or_key_auth)],
        middleware=[
            Middleware(
                AuthenticationMiddleware, 
                backend=BasicAuthBackend(),
            )
        ],
    )
    

    new version:

    app = FastAPI(
        dependencies=[Depends(jwt_or_key_auth)],
        middleware=[
            Middleware(
                AuthenticationMiddleware,
                backend=BasicAuthBackend(),
                on_error=lambda conn, exc: JSONResponse({"detail": str(exc)}, status_code=401),
            )
        ],
    )
    

    I choose to use JSONResponse and return a "detail" key/value to emulate a classic 401 fastapi httperror