Search code examples
sqljsonscalaapache-sparkapache-spark-sql

Programmatically adding several columns to Spark DataFrame


I'm using spark with scala.

I have a Dataframe with 3 columns: ID,Time,RawHexdata. I have a user defined function which takes RawHexData and expands it into X more columns. It is important to state that for each row X is the same (the columns do not vary). However, before I receive the first data, I do not know what the columns are. But once I have the head, I can deduce it.

I would like a second Dataframe with said columns: Id,Time,RawHexData,NewCol1,...,NewCol3.

The "Easiest" method I can think of to do this is: 1. deserialize each row into json (every data tyoe is serializable here) 2. add my new columns, 3. deserialize a new dataframe from the altered json,

However, that seems like a waste, as it involves 2 costly and redundant json serialization steps. I am looking for a cleaner pattern.

Using case-classes, seems like a bad idea, because I don't know the number of columns, or the column names in advance.


Solution

  • What you can do to dynamically extend your DataFrame is to operate on the row RDD which you can obtain by calling dataFrame.rdd. Having a Row instance, you can access the RawHexdata column and parse the contained data. By adding the newly parsed columns to the resulting Row, you've almost solved your problem. The only thing necessary to convert a RDD[Row] back into a DataFrame is to generate the schema data for your new columns. You can do this by collecting a single RawHexdata value on your driver and then extracting the column types.

    The following code illustrates this approach.

    object App {
    
      case class Person(name: String, age: Int)
    
      def main(args: Array[String]) {
        val sparkConf = new SparkConf().setAppName("Test").setMaster("local[4]")
        val sc = new SparkContext(sparkConf)
        val sqlContext = new SQLContext(sc)
        import sqlContext.implicits._
    
        val input = sc.parallelize(Seq(Person("a", 1), Person("b", 2)))
        val dataFrame = input.df
    
        dataFrame.show()
    
        // create the extended rows RDD
        val rowRDD = dataFrame.rdd.map{
          row =>
            val blob = row(1).asInstanceOf[Int]
            val newColumns: Seq[Any] = Seq(blob, blob * 2, blob * 3)
            Row.fromSeq(row.toSeq.init ++ newColumns)
        }
    
        val schema = dataFrame.schema
    
        // we know that the new columns are all integers
        val newColumns = StructType{
          Seq(new StructField("1", IntegerType), new StructField("2", IntegerType), new StructField("3", IntegerType))
        }
    
        val newSchema = StructType(schema.init ++ newColumns)
    
        val newDataFrame = sqlContext.createDataFrame(rowRDD, newSchema)
    
        newDataFrame.show()
      }
    }