I am trying to import python files from databricks notebook. I am able to import the function but it keep giving me error - NameError: name 'col' is not defined.
I have a file /Workspace/Shared/Common/common.py. The code in that file is:
def cast_datatypes(df_target, df_source):
cols = df_target.columns
types = [f.dataType for f in df_target.schema.fields]
for i, c in enumerate(cols):
df_source = df_source.withColumn(c, col(c).cast(types[i]))
df_source = df_source.select(cols)
return df_source
I am calling this function from a notebook, and did below approaches
Approach 1:
import sys
sys.path.append("/Workspace/Shared/Common/common.py")
from common import cast_datatypes
from pyspark.sql.functions import col
<Spark Session declaration>
df_final = cast_datatypes(df_target_A, df_source_A)
Approach 2:
spark.sparkContext.addPyFile("/Workspace/Shared/Common/common.py")
import common as C
from pyspark.sql.functions import col
df_final = C.cast_datatypes(df_target_A, df_source_A)
Both of the approach was able to import the function, but failed to use 'col'. The error I am getting is:
File /Workspace/Shared/Common/common.py:13, in cast_datatypes(df_target, df_source)
11 types = [f.dataType for f in df_target.schema.fields]
12 for i, c in enumerate(cols):
---> 13 df_source = df_source.withColumn(c, col(c).cast(types[i]))
14 df_source = df_source.select(cols)
15 return df_source
NameError: name 'col' is not defined
Do we need to pass all the arguments that import function is using? If yes, then how did I modularize my notebooks in databricks?
What @Reyk said is right.
You need to import col
inside common.py
.
But still you got error, even i got the same error.
After altering the common.py
you need to restart
the cluster or do Detach and Re-attach
the notebook.
/Workspace/Shared/Common/common.py
def cal(df):
re = df.withColumn("tmp", col("trip_distance") + lit(1))
return re
When i tried below code i got error.
import sys
sys.path.append("/Workspace/Shared/Common/")
df = spark.sql("select * from samples.nyctaxi.trips")
from common import cal
from pyspark.sql.functions import col
display(cal(df))
So i added altered common.py
like below.
from pyspark.sql.functions import col,lit
def cal(df):
re = df.withColumn("tmp", col("trip_distance") + lit(1))
return re
Again tried to call the function go same error.
But if i see contents of the file it is updated.
So, do any one of the below and re-run your code.
Output:
tmp
column is created.
If you again do any update on that you need to re-attach notebook or restart.
If you don't want to restart or re-attach you need to reload the module common
.
After altering the code in common.py
just use below code.
import sys
import importlib
sys.path.append("/Workspace/Shared/Common/")
df = spark.sql("select * from samples.nyctaxi.trips")
if "common" in sys.modules:
del sys.modules["common"]
from common import cal
display(cal(df))