Search code examples
apache-sparkpysparkapache-spark-sql

cumulative product in pySpark data frame


I have the following spark DataFrame:

+---+---+
|  a|  b|
+---+---+     
|  1|  1|  
|  1|  2|  
|  1|  3|
|  1|  4|
+---+---+  

I want to make another column named "c" which contains the cumulative product of "b" over "a". The resulting DataFrame should look like:

+---+---+---+
|  a|  b|  c|
+---+---+---+     
|  1|  1|  1|
|  1|  2|  2|
|  1|  3|  6|
|  1|  4| 24|
+---+---+---+  

How can this be done?


Solution

  • You have to set an order column. In your case I used column 'b'

    from pyspark.sql import functions as F, Window, types
    from functools import reduce
    from operator import mul
    
    df = spark.createDataFrame([(1, 1), (1, 2), (1, 3), (1, 4), (1, 5)], ['a', 'b'])
    
    order_column = 'b'
    
    window = Window.orderBy(order_column)
    
    expr = F.col('a') * F.col('b')
    
    mul_udf = F.udf(lambda x: reduce(mul, x), types.IntegerType())
    
    df = df.withColumn('c', mul_udf(F.collect_list(expr).over(window)))
    
    df.show()
    
    +---+---+---+
    |  a|  b|  c|
    +---+---+---+
    |  1|  1|  1|
    |  1|  2|  2|
    |  1|  3|  6|
    |  1|  4| 24|
    |  1|  5|120|
    +---+---+---+