Search code examples
rustlifetimevariancecontravariance

Covariance of Box<dyn FnOnce(T)> in rust


I have a function that expects a short lived object. I would expect that I would be able to always pass it a long lived object. But I am getting a strange error when I try to encode that:

type F<'arg> = Box<dyn FnOnce(&'arg ())>;
fn contravar<'small, 'large: 'small>(f: F<'small>) -> F<'large> {
    f
}

playground

Particularly:

error: lifetime may not live long enough
 --> src/lib.rs:3:5
  |
2 | fn contravar<'small, 'large: 'small>(f: F<'small>) -> F<'large> {
  |              ------  ------ lifetime `'large` defined here
  |              |
  |              lifetime `'small` defined here
3 |     f
  |     ^ function was supposed to return data with lifetime `'large` but it is returning data with lifetime `'small`
  |
  = help: consider adding the following bound: `'small: 'large`

It seems like F is invariant for its argument but I would have guessed that it's contravariant. Am I missing something? Is there a way to make F<'arg> really contravariant for 'arg?

Edit: it looks like the "problem" is that rust wants to treat all generic traits the same (including Fn/FnMut/FnOnce). My opinion is that those 3 are and should be treated special especially given that they are the only way to refer to closures. For that reason I opened an issue


Solution

  • The Rust Reference's page on Subtyping and Variance documents that, as of Rust 1.63.0, fn(T) -> () is contravariant over T and that dyn Trait<T> + 'a is invariant over T.

    FnOnce, FnMut and Fn are traits, so that means dyn FnOnce(&'a ()) is unfortunately invariant over &'a ().

    // Compiles
    pub fn contravariant<'a, 'b: 'a>(x: fn(&'a ())) -> fn(&'b ()) { x }
    
    // Doesn't compile
    pub fn contravariant2<'a, 'b: 'a>(x: Box<dyn FnOnce(&'a ())>) -> Box<dyn FnOnce(&'b ())> { x }
    

    Is there a way to wrap FnOnce somehow to convince the compiler of the correct variance?

    Here's what I could come up with using unsafe code. Note that I'm not making any guarantees as to whether this is sound or not. I don't know of any way to do this without unsafe code.

    use std::marker::PhantomData;
    
    trait Erased {}
    
    impl<T> Erased for T {}
    
    pub struct VariantBoxedFnOnce<Arg, Output> {
        boxed_real_fn: Box<dyn Erased + 'static>,
        _phantom_fn: PhantomData<fn(Arg) -> Output>,
    }
    
    impl<Arg, Output> VariantBoxedFnOnce<Arg, Output> {
        pub fn new(real_fn: Box<dyn FnOnce(Arg) -> Output>) -> Self {
            let boxed_real_fn: Box<dyn Erased + '_> = Box::new(real_fn);
            let boxed_real_fn: Box<dyn Erased + 'static> = unsafe {
                // Step through *const T because *mut T is invariant over T
                Box::from_raw(Box::into_raw(boxed_real_fn) as *const (dyn Erased + '_) as *mut (dyn Erased + 'static))
            };
            Self {
                boxed_real_fn,
                _phantom_fn: PhantomData,
            }
        }
    
        pub fn call_once(self, arg: Arg) -> Output {
            let boxed_real_fn: Box<Box<dyn FnOnce(Arg) -> Output>> = unsafe {
                // Based on Box<dyn Any>::downcast()
                Box::from_raw(Box::into_raw(self.boxed_real_fn) as *mut Box<dyn FnOnce(Arg) -> Output>)
            };
            boxed_real_fn(arg)
        }
    }
    
    pub fn contravariant<'a, 'b: 'a>(x: VariantBoxedFnOnce<&'a (), ()>) -> VariantBoxedFnOnce<&'b (), ()> { x }
    
    #[cfg(test)]
    mod tests {
        use super::*;
    
        fn foo(_x: &()) {}
    
        #[test]
        pub fn check_fn_does_not_require_static() {
            let f = VariantBoxedFnOnce::new(Box::new(foo));
            let x = ();
            f.call_once(&x);
        }
    
        #[test]
        pub fn check_fn_arg_is_contravariant() {
            let f = VariantBoxedFnOnce::new(Box::new(foo));
            let g = contravariant(f);
            let x = ();
            g.call_once(&x);
        }
    }
    

    Here, VariantBoxedFnOnce is limited to functions taking one argument.

    The trick is to store a type-erased version of the Box<dyn FnOnce(Arg) -> Output> such that Arg disappears, because we don't want the variance of VariantBoxedFnOnce<Arg, Output> to depend on Box<dyn FnOnce(Arg) -> Output> (which is invariant over Arg). However, there's also a PhantomData<fn(Arg) -> Output> field to provide the proper contravariance over Arg (and covariance over Output).

    We can't use Any as our erased type, because only 'static types implement Any, and we have a step in VariantBoxedFnOnce::new() where we have a Box<dyn Erased + '_> where '_ is not guaranteed to be 'static. We then immediately "transmute" it into 'static, to avoid having a redundant lifetime parameter on VariantBoxedFnOnce, but that 'static is a lie (hence the unsafe code). call_once "downcasts" the erased type to the "original" Box<dyn FnOnce(Arg) -> Output>, except that Arg and Output may be different from the original due to variance.