Search code examples
rustrust-tokio

How can I make a method stream agnostic?


I have the following method:

async fn transfer_all(stream: &mut TcpStream) -> Result<Vec<Vec<u8>>, Box<dyn std::error::Error>> {
  let mut packets: Vec<Vec<u8>> = Vec::new();
  let mut header = true;
  let mut length: usize = 0;

  let mut packet: Vec<u8> = Vec::new();
  loop {
    stream.readable().await?;

    if header {
      length = 5;
      packet.clear();
      packet.shrink_to_fit();
      packet.reserve(length);
    }

    let mut buf: Vec<u8> = vec![0u8; length];
    match stream.try_read(&mut buf) {
      Ok(0) => {
        break;
      }
      Ok(n) => {
        if header {
          length = u32::from_be_bytes(pop(&buf[1..])) as usize - 4;
          header = false;
          packet.append(&mut buf);
          packet.reserve(length);
          continue;
        }
        packet.append(&mut buf);
        packets.push(packet.clone());
        header = true;
      }
      Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
        break;
      }
      Err(e) => {
        return Err(e.into());
      }
    }
  }

  Ok(packets)
}

It works with TcpStream but I need to also make it work with UnixStream. Since this is a fairly convoluted state machine I'd rather not have two implementations. It was suggested to me to use async fn transfer_all<S: AsyncRead + Unpin>(stream: &mut S) -> Result<Vec<Vec<u8>>, Box<dyn std::error::Error>> and replace match stream.try_read(&mut buf) { with match stream.read(&mut buf).await { but this blocks when there's no more data to read. How can I make this method work with TcpStream and UnixStream?


Solution

  • Since both UnixStream and TcpStream have a try_read method, you can make your own trait for them:

    trait TryRead {
        // overlapping the name makes it hard to work with
        fn do_try_read(&self, buf: &mut [u8]) -> Result<usize>;
    }
    
    impl TryRead for TcpStream {
        fn do_try_read(&self, buf: &mut [u8]) -> Result<usize> {
            self.try_read(buf)
        }
    }
    
    impl TryRead for UnixStream {
        fn do_try_read(&self, buf: &mut [u8]) -> Result<usize> {
            self.try_read(buf)
        }
    }
    

    Then, you can take a S: AsyncRead + TryRead + Unpin then replace try_read with do_try_read.

    Or, with a macro for reduced repetition:

    trait TryRead {
        // overlapping the name makes it hard to work with
        fn do_try_read(&self, buf: &mut [u8]) -> Result<usize>;
    }
    
    macro_rules! make_try_read {
        ($typ: ty) => {
            impl TryRead for $typ {
                fn do_try_read(&self, buf: &mut [u8]) -> Result<usize> {
                    self.try_read(buf)
                }
            }
        }
    }
    
    make_try_read!(TcpStream);
    make_try_read!(UnixStream);