Search code examples
apache-sparkapache-spark-sqlapache-spark-mllibapache-spark-ml

Spark, DataFrame: apply transformer/estimator on groups


I have a DataFrame that looks like follow:

+-----------+-----+------------+
|     userID|group|    features|
+-----------+-----+------------+
|12462563356|    1|  [5.0,43.0]|
|12462563701|    2|   [1.0,8.0]|
|12462563701|    1|  [2.0,12.0]|
|12462564356|    1|   [1.0,1.0]|
|12462565487|    3|   [2.0,3.0]|
|12462565698|    2|   [1.0,1.0]|
|12462565698|    1|   [1.0,1.0]|
|12462566081|    2|   [1.0,2.0]|
|12462566081|    1|  [1.0,15.0]|
|12462566225|    2|   [1.0,1.0]|
|12462566225|    1|  [9.0,85.0]|
|12462566526|    2|   [1.0,1.0]|
|12462566526|    1|  [3.0,79.0]|
|12462567006|    2| [11.0,15.0]|
|12462567006|    1| [10.0,15.0]|
|12462567006|    3| [10.0,15.0]|
|12462586595|    2|  [2.0,42.0]|
|12462586595|    3|  [2.0,16.0]|
|12462589343|    3|   [1.0,1.0]|
+-----------+-----+------------+

Where the columns types are: userID: Long, group: Int, and features:vector.

This is already a grouped DataFrame, i.e. a userID will appear in a particular group at max one time.

My goal is to scale the features column per group.

Is there a way to apply a feature transformer (in my case I would like to apply a StandardScaler) per group instead of applying it to the full DataFrame.

P.S. using ML is not mandatory, so no problem if the solution is based on MLlib.


Solution

  • Compute statistics

    Spark >= 3.0

    Now Summarizer supports standard deviations so

    val summary = data
      .groupBy($"group")
      .agg(Summarizer.metrics("mean", "std")
      .summary($"features").alias("stats"))
      .as[(Int, (Vector, Vector))]
      .collect.toMap
    

    Spark >= 2.3

    In Spark 2.3 or later you could also use Summarizer:

    import org.apache.spark.ml.stat.Summarizer
    
    val summaryVar = data
      .groupBy($"group")
      .agg(Summarizer.metrics("mean", "variance")
      .summary($"features").alias("stats"))
      .as[(Int, (Vector, Vector))]
      .collect.toMap
    

    and adjust downstream code to handle variances instead of standard deviations.

    Spark < 2.0, Spark < 2.3 with adjustments for conversions between ml and mllib Vectors.

    You can compute statistics by group using almost the same code as default Scaler:

    import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
    import org.apache.spark.mllib.linalg.{Vector, Vectors}
    import org.apache.spark.sql.Row
    
    // Compute Multivariate Statistics 
    val summary = data.select($"group", $"features")
        .rdd
        .map {
             case Row(group: Int, features: Vector) => (group, features) 
        }
        .aggregateByKey(new MultivariateOnlineSummarizer)(/* Create an empty new MultivariateOnlineSummarizer */
             (agg, v) => agg.add(v), /* seqOp : Add a new sample Vector to this summarizer, and update the statistical summary. */
             (agg1, agg2) => agg1.merge(agg2)) /* combOp : As MultivariateOnlineSummarizer accepts a merge action with another MultivariateOnlineSummarizer, and update the statistical summary. */
        .mapValues {
          s => (
             s.variance.toArray.map(math.sqrt(_)), /* compute the square root variance for each key */
             s.mean.toArray /* fetch the mean for each key */
          )
        }.collectAsMap
    

    Transformation

    If expected number of groups is relatively low you can broadcast these:

    val summaryBd = sc.broadcast(summary)
    

    and transform your data:

    val scaledRows = df.rdd.map{ case Row(userID, group: Int, features: Vector) =>
      val (stdev, mean)  =  summaryBd.value(group)
      val vs = features.toArray.clone()
      for (i <- 0 until vs.size) {
        vs(i) = if(stdev(i) == 0.0) 0.0 else (vs(i) - mean(i)) * (1 / stdev(i))
      }
      Row(userID, group, Vectors.dense(vs))
    }
    val scaledDf = sqlContext.createDataFrame(scaledRows, df.schema)
    

    Otherwise you can simply join. It shouldn't be hard to wrap this as a ML transformer with group column as a param.