Search code examples
scalaapache-sparkapache-spark-sqlgroupingranking

Spark Dataframe: Group and rank rows on a certain column value


I am trying to rank a column when the "ID" column numbering starts from 1 to max and then resets from 1.

So, the first three rows have a continuous numbering on "ID"; hence these should be grouped with group rank =1. Rows four and five are in another group, group rank = 2.

The rows are sorted by "rownum" column. I am aware of the row_number window function but I don't think I can apply for this use case as there is no constant window. I can only think of looping through each row in the dataframe but not sure how I can update a column when number resets to 1.

val df = Seq( (1, 1 ), (2, 2 ), (3, 3 ), (4, 1), (5, 2), (6, 1), (7, 1), (8, 2) ).toDF("rownum", "ID") df.show()

enter image description here

Expected result is below: enter image description here


Solution

  • You can do it with 2 window-functions, the first one to flag the state, the second one to calculate a running sum:

    df
      .withColumn("increase", $"ID" > lag($"ID",1).over(Window.orderBy($"rownum")))
      .withColumn("group_rank_of_ID",sum(when($"increase",lit(0)).otherwise(lit(1))).over(Window.orderBy($"rownum")))
      .drop($"increase")
      .show()
    

    gives:

    +------+---+----------------+
    |rownum| ID|group_rank_of_ID|
    +------+---+----------------+
    |     1|  1|               1|
    |     2|  2|               1|
    |     3|  3|               1|
    |     4|  1|               2|
    |     5|  2|               2|
    |     6|  1|               3|
    |     7|  1|               4|
    |     8|  2|               4|
    +------+---+----------------+