]> git.feebdaed.xyz Git - 0xmirror/tokio.git/commitdiff
coop: add `cooperative` and `poll_proceed` (#7405)
authorPepijn Van Eeckhoudt <pepijn@vaneeckhoudt.net>
Fri, 11 Jul 2025 08:07:44 +0000 (10:07 +0200)
committerGitHub <noreply@github.com>
Fri, 11 Jul 2025 08:07:44 +0000 (08:07 +0000)
tokio/src/task/coop/mod.rs
tokio/tests/async_send_sync.rs

index 58f0c848d8259cb2527f1fba3f74ca006bd2fda9..6c4588ccf926818d6f9172713a5970f44bf8e0c0 100644 (file)
@@ -250,14 +250,27 @@ cfg_coop! {
     use pin_project_lite::pin_project;
     use std::cell::Cell;
     use std::future::Future;
+    use std::marker::PhantomData;
     use std::pin::Pin;
     use std::task::{ready, Context, Poll};
 
+    /// Value returned by the [`poll_proceed`] method.
+    #[derive(Debug)]
     #[must_use]
-    pub(crate) struct RestoreOnPending(Cell<Budget>);
+    pub struct RestoreOnPending(Cell<Budget>, PhantomData<*mut ()>);
 
     impl RestoreOnPending {
-        pub(crate) fn made_progress(&self) {
+        fn new(budget: Budget) -> Self {
+            RestoreOnPending(
+                Cell::new(budget),
+                PhantomData,
+            )
+        }
+
+        /// Signals that the task that obtained this `RestoreOnPending` was able to make
+        /// progress. This prevents the task budget from being restored to the value
+        /// it had prior to obtaining this instance when it is dropped.
+        pub fn made_progress(&self) {
             self.0.set(Budget::unconstrained());
         }
     }
@@ -275,27 +288,102 @@ cfg_coop! {
         }
     }
 
-    /// Returns `Poll::Pending` if the current task has exceeded its budget and should yield.
+    /// Decrements the task budget and returns [`Poll::Pending`] if the budget is depleted.
+    /// This indicates that the task should yield to the scheduler. Otherwise, returns
+    /// [`RestoreOnPending`] which can be used to commit the budget consumption.
     ///
-    /// When you call this method, the current budget is decremented. However, to ensure that
-    /// progress is made every time a task is polled, the budget is automatically restored to its
-    /// former value if the returned `RestoreOnPending` is dropped. It is the caller's
-    /// responsibility to call `RestoreOnPending::made_progress` if it made progress, to ensure
-    /// that the budget empties appropriately.
+    /// The returned [`RestoreOnPending`] will revert the budget to its former
+    /// value when dropped unless [`RestoreOnPending::made_progress`]
+    /// is called. It is the caller's responsibility to do so when it _was_ able to
+    /// make progress after the call to [`poll_proceed`].
+    /// Restoring the budget automatically ensures the task can try to make progress in some other
+    /// way.
     ///
-    /// Note that `RestoreOnPending` restores the budget **as it was before `poll_proceed`**.
-    /// Therefore, if the budget is _further_ adjusted between when `poll_proceed` returns and
-    /// `RestRestoreOnPending` is dropped, those adjustments are erased unless the caller indicates
+    /// Note that [`RestoreOnPending`] restores the budget **as it was before [`poll_proceed`]**.
+    /// Therefore, if the budget is _further_ adjusted between when [`poll_proceed`] returns and
+    /// [`RestoreOnPending`] is dropped, those adjustments are erased unless the caller indicates
     /// that progress was made.
+    ///
+    /// # Examples
+    ///
+    /// This example shows a simple countdown latch that uses [`poll_proceed`] to participate in
+    /// cooperative scheduling.
+    ///
+    /// ```
+    /// use std::future::{Future};
+    /// use std::pin::Pin;
+    /// use std::task::{ready, Context, Poll, Waker};
+    /// use tokio::task::coop;
+    ///
+    /// struct CountdownLatch<T> {
+    ///     counter: usize,
+    ///     value: Option<T>,
+    ///     waker: Option<Waker>
+    /// }
+    ///
+    /// impl<T> CountdownLatch<T> {
+    ///     fn new(value: T, count: usize) -> Self {
+    ///         CountdownLatch {
+    ///             counter: count,
+    ///             value: Some(value),
+    ///             waker: None
+    ///         }
+    ///     }
+    ///     fn count_down(&mut self) {
+    ///         if self.counter <= 0 {
+    ///             return;
+    ///         }
+    ///
+    ///         self.counter -= 1;
+    ///         if self.counter == 0 {
+    ///             if let Some(w) = self.waker.take() {
+    ///                 w.wake();
+    ///             }
+    ///         }
+    ///     }
+    /// }
+    ///
+    /// impl<T> Future for CountdownLatch<T> {
+    ///     type Output = T;
+    ///
+    ///     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+    ///         // `poll_proceed` checks with the runtime if this task is still allowed to proceed
+    ///         // with performing work.
+    ///         // If not, `Pending` is returned and `ready!` ensures this function returns.
+    ///         // If we are allowed to proceed, coop now represents the budget consumption
+    ///         let coop = ready!(coop::poll_proceed(cx));
+    ///
+    ///         // Get a mutable reference to the CountdownLatch
+    ///         let this = Pin::get_mut(self);
+    ///
+    ///         // Next we check if the latch is ready to release its value
+    ///         if this.counter == 0 {
+    ///             let t = this.value.take();
+    ///             // The latch made progress so call `made_progress` to ensure the budget
+    ///             // is not reverted.
+    ///             coop.made_progress();
+    ///             Poll::Ready(t.unwrap())
+    ///         } else {
+    ///             // If the latch is not ready so return pending and simply drop `coop`.
+    ///             // This will restore the budget making it available again to perform any
+    ///             // other work.
+    ///             this.waker = Some(cx.waker().clone());
+    ///             Poll::Pending
+    ///         }
+    ///     }
+    /// }
+    ///
+    /// impl<T> Unpin for CountdownLatch<T> {}
+    /// ```
     #[inline]
-    pub(crate) fn poll_proceed(cx: &mut Context<'_>) -> Poll<RestoreOnPending> {
+    pub fn poll_proceed(cx: &mut Context<'_>) -> Poll<RestoreOnPending> {
         context::budget(|cell| {
             let mut budget = cell.get();
 
             let decrement = budget.decrement();
 
             if decrement.success {
-                let restore = RestoreOnPending(Cell::new(cell.get()));
+                let restore = RestoreOnPending::new(cell.get());
                 cell.set(budget);
 
                 // avoid double counting
@@ -308,7 +396,7 @@ cfg_coop! {
                 register_waker(cx);
                 Poll::Pending
             }
-        }).unwrap_or(Poll::Ready(RestoreOnPending(Cell::new(Budget::unconstrained()))))
+        }).unwrap_or(Poll::Ready(RestoreOnPending::new(Budget::unconstrained())))
     }
 
     /// Returns `Poll::Ready` if the current task has budget to consume, and `Poll::Pending` otherwise.
@@ -380,15 +468,9 @@ cfg_coop! {
     }
 
     pin_project! {
-        /// Future wrapper to ensure cooperative scheduling.
-        ///
-        /// When being polled `poll_proceed` is called before the inner future is polled to check
-        /// if the inner future has exceeded its budget. If the inner future resolves, this will
-        /// automatically call `RestoreOnPending::made_progress` before resolving this future with
-        /// the result of the inner one. If polling the inner future is pending, polling this future
-        /// type will also return a `Poll::Pending`.
+        /// Future wrapper to ensure cooperative scheduling created by [`cooperative`].
         #[must_use = "futures do nothing unless polled"]
-        pub(crate) struct Coop<F: Future> {
+        pub struct Coop<F: Future> {
             #[pin]
             pub(crate) fut: F,
         }
@@ -409,11 +491,39 @@ cfg_coop! {
         }
     }
 
-    /// Run a future with a budget constraint for cooperative scheduling.
-    /// If the future exceeds its budget while being polled, control is yielded back to the
-    /// runtime.
+    /// Creates a wrapper future that makes the inner future cooperate with the Tokio scheduler.
+    ///
+    /// When polled, the wrapper will first call [`poll_proceed`] to consume task budget, and
+    /// immediately yield if the budget has been depleted. If budget was available, the inner future
+    /// is polled. The budget consumption will be made final using [`RestoreOnPending::made_progress`]
+    /// if the inner future resolves to its final value.
+    ///
+    /// # Examples
+    ///
+    /// When you call `recv` on the `Receiver` of a [`tokio::sync::mpsc`](crate::sync::mpsc)
+    /// channel, task budget will automatically be consumed when the next value is returned.
+    /// This makes tasks that use Tokio mpsc channels automatically cooperative.
+    ///
+    /// If you're using [`futures::channel::mpsc`](https://docs.rs/futures/latest/futures/channel/mpsc/index.html)
+    /// instead, automatic task budget consumption will not happen. This example shows how can use
+    /// `cooperative` to make `futures::channel::mpsc` channels cooperate with the scheduler in the
+    /// same way Tokio channels do.
+    ///
+    /// ```
+    /// use tokio::task::coop::cooperative;
+    /// use futures::channel::mpsc::Receiver;
+    /// use futures::stream::StreamExt;
+    ///
+    /// async fn receive_next<T>(receiver: &mut Receiver<T>) -> Option<T> {
+    ///     // Use `StreamExt::next` to obtain a `Future` that resolves to the next value
+    ///     let recv_future = receiver.next();
+    ///     // Wrap it a cooperative wrapper
+    ///     let coop_future = cooperative(recv_future);
+    ///     // And await
+    ///     coop_future.await
+    /// }
     #[inline]
-    pub(crate) fn cooperative<F: Future>(fut: F) -> Coop<F> {
+    pub fn cooperative<F: Future>(fut: F) -> Coop<F> {
         Coop { fut }
     }
 }
index aa668ce93ecef0817a1c47fa7ba834f4142b77e9..c9cedc38b02402cb92adcf93ff00cab60922999d 100644 (file)
@@ -454,6 +454,7 @@ assert_value!(tokio::task::JoinSet<NN>: !Send & !Sync & Unpin);
 assert_value!(tokio::task::JoinSet<YN>: Send & Sync & Unpin);
 assert_value!(tokio::task::JoinSet<YY>: Send & Sync & Unpin);
 assert_value!(tokio::task::LocalSet: !Send & !Sync & Unpin);
+assert_value!(tokio::task::coop::RestoreOnPending: !Send & !Sync & Unpin);
 async_assert_fn!(tokio::sync::Barrier::wait(_): Send & Sync & !Unpin);
 async_assert_fn!(tokio::sync::Mutex<NN>::lock(_): !Send & !Sync & !Unpin);
 async_assert_fn!(tokio::sync::Mutex<NN>::lock_owned(_): !Send & !Sync & !Unpin);