Search code examples
scalaapache-sparkdataframescalatest

Error in StructField(a,StringType,false). It is false and should be true


I have this error in my Scala test:

StructType(StructField(a,StringType,true), StructField(b,StringType,true), StructField(c,StringType,true), StructField(d,StringType,true), StructField(e,StringType,true), StructField(f,StringType,true), StructField(NewColumn,StringType,false)) did not equal StructType(StructField(a,StringType,true), StructField(b,StringType,true), StructField(c,StringType,true), StructField(d,StringType,true), StructField(e,StringType,true), StructField(f,StringType,true), StructField(NewColumn,StringType,true))

ScalaTestFailureLocation: com.holdenkarau.spark.testing.TestSuite$class at (TestSuite.scala:13)

Expected :StructType(StructField(a,StringType,true), StructField(b,StringType,true), StructField(c,StringType,true), StructField(d,StringType,true), StructField(e,StringType,true), StructField(f,StringType,true), StructField(NewColumn,StringType,true))

Actual   :StructType(StructField(a,StringType,true), StructField(b,StringType,true), StructField(c,StringType,true), StructField(d,StringType,true), StructField(e,StringType,true), StructField(f,StringType,true), StructField(NewColumn,StringType,false))

Last StructField is false when it should be true and I do not why. This true means that the schema accepts null values.

And this is my test:

val schema1 = Array("a", "b", "c", "d", "e", "f")
val df = List(("a1", "b1", "c1", "d1", "e1", "f1"),
  ("a2", "b2", "c2", "d2", "e2", "f2"))
  .toDF(schema1: _*)

val schema2 = Array("a", "b", "c", "d", "e", "f", "NewColumn")

val dfExpected = List(("a1", "b1", "c1", "d1", "e1", "f1", "a1_b1_c1_d1_e1_f1"),
  ("a2", "b2", "c2", "d2", "e2", "f2", "a2_b2_c2_d2_e2_f2")).toDF(schema2: _*)

val transformer = KeyContract("NewColumn", schema1)
val newDf = transformer(df)
newDf.columns should contain ("NewColumn")
assertDataFrameEquals(newDf, dfExpected)

And this is KeyContract:

case class KeyContract(tempColumn: String, columns: Seq[String],
                       unsigned: Boolean = true) extends Transformer {

  override def apply(input: DataFrame): DataFrame = {
    import org.apache.spark.sql.functions._

    val inputModif = columns.foldLeft(input) { (tmpDf, columnName) =>
      tmpDf.withColumn(columnName, when(col(columnName).isNull,
        lit("")).otherwise(col(columnName)))
    }

    inputModif.withColumn(tempColumn, concat_ws("_", columns.map(col): _*))
  }
}

Thanks in advance!!


Solution

  • This happens because concat_ws never returns null and the resulting field is marked as not nullable.

    If you want to use a second DataFrame as a reference, you'll have to use schema and Rows:

    import org.apache.spark.sql.{Row, SparkSession}
    import org.apache.spark.sql.types._
    
    val spark: SparkSession = SparkSession.builder.getOrCreate()
    
    val dfExpected = spark.createDataFrame(spark.sparkContext.parallelize(List(
      Row("a1", "b1", "c1", "d1", "e1", "f1", "a1_b1_c1_d1_e1_f1"),
      Row("a2", "b2", "c2", "d2", "e2", "f2", "a2_b2_c2_d2_e2_f2")
    )), StructType(schema2.map { c => StructField(c, StringType, c != "NewColumn") }))
    

    This way the last column won't be nullable:

    dfExpected.printSchema
    root
     |-- a: string (nullable = true)
     |-- b: string (nullable = true)
     |-- c: string (nullable = true)
     |-- d: string (nullable = true)
     |-- e: string (nullable = true)
     |-- f: string (nullable = true)
     |-- NewColumn: string (nullable = false)