Search code examples
apache-sparkapache-spark-sqlrddapache-spark-dataset

create column with a running total in a Spark Dataset


Suppose we have a Spark Dataset with two columns, say Index and Value, sorted by the first column (Index).

((1, 100), (2, 110), (3, 90), ...)

We'd like to have a Dataset with a third column with a running total of the values in the second column (Value).

((1, 100, 100), (2, 110, 210), (3, 90, 300), ...)

Any suggestions how to do this efficiently, with one pass through the data? Or are there any canned CDF type functions out there that could be utilized for this?

If need be, the Dataset can be converted to a Dataframe or an RDD to accomplish the task, but it will have to remain a distributed data structure. That is, it cannot be simply collected and turned to an array or sequence, and no mutable variables are to be used (val only, no var).


Solution

  • A colleague suggested the following which relies on the RDD.mapPartitionsWithIndex() method. (To my knowledge, the other data structure do not provide this kind of reference to their partitions' indices.)

    val data = sc.parallelize((1 to 5))  // sc is the SparkContext
    val partialSums = data.mapPartitionsWithIndex{ (i, values) => 
        Iterator((i, values.sum))
    }.collect().toMap  // will in general have size other than data.count
    val cumSums = data.mapPartitionsWithIndex{ (i, values) => 
        val prevSums = (0 until i).map(partialSums).sum
        values.scanLeft(prevSums)(_+_).drop(1)
    }