Search code examples
scalaapache-sparkapache-spark-sqlapache-spark-dataset

Conditional application of `filter`/`where` to a Spark `Dataset`/`Dataframe`


Hi folks I have a function that loads a data set from some S3 locations and returns the interesting data

private def filterBrowseIndex(spark: SparkSession, s3BrowseIndex: String, mids: Seq[String] = Seq(), indices: Seq[String] = Seq()): Dataset[BrowseIndex] = {
import spark.implicits._

spark
  .sparkContext.textFile(s3BrowseIndex)
  // split text dataset
  .map(line => line.split("\\s+"))
  // get types for attributes
  .map(BrowseIndex.strAttributesToBrowseIndex)
  // cast it to a dataset (requires implicit conversions)
  .toDS()
  // pick rows for the given marketplaces
  .where($"mid".isin(mids: _*))
  // pick rows for the given indices
  .where($"index".isin(indices: _*))

}

This implementation will filter everything out if someone provides mids = Seq() or indices = Seq(). I on the other hand would like the semantics to be "apply this where clause only if mids is not empty" (same for indices) so that no filtering happens if the user of the function provides empty sequences.

Is there a nice functional way to do that?


Solution

  • Raphael Roth's answer is a good choice for the specific problem of applying a filter, if you don't mind the slightly convoluted logic. The general solution, which works for any conditional transformation (not just filtering and not just doing nothing on one of the decision branches), is to use transform, e.g.,

    spark
      .sparkContext.textFile(s3BrowseIndex)
      // split text dataset
      .map(line => line.split("\\s+"))
      // get types for attributes
      .map(BrowseIndex.strAttributesToBrowseIndex)
      // cast it to a dataset (requires implicit conversions)
      .toDS()
      .transform { ds =>
        // pick rows for the given marketplaces
        if (mids.isEmpty) ds
        else ds.where($"mid".isin(mids: _*))
      }
      .transform { ds =>
        // pick rows for the given indices
        if (indices.isEmpty) ds
        else ds.where($"index".isin(indices: _*))
      }
    

    If you are using datasets of a stable type (or dataframes, which are Dataset[Row]), transform can be very useful as you can build sequences of transformation functions and then apply them:

    transformations.foldLeft(ds)(_ transform _)
    

    In many cases, this approach helps with code reuse and testability.