12

I have a Vec of futures which I want to execute concurrently (but not necessarily in parallel). Basically, I'm looking for some kind of select function that is similar to tokio::select! but takes a collection of futures, or, conversely, a function that is similar to futures::join_all but returns once the first future is done.

An additional requirement is that once a future finished I might want to add a new future to the Vec.

With such a function, my code would roughly look like this:

use std::future::Future;
use std::time::Duration;
use tokio::time::sleep;

async fn wait(millis: u64) -> u64 {
    sleep(Duration::from_millis(millis)).await;
    millis
}

// This pseudo-implementation simply removes the last
// future and awaits it. I'm looking for something that
// instead polls all futures until one is finished, then
// removes that future from the Vec and returns it.
async fn select<F, O>(futures: &mut Vec<F>) -> O
where
    F: Future<Output=O>
{
    let future = futures.pop().unwrap();
    future.await
}

#[tokio::main]
async fn main() {
    let mut futures = vec![
        wait(500),
        wait(300),
        wait(100),
        wait(200),
    ];
    while !futures.is_empty() {
        let finished = select(&mut futures).await;
        println!("Waited {}ms", finished);
        if some_condition() {
            futures.push(wait(200));
        }
    }
}

Florian Brucker
  • 9,621
  • 3
  • 48
  • 81

2 Answers2

17

This is exactly what futures::stream::FuturesUnordered is for (which I've found by looking through the source of StreamExt::for_each_concurrent):

use futures::{stream::FuturesUnordered, StreamExt};
use std::time::Duration;
use tokio::time::{sleep, Instant};

async fn wait(millis: u64) -> u64 {
    sleep(Duration::from_millis(millis)).await;
    millis
}

#[tokio::main]
async fn main() {
    let mut futures = FuturesUnordered::new();
    futures.push(wait(500));
    futures.push(wait(300));
    futures.push(wait(100));
    futures.push(wait(200));
    
    let start_time = Instant::now();

    let mut num_added = 0;
    while let Some(wait_time) = futures.next().await {
        println!("Waited {}ms", wait_time);
        if num_added < 3 {
            num_added += 1;
            futures.push(wait(200));
        }
    }
    
    println!("Completed all work in {}ms", start_time.elapsed().as_millis());
}

(playground)

A word of caution if you're using Tokio: As @Bryan Larsen has pointed out in a comment, there is the risk of performance problems when combining FuturesUnordered with Tokio. This article contains more details, and says that the issue should be fixed in recents versions of the futures crate (0.3.19 and later). Nevertheless, users of Tokio are better off with using Tokio's JoinSet. The same example as above then looks like this:

use std::time::Duration;
use tokio::task::JoinSet;
use tokio::time::{sleep, Instant};

async fn wait(millis: u64) -> u64 {
    sleep(Duration::from_millis(millis)).await;
    millis
}

#[tokio::main]
async fn main() {
    let mut futures = JoinSet::new();
    futures.spawn(wait(500));
    futures.spawn(wait(300));
    futures.spawn(wait(100));
    futures.spawn(wait(200));

    let start_time = Instant::now();

    let mut num_added = 0;
    while let Some(result) = futures.join_next().await {
        let wait_time = result.unwrap();
        println!("Waited {}ms", wait_time);
        if num_added < 3 {
            num_added += 1;
            futures.spawn(wait(200));
        }
    }

    println!(
        "Completed all work in {}ms",
        start_time.elapsed().as_millis()
    );
}

(playground)

Florian Brucker
  • 9,621
  • 3
  • 48
  • 81
1

Here's a working prototype based on streams and StreamExt::for_each_concurrent, as Martin Gallagher has suggested in a comment:

use std::time::Duration;
use tokio::sync::RwLock;
use tokio::time::sleep;

use futures::stream::{self, StreamExt};
use futures::{channel::mpsc, sink::SinkExt};

async fn wait(millis: u64) -> u64 {
    sleep(Duration::from_millis(millis)).await;
    millis
}

#[tokio::main]
async fn main() {
    let (mut sink, futures_stream) = mpsc::unbounded();

    let start_futures = vec![wait(500), wait(300), wait(100), wait(200)];

    let num_futures = RwLock::new(start_futures.len());

    sink.send_all(&mut stream::iter(start_futures.into_iter().map(Ok)))
        .await
        .unwrap();

    let sink_lock = RwLock::new(sink);

    futures_stream
        .for_each_concurrent(None, |fut| async {
            let wait_time = fut.await;
            println!("Waited {}", wait_time);
            if some_condition() {
                println!("Adding new future");
                let mut sink = sink_lock.write().await;
                sink.send(wait(100)).await.unwrap();
            } else {
                let mut num_futures = num_futures.write().await;
                *num_futures -= 1;
            }
            let num_futures = num_futures.read().await;
            if *num_futures <= 0 {
                // Close the sink to exit the for_each_concurrent
                sink_lock.write().await.close().await.unwrap();
            }
        })
        .await;
}

While this approach works it has the drawback that we need to maintain a separate counter of remaining futures so that we can close the sink -- there's no Vec of futures for which we can check whether it's empty. Closing the sink requires another lock.

Given that I'm fairly new to Rust I wouldn't be surprised if this approach could be made more elegant.

Florian Brucker
  • 9,621
  • 3
  • 48
  • 81