Search code examples
rustproxyiostreambuffer

How to "fork" an AsyncReader into two in Rust?


I'm currently working on implementing a TCP proxy in Rust that can detect the server name in HTTPS connections and decide which proxy to use based on it. Specifically, I need to copy data from a TcpStream and pass it to tokio_rustls::LazyConfigAcceptor to detect ClientHello Messages.

I'm looking for a solution on how to efficiently fork an AsyncReader into two separate readers that can both read data from the underlying reader without blocking each other. Additionally, I want the buffer to dynamically grow to accommodate varying amounts of data read by each reader.

Example

let (reader, reader_copy) = fork(reader);
// ClientHello data will read by detect_server_name *and* forwarded to proxy
let server_name = detect_server_name(reader_copy).await;
proxy_map.get(&server_name).forward(reader).await; 

Solution

  • Solved by share a VecDeq buffer and mark two index

    use std::cmp::min;
    use std::collections::VecDeque;
    use std::pin::Pin;
    use std::sync::{Arc, Mutex};
    use std::task::{ready, Waker};
    use std::task::{Context, Poll};
    use tokio::io::{AsyncRead, ReadBuf};
    
    struct Inner<R> {
        reader: R,
        buffer: VecDeque<u8>,
        cursors: (usize, usize),
        wakers: (Option<Waker>, Option<Waker>),
    }
    pub struct AsyncForkReader<R> {
        inner: Arc<Mutex<Inner<R>>>,
        is_left: bool,
    }
    impl<R> AsyncForkReader<R> {
        pub fn new(reader: R) -> (Self, Self) {
            let inner = Arc::new(Mutex::new(Inner {
                reader,
                buffer: VecDeque::new(),
                cursors: (0, 0),
                wakers: (None, None),
            }));
            (
                Self {
                    inner: inner.clone(),
                    is_left: true,
                },
                Self {
                    inner,
                    is_left: false,
                },
            )
        }
    }
    
    impl<R: AsyncRead> AsyncRead for AsyncForkReader<R> {
        fn poll_read(
            self: Pin<&mut Self>,
            cx: &mut Context<'_>,
            buf: &mut ReadBuf<'_>,
        ) -> Poll<std::io::Result<()>> {
            let mut inner = self.inner.lock().unwrap();
            let Inner {
                ref mut buffer,
                ref mut cursors,
                ref mut reader,
                ref mut wakers,
            } = *inner;
            let (ref mut cursor, ref mut other_cursor);
            let (ref mut waker, ref mut other_waker);
            if self.is_left {
                cursor = &mut cursors.0;
                waker = &mut wakers.0;
                other_cursor = &mut cursors.1;
                other_waker = &mut wakers.1;
            } else {
                cursor = &mut cursors.1;
                waker = &mut wakers.1;
                other_cursor = &mut cursors.0;
                other_waker = &mut wakers.0;
            };
    
            if buffer.len() == *cursor {
                *waker = Some(cx.waker().clone());
                // According to the documentation, only the last waker should be called when poll is called many times
                // So we need to wake another fork manualy to make sure waker another fork will not be coverd and nerver wake again
                if let Some(waker) = other_waker.take() {
                    waker.wake();
                }
                // If reads all buffer then try poll more
                let origin_len = buf.filled().len();
                // Cannot use subject: Mutex
                // Safety: Never moved out
                ready!(unsafe { Pin::new_unchecked(reader) }.poll_read(cx, buf))?;
                let slice = buf.filled().split_at(origin_len).1;
                buffer.reserve(slice.len());
                buffer.extend(slice.iter());
                *cursor = buffer.len();
            } else {
                *waker = None;
                // get datas in buffer
                let len = min(buffer.len() - *cursor, buf.remaining());
    
                buffer
                    .range(*cursor..(*cursor + len))
                    .for_each(|value| buf.put_slice(&[*value]));
                *cursor += len;
    
                // drop data that both unused by A and B
                let release_len = min(*cursor, *other_cursor);
                if release_len > 0 {
                    buffer.drain(0..release_len);
                    *cursor -= release_len;
                    *other_cursor -= release_len;
                }
            }
            Poll::Ready(Ok(()))
        }
    }