Search code examples
rusttypesmacrostype-conversiontraits

How to check that T or Any implements a trait in Rust lang


I'm just starting out with Rust and need some help understanding type checking in the language. I'm trying to introduce an activity logger to capture and log the execution duration and result of a function. Here's a simplified implementation:

struct Activity {
    name: String,
    start_time: std::time::Instant,
}

impl Activity {
    fn begin(name: &str) -> Self {
        println!("Activity started: {}", name);
        Activity {
            name: name.to_string(),
            start_time: std::time::Instant::now(),
        }
    }

    fn end_ok(&self) {
        let duration = self.start_time.elapsed().as_millis();
        println!("Activity ended: {}, Status: Success, Duration: {}", self.name, duration);
    }

    fn end_error<T: ToString>(&self, error: &T) {
        let duration = self.start_time.elapsed().as_millis();
        println!(
            "Activity ended: {}, Status: Fail, Duration: {}, Message: {}",
            self.name,
            duration,
            error.to_string()
        );
    }
}

type MyResult = Result<(), String>;

fn test1() -> Result<(), Box<dyn std::error::Error>> {
    test2()?;
    Ok(())
}

fn test2() -> MyResult {
    Err("oops".to_string())
}

fn main() {
    let activity = Activity::begin("TestActivity");

    match test1() {
        Ok(_) => activity.end_ok(),
        Err(error) => activity.end_error(&error),
    }
}

While this works, it feels a bit messy since I have to manually analyze each result and build a result chain. I'd prefer to use the ? operator, but that requires wrapping the result handling. I tried using a proc_macro_attribute to handle this wrapping, but I ran into the challenge of determining the return type of the function.

What I want is to recognize what the function block returns. If it's a Result<_, _> and an Err, I want to call end_error(error); otherwise, I want to call end_ok(). Here's an example of my macro attempt:

extern crate proc_macro;

use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, ItemFn, LitStr, ReturnType};

