Search code examples
scalaapache-sparkrdd

How to split an RDD into different RDD's based on a value and give every part to a function


I have an RDD in which every element is a case class, like this: case class Element(target: Boolean, data: String) Now I need to split the RDD based on what the String data is (it is a discrete variable). And then execute a function def f(elements: RDD[Element]): Double on every split.

I have tried to make a pairRDD like this: val test = elementsRDD.map(E => (E.data, E)) so I have (key, value) pairs but I don't know what to do after this (how to split them because groupBy gives back Iteravle(V) and not an RDD of all the values).

I could also filter on each possible value of data: String and execute function f on the results. But I don't know all the possible values that ´´´data: String´´´ can take in advance. And it doesn't seem efficient to first go over all the data to check the different possibilities and then also filter over it multiple times.

So is there a way it can be done efficiently?


Solution

  • All you really need to do is count by aggregating by data, depending on the 2 values that the boolean can take. The rest is a simple computation that only depends on these 2 values.

    val rdd = sc.parallelize(
      Seq(Element(true,"a"),Element(false,"a"),Element(true,"a"),
        Element(false,"b"),Element(false,"b"),Element(true,"b")))
    
    val log2 = math.log(2)
    
    // calculate an RDD[(String, (Int, Int))], first element of the tuple is the number of "true"s, and the second the number of "false"s
    val entropy = rdd.map(e => (e.data, e.target)).aggregateByKey((0, 0))({
      case ((t, f), target) => if (target) (t + 1, f) else (t, f + 1)
    }, {
      case ((t1, f1), (t2, f2)) => (t1 + t2, f1 + f2)
    }).mapValues {
      case (t, f) =>
        val total = (t + f).toDouble
        val trueRatio = t.toDouble / total
        val falseRatio = f.toDouble / total
        -trueRatio * math.log(trueRatio) / log2 + falseRatio * math.log(falseRatio) / log2
    }
    
    // entropy is an RDD[(String, Double)]
    entropy foreach println
    // (a,-0.1383458330929479)
    // (b,0.1383458330929479)