Search code examples
rusttcpbyte

Reuse a BytesMut buffer while reading over tcp


I am reading by tcp with AsyncReadExt different messages, my idea would be to avoid copying things to memory and reusing the buffers that are generated.

The code I show below is an example of how I read by tcp. First I read the header that contains the total number of bytes to read. then I make a loop to fill the buffer from a set size and when this is reached I create a message and send it to another thread. In this message I have to do a .to_vec() because I have established that it is of type Vec. Once it is full I reset this buffer to continue reading.

Edit: Small example.

Client:

I need to send &[&[u8]] to avoid memory copies of certain variables that I have in my real application. For this reason I have proposed this algorithm to be able to send all the data.

use std::error::Error;

use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpStream;

#[tokio::main]
pub async fn main() -> Result<(), Box<dyn Error>> {
    let mut stream = TcpStream::connect("127.0.0.1:8080").await?;
    println!("[client] connected to server: {:?}", stream.peer_addr()?);

    let mut response: u32 = 999999;

    // Example of a message
    // Headers -> &[1, 2, 3, 4]
    // Body -> random_bytes
    let random_bytes = vec![1; 1048576 * 5];
    let data: &[&[u8]] = &[&[1, 2, 3, 4], &random_bytes];

    // Send buffer capacity -> 1MB
    let send_buffer_capacity: usize = 1048576;

    // Header
    let mut len = 0;
    for slice in data {
        len += slice.len() as u32;
    }

    //Send total len
    stream.write_u32(len).await.unwrap();

    //Send all data -> headers + body
    for slice in data {
        let iterations = slice.len() / send_buffer_capacity;

        if iterations > 0 {
            for i in 0..iterations {
                let index = i as usize;

                stream
                    .write_all(
                        &slice[send_buffer_capacity * index..send_buffer_capacity * (index + 1)],
                    )
                    .await
                    .unwrap();
            }

            let iter = iterations as usize;
            stream
                .write_all(&slice[send_buffer_capacity * iter..slice.len()])
                .await
                .unwrap();
        } else {
            stream.write_all(slice).await.unwrap();
        }
    }
    stream.flush().await.unwrap();

    // read respone from server
    response = stream.read_u32().await.unwrap();

    println!("Server response: {:?}", response);

    Ok(())
}

Server:

use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;

use std::error::Error;
use bytes::BytesMut;
use std::fmt::Debug;

pub struct Message {
    pub id: u32,
    pub sender_id: u32,
    pub op_id: u32,
    pub chunk_id: u32,
    pub last_chunk: bool,
    pub all_mess_len: u32,
    pub bytes: Vec<u8>,
}

impl Message {
    pub fn new(
        id: u32,
        sender_id: u32,
        op_id: u32,
        chunk_id: u32,
        last_chunk: bool,
        all_mess_len: u32,
        bytes: Vec<u8>,
    ) -> Message {
        Message {
            id: id,
            sender_id: sender_id,
            op_id: op_id,
            chunk_id: chunk_id,
            last_chunk: last_chunk,
            all_mess_len: all_mess_len,
            bytes: bytes,
        }
    }
}

impl Debug for Message {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("Message")
            .field("id", &self.id)
            .field("sender_id", &self.sender_id)
            .field("op_id", &self.op_id)
            .field("chunk_id", &self.chunk_id)
            .field("last_chunk", &self.last_chunk)
            .field("all_mess_len", &self.all_mess_len)
            .field("bytes", &self.bytes.len())
            .finish()
    }
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
    let addr = "127.0.0.1:8080".to_string();

    let listener = TcpListener::bind(&addr).await.unwrap();

    loop {
        match listener.accept().await {
            Ok((mut stream, addr)) => {
                println!("accepted a socket: {:?}", addr);

                tokio::spawn(async move {

                    let send_buffer_capacity = 1048576;

                    let mut n_bytes_read = 0;
                    let mut chunk_id = 0;
                    let mut last_chunk = false;
    
                    // Get header
                    let total_bytes = stream.read_u32().await.unwrap();
        
                    let mut buffer = BytesMut::with_capacity(send_buffer_capacity);
    
                    let mut bytes_per_chunk = 0;
    
                    loop {
                        match stream.read_buf(&mut buffer).await {
                            Ok(0) => {
                                continue;
                            }
                            Ok(n) => {
                                bytes_per_chunk += n;
                                n_bytes_read += n;
    
                                if n_bytes_read == total_bytes.try_into().unwrap() {
                                    last_chunk = true;
                                }
    
                                if bytes_per_chunk == send_buffer_capacity || last_chunk {
                                    let message = Message::new(
                                        0,
                                        0,
                                        0,
                                        chunk_id,
                                        last_chunk,
                                        total_bytes,
                                        buffer.to_vec(),
                                    );
    
                                    println!("message: {:?}", message);
                                    //Send message to queue
                                    //queue.push(message);
    
                                    chunk_id += 1;
    
                                    bytes_per_chunk = 0;
                                    buffer = BytesMut::with_capacity(send_buffer_capacity);
                                }
    
                                if last_chunk {
                                    break;
                                }
                            }
                            Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
                                //debug!("Err: TCP -> SV (Write))");
                                continue;
                            }
                            Err(e) => {
                                println!("error: {:?}", e);
                            }
                        }
                    }

                    stream
                        .write_u32(1)
                        .await
                        .unwrap();
                    stream.flush().await.unwrap();

                });
            },
            Err(err) => {
                println!("No stream");
            }
            
        }
    }

    Ok(())
}

