Search code examples
pythonapache-sparkdataframepysparkbroadcast

Accessing broadcast Dictionary from within Dataframe methods in pyspark


I have a broadcast a dicitonary that I would like to use for mapping column values in my DataFrame. Let's say I call withColumn() method for that.

I can only get it to work with a UDF, but not directly:

sc = SparkContext()
ss = SparkSession(sc)
df = ss.createDataFrame( [ "a", "b" ], StringType() ).toDF("key")
# +---+                                                                           
# |key|
# +---+
# |  a|
# |  b|
# +---+
thedict={"a":"A","b":"B","c":"C"}
thedict_bc=sc.broadcast(thedict)

Referencing with a literal or using UDF works fine:

df.withColumn('upper',lit(thedict_bc.value.get('c',"--"))).show()
# +---+-----+
# |key|upper|
# +---+-----+
# |  a|    C|
# |  b|    C|
# +---+-----+
df.withColumn('upper',udf(lambda x : thedict_bc.value.get(x,"--"), StringType())('key')).show()
# +---+-----+
# |key|upper|
# +---+-----+
# |  a|    A|
# |  b|    B|
# +---+-----+

However, accessing the dictionary directly from the command doesn't:

df.withColumn('upper',lit(thedict_bc.value.get(col('key'),"--"))).show()
# +---+-----+
# |key|upper|
# +---+-----+
# |  a|   --|
# |  b|   --|
# +---+-----+
df.withColumn('upper',lit(thedict_bc.value.get(df.key,"--"))).show()
# +---+-----+
# |key|upper|
# +---+-----+
# |  a|   --|
# |  b|   --|
# +---+-----+
df.withColumn('upper',lit(thedict_bc.value.get(df.key.cast("string"),"--"))).show()
# +---+-----+
# |key|upper|
# +---+-----+
# |  a|   --|
# |  b|   --|
# +---+-----+

Am I missing something obvious?


Solution

  • TL;DR You're mixing up things which belong to completely different context. Symbolic SQL expressions (lit, col, etc.) and plain Python code.

    You are mixing up the contexts. Following line:

    thedict_bc.value.get(col('key'),"--")))
    

    is executed in Python on the driver and is literally a local dictionary lookup. thedict doesn't contain col('key') (literal, there is no expansion involved) you you always get default value.

    Personally I would use a simple join:

    lookup = sc.parallelize(thedict.items()).toDF(["key", "upper"])
    df.join(lookup, ["key"], "left").na.fill("upper", "--").show()
    
    +---+-----+                                                                     
    |key|upper|
    +---+-----+
    |  b|    B|
    |  a|    A|
    +---+-----+
    

    but udf (as you've already established) or literal map would work as well:

    from pyspark.sql.functions import coalesce, create_map
    from itertools import chain
    
    thedict_col = create_map(*chain.from_iterable(
        (lit(k), lit(v)) for k, v in thedict.items()
    ))
    
    df.withColumn('upper', coalesce(thedict_col[col("key")], lit("--"))).show()
    
    +---+-----+
    |key|upper|
    +---+-----+
    |  a|    A|
    |  b|    B|
    +---+-----+
    

    Notes:

    • Of course if you want to convert to upper case, just use pyspark.sql.functions.upper.
    • Using some_broadcast.value as an argument for the function won't work at all. Variable substitution will applied locally and broadcasting won't be utilized. value should be called in the function body, so it is executed in the executor context.