Search code examples
dataframefilterpysparkconditional-statements

pyspark filter a dataframe using the min value for each id


Given a table like the following:

+--+------------------+-----------+
|id|     diagnosis_age|  diagnosis|
+--+------------------+-----------+
| 1|2.1843037179180302| 315.320000|
| 1|  2.80033330216659| 315.320000|
| 1|   2.8222365762732| 315.320000|
| 1|  5.64822705794013| 325.320000|
| 1| 5.686557787521759| 335.320000|
| 2|  5.70572315231258| 315.320000|
| 2| 5.724888517103389| 315.320000|
| 3| 5.744053881894209| 315.320000|
| 3|5.7604813374292005| 315.320000|
| 3|  5.77993740687426| 315.320000|
+--+------------------+-----------+

I'm trying to reduce the amount of records per id by only considering the diagnoses with the least diagnosis age per id. In SQL you would join the table to itself, something like:

SELECT a.id, a.diagnosis_age, a.diagnosis
    FROM tbl1 a
INNER JOIN
(SELECT id, MIN(diagnosis_age) AS min_diagnosis_age
    FROM tbl1
        GROUP BY id) b
ON b.id = a.id
WHERE b.min_diagnosis_age = a.diagnosis_age

If it were an rdd you could do something like:

rdd.map(lambda x: (x["id"], [(x["diagnosis_age"], x["diagnosis"])]))\
.reduceByKey(lambda x, y: x + y)\
.map(lambda x: (x[0], [i for i in x[1] if i[0] == min(x[1])[0]]))

How would you achieve the same using only spark dataframe operations? If this is possible? Specifically no sql/ rdd operations.

thanks


Solution

  • You can use window with first function, and then filter out all others.

    from pyspark.sql import functions as F
    from pyspark.sql.window import Window
    w=Window().partitionBy("id").orderBy("diagnosis_age")
    df.withColumn("least_age", F.first("diagnosis_age").over(w))\
    .filter("diagnosis_age=least_age").drop("least_age").show()
    
    +---+------------------+---------+
    | id|     diagnosis_age|diagnosis|
    +---+------------------+---------+
    |  1|2.1843037179180302|   315.32|
    |  3| 5.744053881894209|   315.32|
    |  2|  5.70572315231258|   315.32|
    +---+------------------+---------+
    

    You can also do this without window function, use groupBy min and first:

    from pyspark.sql import functions as F
    df.orderBy("diagnosis_age").groupBy("id")\
    .agg(F.min("diagnosis_age").alias("diagnosis_age"), F.first("diagnosis").alias("diagnosis"))\
    .show()
    +---+------------------+---------+
    | id|     diagnosis_age|diagnosis|
    +---+------------------+---------+
    |  1|2.1843037179180302|   315.32|
    |  3| 5.744053881894209|   315.32|
    |  2|  5.70572315231258|   315.32|
    +---+------------------+---------+
    

    Note that I am ordering By diagnosis_age before the groupyBy to handle those cases where your required diagnosis value does not appear in the first row of the group. However , if your data is already ordered by the diagnosis_age you can use above code without the orderBy.