I am probably stupid, but I have been staring at this issue for days now and cannot figure it out. The middleware functionality in axum have changed so much recently that 90% of the internet is outdated. The docs won't help me either, probably because I am new to rust so I am probably missing something. I am using axum 0.7.5. Anyway here is code:
main.rs:
#[tokio::main]
async fn main() {
// Initialize environment variables from .env file
dotenv::dotenv().ok();
if let Err(e) = run().await {
eprintln!("Application error: {:?}", e);
}
}
lib.rs:
pub async fn run() -> Result<(), Box<dyn Error>> {
// Establish the database connection pool
let db_pool = Arc::new(db::connection::establish_connection().await?);
// Create the application router and add the connection pool as an extension
let app = create_routes(db_pool.clone());
// Use the ? operator to propagate errors instead of unwrapping
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await?;
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.await?;
Ok(())
}
state.rs:
use std::sync::Arc;
#[derive(Clone)]
pub struct AppState {
pub access_token_secret: Arc<String>,
pub refresh_token_secret: Arc<String>,
}
routes/mod.rs:
pub mod analytics;
pub mod auth;
pub mod legal;
pub mod misc;
pub mod users;
use std::{env, sync::Arc};
use axum::middleware::from_fn;
use axum::{Extension, Router};
use axum_client_ip::SecureClientIpSource;
use sqlx::MySqlPool;
use crate::{
middleware::{auth_middleware::auth_middleware, input_validation::input_validation},
state::AppState,
};
pub fn create_routes(pool: Arc<MySqlPool>) -> Router {
// Load the token secrets from environment variables
let access_token_secret =
env::var("ACCESS_TOKEN_SECRET").expect("ACCESS_TOKEN_SECRET must be set");
let refresh_token_secret =
env::var("REFRESH_TOKEN_SECRET").expect("REFRESH_TOKEN_SECRET must be set");
let app_state = AppState {
access_token_secret: Arc::new(access_token_secret),
refresh_token_secret: Arc::new(refresh_token_secret),
};
Router::new()
.nest("/v1/users", users::routes())
.nest("/v1/auth", auth::routes(pool.clone()))
.nest("/v1/legal", legal::routes(pool.clone()))
.nest("/v1", misc::routes())
.layer(from_fn(input_validation))
.layer(from_fn(auth_middleware))
.layer(Extension(app_state.clone()))
.layer(SecureClientIpSource::ConnectInfo.into_extension())
}
middleware/input_validation.rs (this one works):
use axum::{
body::to_bytes, body::Body, http::Request, http::StatusCode, middleware::Next,
response::Response,
};
use hyper::header::{HeaderValue, CONTENT_TYPE};
use serde_json::Value;
use std::collections::HashMap;
use url::form_urlencoded;
use crate::validation::validate_fields::validate_fields;
const MAX_BODY_SIZE: usize = 1024 * 1024; // 1 MB
pub async fn input_validation(req: Request<Body>, next: Next) -> Result<Response, StatusCode> {
let mut fields = HashMap::new();
// Extract query parameters
if let Some(query) = req.uri().query() {
for (key, value) in form_urlencoded::parse(query.as_bytes()) {
fields.insert(key.into_owned(), value.into_owned());
}
}
// Extract content type before moving the request
let content_type = req.headers().get(CONTENT_TYPE).cloned();
// Split the request into parts
let (parts, body) = req.into_parts();
// Extract body if it's JSON
if content_type == Some(HeaderValue::from_static("application/json")) {
// Attempt to read the whole body with a size limit
let whole_body = match to_bytes(body, MAX_BODY_SIZE).await {
Ok(bytes) => bytes,
Err(_) => return Err(StatusCode::BAD_REQUEST),
};
// Check if the body exceeds the maximum size
if whole_body.len() > MAX_BODY_SIZE {
return Err(StatusCode::PAYLOAD_TOO_LARGE);
}
// Parse the JSON body
let json_body: Value = match serde_json::from_slice(&whole_body) {
Ok(json) => json,
Err(_) => return Err(StatusCode::BAD_REQUEST),
};
// Extract fields from the JSON object
if let Some(object) = json_body.as_object() {
for (key, value) in object.iter() {
fields.insert(key.clone(), value.to_string());
}
}
// Validate all extracted fields and clean them
match validate_fields(&fields) {
Ok(cleaned_fields) => {
// Optionally reconstruct JSON body with cleaned values
let new_body = match serde_json::to_string(&cleaned_fields) {
Ok(json_str) => json_str,
Err(_) => return Err(StatusCode::BAD_REQUEST),
};
// Reconstruct the request with the cleaned body
let new_request = Request::from_parts(parts, Body::from(new_body));
return Ok(next.run(new_request).await);
}
Err(ref err) if err == "Invalid credentials." => Err(StatusCode::UNAUTHORIZED),
Err(_) => Err(StatusCode::BAD_REQUEST),
}
} else {
// Validate all extracted fields for non-JSON bodies
match validate_fields(&fields) {
Ok(_) => {
// Reconstruct the request with the original body
let original_request = Request::from_parts(parts, body);
return Ok(next.run(original_request).await);
}
Err(ref err) if err == "Invalid credentials." => Err(StatusCode::UNAUTHORIZED),
Err(_) => Err(StatusCode::BAD_REQUEST),
}
}
}
middleware/auth_validation.rs (this one does not work):
use crate::{services::token_services::decode_token::decode_token, state::AppState};
use axum::{
body::Body,
http::{Request, StatusCode},
middleware::Next,
response::Response,
};
pub async fn auth_middleware(mut req: Request<Body>, next: Next) -> Result<Response, StatusCode> {
// Get the state from request extensions
let state = req
.extensions()
.get::<AppState>()
.cloned()
.ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
// Get the request path
let path = req.uri().path();
// Define public paths that bypass authentication
let public_paths = [
"/v1/auth/login",
"/v1/auth/register",
// Add more public routes if necessary
];
// If the path is public, skip authentication
if public_paths.contains(&path) {
return Ok(next.run(req).await);
}
// Extract the "x-access-token" header
let token = req
.headers()
.get("x-access-token")
.and_then(|t| t.to_str().ok())
.ok_or(StatusCode::UNAUTHORIZED)?;
// Decode and validate the token using your existing decode_token function
match decode_token(
token,
true, // Indicates that we're validating an access token
&state.access_token_secret,
&state.refresh_token_secret,
) {
Ok(token_content) => {
// Insert the TokenContent into request extensions for access in handlers
req.extensions_mut().insert(token_content);
Ok(next.run(req).await)
}
Err(_) => Err(StatusCode::UNAUTHORIZED),
}
}
The rust analyzer is only complaining about routes/mod.rs. I am getting this compile error:
error[E0277]: the trait bound `axum::middleware::FromFn<fn(hyper::Request<axum::body::Body>, Next) -> impl Future<Output = Result<Response<axum::body::Body>, StatusCode>> {auth_middleware}, (), Route, _>: tower_service::Service<hyper::Request<axum::body::Body>>` is not satisfied
--> src/routes/mod.rs:40:16
|
40 | .layer(from_fn(auth_middleware))
| ----- ^^^^^^^^^^^^^^^^^^^^^^^^ the trait `tower_service::Service<hyper::Request<axum::body::Body>>` is not implemented for `axum::middleware::FromFn<fn(hyper::Request<axum::body::Body>, Next) -> impl Future<Output = Result<Response<axum::body::Body>, StatusCode>> {auth_middleware}, (), Route, _>`
| |
| required by a bound introduced by this call
|
= help: the following other types implement trait `tower_service::Service<Request>`:
axum::middleware::FromFn<F, S, I, (T1, T2)>
axum::middleware::FromFn<F, S, I, (T1, T2, T3)>
axum::middleware::FromFn<F, S, I, (T1, T2, T3, T4)>
axum::middleware::FromFn<F, S, I, (T1, T2, T3, T4, T5)>
axum::middleware::FromFn<F, S, I, (T1, T2, T3, T4, T5, T6)>
axum::middleware::FromFn<F, S, I, (T1, T2, T3, T4, T5, T6, T7)>
axum::middleware::FromFn<F, S, I, (T1, T2, T3, T4, T5, T6, T7, T8)>
axum::middleware::FromFn<F, S, I, (T1, T2, T3, T4, T5, T6, T7, T8, T9)>
and 8 others
note: required by a bound in `Router::<S>::layer`
--> /home/admin/.cargo/registry/src/index.crates.io-6f17d22bba15001f/axum-0.7.7/src/routing/mod.rs:279:21
|
276 | pub fn layer<L>(self, layer: L) -> Router<S>
| ----- required by a bound in this associated function
...
279 | L::Service: Service<Request> + Clone + Send + 'static,
| ^^^^^^^^^^^^^^^^ required by this bound in `Router::<S>::layer`
For more information about this error, try `rustc --explain E0277`.
Any and all suggestions are welcome. I have been spending days beating my head against this issue. I have tried good old google, reading the docs and a lot of AI bots: chatGPT 4, 4o, o1-mini, o1-preview, claude, perplexity etc. Nothing has helped me. I have just moved from one issue to getting another, and I believe my current code is close to work, but is missing something.
edit:
services/token_services/decode_token.rs:
use chrono::Utc;
use jsonwebtoken::{decode, DecodingKey, TokenData, Validation};
use crate::models::{token::Token, token_content::TokenContent};
pub fn decode_token(
token: &str,
is_access_token: bool,
access_secret: &str,
refresh_secret: &str,
) -> Result<TokenContent, Box<dyn std::error::Error>> {
// Determine which secret to use based on the token type
let secret = if is_access_token {
access_secret
} else {
refresh_secret
};
// Decode the token
let token_data: TokenData<Token> = decode(
token,
&DecodingKey::from_secret(secret.as_ref()),
&Validation::default(),
)?;
// Check if the token is expired
let current_timestamp = Utc::now().timestamp() as usize;
if token_data.claims.exp < current_timestamp {
return Err("Token is expired".into());
}
// Construct the User struct from the decoded token data
let user = TokenContent {
user_id: token_data.claims.user_id,
email: token_data.claims.email,
};
Ok(user)
}
Adding #[axum::debug_middleware]
to auth_middleware
gave me some new errors with new clues when compiling:
error: future cannot be sent between threads safely
--> src/middleware/auth_middleware.rs:9:1
|
9 | #[axum::debug_middleware]
| ^^^^^^^^^^^^^^^^^^^^^^^^^ future returned by `auth_middleware` is not `Send`
|
= help: the trait `Send` is not implemented for `dyn StdError`, which is required by `impl Future<Output = Result<Response<axum::body::Body>, StatusCode>>: Send`
note: future is not `Send` as this value is used across an await
--> src/middleware/auth_middleware.rs:50:30
|
41 | match decode_token(
| ___________-
42 | | token,
43 | | true, // Indicates that we're validating an access token
44 | | &state.access_token_secret,
45 | | &state.refresh_token_secret,
46 | | ) {
| |_____- has type `Result<TokenContent, Box<dyn StdError>>` which is not `Send`
...
50 | Ok(next.run(req).await)
| ^^^^^ await occurs here, with the value maybe used later
note: required by a bound in `__axum_macros_check_auth_middleware_future::check`
--> src/middleware/auth_middleware.rs:9:1
|
9 | #[axum::debug_middleware]
| ^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `check`
= note: this error originates in the attribute macro `axum::debug_middleware` (in Nightly builds, run with -Z macro-backtrace for more info)
Edit2:
Changing the return type of the decode token function to:
Result<TokenContent, Box<dyn std::error::Error + Send + Sync>>
did the trick
As all visible aspects of the middleware function signatures are the same, I suggest considering invisible auto traits. For example, my guess is that:
token_content
is !Send
(perhaps it uses Rc
instead of Arc
).await
pointfrom_fn
are therefore !Send
FromFn<...>
no longer implements Tower Service
Here's a reproducible example of this issue:
use std::rc::Rc;
pub struct NotSend(Rc<()>);
async fn run(_req: Option<NotSend>) {}
pub async fn broken<'a>(mut req: Option<NotSend>) {
let not_send = NotSend(Rc::default());
req = Some(not_send);
run(req).await;
}
pub fn main() {
fn is_send_and_static<T: Send + 'static>(_t: T) {}
is_send_and_static(broken(None));
}
You can falsify this explanation by showing that token_content
is Send
.
fn is_send<T: Send>(_t: &T) {}
is_send(&token_content);
If token_content
turns out to be Send
, I still think it's worth looking into other reasons why the middleware future could be !Send
.