Search code examples
rustwebsocketrust-tokiorust-rocket

How to detect Rust rocket_ws client disconnected from WebSocket


From rocket_ws documentation (https://api.rocket.rs/v0.5/rocket_ws/) I know I can establish websocket connection with the client with this pice of code:

#[get("/echo?channel")]
fn echo_channel(ws: ws::WebSocket) -> ws::Channel<'static> {
use rocket::futures::{SinkExt, StreamExt};

ws.channel(move |mut stream| Box::pin(async move {
    while let Some(message) = stream.next().await {
        let _ = stream.send(message?).await;
    }

    Ok(())
}))

}

But how can I detect that connection is closed and that the client was disconnected?

This example only shows use-case for reading messages with stream.next(), but what if I don't expect messages from client and just want to send him new values (with something like this let _ = stream.send(ws::Message::Text(json!(reading).to_string())).await;) periodically with let mut interval = interval(Duration::from_secs(10));?


Solution

  • To detect when client disconnects from the websocket you can listen for Close message sent from the client.

    Code would look something like this:

    Some(Ok(message)) = stream.next() => {
        match message {
          ws::Message::Close(close_frame) => {
              // Handle Close message
              println!("Received Close message: {:?}", close_frame);
              let close_frame = ws::frame::CloseFrame {
                  code: ws::frame::CloseCode::Normal,
                  reason: "Client disconected".to_string().into(),
              };
              let _ = stream.close(Some(close_frame)).await;
              break;
            }
    }
    

    So your whole code for handling websockets in Rust with rocket_ws would look something like this:

    #[get("/ws")]
    pub fn echo_channel(ws: ws::WebSocket) -> rocket_ws::Channel<'static> {
        use rocket::futures::{SinkExt, StreamExt};
        use std::time::Duration;
        use rocket_ws as ws;
        use rocket::tokio as tokio;
    
        ws.channel(move |mut stream: ws::stream::DuplexStream| {
            Box::pin(async move {
                let mut interval = interval(Duration::from_secs(10));
    
                tokio::spawn(async move {
                    loop {
                        tokio::select! {
                            _ = interval.tick() => {
                                // Send message every 10 seconds
                                let reading = get_latest_readings().await.unwrap();
                                let _ = stream.send(ws::Message::Text(json!(reading).to_string())).await;
                                // println!("Sent message");
                            }
                            Some(Ok(message)) = stream.next() => {
                                match message {
                                    ws::Message::Text(text) => {
                                        // Handle Text message
                                        println!("Received Text message: {}", text);
                                    }
                                    ws::Message::Binary(data) => {
                                        // Handle Binary message
                                        println!("Received Binary message: {:?}", data);
                                    }
                                    ws::Message::Close(close_frame) => {
                                        // Handle Close message
                                        println!("Received Close message: {:?}", close_frame);
                                        let close_frame = ws::frame::CloseFrame {
                                            code: ws::frame::CloseCode::Normal,
                                            reason: "Client disconected".to_string().into(),
                                        };
                                        let _ = stream.close(Some(close_frame)).await;
                                        break;
                                    }
                                    ws::Message::Ping(ping_data) => {
                                        // Handle Ping message
                                        println!("Received Ping message: {:?}", ping_data);
                                    }
                                    ws::Message::Pong(pong_data) => {
                                        // Handle Pong message
                                        println!("Received Pong message: {:?}", pong_data);
                                    }
                                    _ => {
                                        println!("Received other message: {:?}", message);
                                    }
                                }
                            }
                            else => {
                                println!("Connection closed");
                                let close_frame = ws::frame::CloseFrame {
                                    code: ws::frame::CloseCode::Normal,
                                    reason: "Client disconected".to_string().into(),
                                };
                                let _ = stream.close(Some(close_frame)).await;
                                // The connection is closed by the client
                                break;
                            }
                        }
                    }
                });
    
                tokio::signal::ctrl_c().await.unwrap();
                Ok(())
            })
        })
    }