Search code examples
rustrust-proc-macros

How do I explictly require a Trait for all arguments of a function in a procedural macro?


The short version

The short version is: when writing a procedural macro wrap that wraps around a function (kind of the way the trace crate does), what's the best way to require that all arguments implement a specific trait?

The problem

trace expands to calls to println!("{:?}", i) so with:

struct MyType(u32);

#[trace]
fn foo(i: MyType) {
    println!("MyType: {}", i.0);
}

I get

error[E0277]: `MyType` doesn't implement `Debug`
  --> src/main.rs:10:8
   |
9  | #[trace]
   | ------- in this procedural macro expansion
10 | fn foo(i: MyType) {
   |        ^ `MyType` cannot be formatted using `{:?}`
   |
   = help: the trait `Debug` is not implemented for `MyType`

Which is nice and explicit: it shows the error in the argument declaration, and says that MyType should implement Debug.

Now I'm trying to do the same but by expanding to something like MyTrait::my_trait_method(&i), and still get basically the same error message.

What I tried

So far the best I could do is adding Where clauses in an expansion generated through the quote! macro (I'm still figuring out how to use the syn crate so it's still quite hardcoded), at least I do get an explicit message saying that MyType does not implement MyTrait, but the error is reported on the macro rather than on the argument definition :

#[proc_macro_attribute]
pub fn wrap(_macro_args: TokenStream, input: TokenStream) -> TokenStream {
    let func = parse_macro_input!(input as syn::ItemFn);
    let orig_name = func.sig.ident.to_string();
    let params: Punctuated<Pat, Token![,]> = func
        .sig
        .inputs
        .iter()
        .map(|input| match input {
            FnArg::Typed(PatType { pat: t, .. }) => t.as_ref().clone(),
            _ => panic!("Unexpected input for function {orig_name}: {input:?}"),
        })
        .collect();


    let pattern = "{:?}, ".repeat(params.len());
    quote! {
        fn foo(i: MyType) where MyType: MyTrait{
            //i.my_trait_method();
            println!(#pattern, #params);
        }
    }
    .into()
}
#[trace]
fn foo(i: MyType) {
    println!("MyType: {}", i.0);
}
error[E0277]: the trait bound `MyType: MyTrait` is not satisfied
  --> src/main.rs:11:1
   |
11 | #[wrap]
   | ^^^^^^^ the trait `MyTrait` is not implemented for `MyType`

Is generating Where clauses the way to go (I'll do that properly if that's the case), or is there a better way to require a specific trait for all arguments of a function parsed in a proc_macro ?


Solution

  • While Chayim Friedman's answer is correct, I wanted to add that you can still use identifiers and spans to improve your error.

    In the below code, instead of hardcoding the Trait, you create an identifier and point to the 'span' of parameters (basically the original location of those parameters in your code). The rest of the code is almost unchanged.

    // ...
    use syn::Ident;
    
    #[proc_macro_attribute]
    pub fn wrap(_macro_args: TokenStream, input: TokenStream) -> TokenStream {
        let func = parse_macro_input!(input as syn::ItemFn);
        let orig_name = func.sig.ident.to_string();
        let params: Punctuated<Pat, Token![,]> = func
            .sig
            .inputs
            .iter()
            .map(|input| match input {
                FnArg::Typed(PatType { pat: t, .. }) => t.as_ref().clone(),
                _ => panic!("Unexpected input for function {orig_name}: {input:?}"),
            })
            .collect();
    
        let trait_ident = Ident::new("MyTrait", params.span());
    
        let pattern = "{:?}, ".repeat(params.len());
        quote! {
            fn foo(i: MyType) where MyType: #trait_ident {
                println!(#pattern, #params);
            }
        }
        .into()
    }
    

    Now, with the following code, also based on your example:

    #[derive(Debug)]
    struct MyType(u32);
    
    trait MyTrait {}
    
    // activate if you do not want an error
    // impl MyTrait for MyType {}
    
    #[wrap]
    fn foo(i: MyType) {
        println!("MyType: {}", i.0);
    }
    
    fn main() {}
    

    You will get back this error:

    error[E0277]: the trait bound `MyType: MyTrait` is not satisfied
      --> src/main.rs:12:8
       |
    12 | fn foo(i: MyType) {
       |        ^ the trait `MyTrait` is not implemented for `MyType`
       |
       = help: see issue #48214
    

    Which looks a lot more helpful to me.