import (
"bytes"
"context"
- "errors"
"io"
"testing"
"time"
"github.com/quic-go/quic-go"
- mockquic "github.com/quic-go/quic-go/internal/mocks/quic"
"github.com/quic-go/quic-go/internal/protocol"
- "github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/quicvarint"
- "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
- "go.uber.org/mock/gomock"
)
func TestConnReceiveSettings(t *testing.T) {
- mockCtrl := gomock.NewController(t)
- qconn := mockquic.NewMockEarlyConnection(mockCtrl)
- qconn.EXPECT().ReceiveDatagram(gomock.Any()).Return(nil, errors.New("no datagrams")).MaxTimes(1)
+ clientConn, serverConn := newConnPair(t)
+
conn := newConnection(
- context.Background(),
- qconn,
+ serverConn.Context(),
+ serverConn,
false,
protocol.PerspectiveServer,
nil,
ExtendedConnect: true,
Other: map[uint64]uint64{1337: 42},
}).Append(b)
- r := bytes.NewReader(b)
- controlStr := mockquic.NewMockStream(mockCtrl)
- controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
- qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(controlStr, nil)
- qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("test done"))
+ controlStr, err := clientConn.OpenUniStream()
+ require.NoError(t, err)
+ _, err = controlStr.Write(b)
+ require.NoError(t, err)
+
done := make(chan struct{})
go func() {
defer close(done)
}
func testConnRejectDuplicateStreams(t *testing.T, typ uint64) {
- mockCtrl := gomock.NewController(t)
- qconn := mockquic.NewMockEarlyConnection(mockCtrl)
+ clientConn, serverConn := newConnPair(t)
+
conn := newConnection(
context.Background(),
- qconn,
+ serverConn,
false,
protocol.PerspectiveServer,
nil,
if typ == streamTypeControlStream {
b = (&settingsFrame{}).Append(b)
}
- controlStr1 := mockquic.NewMockStream(mockCtrl)
- controlStr1.EXPECT().Read(gomock.Any()).DoAndReturn(bytes.NewReader(b).Read).AnyTimes()
- controlStr2 := mockquic.NewMockStream(mockCtrl)
- controlStr2.EXPECT().Read(gomock.Any()).DoAndReturn(bytes.NewReader(b).Read).AnyTimes()
+ controlStr1, err := clientConn.OpenUniStream()
+ require.NoError(t, err)
+ _, err = controlStr1.Write(b)
+ require.NoError(t, err)
+ controlStr2, err := clientConn.OpenUniStream()
+ require.NoError(t, err)
+ _, err = controlStr2.Write(b)
+ require.NoError(t, err)
+
done := make(chan struct{})
- closed := make(chan struct{})
- qconn.EXPECT().CloseWithError(qerr.ApplicationErrorCode(ErrCodeStreamCreationError), gomock.Any()).Do(func(qerr.ApplicationErrorCode, string) error {
- close(closed)
- return nil
- })
- qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(controlStr1, nil)
- qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(controlStr2, nil)
- qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("test done"))
go func() {
defer close(done)
conn.handleUnidirectionalStreams(nil)
}()
select {
- case <-closed:
+ case <-clientConn.Context().Done():
+ require.ErrorIs(t,
+ context.Cause(clientConn.Context()),
+ &quic.ApplicationError{Remote: true, ErrorCode: quic.ApplicationErrorCode(ErrCodeStreamCreationError)},
+ )
case <-time.After(time.Second):
t.Fatal("timeout waiting for duplicate stream")
}
}
func TestConnResetUnknownUniStream(t *testing.T) {
- mockCtrl := gomock.NewController(t)
- qconn := mockquic.NewMockEarlyConnection(mockCtrl)
+ clientConn, serverConn := newConnPair(t)
+
conn := newConnection(
context.Background(),
- qconn,
+ serverConn,
false,
protocol.PerspectiveServer,
nil,
0,
)
buf := bytes.NewBuffer(quicvarint.Append(nil, 0x1337))
- str := mockquic.NewMockStream(mockCtrl)
- str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
- reset := make(chan struct{})
- str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)).Do(func(quic.StreamErrorCode) { close(reset) })
- qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(str, nil)
- qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("test done"))
- done := make(chan struct{})
- go func() {
- defer close(done)
- conn.handleUnidirectionalStreams(nil)
- }()
- select {
- case <-reset:
- case <-time.After(time.Second):
- t.Fatal("timeout waiting for reset")
- }
- select {
- case <-done:
- case <-time.After(time.Second):
- t.Fatal("timeout")
- }
+ str, err := clientConn.OpenUniStream()
+ require.NoError(t, err)
+ _, err = str.Write(buf.Bytes())
+ require.NoError(t, err)
+
+ go conn.handleUnidirectionalStreams(nil)
+
+ expectStreamWriteReset(t, str, quic.StreamErrorCode(ErrCodeStreamCreationError))
}
func TestConnControlStreamFailures(t *testing.T) {
}
func testConnControlStreamFailures(t *testing.T, data []byte, readErr error, expectedErr ErrCode) {
- mockCtrl := gomock.NewController(t)
- qconn := mockquic.NewMockEarlyConnection(mockCtrl)
+ clientConn, serverConn := newConnPair(t)
+
conn := newConnection(
- context.Background(),
- qconn,
+ clientConn.Context(),
+ clientConn,
false,
protocol.PerspectiveClient,
nil,
0,
)
- b := quicvarint.Append(nil, streamTypeControlStream)
- b = append(b, data...)
- r := bytes.NewReader(b)
- controlStr := mockquic.NewMockStream(mockCtrl)
- controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
- if r.Len() == 0 {
- return 0, readErr
- }
- return r.Read(b)
- }).AnyTimes()
- qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(controlStr, nil)
- qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("test done"))
- closed := make(chan struct{})
-
- str := mockquic.NewMockStream(mockCtrl)
- str.EXPECT().StreamID().Return(4).AnyTimes()
- str.EXPECT().Context().Return(context.Background()).AnyTimes()
- qconn.EXPECT().OpenStreamSync(gomock.Any()).Return(str, nil)
+ controlStr, err := serverConn.OpenUniStream()
+ require.NoError(t, err)
+ _, err = controlStr.Write(quicvarint.Append(nil, streamTypeControlStream))
+ require.NoError(t, err)
+
+ switch readErr {
+ case nil:
+ _, err = controlStr.Write(data)
+ require.NoError(t, err)
+ case io.EOF:
+ _, err = controlStr.Write(data)
+ require.NoError(t, err)
+ require.NoError(t, controlStr.Close())
+ default:
+ // make sure the stream type is received
+ time.Sleep(scaleDuration(10 * time.Millisecond))
+ controlStr.CancelWrite(1337)
+ }
+
conn.openRequestStream(context.Background(), nil, nil, true, 1000)
- qconn.EXPECT().CloseWithError(quic.ApplicationErrorCode(expectedErr), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error {
- close(closed)
- return nil
- })
done := make(chan struct{})
go func() {
defer close(done)
conn.handleUnidirectionalStreams(nil)
}()
select {
- case <-closed:
+ case <-serverConn.Context().Done():
+ require.ErrorIs(t,
+ context.Cause(serverConn.Context()),
+ &quic.ApplicationError{Remote: true, ErrorCode: quic.ApplicationErrorCode(expectedErr)},
+ )
case <-time.After(time.Second):
t.Fatal("timeout waiting for close")
}
}
func testConnGoAway(t *testing.T, withStream bool) {
- mockCtrl := gomock.NewController(t)
- qconn := mockquic.NewMockEarlyConnection(mockCtrl)
+ clientConn, serverConn := newConnPair(t)
+
conn := newConnection(
- context.Background(),
- qconn,
+ clientConn.Context(),
+ clientConn,
false,
protocol.PerspectiveClient,
nil,
b = (&settingsFrame{}).Append(b)
b = (&goAwayFrame{StreamID: 8}).Append(b)
- var mockStr *mockquic.MockStream
var str quic.Stream
if withStream {
- mockStr = mockquic.NewMockStream(mockCtrl)
- mockStr.EXPECT().StreamID().Return(0).AnyTimes()
- mockStr.EXPECT().Context().Return(context.Background()).AnyTimes()
- qconn.EXPECT().OpenStreamSync(gomock.Any()).Return(mockStr, nil)
s, err := conn.openRequestStream(context.Background(), nil, nil, true, 1000)
require.NoError(t, err)
str = s
}
- done := make(chan struct{})
- defer close(done)
- r := bytes.NewReader(b)
- controlStr := mockquic.NewMockStream(mockCtrl)
- controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
- if r.Len() == 0 {
- <-done
- return 0, errors.New("test done")
- }
- return r.Read(b)
- }).AnyTimes()
- qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(controlStr, nil)
- qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("test done"))
- closed := make(chan struct{})
- qconn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error {
- close(closed)
- return nil
- })
- // duplicate calls to CloseWithError are a no-op
- qconn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).AnyTimes()
+ controlStr, err := serverConn.OpenUniStream()
+ require.NoError(t, err)
+ _, err = controlStr.Write(b)
+ require.NoError(t, err)
+
go conn.handleUnidirectionalStreams(nil)
// the connection should be closed after the stream is closed
if withStream {
select {
- case <-closed:
+ case <-serverConn.Context().Done():
t.Fatal("connection closed")
case <-time.After(scaleDuration(10 * time.Millisecond)):
}
// The stream ID in the GOAWAY frame is 8, so it's possible to open stream 4.
- mockStr2 := mockquic.NewMockStream(mockCtrl)
- mockStr2.EXPECT().StreamID().Return(4).AnyTimes()
- mockStr2.EXPECT().Context().Return(context.Background()).AnyTimes()
- qconn.EXPECT().OpenStreamSync(gomock.Any()).Return(mockStr2, nil)
str2, err := conn.openRequestStream(context.Background(), nil, nil, true, 1000)
require.NoError(t, err)
- mockStr2.EXPECT().Close()
str2.Close()
- mockStr2.EXPECT().CancelRead(gomock.Any())
str2.CancelRead(1337)
// It's not possible to open stream 8.
_, err = conn.openRequestStream(context.Background(), nil, nil, true, 1000)
require.ErrorIs(t, err, errGoAway)
- mockStr.EXPECT().Close()
str.Close()
- mockStr.EXPECT().CancelRead(gomock.Any())
str.CancelRead(1337)
}
select {
- case <-closed:
+ case <-serverConn.Context().Done():
+ require.ErrorIs(t,
+ context.Cause(serverConn.Context()),
+ &quic.ApplicationError{Remote: true, ErrorCode: quic.ApplicationErrorCode(ErrCodeNoError)},
+ )
case <-time.After(time.Second):
t.Fatal("timeout waiting for close")
}
}
func testConnRejectPushStream(t *testing.T, pers protocol.Perspective, expectedErr ErrCode) {
- mockCtrl := gomock.NewController(t)
- qconn := mockquic.NewMockEarlyConnection(mockCtrl)
+ clientConn, serverConn := newConnPair(t)
+
conn := newConnection(
- context.Background(),
- qconn,
+ clientConn.Context(),
+ clientConn,
false,
pers.Opposite(),
nil,
0,
)
buf := bytes.NewBuffer(quicvarint.Append(nil, streamTypePushStream))
- controlStr := mockquic.NewMockStream(mockCtrl)
- controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
- qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(controlStr, nil)
- qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("test done"))
- closed := make(chan struct{})
- qconn.EXPECT().CloseWithError(quic.ApplicationErrorCode(expectedErr), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error {
- close(closed)
- return nil
- })
+ controlStr, err := serverConn.OpenUniStream()
+ require.NoError(t, err)
+ _, err = controlStr.Write(buf.Bytes())
+ require.NoError(t, err)
+
done := make(chan struct{})
go func() {
defer close(done)
conn.handleUnidirectionalStreams(nil)
}()
select {
- case <-closed:
+ case <-serverConn.Context().Done():
+ require.ErrorIs(t,
+ context.Cause(serverConn.Context()),
+ &quic.ApplicationError{Remote: true, ErrorCode: quic.ApplicationErrorCode(expectedErr)},
+ )
case <-time.After(time.Second):
t.Fatal("timeout waiting for close")
}
}
func TestConnInconsistentDatagramSupport(t *testing.T) {
- mockCtrl := gomock.NewController(t)
- qconn := mockquic.NewMockEarlyConnection(mockCtrl)
+ clientConn, serverConn := newConnPair(t)
+
conn := newConnection(
- context.Background(),
- qconn,
+ clientConn.Context(),
+ clientConn,
true,
protocol.PerspectiveClient,
nil,
)
b := quicvarint.Append(nil, streamTypeControlStream)
b = (&settingsFrame{Datagram: true}).Append(b)
- controlStr := mockquic.NewMockStream(mockCtrl)
- controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(bytes.NewReader(b).Read).AnyTimes()
- qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(controlStr, nil)
- qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("test done"))
- qconn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: false})
- closed := make(chan struct{})
- qconn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support").Do(func(quic.ApplicationErrorCode, string) error {
- close(closed)
- return nil
- })
- done := make(chan struct{})
- go func() {
- defer close(done)
- conn.handleUnidirectionalStreams(nil)
- }()
+ controlStr, err := serverConn.OpenUniStream()
+ require.NoError(t, err)
+ _, err = controlStr.Write(b)
+ require.NoError(t, err)
+
+ go conn.handleUnidirectionalStreams(nil)
+
select {
- case <-closed:
+ case <-serverConn.Context().Done():
+ err := context.Cause(serverConn.Context())
+ require.ErrorIs(t, err, &quic.ApplicationError{Remote: true, ErrorCode: quic.ApplicationErrorCode(ErrCodeSettingsError)})
+ require.ErrorContains(t, err, "missing QUIC Datagram support")
case <-time.After(time.Second):
t.Fatal("timeout waiting for close")
}
- select {
- case <-done:
- case <-time.After(time.Second):
- t.Fatal("timeout")
- }
}
func TestConnSendAndReceiveDatagram(t *testing.T) {
- mockCtrl := gomock.NewController(t)
- qconn := mockquic.NewMockEarlyConnection(mockCtrl)
+ clientConn, serverConn := newConnPairWithDatagrams(t)
+
conn := newConnection(
- context.Background(),
- qconn,
+ clientConn.Context(),
+ clientConn,
true,
protocol.PerspectiveClient,
nil,
)
b := quicvarint.Append(nil, streamTypeControlStream)
b = (&settingsFrame{Datagram: true}).Append(b)
- r := bytes.NewReader(b)
- done := make(chan struct{})
- defer close(done)
- controlStr := mockquic.NewMockStream(mockCtrl)
- controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
- if r.Len() == 0 {
- <-done
- }
- return r.Read(b)
- }).AnyTimes()
- qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(controlStr, nil).MaxTimes(1)
- qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("test done")).MaxTimes(1)
- qconn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: true}).MaxTimes(1)
+ controlStr, err := serverConn.OpenUniStream()
+ require.NoError(t, err)
+ _, err = controlStr.Write(b)
+ require.NoError(t, err)
+
+ go conn.handleUnidirectionalStreams(nil)
const strID = 4
// first deliver a datagram...
// since the stream is not open yet, it will be dropped
quarterStreamID := quicvarint.Append([]byte{}, strID/4)
- delivered := make(chan struct{})
- qconn.EXPECT().ReceiveDatagram(gomock.Any()).DoAndReturn(func(context.Context) ([]byte, error) {
- close(delivered)
- return append(quarterStreamID, []byte("foo")...), nil
- })
- streamOpened := make(chan struct{})
- qconn.EXPECT().ReceiveDatagram(gomock.Any()).DoAndReturn(func(context.Context) ([]byte, error) {
- <-streamOpened
- return append(quarterStreamID, []byte("bar")...), nil
- }).MaxTimes(1)
- qconn.EXPECT().ReceiveDatagram(gomock.Any()).Return(nil, errors.New("test done")).MaxTimes(1)
- go func() { conn.handleUnidirectionalStreams(nil) }()
- select {
- case <-delivered:
- case <-time.After(time.Second):
- t.Fatal("timeout waiting for datagram delivery")
- }
- // now open the stream...
- qstr := mockquic.NewMockStream(mockCtrl)
- qstr.EXPECT().StreamID().Return(strID).MinTimes(1)
- qstr.EXPECT().Context().Return(context.Background()).AnyTimes()
- qconn.EXPECT().OpenStreamSync(gomock.Any()).Return(qstr, nil)
+ require.NoError(t, serverConn.SendDatagram(append(quarterStreamID, []byte("foo")...)))
+ time.Sleep(scaleDuration(10 * time.Millisecond)) // give the datagram a chance to be delivered
+
+ // don't use stream 0, since that makes it hard to test that the quarter stream ID is used
+ str1, err := conn.openRequestStream(context.Background(), nil, nil, true, 1000)
+ require.NoError(t, err)
+ str1.Close()
+
str, err := conn.openRequestStream(context.Background(), nil, nil, true, 1000)
require.NoError(t, err)
+ require.Equal(t, protocol.StreamID(strID), str.StreamID())
- // ... then deliver another datagram
- close(streamOpened)
+ // now open the stream...
+ require.NoError(t, serverConn.SendDatagram(append(quarterStreamID, []byte("bar")...)))
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
require.Equal(t, []byte("bar"), data)
// now send a datagram
- const strID2 = 404
- expected := quicvarint.Append([]byte{}, strID2/4)
+ str.SendDatagram([]byte("foobaz"))
+
+ expected := quicvarint.Append([]byte{}, strID/4)
expected = append(expected, []byte("foobaz")...)
- qconn.EXPECT().SendDatagram(expected).Return(assert.AnError)
- require.ErrorIs(t, conn.sendDatagram(strID2, []byte("foobaz")), assert.AnError)
- qconn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).AnyTimes()
+ data, err = serverConn.ReceiveDatagram(ctx)
+ require.NoError(t, err)
+ require.Equal(t, expected, data)
}
func TestConnDatagramFailures(t *testing.T) {
}
func testConnDatagramFailures(t *testing.T, datagram []byte) {
- mockCtrl := gomock.NewController(t)
- qconn := mockquic.NewMockEarlyConnection(mockCtrl)
+ clientConn, serverConn := newConnPairWithDatagrams(t)
+
conn := newConnection(
- context.Background(),
- qconn,
+ clientConn.Context(),
+ clientConn,
true,
protocol.PerspectiveClient,
nil,
)
b := quicvarint.Append(nil, streamTypeControlStream)
b = (&settingsFrame{Datagram: true}).Append(b)
- r := bytes.NewReader(b)
- done := make(chan struct{})
- controlStr := mockquic.NewMockStream(mockCtrl)
- controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
- if r.Len() == 0 {
- <-done
- }
- return r.Read(b)
- }).AnyTimes()
- qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(controlStr, nil).MaxTimes(1)
- qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("test done")).MaxTimes(1)
- qconn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: true}).MaxTimes(1)
-
- qconn.EXPECT().ReceiveDatagram(gomock.Any()).Return(datagram, nil)
- qconn.EXPECT().CloseWithError(qerr.ApplicationErrorCode(ErrCodeDatagramError), gomock.Any()).Do(func(qerr.ApplicationErrorCode, string) error {
- close(done)
- return nil
- })
- qconn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).AnyTimes() // further calls to CloseWithError are a no-op
+ controlStr, err := serverConn.OpenUniStream()
+ require.NoError(t, err)
+ _, err = controlStr.Write(b)
+ require.NoError(t, err)
+
+ require.NoError(t, serverConn.SendDatagram(datagram))
+
go func() { conn.handleUnidirectionalStreams(nil) }()
select {
- case <-done:
+ case <-serverConn.Context().Done():
+ require.ErrorIs(t,
+ context.Cause(serverConn.Context()),
+ &quic.ApplicationError{Remote: true, ErrorCode: quic.ApplicationErrorCode(ErrCodeDatagramError)},
+ )
case <-time.After(time.Second):
t.Fatal("timeout waiting for close")
}