Search code examples
scalaapache-sparkdataframeapache-spark-sqluser-defined-functions

Spark/Scala repeated calls to withColumn() using the same function on multiple columns


I currently have code in which I repeatedly apply the same procedure to multiple DataFrame Columns via multiple chains of .withColumn, and am wanting to create a function to streamline the procedure. In my case, I am finding cumulative sums over columns aggregated by keys:

val newDF = oldDF
  .withColumn("cumA", sum("A").over(Window.partitionBy("ID").orderBy("time")))
  .withColumn("cumB", sum("B").over(Window.partitionBy("ID").orderBy("time")))
  .withColumn("cumC", sum("C").over(Window.partitionBy("ID").orderBy("time")))
  //.withColumn(...)

What I would like is either something like:

def createCumulativeColums(cols: Array[String], df: DataFrame): DataFrame = {
  // Implement the above cumulative sums, partitioning, and ordering
}

or better yet:

def withColumns(cols: Array[String], df: DataFrame, f: function): DataFrame = {
  // Implement a udf/arbitrary function on all the specified columns
}

Solution

  • You can use select with varargs including *:

    import spark.implicits._
    
    df.select($"*" +: Seq("A", "B", "C").map(c => 
      sum(c).over(Window.partitionBy("ID").orderBy("time")).alias(s"cum$c")
    ): _*)
    

    This:

    • Maps columns names to window expressions with Seq("A", ...).map(...)
    • Prepends all pre-existing columns with $"*" +: ....
    • Unpacks combined sequence with ... : _*.

    and can be generalized as:

    import org.apache.spark.sql.{Column, DataFrame}
    
    /**
     * @param cols a sequence of columns to transform
     * @param df an input DataFrame
     * @param f a function to be applied on each col in cols
     */
    def withColumns(cols: Seq[String], df: DataFrame, f: String => Column) =
      df.select($"*" +: cols.map(c => f(c)): _*)
    

    If you find withColumn syntax more readable you can use foldLeft:

    Seq("A", "B", "C").foldLeft(df)((df, c) =>
      df.withColumn(s"cum$c",  sum(c).over(Window.partitionBy("ID").orderBy("time")))
    )
    

    which can be generalized for example to:

    /**
     * @param cols a sequence of columns to transform
     * @param df an input DataFrame
     * @param f a function to be applied on each col in cols
     * @param name a function mapping from input to output name.
     */
    def withColumns(cols: Seq[String], df: DataFrame, 
        f: String =>  Column, name: String => String = identity) =
      cols.foldLeft(df)((df, c) => df.withColumn(name(c), f(c)))