If I change the type in the message to BytesMut I get an error in the line:

match stream.read_buf(&mut buffer).await {

It is the following:

let mut buffer = BytesMut::with_capacity(config.send_buffer_capacity);
    |         ---------- move occurs because `buffer` has type `BytesMut`, which does not implement the `Copy` trait
...
402 |     loop {
    |     ---- inside of this loop
...
405 |             buffer = BytesMut::with_capacity(config.send_buffer_capacity);
    |             ------ this reinitialization might get skipped
...
409 |         match stream.read_buf(&mut buffer).await {
    |                               ^^^^^^^^^^^ value borrowed here after move
...
430 |                         buffer,
    |                         ------ value moved here, in previous iteration of loop

If I do a .clone() of the buffer it will do the same as doing a .to_vec() it will make a copy of the memory.

Then I have to reset out of condition the buffer and nothing would work.

I have tried to use the read_exact() method but it never writes to the buffer.

Is there any way that the buffer remains in memory and I can just pass the reference in the message? And still use the buffer to continue reading?


Solution

  • I found this solution to use read_exact(). The problem was in the initialization of the vector. When done with Vec::with_capacity(n) it doesn't work because the size of this vector is 0, so you can't write to it. If you initialize it this way, it works:

    let mut buffer: Vec<u8> = vec![0; send_buffer_capacity as usize];
    

    I leave the whole solution:

    tokio::spawn(async move {
                    let send_buffer_capacity = 1048576;
                    
                    let mut chunk_id = 0;
                    let mut last_chunk = false;
    
                    // Get header
                    let total_bytes = stream.read_u32().await.unwrap();
    
                    let mut iterations = total_bytes / send_buffer_capacity;
                    let last_buffer_capacity = total_bytes % send_buffer_capacity;
    
                    if last_buffer_capacity > 0 {
                        iterations += 1;
                    }
    
                    for i in 0..iterations {
    
                        let mut buffer: Vec<u8> = vec![0; send_buffer_capacity as usize];
    
                        if last_buffer_capacity > 0 && i == (iterations - 1)  {
                            buffer = vec![0; last_buffer_capacity as usize];
                        }
    
                        match stream.read_exact(&mut buffer).await {
                            Ok(_) => {
                                if i == (iterations - 1) {
                                    last_chunk = true;
                                }
    
                                let message = Message::new(
                                    0,
                                    0,
                                    0,
                                    chunk_id,
                                    last_chunk,
                                    total_bytes,
                                    buffer,
                                );
    
                                println!("message: {:?}", message);
                                //Send message to queue
                                //queue.push(message);
    
                                chunk_id += 1;
    
                                if last_chunk {
                                    break;
                                }
                            }
                            Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
                                //debug!("Err: TCP -> SV (Write))");
                                continue;
                            }
                            Err(e) => {
                                println!("error: {:?}", e);
                            }
                        }
                    }
    
                    stream.write_u32(1).await.unwrap();
                    stream.flush().await.unwrap();
                });
    

    Another doubt that arises to me, in this case no additional memory copy is made, right? If that message is sent to another thread by a queue/channel, the reference of this message would be sent, therefore, we would find the reference of the buffer that we have stored?

    What is the difference between allocating memory or using the BytesMut vector? In terms of memory at the end you use the same I understand, because when you do read_exact() you already know that the buffer is going to be filled. Is this true? Is it better to use read_buf() or read_exact()?