]> git.feebdaed.xyz Git - 0xmirror/tokio.git/commitdiff
sync: add `sync::Notify::notified_owned()` (#7465)
authorAria Andika <57490931+ariaandika@users.noreply.github.com>
Sat, 26 Jul 2025 13:45:34 +0000 (20:45 +0700)
committerGitHub <noreply@github.com>
Sat, 26 Jul 2025 13:45:34 +0000 (21:45 +0800)
tokio/src/sync/mod.rs
tokio/src/sync/notify.rs
tokio/tests/async_send_sync.rs
tokio/tests/sync_notify_owned.rs [new file with mode: 0644]

index a2502a76e785202c3ae7b6b7e7d7cb99ac61eea1..20eef09f164d087271ea5b9152e37309f029b369 100644 (file)
 cfg_sync! {
     /// Named future types.
     pub mod futures {
-        pub use super::notify::Notified;
+        pub use super::notify::{Notified, OwnedNotified};
     }
 
     mod barrier;
index dbdb9b1560941475f8ec7c7543316311c1de3676..d460797936dd27c1aaa6144561b754d6ec88990d 100644 (file)
@@ -17,6 +17,7 @@ use std::panic::{RefUnwindSafe, UnwindSafe};
 use std::pin::Pin;
 use std::ptr::NonNull;
 use std::sync::atomic::Ordering::{self, Acquire, Relaxed, Release, SeqCst};
+use std::sync::Arc;
 use std::task::{Context, Poll, Waker};
 
 type WaitList = LinkedList<Waiter, <Waiter as linked_list::Link>::Target>;
@@ -397,6 +398,38 @@ pub struct Notified<'a> {
 unsafe impl<'a> Send for Notified<'a> {}
 unsafe impl<'a> Sync for Notified<'a> {}
 
+/// Future returned from [`Notify::notified_owned()`].
+///
+/// This future is fused, so once it has completed, any future calls to poll
+/// will immediately return `Poll::Ready`.
+#[derive(Debug)]
+#[must_use = "futures do nothing unless you `.await` or poll them"]
+pub struct OwnedNotified {
+    /// The `Notify` being received on.
+    notify: Arc<Notify>,
+
+    /// The current state of the receiving process.
+    state: State,
+
+    /// Number of calls to `notify_waiters` at the time of creation.
+    notify_waiters_calls: usize,
+
+    /// Entry in the waiter `LinkedList`.
+    waiter: Waiter,
+}
+
+unsafe impl Sync for OwnedNotified {}
+
+/// A custom `project` implementation is used in place of `pin-project-lite`
+/// as a custom drop for [`Notified`] and [`OwnedNotified`] implementation
+/// is needed.
+struct NotifiedProject<'a> {
+    notify: &'a Notify,
+    state: &'a mut State,
+    notify_waiters_calls: &'a usize,
+    waiter: &'a Waiter,
+}
+
 #[derive(Debug)]
 enum State {
     Init,
@@ -541,6 +574,53 @@ impl Notify {
         }
     }
 
+    /// Wait for a notification with an owned `Future`.
+    ///
+    /// Unlike [`Self::notified`] which returns a future tied to the `Notify`'s
+    /// lifetime, `notified_owned` creates a self-contained future that owns its
+    /// notification state, making it safe to move between threads.
+    ///
+    /// See [`Self::notified`] for more details.
+    ///
+    /// # Cancel safety
+    ///
+    /// This method uses a queue to fairly distribute notifications in the order
+    /// they were requested. Cancelling a call to `notified_owned` makes you lose your
+    /// place in the queue.
+    ///
+    /// # Examples
+    ///
+    /// ```
+    /// use std::sync::Arc;
+    /// use tokio::sync::Notify;
+    ///
+    /// #[tokio::main]
+    /// async fn main() {
+    ///     let notify = Arc::new(Notify::new());
+    ///
+    ///     for _ in 0..10 {
+    ///         let notified = notify.clone().notified_owned();
+    ///         tokio::spawn(async move {
+    ///             notified.await;
+    ///             println!("received notification");
+    ///         });
+    ///     }
+    ///
+    ///     println!("sending notification");
+    ///     notify.notify_waiters();
+    /// }
+    /// ```
+    pub fn notified_owned(self: Arc<Self>) -> OwnedNotified {
+        // we load the number of times notify_waiters
+        // was called and store that in the future.
+        let state = self.state.load(SeqCst);
+        OwnedNotified {
+            notify: self,
+            state: State::Init,
+            notify_waiters_calls: get_num_notify_waiters_calls(state),
+            waiter: Waiter::new(),
+        }
+    }
     /// Notifies the first waiting task.
     ///
     /// If a task is currently waiting, that task is notified. Otherwise, a
@@ -911,9 +991,62 @@ impl Notified<'_> {
         self.poll_notified(None).is_ready()
     }
 
