Search code examples
rustconcurrencymutexsmart-pointers

Creating Multiple Mutexes in Rust for Thread Synchronization


I'm not familiar with Rust. The program I'm trying to write will have the value of n determined at runtime.

I want the program to have the following behavior: n threads will be created, each interacting with a user and storing data (these are slave nodes). The latest data from all of these threads will be consolidated and output by a single thread (the master node).

To achieve this, I'll be using arc and mutex to share variables between individual slaves and the master. I want to declare N mutexes, but I'm not sure how to do that.

One of the challenges I'm facing is that because n is determined dynamically, I need to create a vector containing n Arcs.

If this approach is incorrect, I'm open to hearing about alternative solutions.


use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;


fn main() {
    let n = 10;

    let mut shared_list = Vec::with_capacity(n);
    let mut thread_list = Vec::with_capacity(n);
    for _ in 0..n {
        shared_list.push(Arc::new(Mutex::new(0)));
    }

    let print_thread = thread::spawn(move || {
        loop {
            
            let mut sum = 0;
            for i in 0..n {
                let data = shared_list[i].lock().unwrap();
                sum += *data;
            }
            println!("shared_data = {}", sum);
    
            if sum == 10*100 {
                break;
            }
        }
    });
    thread_list.push(print_thread);

    for i in 0..n {
        let shared_data = shared_list[i].clone();
        let slave_thread = thread::spawn(move || {
            loop {
                let mut data = shared_data.lock().unwrap();
                thread::sleep(Duration::from_millis(3));
                *data += 1;
                println!("Thread {}: data = {}", i, *data);
                
                if *data == 100{
                    break;
                }
            }
        });
        thread_list.push(slave_thread);
    }
    
    
    for handle in thread_list {
        handle.join().unwrap();
    }
}

Update The program that will be written should have n threads that continuously receive matrices from the user. Additionally, a separate thread should be included that collects the last saved matrices from each individual thread, regardless of their input operations, and performs additional operations on them. I understand that this is a very specific scenario.

I will share a toy program that is closer to the original scenario.

use rand::Rng;
use ndarray::Array2;
use std::thread;
use std::time::Duration;
use std::sync::{Arc, Mutex};

fn main() {
    let n = 10;
    let mut thread_handles = Vec::new();
    let mut shared_list = Vec::with_capacity(n);

    for _ in 0..n {
        shared_list.push(Arc::new(Mutex::new(Array2::zeros((0, 0)))));
    }
    
    let shared_data_clone = shared_list.clone();
    let output_thread = thread::spawn(move || {
        let mut loop_t = 0;

        loop {
            for i in 0..n {
                let data = shared_data_clone[i].lock().unwrap();
                
            /// something...
            }
            loop_t += 1;
        }
    });
    thread_handles.push(output_thread);
    

    for i in 0..n {
        let shared_data_clone = shared_list[i].clone();
        let handle = thread::spawn(move || {
            let mut loop_t = 0;
            loop {
                let mut rng = rand::thread_rng();
                let m = rng.gen::<u32>() % 9 + 1; 

                let matrix = Array2::from_shape_fn((m as usize, 3), |_| rng.gen_range(1..10)); 
                thread::sleep(Duration::from_millis(3));

                loop_t += 1;

                let mut data = shared_data_clone.lock().unwrap();
                *data = matrix;

                let sum: i32 = data.iter().sum();
                //println!("Thread {} generated matrix:{}", i, sum);
            }
        });
        thread_handles.push(handle);
    }

}

