I noticed by accident a weird behavior of pyspark. Basically, it can execute where
function on a column that doesn't exist in a dataframe:
print(spark.version)
df = spark.read.format("csv").option("header", True).load("abfss://some_abfs_path/df.csv")
print(type(df), df.columns.__len__(), df.count())
c = df.columns[0] # A column name before renaming
df = df.select(*[col(x).alias(f"{x}_new") for x in df.columns]) # Add suffix to column names
print(c in df.columns)
try:
df.select(c)
except:
print("SO THIS DOESN'T WORK, WHICH MAKES SENSE.")
# BUT WHY DOES THIS WORK:
print(df.where(col(c).isNotNull()).count())
# IT'S USING c AS f"{c}_new"
print(df.where(col(f"{c}_new").isNotNull()).count())
Outputs:
3.1.2
<class 'pyspark.sql.dataframe.DataFrame'> 102 1226791
False
SO THIS DOESN'T WORK, WHICH MAKES SENSE.
1226791
1226791
As you can see, the weird part is that when column c
doesn't exist in df
after column renaming, it can still be used for where
function.
My intuition is pyspark compiles where
before select
renaming under the hood. But it will be a horrible design in that case and doesn't explain why both old and new column names could work.
Would appreciate any insights, thanks.
I'm running things on Azure Databricks.
When in doubt, use df.explain()
to figure out what's going on under the hood. This will confirm your intution:
Spark context available as 'sc' (master = local[*], app id = local-1709748307134).
SparkSession available as 'spark'.
>>> df = spark.read.option("header", True).option("inferSchema", True).csv("taxi.csv")
>>> c = df.columns[0]
>>> from pyspark.sql.functions import *
>>> df = df.select(*[col(x).alias(f"{x}_new") for x in df.columns])
>>> df.explain()
== Physical Plan ==
*(1) Project [VendorID#17 AS VendorID_new#51, tpep_pickup_datetime#18 AS tpep_pickup_datetime_new#52, tpep_dropoff_datetime#19 AS tpep_dropoff_datetime_new#53, passenger_count#20 AS passenger_count_new#54, trip_distance#21 AS trip_distance_new#55, RatecodeID#22 AS RatecodeID_new#56, store_and_fwd_flag#23 AS store_and_fwd_flag_new#57, PULocationID#24 AS PULocationID_new#58, DOLocationID#25 AS DOLocationID_new#59, payment_type#26 AS payment_type_new#60, fare_amount#27 AS fare_amount_new#61, extra#28 AS extra_new#62, mta_tax#29 AS mta_tax_new#63, tip_amount#30 AS tip_amount_new#64, tolls_amount#31 AS tolls_amount_new#65, improvement_surcharge#32 AS improvement_surcharge_new#66, total_amount#33 AS total_amount_new#67]
+- FileScan csv [VendorID#17,tpep_pickup_datetime#18,tpep_dropoff_datetime#19,passenger_count#20,trip_distance#21,RatecodeID#22,store_and_fwd_flag#23,PULocationID#24,DOLocationID#25,payment_type#26,fare_amount#27,extra#28,mta_tax#29,tip_amount#30,tolls_amount#31,improvement_surcharge#32,total_amount#33] Batched: false, DataFilters: [], Format: CSV, Location: InMemoryFileIndex(1 paths)[file:/Users/charlie/taxi.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<VendorID:int,tpep_pickup_datetime:string,tpep_dropoff_datetime:string,passenger_count:int,...
>>> df = df.where(col(c).isNotNull())
>>> df.explain()
== Physical Plan ==
*(1) Project [VendorID#17 AS VendorID_new#51, tpep_pickup_datetime#18 AS tpep_pickup_datetime_new#52, tpep_dropoff_datetime#19 AS tpep_dropoff_datetime_new#53, passenger_count#20 AS passenger_count_new#54, trip_distance#21 AS trip_distance_new#55, RatecodeID#22 AS RatecodeID_new#56, store_and_fwd_flag#23 AS store_and_fwd_flag_new#57, PULocationID#24 AS PULocationID_new#58, DOLocationID#25 AS DOLocationID_new#59, payment_type#26 AS payment_type_new#60, fare_amount#27 AS fare_amount_new#61, extra#28 AS extra_new#62, mta_tax#29 AS mta_tax_new#63, tip_amount#30 AS tip_amount_new#64, tolls_amount#31 AS tolls_amount_new#65, improvement_surcharge#32 AS improvement_surcharge_new#66, total_amount#33 AS total_amount_new#67]
+- *(1) Filter isnotnull(VendorID#17)
+- FileScan csv [VendorID#17,tpep_pickup_datetime#18,tpep_dropoff_datetime#19,passenger_count#20,trip_distance#21,RatecodeID#22,store_and_fwd_flag#23,PULocationID#24,DOLocationID#25,payment_type#26,fare_amount#27,extra#28,mta_tax#29,tip_amount#30,tolls_amount#31,improvement_surcharge#32,total_amount#33] Batched: false, DataFilters: [isnotnull(VendorID#17)], Format: CSV, Location: InMemoryFileIndex(1 paths)[file:/Users/charlie/taxi.csv], PartitionFilters: [], PushedFilters: [IsNotNull(VendorID)], ReadSchema: struct<VendorID:int,tpep_pickup_datetime:string,tpep_dropoff_datetime:string,passenger_count:int,...
From bottom to top: FileScan
to read the data, Filter
to discard unneeded data, Project
to apply the alias. It's a sensible way for Spark to construct its DAG - discard data as eagerly as possible so you don't waste time operating on it - but as you've noticed, it can lead to unexpected behavior. If you'd like to avoid this, use df.checkpoint()
to materialize the DataFrame prior to your df.where()
statement - this will give you the expected error when you attempt to reference the old column name:
>>> from pyspark.sql.functions import *
>>> spark.sparkContext.setCheckpointDir("file:/tmp/")
>>> df = spark.read.option("header", True).option("inferSchema", True).csv("taxi.csv")
>>> c = df.columns[0]
>>> df = df.select(*[col(x).alias(f"{x}_new") for x in df.columns])
>>> df = df.checkpoint()
>>> df.explain()
== Physical Plan ==
*(1) Scan ExistingRDD[VendorID_new#51,tpep_pickup_datetime_new#52,tpep_dropoff_datetime_new#53,passenger_count_new#54,trip_distance_new#55,RatecodeID_new#56,store_and_fwd_flag_new#57,PULocationID_new#58,DOLocationID_new#59,payment_type_new#60,fare_amount_new#61,extra_new#62,mta_tax_new#63,tip_amount_new#64,tolls_amount_new#65,improvement_surcharge_new#66,total_amount_new#67]
>>> df = df.where(col(c).isNotNull())
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/opt/homebrew/opt/apache-spark/libexec/python/pyspark/sql/dataframe.py", line 3325, in filter
jdf = self._jdf.filter(condition._jc)
File "/opt/homebrew/opt/apache-spark/libexec/python/lib/py4j-0.10.9.7-src.zip/py4j/java_gateway.py", line 1322, in __call__
File "/opt/homebrew/opt/apache-spark/libexec/python/pyspark/errors/exceptions/captured.py", line 185, in deco
raise converted from None
pyspark.errors.exceptions.captured.AnalysisException: [UNRESOLVED_COLUMN.WITH_SUGGESTION] A column or function parameter with name `VendorID` cannot be resolved. Did you mean one of the following? [`VendorID_new`, `extra_new`, `RatecodeID_new`, `mta_tax_new`, `DOLocationID_new`].;
'Filter isnotnull('VendorID)
+- LogicalRDD [VendorID_new#51, tpep_pickup_datetime_new#52, tpep_dropoff_datetime_new#53, passenger_count_new#54, trip_distance_new#55, RatecodeID_new#56, store_and_fwd_flag_new#57, PULocationID_new#58, DOLocationID_new#59, payment_type_new#60, fare_amount_new#61, extra_new#62, mta_tax_new#63, tip_amount_new#64, tolls_amount_new#65, improvement_surcharge_new#66, total_amount_new#67], false
>>>