Search code examples
scalascala-macrosscala-3

How to get Scala annotations that are given to an argument of a method


Consider the following annotation:

// ok to have more meta
@field
@param
@compileTimeOnly("Only for code generation")
case class Annot(value: String) extends ConstantAnnotation

Now three uses:

case class A(x: Int, @Annot("z") y: String)
object A:
  def f1(x: Int, y: String @Annot("z")): A = ???
  def f2(x: Int, @Annot("z") y: String): A = ???

I would like to use Scala 3 macros to find each of these annotations.

  1. Case Class: Symbol.caseFields gives me the list of parameters, and on each of those parameters (of type Symbol), method annotations gives me what I am looking for.
  2. Annotated Type: Each param is a ValDef. If param.tpt.tpe matches AnnotatedType(tpe, t) then t is the annotation that I am looking for.
  3. Annotated method argument: I HAVE NOTHING!

Any idea how I can get the annotations that are given to an argument in a method? When I print terms/symbols/trees/... I cannot even see "z" in this case.


Solution

  • You can write in a macro

    import scala.annotation.experimental
    import scala.quoted.*
    
    inline def printAnnotations[A]: Unit = ${printAnnotationsImpl[A]}
    
    @experimental // because .typeRef is @experimental
    def printAnnotationsImpl[A](using Quotes, Type[A]): Expr[Unit] = {
      import quotes.reflect.*
    
      val symbol = TypeTree.of[A].symbol
      println(symbol.caseFields.map(_.annotations)) // Case Class
    
      val companion = symbol.companionModule
      val methods = companion.declaredMethods
      // Annotated method argument
      println(methods.map(_.paramSymss.map(_.map(_.annotations)))) 
    
      // Annotated Type
      println(methods.collect(method => companion.typeRef.memberType(method) match {
        case lt: LambdaType => lt.paramTypes.collect {
          case at: AnnotatedType => at.annotation
        }
      }))
    
      '{()}
    }
    

    Then for

    case class A(x: Int, @Annot("z1") y: String)
    object A:
      def f1(x: Int, y: String @Annot("z2")): A = ???
      def f2(x: Int, @Annot("z3") y: String): A = ???
    

    printAnnotations[A] will print

    List(List(), List(Apply(Select(New(Ident(Annot)),<init>),List(Literal(Constant(z1))))))
    
    List(List(List(List(), List())), List(List(List())), List(), List(List(List(), List())), List(List(List(), List(Apply(Select(New(Ident(Annot)),<init>),List(Literal(Constant(z3))))))), List(List()), List(List(List())))
    
    List(List(), List(), List(Apply(Select(New(Ident(Annot)),<init>),List(Literal(Constant(z2))))), List(), List(), List())
    

    Scala 3.1.3