Search code examples
apache-sparkpysparkgroup

Pyspark: Order by values of one column, but generate group id based on another column


I have a dataframe like this

import pyspark.sql.functions as F
from pyspark.sql.window import Window

have = spark.createDataFrame(
    [('a', 'r1', '1'),
    ('b', 'r1', '2'),
    ('c', 'r1', '3'),
    ('d', 's3', '4'),
    ('e', 's3', '5'),
    ('f', 's4', '6'),
    ('g', 'r1', '7')],
    ['id', 'group_col', 'order_col'])

I want to create a group ID column based on the group_col, but only when the group changes. So when the r1 group comes up again, it gets a different group ID than the first r1's that appear.

want = spark.createDataFrame(
    [('a', 'r1', '1', '1'),
    ('b', 'r1', '2', '1'),
    ('c', 'r1', '3', '1'),
    ('d', 's3', '4', '2'),
    ('e', 's3', '5', '2'),
    ('f', 's4', '6', '3'),
    ('g', 'r1', '7', '4')],
    ['id', 'group_col', 'order_col', 'rleid'])

want.show()
+---+---------+---------+-----+
| id|group_col|order_col|rleid|
+---+---------+---------+-----+
|  a|       r1|        1|    1|
|  b|       r1|        2|    1|
|  c|       r1|        3|    1|
|  d|       s3|        4|    2|
|  e|       s3|        5|    2|
|  f|       s4|        6|    3|
|  g|       r1|        7|    4|
+---+---------+---------+-----+

The group IDs don't have to be contiguous, I just need a way for each group to be unique.

Basically I want something like the rleid function in the data.table R package. The equivalent R code would be:

library(data.table)
df <- data.table(
    id = letters[1:7],
    group_col = c("r1", "r1", "r1", "s3", "s3", "s4", "r1"),
    order_col = c(1:7)
)
setorder(df, order_col)
df[, `:=` (rleid = rleid(group_col))]
df
   id group_col order_col rleid
1:  a        r1         1     1
2:  b        r1         2     1
3:  c        r1         3     1
4:  d        s3         4     2
5:  e        s3         5     2
6:  f        s4         6     3
7:  g        r1         7     4

I have tried rank() and dense_rank() over group_col. E.g.

df = have.withColumn("rleid", F.dense_rank().over(Window.orderBy('group_col')))
df.show()
+---+---------+---------+-----+
| id|group_col|order_col|rleid|
+---+---------+---------+-----+
|  a|       r1|        1|    1|
|  b|       r1|        2|    1|
|  c|       r1|        3|    1|
|  g|       r1|        7|    1|
|  d|       s3|        4|    2|
|  e|       s3|        5|    2|
|  f|       s4|        6|    3|
+---+---------+---------+-----+

This doesn't give me what I want because id=g should have rleid=4.

I have also tried array_sort() based on this answer, but unfortunately this didn't work either.

df = (have
    .withColumn("rank", F.array_sort(F.collect_set('group_col').over(Window.orderBy('order_col').rowsBetween(Window.unboundedPreceding, Window.currentRow))))
    .withColumn('rleid', F.expr("array_position(rank, group_col)")))

df.show()
+---+---------+---------+------------+-----+
| id|group_col|order_col|        rank|rleid|
+---+---------+---------+------------+-----+
|  a|       r1|        1|        [r1]|    1|
|  b|       r1|        2|        [r1]|    1|
|  c|       r1|        3|        [r1]|    1|
|  d|       s3|        4|    [r1, s3]|    2|
|  e|       s3|        5|    [r1, s3]|    2|
|  f|       s4|        6|[r1, s3, s4]|    3|
|  g|       r1|        7|[r1, s3, s4]|    1|
+---+---------+---------+------------+-----+

Solution

  • Compare the current and previous rows in group_col to flag the boundary points where group changes, then calculate the cumulative sum over the boundary condition to assign unique id

    W = Window.orderBy('order_col')
    cond = F.lag('group_col').over(W) != F.col('group_col')
    want = have.withColumn('rleid', F.coalesce(F.sum(cond.cast('int')).over(W), F.lit(0)))
    

    want.show()
    
    +---+---------+---------+-----+
    | id|group_col|order_col|rleid|
    +---+---------+---------+-----+
    |  a|       r1|        1|    0|
    |  b|       r1|        2|    0|
    |  c|       r1|        3|    0|
    |  d|       s3|        4|    1|
    |  e|       s3|        5|    1|
    |  f|       s4|        6|    2|
    |  g|       r1|        7|    3|
    +---+---------+---------+-----+