Search code examples
apache-sparkapache-spark-sqlapache-spark-mlapache-spark-datasetspark-csv

Spark schema from case class with correct nullability


For a custom Estimator`s transformSchema method I need to be able to compare the schema of a input data frame to the schema defined in a case class. Usually this could be performed like Generate a Spark StructType / Schema from a case class as outlined below. However, the wrong nullability is used:

The real schema of the df inferred by spark.read.csv().as[MyClass] might look like:

root
 |-- CUSTOMER_ID: integer (nullable = false)

And the case class:

case class MySchema(CUSTOMER_ID: Int)

To compare I use:

val rawSchema = ScalaReflection.schemaFor[MySchema].dataType.asInstanceOf[StructType]
  if (!rawSchema.equals(rawDf.schema))

Unfortunately this always yields false, as the new schema manually inferred from the case class is setting nullable to true (because ja java.Integer actually might be null)

root
 |-- CUSTOMER_ID: integer (nullable = true)

How can I specify nullable = false when creating the schema?


Solution

  • Arguably you're mixing things which don't really belong in the same space. ML Pipelines are inherently dynamic and introducing statically typed objects doesn't really change that.

    Moreover schema for a class defined as:

    case class MySchema(CUSTOMER_ID: Int)
    

    will have not nullable CUSTOMER_ID. scala.Int is not the same as java.lang.Integer:

    scala> import org.apache.spark.sql.catalyst.ScalaReflection.schemaFor
    import org.apache.spark.sql.catalyst.ScalaReflection.schemaFor
    
    scala> case class MySchema(CUSTOMER_ID: Int)
    defined class MySchema
    
    scala> schemaFor[MySchema].dataType
    res0: org.apache.spark.sql.types.DataType = StructType(StructField(CUSTOMER_ID,IntegerType,false))
    

    That being said if you want nullable fields Option[Int]:

    case class MySchema(CUSTOMER_ID: Option[Int])
    

    and if you want not nullable use Int as above.

    Another problem you have here is that for csv every field is nullable by definition and this state is "inherited" by the encoded Dataset. So in practice:

    spark.read.csv(...)
    

    will always result in:

    root
     |-- CUSTOMER_ID: integer (nullable = true)
    

    and this is why you get schema mismatch. Unfortunately it is not possible to override nullable field for sources which don't enforce nullability constraints, like csv or json.

    If having not nullable schema is a hard requirement you could try:

    spark.createDataFrame(
      spark.read.csv(...).rdd,
      schemaFor[MySchema].dataType.asInstanceOf[StructType]
    ).as[MySchema]
    

    This approach is valid only if you know that data is actually null free. Any null value wiil lead to runtime exception.