Search code examples
rusterror-handlingfunctional-programming

How to properly propagate an error through non-error function?


Let's say i want to use a function F : (A -> B) -> C which i cannot modify. I, however, want to pass a function g : A -> Result<B, Err> as an argument instead. I would want a function G : (A -> Result<B, Err>) -> Result<C, Err> that would "propagate" any error resulting from calling g.

Using exceptions, I could do something like

fn transform_to_err(F : (A -> B) -> C, g : A -> Result<B, Err>) -> Result<C, Err> {
    
    fn g_panic (a : A) -> B {
        match g(a) {
            Ok(b) => b
            Err => panic("g_panic error")
        }
    }

    try {
        Ok(F(g_panic)) 
    } catch exception {
        Err
    }
}

Is there a way to do something like this properly handling the results? This way also doesn't allow for algebraic errors which makes it pretty underwhelming.

Here is an example where sum10 acts as F and sqrt_err as g

use std::f64::consts::E;

/// Suppose this function is external and we cannot change it
fn sum10(fun: impl Fn(f64) -> f64) -> f64 {
    let mut sum = 0.;
    for i in 0..10 {
        sum += E.powf(fun(i as f64));
    }
    sum
}

/// We want to use this function with sum10 but it returns a Result
fn sqrt_err(x: f64) -> Result<f64, String> {
    if x < 0.0 {
        Err("Negative value".to_string())
    } else {
        Ok(x.sqrt())
    }
}

What I want is to use sum10 with sqrt_err as the argument, and if at some point sqrt_err is called and returns the Error variant, the error is returned instead.

// So what I want would be equivalent to:
fn sum10_error(fun: impl Fn(f64) -> Result<f64, String>) -> Result<f64, String> {
    let mut sum = 0.;
    for i in 0..10 {
        sum += E.powf(fun(i as f64)?);
    }
    Ok(sum)
}

Solution

  • The other answer is incorrect. It is possible to do exactly what is asked. It is just not idiomatic.

    playground

    use std::f64::consts::E;
    
    fn sum10(fun: impl Fn(f64) -> f64) -> f64 {
        let mut sum = 0.;
        for i in 0..10 {
            sum += E.powf(fun(i as f64));
        }
        sum
    }
    
    fn sqrt_err(x: f64) -> Result<f64, String> {
        if x < 0.0 {
            Err("Negative value".to_string())
        } else {
            Ok(x.sqrt())
        }
    }
    
    fn transform_to_err<'a, F, G, A, B, C, Error>(f: F, g: G) -> Result<C, Error>
    where
        F: FnOnce(Box<dyn Fn(A) -> B + 'a>) -> C + std::panic::UnwindSafe,
        G: Fn(A) -> Result<B, Error> + std::panic::UnwindSafe + 'a,
        Error: Send + std::any::Any,
    {
        // Wrap the error to make sure captured panics are sent from this function
        struct ErrorWrapper<E>(E);
        
        // Wrap g such that it panics instead of returning an error
        let g_panic = move |a: A| {
            match g(a) {
                Ok(ret) => ret,
                Err(error) => {
                    // prevent running the panic hook
                    std::panic::resume_unwind(Box::new(ErrorWrapper(error)));
                }
            }
        };
        // Catch the error
        std::panic::catch_unwind(move ||f(Box::new(g_panic))).map_err(|error|{
            match error.downcast::<ErrorWrapper<Error>>() {
                Ok(error) => error.0,
                Err(error) => std::panic::resume_unwind(error),
            }
        })
    }
    
    fn main() {
        let ret = transform_to_err(sum10, sqrt_err);
        println!("{:?}", ret);
    }