Search code examples
scalaapache-sparkspark3

How to create a map column to count occurrences without udaf


I would like to create a Map column which counts the number of occurrences.

For instance:

+---+----+
|  b|   a|
+---+----+
|  1|   b|
|  2|null|
|  1|   a|
|  1|   a|
+---+----+

would result in

+---+--------------------+
|  b|                 res|
+---+--------------------+
|  1|[a -> 2.0, b -> 1.0]|
|  2|                  []|
+---+--------------------+

For the moment, in Spark 2.4.6, I was able to make it using udaf.

While bumping to Spark3 I was wondering if I could get rid of this udaf (I tried using the new method aggregate without success)

Is there an efficient way to do it? (For the efficiency part, I am able to test easily)


Solution

  • Here a Spark 3 solution:

    import org.apache.spark.sql.functions._
    
    df.groupBy($"b",$"a").count()
      .groupBy($"b")
      .agg(
        map_from_entries(
          collect_list(
            when($"a".isNotNull,struct($"a",$"count"))
          )
        ).as("res")
      )
      .show()
    

    gives:

    +---+----------------+
    |  b|             res|
    +---+----------------+
    |  1|[b -> 1, a -> 2]|
    |  2|              []|
    +---+----------------+
    

    Here the solution using Aggregator:

    import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
    import org.apache.spark.sql.expressions.Aggregator
    import org.apache.spark.sql.functions._
    import org.apache.spark.sql.Encoder
    
    val countOcc = new Aggregator[String, Map[String,Int], Map[String,Int]] with Serializable {
        def zero: Map[String,Int] = Map.empty.withDefaultValue(0)
        def reduce(b: Map[String,Int], a: String) = if(a!=null) b + (a -> (b(a) + 1)) else b
        def merge(b1: Map[String,Int], b2: Map[String,Int]) = {
          val keys = b1.keys.toSet.union(b2.keys.toSet)
          keys.map{ k => (k -> (b1(k) + b2(k))) }.toMap
        }
        def finish(b: Map[String,Int]) = b
        def bufferEncoder: Encoder[Map[String,Int]] = implicitly(ExpressionEncoder[Map[String,Int]])
        def outputEncoder: Encoder[Map[String, Int]] = implicitly(ExpressionEncoder[Map[String, Int]])
    }
    
    val countOccUDAF = udaf(countOcc)
    
    df
      .groupBy($"b")
      .agg(countOccUDAF($"a").as("res"))
      .show()
    

    gives:

    +---+----------------+
    |  b|             res|
    +---+----------------+
    |  1|[b -> 1, a -> 2]|
    |  2|              []|
    +---+----------------+