]> git.feebdaed.xyz Git - 0xmirror/quic-go.git/commitdiff
http3: use actual QUIC connection and stream in server tests (#5161)
authorMarten Seemann <martenseemann@gmail.com>
Thu, 29 May 2025 09:06:09 +0000 (17:06 +0800)
committerGitHub <noreply@github.com>
Thu, 29 May 2025 09:06:09 +0000 (11:06 +0200)
http3/http3_helper_test.go
http3/server_test.go

index 013989241fbfa6ef3ceef3cbc87b50e3e2d1a796..54a1c78b219658816b54cd3cc3ea74eac146f9e9 100644 (file)
@@ -1,12 +1,20 @@
 package http3
 
 import (
+       "bytes"
+       "crypto/tls"
+       "crypto/x509"
+       "io"
        "net"
+       "net/http"
        "os"
        "strconv"
        "testing"
        "time"
 
+       "github.com/quic-go/qpack"
+       "github.com/quic-go/quic-go/integrationtests/tools"
+
        "github.com/stretchr/testify/require"
 )
 
@@ -28,3 +36,70 @@ func scaleDuration(t time.Duration) time.Duration {
        }
        return time.Duration(scaleFactor) * t
 }
+
+var tlsConfig, tlsClientConfig *tls.Config
+
+func init() {
+       ca, caPrivateKey, err := tools.GenerateCA()
+       if err != nil {
+               panic(err)
+       }
+       leafCert, leafPrivateKey, err := tools.GenerateLeafCert(ca, caPrivateKey)
+       if err != nil {
+               panic(err)
+       }
+       tlsConfig = &tls.Config{
+               Certificates: []tls.Certificate{{
+                       Certificate: [][]byte{leafCert.Raw},
+                       PrivateKey:  leafPrivateKey,
+               }},
+               NextProtos: []string{NextProtoH3},
+       }
+
+       root := x509.NewCertPool()
+       root.AddCert(ca)
+       tlsClientConfig = &tls.Config{
+               ServerName: "localhost",
+               RootCAs:    root,
+               NextProtos: []string{NextProtoH3},
+       }
+}
+
+func getTLSConfig() *tls.Config       { return tlsConfig.Clone() }
+func getTLSClientConfig() *tls.Config { return tlsClientConfig.Clone() }
+
+func encodeRequest(t *testing.T, req *http.Request) []byte {
+       t.Helper()
+
+       var buf bytes.Buffer
+       rw := newRequestWriter()
+       require.NoError(t, rw.WriteRequestHeader(&buf, req, false))
+       if req.Body != nil {
+               body, err := io.ReadAll(req.Body)
+               require.NoError(t, err)
+               buf.Write((&dataFrame{Length: uint64(len(body))}).Append(nil))
+               buf.Write(body)
+       }
+       return buf.Bytes()
+}
+
+func decodeHeader(t *testing.T, r io.Reader) map[string][]string {
+       t.Helper()
+
+       fields := make(map[string][]string)
+       decoder := qpack.NewDecoder(nil)
+
+       frame, err := (&frameParser{r: r}).ParseNext()
+       require.NoError(t, err)
+       require.IsType(t, &headersFrame{}, frame)
+       headersFrame := frame.(*headersFrame)
+       data := make([]byte, headersFrame.Length)
+       _, err = io.ReadFull(r, data)
+       require.NoError(t, err)
+       hfs, err := decoder.DecodeFull(data)
+       require.NoError(t, err)
+       for _, p := range hfs {
+               fields[p.Name] = append(fields[p.Name], p.Value)
+       }
+       return fields
+}
index 237d9ae41706f33a7c00e43cefd8fdd314972e32..933d9633f1bc77db36dd4867d95fe3cd3069f453 100644 (file)
@@ -4,28 +4,54 @@ import (
        "bytes"
        "context"
        "crypto/tls"
+       "errors"
        "fmt"
        "io"
        "log/slog"
        "net"
        "net/http"
        "net/http/httptest"
+       "os"
        "runtime"
        "testing"
        "time"
 
-       "github.com/quic-go/qpack"
        "github.com/quic-go/quic-go"
-       mockquic "github.com/quic-go/quic-go/internal/mocks/quic"
        "github.com/quic-go/quic-go/internal/protocol"
        "github.com/quic-go/quic-go/internal/testdata"
+       "github.com/quic-go/quic-go/qlog"
        "github.com/quic-go/quic-go/quicvarint"
 
        "github.com/stretchr/testify/assert"
        "github.com/stretchr/testify/require"
-       "go.uber.org/mock/gomock"
 )
 
