Search code examples
pythonapache-sparkpysparkglobal-variablesuser-defined-functions

Access global variable from UDF (User Defined Function) in python in spark


I am trying to alter a global variable from inside a pyspark.sql.functions.udf function in python. But, the change in not getting reflected in the global variable.

The reproducible example along with outputs is:

counter = 0

schema2 = StructType([\
    StructField("id", IntegerType(), True),
    StructField("name", StringType(), True)   
])

data2 = [(1, "A"), (2, "B")]

df = spark.createDataFrame(data = data2, schema = schema2)

def myFunc(column):
    global counter
    counter = counter + 1
    return column + 5
  
myFuncUDF = udf(myFunc, IntegerType())

display(df.withColumn('id1', myFuncUDF(df.id)))

Output:

id name id1
1 A 6
2 B 7

When I print the counter variable, it remains 0.

Can anyone help me to know how to access a global variable inside a UDF and alter the global variable on each call to the UDF? or whether it is not possible?


Solution

  • We can create a custom set accumulator to store the id's.

    class SetAccumulator(AccumulatorParam):
        def zero(self, init_value: set()):
            return init_value
        
        def addInPlace(self, v1: set, v2: set):
            return v1.union(v2)
    

    To initialise the set set accumulator, and add to the accumulator from each thread where our spark job is running while transforming the dataframe. Reference -

    #accumulator initialization
    acc = spark.sparkContext.accumulator(set(), SetAccumulator())
    
    schema2 = StructType([\
        StructField("id", IntegerType(), True),
        StructField("name", StringType(), True)   
    ])
    
    data2 = [(1, "A"), (2, "B")]
    
    df = spark.createDataFrame(data = data2, schema = schema2)
    
    #access accumulator as a global variable inside the udf 
    def myFunc(column):
        global acc
        int_set = set()
        int_set.add(column)
        acc += int_set
        return column + 5
      
    myFuncUDF = udf(myFunc, IntegerType())