Search code examples
apache-sparkpysparkapache-spark-sqlcountdistinct

Count distinct values with conditions


I have a dataframe as below :

+-----------+------------+-------------+-----------+
| id_doctor | id_patient | consumption | type_drug |
+-----------+------------+-------------+-----------+
| d1        | p1         |        12.0 | bhd       |
| d1        | p2         |        10.0 | lsd       |
| d1        | p1         |         6.0 | bhd       |
| d1        | p1         |        14.0 | carboxyl  |
| d2        | p1         |        12.0 | bhd       |
| d2        | p1         |        13.0 | bhd       |
| d2        | p2         |        12.0 | lsd       |
| d2        | p1         |         6.0 | bhd       |
| d2        | p2         |        12.0 | bhd       |
+-----------+------------+-------------+-----------+

I want to count distinct patients that take bhd with a consumption < 16.0 for each doctor.

I tried the following query2, but it doesn't work:

dataframe.groupBy(col("id_doctor")).agg(
    countDistinct(col("id_patient")).where(
        col("type_drug") == "bhd" & col("consumption") < 16.0
    )
)

Solution

  • Just use the where on your dataframe - this version delete the id_doctor where the count is 0 :

    dataframe.where(
        col("type_drug") == "bhd" & col("consumption") < 16.0
    ).groupBy(
        col("id_doctor")
    ).agg(
        countDistinct(col("id_patient"))
    )
    

    Using this syntax, you can keep all the "doctors" :

    dataframe.withColumn(
        "fg",
        F.when(
            (col("type_drug") == "bhd") 
            & (col("consumption") < 16.0),
            col("id_patient")
        )
    ).groupBy(
        col("id_doctor")
    ).agg(
        countDistinct(col("fg"))
    )