Search code examples
scalarecursiontreetraversal

N-Tree Traversal with Scala Causes Stack Overflow


I am attempting to return a list of widgets from an N-tree data structure. In my unit test, if i have roughly about 2000 widgets each with a single dependency, i'll encounter a stack overflow. What I think is happening is the for loop is causing my tree traversal to not be tail recursive. what's a better way of writing this in scala? Here's my function:

protected def getWidgetTree(key: String) : ListBuffer[Widget] = {
  def traverseTree(accumulator: ListBuffer[Widget], current: Widget) : ListBuffer[Widget] = {
    accumulator.append(current)

    if (!current.hasDependencies) {
      accumulator
    }  else {
      for (dependencyKey <- current.dependencies) {
        if (accumulator.findIndexOf(_.name == dependencyKey) == -1) {
          traverseTree(accumulator, getWidget(dependencyKey))
        }
      }

      accumulator
    }
  }

  traverseTree(ListBuffer[Widget](), getWidget(key))
}

Solution

  • The reason it's not tail-recursive is that you are making multiple recursive calls inside your function. To be tail-recursive, a recursive call can only be the last expression in the function body. After all, the whole point is that it works like a while-loop (and, thus, can be transformed into a loop). A loop can't call itself multiple times within a single iteration.

    To do a tree traversal like this, you can use a queue to carry forward the nodes that need to be visited.

    Assume we have this tree:

    //        1
    //       / \  
    //      2   5
    //     / \
    //    3   4
    

    Represented with this simple data structure:

    case class Widget(name: String, dependencies: List[String]) {
      def hasDependencies = dependencies.nonEmpty
    }
    

    And we have this map pointing to each node:

    val getWidget = List(
      Widget("1", List("2", "5")),
      Widget("2", List("3", "4")),
      Widget("3", List()),
      Widget("4", List()),
      Widget("5", List()))
      .map { w => w.name -> w }.toMap
    

    Now we can rewrite your method to be tail-recursive:

    def getWidgetTree(key: String): List[Widget] = {
      @tailrec
      def traverseTree(queue: List[String], accumulator: List[Widget]): List[Widget] = {
        queue match {
          case currentKey :: queueTail =>        // the queue is not empty
            val current = getWidget(currentKey)  // get the element at the front
            val newQueueItems =                  // filter out the dependencies already known
              current.dependencies.filterNot(dependencyKey => 
                accumulator.exists(_.name == dependencyKey) && !queue.contains(dependencyKey))
            traverseTree(newQueueItems ::: queueTail, current :: accumulator) // 
          case Nil =>                            // the queue is empty
            accumulator.reverse                  // we're done
        }
      }
    
      traverseTree(key :: Nil, List[Widget]())
    }
    

    And test it out:

    for (k <- 1 to 5)
      println(getWidgetTree(k.toString).map(_.name))
    

    prints:

    ListBuffer(1, 2, 3, 4, 5)
    ListBuffer(2, 3, 4)
    ListBuffer(3)
    ListBuffer(4)
    ListBuffer(5)