Search code examples
javastringapache-sparksplitrdd

Is there a way to split an RDD by rows?


i have a bunch of Data with 20000 rows in a JavaRDD. Now i want to save several files with exact the same size (like 70 rows per file).

I tried it with the code below, but because it is not exactly dividable some data sets consist of 69, 70 or 71 rows. The struggle is I need all with the same size except the last record (it can have less).

Help is appreciated!!! Thanks in advance guys!

myString.repartition(286).saveAsTextFile(outputPath);


Solution

  • Unfortunately a Scala answer, but it works.

    First define a custom partitioner:

    class IndexPartitioner[V](n_per_part: Int, rdd: org.apache.spark.rdd.RDD[_ <: Product2[Long, V]], do_cache: Boolean = true) extends org.apache.spark.Partitioner {
    
        val max = {
            if (do_cache) rdd.cache()
            rdd.map(_._1).max
        }
    
        override def numPartitions: Int = math.ceil(max.toDouble/n_per_part).toInt
        override def getPartition(key: Any): Int = key match {
            case k:Long => (k/n_per_part).toInt
            case _ => (key.hashCode/n_per_part).toInt
        }
    }
    

    Create an RDD of random strings and index it:

    val rdd = sc.parallelize(Array.tabulate(1000)(_ => scala.util.Random.alphanumeric.filter(_.isLetter).take(5).mkString))  
    val rdd_idx = rdd.zipWithIndex.map(_.swap)
    

    Create the partitioner and apply it:

    val partitioner = new IndexPartitioner(70, rdd_idx)
    val rdd_part = rdd_idx.partitionBy(partitioner).values
    

    Check partition sizes:

    rdd_part
      .mapPartitionsWithIndex{case (i,rows) => Iterator((i,rows.size))}
      .toDF("partition_number","number_of_records")
      .show
    
    /**
    +----------------+-----------------+
    |               0|               70|
    |               1|               70|
    |               2|               70|
    |               3|               70|
    |               4|               70|
    |               5|               70|
    |               6|               70|
    |               7|               70|
    |               8|               70|
    |               9|               70|
    |              10|               70|
    |              11|               70|
    |              12|               70|
    |              13|               70|
    |              14|               20|
    +----------------+-----------------+
    */
    

    One file for each partition:

    import sqlContext.implicits._
    rdd_part.toDF.write.format("com.databricks.spark.csv").save("/tmp/idx_part_test/")
    

    (+1 for "_SUCCESS")

    XXX$ ls /tmp/idx_part_test/ | wc -l
    16