Search code examples
scalasortingapache-sparkrankingbigdata

Sorting and ranking in apache spark scala?


I wanted to do ranking in spark, as follows:

Input:

5.6
5.6
5.6
6.2
8.1
5.5
5.5

Ranks:

1
1
1
2
3
0
0
0

Output:

Rank Input 
0     5.5
0     5.5
1     5.6
1     5.6
1     5.6
2     6.2
3     8.1

I wanted to know how I can sort these in spark and also get the same ranking as listed above. The requirements are:

  1. ranking starts with 0 not 1
  2. this is a sample case for millions of records and one partition may be very large - I appreciate recommendation on how to rank using an internal sorting method

I wanted to do this in scala. Can someone help me write code for this?


Solution

  • If you expect just some ranks you could first get all distinct values, collect them as a List and transform them into a BroadCast. Below, I show a dirty example, notice that it isn't guaranteed that the output will be sorted (there might probably be better approaches, but this is the first thing that comes to my mind):

    // Case 1. k is small (fits in the driver and nodes)
    val rdd = sc.parallelize(List(1,1,44,4,1,33,44,1,2))
    val distincts = rdd.distinct.collect.sortBy(x => x)
    val broadcast = sc.broadcast(distincts)
    
    val sdd = rdd.map{
      case i: Int => (broadcast.value.asInstanceOf[Array[Int]].indexOf(i), i)
    }
    
    sdd.collect()
    
    // Array[(Int, Int)] = Array((0,1), (0,1), (4,44), (2,4), (0,1), (3,33), (4,44), (0,1), (1,2))
    

    In the second approach I sort using Spark's functionality, in the RDD's documentation you could find how zipWithIndex and keyBy work.

    //case 2. k is big, distinct values don't fit in the Driver.
    val rdd = sc.parallelize(List(1,1,44,4,1,33,44,1,2))
    val distincts = rdd.distinct.sortBy(x => x).zipWithIndex
    rdd.keyBy(x => x)
      .join(distincts.keyBy(_._1))
      .map{
        case (value: Int, (v1: Int, (v2: Int, index: Long))) => (index, value)
      }.collect()
    
    //res15: Array[(Long, Int)] = Array((3,33), (2,4), (0,1), (0,1), (0,1), (0,1), (4,44), (4,44), (1,2))
    

    By the way, I use collect just for visualization purposes, in a real app you shouldn't use it unless you are sure it fits in the driver's memory.