Search code examples
scalarecursionfor-comprehension

Scala Recursive For Comprehension Prepends Empty List Only Once, Why?


similar to this post here, I am working on the "Functional Programming in Scala" Anagrams coursework. I could not figure out the combinations function, but found this incredibly elegant solution elsewhere

def combinations(occurrences: Occurrences): List[Occurrences] =
  List() :: (for {
    (char, max) <- occurrences
    count <- 1 to max
    rest <- combinations(occurrences filter {case (c, _) => c > char})
  } yield List((char, count)) ++ rest)

I understand how the for the comprehension works to create the combinations, but what I do not understand is why the empty list is not pre-appended to every inner list during each recursive call. It's almost as if the compiler skips the prepend statement and only executes the right side for expression.

For example, the input combinations(List(('a', 2), ('b', 2))) returns the expected result set:

res1: List[forcomp.Anagrams.Occurrences] = List(List(), List((a,1)), List((a,1), (b,1)), List((a,1), (b,2)), List((a,2)), List((a,2), (b,1)), List((a,2), (b,2)), List((b,1)), List((b,2)))

With only a single Empty list. Looking at the recursive call, I would expected another Empty list for each recursion. Would someone be so kind as to explain how this elegant solution works?


Solution

  • There is nothing producing an empty list inside this for comprehension. Even if

    combinations(occurrences filter {case (c, _) => c > char})
    

    contained an empty list and returned it in rest <- ... (it should for the first element), a value is prepended in List((char, count)) ++ rest making it non-empty by design.

    So the whole for-comprehension must return a List of non-empty Lists to which an empty list is prepended.

    This basically builds solution by induction:

    • if you have an empty list - return an empty list because it is a valid solution for this input
    • if you start with (char, maxOccurrences) :: rest
      • assume that you have a valid solution for combinations(rest)
      • then take each such solution and add (char, 1) to each element of rest,
      • then take each such solution and add (char, 2) to each element of rest,
      • ...
      • then take each such solution and add (char, maxOccurrences) to each element of rest
      • then combine all of these results into one solution
      • all of these are non-empty because you always prepended something
      • so you are missing empty set, so you add it explicitly to all the other solutions combined to create a complete solution for (char, maxOccurrences) :: rest

    Because you have a valid starting point and a valid way of creating next step from the previous, you know that you can always create a valid solution.

    In the for comprehension

      def combinations(occurrences: Occurrences): List[Occurrences] =
        List() :: (for {
          (char, max) <- occurrences
          count <- 1 to max
          rest <- combinations(occurrences filter {case (c, _) => c > char})
        } yield List((char, count)) ++ rest)
    

    is doing the same thing as

    def combinations(occurrences: Occurrences): List[Occurrences] =
      List() :: occurrences.flatMap { case (char, max) =>
        (1 to map).flatMap { count =>
          combinations(occurrences filter {case (c, _) => c > char}).map { rest =>
            (char, count) :: rest
          }
        }
      }
    

    which is the same as

    def combinations(occurrences: Occurrences): List[Occurrences] =
      occurrences.map { case (char, max) =>
        (1 to map).map { count =>
          val newOccurence = (char, count)
          combinations(occurrences filter {case (c, _) => c > char}).map { rest =>
            newOccurence :: rest
          }
        }
      }.flatten.flatten.::(List())
    

    and this you can easily compare to the induction recipe from above:

    def combinations(occurrences: Occurrences): List[Occurrences] =
      occurrences.map { case (char, max) =>
        // for every character on the list of occurrences
        (1 to max).map { count =>
          // you construct (char, 1), (char, 2), ... (char, max)
          val newOccurence = (char, count)
          // and for each such occurrence
          combinations(occurrences filter {case (c, _) => c > char}).map { rest =>
            // you prepend it into every result from smaller subproblem
            newOccurence :: rest
          }
        }
      }
       // because you would have a List(List(List(List(...)))) here
       // and you need List(List(...)) you flatten it twice
      .flatten.flatten
      // and since you are missing empty result, you prepend it here
      .::(List())
    

    The solution you posted does exactly the same thing just in more compacted way - instead of .map().flatten, there are .flatMaps hidden by a for-comprehension.