+    fn project(self: Pin<&mut Self>) -> NotifiedProject<'_> {
+        unsafe {
+            // Safety: `notify`, `state` and `notify_waiters_calls` are `Unpin`.
+
+            is_unpin::<&Notify>();
+            is_unpin::<State>();
+            is_unpin::<usize>();
+
+            let me = self.get_unchecked_mut();
+            NotifiedProject {
+                notify: me.notify,
+                state: &mut me.state,
+                notify_waiters_calls: &me.notify_waiters_calls,
+                waiter: &me.waiter,
+            }
+        }
+    }
+
+    fn poll_notified(self: Pin<&mut Self>, waker: Option<&Waker>) -> Poll<()> {
+        self.project().poll_notified(waker)
+    }
+}
+
+impl Future for Notified<'_> {
+    type Output = ();
+
+    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
+        self.poll_notified(Some(cx.waker()))
+    }
+}
+
+impl Drop for Notified<'_> {
+    fn drop(&mut self) {
+        // Safety: The type only transitions to a "Waiting" state when pinned.
+        unsafe { Pin::new_unchecked(self) }
+            .project()
+            .drop_notified();
+    }
+}
+
+// ===== impl OwnedNotified =====
+
+impl OwnedNotified {
+    /// Adds this future to the list of futures that are ready to receive
+    /// wakeups from calls to [`notify_one`].
+    ///
+    /// See [`Notified::enable`] for more details.
+    ///
+    /// [`notify_one`]: Notify::notify_one()
+    pub fn enable(self: Pin<&mut Self>) -> bool {
+        self.poll_notified(None).is_ready()
+    }
+
     /// A custom `project` implementation is used in place of `pin-project-lite`
     /// as a custom drop implementation is needed.
-    fn project(self: Pin<&mut Self>) -> (&Notify, &mut State, &usize, &Waiter) {
+    fn project(self: Pin<&mut Self>) -> NotifiedProject<'_> {
         unsafe {
             // Safety: `notify`, `state` and `notify_waiters_calls` are `Unpin`.
 
@@ -922,17 +1055,47 @@ impl Notified<'_> {
             is_unpin::<usize>();
 
             let me = self.get_unchecked_mut();
-            (
-                me.notify,
-                &mut me.state,
-                &me.notify_waiters_calls,
-                &me.waiter,
-            )
+            NotifiedProject {
+                notify: &me.notify,
+                state: &mut me.state,
+                notify_waiters_calls: &me.notify_waiters_calls,
+                waiter: &me.waiter,
+            }
         }
     }
 
     fn poll_notified(self: Pin<&mut Self>, waker: Option<&Waker>) -> Poll<()> {
-        let (notify, state, notify_waiters_calls, waiter) = self.project();
+        self.project().poll_notified(waker)
+    }
+}
+
+impl Future for OwnedNotified {
+    type Output = ();
+
+    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
+        self.poll_notified(Some(cx.waker()))
+    }
+}
+
+impl Drop for OwnedNotified {
+    fn drop(&mut self) {
+        // Safety: The type only transitions to a "Waiting" state when pinned.
+        unsafe { Pin::new_unchecked(self) }
+            .project()
+            .drop_notified();
+    }
+}
+
+// ===== impl NotifiedProject =====
+
+impl NotifiedProject<'_> {
+    fn poll_notified(self, waker: Option<&Waker>) -> Poll<()> {
+        let NotifiedProject {
+            notify,
+            state,
+            notify_waiters_calls,
+            waiter,
+        } = self;
 
         'outer_loop: loop {
             match *state {
@@ -1143,20 +1306,14 @@ impl Notified<'_> {
             }
         }
     }
