Search code examples
scalarecursionfunctional-programmingdynamic-programmingbacktracking

Scala: dynamic programming recursion using iterators


Learning how to do dynamic programming in Scala, and I'm often finding myself in a situation where I want to recursively proceed over an array (or some other iterable) of items. When I do this, I tend to write cumbersome functions like this:

def arraySum(array: Array[Int], index: Int, accumulator: Int): Int => {
  if (index == array.length) {
    accumulator
  } else {
    arraySum(array, index + 1, accumulator + array(index)
  }
}
arraySum(Array(1,2,3), 0, 0)

(Ignore for a moment that I could just call sum on the array or do a .reduce(_ + _), I'm trying to learn programming principles.)

But this seems like I'm passing alot of variables, and what exactly is the point of passing the array to each function call? This seems unclean.

So instead I got the idea to do this with iterators and not worry about passing indexes:

def arraySum(iter: Iterator[Int])(implicit accumulator: Int = 0): Int = {
  try {
    val nextInt = iter.next()
    arraySum(iter)(accumulator + nextInt)
  } catch {
    case nee: NoSuchElementException => accumulator
  }
}
arraySum(Array(1,2,3).toIterator)

This seems like a much cleaner solution. However, this falls apart when you need to use dynamic programming to explore some outcome space and you don't need to call the iterator at every function call. E.g.

def explore(iter: Iterator[Int])(implicit accumulator: Int = 0): Int = {
  if (someCase) {
    explore(iter)(accumulator)
  } else if (someOtherCase){
    val nextInt = iter.next()
    explore(iter)(accumulator + nextInt)
  } else {
    // Some kind of aggregation/selection of explore results
  }
}

My understanding is that the iter iterator here functions as pass by reference, so when this function calls iter.next() that changes the instance of iter that is passed to all other recursive calls of the function. So to get around that, now I'm cloning the iterator at every call of the explore function. E.g.:

def explore(iter: Iterator[Int])(implicit accumulator: Int = 0): Int = {
  if (someCase) {
    explore(iter)(accumulator)
  } else if (someOtherCase){
    val iterClone = iter.toList.toIterator
    explore(iterClone)(accumulator + iterClone.next())
  } else {
    // Some kind of aggregation/selection of explore results
  }
}

But this seems pretty stupid, and the stupidity escalates when I have multiple iterators that may or may not need cloning in multiple else if cases. What is the right way to handle situations like this? How can I elegantly solve these kinds of problems?


Solution

  • Suppose that you want to write a back-tracking recursive function that needs some complex data structure as an argument, so that the recursive calls receive a slightly modified version of the data structure. You have several options how you could do it:

    1. Clone the entire data structure, modify it, pass it to recursive call. This is very simple, but usually very expensive.
    2. Modify the mutable structure in-place, pass it to the recursive call, then revert the modification when backtracking. You have to ensure that every possible call of your recursive function always restores the original state of the data structure exactly. This is much more efficient, but is hard to implement, because it can be very error prone.
    3. Subdivide the structure into a large immutable and a small mutable part. For example, you could pass an index (or a pair of indices) that specify some slice of an array explicitly, along with an array that is never mutated. You could then "clone" and save only the mutable part, and restore it when backtracking. If it works, it is both simple and fast, but it doesn't always work, because substructures can be hard to describe by just few integer indices.
    4. Rely on persistent immutable data structures whenever you can.

    I'd like to elaborate on the last point, because this is the preferred way to do it in Scala and in functional programming in general.

    Here is your original code, that uses the third strategy:

    def arraySum(array: Array[Int], index: Int, accumulator: Int): Int = {
      if (index == array.length) {
        accumulator
      } else {
        arraySum(array, index + 1, accumulator + array(index))
      }
    }
    

    If you would use a List instead of an Array, you could rewrite it to this:

    @annotation.tailrec
    def listSum(list: List[Int], acc: Int): Int = list match {
      case Nil => acc
      case h :: t => listSum(t, acc + h)
    }
    

    Here, h :: t is a pattern that deconstructs the list into the head and the tail. Note that you don't need an explicit index any more, because accessing the tail t of the list is a constant-time operation, so that only the relevant remaining sublist is passed to the recursive call of listSum.

    There is no backtracking here, but if the recursive method would backtrack, using lists would bring another advantage: extracting a sublist is almost free (constant time operation), but it's still guaranteed to be immutable, so you can just pass it into the recursive call, without having to care about whether the recursive call modifies it or not, and so you don't have to do anything to undo any modifications that could have been done by the recursive calls. This is the advantage of persistent immutable data structures: related lists can share most of their structure, while still appearing immutable from the outside, so that it's impossible to break anything in the parent list just because you have the access to the tail of this list. This would not be the case with a view over a mutable array.