Search code examples
rust

How can I write a macro to create the Match block for an Enum Dispatch?


I'm using EnumDispatch to store the possible Actions I want to take, and each Action is a struct I need to give arguments to.

#[enum_dispatch]
pub trait ActionRunner {
    fn run(&self);
}

#[enum_dispatch(ActionRunner)]
#[derive(Debug)]
enum ActionsEnum {
    MyAction,
    MyAction2,
}

let action: ActionsEnum = match my_output.action.name.as_str() {

        "MyAction" => {
            let foo: MyAction = serde_json::from_value(args).unwrap();
            foo.into()
        },
        "MyAction2" => {
            let foo: MyAction2 = serde_json::from_value(args).unwrap();
            foo.into()
        },

                
        _ => panic!("No found action"),
    };

I'm trying to avoid having to write that Match statement by writing a macro and I think I'm not far from getting it but I'm struggling to find out how I need to pass the arguments:

#[macro_export]
macro_rules! match_all_variants {
    ($input_str:expr, $action_args:ident, $enum_type:ident, $($enum_variants:ty),* ) => {
        match $input_str {
            $(
                stringify!($enum_variants) => {
                    let foo: $enum_variants = serde_json::from_value($action_args).unwrap();
                    foo.into()
                },
            )*
        }
    };
}

so that in the end I could just do:

let action: ActionEnum = match_all_variants!("MyAction2", my_args, ActionEnum, <something>);

action.run();

Does anyone have any ideas on how to achieve that? Ideally I didn't want to have to use a macro but I think there's no way without it.


Solution

  • Since you already use serde_json for deserialisation, I would let the deserialisation take care of the instantiation of the expected structure.

    Please find below an attempt to adapt your example with this idea in mind.

    use enum_dispatch::enum_dispatch;
    use serde::Deserialize;
    
    #[enum_dispatch]
    pub trait ActionRunner {
        fn run(&self);
    }
    
    #[enum_dispatch(ActionRunner)]
    #[derive(Debug, Deserialize)]
    enum ActionsEnum {
        MyAction(MyAction),
        MyAction2(MyAction2),
    }
    
    #[derive(Debug, Deserialize)]
    struct MyAction {
        a: i32,
        b: f32,
    }
    
    impl ActionRunner for MyAction {
        fn run(&self) {
            println!("• running MyAction: a={} b={}", self.a, self.b);
        }
    }
    
    #[derive(Debug, Deserialize)]
    struct MyAction2 {
        t: String,
    }
    
    impl ActionRunner for MyAction2 {
        fn run(&self) {
            println!("• running MyAction2: t={:?}", self.t);
        }
    }
    
    fn main() {
        let incoming = [
            ("MyAction", r#"{"a": 123, "b": 45.67}"#),
            ("MyAction2", r#"{"t": "hello"}"#),
        ];
        for (action_name, args) in incoming {
            let full_json = format!("{{{:?}: {}}}", action_name, args);
            match serde_json::from_str::<ActionsEnum>(&full_json) {
                Ok(action) => {
                    println!("Action: {:?}", action);
                    action.run();
                }
                Err(e) => {
                    eprintln!("cannot build action: {}", e);
                }
            }
        }
    }
    /*
    Action: MyAction(MyAction { a: 123, b: 45.67 })
    • running MyAction: a=123 b=45.67
    Action: MyAction2(MyAction2 { t: "hello" })
    • running MyAction2: t="hello"
    */