Search code examples
scalaapache-sparkcatalyst-optimizer

Rewrite LogicalPlan to push down udf from aggregate


I have defined an UDF which increases the input value by one, named "inc", this is the code of my udf

spark.udf.register("inc", (x: Long) => x + 1)

this is my test sql

val df = spark.sql("select sum(inc(vals)) from data")
df.explain(true)
df.show()

this is the optimized plan of that sql

== Optimized Logical Plan ==
Aggregate [sum(inc(vals#4L)) AS sum(inc(vals))#7L]
+- LocalRelation [vals#4L]

I want to rewrite the plan, and extract the "inc" from the "sum", just like python udf does. So, this is the optimized plan which I wanted.

Aggregate [sum(inc_val#6L) AS sum(inc(vals))#7L]
+- Project [inc(vals#4L) AS inc_val#6L]
   +- LocalRelation [vals#4L]

I have found that source code file "ExtractPythonUDFs.scala" provides similar function which works on PythonUDF, but it inserts a new node named "ArrowEvalPython", this is the logical plan of pythonudf.

== Optimized Logical Plan ==
Aggregate [sum(pythonUDF0#7L) AS sum(inc(vals))#4L]
+- Project [pythonUDF0#7L]
   +- ArrowEvalPython [inc(vals#0L)], [pythonUDF0#7L], 200
      +- Repartition 10, true
         +- RelationV2[vals#0L] parquet file:/tmp/vals.parquet

What I want to inset is just a "project node", I don't want to define a new node.


this is the test code of my project

import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, ScalaUDF}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule

object RewritePlanTest {

  case class UdfRule(spark: SparkSession) extends Rule[LogicalPlan] {

    def collectUDFs(e: Expression): Seq[Expression] = e match {
      case udf: ScalaUDF => Seq(udf)
      case _ => e.children.flatMap(collectUDFs)
    }

    override def apply(plan: LogicalPlan): LogicalPlan = plan match {
      case agg@Aggregate(g, a, _) if (g.isEmpty && a.length == 1) =>
        val udfs = agg.expressions.flatMap(collectUDFs)
        println("================")
        udfs.foreach(println)
        val test = udfs(0).isInstanceOf[NamedExpression]
        println(s"cast ScalaUDF to NamedExpression = ${test}")
        println("================")
        agg
      case _ => plan
    }
  }


  def main(args: Array[String]): Unit = {
    Logger.getLogger("org").setLevel(Level.WARN)

    val spark = SparkSession
      .builder()
      .master("local[*]")
      .appName("Rewrite plan test")
      .withExtensions(e => e.injectOptimizerRule(UdfRule))
      .getOrCreate()

    val input = Seq(100L, 200L, 300L)
    import spark.implicits._
    input.toDF("vals").createOrReplaceTempView("data")

    spark.udf.register("inc", (x: Long) => x + 1)

    val df = spark.sql("select sum(inc(vals)) from data")
    df.explain(true)
    df.show()
    spark.stop()
  }
}

I have extract ScalaUDF from the Aggregate node,

since the arguments needed for Project Node is Seq[NamedExpression]

case class Project(projectList: Seq[NamedExpression], child: LogicalPlan)

but it's failed to cast ScalaUDF to NamedExpression,

so I have no idea about how to construct the Project node.

Can someone give me some advices?

Thanks.


Solution

  • OK, finally I find way to so answer this question.

    Though ScalaUDF can't cast to NamedExpression, but Alias could.

    So, I create Alias from ScalaUDF, then construct Project.

    import org.apache.log4j.{Level, Logger}
    import org.apache.spark.sql.SparkSession
    import org.apache.spark.sql.catalyst.InternalRow
    import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
    import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, ExpectsInputTypes, ExprId, Expression, NamedExpression, ScalaUDF}
    import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LocalRelation, LogicalPlan, Project, Subquery}
    import org.apache.spark.sql.catalyst.rules.Rule
    import org.apache.spark.sql.types.{AbstractDataType, DataType}
    
    import scala.collection.mutable
    
    object RewritePlanTest {
    
      case class UdfRule(spark: SparkSession) extends Rule[LogicalPlan] {
    
        def collectUDFs(e: Expression): Seq[Expression] = e match {
          case udf: ScalaUDF => Seq(udf)
          case _ => e.children.flatMap(collectUDFs)
        }
    
        override def apply(plan: LogicalPlan): LogicalPlan = plan match {
          case agg@Aggregate(g, a, c) if g.isEmpty && a.length == 1 => {
            val udfs = agg.expressions.flatMap(collectUDFs)
            if (udfs.isEmpty) {
              agg
            } else {
              val alias_udf = for (i <- 0 until udfs.size) yield Alias(udfs(i), s"udf${i}")()
              val alias_set = mutable.HashMap[Expression, Attribute]()
              val proj = Project(alias_udf, c)
              alias_set ++= udfs.zip(proj.output)
              val new_agg = agg.withNewChildren(Seq(proj)).transformExpressionsUp {
                case udf: ScalaUDF if alias_set.contains(udf) => alias_set(udf)
              }
              println("====== new agg ======")
              println(new_agg)
              new_agg
            }
          }
          case _ => plan
        }
      }
    
    
      def main(args: Array[String]): Unit = {
        Logger.getLogger("org").setLevel(Level.WARN)
    
        val spark = SparkSession
          .builder()
          .master("local[*]")
          .appName("Rewrite plan test")
          .withExtensions(e => e.injectOptimizerRule(UdfRule))
          .getOrCreate()
    
        val input = Seq(100L, 200L, 300L)
        import spark.implicits._
        input.toDF("vals").createOrReplaceTempView("data")
    
        spark.udf.register("inc", (x: Long) => x + 1)
    
        val df = spark.sql("select sum(inc(vals)) from data where vals > 100")
        //    val plan = df.queryExecution.analyzed
        //    println(plan)
        df.explain(true)
        df.show()
    
        spark.stop()
    
      }
    }
    

    This code output the LogicalPlan that I wanted.

    ====== new agg ======
    Aggregate [sum(udf0#9L) AS sum(inc(vals))#7L]
    +- Project [inc(vals#4L) AS udf0#9L]
       +- LocalRelation [vals#4L]