Search code examples
scalaapache-sparkreflectionanonymous-functionscala-reflect

Determine the function signature of an anonymous function in scala


The context is to register a UserDefinedFunction(UDF) in spark, where the UDF is an anonymous function obtained via reflection. Since the function signature of the function is determined at runtime, I was wondering whether it is possible to do so.

Say the function impl() returns an anonymous function:

trait Base {}
class A extends Base{
  def impl(): Function1[Int, String] = new Function1[Int, String] {
    def apply(x: Int): String = "ab" + x.toString
  }
}
val classes = reflections.getSubTypesOf(classOf[Base]).toSet[Class[_ <: Base]].toList

and I obtain the anonymous function in another place:

val clazz = classes(0)
val instance = clazz.newInstance()
val impl = clazz.getDeclaredMethod("impl").invoke(instance)

Now, impl holds the anonymous function but I do not know its signature, and I'd like to ask whether we can convert it into a correct function instance:

impl.asInstanceOf[Function1[Int, String]]   // How to determine the function signature of the anonymous function, in this case Function1[Int, String]?

Since scala does not support generic function, I first consider getting the runtime type of the function:

import scala.reflect.runtime.universe.{TypeTag, typeTag}
def getTypeTag[T: TypeTag](obj: T) = typeTag[T]
val typeList = getTypeTag(impl).tpe.typeArgs

It will return List(Int, String), but I fail to recognize the correct function template via reflection.

Update: if the classes are defined as follows:

trait Base {}
class A extends Base{
  def impl(x: Int): String = {
    "ab" + x.toString
  }
}

where impl is the function itself and we do not know its function signature, can the impl function still be registered?


