Search code examples
pythondata-structuresunion-find

Debugging Union Find Algorithm Implementation


I'm attempting to solve LeetCode problem 547. Number of Provinces:

There are n cities. Some of them are connected, while some are not. If city a is connected directly with city b, and city b is connected directly with city c, then city a is connected indirectly with city c.

A province is a group of directly or indirectly connected cities and no other cities outside of the group.

You are given an n x n matrix isConnected where isConnected[i][j] = 1 if the ith city and the jth city are directly connected, and isConnected[i][j] = 0 otherwise.

Return the total number of provinces.

I tried to implement the union-find algorithm for this problem, but despite adding numerous debugging points, I'm still unable to identify the issue in my code.

Code:

class Solution(object):
    def findCircleNum(self, m):
        n = len(m)
        parent = [i for i in range(n)]
        rank = [1 for i in range(n)]
        count = n
        for i in range(n):
            for j in range(i+1, n):
                if m[i][j] == 1:
                    p1, p2 = parent[i], parent[j]
                    if p1 != p2:
                        print(parent)
                        # print(rank)
                        count -= 1
                    r1, r2 = rank[p1], rank[p2]
                    if r1 >= r2:
                        parent[j] = p1
                        rank[p1] += 1
                    else:
                        parent[i] = p2
                        rank[p2] += 1
        return count

Test Case:

test_input = [[1,1,0,0,0,0,0,1,0,0,0,0,0,0,0],
              [1,1,0,0,0,0,0,0,0,0,0,0,0,0,0],
              [0,0,1,0,0,0,0,0,0,0,0,0,0,0,0],
              [0,0,0,1,0,1,1,0,0,0,0,0,0,0,0],
              [0,0,0,0,1,0,0,0,0,1,1,0,0,0,0],
              [0,0,0,1,0,1,0,0,0,0,1,0,0,0,0],
              [0,0,0,1,0,0,1,0,1,0,0,0,0,1,0],
              [1,0,0,0,0,0,0,1,1,0,0,0,0,0,0],
              [0,0,0,0,0,0,1,1,1,0,0,0,0,1,0],
              [0,0,0,0,1,0,0,0,0,1,0,1,0,0,1],
              [0,0,0,0,1,1,0,0,0,0,1,1,0,0,0],
              [0,0,0,0,0,0,0,0,0,1,1,1,0,0,0],
              [0,0,0,0,0,0,0,0,0,0,0,0,1,0,0],
              [0,0,0,0,0,0,1,0,1,0,0,0,0,1,0],
              [0,0,0,0,0,0,0,0,0,1,0,0,0,0,1]]

print(Solution().findCircleNum(test_input))

The above code is returning 2 but expected is 3

Could someone help me identify the problem in my code?


Solution

  • There are two issues:

    • the condition p1 != p2 (which in your code means parent[i] != parent[j]) is not always a correct indication whether the nodes i and j are in a different set. If p1 and p2 have a common ancestor via (potential multiple) parent dependencies, they still belong to the same set. An important step in Union-Find is to perform the find operation. There are several implementations possible, but this is one:

          def find(i):
              while i != parent[i]:
                  parent[i] = parent[parent[i]]
                  i = parent[i]
              return i
      

      and you should call it to make sure you reach the root of a parent-path:

      p1, p2 = find(i), find(j)
      
    • A related error is in the union operation: parent[i] = p2 may relink a node i that is not the root of the tree it is in, and so its ancestors will not be unified to the other tree. This should be parent[p1] = p2. A similar correction is needed in the alternative case.

    Then there are two points that affect efficiency:

    • The code that performs the union does not need to be executed when the nodes were already found to be in the same set, so you could move it inside the if block.

    • The rank should only be incremented when the original two ranks were equal, as the rank represents the height of the tree.

    Here is your code with these corrections applied (cf. comments):

    class Solution(object):
        def findCircleNum(self, m):
            def find(i):  # The path-halving algorithm
                while i != parent[i]:
                    parent[i] = parent[parent[i]]
                    i = parent[i]
                return i
    
            n = len(m)
            parent = [i for i in range(n)]
            rank = [1 for i in range(n)]
            count = n
            for i in range(n):
                for j in range(i+1, n):
                    if m[i][j] == 1:
                        # Need to find the roots of the trees these nodes are in:
                        p1, p2 = find(i), find(j)
                        if p1 != p2:
                            count -= 1
                            # Only need to perform union when nodes were in different sets
                            r1, r2 = rank[p1], rank[p2]
                            if r1 >= r2:
                                parent[p2] = p1  # Attach one root to the other
                                if r1 == r2:  # No increment is needed if ranks are different
                                    rank[p1] += 1
                            else:
                                parent[p1] = p2  # Attach one root to the other
                                # No increment is needed if ranks are different
            return count