Search code examples
javaapache-sparkbigdatardd

Spark - How to keep max limit on number of values grouped in JavaPairRDD


I have an RDD like this:

JavaPairRDD<String, String> 

that has lot of entries and some keys are repeated lot of times. When I apply either groupByKey or combineByKey, its generating another

JavaPairRDD<String, Iterable<String>

Here is the problem, for some set of keys, the number of values is very huge (because particular keys are skewed). This is causing issue in further downstream consumption even produce memory issues.

My question is how to limit the number of values aggregated per key. I want to group by key, however the value list should not go beyond limit X number. Any overflowing values should be added to a new line, is there a way to do this?


Solution

  • This can be solved using flatMap. I'm not sure how to write it in Java, however, hopefully you can convert the Scala code. Code with example input:

    val rdd = spark.sparkContext.parallelize(Seq((1, Iterable(1,2,3,4,5)), (2, Iterable(6,7,8)), (3, Iterable(1,3,6,8,4,2,7,8,3))))
    
    val maxLength = 3
    val res = rdd.flatMap{ case(id, vals) =>
      vals.grouped(maxLength).map(v => (id, v))
    }
    

    The idea is to split the list into a list of list where each inner list has a max length. Since, flatMap is used here the list of list will be flattened into a simple list which is the result you want. Using a max length of 3 and printing res gives:

    (1,List(1, 2, 3))
    (1,List(4, 5))
    (2,List(6, 7, 8))
    (3,List(1, 3, 6))
    (3,List(8, 4, 2))
    (3,List(7, 8, 3))