Search code examples
pythonapache-sparkpysparkapache-spark-sqllag

Perform Lag over multiple columns using PySpark


I'm fairly new to PySpark, but I am trying to use best practices in my code. I have a PySpark dataframe and I would like to lag multiple columns, replacing the original values with the lagged values. Example:

ID     date        value1     value2     value3
1      2021-12-23  1.1        4.0        2.2
2      2021-12-21  2.4        1.6        11.9
1      2021-12-24  5.4        3.2        7.8
2      2021-12-22  4.2        1.4        9.0
1      2021-12-26  2.3        5.2        7.6
.
.
.

I'd like to take all values according to ID, order them by date, then lag the values by some amount. The code I have so far:

from pyspark.sql import functions as F, Window

window = Window.partitionBy(F.col("ID")).orderBy(F.col("date"))

valueColumns = ['value1', 'value2', 'value3']

df = F.lag(valueColumns, offset=shiftAmount).over(window)

My desired output would be:

ID     date        value1     value2     value3
1      2021-12-23  Null       Null       Null
2      2021-12-21  Null       Null       Null
1      2021-12-24  1.1        4.0        2.2
2      2021-12-22  2.4        1.6        11.9
1      2021-12-26  5.4        3.2        7.86
.
.
.

The problem I'm having is that, from what I can find, F.lag only accepts a single column. I'm looking for suggestions on how to best accomplish this. I suppose I could use a for loop to just append shifted columns or something, but this seems pretty inelegant. Thanks!


Solution

  • A simple list comprehension on column names should do the job:

    df = df.select(
        "ID", "date",
        *[F.lag(c, offset=shiftAmount).over(window).alias(c) for c in valueColumns]
    )