I have a dataframe
test = spark.createDataFrame([('bn', 12452, 221), ('mb', 14521, 330), ('bn', 2, 220), ('mb', 14520, 331)], ['x', 'y', 'z'])
test.show()
# +---+-----+---+
# | x| y| z|
# +---+-----+---+
# | bn|12452|221|
# | mb|14521|330|
# | bn| 2|220|
# | mb|14520|331|
# +---+-----+---+
I need to count the rows based on a condition:
test.groupBy("x").agg(count(col("y") > 12453), count(col("z") > 230)).show()
which gives
+---+------------------+----------------+
| x|count((y > 12453))|count((z > 230))|
+---+------------------+----------------+
| bn| 2| 2|
| mb| 2| 2|
+---+------------------+----------------+
It's just the count of the rows, not the count for certain conditions.
count
doesn't sum Trues, it only counts the number of non null values. To count the True values, you need to convert the conditions to 1 / 0 and then sum
:
import pyspark.sql.functions as F
cnt_cond = lambda cond: F.sum(F.when(cond, 1).otherwise(0))
test.groupBy('x').agg(
cnt_cond(F.col('y') > 12453).alias('y_cnt'),
cnt_cond(F.col('z') > 230).alias('z_cnt')
).show()
+---+-----+-----+
| x|y_cnt|z_cnt|
+---+-----+-----+
| bn| 0| 0|
| mb| 2| 2|
+---+-----+-----+