Search code examples
rustrust-proc-macros

How to parse other attributes in custom rust proc_macro attribute?


I am writing a proc_macro attribute that add fields to a struct and also implement my trait (and other by adding a #[derive(...)]) to the expanded struct

Here a simplified version of what I want to do:

#[foo("some value")]
#[derive(Debug)]
struct A {
    bar: u32,
}

After expansion:

#[derive(Debug, Default, Serialize)]
struct A {
    foo: u64
    bar: u32,
}

impl FooTrait for A {
    ...
}

How can I parse the derive attribute in order to add the Debug trait with all the rest of the trait that I add with the foo proc_macro?


Solution

  • With syn, inspect and change ItemStruct::attrs:

    use std::collections::HashSet;
    
    use proc_macro2::{Delimiter, TokenTree};
    use quote::ToTokens;
    use syn::parse::Parser;
    use syn::punctuated::Punctuated;
    
    #[proc_macro_attribute]
    pub fn foo(
        attr: proc_macro::TokenStream,
        input: proc_macro::TokenStream,
    ) -> proc_macro::TokenStream {
        let mut input = syn::parse_macro_input!(input as syn::ItemStruct);
    
        let mut all_derived_traits = HashSet::new();
        for i in 0..input.attrs.len() {
            if !input.attrs[i].path.is_ident("derive") {
                continue;
            }
    
            let derive = input.attrs.remove(i);
            let mut tokens = derive.tokens.clone().into_iter();
            match [tokens.next(), tokens.next()] {
                [Some(TokenTree::Group(group)), None]
                    if group.delimiter() == Delimiter::Parenthesis =>
                {
                    match Punctuated::<syn::Path, syn::Token![,]>::parse_terminated
                        .parse2(group.stream())
                    {
                        Ok(derived_traits) => all_derived_traits.extend(derived_traits),
                        Err(e) => return e.into_compile_error().into(),
                    }
                }
                _ => {
                    return syn::Error::new_spanned(derive, "malformed derive")
                        .into_compile_error()
                        .into()
                }
            }
        }
        
        all_derived_traits.extend([
            syn::parse_quote!(Default),
            syn::parse_quote!(Serialize),
        ]);
        
        let all_derived_traits = all_derived_traits.into_iter();
        input.attrs.push(syn::parse_quote! {
            #[derive( #(#all_derived_traits),* )]
        });
        
        input.into_token_stream().into()
    }
    

    First we collect all derive() attributes in a hashset, so we won't generate the derives if they're already there. This logic is not perfect: for example, we won't identify std::default::Default as the same as Default. But it should be enough. Then we add our traits and re-generate the derives. Note that mutliple #[derive()] lines are unified and traits may be ordered, but it doesn't matter.