Solution

  • The context is to register a UserDefinedFunction(UDF) in spark, where the UDF is an anonymous function obtained via reflection. Since the function signature of the function is determined at runtime, I was wondering whether it is possible to do so.

    Normally you register a UDF as follows

    import org.apache.spark.sql.SparkSession
    
    object App {
      val spark = SparkSession.builder
        .master("local")
        .appName("Spark app")
        .getOrCreate()
    
      def impl(): Int => String = x => "ab" + x.toString
    
      spark.udf.register("foo", impl())
    
      def main(args: Array[String]): Unit = {
        spark.sql("""SELECT foo(10)""").show()
        //+-------+
        //|foo(10)|
        //+-------+
        //|   ab10|
        //+-------+
      }
    }
    

    The signature of register is

    def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction
    

    aka

    def register[RT, A1](name: String, func: Function1[A1, RT])(implicit
      ttag:  TypeTag[RT],
      ttag1: TypeTag[A1]
    ): UserDefinedFunction
    

    What TypeTag normally does is persisting a type information from compile time to runtime.

    So in order to call register you either have to know types at compile time or have to know how to construct type tags at runtime.

    If you don't have access to how impl() is constructed at runtime and you don't have (at least at runtime) the information about types/type tags at all then unfortunately this type information is irreversibly lost because of the type erasure (Function1[Int, String] is just Function1[_,_] at runtime)

    def impl(): Any = (x: Int) => "ab" + x.toString
    

    But it's possible that you have access to how impl() is constructed at runtime and you know (at least at runtime) the information about types/type tags. So I assume that you don't have types Int, String statically and you can't call typeTag[Int], typeTag[String] (as I do below) but you have somehow runtime objects of Type/TypeTag

    import org.apache.spark.sql.catalyst.ScalaReflection.universe._
    
    def impl(): Any = (x: Int) => "ab" + x.toString
    val ttag1 = typeTag[Int]    // actual definition is probably different
    val ttag  = typeTag[String] // actual definition is probably different
    

    In such case you can call register resolving implicits explicitly

    spark.udf.register("foo", impl().asInstanceOf[Function1[_,_]])(ttag.asInstanceOf[TypeTag[_]], ttag1.asInstanceOf[TypeTag[_]])
    

    Well, this doesn't compile because of existential types but you can trick the compiler

    type A
    type B
    spark.udf.register("foo", impl().asInstanceOf[A => B])(ttag.asInstanceOf[TypeTag[B]], ttag1.asInstanceOf[TypeTag[A]])
    

    https://gist.github.com/DmytroMitin/0b3660d646f74fb109665bad41b3ae9f

    Alternatively you can use runtime compilation (creating a new compile time inside the runtime)

    import org.apache.spark.sql.catalyst.ScalaReflection
    import ScalaReflection.universe._
    import scala.tools.reflect.ToolBox // libraryDependencies += scalaOrganization.value % "scala-compiler" % scalaVersion.value
    
    val rm = ScalaReflection.mirror
    val tb = rm.mkToolBox()
    tb.eval(q"""App.spark.udf.register("foo", App.impl().asInstanceOf[$ttag1 => $ttag])""")
    

    https://gist.github.com/DmytroMitin/5b5dd4d7db0d0eebb51dd8c16735e0fb

    You should provide some code how you construct impl() and we'll see whether it's possible to restore the types.

    Spark registered a Scala object all of the methods as a UDF

    scala cast object based on reflection symbol


    Update. After you get val impl = clazz.getDeclaredMethod("impl").invoke(instance) it's too late to restore function types (you can check that typeList is empty). Where function type (or type tag) should be captured is somewhere not too far from class A, maybe inside A or outside A but when Int, String are not lost yet. What TypeTag can do is persisting type information from compile time to runtime, it can't restore type information at runtime if it's lost.

    import org.apache.spark.sql.catalyst.ScalaReflection
    import ScalaReflection.universe._
    import org.apache.spark.sql.SparkSession
    import org.reflections.Reflections
    import scala.jdk.CollectionConverters._
    import scala.reflect.api
    
    object App {
      def getType[T: TypeTag](obj: T) = typeOf[T]
    
      trait Base
      class A extends Base {
        def impl(): Int => String = x => "ab" + x.toString 
    
           // NotSerializableException
        //def impl(): Function1[Int, String] = new Function1[Int, String] {
        //  def apply(x: Int): String = "ab" + x.toString
        //}
    
        val tpe = getType(impl())
      }
    
      val reflections = new Reflections()
      val classes: List[Class[_ <: Base]] = reflections.getSubTypesOf(classOf[Base]).asScala.toList
    
      val clazz = classes(0)
      val instance = clazz.newInstance()
      val impl = clazz.getDeclaredMethod("impl").invoke(instance)
      val functionType = clazz.getDeclaredMethod("tpe").invoke(instance).asInstanceOf[Type]
      val List(argType, returnType) = functionType.typeArgs
    
      val spark = SparkSession.builder()
        .master("local")
        .appName("Spark app")
        .getOrCreate()
    
      val rm = ScalaReflection.mirror
    
      // (*)
      def typeToTypeTag[T](tpe: Type): TypeTag[T] =
        TypeTag(rm, new api.TypeCreator {
          def apply[U <: api.Universe with Singleton](m: api.Mirror[U]) =
            tpe.asInstanceOf[U#Type]
        })
    
    //  type X
    //  type Y
    //  spark.udf.register("foo", impl.asInstanceOf[X => Y])(
    //    typeToTypeTag[Y](returnType),
    //    typeToTypeTag[X](argType)
    //  )
    
      impl match {
        case impl: Function1[x, y] => spark.udf.register("foo", impl)(
          typeToTypeTag[y](returnType),
          typeToTypeTag[x](argType)
        )
      }
    
      def main(args: Array[String]): Unit = {
        spark.sql("""SELECT foo(10)""").show()
      }
    
    }
    

    https://gist.github.com/DmytroMitin/2ebfae922f8a467d01b6ef18c8b8e5ad

    (*) Get a TypeTag from a Type?

    Now spark.sql("""SELECT foo(10)""").show() throws java.io.NotSerializableException but I guess it's not related to reflection.

    Alternatively you can use runtime compilation (instead of manual resolution of implicits and construction of type tags from types)

    import scala.tools.reflect.ToolBox
    
    val rm = ScalaReflection.mirror
    val tb = rm.mkToolBox()
    tb.eval(q"""App.spark.udf.register("foo", App.impl.asInstanceOf[$functionType])""")
    

    https://gist.github.com/DmytroMitin/ba469faeca2230890845e1532b36e2a1

    One more option is to request the return type of method impl() as soon as we get class A (outside A)

    class A extends Base {
      def impl(): Int => String = x => "ab" + x.toString
    }
    
    // ...
    val functionType = rm.classSymbol(clazz).typeSignature.decl(TermName("impl")).asMethod.returnType
    val List(argType, returnType) = functionType.typeArgs
    

    https://gist.github.com/DmytroMitin/3bd2c19d158f8241a80952c397ee5e09


    Update 2. If the methods are defined as follows:

    class A extends Base{
      def impl(x: Int): String = {
        "ab" + x.toString
      }
    }
    

    then runtime compilation normally should be

    val rm = ScalaReflection.mirror
    val classSymbol = rm.classSymbol(clazz)
    val tb = rm.mkToolBox()
    
    tb.eval(q"""App.spark.udf.register("foo", (new $classSymbol).$methodSymbol(_))""")
    

    or

    tb.eval(q"""App.spark.udf.register("foo", (new $classSymbol).impl(_))""")
    

    but now with Spark it produces ClassCastException: cannot assign instance of java.lang.invoke.SerializedLambda to field org.apache.spark.sql.catalyst.expressions.ScalaUDF.f of type scala.Function1 in instance of org.apache.spark.sql.catalyst.expressions.ScalaUDF similarly to Spark registered a Scala object all of the methods as a UDF

    https://gist.github.com/DmytroMitin/b0f110f4cf15e2dfd4add70f7124a7b6

    But ordinary Scala runtime reflection seems to work

    val rm = ScalaReflection.mirror
    val classSymbol = rm.classSymbol(clazz)
    val methodSymbol = classSymbol.typeSignature.decl(TermName("impl")).asMethod
    val returnType = methodSymbol.returnType
    val argType = methodSymbol.paramLists.head.head.typeSignature
    
    val constructorSymbol = classSymbol.typeSignature.decl(termNames.CONSTRUCTOR).asMethod
    val instance = rm.reflectClass(classSymbol).reflectConstructor(constructorSymbol)()
    val impl: Any => Any = rm.reflect(instance).reflectMethod(methodSymbol)(_)
    
    def typeToTypeTag[T](tpe: Type): TypeTag[T] =
      TypeTag(rm, new api.TypeCreator {
        def apply[U <: api.Universe with Singleton](m: api.Mirror[U]) =
          tpe.asInstanceOf[U#Type]
      })
    
    impl match {
      case impl: Function1[x, y] => spark.udf.register("foo", impl)(
        typeToTypeTag[y](returnType),
        typeToTypeTag[x](argType)
      )
    }
    

    https://gist.github.com/DmytroMitin/763751096fe9cdb2e0d18ae4b9290a54


    Update 3. One more approach is to use compile-time reflection (macros) rather than runtime reflection if you have enough information at compile time (e.g. if all the classes are known at compile time)

    import scala.collection.mutable
    import scala.language.experimental.macros
    import scala.reflect.macros.blackbox
    
    object Macros {
      def registerMethod[A](): Unit = macro registerMethodImpl[A]
    
      def registerMethodImpl[A: c.WeakTypeTag](c: blackbox.Context)(): c.Tree = {
        import c.universe._
        val A = weakTypeOf[A]
    
        var children = mutable.Seq[Type]()
    
        val traverser = new Traverser {
          override def traverse(tree: Tree): Unit = {
            tree match {
              case _: ClassDef =>
                val tpe = tree.symbol.asClass.toType
                if (tpe <:< A && !(tpe =:= A)) children :+= tpe
              case _ =>
            }
    
            super.traverse(tree)
          }
        }
    
        c.enclosingRun.units.foreach(unit => traverser.traverse(unit.body))
    
        val calls = children.map(tpe =>
          q"""spark.udf.register("foo", (new $tpe).impl(_))"""
        )
    
        q"..$calls"
      }
    }
    
    // in a different subproject
    
    import org.apache.spark.sql.SparkSession
    
    object App {
      trait Base
    
      class A extends Base {
        def impl(x: Int): String = "ab" + x.toString
      }
    
      val spark = SparkSession.builder()
        .master("local")
        .appName("Spark app")
        .getOrCreate()
    
      Macros.registerMethod[Base]()
    
      def main(args: Array[String]): Unit = {
        spark.sql("""SELECT foo(10)""").show()
      }
    }
    

    https://gist.github.com/DmytroMitin/6623f1f900330f8341f209e1347a0007

    Shapeless - How to derive LabelledGeneric for Coproduct (KnownSubclasses)


    Update 4. If we replace val clazz = classes.head with classes.foreach(clazz => ... then issues with NotSerializableException can be fixed with inlining

    import scala.language.experimental.macros
    import scala.reflect.macros.blackbox
    
    object Macros {
      def registerMethod(clazz: Class[_]): Unit = macro registerMethodImpl
    
      def registerMethodImpl(c: blackbox.Context)(clazz: c.Tree): c.Tree = {
        import c.universe._
    
        val ScalaReflection = q"_root_.org.apache.spark.sql.catalyst.ScalaReflection"
        val rm = q"$ScalaReflection.mirror"
        val ru = q"$ScalaReflection.universe"
        val classSymbol = q"$rm.classSymbol($clazz)"
        val methodSymbol = q"""$classSymbol.typeSignature.decl($ru.TermName("impl")).asMethod"""
        val returnType = q"$methodSymbol.returnType"
        val argType = q"$methodSymbol.paramLists.head.head.typeSignature"
    
        val constructorSymbol = q"$classSymbol.typeSignature.decl($ru.termNames.CONSTRUCTOR).asMethod"
        val instance = q"$rm.reflectClass($classSymbol).reflectConstructor($constructorSymbol).apply()"
        val impl1 = q"(x: Any) => $rm.reflect($instance).reflectMethod($methodSymbol).apply(x)"
        val api = q"_root_.scala.reflect.api"
    
        def typeToTypeTag(T: Tree, tpe: Tree): Tree =
          q"""
            $ru.TypeTag[$T]($rm, new $api.TypeCreator {
              override def apply[U <: $api.Universe with _root_.scala.Singleton](m: $api.Mirror[U]) =
                $tpe.asInstanceOf[U#Type]
            })
          """
    
        val impl2 = TermName(c.freshName("impl2"))
        val x = TypeName(c.freshName("x"))
        val y = TypeName(c.freshName("y"))
        q"""
          $impl1 match {
            case $impl2: _root_.scala.Function1[$x, $y] => spark.udf.register("foo", $impl2)(
              ${typeToTypeTag(tq"$y", returnType)},
              ${typeToTypeTag(tq"$x", argType)}
            )
          }
        """
      }
    }
    
    // in a different subproject
    
    import org.apache.spark.sql.SparkSession
    import org.reflections.Reflections
    import scala.jdk.CollectionConverters._
    
    trait Base
    class A extends Base /*with Serializable*/ {
      def impl(x: Int): String = "ab" + x.toString
    }
    
    object App {
      val spark: SparkSession = SparkSession.builder()
        .master("local")
        .appName("Spark app")
        .getOrCreate()
    
      val reflections = new Reflections()
      val classes: List[Class[_ <: Base]] = reflections.getSubTypesOf(classOf[Base]).asScala.toList
    
      classes.foreach(clazz =>
        Macros.registerMethod(clazz)
      )
    
      def main(args: Array[String]): Unit = {
        spark.sql("""SELECT foo(10)""").show()
      }
    }
    

    https://gist.github.com/DmytroMitin/c926158a9ff94a6539097c603bbedf6a