]> git.feebdaed.xyz Git - 0xmirror/quic-go.git/commitdiff
http3: use actual QUIC connection and stream in conn tests (#5163)
authorMarten Seemann <martenseemann@gmail.com>
Thu, 29 May 2025 12:42:36 +0000 (20:42 +0800)
committerGitHub <noreply@github.com>
Thu, 29 May 2025 12:42:36 +0000 (14:42 +0200)
* http3: use actual QUIC connection and stream in conn tests

* add errror assertion

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
http3/conn_test.go
http3/http3_helper_test.go

index 5096c3a9cea68c579705700c526af4ad913fa1f5..34f9fdf8a0da7205d08144120a32422f00b140ec 100644 (file)
@@ -3,29 +3,23 @@ package http3
 import (
        "bytes"
        "context"
-       "errors"
        "io"
        "testing"
        "time"
 
        "github.com/quic-go/quic-go"
-       mockquic "github.com/quic-go/quic-go/internal/mocks/quic"
        "github.com/quic-go/quic-go/internal/protocol"
-       "github.com/quic-go/quic-go/internal/qerr"
        "github.com/quic-go/quic-go/quicvarint"
 
-       "github.com/stretchr/testify/assert"
        "github.com/stretchr/testify/require"
-       "go.uber.org/mock/gomock"
 )
 
 func TestConnReceiveSettings(t *testing.T) {
-       mockCtrl := gomock.NewController(t)
-       qconn := mockquic.NewMockEarlyConnection(mockCtrl)
-       qconn.EXPECT().ReceiveDatagram(gomock.Any()).Return(nil, errors.New("no datagrams")).MaxTimes(1)
+       clientConn, serverConn := newConnPair(t)
+
        conn := newConnection(
-               context.Background(),
-               qconn,
+               serverConn.Context(),
+               serverConn,
                false,
                protocol.PerspectiveServer,
                nil,
@@ -37,11 +31,11 @@ func TestConnReceiveSettings(t *testing.T) {
                ExtendedConnect: true,
                Other:           map[uint64]uint64{1337: 42},
        }).Append(b)
-       r := bytes.NewReader(b)
-       controlStr := mockquic.NewMockStream(mockCtrl)
-       controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
-       qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(controlStr, nil)
-       qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("test done"))
+       controlStr, err := clientConn.OpenUniStream()
+       require.NoError(t, err)
+       _, err = controlStr.Write(b)
+       require.NoError(t, err)
+
        done := make(chan struct{})
        go func() {
                defer close(done)
@@ -71,11 +65,11 @@ func TestConnRejectDuplicateStreams(t *testing.T) {
 }
 
 func testConnRejectDuplicateStreams(t *testing.T, typ uint64) {
-       mockCtrl := gomock.NewController(t)
-       qconn := mockquic.NewMockEarlyConnection(mockCtrl)
+       clientConn, serverConn := newConnPair(t)
+
        conn := newConnection(
                context.Background(),
-               qconn,
+               serverConn,
                false,
                protocol.PerspectiveServer,
                nil,
@@ -85,25 +79,26 @@ func testConnRejectDuplicateStreams(t *testing.T, typ uint64) {
        if typ == streamTypeControlStream {
                b = (&settingsFrame{}).Append(b)
        }
-       controlStr1 := mockquic.NewMockStream(mockCtrl)
-       controlStr1.EXPECT().Read(gomock.Any()).DoAndReturn(bytes.NewReader(b).Read).AnyTimes()
-       controlStr2 := mockquic.NewMockStream(mockCtrl)
-       controlStr2.EXPECT().Read(gomock.Any()).DoAndReturn(bytes.NewReader(b).Read).AnyTimes()
+       controlStr1, err := clientConn.OpenUniStream()
+       require.NoError(t, err)
+       _, err = controlStr1.Write(b)
+       require.NoError(t, err)
+       controlStr2, err := clientConn.OpenUniStream()
+       require.NoError(t, err)
+       _, err = controlStr2.Write(b)
+       require.NoError(t, err)
+
        done := make(chan struct{})
-       closed := make(chan struct{})
-       qconn.EXPECT().CloseWithError(qerr.ApplicationErrorCode(ErrCodeStreamCreationError), gomock.Any()).Do(func(qerr.ApplicationErrorCode, string) error {
-               close(closed)
-               return nil
-       })
-       qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(controlStr1, nil)
-       qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(controlStr2, nil)
-       qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("test done"))
        go func() {
                defer close(done)
                conn.handleUnidirectionalStreams(nil)
        }()
        select {
-       case <-closed:
+       case <-clientConn.Context().Done():
+               require.ErrorIs(t,
+                       context.Cause(clientConn.Context()),
+                       &quic.ApplicationError{Remote: true, ErrorCode: quic.ApplicationErrorCode(ErrCodeStreamCreationError)},
+               )
        case <-time.After(time.Second):
                t.Fatal("timeout waiting for duplicate stream")
        }
@@ -115,38 +110,25 @@ func testConnRejectDuplicateStreams(t *testing.T, typ uint64) {
 }
 
 func TestConnResetUnknownUniStream(t *testing.T) {
-       mockCtrl := gomock.NewController(t)
-       qconn := mockquic.NewMockEarlyConnection(mockCtrl)
+       clientConn, serverConn := newConnPair(t)
+
        conn := newConnection(
                context.Background(),
-               qconn,
+               serverConn,
                false,
                protocol.PerspectiveServer,
                nil,
                0,
        )
        buf := bytes.NewBuffer(quicvarint.Append(nil, 0x1337))
-       str := mockquic.NewMockStream(mockCtrl)
-       str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
-       reset := make(chan struct{})
-       str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)).Do(func(quic.StreamErrorCode) { close(reset) })
-       qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(str, nil)
-       qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("test done"))
-       done := make(chan struct{})
-       go func() {
-               defer close(done)
-               conn.handleUnidirectionalStreams(nil)
-       }()
-       select {
-       case <-reset:
-       case <-time.After(time.Second):
-               t.Fatal("timeout waiting for reset")
-       }
-       select {
-       case <-done:
-       case <-time.After(time.Second):
-               t.Fatal("timeout")
-       }
+       str, err := clientConn.OpenUniStream()
+       require.NoError(t, err)
+       _, err = str.Write(buf.Bytes())
+       require.NoError(t, err)
+
+       go conn.handleUnidirectionalStreams(nil)
+
+       expectStreamWriteReset(t, str, quic.StreamErrorCode(ErrCodeStreamCreationError))
 }
 
 func TestConnControlStreamFailures(t *testing.T) {
@@ -210,47 +192,48 @@ func TestConnGoAwayFailures(t *testing.T) {
 }
 
 func testConnControlStreamFailures(t *testing.T, data []byte, readErr error, expectedErr ErrCode) {
-       mockCtrl := gomock.NewController(t)
-       qconn := mockquic.NewMockEarlyConnection(mockCtrl)
+       clientConn, serverConn := newConnPair(t)
+
        conn := newConnection(
-               context.Background(),
-               qconn,
+               clientConn.Context(),
+               clientConn,
                false,
                protocol.PerspectiveClient,
                nil,
                0,
        )
-       b := quicvarint.Append(nil, streamTypeControlStream)
-       b = append(b, data...)
-       r := bytes.NewReader(b)
-       controlStr := mockquic.NewMockStream(mockCtrl)
-       controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
-               if r.Len() == 0 {
-                       return 0, readErr
-               }
-               return r.Read(b)
-       }).AnyTimes()
-       qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(controlStr, nil)
-       qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("test done"))
-       closed := make(chan struct{})
-
-       str := mockquic.NewMockStream(mockCtrl)
-       str.EXPECT().StreamID().Return(4).AnyTimes()
-       str.EXPECT().Context().Return(context.Background()).AnyTimes()
-       qconn.EXPECT().OpenStreamSync(gomock.Any()).Return(str, nil)
+       controlStr, err := serverConn.OpenUniStream()
+       require.NoError(t, err)
+       _, err = controlStr.Write(quicvarint.Append(nil, streamTypeControlStream))
+       require.NoError(t, err)
+
+       switch readErr {
+       case nil:
+               _, err = controlStr.Write(data)
+               require.NoError(t, err)
+       case io.EOF:
+               _, err = controlStr.Write(data)
+               require.NoError(t, err)
+               require.NoError(t, controlStr.Close())
+       default:
+               // make sure the stream type is received
+               time.Sleep(scaleDuration(10 * time.Millisecond))
+               controlStr.CancelWrite(1337)
+       }
+
        conn.openRequestStream(context.Background(), nil, nil, true, 1000)
 
