Search code examples
rustrust-tokio

How to store a list of closures returning a Future and share it between threads in Rust?


I'm trying to write a multi-thread TCP server that can handle multiple connections at the same time with tokio.

I want to structure it in an event-driven way, where it's possible to attach one or more closures to a specific event (like new connection, message received, client disconnected etc).

For example:

    server.on_message(|msg: String, stream: &mut TcpStream| {
       async move {
            println!("Recieved {:?}", msg);
            stream.write_all(b"Hello\n").await;
       }
    }).await;

Every connection will receive its own thread which should have read access to a Vec of callbacks.

    pub async fn run(&mut self) {
        let listener = TcpListener::bind("127.0.0.1:9090").await.unwrap();

        loop {
            let (mut socket, _) = listener.accept().await.unwrap();

            let cb = self.on_message.clone();
            tokio::spawn(async move {
                Self::process(socket, cb).await;
            });
        }
    }

Unfortunately, my understanding of Rust is still very rudimentary and I run in circles with:

  • finding the right type to be stored in Vec
  • finding the right type for function argument that takes closure callback

Whenever I feel like I make progress in one place then I realise the other got messed up. This is the best I could do but it still doesn't work.

type Callback<T> = dyn Fn(T, &mut TcpStream) -> Pin<Box<dyn Future<Output=()> + Send>> + Send + 'static;

unsafe impl<T> Send for TcpStreamCallbackList<T> {}
unsafe impl<T> Sync for TcpStreamCallbackList<T> {}

impl<T> TcpStreamCallbackList<T> {

    pub fn new() -> Self {
        Self { callbacks: Vec::new() }
    }

    pub fn push<G: Send + 'static>(&mut self, mut fun: impl Fn(T, &mut TcpStream) -> G + Send + 'static) where G: Future<Output=()> {
        self.callbacks.push(Arc::new(Box::new(move |val:T, stream: &mut TcpStream| Box::pin(fun(val, stream)))));
    }

    pub async fn call(&self, val: T, stream: &mut TcpStream) where T: Clone {
        for cb in self.callbacks.iter() {
            let _cb = cb.clone();
            _cb(val.clone(), stream).await; // B O O M

        }
    }
}

The above code doesn't compile until I remove .await on the Future returned by callback in the call function (which defeats the purpose).

error[E0277]: `dyn for<'a> Fn(String, &'a mut tokio::net::TcpStream) -> Pin<Box<dyn futures::Future<Output = ()> + std::marker::Send>> + std::marker::Send` cannot be shared between threads safely
   --> src/main.rs:94:26

From what I understand the problem is that the retuned Future is not Send.

note: required by a bound in `tokio::spawn`
   --> /Users/lukasz/.cargo/registry/src/github.com-1ecc6299db9ec823/tokio-1.25.0/src/task/spawn.rs:163:21
    |
163 |         T: Future + Send + 'static,
    |                     ^^^^ required by this bound in `tokio::spawn`

I have no idea if my type makes sense and is thread-safe. I also don't know why the compiler thinks the return type is not Send. I'm really stuck here and would appreciate any help.


Solution

  • I've put together one that is slightly simpler but can be spawned (playground). The key is that Callback needs to be Sync in order for &self to be Send. I tried to use the trick mentioned in this comment, but it doesn't appear to work here, nor does making call take &mut self. I wrote more about Send and Sync on that answer.

    use std::future::Future;
    use std::pin::Pin;
    
    type CallbackFuture<O> = Pin<Box<dyn Future<Output = O> + Send>>;
    type Callback<T> = dyn (Fn(T) -> CallbackFuture<()>) + Send + Sync;
    
    pub struct CallbackList<T> {
        list: Vec<Box<Callback<T>>>,
    }
    
    impl<T> CallbackList<T> {
        pub fn new() -> Self {
            Self { list: Vec::new() }
        }
    
        pub fn push<F>(&mut self, f: F)
        where
            F: Fn(T) -> CallbackFuture<()>,
            F: Send + Sync + 'static,
        {
            self.list.push(Box::new(f))
        }
    
        pub async fn call(&self, t: T)
        where
            T: Clone,
        {
            for f in &self.list {
                f(t.clone()).await;
            }
        }
    }
    
    #[tokio::main]
    async fn main() {
        let mut calls = CallbackList::new();
        calls.push(|i| {
            Box::pin(async move {
                println!("{i}");
            })
        });
    
        calls.push(|i| {
            Box::pin(async move {
                println!("{}", i + 1);
            })
        });
    
        let handle = tokio::spawn(async move {
            calls.call(34).await;
        });
    
        handle.await.unwrap();
    }
    

    I have removed as many trait bounds, 'statics, and wrappers as possible, but you may need to add some back depending on what you do with it. Right now it takes T, but it should be possible to separate that into T and &mut TcpStream. If you update your question with a main function that uses all the elements, I can change mine to match. If all else fails, you can use (_, Arc<Mutex<TcpStream>>) as T.