Search code examples
apache-sparkpysparkdynamicwindow-functionslag

Dynamically update a Spark dataframe column when used with lag and window functions


I would like to generate the below dataframe

enter image description here

Here, I am calculating the "adstock" based on the column "col_lag" and an engagement factor 0.9 as below:

# window
windowSpec  = Window.partitionBy("id").orderBy("dt")

# create the column if it does not exist
if ('adstock' not in df.columns):
    df = df.withColumn("adstock",lit(0))

df = df.withColumn("adstock", (col('col_lag') + (lit(0.9)*(lag("adstock", 1).over(windowSpec)))))

When I run the above, somehow the code does not generate values after two or three rows and gives something like below:

enter image description here

I have around 125000 Ids and weekly data from 2020-01-24 to current week. I tried various methods like rowsBetween(Window.unboundedPreceding, 1) or creation of another column etc., but have not been successful.

I would appreciate any suggestions in this regard.


Solution

  • Spark does not do calculations going from row to row, so it cannot access the result of previous row of the current calculation. To go around this, you may move all the values for the same id to one row and build a calculation logic from there. Higher-order function aggregate allows to do kind-of loops with the ability to access the previous value.

    Input:

    from pyspark.sql import functions as F
    df = spark.createDataFrame(
        [(1, '2020-10-07', 1),
         (1, '2020-10-14', 2),
         (1, '2020-10-21', 4),
         (1, '2020-10-28', 0),
         (2, '2021-09-08', 1),
         (2, '2021-09-15', 2),
         (2, '2021-09-22', 0),
         (2, '2021-09-29', 0)],
        ['id', 'dt', 'col_lag'])
    

    Script:

    df = df.groupBy("id").agg(
        F.aggregate(
            F.array_sort(F.collect_list(F.struct("dt", "col_lag"))),
            F.expr("array(struct(string(null) dt, 0L col_lag, 0D adstock))"),
            lambda acc, x: F.array_union(
                acc,
                F.array(x.withField(
                    'adstock',
                    x["col_lag"] + F.lit(0.9) * F.element_at(acc, -1)['adstock']
                ))
            )
        ).alias("a")
    )
    df = df.selectExpr("id", "inline(slice(a, 2, size(a)))")
    
    df.show()
    # +---+----------+-------+------------------+
    # | id|        dt|col_lag|           adstock|
    # +---+----------+-------+------------------+
    # |  1|2020-10-07|      1|               1.0|
    # |  1|2020-10-14|      2|               2.9|
    # |  1|2020-10-21|      4| 6.609999999999999|
    # |  1|2020-10-28|      0|             5.949|
    # |  2|2021-09-08|      1|               1.0|
    # |  2|2021-09-15|      2|               2.9|
    # |  2|2021-09-22|      0|              2.61|
    # |  2|2021-09-29|      0|2.3489999999999998|
    # +---+----------+-------+------------------+
    

    Thorough explanation of the script is provided in this answer.