Search code examples
pythonazurepysparkjupyter-notebookazure-databricks

Import Python file in databricks notebook


I am trying to import python files from databricks notebook. I am able to import the function but it keep giving me error - NameError: name 'col' is not defined.

I have a file /Workspace/Shared/Common/common.py. The code in that file is:

def cast_datatypes(df_target, df_source):
    cols = df_target.columns
    types = [f.dataType for f in df_target.schema.fields]
    for i, c in enumerate(cols):
        df_source = df_source.withColumn(c, col(c).cast(types[i]))
    df_source = df_source.select(cols)
    return df_source

I am calling this function from a notebook, and did below approaches

Approach 1:

import sys
sys.path.append("/Workspace/Shared/Common/common.py")
from common import cast_datatypes
from pyspark.sql.functions import col

<Spark Session declaration>

df_final = cast_datatypes(df_target_A, df_source_A)

Approach 2:

spark.sparkContext.addPyFile("/Workspace/Shared/Common/common.py")
import common as C
from pyspark.sql.functions import col

df_final = C.cast_datatypes(df_target_A, df_source_A)

Both of the approach was able to import the function, but failed to use 'col'. The error I am getting is:

File /Workspace/Shared/Common/common.py:13, in cast_datatypes(df_target, df_source)
     11 types = [f.dataType for f in df_target.schema.fields]
     12 for i, c in enumerate(cols):
---> 13     df_source = df_source.withColumn(c, col(c).cast(types[i]))
     14 df_source = df_source.select(cols)
     15 return df_source

NameError: name 'col' is not defined

Do we need to pass all the arguments that import function is using? If yes, then how did I modularize my notebooks in databricks?


Solution

  • What @Reyk said is right. You need to import col inside common.py. But still you got error, even i got the same error. After altering the common.py you need to restart the cluster or do Detach and Re-attach the notebook.

    /Workspace/Shared/Common/common.py

    def cal(df):
        re = df.withColumn("tmp", col("trip_distance") + lit(1))
        return re
    

    When i tried below code i got error.

    import sys
    sys.path.append("/Workspace/Shared/Common/")
    df = spark.sql("select * from samples.nyctaxi.trips")
    
    from common import cal
    from pyspark.sql.functions import col
    display(cal(df))
    

    enter image description here

    So i added altered common.py like below.

    from pyspark.sql.functions import col,lit
    
    def cal(df):
        re = df.withColumn("tmp", col("trip_distance") + lit(1))
        return re
    

    Again tried to call the function go same error.

    enter image description here

    But if i see contents of the file it is updated.

    enter image description here

    So, do any one of the below and re-run your code.

    enter image description here

    Output:

    tmp column is created.

    enter image description here

    If you again do any update on that you need to re-attach notebook or restart.

    If you don't want to restart or re-attach you need to reload the module common.

    After altering the code in common.py just use below code.

    import sys
    import importlib
    sys.path.append("/Workspace/Shared/Common/")
    df = spark.sql("select * from samples.nyctaxi.trips")
    if "common" in sys.modules:
        del sys.modules["common"]
    
    from common import cal
    display(cal(df))
    

    enter image description here