Search code examples
rustreferenceundefined-behaviorunsafe

Rust: Modifying the referent of a reference with a function; does this contain UB?


Recently, I wrote the following:

use std::ptr;

fn modify_mut_ret<T,R,F> (ptr: &mut T, f: F) -> R
  where F: FnOnce(T) -> (T,R)
{
   unsafe {
      let (t,r) = f(ptr::read(ptr));
      ptr::write(ptr,t);
      r
   }
}

This is a simple utility, so I expected it was in the standard library, but I couldn't find it (at least in std::mem). If we assume, for example, T: Default, we can safely implement this with an extra drop overhead:

use std::mem;

#[inline]
fn modify_mut_ret<T,R,F>(ptr: &mut T, f: F) -> R
  where F: FnOnce(T) -> (T,R),
        T: Default
{
    let mut t = T::default();
    mem::swap(ptr, &mut t);
    let (t,r) = f(t);
    *ptr = t;
    r
}

I don't think the first implementation contains any undefined behavior: we have no alignment issue, and we, with ptr::write, eliminate one of the two ownerships duplicated with ptr::read. However I'm anxious about the fact that std seemingly doesn't contain a function with this behavior. Have I got anything wrong or have I forgot something? Does the unsafe code above contain any UB?


Solution

  • This code contains only one instance of UB, which is because the function can return early. Let's take a closer look at it (I moved some things around to make it easier to take apart):

    fn modify_mut_ret<T, R, F: FnOnce(T) -> (T, R)>(x: &mut T, f: F) -> R {
       unsafe {
          let old_val = ptr::read(x); // Copied from original value, two copies of the
                                      // same non-Copy object exist now
          let (t, r) = f(old_val); // Supplied one copy to the closure
          ptr::write(x, t); // Erased the second copy by writing without dropping it
          r
       }
    }
    

    If the closure runs fine, the outer function will progress as normal and the total number of copies of the old value of x is going to stay at just one copy, which will be owned by the closure, which it may or may not store for later in an Rc<RefCell<...>>/Arc<RwLock<...>> or a global variable.

    If it panics, however, and the panic is caught by the code calling modify_mut_ret using std::panic::catch_unwind, there would be two copies of the old value of x, because the ptr::write wasn't reached yet but ptr::read already was.

    What you need to do is handle the panicking by aborting the process:

    use std::{ptr, panic::{catch_unwind, AssertUnwindSafe}};
    
    fn modify_mut_ret<T, R, F>(x: &mut T, f: F) -> R
    where F: FnOnce(T) -> (T, R) {
        unsafe {
            let old_val = ptr::read(x);
            let (t, r) = catch_unwind(AssertUnwindSafe(|| f(old_val)))
                .unwrap_or_else(|_| std::process::abort());
            ptr::write(x, t); // Erased the second copy by writing without dropping it
            r
        }
    }
    

    This way, panicking in the closure will never leave the function since it will catch the panic and abort the process immediately, before any other code can observe the duplicated value.

    The AssertUnwindSafe is there because we have to ensure that we're not gonna observe logically invalid values created as a result of a panic, since we will always abort after a panic. See UnwindSafe's documentation for more on that.