-       qconn.EXPECT().CloseWithError(quic.ApplicationErrorCode(expectedErr), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error {
-               close(closed)
-               return nil
-       })
        done := make(chan struct{})
        go func() {
                defer close(done)
                conn.handleUnidirectionalStreams(nil)
        }()
        select {
-       case <-closed:
+       case <-serverConn.Context().Done():
+               require.ErrorIs(t,
+                       context.Cause(serverConn.Context()),
+                       &quic.ApplicationError{Remote: true, ErrorCode: quic.ApplicationErrorCode(expectedErr)},
+               )
        case <-time.After(time.Second):
                t.Fatal("timeout waiting for close")
        }
@@ -271,11 +254,11 @@ func TestConnGoAway(t *testing.T) {
 }
 
 func testConnGoAway(t *testing.T, withStream bool) {
-       mockCtrl := gomock.NewController(t)
-       qconn := mockquic.NewMockEarlyConnection(mockCtrl)
+       clientConn, serverConn := newConnPair(t)
+
        conn := newConnection(
-               context.Background(),
-               qconn,
+               clientConn.Context(),
+               clientConn,
                false,
                protocol.PerspectiveClient,
                nil,
@@ -285,72 +268,48 @@ func testConnGoAway(t *testing.T, withStream bool) {
        b = (&settingsFrame{}).Append(b)
        b = (&goAwayFrame{StreamID: 8}).Append(b)
 
-       var mockStr *mockquic.MockStream
        var str quic.Stream
        if withStream {
-               mockStr = mockquic.NewMockStream(mockCtrl)
-               mockStr.EXPECT().StreamID().Return(0).AnyTimes()
-               mockStr.EXPECT().Context().Return(context.Background()).AnyTimes()
-               qconn.EXPECT().OpenStreamSync(gomock.Any()).Return(mockStr, nil)
                s, err := conn.openRequestStream(context.Background(), nil, nil, true, 1000)
                require.NoError(t, err)
                str = s
        }
 
-       done := make(chan struct{})
-       defer close(done)
-       r := bytes.NewReader(b)
-       controlStr := mockquic.NewMockStream(mockCtrl)
-       controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
-               if r.Len() == 0 {
-                       <-done
-                       return 0, errors.New("test done")
-               }
-               return r.Read(b)
-       }).AnyTimes()
-       qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(controlStr, nil)
-       qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("test done"))
-       closed := make(chan struct{})
-       qconn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error {
-               close(closed)
-               return nil
-       })
-       // duplicate calls to CloseWithError are a no-op
-       qconn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).AnyTimes()
+       controlStr, err := serverConn.OpenUniStream()
+       require.NoError(t, err)
+       _, err = controlStr.Write(b)
+       require.NoError(t, err)
+
        go conn.handleUnidirectionalStreams(nil)
 
        // the connection should be closed after the stream is closed
        if withStream {
                select {
-               case <-closed:
+               case <-serverConn.Context().Done():
                        t.Fatal("connection closed")
                case <-time.After(scaleDuration(10 * time.Millisecond)):
                }
 
                // The stream ID in the GOAWAY frame is 8, so it's possible to open stream 4.
-               mockStr2 := mockquic.NewMockStream(mockCtrl)
-               mockStr2.EXPECT().StreamID().Return(4).AnyTimes()
-               mockStr2.EXPECT().Context().Return(context.Background()).AnyTimes()
-               qconn.EXPECT().OpenStreamSync(gomock.Any()).Return(mockStr2, nil)
                str2, err := conn.openRequestStream(context.Background(), nil, nil, true, 1000)
                require.NoError(t, err)
-               mockStr2.EXPECT().Close()
                str2.Close()
-               mockStr2.EXPECT().CancelRead(gomock.Any())
                str2.CancelRead(1337)
 
                // It's not possible to open stream 8.
                _, err = conn.openRequestStream(context.Background(), nil, nil, true, 1000)
                require.ErrorIs(t, err, errGoAway)
 
-               mockStr.EXPECT().Close()
                str.Close()
-               mockStr.EXPECT().CancelRead(gomock.Any())
                str.CancelRead(1337)
        }
 
        select {
-       case <-closed:
+       case <-serverConn.Context().Done():
+               require.ErrorIs(t,
+                       context.Cause(serverConn.Context()),
+                       &quic.ApplicationError{Remote: true, ErrorCode: quic.ApplicationErrorCode(ErrCodeNoError)},
+               )
        case <-time.After(time.Second):
                t.Fatal("timeout waiting for close")
        }
