Search code examples
scalaapache-sparkapache-spark-dataset

Spark Dataset equivalent for scala's "collect" taking a partial function


Regular scala collections have a nifty collect method which lets me do a filter-map operation in one pass using a partial function. Is there an equivalent operation on spark Datasets?


I'd like it for two reasons:

  • syntactic simplicity
  • it reduces filter-map style operations to a single pass (although in spark I am guessing there are optimizations which spot these things for you)

Here is an example to show what I mean. Suppose I have a sequence of options and I want to extract and double just the defined integers (those in a Some):

val input = Seq(Some(3), None, Some(-1), None, Some(4), Some(5)) 

Method 1 - collect

input.collect {
  case Some(value) => value * 2
} 
// List(6, -2, 8, 10)

The collect makes this quite neat syntactically and does one pass.

Method 2 - filter-map

input.filter(_.isDefined).map(_.get * 2)

I can carry this kind of pattern over to spark because datasets and data frames have analogous methods.

But I don't like this so much because isDefined and get seem like code smells to me. There's an implicit assumption that map is receiving only Somes. The compiler can't verify this. In a bigger example, that assumption would be harder for a developer to spot and the developer might swap the filter and map around for example without getting a syntax error.

Method 3 - fold* operations

input.foldRight[List[Int]](Nil) {
  case (nextOpt, acc) => nextOpt match {
    case Some(next) => next*2 :: acc
    case None => acc
  }
}

I haven't used spark enough to know if fold has an equivalent so this might be a bit tangential.

Anyway, the pattern match, the fold boiler plate and the rebuilding of the list all get jumbled together and it's hard to read.


So overall I find the collect syntax the nicest and I'm hoping spark has something like this.


Solution

  • The collect method defined over RDDs and Datasets is used to materialize the data in the driver program.

    Despite not having something akin to the Collections API collect method, your intuition is right: since both operations are evaluated lazily, the engine has the opportunity to optimize the operations and chain them so that they are performed with maximum locality.

    For the use case you mentioned in particular I would suggest you take flatMap in consideration, which works on both RDDs and Datasets:

    // Assumes the usual spark-shell environment
    // sc: SparkContext, spark: SparkSession
    val collection = Seq(Some(1), None, Some(2), None, Some(3))
    val rdd = sc.parallelize(collection)
    val dataset = spark.createDataset(rdd)
    
    // Both operations will yield `Array(2, 4, 6)`
    rdd.flatMap(_.map(_ * 2)).collect
    dataset.flatMap(_.map(_ * 2)).collect
    
    // You can also express the operation in terms of a for-comprehension
    (for (option <- rdd; n <- option) yield n * 2).collect
    (for (option <- dataset; n <- option) yield n * 2).collect
    
    // The same approach is valid for traditional collections as well
    collection.flatMap(_.map(_ * 2))
    for (option <- collection; n <- option) yield n * 2
    

    EDIT

    As correctly pointed out in another question, RDDs actually have the collect method that transforms an RDD by applying a partial function just like it happens in normal collections. As the Spark documentation points out, however, "this method should only be used if the resulting array is expected to be small, as all the data is loaded into the driver's memory."