I have a dataframe like this
import pyspark.sql.functions as F
from pyspark.sql.window import Window
have = spark.createDataFrame(
[('a', 'r1', '1'),
('b', 'r1', '2'),
('c', 'r1', '3'),
('d', 's3', '4'),
('e', 's3', '5'),
('f', 's4', '6'),
('g', 'r1', '7')],
['id', 'group_col', 'order_col'])
I want to create a group ID column based on the group_col
, but only when the group changes. So when the r1 group comes up again, it gets a different group ID than the first r1's that appear.
want = spark.createDataFrame(
[('a', 'r1', '1', '1'),
('b', 'r1', '2', '1'),
('c', 'r1', '3', '1'),
('d', 's3', '4', '2'),
('e', 's3', '5', '2'),
('f', 's4', '6', '3'),
('g', 'r1', '7', '4')],
['id', 'group_col', 'order_col', 'rleid'])
want.show()
+---+---------+---------+-----+
| id|group_col|order_col|rleid|
+---+---------+---------+-----+
| a| r1| 1| 1|
| b| r1| 2| 1|
| c| r1| 3| 1|
| d| s3| 4| 2|
| e| s3| 5| 2|
| f| s4| 6| 3|
| g| r1| 7| 4|
+---+---------+---------+-----+
The group IDs don't have to be contiguous, I just need a way for each group to be unique.
Basically I want something like the rleid function in the data.table R package. The equivalent R code would be:
library(data.table)
df <- data.table(
id = letters[1:7],
group_col = c("r1", "r1", "r1", "s3", "s3", "s4", "r1"),
order_col = c(1:7)
)
setorder(df, order_col)
df[, `:=` (rleid = rleid(group_col))]
df
id group_col order_col rleid
1: a r1 1 1
2: b r1 2 1
3: c r1 3 1
4: d s3 4 2
5: e s3 5 2
6: f s4 6 3
7: g r1 7 4
I have tried rank()
and dense_rank()
over group_col
.
E.g.
df = have.withColumn("rleid", F.dense_rank().over(Window.orderBy('group_col')))
df.show()
+---+---------+---------+-----+
| id|group_col|order_col|rleid|
+---+---------+---------+-----+
| a| r1| 1| 1|
| b| r1| 2| 1|
| c| r1| 3| 1|
| g| r1| 7| 1|
| d| s3| 4| 2|
| e| s3| 5| 2|
| f| s4| 6| 3|
+---+---------+---------+-----+
This doesn't give me what I want because id=g should have rleid=4.
I have also tried array_sort()
based on this answer, but unfortunately this didn't work either.
df = (have
.withColumn("rank", F.array_sort(F.collect_set('group_col').over(Window.orderBy('order_col').rowsBetween(Window.unboundedPreceding, Window.currentRow))))
.withColumn('rleid', F.expr("array_position(rank, group_col)")))
df.show()
+---+---------+---------+------------+-----+
| id|group_col|order_col| rank|rleid|
+---+---------+---------+------------+-----+
| a| r1| 1| [r1]| 1|
| b| r1| 2| [r1]| 1|
| c| r1| 3| [r1]| 1|
| d| s3| 4| [r1, s3]| 2|
| e| s3| 5| [r1, s3]| 2|
| f| s4| 6|[r1, s3, s4]| 3|
| g| r1| 7|[r1, s3, s4]| 1|
+---+---------+---------+------------+-----+
Compare the current and previous rows in group_col to flag the boundary points where group changes, then calculate the cumulative sum over the boundary condition to assign unique id
W = Window.orderBy('order_col')
cond = F.lag('group_col').over(W) != F.col('group_col')
want = have.withColumn('rleid', F.coalesce(F.sum(cond.cast('int')).over(W), F.lit(0)))
want.show()
+---+---------+---------+-----+
| id|group_col|order_col|rleid|
+---+---------+---------+-----+
| a| r1| 1| 0|
| b| r1| 2| 0|
| c| r1| 3| 0|
| d| s3| 4| 1|
| e| s3| 5| 1|
| f| s4| 6| 2|
| g| r1| 7| 3|
+---+---------+---------+-----+