Search code examples
joinpyspark

Using max/min on columns with null values


I have the following dataframe df -

+---+-------------------+----+
| id|               date|site|
+---+-------------------+----+
|100|2020-03-24 00:00:00|   a|
|100|2019-08-30 00:00:00|   a|
|100|2020-03-24 00:00:00|   b|
|101|2019-12-20 00:00:00|NULL|
|101|2019-12-20 00:00:00|   a|
|102|2019-04-14 00:00:00|NULL| 
|103|2019-09-28 00:00:00|   c|
+---+-------------------+----+

where date is TimestampType and site is a string.

My goal is to remove duplicate rows with the same id with the following logic: For each id leave only the row with the latest date (dates cannot be null). If there are two or more rows with that date, choose an arbitrary row with non null site (sites can be null), so it looks like the easiest way of doing so is to take the max or min value.

I've managed to leave only the latest date with this following code (taken from here) -

w = Window.partitionBy('id')
df2 = df.withColumn('maxDate', f.max('date').over(w)) \
    .where(f.col('date') == f.col('maxDate')) \
    .drop('maxDate')
    

This results with df2 -

+---+-------------------+----+
| id|               date|site|
+---+-------------------+----+
|100|2020-03-24 00:00:00|   a|
|100|2020-03-24 00:00:00|   b|
|101|2019-12-20 00:00:00|NULL|
|101|2019-12-20 00:00:00|   a|
|102|2019-04-14 00:00:00|NULL|
|103|2019-09-28 00:00:00|   c|
+---+-------------------+----+

Then I've tried to do similar thing with the site -

w = Window.partitionBy('id')
df3 = df2.withColumn('maxSite', f.max('site').over(w)) \
    .where(f.col('site') == f.col('maxSite')) \
    .drop('maxSite')
    

But the result is df3 -

+---+-------------------+----+
| id|               date|site|
+---+-------------------+----+
|100|2020-03-24 00:00:00|   b|
|101|2019-12-20 00:00:00|   a|
|103|2019-09-28 00:00:00|   c|
+---+-------------------+----+

id 102 is missing and it looks like max (or min) skips null values, and since id 102 has only one row with null value, it doesn't find it.
I managed to overcome it by doing left anti join between the first dataframe and the last one, so I get all the ids that have only null sites, and then used union to unite all the results -

df_left_anti = df.join(df3, df['id'] == df3['id'], 'left_anti')
df_all = df3.union(df_left_anti).orderBy('id')

My question is - Is it the right and efficient way of doing it? Can I use max/min with nulls?


Solution

  • Instead of using a max, I would simply use a row_number. You compute the row_number for each id, order by date and site. If you do not specify the ordering, it will be equivalent to a "min" but here, I use order desc to be equivalent to a max. You can ensure you do not prioritize nulls by using the method desc_nulls_last instead of simple desc.

    Then, you just filter on the first line (there will always be a first line for a given id).

    from pyspark.sql import functions as F, Window as W
    
    df.withColumn(
        "rnk",
        F.row_number().over(
            W.partitionBy("id").orderBy(
                F.col("site").desc_nulls_last(),
                F.col("date").desc(),
            )
        ),
    ).where(F.col("rnk") == 1).drop("rnk")
    
    +---+----------+----+
    | id|      date|site|
    +---+----------+----+
    |100|2020-03-24|   b|
    |101|2019-12-20|   a|
    |102|2019-04-14|null|
    |103|2019-09-28|   c|
    +---+----------+----+