Search code examples
scalaapache-sparkscala-reflectbinary-compatibility

Reflection to call method that had its name changed in an upgrade?


My code compiled with Spark 3.1.2:

private def work(plan: LogicalPlan): LogicalPlan = {
  val result = plan.transformDown {
    // irrelevant details
  }
}

When run with Spark 3.3.0, I run into:

java.lang.NoSuchMethodError: org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.transformDown(Lscala/PartialFunction;)Lorg/apache/spark/sql/catalyst/plans/logical/LogicalPlan;

which makes sense because transformDown doesn't exist anymore in Spark 3.3.0 and seems to now be transformDownWithPruning.

I want to make this work via reflection with the logic:

if (sparkVersion = 3.1.2) plan.transformDown
else if (sparkVersion = 3.3.0) plan.transformDownWithPruning

I know you can call methods by exact name via reflection, but is there any way to get the method to call based on the method name containing a string? In this case it would if it contains "transformDown". I've been writing up the code below, but something like this:

private def transformWithReflection(plan: LogicalPlan) = {
    val runtime = scala.reflect.runtime.universe
    val mirror = runtime.runtimeMirror(getClass.getClassLoader)
    val instanceMirror = mirror.reflect(plan)

    //We target the transformDown method
    val transformMethodAlternatives = runtime
      .typeOf[LogicalPlan]
      .decl(runtime.TermName("transformDown")) // this looks for exact name right?
      .asTerm
      .alternatives
    
    ...
    // call reflected method
}

Or maybe I can get a list of all the methods under this class and filter them by "contains transformDown" which should only be 1 per list and then call that?