+func getConnPair(t *testing.T) (client, server quic.Connection) {
+       t.Helper()
+
+       ln, err := quic.Listen(
+               newUDPConnLocalhost(t),
+               getTLSConfig(),
+               &quic.Config{
+                       InitialStreamReceiveWindow:     uint64(protocol.MaxByteCount),
+                       InitialConnectionReceiveWindow: uint64(protocol.MaxByteCount),
+                       Tracer:                         qlog.DefaultConnectionTracer,
+               },
+       )
+       require.NoError(t, err)
+
+       cl, err := quic.Dial(context.Background(), newUDPConnLocalhost(t), ln.Addr(), getTLSClientConfig(), &quic.Config{})
+       require.NoError(t, err)
+       t.Cleanup(func() { cl.CloseWithError(0, "") })
+
+       ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+       defer cancel()
+       conn, err := ln.Accept(ctx)
+       require.NoError(t, err)
+       t.Cleanup(func() { conn.CloseWithError(0, "") })
+       return cl, conn
+}
+
 func TestConfigureTLSConfig(t *testing.T) {
        t.Run("basic config", func(t *testing.T) {
                conf := ConfigureTLSConfig(&tls.Config{})
@@ -84,66 +110,33 @@ func testServerSettings(t *testing.T, enableDatagrams bool, other map[uint64]uin
 
        testDone := make(chan struct{})
        defer close(testDone)
-       settingsChan := make(chan []byte)
-       mockCtrl := gomock.NewController(t)
-       conn := mockquic.NewMockEarlyConnection(mockCtrl)
-       controlStr := mockquic.NewMockStream(mockCtrl)
-       controlStr.EXPECT().Write(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
-               settingsChan <- b
-               return len(b), nil
-       })
 
-       conn.EXPECT().OpenUniStream().Return(controlStr, nil)
-       conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
-               <-testDone
-               return nil, assert.AnError
-       }).MaxTimes(1)
-       conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(ctx context.Context) (quic.Stream, error) {
-               <-testDone
-               return nil, assert.AnError
-       }).MaxTimes(1)
-       conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
-       conn.EXPECT().LocalAddr().AnyTimes()
-       conn.EXPECT().Context().Return(context.Background()).AnyTimes()
-
-       go s.handleConn(conn)
+       clientConn, serverConn := getConnPair(t)
+       go s.handleConn(serverConn)
 
-       select {
-       case b := <-settingsChan:
-               typ, l, err := quicvarint.Parse(b)
-               require.NoError(t, err)
-               require.EqualValues(t, streamTypeControlStream, typ)
-               fp := (&frameParser{r: bytes.NewReader(b[l:])})
-               f, err := fp.ParseNext()
-               require.NoError(t, err)
-               require.IsType(t, &settingsFrame{}, f)
-               settingsFrame := f.(*settingsFrame)
-               // Extended CONNECT is always supported
-               require.True(t, settingsFrame.ExtendedConnect)
-               require.Equal(t, settingsFrame.Datagram, enableDatagrams)
-               require.Equal(t, settingsFrame.Other, other)
-       case <-time.After(time.Second):
-               t.Fatal("timeout")
-       }
-}
-
-func decodeHeader(t *testing.T, r io.Reader) map[string][]string {
-       fields := make(map[string][]string)
-       decoder := qpack.NewDecoder(nil)
+       ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+       defer cancel()
+       settingsStr, err := clientConn.AcceptUniStream(ctx)
+       require.NoError(t, err)
 
-       frame, err := (&frameParser{r: r}).ParseNext()
+       settingsStr.SetReadDeadline(time.Now().Add(time.Second))
+       b := make([]byte, 1024)
+       n, err := settingsStr.Read(b)
        require.NoError(t, err)
-       require.IsType(t, &headersFrame{}, frame)
-       headersFrame := frame.(*headersFrame)
-       data := make([]byte, headersFrame.Length)
-       _, err = io.ReadFull(r, data)
+       b = b[:n]
+
+       typ, l, err := quicvarint.Parse(b)
        require.NoError(t, err)
-       hfs, err := decoder.DecodeFull(data)
+       require.EqualValues(t, streamTypeControlStream, typ)
+       fp := (&frameParser{r: bytes.NewReader(b[l:])})
+       f, err := fp.ParseNext()
        require.NoError(t, err)
-       for _, p := range hfs {
-               fields[p.Name] = append(fields[p.Name], p.Value)
-       }
-       return fields
+       require.IsType(t, &settingsFrame{}, f)
+       settingsFrame := f.(*settingsFrame)
+       // Extended CONNECT is always supported
+       require.True(t, settingsFrame.ExtendedConnect)
+       require.Equal(t, settingsFrame.Datagram, enableDatagrams)
+       require.Equal(t, settingsFrame.Other, other)
 }
 
 func TestServerRequestHandling(t *testing.T) {
@@ -208,47 +201,22 @@ func TestServerRequestHandling(t *testing.T) {
        })
 }
 
