Search code examples
pysparkapache-spark-sqlflatmap

Join/unfolded mapType column in spark back with the original dataframe


I have a dataframe in (py)Spark, where 1 of the columns is from the type 'map'. That column I want to flatten or split into multiple columns which should be added to the original dataframe. I'm able to unfold the column with flatMap, however I loose the key to join the new dataframe (from the unfolded column) with the original dataframe.

My schema is like this:

    rroot
 |-- key: string (nullable = true)
 |-- metric: map (nullable = false)
 |    |-- key: string
 |    |-- value: float (valueContainsNull = true)

As you can see, the column 'metric' is a map-field. This is the column that I want to flatten. Before flattening it looks like:

+----+---------------------------------------------------+
|key |metric                                             |
+----+---------------------------------------------------+
|123k|Map(metric1 -> 1.3, metric2 -> 6.3, metric3 -> 7.6)|
|d23d|Map(metric1 -> 1.5, metric2 -> 2.0, metric3 -> 2.2)|
|as3d|Map(metric1 -> 2.2, metric2 -> 4.3, metric3 -> 9.0)|
+----+---------------------------------------------------+

To convert that field to columns I do

df2.select('metric').rdd.flatMap(lambda x: x).toDF().show()

which gives

   +------------------+-----------------+-----------------+
|           metric1|          metric2|          metric3|
+------------------+-----------------+-----------------+
|1.2999999523162842|6.300000190734863|7.599999904632568|
|               1.5|              2.0|2.200000047683716|
| 2.200000047683716|4.300000190734863|              9.0|
+------------------+-----------------+-----------------+

However I don't see the key , therefore I don't know how to add this data to the original dataframe.

What I want is:

+----+-------+-------+-------+
| key|metric1|metric2|metric3|
+----+-------+-------+-------+
|123k|    1.3|    6.3|    7.6|
|d23d|    1.5|    2.0|    2.2|
|as3d|    2.2|    4.3|    9.0|
+----+-------+-------+-------+

My question thus is: How can i get df2 back to df (given that i originally don't know df and only have df2)

To make df2:

rdd = sc.parallelize([('123k', 1.3, 6.3, 7.6),
                      ('d23d', 1.5, 2.0, 2.2), 
                      ('as3d', 2.2, 4.3, 9.0)
                          ])
schema = StructType([StructField('key', StringType(), True),
                     StructField('metric1', FloatType(), True),
                     StructField('metric2', FloatType(), True),
                     StructField('metric3', FloatType(), True)])
df = sqlContext.createDataFrame(rdd, schema)


from pyspark.sql.functions import lit, col, create_map
from itertools import chain

metric = create_map(list(chain(*(
    (lit(name), col(name)) for name in df.columns if "metric" in name
)))).alias("metric")


df2 = df.select("key", metric)

Solution

  • I can select a certain key from a maptype by doing:

    df.select('maptypecolumn'.'key')
    

    In my example I did it as follows:

    columns= df2.select('metric').rdd.flatMap(lambda x: x).toDF().columns
    for i in columns:
      df2= df2.withColumn(i,lit(df2.metric[i]))