Search code examples
algorithmmatlabunion-find

Disjoint-Set (union-find) data structure in MATLAB


Are there any implementations of the union-find algorithm in MATLAB?

If not, is it possible to implement it using a classdef or any other method?

I looked online a lot and couldn't find any implementations for MATLAB! This was the last place I could ask, a hand would be awesome!

Thanks in advance


Solution

  • A disjoint set can be implemented by just using vector. a[u] = ancestor of node u.

    My implementation, using path halving and union by size to limit the height of the tree (not very well tested):

    classdef DJSet < handle
        properties
            N; root;size;
        end
    
        methods
            function obj=DJSet(n)
                obj.N = n;
                obj.root = 1:n;
                obj.size = ones(1,n);
            end
    
            function root = find(obj, u)
                while obj.root(u) ~= u
                    obj.root(u) = obj.root(obj.root(u));
                    u = obj.root(u);
                end
                root = u;
            end
    
            function union(obj, u, v)
                root_u = obj.find(u);
                root_v = obj.find(v);
                if root_u == root_v
                    return;
                end
                if obj.size(root_u) < obj.size(root_v)
                    obj.root(root_u) = root_v;
                    obj.size(root_v) = obj.size(root_v) + obj.size(root_u);
                else
                    obj.root(root_v) = root_u;
                    obj.size(root_u) = obj.size(root_u) + obj.size(root_v);
                end
            end
    
            function res = is_connected(obj, u, v)
                root_u = obj.find(u);
                root_v = obj.find(v);
                res = root_u == root_v;
            end
        end
    end
    

    Test cases:

    dj=DJSet(10);
    edges = [1 2; 3 4; 5 6; 7 8; 2 3; 1 3; 6 7];
    
    for i = 1:size(edges,1)
        dj.union(edges(i,1), edges(i,2));
    end
    
    for j = 2:10
        fprintf('%d and %d connection is %d\n', 1, j, dj.is_connected(1, j));
    end
    
    dj.union(3,6);
    
    fprintf('#####\nAfter connecting 3 and 6\n')
    for j = 2:10
        fprintf('%d and %d connection is %d\n', 1, j, dj.is_connected(1, j));
    end
    
    >> test
    1 and 2 connection is 1
    1 and 3 connection is 1
    1 and 4 connection is 1
    1 and 5 connection is 0
    1 and 6 connection is 0
    1 and 7 connection is 0
    1 and 8 connection is 0
    1 and 9 connection is 0
    1 and 10 connection is 0
    #####
    After connecting 3 and 6
    1 and 2 connection is 1
    1 and 3 connection is 1
    1 and 4 connection is 1
    1 and 5 connection is 1
    1 and 6 connection is 1
    1 and 7 connection is 1
    1 and 8 connection is 1
    1 and 9 connection is 0
    1 and 10 connection is 0