@@ -366,33 +325,33 @@ func TestConnRejectPushStream(t *testing.T) {
 }
 
 func testConnRejectPushStream(t *testing.T, pers protocol.Perspective, expectedErr ErrCode) {
-       mockCtrl := gomock.NewController(t)
-       qconn := mockquic.NewMockEarlyConnection(mockCtrl)
+       clientConn, serverConn := newConnPair(t)
+
        conn := newConnection(
-               context.Background(),
-               qconn,
+               clientConn.Context(),
+               clientConn,
                false,
                pers.Opposite(),
                nil,
                0,
        )
        buf := bytes.NewBuffer(quicvarint.Append(nil, streamTypePushStream))
-       controlStr := mockquic.NewMockStream(mockCtrl)
-       controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
-       qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(controlStr, nil)
-       qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("test done"))
-       closed := make(chan struct{})
-       qconn.EXPECT().CloseWithError(quic.ApplicationErrorCode(expectedErr), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error {
-               close(closed)
-               return nil
-       })
+       controlStr, err := serverConn.OpenUniStream()
+       require.NoError(t, err)
+       _, err = controlStr.Write(buf.Bytes())
+       require.NoError(t, err)
+
        done := make(chan struct{})
        go func() {
                defer close(done)
                conn.handleUnidirectionalStreams(nil)
        }()
        select {
-       case <-closed:
+       case <-serverConn.Context().Done():
+               require.ErrorIs(t,
+                       context.Cause(serverConn.Context()),
+                       &quic.ApplicationError{Remote: true, ErrorCode: quic.ApplicationErrorCode(expectedErr)},
+               )
        case <-time.After(time.Second):
                t.Fatal("timeout waiting for close")
        }
