I have the following method:
async fn transfer_all(stream: &mut TcpStream) -> Result<Vec<Vec<u8>>, Box<dyn std::error::Error>> {
let mut packets: Vec<Vec<u8>> = Vec::new();
let mut header = true;
let mut length: usize = 0;
let mut packet: Vec<u8> = Vec::new();
loop {
stream.readable().await?;
if header {
length = 5;
packet.clear();
packet.shrink_to_fit();
packet.reserve(length);
}
let mut buf: Vec<u8> = vec![0u8; length];
match stream.try_read(&mut buf) {
Ok(0) => {
break;
}
Ok(n) => {
if header {
length = u32::from_be_bytes(pop(&buf[1..])) as usize - 4;
header = false;
packet.append(&mut buf);
packet.reserve(length);
continue;
}
packet.append(&mut buf);
packets.push(packet.clone());
header = true;
}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
break;
}
Err(e) => {
return Err(e.into());
}
}
}
Ok(packets)
}
It works with TcpStream but I need to also make it work with UnixStream. Since this is a fairly convoluted state machine I'd rather not have two implementations. It was suggested to me to use async fn transfer_all<S: AsyncRead + Unpin>(stream: &mut S) -> Result<Vec<Vec<u8>>, Box<dyn std::error::Error>>
and replace match stream.try_read(&mut buf) {
with match stream.read(&mut buf).await {
but this blocks when there's no more data to read. How can I make this method work with TcpStream and UnixStream?
Since both UnixStream
and TcpStream
have a try_read
method, you can make your own trait for them:
trait TryRead {
// overlapping the name makes it hard to work with
fn do_try_read(&self, buf: &mut [u8]) -> Result<usize>;
}
impl TryRead for TcpStream {
fn do_try_read(&self, buf: &mut [u8]) -> Result<usize> {
self.try_read(buf)
}
}
impl TryRead for UnixStream {
fn do_try_read(&self, buf: &mut [u8]) -> Result<usize> {
self.try_read(buf)
}
}
Then, you can take a S: AsyncRead + TryRead + Unpin
then replace try_read
with do_try_read
.
Or, with a macro for reduced repetition:
trait TryRead {
// overlapping the name makes it hard to work with
fn do_try_read(&self, buf: &mut [u8]) -> Result<usize>;
}
macro_rules! make_try_read {
($typ: ty) => {
impl TryRead for $typ {
fn do_try_read(&self, buf: &mut [u8]) -> Result<usize> {
self.try_read(buf)
}
}
}
}
make_try_read!(TcpStream);
make_try_read!(UnixStream);