Search code examples
dataframepyspark

PySpark: Filtering a Lag for Date Differences


I have a table of field values and dates stored as a PySpark dataframe. What is the most sensible way in PySpark to add an additional column, that contains the date difference between each row and the most recent row that has the same field value? For a simple difference of consecutive rows, the answer is obviously the application of a "lag" function, but I am unsure of the most sensible way to apply a filter based on the specific field value in a row.

Example input:

1111| 23/May/2024
2222| 20/May/2024
3333| 19/May/2024
1111| 16/May/2024
4444| 12/May/2024
1111| 07/May/2024
2222| 01/May/2024

Desired Ouput:

1111| 23/May/2024| 7
2222| 20/May/2024| 19
3333| 19/May/2024| default
1111| 16/May/2024| 9
4444| 12/May/2024| default
1111| 07/May/2024| default
2222| 01/May/2024| default

Solution

  • You can achieve the desired result using the lag function as below.

    from pyspark.sql.window import Window
    from pyspark.sql.functions import lag, datediff, when, col, lit
    
    data = [
        (1111, "2024-05-23"),
        (2222, "2024-05-20"),
        (3333, "2024-05-19"),
        (1111, "2024-05-16"),
        (4444, "2024-05-12"),
        (1111, "2024-05-07"),
        (2222, "2024-05-01"),
    ]
    # Create DataFrame
    df = spark.createDataFrame(data, ["id", "start"])
    df.show()
    
    +----+----------+
    |  id|     start|
    +----+----------+
    |1111|2024-05-23|
    |2222|2024-05-20|
    |3333|2024-05-19|
    |1111|2024-05-16|
    |4444|2024-05-12|
    |1111|2024-05-07|
    |2222|2024-05-01|
    +----+----------+
    
    df = (
        df.withColumn(
            "previous_start", lag("start").over(Window.partitionBy("id").orderBy("start"))
        )
        .withColumn(
            "datediff",
            when(
                col("previous_start").isNotNull(),
                datediff(col("start"), col("previous_start")),
            ).otherwise(lit("default")),
        )
        .drop("previous_start")
        .orderBy(col("start").desc())
    )
    df.show()
    
    +----+----------+--------+
    |  id|     start|datediff|
    +----+----------+--------+
    |1111|2024-05-23|       7|
    |2222|2024-05-20|      19|
    |3333|2024-05-19| default|
    |1111|2024-05-16|       9|
    |4444|2024-05-12| default|
    |1111|2024-05-07| default|
    |2222|2024-05-01| default|
    +----+----------+--------+