I'm currently working on implementing a TCP proxy in Rust that can detect the server name in HTTPS connections and decide which proxy to use based on it. Specifically, I need to copy data from a TcpStream and pass it to tokio_rustls::LazyConfigAcceptor to detect ClientHello Messages.
I'm looking for a solution on how to efficiently fork an AsyncReader into two separate readers that can both read data from the underlying reader without blocking each other. Additionally, I want the buffer to dynamically grow to accommodate varying amounts of data read by each reader.
Example
let (reader, reader_copy) = fork(reader);
// ClientHello data will read by detect_server_name *and* forwarded to proxy
let server_name = detect_server_name(reader_copy).await;
proxy_map.get(&server_name).forward(reader).await;
Solved by share a VecDeq buffer and mark two index
use std::cmp::min;
use std::collections::VecDeque;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{ready, Waker};
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, ReadBuf};
struct Inner<R> {
reader: R,
buffer: VecDeque<u8>,
cursors: (usize, usize),
wakers: (Option<Waker>, Option<Waker>),
}
pub struct AsyncForkReader<R> {
inner: Arc<Mutex<Inner<R>>>,
is_left: bool,
}
impl<R> AsyncForkReader<R> {
pub fn new(reader: R) -> (Self, Self) {
let inner = Arc::new(Mutex::new(Inner {
reader,
buffer: VecDeque::new(),
cursors: (0, 0),
wakers: (None, None),
}));
(
Self {
inner: inner.clone(),
is_left: true,
},
Self {
inner,
is_left: false,
},
)
}
}
impl<R: AsyncRead> AsyncRead for AsyncForkReader<R> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let mut inner = self.inner.lock().unwrap();
let Inner {
ref mut buffer,
ref mut cursors,
ref mut reader,
ref mut wakers,
} = *inner;
let (ref mut cursor, ref mut other_cursor);
let (ref mut waker, ref mut other_waker);
if self.is_left {
cursor = &mut cursors.0;
waker = &mut wakers.0;
other_cursor = &mut cursors.1;
other_waker = &mut wakers.1;
} else {
cursor = &mut cursors.1;
waker = &mut wakers.1;
other_cursor = &mut cursors.0;
other_waker = &mut wakers.0;
};
if buffer.len() == *cursor {
*waker = Some(cx.waker().clone());
// According to the documentation, only the last waker should be called when poll is called many times
// So we need to wake another fork manualy to make sure waker another fork will not be coverd and nerver wake again
if let Some(waker) = other_waker.take() {
waker.wake();
}
// If reads all buffer then try poll more
let origin_len = buf.filled().len();
// Cannot use subject: Mutex
// Safety: Never moved out
ready!(unsafe { Pin::new_unchecked(reader) }.poll_read(cx, buf))?;
let slice = buf.filled().split_at(origin_len).1;
buffer.reserve(slice.len());
buffer.extend(slice.iter());
*cursor = buffer.len();
} else {
*waker = None;
// get datas in buffer
let len = min(buffer.len() - *cursor, buf.remaining());
buffer
.range(*cursor..(*cursor + len))
.for_each(|value| buf.put_slice(&[*value]));
*cursor += len;
// drop data that both unused by A and B
let release_len = min(*cursor, *other_cursor);
if release_len > 0 {
buffer.drain(0..release_len);
*cursor -= release_len;
*other_cursor -= release_len;
}
}
Poll::Ready(Ok(()))
}
}