Search code examples
apache-sparkpyspark

Group by sum larger than aggregated column?


If I have eg a table as shown below:

ID Threshold Value
1 2 1
1 2 2
1 2 3
1 2 4
2 4 1
2 4 3
2 4 5

How could I use spark to obtain the following?

ID Threshold total_above_threshold
1 2 7
2 4 5

I have thought of a workaround where I create an additional column to flag out values below or equal the threshold, and aggregate the unflagged values. But does spark provide a good way (eg. window function?) that doesnt require creating additional columns?


Solution

  • No window function is needed - you can simply use groupBy as follows:

    from pyspark.sql import SparkSession
    from pyspark.sql.functions import col, sum
    
    spark = SparkSession.builder.getOrCreate()
    
    data = [
        (1, 2, 1),
        (1, 2, 2),
        (1, 2, 3),
        (1, 2, 4),
        (2, 4, 1),
        (2, 4, 3),
        (2, 4, 5)
    ]
    
    df = spark.createDataFrame(data, ["ID", "Threshold", "Value"])
    
    result_df = (
        df
        .filter(col("Value") > col("Threshold"))
        .groupBy("ID", "Threshold")
        .agg(sum(col("Value")).alias("total_above_threshold"))
    )
    
    result_df.show()
    
    # +---+---------+---------------------+
    # | ID|Threshold|total_above_threshold|
    # +---+---------+---------------------+
    # |  1|        2|                    7|
    # |  2|        4|                    5|
    # +---+---------+---------------------+