@@ -404,11 +363,11 @@ func testConnRejectPushStream(t *testing.T, pers protocol.Perspective, expectedE
 }
 
 func TestConnInconsistentDatagramSupport(t *testing.T) {
-       mockCtrl := gomock.NewController(t)
-       qconn := mockquic.NewMockEarlyConnection(mockCtrl)
+       clientConn, serverConn := newConnPair(t)
+
        conn := newConnection(
-               context.Background(),
-               qconn,
+               clientConn.Context(),
+               clientConn,
                true,
                protocol.PerspectiveClient,
                nil,
@@ -416,39 +375,29 @@ func TestConnInconsistentDatagramSupport(t *testing.T) {
        )
        b := quicvarint.Append(nil, streamTypeControlStream)
        b = (&settingsFrame{Datagram: true}).Append(b)
-       controlStr := mockquic.NewMockStream(mockCtrl)
-       controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(bytes.NewReader(b).Read).AnyTimes()
-       qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(controlStr, nil)
-       qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("test done"))
-       qconn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: false})
-       closed := make(chan struct{})
-       qconn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support").Do(func(quic.ApplicationErrorCode, string) error {
-               close(closed)
-               return nil
-       })
-       done := make(chan struct{})
-       go func() {
-               defer close(done)
-               conn.handleUnidirectionalStreams(nil)
-       }()
+       controlStr, err := serverConn.OpenUniStream()
+       require.NoError(t, err)
+       _, err = controlStr.Write(b)
+       require.NoError(t, err)
+
+       go conn.handleUnidirectionalStreams(nil)
+
        select {
-       case <-closed:
+       case <-serverConn.Context().Done():
+               err := context.Cause(serverConn.Context())
+               require.ErrorIs(t, err, &quic.ApplicationError{Remote: true, ErrorCode: quic.ApplicationErrorCode(ErrCodeSettingsError)})
+               require.ErrorContains(t, err, "missing QUIC Datagram support")
        case <-time.After(time.Second):
                t.Fatal("timeout waiting for close")
        }
