Search code examples
pythonapache-sparkpyspark

Nested condition on simple data


I have a dataframe having 3 columns, two boolean type and one column as string.

from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, BooleanType, StringType

# Create a Spark session
spark = SparkSession.builder \
    .appName("Condition Test") \
    .getOrCreate()

# Sample data
data = [
    (True, 'CA', None),
    (True, 'US', None),
    (False, 'CA', None)
]

# Define schema for the dataframe
schema = StructType([
    StructField("is_flag", BooleanType(), nullable=False),
    StructField("country", StringType(), nullable=False),
    StructField("rule", BooleanType(), nullable=True)  
])

# Create DataFrame
df = spark.createDataFrame(data, schema=schema)

# Show initial dataframe
df.show(truncate=False)

condition = (
    (~col("is_flag")) |
    ((col("is_flag")) & (trim(col("country")) != 'CA') & nvl(col("rule"),lit(False)) != True)
)

df = df.filter(condition)

# show filtered dataframe
df.show(truncate=False)

Above code is returning below data.

+-------+-------+----+
|is_flag|country|rule|
+-------+-------+----+
|true   |CA     |NULL|
|true   |US     |NULL|
|false  |CA     |NULL|
+-------+-------+----+

However since I'm explicitely mentioning ((col("is_flag")) & (trim(col("country")) != 'CA') & nvl(col("rule"),lit(False)) != True) ie. trim(col("country")) != 'CA' when is_flag is true, I'm not expecting first record, I need results like below.

+-------+-------+----+
|is_flag|country|rule|
+-------+-------+----+
|true   |US     |NULL|
|false  |CA     |NULL|
+-------+-------+----+

Question: why the above code also returns 1st record |true |CA |NULL|, where as we have explicitly mentioned country != 'CA' when is_flag is true (boolean).

However same when confition is applied via sql returns expected result.

select *
from df
where (
       not is_flag or 
       (is_flag and trim(country) != 'CA' and nvl(rule,False) != True)
      )

Solution

  • The condition is invalid because it doesn't consider operator precedence and hence the wrong result. Operator & has a higher precedence than !=.

    Here's the updated condition with parentheses:

    condition = (
        (~col("is_flag")) |
        ((col("is_flag")) & (trim(col("country")) != 'CA') & (nvl(col("rule"),lit(False)) != True))
    )
    

    Output:

    +-------+-------+----+
    |is_flag|country|rule|
    +-------+-------+----+
    |true   |US     |NULL|
    |false  |CA     |NULL|
    +-------+-------+----+