Search code examples
pyspark

Filter data using multiple thresholds across single column summing other column


I have a dataframe with two columns (speed and distance) and a list of speed thresholds. For each speed threshold, I want to sum the distance travelled above the speed threshold. In reality I have a large dataset I want to apply this to and I am unable to come up with a quick solution.

I have included an example dataframe below. The expected outcome from that dataframe would be the following, ideally shown in a dataframe with threshold and distance covered.

The solution should generalise well to a large dataframe.

  • 7 = 267, > 8 = 240, > 9 = 220, > 10 188, > 11 = 181, > 12 = 173

data = [(7,4), (8,5), (9,6), (10,7), (11,8), (7,9), (9,10), (14,11), (4,12), (16,13), (7,14), (8,15), (9,16), (2,17), (12,18), (14,19), (16,20), (24,21), (25,22), (6,23), (27,24), (28,25)]
columns = ["speed", "distance"]
thresholds = [7, 8, 9, 10, 11, 12]
df = spark.createDataFrame(data = data, schema = columns)

Solution

  • You can do this with a case-when column expression and grouping with a list comprehension.

    I would prefer this for large datasets over, say, cross-joining the dataframe with the thresholds.

    Solution:

    import pyspark.sql.functions as F
    
    def sum_if_greater_than_by(sum_col, by_col, threshold):
        cond_expr = F.when(F.col(by_col)>=threshold, F.col(sum_col)).otherwise(F.lit(0))
        return F.sum(cond_expr)
    
    df_agg = df.groupBy().agg(*[sum_if_greater_than_by('distance', 'speed', threshold).alias(f">={threshold}") for threshold in thresholds])
    

    Result:

    The result is a dataframe:

    df_agg.show()
    
    +---+---+---+----+----+----+
    |>=7|>=8|>=9|>=10|>=11|>=12|
    +---+---+---+----+----+----+
    |267|240|220| 188| 181| 173|
    +---+---+---+----+----+----+