Search code examples
arraysdataframeapache-sparkstructapache-spark-sql

Spark - How to add an element to an array of structs


Having this schema:

root
 |-- Elems: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- Elem: integer (nullable = true)
 |    |    |-- Desc: string (nullable = true)

How can we add a new field like that?

root
 |-- Elems: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- New_field: integer (nullable = true)
 |    |    |-- Elem: integer (nullable = true)
 |    |    |-- Desc: string (nullable = true)

I've already done that with a simple struct (more detail at the bottom of this post), but I'm not able to do it with an array of struct.

This is the code in order to test it:

val schema = new StructType()
    .add("Elems", ArrayType(new StructType()
        .add("Elem", IntegerType)
        .add("Desc", StringType)
    ))

val dataDS = Seq("""
{
  "Elems": [ {"Elem":1, "Desc": "d1"}, {"Elem":2, "Desc": "d2"}, {"Elem":3, "Desc": "d3"} ]
}
""").toDS()

val df = spark.read.schema(schema).json(dataDS.rdd)

df.show(false)
+---------------------------+
|Elems                      |
+---------------------------+
|[[1, d1], [2, d2], [3, d3]]|
+---------------------------+

Once we have the DF, the best approach I have is creating a Struct of arrays for each element:

val mod_df = df.withColumn("modif_elems", 
     struct(
         array(lit("")).as("New_field"),
         col("Elems.Elem"),
         col("Elems.Desc")
                            ))

mod_df.show(false)
+---------------------------+-----------------------------+
|Elems                      |modif_elems                  |
+---------------------------+-----------------------------+
|[[1, d1], [2, d2], [3, d3]]|[[], [1, 2, 3], [d1, d2, d3]]|
+---------------------------+-----------------------------+


mod_df.printSchema
root
 |-- Elems: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- Elem: integer (nullable = true)
 |    |    |-- Desc: string (nullable = true)
 |-- modif_elems: struct (nullable = false)
 |    |-- New_field: array (nullable = false)
 |    |    |-- element: string (containsNull = false)
 |    |-- Elem: array (nullable = true)
 |    |    |-- element: integer (containsNull = true)
 |    |-- Desc: array (nullable = true)
 |    |    |-- element: string (containsNull = true)

We don't lose any data but this is not exactly what I want.

Update: Workaround in PD1.


Bonus track: Modifying a struct (not in an array)

The code is almost the same but now we don't have an array of struct, so it's easier to modify the struct:

val schema = new StructType()
    .add("Elems", new StructType()
        .add("Elem", IntegerType)
        .add("Desc", StringType)
    )


val dataDS = Seq("""
{
  "Elems": {"Elem":1, "Desc": "d1"}
}
""").toDS()    


val df = spark.read.schema(schema).json(dataDS.rdd)
df.show(false)
+-------+
|Elems  |
+-------+
|[1, d1]|
+-------+

df.printSchema
root
 |-- Elems: struct (nullable = true)
 |    |-- Elem: integer (nullable = true)
 |    |-- Desc: string (nullable = true)

In this case, in order to add the field we need to create another struct:

val mod_df = df
    .withColumn("modif_elems", 
                struct(
                    lit("").alias("New_field"),
                    col("Elems.Elem"),
                    col("Elems.Desc")
                    )
               )

mod_df.show
+-------+-----------+
|  Elems|modif_elems|
+-------+-----------+
|[1, d1]|  [, 1, d1]|
+-------+-----------+

mod_df.printSchema
root
 |-- Elems: struct (nullable = true)
 |    |-- Elem: integer (nullable = true)
 |    |-- Desc: string (nullable = true)
 |-- modif_elems: struct (nullable = false)
 |    |-- New_field: string (nullable = false)
 |    |-- Elem: integer (nullable = true)
 |    |-- Desc: string (nullable = true)


PD1:

Ok, I have used arrays_zip Spark SQL function (new in 2.4.0 version) and it's nearly what I want but I can't see how we can change the elements names (as or alias doesn't work here):

val mod_df = df.withColumn("modif_elems", 
        arrays_zip(
            array(lit("")).as("New_field"),
            col("Elems.Elem").as("Elem"),
            col("Elems.Desc").alias("Desc")
                    )
        )

mod_df.show(false)
+---------------------------+---------------------------------+
|Elems                      |modif_elems                      |
+---------------------------+---------------------------------+
|[[1, d1], [2, d2], [3, d3]]|[[, 1, d1], [, 2, d2], [, 3, d3]]|
+---------------------------+---------------------------------+

mod_df.printSchema
root
 |-- Elems: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- Elem: integer (nullable = true)
 |    |    |-- Desc: string (nullable = true)
 |-- modif_elems: array (nullable = true)
 |    |-- element: struct (containsNull = false)
 |    |    |-- 0: string (nullable = true)
 |    |    |-- 1: integer (nullable = true)
 |    |    |-- 2: string (nullable = true)

Struct modif_elems shoud contains 3 elements named New_field, Elem and Desc, not 0, 1 and 2.


Solution

  • Spark 3.1+

    withField can be used (together with transform)

    • Scala

      Input:

      val df = spark.createDataFrame(Seq((1, "2")))
          .select(
              array(struct(
                  col("_1").as("Elem"),
                  col("_2").as("Desc")
              )).as("Elems")
          )
      df.printSchema()
      // root
      //  |-- Elems: array (nullable = true)
      //  |    |-- element: struct (containsNull = true)
      //  |    |    |-- Elem: integer (nullable = true)
      //  |    |    |-- Desc: string (nullable = true)
      

      Script

      val df2 = df.withColumn(
          "Elems",
          transform(
              $"Elems",
              x => x.withField("New_field", lit(3))
          )
      )
      df2.printSchema()
      // root
      //  |-- Elems: array (nullable = false)
      //  |    |-- element: struct (containsNull = false)
      //  |    |    |-- Elem: long (nullable = true)
      //  |    |    |-- Desc: string (nullable = true)
      //  |    |    |-- New_field: integer (nullable = false)
      
    • PySpark

      Input:

      from pyspark.sql import functions as F
      df = spark.createDataFrame([(1, "2",)]) \
          .select(
              F.array(F.struct(
                  F.col("_1").alias("Elem"),
                  F.col("_2").alias("Desc")
              )).alias("Elems")
          )
      df.printSchema()
      # root
      #  |-- Elems: array (nullable = true)
      #  |    |-- element: struct (containsNull = true)
      #  |    |    |-- Elem: integer (nullable = true)
      #  |    |    |-- Desc: string (nullable = true)
      

      Script:

      df = df.withColumn(
          "Elems",
          F.transform(
              F.col("Elems"),
              lambda x: x.withField("New_field", F.lit(3))
          )
      )
      df.printSchema()
      # root
      #  |-- Elems: array (nullable = false)
      #  |    |-- element: struct (containsNull = false)
      #  |    |    |-- Elem: long (nullable = true)
      #  |    |    |-- Desc: string (nullable = true)
      #  |    |    |-- New_field: integer (nullable = false)