Search code examples
randomrusthashset

Can I randomly sample from a HashSet efficiently?


I have a std::collections::HashSet, and I want to sample and remove a uniformly random element.

Currently, what I'm doing is randomly sampling an index using rand.gen_range, then iterating over the HashSet to that index to get the element. Then I remove the selected element. This works, but it's not efficient. Is there an efficient way to do randomly sample an element?

Here's a stripped down version of what my code looks like:

use std::collections::HashSet;

extern crate rand;
use rand::thread_rng;
use rand::Rng;

let mut hash_set = HashSet::new();

// ... Fill up hash_set ...

let index = thread_rng().gen_range(0, hash_set.len());
let element = hash_set.iter().nth(index).unwrap().clone();
hash_set.remove(&element);

// ... Use element ...

Solution

  • In 2023, the hashbrown crate, which implements the Rust Standard Library's HashSet, has introduced the RawTable API, which allows unsafe, lower-level access to the internal state of the HashSet. Using this API, we can directly implement random removal from the HashSet:

    use hashbrown::HashSet;
    use rand::prelude::*;
    use std::hash::Hash;
    
    fn remove_random<T, R>(set: &mut HashSet<T>, rng: &mut R) -> Option<T>
    where
        R: Rng,
        T: Eq + PartialEq + Hash,
    {
        if set.is_empty() {
            return None;
        }
        // If load factor is under 25%, shrink to fit.
        // We need a high load factor to ensure that the sampling succeeds in a reasonable time,
        // and the table doesn't rebalance on removals.
        // Insertions can only cause the load factor to reach as low as 50%,
        // so it's safe to shrink at 25%.
        if set.capacity() >= 8 && set.len() < set.capacity() / 4 {
            set.shrink_to_fit();
        }
        let raw_table = set.raw_table_mut();
        let num_buckets = raw_table.buckets();
        // Perform rejection sampling: Pick a random bucket, check if it's full,
        // repeat until a full bucket is found.
        loop {
            let bucket_index = rng.gen_range(0..num_buckets);
            // Safety: bucket_index is less than the number of buckets.
            // Note that we return the first time we modify the table,
            // so raw_table.buckets() never changes.
            // Also, the table has been allocated, because set is a HashSet.
            unsafe {
                if raw_table.is_bucket_full(bucket_index) {
                    let bucket = raw_table.bucket(bucket_index);
                    let ((element, ()), _insert_slot) = raw_table.remove(bucket);
                    return Some(element);
                }
            }
        }
    }
    

    This program requires the "raw" feature of the hashbrown crate to be enabled, with a Cargo.toml like the following:

    [dependencies]
    hashbrown = { version = "^0.14.2", features = ["raw"] }
    rand = "^0.8.5"