Search code examples
scalamacrosannotationsscala-macrosscalameta

Macro annotation to override toString of Scala function


How to write macro annotation which looks in usage like @named("+2") _ + 2 and produces:

new (Int => Int) {
  override def toString(): String = "+2"
  def apply(x: Int): Int = x + 2
}

Solution

  • Correct syntax is ((_: Int) + 2): @named("+2"). Unfortunately macro annotations annotating expressions don't expand.

    The simplest is to use

    object Named { 
      def build[T, R](name: String)(applyFunc: T => R): T => R = new (T => R) { 
        override def toString() = name
        def apply(x: T): R = applyFunc(x) 
      }
    }
    

    without any macros.

    Otherwise Scalameta can expand annotations on expressions:

    build.sbt (sbt documentation about generation of sources is here)

    ThisBuild / name := "scalametademo"
    
    lazy val commonSettings = Seq(
      scalaVersion := "2.13.1",
    )
    
    lazy val annotations = project
      .settings(
        commonSettings,
      )
    
    lazy val helpers = project
      .settings(
        commonSettings,
      )
    
    lazy val in = project
      .dependsOn(annotations)
      .settings(
        commonSettings,
      )
    
    lazy val out = project
      .dependsOn(helpers)
      .settings(
        sourceGenerators in Compile += Def.task {
          Generator.gen(
            inputDir  = sourceDirectory.in(in, Compile).value,
            outputDir = sourceManaged.in(Compile).value
          )
        }.taskValue,
    
        commonSettings,
      )
    

    project/build.sbt

    libraryDependencies += "org.scalameta" %% "scalameta" % "4.3.0"
    

    project/Generator.scala

    import sbt._
    
    object Generator {
      def gen(inputDir: File, outputDir: File): Seq[File] = {
        val finder: PathFinder = inputDir ** "*.scala"
    
        for(inputFile <- finder.get) yield {
          val inputStr = IO.read(inputFile)
          val outputFile = outputDir / inputFile.toURI.toString.stripPrefix(inputDir.toURI.toString)
          val outputStr = Transformer.transform(inputStr)
          IO.write(outputFile, outputStr)
          outputFile
        }
      }
    }
    

    project/Transformer.scala

    import scala.meta._
    
    object Transformer {
      val getNamedAnnotationParam: PartialFunction[Mod, Lit] = {
        case mod"@named(...${List(List(s: Lit))})" => s
      }
    
      val isNamedAnnotated: Mod => Boolean = getNamedAnnotationParam.lift(_).isDefined
    
      def transform(input: String): String = transform(input.parse[Source].get).toString
    
      def transform(input: Tree): Tree = input.transform {
        case q"package $eref { ..$stats }" =>
          val stats1 = stats.filter {
            case q"import ..${List(importer"annotations.{..$importeesnel}")}" => false
            case _ => true
          }
    
          q"package $eref { ..$stats1 }"
    
        case q"$expr: ..@$annotsnel" if annotsnel.exists(isNamedAnnotated) =>
          val annotsnel1 = annotsnel.filterNot(isNamedAnnotated)
          val name = annotsnel.collect(getNamedAnnotationParam).head
    
          val expr1 = expr match {
            case q"(..$params) => $expr2" =>
              val params1 = params.map {
                case param"..$mods $name: ${Some(tpe)} = $expropt" => 
                  param"..$mods $name: ${Some(tpe)} = $expropt"
                case param"..$mods $name: ${None} = $expropt" => 
                  param"..$mods $name: scala.Any = $expropt"
              }
    
              val domain = params1.map {
                case param"..$mods $name: $tpeopt = $expropt" => tpeopt.get
              }
    
              q"""
                   val typed = com.example.helpers.${Term.Name("TypedFunction" + params.length)}($expr)
    
                   new ((..$domain) => typed.CoDomain) {
                     override def toString(): String = $name
                     def apply(..$params1): typed.CoDomain = $expr2
                   }
                 """
    
            case e => e
          }
    
          if (annotsnel1.nonEmpty)
            q"$expr1: ..@$annotsnel1"
          else q"$expr1"
      }
    }
    

    annotations/src/main/scala/com/example/annotations/named.scala

    package com.example.annotations
    
    import scala.annotation.StaticAnnotation
    
    class named(name: String) extends StaticAnnotation
    

    helpers/src/main/scala/com/example/helpers/TypedFunctions.scala

    package com.example.helpers
    
    sealed trait TypedFunctions[_CoDomain] {
      type CoDomain = _CoDomain
    }
    
    case class TypedFunction0[_CoDomain](f: () => _CoDomain) extends TypedFunctions[_CoDomain]
    case class TypedFunction1[_Domain,  _CoDomain](f: _Domain => _CoDomain) extends TypedFunctions[_CoDomain]
    case class TypedFunction2[_Domain1, _Domain2,  _CoDomain](f: (_Domain1, _Domain2) => _CoDomain) extends TypedFunctions[_CoDomain]
    case class TypedFunction3[_Domain1, _Domain2, _Domain3, _CoDomain](f: (_Domain1, _Domain2, _Domain3) => _CoDomain) extends TypedFunctions[_CoDomain]
    

    in/src/main/scala/com/example/App.scala

    package com.example
    
    import annotations.named
    
    object App {
      (((x: Int) => x + 2): @named("+2"))
    
      (((x: Int, y: Int) => x + y): @named("+"))
    }
    

    out/target/scala-2.13/src_managed/main/scala/com/example/App.scala (after sbt "; project out; clean; compile")

    package com.example
    object App {
      {
        val typed = com.example.helpers.TypedFunction1 { (x: Int) => x + 2 }
        new (Int => typed.CoDomain) {
          override def toString(): String = "+2"
          def apply(x: Int): typed.CoDomain = x + 2
        }
      }
      {
        val typed = com.example.helpers.TypedFunction2 { (x: Int, y: Int) => x + y }
        new ((Int, Int) => typed.CoDomain) {
          override def toString(): String = "+"
          def apply(x: Int, y: Int): typed.CoDomain = x + y
        }
      }
    }
    

    Another example is How to merge multiple imports in scala?