]> git.feebdaed.xyz Git - 0xmirror/quic-go.git/commitdiff
http3: don't close QUIC listeners created by the application (#5129)
authorMarten Seemann <martenseemann@gmail.com>
Thu, 8 May 2025 05:40:23 +0000 (13:40 +0800)
committerGitHub <noreply@github.com>
Thu, 8 May 2025 05:40:23 +0000 (07:40 +0200)
http3/server.go
http3/server_test.go
integrationtests/self/http_hotswap_test.go
integrationtests/self/http_shutdown_test.go

index 10f3f80a59d3734622cf84b30cf9bccdd9a6861b..be004f136a12ef986f011cc4b7f45c14805f2355 100644 (file)
@@ -94,6 +94,9 @@ var RemoteAddrContextKey = &contextKey{"remote-addr"}
 type listener struct {
        ln   *QUICEarlyListener
        port int // 0 means that no info about port is available
+
+       // if this listener was constructed by the application, it won't be closed when the server is closed
+       createdLocally bool
 }
 
 // Server is a HTTP/3 server.
@@ -273,7 +276,7 @@ func (s *Server) ServeQUICConn(conn quic.Connection) error {
 // ServeListener always returns a non-nil error. After Shutdown or Close, the returned error is http.ErrServerClosed.
 func (s *Server) ServeListener(ln QUICEarlyListener) error {
        s.mutex.Lock()
-       if err := s.addListener(&ln); err != nil {
+       if err := s.addListener(&ln, false); err != nil {
                s.mutex.Unlock()
                return err
        }
@@ -344,7 +347,7 @@ func (s *Server) setupListenerForConn(tlsConf *tls.Config, conn net.PacketConn)
        if err != nil {
                return nil, err
        }
-       if err := s.addListener(&ln); err != nil {
+       if err := s.addListener(&ln, true); err != nil {
                return nil, err
        }
        return &ln, nil
@@ -401,7 +404,7 @@ func (s *Server) generateAltSvcHeader() {
        s.altSvcHeader = strings.Join(altSvc, ",")
 }
 
-func (s *Server) addListener(l *QUICEarlyListener) error {
+func (s *Server) addListener(l *QUICEarlyListener, createdLocally bool) error {
        if s.closed {
                return http.ErrServerClosed
        }
@@ -409,14 +412,14 @@ func (s *Server) addListener(l *QUICEarlyListener) error {
 
        laddr := (*l).Addr()
        if port, err := extractPort(laddr.String()); err == nil {
-               s.listeners = append(s.listeners, listener{ln: l, port: port})
+               s.listeners = append(s.listeners, listener{ln: l, port: port, createdLocally: createdLocally})
        } else {
                logger := s.Logger
                if logger == nil {
                        logger = slog.Default()
                }
                logger.Error("Unable to extract port from listener, will not be announced using SetQUICHeaders", "local addr", laddr, "error", err)
-               s.listeners = append(s.listeners, listener{ln: l, port: 0})
+               s.listeners = append(s.listeners, listener{ln: l, port: 0, createdLocally: createdLocally})
        }
        s.generateAltSvcHeader()
        return nil
@@ -688,9 +691,11 @@ func (s *Server) Close() error {
        s.closeCancel()
 
        var err error
-       for _, info := range s.listeners {
-               if cerr := (*info.ln).Close(); cerr != nil && err == nil {
-                       err = cerr
+       for _, l := range s.listeners {
+               if l.createdLocally {
+                       if cerr := (*l.ln).Close(); cerr != nil && err == nil {
+                               err = cerr
+                       }
                }
        }
        if s.connCount.Load() == 0 {
@@ -708,7 +713,7 @@ func (s *Server) Close() error {
 func (s *Server) Shutdown(ctx context.Context) error {
        s.mutex.Lock()
        s.closed = true
-       // server is never used
+       // server was never used
        if s.closeCtx == nil {
                s.mutex.Unlock()
                return nil
index b27d85f545b0e9caccc10c857d6a8b4bccdf3f51..02df9a0ac2c245bf46e8c3dd8d471820d37bafc4 100644 (file)
@@ -880,6 +880,18 @@ func TestServerConcurrentServeAndClose(t *testing.T) {
        }
 }
 
+func TestServerImmediateGracefulShutdown(t *testing.T) {
+       s := &Server{TLSConfig: testdata.GetTLSConfig()}
+       errChan := make(chan error, 1)
+       go func() { errChan <- s.Shutdown(context.Background()) }()
+       select {
+       case err := <-errChan:
+               require.NoError(t, err)
+       case <-time.After(time.Second):
+               t.Fatal("timeout")
+       }
+}
+
 func TestServerGracefulShutdown(t *testing.T) {
        s := &Server{TLSConfig: testdata.GetTLSConfig()}
        s.init()
index 638d8e24d3624867e236193704fe0d6bc98f8837..246322c918328da66b1583fbbeb34f6697cc7436 100644 (file)
@@ -1,12 +1,10 @@
 package self_test
 
 import (
-       "context"
        "io"
        "net"
        "net/http"
        "strconv"
-       "sync/atomic"
        "testing"
        "time"
 
@@ -15,48 +13,6 @@ import (
        "github.com/stretchr/testify/require"
 )
 
-type listenerWrapper struct {
-       http3.QUICEarlyListener
-       listenerClosed bool
-       count          atomic.Int32
-}
-
-func (ln *listenerWrapper) Close() error {
-       ln.listenerClosed = true
-       return ln.QUICEarlyListener.Close()
-}
-
-func (ln *listenerWrapper) Faker() *fakeClosingListener {
-       ln.count.Add(1)
-       ctx, cancel := context.WithCancel(context.Background())
-       return &fakeClosingListener{
-               listenerWrapper: ln,
-               ctx:             ctx,
-               cancel:          cancel,
-       }
-}
-
-type fakeClosingListener struct {
-       *listenerWrapper
-       closed atomic.Bool
-       ctx    context.Context
-       cancel context.CancelFunc
-}
-
-func (ln *fakeClosingListener) Accept(ctx context.Context) (quic.EarlyConnection, error) {
-       return ln.listenerWrapper.Accept(ln.ctx)
-}
-
-func (ln *fakeClosingListener) Close() error {
-       if ln.closed.CompareAndSwap(false, true) {
-               ln.cancel()
-               if ln.count.Add(-1) == 0 {
-                       ln.listenerWrapper.Close()
-               }
-       }
-       return nil
-}
-
 func TestHTTP3ServerHotswap(t *testing.T) {
        mux1 := http.NewServeMux()
        mux1.HandleFunc("/hello1", func(w http.ResponseWriter, r *http.Request) {
@@ -78,9 +34,8 @@ func TestHTTP3ServerHotswap(t *testing.T) {
        }
 
        tlsConf := http3.ConfigureTLSConfig(getTLSConfig())
-       quicLn, err := quic.ListenEarly(newUDPConnLocalhost(t), tlsConf, getQuicConfig(nil))
+       ln, err := quic.ListenEarly(newUDPConnLocalhost(t), tlsConf, getQuicConfig(nil))
        require.NoError(t, err)
-       ln := &listenerWrapper{QUICEarlyListener: quicLn}
        port := strconv.Itoa(ln.Addr().(*net.UDPAddr).Port)
 
        rt := &http3.Transport{
@@ -96,12 +51,8 @@ func TestHTTP3ServerHotswap(t *testing.T) {
        }()
 
        // open first server and make single request to it
-       fake1 := ln.Faker()
-       stoppedServing1 := make(chan struct{})
-       go func() {
-               server1.ServeListener(fake1)
-               close(stoppedServing1)
-       }()
+       errChan1 := make(chan error, 1)
+       go func() { errChan1 <- server1.ServeListener(ln) }()
 
        resp, err := client.Get("https://localhost:" + port + "/hello1")
        require.NoError(t, err)
@@ -111,36 +62,29 @@ func TestHTTP3ServerHotswap(t *testing.T) {
        require.Equal(t, "Hello, World 1!\n", string(body))
 
        // open second server with same underlying listener
-       fake2 := ln.Faker()
-       stoppedServing2 := make(chan struct{})
-       go func() {
-               server2.ServeListener(fake2)
-               close(stoppedServing2)
-       }()
+       errChan2 := make(chan error, 1)
+       go func() { errChan2 <- server2.ServeListener(ln) }()
 
-       // Verify both servers are running by waiting a bit and checking channels aren't closed
-       time.Sleep(50 * time.Millisecond)
+       time.Sleep(scaleDuration(20 * time.Millisecond))
        select {
-       case <-stoppedServing1:
-               t.Fatal("server1 stopped unexpectedly")
-       case <-stoppedServing2:
-               t.Fatal("server2 stopped unexpectedly")
+       case err := <-errChan1:
+               t.Fatalf("server1 stopped unexpectedly: %v", err)
+       case err := <-errChan2:
+               t.Fatalf("server2 stopped unexpectedly: %v", err)
        default:
        }
 
        // now close first server
        require.NoError(t, server1.Close())
        select {
-       case <-stoppedServing1:
-       case <-time.After(time.Second):
+       case err := <-errChan1:
+               require.ErrorIs(t, err, http.ErrServerClosed)
+       case <-time.After(5 * time.Second):
                t.Fatal("timed out waiting for server1 to stop")
        }
-       require.True(t, fake1.closed.Load())
-       require.False(t, fake2.closed.Load())
-       require.False(t, ln.listenerClosed)
        require.NoError(t, client.Transport.(*http3.Transport).Close())
 
-       // verify that new connections are being initiated from the second server now
+       // verify that new connections are handled by the second server now
        resp, err = client.Get("https://localhost:" + port + "/hello2")
        require.NoError(t, err)
        require.Equal(t, http.StatusOK, resp.StatusCode)
@@ -148,13 +92,12 @@ func TestHTTP3ServerHotswap(t *testing.T) {
        require.NoError(t, err)
        require.Equal(t, "Hello, World 2!\n", string(body))
 
-       // close the other server - both the fake and the actual listeners must close now
+       // close the other server
        require.NoError(t, server2.Close())
        select {
-       case <-stoppedServing2:
+       case err := <-errChan2:
+               require.ErrorIs(t, err, http.ErrServerClosed)
        case <-time.After(time.Second):
                t.Fatal("timed out waiting for server2 to stop")
        }
-       require.True(t, fake2.closed.Load())
-       require.True(t, ln.listenerClosed)
 }
index fd1e00e2c22ce8d2f120299874499b43af8f312a..0ff40e3ef586414c5851bd6bd2b1ad43e00ddff4 100644 (file)
@@ -6,9 +6,11 @@ import (
        "io"
        "net"
        "net/http"
+       "net/url"
        "testing"
        "time"
 
+       "github.com/quic-go/quic-go"
        "github.com/quic-go/quic-go/http3"
        quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
 
@@ -196,3 +198,118 @@ func TestGracefulShutdownPendingStreams(t *testing.T) {
                t.Fatal("shutdown did not complete")
        }
 }
+
+func TestHTTP3ListenerClosing(t *testing.T) {
+       t.Run("application listener", func(t *testing.T) {
+               testHTTP3ListenerClosing(t, true)
+       })
+       t.Run("listener created by the http3.Server", func(t *testing.T) {
+               testHTTP3ListenerClosing(t, false)
+       })
+}
+
+func testHTTP3ListenerClosing(t *testing.T, useApplicationListener bool) {
+       dial := func(t *testing.T, ctx context.Context, u *url.URL) error {
+               t.Helper()
+               tlsConf := getTLSClientConfig()
+               tlsConf.NextProtos = []string{http3.NextProtoH3}
+               tr := &http3.Transport{TLSClientConfig: tlsConf}
+               defer tr.Close()
+               cl := &http.Client{Transport: tr}
+               req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
+               require.NoError(t, err)
+               resp, err := cl.Do(req)
+               if err != nil {
+                       return err
+               }
+               defer resp.Body.Close()
+               require.Equal(t, http.StatusOK, resp.StatusCode)
+               return nil
+       }
+
+       mux := http.NewServeMux()
+       mux.HandleFunc("/ok", func(w http.ResponseWriter, r *http.Request) {
+               w.WriteHeader(http.StatusOK)
+       })
+       tlsConf := http3.ConfigureTLSConfig(getTLSConfig())
+       server := &http3.Server{
+               Handler: mux,
+               // the following values will be ignored when using ServeListener
+               TLSConfig:  tlsConf,
+               QUICConfig: getQuicConfig(nil),
+               Addr:       "127.0.0.1:47283",
+       }
+
+       serveChan := make(chan error, 1)
+       var host string
+       var ln *quic.EarlyListener // only set when using application listener
+       if useApplicationListener {
+               var err error
+               ln, err = quic.ListenEarly(newUDPConnLocalhost(t), tlsConf, getQuicConfig(nil))
+               require.NoError(t, err)
+               defer ln.Close()
+               host = ln.Addr().String()
+               go func() { serveChan <- server.ServeListener(ln) }()
+       } else {
+               go func() { serveChan <- server.ListenAndServe() }()
+               host = server.Addr
+       }
+
+       u := &url.URL{Scheme: "https", Host: host, Path: "/ok"}
+       ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+       defer cancel()
+       require.NoError(t, dial(t, ctx, u))
+
+       // close the server
+       require.NoError(t, server.Close())
+
+       select {
+       case err := <-serveChan:
+               require.ErrorIs(t, err, http.ErrServerClosed)
+       case <-time.After(time.Second):
+               t.Fatal("server did not stop")
+       }
+
+       // If the listener was created by the http3.Server, it will now be closed.
+       if !useApplicationListener {
+               ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(10*time.Millisecond))
+               defer cancel()
+               require.ErrorIs(t, dial(t, ctx, u), context.DeadlineExceeded)
+               return
+       }
+
+       // If the listener was created by the application, it will not be closed,
+       // and it can be used to accept new connections.
+       errChan := make(chan error, 1)
+       go func() {
+               for {
+                       conn, err := ln.Accept(context.Background())
+                       if err != nil {
+                               errChan <- err
+                               return
+                       }
+                       select {
+                       case <-conn.HandshakeComplete():
+                               conn.CloseWithError(1337, "")
+                       case <-time.After(time.Second):
+                               errChan <- fmt.Errorf("connection did not complete handshake")
+                       }
+                       errChan <- nil
+               }
+       }()
+
+       for range 3 {
+               ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+               defer cancel()
+               err := dial(t, ctx, u)
+               var h3Err *http3.Error
+               require.ErrorAs(t, err, &h3Err)
+               require.Equal(t, http3.ErrCode(1337), h3Err.ErrorCode)
+               select {
+               case err := <-errChan:
+                       require.NoError(t, err)
+               case <-time.After(time.Second):
+                       t.Fatal("server did not accept connection")
+               }
+       }
+}