Search code examples
scalaapache-sparkapache-spark-sqlaggregate-functionsuser-defined-functions

How to extend built-in aggregate function in Spark SQL (using Scala)?


Basically the end goal would be to create something like dollarSum which would return the same values as ROUND(SUM(col), 2).

I'm using Databricks runtime 10.4 LTS ML, which apparently corresponds to Spark 3.2.1 and Scala 2.12.

I am able to follow the tutorial / example code for UDAFs, and used it to create something analogous to the built-in EVERY function. But that seems to be more like ImperativeAggregate, whereas what I want might be more like DeclarativeAggregate, cf. the comments in the Spark source code.

Overall I haven't been able to find any documentation online of how you would extend build-in aggregate functions in a simple way, where you only modify the "finish" or "evaluate" step, and even then just by adding on extra behavior.

What I have tried so far: I have tried at least four things so far, and none of them work.

Attempt 1:

import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.functions.{sum, round}

object dollarSum extends Aggregator[Double, Double, Double] {

def zero: Double = sum.zero

def reduce(buffer: Double, row: Double): Double = sum.reduce

def merge(buffer1: Double, buffer2: Double) Double = sum.merge

def finish(reduction: Double): Double = {
    sum.finish(reduction)
    round(reduction, 2)
}

def bufferEncoder: Encoder[Double] = sum.bufferEncoder
def outputEncoder: Encoder[Double] = sum.outputEncoder
}

