Search code examples
pythonalgorithmrecursiondepth-first-search

Tarjan algorithm, recursion mistakes


I'm trying to implement the Tarjan's algorithm (to find strongly connected components in a graph).

I'm stuck in the dfs part of the algorithm where the components counter does not update itself properly. I think it's a problem with my recursion method but I'm not able able to fix it.

Here is the code:

def dfs_scc(graph, node, components_c, nodes_c, connected_components, visited_nodes):
    nodes_c+=1
    connected_components[node]=-nodes_c
    visited_nodes.append(node)
    last=nodes_c
    for adj in graph.get_adj(node):
        if (connected_components[adj[1]]==0):
            b=dfs_scc(graph, adj[1], components_c, nodes_c, connected_components, visited_nodes)
            last=min(last, b)
        elif (connected_components[adj[1]]<0): last=min(last, -connected_components[adj[1]])
    if (last==-connected_components[node]):
        components_c+=1
        print('VISITED NODE QUEUE: ', list(visited_nodes), '; COMPONENTS COUNTER: ', components_c)
        while(visited_nodes[-1]!=node):
            w=visited_nodes.pop()
            connected_components[w]=components_c
        w=visited_nodes.pop()
        connected_components[w]=components_c
    return last

And here the output:

VISITED NODE QUEUE: [0, 1, 6, 2, 4, 5, 7] ; COMPONENTS COUNTER: 1
VISITED NODE QUEUE: [0, 1, 6, 2, 4, 5] ; COMPONENTS COUNTER: 1
VISITED NODE QUEUE: [0, 1, 6] ; COMPONENTS COUNTER: 1 
VISITED NODE QUEUE: [3] ; COMPONENTS COUNTER: 1
----------------------
CONNECTED COMPONENT: {0: 1, 1: 1, 2: 1, 3: 1, 4: 1, 5: 1, 6: 1, 7: 1}

As you can see the queue of the visited nodes remove elements in the right order, at first recursion the node 7 is popped out course is just it's in that components; at next recursion the node 2, 4 and 5 are removed belonging to the same component and so on.

Instead in the final line of the output I've a dictionary (node : component) in which the component value is always the same.

Have any idea?

As asked here is the entire code:

class Graph(object):
    def __init__(self, graph=None):
        if graph == None: graph = {}
        self.__graph = graph

    def get_nodes(self): return list(self.__graph.keys())

    def get_edges(self): return self.__generate_edges()

    def __generate_edges(self):
        edges = []
        for node in self.__graph:
            for adj in self.__graph[node]:
                if node!=adj: edges.append((node, adj))
        return edges

    def add_node(self, node):
        if node not in self.__graph: self.__graph[node] = []

    def add_edge(self, edge):
        if edge[0] in self.__graph: self.__graph[edge[0]].append(edge[1])
        else:self.__graph[edge[0]] = [edge[1]]

    def get_adj(self, node):
        adj=[]
        for edge in self.__generate_edges():
            if node==edge[0] and edge[0]!=edge[1]: adj.append(edge)
        return adj

def scc(graph):
    #connected_components : {npde0: components, node1: components, node2: components, node3 : components, ...}
    connected_components={graph.get_nodes()[i]: 0 for i in range(len(graph.get_nodes()))}
    components_c=nodes_c=0
    visited_nodes=deque()
    for node in graph.get_nodes():
        if (connected_components[node]==0):
            dfs_scc(graph, node, components_c, nodes_c, connected_components, visited_nodes)
    return connected_components


def dfs_scc(graph, node, components_c, nodes_c, connected_components, visited_nodes):
    nodes_c+=1
    connected_components[node]=-nodes_c
    visited_nodes.append(node)
    last=nodes_c
    for adj in graph.get_adj(node):
        if (connected_components[adj[1]]==0):
            b=dfs_scc(graph, adj[1], components_c, nodes_c, connected_components, visited_nodes)
            last=min(last, b)
        elif (connected_components[adj[1]]<0):
            last=min(last, -connected_components[adj[1]])
    if (last==-connected_components[node]):
        components_c+=1
        print('VISITED NODE QUEUE: ', list(visited_nodes), '; COMPONENTS COUNTER: ', components_c)
        while(visited_nodes[-1]!=node):
            w=visited_nodes.pop()
            connected_components[w]=components_c
        w=visited_nodes.pop()
        connected_components[w]=components_c
    return last


def main():
    g={0: [1, 2], 1: [6], 2: [4], 3: [], 4: [5], 5: [2, 7], 6: [0], 7: []}
    graph=Graph(g)
    # nodes=random.randint(1, 10)
    # for i in range(nodes): graph.add_node(i)
    # for i in range(0, nodes):
    #     for j in range(0, nodes):
    #         graph.add_edge((i, j))
    cc=scc(graph)
    print("CONNECTED COMPONENTS: ", cc)

Solution

  • The issue is that the modification that a function execution brings to components_c, nodes_c, must carry back to the caller's variables with the same name, but that is not happening, because these variables are local to their own function execution context. The caller's variables with those names will not be modified by the recursive calls it makes, but they should.

    You can solve this in different ways. One way is to make dfs_scc a function that is defined within scc, and to only define the two variables mentioned above in the scope of scc. Then dfs_scc can reference those variables via the nonlocal keyword instead of getting them as arguments, and so modify them in a way that will be seen by all execution contexts in the recursion tree.

    Here is how that looks:

    def scc(graph):
        components_c=nodes_c=0
    
        # define the recursive function with the scope where the above variables are defined
        def dfs_scc(graph, node, connected_components, visited_nodes):
            nonlocal components_c, nodes_c # reference those variables
    
            nodes_c+=1
            connected_components[node]=-nodes_c
            visited_nodes.append(node)
            last=nodes_c
            for adj in graph.get_adj(node):
                if (connected_components[adj[1]]==0):
                    b=dfs_scc(graph, adj[1], connected_components, visited_nodes)
                    last=min(last, b)
                elif (connected_components[adj[1]]<0):
                    last=min(last, -connected_components[adj[1]])
            if (last==-connected_components[node]):
                components_c+=1
                print('VISITED NODE QUEUE: ', list(visited_nodes), '; COMPONENTS COUNTER: ', components_c)
                while(visited_nodes[-1]!=node):
                    w=visited_nodes.pop()
                    connected_components[w]=components_c
                w=visited_nodes.pop()
                connected_components[w]=components_c
            return last
        #connected_components : {npde0: components, node1: components, node2: components, node3 : components, ...}
        connected_components={graph.get_nodes()[i]: 0 for i in range(len(graph.get_nodes()))}
        visited_nodes=deque()
        for node in graph.get_nodes():
            if (connected_components[node]==0):
                dfs_scc(graph, node, connected_components, visited_nodes)
        return connected_components