Search code examples
pysparkfilterapache-spark-sql

find rows where columns mismatch


How could I find rows where two columns don't equal (like you can with pandas)?

data = [("John", "Doe"), (None, "Doe"), ("John", None), (None, None)]
df = spark.createDataFrame(data, ["first", "last"])

df
first     last
John       Doe
null       Doe
John       null
null       null

Expected output:

first     last
John       Doe
null       Doe
John       null

I tried

df_filtered = df.filter(col("first") != col("last"))

Solution

  • There are multiple ways to get your expected output in PySpark as follows:

    1. Using equal_null:
      df.filter(~equal_null(col("first"), col("last")))
      
    2. Using nvl:
      df.filter(nvl(col("first"), lit("")) != nvl(col("last"), lit("")))
      
    3. Using ifnull:
      df.filter(ifnull(col("first"), lit("")) != ifnull(col("last"), lit("")))
      

    You can do this in many other ways like @s.polam did in their answer, using expr or just SparkSQL. For me, the best approach is using equal_null because it matches your requirement exactly and doesn't require any additional literals (empty strings) like the other approaches.