Search code examples
rustmiddlewarerust-axum

How to pass optional parametr to middleware in from_fn func in axum?


In my axum backend I want to be able to determine what my auth middleware will add to the request: user_id or user model itself. How can I pass the optional full_user parameter to router? Example of using middleware:

.route("/", post(some_handlers::some_handler::post_smth),
        )
        .route_layer(middleware::from_fn_with_state(
            client.clone(),
            auth_middleware::auth,
        ));

I have such auth middleware:

pub async fn auth<B>(
    State(client): State<Client>,
    mut req: Request<B>,
    next: Next<B>,
) -> Result<Response, StatusCode> {
    let auth_header = match req.headers().get(http::header::AUTHORIZATION) {
        Some(header) => header.to_str().ok(),
        None => None,
    };

    let jwt_secret = std::env::var("JWT_SECRET").map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

    let token = match auth_header {
        Some(token) => token,
        None => return Err(StatusCode::UNAUTHORIZED),
    };

    let token_claims = verify_token(token, &jwt_secret).map_err(|_| StatusCode::UNAUTHORIZED)?;

    let user_id = ObjectId::parse_str(&token_claims.sub).map_err(|_| StatusCode::UNAUTHORIZED)?;

    let collection: Collection<User> = client.database("Merume").collection("users");
    match collection.find_one(doc! {"_id": user_id}, None).await {
        Ok(Some(user)) => user,
        Ok(None) => return Err(StatusCode::UNAUTHORIZED),
        Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
    };

    req.extensions_mut().insert(user_id);
    Ok(next.run(req).await)
}

Tried something like this but it didn't work cause arguments to this func become incorrect

.layer(middleware::from_fn_with_state(
    client.clone(),
    |req, next| auth(req, next, Some(true)),
));

Solution

  • I missed the status parameter as @cdhowie was sad in the comments. Solution:

    .route_layer(middleware::from_fn_with_state(
            client.clone(),
            |state, req, next| auth_middleware::auth(state, req, next, Some(false)),
        ));