Search code examples
rustrust-proc-macros

How do I extract information about the type in a derive macro?


I am implementing a derive macro to reduce the amount of boilerplate I have to write for similar types.

I want the macro to operate on structs which have the following format:

#[derive(MyTrait)]
struct SomeStruct {
    records: HashMap<Id, Record>
}

Calling the macro should generate an implementation like so:

impl MyTrait for SomeStruct {
    fn foo(&self, id: Id) -> Record { ... }
}

So I understand how to generate the code using quote:

#[proc_macro_derive(MyTrait)]
pub fn derive_answer_fn(item: TokenStream) -> TokenStream {
    ...
    let generated = quote!{

        impl MyTrait for #struct_name {
            fn foo(&self, id: #id_type) -> #record_type { ... }
        }

    }
    ...
}

But what is the best way to get #struct_name, #id_type and #record_type from the input token stream?


Solution

  • One way is to use the venial crate to parse the TokenStream.

    use quote::quote;
    
    #[proc_macro_derive(MyTrait)]
    pub fn derive_answer_fn(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
        // Ensure it's deriving for a struct.
        let s = match venial::parse_declaration(proc_macro2::TokenStream::from(item)) {
            Ok(venial::Declaration::Struct(s)) => s,
            Ok(_) => panic!("Can only derive this trait on a struct"),
            Err(_) => panic!("Error parsing into valid Rust"),
        };
    
        let struct_name = s.name;
    
        // Get the struct's first field.
        let fields = s.fields;
        let named_fields = match fields {
            venial::StructFields::Named(named_fields) => named_fields,
            _ => panic!("Expected a named field"),
        };
    
        let inners: Vec<(venial::NamedField, proc_macro2::Punct)> = named_fields.fields.inner;
        if inners.len() != 1 {
            panic!("Expected exactly one named field");
        }
    
        // Get the name and type of the first field.
        let first_field_name = &inners[0].0.name;
        let first_field_type = &inners[0].0.ty;
    
        // Extract Id and Record from the type HashMap<Id, Record>
        if first_field_type.tokens.len() != 6 {
            panic!("Expected type T<R, S> for first named field");
        }
    
        let id = first_field_type.tokens[2].clone();
        let record = first_field_type.tokens[4].clone();
    
        // Implement MyTrait.
        let generated = quote! {
            impl MyTrait for #struct_name {
                fn foo(&self, id: #id) -> #record { *self.#first_field_name.get(&id).unwrap() }
            }
        };
    
        proc_macro::TokenStream::from(generated)
    }