I have an RDD like this:
JavaPairRDD<String, String>
that has lot of entries and some keys are repeated lot of times. When I apply either groupByKey
or combineByKey
, its generating another
JavaPairRDD<String, Iterable<String>
Here is the problem, for some set of keys, the number of values is very huge (because particular keys are skewed). This is causing issue in further downstream consumption even produce memory issues.
My question is how to limit the number of values aggregated per key. I want to group by key, however the value list should not go beyond limit X number. Any overflowing values should be added to a new line, is there a way to do this?
This can be solved using flatMap
. I'm not sure how to write it in Java, however, hopefully you can convert the Scala code. Code with example input:
val rdd = spark.sparkContext.parallelize(Seq((1, Iterable(1,2,3,4,5)), (2, Iterable(6,7,8)), (3, Iterable(1,3,6,8,4,2,7,8,3))))
val maxLength = 3
val res = rdd.flatMap{ case(id, vals) =>
vals.grouped(maxLength).map(v => (id, v))
}
The idea is to split the list into a list of list where each inner list has a max length. Since, flatMap
is used here the list of list will be flattened into a simple list which is the result you want. Using a max length of 3 and printing res
gives:
(1,List(1, 2, 3))
(1,List(4, 5))
(2,List(6, 7, 8))
(3,List(1, 3, 6))
(3,List(8, 4, 2))
(3,List(7, 8, 3))