Search code examples
apache-sparkapache-spark-sqlapache-spark-dataset

Spark dynamic DAG is a lot slower and different from hard coded DAG


I have an operation in spark which should be performed for several columns in a data frame. Generally, there are 2 possibilities to specify such operations

  • hardcode
handleBias("bar", df)
  .join(handleBias("baz", df), df.columns)
  .drop(columnsToDrop: _*).show
  • dynamically generate them from a list of colnames
var isFirst = true
var res = df
for (col <- columnsToDrop ++ columnsToCode) {
  if (isFirst) {
    res = handleBias(col, res)
    isFirst = false
  } else {
    res = handleBias(col, res)
  }
}
res.drop(columnsToDrop: _*).show

The problem is that the DAG generated dynamically is different and the runtime of the dynamic solution increases far more when more columns are used than for the hard coded operations.

I am curious how to combine the elegance of the dynamic construction with quick execution times.

Here is the comparison for the DAGs of the example code complexity comparison

For around 80 columns this results in a rather nice graph for the hard-coded variant hardCoded And a very big, probably less parallelizable and way slower DAG for the dynamically constructed query. hugeMessDynamic

A current version of spark (2.0.2) was used with DataFrames and spark-sql

Code to complete the minimal example:

def handleBias(col: String, df: DataFrame, target: String = "FOO"): DataFrame = {
  val pre1_1 = df
    .filter(df(target) === 1)
    .groupBy(col, target)
    .agg((count("*") / df.filter(df(target) === 1).count).alias("pre_" + col))
    .drop(target)

  val pre2_1 = df
    .groupBy(col)
    .agg(mean(target).alias("pre2_" + col))

  df
    .join(pre1_1, Seq(col), "left")
    .join(pre2_1, Seq(col), "left")
    .na.fill(0)
}

edit

Running your task with foldleft generates a linear DAG foldleft and hard coding the function for all the columns results in hardcoded

Both are a lot better than my original DAGs but still, the hardcoded variant looks better to me. String concatenating a SQL statement in spark could allow me to dynamically generate the hard coded execution graph but that seems rather ugly. Do you see any other option?


Solution

  • Edit 1: Removed one window function from handleBias and transformed it into a broadcast join.

    Edit 2: Changed replacing strategy for null values.

    I have some suggestions that can improve your code. First, for the "handleBias" function, I would do it using window functions and "withColumn" calls, avoiding the joins:

    import org.apache.spark.sql.DataFrame
    import org.apache.spark.sql.functions._
    import org.apache.spark.sql.expressions.Window
    
    def handleBias(df: DataFrame, colName: String, target: String = "foo") = {
      val w1 = Window.partitionBy(colName)
      val w2 = Window.partitionBy(colName, target)
      val result = df
        .withColumn("cnt_group", count("*").over(w2))
        .withColumn("pre2_" + colName, mean(target).over(w1))
        .withColumn("pre_" + colName, coalesce(min(col("cnt_group") / col("cnt_foo_eq_1")).over(w1), lit(0D)))
        .drop("cnt_group")
      result
    }
    

    Then, for calling it for multiple columns, I would recommend using foldLeft which is the "functional" approach for this kind of problem:

    val df = Seq((1, "first", "A"), (1, "second", "A"),(2, "noValidFormat", "B"),(1, "lastAssumingSameDate", "C")).toDF("foo", "bar", "baz")
    
    val columnsToDrop = Seq("baz")
    val columnsToCode = Seq("bar", "baz")
    val target = "foo"
    
    val targetCounts = df.filter(df(target) === 1).groupBy(target)
      .agg(count(target).as("cnt_foo_eq_1"))
    val newDF = df.join(broadcast(targetCounts), Seq(target), "left")
    
    val result = (columnsToDrop ++ columnsToCode).toSet.foldLeft(df) {
      (currentDF, colName) => handleBias(currentDF, colName)
    }
    
    result.drop(columnsToDrop:_*).show()
    
    +---+--------------------+------------------+--------+------------------+--------+
    |foo|                 bar|           pre_baz|pre2_baz|           pre_bar|pre2_bar|
    +---+--------------------+------------------+--------+------------------+--------+
    |  2|       noValidFormat|               0.0|     2.0|               0.0|     2.0|
    |  1|lastAssumingSameDate|0.3333333333333333|     1.0|0.3333333333333333|     1.0|
    |  1|              second|0.6666666666666666|     1.0|0.3333333333333333|     1.0|
    |  1|               first|0.6666666666666666|     1.0|0.3333333333333333|     1.0|
    +---+--------------------+------------------+--------+------------------+--------+
    

    I'm not sure it will improve a lot your DAG, but at least it makes the code cleaner and more readable.

    Reference: