Search code examples
apache-sparkpysparkapache-spark-sqlnullaggregate

Return null in SUM if some values are null


I have a case where I may have null values in the column that needs to be summed up in a group.

If I encounter a null in a group, I want the sum of that group to be null. But PySpark by default seems to ignore the null rows and sum-up the rest of the non-null values.

For example:

enter image description here

dataframe = dataframe.groupBy('dataframe.product', 'dataframe.price') \
                     .agg(f.sum('price'))

Expected output is:

enter image description here

But I am getting:

enter image description here


Solution

  • sum function returns NULL only if all values are null for that column otherwise nulls are simply ignored.

    You can use conditional aggregation, if count(price) == count(*) it means there are no nulls and we return sum(price). Else, null is returned:

    from pyspark.sql import functions as F
    
    df.groupby("product").agg(
        F.when(F.count("price") == F.count("*"), F.sum("price")).alias("sum_price")
    ).show()
    
    #+-------+---------+
    #|product|sum_price|
    #+-------+---------+
    #|      B|      200|
    #|      C|     null|
    #|      A|      250|
    #+-------+---------+
    

    Since Spark 3.0+, one can also use any function:

    df.groupby("product").agg(
        F.when(~F.expr("any(price is null)"), F.sum("price")).alias("sum_price")
    ).show()