Search code examples
rustborrow-checkerownership

Rust: Safely terminate running threads inside a ThreadPool?


I followed the official Rust documentation that shows you how to write a multi-threaded HTTP server. They walk you through constructing a thread pool that's backed by individual workers, where each worker has its own thread. The thread pool will execute Job instances and pass them along to workers through the MPSC implementation.

This is what the code looks like, per the docs, and the full guide is here.

struct Worker {
    id: usize,
    thread: thread::JoinHandle<()>
}

impl Worker {
    fn new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Job>>>) -> Worker {
        let thread = thread::spawn(move || loop {
            let job = receiver.lock().unwrap().recv().unwrap();
            println!("Receiver {} received a job.", id);
            job();
        });
        Worker { id, thread }
    }

    fn get_thread(&self) -> &thread::JoinHandle<()> {
        return &self.thread;
    }
}

I want to add an additional function: shutdown() to my Worker implementation, i.e.

    fn shutdown(&mut self) {
        &self.thread.join();
    }

I would call this from the main thread, thereby waiting on all the threads to complete before the program ends execution. However, the shutdown function above gives me this error:

error: src/main.rs:31: cannot move out of `self.thread` which is behind a mutable reference
error: src/main.rs:31: move occurs because `self.thread` has type `JoinHandle<()>`, which does not implement the `Copy` trait

I declared self as &mut self since I'll be mutation self.thread. What am I overlooking?

I think it's telling me that by calling &self.thread.join(), I'm going to destroy the thread::JoinHandle<()> that belongs to this worker instance, and I'm not allowed to do that? Then how would I kill this thread?

Secondly, I recognize that my thread has an infinite loop, and so technically calling join on this thread will just cause the main thread to hang. How should I go about breaking out of the thread? Pass a shared reference to my Worker and have the worker check its value, breaking out from the loop?


Solution

  • So first of all, I am guessing you have a ThreadPool struct. You would need to implement the Drop trait for that struct such that the joining of all threads is done when the entire ThreadPool is being dropped, you can't really do so on a per worker basis.

    impl Drop for ThreadPool {
        fn drop(&mut self) {
            for worker in &mut self.workers {
                println!("Shutting down worker {}", worker.id);
    
                if let Some(thread) = worker.thread.take() {
                    thread.join().unwrap();
                }
            }
        }
    }
    

    The take() call is important here because we want to take out the Some variant of the option. If the thread is already None, we don't need to call join.

    Shutting down a thread pool in any language requires some central signal to be acknowledged by all threads, not just one. If this was C++, you might check an atomic bool in the top of your thread's work function, and if it's true, return. While you could do that in rust, I will assume you want to follow a method closer to the project presented in the book itself, which is messages.

    Instead of a job, you want to use a message type in your channel, such that you have two variants, a job message and a terminate message.

    type Job = Box<dyn FnOnce() + Send + 'static>;
    enum Message {
        NewJob(Job),
        Terminate,
    }
    

    And your ThreadPool would be

    pub struct ThreadPool {
        workers: Vec<Worker>,
        sender: mpsc::Sender<Message>,
    }
    

    Then, your worker function would look more like this, similar to this section of the book

    impl Worker {
        fn new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Message>>>) -> Worker {
            let thread = thread::spawn(move || loop {
                let message = receiver.lock().unwrap().recv().unwrap();
    
                match message {
                    Message::NewJob(job) => {
                        println!("Worker {} got a job; executing.", id);
    
                        job();
                    }
                    Message::Terminate => {
                        println!("Worker {} was told to terminate.", id);
    
                        break;
                    }
                }
            });
    

    Now, you also mentioned your program hanging , with no way to externally shut it down. It's probably easier to use a signal handling crate than try to implement this yourself at first. One option is the singal_hook crate:

    use signal_hook::{iterator::Signals, SIGINT};
    use std::{error::Error, thread, time::Duration};
    
    fn main() -> Result<(), Box<dyn Error>> {
        let signals = Signals::new(&[SIGINT])?;
    
        thread::spawn(move || {
            for sig in signals.forever() {
                drop(pool);
            }
        });
    
        Ok(())
    }
    

    This will terminate the program when you do ctrl+c on your keyboard