Search code examples
scalaapache-sparkapache-spark-sqlproductaggregation

Cumulative product in Spark


I try to implement a cumulative product in Spark Scala, but I really don't know how to it. I have the following dataframe:

Input data:
+--+--+--------+----+
|A |B | date   | val|
+--+--+--------+----+
|rr|gg|20171103| 2  |
|hh|jj|20171103| 3  |
|rr|gg|20171104| 4  |
|hh|jj|20171104| 5  |
|rr|gg|20171105| 6  |
|hh|jj|20171105| 7  |
+-------+------+----+

And I would like to have the following output:

Output data:
+--+--+--------+-----+
|A |B | date   | val |
+--+--+--------+-----+
|rr|gg|20171105| 48  | // 2 * 4 * 6
|hh|jj|20171105| 105 | // 3 * 5 * 7
+-------+------+-----+

Solution

  • As long as the number are strictly positive (0 can be handled as well, if present, using coalesce) as in your example, the simplest solution is to compute the sum of logarithms and take the exponential:

    import org.apache.spark.sql.functions.{exp, log, max, sum}
    
    val df = Seq(
      ("rr", "gg", "20171103", 2), ("hh", "jj", "20171103", 3), 
      ("rr", "gg", "20171104", 4), ("hh", "jj", "20171104", 5), 
      ("rr", "gg", "20171105", 6), ("hh", "jj", "20171105", 7)
    ).toDF("A", "B", "date", "val")
    
    val result = df
      .groupBy("A", "B")
      .agg(
        max($"date").as("date"), 
        exp(sum(log($"val"))).as("val"))
    

    Since this uses FP arithmetic the result won't be exact:

    result.show
    
    +---+---+--------+------------------+
    |  A|  B|    date|               val|
    +---+---+--------+------------------+
    | hh| jj|20171105|104.99999999999997|
    | rr| gg|20171105|47.999999999999986|
    +---+---+--------+------------------+
    

    but after rounding should good enough for majority of applications.

    result.withColumn("val", round($"val")).show
    
    +---+---+--------+-----+
    |  A|  B|    date|  val|
    +---+---+--------+-----+
    | hh| jj|20171105|105.0|
    | rr| gg|20171105| 48.0|
    +---+---+--------+-----+
    

    If that's not enough you can define an UserDefinedAggregateFunction or Aggregator (How to define and use a User-Defined Aggregate Function in Spark SQL?) or use functional API with reduceGroups:

    import scala.math.Ordering
    
    case class Record(A: String, B: String, date: String, value: Long)
    
    df.withColumnRenamed("val", "value").as[Record]
      .groupByKey(x => (x.A, x.B))
      .reduceGroups((x, y) => x.copy(
        date = Ordering[String].max(x.date, y.date),
        value = x.value * y.value))
      .toDF("key", "value")
      .select($"value.*")
      .show
    
    +---+---+--------+-----+
    |  A|  B|    date|value|
    +---+---+--------+-----+
    | hh| jj|20171105|  105|
    | rr| gg|20171105|   48|
    +---+---+--------+-----+