-func encodeRequest(t *testing.T, req *http.Request) io.Reader {
-       var buf bytes.Buffer
-       rw := newRequestWriter()
-       require.NoError(t, rw.WriteRequestHeader(&buf, req, false))
-       if req.Body != nil {
-               body, err := io.ReadAll(req.Body)
-               require.NoError(t, err)
-               buf.Write((&dataFrame{Length: uint64(len(body))}).Append(nil))
-               buf.Write(body)
-       }
-       return bytes.NewReader(buf.Bytes())
-}
-
 func testServerRequestHandling(t *testing.T,
        handler http.HandlerFunc,
        req *http.Request,
 ) (responseHeaders map[string][]string, body []byte) {
-       responseBuf := &bytes.Buffer{}
-       mockCtrl := gomock.NewController(t)
-       str := NewMockDatagramStream(mockCtrl)
-       str.EXPECT().Context().Return(context.Background()).AnyTimes()
-       str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
-       str.EXPECT().CancelRead(gomock.Any())
-       str.EXPECT().Close()
-       str.EXPECT().Read(gomock.Any()).DoAndReturn(encodeRequest(t, req).Read).AnyTimes()
-
-       s := &Server{
-               TLSConfig: testdata.GetTLSConfig(),
-               Handler:   handler,
-       }
+       clientConn, serverConn := getConnPair(t)
+       str, err := clientConn.OpenStream()
+       require.NoError(t, err)
+       _, err = str.Write(encodeRequest(t, req))
+       require.NoError(t, err)
+       require.NoError(t, str.Close())
 
-       qconn := mockquic.NewMockEarlyConnection(mockCtrl)
-       qconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
-       qconn.EXPECT().LocalAddr().AnyTimes()
-       qconn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
-       qconn.EXPECT().Context().Return(context.Background()).AnyTimes()
-       conn := newConnection(context.Background(), qconn, false, protocol.PerspectiveServer, nil, 0)
+       s := &Server{Handler: handler}
+       go s.ServeQUICConn(serverConn)
 
-       s.handleRequest(conn, str, qpack.NewDecoder(nil))
-       hfs := decodeHeader(t, responseBuf)
-       fp := frameParser{r: responseBuf}
+       hfs := decodeHeader(t, str)
+       fp := frameParser{r: str}
        var content []byte
        for {
                frame, err := fp.ParseNext()
@@ -258,13 +226,39 @@ func testServerRequestHandling(t *testing.T,
                require.NoError(t, err)
                require.IsType(t, &dataFrame{}, frame)
                b := make([]byte, frame.(*dataFrame).Length)
-               _, err = io.ReadFull(responseBuf, b)
+               _, err = io.ReadFull(str, b)
                require.NoError(t, err)
                content = append(content, b...)
        }
        return hfs, content
 }
 
+func TestServerFirstFrameNotHeaders(t *testing.T) {
+       clientConn, serverConn := getConnPair(t)
+       str, err := clientConn.OpenStream()
+       require.NoError(t, err)
+
+       var buf bytes.Buffer
+       buf.Write((&dataFrame{Length: 6}).Append(nil))
+       buf.Write([]byte("foobar"))
+       _, err = str.Write(buf.Bytes())
+       require.NoError(t, err)
+       require.NoError(t, str.Close())
+
+       s := &Server{}
+       go s.ServeQUICConn(serverConn)
+
+       select {
+       case <-clientConn.Context().Done():
+               err := context.Cause(clientConn.Context())
+               var appErr *quic.ApplicationError
+               require.ErrorAs(t, err, &appErr)
+               require.Equal(t, quic.ApplicationErrorCode(ErrCodeFrameUnexpected), appErr.ErrorCode)
+       case <-time.After(time.Second):
+               t.Fatal("timeout")
+       }
+}
+
 func TestServerHandlerBodyNotRead(t *testing.T) {
        t.Run("GET request with a body", func(t *testing.T) {
                testServerHandlerBodyNotRead(t,
@@ -293,83 +287,74 @@ func TestServerHandlerBodyNotRead(t *testing.T) {
        })
 }
 
-func TestServerFirstFrameNotHeaders(t *testing.T) {
-       mockCtrl := gomock.NewController(t)
-       str := NewMockDatagramStream(mockCtrl)
-       str.EXPECT().Write(gomock.Any()).AnyTimes()
-       str.EXPECT().Context().Return(context.Background()).AnyTimes()
-       var buf bytes.Buffer
-       buf.Write((&dataFrame{Length: 6}).Append(nil))
-       buf.Write([]byte("foobar"))
-       str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
-
-       s := &Server{TLSConfig: testdata.GetTLSConfig()}
-
-       qconn := mockquic.NewMockEarlyConnection(mockCtrl)
-       qconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
-       qconn.EXPECT().LocalAddr().AnyTimes()
-       qconn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any())
-       conn := newConnection(context.Background(), qconn, false, protocol.PerspectiveServer, nil, 0)
-
-       s.handleRequest(conn, str, qpack.NewDecoder(nil))
-}
-
 func testServerHandlerBodyNotRead(t *testing.T, req *http.Request, handler http.HandlerFunc) {
-       mockCtrl := gomock.NewController(t)
-       str := NewMockDatagramStream(mockCtrl)
-       str.EXPECT().Write(gomock.Any()).AnyTimes()
-       str.EXPECT().Context().Return(context.Background()).AnyTimes()
-       str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError))
-       str.EXPECT().Close().MaxTimes(1)
-       str.EXPECT().Read(gomock.Any()).DoAndReturn(encodeRequest(t, req).Read).AnyTimes()
+       clientConn, serverConn := getConnPair(t)
+       str, err := clientConn.OpenStream()
+       require.NoError(t, err)
+       _, err = str.Write(encodeRequest(t, req))
+       require.NoError(t, err)
+       // require.NoError(t, str.Close())
 
-       var called bool
+       done := make(chan struct{})
        s := &Server{
-               TLSConfig: testdata.GetTLSConfig(),
                Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-                       called = true
+                       defer close(done)
                        handler(w, r)
                }),
        }
 