Solution

  • I guess there is some misunderstanding.

    The class org.apache.spark.sql.catalyst.plans.logical.LogicalPlan itself doesn't have method transformDown either in Spark 3.1.2 or 3.3.0

    runtime
      .typeOf[LogicalPlan]
      .decls
      .foreach(println)
    
    // 3.1.2
    constructor LogicalPlan
    method metadataOutput
    method isStreaming
    method verboseStringWithSuffix
    method maxRows
    method maxRowsPerPartition
    lazy value resolved
    method statePrefix
    method childrenResolved
    method resolve
    lazy value childAttributes
    lazy value childMetadataAttributes
    lazy value outputAttributes
    lazy value outputMetadataAttributes
    method resolveChildren
    method resolve
    method resolveQuoted
    method refresh
    method outputOrdering
    method sameOutput
    
    // 3.3.0
    constructor LogicalPlan
    method metadataOutput
    method isStreaming
    lazy value _isStreaming
    method verboseStringWithSuffix
    method maxRows
    method maxRowsPerPartition
    lazy value resolved
    method statePrefix
    method childrenResolved
    method resolve
    lazy value childAttributes
    lazy value childMetadataAttributes
    lazy value outputAttributes
    lazy value outputMetadataAttributes
    method resolveChildren
    method resolve
    method resolveQuoted
    method refresh
    method outputOrdering
    method sameOutput
    

    https://www.diffchecker.com/gvueXinY/

    It's an inherited member and exists both in 3.1.2 and 3.3.0

    runtime
      .typeOf[LogicalPlan]
      .members
      .foreach(println)
    
    // 3.1.2
    method sameOutput
    method outputOrdering
    method refresh
    method resolveQuoted
    method resolve
    method resolveChildren
    lazy value outputMetadataAttributes
    lazy value outputAttributes
    lazy value childMetadataAttributes
    lazy value childAttributes
    method resolve
    method childrenResolved
    method statePrefix
    lazy value resolved
    method maxRowsPerPartition
    method maxRows
    method verboseStringWithSuffix
    method isStreaming
    method metadataOutput
    constructor LogicalPlan
    method initializeForcefully
    method initializeLogIfNecessary$default$2
    method initializeLogIfNecessary
    method initializeLogIfNecessary
    method isTraceEnabled
    method logError
    method logWarning
    method logTrace
    method logDebug
    method logInfo
    method logError
    method logWarning
    method logTrace
    method logDebug
    method logInfo
    method log
    method logName
    method $init$
    lazy value validConstraints
    lazy value constraints
    method constructIsNotNullConstraints
    method inferAdditionalConstraints
    method invalidateStatsCache
    variable statsCache
    variable statsCache
    method stats
    method clone
    method transformAllExpressions
    method transformUp
    method transformDown                            // <--- HERE !!!
    method assertNotAnalysisRule
    method resolveExpressions
    method transformUpWithNewOutput
    method resolveOperatorsUpWithNewOutput
    method resolveOperatorsDown
    method resolveOperatorsUp
    method resolveOperators
    method analyzed
    method setAnalyzed
    lazy value allAttributes
    method semanticHash
    method sameResult
    method doCanonicalize
    lazy value canonicalized
    method isCanonicalizedPlan
    method innerChildren
    method collectWithSubqueries
    method subqueriesAll
    method subqueries
    method formattedNodeName
    method verboseStringWithOperatorId
    method simpleStringWithNodeId
    method verboseString
    method simpleString
    method printSchema
    method schemaString
    lazy value schema
    method transformUpWithNewOutput$default$3
    method transformUpWithNewOutput$default$2
    method expressions
    method mapExpressions
    method transformExpressionsUp
    method transformExpressionsDown
    method transformExpressions
    method missingInput
    lazy value references
    method producedAttributes
    method inputSet
    lazy value outputSet
    method conf
    method jsonFields
    method prettyJson
    method toJSON
    method asCode
    method generateTreeString$default$9
    method generateTreeString$default$6
    method generateTreeString$default$5
    method generateTreeString
    method p
    method apply
    method numberedTreeString
    method treeString
    method treeString$default$4
    method treeString$default$3
    method treeString$default$2
    method treeString
    method treeString
    method toString
    method argString
    method stringArgs
    method nodeName
    method makeCopy
    method otherCopyArgs
    method mapChildren
    method transform
    method withNewChildren
    method mapProductIterator
    method collectFirst
    method collectLeaves
    method collect
    method flatMap
    method map
    method foreachUp
    method foreach
    method find
    method fastEquals
    method hashCode
    lazy value containsChild
    method unsetTagValue
    method getTagValue
    method setTagValue
    method copyTagsFrom
    value origin
    method productPrefix
    method productIterator
    method synchronized
    method ##
    method !=
    method ==
    method ne
    method eq
    method notifyAll
    method notify
    method getClass
    method equals
    method wait
    method wait
    method wait
    method finalize
    method asInstanceOf
    method isInstanceOf
    method output
    method children
    method productArity
    method productElement
    method canEqual
    
    // 3.3.0
    method sameOutput
    method outputOrdering
    method refresh
    method resolveQuoted
    method resolve
    method resolveChildren
    lazy value outputMetadataAttributes
    lazy value outputAttributes
    lazy value childMetadataAttributes
    lazy value childAttributes
    method resolve
    method childrenResolved
    method statePrefix
    lazy value resolved
    method maxRowsPerPartition
    method maxRows
    method verboseStringWithSuffix
    lazy value _isStreaming
    method isStreaming
    method metadataOutput
    constructor LogicalPlan
    method initializeForcefully
    method initializeLogIfNecessary$default$2
    method initializeLogIfNecessary
    method initializeLogIfNecessary
    method isTraceEnabled
    method logError
    method logWarning
    method logTrace
    method logDebug
    method logInfo
    method logError
    method logWarning
    method logTrace
    method logDebug
    method logInfo
    method log
    method logName
    method $init$
    lazy value validConstraints
    lazy value constraints
    method constructIsNotNullConstraints
    method inferAdditionalConstraints
    lazy value distinctKeys
    method invalidateStatsCache
    variable statsCache
    variable statsCache
    method stats
    method clone
    method transformAllExpressionsWithPruning$default$2
    method transformAllExpressionsWithPruning
    method transformUpWithPruning$default$2
    method transformUpWithPruning
    method transformDownWithPruning$default$2
    method transformDownWithPruning
    method assertNotAnalysisRule
    method resolveExpressionsWithPruning$default$2
    method resolveExpressionsWithPruning
    method resolveExpressions
    method updateOuterReferencesInSubquery
    method transformUpWithNewOutput
    method resolveOperatorsUpWithNewOutput
    method resolveOperatorsDownWithPruning$default$2
    method resolveOperatorsDownWithPruning
    method resolveOperatorsDown
    method resolveOperatorsUpWithPruning$default$2
    method resolveOperatorsUpWithPruning
    method resolveOperatorsUp
    method resolveOperatorsWithPruning$default$2
    method resolveOperatorsWithPruning
    method resolveOperators
    method analyzed
    method setAnalyzed
    lazy value allAttributes
    method semanticHash
    method sameResult
    method doCanonicalize
    lazy value canonicalized
    method isCanonicalizedPlan
    method innerChildren
    method collectWithSubqueries
    method transformDownWithSubqueriesAndPruning$default$2
    method transformDownWithSubqueriesAndPruning
    method transformDownWithSubqueries
    method transformUpWithSubqueries
    method transformWithSubqueries
    method subqueriesAll
    lazy value subqueries
    method formattedNodeName
    method verboseStringWithOperatorId
    method simpleStringWithNodeId
    method verboseString
    method simpleString
    method printSchema
    method schemaString
    lazy value schema
    method rewriteAttrs
    method transformUpWithNewOutput$default$3
    method transformUpWithNewOutput$default$2
    method expressions
    method transformAllExpressions
    method mapExpressions
    method transformExpressionsUpWithPruning$default$2
    method transformExpressionsUpWithPruning
    method transformExpressionsUp
    method transformExpressionsDownWithPruning$default$2
    method transformExpressionsDownWithPruning
    method transformExpressionsDown
    method transformExpressionsWithPruning$default$2
    method transformExpressionsWithPruning
    method transformExpressions
    method missingInput
    lazy value deterministic
    lazy value references
    method producedAttributes
    method inputSet
    lazy value treePatternBits
    lazy value outputSet
    method conf
    method jsonFields
    method prettyJson
    method toJSON
    method asCode
    method generateTreeString$default$9
    method generateTreeString$default$6
    method generateTreeString$default$5
    method generateTreeString
    method p
    method apply
    method numberedTreeString
    method treeString
    method treeString$default$4
    method treeString$default$3
    method treeString$default$2
    method treeString
    method treeString
    method toString
    method argString
    method stringArgs
    method nodeName
    method makeCopy
    method otherCopyArgs
    method mapChildren
    method transformUpWithBeforeAndAfterRuleOnChildren$default$2
    method transformUpWithBeforeAndAfterRuleOnChildren
    method transformUp
    method transformDown                       // <--- HERE !!!
    method transformWithPruning$default$2
    method transformWithPruning
    method transform
    method legacyWithNewChildren
    method withNewChildren
    method mapProductIterator
    method collectFirst
    method collectLeaves
    method collect
    method flatMap
    method map
    method foreachUp
    method foreach
    method exists
    method find
    method fastEquals
    method hashCode
    lazy value containsChild
    method unsetTagValue
    method getTagValue
    method setTagValue
    method copyTagsFrom
    method isRuleIneffective
    method markRuleAsIneffective
    value nodePatterns
    method getDefaultTreePatternBits
    value origin
    method containsAnyPattern
    method containsAllPatterns
    method containsPattern
    method productPrefix
    method productIterator
    method synchronized
    method ##
    method !=
    method ==
    method ne
    method eq
    method notifyAll
    method notify
    method getClass
    method equals
    method wait
    method wait
    method wait
    method finalize
    method asInstanceOf
    method isInstanceOf
    method output
    method withNewChildrenInternal
    method children
    method productArity
    method productElement
    method canEqual
    

    https://www.diffchecker.com/gKx1ZcYM/

    What changed in 3.3.0 in comparison with 3.1.2 is the signature of the method (since the method is inherited you need .member rather than .decl)

    println(
      runtime
        .typeOf[LogicalPlan]
        .member(runtime.TermName("transformDown"))
        .typeSignature
    )
    
    // 3.1.2
    (rule: PartialFunction[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan,org.apache.spark.sql.catalyst.plans.logical.LogicalPlan])org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
    
    // 3.3.0
    (rule: PartialFunction[BaseType,BaseType])BaseType
    

    The method is inherited from the class org.apache.spark.sql.catalyst.trees.TreeNode

    https://github.com/apache/spark/blob/v3.1.2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala#L316

    // 3.1.2
    def transformDown(rule: PartialFunction[BaseType, BaseType]): BaseType = {
      ...
    

    https://github.com/apache/spark/blob/v3.3.0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala#L559-L577

    // 3.3.0
    def transformDown(rule: PartialFunction[BaseType, BaseType]): BaseType = {
      transformDownWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule)
    }
    
    ...
    
    def transformDownWithPruning(cond: TreePatternBits => Boolean,
      ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[BaseType, BaseType])
    : BaseType = {
      ...
    

    What should work both in 3.1.2 and 3.3.0 is static upcasting (with type ascription) to TreeNode

    def work(plan: LogicalPlan): LogicalPlan = {
      val result: LogicalPlan = (plan: TreeNode[LogicalPlan]).transformDown {
        case x => x
      }
      result
    }
    

    Runtime reflection also should work both in 3.1.2 and 3.3.0

    import org.apache.spark.sql.catalyst.ScalaReflection
    import ScalaReflection.universe._
    
    def work(plan: LogicalPlan): LogicalPlan = {
      val runtime = ScalaReflection.mirror
    
      val method = runtime
        .typeOf[LogicalPlan]
        .member(TermName("transformDown"))
        .asMethod
    
      val result: LogicalPlan = runtime.reflect(plan).reflectMethod(method).apply( {
        case x => x
      }: PartialFunction[LogicalPlan, LogicalPlan]).asInstanceOf[LogicalPlan]
    
      result
    }
    

    In principle, when you'd like to call different methods (this doesn't seem to be your case) you can do this for example with runtime reflection

    val name =
      if (sparkVersion == "3.1.2") "transformDown"
      else if (sparkVersion == "3.3.0") "transformDownWithPruning"
      else ???
    
    val method = runtime
      .typeOf[LogicalPlan]
      .member(TermName(name))
      .asMethod
    
    val method = rm
      .typeOf[LogicalPlan]
      .members
      .find(_ == TermName(name))
      .get
      .asMethod
    
    val method = rm
      .typeOf[LogicalPlan]
      .members
      .filter(_.name.toString.startsWith(name))
      .head
      .asMethod
    

    or runtime compilation

    // libraryDependencies += scalaOrganization.value % "scala-compiler" % scalaVersion.value exclude("org.scala-lang.modules", "scala-xml_2.12")
    import scala.tools.reflect.ToolBox
    
    val tb = runtime.mkToolBox()
    
    tb.eval(
      q"""
        (_: ${typeOf[LogicalPlan]}).${TermName(name)} {
          case x => x
        }
      """
    ).asInstanceOf[LogicalPlan => LogicalPlan].apply(plan)
    

    NoSuchMethodError: scala.tools.nsc.Settings.usejavacp()Lscala/tools/nsc/settings/AbsSettings$AbsSetting;