Search code examples
scalaapache-sparkdataframeapache-spark-sqlwindow-functions

Skip the current row COUNT and sum up the other COUNTS for current key with Spark Dataframe


My input:

 val df = sc.parallelize(Seq(
  ("0","car1", "success"),
  ("0","car1", "success"),
  ("0","car3", "success"),
  ("0","car2", "success"),
  ("1","car1", "success"),
  ("1","car2", "success"),
  ("0","car3", "success")
)).toDF("id", "item", "status")

My intermediary group by output looks like this:

val df2 = df.groupBy("id", "item").agg(count("item").alias("occurences"))
+---+----+----------+
| id|item|occurences|
+---+----+----------+
|  0|car3|         2|
|  0|car2|         1|
|  0|car1|         2|
|  1|car2|         1|
|  1|car1|         1|
+---+----+----------+

The output I would like is: Calculating the sum of occurrences of item skipping the occurrences value of the current id's item.

For example in the output table below, car3 appeared for id "0" 2 times, car 2 appeared 1 time and car 1 appeared 2 times.

So for id "0", the sum of other occurrences for its "car3" item would be value of car2(1) + car1(2) = 3.
For the same id "0", the sum of other occurrences for its "car2" item would be value of car3(2) + car1(2) = 4.

This continues for the rest. Sample output

+---+----+----------+----------------------+
| id|item|occurences| other_occurences_sum |
+---+----+----------+----------------------+
|  0|car3|         2|          3           |<- (car2+car1) for id 0
|  0|car2|         1|          4           |<- (car3+car1) for id 0  
|  0|car1|         2|          3           |<- (car3+car2) for id 0
|  1|car2|         1|          1           |<- (car1) for id 1
|  1|car1|         1|          1           |<- (car2) for id 1
+---+----+----------+----------------------+

Solution

  • That's perfect target for a window function.

    import org.apache.spark.sql.expressions.Window
    import org.apache.spark.sql.functions.sum
    
    val w = Window.partitionBy("id")
    
    df2.withColumn(
      "other_occurences_sum", sum($"occurences").over(w) - $"occurences"
    ).show
    // +---+----+----------+--------------------+     
    // | id|item|occurences|other_occurences_sum|
    // +---+----+----------+--------------------+
    // |  0|car3|         2|                   3|
    // |  0|car2|         1|                   4|
    // |  0|car1|         2|                   3|
    // |  1|car2|         1|                   1|
    // |  1|car1|         1|                   1|
    // +---+----+----------+--------------------+
    

    where sum($"occurences").over(w) is a sum of all occurrences for the current id. Of course join is also valid:

    df2.join(
      df2.groupBy("id").agg(sum($"occurences") as "total"), Seq("id")
    ).select(
        $"*", ($"total" - $"occurences") as "other_occurences_sum"
    ).show
    
    // +---+----+----------+--------------------+
    // | id|item|occurences|other_occurences_sum|
    // +---+----+----------+--------------------+
    // |  0|car3|         2|                   3|
    // |  0|car2|         1|                   4|
    // |  0|car1|         2|                   3|
    // |  1|car2|         1|                   1|
    // |  1|car1|         1|                   1|
    // +---+----+----------+--------------------+