Search code examples
pythonapache-sparkpysparkapache-spark-sqlwindow-functions

efficient way to do cumulate sum on multiple columns in Pyspark


I have a table looks like:

+----+------+-----+-------+
|time|val1  |val2 |  class|
+----+------+-----+-------+
|   1|    3 |    2|      b|
|   2|    3 |    1|      b|
|   1|    2 |    4|      a|
|   2|    2 |    5|      a|
|   3|    1 |    5|      a|
+----+------+-----+-------+

Now I want to do cumulative sum on val1 and val2 columns. So I create a window function:

windowval = (Window.partitionBy('class').orderBy('time')
             .rangeBetween(Window.unboundedPreceding, 0))


new_df = my_df.withColumn('cum_sum1', F.sum("val1").over(windowval))
              .withColumn('cum_sum2', F.sum("val2").over(windowval))

But I think Spark will apply window function twice on the original table, which seems less efficient. Since the problem is pretty straightforward, is there a way to simply apply window function once, and do cumulative sum on both columns together?


Solution

  • But I think Spark will apply window function twice on the original table, which seems less efficient.

    You assumption is incorrect. It is enough to take a look at the optimized logical

    == Optimized Logical Plan ==
    Window [sum(val1#1L) windowspecdefinition(class#3, time#0L ASC NULLS FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS cum_sum1#9L, sum(val2#2L) windowspecdefinition(class#3, time#0L ASC NULLS FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS cum_sum2#16L], [class#3], [time#0L ASC NULLS FIRST]
    +- LogicalRDD [time#0L, val1#1L, val2#2L, class#3], false
    

    or physical plan

    == Physical Plan ==
    Window [sum(val1#1L) windowspecdefinition(class#3, time#0L ASC NULLS FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS cum_sum1#9L, sum(val2#2L) windowspecdefinition(class#3, time#0L ASC NULLS FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS cum_sum2#16L], [class#3], [time#0L ASC NULLS FIRST]
    +- *(1) Sort [class#3 ASC NULLS FIRST, time#0L ASC NULLS FIRST], false, 0
       +- Exchange hashpartitioning(class#3, 200)
          +- Scan ExistingRDD[time#0L,val1#1L,val2#2L,class#3]
    

    both clearly indicate that Window is applied only once.