-       qconn := mockquic.NewMockEarlyConnection(mockCtrl)
-       qconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
-       qconn.EXPECT().LocalAddr().AnyTimes()
-       qconn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
-       qconn.EXPECT().Context().Return(context.Background()).AnyTimes()
-       conn := newConnection(context.Background(), qconn, false, protocol.PerspectiveServer, nil, 0)
+       go s.ServeQUICConn(serverConn)
+
+       select {
+       case <-done:
+       case <-time.After(time.Second):
+               t.Fatal("timeout")
+       }
+}
+
+func expectStreamReadReset(t *testing.T, str quic.ReceiveStream, errCode quic.StreamErrorCode) {
+       t.Helper()
+       str.SetReadDeadline(time.Now().Add(time.Second))
+       _, err := str.Read([]byte{0})
+       require.Error(t, err)
+       if errors.Is(err, os.ErrDeadlineExceeded) {
+               t.Fatal("didn't receive a stream reset")
+       }
+       var strErr *quic.StreamError
+       require.ErrorAs(t, err, &strErr)
+       require.Equal(t, errCode, strErr.ErrorCode)
+}
 
-       s.handleRequest(conn, str, qpack.NewDecoder(nil))
-       require.True(t, called)
+func expectStreamWriteReset(t *testing.T, str quic.SendStream, errCode quic.StreamErrorCode) {
+       t.Helper()
+       select {
+       case <-str.Context().Done():
+       case <-time.After(time.Second):
+               t.Fatal("timeout")
+       }
+       _, err := str.Write([]byte{0})
+       require.Error(t, err)
+       var strErr *quic.StreamError
+       require.ErrorAs(t, err, &strErr)
+       require.Equal(t, errCode, strErr.ErrorCode)
 }
 
 func TestServerStreamResetByClient(t *testing.T) {
-       mockCtrl := gomock.NewController(t)
-       str := NewMockDatagramStream(mockCtrl)
-       done := make(chan struct{})
-       str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete))
-       str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)).Do(func(quic.StreamErrorCode) { close(done) })
-       str.EXPECT().Read(gomock.Any()).Return(0, assert.AnError)
+       clientConn, serverConn := getConnPair(t)
+       str, err := clientConn.OpenStream()
+       require.NoError(t, err)
+       str.CancelWrite(1337)
 
        var called bool
        s := &Server{
-               TLSConfig: testdata.GetTLSConfig(),
                Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
                        called = true
                }),
        }
 
-       qconn := mockquic.NewMockEarlyConnection(mockCtrl)
-       qconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
-       qconn.EXPECT().LocalAddr().AnyTimes()
-       conn := newConnection(context.Background(), qconn, false, protocol.PerspectiveServer, nil, 0)
+       go s.ServeQUICConn(serverConn)
 
-       s.handleRequest(conn, str, qpack.NewDecoder(nil))
-       select {
-       case <-done:
-       case <-time.After(time.Second):
-               t.Fatal("timeout")
-       }
+       expectStreamReadReset(t, str, quic.StreamErrorCode(ErrCodeRequestIncomplete))
        require.False(t, called)
 }
 
@@ -392,30 +377,24 @@ func TestServerPanickingHandler(t *testing.T) {
 }
 
 func testServerPanickingHandler(t *testing.T, handler http.HandlerFunc) (logOutput string) {
-       mockCtrl := gomock.NewController(t)
-       str := NewMockDatagramStream(mockCtrl)
-       str.EXPECT().Context().Return(context.Background()).AnyTimes()
-       str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeInternalError))
-       str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeInternalError))
-       str.EXPECT().Read(gomock.Any()).DoAndReturn(
-               encodeRequest(t, httptest.NewRequest(http.MethodHead, "https://www.example.com", nil)).Read,
-       ).AnyTimes()
+       clientConn, serverConn := getConnPair(t)
+       str, err := clientConn.OpenStream()
+       require.NoError(t, err)
+       _, err = str.Write(encodeRequest(t, httptest.NewRequest(http.MethodHead, "https://www.example.com", nil)))
+       require.NoError(t, err)
+       require.NoError(t, str.Close())
 
        var logBuf bytes.Buffer
        s := &Server{
-               TLSConfig: testdata.GetTLSConfig(),
-               Handler:   handler,
-               Logger:    slog.New(slog.NewTextHandler(&logBuf, nil)),
+               Handler: handler,
+               Logger:  slog.New(slog.NewTextHandler(&logBuf, nil)),
        }
 
-       qconn := mockquic.NewMockEarlyConnection(mockCtrl)
-       qconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
-       qconn.EXPECT().LocalAddr().AnyTimes()
-       qconn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
-       qconn.EXPECT().Context().Return(context.Background()).AnyTimes()
-       conn := newConnection(context.Background(), qconn, false, protocol.PerspectiveServer, nil, 0)
+       go s.ServeQUICConn(serverConn)
+
+       expectStreamReadReset(t, str, quic.StreamErrorCode(ErrCodeInternalError))
+       s.Close()
 
