Search code examples
vectorrustreferenceborrow-checkerborrowing

How to get multiple mutable references to elements in a Vec?


I have a large nested data structure and would like to pluck out a few parts to pass around for processing. Ultimately I want to send sections to multiple threads to update but I'd like to get my feet wet understanding the simple example I illustrate below. In C I would just assemble an array of the relevant pointers. It seems doable in Rust as the interior vectors will never need multiple mutable references. Here's the sample code.

fn main() {
    let mut data = Data::new(vec![2, 3, 4]);
    // this works
    let slice = data.get_mut_slice(1);
    slice[2] = 5.0;
    println!("{:?}", data);

    // what I would like to do
    // let slices = data.get_mut_slices(vec![0, 1]);
    // slices[0][0] = 2.0;
    // slices[1][0] = 3.0;
    // println!("{:?}", data);
}

#[derive(Debug)]
struct Data {
    data: Vec<Vec<f64>>,
}

impl Data {
    fn new(lengths: Vec<usize>) -> Data {
        Data {
            data: lengths.iter().map(|n| vec![0_f64; *n]).collect(),
        }
    }

    fn get_mut_slice(&mut self, index: usize) -> &mut [f64] {
        &mut self.data[index][..]
    }

    // doesnt work
    // fn get_mut_slices(&mut self, indexes: Vec<usize>) -> Vec<&mut [f64]> {
    //     indexes.iter().map(|i| self.get_mut_slice(*i)).collect()
    // }
}

Solution

  • This is possible using safe Rust as long as you're very careful. The trick is to leverage the unsafe Rust code in the standard library behind the safe .iter_mut() and .nth() methods on Vec. Here's a working example with comments explaining the code in context:

    fn main() {
        let mut data = Data::new(vec![2, 3, 4]);
    
        // this works
        let slice = data.get_mut_slice(1);
        slice[2] = 5.0;
        println!("{:?}", data);
    
        // and now this works too!
        let mut slices = data.get_mut_slices(vec![0, 1]);
        slices[0][0] = 2.0;
        slices[1][0] = 3.0;
        println!("{:?}", data);
    }
    
    #[derive(Debug)]
    struct Data {
        data: Vec<Vec<f64>>,
    }
    
    impl Data {
        fn new(lengths: Vec<usize>) -> Data {
            Data {
                data: lengths.iter().map(|n| vec![0_f64; *n]).collect(),
            }
        }
    
        fn get_mut_slice(&mut self, index: usize) -> &mut [f64] {
            &mut self.data[index][..]
        }
    
        // now works!
        fn get_mut_slices(&mut self, mut indexes: Vec<usize>) -> Vec<&mut [f64]> {
            // sort indexes for easier processing
            indexes.sort();
            let index_len = indexes.len();
    
            // early return for edge case
            if index_len == 0 {
                return Vec::new();
            }
    
            // check that the largest index is in bounds
            let max_index = indexes[index_len - 1];
            if max_index > self.data.len() {
                panic!("{} index is out of bounds of data", max_index);
            }
    
            // check that we have no overlapping indexes
            indexes.dedup();
            let uniq_index_len = indexes.len();
            if index_len != uniq_index_len {
                panic!("cannot return aliased mut refs to overlapping indexes");
            }
    
            // leverage the unsafe code that's written in the standard library
            // to safely get multiple unique disjoint mutable references
            // out of the Vec
            let mut mut_slices_iter = self.data.iter_mut();
            let mut mut_slices = Vec::with_capacity(index_len);
            let mut last_index = 0;
            for curr_index in indexes {
                mut_slices.push(
                    mut_slices_iter
                        .nth(curr_index - last_index)
                        .unwrap()
                        .as_mut_slice(),
                );
                last_index = curr_index;
            }
    
            // return results
            mut_slices
        }
    }
    

    playground


    What I believe I learned is that the Rust compiler demands an iterator in this situation because that is the only way it can know that each mut slice comes from a different vector.

    The compiler doesn't know that actually. All it knows is that the iterator returns mut references. The underlying implementation uses unsafe Rust but the method iter_mut() itself is safe because the implementation guarantees to only emit each mut ref once and that all of the mut refs are unique.

    Would the compiler complain if another mut_slices_iter was created in the for loop (which could grab the same data twice)?

    Yes. Calling iter_mut() on Vec mutably borrows it and overlapping mutable borrows of the same data are against Rust's ownership rules, so you can't call iter_mut() twice in the same scope (unless the iterator returned by the first call is dropped before the second call).

    Also am I right that the .nth method will call next() n times so it is ultimately theta(n) on the first axis?

    Not quite. That's the default implementation for nth BUT the iterator returned by calling iter_mut() on Vec uses its own custom implementation and it seems to skip past items in the iterator without calling next() so it should be as fast as if you were just regularly indexing into the Vec, i.e. getting 3 randomly indexed items using .nth() would be as fast on an iterator of 10000 items as it would be on an iterator of 10 items, although this is only specifically true for iterators created from collections which support fast random access like Vecs.