-}
-
-impl Future for Notified<'_> {
-    type Output = ();
 
-    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
-        self.poll_notified(Some(cx.waker()))
-    }
-}
-
-impl Drop for Notified<'_> {
-    fn drop(&mut self) {
-        // Safety: The type only transitions to a "Waiting" state when pinned.
-        let (notify, state, _, waiter) = unsafe { Pin::new_unchecked(self).project() };
+    fn drop_notified(self) {
+        let NotifiedProject {
+            notify,
+            state,
+            waiter,
+            ..
+        } = self;
 
         // This is where we ensure safety. The `Notified` value is being
         // dropped, which means we must ensure that the waiter entry is no
index 7d4bb9bcd01d7184e5b596d92f6785ff5d5e1b1b..19cc6aed71140247c2cb18b230355793b17926ad 100644 (file)
@@ -401,6 +401,7 @@ assert_value!(tokio::sync::broadcast::WeakSender<NN>: !Send & !Sync & Unpin);
 assert_value!(tokio::sync::broadcast::WeakSender<YN>: Send & Sync & Unpin);
 assert_value!(tokio::sync::broadcast::WeakSender<YY>: Send & Sync & Unpin);
 assert_value!(tokio::sync::futures::Notified<'_>: Send & Sync & !Unpin);
+assert_value!(tokio::sync::futures::OwnedNotified: Send & Sync & !Unpin);
 assert_value!(tokio::sync::mpsc::OwnedPermit<NN>: !Send & !Sync & Unpin);
 assert_value!(tokio::sync::mpsc::OwnedPermit<YN>: Send & Sync & Unpin);
 assert_value!(tokio::sync::mpsc::OwnedPermit<YY>: Send & Sync & Unpin);
diff --git a/tokio/tests/sync_notify_owned.rs b/tokio/tests/sync_notify_owned.rs
new file mode 100644 (file)
index 0000000..06a0f6a
--- /dev/null
@@ -0,0 +1,304 @@
+#![warn(rust_2018_idioms)]
+#![cfg(feature = "sync")]
+
+#[cfg(all(target_family = "wasm", not(target_os = "wasi")))]
+use wasm_bindgen_test::wasm_bindgen_test as test;
+
+use std::sync::Arc;
+use tokio::sync::Notify;
+use tokio_test::task::spawn;
+use tokio_test::*;
+
+#[allow(unused)]
+trait AssertSend: Send + Sync {}
+impl AssertSend for Notify {}
+
+#[test]
+fn notify_notified_one() {
+    let notify = Arc::new(Notify::new());
+    let mut notified = spawn(async { notify.clone().notified_owned().await });
+
+    notify.notify_one();
+    assert_ready!(notified.poll());
+}
+
+#[test]
+fn notify_multi_notified_one() {
+    let notify = Arc::new(Notify::new());
+    let mut notified1 = spawn(async { notify.clone().notified_owned().await });
+    let mut notified2 = spawn(async { notify.clone().notified_owned().await });
+
+    // add two waiters into the queue
+    assert_pending!(notified1.poll());
+    assert_pending!(notified2.poll());
+
+    // should wakeup the first one
+    notify.notify_one();
+    assert_ready!(notified1.poll());
+    assert_pending!(notified2.poll());
+}
+
+#[test]
+fn notify_multi_notified_last() {
+    let notify = Arc::new(Notify::new());
+    let mut notified1 = spawn(async { notify.clone().notified_owned().await });
+    let mut notified2 = spawn(async { notify.clone().notified_owned().await });
+
+    // add two waiters into the queue
+    assert_pending!(notified1.poll());
+    assert_pending!(notified2.poll());
+
+    // should wakeup the last one
+    notify.notify_last();
+    assert_pending!(notified1.poll());
+    assert_ready!(notified2.poll());
+}
+
+#[test]
+fn notified_one_notify() {
+    let notify = Arc::new(Notify::new());
+    let mut notified = spawn(async { notify.clone().notified_owned().await });
+
+    assert_pending!(notified.poll());
+
+    notify.notify_one();
+    assert!(notified.is_woken());
+    assert_ready!(notified.poll());
+}
+
+#[test]
+fn notified_multi_notify() {
+    let notify = Arc::new(Notify::new());
+    let mut notified1 = spawn(async { notify.clone().notified_owned().await });
+    let mut notified2 = spawn(async { notify.clone().notified_owned().await });
+
+    assert_pending!(notified1.poll());
+    assert_pending!(notified2.poll());
+
+    notify.notify_one();
+    assert!(notified1.is_woken());
+    assert!(!notified2.is_woken());
+
+    assert_ready!(notified1.poll());
+    assert_pending!(notified2.poll());
+}
+
+#[test]
+fn notify_notified_multi() {
+    let notify = Arc::new(Notify::new());
+
+    notify.notify_one();
+
+    let mut notified1 = spawn(async { notify.clone().notified_owned().await });
+    let mut notified2 = spawn(async { notify.clone().notified_owned().await });
+
+    assert_ready!(notified1.poll());
+    assert_pending!(notified2.poll());
+
+    notify.notify_one();
+
+    assert!(notified2.is_woken());
+    assert_ready!(notified2.poll());
+}
+
+#[test]
+fn notified_drop_notified_notify() {
+    let notify = Arc::new(Notify::new());
+    let mut notified1 = spawn(async { notify.clone().notified_owned().await });
+    let mut notified2 = spawn(async { notify.clone().notified_owned().await });
+
+    assert_pending!(notified1.poll());
+
+    drop(notified1);
+
+    assert_pending!(notified2.poll());
+
+    notify.notify_one();
+    assert!(notified2.is_woken());
+    assert_ready!(notified2.poll());
+}
+
+#[test]
+fn notified_multi_notify_drop_one() {
+    let notify = Arc::new(Notify::new());
+    let mut notified1 = spawn(async { notify.clone().notified_owned().await });
+    let mut notified2 = spawn(async { notify.clone().notified_owned().await });
+
+    assert_pending!(notified1.poll());
+    assert_pending!(notified2.poll());
+
+    notify.notify_one();
+
+    assert!(notified1.is_woken());
+    assert!(!notified2.is_woken());
+
+    drop(notified1);
+
+    assert!(notified2.is_woken());
+    assert_ready!(notified2.poll());
+}
+
+#[test]
+fn notified_multi_notify_one_drop() {
+    let notify = Arc::new(Notify::new());
+    let mut notified1 = spawn(async { notify.clone().notified_owned().await });
+    let mut notified2 = spawn(async { notify.clone().notified_owned().await });
+    let mut notified3 = spawn(async { notify.clone().notified_owned().await });
+
+    // add waiters by order of poll execution
+    assert_pending!(notified1.poll());
+    assert_pending!(notified2.poll());
+    assert_pending!(notified3.poll());
+
+    // by default fifo
+    notify.notify_one();
+
+    drop(notified1);
+
+    // next waiter should be the one to be to woken up
+    assert_ready!(notified2.poll());
+    assert_pending!(notified3.poll());
+}
+
+#[test]
+fn notified_multi_notify_last_drop() {
+    let notify = Arc::new(Notify::new());
+    let mut notified1 = spawn(async { notify.clone().notified_owned().await });
+    let mut notified2 = spawn(async { notify.clone().notified_owned().await });
+    let mut notified3 = spawn(async { notify.clone().notified_owned().await });
+
+    // add waiters by order of poll execution
+    assert_pending!(notified1.poll());
+    assert_pending!(notified2.poll());
+    assert_pending!(notified3.poll());
+
+    notify.notify_last();
+
+    drop(notified3);
+
+    // latest waiter added should be the one to woken up
+    assert_ready!(notified2.poll());
+    assert_pending!(notified1.poll());
+}
+
+#[test]
+fn notify_in_drop_after_wake() {
+    use futures::task::ArcWake;
+    use std::future::Future;
+    use std::sync::Arc;
+
+    let notify = Arc::new(Notify::new());
+
+    struct NotifyOnDrop(Arc<Notify>);
+
+    impl ArcWake for NotifyOnDrop {
+        fn wake_by_ref(_arc_self: &Arc<Self>) {}
+    }
+
+    impl Drop for NotifyOnDrop {
+        fn drop(&mut self) {
+            self.0.notify_waiters();
+        }
+    }
+
+    let mut fut = Box::pin(async {
+        notify.clone().notified_owned().await;
+    });
+
+    {
+        let waker = futures::task::waker(Arc::new(NotifyOnDrop(notify.clone())));
+        let mut cx = std::task::Context::from_waker(&waker);
+        assert!(fut.as_mut().poll(&mut cx).is_pending());
+    }
+
+    // Now, notifying **should not** deadlock
+    notify.notify_waiters();
+}
+
+#[test]
+fn notify_one_after_dropped_all() {
+    let notify = Arc::new(Notify::new());
+    let mut notified1 = spawn(async { notify.clone().notified_owned().await });
+
+    assert_pending!(notified1.poll());
+
+    notify.notify_waiters();
+    notify.notify_one();
+
+    drop(notified1);
+
+    let mut notified2 = spawn(async { notify.clone().notified_owned().await });
+
+    assert_ready!(notified2.poll());
+}
+
+#[test]
+fn test_notify_one_not_enabled() {
+    let notify = Arc::new(Notify::new());
+    let mut future = spawn(notify.clone().notified_owned());
+
+    notify.notify_one();
+    assert_ready!(future.poll());
+}
+
+#[test]
+fn test_notify_one_after_enable() {
+    let notify = Arc::new(Notify::new());
+    let mut future = spawn(notify.clone().notified_owned());
+
+    future.enter(|_, fut| assert!(!fut.enable()));
+
+    notify.notify_one();
+    assert_ready!(future.poll());
+    future.enter(|_, fut| assert!(fut.enable()));
+}
+
+#[test]
+fn test_poll_after_enable() {
+    let notify = Arc::new(Notify::new());
+    let mut future = spawn(notify.clone().notified_owned());
+
+    future.enter(|_, fut| assert!(!fut.enable()));
+    assert_pending!(future.poll());
+}
+
+#[test]
+fn test_enable_after_poll() {
+    let notify = Arc::new(Notify::new());
+    let mut future = spawn(notify.clone().notified_owned());
+
+    assert_pending!(future.poll());
+    future.enter(|_, fut| assert!(!fut.enable()));
+}
+
+#[test]
+fn test_enable_consumes_permit() {
+    let notify = Arc::new(Notify::new());
+
+    // Add a permit.
+    notify.notify_one();
+
+    let mut future1 = spawn(notify.clone().notified_owned());
+    future1.enter(|_, fut| assert!(fut.enable()));
+
+    let mut future2 = spawn(notify.clone().notified_owned());
+    future2.enter(|_, fut| assert!(!fut.enable()));
+}
+
+#[test]
+fn test_waker_update() {
+    use futures::task::noop_waker;
+    use std::future::Future;
+    use std::task::Context;
+
+    let notify = Arc::new(Notify::new());
+    let mut future = spawn(notify.clone().notified_owned());
+
+    let noop = noop_waker();
+    future.enter(|_, fut| assert_pending!(fut.poll(&mut Context::from_waker(&noop))));
+
+    assert_pending!(future.poll());
+    notify.notify_one();
+
+    assert!(future.is_woken());
+}