Search code examples
scalalambdaclosuresscala-3dotty

Finding lambda captured values (or their classes) in Scala 3


I'm looking for a way to find values (or their classes) that are captured by lambda (for serialization - something like Spark) in Scala 3 (I don't need Scala 2 support):

val a = "abc"
val f = () => a + "xyz"
serialize(f) // Should detect a / String as captured value

Doing this in runtime is kinda easy (iterating over f.getClass.getDeclaredFields), but I would like to do it in compile time.

I was trying to inspect time of lambda in Macro, but it's detected as plain scala.Function0 without any interesting info.

I wonder if I can do some tree inspection, but I would really like to avoid that - I feel like I would have to copy compiler internals to catch all edge cases.


Solution

  • Try the following macro

    import scala.quoted.*
    
    inline def serialize(x: Any): Unit = ${serializeImpl('x)}
    
    def serializeImpl(x: Expr[Any])(using Quotes): Expr[Unit] = {
      import quotes.reflect.*
    
      def owners(s: Symbol): List[Symbol] = s :: List.unfold(s)(s1 => Option.when(s1.maybeOwner != Symbol.noSymbol)((s1.maybeOwner, s1.maybeOwner)))
    
      val symbol = x.asTerm.underlying.symbol
      val rhs = symbol.tree match {
        case ValDef(_, _, Some(rhs)) => rhs
      }
    
      val traverser = new TreeTraverser {
        override def traverseTree(tree: Tree)(owner: Symbol): Unit = {
          tree match {
            case Ident(name) =>
              val symbol1 = tree.symbol
              val pos1 = symbol1.pos.get
              println(s"identifier: $name, defined inside lambda: ${owners(symbol1).contains(symbol)}, defined in current file: ${pos1.sourceFile == SourceFile.current}")
    
            case _ =>
          }
    
          super.traverseTree(tree)(owner)
        }
      }
    
      traverser.traverseTree(rhs)(rhs.symbol)
    
      '{()}
    }
    

    Usage:

    object App1 {
      val b = "bbb"
    }
    
    import App1.b
    
    object App {
      val a = "abc"
      val f = () => { val x = "uvw"; a + b + x + "xyz"}
      serialize(f)
    }
    
    //scalac: identifier: a, defined inside lambda: false, defined in current file: true
    //scalac: identifier: b, defined inside lambda: false, defined in current file: false
    //scalac: identifier: x, defined inside lambda: true, defined in current file: true
    //scalac: identifier: $anonfun, defined inside lambda: true, defined in current file: true