Search code examples
dataframepysparkapache-spark-sql

Aggregate (sum) consecutive rows where the number of consecutive rows is defined in a dataframe column


Initial Dataframe:

Every "id" has the same "range" value, I have to execute the following aggregation:

  • grouping on column "id" a dynamic range of consecutive rows (col "range")
  • sum the column "amount" for that range of consecutive rows
  • ordering based on "id","row_num" asc
id row_num range amount
a 1 2 2
a 2 2 10
a 3 2 5
a 4 2 5
a 5 2 1
b 1 3 10
b 2 3 2
b 3 3 10
b 4 3 6
b 5 3 10
b 6 3 4

target result:

id row_num range amount sum
a 1 2 2 12
a 2 2 10 15
a 3 2 5 10
a 4 2 5 6
a 5 2 1 1
b 1 3 10 22
b 2 3 2 18
b 3 3 10 26
b 4 3 6 20
b 5 3 10 14
b 6 3 4 4

I've tried with Window.rowsBetween but cannot use the "range" column value dynamically.

I don't want to use "case...when" or "hardcoded like" syntax as the sample Dataframe is a simplified version of the real one.

I'm pretty sure it can be solved with "advanced" windowing / partitioning.

Any tips?


Solution

  • One possible approach is to collect the amounts per id, then slice them based on ('row_num', 'range') and then sum:

    from pyspark.sql import functions as F
    from pyspark.sql import types as T
    from pyspark.sql import Window
    
    data = [
        ('a', 1, 2, 2),
        ('a', 2, 2, 10),
        ('a', 3, 2, 5),
        ('a', 4, 2, 5),
        ('a', 5, 2, 1),
        ('b', 1, 3, 10),
        ('b', 2, 3, 2),
        ('b', 3, 3, 10),
        ('b', 4, 3, 6),
        ('b', 5, 3, 10),
        ('b', 6, 3, 4)
    ]
    
    schema = T.StructType([
        T.StructField('id', T.StringType(), True),
        T.StructField('row_num', T.IntegerType(), True),
        T.StructField('range', T.IntegerType(), True),
        T.StructField('amount', T.IntegerType(), True)
    ])
    
    df = spark.createDataFrame(data, schema)
    
    w = (
        Window
        .partitionBy('id')
        .orderBy('row_num')
        .rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing)
    )
    amounts = F.collect_list('amount').over(w)
    row_amounts = F.slice(amounts, start='row_num', length='range')
    sum_amounts = F.aggregate(
        col=row_amounts,
        initialValue=F.lit(0),
        merge=lambda x, y: x + y
    )
    df = df.withColumn('sum', sum_amounts)
    df.show(20, False)
    
    # +---+-------+-----+------+---+
    # |id |row_num|range|amount|sum|
    # +---+-------+-----+------+---+
    # |a  |1      |2    |2     |12 |
    # |a  |2      |2    |10    |15 |
    # |a  |3      |2    |5     |10 |
    # |a  |4      |2    |5     |6  |
    # |a  |5      |2    |1     |1  |
    # |b  |1      |3    |10    |22 |
    # |b  |2      |3    |2     |18 |
    # |b  |3      |3    |10    |26 |
    # |b  |4      |3    |6     |20 |
    # |b  |5      |3    |10    |14 |
    # |b  |6      |3    |4     |4  |
    # +---+-------+-----+------+---+