Search code examples
scalaapache-sparkshardingdata-partitioning

Spark Partition Dataset By Column Value


(I am new to Spark) I need to store a large number of rows of data, and then handle updates to those data. We have unique IDs (DB PKs) for those rows, and we would like to shard the data set by uniqueID % numShards, to make equal sized, addressable partitions. Since the PKs (unique IDs) are present both in the data and in the update files, it will be easy to determine which partition will be updated. We intend to shard the data and the updates by the same criteria, and periodically rewrite "shard S + all updates accumulated for shard S => new shard S". (We know how to combine shard S + updates = new shard S.)

If this is our design, we need to (1) shard a DataFrame by one of its columns (say: column K) into |range(K)| partitions where it is guaranteed that all rows in a partition have the same value in column K and (2) be able to find the Parquet file that corresponds to column_K=k, knowing k = row.uniqueID % numShards.

Is this a good design, or does Spark offer something out of the box that makes our task much easier?

Which Spark class/method should we use for partitioning our data? We are looking at RangePartitioner, but the constructor is asking for the number of partitions. We want to specify "use column_K for partitioning, and make one partition for each distinct value k in range(K)", because we have already created column_K = uniqueID % numShards. Which partitioner is appropriate for splitting on the value of one column of a DataFrame? Do we need to create a custom partitioner, or use partitionBy, or repartitionByRange, or...?

This is what we have so far:

import org.apache.spark.sql.functions._
val df = spark.read
.option("fetchsize", 1000)
.option("driver", "oracle.jdbc.driver.OracleDriver")
.jdbc(jdbc_url, "SCHEMA.TABLE_NAME", partitions, props)
.withColumn("SHARD_ID", col("TABLE_PK") % 1024)
.write
.parquet("parquet/table_name")

Now we need to specify that this DataFrame should be partitioned by SHARD_ID before it is written out as Parquet files.


Solution

  • This works:

    val df = spark.read
    .option("fetchsize", 1000)
    .option("driver", "oracle.jdbc.driver.OracleDriver")
    .jdbc(jdbc.getString("url"), "SCHEMA.TABLE_NAME", partitions, props)
    .withColumn("SHARD_ID", col("TABLE_PK") % 1024)
    .write
    .partitionBy("SHARD_ID")
    .parquet("parquet/table_name")