Search code examples
rustrust-tokio

tokio::select! but for a Vec of futures


I have a Vec of futures which I want to execute concurrently (but not necessarily in parallel). Basically, I'm looking for some kind of select function that is similar to tokio::select! but takes a collection of futures, or, conversely, a function that is similar to futures::join_all but returns once the first future is done.

An additional requirement is that once a future finished I might want to add a new future to the Vec.

With such a function, my code would roughly look like this:

use std::future::Future;
use std::time::Duration;
use tokio::time::sleep;

async fn wait(millis: u64) -> u64 {
    sleep(Duration::from_millis(millis)).await;
    millis
}

// This pseudo-implementation simply removes the last
// future and awaits it. I'm looking for something that
// instead polls all futures until one is finished, then
// removes that future from the Vec and returns it.
async fn select<F, O>(futures: &mut Vec<F>) -> O
where
    F: Future<Output=O>
{
    let future = futures.pop().unwrap();
    future.await
}

#[tokio::main]
async fn main() {
    let mut futures = vec![
        wait(500),
        wait(300),
        wait(100),
        wait(200),
    ];
    while !futures.is_empty() {
        let finished = select(&mut futures).await;
        println!("Waited {}ms", finished);
        if some_condition() {
            futures.push(wait(200));
        }
    }
}


Solution

  • This is exactly what futures::stream::FuturesUnordered is for (which I've found by looking through the source of StreamExt::for_each_concurrent):

    use futures::{stream::FuturesUnordered, StreamExt};
    use std::time::Duration;
    use tokio::time::{sleep, Instant};
    
    async fn wait(millis: u64) -> u64 {
        sleep(Duration::from_millis(millis)).await;
        millis
    }
    
    #[tokio::main]
    async fn main() {
        let mut futures = FuturesUnordered::new();
        futures.push(wait(500));
        futures.push(wait(300));
        futures.push(wait(100));
        futures.push(wait(200));
        
        let start_time = Instant::now();
    
        let mut num_added = 0;
        while let Some(wait_time) = futures.next().await {
            println!("Waited {}ms", wait_time);
            if num_added < 3 {
                num_added += 1;
                futures.push(wait(200));
            }
        }
        
        println!("Completed all work in {}ms", start_time.elapsed().as_millis());
    }
    

    (playground)

    A word of caution if you're using Tokio: As @Bryan Larsen has pointed out in a comment, there is the risk of performance problems when combining FuturesUnordered with Tokio. This article contains more details, and says that the issue should be fixed in recents versions of the futures crate (0.3.19 and later). Nevertheless, users of Tokio are better off with using Tokio's JoinSet. The same example as above then looks like this:

    use std::time::Duration;
    use tokio::task::JoinSet;
    use tokio::time::{sleep, Instant};
    
    async fn wait(millis: u64) -> u64 {
        sleep(Duration::from_millis(millis)).await;
        millis
    }
    
    #[tokio::main]
    async fn main() {
        let mut futures = JoinSet::new();
        futures.spawn(wait(500));
        futures.spawn(wait(300));
        futures.spawn(wait(100));
        futures.spawn(wait(200));
    
        let start_time = Instant::now();
    
        let mut num_added = 0;
        while let Some(result) = futures.join_next().await {
            let wait_time = result.unwrap();
            println!("Waited {}ms", wait_time);
            if num_added < 3 {
                num_added += 1;
                futures.spawn(wait(200));
            }
        }
    
        println!(
            "Completed all work in {}ms",
            start_time.elapsed().as_millis()
        );
    }
    

    (playground)