Search code examples
scalaapache-spark

How to read the input json using a schema file and populate default value if column not being found in scala?


Input Dataframe

val input_json="""[{"orderid":"111","customers":{"customerId":"123"},"Offers":[{"Offerid":"1"},{"Offerid":"2"}]}]""";
val inputdataRdd = spark.sparkContext.parallelize(input_json :: Nil);
val inputdataRdddf = spark.read.json(inputdataRdd);
inputdataRdddf.show();

schema df

 val schema_json="""[{"orders":{"order_id":{"path":"orderid","type":"string","nullable":false},"customer_id":{"path":"customers.customerId","type":"int","nullable":false,"default_value":"null"},"offer_id":{"path":"Offers.Offerid","type":"string","nullable":false},"eligible":{"path":"eligible.eligiblestatus","type":"string","nullable":true,"default_value":"not eligible"}},"products":{"product_id":{"path":"product_id","type":"string","nullable":false},"product_name":{"path":"products.productname","type":"string","nullable":false}}}]""";
val schemaRdd = spark.sparkContext.parallelize(schema_json :: Nil);
val schemaRdddf = spark.read.json(schemaRdd);
schemaRdddf.show();

enter image description here

using the schema df , i want to read all the columns from the input dataframe.

  1. if the nullable key is true then i want to populate the column with default value (if the column is not present or not having any data). In the above example, eligible.eligiblestatus is not present hence i want to populate with some default value
  2. Also i want to change the data type of the columns based in type key defined in the schema json. . e.g customer id is of type INT in schema json but in input dataframe it is coming as string, hence i want to cast it to integer.
  3. the final column name should be taken from the key from schema json. e.g order_id is the key for orderid attribute

Final DF should have columns like:

order_id:String,customer_id:int, offer_id: string(array type cast to string),eligiblestatus:string

enter image description here


Solution

  • Please find the code below.

    import org.apache.spark.sql.DataFrame
    import org.apache.spark.sql.functions._
    import org.apache.spark.sql.types.{ArrayType, StructType}
    
    val input =
      """[{"orderid":"111","customers":{"customerId":"123"},"Offers":[{"Offerid":"1"},{"Offerid":"2"}]}]"""
    val schema =
      """[{"orders":{"order_id":{"path":"orderid","type":"string","nullable":false},"customer_id":{"path":"customers.customerId","type":"int","nullable":false,"default_value":"null"},"offer_id":{"path":"Offers.Offerid","type":"string","nullable":false},"eligible":{"path":"eligible.eligiblestatus","type":"string","nullable":true,"default_value":"not eligible"}},"products":{"product_id":{"path":"product_id","type":"string","nullable":false},"product_name":{"path":"products.productname","type":"string","nullable":false}}}]"""
    
    val jsonSchema = """
          map<
            string, 
            struct<
                path:string,
                type:string,
                nullable:string,
                default_value:string  
            >
          >
        """
    
    val inputDF = spark.read.json(Seq(input).toDS)
    val inputFieldMap = typedLit(inputDF.fields.map(f => f -> f).toMap)
    val schemaColumns = Seq("orders") // You can get this from variable.
    val schemaDF = spark.read.json(Seq(schema).toDS).selectExpr(schemaColumns:_*)
    

    Helper class to get nested columns from dataframe.

    implicit class DFHelpers(df: DataFrame) {
      def fields: Seq[String] =
        this.fields(df.schema)
      def fields(
          schema: StructType = df.schema,
          root: String = "",
          sep: String = "."
      ): Seq[String] = {
        schema.fields.flatMap { column =>
          column match {
            case _ if column.dataType.isInstanceOf[StructType] =>
              fields(
                column.dataType.asInstanceOf[StructType],
                s"${root}${sep}${column.name}".stripPrefix(sep)
              )
            case _ if column.dataType.isInstanceOf[ArrayType] =>
              column.dataType
                .asInstanceOf[ArrayType]
                .productIterator
                .filter(_.isInstanceOf[StructType])
                .map(_.asInstanceOf[StructType])
                .flatMap(f => fields(f, s"${root}${sep}${column.name}".stripPrefix(sep)))
            case _ => Seq(s"${root}${sep}${column.name}".stripPrefix(sep))
          }
        }.toList
      }
    }
    
    val schemaExprs = schemaDF.columns.map { columnName =>
      s"""
        filter(
           transform(
              map_entries(
                 ${columnName}),
                 e ->
                    CASE WHEN fields[e.value.path] IS NOT NULL THEN
                         IF(
                            e.value.nullable == 'true', 
                            CONCAT("CAST( NVL(",e.value.path,",'",e.value.default_value,"') AS ",e.value.type," )"," AS ",e.key), 
                            CONCAT("CAST( ",e.value.path, " AS ",e.value.type," )"," AS ",e.key)
                         )
                     END
              ),
           f -> f IS NOT NULL
        ) AS ${columnName}
      """
    }
    
    val columns = schemaDF
      .selectExpr(
        schemaDF.columns.map(c =>
          s"from_json(to_json(${c}), '${jsonSchema}') AS ${c}"
        ): _*
      )
      .withColumn("fields", inputFieldMap)// Checking if inputDF fields exist in the schemaDF, if not, default values will be used; if default values are not available, the column will be removed.
      .selectExpr(schemaExprs: _*)
      .select(schemaDF.columns.map(col(_)).reduce(array_union).as("columns"))
      .as[Seq[String]]
      .collect()
      .flatten
    
    inputDF.selectExpr(columns: _*).show(false)
    
    +-----------+------------+--------+--------+
    |customer_id|eligible    |offer_id|order_id|
    +-----------+------------+--------+--------+
    |123        |not eligible|[1, 2]  |111     |
    +-----------+------------+--------+--------+