Search code examples
scalaapache-sparkapache-spark-sql

adding new column to dataframe of Array[String] type based on condition, spark scala


I have the following dataframe -

colA colB
A1 B1
A2 B2
A3 B3

colA: String, colB: String

Also, I have a Map[String, Array[String]] I want to add a new column 'colC' containing values of Map corresponding to the values present in colB (reference column)

map that i have Map[String, Array[String]]("B1" -> Array("C1","C2"), "B2" -> Array("C3","C4"))

So, I expect something like this -

colA colB colC
A1 B1 (C1, C2)
A2 B2 (C3, C4)
A3 B3 (B3)
import sqlContext.implicits._ 
import org.apache.spark.sql.functions._

val col = List("colA","colB")
val test_schema = StructType(col.map(fieldName => StructField(fieldName, StringType, true)))
val data = Seq(Seq("A1","B1"), Seq("A2","B2"), Seq("A3","B3"))
val rowsRDD = spark.sparkContext.parallelize(data).map(Row.fromSeq)
var df = spark.createDataFrame(rowsRDD, test_schema)

val map = Map[String, Array[String]]("B1" -> Array("C1", "C2"), "B2" -> Array("C3", "C4"))

df = df.withColumn("colC", when(map.contains($"colB"), map($"colB")).otherwise(Array($"colB")))

got this error -

:60: error: type mismatch; found : org.apache.spark.sql.Column required: String

How to add complex data type such as Map or Array to a column in spark (scala) dataframe


Solution

  • Although you can use maps directly (e.g. Quality's support) it's often just easier (and faster) to provide Spark a dataset representing the map e.g.:

    val map = Map[String, Array[String]]("B1" -> Array("C1", "C2"), "B2" -> Array("C3", "C4"))
    val mapAsDF = map.toSeq.toDF("key","value")
    df = df.join(mapAsDF, $"colB" === $"key", "left_outer")
    df.show
    

    yielding:

    +----+----+----+--------+
    |colA|colB| key|   value|
    +----+----+----+--------+
    |  A1|  B1|  B1|[C1, C2]|
    |  A2|  B2|  B2|[C3, C4]|
    |  A3|  B3|NULL|    NULL|
    +----+----+----+--------+
    

    now you can use a simple if to correct your new colC:

    df = df.select($"colA", $"colB", when(isnull($"value"), array($"colB")).otherwise($"value").as("colC"))
    df.show
    

    yielding:

    +----+----+--------+
    |colA|colB|    colC|
    +----+----+--------+
    |  A1|  B1|[C1, C2]|
    |  A2|  B2|[C3, C4]|
    |  A3|  B3|    [B3]|
    +----+----+--------+