Search code examples
multithreadingasynchronousrustautomatic-ref-countingrwlock

Shared memory in Rust


Environment:

macOS Sonoma Ver.14.0 (M1 mac) Rust Ver.1.65.0

What I want to do: I want to share a vec with an array of [u8;128] elements between multithreads. The requirements I want to perform when sharing are as follows.

  1. the entire vec must be readable
  2. to be able to rewrite elements of a specific [u8; 128] type in the vec
  3. be able to insert data of type [u8; 128] into vec

Below is the code I wrote, but this code can do up to reading, but there is a problem that the writing is not reflected. If I run this code and then run the following command once on the computer where it was executed


    nc -v localhost 50051


    [[0u8; 128],[1u8; 128],[2u8; 128]]

will be output. This is correct up to this point, but the data output on the second run is the same as the first run. My intention is that the second element will output data with 3 fillings as shown below, because I am updating the data in the first run.


    [[0u8; 128],[3u8; 128],[2u8; 128]]

I am guessing that my use of Arc is wrong and that it is actually a clone of SharedData being passed around instead of a reference to SharedData, but I don't know how I can identify this. How can I fix the code to make it work as I intended?

main.rs:

use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::RwLock;
use std::time::Duration;
use tokio_task_pool::Pool;

struct SharedData {
    data: Arc<RwLock<Vec<[u8; 128]>>>
}

impl SharedData {
    fn new(data: RwLock<Vec<[u8; 128]>>) -> Self {
        Self {
            data: Arc::new(data)
        }
    }

    fn update(&self, index: usize, update_data: [u8; 128]) {
        let read_guard_for_array = self.data.read().unwrap();
        let write_lock = RwLock::new((*read_guard_for_array)[index]);
        let mut write_guard_for_item = write_lock.write().unwrap();
        *write_guard_for_item = update_data;
    }
}

fn socket_to_async_tcplistener(s: socket2::Socket) -> std::io::Result<tokio::net::TcpListener> {
    std::net::TcpListener::from(s).try_into()
}

async fn process(mut stream: tokio::net::TcpStream, db_arc: Arc<SharedData>) {
    let read_guard = db_arc.data.read().unwrap();
    println!("In process() read: {:?}", *read_guard);
    db_arc.update(1, [3u8; 128]);
}

async fn serve(_: usize, tcplistener_arc: Arc<tokio::net::TcpListener>, db_arc: Arc<SharedData>) {
    let task_pool_capacity = 10;

    let task_pool = Pool::bounded(task_pool_capacity)
        .with_spawn_timeout(Duration::from_secs(300))
        .with_run_timeout(Duration::from_secs(300));
    
    loop {
        let (stream, _) = tcplistener_arc.as_ref().accept().await.unwrap();
        let db_arc_clone = db_arc.clone();

        task_pool.spawn(async move {
            process(stream, db_arc_clone).await;
        }).await.unwrap();
    }
}

#[tokio::main]
async fn main() {
    let addr: std::net::SocketAddr = "0.0.0.0:50051".parse().unwrap();
    let soc2 = socket2::Socket::new(
        match addr {
            SocketAddr::V4(_) => socket2::Domain::IPV4,
            SocketAddr::V6(_) => socket2::Domain::IPV6,
        },
        socket2::Type::STREAM,
        Some(socket2::Protocol::TCP)
    ).unwrap();
    
    soc2.set_reuse_address(true).unwrap();
    soc2.set_reuse_port(true).unwrap();
    soc2.set_nonblocking(true).unwrap();
    soc2.bind(&addr.into()).unwrap();
    soc2.listen(8192).unwrap();

    let tcp_listener = Arc::new(socket_to_async_tcplistener(soc2).unwrap());

    let mut vec = vec![
        [0u8; 128],
        [1u8; 128],
        [2u8; 128],
    ];

    let share_db = Arc::new(SharedData::new(RwLock::new(vec)));
    let mut handlers = Vec::new();
    for i in 0..num_cpus::get() - 1 {
        let cloned_listener = Arc::clone(&tcp_listener);
        let db_arc = share_db.clone();

        let h = std::thread::spawn(move || {
            tokio::runtime::Builder::new_current_thread()
                .enable_all()
                .build()
                .unwrap()
                .block_on(serve(i, cloned_listener, db_arc));
        });
        handlers.push(h);
    }

    for h in handlers {
        h.join().unwrap();
    }
}

Cargo.toml:

[package]
name = "tokio-test"
version = "0.1.0"
edition = "2021"

[dependencies]
log = "0.4.20"
env_logger = "0.10.0"
tokio = { version = "1.34.0", features = ["full"] }
tokio-stream = { version = "0.1.14", features = ["net"] }
serde = { version = "1.0.193", features = ["derive"] }
serde_yaml = "0.9.27"
serde_derive = "1.0.193"
mio = {version="0.8.9", features=["net", "os-poll", "os-ext"]}
num_cpus = "1.16.0"
socket2 = { version="0.5.5", features = ["all"]}
array-macro = "2.1.8"
tokio-task-pool = "0.1.5"
argparse = "0.2.2"

Solution

  • I haven't looked at the entire code, but there are a few errors.

    fn update()

        fn update(&self, index: usize, update_data: [u8; 128]) {
            let read_guard_for_array = self.data.read().unwrap();
            let write_lock = RwLock::new((*read_guard_for_array)[index]);
            let mut write_guard_for_item = write_lock.write().unwrap();
            *write_guard_for_item = update_data;
        }
    

    That's not how you use a RwLock:

    • if you want to modify the data, instead of using self.data.read(), use self.data.write();
    • I'm not sure what you intend to do with this the second RwLock, but it is useless.

    Rather, do something like

        fn update(&self, index: usize, update_data: [u8; 128]) {
            let write_guard_for_array = self.data.write().unwrap();
            write_guard_for_array[index] = update_data;
        }
    

    fn process()

    async fn process(mut stream: tokio::net::TcpStream, db_arc: Arc<SharedData>) {
        let read_guard = db_arc.data.read().unwrap();
        println!("In process() read: {:?}", *read_guard);
        db_arc.update(1, [3u8; 128]);
    }
    

    Generally, you probably shouldn't access db_arc.data directly. But beyond that, once you fix function update(), this is going to deadlock:

    1. You acquire db_arc.data.read(). By definition of a RwLock, this means that nobody can modify the contents of db_arc.data until the read lock is released.
    2. The read lock is released only at the end of the scope.
    3. Before the end of the scope, you call update(), which is going to attempt to acquire data.write(). But it cannot acquire it until the read lock is released.

    You probably want something along the lines of:

    async fn process(mut stream: tokio::net::TcpStream, db_arc: Arc<SharedData>) {
        {
          let read_guard = db_arc.data.read().unwrap();
          println!("In process() read: {:?}", *read_guard);
        } // End of scope, `read_guard` is released.
        db_arc.update(1, [3u8; 128]);
    }
    

    tokio + threads

    You're mixing threads and tokio. It's theoretically possible, but risky. Both choices are valid, but I suggest picking either one or the other. Typically, pick tokio if you have lots of I/O (e.g. network requests or disk access) or threads if you have lots of CPU usage.