Search code examples
rustasync-awaitchannelrust-tokio

How to send().await in a channel inside a tokio::spawn task?


The below code is very simple. I'm wondering how Rust people fixed this channels/spawn/mutex/async/await issue.

The issue is within broadcast() function. I'm getting the error below and I know the reason: clients.lock() is not usable here because of https://tokio.rs/tokio/tutorial/shared-state#holding-a-mutexguard-across-an-await.

How to send().await and at the same time do this waiting in a spawned task?

I need to do this in a spawn task because I need to not wait at all when calling the broadcast() function from my handlers.

I know I can https://tokio.rs/tokio/tutorial/shared-state#use-tokios-asynchronous-mutex but they also say:

an asynchronous mutex is more expensive than an ordinary mutex, and it is typically better to use one of the two other approaches.

For this reason I'm wondering if there is a better way I don't see yet.

Is there a better way than this shared state method?

REPL: https://www.rustexplorer.com/b/pk8weh

use std::{
    collections::HashMap,
    sync::{Arc, Mutex},
};
use tokio::sync::mpsc;

type TeamId = String;
type PlayerId = String;

#[derive(Default)]
pub struct Broadcaster {
    clients: Arc<Mutex<HashMap<TeamId, HashMap<PlayerId, Vec<Connection>>>>>,
}

pub struct Connection {
    pub session_id: String,
    pub player_id: String,
    pub sender: mpsc::Sender<Arc<Message>>,
}

pub struct Message {
    pub team_id: TeamId,
    pub session_id: String,
    pub message: String,
}

#[tokio::main]
async fn main() {
    let broadcaster = Arc::new(Broadcaster::default());

    for i in 0..10 {
        let broadcaster = broadcaster.clone();

        tokio::spawn(async move {
            let mut rx = broadcaster
                .add_client("1", &format!("session_{i}"), &format!("player_{i}"))
                .await;

            while let Some(_) = rx.recv().await {
                println!("GOT message in team 1 - i: {i}");
            }
        });

        println!("added team 1 - client {}", i);
    }

    for i in 0..5 {
        let broadcaster = broadcaster.clone();

        tokio::spawn(async move {
            let mut rx = broadcaster
                .add_client("2", &format!("session_{i}"), &format!("player_{i}"))
                .await;

            while let Some(_) = rx.recv().await {
                println!("GOT message in team 2 - i: {i}");
            }
        });

        println!("added team 2 - client {}", i);
    }

    tokio::time::sleep(std::time::Duration::from_millis(100)).await;

    dbg!(broadcaster.clients.lock().unwrap().len());

    for _ in 0..50 {
        broadcaster
            .broadcast(Message {
                team_id: "1".to_string(),
                session_id: "1".to_string(),
                message: "fake one".to_string(),
            })
            .await;
    }
}

impl Broadcaster {
    pub async fn add_client(
        &self,
        team_id: &str,
        session_id: &str,
        player_id: &str,
    ) -> mpsc::Receiver<Arc<Message>> {
        let (tx, rx) = mpsc::channel::<Arc<Message>>(10);

        let mut clients = self.clients.lock().unwrap();

        if !clients.contains_key(team_id) {
            clients.insert(team_id.to_string(), HashMap::new());
        }

        let players = clients.get_mut(team_id).unwrap();

        if !players.contains_key(player_id) {
            players.insert(player_id.to_string().into(), Vec::new());
        }

        let connections = players.get_mut(player_id).unwrap();

        let connection = Connection {
            session_id: session_id.to_string(),
            player_id: player_id.to_string(),
            sender: tx,
        };

        connections.push(connection);

        rx
    }

    // HERE THE ISSUE!

    pub async fn broadcast(&self, message: Message) {
        let clients = self.clients.clone();

        let message = Arc::new(message);

        // I need this to not wait for broadcast()

        tokio::spawn(async move {
            let clients = clients.lock().unwrap();

            for connections in clients.get(&message.team_id).unwrap().values() {
                for connection in connections {
                    connection.sender.send(message.clone()).await;
                }
            }
        });
    }
}

