Search code examples
rustunion-find

Union-Find implementation does not update parent tags


I'm trying to create some sets of Strings and then merge some of these sets so that they have the same tag (of type usize). Once I initialize the map, I start adding strings:

self.clusters.make_set("a");
self.clusters.make_set("b");

When I call self.clusters.find("a") and self.clusters.find("b"), different values are returned, which is fine because I haven't merged the sets yet. Then I call the following method to merge two sets

let _ = self.clusters.union("a", "b");

If I call self.clusters.find("a") and self.clusters.find("b") now, I get the same value. However, when I call the finalize() method and try to iterate through the map, the original tags are returned, as if I never merged the sets.

self.clusters.finalize();

for (address, tag) in &self.clusters.map {
    self.clusterizer_writer.write_all(format!("{};{}\n", address, 
    self.clusters.parent[*tag]).as_bytes()).unwrap();
}

// to output all keys with the same tag as a list. 
let a: Vec<(usize, Vec<String>)> = {
    let mut x = HashMap::new();
    for (k, v) in self.clusters.map.clone() {
        x.entry(v).or_insert_with(Vec::new).push(k)
    }
    x.into_iter().collect()
};

I can't figure out why this is the case, but I'm relatively new to Rust; maybe its an issue with pointers?

Instead of "a" and "b", I'm actually using something like utils::arr_to_hex(&input.outpoint.txid) of type String.

This is the Rust implementation of the Union-Find algorithm that I am using:

/// Tarjan's Union-Find data structure.
#[derive(RustcDecodable, RustcEncodable)]
pub struct DisjointSet<T: Clone + Hash + Eq> {
    set_size: usize,
    parent: Vec<usize>,
    rank: Vec<usize>,
    map: HashMap<T, usize>, // Each T entry is mapped onto a usize tag.
}

impl<T> DisjointSet<T>
where
    T: Clone + Hash + Eq,
{
    pub fn new() -> Self {
        const CAPACITY: usize = 1000000;
        DisjointSet {
            set_size: 0,
            parent: Vec::with_capacity(CAPACITY),
            rank: Vec::with_capacity(CAPACITY),
            map: HashMap::with_capacity(CAPACITY),
        }
    }

    pub fn make_set(&mut self, x: T) {
        if self.map.contains_key(&x) {
            return;
        }

        let len = &mut self.set_size;
        self.map.insert(x, *len);
        self.parent.push(*len);
        self.rank.push(0);

        *len += 1;
    }

    /// Returns Some(num), num is the tag of subset in which x is.
    /// If x is not in the data structure, it returns None.
    pub fn find(&mut self, x: T) -> Option<usize> {
        let pos: usize;
        match self.map.get(&x) {
            Some(p) => {
                pos = *p;
            }
            None => return None,
        }

        let ret = DisjointSet::<T>::find_internal(&mut self.parent, pos);
        Some(ret)
    }

    /// Implements path compression.
    fn find_internal(p: &mut Vec<usize>, n: usize) -> usize {
        if p[n] != n {
            let parent = p[n];
            p[n] = DisjointSet::<T>::find_internal(p, parent);
            p[n]
        } else {
            n
        }
    }

    /// Union the subsets to which x and y belong.
    /// If it returns Ok<u32>, it is the tag for unified subset.
    /// If it returns Err(), at least one of x and y is not in the disjoint-set.
    pub fn union(&mut self, x: T, y: T) -> Result<usize, ()> {
        let x_root;
        let y_root;
        let x_rank;
        let y_rank;
        match self.find(x) {
            Some(x_r) => {
                x_root = x_r;
                x_rank = self.rank[x_root];
            }
            None => {
                return Err(());
            }
        }

        match self.find(y) {
            Some(y_r) => {
                y_root = y_r;
                y_rank = self.rank[y_root];
            }
            None => {
                return Err(());
            }
        }

        // Implements union-by-rank optimization.
        if x_root == y_root {
            return Ok(x_root);
        }

        if x_rank > y_rank {
            self.parent[y_root] = x_root;
            return Ok(x_root);
        } else {
            self.parent[x_root] = y_root;
            if x_rank == y_rank {
                self.rank[y_root] += 1;
            }
            return Ok(y_root);
        }
    }

    /// Forces all laziness, updating every tag.
    pub fn finalize(&mut self) {
        for i in 0..self.set_size {
            DisjointSet::<T>::find_internal(&mut self.parent, i);
        }
    }
}

Solution

  • I think you're just not extracting the information out of your DisjointSet struct correctly.

    I got sniped by this and implemented union find. First, with a basic usize implemention:

    pub struct UnionFinderImpl {
        parent: Vec<usize>,
    }
    

    Then with a wrapper for more generic types:

    pub struct UnionFinder<T: Hash> {
        rev: Vec<Rc<T>>,
        fwd: HashMap<Rc<T>, usize>,
        uf: UnionFinderImpl,
    }
    

    Both structs implement a groups() method that returns a Vec<Vec<>> of groups. Clone isn't required because I used Rc.

    Playground