Search code examples
rustrust-macros

Custom literals via Rust macros?


Is it possible in Rust to define a macro that can parse custom literals, e.g. something along the lines of

vector!(3x + 15y)

To clarify, I would like to be able to get as close to the above syntax as one can (within the realm of what is possible of course).


Solution

  • I'm going to assume that by "custom literal", you specifically mean "a regular Rust literal (excluding raw literals), immediately followed by a custom identifier". This includes:

    • "str"x, a string literal "str" with custom suffix x
    • 123x, a numeric literal 123 with custom suffix x
    • b"bytes"x, a byte literal b"bytes" with custom suffix x

    If the above is a sufficient definition for you, then you're lucky, as the above are indeed all valid literal tokens in Rust, according to the Rust reference:

    A suffix is a non-raw identifier immediately (without whitespace) following the primary part of a literal.

    Any kind of literal (string, integer, etc) with any suffix is valid as a token, and can be passed to a macro without producing an error. The macro itself will decide how to interpret such a token and whether to produce an error or not.

    However, suffixes on literal tokens parsed as Rust code are restricted. Any suffixes are rejected on non-numeric literal tokens, and numeric literal tokens are accepted only with suffixes from the list below.

    So Rust explicitly allows macros to support custom literals.

    Now, how would you go about writing such a macro? You can't write a declarative macro with macro_rules!, since it's not possible to detect and manipulate custom literal suffixes with its simple pattern matching. However, it is possible to write a procedural macro that does this.

    I won't go into too much detail about how to write procedural macros, since that would be too much to write in a single StackOverflow answer. However, I'll give you this example of a procedural macro that does something along the lines of what you asked for, as a starting point. It takes any custom integer literals 123x or 123y in the given expression, and transforms them into the function calls x_literal(123) and y_literal(123) instead:

    extern crate proc_macro;
    
    use proc_macro::TokenStream;
    use quote::ToTokens;
    use syn::{
        parse_macro_input, parse_quote,
        visit_mut::{self, VisitMut},
        Expr, ExprLit, Lit, LitInt,
    };
    
    
    // actual procedural macro
    #[proc_macro]
    pub fn vector(input: TokenStream) -> TokenStream {
        let mut input = parse_macro_input!(input as Expr);
        LiteralReplacer.visit_expr_mut(&mut input);
        input.into_token_stream().into()
    }
    
    // "visitor" that visits every node in the syntax tree
    // we add our own behavior to replace custom literals with proper Rust code
    struct LiteralReplacer;
    
    impl VisitMut for LiteralReplacer {
        fn visit_expr_mut(&mut self, i: &mut Expr) {
            if let Expr::Lit(ExprLit { lit, .. }) = i {
                match lit {
                    Lit::Int(lit) => {
                        // get literal suffix
                        let suffix = lit.suffix();
                        // get literal without suffix
                        let lit_nosuffix = LitInt::new(lit.base10_digits(), lit.span());
    
                        match suffix {
                            // replace literal expression with new expression
                            "x" => *i = parse_quote! { x_literal(#lit_nosuffix) },
                            "y" => *i = parse_quote! { y_literal(#lit_nosuffix) },
                            _ => (), // other literal suffix we won't modify
                        }
                    }
    
                    _ => (), // other literal type we won't modify
                }
            } else {
                // not a literal, use default visitor method
                visit_mut::visit_expr_mut(self, i)
            }
        }
    }
    

    The macro would for example transform vector!(3x + 4y) into x_literal(3) + y_literal(4).