Search code examples
rustasync-awaitrust-tower

Implement tower::Layer using async block and ServiceFn


I'm trying to implement a tower Layer using the tower::layer_fn and tower::service_fn helper functions like this (compiles fine):

use std::convert::Infallible;
use tower::Service;
use tower::util::ServiceFn;


tower::ServiceBuilder::new()
    .layer_fn(|mut service: ServiceFn<_>| {
        // Just do nothing except calling the downstream service
        tower::service_fn(move |request| {
            service.call(request)
        })
    })
    .service_fn(|request: String| {
        // Echo service
        async move {
            let response = request;
            Ok::<_, Infallible>(response)
        }
    });

Because I have multiple .await points in my real code, I would like to avoid implementing Layer by hand, i.e. Service::call(), but instead do so in an async block. For the echo service, that does work fine, as shown above. However, for the service inside the layer_fn, that doesn't compile:

tower::ServiceBuilder::new()
    .layer_fn(|mut service: ServiceFn<_>| {
        tower::service_fn(move |request| {
            // Just do nothing except calling the downstream service
            async move {
                let response: Result<String, Infallible> = service.call(request).await;
                // Do something with response, await'ing thereby
                response
            }
        })
    })
    .service_fn(|request: String| {
        // Echo service
        async move {
            let response = request;
            Ok::<_, Infallible>(response)
        }
    });
});

I get the following error, but I don't know how to help the compiler with the typing:

error[E0698]: type inside `async` block must be known in this context
  --> src/main.rs:32:64
   |
32 |                     let response: Result<String, Infallible> = service.call(request).await;
   |                                                                ^^^^^^^ cannot infer type
   |
note: the type is part of the `async` block because of this `await`
  --> src/main.rs:32:85
   |
32 |                     let response: Result<String, Infallible> = service.call(request).await;
   |                                                                                     ^^^^^^

Solution

  • Type inference in async contexts is sometimes less powerful than in sync context. Unfortunately, the only solution I can see is to use the nightly type_alias_impl_trait:

    #![feature(type_alias_impl_trait)]
    
    type Fut = impl Future<Output = Result<String, Infallible>>;
    type Callback = impl Fn(String) -> Fut;
    tower::ServiceBuilder::new()
        .layer_fn(|mut service: ServiceFn<Callback>| {
            tower::service_fn(move |request| {
                // Just do nothing except calling the downstream service
                async move {
                    let response: Result<String, Infallible> = service.call(request).await;
                    // Do something with response, await'ing thereby
                    response
                }
            })
        })
        .service_fn::<Callback>(|request: String| {
            // Echo service
            async move {
                let response = request;
                Ok::<_, Infallible>(response)
            }
        });
    

    Or by boxing both the callback and the future:

    type Fut = Pin<Box<dyn Future<Output = Result<String, Infallible>>>>;
    type Callback = Box<dyn Fn(String) -> Fut>;
    tower::ServiceBuilder::new()
        .layer_fn(|mut service: ServiceFn<Callback>| {
            tower::service_fn(move |request| {
                // Just do nothing except calling the downstream service
                async move {
                    let response: Result<String, Infallible> = service.call(request).await;
                    // Do something with response, await'ing thereby
                    response
                }
            })
        })
        .service_fn::<Callback>(Box::new(|request: String| {
            // Echo service
            Box::pin(async move {
                let response = request;
                Ok::<_, Infallible>(response)
            })
        }));