Search code examples
scalaquicksortimplicit

Scala: comparison incorrectly evaluated using Ordering trait


I have the following implementation of SampleSort:

import scala.reflect.ClassTag
import ca.vgorcinschi.ArrayOps

import Ordered._

//noinspection SpellCheckingInspection
class SampleSort[T: ClassTag : Ordering](val sampleSize: Int = 30) extends QuickSort[T] {

  import SearchTree._

  override def sort(a: Array[T]): Array[T] = {
    require(a != null, "Passed-in array should not be null")
    sortHelper(a)
  }

  private def sortHelper(a: Array[T]): Array[T] = {
    //if the array is shorter then the sampling - sort it with Quicksort
    if (a.length <= sampleSize) return super.sort(a)

    /*
      just the indices for the sample array.
      also required later for figuring out the nonPartitionedRemainder of the array
     */
    val sampleArrayIndices: Array[Int] = a.subArrayOfSize(sampleSize)
    val sampleArray: Array[T] = sampleArrayIndices map (a(_))

    val sortedSampleArray: Array[T] = sort(sampleArray, 0, sampleArray.length - 1)
    val searchTree: SearchTree = buildTree(sortedSampleArray, sampleSize / 2)
    val nonPartitionedRemainder = a.slice(0, sampleArrayIndices.head) ++ a.slice(sampleArrayIndices.last + 1, a.length)
    val finalTree = (searchTree /: nonPartitionedRemainder) (_ nest _)
    finalTree.arrays() flatMap sort
  }

  private class SearchTree(lt: Array[T], median: Array[T], gt: Array[T]) {
    //hear median is guaranteed to be non null and non empty based off the partitioning in sortHelper
    private val pivot: T = median.head

    def nest(value: T): SearchTree = {
      if (value < pivot) SearchTree(lt :+ value, median, gt)
      if (value > pivot) SearchTree(lt, median, gt :+ value)
      else SearchTree(lt, median :+ value, gt)
    }

    def arrays(): Array[Array[T]] = Array(lt, median, gt)
  }

  private object SearchTree {
    def buildTree(sample: Array[T], pivot: Int): SearchTree = {
      //do not look beyond pivot since sample is guaranteed to be partitioned
      val lt = sample.takeWhile(_ < sample(pivot))
      //only look from pivot and up
      val medianAndGt: (Array[T], Array[T]) = sample.slice(lt.length, sample.length) partition (_ == sample(pivot))
      SearchTree(lt, medianAndGt._1, medianAndGt._2)
    }

    def apply(lt: Array[T], median: Array[T], gt: Array[T]): SearchTree = new SearchTree(lt, median, gt)
  }

}

Briefly what this code does:

  1. Sort a sample of the passed-in array
  2. Put values lt, eq or gt in corresponding buckets
  3. Distribute unsorted part of the array in one of these buckets
  4. Repeat recursively

This is currently failing in SearchTree.nest method (point 3 above) because all the values are getting into the median (eq) bucket:

enter image description here

However similar comparisons work inside SearchTree.buildTree object function, using the same import Ordered._ operations!

I am not sure what am I missing here. I would appreciate any help or advice in this matter.


Solution

  • You are missing else before if (value > pivot). Your current code in nest says:

    1. if value < pivot, build a new SearchTree and throw it away;

    2. if value > pivot ...

    So when value < pivot holds, you get the else branch of the second if.