Search code examples
scalaapache-sparkmapsspark3

find set of keys in Scala map where values overlap


I'm working with a map object in scala where the key is a basket ID and the value is a set of item ID's contained within a basket. The goal is to ingest this map object and compute for each basket, a set of other basket ID's that contain at least one common item.

Say the input map object is

val basket = Map("b1" -> Set("i1", "i2", "i3"), "b2" -> Set("i2", "i4"), "b3" -> Set("i3", "i5"), "b4" -> Set("i6"))

Is it possible to perform the computation in spark such that I get the intersecting basket information back? For example val intersects = Map("b1" -> Set("b2", "b3"), "b2" -> Set("b1"), "b3" -> Set("b1"), "b4" -> Set())

Thanks!


Solution

  • Something like...

    val basket = Map("b1" -> Set("i1", "i2", "i3"), "b2" -> Set("i2", "i4"), "b3" -> Set("i3", "i5"), "b4" -> Set("i6"))
    
    def intersectKeys( set : Set[String], map : Map[String,Set[String]] ) : Set[String] = {
      val checks = map.map { case (k, v) =>
        if (set.intersect(v).nonEmpty) Some(k) else None
      }
      checks.collect { case Some(k) => k }.toSet
    }
    
    // each set picks up its own key, which we don't want, so we subtract it back out
    val intersects = basket.map { case (k,v) => (k, intersectKeys(v, basket) - k) }