Search code examples
rustasync-awaitblockingrust-tokiopeek

How to peek into a TcpStream and block until enough bytes are available?


I would like to classify incoming tcp streams by their first n bytes and then pass to different handlers according to the classification.

I do not want to consume any of the bytes in the stream, otherwise I will be passing invalid streams to the handlers, that start with the nth byte.

So poll_peek looks almost like what I need, as it waits for data to be available before it peeks.

However I think what I would ideally need would be a poll_peek_exact that does not return until the passed buffer is full. This method does not seem to exist in TcpStream, so I'm not sure what the correct way would be to peek the first n bytes of a TcpStream without consuming them.

I could do something like:

    // Keep peeking until we have enough bytes to decide.
    while let Ok(num_bytes) = poll_fn(|cx| {
        tcp_stream.poll_peek(cx, &mut buf)
    }).await? {
        if num_bytes >= n {
            return classify(&buf);
        }
    }

But I think that would be busy waiting, so it seems like a bad idea, right? I could of course add a sleep to the loop, but that also does not seem like good style to me.

So what's the right way to do that?


Solution

  • Here is my attempt:

    use pin_project::pin_project;
    use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
    use tokio::net::{TcpListener, TcpStream};
    
    use std::error::Error;
    
    #[pin_project]
    struct HeaderExtractor<const S: usize> {
        #[pin]
        socket: TcpStream,
        header: [u8; S],
        num_forwarded: usize,
    }
    
    impl<const S: usize> HeaderExtractor<S> {
        pub async fn read_header(socket: TcpStream) -> Result<Self, Box<dyn Error>> {
            let mut this = Self {
                socket,
                header: [0; S],
                num_forwarded: 0,
            };
    
            this.socket.read_exact(&mut this.header).await?;
    
            Ok(this)
        }
    
        pub fn get_header(&mut self) -> &[u8; S] {
            &self.header
        }
    }
    
    impl<const S: usize> AsyncRead for HeaderExtractor<S> {
        fn poll_read(
            self: std::pin::Pin<&mut Self>,
            cx: &mut std::task::Context<'_>,
            buf: &mut tokio::io::ReadBuf<'_>,
        ) -> std::task::Poll<std::io::Result<()>> {
            let this = self.project();
    
            if *this.num_forwarded < this.header.len() {
                let leftover = &this.header[*this.num_forwarded..];
    
                let num_forward_now = leftover.len().min(buf.remaining());
                let forward = &leftover[..num_forward_now];
                buf.put_slice(forward);
    
                *this.num_forwarded += num_forward_now;
    
                std::task::Poll::Ready(Ok(()))
            } else {
                this.socket.poll_read(cx, buf)
            }
        }
    }
    
    impl<const S: usize> AsyncWrite for HeaderExtractor<S> {
        fn poll_write(
            self: std::pin::Pin<&mut Self>,
            cx: &mut std::task::Context<'_>,
            buf: &[u8],
        ) -> std::task::Poll<Result<usize, std::io::Error>> {
            let this = self.project();
            this.socket.poll_write(cx, buf)
        }
    
        fn poll_flush(
            self: std::pin::Pin<&mut Self>,
            cx: &mut std::task::Context<'_>,
        ) -> std::task::Poll<Result<(), std::io::Error>> {
            let this = self.project();
            this.socket.poll_flush(cx)
        }
    
        fn poll_shutdown(
            self: std::pin::Pin<&mut Self>,
            cx: &mut std::task::Context<'_>,
        ) -> std::task::Poll<Result<(), std::io::Error>> {
            let this = self.project();
            this.socket.poll_shutdown(cx)
        }
    }
    
    #[tokio::main]
    async fn main() -> Result<(), Box<dyn Error>> {
        let listener = TcpListener::bind("127.0.0.1:12345").await?;
    
        loop {
            // Asynchronously wait for an inbound socket.
            let (socket, _) = listener.accept().await?;
    
            let mut socket = HeaderExtractor::<3>::read_header(socket).await?;
            let header = socket.get_header();
            println!("Got header: {:?}", header);
    
            tokio::spawn(async move {
                let mut buf = vec![0; 1024];
    
                // In a loop, read data from the socket and write the data back.
                loop {
                    let n = socket
                        .read(&mut buf)
                        .await
                        .expect("failed to read data from socket");
    
                    if n == 0 {
                        println!("Connection closed.");
                        return;
                    }
    
                    println!("Received: {:?}", &buf[..n]);
                }
            });
        }
    }
    

    When I run echo "123HelloWorld!" | nc -N l localhost 12345 on another console, I get:

    Got header: [49, 50, 51]
    Received: [49, 50, 51]
    Received: [72, 101, 108, 108, 111, 87, 111, 114, 108, 100, 33, 10]
    Connection closed.