-       s.handleRequest(conn, str, qpack.NewDecoder(nil))
        return logBuf.String()
 }
 
@@ -441,92 +420,44 @@ func TestServerRequestHeaderTooLarge(t *testing.T) {
 func testServerRequestHeaderTooLarge(t *testing.T, req *http.Request, maxHeaderBytes int) {
        var called bool
        s := &Server{
-               TLSConfig:      testdata.GetTLSConfig(),
                MaxHeaderBytes: maxHeaderBytes,
                Handler:        http.HandlerFunc(func(http.ResponseWriter, *http.Request) { called = true }),
        }
        s.init()
 
-       done := make(chan struct{}, 2)
-       mockCtrl := gomock.NewController(t)
-       str := mockquic.NewMockStream(mockCtrl)
-       str.EXPECT().Context().Return(context.Background()).AnyTimes()
-       str.EXPECT().StreamID().AnyTimes()
-       str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { done <- struct{}{} })
-       str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { done <- struct{}{} })
-       str.EXPECT().Read(gomock.Any()).DoAndReturn(encodeRequest(t, req).Read).AnyTimes()
+       clientConn, serverConn := getConnPair(t)
+       str, err := clientConn.OpenStream()
+       require.NoError(t, err)
+       _, err = str.Write(encodeRequest(t, req))
+       require.NoError(t, err)
+       require.NoError(t, str.Close())
+
+       go s.ServeQUICConn(serverConn)
+
+       expectStreamReadReset(t, str, quic.StreamErrorCode(ErrCodeFrameError))
+       expectStreamWriteReset(t, str, quic.StreamErrorCode(ErrCodeFrameError))
 
-       testDone := make(chan struct{})
-       conn := mockquic.NewMockEarlyConnection(mockCtrl)
-       controlStr := mockquic.NewMockStream(mockCtrl)
-       controlStr.EXPECT().Write(gomock.Any())
-       conn.EXPECT().OpenUniStream().Return(controlStr, nil)
-       conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
-               <-testDone
-               return nil, assert.AnError
-       }).MaxTimes(1)
-       conn.EXPECT().AcceptStream(gomock.Any()).Return(str, nil)
-       conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, assert.AnError)
-       conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
-       conn.EXPECT().LocalAddr().AnyTimes()
-       conn.EXPECT().Context().Return(context.Background()).AnyTimes()
-
-       s.handleConn(conn)
-       for range 2 {
-               select {
-               case <-done:
-               case <-time.After(time.Second):
-                       t.Fatal("timeout")
-               }
-       }
        require.False(t, called)
 }
 
 func TestServerRequestContext(t *testing.T) {
-       responseBuf := &bytes.Buffer{}
-       mockCtrl := gomock.NewController(t)
-       str := mockquic.NewMockStream(mockCtrl)
-       strCtx, strCtxCancel := context.WithCancel(context.Background())
-       str.EXPECT().StreamID().AnyTimes()
-       str.EXPECT().Context().Return(strCtx).AnyTimes()
-       str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
-       str.EXPECT().CancelRead(gomock.Any())
-       str.EXPECT().Close()
-       str.EXPECT().Read(gomock.Any()).DoAndReturn(
-               encodeRequest(t, httptest.NewRequest(http.MethodGet, "https://www.example.com", nil)).Read,
-       ).AnyTimes()
+       clientConn, serverConn := getConnPair(t)
+       str, err := clientConn.OpenStream()
+       require.NoError(t, err)
+       _, err = str.Write(encodeRequest(t, httptest.NewRequest(http.MethodHead, "https://www.example.com", nil)))
+       require.NoError(t, err)
 
        ctxChan := make(chan context.Context, 1)
+       block := make(chan struct{})
        s := &Server{
-               TLSConfig: testdata.GetTLSConfig(),
                Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
                        ctxChan <- r.Context()
+                       <-block
                }),
        }
-       s.init()
 
-       testDone := make(chan struct{})
-       defer close(testDone)
-       controlStr := mockquic.NewMockStream(mockCtrl)
-       controlStr.EXPECT().Write(gomock.Any()).AnyTimes()
-       conn := mockquic.NewMockEarlyConnection(mockCtrl)
-       conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1337}).AnyTimes()
-       conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(192, 168, 1, 2), Port: 42}).AnyTimes()
-       conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
-       connCtx := context.WithValue(context.Background(), "connection context", "connection context value")
-       conn.EXPECT().Context().Return(connCtx).AnyTimes()
-       conn.EXPECT().OpenUniStream().Return(str, nil)
-       conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
-               <-testDone
-               return nil, assert.AnError
-       }).MaxTimes(1)
-       conn.EXPECT().AcceptStream(gomock.Any()).Return(str, nil)
-       conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
-               <-testDone
-               return nil, assert.AnError
-       }).MaxTimes(1)
-
-       go s.handleConn(conn)
+       go s.ServeQUICConn(serverConn)
+
        var requestContext context.Context
        select {
        case requestContext = <-ctxChan:
@@ -534,54 +465,58 @@ func TestServerRequestContext(t *testing.T) {
                t.Fatal("timeout")
        }
 
-       require.Equal(t, "connection context value", requestContext.Value("connection context"))
        require.Equal(t, s, requestContext.Value(ServerContextKey))
-       require.Equal(t, &net.UDPAddr{IP: net.IPv4(192, 168, 1, 2), Port: 42}, requestContext.Value(http.LocalAddrContextKey))
-       require.Equal(t, &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1337}, requestContext.Value(RemoteAddrContextKey))
+       require.Equal(t, serverConn.LocalAddr(), requestContext.Value(http.LocalAddrContextKey))
+       require.Equal(t, serverConn.RemoteAddr(), requestContext.Value(RemoteAddrContextKey))
        select {
        case <-requestContext.Done():
                t.Fatal("request context was canceled")
        case <-time.After(scaleDuration(10 * time.Millisecond)):
        }
 
