Search code examples
scalaapache-sparkneighbours

How to find grid neighbours (x, y as integers) group them and calculate mean of their values in spark


I'm struggling to find a way to calculate neighbours avarage value from data set that looks like this:

+------+------+---------+
|     X|     Y|  value  |
+------+------+---------+
|     1|     5|   1     |
|     1|     8|   1     |
|     1|     6|   6     |
|     2|     8|   5     |
|     2|     6|   3     |
+------+------+---------+

For example:

(1, 5) neighbours would be (1,6), (2,6) so I need to find mean of all their values and the answer here would be (1 + 6 + 3) / 3 = 3.33

(1, 8) neighbours would be (2, 8) and the mean of their values would be (1 + 5) / 2 = 3

I'm hoping my solution to look something like this (I just concat coordinates as strings here for the key):

+--------------------------+
|  neighbour_values | mean |
+--------------------------+
| (1,5)_(1,6)_(2,6) | 3.33 |
| (1,8)_(2,8)       | 3    |
+--------------------------+

I've tried it with column concatenation but didn't seem to go far. One of the solutions that I'm thinking of is to iterate threw table twice, once for element and again for the other values and check if its a neighbour or not. Unfortunately, I'm fairly new to spark and I can't seem to find any information on how to do it.

ANY help is VERY much appreciated! Thank you!:))


Solution

  • The answer depends on if you are concerned with only grouping by adjacent neighbors. That scenario can lead to ambiguity, if say, there is a contiguous block of greater than width or height of two items. Therefore the approach below assumes that all items in a contiguous set of coordinates is bunched into a single group, and that each original record belongs to exactly one grouping.

    This assumption of partitioning the set into disjoint coordinates lends itself to the union-find algorithm.

    Since union-find is recursive, this approach collects the original elements into memory and creates a UDF based on those values. Note that this can be slow and/or require a lot of memory for large datasets.

    // create example DF
    val df = Seq((1, 5, 1), (1, 8, 1), (1, 6, 6), (2, 8, 5), (2, 6, 3)).toDF("x", "y", "value")
    
    // collect all coordinates into in-memory collections
    val coordinates = df.select("x", "y").collect().map(r => (r.getInt(0), r.getInt(1)))
    val coordSet = coordinates.toSet
    
    type K = (Int, Int)
    val directParent:Map[K,Option[K]] = coordinates.map { case (x: Int, y: Int) =>
      val possibleParents = coordSet.intersect(Set((x - 1, y - 1), (x, y - 1), (x - 1, y)))
      val parent = if (possibleParents.isEmpty) None else Some(possibleParents.min)
      ((x, y), parent)
    }.toMap
    
    // skip unionFind if only concerned with direct neighbors
    def unionFind(key: K, map:Map[K,Option[K]]): K = {
      val mapValue = map.get(key)
      mapValue.map(parentOpt => parentOpt match {
        case None => key
        case Some(parent) => unionFind(parent, map)
      }).getOrElse(key)
    }
    
    val canonicalUDF = udf((x: Int, y: Int) => unionFind((x, y), directParent))
    
    // group using the canonical element
    // create column "neighbors" based on x, y values in each group
    val avgDF = df.groupBy(canonicalUDF($"x", $"y").alias("canonical")).agg(
      concat_ws("_", collect_list(concat(lit("("), $"x", lit(","), $"y", lit(")")))).alias("neighbors"),
      avg($"value")).drop("canonical")
    

    Result:

    avgDF.show(10, false)
    +-----------------+------------------+
    |neighbors        |avg(value)        |
    +-----------------+------------------+
    |(1,8)_(2,8)      |3.0               |
    |(1,5)_(1,6)_(2,6)|3.3333333333333335|
    +-----------------+------------------+