Search code examples
dataframescalaapache-sparkpysparkapache-spark-sql

Add columns in complex data type nested data frame


I have a data frame which has simple type and complex structure like struct, array of struct, array of array of struct. I have an expected schema(StructType as root). I need to mold the data frame with respect to schema. If columns are missing which are not in data frame, we need to add those columns to DF with default value.

For Example:

Expected Schema :-

root
 - struct
   - a: String
   - b: Int
   - c: array of struct
     - e: String
     - f: String

DataFrame Schema :-

root
 - struct
   - a: String
   - b: Int
   - c: array of String

How to do it recursively at each level of element using spark?

Here is what I am doing in the code.

   // Function to add missing columns recursively
def addMissingColumns(df1: DataFrame, df2: DataFrame, currentPath: String = ""): DataFrame = {
    val df1Schema = df1.schema
    val df2Schema = df2.schema

    val missingColumns = df1Schema.fields.filterNot { field1 =>
        df2Schema.fields.exists { field2 =>
            field1.name == field2.name && field1.dataType == field2.dataType
        }
    }

    val dfWithMissingColumns = missingColumns.foldLeft(df2) { (accDF, field) =>
        val colName = currentPath + field.name
        val dataType = field.dataType

        dataType match {
            case _: StructType =>
                val updatedDF = if (accDF.columns.contains(colName)) {
                    val missingCols = addMissingColumns(df1.selectExpr(s"$colName.*"), accDF.selectExpr(s"$colName.*"), colName + ".")
                    missingCols
                } else {
                    accDF.withColumn(colName, lit(null).cast(dataType)) // You can cast to the appropriate data type
                }
                updatedDF
            case _ =>
                // Handle non-struct types by adding them with null values if missing
                val updatedDF = if (accDF.columns.contains(colName)) {
                    accDF
                } else {
                    accDF.withColumn(field.name, lit(null).cast(dataType))
                }
                updatedDF
        }
    }
    dfWithMissingColumns.printSchema()
    dfWithMissingColumns.show()
    dfWithMissingColumns
}


def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
        .appName("YourAppName")
        .master("local[1]") // Use all available CPU cores
        .getOrCreate()

    // Sample DataFrames
    val df1 = spark.createDataFrame(Seq(
        (1, "Alice", Array(1, 2), (10, "New York", ("dddd","ggg"))),
        (2, "Bob", Array(3, 4), (20, "San Francisco", ("ddd","ggg")))
    )).toDF("id", "name", "numbers", "location")
    df1.printSchema()

    val df2 = spark.createDataFrame(Seq(
        (1, "Alice", Array(1, 2), (10, "New York 2")),
        (2, "Bob", Array(3, 4), (20, "San Francisco 2")),
        (3, "Ankur", Array(3, 4), (20, "Netherlands 2"))
    )).toDF("id", "name", "numbers", "location")

    df2.printSchema()

    addMissingColumns(df1, df2, "").printSchema()
}

The issue with the code is it is returning only location column. Here is the output schema -

 root
 |-- _1: integer (nullable = true)
 |-- _2: string (nullable = true)
 |-- location._3: struct (nullable = true)
 |    |-- _1: string (nullable = true)
 |    |-- _2: string (nullable = true)

Solution

  • I think you can convert df2 to json data using to_json function & then convert back to columns by passing schema of df1 in from_json function. By default this approach add missing columns with default values.

    Logic - df2.select(from_json(to_json(struct($"*")), df1.schema).as("data")).select($"data.*")

    Please check below & It is quit faster than recursive logic.. :)

    scala> :paste
    // Entering paste mode (ctrl-D to finish)
    
        val df1 = spark.createDataFrame(Seq(
            (1, "Alice", Array(1, 2), (10, "New York", ("dddd","ggg"))),
            (2, "Bob", Array(3, 4), (20, "San Francisco", ("ddd","ggg")))
        )).toDF("id", "name", "numbers", "location")
    
    
    scala> df1.show(false)
    +---+-----+-------+-------------------------------+
    |id |name |numbers|location                       |
    +---+-----+-------+-------------------------------+
    |1  |Alice|[1, 2] |{10, New York, {dddd, ggg}}    |
    |2  |Bob  |[3, 4] |{20, San Francisco, {ddd, ggg}}|
    +---+-----+-------+-------------------------------+
    
    scala> df1.printSchema
    root
     |-- id: integer (nullable = false)
     |-- name: string (nullable = true)
     |-- numbers: array (nullable = true)
     |    |-- element: integer (containsNull = false)
     |-- location: struct (nullable = true)
     |    |-- _1: integer (nullable = false)
     |    |-- _2: string (nullable = true)
     |    |-- _3: struct (nullable = true)
     |    |    |-- _1: string (nullable = true)
     |    |    |-- _2: string (nullable = true)
    
    
    scala> :paste
    // Entering paste mode (ctrl-D to finish)
    
    val df2 = spark.createDataFrame(Seq(
            (1, "Alice", Array(1, 2), (10, "New York 2")),
            (2, "Bob", Array(3, 4), (20, "San Francisco 2")),
            (3, "Ankur", Array(3, 4), (20, "Netherlands 2"))
        )).toDF("id", "name", "numbers", "location")
    
    
    scala> df2.printSchema
    root
     |-- id: integer (nullable = false)
     |-- name: string (nullable = true)
     |-- numbers: array (nullable = true)
     |    |-- element: integer (containsNull = false)
     |-- location: struct (nullable = true)
     |    |-- _1: integer (nullable = false)
     |    |-- _2: string (nullable = true)
    
    
    scala> df2.show(false)
    +---+-----+-------+---------------------+
    |id |name |numbers|location             |
    +---+-----+-------+---------------------+
    |1  |Alice|[1, 2] |{10, New York 2}     |
    |2  |Bob  |[3, 4] |{20, San Francisco 2}|
    |3  |Ankur|[3, 4] |{20, Netherlands 2}  |
    +---+-----+-------+---------------------+
    
    scala> val updatedDF = df2.select(
      from_json(
        to_json(
          struct($"*")
        ), 
        df1.schema
      )
      .as("data")
    )
    .select($"data.*")
    
    scala> updatedDF.printSchema
    root
     |-- id: integer (nullable = true)
     |-- name: string (nullable = true)
     |-- numbers: array (nullable = true)
     |    |-- element: integer (containsNull = true)
     |-- location: struct (nullable = true)
     |    |-- _1: integer (nullable = true)
     |    |-- _2: string (nullable = true)
     |    |-- _3: struct (nullable = true)
     |    |    |-- _1: string (nullable = true)
     |    |    |-- _2: string (nullable = true)
    
    scala> updatedDF.show(false)
    +---+-----+-------+---------------------------+
    |id |name |numbers|location                   |
    +---+-----+-------+---------------------------+
    |1  |Alice|[1, 2] |{10, New York 2, null}     |
    |2  |Bob  |[3, 4] |{20, San Francisco 2, null}|
    |3  |Ankur|[3, 4] |{20, Netherlands 2, null}  |
    +---+-----+-------+---------------------------+