Search code examples
scalaapache-sparkuser-defined-functions

Apache Spark - Is there a problem with passing parameters to custom Aggregator constructor?


I created an Aggregator to be used an udaf, which uses three columns in a dataframe to calculate its results, but it also needs other two parameters, common to every row. Initially, I defined the input type like this (simplifying unnecessary details)

case class In(a: Long, b: Double, c: Double, d: Long, e: Double)
class MyUDAF extends Aggregator[In, Buf, Long] {
   ...
}

and passed those extra parameters using lit from org.apache.spark.sql.functions:

val myudaf = udaf(new MyUDAF, ExpressionEncoder[In])
val df: DataFrame = _ // suppose there's an actual DataFrame here
df.withColumn("result", myudaf(col("a"), col("b"), col("c"), lit(100L), lit(10.0)))

It worked perfectly, but I didn't like this approach of passing those two parameters as columns, since I had to keep them inside the reduction buffers (the merge method takes only buffers as parameters). So I decided to include those in MyUDAF constructor, and use it like this:

case class In(a: Long, b: Double, c: Double)
class MyUDAF(d: Long, e: Double) extends Aggregator[In, Buf, Long] {
   ...
}
val myudaf = udaf(new MyUDAF(100L, 10.0), ExpressionEncoder[In])
val df: DataFrame = _
df.withColumn("result", myudaf(col("a"), col("b"), col("c")))

This also worked perfectly on local tests. But I'm new to Spark, so I don't know if this practice brings possible errors. Unfortunately, I currently don't have access to more machines to create a cluster and check for myself if something goes wrong in a more complex scenario. So the question is: could the act of using data different from those contained in input Rows and buffers (like values from constructor) cause any problems, errors or side effects? Is my second approach ok?


Solution

  • Apache DataFu-Spark has an example of this in its CountDistinctUpTo UDAF. (disclosure: I am a member of DataFu and wrote this code).

    The declaration looks like this:

      /**
       * Counts number of distinct records, but only up to a preset amount -
       * more efficient than an unbounded count
       */
      class CountDistinctUpTo(maxItems: Int) extends Aggregator[String, Set[String], Int] with Serializable {
    

    As you can see, a single value of maxItems is used for all the data that is processed. This runs on a cluster without complications.