I have a sample pyspark dataframe that can be created like this
sample_df = spark.createDataFrame([
('2020-01-01', '2021-01-01', 1),
('2020-02-01', '2021-02-01', 1),
('2021-01-15', '2022-01-15', 2),
('2022-01-15', '2023-01-15', 2),
('2022-02-01', '2023-02-01', 3),
('2022-03-01', '2023-03-01', 3),
('2023-03-01', '2024-03-01', 4),
], ['item_date', 'max_window', 'expected_grouping_index'])
After sorting by item_date
, I want to assume the first item starts a grouping. Any following item that is less than or equal to the first items max_window
(which will always be the same number of days added to the item_date
for the entire df, about 365 days in this example) will be given the same grouping_index
.
If an item does not fall inside the grouping, it will start a new grouping and be given another arbitrary grouping_index
. Then all following items will be assessed based on that items max_window
. And so on.
The grouping_index
is just a means goal, I eventually want to only keep the first row in each group.
How can I achieve this without a UDF or converting to a pandas df?
I used the following answer as an inspiration to write the following code. Basically, clever use of complex accumulator function allows the grouping index to be performed properly.
https://stackoverflow.com/a/64957835/3238085
import sys
from pyspark.sql import Window
from pyspark import SQLContext
from pyspark.sql.functions import *
import pyspark.sql.functions as F
spark = SparkContext('local')
sqlContext = SQLContext(spark)
sample_df = sqlContext.createDataFrame([
('2020-01-01', '2021-01-01', 1),
('2020-02-01', '2021-02-01', 1),
('2021-01-15', '2022-01-15', 2),
('2022-01-15', '2023-01-15', 2),
('2022-02-01', '2023-02-01', 3),
('2022-03-01', '2023-03-01', 3),
('2023-03-01', '2024-03-01', 4),
], ['item_date', 'max_window', 'expected_grouping_index'])
sample_df.show(100, truncate=False)
sample_df = sample_df.withColumn("item_date", col("item_date").cast("date"))
windowSpec = Window.orderBy("item_date")
sample_df = sample_df.withColumn("datedifference",
datediff(col("item_date"), lag(col("item_date"), 1).over(windowSpec)))
w = (windowSpec.rowsBetween(Window.unboundedPreceding, Window.currentRow))
sample_df = sample_df.withColumn('datedifferencelist', F.collect_list('datedifference').over(w))
# expr = "AGGREGATE(datedifferencelist, (0,0), (acc, el) -> IF(acc[0] + el < 365, (acc[0] + el, acc[1]), (el, acc[1]+1) ))"
# sample_df = sample_df.withColumn('cumsum', F.expr(expr))
initial_value = F.array(F.lit(0), F.lit(0))
sample_df = sample_df.withColumn('cumsum', aggregate("datedifferencelist", initial_value,
lambda acc, x: F.when( ((acc.getItem(0) + x) <= 365), F.array(acc.getItem(0) + x, acc.getItem(1))).otherwise( F.array(F.lit(0), acc.getItem(1) + 1) ) ))
sample_df = sample_df.withColumn('group_id', F.col('cumsum').getItem(1))
sample_df.show(truncate=False)
Output :
+----------+----------+-----------------------+--------------+---------------------------+--------+--------+
|item_date |max_window|expected_grouping_index|datedifference|datedifferencelist |cumsum |group_id|
+----------+----------+-----------------------+--------------+---------------------------+--------+--------+
|2020-01-01|2021-01-01|1 |NULL |[] |[0, 0] |0 |
|2020-02-01|2021-02-01|1 |31 |[31] |[31, 0] |0 |
|2021-01-15|2022-01-15|2 |349 |[31, 349] |[0, 1] |1 |
|2022-01-15|2023-01-15|2 |365 |[31, 349, 365] |[365, 1]|1 |
|2022-02-01|2023-02-01|3 |17 |[31, 349, 365, 17] |[0, 2] |2 |
|2022-03-01|2023-03-01|3 |28 |[31, 349, 365, 17, 28] |[28, 2] |2 |
|2023-03-01|2024-03-01|4 |365 |[31, 349, 365, 17, 28, 365]|[0, 3] |3 |
+----------+----------+-----------------------+--------------+---------------------------+--------+--------+