Solution

  • I have redesigned your approach a little bit, first of all, let's use the already provided by tokio broadcast channel.

    #[derive(Clone)]
    pub struct Broadcaster {
        sender: broadcast::Sender<Arc<Message>>,
    }
    

    Let's move the filtering logic to the end receiver. Let's call it Listener.

    pub struct Listener {
        // Could be added more fields to filter the exact receiver more precisely
        session_id: String,
        team_id: String,
        receiver: broadcast::Receiver<Arc<Message>>,
    }
    

    The listener will have a single API.

    impl Listener {
        async fn recv(&mut self) -> Option<Arc<Message>> {
            match self.receiver.recv().await {
                Ok(msg) if msg.team_id == self.team_id && msg.session_id == self.session_id => {
                    Some(msg)
                }
                // need to handle RecvError::Closed,
                Err(_) | Ok(_) => None,
            }
        }
    }
    

    Basically, here we receive the message and then check if this message should be processed by our listener.

    Then simple implementation of Broadcaster

    impl Broadcaster {
        pub fn new() -> Self {
            let (sender, _rx) = broadcast::channel::<Arc<Message>>(100);
    
            Self { sender }
        }
    
        pub async fn add_client(&self, team_id: &str, session_id: &str, _player_id: &str) -> Listener {
            Listener {
                session_id: session_id.to_owned(),
                team_id: team_id.to_owned(),
                receiver: self.sender.subscribe(),
            }
        }
    
        pub async fn broadcast(&self, message: Message) {
            let message = Arc::new(message);
            // non-blocking operation.
            // you need to handle the result
            let _ = self.sender.send(message);
        }
    }
    

    Full code with main:

    use std::sync::Arc;
    use tokio::sync::broadcast;
    
    type TeamId = String;
    type PlayerId = String;
    
    #[derive(Clone)]
    pub struct Broadcaster {
        sender: broadcast::Sender<Arc<Message>>,
    }
    
    pub struct Connection {
        pub session_id: String,
        pub player_id: String,
        pub sender: broadcast::Sender<Arc<Message>>,
    }
    
    pub struct Message {
        pub team_id: TeamId,
        pub session_id: String,
        pub message: String,
    }
    
    pub struct Listener {
        session_id: String,
        team_id: String,
        receiver: broadcast::Receiver<Arc<Message>>,
    }
    
    impl Listener {
        async fn recv(&mut self) -> Option<Arc<Message>> {
            match self.receiver.recv().await {
                Ok(msg) if msg.team_id == self.team_id && msg.session_id == self.session_id => {
                    Some(msg)
                }
                // need to handle RecvError::Closed,
                Err(_) | Ok(_) => None,
            }
        }
    }
    
    #[tokio::main]
    async fn main() {
        let broadcaster = Broadcaster::new();
    
        for i in 0..10 {
            let broadcaster = broadcaster.clone();
    
            tokio::spawn(async move {
                let mut rx = broadcaster
                    .add_client("1", &format!("1"), &format!("player_{i}"))
                    .await;
    
                while let Some(_) = rx.recv().await {
                    println!("GOT message in team 1 - i: {i}");
                }
            });
    
            println!("added team 1 - client {}", i);
        }
    
        for i in 0..5 {
            let broadcaster = broadcaster.clone();
    
            tokio::spawn(async move {
                let mut rx = broadcaster
                    .add_client("2", &format!("session_{i}"), &format!("player_{i}"))
                    .await;
    
                while let Some(_) = rx.recv().await {
                    println!("GOT message in team 2 - i: {i}");
                }
            });
    
            println!("added team 2 - client {}", i);
        }
    
        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
    
        dbg!(broadcaster.sender.receiver_count());
    
        for _ in 0..50 {
            broadcaster
                .broadcast(Message {
                    team_id: "1".to_string(),
                    session_id: "1".to_string(),
                    message: "fake one".to_string(),
                })
                .await;
        }
    }
    
    impl Broadcaster {
        pub fn new() -> Self {
            let (sender, _rx) = broadcast::channel::<Arc<Message>>(100);
    
            Self { sender }
        }
    
        pub async fn add_client(&self, team_id: &str, session_id: &str, _player_id: &str) -> Listener {
            Listener {
                session_id: session_id.to_owned(),
                team_id: team_id.to_owned(),
                receiver: self.sender.subscribe(),
            }
        }
    
        pub async fn broadcast(&self, message: Message) {
            let message = Arc::new(message);
            // you need to handle the result
            let _ = self.sender.send(message);
        }
    }