#[proc_macro_attribute]
pub fn activity_logger(attr: TokenStream, item: TokenStream) -> TokenStream {
    let logger_name = parse_macro_input!(attr as LitStr).value();
    let input_fn = parse_macro_input!(item as ItemFn);
    let fn_name = &input_fn.sig.ident;
    let fn_block = &input_fn.block;
    let fn_return_type = &input_fn.sig.output;
    
    let output = match fn_return_type {
        ReturnType::Type(_, _) => quote! {
            fn #fn_name() #fn_return_type {
                let activity = Activity::begin(#logger_name);

                let result = (|| #fn_block)();
                
                // C# like pseudo code start
                if result is Result<_, _> {
                    match result {
                        Ok(_) => activity.end_ok(),
                        Err(error) => activity.end_error(error),
                    }
                } else {
                    activity.end_ok();
                }
                // C# like pseudo code end

                result
            }
        },
        ReturnType::Default => quote! {
            fn #fn_name() {
                let activity = Activity::begin(#logger_name);

                #fn_block
                
                activity.end_ok();
            }
        },
    };

    output.into()
}

Is there any way to determine at runtime whether an object implements the Try trait or is a Result<_, _>? Or can a macro help with this? I appreciate any advice!

The only way I've found to identify the return type is by extracting the type name using ReturnType::Type(_, ty) or with std::any::type_name, but that's not ideal since the result type can be renamed.

I tries resolve it with traits

  • but Rust doesn't allow such code because of impl ambiguity:
    trait IsResult {
        const IS_RESULT: bool;
    }
    
    impl<T, E> IsResult for Result<T, E> {
        const IS_RESULT: bool = true;
    }
    
    impl<T> IsResult for T {
        const IS_RESULT: bool = false;
    }
    
  • but Rust's traits aren't like interfaces in other languages, and this approach didn't work:
    trait Loggable {
        fn can_log(&self) -> bool;
    }
    
    trait LoggableResult: Loggable {
        fn can_log(&self) -> bool;
    }
    
    impl<T> Loggable for T {
        fn can_log(&self) -> bool {
            false
        }
    }
    
    impl<T, E> LoggableResult for Result<T, E> {
        fn can_log(&self) -> bool {
            true
        }
    }
    
    fn can_log(value: &dyn Loggable) -> bool {
        value.can_log()
    }
    
    fn test() {
        let test1_result = 123;
        let test1 = can_log(&test1_result); // false
        let test2_result: Result<(), String> = Result::Err("oops".to_string());
        let test2 = can_log(&test2_result); // false
    }
    

Solution

  • This is actually possible! (see https://github.com/dtolnay/case-studies/blob/master/autoref-specialization/README.md) I will create a declarative macro to demonstrate:

    trait TypeDetector {
        fn type_detect(&self, activity: &Activity);
    }
    
    impl<A, B: ToString> TypeDetector for Result<A, B> {
        fn type_detect(&self, activity: &Activity) {
            match self {
                Ok(_) => {
                    activity.end_ok();
                }
                Err(error) => {
                    activity.end_error(error);
                }
            }
        }
    }
    
    impl<A> TypeDetector for &A {
        fn type_detect(&self, activity: &Activity) {
            activity.end_ok();
        }
    }
    
    macro_rules! activity_logger {
        ($log_name:expr, $vis:vis fn $name:ident() -> $out:ty {$($body:tt)*}) => {
            $vis fn $name() -> $out {
                let activity = Activity::begin($log_name);
                let result = (||{$($body)*})();
                (&result).type_detect(&activity);
                result
            }
        }
    }
    

    The magic is in this line

    (&result).type_detect(&activity);
    

    If result is a result, this will find the implementation of TypeDetector on Result. Otherwise, the compiler will try to call type_detect on &&result, which will use the other implementation of TypeDetector.

    Here is a full example (playground):

    struct Activity {
        name: String,
        start_time: std::time::Instant,
    }
    
    impl Activity {
        fn begin(name: &str) -> Self {
            println!("Activity started: {}", name);
            Activity {
                name: name.to_string(),
                start_time: std::time::Instant::now(),
            }
        }
    
        fn end_ok(&self) {
            let duration = self.start_time.elapsed().as_millis();
            println!("Activity ended: {}, Status: Success, Duration: {}", self.name, duration);
        }
    
        fn end_error<T: ToString>(&self, error: &T) {
            let duration = self.start_time.elapsed().as_millis();
            println!(
                "Activity ended: {}, Status: Fail, Duration: {}, Message: {}",
                self.name,
                duration,
                error.to_string()
            );
        }
    }
    
    type MyResult = Result<(), String>;
    
    trait TypeDetector {
        fn type_detect(&self, activity: &Activity);
    }
    
    impl<A, B: ToString> TypeDetector for Result<A, B> {
        fn type_detect(&self, activity: &Activity) {
            match self {
                Ok(_) => {
                    activity.end_ok();
                }
                Err(error) => {
                    activity.end_error(error);
                }
            }
        }
    }
    
    impl<A> TypeDetector for &A {
        fn type_detect(&self, activity: &Activity) {
            activity.end_ok();
        }
    }
    
    macro_rules! activity_logger {
        ($log_name:expr, $vis:vis fn $name:ident() -> $out:ty {$($body:tt)*}) => {
            $vis fn $name() -> $out {
                let activity = Activity::begin($log_name);
                let result = (||{$($body)*})();
                (&result).type_detect(&activity);
                result
            }
        }
    }
    
    activity_logger! { "TestActivity",
        fn test1() -> Result<(), Box<dyn std::error::Error>> {
            test2()?;
            Ok(())
        }
    }
    
    activity_logger! {"test2",
        fn test2() -> MyResult {
            Err("oops".to_string())
        }
    }
    
    activity_logger! {"test3",
        fn test3() -> u8 {
            3
        }
    }
    
    fn main() {
        test1();
        test3();
    }