I have a case where I may have null values in the column that needs to be summed up in a group.
If I encounter a null in a group, I want the sum of that group to be null. But PySpark by default seems to ignore the null rows and sum-up the rest of the non-null values.
For example:
dataframe = dataframe.groupBy('dataframe.product', 'dataframe.price') \
.agg(f.sum('price'))
Expected output is:
But I am getting:
sum
function returns NULL only if all values are null for that column otherwise nulls are simply ignored.
You can use conditional aggregation, if count(price) == count(*)
it means there are no nulls and we return sum(price)
. Else, null is returned:
from pyspark.sql import functions as F
df.groupby("product").agg(
F.when(F.count("price") == F.count("*"), F.sum("price")).alias("sum_price")
).show()
#+-------+---------+
#|product|sum_price|
#+-------+---------+
#| B| 200|
#| C| null|
#| A| 250|
#+-------+---------+
Since Spark 3.0+, one can also use any
function:
df.groupby("product").agg(
F.when(~F.expr("any(price is null)"), F.sum("price")).alias("sum_price")
).show()