Search code examples
algorithmscalanearest-neighborkdtreer-tree

Efficient nearest neighbour search in Scala


Let this coordinates class with the Euclidean distance,

case class coord(x: Double, y: Double) {
  def dist(c: coord) = Math.sqrt( Math.pow(x-c.x, 2) + Math.pow(y-c.y, 2) ) 
}

and let a grid of coordinates, for instance

val grid = (1 to 25).map {_ => coord(Math.random*5, Math.random*5) }

Then for any given coordinate

val x = coord(Math.random*5, Math.random*5) 

the nearest points to x are

val nearest = grid.sortWith( (p,q) => p.dist(x) < q.dist(x) )

so the first three closest are nearest.take(3).

Is there a way to make these calculations more time efficient especially for the case of a grid with one million points ?


Solution

  • I'm not sure if this is helpful (or even stupid), but I thought of this:

    You use a sort-function to sort ALL elements in the grid and then pick the first k elements. If you consider a sorting algorithm like recursive merge-sort, you have something like this:

    1. Split collection in half
    2. Recurse on both halves
    3. Merge both sorted halves

    Maybe you could optimize such a function for your needs. The merging part normally merges all elements from both halves, but you are only interested in the first k that result from the merging. So you could only merge until you have k elements and ignore the rest.

    So in the worst-case, where k >= n (n is the size of the grid) you would still only have the complexity of merge-sort. O(n log n) To be honest I'm not able to determine the complexity of this solution relative to k. (too tired for that at the moment)

    Here is an example implementation of that solution (it's definitely not optimal and not generalized):

    def minK(seq: IndexedSeq[coord], x: coord, k: Int) = {
    
      val dist = (c: coord) => c.dist(x)
    
      def sort(seq: IndexedSeq[coord]): IndexedSeq[coord] = seq.size match {
        case 0 | 1 => seq
        case size => {
          val (left, right) = seq.splitAt(size / 2)
          merge(sort(left), sort(right))
        }
      }
    
      def merge(left: IndexedSeq[coord], right: IndexedSeq[coord]) = {
    
        val leftF = left.lift
        val rightF = right.lift
    
        val builder = IndexedSeq.newBuilder[coord]
    
        @tailrec
        def loop(leftIndex: Int = 0, rightIndex: Int = 0): Unit = {
          if (leftIndex + rightIndex < k) {
            (leftF(leftIndex), rightF(rightIndex)) match {
              case (Some(leftCoord), Some(rightCoord)) => {
                if (dist(leftCoord) < dist(rightCoord)) {
                  builder += leftCoord
                  loop(leftIndex + 1, rightIndex)
                } else {
                  builder += rightCoord
                  loop(leftIndex, rightIndex + 1)
                }
              }
              case (Some(leftCoord), None) => {
                builder += leftCoord
                loop(leftIndex + 1, rightIndex)
              }
              case (None, Some(rightCoord)) => {
                builder += rightCoord
                loop(leftIndex, rightIndex + 1)
              }
              case _ =>
            }
          }
        }
    
        loop()
    
        builder.result
      }
    
      sort(seq)
    }