Search code examples
scalaapache-spark-sqlwindow-functionsapache-spark-ml

Spark ML Transformer - aggregate over a window using rangeBetween


I would like to create custom Spark ML Transformer that applies an aggregation function within rolling window with the construct over window. I would like to be able to use this transformer in Spark ML Pipeline.

I would like to achieve something that could be done quite easily with withColumn as given in this answer

Spark Window Functions - rangeBetween dates

for example:

val w = Window.orderBy(col("unixTimeMS")).rangeBetween(0, 700)
val df_new = df.withColumn("cts", sum("someColumnName").over(w))

Where

  • df is my dataframe
  • unixTimeMS is unix time in milliseconds
  • someColumnName is some column that I want to perform aggregation. In this example I do a sum over the rows within the window.
  • the window w includes current transaction and all transactions within 700 ms from the current transaction.

Is it possible to put such window aggregation into Spark ML Transformer?

I was able to achieve something similar with Spark ML SQLTransformer where the

val query = """SELECT *,
              sum(someColumnName) over (order by unixTimeMS) as cts
              FROM __THIS__"""

new SQLTransformer().setStatement(query)

But I can't figure out how to use rangeBetween in SQL to select period of time. Not just number of rows. I need specific period of time with respect to unixTimeMS of the current row.

I understand the Unary Transforme is not the way to do it because I need to make an aggregate. Do I need to define a UDAF (user defined aggregate function) and use it in SQLTransformer? I wasn't able to find any example of UDAF containing window function.


Solution

  • I am answering my own question for the future reference. I ended up using SQLTransformer. Just like the window function in the example where I use range between:

    val query = SELECT *,
    sum(dollars) over (
          partition by Numerocarte
          order by UnixTime
          range between 1000 preceding and 200 following) as cts
          FROM __THIS__"
    

    Where 1000 and 200 in range between relate to units of the order by column.