Search code examples
scalarecursiontail-recursiontopological-sort

Manually transforming tree recursion into tail recursion using a stack


I'm implementing a variation on topological sort (on top of scala-graph) that returns all topological orderings instead of just one. I have a tree recursive implementation that i want to make tail recursive. I don't want to use trampolines, instead, i want to mimic the call stack as described in this answer.

This is the tree recursive version of my algorithm:

import scalax.collection.Graph
import scalax.collection.GraphPredef._
import scalax.collection.GraphEdge._
import scala.collection.Set

def allTopologicalSorts[T](graph: Graph[T, DiEdge]): Unit = {
  val indegree: Map[graph.NodeT, Int] = graph.nodes.map(node => (node, node.inDegree)).toMap

  def isSource(node: graph.NodeT): Boolean = indegree.get(node).get == 0

  def getSources(): Set[graph.NodeT] = graph.nodes.filter(node => isSource(node))

  processSources(getSources(), indegree, List[graph.NodeT](), 0)

  def processSources(sources: Set[graph.NodeT], indegrees: Map[graph.NodeT, Int], topOrder: List[graph.NodeT], cnt: Int): Unit = {
    if (sources.nonEmpty) {
      // `sources` contain all the nodes we can pick
      // --> generate all possibilities
      for (src <- sources) {
        val newTopOrder = src :: topOrder
        var newSources = sources - src

        // Decrease the in-degree of all adjacent nodes
        var newIndegrees = indegrees
        for (adjacent <- src.diSuccessors) {
          val newIndeg = newIndegrees.get(adjacent).get - 1
          newIndegrees = newIndegrees.updated(adjacent, newIndeg)
          // If in-degree becomes zero, add to sources
          if (newIndeg == 0) {
            newSources = newSources + adjacent
          }
        }

        processSources(newSources, newIndegrees, newTopOrder, cnt + 1)
      }
    }
    else if (cnt != graph.nodes.size) {
      println("There is a cycle in the graph.")
    }
    else {
      println(topOrder.reverse)
    }
  }
}

And we can run the algorithm as follows

val graph: Graph[Int, DiEdge] = Graph(2 ~> 4, 2 ~> 7, 4 ~> 5)
allTopologicalSorts(graph)

Which correctly returns

  • List(2, 7, 4, 5)
  • List(2, 4, 7, 5)
  • List(2, 4, 5, 7)

Now, i tried implementing a tail recursive version by manually keeping a stack

import scalax.collection.Graph
import scalax.collection.GraphPredef._
import scalax.collection.GraphEdge._
import scala.collection.Set

def allTopologicalSorts[T](graph: Graph[T, DiEdge]): Unit = { 
  val indegree: Map[graph.NodeT, Int] = graph.nodes.map(node => (node, node.inDegree)).toMap

  def isSource(node: graph.NodeT): Boolean = indegree.get(node).get == 0

  def getSources(): Set[graph.NodeT] = graph.nodes.filter(node => isSource(node))

  def processSources(sources: Set[graph.NodeT], indegrees: Map[graph.NodeT, Int]): Unit = {
    type Order = List[graph.NodeT]
    case class Frame(sources: List[graph.NodeT], indegrees: Map[graph.NodeT, Int], topOrder: Order, cnt: Int)

    def step(stack: List[Frame]): Unit = {
      stack match {
        case Frame(src :: rest, indegrees, topOrder, cnt) :: tail => {
          val onBacktrackingFrame = Frame(rest, indegrees, topOrder, cnt)

          // Process src now and remember to do the rest later
          val newTopOrder = src :: topOrder
          var newSources = rest

          // Decrease the in-degree of all adjacent nodes
          var newIndegrees = indegrees
          for (adjacent <- src.diSuccessors) {
            val newIndeg = newIndegrees.get(adjacent).get - 1
            newIndegrees = newIndegrees.updated(adjacent, newIndeg)
            // If in-degree becomes zero, add to sources
            if (newIndeg == 0) {
              newSources = adjacent :: newSources
            }
          }

          val recursionFrame = Frame(newSources, newIndegrees, newTopOrder, cnt + 1)
          step(recursionFrame :: onBacktrackingFrame :: tail)
        }
        case Frame(Nil, indegrees, topOrder, cnt) :: tail => {
          println(topOrder.reverse)
          step(tail)
        }
        case Nil =>
      }
    }

    step(List(Frame(sources.toList, indegrees, List[graph.NodeT](), 0)))
  }

  processSources(getSources(), indegree)
}

However, this does not work as it results in

  • List(2, 4, 5, 7)
  • List(2, 4, 5)
  • List(2, 4, 7)
  • List(2, 4)
  • List(2, 7)
  • List(2)
  • List()

There's something off with the stack, but i couldn't find the problem.

Connected question: Tail recursive algorithm for generating all topological orderings in a graph


Solution

  • This solution is tail-recursive AFAICT and works when I run it, though I changed parts of it back to the first version, notably changing some types from List to Set, in order to keep the changes from the original small (I believe changing it back to List again should be relatively straight-forward):

    def allTopologicalSortsNew[T](graph: Graph[T, DiEdge]): Unit = { 
      type Order = List[graph.NodeT]
      case class Frame(sources: Set[graph.NodeT], indegrees: Map[graph.NodeT, Int], topOrder: Order, cnt: Int)
      val indegree: Map[graph.NodeT, Int] = graph.nodes.map(node => (node, node.inDegree)).toMap
    
      def isSource(node: graph.NodeT): Boolean = indegree.get(node).get == 0
    
      def getSources(): Set[graph.NodeT] = graph.nodes.filter(node => isSource(node))
    
      def processSources(initialSources: Set[graph.NodeT], initialIndegrees: Map[graph.NodeT, Int]): Unit = {
    
        def step(stack: List[Frame]): Unit = {
          stack match {
            case Frame(sources, indegrees, topOrder, cnt) :: tail if !sources.isEmpty => {
    
              val futureFrames = for (src <- sources) yield {
                val newTopOrder = src :: topOrder
                var newSources = sources - src
    
                // Decrease the in-degree of all adjacent nodes
                var newIndegrees = indegrees
                for (adjacent <- src.diSuccessors) {
                  val newIndeg = newIndegrees.get(adjacent).get - 1
                  newIndegrees = newIndegrees.updated(adjacent, newIndeg)
                  // If in-degree becomes zero, add to sources
                  if (newIndeg == 0) {
                    newSources = newSources + adjacent
                  }
                }
    
                Frame(newSources, newIndegrees, newTopOrder, cnt + 1)
              }
    
              step(futureFrames.toList ::: tail)
            }
            case Frame(sources, indegrees, topOrder, cnt) :: tail if sources.isEmpty => {
              println(topOrder.reverse)
              step(tail)
            }
            case Nil =>
          }
        }
    
        step(List(Frame(initialSources, initialIndegrees, List[graph.NodeT](), 0)))
      }
    
      processSources(getSources(), indegree)
    }