I have a std::collections::HashSet
, and I want to sample and remove a uniformly random element.
Currently, what I'm doing is randomly sampling an index using rand.gen_range
, then iterating over the HashSet
to that index to get the element. Then I remove the selected element. This works, but it's not efficient. Is there an efficient way to do randomly sample an element?
Here's a stripped down version of what my code looks like:
use std::collections::HashSet;
extern crate rand;
use rand::thread_rng;
use rand::Rng;
let mut hash_set = HashSet::new();
// ... Fill up hash_set ...
let index = thread_rng().gen_range(0, hash_set.len());
let element = hash_set.iter().nth(index).unwrap().clone();
hash_set.remove(&element);
// ... Use element ...
In 2023, the hashbrown crate, which implements the Rust Standard Library's HashSet, has introduced the RawTable API, which allows unsafe, lower-level access to the internal state of the HashSet. Using this API, we can directly implement random removal from the HashSet:
use hashbrown::HashSet;
use rand::prelude::*;
use std::hash::Hash;
fn remove_random<T, R>(set: &mut HashSet<T>, rng: &mut R) -> Option<T>
where
R: Rng,
T: Eq + PartialEq + Hash,
{
if set.is_empty() {
return None;
}
// If load factor is under 25%, shrink to fit.
// We need a high load factor to ensure that the sampling succeeds in a reasonable time,
// and the table doesn't rebalance on removals.
// Insertions can only cause the load factor to reach as low as 50%,
// so it's safe to shrink at 25%.
if set.capacity() >= 8 && set.len() < set.capacity() / 4 {
set.shrink_to_fit();
}
let raw_table = set.raw_table_mut();
let num_buckets = raw_table.buckets();
// Perform rejection sampling: Pick a random bucket, check if it's full,
// repeat until a full bucket is found.
loop {
let bucket_index = rng.gen_range(0..num_buckets);
// Safety: bucket_index is less than the number of buckets.
// Note that we return the first time we modify the table,
// so raw_table.buckets() never changes.
// Also, the table has been allocated, because set is a HashSet.
unsafe {
if raw_table.is_bucket_full(bucket_index) {
let bucket = raw_table.bucket(bucket_index);
let ((element, ()), _insert_slot) = raw_table.remove(bucket);
return Some(element);
}
}
}
}
This program requires the "raw" feature of the hashbrown crate to be enabled, with a Cargo.toml
like the following:
[dependencies]
hashbrown = { version = "^0.14.2", features = ["raw"] }
rand = "^0.8.5"