Search code examples
rustrust-async-std

Rust, how to perform basic recursive async?


I am just doing some quick experimenting in an attempt to learn the rust language, I have done a few successful async tests, this is my starting point:

use async_std::task;
use futures;
use std::time::SystemTime;

fn main() {
    let now = SystemTime::now();
    task::block_on(async {
        let mut fs = Vec::new();
        let sum = 100000000;
        let chunks: u64 = 5; //only valid for factors of sum
        let chunk_size: u64 = sum/chunks;
        for n in 1..=chunks {
            fs.push(task::spawn(async move {
                add_range((n - 1) * chunk_size + 1, n * chunk_size + 1)
            }));
        }
        let vals = futures::future::join_all(fs).await;
        // 5000000050000000 for this configuration of inputs
        println!("{}", vals.iter().sum::<u64>());
    });
    println!("{}ms", now.elapsed().unwrap().as_millis());
}

fn add_range(start: u64, end: u64) -> u64 {
    println!("{}, {}", start, end);
    let mut total: u64 = 0;
    for n in start..end {
        total += n;
    }
    return total;
}

by changing the value of chunks you can change how many task::spawns there are. Now rather than a flat set of workers, I want the add_range function to be recursive and to keep forking off workers based on the inputs, however following the compiler errors I have gotten myself quite tangled up:

use async_std::task;
use futures;
use std::future::Future;
use std::pin::Pin;

fn main() {
    let pin_box_u64 = task::block_on(add_range(0, 10, 10, 1, 1001));
    println!("{}", pin_box_u64/*how do i get u64 out of this*/)
}

// recursively calls itself in a branching tree structure
// forking off more worker threads
async fn add_range(
    depth: u64,
    chunk_split: u64,
    chunk_size: u64,
    start: u64,
    end: u64,
) -> Pin<Box<dyn Future<Output = u64>>> {
    println!("{}, {}, {}", depth, start, end);
    // if the range of start to end is more than the allowed
    // chunk_size then fork off more workers dividing
    // the work up further.
    if end - start > chunk_size {
        let mut fs = Vec::new();
        let next_chunk_size = (end - start) / chunk_split;
        for n in 0..chunk_split {
            let s = start + (next_chunk_size * n);
            let mut e = start + (next_chunk_size * (n + 1));
            if e > end {
                e = end;
            }
            // spawn more workers
            fs.push(task::spawn(add_range(depth + 1, chunk_split, chunk_size, s, e)));
        }
        return Box::pin(async move {
            // join workers back up and do joining sum. 
            return futures::future::join_all(fs).await.iter().map(/*how do i get u64s out of here*/).sum::<u64>();
        });
    } else {
        // else the work is less than the allowed chunk_size
        // so lets now do the actual sum for my chunk 
        let mut total: u64 = 0;
        for n in start..end {
            total += n;
        }
        return Box::pin(async move { total });
    }
}

I have played around with this for a while but I feel like Im just becoming more and more lost with the compiler errors.


Solution

  • You need to box the returned future, otherwise the compiler can't determine the size of the return type.

    Additional context can be found here: https://rust-lang.github.io/async-book/07_workarounds/04_recursion.html

    use std::pin::Pin;
    
    use async_std::task;
    use futures::Future;
    use futures::FutureExt;
    
    fn main() {
        let pin_box_u64 = task::block_on(add_range(0, 10, 10, 1, 1001));
        println!("{}", pin_box_u64)
    }
    
    // recursively calls itself in a branching tree structure
    // forking off more worker threads
    fn add_range(
        depth: u64,
        chunk_split: u64,
        chunk_size: u64,
        start: u64,
        end: u64,
    ) -> Pin<Box<dyn Future<Output = u64> + Send + 'static>> {
        println!("{}, {}, {}", depth, start, end);
        // if the range of start to end is more than the allowed
        // chunk_size then fork off more workers dividing
        // the work up further.
        if end - start > chunk_size {
            let mut fs = Vec::new();
            let next_chunk_size = (end - start) / chunk_split;
            for n in 0..chunk_split {
                let s = start + (next_chunk_size * n);
                let mut e = start + (next_chunk_size * (n + 1));
                if e > end {
                    e = end;
                }
                // spawn more workers
                fs.push(task::spawn(add_range(
                    depth + 1,
                    chunk_split,
                    chunk_size,
                    s,
                    e,
                )));
            }
            // join workers back up and do joining sum.
            return futures::future::join_all(fs)
                .map(|v| v.iter().sum::<u64>())
                .boxed();
        } else {
            // else the work is less than the allowed chunk_size
            // so lets now do the actual sum for my chunk
            let mut total: u64 = 0;
            for n in start..end {
                total += n;
            }
            return futures::future::ready(total).boxed();
        }
    }