"bytes"
"context"
"crypto/tls"
+ "errors"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/http/httptest"
+ "os"
"runtime"
"testing"
"time"
- "github.com/quic-go/qpack"
"github.com/quic-go/quic-go"
- mockquic "github.com/quic-go/quic-go/internal/mocks/quic"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/testdata"
+ "github.com/quic-go/quic-go/qlog"
"github.com/quic-go/quic-go/quicvarint"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
- "go.uber.org/mock/gomock"
)
+func getConnPair(t *testing.T) (client, server quic.Connection) {
+ t.Helper()
+
+ ln, err := quic.Listen(
+ newUDPConnLocalhost(t),
+ getTLSConfig(),
+ &quic.Config{
+ InitialStreamReceiveWindow: uint64(protocol.MaxByteCount),
+ InitialConnectionReceiveWindow: uint64(protocol.MaxByteCount),
+ Tracer: qlog.DefaultConnectionTracer,
+ },
+ )
+ require.NoError(t, err)
+
+ cl, err := quic.Dial(context.Background(), newUDPConnLocalhost(t), ln.Addr(), getTLSClientConfig(), &quic.Config{})
+ require.NoError(t, err)
+ t.Cleanup(func() { cl.CloseWithError(0, "") })
+
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+ defer cancel()
+ conn, err := ln.Accept(ctx)
+ require.NoError(t, err)
+ t.Cleanup(func() { conn.CloseWithError(0, "") })
+ return cl, conn
+}
+
func TestConfigureTLSConfig(t *testing.T) {
t.Run("basic config", func(t *testing.T) {
conf := ConfigureTLSConfig(&tls.Config{})
testDone := make(chan struct{})
defer close(testDone)
- settingsChan := make(chan []byte)
- mockCtrl := gomock.NewController(t)
- conn := mockquic.NewMockEarlyConnection(mockCtrl)
- controlStr := mockquic.NewMockStream(mockCtrl)
- controlStr.EXPECT().Write(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
- settingsChan <- b
- return len(b), nil
- })
- conn.EXPECT().OpenUniStream().Return(controlStr, nil)
- conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
- <-testDone
- return nil, assert.AnError
- }).MaxTimes(1)
- conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(ctx context.Context) (quic.Stream, error) {
- <-testDone
- return nil, assert.AnError
- }).MaxTimes(1)
- conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
- conn.EXPECT().LocalAddr().AnyTimes()
- conn.EXPECT().Context().Return(context.Background()).AnyTimes()
-
- go s.handleConn(conn)
+ clientConn, serverConn := getConnPair(t)
+ go s.handleConn(serverConn)
- select {
- case b := <-settingsChan:
- typ, l, err := quicvarint.Parse(b)
- require.NoError(t, err)
- require.EqualValues(t, streamTypeControlStream, typ)
- fp := (&frameParser{r: bytes.NewReader(b[l:])})
- f, err := fp.ParseNext()
- require.NoError(t, err)
- require.IsType(t, &settingsFrame{}, f)
- settingsFrame := f.(*settingsFrame)
- // Extended CONNECT is always supported
- require.True(t, settingsFrame.ExtendedConnect)
- require.Equal(t, settingsFrame.Datagram, enableDatagrams)
- require.Equal(t, settingsFrame.Other, other)
- case <-time.After(time.Second):
- t.Fatal("timeout")
- }
-}
-
-func decodeHeader(t *testing.T, r io.Reader) map[string][]string {
- fields := make(map[string][]string)
- decoder := qpack.NewDecoder(nil)
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+ defer cancel()
+ settingsStr, err := clientConn.AcceptUniStream(ctx)
+ require.NoError(t, err)
- frame, err := (&frameParser{r: r}).ParseNext()
+ settingsStr.SetReadDeadline(time.Now().Add(time.Second))
+ b := make([]byte, 1024)
+ n, err := settingsStr.Read(b)
require.NoError(t, err)
- require.IsType(t, &headersFrame{}, frame)
- headersFrame := frame.(*headersFrame)
- data := make([]byte, headersFrame.Length)
- _, err = io.ReadFull(r, data)
+ b = b[:n]
+
+ typ, l, err := quicvarint.Parse(b)
require.NoError(t, err)
- hfs, err := decoder.DecodeFull(data)
+ require.EqualValues(t, streamTypeControlStream, typ)
+ fp := (&frameParser{r: bytes.NewReader(b[l:])})
+ f, err := fp.ParseNext()
require.NoError(t, err)
- for _, p := range hfs {
- fields[p.Name] = append(fields[p.Name], p.Value)
- }
- return fields
+ require.IsType(t, &settingsFrame{}, f)
+ settingsFrame := f.(*settingsFrame)
+ // Extended CONNECT is always supported
+ require.True(t, settingsFrame.ExtendedConnect)
+ require.Equal(t, settingsFrame.Datagram, enableDatagrams)
+ require.Equal(t, settingsFrame.Other, other)
}
func TestServerRequestHandling(t *testing.T) {
})
}
-func encodeRequest(t *testing.T, req *http.Request) io.Reader {
- var buf bytes.Buffer
- rw := newRequestWriter()
- require.NoError(t, rw.WriteRequestHeader(&buf, req, false))
- if req.Body != nil {
- body, err := io.ReadAll(req.Body)
- require.NoError(t, err)
- buf.Write((&dataFrame{Length: uint64(len(body))}).Append(nil))
- buf.Write(body)
- }
- return bytes.NewReader(buf.Bytes())
-}
-
func testServerRequestHandling(t *testing.T,
handler http.HandlerFunc,
req *http.Request,
) (responseHeaders map[string][]string, body []byte) {
- responseBuf := &bytes.Buffer{}
- mockCtrl := gomock.NewController(t)
- str := NewMockDatagramStream(mockCtrl)
- str.EXPECT().Context().Return(context.Background()).AnyTimes()
- str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
- str.EXPECT().CancelRead(gomock.Any())
- str.EXPECT().Close()
- str.EXPECT().Read(gomock.Any()).DoAndReturn(encodeRequest(t, req).Read).AnyTimes()
-
- s := &Server{
- TLSConfig: testdata.GetTLSConfig(),
- Handler: handler,
- }
+ clientConn, serverConn := getConnPair(t)
+ str, err := clientConn.OpenStream()
+ require.NoError(t, err)
+ _, err = str.Write(encodeRequest(t, req))
+ require.NoError(t, err)
+ require.NoError(t, str.Close())
- qconn := mockquic.NewMockEarlyConnection(mockCtrl)
- qconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
- qconn.EXPECT().LocalAddr().AnyTimes()
- qconn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
- qconn.EXPECT().Context().Return(context.Background()).AnyTimes()
- conn := newConnection(context.Background(), qconn, false, protocol.PerspectiveServer, nil, 0)
+ s := &Server{Handler: handler}
+ go s.ServeQUICConn(serverConn)
- s.handleRequest(conn, str, qpack.NewDecoder(nil))
- hfs := decodeHeader(t, responseBuf)
- fp := frameParser{r: responseBuf}
+ hfs := decodeHeader(t, str)
+ fp := frameParser{r: str}
var content []byte
for {
frame, err := fp.ParseNext()
require.NoError(t, err)
require.IsType(t, &dataFrame{}, frame)
b := make([]byte, frame.(*dataFrame).Length)
- _, err = io.ReadFull(responseBuf, b)
+ _, err = io.ReadFull(str, b)
require.NoError(t, err)
content = append(content, b...)
}
return hfs, content
}
+func TestServerFirstFrameNotHeaders(t *testing.T) {
+ clientConn, serverConn := getConnPair(t)
+ str, err := clientConn.OpenStream()
+ require.NoError(t, err)
+
+ var buf bytes.Buffer
+ buf.Write((&dataFrame{Length: 6}).Append(nil))
+ buf.Write([]byte("foobar"))
+ _, err = str.Write(buf.Bytes())
+ require.NoError(t, err)
+ require.NoError(t, str.Close())
+
+ s := &Server{}
+ go s.ServeQUICConn(serverConn)
+
+ select {
+ case <-clientConn.Context().Done():
+ err := context.Cause(clientConn.Context())
+ var appErr *quic.ApplicationError
+ require.ErrorAs(t, err, &appErr)
+ require.Equal(t, quic.ApplicationErrorCode(ErrCodeFrameUnexpected), appErr.ErrorCode)
+ case <-time.After(time.Second):
+ t.Fatal("timeout")
+ }
+}
+
func TestServerHandlerBodyNotRead(t *testing.T) {
t.Run("GET request with a body", func(t *testing.T) {
testServerHandlerBodyNotRead(t,
})
}
-func TestServerFirstFrameNotHeaders(t *testing.T) {
- mockCtrl := gomock.NewController(t)
- str := NewMockDatagramStream(mockCtrl)
- str.EXPECT().Write(gomock.Any()).AnyTimes()
- str.EXPECT().Context().Return(context.Background()).AnyTimes()
- var buf bytes.Buffer
- buf.Write((&dataFrame{Length: 6}).Append(nil))
- buf.Write([]byte("foobar"))
- str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
-
- s := &Server{TLSConfig: testdata.GetTLSConfig()}
-
- qconn := mockquic.NewMockEarlyConnection(mockCtrl)
- qconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
- qconn.EXPECT().LocalAddr().AnyTimes()
- qconn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any())
- conn := newConnection(context.Background(), qconn, false, protocol.PerspectiveServer, nil, 0)
-
- s.handleRequest(conn, str, qpack.NewDecoder(nil))
-}
-
func testServerHandlerBodyNotRead(t *testing.T, req *http.Request, handler http.HandlerFunc) {
- mockCtrl := gomock.NewController(t)
- str := NewMockDatagramStream(mockCtrl)
- str.EXPECT().Write(gomock.Any()).AnyTimes()
- str.EXPECT().Context().Return(context.Background()).AnyTimes()
- str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError))
- str.EXPECT().Close().MaxTimes(1)
- str.EXPECT().Read(gomock.Any()).DoAndReturn(encodeRequest(t, req).Read).AnyTimes()
+ clientConn, serverConn := getConnPair(t)
+ str, err := clientConn.OpenStream()
+ require.NoError(t, err)
+ _, err = str.Write(encodeRequest(t, req))
+ require.NoError(t, err)
+ // require.NoError(t, str.Close())
- var called bool
+ done := make(chan struct{})
s := &Server{
- TLSConfig: testdata.GetTLSConfig(),
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- called = true
+ defer close(done)
handler(w, r)
}),
}
- qconn := mockquic.NewMockEarlyConnection(mockCtrl)
- qconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
- qconn.EXPECT().LocalAddr().AnyTimes()
- qconn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
- qconn.EXPECT().Context().Return(context.Background()).AnyTimes()
- conn := newConnection(context.Background(), qconn, false, protocol.PerspectiveServer, nil, 0)
+ go s.ServeQUICConn(serverConn)
+
+ select {
+ case <-done:
+ case <-time.After(time.Second):
+ t.Fatal("timeout")
+ }
+}
+
+func expectStreamReadReset(t *testing.T, str quic.ReceiveStream, errCode quic.StreamErrorCode) {
+ t.Helper()
+ str.SetReadDeadline(time.Now().Add(time.Second))
+ _, err := str.Read([]byte{0})
+ require.Error(t, err)
+ if errors.Is(err, os.ErrDeadlineExceeded) {
+ t.Fatal("didn't receive a stream reset")
+ }
+ var strErr *quic.StreamError
+ require.ErrorAs(t, err, &strErr)
+ require.Equal(t, errCode, strErr.ErrorCode)
+}
- s.handleRequest(conn, str, qpack.NewDecoder(nil))
- require.True(t, called)
+func expectStreamWriteReset(t *testing.T, str quic.SendStream, errCode quic.StreamErrorCode) {
+ t.Helper()
+ select {
+ case <-str.Context().Done():
+ case <-time.After(time.Second):
+ t.Fatal("timeout")
+ }
+ _, err := str.Write([]byte{0})
+ require.Error(t, err)
+ var strErr *quic.StreamError
+ require.ErrorAs(t, err, &strErr)
+ require.Equal(t, errCode, strErr.ErrorCode)
}
func TestServerStreamResetByClient(t *testing.T) {
- mockCtrl := gomock.NewController(t)
- str := NewMockDatagramStream(mockCtrl)
- done := make(chan struct{})
- str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete))
- str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)).Do(func(quic.StreamErrorCode) { close(done) })
- str.EXPECT().Read(gomock.Any()).Return(0, assert.AnError)
+ clientConn, serverConn := getConnPair(t)
+ str, err := clientConn.OpenStream()
+ require.NoError(t, err)
+ str.CancelWrite(1337)
var called bool
s := &Server{
- TLSConfig: testdata.GetTLSConfig(),
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
}),
}
- qconn := mockquic.NewMockEarlyConnection(mockCtrl)
- qconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
- qconn.EXPECT().LocalAddr().AnyTimes()
- conn := newConnection(context.Background(), qconn, false, protocol.PerspectiveServer, nil, 0)
+ go s.ServeQUICConn(serverConn)
- s.handleRequest(conn, str, qpack.NewDecoder(nil))
- select {
- case <-done:
- case <-time.After(time.Second):
- t.Fatal("timeout")
- }
+ expectStreamReadReset(t, str, quic.StreamErrorCode(ErrCodeRequestIncomplete))
require.False(t, called)
}
}
func testServerPanickingHandler(t *testing.T, handler http.HandlerFunc) (logOutput string) {
- mockCtrl := gomock.NewController(t)
- str := NewMockDatagramStream(mockCtrl)
- str.EXPECT().Context().Return(context.Background()).AnyTimes()
- str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeInternalError))
- str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeInternalError))
- str.EXPECT().Read(gomock.Any()).DoAndReturn(
- encodeRequest(t, httptest.NewRequest(http.MethodHead, "https://www.example.com", nil)).Read,
- ).AnyTimes()
+ clientConn, serverConn := getConnPair(t)
+ str, err := clientConn.OpenStream()
+ require.NoError(t, err)
+ _, err = str.Write(encodeRequest(t, httptest.NewRequest(http.MethodHead, "https://www.example.com", nil)))
+ require.NoError(t, err)
+ require.NoError(t, str.Close())
var logBuf bytes.Buffer
s := &Server{
- TLSConfig: testdata.GetTLSConfig(),
- Handler: handler,
- Logger: slog.New(slog.NewTextHandler(&logBuf, nil)),
+ Handler: handler,
+ Logger: slog.New(slog.NewTextHandler(&logBuf, nil)),
}
- qconn := mockquic.NewMockEarlyConnection(mockCtrl)
- qconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
- qconn.EXPECT().LocalAddr().AnyTimes()
- qconn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
- qconn.EXPECT().Context().Return(context.Background()).AnyTimes()
- conn := newConnection(context.Background(), qconn, false, protocol.PerspectiveServer, nil, 0)
+ go s.ServeQUICConn(serverConn)
+
+ expectStreamReadReset(t, str, quic.StreamErrorCode(ErrCodeInternalError))
+ s.Close()
- s.handleRequest(conn, str, qpack.NewDecoder(nil))
return logBuf.String()
}
func testServerRequestHeaderTooLarge(t *testing.T, req *http.Request, maxHeaderBytes int) {
var called bool
s := &Server{
- TLSConfig: testdata.GetTLSConfig(),
MaxHeaderBytes: maxHeaderBytes,
Handler: http.HandlerFunc(func(http.ResponseWriter, *http.Request) { called = true }),
}
s.init()
- done := make(chan struct{}, 2)
- mockCtrl := gomock.NewController(t)
- str := mockquic.NewMockStream(mockCtrl)
- str.EXPECT().Context().Return(context.Background()).AnyTimes()
- str.EXPECT().StreamID().AnyTimes()
- str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { done <- struct{}{} })
- str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { done <- struct{}{} })
- str.EXPECT().Read(gomock.Any()).DoAndReturn(encodeRequest(t, req).Read).AnyTimes()
+ clientConn, serverConn := getConnPair(t)
+ str, err := clientConn.OpenStream()
+ require.NoError(t, err)
+ _, err = str.Write(encodeRequest(t, req))
+ require.NoError(t, err)
+ require.NoError(t, str.Close())
+
+ go s.ServeQUICConn(serverConn)
+
+ expectStreamReadReset(t, str, quic.StreamErrorCode(ErrCodeFrameError))
+ expectStreamWriteReset(t, str, quic.StreamErrorCode(ErrCodeFrameError))
- testDone := make(chan struct{})
- conn := mockquic.NewMockEarlyConnection(mockCtrl)
- controlStr := mockquic.NewMockStream(mockCtrl)
- controlStr.EXPECT().Write(gomock.Any())
- conn.EXPECT().OpenUniStream().Return(controlStr, nil)
- conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
- <-testDone
- return nil, assert.AnError
- }).MaxTimes(1)
- conn.EXPECT().AcceptStream(gomock.Any()).Return(str, nil)
- conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, assert.AnError)
- conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
- conn.EXPECT().LocalAddr().AnyTimes()
- conn.EXPECT().Context().Return(context.Background()).AnyTimes()
-
- s.handleConn(conn)
- for range 2 {
- select {
- case <-done:
- case <-time.After(time.Second):
- t.Fatal("timeout")
- }
- }
require.False(t, called)
}
func TestServerRequestContext(t *testing.T) {
- responseBuf := &bytes.Buffer{}
- mockCtrl := gomock.NewController(t)
- str := mockquic.NewMockStream(mockCtrl)
- strCtx, strCtxCancel := context.WithCancel(context.Background())
- str.EXPECT().StreamID().AnyTimes()
- str.EXPECT().Context().Return(strCtx).AnyTimes()
- str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
- str.EXPECT().CancelRead(gomock.Any())
- str.EXPECT().Close()
- str.EXPECT().Read(gomock.Any()).DoAndReturn(
- encodeRequest(t, httptest.NewRequest(http.MethodGet, "https://www.example.com", nil)).Read,
- ).AnyTimes()
+ clientConn, serverConn := getConnPair(t)
+ str, err := clientConn.OpenStream()
+ require.NoError(t, err)
+ _, err = str.Write(encodeRequest(t, httptest.NewRequest(http.MethodHead, "https://www.example.com", nil)))
+ require.NoError(t, err)
ctxChan := make(chan context.Context, 1)
+ block := make(chan struct{})
s := &Server{
- TLSConfig: testdata.GetTLSConfig(),
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctxChan <- r.Context()
+ <-block
}),
}
- s.init()
- testDone := make(chan struct{})
- defer close(testDone)
- controlStr := mockquic.NewMockStream(mockCtrl)
- controlStr.EXPECT().Write(gomock.Any()).AnyTimes()
- conn := mockquic.NewMockEarlyConnection(mockCtrl)
- conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1337}).AnyTimes()
- conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(192, 168, 1, 2), Port: 42}).AnyTimes()
- conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
- connCtx := context.WithValue(context.Background(), "connection context", "connection context value")
- conn.EXPECT().Context().Return(connCtx).AnyTimes()
- conn.EXPECT().OpenUniStream().Return(str, nil)
- conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
- <-testDone
- return nil, assert.AnError
- }).MaxTimes(1)
- conn.EXPECT().AcceptStream(gomock.Any()).Return(str, nil)
- conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
- <-testDone
- return nil, assert.AnError
- }).MaxTimes(1)
-
- go s.handleConn(conn)
+ go s.ServeQUICConn(serverConn)
+
var requestContext context.Context
select {
case requestContext = <-ctxChan:
t.Fatal("timeout")
}
- require.Equal(t, "connection context value", requestContext.Value("connection context"))
require.Equal(t, s, requestContext.Value(ServerContextKey))
- require.Equal(t, &net.UDPAddr{IP: net.IPv4(192, 168, 1, 2), Port: 42}, requestContext.Value(http.LocalAddrContextKey))
- require.Equal(t, &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1337}, requestContext.Value(RemoteAddrContextKey))
+ require.Equal(t, serverConn.LocalAddr(), requestContext.Value(http.LocalAddrContextKey))
+ require.Equal(t, serverConn.RemoteAddr(), requestContext.Value(RemoteAddrContextKey))
select {
case <-requestContext.Done():
t.Fatal("request context was canceled")
case <-time.After(scaleDuration(10 * time.Millisecond)):
}
- strCtxCancel()
+ str.CancelRead(1337)
+
select {
case <-requestContext.Done():
case <-time.After(time.Second):
t.Fatal("timeout")
}
require.Equal(t, context.Canceled, requestContext.Err())
+ close(block)
}
func TestServerHTTPStreamHijacking(t *testing.T) {
- responseBuf := &bytes.Buffer{}
- mockCtrl := gomock.NewController(t)
- str := NewMockDatagramStream(mockCtrl)
- str.EXPECT().Context().Return(context.Background()).AnyTimes()
- str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
- str.EXPECT().Read(gomock.Any()).DoAndReturn(
- encodeRequest(t, httptest.NewRequest(http.MethodGet, "https://www.example.com", nil)).Read,
- ).AnyTimes()
+ clientConn, serverConn := getConnPair(t)
+ str, err := clientConn.OpenStream()
+ require.NoError(t, err)
+ _, err = str.Write(encodeRequest(t, httptest.NewRequest(http.MethodHead, "https://www.example.com", nil)))
+ require.NoError(t, err)
+ require.NoError(t, str.Close())
s := &Server{
- TLSConfig: testdata.GetTLSConfig(),
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.(HTTPStreamer).HTTPStream()
+ str := w.(HTTPStreamer).HTTPStream()
str.Write([]byte("foobar"))
+ str.Close()
}),
}
+ go s.ServeQUICConn(serverConn)
- qconn := mockquic.NewMockEarlyConnection(mockCtrl)
- qconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
- qconn.EXPECT().LocalAddr().AnyTimes()
- qconn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
- qconn.EXPECT().Context().Return(context.Background()).AnyTimes()
- conn := newConnection(context.Background(), qconn, false, protocol.PerspectiveServer, nil, 0)
-
- s.handleRequest(conn, str, qpack.NewDecoder(nil))
- hfs := decodeHeader(t, responseBuf)
+ str.SetReadDeadline(time.Now().Add(time.Second))
+ rsp, err := io.ReadAll(str)
+ require.NoError(t, err)
+ r := bytes.NewReader(rsp)
+ hfs := decodeHeader(t, r)
require.Equal(t, hfs[":status"], []string{"200"})
- require.Equal(t, []byte("foobar"), responseBuf.Bytes())
+ fp := frameParser{r: r}
+ frame, err := fp.ParseNext()
+ require.NoError(t, err)
+ require.IsType(t, &dataFrame{}, frame)
+ dataFrame := frame.(*dataFrame)
+ require.Equal(t, uint64(6), dataFrame.Length)
+ data, err := io.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, []byte("foobar"), data)
}
func TestServerStreamHijacking(t *testing.T) {
}
func testServerHijackBidirectionalStream(t *testing.T, bidirectional bool, doHijack bool, hijackErr error) {
- id := quic.ConnectionTracingID(1337)
type hijackCall struct {
ft FrameType // for bidirectional streams
st StreamType // for unidirectional streams
hijackChan := make(chan hijackCall, 1)
testDone := make(chan struct{})
s := &Server{
- TLSConfig: testdata.GetTLSConfig(),
StreamHijacker: func(ft FrameType, connTracingID quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) {
defer close(testDone)
hijackChan <- hijackCall{ft: ft, connTracingID: connTracingID, e: e}
return doHijack
},
}
- s.init()
+
+ clientConn, serverConn := getConnPair(t)
+ go s.ServeQUICConn(serverConn)
buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
- mockCtrl := gomock.NewController(t)
- unknownStr := mockquic.NewMockStream(mockCtrl)
- unknownStr.EXPECT().Context().Return(context.Background()).AnyTimes()
- unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
- unknownStr.EXPECT().StreamID().AnyTimes()
- if !doHijack || hijackErr != nil {
- if bidirectional {
- unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete))
- unknownStr.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete))
- } else {
- unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError))
- }
- }
- conn := mockquic.NewMockEarlyConnection(mockCtrl)
if bidirectional {
- conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
+ str, err := clientConn.OpenStream()
+ require.NoError(t, err)
+ _, err = str.Write(buf.Bytes())
+ require.NoError(t, err)
+
+ if hijackErr != nil {
+ expectStreamReadReset(t, str, quic.StreamErrorCode(ErrCodeRequestIncomplete))
+ expectStreamWriteReset(t, str, quic.StreamErrorCode(ErrCodeRequestIncomplete))
+ }
+ // if the stream is not hijacked, the frame parser will skip the frame
} else {
- conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil)
+ str, err := clientConn.OpenUniStream()
+ require.NoError(t, err)
+ _, err = str.Write(buf.Bytes())
+ require.NoError(t, err)
+
+ if !doHijack || hijackErr != nil {
+ expectStreamWriteReset(t, str, quic.StreamErrorCode(ErrCodeStreamCreationError))
+ }
}
- conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
- <-testDone
- return nil, assert.AnError
- })
- controlStr := mockquic.NewMockStream(mockCtrl)
- controlStr.EXPECT().Write(gomock.Any())
- conn.EXPECT().OpenUniStream().Return(controlStr, nil)
- conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
- conn.EXPECT().LocalAddr().AnyTimes()
- conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
- <-testDone
- return nil, assert.AnError
- })
- ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, id)
- conn.EXPECT().Context().Return(ctx).AnyTimes()
- s.handleConn(conn)
+
select {
case hijackCall := <-hijackChan:
if bidirectional {
assert.Equal(t, hijackCall.st, StreamType(0x41))
assert.Zero(t, hijackCall.ft)
}
- assert.Equal(t, hijackCall.connTracingID, id)
+ assert.Equal(t, serverConn.Context().Value(quic.ConnectionTracingKey), hijackCall.connTracingID)
assert.NoError(t, hijackCall.e)
case <-time.After(time.Second):
t.Fatal("hijack call not received")
}
func testServerAltSvcFromListenersAndConns(t *testing.T, versions []quic.Version) {
- ln1, err := quic.ListenEarly(newUDPConnLocalhost(t), testdata.GetTLSConfig(), nil)
+ ln1, err := quic.ListenEarly(newUDPConnLocalhost(t), getTLSConfig(), nil)
require.NoError(t, err)
port1 := ln1.Addr().(*net.UDPAddr).Port
s := &Server{
Addr: ":1337", // will be ignored since we're using listeners
- TLSConfig: testdata.GetTLSConfig(),
+ TLSConfig: getTLSConfig(),
QUICConfig: &quic.Config{Versions: versions},
}
done1 := make(chan struct{})
_, ok := getAltSvc(s)
require.False(t, ok)
- ln, err := quic.ListenEarly(newUDPConnLocalhost(t), testdata.GetTLSConfig(), nil)
+ ln, err := quic.ListenEarly(newUDPConnLocalhost(t), getTLSConfig(), nil)
require.NoError(t, err)
done := make(chan struct{})
go func() {
}
func TestServerClosing(t *testing.T) {
- s := &Server{TLSConfig: testdata.GetTLSConfig()}
+ s := &Server{TLSConfig: getTLSConfig()}
require.NoError(t, s.Close())
require.NoError(t, s.Close()) // duplicate calls are ok
require.ErrorIs(t, s.ListenAndServe(), http.ErrServerClosed)
}
func TestServerGracefulShutdown(t *testing.T) {
- s := &Server{TLSConfig: testdata.GetTLSConfig()}
- s.init()
-
- mockCtrl := gomock.NewController(t)
- controlStr := mockquic.NewMockStream(mockCtrl)
- controlStrChan := make(chan []byte, 1)
- controlStr.EXPECT().Write(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
- controlStrChan <- b
- return len(b), nil
- }).AnyTimes()
-
- streamChan := make(chan quic.Stream, 1)
- testDone := make(chan struct{})
- conn := mockquic.NewMockEarlyConnection(mockCtrl)
- conn.EXPECT().OpenUniStream().Return(controlStr, nil)
- conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
- <-testDone
- return nil, assert.AnError
- }).MaxTimes(1)
- conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(ctx context.Context) (quic.Stream, error) {
- select {
- case <-ctx.Done():
- return nil, ctx.Err()
- case str, ok := <-streamChan:
- if !ok {
- return nil, assert.AnError
- }
- return str, nil
- }
- }).AnyTimes()
- conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
- conn.EXPECT().LocalAddr().AnyTimes()
- conn.EXPECT().Context().Return(context.Background()).AnyTimes()
-
- firstStream := mockquic.NewMockStream(mockCtrl)
- firstStreamAccepted := make(chan struct{}, 1)
- firstStream.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
- firstStreamAccepted <- struct{}{}
- <-testDone
- return 0, assert.AnError
- })
- firstStream.EXPECT().StreamID().Return(quic.StreamID(1337)).AnyTimes()
- firstStream.EXPECT().Context().Return(context.Background()).AnyTimes()
- streamChan <- firstStream
-
- go s.ServeQUICConn(conn)
+ requestChan := make(chan struct{}, 1)
+ s := &Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ requestChan <- struct{}{}
+ })}
- var r bytes.Buffer
- fp := &frameParser{r: &r}
+ clientConn, serverConn := getConnPair(t)
+ go s.ServeQUICConn(serverConn)
- select {
- case b := <-controlStrChan:
- _, l, err := quicvarint.Parse(b)
- require.NoError(t, err)
- r.Write(b[l:])
- f, err := fp.ParseNext()
- require.NoError(t, err)
- require.IsType(t, &settingsFrame{}, f)
- case <-time.After(time.Second):
- t.Fatal("timeout")
- }
+ firstStream, err := clientConn.OpenStream()
+ require.NoError(t, err)
+ _, err = firstStream.Write(encodeRequest(t, httptest.NewRequest(http.MethodGet, "https://www.example.com", nil)))
+ require.NoError(t, err)
select {
- case <-firstStreamAccepted:
+ case <-requestChan:
case <-time.After(time.Second):
t.Fatal("timeout")
}
- time.Sleep(scaleDuration(10 * time.Millisecond))
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+ defer cancel()
+ controlStr, err := clientConn.AcceptUniStream(ctx)
+ require.NoError(t, err)
+ typ, err := quicvarint.Read(quicvarint.NewReader(controlStr))
+ require.NoError(t, err)
+ require.EqualValues(t, streamTypeControlStream, typ)
+ fp := &frameParser{r: controlStr}
+ f, err := fp.ParseNext()
+ require.NoError(t, err)
+ require.IsType(t, &settingsFrame{}, f)
- ctx, cancel := context.WithCancel(context.Background())
+ shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
errChan := make(chan error)
go func() {
- errChan <- s.Shutdown(ctx)
+ errChan <- s.Shutdown(shutdownCtx)
}()
- select {
- case b := <-controlStrChan:
- r.Write(b)
- f, err := fp.ParseNext()
- require.NoError(t, err)
- require.Equal(t, &goAwayFrame{StreamID: 1337 + 4}, f)
- case <-time.After(time.Second):
- t.Fatal("timeout")
- }
+ f, err = fp.ParseNext()
+ require.NoError(t, err)
+ require.Equal(t, &goAwayFrame{StreamID: 4}, f)
select {
case <-errChan:
// all further streams are getting rejected
for range 3 {
- resetChan := make(chan struct{}, 2)
- str := mockquic.NewMockStream(mockCtrl)
- str.EXPECT().StreamID().AnyTimes()
- str.EXPECT().Context().Return(context.Background()).AnyTimes()
- str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestRejected)).Do(func(sec quic.StreamErrorCode) {
- resetChan <- struct{}{}
- })
- str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestRejected)).Do(func(sec quic.StreamErrorCode) {
- resetChan <- struct{}{}
- })
- streamChan <- str
-
- for range 2 {
- select {
- case <-resetChan:
- case <-time.After(time.Second):
- t.Fatal("expected stream reset")
- }
- }
+ str, err := clientConn.OpenStream()
+ require.NoError(t, err)
+ _, _ = str.Write(encodeRequest(t, httptest.NewRequest(http.MethodGet, "https://www.example.com", nil)))
+ expectStreamReadReset(t, str, quic.StreamErrorCode(ErrCodeRequestRejected))
+ expectStreamWriteReset(t, str, quic.StreamErrorCode(ErrCodeRequestRejected))
}
// cancel the context passed to Shutdown
- conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), gomock.Any())
- cancel()
- firstStream.EXPECT().CancelRead(gomock.Any())
- firstStream.EXPECT().CancelWrite(gomock.Any())
- close(testDone)
+ shutdownCancel()
select {
case err := <-errChan: