Search code examples
scalaapache-sparkapache-spark-sqlapache-spark-ml

Spark, Scala, DataFrame: create feature vectors


I have a DataFrame that looks like follow:

userID, category, frequency
1,cat1,1
1,cat2,3
1,cat9,5
2,cat4,6
2,cat9,2
2,cat10,1
3,cat1,5
3,cat7,16
3,cat8,2

The number of distinct categories is 10, and I would like to create a feature vector for each userID and fill the missing categories with zeros.

So the output would be something like:

userID,feature
1,[1,3,0,0,0,0,0,0,5,0]
2,[0,0,0,6,0,0,0,0,2,1]
3,[5,0,0,0,0,0,16,2,0,0]

It is just an illustrative example, in reality I have about 200,000 unique userID and and 300 unique category.

What is the most efficient way to create the features DataFrame?


Solution

  • Suppose:

    val cs: SparkContext
    val sc: SQLContext
    val cats: DataFrame
    

    Where userId and frequency are bigint columns which corresponds to scala.Long

    We are creating intermediate mapping RDD:

    val catMaps = cats.rdd
      .groupBy(_.getAs[Long]("userId"))
      .map { case (id, rows) => id -> rows
        .map { row => row.getAs[String]("category") -> row.getAs[Long]("frequency") }
        .toMap
      }
    

    Then collecting all presented categories in the lexicographic order

    val catNames = cs.broadcast(catMaps.map(_._2.keySet).reduce(_ union _).toArray.sorted)
    

    Or creating it manually

    val catNames = cs.broadcast(1 to 10 map {n => s"cat$n"} toArray)
    

    Finally we're transforming maps to arrays with 0-values for non-existing values

    import sc.implicits._
    val catArrays = catMaps
          .map { case (id, catMap) => id -> catNames.value.map(catMap.getOrElse(_, 0L)) }
          .toDF("userId", "feature")
    

    now catArrays.show() prints something like

    +------+--------------------+
    |userId|             feature|
    +------+--------------------+
    |     2|[0, 1, 0, 6, 0, 0...|
    |     1|[1, 0, 3, 0, 0, 0...|
    |     3|[5, 0, 0, 0, 16, ...|
    +------+--------------------+
    

    This could be not the most elegant solution for dataframes, as I barely familiar with this area of spark.

    Note, that you could create your catNames manually to add zeros for missing cat3, cat5, ...

    Also note that otherwise catMaps RDD is operated twice, you might want to .persist() it