std::future::poll_fn(|cx| self.poll_join_next_with_id(cx)).await
}
+ /// Tries to poll an `AbortOnDropHandle` without blocking or yielding.
+ ///
+ /// Note that on success the handle will panic on subsequent polls
+ /// since it becomes consumed.
+ fn try_poll_handle(jh: &mut AbortOnDropHandle<T>) -> Option<Result<T, JoinError>> {
+ let waker = futures_util::task::noop_waker();
+ let mut cx = Context::from_waker(&waker);
+
+ // Since this function is not async and cannot be forced to yield, we should
+ // disable budgeting when we want to check for the `JoinHandle` readiness.
+ let jh = std::pin::pin!(tokio::task::coop::unconstrained(jh));
+ if let Poll::Ready(res) = jh.poll(&mut cx) {
+ Some(res)
+ } else {
+ None
+ }
+ }
+
+ /// Tries to join the next task in FIFO order if it has completed.
+ ///
+ /// Returns `None` if the queue is empty or if the next task is not yet ready.
+ pub fn try_join_next(&mut self) -> Option<Result<T, JoinError>> {
+ let jh = self.0.front_mut()?;
+ let res = Self::try_poll_handle(jh)?;
+ // Use `detach` to avoid calling `abort` on a task that has already completed.
+ // Dropping `AbortOnDropHandle` would abort the task, but since it is finished,
+ // we only need to drop the `JoinHandle` for cleanup.
+ drop(self.0.pop_front().unwrap().detach());
+ Some(res)
+ }
+
+ /// Tries to join the next task in FIFO order if it has completed and return its output,
+ /// along with its [task ID].
+ ///
+ /// Returns `None` if the queue is empty or if the next task is not yet ready.
+ ///
+ /// When this method returns an error, then the id of the task that failed can be accessed
+ /// using the [`JoinError::id`] method.
+ ///
+ /// [task ID]: tokio::task::Id
+ /// [`JoinError::id`]: fn@tokio::task::JoinError::id
+ pub fn try_join_next_with_id(&mut self) -> Option<Result<(Id, T), JoinError>> {
+ let jh = self.0.front_mut()?;
+ let res = Self::try_poll_handle(jh)?;
+ // Use `detach` to avoid calling `abort` on a task that has already completed.
+ // Dropping `AbortOnDropHandle` would abort the task, but since it is finished,
+ // we only need to drop the `JoinHandle` for cleanup.
+ let jh = self.0.pop_front().unwrap().detach();
+ let id = jh.id();
+ drop(jh);
+ Some(res.map(|output| (id, output)))
+ }
+
/// Aborts all tasks and waits for them to finish shutting down.
///
/// Calling this method is equivalent to calling [`abort_all`] and then calling [`join_next`] in
let (send, recv) = tokio::sync::watch::channel(());
- let mut set = JoinQueue::new();
+ let mut queue = JoinQueue::new();
let mut spawned = Vec::with_capacity(TASK_NUM as usize);
for _ in 0..TASK_NUM {
let mut recv = recv.clone();
- let handle = set.spawn(async move { recv.changed().await.unwrap() });
+ let handle = queue.spawn(async move { recv.changed().await.unwrap() });
spawned.push(handle.id());
}
let mut count = 0;
let mut joined = Vec::with_capacity(TASK_NUM as usize);
- while let Some(res) = set.join_next_with_id().await {
+ while let Some(res) = queue.join_next_with_id().await {
match res {
Ok((id, ())) => {
count += 1;
assert_eq!(count, TASK_NUM);
assert_eq!(joined, spawned);
}
+
+#[tokio::test]
+async fn test_join_queue_try_join_next() {
+ let mut queue = JoinQueue::new();
+ let (tx1, rx1) = oneshot::channel::<()>();
+ queue.spawn(async {
+ let _ = rx1.await;
+ });
+ let (tx2, rx2) = oneshot::channel::<()>();
+ queue.spawn(async {
+ let _ = rx2.await;
+ });
+ let (tx3, rx3) = oneshot::channel::<()>();
+ queue.spawn(async {
+ let _ = rx3.await;
+ });
+
+ // This function also checks that calling `queue.try_join_next()` repeatedly when
+ // no task is ready is idempotent, i.e. that it does not change the queue state.
+ fn check_try_join_next_is_noop(queue: &mut JoinQueue<()>) {
+ let len = queue.len();
+ for _ in 0..5 {
+ assert!(queue.try_join_next().is_none());
+ assert_eq!(queue.len(), len);
+ }
+ }
+
+ assert_eq!(queue.len(), 3);
+ check_try_join_next_is_noop(&mut queue);
+
+ tx1.send(()).unwrap();
+ tokio::task::yield_now().await;
+
+ assert_eq!(queue.len(), 3);
+ assert!(queue.try_join_next().is_some());
+ assert_eq!(queue.len(), 2);
+ check_try_join_next_is_noop(&mut queue);
+
+ tx3.send(()).unwrap();
+ tokio::task::yield_now().await;
+
+ assert_eq!(queue.len(), 2);
+ check_try_join_next_is_noop(&mut queue);
+
+ tx2.send(()).unwrap();
+ tokio::task::yield_now().await;
+
+ assert_eq!(queue.len(), 2);
+ assert!(queue.try_join_next().is_some());
+ assert_eq!(queue.len(), 1);
+ assert!(queue.try_join_next().is_some());
+ assert!(queue.is_empty());
+ check_try_join_next_is_noop(&mut queue);
+}
+
+#[tokio::test]
+async fn test_join_queue_try_join_next_disabled_coop() {
+ // This number is large enough to trigger coop. Without using `tokio::task::coop::unconstrained`
+ // inside `try_join_next` this test fails on `assert!(coop_count == 0)`.
+ const TASK_NUM: u32 = 1000;
+
+ let sem: std::sync::Arc<tokio::sync::Semaphore> =
+ std::sync::Arc::new(tokio::sync::Semaphore::new(0));
+
+ let mut queue = JoinQueue::new();
+
+ for _ in 0..TASK_NUM {
+ let sem = sem.clone();
+ queue.spawn(async move {
+ sem.add_permits(1);
+ });
+ }
+
+ let _ = sem.acquire_many(TASK_NUM).await.unwrap();
+
+ let mut count = 0;
+ let mut coop_count = 0;
+ while !queue.is_empty() {
+ match queue.try_join_next() {
+ Some(Ok(())) => count += 1,
+ Some(Err(err)) => panic!("failed: {err}"),
+ None => {
+ coop_count += 1;
+ tokio::task::yield_now().await;
+ }
+ }
+ }
+ assert_eq!(coop_count, 0);
+ assert_eq!(count, TASK_NUM);
+}
+
+#[tokio::test]
+async fn test_join_queue_try_join_next_with_id_disabled_coop() {
+ // Note that this number is large enough to trigger coop as in
+ // `test_join_queue_try_join_next_coop` test. Without using
+ // `tokio::task::coop::unconstrained` inside `try_join_next_with_id`
+ // this test fails on `assert_eq!(count, TASK_NUM)`.
+ const TASK_NUM: u32 = 1000;
+
+ let (send, recv) = tokio::sync::watch::channel(());
+
+ let mut queue = JoinQueue::new();
+ let mut spawned = Vec::with_capacity(TASK_NUM as usize);
+
+ for _ in 0..TASK_NUM {
+ let mut recv = recv.clone();
+ let handle = queue.spawn(async move { recv.changed().await.unwrap() });
+
+ spawned.push(handle.id());
+ }
+ drop(recv);
+
+ assert!(queue.try_join_next_with_id().is_none());
+
+ send.send_replace(());
+ send.closed().await;
+
+ let mut count = 0;
+ let mut coop_count = 0;
+ let mut joined = Vec::with_capacity(TASK_NUM as usize);
+ while !queue.is_empty() {
+ match queue.try_join_next_with_id() {
+ Some(Ok((id, ()))) => {
+ count += 1;
+ joined.push(id);
+ }
+ Some(Err(err)) => panic!("failed: {err}"),
+ None => {
+ coop_count += 1;
+ tokio::task::yield_now().await;
+ }
+ }
+ }
+
+ assert_eq!(coop_count, 0);
+ assert_eq!(count, TASK_NUM);
+ assert_eq!(joined, spawned);
+}