I have this error in my Scala test:
StructType(StructField(a,StringType,true), StructField(b,StringType,true), StructField(c,StringType,true), StructField(d,StringType,true), StructField(e,StringType,true), StructField(f,StringType,true), StructField(NewColumn,StringType,false)) did not equal StructType(StructField(a,StringType,true), StructField(b,StringType,true), StructField(c,StringType,true), StructField(d,StringType,true), StructField(e,StringType,true), StructField(f,StringType,true), StructField(NewColumn,StringType,true))
ScalaTestFailureLocation: com.holdenkarau.spark.testing.TestSuite$class at (TestSuite.scala:13)
Expected :StructType(StructField(a,StringType,true), StructField(b,StringType,true), StructField(c,StringType,true), StructField(d,StringType,true), StructField(e,StringType,true), StructField(f,StringType,true), StructField(NewColumn,StringType,true))
Actual :StructType(StructField(a,StringType,true), StructField(b,StringType,true), StructField(c,StringType,true), StructField(d,StringType,true), StructField(e,StringType,true), StructField(f,StringType,true), StructField(NewColumn,StringType,false))
Last StructField
is false
when it should be true
and I do not why. This true means that the schema accepts null values.
And this is my test:
val schema1 = Array("a", "b", "c", "d", "e", "f")
val df = List(("a1", "b1", "c1", "d1", "e1", "f1"),
("a2", "b2", "c2", "d2", "e2", "f2"))
.toDF(schema1: _*)
val schema2 = Array("a", "b", "c", "d", "e", "f", "NewColumn")
val dfExpected = List(("a1", "b1", "c1", "d1", "e1", "f1", "a1_b1_c1_d1_e1_f1"),
("a2", "b2", "c2", "d2", "e2", "f2", "a2_b2_c2_d2_e2_f2")).toDF(schema2: _*)
val transformer = KeyContract("NewColumn", schema1)
val newDf = transformer(df)
newDf.columns should contain ("NewColumn")
assertDataFrameEquals(newDf, dfExpected)
And this is KeyContract:
case class KeyContract(tempColumn: String, columns: Seq[String],
unsigned: Boolean = true) extends Transformer {
override def apply(input: DataFrame): DataFrame = {
import org.apache.spark.sql.functions._
val inputModif = columns.foldLeft(input) { (tmpDf, columnName) =>
tmpDf.withColumn(columnName, when(col(columnName).isNull,
lit("")).otherwise(col(columnName)))
}
inputModif.withColumn(tempColumn, concat_ws("_", columns.map(col): _*))
}
}
Thanks in advance!!
This happens because concat_ws
never returns null
and the resulting field is marked as not nullable.
If you want to use a second DataFrame
as a reference, you'll have to use schema and Rows
:
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types._
val spark: SparkSession = SparkSession.builder.getOrCreate()
val dfExpected = spark.createDataFrame(spark.sparkContext.parallelize(List(
Row("a1", "b1", "c1", "d1", "e1", "f1", "a1_b1_c1_d1_e1_f1"),
Row("a2", "b2", "c2", "d2", "e2", "f2", "a2_b2_c2_d2_e2_f2")
)), StructType(schema2.map { c => StructField(c, StringType, c != "NewColumn") }))
This way the last column won't be nullable:
dfExpected.printSchema
root
|-- a: string (nullable = true)
|-- b: string (nullable = true)
|-- c: string (nullable = true)
|-- d: string (nullable = true)
|-- e: string (nullable = true)
|-- f: string (nullable = true)
|-- NewColumn: string (nullable = false)