Search code examples
scalaapache-sparkapache-spark-sql

Create new columns from values of other columns in Scala Spark


I have an input dataframe:

inputDF=

+--------------------------+-----------------------------+
| info (String)            |   chars (Seq[String])       |
+--------------------------+-----------------------------+
|weight=100,height=70      | [weight,height]             |
+--------------------------+-----------------------------+
|weight=92,skinCol=white   | [weight,skinCol]            |
+--------------------------+-----------------------------+
|hairCol=gray,skinCol=white| [hairCol,skinCol]           |
+--------------------------+-----------------------------+

How to I get this dataframe as an output? I do not know in advance what are the strings contained in chars column

outputDF=

+--------------------------+-----------------------------+-------+-------+-------+-------+
| info (String)            |   chars (Seq[String])       | weight|height |skinCol|hairCol|
+--------------------------+-----------------------------+-------+-------+-------+-------+
|weight=100,height=70      | [weight,height]             |  100  | 70    | null  |null   |
+--------------------------+-----------------------------+-------+-------+-------+-------+
|weight=92,skinCol=white   | [weight,skinCol]            |  92   |null   |white  |null   |
+--------------------------+-----------------------------+-------+-------+-------+-------+
|hairCol=gray,skinCol=white| [hairCol,skinCol]           |null   |null   |white  |gray   |
+--------------------------+-----------------------------+-------+-------+-------+-------+

I also would like to save the following Seq[String] as a variable, but without using .collect() function on the dataframes.

val aVariable: Seq[String] = [weight, height, skinCol, hairCol]

Solution

  • You create another dataframe pivoting on the key of info column than join it back using an id column:

    import spark.implicits._
    val data = Seq(
      ("weight=100,height=70", Seq("weight", "height")),
      ("weight=92,skinCol=white", Seq("weight", "skinCol")),
      ("hairCol=gray,skinCol=white", Seq("hairCol", "skinCol"))
    )
    
    val df = spark.sparkContext.parallelize(data).toDF("info", "chars")
      .withColumn("id", monotonically_increasing_id() + 1)
    
    val pivotDf = df
      .withColumn("tmp", split(col("info"), ","))
      .withColumn("tmp", explode(col("tmp")))
      .withColumn("val1", split(col("tmp"), "=")(0))
      .withColumn("val2", split(col("tmp"), "=")(1)).select("id", "val1", "val2")
      .groupBy("id").pivot("val1").agg(first(col("val2")))
    
    df.join(pivotDf, Seq("id"), "left").drop("id").show(false)
    
    
    +--------------------------+------------------+-------+------+-------+------+
    |info                      |chars             |hairCol|height|skinCol|weight|
    +--------------------------+------------------+-------+------+-------+------+
    |weight=100,height=70      |[weight, height]  |null   |70    |null   |100   |
    |hairCol=gray,skinCol=white|[hairCol, skinCol]|gray   |null  |white  |null  |
    |weight=92,skinCol=white   |[weight, skinCol] |null   |null  |white  |92    |
    +--------------------------+------------------+-------+------+-------+------+
    

    for your second question you can get those values in a dataframe like this:

    df.withColumn("tmp", explode(split(col("info"), ",")))
      .withColumn("values", split(col("tmp"), "=")(0)).select("values").distinct().show()
    
    +-------+
    | values|
    +-------+
    | height|
    |hairCol|
    |skinCol|
    | weight|
    +-------+
    

    but you cannot get them in Seq variable without using collect, that just impossible.