-       select {
-       case <-done:
-       case <-time.After(time.Second):
-               t.Fatal("timeout")
-       }
 }
 
 func TestConnSendAndReceiveDatagram(t *testing.T) {
-       mockCtrl := gomock.NewController(t)
-       qconn := mockquic.NewMockEarlyConnection(mockCtrl)
+       clientConn, serverConn := newConnPairWithDatagrams(t)
+
        conn := newConnection(
-               context.Background(),
-               qconn,
+               clientConn.Context(),
+               clientConn,
                true,
                protocol.PerspectiveClient,
                nil,
@@ -456,53 +405,33 @@ func TestConnSendAndReceiveDatagram(t *testing.T) {
        )
        b := quicvarint.Append(nil, streamTypeControlStream)
        b = (&settingsFrame{Datagram: true}).Append(b)
-       r := bytes.NewReader(b)
-       done := make(chan struct{})
-       defer close(done)
-       controlStr := mockquic.NewMockStream(mockCtrl)
-       controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
-               if r.Len() == 0 {
-                       <-done
-               }
-               return r.Read(b)
-       }).AnyTimes()
-       qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(controlStr, nil).MaxTimes(1)
-       qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("test done")).MaxTimes(1)
-       qconn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: true}).MaxTimes(1)
+       controlStr, err := serverConn.OpenUniStream()
+       require.NoError(t, err)
+       _, err = controlStr.Write(b)
+       require.NoError(t, err)
+
+       go conn.handleUnidirectionalStreams(nil)
 
        const strID = 4
 
        // first deliver a datagram...
        // since the stream is not open yet, it will be dropped
        quarterStreamID := quicvarint.Append([]byte{}, strID/4)
-       delivered := make(chan struct{})
-       qconn.EXPECT().ReceiveDatagram(gomock.Any()).DoAndReturn(func(context.Context) ([]byte, error) {
-               close(delivered)
-               return append(quarterStreamID, []byte("foo")...), nil
-       })
-       streamOpened := make(chan struct{})
-       qconn.EXPECT().ReceiveDatagram(gomock.Any()).DoAndReturn(func(context.Context) ([]byte, error) {
-               <-streamOpened
-               return append(quarterStreamID, []byte("bar")...), nil
-       }).MaxTimes(1)
-       qconn.EXPECT().ReceiveDatagram(gomock.Any()).Return(nil, errors.New("test done")).MaxTimes(1)
-       go func() { conn.handleUnidirectionalStreams(nil) }()
-       select {
-       case <-delivered:
-       case <-time.After(time.Second):
-               t.Fatal("timeout waiting for datagram delivery")
-       }
 
-       // now open the stream...
-       qstr := mockquic.NewMockStream(mockCtrl)
-       qstr.EXPECT().StreamID().Return(strID).MinTimes(1)
-       qstr.EXPECT().Context().Return(context.Background()).AnyTimes()
-       qconn.EXPECT().OpenStreamSync(gomock.Any()).Return(qstr, nil)
+       require.NoError(t, serverConn.SendDatagram(append(quarterStreamID, []byte("foo")...)))
+       time.Sleep(scaleDuration(10 * time.Millisecond)) // give the datagram a chance to be delivered
+
+       // don't use stream 0, since that makes it hard to test that the quarter stream ID is used
+       str1, err := conn.openRequestStream(context.Background(), nil, nil, true, 1000)
+       require.NoError(t, err)
+       str1.Close()
+
        str, err := conn.openRequestStream(context.Background(), nil, nil, true, 1000)
        require.NoError(t, err)
+       require.Equal(t, protocol.StreamID(strID), str.StreamID())
 
-       // ... then deliver another datagram
-       close(streamOpened)
+       // now open the stream...
+       require.NoError(t, serverConn.SendDatagram(append(quarterStreamID, []byte("bar")...)))
 
        ctx, cancel := context.WithTimeout(context.Background(), time.Second)
        defer cancel()
