Search code examples
rustrust-axum

Rust axum middleware always returning an error when applying to routes


I am probably stupid, but I have been staring at this issue for days now and cannot figure it out. The middleware functionality in axum have changed so much recently that 90% of the internet is outdated. The docs won't help me either, probably because I am new to rust so I am probably missing something. I am using axum 0.7.5. Anyway here is code:

main.rs:

#[tokio::main]
async fn main() {
    // Initialize environment variables from .env file
    dotenv::dotenv().ok();

    if let Err(e) = run().await {
        eprintln!("Application error: {:?}", e); 
    }
}

lib.rs:

pub async fn run() -> Result<(), Box<dyn Error>> {
    // Establish the database connection pool
    let db_pool = Arc::new(db::connection::establish_connection().await?);

    // Create the application router and add the connection pool as an extension
    let app = create_routes(db_pool.clone());

    // Use the ? operator to propagate errors instead of unwrapping
    let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await?;
    axum::serve(
        listener,
        app.into_make_service_with_connect_info::<SocketAddr>(),
    )
    .await?;

    Ok(())
}

state.rs:

use std::sync::Arc;

#[derive(Clone)]
pub struct AppState {
    pub access_token_secret: Arc<String>,
    pub refresh_token_secret: Arc<String>,
}

routes/mod.rs:

pub mod analytics;
pub mod auth;
pub mod legal;
pub mod misc;
pub mod users;

use std::{env, sync::Arc};

use axum::middleware::from_fn;
use axum::{Extension, Router};
use axum_client_ip::SecureClientIpSource;
use sqlx::MySqlPool;

use crate::{
    middleware::{auth_middleware::auth_middleware, input_validation::input_validation},
    state::AppState,
};

pub fn create_routes(pool: Arc<MySqlPool>) -> Router {
    // Load the token secrets from environment variables
    let access_token_secret =
        env::var("ACCESS_TOKEN_SECRET").expect("ACCESS_TOKEN_SECRET must be set");
    let refresh_token_secret =
        env::var("REFRESH_TOKEN_SECRET").expect("REFRESH_TOKEN_SECRET must be set");

    let app_state = AppState {
        access_token_secret: Arc::new(access_token_secret),
        refresh_token_secret: Arc::new(refresh_token_secret),
    };

    Router::new()
        .nest("/v1/users", users::routes())
        .nest("/v1/auth", auth::routes(pool.clone()))
        .nest("/v1/legal", legal::routes(pool.clone()))
        .nest("/v1", misc::routes())
        .layer(from_fn(input_validation))
        .layer(from_fn(auth_middleware))
        .layer(Extension(app_state.clone()))
        .layer(SecureClientIpSource::ConnectInfo.into_extension())
}

middleware/input_validation.rs (this one works):

use axum::{
    body::to_bytes, body::Body, http::Request, http::StatusCode, middleware::Next,
    response::Response,
};
use hyper::header::{HeaderValue, CONTENT_TYPE};
use serde_json::Value;
use std::collections::HashMap;
use url::form_urlencoded;

use crate::validation::validate_fields::validate_fields;

const MAX_BODY_SIZE: usize = 1024 * 1024; // 1 MB

pub async fn input_validation(req: Request<Body>, next: Next) -> Result<Response, StatusCode> {
    let mut fields = HashMap::new();

    // Extract query parameters
    if let Some(query) = req.uri().query() {
        for (key, value) in form_urlencoded::parse(query.as_bytes()) {
            fields.insert(key.into_owned(), value.into_owned());
        }
    }

    // Extract content type before moving the request
    let content_type = req.headers().get(CONTENT_TYPE).cloned();

    // Split the request into parts
    let (parts, body) = req.into_parts();

    // Extract body if it's JSON
    if content_type == Some(HeaderValue::from_static("application/json")) {
        // Attempt to read the whole body with a size limit
        let whole_body = match to_bytes(body, MAX_BODY_SIZE).await {
            Ok(bytes) => bytes,
            Err(_) => return Err(StatusCode::BAD_REQUEST),
        };

        // Check if the body exceeds the maximum size
        if whole_body.len() > MAX_BODY_SIZE {
            return Err(StatusCode::PAYLOAD_TOO_LARGE);
        }

        // Parse the JSON body
        let json_body: Value = match serde_json::from_slice(&whole_body) {
            Ok(json) => json,
            Err(_) => return Err(StatusCode::BAD_REQUEST),
        };

        // Extract fields from the JSON object
        if let Some(object) = json_body.as_object() {
            for (key, value) in object.iter() {
                fields.insert(key.clone(), value.to_string());
            }
        }

        // Validate all extracted fields and clean them
        match validate_fields(&fields) {
            Ok(cleaned_fields) => {
                // Optionally reconstruct JSON body with cleaned values
                let new_body = match serde_json::to_string(&cleaned_fields) {
                    Ok(json_str) => json_str,
                    Err(_) => return Err(StatusCode::BAD_REQUEST),
                };
                // Reconstruct the request with the cleaned body
                let new_request = Request::from_parts(parts, Body::from(new_body));

                return Ok(next.run(new_request).await);
            }
            Err(ref err) if err == "Invalid credentials." => Err(StatusCode::UNAUTHORIZED),
            Err(_) => Err(StatusCode::BAD_REQUEST),
        }
    } else {
        // Validate all extracted fields for non-JSON bodies
        match validate_fields(&fields) {
            Ok(_) => {
                // Reconstruct the request with the original body
                let original_request = Request::from_parts(parts, body);

                return Ok(next.run(original_request).await);
            }
            Err(ref err) if err == "Invalid credentials." => Err(StatusCode::UNAUTHORIZED),
            Err(_) => Err(StatusCode::BAD_REQUEST),
        }
    }
}

