Search code examples
rustrust-tokiorust-sqlx

How to cancel a long-running query when using rust-sqlx/tokio


I am migrating from rusqlite where I was using get_interrupt_handle to abort a query immediately from another thread (when the user changed the filter parameters).

Here's an example of my current code. The best I can do is add an interrupt check before every await but that doesn't help if the initial query is taking ages to return the first result.

struct Query {
    title: String,
}

fn start_async(requests: crossbeam::channel::Receiver<Query>) {
    thread::spawn(move || {
        let runtime = tokio::runtime::Builder::new_current_thread()
            .enable_all()
            .build()
            .unwrap();
        runtime.block_on(run_db_thread(requests));
    });
}

async fn run_db_thread(requests: crossbeam::channel::Sender<Query>) {
    let connection = SqliteConnection::connect("test.sqlite").await?;
    loop {
        if let Ok(query) = requests.recv() {
            do_query(&connection, &query).await?;
        }
    }
}

async fn do_query(connection: &SqliteConnection, query: &Query) -> Result<(), Box<dyn Error>> {
    let mut stream = sqlx::query("SELECT title, authors, series FROM Books where title like ?")
        .bind(&format!("%{}%", query.title))
        .fetch(&connection);
    while let Some(row) = stream.next().await {
        let (title, authors, series) = row?;
        println!("{} {} {}", title, authors, series);
    }
}

Is there a way to interrupt a running sqlx execution when a new Query arrives in the channel? I'd be happy to send a signal separately if necessary.


Solution

  • All futures are inherently cancellable — this is one of the benefits (and hazards) of async over blocking multithreaded code. You simply drop the future instead of polling it further.

    The first thing you will need to do is change from a blocking channel to an async channel — this allows checking the channel to be mixed with running the query. Then, you can use various future manipulating tools to decide whether or not to continue running the query. I've decided to do this with select!, which polls several futures and runs code according to whichever one completes first.

    (There might be a better tool for this; I am familiar with how async works but haven't written a lot of complex async code.)

    use futures::future::OptionFuture;
    use std::time::Duration;
    use tokio::sync::mpsc;
    
    async fn run_db_thread(mut requests: mpsc::Receiver<Query>) {
        // This variable holds the current query being run, if there is one
        let mut current_query_future = OptionFuture::default();
    
        loop {
            tokio::select! {
                // If we receive a new query, replace the current query with it.
                Some(query) = requests.recv() => {
                    println!("Starting new query {query:?}");
                    current_query_future = OptionFuture::from(Some(Box::pin(async move {
                        let answer = do_query(&query).await;
                        println!("Finished query {query:?} => {answer:?}");
                        answer
                    })));
                    // Note that we did not `.await` the new future.
                }
    
                // Poll the current query future, and check if it is done yet.
                Some(_answer) = &mut current_query_future => {
                    // Stop polling the completed future.
                    current_query_future = None.into();
                }
    
                // We get here if both of the above branches saw None, which means that the
                // channel is closed, *and* there is no query to run.
                else => {
                    println!("Channel closed; run_db_thread() exiting");
                    break;
                }
            }
        }
    }
    
    /// Example to drive the loop.
    #[tokio::main]
    async fn main() {
        let (sender, receiver) = mpsc::channel(1);
        tokio::spawn(run_db_thread(receiver));
    
        for (i, delay) in [1000, 1000, 1, 1, 1, 1000, 1000].into_iter().enumerate() {
            sender.send(Query(i)).await.unwrap();
            tokio::time::sleep(Duration::from_millis(delay)).await;
        }
        println!("main() exiting");
    }
    
    // Skeleton data types to make the example compile
    #[derive(Debug)]
    #[allow(dead_code)]
    struct Query(usize);
    #[derive(Debug)]
    struct Answer;
    async fn do_query(_q: &Query) -> Answer {
        tokio::time::sleep(Duration::from_millis(100)).await;
        Answer
    }
    

    This example code prints:

    Starting new query Query(0)
    Finished query Query(0) => Answer
    Starting new query Query(1)
    Finished query Query(1) => Answer
    Starting new query Query(2)
    Starting new query Query(3)
    Starting new query Query(4)
    Starting new query Query(5)
    Finished query Query(5) => Answer
    Starting new query Query(6)
    Finished query Query(6) => Answer
    main() exiting
    

    That is, queries 0, 1, 5, and 6 completed, but queries 2, 3, and 4 were cancelled by a new query arriving before they could complete.