Search code examples
pythonapache-sparkpysparkdatabricks

pyspark where clause can work on a column that doesn't exist


I noticed by accident a weird behavior of pyspark. Basically, it can execute where function on a column that doesn't exist in a dataframe:

print(spark.version)

df = spark.read.format("csv").option("header", True).load("abfss://some_abfs_path/df.csv")
print(type(df), df.columns.__len__(), df.count())

c = df.columns[0] # A column name before renaming
df = df.select(*[col(x).alias(f"{x}_new") for x in df.columns]) # Add suffix to column names

print(c in df.columns)

try:
    df.select(c)
except:
    print("SO THIS DOESN'T WORK, WHICH MAKES SENSE.")

# BUT WHY DOES THIS WORK:
print(df.where(col(c).isNotNull()).count())
# IT'S USING c AS f"{c}_new"
print(df.where(col(f"{c}_new").isNotNull()).count())

Outputs:

3.1.2
<class 'pyspark.sql.dataframe.DataFrame'> 102 1226791
False
SO THIS DOESN'T WORK, WHICH MAKES SENSE.
1226791
1226791

As you can see, the weird part is that when column c doesn't exist in df after column renaming, it can still be used for where function.

My intuition is pyspark compiles where before select renaming under the hood. But it will be a horrible design in that case and doesn't explain why both old and new column names could work.

Would appreciate any insights, thanks.

I'm running things on Azure Databricks.


