Search code examples
scalaapache-sparkdataframeapache-spark-sqlrdd

Getting an apache spark dataframe in the right format


I am trying to convert some input to the format I want in an spark dataframe. The input I have is a Sequence of this case class with up to 10,000,000 classes (or possibly also the Json string before I convert it to the case class..):

case class Element(paramName: String, value: Int, time: Int)

As a result I want a dataframe like this:

|Time | ParamA | ParamB | ParamC | Param 10,000 |  
|1000 | 432432 | 8768768 | Null....... | 75675678622 |  
|2000 | Null.......| Null.........| 734543 | Null................. |  

....
So not every parameter has to have to be defined for all time slots. Missing values should be filled with Null. And there are probably going to be 10,000 parameter and around 1000 time slots.

The way I do it right now seems to be very bad from the efficiency:

case class Elements(name: String, value: Int, time: Int)

case class GroupedObjects(time: Int, params: (String, Int)*)

 //elements contains the seq of Element
val elementsRdd: RDD[Elements] = sc.parallelize(elements)
val groupedRDD: RDD[GroupedObjects] = elementsRdd
  .groupBy(element => element.time)
  .map(tuple => GroupedObjects(tuple._1, tuple._2.map(element =>
    (element.name, element.value)).toSeq: _*))

//transforming back to json string to get right format for RDD
val jsonRDD: RDD[String] = groupedRDD.map { obj =>
  "{\"time\":" + obj.time + obj.params.map(tuple => 
     ",\"" + tuple._1 + "\":" + tuple._2).reduce(_ + _) + "}"
}
val df = sqlContext.read.json(jsonRDD).orderBy("time")
df.show(10)

The problem I see here is definitely the change back to a String, only to read it in again in the right format. I would be really glad for any help showing me how to get the input case class in the wanted dataframe format.
With the way I am doing it right now it is really slow and I get a heap size exception for 10,000,000 input lines.


Solution

  • You might try to build Row objects and define the RDD schema manually, something like the following example:

    // These extra imports will be required if you don't have them already
    import org.apache.spark.sql.Row
    import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
    
    //elements contains the seq of Element
    val elementsRdd = sc.parallelize(elements)
    
    val columnNames = elementsRdd.map(_.name).distinct().collect().sorted
    
    val pivoted = elementsRdd.groupBy(_.time).map {
      case (time, elemsByTime) =>
        val valuesByColumnName = elemsByTime.groupBy(_.name).map {
          case (name, elemsByTimeAndName) => (name, elemsByTimeAndName.map(_.value).sum)
        }
        val allValuesForRow = columnNames.map(valuesByColumnName.getOrElse(_, null))
        (time, allValuesForRow)
    }
    
    val schema = StructType(StructField("Time", IntegerType) :: columnNames.map(columnName => StructField(columnName, IntegerType, nullable = true)).toList)
    val rowRDD = pivoted.map(p => Row.fromSeq(p._1 :: p._2.toList))
    val df = sqlContext.createDataFrame(rowRDD, schema)
    df.show(10)
    

    I tried this locally with 10,000,000 elements like this:

    val elements = (1 to 10000000).map(i => Element("Param" + (i % 1000).toString, i + 100, i % 10000))
    

    And it completes successfully in a reasonable time.