Search code examples
scalaapache-sparksimilarityrddset-intersection

Spark: How to efficiently have intersections preserving duplicates (in Scala)?


I have 2 RDDs, each of which is a set of strings containing duplicates. I want to find the intersection of the two sets preserving duplicates. Example:

RDD1 : a, b, b, c, c, c, c

RDD2 : a, a, b, c, c

The intersection I want is the set a, b, c, c i.e. the intersection would contain each element the minimum amount of times that it is present in both the sets.

The default intersection transformation does not preserve duplicates AFAIK. Is there a way to efficiently compute the intersection using some other transformations and/or the intersection transformation? I'm trying to avoid doing it algorithmically, which is unlikely to be as efficient as doing it the Spark way. (For the interested, I'm trying to compute Jaccard bag similarity for a set of files).


Solution

  • Borrowing a little from the implementation of intersection, you could do something like:

    (val rdd1 = sc.parallelize(Seq("a", "b", "b", "c", "c", "c", "c")))
    (val rdd2 = sc.parallelize(Seq("a", "a", "b", "c", "c")))
    
    val cogrouped = rdd1.map(k => (k, null)).cogroup(rdd2.map(k => (k, null)))
    val groupSize = cogrouped.map { case (key, (buf1, buf2)) => (key, math.min(buf1.size, buf2.size)) }
    val finalSet = groupSize.flatMap { case (key, size) => List.fill(size)(key) }
    
    (finalSet.collect = Array(a, b, c, c))
    

    This works because cogroup will maintain duplicate occurrences of values of a pair for each grouping (in this case, all of your nulls). Also note that we are doing no more shuffles here than we would have with the original use of intersection.