Search code examples
scalamacrosscala-macrosscala-3

Can you implement dsinfo in Scala 3? (Can Scala 3 macros get info about their context?)


The dsinfo library lets you access the names of values from the context of where a function is written using Scala 2 macros. The example they give is that if you have something like

val name = myFunction(x, y)

myFunction will actually be passed the name of its val in addition to the other arguments, i.e., myFunction("name", x, y).

This is very useful for DSLs where you'd like named values for error reporting or other kinds of encoding. The only other option seems to explicitly pass the name as a String, which can lead to unintentional mismatches.

Is this possible with Scala 3 macros, and if so, how do you "climb up" the tree at the macro's use location to find its id?


Solution

  • In Scala 3 there is no c.macroApplication. Only Position.ofMacroExpansion instead of a tree. But we can analyze Symbol.spliceOwner.maybeOwner. I presume that scalacOptions += "-Yretain-trees" is switched on.

    import scala.annotation.experimental
    import scala.quoted.*
    
    object Macro {
      inline def makeCallWithName[T](inline methodName: String): T = 
        ${makeCallWithNameImpl[T]('methodName)}
    
      @experimental
      def makeCallWithNameImpl[T](methodName: Expr[String])(using Quotes, Type[T]): Expr[T] = {
        import quotes.reflect.*
        println(Position.ofMacroExpansion.sourceCode)//Some(twoargs(1, "one"))
    
        val methodNameStr = methodName.valueOrAbort
        val strs = methodNameStr.split('.')
        val moduleName = strs.init.mkString(".")
        val moduleSymbol = Symbol.requiredModule(moduleName)
    
        val shortMethodName = strs.last
        val ident = Ident(TermRef(moduleSymbol.termRef, shortMethodName))
    
        val (ownerName, ownerRhs) = Symbol.spliceOwner.maybeOwner.tree match {
          case ValDef(name, tpt, Some(rhs)) => (name, rhs)
          case DefDef(name, paramss, tpt, Some(rhs)) => (name, rhs)
          case t => report.errorAndAbort(s"can't find RHS of ${t.show}")
        }
    
        val treeAccumulator = new TreeAccumulator[Option[Tree]] {
          override def foldTree(acc: Option[Tree], tree: Tree)(owner: Symbol): Option[Tree] = tree match {
            case Apply(fun, args) if fun.symbol.fullName == "App$.twoargs" =>
              Some(Apply(ident, Literal(StringConstant(ownerName)) :: args))
            case _ => foldOverTree(acc, tree)(owner)
          }
        }
        treeAccumulator.foldTree(None, ownerRhs)(ownerRhs.symbol)
          .getOrElse(report.errorAndAbort(s"can't find twoargs in RHS: ${ownerRhs.show}"))
          .asExprOf[T]
      }
    }
    

    Usage:

    package mypackage
    case class TwoArgs(name : String, i : Int, s : String)
    
    import mypackage.TwoArgs
    
    object App {
      inline def twoargs(i: Int, s: String) = 
        Macro.makeCallWithName[TwoArgs]("mypackage.TwoArgs.apply")
    
      def x() = twoargs(1, "one") // TwoArgs("x", 1, "one")
    
      def aMethod() = {
        val y = twoargs(2, "two") // TwoArgs("y", 2, "two")
      }
    
      val z = Some(twoargs(3, "three")) // Some(TwoArgs("z", 3, "three"))
    }
    

    dsinfo also handles the name twoargs at call site (as template $macro) but I didn't implement this. I guess the name (if necessary) can be obtained from Position.ofMacroExpansion.sourceCode.


    Update. Here is implementation handling name of inline method (e.g. twoargs) using Scalameta + Semanticdb besides Scala 3 macros.

    import mypackage.TwoArgs
    
    object App {
      inline def twoargs(i: Int, s: String) =
        Macro.makeCallWithName[TwoArgs]("mypackage.TwoArgs.apply")
    
      inline def twoargs1(i: Int, s: String) =
        Macro.makeCallWithName[TwoArgs]("mypackage.TwoArgs.apply")
    
      def x() = twoargs(1, "one") // TwoArgs("x", 1, "one")
    
      def aMethod() = {
        val y = twoargs(2, "two") // TwoArgs("y", 2, "two")
      }
    
      val z = Some(twoargs1(3, "three")) // Some(TwoArgs("z", 3, "three"))
    }
    
    package mypackage
    
    case class TwoArgs(name : String, i : Int, s : String)
    
    import scala.annotation.experimental
    import scala.quoted.*
    
    object Macro {
      inline def makeCallWithName[T](inline methodName: String): T =
        ${makeCallWithNameImpl[T]('methodName)}
    
      @experimental
      def makeCallWithNameImpl[T](methodName: Expr[String])(using Quotes, Type[T]): Expr[T] = {
        import quotes.reflect.*
    
        val position = Position.ofMacroExpansion
        val scalaFile = position.sourceFile.getJPath.getOrElse(
          report.errorAndAbort(s"maybe virtual file, can't find path to position $position")
        )
        val inlineMethodSymbol =
          new SemanticdbInspector(scalaFile)
            .getInlineMethodSymbol(position.start, position.end)
            .getOrElse(report.errorAndAbort(s"can't find Scalameta symbol at position (${position.startLine},${position.startColumn})..(${position.endLine},${position.endColumn})=$position"))
    
        val methodNameStr = methodName.valueOrAbort
        val strs = methodNameStr.split('.')
        val moduleName = strs.init.mkString(".")
        val moduleSymbol = Symbol.requiredModule(moduleName)
    
        val shortMethodName = strs.last
        val ident = Ident(TermRef(moduleSymbol.termRef, shortMethodName))
    
        val owner = Symbol.spliceOwner.maybeOwner
    
        val macroApplication: Option[Tree] = {
          val (ownerName, ownerRhs) = owner.tree match {
            case ValDef(name, tpt, Some(rhs)) => (name, rhs)
            case DefDef(name, paramss, tpt, Some(rhs)) => (name, rhs)
            case t => report.errorAndAbort(s"can't find RHS of ${t.show}")
          }
    
          val treeAccumulator = new TreeAccumulator[Option[Tree]] {
            override def foldTree(acc: Option[Tree], tree: Tree)(owner: Symbol): Option[Tree] = tree match {
              case Apply(fun, args) if tree.pos == position /* fun.symbol.fullName == inlineMethodSymbol */ =>
                Some(Apply(ident, Literal(StringConstant(ownerName)) :: args))
              case _ => foldOverTree(acc, tree)(owner)
            }
          }
          treeAccumulator.foldTree(None, ownerRhs)(ownerRhs.symbol)
        }
    
        val res = macroApplication
          .getOrElse(report.errorAndAbort(s"can't find application of $inlineMethodSymbol in RHS of $owner"))
        report.info(res.show)
        res.asExprOf[T]
      }
    }
    
    import java.nio.file.{Path, Paths}
    import scala.io
    import scala.io.BufferedSource
    import scala.meta.*
    import scala.meta.interactive.InteractiveSemanticdb
    import scala.meta.internal.semanticdb.{ClassSignature, Locator, Range, SymbolInformation, SymbolOccurrence, TextDocument, TypeRef}
    
    class SemanticdbInspector(val scalaFile: Path) {
      val scalaFileStr = scalaFile.toString
    
      var textDocuments: Seq[TextDocument] = Seq()
      Locator(
        Paths.get(scalaFileStr + ".semanticdb")
      )((path, textDocs) => {
        textDocuments ++= textDocs.documents
      })
    
      val bufferedSource: BufferedSource = io.Source.fromFile(scalaFileStr)
      val source = try bufferedSource.mkString finally bufferedSource.close()
    
      extension (tree: Tree) {
        def occurence: Option[SymbolOccurrence] = {
          val treeRange = Range(tree.pos.startLine, tree.pos.startColumn, tree.pos.endLine, tree.pos.endColumn)
          textDocuments.flatMap(_.occurrences)
            .find(_.range.exists(occurrenceRange => treeRange == occurrenceRange))
        }
    
        def info: Option[SymbolInformation] = occurence.flatMap(_.symbol.info)
      }
    
      extension (symbol: String) {
        def info: Option[SymbolInformation] = textDocuments.flatMap(_.symbols).find(_.symbol == symbol)
      }
    
      def getInlineMethodSymbol(startOffset: Int, endOffset: Int): Option[String] = {
        def translateScalametaToMacro3(symbol: String): String =
          symbol
            .stripPrefix("_empty_/")
            .stripSuffix("().")
            .replace(".", "$.")
            .replace("/", ".")
    
        dialects.Scala3(source).parse[Source].get.collect {
          case [email protected](fun, args) if t.pos.start == startOffset && t.pos.end == endOffset =>
            fun.info.map(_.symbol)
        }.headOption.flatten.map(translateScalametaToMacro3)
      }
    }
    
    lazy val scala3V = "3.1.3"
    lazy val scala2V = "2.13.8"
    lazy val scalametaV = "4.5.13"
    
    lazy val root = project
      .in(file("."))
      .settings(
        name := "scala3demo",
        version := "0.1.0-SNAPSHOT",
        scalaVersion := scala3V,
        libraryDependencies ++= Seq(
          "org.scalameta" %% "scalameta" % scalametaV cross CrossVersion.for3Use2_13,
          "org.scalameta" % s"semanticdb-scalac_$scala2V" % scalametaV,
        ),
        scalacOptions ++= Seq(
          "-Yretain-trees",
        ),
        semanticdbEnabled := true,
      )
    

    By the way, Semantidb can't be replaced by Tasty here because when a macro in App is being expanded, the file App.scala.semantidb already exists (it's generated early, at frontend phase of compilation) but App.tasty hasn't yet (it appears when App has been compiled i.e. after expansion of the macro, at pickler phase).

    .scala.semanticdb file will appear even if .scala file doesn't compile (e.g. if there is an error in macro expansion) but .tasty file won't.

    scala.meta parent of parent of Defn.Object

    Is it possible to using macro to modify the generated code of structural-typing instance invocation?

    Scala conditional compilation

    Macro annotation to override toString of Scala function

    How to merge multiple imports in scala?

    How to get the type of a variable with scalameta if the decltpe is empty?


    See also https://github.com/lampepfl/dotty-macro-examples/tree/main/accessEnclosingParameters


    Simplified version:

    import scala.quoted.*
    
    inline def makeCallWithName[T](inline methodName: String): T =
      ${makeCallWithNameImpl[T]('methodName)}
    
    def makeCallWithNameImpl[T](methodName: Expr[String])(using Quotes, Type[T]): Expr[T] = {
      import quotes.reflect.*
    
      val position = Position.ofMacroExpansion
    
      val methodNameStr = methodName.valueOrAbort
      val strs = methodNameStr.split('.')
      val moduleName = strs.init.mkString(".")
      val moduleSymbol = Symbol.requiredModule(moduleName)
      val shortMethodName = strs.last
      val ident = Ident(TermRef(moduleSymbol.termRef, shortMethodName))
    
      val owner0 = Symbol.spliceOwner.maybeOwner
    
      val ownerName = owner0.tree match {
        case ValDef(name, _, _) => name
        case DefDef(name, _, _, _) => name
        case t => report.errorAndAbort(s"unexpected tree shape: ${t.show}")
      }
    
      val owner = if owner0.isLocalDummy then owner0.maybeOwner else owner0
      
      val macroApplication: Option[Tree] = {
        val treeAccumulator = new TreeAccumulator[Option[Tree]] {
          override def foldTree(acc: Option[Tree], tree: Tree)(owner: Symbol): Option[Tree] = tree match {
            case _ if tree.pos == position => Some(tree)
            case _ => foldOverTree(acc, tree)(owner)
          }
        }
        treeAccumulator.foldTree(None, owner.tree)(owner)
      }
    
      val res = macroApplication.getOrElse(
        report.errorAndAbort("can't find macro application")
      ) match {
        case Apply(_, args) => Apply(ident, Literal(StringConstant(ownerName)) :: args)
        case t => report.errorAndAbort(s"unexpected shape of macro application: ${t.show}")
      }
      report.info(res.show)
      res.asExprOf[T]
    }