I have a table of field values and dates stored as a PySpark dataframe. What is the most sensible way in PySpark to add an additional column, that contains the date difference between each row and the most recent row that has the same field value? For a simple difference of consecutive rows, the answer is obviously the application of a "lag" function, but I am unsure of the most sensible way to apply a filter based on the specific field value in a row.
Example input:
1111| 23/May/2024
2222| 20/May/2024
3333| 19/May/2024
1111| 16/May/2024
4444| 12/May/2024
1111| 07/May/2024
2222| 01/May/2024
Desired Ouput:
1111| 23/May/2024| 7
2222| 20/May/2024| 19
3333| 19/May/2024| default
1111| 16/May/2024| 9
4444| 12/May/2024| default
1111| 07/May/2024| default
2222| 01/May/2024| default
You can achieve the desired result using the lag
function as below.
from pyspark.sql.window import Window
from pyspark.sql.functions import lag, datediff, when, col, lit
data = [
(1111, "2024-05-23"),
(2222, "2024-05-20"),
(3333, "2024-05-19"),
(1111, "2024-05-16"),
(4444, "2024-05-12"),
(1111, "2024-05-07"),
(2222, "2024-05-01"),
]
# Create DataFrame
df = spark.createDataFrame(data, ["id", "start"])
df.show()
+----+----------+
| id| start|
+----+----------+
|1111|2024-05-23|
|2222|2024-05-20|
|3333|2024-05-19|
|1111|2024-05-16|
|4444|2024-05-12|
|1111|2024-05-07|
|2222|2024-05-01|
+----+----------+
df = (
df.withColumn(
"previous_start", lag("start").over(Window.partitionBy("id").orderBy("start"))
)
.withColumn(
"datediff",
when(
col("previous_start").isNotNull(),
datediff(col("start"), col("previous_start")),
).otherwise(lit("default")),
)
.drop("previous_start")
.orderBy(col("start").desc())
)
df.show()
+----+----------+--------+
| id| start|datediff|
+----+----------+--------+
|1111|2024-05-23| 7|
|2222|2024-05-20| 19|
|3333|2024-05-19| default|
|1111|2024-05-16| 9|
|4444|2024-05-12| default|
|1111|2024-05-07| default|
|2222|2024-05-01| default|
+----+----------+--------+