Search code examples
scalaapache-sparkdataframecountvectorizer

Scala Spark - split vector column into separate columns in a Spark DataFrame


I have a Spark DataFrame where I have a column with Vector values. The vector values are all n-dimensional, aka with the same length. I also have a list of column names Array("f1", "f2", "f3", ..., "fn"), each corresponds to one element in the vector.

some_columns... | Features
      ...       | [0,1,0,..., 0]

to

some_columns... | f1 | f2 | f3 | ... | fn

      ...       | 0  | 1  | 0  | ... | 0

What is the best way to achieve this? I thought of one way which is to create a new DataFrame with createDataFrame(Row(Features), featureNameList) and then join with the old one, but it requires spark context to use createDataFrame. I only want to transform the existing data frame. I also know .withColumn("fi", value) but what do I do if n is large?

I'm new to Scala and Spark and couldn't find any good examples for this. I think this can be a common task. My particular case is that I used the CountVectorizer and wanted to recover each column individually for better readability instead of only having the vector result.


Solution

  • One way could be to convert the vector column to an array<double> and then using getItem to extract individual elements.

    import org.apache.spark.sql.functions._
    import org.apache.spark.ml._
    
    val df = Seq( (1 , linalg.Vectors.dense(1,0,1,1,0) ) ).toDF("id", "features")
    //df: org.apache.spark.sql.DataFrame = [id: int, features: vector]
    
    df.show
    //+---+---------------------+
    //|id |features             |
    //+---+---------------------+
    //|1  |[1.0,0.0,1.0,1.0,0.0]|
    //+---+---------------------+
    
    // A UDF to convert VectorUDT to ArrayType
    val vecToArray = udf( (xs: linalg.Vector) => xs.toArray )
    
    // Add a ArrayType Column   
    val dfArr = df.withColumn("featuresArr" , vecToArray($"features") )
    
    // Array of element names that need to be fetched
    // ArrayIndexOutOfBounds is not checked.
    // sizeof `elements` should be equal to the number of entries in column `features`
    val elements = Array("f1", "f2", "f3", "f4", "f5")
    
    // Create a SQL-like expression using the array 
    val sqlExpr = elements.zipWithIndex.map{ case (alias, idx) => col("featuresArr").getItem(idx).as(alias) }
    
    // Extract Elements from dfArr    
    dfArr.select(sqlExpr : _*).show
    //+---+---+---+---+---+
    //| f1| f2| f3| f4| f5|
    //+---+---+---+---+---+
    //|1.0|0.0|1.0|1.0|0.0|
    //+---+---+---+---+---+