I'm fairly new to PySpark, but I am trying to use best practices in my code. I have a PySpark dataframe and I would like to lag multiple columns, replacing the original values with the lagged values. Example:
ID date value1 value2 value3
1 2021-12-23 1.1 4.0 2.2
2 2021-12-21 2.4 1.6 11.9
1 2021-12-24 5.4 3.2 7.8
2 2021-12-22 4.2 1.4 9.0
1 2021-12-26 2.3 5.2 7.6
.
.
.
I'd like to take all values according to ID
, order them by date
, then lag the values by some amount. The code I have so far:
from pyspark.sql import functions as F, Window
window = Window.partitionBy(F.col("ID")).orderBy(F.col("date"))
valueColumns = ['value1', 'value2', 'value3']
df = F.lag(valueColumns, offset=shiftAmount).over(window)
My desired output would be:
ID date value1 value2 value3
1 2021-12-23 Null Null Null
2 2021-12-21 Null Null Null
1 2021-12-24 1.1 4.0 2.2
2 2021-12-22 2.4 1.6 11.9
1 2021-12-26 5.4 3.2 7.86
.
.
.
The problem I'm having is that, from what I can find, F.lag
only accepts a single column. I'm looking for suggestions on how to best accomplish this. I suppose I could use a for loop to just append shifted columns or something, but this seems pretty inelegant. Thanks!
A simple list comprehension on column names should do the job:
df = df.select(
"ID", "date",
*[F.lag(c, offset=shiftAmount).over(window).alias(c) for c in valueColumns]
)