Search code examples
scalaunit-testingapache-sparkscalatest

Assert RDD is not sorted


I have a method called split that accepts an RDD[T] and a splitSize and returns an Array[RDD[T]].

Now, one of the test cases I write for it should verify that this function also randomly shuffles the RDD.

So I create a sorted RDD, and then see the results:

  it should "randomize shuffle" in {
    val inputRDD = sc.parallelize((0 until 16))
    val result = RDDUtils.split(inputRDD, 2)

    result.foreach(rdd => {
      rdd.collect.foreach(println)
    })

    // Asset result is not sorted
  }

If the results are:

0 1 2 3 .. 15

Then it's not working as expected.

A good result can be something like:

11 3 9 14 ... 1 6

How can I assert the output Array[RDD[T]]] is not sorted?


Solution

  • You could try something like this

    val resultOrder = result.sortBy(....)
    assert(!resultOrder.sameElements(result))
    

    or

    val resultOrder = result.sortBy(....)
    assert(!resultOrder.toList == result.toList)
    

    It's important to note that the key is to know how to sort the Array. For an Integer data type it would be easy, but for a complex data type you could need an implicit Ordering for your data type. e.g:

    implicit val ordering: Ordering[T] =
        Ordering.fromLessThan[T]((sa: T, sb: T) => sa < sb)
    
    // OR
    
    implicit val ordering: Ordering[MyClass] =
        Ordering.fromLessThan[MyClass]((sa: MyClass, sb: MyClass) => sa.field1 < sb.field1)
    

    The exact code would depend of your data type.

    As a full example of this

    package tests
    
    import org.apache.log4j.{Level, Logger}
    import org.apache.spark.rdd.RDD
    import org.apache.spark.sql.SparkSession
    
    object SortArrayRDD {
    
      val spark = SparkSession
        .builder()
        .appName("SortArrayRDD")
        .master("local[*]")
        .config("spark.sql.shuffle.partitions","4") //Change to a more reasonable default number of partitions for our data
        .config("spark.app.id","SortArrayRDD") // To silence Metrics warning
        .getOrCreate()
    
      val sc = spark.sparkContext
    
      def main(args: Array[String]): Unit = {
        try {
    
          Logger.getRootLogger.setLevel(Level.ERROR)
    
          val arrRDD: Array[RDD[Int]] = Array(sc.parallelize(List(2,3)),sc.parallelize(List(10,11)),sc.parallelize(List(6,7)),sc.parallelize(List(8,9)),
            sc.parallelize(List(4,5)),sc.parallelize(List(0,1)),sc.parallelize(List(12,13)),sc.parallelize(List(14,15)))
          val aux = arrRDD
    
          implicit val ordering: Ordering[RDD[Int]] = Ordering.fromLessThan[RDD[Int]]((sa: RDD[Int], sb: RDD[Int]) => sa.sum() < sb.sum())
    
          aux.sorted.foreach(rdd => println(rdd.collect().mkString(",")))
    
          val resultOrder = aux.sorted
    
          assert(!resultOrder.sameElements(arrRDD))
          println("It's unordered")
        } finally {
          sc.stop()
        }
      }
    }