I'm trying to use Axum's custom extractor to implement JWT authentication. I can print the contents of the State within the extractor, but I can't access it inside the extractor. Below is the JWT validation code I've written. Can you advise me on how to modify it to achieve the functionality I desire?
//main.rs
use axum::body::Bytes;
use axum::extract::{Json, Request, State};
use axum::{
routing::{get, post},
Router,
};
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tower_http::trace::TraceLayer;
use tracing::Span;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
mod authentication_token;
use authentication_token::ExtractAuthorization;
#[derive(Clone, Debug)]
struct AppState {
secret: String,
}
#[derive(Deserialize, Debug, PartialEq)]
struct User {
account: usize,
password: String,
}
#[derive(Serialize, Deserialize, Debug, PartialEq)]
pub struct Claims {
pub id: usize,
pub exp: usize,
}
async fn register(State(secret): State<AppState>, Json(user): Json<User>) -> String {
let store_user = User {
account: 195,
password: "world".to_string(),
};
if user == store_user {
let expiration = SystemTime::now() + Duration::from_secs(30 * 60);
let exp_timestamp = expiration.duration_since(UNIX_EPOCH).unwrap().as_secs();
let claims = Claims {
id: user.account,
exp: exp_timestamp as usize,
};
let token = encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(secret.secret.as_bytes()),
)
.unwrap();
token
} else {
"hello, world!".to_string()
}
}
async fn login(State(secret): State<AppState>, req: Request) -> Json<Claims> {
let token = req
.headers()
.get("Authorization")
.unwrap()
.to_str()
.unwrap();
let payload = decode::<Claims>(
token,
&DecodingKey::from_secret(secret.secret.as_bytes()),
&Validation::new(Algorithm::HS256),
)
.unwrap();
Json(payload.claims)
}
async fn protected(_auth_token: ExtractAuthorization, req: Request) -> String {
println!("{:?}", req);
"World!".to_string()
}
#[tokio::main]
async fn main() {
let state = AppState {
secret: "baby195lxl".to_string(),
};
tracing_subscriber::registry()
.with(tracing_subscriber::EnvFilter::new("debug"))
.with(tracing_subscriber::fmt::layer())
.init();
let app = Router::new()
.route("/register", post(register))
.route("/login", post(login))
.route("/protected", get(protected))
.with_state(state)
.layer(TraceLayer::new_for_http().on_body_chunk(
|chunk: &Bytes, latency: Duration, _span: &Span| {
tracing::debug!("streaming {} bytes in {:?}", chunk.len(), latency);
},
));
let listener = tokio::net::TcpListener::bind("127.0.0.1:5000")
.await
.unwrap();
tracing::debug!("listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.unwrap();
}
//authentication_token.rs
use axum::{
async_trait,
extract::FromRequestParts,
http::{header::AUTHORIZATION, request::Parts, StatusCode},
};
use jsonwebtoken::{
decode, errors::Error as JwtError, Algorithm, DecodingKey, TokenData, Validation,
};
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
pub struct ExtractAuthorization {
id: usize,
}
#[derive(Serialize, Deserialize)]
pub struct Claims {
pub id: usize,
pub exp: usize,
}
#[async_trait]
impl<S> FromRequestParts<S> for ExtractAuthorization
where
S: Send + Sync + Debug,
{
type Rejection = (StatusCode, &'static str);
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
println!("{:?}", state);
let secret = "baby195lxl";
let auth_header = parts.headers.get(AUTHORIZATION);
if auth_header.is_none() {
return Err((StatusCode::BAD_REQUEST, "Authorization is Missing"));
}
let auth_token: String = auth_header.unwrap().to_str().unwrap_or("").to_string();
if auth_token.is_empty() {
return Err((
StatusCode::BAD_REQUEST,
"Authentication token has foreign chars!",
));
}
let token_result: Result<TokenData<Claims>, JwtError> = decode::<Claims>(
&auth_token,
&DecodingKey::from_secret(secret.as_bytes()),
&Validation::new(Algorithm::HS256),
);
match token_result {
Ok(token) => Ok(ExtractAuthorization {
id: token.claims.id,
}),
Err(_) => Err((StatusCode::BAD_REQUEST, "Token Error")),
}
}
}
I have not yet found a solution that doesn't produce errors. The configuration of Cargo.toml
is as follows, and the Rust compiler version I am using is rustc 1.76.0 (07dca489a 2024-02-04). I would appreciate it if someone could help me clear up this confusion, and any reply are welcome. Thanks.
[dependencies]
axum = "^0.7"
tokio = { version = "^1.36", features = ["full"] }
tower-http = { version = "^0.5", features = ["trace"] }
tracing = "^0.1"
tracing-subscriber = { version = "^0.3", features = ["env-filter"] }
serde = { version = "1.0", features = ["derive"] }
jsonwebtoken = "9.2.0"
Instead of implementing the extractor for a all states generically, implement it for the specific state you have:
#[async_trait]
impl FromRequestParts<AppState> for ExtractAuthorization {
type Rejection = (StatusCode, &'static str);
async fn from_request_parts(parts: &mut Parts, state: &AppState) -> Result<Self, Self::Rejection> {
// ...
}
}