Solution

  • First, let me fix your code in the simplest way possible. Then, I'll talk about why your solution is non-ideal and what better alternatives would be.


    The problem you have at hand is that the "error[E0382]: borrow of moved value: shared_list".

    Fixing this is simple. You move the shared_list variable into the primary thread before you use it again later to spawn the other threads.

    Simply swap those two. Create the threads first and then move it into the primary thread.

    use std::sync::{Arc, Mutex};
    use std::thread;
    use std::time::Duration;
    
    fn main() {
        let n = 10;
    
        let mut shared_list = Vec::with_capacity(n);
        let mut thread_list = Vec::with_capacity(n);
        for _ in 0..n {
            shared_list.push(Arc::new(Mutex::new(0)));
        }
    
        for i in 0..n {
            let shared_data = shared_list[i].clone();
            let slave_thread = thread::spawn(move || loop {
                let mut data = shared_data.lock().unwrap();
                thread::sleep(Duration::from_millis(3));
                *data += 1;
                println!("Thread {}: data = {}", i, *data);
    
                if *data == 100 {
                    break;
                }
            });
            thread_list.push(slave_thread);
        }
    
        let print_thread = thread::spawn(move || loop {
            let mut sum = 0;
            for i in 0..n {
                let data = shared_list[i].lock().unwrap();
                sum += *data;
            }
            println!("shared_data = {}", sum);
    
            if sum == 10 * 100 {
                break;
            }
        });
        thread_list.push(print_thread);
    
        for handle in thread_list {
            handle.join().unwrap();
        }
    }
    
    Thread 0: data = 1
    Thread 1: data = 1
    Thread 3: data = 1
    Thread 2: data = 1
    Thread 6: data = 1
    Thread 8: data = 1
    Thread 5: data = 1
    Thread 4: data = 1
    ...
    Thread 6: data = 100
    Thread 9: data = 100
    Thread 5: data = 100
    Thread 0: data = 100
    Thread 2: data = 100
    Thread 1: data = 100
    Thread 4: data = 100
    shared_data = 921
    shared_data = 1000
    

    Now let's talk about the problems.

    Note that shared_data = does not get printed a single time before all threads are done. That's because all of your threads are continuously holding the mutexes, even while sleeping; therefore your primary thread is continuously and desperately waiting for the next mutex to be released.

    There are more problems, though:

    • Mutex<i32> is an anti-pattern, use AtomicI32. Although I guess your real data might have more complex data, actually requiring a Mutex.
    • Arc is not needed at all, at least in the code you show - you can use references instead. Note that this will require std::thread::scope, because references to the outside only work in scoped threads. Otherwise the compiler can't guarantee that the reference gets dropped before the actual data in main does.
    • Your primary thread never waits. If it's your intention to compute the intermediate result with 100% CPU power over and over again, then you achieved that - if not, you should add a sleep somewhere.
    • All of your data points are locked individually - that means, when the primary thread finishes its iteration, the other data most likely already changed, and your result is void. Would it be benefitial to only have a single piece of data that all threads modify? If not, no problem - it might still make sense, though, to propagate the combined data to a single point and not have one thread iterate over the vector continuously. Although this point depends heavily on the usecase, so your layout may make sense for your situation.
    • Don't use for i in 0..n { data[i] } to loop over a vector - it's slow and C-style. The [] operator in Rust has overhead because it performs out-of-bounds checking. Use the for loop directly to iterate over it, like for elem in data {}, or use iterators. Iterators in your case would make more sense, because they even offer a .sum() option.
    • Don't even spawn your primary thread - use the main thread directly. It's just in idle, waiting for all the threads.

    Here is some inspiration. For readability I reduced the number of iterations to 2:

    • scope and Atomics:
    use std::sync::atomic::{AtomicI32, Ordering};
    use std::thread;
    use std::time::Duration;
    
    fn main() {
        let n = 10;
    
        let mut shared_list = Vec::new();
        for _ in 0..n {
            shared_list.push(AtomicI32::new(0));
        }
    
        std::thread::scope(|s| {
            for (i, shared_data) in shared_list.iter().enumerate() {
                s.spawn(move || loop {
                    thread::sleep(Duration::from_millis(3));
                    let data = shared_data.fetch_add(1, Ordering::Relaxed) + 1;
                    println!("Thread {}: data = {}", i, data);
    
                    if data == 2 {
                        break;
                    }
                });
            }
    
            loop {
                thread::sleep(Duration::from_millis(2));
                let sum: i32 = shared_list
                    .iter()
                    .map(|data| data.load(Ordering::Relaxed))
                    .sum();
    
                println!("shared_data = {}", sum);
    
                if sum == 10 * 2 {
                    break;
                }
            }
    
            // No need for joining - The threads get joined automatically
            // at the end of the scope.
        })
    }
    
    shared_data = 0
    Thread 0: data = 1
    Thread 1: data = 1
    Thread 2: data = 1
    Thread 4: data = 1
    Thread 7: data = 1
    Thread 8: data = 1
    Thread 3: data = 1
    Thread 9: data = 1
    Thread 5: data = 1
    Thread 6: data = 1
    shared_data = 10
    Thread 0: data = 2
    Thread 1: data = 2
    Thread 2: data = 2
    Thread 4: data = 2
    Thread 7: data = 2
    Thread 8: data = 2
    Thread 3: data = 2
    shared_data = 17
    Thread 5: data = 2
    Thread 9: data = 2
    Thread 6: data = 2
    shared_data = 20
    
    • Arc<Mutex<>>, but done properly (in my opinion):
    use std::sync::{Arc, Mutex};
    use std::thread;
    use std::time::Duration;
    
    fn main() {
        let n = 10;
    
        std::thread::scope(|s| {
            let mut shared_list = vec![];
            for i in 0..n {
                let shared_data = Arc::new(Mutex::new(0));
                shared_list.push(Arc::clone(&shared_data));
    
                s.spawn(move || loop {
                    thread::sleep(Duration::from_millis(3));
                    // First sleep, then lock. Otherwise the lock is held during
                    // sleep, blocking whoever waits for the lock.
                    let new_data = {
                        let mut data = shared_data.lock().unwrap();
                        *data += 1;
                        *data
                        // Release the lock here again, so it isn't held
                        // during `println`.
                    };
    
                    println!("Thread {}: data = {}", i, new_data);
    
                    if new_data == 2 {
                        break;
                    }
                });
            }
    
            loop {
                thread::sleep(Duration::from_millis(2));
                let sum = shared_list
                    .iter()
                    .map(|data| *data.lock().unwrap())
                    .sum::<i32>();
    
                println!("shared_data = {}", sum);
    
                if sum == 10 * 2 {
                    break;
                }
            }
    
            // No need for joining - The threads get joined automatically
            // at the end of the scope.
        })
    }
    
    shared_data = 0
    Thread 0: data = 1
    Thread 2: data = 1
    Thread 1: data = 1
    Thread 3: data = 1
    Thread 4: data = 1
    Thread 5: data = 1
    Thread 6: data = 1
    Thread 8: data = 1
    Thread 9: data = 1
    Thread 7: data = 1
    shared_data = 10
    Thread 2: data = 2
    Thread 0: data = 2
    Thread 1: data = 2
    Thread 3: data = 2
    Thread 4: data = 2
    Thread 5: data = 2
    Thread 7: data = 2
    shared_data = 17
    Thread 8: data = 2
    Thread 9: data = 2
    Thread 6: data = 2
    shared_data = 20
    

    There are probably way more ways to solve this, like incorporating mpsc or using only one global variable instead of a vector of many. But I think this post is long enough :) Maybe this gave you a couple of ideas to continue your quest.


    EDIT: Here is an example of having only one global shared variable for statistics data.

    The principle is that the printing thread doesn't actively query the data from the threads, but instead the data is private to each thread and not shared with anyone. The threads themselves then push changes into the printing thread.

    This has the advantage that you can use a Condvar to keep the printing thread dormant until a change happens, causing it to respond quick and snappy.

    While I do like std::thread::scope, I chose not to use it here to stay closer to your original code. Normally I do use it wherever I can, though, because it's just a really nice concept.

    Like this:

    use std::sync::{Arc, Condvar, Mutex};
    use std::thread;
    use std::time::Duration;
    
    struct InnerAggregateData {
        // The actual global data
        sum: i32,
        // For notifying the watcher thread of changes
        data_changed: bool,
    }
    
    struct AggregateData {
        inner: Mutex<InnerAggregateData>,
        condvar: Condvar,
    }
    impl AggregateData {
        pub fn new() -> Self {
            Self {
                inner: Mutex::new(InnerAggregateData {
                    sum: 0,
                    data_changed: true,
                }),
                condvar: Condvar::new(),
            }
        }
    }
    
    fn main() {
        let n = 5;
        let num_iters = 3;
    
        let mut threads = vec![];
    
        let aggregate_data = Arc::new(AggregateData::new());
    
        for i in 0..n {
            let aggregate_data = Arc::clone(&aggregate_data);
            threads.push(thread::spawn(move || {
                let mut local_data = 0;
                loop {
                    thread::sleep(Duration::from_millis(200));
    
                    // Store previous local data
                    let previous_data = local_data;
    
                    // Update local data
                    local_data += 1;
                    println!("Thread {}: data = {}", i, local_data);
    
                    // Compute difference
                    let diff = local_data - previous_data;
    
                    // Update global data
                    let mut global_data = aggregate_data.inner.lock().unwrap();
                    global_data.sum += diff;
    
                    // Signal global data that it changed
                    global_data.data_changed = true;
                    // Important: trigger the condvar while the lock is still held.
                    // Otherwise there would be a race condition that could cause
                    // a deadlock.
                    aggregate_data.condvar.notify_all();
    
                    if local_data == num_iters {
                        break;
                    }
                }
            }));
        }
    
        threads.push(thread::spawn(move || loop {
            let locked_data = aggregate_data.inner.lock().unwrap();
            let mut locked_data = aggregate_data
                .condvar
                .wait_while(locked_data, |data| !data.data_changed)
                .unwrap();
    
            // Reset the `data_changed` variable
            locked_data.data_changed = false;
    
            // Extract data from lock
            let sum = locked_data.sum;
    
            // IMPORTANT: drop lock as soon as we have all the data, otherwise we
            // will block all the threads that try to write into the data.
            drop(locked_data);
    
            // Further processing without a locked mutex. This is allowed to take
            // long, and could even contain a `sleep` without blocking any other
            // threads.
            println!("shared_data = {}", sum);
    
            if sum == n * num_iters {
                break;
            }
        }));
    
        for handle in threads {
            handle.join().unwrap();
        }
    }
    
    shared_data = 0
    Thread 0: data = 1
    Thread 1: data = 1
    shared_data = 1
    Thread 4: data = 1
    shared_data = 3
    Thread 3: data = 1
    Thread 2: data = 1
    shared_data = 4
    shared_data = 5
    ... sleep ...
    Thread 1: data = 2
    Thread 0: data = 2
    shared_data = 6
    shared_data = 7
    Thread 4: data = 2
    Thread 2: data = 2
    shared_data = 8
    Thread 3: data = 2
    shared_data = 9
    shared_data = 10
    ... sleep ...
    Thread 1: data = 3
    Thread 0: data = 3
    shared_data = 12
    Thread 4: data = 3
    shared_data = 13
    Thread 2: data = 3
    Thread 3: data = 3
    shared_data = 14
    shared_data = 15
    

    Note that the shared_data = reacts quickly to changes, but does not cause a jam. It is still capable of skipping a change if another change occured simultaneously; for example in my output, it didn't print shared_data = 2.