-       strCtxCancel()
+       str.CancelRead(1337)
+
        select {
        case <-requestContext.Done():
        case <-time.After(time.Second):
                t.Fatal("timeout")
        }
        require.Equal(t, context.Canceled, requestContext.Err())
+       close(block)
 }
 
 func TestServerHTTPStreamHijacking(t *testing.T) {
-       responseBuf := &bytes.Buffer{}
-       mockCtrl := gomock.NewController(t)
-       str := NewMockDatagramStream(mockCtrl)
-       str.EXPECT().Context().Return(context.Background()).AnyTimes()
-       str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
-       str.EXPECT().Read(gomock.Any()).DoAndReturn(
-               encodeRequest(t, httptest.NewRequest(http.MethodGet, "https://www.example.com", nil)).Read,
-       ).AnyTimes()
+       clientConn, serverConn := getConnPair(t)
+       str, err := clientConn.OpenStream()
+       require.NoError(t, err)
+       _, err = str.Write(encodeRequest(t, httptest.NewRequest(http.MethodHead, "https://www.example.com", nil)))
+       require.NoError(t, err)
+       require.NoError(t, str.Close())
 
        s := &Server{
-               TLSConfig: testdata.GetTLSConfig(),
                Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-                       w.(HTTPStreamer).HTTPStream()
+                       str := w.(HTTPStreamer).HTTPStream()
                        str.Write([]byte("foobar"))
+                       str.Close()
                }),
        }
+       go s.ServeQUICConn(serverConn)
 
-       qconn := mockquic.NewMockEarlyConnection(mockCtrl)
-       qconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
-       qconn.EXPECT().LocalAddr().AnyTimes()
-       qconn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
-       qconn.EXPECT().Context().Return(context.Background()).AnyTimes()
-       conn := newConnection(context.Background(), qconn, false, protocol.PerspectiveServer, nil, 0)
-
-       s.handleRequest(conn, str, qpack.NewDecoder(nil))
-       hfs := decodeHeader(t, responseBuf)
+       str.SetReadDeadline(time.Now().Add(time.Second))
+       rsp, err := io.ReadAll(str)
+       require.NoError(t, err)
+       r := bytes.NewReader(rsp)
+       hfs := decodeHeader(t, r)
        require.Equal(t, hfs[":status"], []string{"200"})
-       require.Equal(t, []byte("foobar"), responseBuf.Bytes())
+       fp := frameParser{r: r}
+       frame, err := fp.ParseNext()
+       require.NoError(t, err)
+       require.IsType(t, &dataFrame{}, frame)
+       dataFrame := frame.(*dataFrame)
+       require.Equal(t, uint64(6), dataFrame.Length)
+       data, err := io.ReadAll(r)
+       require.NoError(t, err)
+       require.Equal(t, []byte("foobar"), data)
 }
 
 func TestServerStreamHijacking(t *testing.T) {
@@ -605,7 +540,6 @@ func TestServerStreamHijacking(t *testing.T) {
 }
 
 func testServerHijackBidirectionalStream(t *testing.T, bidirectional bool, doHijack bool, hijackErr error) {
-       id := quic.ConnectionTracingID(1337)
        type hijackCall struct {
                ft            FrameType  // for bidirectional streams
                st            StreamType // for unidirectional streams
@@ -615,7 +549,6 @@ func testServerHijackBidirectionalStream(t *testing.T, bidirectional bool, doHij
        hijackChan := make(chan hijackCall, 1)
        testDone := make(chan struct{})
        s := &Server{
-               TLSConfig: testdata.GetTLSConfig(),
                StreamHijacker: func(ft FrameType, connTracingID quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) {
                        defer close(testDone)
                        hijackChan <- hijackCall{ft: ft, connTracingID: connTracingID, e: e}
@@ -627,44 +560,33 @@ func testServerHijackBidirectionalStream(t *testing.T, bidirectional bool, doHij
                        return doHijack
                },
        }
-       s.init()
+
+       clientConn, serverConn := getConnPair(t)
+       go s.ServeQUICConn(serverConn)
 
        buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
-       mockCtrl := gomock.NewController(t)
-       unknownStr := mockquic.NewMockStream(mockCtrl)
-       unknownStr.EXPECT().Context().Return(context.Background()).AnyTimes()
-       unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
-       unknownStr.EXPECT().StreamID().AnyTimes()
-       if !doHijack || hijackErr != nil {
-               if bidirectional {
-                       unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete))
-                       unknownStr.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete))
-               } else {
-                       unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError))
-               }
-       }
-       conn := mockquic.NewMockEarlyConnection(mockCtrl)
        if bidirectional {
-               conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
+               str, err := clientConn.OpenStream()
+               require.NoError(t, err)
+               _, err = str.Write(buf.Bytes())
+               require.NoError(t, err)
+
+               if hijackErr != nil {
+                       expectStreamReadReset(t, str, quic.StreamErrorCode(ErrCodeRequestIncomplete))
+                       expectStreamWriteReset(t, str, quic.StreamErrorCode(ErrCodeRequestIncomplete))
+               }
+               // if the stream is not hijacked, the frame parser will skip the frame
        } else {
-               conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil)
+               str, err := clientConn.OpenUniStream()
+               require.NoError(t, err)
+               _, err = str.Write(buf.Bytes())
+               require.NoError(t, err)
+
+               if !doHijack || hijackErr != nil {
+                       expectStreamWriteReset(t, str, quic.StreamErrorCode(ErrCodeStreamCreationError))
+               }
        }
