Search code examples
scalaapache-sparkdataframeuser-defined-aggregate

Can every Spark UDAF be used with Window?


I always thought that Spark does not allow to define User-Defined-Window-Functions. I just tested the "Geometric Mean" UDAF example from here (https://docs.databricks.com/spark/latest/spark-sql/udaf-scala.html) as a window function, and it seems to work just fine, e.g.:

val geomMean = new GeometricMean

(1 to 10).map(i=>
  (i,i.toDouble)
)
.toDF("i","x")
.withColumn("geom_mean",geomMean($"x").over(Window.orderBy($"i").rowsBetween(-1,1)))
.show()

+---+----+------------------+
|  i|   x|         geom_mean|
+---+----+------------------+
|  1| 1.0|1.4142135623730951|
|  2| 2.0|1.8171205928321397|
|  3| 3.0|2.8844991406148166|
|  4| 4.0|3.9148676411688634|
|  5| 5.0|  4.93242414866094|
|  6| 6.0| 5.943921952763129|
|  7| 7.0| 6.952053289772898|
|  8| 8.0| 7.958114415792783|
|  9| 9.0| 8.962809493114328|
| 10|10.0| 9.486832980505138|
+---+----+------------------+

I've never seen the spark docs talk about using UDAF as window function. Is this allowed, i.e. are the results correct? I'm using spark 2.1 by the way

EDIT:

What confuses me is that in standard aggregation (i.e. followed by a groupBy), data is always added to the buffers, i.e. they will always grow, never shrink. With window function (especially in conjunction with rowsBetween()), data also need to be removed from the buffer, as "old" element will drop out of the window as it moves along the rows defined by the ordering. I think of window-functions to move along the ordering with a state. So I assumed there must be something like a "remove" method to be implemented


Solution

  • I am not sure what exactly is your question.

    Can every Spark UDAF be used with Window?

    Yes

    Here is my personal experience in this topic:

    I have been working lately a lot with Spark window functions and UDAFs (Spark 2.0.1) and I confirm they work very well together. Results are correct (assuming your UDAF is correctly written). UDAFs are a bit of a pain to write, but once you get it, it goes fast for next ones.

    I didn't test all of them, but build-in aggregation functions from org.apache.spark.sql.functions._ worked also for me. Search for Aggregate in functions. I was working mostly with some classical aggregators like sum, count, avg, stddev and they all returned correct values.