Search code examples
pythonpysparkgroup-bywindow-functions

Pyspark Rolling Sum based on ID, timestamp and condition


I have the following pyspark dataframe

id  timestamp   col1
1   2022-01-01   0
1   2022-01-02   1
1   2022-01-03   1
1   2022-01-04   0
2   2022-01-01   1
2   2022-01-02   0
2   2022-01-03   1

I would like to get the cumulative sum of col1 for each ID and based on timestamp as an additional column and obtain something like this:

id  timestamp   col1  cum_sum
1   2022-01-01   0      0
1   2022-01-02   1      1
1   2022-01-03   1      2
1   2022-01-04   0      2
2   2022-01-01   1      1
2   2022-01-02   0      1
2   2022-01-03   1      2

Probably a Window Function can work here but I am not sure how to count only when col1 is equal to 1.


Solution

  • You indeed need a window function and a sum, the orderby on the window function is what makes it 'rolling'.

    import pyspark.sql.functions as F
    from pyspark.sql import Window
    
    w = Window.partitionBy('id').orderBy('timestamp')
    df.withColumn('cum_sum', F.sum('col1').over(w)).show()
    
    +---+----------+----+-------+
    | id| timestamp|col1|cum_sum|
    +---+----------+----+-------+
    |  1|2022-01-01|   0|      0|
    |  1|2022-01-02|   1|      1|
    |  1|2022-01-03|   1|      2|
    |  1|2022-01-04|   0|      2|
    |  2|2022-01-01|   1|      1|
    |  2|2022-01-02|   0|      1|
    |  2|2022-01-03|   1|      2|
    +---+----------+----+-------+