Search code examples
rustconcurrencyoperating-systemmutexautomatic-ref-counting

How should I implement a concurrent approximate counter in rust?


I am reading chapter 29 of OS: the three easy pieces, which is about concurrent data structure. The first example of concurrent data structure is approximate counter. This data structure increments numbers by using a global Mutex and several local Mutexes with local counters. When a local counter hits a threshold, it grabs the global mutex and flushes its local counter number to the global counter.

This chapter shows code in C language. Since I'm practicing Rust language, I've been trying to implement the data structure in Rust, but I encountered the Deref trait implementation error. It seems that "Arc" can be used only with few structs. How should I change my code?

The original code in C

typedef struct __counter_t {
    int global; // global count
    pthread_mutex_t glock; // global lock
    int local[NUMCPUS]; // per-CPU count
    pthread_mutex_t llock[NUMCPUS]; // ... and locks
    int threshold; // update frequency
} counter_t;

// init: record threshold, init locks, init values
// of all local counts and global count
void init(counter_t *c, int threshold) {
    c->threshold = threshold;
    c->global = 0;
    pthread_mutex_init(&c->glock, NULL);
    int i;

    for (i = 0; i < NUMCPUS; i++) {
        c->local[i] = 0;
        pthread_mutex_init(&c->llock[i], NULL);
    }
}


// update: usually, just grab local lock and update
// local amount; once local count has risen ’threshold’,
// grab global lock and transfer local values to it
void update(counter_t *c, int threadID, int amt) {
    int cpu = threadID % NUMCPUS;
    pthread_mutex_lock(&c->llock[cpu]);
    c->local[cpu] += amt;
    if (c->local[cpu] >= c->threshold) {
        // transfer to global (assumes amt>0)
        pthread_mutex_lock(&c->glock);
        c->global += c->local[cpu];
        pthread_mutex_unlock(&c->glock);
        c->local[cpu] = 0;
    }
    pthread_mutex_unlock(&c->llock[cpu]);
}

// get: just return global amount (approximate)
int get(counter_t *c) {
    pthread_mutex_lock(&c->glock);
    int val = c->global;
    pthread_mutex_unlock(&c->glock);
    return val; // only approximate!
}

My code:

use std::fmt;
use std::sync::{Arc, Mutex};


pub struct Counter {
    value: Mutex<i32>
}

impl Counter {
    pub fn new() -> Self {
        Counter { value: Mutex::new(0)}
    }

    pub fn test_and_increment(&mut self) -> i32 {
        let mut value = self.value.lock().unwrap();
        *value += 1;

        if *value >= 10 {
            let old = *value;
            *value = 0;
            return old;
        }
        else {
            return 0;
        }
    }

    pub fn get(&mut self) -> i32 {
        *(self.value.lock().unwrap())
    }

    pub fn add(&mut self, value: i32) {
        *(self.value.lock().unwrap()) += value;
    }
}


impl fmt::Display for Counter {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(f, "{}", *self.value.lock().unwrap())
    }
}


pub struct ApproximateCounter {
    value: Counter,
    local_counters: [Counter; 4]
}

impl ApproximateCounter {
    pub fn new() -> Self {
        ApproximateCounter {
            value: Counter::new(),
            local_counters: [Counter::new(), Counter::new(), Counter::new(), Counter::new()]
        }
    }

    pub fn increment(&mut self, i: usize) {
        let local_value = self.local_counters[i].test_and_increment();

        if local_value > 0 {
            self.value.add(local_value);
        }
    }

    pub fn get(&mut self) -> i32 {
        self.value.get()
    }
}

fn main() {
    let mut counter = Arc::new(ApproximateCounter::new());
    let mut threads = Vec::new();
    for i in 0..4 {
        let c_counter = counter.clone();
        threads.push(thread::spawn(move || {
            for _ in 0..100 {
                c_counter.increment(i);
            }
        }));
    }
    for thread in threads {
        thread.join();
    }
    println!("{}", counter.get());
}

Error message:

error[E0596]: cannot borrow data in an `Arc` as mutable
  --> src/main.rs:54:21
   |
54 |                     c_counter.increment(i);
   |                     ^^^^^^^^^^^^^^^^^^^^^^ cannot borrow as mutable
   |
   = help: trait `DerefMut` is required to modify through a dereference, but it is not implemented for `Arc<ApproximateCounter>

Solution

  • ApproximateCounter::increment takes &mut self, but it should take &self instead. Arc gives you a shared reference. You therefore cannot obtain a mutable reference to something that is held in Arc. However Mutex provides interior mutability allowing you to mutate data behind a shared reference. So if you change incerment and test_and_incremet methods to take shared references to self instead your code should work.

    Couple minor things:

    • Your code won't compile, because you inside struct Counter you have Mutex, but you didn't specify it's type. Change it to Mutex<i32>.
    • get and add methods should take &self as well.
    • When you are cloning Arc prefer Arc::clone(&thing) over thing.clone().