@@ -511,13 +440,14 @@ func TestConnSendAndReceiveDatagram(t *testing.T) {
        require.Equal(t, []byte("bar"), data)
 
        // now send a datagram
-       const strID2 = 404
-       expected := quicvarint.Append([]byte{}, strID2/4)
+       str.SendDatagram([]byte("foobaz"))
+
+       expected := quicvarint.Append([]byte{}, strID/4)
        expected = append(expected, []byte("foobaz")...)
-       qconn.EXPECT().SendDatagram(expected).Return(assert.AnError)
-       require.ErrorIs(t, conn.sendDatagram(strID2, []byte("foobaz")), assert.AnError)
 
-       qconn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).AnyTimes()
+       data, err = serverConn.ReceiveDatagram(ctx)
+       require.NoError(t, err)
+       require.Equal(t, expected, data)
 }
 
 func TestConnDatagramFailures(t *testing.T) {
@@ -530,11 +460,11 @@ func TestConnDatagramFailures(t *testing.T) {
 }
 
 func testConnDatagramFailures(t *testing.T, datagram []byte) {
-       mockCtrl := gomock.NewController(t)
-       qconn := mockquic.NewMockEarlyConnection(mockCtrl)
+       clientConn, serverConn := newConnPairWithDatagrams(t)
+
        conn := newConnection(
-               context.Background(),
-               qconn,
+               clientConn.Context(),
+               clientConn,
                true,
                protocol.PerspectiveClient,
                nil,
@@ -542,28 +472,20 @@ func testConnDatagramFailures(t *testing.T, datagram []byte) {
        )
        b := quicvarint.Append(nil, streamTypeControlStream)
        b = (&settingsFrame{Datagram: true}).Append(b)
-       r := bytes.NewReader(b)
-       done := make(chan struct{})
-       controlStr := mockquic.NewMockStream(mockCtrl)
-       controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
-               if r.Len() == 0 {
-                       <-done
-               }
-               return r.Read(b)
-       }).AnyTimes()
-       qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(controlStr, nil).MaxTimes(1)
-       qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("test done")).MaxTimes(1)
-       qconn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: true}).MaxTimes(1)
-
-       qconn.EXPECT().ReceiveDatagram(gomock.Any()).Return(datagram, nil)
-       qconn.EXPECT().CloseWithError(qerr.ApplicationErrorCode(ErrCodeDatagramError), gomock.Any()).Do(func(qerr.ApplicationErrorCode, string) error {
-               close(done)
-               return nil
-       })
-       qconn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).AnyTimes() // further calls to CloseWithError are a no-op
+       controlStr, err := serverConn.OpenUniStream()
+       require.NoError(t, err)
+       _, err = controlStr.Write(b)
+       require.NoError(t, err)
+
+       require.NoError(t, serverConn.SendDatagram(datagram))
+
        go func() { conn.handleUnidirectionalStreams(nil) }()
        select {
-       case <-done:
+       case <-serverConn.Context().Done():
+               require.ErrorIs(t,
+                       context.Cause(serverConn.Context()),
+                       &quic.ApplicationError{Remote: true, ErrorCode: quic.ApplicationErrorCode(ErrCodeDatagramError)},
+               )
        case <-time.After(time.Second):
                t.Fatal("timeout waiting for close")
        }
index cf504e45ed883dfd429819dc07448bb25f68a897..ebf8b815481d4a275e852b2d12d233dcf1a12c89 100644 (file)
@@ -85,12 +85,38 @@ func newConnPair(t *testing.T) (client, server quic.Connection) {
        )
        require.NoError(t, err)
 
-       cl, err := quic.Dial(context.Background(), newUDPConnLocalhost(t), ln.Addr(), getTLSClientConfig(), &quic.Config{})
+       ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+       defer cancel()
+       cl, err := quic.Dial(ctx, newUDPConnLocalhost(t), ln.Addr(), getTLSClientConfig(), &quic.Config{})
        require.NoError(t, err)
        t.Cleanup(func() { cl.CloseWithError(0, "") })
 
+       conn, err := ln.Accept(ctx)
+       require.NoError(t, err)
+       t.Cleanup(func() { conn.CloseWithError(0, "") })
+       return cl, conn
+}
+
+func newConnPairWithDatagrams(t *testing.T) (client, server quic.Connection) {
+       t.Helper()
+
+       ln, err := quic.Listen(
+               newUDPConnLocalhost(t),
+               getTLSConfig(),
+               &quic.Config{
+                       InitialStreamReceiveWindow:     uint64(protocol.MaxByteCount),
+                       InitialConnectionReceiveWindow: uint64(protocol.MaxByteCount),
+                       EnableDatagrams:                true,
+               },
+       )
+       require.NoError(t, err)
+
        ctx, cancel := context.WithTimeout(context.Background(), time.Second)
        defer cancel()
+       cl, err := quic.Dial(ctx, newUDPConnLocalhost(t), ln.Addr(), getTLSClientConfig(), &quic.Config{EnableDatagrams: true})
+       require.NoError(t, err)
+       t.Cleanup(func() { cl.CloseWithError(0, "") })
+
        conn, err := ln.Accept(ctx)
        require.NoError(t, err)
        t.Cleanup(func() { conn.CloseWithError(0, "") })