-       conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
-               <-testDone
-               return nil, assert.AnError
-       })
-       controlStr := mockquic.NewMockStream(mockCtrl)
-       controlStr.EXPECT().Write(gomock.Any())
-       conn.EXPECT().OpenUniStream().Return(controlStr, nil)
-       conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
-       conn.EXPECT().LocalAddr().AnyTimes()
-       conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
-               <-testDone
-               return nil, assert.AnError
-       })
-       ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, id)
-       conn.EXPECT().Context().Return(ctx).AnyTimes()
-       s.handleConn(conn)
+
        select {
        case hijackCall := <-hijackChan:
                if bidirectional {
@@ -674,7 +596,7 @@ func testServerHijackBidirectionalStream(t *testing.T, bidirectional bool, doHij
                        assert.Equal(t, hijackCall.st, StreamType(0x41))
                        assert.Zero(t, hijackCall.ft)
                }
-               assert.Equal(t, hijackCall.connTracingID, id)
+               assert.Equal(t, serverConn.Context().Value(quic.ConnectionTracingKey), hijackCall.connTracingID)
                assert.NoError(t, hijackCall.e)
        case <-time.After(time.Second):
                t.Fatal("hijack call not received")
@@ -705,13 +627,13 @@ func TestServerAltSvcFromListenersAndConns(t *testing.T) {
 }
 
 func testServerAltSvcFromListenersAndConns(t *testing.T, versions []quic.Version) {
-       ln1, err := quic.ListenEarly(newUDPConnLocalhost(t), testdata.GetTLSConfig(), nil)
+       ln1, err := quic.ListenEarly(newUDPConnLocalhost(t), getTLSConfig(), nil)
        require.NoError(t, err)
        port1 := ln1.Addr().(*net.UDPAddr).Port
 
        s := &Server{
                Addr:       ":1337", // will be ignored since we're using listeners
-               TLSConfig:  testdata.GetTLSConfig(),
+               TLSConfig:  getTLSConfig(),
                QUICConfig: &quic.Config{Versions: versions},
        }
        done1 := make(chan struct{})
@@ -767,7 +689,7 @@ func TestServerAltSvcFromPort(t *testing.T) {
        _, ok := getAltSvc(s)
        require.False(t, ok)
 
-       ln, err := quic.ListenEarly(newUDPConnLocalhost(t), testdata.GetTLSConfig(), nil)
+       ln, err := quic.ListenEarly(newUDPConnLocalhost(t), getTLSConfig(), nil)
        require.NoError(t, err)
        done := make(chan struct{})
        go func() {
@@ -850,7 +772,7 @@ func TestServerListenAndServeErrors(t *testing.T) {
 }
 
 func TestServerClosing(t *testing.T) {
-       s := &Server{TLSConfig: testdata.GetTLSConfig()}
+       s := &Server{TLSConfig: getTLSConfig()}
        require.NoError(t, s.Close())
        require.NoError(t, s.Close()) // duplicate calls are ok
        require.ErrorIs(t, s.ListenAndServe(), http.ErrServerClosed)
@@ -893,91 +815,46 @@ func TestServerImmediateGracefulShutdown(t *testing.T) {
 }
 
 func TestServerGracefulShutdown(t *testing.T) {
-       s := &Server{TLSConfig: testdata.GetTLSConfig()}
-       s.init()
-
-       mockCtrl := gomock.NewController(t)
-       controlStr := mockquic.NewMockStream(mockCtrl)
-       controlStrChan := make(chan []byte, 1)
-       controlStr.EXPECT().Write(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
-               controlStrChan <- b
-               return len(b), nil
-       }).AnyTimes()
-
-       streamChan := make(chan quic.Stream, 1)
-       testDone := make(chan struct{})
-       conn := mockquic.NewMockEarlyConnection(mockCtrl)
-       conn.EXPECT().OpenUniStream().Return(controlStr, nil)
-       conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
-               <-testDone
-               return nil, assert.AnError
-       }).MaxTimes(1)
-       conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(ctx context.Context) (quic.Stream, error) {
-               select {
-               case <-ctx.Done():
-                       return nil, ctx.Err()
-               case str, ok := <-streamChan:
-                       if !ok {
-                               return nil, assert.AnError
-                       }
-                       return str, nil
-               }
-       }).AnyTimes()
-       conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
-       conn.EXPECT().LocalAddr().AnyTimes()
-       conn.EXPECT().Context().Return(context.Background()).AnyTimes()
-
-       firstStream := mockquic.NewMockStream(mockCtrl)
-       firstStreamAccepted := make(chan struct{}, 1)
-       firstStream.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
-               firstStreamAccepted <- struct{}{}
-               <-testDone
-               return 0, assert.AnError
-       })
-       firstStream.EXPECT().StreamID().Return(quic.StreamID(1337)).AnyTimes()
-       firstStream.EXPECT().Context().Return(context.Background()).AnyTimes()
-       streamChan <- firstStream
-
-       go s.ServeQUICConn(conn)
+       requestChan := make(chan struct{}, 1)
+       s := &Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+               requestChan <- struct{}{}
+       })}
 
