]> git.feebdaed.xyz Git - 0xmirror/tokio.git/commitdiff
task: add `try_join_next` and `try_join_next_with_id` on `JoinQueue` (#7636)
authorNikolai Kuklin <nickkuklin@gmail.com>
Thu, 25 Sep 2025 12:31:31 +0000 (14:31 +0200)
committerGitHub <noreply@github.com>
Thu, 25 Sep 2025 12:31:31 +0000 (20:31 +0800)
tokio-util/src/task/join_queue.rs
tokio-util/tests/task_join_queue.rs

index 744b9cfb8b5692dad89a97c0a6663ea12204250a..9baf78d2e5a4fef3433de1aebd5812471c9c451a 100644 (file)
@@ -183,6 +183,59 @@ impl<T> JoinQueue<T> {
         std::future::poll_fn(|cx| self.poll_join_next_with_id(cx)).await
     }
 
+    /// Tries to poll an `AbortOnDropHandle` without blocking or yielding.
+    ///
+    /// Note that on success the handle will panic on subsequent polls
+    /// since it becomes consumed.
+    fn try_poll_handle(jh: &mut AbortOnDropHandle<T>) -> Option<Result<T, JoinError>> {
+        let waker = futures_util::task::noop_waker();
+        let mut cx = Context::from_waker(&waker);
+
+        // Since this function is not async and cannot be forced to yield, we should
+        // disable budgeting when we want to check for the `JoinHandle` readiness.
+        let jh = std::pin::pin!(tokio::task::coop::unconstrained(jh));
+        if let Poll::Ready(res) = jh.poll(&mut cx) {
+            Some(res)
+        } else {
+            None
+        }
+    }
+
+    /// Tries to join the next task in FIFO order if it has completed.
+    ///
+    /// Returns `None` if the queue is empty or if the next task is not yet ready.
+    pub fn try_join_next(&mut self) -> Option<Result<T, JoinError>> {
+        let jh = self.0.front_mut()?;
+        let res = Self::try_poll_handle(jh)?;
+        // Use `detach` to avoid calling `abort` on a task that has already completed.
+        // Dropping `AbortOnDropHandle` would abort the task, but since it is finished,
+        // we only need to drop the `JoinHandle` for cleanup.
+        drop(self.0.pop_front().unwrap().detach());
+        Some(res)
+    }
+
+    /// Tries to join the next task in FIFO order if it has completed and return its output,
+    /// along with its [task ID].
+    ///
+    /// Returns `None` if the queue is empty or if the next task is not yet ready.
+    ///
+    /// When this method returns an error, then the id of the task that failed can be accessed
+    /// using the [`JoinError::id`] method.
+    ///
+    /// [task ID]: tokio::task::Id
+    /// [`JoinError::id`]: fn@tokio::task::JoinError::id
+    pub fn try_join_next_with_id(&mut self) -> Option<Result<(Id, T), JoinError>> {
+        let jh = self.0.front_mut()?;
+        let res = Self::try_poll_handle(jh)?;
+        // Use `detach` to avoid calling `abort` on a task that has already completed.
+        // Dropping `AbortOnDropHandle` would abort the task, but since it is finished,
+        // we only need to drop the `JoinHandle` for cleanup.
+        let jh = self.0.pop_front().unwrap().detach();
+        let id = jh.id();
+        drop(jh);
+        Some(res.map(|output| (id, output)))
+    }
+
     /// Aborts all tasks and waits for them to finish shutting down.
     ///
     /// Calling this method is equivalent to calling [`abort_all`] and then calling [`join_next`] in
index 6b23aa2fd49ce510a23504d291c75c33a36bedb1..cf29260bbff88e9e0eee3f1f0eac411284c9ffd2 100644 (file)
@@ -192,12 +192,12 @@ async fn test_join_queue_join_next_with_id() {
 
     let (send, recv) = tokio::sync::watch::channel(());
 
-    let mut set = JoinQueue::new();
+    let mut queue = JoinQueue::new();
     let mut spawned = Vec::with_capacity(TASK_NUM as usize);
 
     for _ in 0..TASK_NUM {
         let mut recv = recv.clone();
-        let handle = set.spawn(async move { recv.changed().await.unwrap() });
+        let handle = queue.spawn(async move { recv.changed().await.unwrap() });
 
         spawned.push(handle.id());
     }
@@ -208,7 +208,7 @@ async fn test_join_queue_join_next_with_id() {
 
     let mut count = 0;
     let mut joined = Vec::with_capacity(TASK_NUM as usize);
-    while let Some(res) = set.join_next_with_id().await {
+    while let Some(res) = queue.join_next_with_id().await {
         match res {
             Ok((id, ())) => {
                 count += 1;
@@ -221,3 +221,141 @@ async fn test_join_queue_join_next_with_id() {
     assert_eq!(count, TASK_NUM);
     assert_eq!(joined, spawned);
 }
+
+#[tokio::test]
+async fn test_join_queue_try_join_next() {
+    let mut queue = JoinQueue::new();
+    let (tx1, rx1) = oneshot::channel::<()>();
+    queue.spawn(async {
+        let _ = rx1.await;
+    });
+    let (tx2, rx2) = oneshot::channel::<()>();
+    queue.spawn(async {
+        let _ = rx2.await;
+    });
+    let (tx3, rx3) = oneshot::channel::<()>();
+    queue.spawn(async {
+        let _ = rx3.await;
+    });
+
+    // This function also checks that calling `queue.try_join_next()` repeatedly when
+    // no task is ready is idempotent, i.e. that it does not change the queue state.
+    fn check_try_join_next_is_noop(queue: &mut JoinQueue<()>) {
+        let len = queue.len();
+        for _ in 0..5 {
+            assert!(queue.try_join_next().is_none());
+            assert_eq!(queue.len(), len);
+        }
+    }
+
+    assert_eq!(queue.len(), 3);
+    check_try_join_next_is_noop(&mut queue);
+
+    tx1.send(()).unwrap();
+    tokio::task::yield_now().await;
+
+    assert_eq!(queue.len(), 3);
+    assert!(queue.try_join_next().is_some());
+    assert_eq!(queue.len(), 2);
+    check_try_join_next_is_noop(&mut queue);
+
+    tx3.send(()).unwrap();
+    tokio::task::yield_now().await;
+
+    assert_eq!(queue.len(), 2);
+    check_try_join_next_is_noop(&mut queue);
+
+    tx2.send(()).unwrap();
+    tokio::task::yield_now().await;
+
+    assert_eq!(queue.len(), 2);
+    assert!(queue.try_join_next().is_some());
+    assert_eq!(queue.len(), 1);
+    assert!(queue.try_join_next().is_some());
+    assert!(queue.is_empty());
+    check_try_join_next_is_noop(&mut queue);
+}
+
+#[tokio::test]
+async fn test_join_queue_try_join_next_disabled_coop() {
+    // This number is large enough to trigger coop. Without using `tokio::task::coop::unconstrained`
+    // inside `try_join_next` this test fails on `assert!(coop_count == 0)`.
+    const TASK_NUM: u32 = 1000;
+
+    let sem: std::sync::Arc<tokio::sync::Semaphore> =
+        std::sync::Arc::new(tokio::sync::Semaphore::new(0));
+
+    let mut queue = JoinQueue::new();
+
+    for _ in 0..TASK_NUM {
+        let sem = sem.clone();
+        queue.spawn(async move {
+            sem.add_permits(1);
+        });
+    }
+
+    let _ = sem.acquire_many(TASK_NUM).await.unwrap();
+
+    let mut count = 0;
+    let mut coop_count = 0;
+    while !queue.is_empty() {
+        match queue.try_join_next() {
+            Some(Ok(())) => count += 1,
+            Some(Err(err)) => panic!("failed: {err}"),
+            None => {
+                coop_count += 1;
+                tokio::task::yield_now().await;
+            }
+        }
+    }
+    assert_eq!(coop_count, 0);
+    assert_eq!(count, TASK_NUM);
+}
+
+#[tokio::test]
+async fn test_join_queue_try_join_next_with_id_disabled_coop() {
+    // Note that this number is large enough to trigger coop as in
+    // `test_join_queue_try_join_next_coop` test. Without using
+    // `tokio::task::coop::unconstrained` inside `try_join_next_with_id`
+    // this test fails on `assert_eq!(count, TASK_NUM)`.
+    const TASK_NUM: u32 = 1000;
+
+    let (send, recv) = tokio::sync::watch::channel(());
+
+    let mut queue = JoinQueue::new();
+    let mut spawned = Vec::with_capacity(TASK_NUM as usize);
+
+    for _ in 0..TASK_NUM {
+        let mut recv = recv.clone();
+        let handle = queue.spawn(async move { recv.changed().await.unwrap() });
+
+        spawned.push(handle.id());
+    }
+    drop(recv);
+
+    assert!(queue.try_join_next_with_id().is_none());
+
+    send.send_replace(());
+    send.closed().await;
+
+    let mut count = 0;
+    let mut coop_count = 0;
+    let mut joined = Vec::with_capacity(TASK_NUM as usize);
+    while !queue.is_empty() {
+        match queue.try_join_next_with_id() {
+            Some(Ok((id, ()))) => {
+                count += 1;
+                joined.push(id);
+            }
+            Some(Err(err)) => panic!("failed: {err}"),
+            None => {
+                coop_count += 1;
+                tokio::task::yield_now().await;
+            }
+        }
+    }
+
+    assert_eq!(coop_count, 0);
+    assert_eq!(count, TASK_NUM);
+    assert_eq!(joined, spawned);
+}