Search code examples
pythonpysparkgroup-by

Pyspark Group By Date Range


I have a sample pyspark dataframe that can be created like this

sample_df = spark.createDataFrame([
    ('2020-01-01', '2021-01-01', 1),
    ('2020-02-01', '2021-02-01', 1),
    ('2021-01-15', '2022-01-15', 2),
    ('2022-01-15', '2023-01-15', 2),
    ('2022-02-01', '2023-02-01', 3),
    ('2022-03-01', '2023-03-01', 3),
    ('2023-03-01', '2024-03-01', 4),
  ], ['item_date', 'max_window', 'expected_grouping_index'])

After sorting by item_date, I want to assume the first item starts a grouping. Any following item that is less than or equal to the first items max_window (which will always be the same number of days added to the item_date for the entire df, about 365 days in this example) will be given the same grouping_index.

If an item does not fall inside the grouping, it will start a new grouping and be given another arbitrary grouping_index. Then all following items will be assessed based on that items max_window. And so on.

The grouping_index is just a means goal, I eventually want to only keep the first row in each group.

How can I achieve this without a UDF or converting to a pandas df?


Solution

  • I used the following answer as an inspiration to write the following code. Basically, clever use of complex accumulator function allows the grouping index to be performed properly.

    https://stackoverflow.com/a/64957835/3238085

    import sys
    
    from pyspark.sql import Window
    from pyspark import SQLContext
    from pyspark.sql.functions import *
    import pyspark.sql.functions as F
    
    spark = SparkContext('local')
    sqlContext = SQLContext(spark)
    
    sample_df = sqlContext.createDataFrame([
        ('2020-01-01', '2021-01-01', 1),
        ('2020-02-01', '2021-02-01', 1),
        ('2021-01-15', '2022-01-15', 2),
        ('2022-01-15', '2023-01-15', 2),
        ('2022-02-01', '2023-02-01', 3),
        ('2022-03-01', '2023-03-01', 3),
        ('2023-03-01', '2024-03-01', 4),
    ], ['item_date', 'max_window', 'expected_grouping_index'])
    
    sample_df.show(100, truncate=False)
    
    sample_df = sample_df.withColumn("item_date", col("item_date").cast("date"))
    
    windowSpec = Window.orderBy("item_date")
    
    sample_df = sample_df.withColumn("datedifference",
                                     datediff(col("item_date"), lag(col("item_date"), 1).over(windowSpec)))
    
    w = (windowSpec.rowsBetween(Window.unboundedPreceding, Window.currentRow))
    sample_df = sample_df.withColumn('datedifferencelist', F.collect_list('datedifference').over(w))
    
    # expr = "AGGREGATE(datedifferencelist, (0,0), (acc, el) -> IF(acc[0] + el < 365, (acc[0] + el, acc[1]),  (el, acc[1]+1) ))"
    # sample_df = sample_df.withColumn('cumsum', F.expr(expr))
    
    initial_value = F.array(F.lit(0), F.lit(0))
    
    sample_df = sample_df.withColumn('cumsum', aggregate("datedifferencelist", initial_value,
                                                         lambda acc, x:  F.when( ((acc.getItem(0) + x) <= 365), F.array(acc.getItem(0) + x, acc.getItem(1))).otherwise( F.array(F.lit(0), acc.getItem(1) + 1)   )    ))
    
    sample_df = sample_df.withColumn('group_id', F.col('cumsum').getItem(1))
    
    sample_df.show(truncate=False)
    

    Output :

    +----------+----------+-----------------------+--------------+---------------------------+--------+--------+
    |item_date |max_window|expected_grouping_index|datedifference|datedifferencelist         |cumsum  |group_id|
    +----------+----------+-----------------------+--------------+---------------------------+--------+--------+
    |2020-01-01|2021-01-01|1                      |NULL          |[]                         |[0, 0]  |0       |
    |2020-02-01|2021-02-01|1                      |31            |[31]                       |[31, 0] |0       |
    |2021-01-15|2022-01-15|2                      |349           |[31, 349]                  |[0, 1]  |1       |
    |2022-01-15|2023-01-15|2                      |365           |[31, 349, 365]             |[365, 1]|1       |
    |2022-02-01|2023-02-01|3                      |17            |[31, 349, 365, 17]         |[0, 2]  |2       |
    |2022-03-01|2023-03-01|3                      |28            |[31, 349, 365, 17, 28]     |[28, 2] |2       |
    |2023-03-01|2024-03-01|4                      |365           |[31, 349, 365, 17, 28, 365]|[0, 3]  |3       |
    +----------+----------+-----------------------+--------------+---------------------------+--------+--------+