]> git.feebdaed.xyz Git - 0xmirror/tokio.git/commitdiff
task: remove raw-entry feature from hashbrown dep (#7252)
authorConrad Ludgate <oon@conradludgate.com>
Tue, 22 Jul 2025 13:52:54 +0000 (14:52 +0100)
committerGitHub <noreply@github.com>
Tue, 22 Jul 2025 13:52:54 +0000 (15:52 +0200)
tokio-util/Cargo.toml
tokio-util/src/task/join_map.rs
tokio-util/tests/task_join_map.rs

index 4187ebfb53e29d4e0b39f3d286b60dab0d4b40c8..084d123fa58d45172d4eacd034a09c1958f112ab 100644 (file)
@@ -45,7 +45,7 @@ slab = { version = "0.4.4", optional = true } # Backs `DelayQueue`
 tracing = { version = "0.1.29", default-features = false, features = ["std"], optional = true }
 
 [target.'cfg(tokio_unstable)'.dependencies]
-hashbrown = { version = "0.15.0", default-features = false, features = ["raw-entry"], optional = true }
+hashbrown = { version = "0.15.0", default-features = false, optional = true }
 
 [dev-dependencies]
 tokio = { version = "1.0.0", path = "../tokio", features = ["full"] }
index d5c5d2a42c2d6fce8cfa9e6fc7480d11f52c07c7..bf416562065fca5e37da20c4788fcccd6c124161 100644 (file)
@@ -1,5 +1,5 @@
-use hashbrown::hash_map::RawEntryMut;
-use hashbrown::HashMap;
+use hashbrown::hash_table::Entry;
+use hashbrown::{HashMap, HashTable};
 use std::borrow::Borrow;
 use std::collections::hash_map::RandomState;
 use std::fmt;
@@ -103,13 +103,8 @@ use tokio::task::{AbortHandle, Id, JoinError, JoinSet, LocalSet};
 #[cfg_attr(docsrs, doc(cfg(all(feature = "rt", tokio_unstable))))]
 pub struct JoinMap<K, V, S = RandomState> {
     /// A map of the [`AbortHandle`]s of the tasks spawned on this `JoinMap`,
-    /// indexed by their keys and task IDs.
-    ///
-    /// The [`Key`] type contains both the task's `K`-typed key provided when
-    /// spawning tasks, and the task's IDs. The IDs are stored here to resolve
-    /// hash collisions when looking up tasks based on their pre-computed hash
-    /// (as stored in the `hashes_by_task` map).
-    tasks_by_key: HashMap<Key<K>, AbortHandle, S>,
+    /// indexed by their keys.
+    tasks_by_key: HashTable<(K, AbortHandle)>,
 
     /// A map from task IDs to the hash of the key associated with that task.
     ///
@@ -125,21 +120,6 @@ pub struct JoinMap<K, V, S = RandomState> {
     tasks: JoinSet<V>,
 }
 
-/// A [`JoinMap`] key.
-///
-/// This holds both a `K`-typed key (the actual key as seen by the user), _and_
-/// a task ID, so that hash collisions between `K`-typed keys can be resolved
-/// using either `K`'s `Eq` impl *or* by checking the task IDs.
-///
-/// This allows looking up a task using either an actual key (such as when the
-/// user queries the map with a key), *or* using a task ID and a hash (such as
-/// when removing completed tasks from the map).
-#[derive(Debug)]
-struct Key<K> {
-    key: K,
-    id: Id,
-}
-
 impl<K, V> JoinMap<K, V> {
     /// Creates a new empty `JoinMap`.
     ///
@@ -176,7 +156,7 @@ impl<K, V> JoinMap<K, V> {
     }
 }
 
-impl<K, V, S: Clone> JoinMap<K, V, S> {
+impl<K, V, S> JoinMap<K, V, S> {
     /// Creates an empty `JoinMap` which will use the given hash builder to hash
     /// keys.
     ///
@@ -226,7 +206,7 @@ impl<K, V, S: Clone> JoinMap<K, V, S> {
     #[must_use]
     pub fn with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Self {
         Self {
-            tasks_by_key: HashMap::with_capacity_and_hasher(capacity, hash_builder.clone()),
+            tasks_by_key: HashTable::with_capacity(capacity),
             hashes_by_task: HashMap::with_capacity_and_hasher(capacity, hash_builder),
             tasks: JoinSet::new(),
         }
@@ -415,33 +395,42 @@ where
         self.insert(key, task)
     }
 
-    fn insert(&mut self, key: K, abort: AbortHandle) {
-        let hash = self.hash(&key);
+    fn insert(&mut self, mut key: K, mut abort: AbortHandle) {
+        let hash_builder = self.hashes_by_task.hasher();
+        let hash = hash_one(hash_builder, &key);
         let id = abort.id();
-        let map_key = Key { id, key };
 
         // Insert the new key into the map of tasks by keys.
-        let entry = self
-            .tasks_by_key
-            .raw_entry_mut()
-            .from_hash(hash, |k| k.key == map_key.key);
+        let entry =
+            self.tasks_by_key
+                .entry(hash, |(k, _)| *k == key, |(k, _)| hash_one(hash_builder, k));
         match entry {
-            RawEntryMut::Occupied(mut occ) => {
+            Entry::Occupied(occ) => {
                 // There was a previous task spawned with the same key! Cancel
                 // that task, and remove its ID from the map of hashes by task IDs.
-                let Key { id: prev_id, .. } = occ.insert_key(map_key);
-                occ.insert(abort).abort();
-                let _prev_hash = self.hashes_by_task.remove(&prev_id);
+                (key, abort) = std::mem::replace(occ.into_mut(), (key, abort));
+
+                // Remove the old task ID.
+                let _prev_hash = self.hashes_by_task.remove(&abort.id());
                 debug_assert_eq!(Some(hash), _prev_hash);
+
+                // Associate the key's hash with the new task's ID, for looking up tasks by ID.
+                let _prev = self.hashes_by_task.insert(id, hash);
+                debug_assert!(_prev.is_none(), "no prior task should have had the same ID");
+
+                // Note: it's important to drop `key` and abort the task here.
+                // This defends against any panics during drop handling for causing inconsistent state.
+                abort.abort();
+                drop(key);
             }
-            RawEntryMut::Vacant(vac) => {
-                vac.insert(map_key, abort);
+            Entry::Vacant(vac) => {
+                vac.insert((key, abort));
+
+                // Associate the key's hash with this task's ID, for looking up tasks by ID.
+                let _prev = self.hashes_by_task.insert(id, hash);
+                debug_assert!(_prev.is_none(), "no prior task should have had the same ID");
             }
         };
-
-        // Associate the key's hash with this task's ID, for looking up tasks by ID.
-        let _prev = self.hashes_by_task.insert(id, hash);
-        debug_assert!(_prev.is_none(), "no prior task should have had the same ID");
     }
 
     /// Waits until one of the tasks in the map completes and returns its
@@ -623,7 +612,7 @@ where
         // Note: this method iterates over the tasks and keys *without* removing
         // any entries, so that the keys from aborted tasks can still be
         // returned when calling `join_next` in the future.
-        for (Key { ref key, .. }, task) in &self.tasks_by_key {
+        for (key, task) in &self.tasks_by_key {
             if predicate(key) {
                 task.abort();
             }
@@ -638,7 +627,7 @@ where
     /// [`join_next`]: fn@Self::join_next
     pub fn keys(&self) -> JoinMapKeys<'_, K, V> {
         JoinMapKeys {
-            iter: self.tasks_by_key.keys(),
+            iter: self.tasks_by_key.iter(),
             _value: PhantomData,
         }
     }
@@ -666,7 +655,7 @@ where
     /// [`join_next`]: fn@Self::join_next
     /// [task ID]: tokio::task::Id
     pub fn contains_task(&self, task: &Id) -> bool {
-        self.get_by_id(task).is_some()
+        self.hashes_by_task.contains_key(task)
     }
 
     /// Reserves capacity for at least `additional` more tasks to be spawned
@@ -690,7 +679,9 @@ where
     /// ```
     #[inline]
     pub fn reserve(&mut self, additional: usize) {
-        self.tasks_by_key.reserve(additional);
+        let hash_builder = self.hashes_by_task.hasher();
+        self.tasks_by_key
+            .reserve(additional, |(k, _)| hash_one(hash_builder, k));
         self.hashes_by_task.reserve(additional);
     }
 
@@ -716,7 +707,9 @@ where
     #[inline]
     pub fn shrink_to_fit(&mut self) {
         self.hashes_by_task.shrink_to_fit();
-        self.tasks_by_key.shrink_to_fit();
+        let hash_builder = self.hashes_by_task.hasher();
+        self.tasks_by_key
+            .shrink_to_fit(|(k, _)| hash_one(hash_builder, k));
     }
 
     /// Shrinks the capacity of the map with a lower limit. It will drop
@@ -745,27 +738,20 @@ where
     #[inline]
     pub fn shrink_to(&mut self, min_capacity: usize) {
         self.hashes_by_task.shrink_to(min_capacity);
-        self.tasks_by_key.shrink_to(min_capacity)
+        let hash_builder = self.hashes_by_task.hasher();
+        self.tasks_by_key
+            .shrink_to(min_capacity, |(k, _)| hash_one(hash_builder, k))
     }
 
     /// Look up a task in the map by its key, returning the key and abort handle.
-    fn get_by_key<'map, Q: ?Sized>(&'map self, key: &Q) -> Option<(&'map Key<K>, &'map AbortHandle)>
+    fn get_by_key<'map, Q: ?Sized>(&'map self, key: &Q) -> Option<&'map (K, AbortHandle)>
     where
         Q: Hash + Eq,
         K: Borrow<Q>,
     {
-        let hash = self.hash(key);
-        self.tasks_by_key
-            .raw_entry()
-            .from_hash(hash, |k| k.key.borrow() == key)
-    }
-
-    /// Look up a task in the map by its task ID, returning the key and abort handle.
-    fn get_by_id<'map>(&'map self, id: &Id) -> Option<(&'map Key<K>, &'map AbortHandle)> {
-        let hash = self.hashes_by_task.get(id)?;
-        self.tasks_by_key
-            .raw_entry()
-            .from_hash(*hash, |k| &k.id == id)
+        let hash_builder = self.hashes_by_task.hasher();
+        let hash = hash_one(hash_builder, key);
+        self.tasks_by_key.find(hash, |(k, _)| k.borrow() == key)
     }
 
     /// Remove a task from the map by ID, returning the key for that task.
@@ -776,28 +762,25 @@ where
         // Remove the entry for that hash.
         let entry = self
             .tasks_by_key
-            .raw_entry_mut()
-            .from_hash(hash, |k| k.id == id);
-        let (Key { id: _key_id, key }, handle) = match entry {
-            RawEntryMut::Occupied(entry) => entry.remove_entry(),
+            .find_entry(hash, |(_, abort)| abort.id() == id);
+        let (key, _) = match entry {
+            Ok(entry) => entry.remove().0,
             _ => return None,
         };
-        debug_assert_eq!(_key_id, id);
-        debug_assert_eq!(id, handle.id());
         self.hashes_by_task.remove(&id);
         Some(key)
     }
+}
 
-    /// Returns the hash for a given key.
-    #[inline]
-    fn hash<Q: ?Sized>(&self, key: &Q) -> u64
-    where
-        Q: Hash,
-    {
-        let mut hasher = self.tasks_by_key.hasher().build_hasher();
-        key.hash(&mut hasher);
-        hasher.finish()
-    }
+/// Returns the hash for a given key.
+#[inline]
+fn hash_one<S: BuildHasher, Q: ?Sized>(hash_builder: &S, key: &Q) -> u64
+where
+    Q: Hash,
+{
+    let mut hasher = hash_builder.build_hasher();
+    key.hash(&mut hasher);
+    hasher.finish()
 }
 
 impl<K, V, S> JoinMap<K, V, S>
@@ -831,11 +814,11 @@ impl<K: fmt::Debug, V, S> fmt::Debug for JoinMap<K, V, S> {
         // printing the key and task ID pairs, without format the `Key` struct
         // itself or the `AbortHandle`, which would just format the task's ID
         // again.
-        struct KeySet<'a, K: fmt::Debug, S>(&'a HashMap<Key<K>, AbortHandle, S>);
-        impl<K: fmt::Debug, S> fmt::Debug for KeySet<'_, K, S> {
+        struct KeySet<'a, K: fmt::Debug>(&'a HashTable<(K, AbortHandle)>);
+        impl<K: fmt::Debug> fmt::Debug for KeySet<'_, K> {
             fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
                 f.debug_map()
-                    .entries(self.0.keys().map(|Key { key, id }| (key, id)))
+                    .entries(self.0.iter().map(|(key, abort)| (key, abort.id())))
                     .finish()
             }
         }
@@ -856,31 +839,10 @@ impl<K, V> Default for JoinMap<K, V> {
     }
 }
 
-// === impl Key ===
-
-impl<K: Hash> Hash for Key<K> {
-    // Don't include the task ID in the hash.
-    #[inline]
-    fn hash<H: Hasher>(&self, hasher: &mut H) {
-        self.key.hash(hasher);
-    }
-}
-
-// Because we override `Hash` for this type, we must also override the
-// `PartialEq` impl, so that all instances with the same hash are equal.
-impl<K: PartialEq> PartialEq for Key<K> {
-    #[inline]
-    fn eq(&self, other: &Self) -> bool {
-        self.key == other.key
-    }
-}
-
-impl<K: Eq> Eq for Key<K> {}
-
 /// An iterator over the keys of a [`JoinMap`].
 #[derive(Debug, Clone)]
 pub struct JoinMapKeys<'a, K, V> {
-    iter: hashbrown::hash_map::Keys<'a, Key<K>, AbortHandle>,
+    iter: hashbrown::hash_table::Iter<'a, (K, AbortHandle)>,
     /// To make it easier to change `JoinMap` in the future, keep V as a generic
     /// parameter.
     _value: PhantomData<&'a V>,
@@ -890,7 +852,7 @@ impl<'a, K, V> Iterator for JoinMapKeys<'a, K, V> {
     type Item = &'a K;
 
     fn next(&mut self) -> Option<&'a K> {
-        self.iter.next().map(|key| &key.key)
+        self.iter.next().map(|(key, _)| key)
     }
 
     fn size_hint(&self) -> (usize, Option<usize>) {
index c86d76dafa0574ff2e2b9ef842922293af1d3a82..2dcb18804be05638ba0927725efb6facacc8ce86 100644 (file)
@@ -1,6 +1,8 @@
 #![warn(rust_2018_idioms)]
 #![cfg(all(feature = "rt", tokio_unstable))]
 
+use std::panic::AssertUnwindSafe;
+
 use tokio::sync::oneshot;
 use tokio::time::Duration;
 use tokio_util::task::JoinMap;
@@ -343,3 +345,34 @@ async fn duplicate_keys2() {
 
     assert!(map.join_next().await.is_none());
 }
+
+#[cfg_attr(not(panic = "unwind"), ignore)]
+#[tokio::test]
+async fn duplicate_keys_drop() {
+    #[derive(Hash, Debug, PartialEq, Eq)]
+    struct Key;
+    impl Drop for Key {
+        fn drop(&mut self) {
+            panic!("drop called for key");
+        }
+    }
+
+    let (send, recv) = oneshot::channel::<()>();
+
+    let mut map = JoinMap::new();
+
+    map.spawn(Key, async { recv.await.unwrap() });
+
+    // replace the task, force it to drop the key and abort the task
+    // we should expect it to panic when dropping the key.
+    let _ = std::panic::catch_unwind(AssertUnwindSafe(|| map.spawn(Key, async {}))).unwrap_err();
+
+    // don't panic when this key drops.
+    let (key, _) = map.join_next().await.unwrap();
+    std::mem::forget(key);
+
+    // original task should have been aborted, so the sender should be dangling.
+    assert!(send.is_closed());
+
+    assert!(map.join_next().await.is_none());
+}