Search code examples
pythondataframepysparkapache-spark-sqlconditional-statements

PySpark 3.3.0 - Aggregate sum with condition to avoid self join


Given the following dataframe structure:

+----------+-----+-------+
|  endPoint|count|outcome|
+----------+-----+-------+
|  getBooks|    3|success|
|  getBooks|    1|failure|
|getClasses|    0|success|
|getClasses|    4|failure|
+----------+-----+-------+

I'm trying to aggregate the data to get a failure rate. My resulting dataframe would look like this.

+----------+-----------+
|  endPoint|failureRate|
+----------+-----------+
|  getBooks|       0.25|
|getClasses|          1|
+----------+-----------+

I'm currently able to do this by creating a second dataframe which filters out the success rows, then join the two dataframes back together and create a new column that divides the sum of the failed count (for that endpoint) with the sum of the total count.

I'm trying to find a way to avoid creating a separate dataframe and then having to re-join them back together as it seems expensive and unnecessary. Is there a way to sum columns conditionally? I've been playing around with the syntax but am getting stuck.

If I could do something like this:

df.groupBy("endPoint").sum("count").when(outcome = "failure"))

that would be ideal but I'm having trouble with this and wonder if I'm missing something fundamental here.


Solution

  • You can use a when() within the sum aggregate.

    data_sdf. \
        groupBy('end_point'). \
        agg(func.sum(func.when(func.col('outcome') == 'failure', func.col('count'))).alias('failure_count'),
            func.sum(func.when(func.col('outcome') == 'success', func.col('count'))).alias('success_count')
            ). \
        withColumn('failure_rate', 
                   func.col('failure_count') / (func.col('failure_count') + func.col('success_count'))
                   ). \
        show()
    
    # +----------+-------------+-------------+------------+
    # | end_point|failure_count|success_count|failure_rate|
    # +----------+-------------+-------------+------------+
    # |getClasses|            4|            0|         1.0|
    # |  getBooks|            1|            3|        0.25|
    # +----------+-------------+-------------+------------+