Search code examples
scalaapache-spark

Aggregator Function in Scala Spark over multiple columns to create a hash


I'm trying to write a custom UDAF/Aggregator in Scala Spark 3.3.x or above to get the concatenated hash of a hash column ordered by an id column.

Here is the example dataframe I'm using:

val df = Seq(
      (1, 10, "hash1"),
      (2, 10, "hash2"),
      (3, 20, "hash3"),
      (4, 20, "hash4")
    )
      .map { case (id, group_id, hash) =>
        (id.toString, group_id.toString, hash)
      }
      .toDF("id", "group_id", "hash")

--

root
 |-- id: string (nullable = true)
 |-- group_id: string (nullable = true)
 |-- hash: string (nullable = true)

--

+---+--------+-----+
|id |group_id|hash |
+---+--------+-----+
|1  |10      |hash1|
|2  |10      |hash2|
|3  |20      |hash3|
|4  |20      |hash4|
+---+--------+-----+

and here is the custom aggregator class I've come up with so far:

case class IdHashPair(id: String, hash: String)

class HashAggregator extends Aggregator[IdHashPair, List[IdHashPair], String] {
  def zero: List[IdHashPair] = List.empty[IdHashPair]

  def reduce(buffer: List[IdHashPair], idHashPair: IdHashPair): List[IdHashPair] =
    buffer :+ idHashPair

  def merge(b1: List[IdHashPair], b2: List[IdHashPair]): List[IdHashPair] = b1 ++ b2

  def finish(reduction: List[IdHashPair]): String = {
    val orderedHashes = reduction.sortBy(_.id).map(_.hash).mkString
    val md = MessageDigest.getInstance("SHA-256")
    val hashBytes = md.digest(orderedHashes.getBytes("UTF-8"))
    hashBytes.map("%02x".format(_)).mkString
  }

  def bufferEncoder: Encoder[List[IdHashPair]] = Encoders.product[List[IdHashPair]]
  def outputEncoder: Encoder[String] = Encoders.STRING
}

When I try to use the aggregator function as shown below:

df
  .groupBy($"group_id")
  .agg(hashAggregator(struct($"id", $"hash")))

I get the following exception

Exception in thread "main" org.apache.spark.sql.AnalysisException: [UNRESOLVED_COLUMN.WITH_SUGGESTION] A column or function parameter with name `id` cannot be resolved. Did you mean one of the following? [`id`, `hash`, `group_id`].;

This error doesn't make sense to me since the column id is clearly in the schema as shown in the column suggestions in the error. I also tried to make the HashAggregator class use Row as the IN component and it still ends up in an error.

What am I getting wrong here?


Solution

  • try next code

    val hashAggregator = functions.udaf(new HashAggregator())
    
    val df = Seq(
      (1, 10, "hash1"),
      (2, 10, "hash2"),
      (3, 20, "hash3"),
      (4, 20, "hash4")
    )
      .map { case (id, group_id, hash) =>
        (id.toString, group_id.toString, hash)
      }
      .toDF("id", "group_id", "hash")
    
    df
      .groupBy($"group_id")
      .agg(hashAggregator($"id", $"hash"))
      .show(false)