apache-sparkdictionarypysparkapache-spark-sqluser-defined-functions

Select values from MapType Column in UDF PySpark


I am trying to extract the value from the MapType column in PySpark dataframe in the UDF function.

Below is the PySpark dataframe:

+-----------+------------+-------------+
|CUSTOMER_ID|col_a       |col_b        |
+-----------+------------+-------------+
|    100    |{0.0 -> 1.0}| {0.2 -> 1.0}|
|    101    |{0.0 -> 1.0}| {0.2 -> 1.0}|
|    102    |{0.0 -> 1.0}| {0.2 -> 1.0}|
|    103    |{0.0 -> 1.0}| {0.2 -> 1.0}|
|    104    |{0.0 -> 1.0}| {0.2 -> 1.0}|
|    105    |{0.0 -> 1.0}| {0.2 -> 1.0}|
+-----------+------------+-------------+
df.printSchema()

# root
#  |-- CUSTOMER_ID: integer (nullable = true)
#  |-- col_a: map (nullable = true)
#  |    |-- key: float
#  |    |-- value: float (valueContainsNull = true)
#  |-- col_b: map (nullable = true)
#  |    |-- key: float
#  |    |-- value: float (valueContainsNull = true)

Below is the UDF

@F.udf(T.FloatType())
def test(col):
    return col[1]

Below is the code:

df_temp=df_temp.withColumn('test',test(F.col('col_a')))

I am not getting the value from the col_a column when I pass it to the UDF. Can anyone explain this?


Solution

  • The notation col[1] will successfully extract the value from map type column when:

    • col is a column expression
    • 1 is an existent key in the map

    In your case, your map does not have a key=1, that's why it doesn't work.

    from pyspark.sql import functions as F
    df = spark.createDataFrame([(100, {0.0: 1.0},)], ['CUSTOMER_ID', 'col_a'])
    df.show()
    # +-----------+------------+
    # |CUSTOMER_ID|       col_a|
    # +-----------+------------+
    # |        100|{0.0 -> 1.0}|
    # +-----------+------------+
    
    df = df.withColumn('col_a_0', F.col('col_a')[0])
    df = df.withColumn('col_a_1', F.col('col_a')[1])
    
    df.show()
    # +-----------+------------+-------+-------+
    # |CUSTOMER_ID|       col_a|col_a_0|col_a_1|
    # +-----------+------------+-------+-------+
    # |        100|{0.0 -> 1.0}|    1.0|   null|
    # +-----------+------------+-------+-------+