Pyspark Group By Date Range

I have a sample pyspark dataframe that can be created like this

sample_df = spark.createDataFrame([
    ('2020-01-01', '2021-01-01', 1),
    ('2020-02-01', '2021-02-01', 1),
    ('2021-01-15', '2022-01-15', 2),
    ('2022-01-15', '2023-01-15', 2),
    ('2022-02-01', '2023-02-01', 3),
    ('2022-03-01', '2023-03-01', 3),
    ('2023-03-01', '2024-03-01', 4),
  ], ['item_date', 'max_window', 'expected_grouping_index'])

After sorting by item_date, I want to assume the first item starts a grouping. Any following item that is less than or equal to the first items max_window (which will always be the same number of days added to the item_date for the entire df, about 365 days in this example) will be given the same grouping_index.

If an item does not fall inside the grouping, it will start a new grouping and be given another arbitrary grouping_index. Then all following items will be assessed based on that items max_window. And so on.

The grouping_index is just a means goal, I eventually want to only keep the first row in each group.

How can I achieve this without a UDF or converting to a pandas df?


  • I used the following answer as an inspiration to write the following code. Basically, clever use of complex accumulator function allows the grouping index to be performed properly.

    import sys
    from pyspark.sql import Window
    from pyspark import SQLContext
    from pyspark.sql.functions import *
    import pyspark.sql.functions as F
    spark = SparkContext('local')
    sqlContext = SQLContext(spark)
    sample_df = sqlContext.createDataFrame([
        ('2020-01-01', '2021-01-01', 1),
        ('2020-02-01', '2021-02-01', 1),
        ('2021-01-15', '2022-01-15', 2),
        ('2022-01-15', '2023-01-15', 2),
        ('2022-02-01', '2023-02-01', 3),
        ('2022-03-01', '2023-03-01', 3),
        ('2023-03-01', '2024-03-01', 4),
    ], ['item_date', 'max_window', 'expected_grouping_index']), truncate=False)
    sample_df = sample_df.withColumn("item_date", col("item_date").cast("date"))
    windowSpec = Window.orderBy("item_date")
    sample_df = sample_df.withColumn("datedifference",
                                     datediff(col("item_date"), lag(col("item_date"), 1).over(windowSpec)))
    w = (windowSpec.rowsBetween(Window.unboundedPreceding, Window.currentRow))
    sample_df = sample_df.withColumn('datedifferencelist', F.collect_list('datedifference').over(w))
    # expr = "AGGREGATE(datedifferencelist, (0,0), (acc, el) -> IF(acc[0] + el < 365, (acc[0] + el, acc[1]),  (el, acc[1]+1) ))"
    # sample_df = sample_df.withColumn('cumsum', F.expr(expr))
    initial_value = F.array(F.lit(0), F.lit(0))
    sample_df = sample_df.withColumn('cumsum', aggregate("datedifferencelist", initial_value,
                                                         lambda acc, x:  F.when( ((acc.getItem(0) + x) <= 365), F.array(acc.getItem(0) + x, acc.getItem(1))).otherwise( F.array(F.lit(0), acc.getItem(1) + 1)   )    ))
    sample_df = sample_df.withColumn('group_id', F.col('cumsum').getItem(1))

    Output :

    |item_date |max_window|expected_grouping_index|datedifference|datedifferencelist         |cumsum  |group_id|
    |2020-01-01|2021-01-01|1                      |NULL          |[]                         |[0, 0]  |0       |
    |2020-02-01|2021-02-01|1                      |31            |[31]                       |[31, 0] |0       |
    |2021-01-15|2022-01-15|2                      |349           |[31, 349]                  |[0, 1]  |1       |
    |2022-01-15|2023-01-15|2                      |365           |[31, 349, 365]             |[365, 1]|1       |
    |2022-02-01|2023-02-01|3                      |17            |[31, 349, 365, 17]         |[0, 2]  |2       |
    |2022-03-01|2023-03-01|3                      |28            |[31, 349, 365, 17, 28]     |[28, 2] |2       |
    |2023-03-01|2024-03-01|4                      |365           |[31, 349, 365, 17, 28, 365]|[0, 3]  |3       |