Search code examples
scalaapache-sparkapache-spark-sqlapache-spark-datasetapache-spark-2.0

Apache spark join with dynamic re-partitionion


I'm trying to do a fairly straightforward join on two tables, nothing complicated. Load both tables, do a join and update columns but it keeps throwing an exception.

I noticed the task is stuck on the last partition 199/200 and eventually crashes. My suspicion is that the data is skewed causing all the data to be loaded in the last partition 199.

SELECT COUNT(DISTINCT report_audit) FROM ReportDs = 1.5million.

While

SELECT COUNT(*) FROM ReportDs = 57million.

Cluster details: CPU: 40 cores, Memory: 160G.

Here is my sample code:

...
def main(args: Array[String]) {

  val log = LogManager.getRootLogger
  log.setLevel(Level.INFO)

  val conf = new SparkConf().setAppName("ExampleJob")
                          //.setMaster("local[*]")
                          //.set("spark.sql.shuffle.partitions", "3000")
                          //.set("spark.sql.crossJoin.enabled", "true")
                          .set("spark.storage.memoryFraction", "0.02")
                          .set("spark.shuffle.memoryFraction", "0.8")
                          .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
                          .set("spark.default.parallelism", (CPU * 3).toString)


  val sparkSession = SparkSession.builder()
                                 .config(conf)
                                 .getOrCreate()


  val reportOpts = Map(
              "url"     -> s"jdbc:postgresql://$DB_HOST:$DB_PORT/$DATABASE",
              "driver"  -> "org.postgresql.Driver",
              "dbtable" -> "REPORT_TBL",
              "user"    -> DB_USER,
              "password"-> DB_PASSWORD,
              "partitionColumn" -> RPT_NUM_PARTITION,
              "lowerBound" -> RPT_LOWER_BOUND,
              "upperBound" -> RPT_UPPER_BOUND,
              "numPartitions" -> "200"
            )


  val accountOpts = Map(
                "url"     -> s"jdbc:postgresql://$DB_HOST:$DB_PORT/$DATABASE",
                "driver"  -> "org.postgresql.Driver",
                "dbtable" -> ACCOUNT_TBL,
                "user"    -> DB_USER,
                "password"-> DB_PASSWORD,
                "partitionColumn" -> ACCT_NUM_PARTITION,
                "lowerBound" -> ACCT_LOWER_BOUND,
                "upperBound" -> ACCT_UPPER_BOUND,
                "numPartitions" -> "200"
              )

  val sc = sparkSession.sparkContext;

  import sparkSession.implicits._

  val reportDs = sparkSession.read.format("jdbc").options(reportOpts).load.cache().alias("a")

  val accountDs = sparkSession.read.format("jdbc").options(accountOpts).load.cache().alias("c")

  val reportData =  reportDs.join(accountDs, reportDs("report_audit") === accountDs("reference_id"))
                                        .withColumn("report_name", when($"report_id" === "xxxx-xxx-asd", $"report_id_ref_1")
                                                                   .when($"report_id" === "demoasd-asdad-asda", $"report_id_ref_2")
                                                                   .otherwise($"report_id_ref_1" + ":" + $"report_id_ref_2"))
                                        .withColumn("report_version", when($"report_id" === "xxxx-xxx-asd", $"report_version_1")
                                                                       .when($"report_id" === "demoasd-asdad-asda", $"report_version_2")
                                                                       .otherwise($"report_version_3"))
                                        .withColumn("status", when($"report_id" === "xxxx-xxx-asd", $"report_status")
                                                                .when($"report_id" === "demoasd-asdad-asda", $"report_status_1")
                                                                .otherwise($"report_id"))
                                        .select("...")






  val prop = new Properties()
  prop.setProperty("user", DB_USER)
  prop.setProperty("password", DB_PASSWORD)
  prop.setProperty("driver", "org.postgresql.Driver")


  reportData.write
                  .mode(SaveMode.Append)
                  .jdbc(s"jdbc:postgresql://$DB_HOST:$DB_PORT/$DATABASE", "cust_report_data", prop)


  sparkSession.stop()

I think there should be an elegant way to handle this sort of data skewness.


Solution

  • Your values for partitionColumn, upperBound, and lowerBound could cause this exact behavior if they aren't set correctly. For instance, if lowerBound == upperBound, then all of the data would be loaded into a single partition, regardless of numPartitions.

    The combination of these attributes determines which (or how many) records get loaded into your DataFrame partitions from your SQL database.