Search code examples
dataframeapache-sparkpysparkcountconditional-statements

PySpark count rows on condition


I have a dataframe

test = spark.createDataFrame([('bn', 12452, 221), ('mb', 14521, 330), ('bn', 2, 220), ('mb', 14520, 331)], ['x', 'y', 'z'])
test.show()
# +---+-----+---+
# |  x|    y|  z|
# +---+-----+---+
# | bn|12452|221|
# | mb|14521|330|
# | bn|    2|220|
# | mb|14520|331|
# +---+-----+---+

I need to count the rows based on a condition:

test.groupBy("x").agg(count(col("y") > 12453), count(col("z") > 230)).show()

which gives

+---+------------------+----------------+
|  x|count((y > 12453))|count((z > 230))|
+---+------------------+----------------+
| bn|                 2|               2|
| mb|                 2|               2|
+---+------------------+----------------+

It's just the count of the rows, not the count for certain conditions.


Solution

  • count doesn't sum Trues, it only counts the number of non null values. To count the True values, you need to convert the conditions to 1 / 0 and then sum:

    import pyspark.sql.functions as F
    
    cnt_cond = lambda cond: F.sum(F.when(cond, 1).otherwise(0))
    test.groupBy('x').agg(
        cnt_cond(F.col('y') > 12453).alias('y_cnt'), 
        cnt_cond(F.col('z') > 230).alias('z_cnt')
    ).show()
    +---+-----+-----+
    |  x|y_cnt|z_cnt|
    +---+-----+-----+
    | bn|    0|    0|
    | mb|    2|    2|
    +---+-----+-----+