Search code examples
ruststreamrust-tokio

Pinning trouble with tokio Stream


I would like to process incoming websocket messages differently for before 2s has elapsed and after 2s has elapsed.

It's tricky because we only have one read (which obviously can't be cloned) and isn't happy about being passed to functions either.

I thought I would select! on processing messages and the timer, and then select! again for phase 2 after the timer fuses the first select!, passing a mutable borrow of read to a different processing function.

Turns out I am unable to pass the read to a function at all due to pinning.

use std::time::Duration;

use futures_util::{Stream, StreamExt};
use tokio_tungstenite::connect_async; 

async fn wait_2_seconds() {
    tokio::time::sleep(Duration::from_secs(2)).await;
}

async fn process_messages(read: &mut impl Stream) {
    while let Some(m) = read.next().await {
        let data = m.unwrap().into_data();
        println!("{m:?}");
    }
}

#[tokio::main]
async fn main() {
    let url = url::Url::parse("wss://127.0.0.1:12345").unwrap();

    let (ws_stream, _) = connect_async(url).await.expect("Failed to connect");

    // don't plan on sending anything to ws server so discard write half
    let (_, read) = ws_stream.split();

    tokio::select!{
        _ = process_messages(&mut read) => {}, 
        _ = wait_2_seconds() => {}, 
    };
    
    println!("phase 1 complete");
}

So I am unsure how to pass (a mut borrow of) read to a function.

The error message says consider using Box::pin but then I realise I even know how to use Box::pin in this situation. I tried changing the process_messages parameter type to Box<Pin<&mut impl Stream>> and realised I needed help.


Solution

  • Just pin the read in main(). You can either Box::pin() it or better, tokio::pin!() it (or futures::pin_mut(), or even the nightly std::pin::pin!()). You also need to specify the Item type of the stream. Then take Pin<&mut impl Stream<Item = ...>> in process_messages():

    use std::pin::Pin;
    
    use tokio_tungstenite::tungstenite::error::Error;
    use tokio_tungstenite::tungstenite::protocol::Message;
    
    async fn process_messages(mut read: Pin<&mut impl Stream<Item = Result<Message, Error>>>) {
        while let Some(m) = read.next().await {
            let data = m.unwrap().into_data();
            println!("{data:?}");
        }
    }
    
    #[tokio::main]
    async fn main() {
        let url = url::Url::parse("wss://127.0.0.1:12345").unwrap();
    
        let (ws_stream, _) = connect_async(url).await.expect("Failed to connect");
    
        // don't plan on sending anything to ws server so discard write half
        let (_, read) = ws_stream.split();
        // Or `let read = Box::pin(read);`.
        tokio::pin!(read);
    
        tokio::select! {
            _ = process_messages(read.as_mut()) => {},
            _ = wait_2_seconds() => {},
        };
    
        println!("phase 1 complete");
    
        // Process after 2 seconds.
        process_messages(read).await;
    }