Search code examples
rustrust-warp

Is there a way to do validation as part of a filter in Warp?


I have a route and an endpoint function defined. I've also injected some dependencies.

pub fn route1() -> BoxedFilter<(String, ParamType)> {
    warp::get()
        .and(warp::path::param())
        .and(warp::filters::query::query())
        .and(warp::path::end())
        .boxed()
}

pub async fn handler1(
    query: String,
    param: ParamType,
    dependency: DependencyType,
) -> Result<impl warp::Reply, warp::Rejection> {
}
let api = api::routes::route1()
    .and(warp::any().map(move || dependency))
    .and_then(api::hanlders::hander1);

This all seems to work fine.

However, I want to be able to have something that sits in front of several endpoints that checks for a valid key in the query parameter. Inside handler1 I can add:

if !param.key_valid {
    return Ok(warp::reply::with_status(
        warp::reply::json(&""),
        StatusCode::BAD_REQUEST,
    ));
}

I do not want to add this to every handler individually.

It seems like I should be able to do it via filter, but I can't figure it out. I've tried using .map() but then returning multiple items shifts it to a tuple and I have to change my downstream function signature. Ideally I want to find a way to add verification or other filters that can reject the request without any downstream values knowing about them.


Solution

  • This is effectively demonstrated by warp's rejection example:

    Rejections represent cases where a filter should not continue processing the request, but a different filter could process it.

    Extract a denominator from a "div-by" header, or reject with DivideByZero.

    You need to

    1. Use Filter::and_then to take the existing filter (in this case query()) and perform the validation. If the validation fails, return a custom rejection.
    2. Use Filter::recover to appropriately handle the custom rejection and any other possible errors.

    Applied to your situation:

    use serde::Deserialize;
    use std::{convert::Infallible, net::IpAddr};
    use warp::{filters::BoxedFilter, http::StatusCode, reject::Reject, Filter, Rejection, Reply};
    
    fn route1() -> BoxedFilter<(String, ParamType)> {
        warp::get()
            .and(warp::path::param())
            .and(validated_query())
            .and(warp::path::end())
            .boxed()
    }
    
    #[derive(Debug)]
    struct Invalid;
    impl Reject for Invalid {}
    
    fn validated_query() -> impl Filter<Extract = (ParamType,), Error = Rejection> + Copy {
        warp::filters::query::query().and_then(|param: ParamType| async move {
            if param.valid {
                Ok(param)
            } else {
                Err(warp::reject::custom(Invalid))
            }
        })
    }
    
    async fn report_invalid(r: Rejection) -> Result<impl Reply, Infallible> {
        let reply = warp::reply::reply();
    
        if let Some(Invalid) = r.find() {
            Ok(warp::reply::with_status(reply, StatusCode::BAD_REQUEST))
        } else {
            // Do better error handling here
            Ok(warp::reply::with_status(
                reply,
                StatusCode::INTERNAL_SERVER_ERROR,
            ))
        }
    }
    
    async fn handler1(
        _query: String,
        _param: ParamType,
        _dependency: DependencyType,
    ) -> Result<impl warp::Reply, warp::Rejection> {
        Ok(warp::reply::reply())
    }
    
    struct DependencyType;
    
    #[derive(Deserialize)]
    struct ParamType {
        valid: bool,
    }
    
    #[tokio::main]
    async fn main() {
        let api = route1()
            .and(warp::any().map(move || DependencyType))
            .and_then(handler1)
            .recover(report_invalid);
    
        let ip: IpAddr = "127.0.0.1".parse().unwrap();
        let port = 8888;
        warp::serve(api).run((ip, port)).await;
    }
    

    And the output of curl with irrelevant lines removed:

    % curl -v '127.0.0.1:8888/dummy/?valid=false'
    < HTTP/1.1 400 Bad Request
    
    % curl -v '127.0.0.1:8888/dummy/?valid=true'
    < HTTP/1.1 200 OK
    

    Cargo.toml

    [dependencies]
    warp = "0.2.2"
    serde = { version = "1.0.104", features = ["derive"] }
    tokio = { version = "0.2.13", features = ["full"] }