I'm trying to create an UDAF on Spark (2.0.1, Scala 2.11) as below. This is to essentially aggregates tuples and output a Map
import org.apache.spark.sql.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.{Row, Column}
class mySumToMap[K, V](keyType: DataType, valueType: DataType) extends UserDefinedAggregateFunction {
override def inputSchema = new StructType()
.add("a_key", keyType)
.add("a_value", valueType)
override def bufferSchema = new StructType()
.add("buffer_map", MapType(keyType, valueType))
override def dataType = MapType(keyType, valueType)
override def deterministic = true
override def initialize(buffer: MutableAggregationBuffer) = {
buffer(0) = Map[K, V]()
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
// input :: 0 = a_key (k), 1 = a_value
if ( !(input.isNullAt(0)) ) {
val a_map = buffer(0).asInstanceOf[Map[K, V]]
val k = input.getAs[K](0) // get the value of position 0 of the input as string (a_key)
// I've split these on purpose to show that return values are all of type V
val new_v1: V = a_map.getOrElse(k, 0.asInstanceOf[V])
val new_v2: V = input.getAs[V](1)
val new_v: V = new_v1 + new_v2
buffer(0) = if (new_v != 0) a_map + (k -> new_v) else a_map - k
}
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
val map1: Map[K, V] = buffer1(0).asInstanceOf[Map[K, V]]
val map2: Map[K, V] = buffer2(0).asInstanceOf[Map[K, V]]
buffer1(0) = map1 ++ map2.map{ case (k,v) => k -> (v + map1.getOrElse(k, 0.asInstanceOf[V])) }
}
override def evaluate(buffer: Row) = buffer(0).asInstanceOf[Map[K, V]]
}
But when I compile this, I see the below error:
<console>:74: error: type mismatch;
found : V
required: String
val new_v: V = new_v1 + new_v2
^
<console>:84: error: type mismatch;
found : V
required: String
buffer1(0) = map1 ++ map2.map{ case (k,v) => k -> (v + map1.getOrElse(k, 0.asInstanceOf[V])) }
What am I doing wrong?
EDIT: For the folks marking this as a duplicate of Spark UDAF - using generics as input type? - this is not a duplicate of that issue, as that one does not deal with Map datatype. The above code is very specific and complete regarding the problem faced using the Map datatype.
Restrict types to the ones having Numeric[_]
class mySumToMap[K, V: Numeric](keyType: DataType, valueType: DataType)
extends UserDefinedAggregateFunction {
...
use Implicitly
to get it on runtime:
val n = implicitly[Numeric[V]]
and use its plus
method in place of +
an zero
in place of 0
buffer1(0) = map1 ++ map2.map{
case (k,v) => k -> n.plus(v, map1.getOrElse(k, n.zero))
}
To support wider set of types you can use cats
Monoid
:
import cats._
import cats.implicits._
and adjust the code:
class mySumToMap[K, V: Monoid](keyType: DataType, valueType: DataType)
extends UserDefinedAggregateFunction {
...
and later:
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
val map1: Map[K, V] = buffer1.getMap[K, V](0)
val map2: Map[K, V] = buffer2.getMap[K, V](0)
val m = implicitly[Monoid[Map[K, V]]]
buffer1(0) = m.combine(map1, map2)
}