Search code examples
pythonapache-sparkpyspark

Unexpected output from least (source data includes nulls)


Inspired by this answer, I want to find the row-wise minimum between several date columns, and return the column name.

I'm getting unexpected results when a row contains NULLs, which I thought least excluded, specifically rows 2-5 in this toy example:

import datetime as dt

from pyspark.sql import Row
from pyspark.sql.types import StructField, StructType, DateType

schema = StructType([
    StructField("date1", DateType(), True),
    StructField("date2", DateType(), True),
    StructField("date3", DateType(), True)
])
row1 = Row(dt.date(2024, 1, 1), dt.date(2024, 1, 2), dt.date(2024, 1, 3))
row2 = Row(None, None, dt.date(2024, 1, 3))
row3 = Row(None, dt.date(2024, 1, 1), dt.date(2024, 1, 2))
row4 = Row(None, None, None)
row5 = Row(dt.date(2024, 1, 1), None, None)

df = spark.createDataFrame([row1, row2, row3, row4, row5], schema)

def row_min(*cols):
    cols_ = [F.struct(F.col(c).alias("value"), F.lit(c).alias("col")) for c in cols]
    return F.least(*cols_)

df.withColumn("output", row_min('date1', 'date2', 'date3').col).show()

returns

+----------+----------+----------+------+
|     date1|     date2|     date3|output|
+----------+----------+----------+------+
|2024-01-01|2024-01-02|2024-01-03| date1|
|      NULL|      NULL|2024-01-03| date1|
|      NULL|2024-01-01|2024-01-02| date1|
|      NULL|      NULL|      NULL| date1|
|2024-01-01|      NULL|      NULL| date2|
+----------+----------+----------+------+

but the desired output is:

+----------+----------+----------+------+
|     date1|     date2|     date3|output|
+----------+----------+----------+------+
|2024-01-01|2024-01-02|2024-01-03| date1|
|      NULL|      NULL|2024-01-03| date3|
|      NULL|2024-01-01|2024-01-02| date2|
|      NULL|      NULL|      NULL|  NULL|
|2024-01-01|      NULL|      NULL| date1|
+----------+----------+----------+------+

Solution

  • You are comparing struct<value:date,col:string> in which the value field might be NULLs. least function ignore NULL only if the whole struct is NULL, not one of the fields.

    For Spark SQL sorting, by default NULL values appear first in ascending order, and last in descending order. So one quick fix is to negate the date (for example using a date interval (F.expr("date'1970'")-F.col(c)).alias("value"), and then apply the greatest function.

    def row_min(*cols):
        cols_ = [F.struct((F.expr("date'1970'")-F.col(c)).alias("value"), F.lit(c).alias("col")) for c in cols ]
        return F.greatest(*cols_)
    
    row_least = row_min('date1', 'date2', 'date3')
    df.withColumn("output", F.when(F.isnull(row_least.value),None).otherwise(row_least.col)).show()
    +----------+----------+----------+------+
    |     date1|     date2|     date3|output|
    +----------+----------+----------+------+
    |2024-01-01|2024-01-02|2024-01-03| date1|
    |      null|      null|2024-01-03| date3|
    |      null|2024-01-01|2024-01-02| date2|
    |      null|      null|      null|  null|
    |2024-01-01|      null|      null| date1|
    +----------+----------+----------+------+