I have an ASGI middleware that adds fields to the POST request body before it hits the route in my fastapi app.
from starlette.types import ASGIApp, Message, Scope, Receive, Send
class MyMiddleware:
"""
This middleware implements a raw ASGI middleware instead of a starlette.middleware.base.BaseHTTPMiddleware
because the BaseHTTPMiddleware does not allow us to modify the request body.
For documentation see https://www.starlette.io/middleware/#pure-asgi-middleware
"""
def __init__(self, app: ASGIApp):
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return ""
async def modify_message():
message: dict = await receive()
if message.get("type", "") != "http.request":
return message
if not message.get("body", None):
return message
body: dict = json.loads(message.get("body", b"'{}'").decode("utf-8"))
body["some_field"] = "foobar"
message["body"] = json.dumps(body).encode("utf-8")
return message
await self.app(scope, modify_message, send)
Is there an example on how to unit test an ASGI middleware? I would like to test directly the __call__
part which is difficult as it does not return anything. Do I need to use a test api client (e.g. TestClient
from fastapi) to then create some dummy endpoint which returns the request as response and thereby check if the middleware was successful or is there a more "direct" way?
I've faced the similar problem recently, so I want to share my solution for fastapi
and pytest
.
I had to implement per request logs for the fastapi app using middlewares.
I've checked Starlette's test suite as Marcelo Trylesinski suggested and adapted the code to fit fastapi. Thank you for the recommendation, Marcelo!
Here is my middleware that logs information from every request and response.
# middlewares.py
import logging
from starlette.types import ASGIApp, Scope, Receive, Send
logger = logging.getLogger("app")
class LogRequestsMiddleware:
def __init__(self, app: ASGIApp) -> None:
self.app = app
async def __call__(
self, scope: Scope, receive: Receive, send: Send
) -> None:
async def send_with_logs(message):
"""Log every request info and response status code."""
if message["type"] == "http.response.start":
# request info is stored in the scope
# status code is stored in the message
logger.info(
f'{scope["client"][0]}:{scope["client"][1]} - '
f'"{scope["method"]} {scope["path"]} '
f'{scope["scheme"]}/{scope["http_version"]}" '
f'{message["status"]}'
)
await send(message)
await self.app(scope, receive, send_with_logs)
To test a middleware, I had to create test_factory_client fixture:
# conftest.py
import pytest
from fastapi.testclient import TestClient
@pytest.fixture
def test_client_factory() -> TestClient:
return TestClient
In the test, I mocked logger.info()
call within the middleware and asserted if the method was called.
# test_middlewares.py
from unittest import mock
from fastapi.testclient import TestClient
from fastapi import FastAPI
from .middlewares import LogRequestsMiddleware
# mock logger call within the pure middleware
@mock.patch("path.to.middlewares.logger.info")
def test_log_requests_middleware(
mock_logger, test_client_factory: TestClient
):
# create a fresh app instance to isolate tested middlewares
app = FastAPI()
app.add_middleware(LogRequestsMiddleware)
# create an endpoint to test middlewares
@app.get("/")
def homepage():
return {"hello": "world"}
# create a client for the app using fixure
client = test_client_factory(app)
# call an endpoint
response = client.get("/")
# sanity check
assert response.status_code == 200
# check if the logger was called
mock_logger.assert_called_once()