-       var r bytes.Buffer
-       fp := &frameParser{r: &r}
+       clientConn, serverConn := getConnPair(t)
+       go s.ServeQUICConn(serverConn)
 
-       select {
-       case b := <-controlStrChan:
-               _, l, err := quicvarint.Parse(b)
-               require.NoError(t, err)
-               r.Write(b[l:])
-               f, err := fp.ParseNext()
-               require.NoError(t, err)
-               require.IsType(t, &settingsFrame{}, f)
-       case <-time.After(time.Second):
-               t.Fatal("timeout")
-       }
+       firstStream, err := clientConn.OpenStream()
+       require.NoError(t, err)
+       _, err = firstStream.Write(encodeRequest(t, httptest.NewRequest(http.MethodGet, "https://www.example.com", nil)))
+       require.NoError(t, err)
 
        select {
-       case <-firstStreamAccepted:
+       case <-requestChan:
        case <-time.After(time.Second):
                t.Fatal("timeout")
        }
 
-       time.Sleep(scaleDuration(10 * time.Millisecond))
+       ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+       defer cancel()
+       controlStr, err := clientConn.AcceptUniStream(ctx)
+       require.NoError(t, err)
+       typ, err := quicvarint.Read(quicvarint.NewReader(controlStr))
+       require.NoError(t, err)
+       require.EqualValues(t, streamTypeControlStream, typ)
+       fp := &frameParser{r: controlStr}
+       f, err := fp.ParseNext()
+       require.NoError(t, err)
+       require.IsType(t, &settingsFrame{}, f)
 
-       ctx, cancel := context.WithCancel(context.Background())
+       shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
        errChan := make(chan error)
        go func() {
-               errChan <- s.Shutdown(ctx)
+               errChan <- s.Shutdown(shutdownCtx)
        }()
 
-       select {
-       case b := <-controlStrChan:
-               r.Write(b)
-               f, err := fp.ParseNext()
-               require.NoError(t, err)
-               require.Equal(t, &goAwayFrame{StreamID: 1337 + 4}, f)
-       case <-time.After(time.Second):
-               t.Fatal("timeout")
-       }
+       f, err = fp.ParseNext()
+       require.NoError(t, err)
+       require.Equal(t, &goAwayFrame{StreamID: 4}, f)
 
        select {
        case <-errChan:
@@ -987,33 +864,15 @@ func TestServerGracefulShutdown(t *testing.T) {
 
        // all further streams are getting rejected
        for range 3 {
-               resetChan := make(chan struct{}, 2)
-               str := mockquic.NewMockStream(mockCtrl)
-               str.EXPECT().StreamID().AnyTimes()
-               str.EXPECT().Context().Return(context.Background()).AnyTimes()
-               str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestRejected)).Do(func(sec quic.StreamErrorCode) {
-                       resetChan <- struct{}{}
-               })
-               str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestRejected)).Do(func(sec quic.StreamErrorCode) {
-                       resetChan <- struct{}{}
-               })
-               streamChan <- str
-
-               for range 2 {
-                       select {
-                       case <-resetChan:
-                       case <-time.After(time.Second):
-                               t.Fatal("expected stream reset")
-                       }
-               }
+               str, err := clientConn.OpenStream()
+               require.NoError(t, err)
+               _, _ = str.Write(encodeRequest(t, httptest.NewRequest(http.MethodGet, "https://www.example.com", nil)))
+               expectStreamReadReset(t, str, quic.StreamErrorCode(ErrCodeRequestRejected))
+               expectStreamWriteReset(t, str, quic.StreamErrorCode(ErrCodeRequestRejected))
        }
 
        // cancel the context passed to Shutdown
-       conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), gomock.Any())
-       cancel()
-       firstStream.EXPECT().CancelRead(gomock.Any())
-       firstStream.EXPECT().CancelWrite(gomock.Any())
-       close(testDone)
+       shutdownCancel()
 
        select {
        case err := <-errChan: