Search code examples
pythonfastapimiddlewarestarlette

How to write a custom FastAPI middleware class


I have read FastAPI's documentation about middlewares (specifically, the middleware tutorial, the CORS middleware section and the advanced middleware guide), but couldn't find a concrete example of how to write a middleware class which you can add using the add_middleware function (in contrast to a basic middleware function added using a decorator) there nor on this site.

The reason I prefer to use add_middleware over the app based decorator, is that I want to write a middleware in a shared library that will be used by several different projects, and therefore I can't tie it to a specific FastAPI instance.

So my question is: how do you do it?


Solution

  • As FastAPI is actually Starlette underneath, you could use BaseHTTPMiddleware that allows you to implement a middleware class (you may want to have a look at this post as well). Below are given two variants of the same approach on how to do that, where the add_middleware() function is used to add the middleware class. Please note that is currently not possible to use BackgroundTasks (if that's a requirement for your task) with BaseHTTPMiddleware—check #1438 and #1640 for more details. Alternatives can be found in this answer and this answer.

    Option 1

    middleware.py

    from fastapi import Request
    
    class MyMiddleware:
        def __init__(
                self,
                some_attribute: str,
        ):
            self.some_attribute = some_attribute
    
        async def __call__(self, request: Request, call_next):
            # do something with the request object
            content_type = request.headers.get('Content-Type')
            print(content_type)
            
            # process the request and get the response    
            response = await call_next(request)
            
            return response
    

    app.py

    from fastapi import FastAPI
    from middleware import MyMiddleware
    from starlette.middleware.base import BaseHTTPMiddleware
    
    app = FastAPI()
    my_middleware = MyMiddleware(some_attribute="some_attribute_here_if_needed")
    app.add_middleware(BaseHTTPMiddleware, dispatch=my_middleware)
    

    Option 2

    middleware.py

    from fastapi import Request
    from starlette.middleware.base import BaseHTTPMiddleware
    
    class MyMiddleware(BaseHTTPMiddleware):
        def __init__(
                self,
                app,
                some_attribute: str,
        ):
            super().__init__(app)
            self.some_attribute = some_attribute
    
        async def dispatch(self, request: Request, call_next):
            # do something with the request object, for example
            content_type = request.headers.get('Content-Type')
            print(content_type)
            
            # process the request and get the response    
            response = await call_next(request)
            
            return response
    

    app.py

    from fastapi import FastAPI
    from middleware import MyMiddleware
    
    app = FastAPI()
    app.add_middleware(MyMiddleware, some_attribute="some_attribute_here_if_needed")