Solution

  • When in doubt, use df.explain() to figure out what's going on under the hood. This will confirm your intution:

    Spark context available as 'sc' (master = local[*], app id = local-1709748307134).
    SparkSession available as 'spark'.
    >>> df = spark.read.option("header", True).option("inferSchema", True).csv("taxi.csv")
    >>> c = df.columns[0]
    >>> from pyspark.sql.functions import *
    >>> df = df.select(*[col(x).alias(f"{x}_new") for x in df.columns]) 
    >>> df.explain()
    == Physical Plan ==
    *(1) Project [VendorID#17 AS VendorID_new#51, tpep_pickup_datetime#18 AS tpep_pickup_datetime_new#52, tpep_dropoff_datetime#19 AS tpep_dropoff_datetime_new#53, passenger_count#20 AS passenger_count_new#54, trip_distance#21 AS trip_distance_new#55, RatecodeID#22 AS RatecodeID_new#56, store_and_fwd_flag#23 AS store_and_fwd_flag_new#57, PULocationID#24 AS PULocationID_new#58, DOLocationID#25 AS DOLocationID_new#59, payment_type#26 AS payment_type_new#60, fare_amount#27 AS fare_amount_new#61, extra#28 AS extra_new#62, mta_tax#29 AS mta_tax_new#63, tip_amount#30 AS tip_amount_new#64, tolls_amount#31 AS tolls_amount_new#65, improvement_surcharge#32 AS improvement_surcharge_new#66, total_amount#33 AS total_amount_new#67]
    +- FileScan csv [VendorID#17,tpep_pickup_datetime#18,tpep_dropoff_datetime#19,passenger_count#20,trip_distance#21,RatecodeID#22,store_and_fwd_flag#23,PULocationID#24,DOLocationID#25,payment_type#26,fare_amount#27,extra#28,mta_tax#29,tip_amount#30,tolls_amount#31,improvement_surcharge#32,total_amount#33] Batched: false, DataFilters: [], Format: CSV, Location: InMemoryFileIndex(1 paths)[file:/Users/charlie/taxi.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<VendorID:int,tpep_pickup_datetime:string,tpep_dropoff_datetime:string,passenger_count:int,...
    
    
    >>> df = df.where(col(c).isNotNull())
    >>> df.explain()
    == Physical Plan ==
    *(1) Project [VendorID#17 AS VendorID_new#51, tpep_pickup_datetime#18 AS tpep_pickup_datetime_new#52, tpep_dropoff_datetime#19 AS tpep_dropoff_datetime_new#53, passenger_count#20 AS passenger_count_new#54, trip_distance#21 AS trip_distance_new#55, RatecodeID#22 AS RatecodeID_new#56, store_and_fwd_flag#23 AS store_and_fwd_flag_new#57, PULocationID#24 AS PULocationID_new#58, DOLocationID#25 AS DOLocationID_new#59, payment_type#26 AS payment_type_new#60, fare_amount#27 AS fare_amount_new#61, extra#28 AS extra_new#62, mta_tax#29 AS mta_tax_new#63, tip_amount#30 AS tip_amount_new#64, tolls_amount#31 AS tolls_amount_new#65, improvement_surcharge#32 AS improvement_surcharge_new#66, total_amount#33 AS total_amount_new#67]
    +- *(1) Filter isnotnull(VendorID#17)
       +- FileScan csv [VendorID#17,tpep_pickup_datetime#18,tpep_dropoff_datetime#19,passenger_count#20,trip_distance#21,RatecodeID#22,store_and_fwd_flag#23,PULocationID#24,DOLocationID#25,payment_type#26,fare_amount#27,extra#28,mta_tax#29,tip_amount#30,tolls_amount#31,improvement_surcharge#32,total_amount#33] Batched: false, DataFilters: [isnotnull(VendorID#17)], Format: CSV, Location: InMemoryFileIndex(1 paths)[file:/Users/charlie/taxi.csv], PartitionFilters: [], PushedFilters: [IsNotNull(VendorID)], ReadSchema: struct<VendorID:int,tpep_pickup_datetime:string,tpep_dropoff_datetime:string,passenger_count:int,...
    

    From bottom to top: FileScan to read the data, Filter to discard unneeded data, Project to apply the alias. It's a sensible way for Spark to construct its DAG - discard data as eagerly as possible so you don't waste time operating on it - but as you've noticed, it can lead to unexpected behavior. If you'd like to avoid this, use df.checkpoint() to materialize the DataFrame prior to your df.where() statement - this will give you the expected error when you attempt to reference the old column name:

    >>> from pyspark.sql.functions import *
    >>> spark.sparkContext.setCheckpointDir("file:/tmp/")
    >>> df = spark.read.option("header", True).option("inferSchema", True).csv("taxi.csv")
    >>> c = df.columns[0]
    >>> df = df.select(*[col(x).alias(f"{x}_new") for x in df.columns]) 
    >>> df = df.checkpoint()
    >>> df.explain()
    == Physical Plan ==
    *(1) Scan ExistingRDD[VendorID_new#51,tpep_pickup_datetime_new#52,tpep_dropoff_datetime_new#53,passenger_count_new#54,trip_distance_new#55,RatecodeID_new#56,store_and_fwd_flag_new#57,PULocationID_new#58,DOLocationID_new#59,payment_type_new#60,fare_amount_new#61,extra_new#62,mta_tax_new#63,tip_amount_new#64,tolls_amount_new#65,improvement_surcharge_new#66,total_amount_new#67]
    
    
    >>> df = df.where(col(c).isNotNull())
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
      File "/opt/homebrew/opt/apache-spark/libexec/python/pyspark/sql/dataframe.py", line 3325, in filter
        jdf = self._jdf.filter(condition._jc)
      File "/opt/homebrew/opt/apache-spark/libexec/python/lib/py4j-0.10.9.7-src.zip/py4j/java_gateway.py", line 1322, in __call__
      File "/opt/homebrew/opt/apache-spark/libexec/python/pyspark/errors/exceptions/captured.py", line 185, in deco
        raise converted from None
    pyspark.errors.exceptions.captured.AnalysisException: [UNRESOLVED_COLUMN.WITH_SUGGESTION] A column or function parameter with name `VendorID` cannot be resolved. Did you mean one of the following? [`VendorID_new`, `extra_new`, `RatecodeID_new`, `mta_tax_new`, `DOLocationID_new`].;
    'Filter isnotnull('VendorID)
    +- LogicalRDD [VendorID_new#51, tpep_pickup_datetime_new#52, tpep_dropoff_datetime_new#53, passenger_count_new#54, trip_distance_new#55, RatecodeID_new#56, store_and_fwd_flag_new#57, PULocationID_new#58, DOLocationID_new#59, payment_type_new#60, fare_amount_new#61, extra_new#62, mta_tax_new#63, tip_amount_new#64, tolls_amount_new#65, improvement_surcharge_new#66, total_amount_new#67], false
    
    >>>