Search code examples
dataframescalaapache-sparkrdd

Problem creating a Dataframe from a dataset with nested sequences in Scala Spark


I am trying to create a dataframe from a Sequence that contains nested sequences but am getting a scala.Match error.

val data = Seq(("Java", Seq(Seq(true, 5L), 0, 7.5)), ("Python", Seq(Seq(true, 10L), 1, 8.5)), ("Scala", Seq(Seq(false, 8L), 2, 9.0)))

val rdd = spark.sparkContext.parallelize(data).map {
  case (l, s) => Row(l, Row.fromSeq(s.map{
    case (s: Seq[Any], i, d) => Row(Row.fromSeq(s), i, d)})) }

val schema = StructType(Seq(
  StructField("language", StringType, true),
  StructField("stats", StructType(Seq(
    StructField("users", StructType(Seq(
      StructField("active", BooleanType, true),
      StructField("level", LongType, true)
    ))),
    StructField("difficulty", IntegerType, true),
    StructField("average_review", DoubleType, true)
  )))
))

val ds = spark.createDataFrame(rdd, schema)
ds.show()

The pipeline breaks when I ds.show() and gives the error scala.MatchError: List(true, 5) (of class scala.collection.immutable.$colon$colon)

I suspect the issue is with my spark.sparkContext.parallelize(data).map... part of the code but I can't figure out what the problem is.


Solution

  • Try

    val rdd: RDD[Row] = spark.sparkContext.parallelize(data).map {
      case (language, Seq(Seq(active, level), difficulty, average_review)) =>
        Row(language, Row(Row(active, level), difficulty, average_review))
    }
    
    //+--------+--------------------+
    //|language|               stats|
    //+--------+--------------------+
    //|    Java| {{true, 5}, 0, 7.5}|
    //|  Python|{{true, 10}, 1, 8.5}|
    //|   Scala|{{false, 8}, 2, 9.0}|
    //+--------+--------------------+ 
    

    or

    val rdd: RDD[Row] = spark.sparkContext.parallelize(data).map {
      case (language, Seq(users: Seq[Any], difficulty, average_review)) =>
        Row(language, Row(Row.fromSeq(users), difficulty, average_review))
    }
    

    or

    val rdd: RDD[Row] = spark.sparkContext.parallelize(data).map {
      case (language, stats) =>
        Row(language, Row.fromSeq(stats.map{
          case users: Seq[Any] => Row.fromSeq(users)
          case x => x
        }))
    }
    

    or

    case class Data(language: String, stats: Stats)
    case class Stats(users: Users, difficulty: Int, average_review: Double)
    case class Users(active: Boolean, level: Long)
    
    val data1 = data.map {
      case (language, Seq(Seq(active: Boolean, level: Long), difficulty: Int, average_review: Double)) =>
        Data(language, Stats(Users(active, level), difficulty, average_review))
    }
    
    import spark.implicits._
    
    val ds: Dataset[Data] = data1.toDS()