Attempt 2: I tried to copy-paste-modify code from here. This seems to fail because most of the attributes and methods of the built-in Sum class appear to be private (probably because the developers didn't want people like me who don't know what they're doing to break the code). But I don't what public interface / API I could use instead to get what I want.

import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.functions.round
import org.apache.spark.sql.catalyst.expressions.EvalMode
import org.apache.spark.sql.types.DecimalType

trait dollarSum extends Sum {

  override lazy val evaluateExpression: Expression = {
    Sum.resultType match {
      case d: DecimalType =>
        val checkOverflowInSum =
          CheckOverflowInSum(Sum.sum, d, evalMode != EvalMode.ANSI, getContextOrNull())
        If(isEmpty, Literal.create(null, Sum.resultType), checkOverflowInSum)
      case _ if shouldTrackIsEmpty =>
        If(isEmpty, Literal.create(null, Sum.resultType), Sum.sum)
      case _ => round(Sum.sum, 2)
    }
  }

}

This would probably still fail due to some other missing imports, but again I wasn't able to get that far in debugging due to trying to access private methods and attributes that probably shouldn't be accessed.

Attempt 3: The source code for try_sum in the same file seemed closer to using a "public API" for sum, so I tried copy-paste-modifying that instead. But ExpressionBuilder also seems like it's a private class, so this fails too.

import org.apache.spark.sql.catalyst.analysis.ExpressionBuilder
import org.apache.spark.sql.catalyst.expressions.Expression

object DollarSumExpressionBuilder extends ExpressionBuilder {
  override def build(funcName: String, expressions: Seq[Expression]): Expression = {
    val numArgs = expressions.length
    if (numArgs == 1) {
      round(Sum(expressions.head),2)
    } else {
      throw QueryCompilationErrors.wrongNumArgsError(funcName, Seq(1, 2), numArgs)
    }
  }
}

Then the idea would be that if that worked, I would try registering the function the same way that TRY_SUM is registered with Spark SQL in the source code, cf. here. But I got an error about ExpressionBuilder not existing, which seems to indicate that it is also a private class for the package and thus not the public interface I could use to extend SUM.

Also it's not clear to me what the return type is for the SUM constructor, I think it might be AggregateExpression inheriting from Expression. And I'm not certain what the input type is for round, it seems like it might be org.apache.spark.sql.Column, if so, I'm not sure how to convert from Expression to Column.

E.g. whether in the above

round(org.apache.spark.sql.Column((Sum(expressions.head)),2)

or

round(org.apache.spark.sql.functions.col((Sum(expressions.head)),2)

would be able to achieve the desired type conversion (seemingly neither works).

Attempt 4: Along the lines of the above, not knowing which types are needed and how to convert between them, and what the public interface for SUM is, I tried using org.apache.spark.sql.functions.sum as the "public interface" for SUM instead, but this also didn't work.

Specifically

import org.apache.spark.sql.functions.{round, sum}
import org.apache.spark.sql.Column

// originally I had `expression: org.apache.spark.sql.catalyst.expressions.Expression` but that didn't work
def dollarSum(expression: Column): Column = {round(sum(expression), 2)}

actually doesn't throw any errors, but then when I try to actually register the resulting object as a(n aggregate) function, it fails, specifically

spark.udf.register("dollar_sum", functions.udaf(dollarSum))

doesn't work, nor does

spark.udf.register("dollar_sum", functions.udf(dollarSum))

Solution

  • Wow, lots of fun stuff in this question and awfully familiar: Quality's agg_expr was my journey into that space.

    To build a custom expression you may need to put code into the org.apache.spark.sql package e.g. registerFunction. Using the SparkSession instance FunctionRegistry createOrReplaceTempFunction (e.g. SparkSession.getActiveSession.get.sessionState.functionRegistry) you can use the function within a session. If you need it in hive views etc. you must use SparkSessionExtensions for scope and FunctionRegistry.builtin.registerFunction.

    The actual registration ExpressionBuilder is just an alias for Seq[Expression] => Expression, representing the parameters passed into construct your expression.

    So, depending on Spark version (the internal api changes alot):

        import org.apache.spark.sql.SparkSession
        import org.apache.spark.sql.catalyst.expressions.{Round, Literal, EvalMode}
        import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
    
        SparkSession.getActiveSession.get.sessionState.functionRegistry.
          createOrReplaceTempFunction("dollarSum", exps => Round(
            Sum(exps.head, EvalMode.TRY).toAggregateExpression(), Literal(2)), "built-in")
    
        val seq = Seq(1.245, 242.535, 65656.234425, 2343.666)
        import sparkSession.implicits._
    
        seq.toDF("amount")//.selectExpr("round(sum(amount), 2)").show
          .selectExpr("dollarSum(amount)").show
    

    NB/FYI: An obvious idea with Quality would be to use a lambda:

        import com.sparkutils.quality.{LambdaFunction, Id, registerLambdaFunctions, registerQualityFunctions}
        registerQualityFunctions()
        registerLambdaFunctions(Seq(
          LambdaFunction("dollarSum", "a -> round(sum(a), 2)", Id(1,1))
        ))
    

    this however fails as Spark LambdaFunction's and AggregateFunctions don't readily mix. The direct FunctionRegistry route doesn't involve a LambdaFunction and so works correctly.

    Extra info, per comments questions...

    Why "built-in", it's used to specify sources, you can't create the function unless it's a valid source (from ExpressionInfo):

    private static final Set<String> validSources =
                new HashSet<>(Arrays.asList("built-in", "hive", "python_udf", "scala_udf", "java_udf"));
    

    as such only built-in is close. The name refers to the static FunctionRegistry.builtin instance which houses all the normal spark sql functions - and what you need to use if you want to use the function in create view etc.

    Re the builder - as I wrote above it's a function that takes expressions and returns an expression i.e. the constructor. You will need to call createOrReplaceTempFunction (or the others I mention above) to actually register but it's just a name and Seq[Expression] => Expression pair, easy enough to manage differently. As Spark's interface changes each couple of releases for this the actual call in Quality is made in different Spark compatibility layers e.g. 10.4 LTS or oss 2.4, the functions themselves are however managed here and below.

    In order to provide some useful errors on parameter handling I also specify parameter combinations handled here.

    Now, in order to make more complicated logic you will have to understand each of the Spark Expressions themselves and many of them change each release, worst as you are using Databricks is that the OSS version advertised is only for the public interfaces, this means you must sometimes guess or use reflection to figure out what the Expressions actually look like on Databricks. Typically this is just backports of future releases, but not always, there have been traits that were swapped for abstract classes leading to hideous workarounds like this where I have to shim the types to correctly compile under OSS with a target of DBR 9.1, caveat emptor.

    That said although there is the odd surprise and risk waiting, e.g. a DBR version backports a fix or feature that breaks interface without bumping version. So you are calmly and happily using your Sum code on 10.4 but overnight 10.4 stops working and your DECIMAL sums are clearly suffering overflow of some kind. Every other user of 10.4 they get a nice performance bump, but you get broken math... So be prepared to continuously test and be able to make fixes quickly, this is the price for using internals.

    To be really clear - I deeply appreciate the Databricks product and team, this issue is not one of their making, it'd be yours (and clearly mine) for using internals apis.

    The core Spark team has also openly stated they don't approve of such usage wrt. Frameless (3.2.0 to 3.2.1 changed the internal Encoder API, breaking Frameless users). Clearly they too should be free to innovate and re-organise internal api's. The performance and flexibility of using them though...