Search code examples
asynchronousrustlifetimerust-tokiohyper

Lifetime error when implementing AsyncWrite to hyper Sender


I need to covert the hyper::body::Sender to a tokio::io::AsyncWrite and pass it to one of my reusable function. That function is platform agnostic and can be used for any io operation. That's why I am taking an AsyncWrite as a parameter.

First I tried to use the stream-body crate and found that it using old version of tokio. So I decided to implement the AsyncWrite to the Sender. Then I got a lifetime error when storing the future in my struct.

This is my try:- Playground

use hyper::{Request, Body, body::Sender, Response}; // 0.14.26
use futures::{future::BoxFuture, Future}; // 0.3.28
use std::task::Poll;
use pin_project::pin_project; // 1.1.0
use tokio::io::AsyncWrite; // 1.28.2
use bytes::Bytes; // 1.4.0

#[pin_project]
pub struct SenderWriter {
    sender: Sender,
    #[pin]
    write_fut: Option<BoxFuture<'static, hyper::Result<()>>>,
    last_len: usize
}

impl SenderWriter {
    pub fn new(sender: Sender) -> SenderWriter {
        SenderWriter { sender, write_fut: None, last_len: 0 }
    }
}

impl AsyncWrite for SenderWriter {
    fn poll_write(
            self: std::pin::Pin<&mut Self>,
            cx: &mut std::task::Context<'_>,
            buf: &[u8],
        ) -> Poll<Result<usize, std::io::Error>> {
        let mut this = self.project();
        
        if this.write_fut.is_none() {
            // Storing the last buffer length in memory
            *this.last_len = buf.len();
            // Creating the future
            let fut = this.sender.send_data(Bytes::copy_from_slice(buf));
            *this.write_fut = Some(Box::pin(fut));
        }

        // Keeping length in memory to send with poll result
        let last_len = this.last_len.clone();

        let polled = this.write_fut.as_mut().as_pin_mut().unwrap().poll(cx);

        if polled.is_ready() {
            // Resetting to accept other set ot bytes
            *this.last_len = 0;
            *this.write_fut = None;
        }

        polled.map(move |res|res.map(|_|last_len).map_err(|e|std::io::Error::new(std::io::ErrorKind::Other, e)))
    }

    fn poll_flush(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Result<(), std::io::Error>> {
        let this = self.project();
        let res = this.sender.poll_ready(cx);
        res.map(|r|r.map_err(|e|std::io::Error::new(std::io::ErrorKind::Other, e)))
    }

    fn poll_shutdown(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Result<(), std::io::Error>> {
        self.poll_flush(cx)
    }
}

pub async fn my_reusable_fn<W: AsyncWrite+ Send + Unpin + 'static>(_writer: W) {
    
}

pub async fn download_handler(_req: Request<Body>) -> Response<Body> {
    let (sender, body) = Body::channel();
    let sender_writer = SenderWriter::new(sender);
    tokio::spawn(my_reusable_fn(sender_writer));
    Response::builder().body(body).unwrap()
}

Then I changed the 'static lifetime parameter in the BoxFuture to a generic lifetime parameter. But then the self.project() statement returned a lifetime error.


Solution

  • All Sender::send_data() does is to wait the sender to become ready then call try_send_data(). We can do that manually:

    use std::io::{Error, ErrorKind};
    use std::pin::Pin;
    use std::task::{ready, Context, Poll};
    
    use hyper::body::Sender;
    use tokio::io::AsyncWrite;
    
    pub struct SenderWriter {
        sender: Sender,
    }
    
    impl SenderWriter {
        pub fn new(sender: Sender) -> SenderWriter {
            SenderWriter { sender }
        }
    }
    
    impl AsyncWrite for SenderWriter {
        fn poll_write(
            mut self: Pin<&mut Self>,
            cx: &mut Context<'_>,
            buf: &[u8],
        ) -> Poll<Result<usize, Error>> {
            ready!(self
                .sender
                .poll_ready(cx)
                .map_err(|e| Error::new(ErrorKind::Other, e))?);
    
            match self.sender.try_send_data(Box::<[u8]>::from(buf).into()) {
                Ok(()) => Poll::Ready(Ok(buf.len())),
                Err(_) => Poll::Ready(Err(Error::new(ErrorKind::Other, "Body closed"))),
            }
        }
    
        fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
            let res = self.sender.poll_ready(cx);
            res.map_err(|e| Error::new(ErrorKind::Other, e))
        }
    
        fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
            self.poll_flush(cx)
        }
    }