Regular scala collections have a nifty collect
method which lets me do a filter-map
operation in one pass using a partial function. Is there an equivalent operation on spark Dataset
s?
I'd like it for two reasons:
filter-map
style operations to a single pass (although in spark I am guessing there are optimizations which spot these things for you)Here is an example to show what I mean. Suppose I have a sequence of options and I want to extract and double just the defined integers (those in a Some
):
val input = Seq(Some(3), None, Some(-1), None, Some(4), Some(5))
Method 1 - collect
input.collect {
case Some(value) => value * 2
}
// List(6, -2, 8, 10)
The collect
makes this quite neat syntactically and does one pass.
Method 2 - filter-map
input.filter(_.isDefined).map(_.get * 2)
I can carry this kind of pattern over to spark because datasets and data frames have analogous methods.
But I don't like this so much because isDefined
and get
seem like code smells to me. There's an implicit assumption that map is receiving only Some
s. The compiler can't verify this. In a bigger example, that assumption would be harder for a developer to spot and the developer might swap the filter and map around for example without getting a syntax error.
Method 3 - fold*
operations
input.foldRight[List[Int]](Nil) {
case (nextOpt, acc) => nextOpt match {
case Some(next) => next*2 :: acc
case None => acc
}
}
I haven't used spark enough to know if fold has an equivalent so this might be a bit tangential.
Anyway, the pattern match, the fold boiler plate and the rebuilding of the list all get jumbled together and it's hard to read.
So overall I find the collect
syntax the nicest and I'm hoping spark has something like this.
The collect
method defined over RDD
s and Dataset
s is used to materialize the data in the driver program.
Despite not having something akin to the Collections API collect
method, your intuition is right: since both operations are evaluated lazily, the engine has the opportunity to optimize the operations and chain them so that they are performed with maximum locality.
For the use case you mentioned in particular I would suggest you take flatMap
in consideration, which works on both RDD
s and Dataset
s:
// Assumes the usual spark-shell environment
// sc: SparkContext, spark: SparkSession
val collection = Seq(Some(1), None, Some(2), None, Some(3))
val rdd = sc.parallelize(collection)
val dataset = spark.createDataset(rdd)
// Both operations will yield `Array(2, 4, 6)`
rdd.flatMap(_.map(_ * 2)).collect
dataset.flatMap(_.map(_ * 2)).collect
// You can also express the operation in terms of a for-comprehension
(for (option <- rdd; n <- option) yield n * 2).collect
(for (option <- dataset; n <- option) yield n * 2).collect
// The same approach is valid for traditional collections as well
collection.flatMap(_.map(_ * 2))
for (option <- collection; n <- option) yield n * 2
EDIT
As correctly pointed out in another question, RDD
s actually have the collect
method that transforms an RDD
by applying a partial function just like it happens in normal collections. As the Spark documentation points out, however, "this method should only be used if the resulting array is expected to be small, as all the data is loaded into the driver's memory."