middleware/auth_validation.rs (this one does not work):

use crate::{services::token_services::decode_token::decode_token, state::AppState};
use axum::{
    body::Body,
    http::{Request, StatusCode},
    middleware::Next,
    response::Response,
};

pub async fn auth_middleware(mut req: Request<Body>, next: Next) -> Result<Response, StatusCode> {
    // Get the state from request extensions
    let state = req
        .extensions()
        .get::<AppState>()
        .cloned()
        .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;

    // Get the request path
    let path = req.uri().path();

    // Define public paths that bypass authentication
    let public_paths = [
        "/v1/auth/login",
        "/v1/auth/register",
        // Add more public routes if necessary
    ];

    // If the path is public, skip authentication
    if public_paths.contains(&path) {
        return Ok(next.run(req).await);
    }

    // Extract the "x-access-token" header
    let token = req
        .headers()
        .get("x-access-token")
        .and_then(|t| t.to_str().ok())
        .ok_or(StatusCode::UNAUTHORIZED)?;

    // Decode and validate the token using your existing decode_token function
    match decode_token(
        token,
        true, // Indicates that we're validating an access token
        &state.access_token_secret,
        &state.refresh_token_secret,
    ) {
        Ok(token_content) => {
            // Insert the TokenContent into request extensions for access in handlers
            req.extensions_mut().insert(token_content);
            Ok(next.run(req).await)
        }
        Err(_) => Err(StatusCode::UNAUTHORIZED),
    }
}

The rust analyzer is only complaining about routes/mod.rs. I am getting this compile error:

error[E0277]: the trait bound `axum::middleware::FromFn<fn(hyper::Request<axum::body::Body>, Next) -> impl Future<Output = Result<Response<axum::body::Body>, StatusCode>> {auth_middleware}, (), Route, _>: tower_service::Service<hyper::Request<axum::body::Body>>` is not satisfied
   --> src/routes/mod.rs:40:16
    |
40  |         .layer(from_fn(auth_middleware))
    |          ----- ^^^^^^^^^^^^^^^^^^^^^^^^ the trait `tower_service::Service<hyper::Request<axum::body::Body>>` is not implemented for `axum::middleware::FromFn<fn(hyper::Request<axum::body::Body>, Next) -> impl Future<Output = Result<Response<axum::body::Body>, StatusCode>> {auth_middleware}, (), Route, _>`
    |          |
    |          required by a bound introduced by this call
    |
    = help: the following other types implement trait `tower_service::Service<Request>`:
              axum::middleware::FromFn<F, S, I, (T1, T2)>
              axum::middleware::FromFn<F, S, I, (T1, T2, T3)>
              axum::middleware::FromFn<F, S, I, (T1, T2, T3, T4)>
              axum::middleware::FromFn<F, S, I, (T1, T2, T3, T4, T5)>
              axum::middleware::FromFn<F, S, I, (T1, T2, T3, T4, T5, T6)>
              axum::middleware::FromFn<F, S, I, (T1, T2, T3, T4, T5, T6, T7)>
              axum::middleware::FromFn<F, S, I, (T1, T2, T3, T4, T5, T6, T7, T8)>
              axum::middleware::FromFn<F, S, I, (T1, T2, T3, T4, T5, T6, T7, T8, T9)>
            and 8 others
