Search code examples
pythongraph-algorithmunion-find

Union-find algorithm does not return expected result


I implemented the following union-find algorithm using this example:

import numpy as np


class UnionFind(object):

    def __init__(self, edges):
        self.edges = edges
        self.n_edges = np.max(edges) + 1
        self.data = list(range(self.n_edges))

    def find(self, i):
        if i != self.data[i]:
            self.data[i] = self.find(self.data[i])
        return self.data[i]

    def union(self, i, j):
        pi, pj = self.find(i), self.find(j)
        if pi != pj:
            self.data[pi] = pj

    def run(self):

        for i, j in self.edges:
            self.union(i, j)

        labels = dict()
        for i in range(self.n_edges):
            labels[i] = self.find(i)

        for k, v in labels.items():
            print(k, v)


if __name__ == '__main__':
    edges = [(1, 1), (2, 2), (2, 3), (3, 3), (4, 2), (4, 4)] // pairs of equivalent labels
    uf = UnionFind(edges)
    uf.run()

I would expect the result to be

0 0 
1 1
2 2
3 2
4 2

but the algorithm above returns

0 0 
1 1
2 3
3 3
4 3

That is, I would like the smallest label to be the parent

Is there someone who can point out why this is the case and what I can do to get the expected result?


Solution

  • You want Union-Find by Rank

    Code

    Source

    class UF:
        """An implementation of union find data structure.
        It uses weighted quick union by rank with path compression.
        """
    
        def __init__(self, N):
            """Initialize an empty union find object with N items.
    
            Args:
                N: Number of items in the union find object.
            """
    
            self._id = list(range(N))
            self._count = N
            self._rank = [0] * N
    
        def find(self, p):
            """Find the set identifier for the item p."""
    
            id = self._id
            while p != id[p]:
                p = id[p] = id[id[p]]   # Path compression using halving.
            return p
    
        def count(self):
            """Return the number of items."""
    
            return self._count
    
        def connected(self, p, q):
            """Check if the items p and q are on the same set or not."""
    
            return self.find(p) == self.find(q)
    
        def union(self, p, q):
            """Combine sets containing p and q into a single set."""
    
            id = self._id
            rank = self._rank
    
            i = self.find(p)
            j = self.find(q)
            if i == j:
                return
    
            self._count -= 1
            if rank[i] < rank[j]:
                id[i] = j
            elif rank[i] > rank[j]:
                id[j] = i
            else:
                id[j] = i
                rank[i] += 1
    
        def __str__(self):
            """String representation of the union find object."""
            return " ".join([str(x) for x in self._id])
    
        def __repr__(self):
            """Representation of the union find object."""
            return "UF(" + str(self) + ")"
    

    Example

    Using your example edges.

    N = 5
    edges = [(1, 1), (2, 2), (2, 3), (3, 3), (4, 2), (4, 4)] 
    
    uf = UF(N)
    
    for p, q in edges:
      uf.union(p, q)
    
    uf.show()
    

    Output

    0 0
    1 1
    2 2
    2 2
    2 2
    

    Comments

    It is not common to show the self edges as edges in undirected graphs.

    Thus, rather than

    edges = [(1, 1), (2, 2), (2, 3), (3, 3), (4, 2), (4, 4)]
    

    Its more common to have (i.e. just the non-self edges):

    edges = [(2, 3), (4, 2)]
    

    The same output is produced by the above code in either case.

    Since self-edges are not shown, you can not get the number of vertices from

    self.n_edges = np.max(edges) + 1  # not normally correct 
    

    Number of vertices is normally specified.