Search code examples
recursionruststacklifetimeunsafe

Stack of references in unsafe Rust, but ensuring that the unsafeness does not leak out of the stack?


I'm implementing some recursive code, where function instances deeper down in the call stack may need to refer to data from prior frames. However, I only have non-mut access to those data, so I receive those data as references. As such, I would need to keep references to those data in a stack data structure that can be accessed from the deeper instances.

To illustrate:

// I would like to implement this RefStack class properly, without per-item memory allocations
struct RefStack<T: ?Sized> {
    content: Vec<&T>,
}
impl<T: ?Sized> RefStack<T> {
    fn new() -> Self { Self{ content: Vec::new() } }
    fn get(&self, index: usize) -> &T { self.content[index] }
    fn len(&self) -> usize { self.content.len() }
    fn with_element<F: FnOnce(&mut Self)>(&mut self, el: &T, f: F) {
        self.content.push(el);
        f(self);
        self.content.pop();
    }
}

// This is just an example demonstrating how I would need to use the RefStack class
fn do_recursion(n: usize, node: &LinkedListNode, st: &mut RefStack<str>) {
    // get references to one or more items in the stack
    // the references should be allowed to live until the end of this function, but shouldn't prevent me from calling with_element() later
    let tmp: &str = st.get(rng.gen_range(0, st.len()));
    // do stuff with those references (println is just an example)
    println!("Item: {}", tmp);
    // recurse deeper if necessary
    if n > 0 {
        let (head, tail): (_, &LinkedListNode) = node.get_parts();
        manager.get_str(head, |s: &str| // the actual string is a local variable somewhere in the implementation details of get_str()
            st.with_element(s, |st| do_recursion(n - 1, tail, st))
        );
    }
    // do more stuff with those references (println is just an example)
    println!("Item: {}", tmp);
}

fn main() {
    do_recursion(100, list /* gotten from somewhere else */, &mut RefStack::new());
}

In the example above, I'm concerned about how to implement RefStack without any per-item memory allocations. The occasional allocations by the Vec is acceptable - those are few and far in between. The LinkedListNode is just an example - in practice it's some complicated graph data structure, but the same thing applies - I only have a non-mut reference to it, and the closure given to manager.get_str() only provides a non-mut str. Note that the non-mut str passed into the closure may only be constructed in the get_str() implementation, so we cannot assume that all the &str have the same lifetime.

I'm fairly certain that RefStack can't be implemented in safe Rust without copying out the str into owned Strings, so my question is how this can be done in unsafe Rust. It feels like I might be able to get a solution such that:

  • The unsafeness is confined to the implementation of RefStack
  • The reference returned by st.get() should live at least as long as the current instance of the do_recursion function (in particular, it should be able to live past the call to st.with_element(), and this is logically safe since the &T that is returned by st.get() isn't referring to any memory owned by the RefStack anyway)

How can such a struct be implemented in (unsafe) Rust?

It feels that I could just cast the element references to pointers and store them as pointers, but I will still face difficulties expressing the requirement in the second bullet point above when casting them back to references. Or is there a better way (or by any chance is such a struct implementable in safe Rust, or already in some library somewhere)?


Solution

  • Based on rodrigo's answer, I implemented this slightly simpler version:

    struct RefStack<'a, T: ?Sized + 'static> {
        content: Vec<&'a T>,
    }
    
    impl<'a, T: ?Sized + 'static> RefStack<'a, T> {
        fn new() -> Self {
            RefStack {
                content: Vec::new(),
            }
        }
    
        fn get(&self, index: usize) -> &'a T {
            self.content[index]
        }
    
        fn len(&self) -> usize {
            self.content.len()
        }
    
        fn with_element<'t, F: >(&mut self, el: &'t T, f: F)
        where
            F: FnOnce(&mut RefStack<'t, T>),
            'a: 't,
        {
            let mut st = RefStack {
                content: std::mem::take(&mut self.content),
            };
            st.content.push(el);
            f(&mut st);
            st.content.pop();
            self.content = unsafe { std::mem::transmute(st.content) };
        }
    }
    

    The only difference to rodrigo's solution is that the vector is represented as vector of references instead of pointers, so we don't need the PhantomData and the unsafe code to access an element.

    When a new element is pushed to the stack in with_element(), we require that it has a shorter lifetime than the existing elements with the a': t' bound. We then create a new stack with the shorter lifetime, which is possible in safe code since we know the data the references in the vector are pointing to even lives for the longer lifetime 'a. We then push the new element with lifetime 't to the new vector, again in safe code, and only after we removed that element again we move the vector back in it's original place. This requires unsafe code since we are extending the lifetime of the references in the vector from 't to 'a this time. We know this is safe, since the vector is back to its original state, but the compiler doesn't know this.

    I feel this version represents the intent better than rodrigo's almost identical version. The type of the vector always is "correct", in that it desribes that the elements are actually references, not raw pointers, and it always assigns the correct lifetime to the vector. And we use unsafe code exactly in the place where something potentially unsafe happens – when extending the lifetime of the references in the vector.