Search code examples

Replace groupByKey() with reduceByKey()

This is a follow up question from here. I am trying to implement k-means based on this implementation. It works great, but I would like to replace groupByKey() with reduceByKey(), but I am not sure how (I am not worried about performance now). Here is the relevant minified code:

val data = sc.textFile("dense.txt").map(
        t => (t.split("#")(0), parseVector(t.split("#")(1)))).cache()

val read_mean_centroids = sc.textFile("centroids.txt").map(
        t => (t.split("#")(0), parseVector(t.split("#")(1))))
var centroids = read_mean_centroids.takeSample(false, K, 42).map(x => x._2)
do {
    var closest = => (closestPoint(p._2, centroids), p._2))
    var pointsGroup = closest.groupByKey() // <-- THE VICTIM :)
    var newCentroids = pointsGroup.mapValues(ps => average(ps.toSeq)).collectAsMap()

Notice that println(newCentroids) will give:

Map(23 -> (-6.269305E-4, -0.0011746404, -4.08004E-5), 8 -> (-5.108732E-4, 7.336348E-4, -3.707591E-4), 17 -> (-0.0016383086, -0.0016974678, 1.45..

and println(closest):

MapPartitionsRDD[6] at map at kmeans.scala:75

Relevant question: Using reduceByKey in Apache Spark (Scala).

Some documentation:

def reduceByKey(func: (V, V) ⇒ V): RDD[(K, V)]

Merge the values for each key using an associative reduce function.

def reduceByKey(func: (V, V) ⇒ V, numPartitions: Int): RDD[(K, V)]

Merge the values for each key using an associative reduce function.

def reduceByKey(partitioner: Partitioner, func: (V, V) ⇒ V): RDD[(K, V)]

Merge the values for each key using an associative reduce function.

def groupByKey(): RDD[(K, Iterable[V])]

Group the values for each key in the RDD into a single sequence.


  • You could use an aggregateByKey() (a bit more natural than reduceByKey()) like this to compute newCentroids:

    val newCentroids = closest.aggregateByKey((Vector.zeros(dim), 0L))(
      (agg, v) => (agg._1 += v, agg._2 + 1L),
      (agg1, agg2) => (agg1._1 += agg2._1, agg1._2 + agg2._2)
    ).mapValues(agg => agg._1/agg._2).collectAsMap 

    For this to work you will need to compute the dimensionality of your data, i.e. dim, but you only need to do this once. You could probably use something like val dim = data.first._2.length.