Search code examples
scalaapache-sparkfoldleft

Scala map with dependent variables


In scala I have a list of functions that return a value. The order in which the functions are executed are important since the argument of function n is the output of function n-1.

This hints to use foldLeft, something like:

val base: A
val funcs: Seq[Function[A, A]]

funcs.foldLeft(base)(x, f) => f(x)

(detail: type A is actually a Spark DataFrame).

However, the results of each functions are mutually exclusive and in the end I want the union of all the results for each function. This hints to use a map, something like:

funcs.map(f => f(base)).reduce(_.union(_)

But here each function is applied to base which is not what I want.

Short: A list of variable length of ordered functions needs to return a list of equal length of return values, where each value n-1 was the input for function n (starting from base where n=0). Such that the result values can be concatenated.

How can I achieve this?

EDIT example:

case class X(id:Int, value:Int)
val base = spark.createDataset(Seq(X(1, 1), X(2, 2), X(3, 3), X(4, 4), X(5, 5))).toDF

def toA = (x: DataFrame) => x.filter('value.mod(2) === 1).withColumn("value", lit("a"))
def toB = (x: DataFrame) => x.withColumn("value", lit("b"))

val a = toA(base)
val remainder = base.join(a, Seq("id"), "leftanti")
val b = toB(remainder)

a.union(b)

+---+-----+
| id|value|
+---+-----+
|  1|    a|
|  3|    a|
|  5|    a|
|  2|    b|
|  4|    b|
+---+-----+

This should work for an arbitrary number of functions (e.g. toA, toB ... toN. Where each time the remainder of the previous result is calculated and passed into the next function. In the end a union is applied to all results.


Solution

  • Seq already has a method scanLeft that does this out-of-the-box:

    funcs.scanLeft(base)((acc, f) => f(acc)).tail
    

    Make sure to drop the first element of the result of scanLeft if you don't want base to be included.


    Using only foldLeft it is possible too:

    funcs.foldLeft((base, List.empty[A])){ case ((x, list), f) => 
      val res = f(x)
      (res, res :: list) 
    }._2.reverse.reduce(_.union(_))
    

    Or:

    funcs.foldLeft((base, Vector.empty[A])){ case ((x, list), f) => 
      val res = f(x)
      (res, list :+ res) 
    }._2.reduce(_.union(_))
    

    The trick is to accumulate into a Seq inside the fold.

    Example:

    scala> val base = 7
    base: Int = 7
    
    scala> val funcs: List[Int => Int] = List(_ * 2, _ + 3)
    funcs: List[Int => Int] = List($$Lambda$1772/1298658703@7d46af18, $$Lambda$1773/107346281@5470fb9b)
    
    scala> funcs.foldLeft((base, Vector.empty[Int])){ case ((x, list), f) => 
         |   val res = f(x)
         |   (res, list :+ res) 
         | }._2
    res8: scala.collection.immutable.Vector[Int] = Vector(14, 17)
    
    scala> .reduce(_ + _)
    res9: Int = 31