Search code examples
apache-sparkjoinrdd

Spark: RDD Left Outer Join Optimization for Duplicate Keys


THE SCENARIO

I'm trying to write a Spark program that efficiently performs a left outer join between two RDDs. One caveat is that these RDDs can have duplicate keys, which apparently causes the whole program to be inefficient.

What I'm trying to achieve is simple:

  • Given two RDDs: rdd1 and rdd2 (both have the same structure: (k, v))
  • Using rdd1 and rdd2, generate another RDD rdd3 that has the structure: (k1, v1, List(v2..))
  • k1 and v1 come from rdd1 (same values, this will lead to rdd1 and rdd3 have the same length)
  • List(v2..) is a list whose values are coming from the values of rdd2
  • To add an rdd2's v to the list in rdd3's tuple, its k (the key from rdd2) should match the k from rdd1

MY ATTEMPT

My approach was to use a left outer join. So, I came up with something like this:

rdd1.leftOuterJoin(rdd2).map{case(k, (v1, v2)) => ((k, v1), Array(v2))}
                        .reduceByKey(_ ++ _)

This actually produces the result that I'm trying to acheive. But, when I use a huge data, the program becomes very slow.

AN EXAMPLE

Just in case my idea is not clear yet, I have the following example:

Given two RDDs that have the following data:

rdd1:

key | value
-----------
 1  |  a
 1  |  b
 1  |  c
 2  |  a
 2  |  b
 3  |  c

rdd2:

key | value
-----------
 1  |  v
 1  |  w
 1  |  x
 1  |  y
 1  |  z
 2  |  v
 2  |  w
 2  |  x
 3  |  y
 4  |  z

The resulting rdd3 should be

key | value | list
------------------------
1   |   a   |  v,w,x,y,z
1   |   b   |  v,w,x,y,z
1   |   c   |  v,w,x,y,z
2   |   a   |  v,w,x
2   |   b   |  v,w,x
3   |   c   |  y

Solution

  • First of all don't use:

    map { ... => (..., Array(...)) }.reduceByKey(_ ++ _)
    

    That's pretty much as inefficient as it gets. To group values like this using RDDs you should really go with groupByKey.

    Additionally just to groupByKey afterwards is pretty wasteful. You are doing the same job (grouping by key) on the right hand side twice. It makes more sense to use cogroup directly (that's how RDD joins work) and flatMap

    val rdd1 = sc.parallelize(Seq(
      (1, "a"), (1, "b"), (1, "c"), (2, "a"), (2, "b"),(3, "c")
    ))
    
    val rdd2 = sc.parallelize(Seq(
      (1, "v"), (1, "w"), (1, "x"), (1, "y"), (1, "z"), (2, "v"),
      (2, "w"), (2, "x"), (3, "y"),(4, "z")
    ))
    
    val rdd = rdd1
      .cogroup(rdd2)
      .flatMapValues { case (left, right) => left.map((_, right)) }
      .map { case (k1, (k2, vs)) => ((k1, k2), vs) }
    

    You can also use DataSet API which tends to be more efficient in such cases

    import org.apache.spark.sql.functions.collect_list
    
    val df1 = rdd1.toDF("k", "v")
    val df2 = rdd2.toDF("k", "v")
    
    
    df2.groupBy("k")
     .agg(collect_list("v").as("list"))
     .join(rdd1.toDF("k", "v"), Seq("k"), "rightouter")
     .show
    

    The result:

    +---+---------------+---+                 
    |  k|           list|  v|
    +---+---------------+---+
    |  1|[v, w, x, y, z]|  a|
    |  1|[v, w, x, y, z]|  b|
    |  1|[v, w, x, y, z]|  c|
    |  3|            [y]|  c|
    |  2|      [v, w, x]|  a|
    |  2|      [v, w, x]|  b|
    +---+---------------+---+
    

    If the intersect of the sets of keys is small you can try to optimize the process by applying a filter first

    val should_keep = {
      val f = df1.stat.bloomFilter("k", df1.count, 0.005)
      udf((x: Any) => f.mightContain(x))
    }
    
    
    df2.where(should_keep($"k")).groupBy("k")
     .agg(collect_list("v").as("list"))
     .join(rdd1.toDF("k", "v"), Seq("k"), "rightouter")
     .show
    
    +---+---------------+---+
    |  k|           list|  v|
    +---+---------------+---+
    |  1|[v, w, x, y, z]|  a|
    |  1|[v, w, x, y, z]|  b|
    |  1|[v, w, x, y, z]|  c|
    |  3|            [y]|  c|
    |  2|      [v, w, x]|  a|
    |  2|      [v, w, x]|  b|
    +---+---------------+---+
    

    When using Dataset API please be sure to adjust spark.sql.shuffle.partitions to reflect the amount of data you process.

    Note:

    None of that will help you if number of duplicates in rdd2 is large. In such case the overall problem formulation is impossible to defend and you should try to reformulate it, taking into account requirements of the downstream process.