Search code examples
pythonalgorithmgraphkruskals-algorithm

Kruskal's algorithm including a specific edge


I'm trying to solve the following question in which I have to find a list of critical edges and pseudocritical edges. From my understanding of the problem, critical edges are edges that must be included in all minimum spanning trees and pseudocritical edges are edges that are included in some minimum spanning trees. Here's my solution to the problem:

class UnionFind:
    def __init__(self, n):
        self.parent = [i for i in range(n)]
        self.size = [1 for _ in range(n)]
        self.cnt = n
    
    def find(self, node):
        while node != self.parent[node]:
            self.parent[node] = self.parent[self.parent[node]]
            node = self.parent[node]
        return node
    
    def union_(self, node1, node2):
        root1 = self.find(node1)
        root2 = self.find(node2)
        if root1 == root2:
            return
        if self.size[root1] > self.size[root2]:
            self.parent[root2] = root1
            self.size[root1] += 1
        else:
            self.parent[root1] = root2
            self.size[root2] += 1
        
        self.cnt -= 1

class Solution:
    def kruskal(self, num_nodes, edges):
        result = 0
        edges.sort(key = lambda x: x[2])
        uf = UnionFind(num_nodes)
        for a, b, w in edges:
            if uf.find(a) != uf.find(b):
                uf.union_(a, b)
                result += w
        return (result, uf.cnt)
    
    def kruskal_included(self, num_nodes, edges, edge):
        result = edge[2]
        edges.sort(key = lambda x: x[2])
        uf = UnionFind(num_nodes)
        uf.union_(edge[0], edge[1])
        for a, b, w in edges:
            if uf.find(a) != uf.find(b):
                uf.union_(a, b)
                result += w
        return result

    def findCriticalAndPseudoCriticalEdges(self, n: int, edges: List[List[int]]) -> List[List[int]]:
        min_cost, _ = self.kruskal(n, edges)
        first_list = []
        second_list = []
        for i in range(len(edges)):
            new_edges = [edges[j] for j in range(len(edges)) if j != i]
            new_cost, cnt = self.kruskal(n, new_edges)
            if cnt > 1 or new_cost > min_cost:
                first_list.append(i)
            elif self.kruskal_included(n, edges, edges[i]) == min_cost:
                second_list.append(i)
        return [first_list, second_list]

Essentially, my approach is that I ran Kruskal's algorithm on the original edges list to get the minimum spanning tree of the original graph. And then I loop through the edges list and use Kruskal's algorithm on the graph excluding the current edge. If either the graph becomes disconnected or the weight of minimum spanning tree increases, then this is one of the critical edges. Otherwise, I use the algorithm which includes the current edge. If there's a minimum spanning tree including this edge, then it's pseudocritical. However, the solution doesn't pass one of the test cases and I'm not too sure where it goes wrong. Can anyone help me out?

EDIT: I actually got one incorrect result for the following test case: n = 6, edges = [[0,1,1],[1,2,1],[0,2,1],[2,3,4],[3,4,2],[3,5,2],[4,5,2]]. My output is critical edges = [] and pseudocritical edges = [0,1,2,3,4,5,6] but the expected results are critical edges = [3] and pseudocritical edges = [0,1,2,4,5,6].


Solution

  • The issue is your Kruskal methods are sorting the list of edges, which causes the indexes returned to be incorrect. You can get a sorted copy of edges with the sorted function instead of calling .sort on the list.

    edges = sorted(edges, key = lambda x: x[2])
    

    Alternatively, you can create a copy of edges to pass in with [*edges] when calling the methods.