I have a dataframe with two columns (speed and distance) and a list of speed thresholds. For each speed threshold, I want to sum the distance travelled above the speed threshold. In reality I have a large dataset I want to apply this to and I am unable to come up with a quick solution.
I have included an example dataframe below. The expected outcome from that dataframe would be the following, ideally shown in a dataframe with threshold and distance covered.
The solution should generalise well to a large dataframe.
7 = 267, > 8 = 240, > 9 = 220, > 10 188, > 11 = 181, > 12 = 173
data = [(7,4), (8,5), (9,6), (10,7), (11,8), (7,9), (9,10), (14,11), (4,12), (16,13), (7,14), (8,15), (9,16), (2,17), (12,18), (14,19), (16,20), (24,21), (25,22), (6,23), (27,24), (28,25)]
columns = ["speed", "distance"]
thresholds = [7, 8, 9, 10, 11, 12]
df = spark.createDataFrame(data = data, schema = columns)
You can do this with a case-when column expression and grouping with a list comprehension.
I would prefer this for large datasets over, say, cross-joining the dataframe with the thresholds.
Solution:
import pyspark.sql.functions as F
def sum_if_greater_than_by(sum_col, by_col, threshold):
cond_expr = F.when(F.col(by_col)>=threshold, F.col(sum_col)).otherwise(F.lit(0))
return F.sum(cond_expr)
df_agg = df.groupBy().agg(*[sum_if_greater_than_by('distance', 'speed', threshold).alias(f">={threshold}") for threshold in thresholds])
Result:
The result is a dataframe:
df_agg.show()
+---+---+---+----+----+----+
|>=7|>=8|>=9|>=10|>=11|>=12|
+---+---+---+----+----+----+
|267|240|220| 188| 181| 173|
+---+---+---+----+----+----+