Given a table like the following:
+--+------------------+-----------+
|id| diagnosis_age| diagnosis|
+--+------------------+-----------+
| 1|2.1843037179180302| 315.320000|
| 1| 2.80033330216659| 315.320000|
| 1| 2.8222365762732| 315.320000|
| 1| 5.64822705794013| 325.320000|
| 1| 5.686557787521759| 335.320000|
| 2| 5.70572315231258| 315.320000|
| 2| 5.724888517103389| 315.320000|
| 3| 5.744053881894209| 315.320000|
| 3|5.7604813374292005| 315.320000|
| 3| 5.77993740687426| 315.320000|
+--+------------------+-----------+
I'm trying to reduce the amount of records per id by only considering the diagnoses with the least diagnosis age per id. In SQL you would join the table to itself, something like:
SELECT a.id, a.diagnosis_age, a.diagnosis
FROM tbl1 a
INNER JOIN
(SELECT id, MIN(diagnosis_age) AS min_diagnosis_age
FROM tbl1
GROUP BY id) b
ON b.id = a.id
WHERE b.min_diagnosis_age = a.diagnosis_age
If it were an rdd you could do something like:
rdd.map(lambda x: (x["id"], [(x["diagnosis_age"], x["diagnosis"])]))\
.reduceByKey(lambda x, y: x + y)\
.map(lambda x: (x[0], [i for i in x[1] if i[0] == min(x[1])[0]]))
How would you achieve the same using only spark dataframe operations? If this is possible? Specifically no sql/ rdd operations.
thanks
You can use window
with first
function, and then filter
out all others.
from pyspark.sql import functions as F
from pyspark.sql.window import Window
w=Window().partitionBy("id").orderBy("diagnosis_age")
df.withColumn("least_age", F.first("diagnosis_age").over(w))\
.filter("diagnosis_age=least_age").drop("least_age").show()
+---+------------------+---------+
| id| diagnosis_age|diagnosis|
+---+------------------+---------+
| 1|2.1843037179180302| 315.32|
| 3| 5.744053881894209| 315.32|
| 2| 5.70572315231258| 315.32|
+---+------------------+---------+
You can also do this without window function, use groupBy
min
and first
:
from pyspark.sql import functions as F
df.orderBy("diagnosis_age").groupBy("id")\
.agg(F.min("diagnosis_age").alias("diagnosis_age"), F.first("diagnosis").alias("diagnosis"))\
.show()
+---+------------------+---------+
| id| diagnosis_age|diagnosis|
+---+------------------+---------+
| 1|2.1843037179180302| 315.32|
| 3| 5.744053881894209| 315.32|
| 2| 5.70572315231258| 315.32|
+---+------------------+---------+
Note that I am ordering By diagnosis_age
before the groupyBy
to handle those cases where your required diagnosis value does not appear in the first row of the group. However , if your data is already ordered by the diagnosis_age
you can use above code without the orderBy
.