The below code is very simple. I'm wondering how Rust people fixed this channels/spawn/mutex/async/await issue.
The issue is within broadcast()
function. I'm getting the error below and I know the reason: clients.lock()
is not usable here because of https://tokio.rs/tokio/tutorial/shared-state#holding-a-mutexguard-across-an-await.
How to send().await
and at the same time do this waiting in a spawn
ed task?
I need to do this in a spawn
task because I need to not wait at all when calling the broadcast()
function from my handlers.
I know I can https://tokio.rs/tokio/tutorial/shared-state#use-tokios-asynchronous-mutex but they also say:
an asynchronous mutex is more expensive than an ordinary mutex, and it is typically better to use one of the two other approaches.
For this reason I'm wondering if there is a better way I don't see yet.
Is there a better way than this shared state method?
REPL: https://www.rustexplorer.com/b/pk8weh
use std::{
collections::HashMap,
sync::{Arc, Mutex},
};
use tokio::sync::mpsc;
type TeamId = String;
type PlayerId = String;
#[derive(Default)]
pub struct Broadcaster {
clients: Arc<Mutex<HashMap<TeamId, HashMap<PlayerId, Vec<Connection>>>>>,
}
pub struct Connection {
pub session_id: String,
pub player_id: String,
pub sender: mpsc::Sender<Arc<Message>>,
}
pub struct Message {
pub team_id: TeamId,
pub session_id: String,
pub message: String,
}
#[tokio::main]
async fn main() {
let broadcaster = Arc::new(Broadcaster::default());
for i in 0..10 {
let broadcaster = broadcaster.clone();
tokio::spawn(async move {
let mut rx = broadcaster
.add_client("1", &format!("session_{i}"), &format!("player_{i}"))
.await;
while let Some(_) = rx.recv().await {
println!("GOT message in team 1 - i: {i}");
}
});
println!("added team 1 - client {}", i);
}
for i in 0..5 {
let broadcaster = broadcaster.clone();
tokio::spawn(async move {
let mut rx = broadcaster
.add_client("2", &format!("session_{i}"), &format!("player_{i}"))
.await;
while let Some(_) = rx.recv().await {
println!("GOT message in team 2 - i: {i}");
}
});
println!("added team 2 - client {}", i);
}
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
dbg!(broadcaster.clients.lock().unwrap().len());
for _ in 0..50 {
broadcaster
.broadcast(Message {
team_id: "1".to_string(),
session_id: "1".to_string(),
message: "fake one".to_string(),
})
.await;
}
}
impl Broadcaster {
pub async fn add_client(
&self,
team_id: &str,
session_id: &str,
player_id: &str,
) -> mpsc::Receiver<Arc<Message>> {
let (tx, rx) = mpsc::channel::<Arc<Message>>(10);
let mut clients = self.clients.lock().unwrap();
if !clients.contains_key(team_id) {
clients.insert(team_id.to_string(), HashMap::new());
}
let players = clients.get_mut(team_id).unwrap();
if !players.contains_key(player_id) {
players.insert(player_id.to_string().into(), Vec::new());
}
let connections = players.get_mut(player_id).unwrap();
let connection = Connection {
session_id: session_id.to_string(),
player_id: player_id.to_string(),
sender: tx,
};
connections.push(connection);
rx
}
// HERE THE ISSUE!
pub async fn broadcast(&self, message: Message) {
let clients = self.clients.clone();
let message = Arc::new(message);
// I need this to not wait for broadcast()
tokio::spawn(async move {
let clients = clients.lock().unwrap();
for connections in clients.get(&message.team_id).unwrap().values() {
for connection in connections {
connection.sender.send(message.clone()).await;
}
}
});
}
}
I have redesigned your approach a little bit, first of all, let's use the already provided by tokio
broadcast channel.
#[derive(Clone)]
pub struct Broadcaster {
sender: broadcast::Sender<Arc<Message>>,
}
Let's move the filtering logic to the end receiver. Let's call it Listener
.
pub struct Listener {
// Could be added more fields to filter the exact receiver more precisely
session_id: String,
team_id: String,
receiver: broadcast::Receiver<Arc<Message>>,
}
The listener will have a single API.
impl Listener {
async fn recv(&mut self) -> Option<Arc<Message>> {
match self.receiver.recv().await {
Ok(msg) if msg.team_id == self.team_id && msg.session_id == self.session_id => {
Some(msg)
}
// need to handle RecvError::Closed,
Err(_) | Ok(_) => None,
}
}
}
Basically, here we receive the message and then check if this message should be processed by our listener.
Then simple implementation of Broadcaster
impl Broadcaster {
pub fn new() -> Self {
let (sender, _rx) = broadcast::channel::<Arc<Message>>(100);
Self { sender }
}
pub async fn add_client(&self, team_id: &str, session_id: &str, _player_id: &str) -> Listener {
Listener {
session_id: session_id.to_owned(),
team_id: team_id.to_owned(),
receiver: self.sender.subscribe(),
}
}
pub async fn broadcast(&self, message: Message) {
let message = Arc::new(message);
// non-blocking operation.
// you need to handle the result
let _ = self.sender.send(message);
}
}
Full code with main:
use std::sync::Arc;
use tokio::sync::broadcast;
type TeamId = String;
type PlayerId = String;
#[derive(Clone)]
pub struct Broadcaster {
sender: broadcast::Sender<Arc<Message>>,
}
pub struct Connection {
pub session_id: String,
pub player_id: String,
pub sender: broadcast::Sender<Arc<Message>>,
}
pub struct Message {
pub team_id: TeamId,
pub session_id: String,
pub message: String,
}
pub struct Listener {
session_id: String,
team_id: String,
receiver: broadcast::Receiver<Arc<Message>>,
}
impl Listener {
async fn recv(&mut self) -> Option<Arc<Message>> {
match self.receiver.recv().await {
Ok(msg) if msg.team_id == self.team_id && msg.session_id == self.session_id => {
Some(msg)
}
// need to handle RecvError::Closed,
Err(_) | Ok(_) => None,
}
}
}
#[tokio::main]
async fn main() {
let broadcaster = Broadcaster::new();
for i in 0..10 {
let broadcaster = broadcaster.clone();
tokio::spawn(async move {
let mut rx = broadcaster
.add_client("1", &format!("1"), &format!("player_{i}"))
.await;
while let Some(_) = rx.recv().await {
println!("GOT message in team 1 - i: {i}");
}
});
println!("added team 1 - client {}", i);
}
for i in 0..5 {
let broadcaster = broadcaster.clone();
tokio::spawn(async move {
let mut rx = broadcaster
.add_client("2", &format!("session_{i}"), &format!("player_{i}"))
.await;
while let Some(_) = rx.recv().await {
println!("GOT message in team 2 - i: {i}");
}
});
println!("added team 2 - client {}", i);
}
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
dbg!(broadcaster.sender.receiver_count());
for _ in 0..50 {
broadcaster
.broadcast(Message {
team_id: "1".to_string(),
session_id: "1".to_string(),
message: "fake one".to_string(),
})
.await;
}
}
impl Broadcaster {
pub fn new() -> Self {
let (sender, _rx) = broadcast::channel::<Arc<Message>>(100);
Self { sender }
}
pub async fn add_client(&self, team_id: &str, session_id: &str, _player_id: &str) -> Listener {
Listener {
session_id: session_id.to_owned(),
team_id: team_id.to_owned(),
receiver: self.sender.subscribe(),
}
}
pub async fn broadcast(&self, message: Message) {
let message = Arc::new(message);
// you need to handle the result
let _ = self.sender.send(message);
}
}