]> git.feebdaed.xyz Git - 0xmirror/tokio.git/commitdiff
net: ignore `NotConnected` in `TcpStream::shutdown` (#7290)
authorsoundofspace <116737867+soundofspace@users.noreply.github.com>
Tue, 6 May 2025 08:27:40 +0000 (10:27 +0200)
committerGitHub <noreply@github.com>
Tue, 6 May 2025 08:27:40 +0000 (17:27 +0900)
tokio/src/net/tcp/stream.rs
tokio/tests/tcp_shutdown.rs

index b0e3ec27ce16c6341f5134af74752ca57144837b..f64a526b4fd5f3b96f225a2bef4108c90c030ddd 100644 (file)
@@ -1112,8 +1112,16 @@ impl TcpStream {
     /// This function will cause all pending and future I/O on the specified
     /// portions to return immediately with an appropriate value (see the
     /// documentation of `Shutdown`).
+    ///
+    /// Remark: this function transforms `Err(std::io::ErrorKind::NotConnected)` to `Ok(())`.
+    /// It does this to abstract away OS specific logic and to prevent a race condition between
+    /// this function call and the OS closing this socket because of external events (e.g. TCP reset).
+    /// See <https://github.com/tokio-rs/tokio/issues/4665> for more information.
     pub(super) fn shutdown_std(&self, how: Shutdown) -> io::Result<()> {
-        self.io.shutdown(how)
+        match self.io.shutdown(how) {
+            Err(err) if err.kind() == std::io::ErrorKind::NotConnected => Ok(()),
+            result => result,
+        }
     }
 
     /// Gets the value of the `TCP_NODELAY` option on this socket.
index 2497c1a401d1cfd2cac464d11c43e612962a27a9..837e6123053b4ff3a6648180d072578da190e373 100644 (file)
@@ -2,8 +2,10 @@
 #![cfg(all(feature = "full", not(target_os = "wasi"), not(miri)))] // Wasi doesn't support bind
                                                                    // No `socket` on miri.
 
+use std::time::Duration;
 use tokio::io::{self, AsyncReadExt, AsyncWriteExt};
 use tokio::net::{TcpListener, TcpStream};
+use tokio::sync::oneshot::channel;
 use tokio_test::assert_ok;
 
 #[tokio::test]
@@ -11,7 +13,7 @@ async fn shutdown() {
     let srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
     let addr = assert_ok!(srv.local_addr());
 
-    tokio::spawn(async move {
+    let handle = tokio::spawn(async move {
         let mut stream = assert_ok!(TcpStream::connect(&addr).await);
 
         assert_ok!(AsyncWriteExt::shutdown(&mut stream).await);
@@ -26,4 +28,55 @@ async fn shutdown() {
 
     let n = assert_ok!(io::copy(&mut rd, &mut wr).await);
     assert_eq!(n, 0);
+    assert_ok!(AsyncWriteExt::shutdown(&mut stream).await);
+    handle.await.unwrap()
+}
+
+#[tokio::test]
+async fn shutdown_after_tcp_reset() {
+    let srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
+    let addr = assert_ok!(srv.local_addr());
+
+    let (connected_tx, connected_rx) = channel();
+    let (dropped_tx, dropped_rx) = channel();
+
+    let handle = tokio::spawn(async move {
+        let mut stream = assert_ok!(TcpStream::connect(&addr).await);
+        connected_tx.send(()).unwrap();
+
+        dropped_rx.await.unwrap();
+        assert_ok!(AsyncWriteExt::shutdown(&mut stream).await);
+    });
+
+    let (stream, _) = assert_ok!(srv.accept().await);
+    // By setting linger to 0 we will trigger a TCP reset
+    stream.set_linger(Some(Duration::new(0, 0))).unwrap();
+    connected_rx.await.unwrap();
+
+    drop(stream);
+    dropped_tx.send(()).unwrap();
+
+    handle.await.unwrap();
+}
+
+#[tokio::test]
+async fn shutdown_multiple_calls() {
+    let srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
+    let addr = assert_ok!(srv.local_addr());
+
+    let (connected_tx, connected_rx) = channel();
+
+    let handle = tokio::spawn(async move {
+        let mut stream = assert_ok!(TcpStream::connect(&addr).await);
+        connected_tx.send(()).unwrap();
+        assert_ok!(AsyncWriteExt::shutdown(&mut stream).await);
+        assert_ok!(AsyncWriteExt::shutdown(&mut stream).await);
+        assert_ok!(AsyncWriteExt::shutdown(&mut stream).await);
+    });
+
+    let (mut stream, _) = assert_ok!(srv.accept().await);
+    connected_rx.await.unwrap();
+
+    assert_ok!(AsyncWriteExt::shutdown(&mut stream).await);
+    handle.await.unwrap();
 }