note: required by a bound in `Router::<S>::layer`
   --> /home/admin/.cargo/registry/src/index.crates.io-6f17d22bba15001f/axum-0.7.7/src/routing/mod.rs:279:21
    |
276 |     pub fn layer<L>(self, layer: L) -> Router<S>
    |            ----- required by a bound in this associated function
...
279 |         L::Service: Service<Request> + Clone + Send + 'static,
    |                     ^^^^^^^^^^^^^^^^ required by this bound in `Router::<S>::layer`

For more information about this error, try `rustc --explain E0277`.

Any and all suggestions are welcome. I have been spending days beating my head against this issue. I have tried good old google, reading the docs and a lot of AI bots: chatGPT 4, 4o, o1-mini, o1-preview, claude, perplexity etc. Nothing has helped me. I have just moved from one issue to getting another, and I believe my current code is close to work, but is missing something.

edit:

services/token_services/decode_token.rs:

use chrono::Utc;
use jsonwebtoken::{decode, DecodingKey, TokenData, Validation};

use crate::models::{token::Token, token_content::TokenContent};

pub fn decode_token(
    token: &str,
    is_access_token: bool,
    access_secret: &str,
    refresh_secret: &str,
) -> Result<TokenContent, Box<dyn std::error::Error>> {
    // Determine which secret to use based on the token type
    let secret = if is_access_token {
        access_secret
    } else {
        refresh_secret
    };

    // Decode the token
    let token_data: TokenData<Token> = decode(
        token,
        &DecodingKey::from_secret(secret.as_ref()),
        &Validation::default(),
    )?;

    // Check if the token is expired
    let current_timestamp = Utc::now().timestamp() as usize;
    if token_data.claims.exp < current_timestamp {
        return Err("Token is expired".into());
    }

    // Construct the User struct from the decoded token data
    let user = TokenContent {
        user_id: token_data.claims.user_id,
        email: token_data.claims.email,
    };

    Ok(user)
}

Adding #[axum::debug_middleware] to auth_middleware gave me some new errors with new clues when compiling:

error: future cannot be sent between threads safely
  --> src/middleware/auth_middleware.rs:9:1
   |
9  | #[axum::debug_middleware]
   | ^^^^^^^^^^^^^^^^^^^^^^^^^ future returned by `auth_middleware` is not `Send`
   |
   = help: the trait `Send` is not implemented for `dyn StdError`, which is required by `impl Future<Output = Result<Response<axum::body::Body>, StatusCode>>: Send`
note: future is not `Send` as this value is used across an await
  --> src/middleware/auth_middleware.rs:50:30
   |
41 |       match decode_token(
   |  ___________-
42 | |         token,
43 | |         true, // Indicates that we're validating an access token
44 | |         &state.access_token_secret,
45 | |         &state.refresh_token_secret,
46 | |     ) {
   | |_____- has type `Result<TokenContent, Box<dyn StdError>>` which is not `Send`
...
50 |               Ok(next.run(req).await)
   |                                ^^^^^ await occurs here, with the value maybe used later
note: required by a bound in `__axum_macros_check_auth_middleware_future::check`
  --> src/middleware/auth_middleware.rs:9:1
   |
9  | #[axum::debug_middleware]
   | ^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `check`
   = note: this error originates in the attribute macro `axum::debug_middleware` (in Nightly builds, run with -Z macro-backtrace for more info)

Edit2: Changing the return type of the decode token function to: Result<TokenContent, Box<dyn std::error::Error + Send + Sync>> did the trick


Solution

  • As all visible aspects of the middleware function signatures are the same, I suggest considering invisible auto traits. For example, my guess is that:

    1. token_content is !Send (perhaps it uses Rc instead of Arc)
    2. Rust's approximate static analysis considers it to be held over the subsequent .await point
    3. The futures returned by the function passed to from_fn are therefore !Send
    4. FromFn<...> no longer implements Tower Service

    Here's a reproducible example of this issue:

    use std::rc::Rc;
    pub struct NotSend(Rc<()>);
    
    async fn run(_req: Option<NotSend>) {}
    
    pub async fn broken<'a>(mut req: Option<NotSend>) {
        let not_send = NotSend(Rc::default());
        req = Some(not_send);
        run(req).await;
    }
    
    pub fn main() {
        fn is_send_and_static<T: Send + 'static>(_t: T) {}
        is_send_and_static(broken(None));
    }
    

    You can falsify this explanation by showing that token_content is Send.

    fn is_send<T: Send>(_t: &T) {}
    is_send(&token_content);
    

    If token_content turns out to be Send, I still think it's worth looking into other reasons why the middleware future could be !Send.