Search code examples
scalaapache-sparkspark3

How to port UDAF to Aggregator?


I have a DF looking like this:

time,channel,value
0,foo,5
0,bar,23
100,foo,42
...

I want a DF like this:

time,foo,bar
0,5,23
100,42,...

In Spark 2, I did it with a UDAF like this:

case class ColumnBuilderUDAF(channels: Seq[String]) extends UserDefinedAggregateFunction {

  @transient lazy val inputSchema: StructType = StructType {
    StructField("channel", StringType, nullable = false) ::
      StructField("value", DoubleType, nullable = false) ::
      Nil
  }

  @transient lazy val bufferSchema: StructType = StructType {
    channels
      .toList
      .indices
      .map(i => StructField("c%d".format(i), DoubleType, nullable = false))
  }

  @transient lazy val dataType: DataType = bufferSchema

  @transient lazy val deterministic: Boolean = false

  def initialize(buffer: MutableAggregationBuffer): Unit = channels.indices.foreach(buffer(_) = NaN)

  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val channel = input.getAs[String](0)
    val p = channels.indexOf(channel)
    if (p >= 0 && p < channels.length) {
      val v = input.getAs[Double](1)
      if (!v.isNaN) {
        buffer(p) = v
      }
    }
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit =
    channels
      .indices
      .foreach { i =>
        val v2 = buffer2.getAs[Double](i)
        if ((!v2.isNaN) && buffer1.getAs[Double](i).isNaN) {
          buffer1(i) = v2
        }
      }

  def evaluate(buffer: Row): Any =
    new GenericRowWithSchema(channels.indices.map(buffer.getAs[Double]).toArray, dataType.asInstanceOf[StructType])
}

which I use like this:

val cb = ColumnBuilderUDAF(Seq("foo", "bar"))
val dfColumnar = df.groupBy($"time").agg(cb($"channel", $"value") as "c")

and then, I rename c.c0, c.c1 etc. to foo, bar etc.

In Spark 3, UDAF is deprecated and Aggregator should be used instead. So I began to port it like this:

case class ColumnBuilder(channels: Seq[String]) extends Aggregator[(String, Double), Array[Double], Row] {

  lazy val bufferEncoder: Encoder[Array[Double]] = Encoders.javaSerialization[Array[Double]]

  lazy val zero: Array[Double] = channels.map(_ => Double.NaN).toArray

  def reduce(b: Array[Double], a: (String, Double)): Array[Double] = {
    val index = channels.indexOf(a._1)
    if (index >= 0 && !a._2.isNaN) b(index) = a._2
    b
  }

  def merge(b1: Array[Double], b2: Array[Double]): Array[Double] = {
    (0 until b1.length.min(b2.length)).foreach(i => if (b1(i).isNaN) b1(i) = b2(i))
    b1
  }

  def finish(reduction: Array[Double]): Row =
    new GenericRowWithSchema(reduction.map(x => x: Any), outputEncoder.schema)

  def outputEncoder: Encoder[Row] = ??? // what goes here?
}

I don't know how to implement the Encoder[Row] as Spark does not have a pre-defined one. If I simply do a straightforward approach like this:

  val outputEncoder: Encoder[Row] = new Encoder[Row] {
    val schema: StructType = StructType(channels.map(StructField(_, DoubleType, nullable = false)))

    val clsTag: ClassTag[Row] = classTag[Row]
  }

I get a ClassCastException because outputEncoder actually has to be ExpressionEncoder.

So, how do I implement this correctly? Or do I still have to use the deprecated UDAF?


Solution

  • You can do it with the use of groupBy and pivot

    import spark.implicits._
    import org.apache.spark.sql.functions._
    
    val df = Seq(
      (0, "foo", 5),
      (0, "bar", 23),
      (100, "foo", 42)
    ).toDF("time", "channel", "value")
    
    df.groupBy("time")
      .pivot("channel")
      .agg(first("value"))
      .show(false)
    

    Output:

    +----+----+---+
    |time|bar |foo|
    +----+----+---+
    |100 |null|42 |
    |0   |23  |5  |
    +----+----+---+