Search code examples
rustrust-axum

Axum Middleware to log the response body


I want to log the responses of my Http requests. So i looked at some examples at the axum github and found the following.

...
.layer(axum::middleware::from_fn(print_request_response))
...

async fn print_request_response<B>(
    req: Request<B>,
    next: Next<B>
) -> Result<impl IntoResponse, (StatusCode, String)> {
    let (parts, body) = req.into_parts();
    let bytes = buffer_and_print("request", body).await?;
    let req = Request::from_parts(parts, hyper::Body::from(bytes));
    
    let res = next.run(req).await;
    
    let (parts, body) = res.into_parts();
    let bytes = buffer_and_print("response", body).await?;
    let res = Response::from_parts(parts, Body::from(bytes));

    Ok(res)
}
async fn buffer_and_print<B>(direction: &str, body: B) -> Result<Bytes, (StatusCode, String)>
{
    let bytes = match hyper::body::to_bytes(body).await {
        Ok(bytes) => bytes,
        Err(err) => {
            return Err((
                StatusCode::BAD_REQUEST,
                format!("failed to read {} body: {}", direction, err),
            ));
        }
    };

    if let Ok(body) = std::str::from_utf8(&bytes) {
        tracing::debug!("{} body = {:?}", direction, body);
    }

    Ok(bytes)
}

In the example no types were given but the compiler directly said i need some types for Request, Next and the functions. I've been struggling to get it to work. Right now the problem is the following. At the line

let res = next.run(req).await;

I get this error:

error[E0308]: mismatched types
   --> src\core.rs:302:24
    |
294 | async fn print_request_response<B>(
    |                                 - this type parameter
...
302 |     let res = next.run(req).await;
    |                    --- ^^^ expected type parameter `B`, found struct `Body`
    |                    |
    |                    arguments to this function are incorrect
    |
    = note: expected struct `hyper::Request<B>`
               found struct `hyper::Request<Body>`

I understand the type mismatch. But according to the implementation, next.run() accepts a generic type?

I tried different type parameters and changing the return type of

let req = Request::from_parts(parts, hyper::Body::from(bytes));

but it didn't work.

I also dont need this exact example to work, I just want to get the responses of my Http Request logged.

Edit minimal reproducible example:

cargo.toml

[package]
name = "test"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
axum = { version = "0.6.18", features = ["http2"] }
hyper = { version = "0.14", features = ["full"] }
tokio = { version = "1.0", features = ["full"] }
tower = { version = "0.4", features = ["util", "filter"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }

main.rs

use std::net::SocketAddr;
use axum::{
    body::{Body, Bytes},
    http::StatusCode,
    middleware::{self, Next},
    response::{IntoResponse, Response},
    routing::post,
    Router,
};
use hyper::Request;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

#[tokio::main]
async fn main() {
    let app = Router::new()
        .route("/", post(|| async move { "Hello from `POST /`" }))
        .layer(middleware::from_fn(print_request_response));

    let addr = SocketAddr::from(([0, 0, 0, 0], 8080));
    axum::Server::bind(&addr)
        // .http2_only(true)
        .serve(app.into_make_service())
        .await
        .unwrap();
}

async fn print_request_response<B>(
    req: Request<B>,
    next: Next<B>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
    let (parts, body) = req.into_parts();
    let bytes = buffer_and_print("request", body).await?;
    let req = Request::from_parts(parts, Body::from(bytes));

    let res = next.run(req).await;

    let (parts, body) = res.into_parts();
    let bytes = buffer_and_print("response", body).await?;
    let res = Response::from_parts(parts, Body::from(bytes));

    Ok(res)
}

async fn buffer_and_print<B>(direction: &str, body: B) -> Result<Bytes, (StatusCode, String)>
{
    let bytes = match hyper::body::to_bytes(body).await {
        Ok(bytes) => bytes,
        Err(err) => {
            return Err((
                StatusCode::BAD_REQUEST,
                format!("failed to read {} body: {}", direction, err),
            ));
        }
    };

    if let Ok(body) = std::str::from_utf8(&bytes) {
        tracing::debug!("{} body = {:?}", direction, body);
    }

    Ok(bytes)
}

Solution

  • The solution that works for me now.

    use axum::{middleware, Router};
    use axum::body::Bytes;
    use axum::http::{Request, Response, StatusCode};
    use axum::middleware::Next;
    use axum::response::IntoResponse;
    use axum::routing::{get, post};
    use hyper::Body;
    use log::info;
    use tower::ServiceExt;
    
    pub async fn log_request_response(
        req: Request<axum::body::Body>,
        next: Next<axum::body::Body>,
    ) -> Result<impl IntoResponse, (StatusCode, String)> {
        let mut do_log = true;
    
        let path = &req.uri().path().to_string();
    
        // Don't log these extensions
        let extension_skip = vec![".js", ".html", ".css", ".png", ".jpeg"];
        for ext in extension_skip {
            if path.ends_with(ext) {
                do_log = false;
                break;
            }
        }
    
        // Want to skip logging these paths
        let skip_paths = vec!["/example/path"];
        for skip_path in skip_paths {
            if path.ends_with(skip_path) {
                do_log = false;
                break;
            }
        }
    
        let (req_parts, req_body) = req.into_parts();
    
        // Print request
        let bytes = buffer_and_print("request", path, req_body, do_log).await?;
        let req = Request::from_parts(req_parts, hyper::Body::from(bytes));
    
        let res = next.run(req).await;
        
    
        let (mut res_parts, res_body) = res.into_parts();
    
        // Print response
        let bytes = buffer_and_print("response", path, res_body, do_log).await?;
        
        // When your encoding is chunked there can be problems without removing the header
        res_parts.headers.remove("transfer-encoding");
        
        let res = Response::from_parts(res_parts, Body::from(bytes));
         
        Ok(res)
    }
    
    // Consumes body and prints
    async fn buffer_and_print<B>(direction: &str, path: &str, body: B, log: bool) -> Result<Bytes, (StatusCode, String)>
        where
            B: axum::body::HttpBody<Data=Bytes>,
            B::Error: std::fmt::Display,
    {
        let bytes = match hyper::body::to_bytes(body).await {
            Ok(bytes) => bytes,
            Err(err) => {
                return Err((
                    StatusCode::BAD_REQUEST,
                    format!("failed to read {} body: {}", direction, err),
                ));
            }
        };
    
        if let Ok(body) = std::str::from_utf8(&bytes) {
            if log && !body.is_empty() {
                if body.len() > 2000 {
                    info!("{} for req: {} with body: {}...", direction, path, &body[0..2000]);
                }
                else {
                    info!("{} for req: {} with body: {}", direction, path, body);
                }
            }
        }
    
        Ok(bytes)
    }
    
    #[tokio::test]
    async fn test_log_request_response() {
        // create a request to be passed to the middleware
        let req = Request::new(Body::from("Hello, Axum!"));
    
        // create a simple router to test the middleware
        let app = Router::new()
            .route("/", get(|| async { "Hello, World!" }))
            .layer(middleware::from_fn(log_request_response));
    
        // send the request through the middleware
        let res = app.clone().oneshot(req).await.unwrap();
    
        // make sure the response has a status code of 200
        assert_eq!(res.status(), StatusCode::OK);
    }