Search code examples
macrosrustrust-macros

How to match Rust's `if` expressions in a macro?


I'm trying to write a macro that will rewrite certain Rust control flow, but I'm having difficulty matching an if expression. The problem is that the predicate is an expression, but an expr is not permitted to be followed by a block or {.

The best I've got is to use tt:

macro_rules! branch {
    (
        if $pred:tt 
            $r1:block
        else
            $r2:block
    ) => {
        if $pred { 
            $r1
        } else {
            $r2
        }
    };
}

Which works fine with single-token or grouped predicates:

branch! {
    if (foo == bar) {
        1
    } else {
        2
    }
}

But fails if the predicate was not grouped:

branch! {
    if foo == bar {
        1
    } else {
        2
    }
}
error: no rules expected the token `==`

I also tried to use a repeating pattern of tt in the predicate:

macro_rules! branch {
    (
        if $($pred:tt)+
            $r1:block
        else
            $r2:block
    ) => {
        if $($pred)+ { 
            $r1
        } else {
            $r2
        }
    };
}

But this produces an error because it's now ambiguous whether subsequent block should match the tt too:

error: local ambiguity: multiple parsing options: built-in NTs tt ('pred') or block ('r1').

Is there a way to do this, or am I stuck with inventing special syntax to use in the macro?


Solution

  • You could use a TT muncher to parse the predicate:

    macro_rules! branch {
        {
            if $($rest:tt)*
        } => {
            branch_parser! {
                predicate = ()
                rest = ($($rest)*)
            }
        };
    }
    
    macro_rules! branch_parser {
        {
            predicate = ($($predicate:tt)*)
            rest = ({ $($then:tt)* } else { $($else:tt)* })
        } => {
            println!("predicate: {}", stringify!($($predicate)*));
            println!("then: {}", stringify!($($then)*));
            println!("else: {}", stringify!($($else)*));
        };
    
        {
            predicate = ($($predicate:tt)*)
            rest = ($next:tt $($rest:tt)*)
        } => {
            branch_parser! {
                predicate = ($($predicate)* $next)
                rest = ($($rest)*)
            }
        };
    }
    
    fn main() {
        branch! {
            if foo == bar {
                1
            } else {
                2
            }
        }
    }
    

    Output:

    predicate: foo == bar
    then: 1
    else: 2