Search code examples
scalaapache-spark

How to get median in Scala Spark, when there's frequency of values given?


I'm pretty new to Spark and Scala.

Given table "tab" (but without median column):

movie voted_10 voted_9 voted_8 median
Some 10 15 15 9
Movie 20 16 52 8

How do you calculate the median column? The voted_* columns are frequencies of ratings 10-1 (the values mean how many people rated the movie 10, how many rated it 9 etc.).

I calculated the mean like this:

val ratingsDfChanged = ratingsDf
  .withColumn("mean", ((col("votes_1")+col("votes_2")*2+col("votes_3")*3+col("votes_4")*4+col("votes_5")*5+col("votes_6")*6+col("votes_7")*7+col("votes_8")*8+col("votes_9")*9+col("votes_10")*10)/$"total_votes")) // :))

Which doesn't look pretty, but in the case of median I have no idea how to do it.

I've tried something like

.withColumn("wartosci", Array.fill($"votes_10")(10))

And then maybe concatenate these arrays and finally calculate median, but in this case I get error, because Array.fill requires int and I'm passing in column, if you know how to bypass that it would also be welcome.


Solution

  • To avoid hardcoding the score columns, apply a pattern filtering to capture the column names, sort the column names by their corresponding scores, and assemble a Map with the column names and corresponding scores as keys/values.

    val df = Seq(
      ("Forrest Gump", 10, 15, 15),
      ("The Matrix", 20, 16, 52)
    ).toDF("movie", "scored_10", "scored_9", "scored_8")
    
    val scoreSorted = df.columns.
      filter(_.matches("scored_\\d+")).
      sortBy(_.split("_")(1).toInt)  // sorted for computing median
    // Array(scored_8, scored_9, scored_10)
    
    val scoreMap = scoreSorted.map(c => (c, c.split("_")(1).toInt)).toMap
    

    For the mean, use foldLeft to aggregate for the total scores to be divided by the total votes:

    df.
      withColumn("votes", scoreSorted.map(col).reduce(_ + _)).
      withColumn("mean", scoreSorted.foldLeft(lit(0.0))((acc, c) =>
        acc + col(c) * scoreMap(c)) / $"votes"
      ).
    show
    // +------------+---------+--------+--------+-----+-----------------+
    // |       movie|scored_10|scored_9|scored_8|votes|             mean|
    // +------------+---------+--------+--------+-----+-----------------+
    // |Forrest Gump|       10|      15|      15|   40|            8.875|
    // |  The Matrix|       20|      16|      52|   88|8.636363636363637|
    // +------------+---------+--------+--------+-----+-----------------+
    

    To compute the median, first generate data row-wise via explode from an array composed of elements expanded out of the scoreMap, followed by using percentile_approx with percentage set to 0.5 (i.e. mid point) and accuracy set to maximum:

    df.
      withColumn("scores", explode(flatten(array(scoreSorted.map(c =>
        array_repeat(lit(scoreMap(c)), col(c))): _*)
      ))).
      groupBy("movie").agg(
        percentile_approx($"scores", lit(0.5), lit(Int.MaxValue)).as("median")
      ).
    show
    // +------------+------+
    // |       movie|median|
    // +------------+------+
    // |Forrest Gump|     9|
    // |  The Matrix|     8|
    // +------------+------+
    

    Note that percentile_approx is available only on Spark 3.1+. Its accuracy parameter takes an integer between 1 and Int max value -- the higher the value the more accurate (and more costly) is the percentile result. Also note that the sorting of the score columns is for computing median (i.e. not needed for computing mean) and doing it upfront costs less.