I'm implementing a variation on topological sort (on top of scala-graph) that returns all topological orderings instead of just one. I have a tree recursive implementation that i want to make tail recursive. I don't want to use trampolines, instead, i want to mimic the call stack as described in this answer.
This is the tree recursive version of my algorithm:
import scalax.collection.Graph
import scalax.collection.GraphPredef._
import scalax.collection.GraphEdge._
import scala.collection.Set
def allTopologicalSorts[T](graph: Graph[T, DiEdge]): Unit = {
val indegree: Map[graph.NodeT, Int] = graph.nodes.map(node => (node, node.inDegree)).toMap
def isSource(node: graph.NodeT): Boolean = indegree.get(node).get == 0
def getSources(): Set[graph.NodeT] = graph.nodes.filter(node => isSource(node))
processSources(getSources(), indegree, List[graph.NodeT](), 0)
def processSources(sources: Set[graph.NodeT], indegrees: Map[graph.NodeT, Int], topOrder: List[graph.NodeT], cnt: Int): Unit = {
if (sources.nonEmpty) {
// `sources` contain all the nodes we can pick
// --> generate all possibilities
for (src <- sources) {
val newTopOrder = src :: topOrder
var newSources = sources - src
// Decrease the in-degree of all adjacent nodes
var newIndegrees = indegrees
for (adjacent <- src.diSuccessors) {
val newIndeg = newIndegrees.get(adjacent).get - 1
newIndegrees = newIndegrees.updated(adjacent, newIndeg)
// If in-degree becomes zero, add to sources
if (newIndeg == 0) {
newSources = newSources + adjacent
}
}
processSources(newSources, newIndegrees, newTopOrder, cnt + 1)
}
}
else if (cnt != graph.nodes.size) {
println("There is a cycle in the graph.")
}
else {
println(topOrder.reverse)
}
}
}
And we can run the algorithm as follows
val graph: Graph[Int, DiEdge] = Graph(2 ~> 4, 2 ~> 7, 4 ~> 5)
allTopologicalSorts(graph)
Which correctly returns
Now, i tried implementing a tail recursive version by manually keeping a stack
import scalax.collection.Graph
import scalax.collection.GraphPredef._
import scalax.collection.GraphEdge._
import scala.collection.Set
def allTopologicalSorts[T](graph: Graph[T, DiEdge]): Unit = {
val indegree: Map[graph.NodeT, Int] = graph.nodes.map(node => (node, node.inDegree)).toMap
def isSource(node: graph.NodeT): Boolean = indegree.get(node).get == 0
def getSources(): Set[graph.NodeT] = graph.nodes.filter(node => isSource(node))
def processSources(sources: Set[graph.NodeT], indegrees: Map[graph.NodeT, Int]): Unit = {
type Order = List[graph.NodeT]
case class Frame(sources: List[graph.NodeT], indegrees: Map[graph.NodeT, Int], topOrder: Order, cnt: Int)
def step(stack: List[Frame]): Unit = {
stack match {
case Frame(src :: rest, indegrees, topOrder, cnt) :: tail => {
val onBacktrackingFrame = Frame(rest, indegrees, topOrder, cnt)
// Process src now and remember to do the rest later
val newTopOrder = src :: topOrder
var newSources = rest
// Decrease the in-degree of all adjacent nodes
var newIndegrees = indegrees
for (adjacent <- src.diSuccessors) {
val newIndeg = newIndegrees.get(adjacent).get - 1
newIndegrees = newIndegrees.updated(adjacent, newIndeg)
// If in-degree becomes zero, add to sources
if (newIndeg == 0) {
newSources = adjacent :: newSources
}
}
val recursionFrame = Frame(newSources, newIndegrees, newTopOrder, cnt + 1)
step(recursionFrame :: onBacktrackingFrame :: tail)
}
case Frame(Nil, indegrees, topOrder, cnt) :: tail => {
println(topOrder.reverse)
step(tail)
}
case Nil =>
}
}
step(List(Frame(sources.toList, indegrees, List[graph.NodeT](), 0)))
}
processSources(getSources(), indegree)
}
However, this does not work as it results in
There's something off with the stack, but i couldn't find the problem.
Connected question: Tail recursive algorithm for generating all topological orderings in a graph
This solution is tail-recursive AFAICT and works when I run it, though I changed parts of it back to the first version, notably changing some types from List
to Set
, in order to keep the changes from the original small (I believe changing it back to List
again should be relatively straight-forward):
def allTopologicalSortsNew[T](graph: Graph[T, DiEdge]): Unit = {
type Order = List[graph.NodeT]
case class Frame(sources: Set[graph.NodeT], indegrees: Map[graph.NodeT, Int], topOrder: Order, cnt: Int)
val indegree: Map[graph.NodeT, Int] = graph.nodes.map(node => (node, node.inDegree)).toMap
def isSource(node: graph.NodeT): Boolean = indegree.get(node).get == 0
def getSources(): Set[graph.NodeT] = graph.nodes.filter(node => isSource(node))
def processSources(initialSources: Set[graph.NodeT], initialIndegrees: Map[graph.NodeT, Int]): Unit = {
def step(stack: List[Frame]): Unit = {
stack match {
case Frame(sources, indegrees, topOrder, cnt) :: tail if !sources.isEmpty => {
val futureFrames = for (src <- sources) yield {
val newTopOrder = src :: topOrder
var newSources = sources - src
// Decrease the in-degree of all adjacent nodes
var newIndegrees = indegrees
for (adjacent <- src.diSuccessors) {
val newIndeg = newIndegrees.get(adjacent).get - 1
newIndegrees = newIndegrees.updated(adjacent, newIndeg)
// If in-degree becomes zero, add to sources
if (newIndeg == 0) {
newSources = newSources + adjacent
}
}
Frame(newSources, newIndegrees, newTopOrder, cnt + 1)
}
step(futureFrames.toList ::: tail)
}
case Frame(sources, indegrees, topOrder, cnt) :: tail if sources.isEmpty => {
println(topOrder.reverse)
step(tail)
}
case Nil =>
}
}
step(List(Frame(initialSources, initialIndegrees, List[graph.NodeT](), 0)))
}
processSources(getSources(), indegree)
}