"math"
"net"
"runtime"
+ "strings"
"sync/atomic"
"syscall"
"testing"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
+ "github.com/quic-go/quic-go/internal/synctest"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/qlog"
"github.com/quic-go/quic-go/qlogwriter"
"github.com/quic-go/quic-go/testutils/events"
+ "github.com/quic-go/quic-go/testutils/simnet"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
func (h *mockPacketHandler) closeWithTransportError(code qerr.TransportErrorCode) {}
+func newSimnetLink(t *testing.T, rtt time.Duration) (client, server net.PacketConn, close func()) {
+ t.Helper()
+
+ n := &simnet.Simnet{Router: &simnet.PerfectRouter{}}
+ settings := simnet.NodeBiDiLinkSettings{
+ Downlink: simnet.LinkSettings{BitsPerSecond: math.MaxInt / 1024, Latency: rtt / 4},
+ Uplink: simnet.LinkSettings{BitsPerSecond: math.MaxInt / 1024, Latency: rtt / 4},
+ }
+
+ client = n.NewEndpoint(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 9001}, settings)
+ server = n.NewEndpoint(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 9002}, settings)
+ require.NoError(t, n.Start())
+ return client, server, func() {
+ require.NoError(t, n.Close())
+ }
+}
+
func TestTransportPacketHandling(t *testing.T) {
tr := &Transport{Conn: newUDPConnLocalhost(t)}
tr.init(true)
func TestTransportErrFromConn(t *testing.T) {
t.Setenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING", "true")
- readErrChan := make(chan error, 2)
- tr := Transport{
- Conn: &mockPacketConn{
- readErrs: readErrChan,
- localAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234},
- },
- }
- defer tr.Close()
- tr.init(true)
+ synctest.Test(t, func(t *testing.T) {
+ readErrChan := make(chan error, 2)
+ tr := Transport{
+ Conn: &mockPacketConn{
+ readErrs: readErrChan,
+ localAddr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234},
+ },
+ }
+ defer tr.Close()
+ tr.init(true)
+
+ errChan := make(chan error, 1)
+ ph := &mockPacketHandler{destruction: errChan}
+ (*packetHandlerMap)(&tr).Add(protocol.ParseConnectionID([]byte{1, 2, 3, 4}), ph)
+
+ // temporary errors don't lead to a shutdown...
+ var tempErr deadlineError
+ require.True(t, tempErr.Temporary())
+ readErrChan <- tempErr
+ // don't expect any calls to phm.Close
+ synctest.Wait()
+
+ // ...but non-temporary errors do
+ readErrChan <- errors.New("read failed")
+ synctest.Wait()
+
+ select {
+ case err := <-errChan:
+ require.ErrorIs(t, err, ErrTransportClosed)
+ case <-time.After(time.Second):
+ t.Fatal("timeout")
+ }
- errChan := make(chan error, 1)
- ph := &mockPacketHandler{destruction: errChan}
- (*packetHandlerMap)(&tr).Add(protocol.ParseConnectionID([]byte{1, 2, 3, 4}), ph)
-
- // temporary errors don't lead to a shutdown...
- var tempErr deadlineError
- require.True(t, tempErr.Temporary())
- readErrChan <- tempErr
- // don't expect any calls to phm.Close
- time.Sleep(scaleDuration(10 * time.Millisecond))
-
- // ...but non-temporary errors do
- readErrChan <- errors.New("read failed")
- select {
- case err := <-errChan:
+ _, err := tr.Listen(&tls.Config{}, nil)
require.ErrorIs(t, err, ErrTransportClosed)
- case <-time.After(time.Second):
- t.Fatal("timeout")
- }
-
- _, err := tr.Listen(&tls.Config{}, nil)
- require.ErrorIs(t, err, ErrTransportClosed)
+ })
}
func TestTransportStatelessResetReceiving(t *testing.T) {
}
func TestTransportStatelessResetSending(t *testing.T) {
- var eventRecorder events.Recorder
- tr := &Transport{
- Conn: newUDPConnLocalhost(t),
- ConnectionIDLength: 4,
- StatelessResetKey: &StatelessResetKey{1, 2, 3, 4},
- Tracer: &eventRecorder,
- }
- tr.init(true)
- defer tr.Close()
+ synctest.Test(t, func(t *testing.T) {
+ const rtt = 10 * time.Millisecond
+ clientConn, serverConn, closeFn := newSimnetLink(t, rtt)
+ defer closeFn()
+
+ var eventRecorder events.Recorder
+ tr := &Transport{
+ Conn: serverConn,
+ ConnectionIDLength: 4,
+ StatelessResetKey: &StatelessResetKey{1, 2, 3, 4},
+ Tracer: &eventRecorder,
+ }
+ tr.init(true)
+ defer tr.Close()
- connID := protocol.ParseConnectionID([]byte{9, 10, 11, 12})
+ connID := protocol.ParseConnectionID([]byte{9, 10, 11, 12})
- // now send a packet with a connection ID that doesn't exist
- b, err := wire.AppendShortHeader(nil, connID, 1337, 2, protocol.KeyPhaseOne)
- require.NoError(t, err)
+ // now send a packet with a connection ID that doesn't exist
+ b, err := wire.AppendShortHeader(nil, connID, 1337, 2, protocol.KeyPhaseOne)
+ require.NoError(t, err)
- conn := newUDPConnLocalhost(t)
+ // no stateless reset sent for packets smaller than MinStatelessResetSize
+ smallPacket := append(b, make([]byte, protocol.MinStatelessResetSize-len(b))...)
+ _, err = clientConn.WriteTo(smallPacket, tr.Conn.LocalAddr())
+ require.NoError(t, err)
- // no stateless reset sent for packets smaller than MinStatelessResetSize
- smallPacket := append(b, make([]byte, protocol.MinStatelessResetSize-len(b))...)
- _, err = conn.WriteTo(smallPacket, tr.Conn.LocalAddr())
- require.NoError(t, err)
- require.Eventually(t,
- func() bool { return len(eventRecorder.Events(qlog.PacketDropped{})) > 0 },
- time.Second,
- 10*time.Millisecond,
- )
- require.Equal(t,
- []qlogwriter.Event{
- qlog.PacketDropped{
- Header: qlog.PacketHeader{PacketType: qlog.PacketType1RTT},
- Raw: qlog.RawInfo{Length: len(smallPacket)},
- Trigger: qlog.PacketDropUnknownConnectionID,
+ time.Sleep(rtt) // so that the packet arrives at the server
+
+ require.Equal(t,
+ []qlogwriter.Event{
+ qlog.PacketDropped{
+ Header: qlog.PacketHeader{PacketType: qlog.PacketType1RTT},
+ Raw: qlog.RawInfo{Length: len(smallPacket)},
+ Trigger: qlog.PacketDropUnknownConnectionID,
+ },
},
- },
- eventRecorder.Events(qlog.PacketDropped{}),
- )
+ eventRecorder.Events(qlog.PacketDropped{}),
+ )
- // but a stateless reset is sent for packets larger than MinStatelessResetSize
- _, err = conn.WriteTo(append(b, make([]byte, protocol.MinStatelessResetSize-len(b)+1)...), tr.Conn.LocalAddr())
- require.NoError(t, err)
- conn.SetReadDeadline(time.Now().Add(time.Second))
- p := make([]byte, 1024)
- n, addr, err := conn.ReadFrom(p)
- require.NoError(t, err)
- require.Equal(t, addr, tr.Conn.LocalAddr())
- srt := newStatelessResetter(tr.StatelessResetKey).GetStatelessResetToken(connID)
- require.Contains(t, string(p[:n]), string(srt[:]))
+ // but a stateless reset is sent for packets larger than MinStatelessResetSize
+ _, err = clientConn.WriteTo(append(b, make([]byte, protocol.MinStatelessResetSize-len(b)+1)...), tr.Conn.LocalAddr())
+ require.NoError(t, err)
+ clientConn.SetReadDeadline(time.Now().Add(time.Second))
+ p := make([]byte, 1024)
+ n, addr, err := clientConn.ReadFrom(p)
+ require.NoError(t, err)
+ require.Equal(t, addr, tr.Conn.LocalAddr())
+ srt := newStatelessResetter(tr.StatelessResetKey).GetStatelessResetToken(connID)
+ require.Contains(t, string(p[:n]), string(srt[:]))
+ })
}
func TestTransportUnparseableQUICPackets(t *testing.T) {
- var eventRecorder events.Recorder
- tr := &Transport{
- Conn: newUDPConnLocalhost(t),
- ConnectionIDLength: 10,
- Tracer: &eventRecorder,
- }
- require.NoError(t, tr.init(true))
- defer tr.Close()
+ synctest.Test(t, func(t *testing.T) {
+ const rtt = 10 * time.Millisecond
+ clientConn, serverConn, closeFn := newSimnetLink(t, rtt)
+ defer closeFn()
+
+ var eventRecorder events.Recorder
+ tr := &Transport{
+ Conn: serverConn,
+ ConnectionIDLength: 10,
+ Tracer: &eventRecorder,
+ }
+ require.NoError(t, tr.init(true))
+ defer tr.Close()
- conn := newUDPConnLocalhost(t)
- _, err := conn.WriteTo([]byte{0x40 /* set the QUIC bit */, 1, 2, 3}, tr.Conn.LocalAddr())
- require.NoError(t, err)
+ _, err := clientConn.WriteTo([]byte{0x40 /* set the QUIC bit */, 1, 2, 3}, tr.Conn.LocalAddr())
+ require.NoError(t, err)
- require.Eventually(t,
- func() bool { return len(eventRecorder.Events(qlog.PacketDropped{})) > 0 },
- time.Second,
- 10*time.Millisecond,
- )
- require.Equal(t,
- []qlogwriter.Event{
- qlog.PacketDropped{
- Raw: qlog.RawInfo{Length: 4},
- Trigger: qlog.PacketDropHeaderParseError,
+ time.Sleep(rtt) // so that the packet arrives at the server
+
+ require.Equal(t,
+ []qlogwriter.Event{
+ qlog.PacketDropped{
+ Raw: qlog.RawInfo{Length: 4},
+ Trigger: qlog.PacketDropHeaderParseError,
+ },
},
- },
- eventRecorder.Events(qlog.PacketDropped{}),
- )
+ eventRecorder.Events(qlog.PacketDropped{}),
+ )
+ })
}
func TestTransportListening(t *testing.T) {
- var eventRecorder events.Recorder
- tr := &Transport{
- Conn: newUDPConnLocalhost(t),
- ConnectionIDLength: 5,
- Tracer: &eventRecorder,
- }
- require.NoError(t, tr.init(true))
- defer tr.Close()
+ synctest.Test(t, func(t *testing.T) {
+ const rtt = 10 * time.Millisecond
+ clientConn, serverConn, closeFn := newSimnetLink(t, rtt)
+ defer closeFn()
+
+ var eventRecorder events.Recorder
+ tr := &Transport{
+ Conn: serverConn,
+ ConnectionIDLength: 5,
+ Tracer: &eventRecorder,
+ }
+ require.NoError(t, tr.init(true))
+ defer tr.Close()
- conn := newUDPConnLocalhost(t)
- data := wire.ComposeVersionNegotiation([]byte{1, 2, 3, 4, 5}, []byte{6, 7, 8, 9, 10}, []protocol.Version{protocol.Version1})
+ data := wire.ComposeVersionNegotiation([]byte{1, 2, 3, 4, 5}, []byte{6, 7, 8, 9, 10}, []protocol.Version{protocol.Version1})
- _, err := conn.WriteTo(data, tr.Conn.LocalAddr())
- require.NoError(t, err)
- require.Eventually(t,
- func() bool { return len(eventRecorder.Events(qlog.PacketDropped{})) > 0 },
- time.Second,
- 10*time.Millisecond,
- )
- require.Equal(t,
- []qlogwriter.Event{
- qlog.PacketDropped{
- Raw: qlog.RawInfo{Length: len(data)},
- Trigger: qlog.PacketDropUnknownConnectionID,
+ _, err := clientConn.WriteTo(data, tr.Conn.LocalAddr())
+ require.NoError(t, err)
+
+ time.Sleep(rtt) // so that the packet arrives at the server
+
+ require.Equal(t,
+ []qlogwriter.Event{
+ qlog.PacketDropped{
+ Raw: qlog.RawInfo{Length: len(data)},
+ Trigger: qlog.PacketDropUnknownConnectionID,
+ },
},
- },
- eventRecorder.Events(qlog.PacketDropped{}),
- )
- eventRecorder.Clear()
+ eventRecorder.Events(qlog.PacketDropped{}),
+ )
+ eventRecorder.Clear()
- ln, err := tr.Listen(&tls.Config{}, nil)
- require.NoError(t, err)
+ ln, err := tr.Listen(&tls.Config{}, nil)
+ require.NoError(t, err)
- _, err = conn.WriteTo(data, tr.Conn.LocalAddr())
- require.NoError(t, err)
- require.Eventually(t,
- func() bool { return len(eventRecorder.Events(qlog.PacketDropped{})) > 0 },
- time.Second,
- 10*time.Millisecond,
- )
- require.Equal(t,
- []qlogwriter.Event{
- qlog.PacketDropped{
- Header: qlog.PacketHeader{PacketType: qlog.PacketTypeVersionNegotiation},
- Raw: qlog.RawInfo{Length: len(data)},
- Trigger: qlog.PacketDropUnexpectedPacket,
+ _, err = clientConn.WriteTo(data, tr.Conn.LocalAddr())
+ require.NoError(t, err)
+ time.Sleep(rtt) // so that the packet arrives at the server
+
+ require.Equal(t,
+ []qlogwriter.Event{
+ qlog.PacketDropped{
+ Header: qlog.PacketHeader{PacketType: qlog.PacketTypeVersionNegotiation},
+ Raw: qlog.RawInfo{Length: len(data)},
+ Trigger: qlog.PacketDropUnexpectedPacket,
+ },
},
- },
- eventRecorder.Events(qlog.PacketDropped{}),
- )
+ eventRecorder.Events(qlog.PacketDropped{}),
+ )
- // only a single listener can be set
- _, err = tr.Listen(&tls.Config{}, nil)
- require.Error(t, err)
- require.ErrorIs(t, err, errListenerAlreadySet)
+ // only a single listener can be set
+ _, err = tr.Listen(&tls.Config{}, nil)
+ require.Error(t, err)
+ require.ErrorIs(t, err, errListenerAlreadySet)
- require.NoError(t, ln.Close())
- // now it's possible to add a new listener
- ln, err = tr.Listen(&tls.Config{}, nil)
- require.NoError(t, err)
- defer ln.Close()
+ require.NoError(t, ln.Close())
+ // now it's possible to add a new listener
+ ln, err = tr.Listen(&tls.Config{}, nil)
+ require.NoError(t, err)
+ defer ln.Close()
+ })
}
func TestTransportNonQUICPackets(t *testing.T) {
- tr := &Transport{Conn: newUDPConnLocalhost(t)}
- defer tr.Close()
-
- ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(5*time.Millisecond))
- defer cancel()
- _, _, err := tr.ReadNonQUICPacket(ctx, make([]byte, 1024))
- require.Error(t, err)
- require.ErrorIs(t, err, context.DeadlineExceeded)
+ synctest.Test(t, func(t *testing.T) {
+ const rtt = 10 * time.Millisecond
+ clientConn, serverConn, closeFn := newSimnetLink(t, rtt)
+ defer closeFn()
- conn := newUDPConnLocalhost(t)
- data := []byte{0 /* don't set the QUIC bit */, 1, 2, 3}
- _, err = conn.WriteTo(data, tr.Conn.LocalAddr())
- require.NoError(t, err)
- _, err = conn.WriteTo(data, tr.Conn.LocalAddr())
- require.NoError(t, err)
+ tr := &Transport{Conn: serverConn}
+ defer tr.Close()
- ctx, cancel = context.WithTimeout(context.Background(), scaleDuration(time.Second))
- defer cancel()
- b := make([]byte, 1024)
- n, addr, err := tr.ReadNonQUICPacket(ctx, b)
- require.NoError(t, err)
- require.Equal(t, data, b[:n])
- require.Equal(t, addr, conn.LocalAddr())
+ ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
+ defer cancel()
+ _, _, err := tr.ReadNonQUICPacket(ctx, make([]byte, 1024))
+ require.Error(t, err)
+ require.ErrorIs(t, err, context.DeadlineExceeded)
- // now send a lot of packets without reading them
- for i := range 2 * maxQueuedNonQUICPackets {
- data := append([]byte{0 /* don't set the QUIC bit */, uint8(i)}, bytes.Repeat([]byte{uint8(i)}, 1000)...)
- _, err = conn.WriteTo(data, tr.Conn.LocalAddr())
+ data := []byte{0 /* don't set the QUIC bit */, 1, 2, 3}
+ _, err = clientConn.WriteTo(data, tr.Conn.LocalAddr())
+ require.NoError(t, err)
+ _, err = clientConn.WriteTo(data, tr.Conn.LocalAddr())
require.NoError(t, err)
- }
- time.Sleep(scaleDuration(10 * time.Millisecond))
- var received int
- for {
- ctx, cancel = context.WithTimeout(context.Background(), scaleDuration(20*time.Millisecond))
+ ctx, cancel = context.WithTimeout(context.Background(), time.Second)
defer cancel()
- _, _, err := tr.ReadNonQUICPacket(ctx, b)
- if errors.Is(err, context.DeadlineExceeded) {
- break
- }
+ b := make([]byte, 1024)
+ n, addr, err := tr.ReadNonQUICPacket(ctx, b)
require.NoError(t, err)
- received++
- }
- require.Equal(t, received, maxQueuedNonQUICPackets)
+ require.Equal(t, data, b[:n])
+ require.Equal(t, addr, clientConn.LocalAddr())
+
+ // now send a lot of packets without reading them
+ for i := range 2 * maxQueuedNonQUICPackets {
+ data := append([]byte{0 /* don't set the QUIC bit */, uint8(i)}, bytes.Repeat([]byte{uint8(i)}, 1000)...)
+ _, err = clientConn.WriteTo(data, tr.Conn.LocalAddr())
+ require.NoError(t, err)
+ }
+
+ time.Sleep(rtt) // so that all packets arrive at the server
+
+ var received int
+ for {
+ ctx, cancel = context.WithTimeout(context.Background(), 20*time.Millisecond)
+ defer cancel()
+ _, _, err := tr.ReadNonQUICPacket(ctx, b)
+ if errors.Is(err, context.DeadlineExceeded) {
+ break
+ }
+ require.NoError(t, err)
+ received++
+ }
+ require.Equal(t, received, maxQueuedNonQUICPackets)
+ })
}
type faultySyscallConn struct{ net.PacketConn }
originalClientConnConstructor := newClientConnection
t.Cleanup(func() { newClientConnection = originalClientConnConstructor })
- var conn *connTestHooks
- handshakeChan := make(chan struct{})
- blockRun := make(chan struct{})
- if early {
- conn = &connTestHooks{
- earlyConnReady: func() <-chan struct{} { return handshakeChan },
- handshakeComplete: func() <-chan struct{} { return make(chan struct{}) },
+ synctest.Test(t, func(t *testing.T) {
+ _, serverConn, closeFn := newSimnetLink(t, 10*time.Millisecond)
+ defer closeFn()
+
+ var conn *connTestHooks
+ handshakeChan := make(chan struct{})
+ blockRun := make(chan struct{})
+ if early {
+ conn = &connTestHooks{
+ earlyConnReady: func() <-chan struct{} { return handshakeChan },
+ handshakeComplete: func() <-chan struct{} { return make(chan struct{}) },
+ }
+ } else {
+ conn = &connTestHooks{
+ handshakeComplete: func() <-chan struct{} { return handshakeChan },
+ }
}
- } else {
- conn = &connTestHooks{
- handshakeComplete: func() <-chan struct{} { return handshakeChan },
+ conn.run = func() error { <-blockRun; return errors.New("done") }
+ defer close(blockRun)
+
+ newClientConnection = func(
+ _ context.Context,
+ _ sendConn,
+ _ connRunner,
+ _ protocol.ConnectionID,
+ _ protocol.ConnectionID,
+ _ ConnectionIDGenerator,
+ _ *statelessResetter,
+ _ *Config,
+ _ *tls.Config,
+ _ protocol.PacketNumber,
+ _ bool,
+ _ bool,
+ _ qlogwriter.Trace,
+ _ utils.Logger,
+ _ protocol.Version,
+ ) *wrappedConn {
+ return &wrappedConn{testHooks: conn}
}
- }
- conn.run = func() error { <-blockRun; return errors.New("done") }
- defer close(blockRun)
- newClientConnection = func(
- _ context.Context,
- _ sendConn,
- _ connRunner,
- _ protocol.ConnectionID,
- _ protocol.ConnectionID,
- _ ConnectionIDGenerator,
- _ *statelessResetter,
- _ *Config,
- _ *tls.Config,
- _ protocol.PacketNumber,
- _ bool,
- _ bool,
- _ qlogwriter.Trace,
- _ utils.Logger,
- _ protocol.Version,
- ) *wrappedConn {
- return &wrappedConn{testHooks: conn}
- }
+ tr := &Transport{Conn: serverConn}
+ tr.init(true)
+ defer tr.Close()
+
+ errChan := make(chan error, 1)
+ go func() {
+ var err error
+ if early {
+ _, err = tr.DialEarly(context.Background(), nil, &tls.Config{}, nil)
+ } else {
+ _, err = tr.Dial(context.Background(), nil, &tls.Config{}, nil)
+ }
+ errChan <- err
+ }()
- tr := &Transport{Conn: newUDPConnLocalhost(t)}
- tr.init(true)
- defer tr.Close()
+ synctest.Wait()
- errChan := make(chan error, 1)
- go func() {
- var err error
- if early {
- _, err = tr.DialEarly(context.Background(), nil, &tls.Config{}, nil)
- } else {
- _, err = tr.Dial(context.Background(), nil, &tls.Config{}, nil)
+ select {
+ case <-errChan:
+ t.Fatal("Dial shouldn't have returned")
+ default:
}
- errChan <- err
- }()
- select {
- case <-errChan:
- t.Fatal("Dial shouldn't have returned")
- case <-time.After(scaleDuration(10 * time.Millisecond)):
- }
+ close(handshakeChan)
- close(handshakeChan)
- select {
- case err := <-errChan:
- require.NoError(t, err)
- case <-time.After(time.Second):
- }
+ synctest.Wait()
+
+ select {
+ case err := <-errChan:
+ require.NoError(t, err)
+ default:
+ }
+ })
}
func TestTransportDialingVersionNegotiation(t *testing.T) {
}
func TestTransportReplaceWithClosed(t *testing.T) {
- t.Setenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING", "true")
-
+ // synctest works slightly differently on Go 1.24,
+ // so we skip the test
+ if strings.HasPrefix(runtime.Version(), "go1.24") {
+ t.Skip("skipping on Go 1.24 due to synctest issues")
+ }
t.Run("local", func(t *testing.T) {
testTransportReplaceWithClosed(t, true)
})
}
func testTransportReplaceWithClosed(t *testing.T, local bool) {
- srk := StatelessResetKey{1, 2, 3, 4}
- tr := &Transport{
- Conn: newUDPConnLocalhost(t),
- ConnectionIDLength: 4,
- StatelessResetKey: &srk,
- }
- tr.init(true)
- defer tr.Close()
-
- dur := scaleDuration(20 * time.Millisecond)
-
- var closePacket []byte
- if local {
- closePacket = []byte("foobar")
- }
+ synctest.Test(t, func(t *testing.T) {
+ clientConn, serverConn, closeFn := newSimnetLink(t, 10*time.Millisecond)
+ defer closeFn()
+
+ srk := StatelessResetKey{1, 2, 3, 4}
+ tr := &Transport{
+ Conn: serverConn,
+ ConnectionIDLength: 4,
+ StatelessResetKey: &srk,
+ }
+ tr.init(true)
+ defer tr.Close()
- handler := &mockPacketHandler{}
- connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
- m := (*packetHandlerMap)(tr)
- require.True(t, m.Add(connID, handler))
- m.ReplaceWithClosed([]protocol.ConnectionID{connID}, closePacket, dur)
+ var closePacket []byte
+ if local {
+ closePacket = []byte("foobar")
+ }
- p := make([]byte, 100)
- p[0] = 0x40 // QUIC bit
- copy(p[1:], connID.Bytes())
+ const expiry = 50 * time.Millisecond
+ handler := &mockPacketHandler{}
+ connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
+ m := (*packetHandlerMap)(tr)
+ require.True(t, m.Add(connID, handler))
+ m.ReplaceWithClosed([]protocol.ConnectionID{connID}, closePacket, expiry)
+
+ p := make([]byte, 100)
+ p[0] = 0x40 // QUIC bit
+ copy(p[1:], connID.Bytes())
+
+ var sent atomic.Int64
+ errChan := make(chan error, 1)
+ stopSending := make(chan struct{})
+ go func() {
+ defer close(errChan)
+ ticker := time.NewTicker(expiry / 200)
+ timeout := time.NewTimer(time.Second)
+ for {
+ select {
+ case <-stopSending:
+ return
+ case <-timeout.C:
+ errChan <- errors.New("timeout")
+ return
+ case <-ticker.C:
+ }
+ if _, err := clientConn.WriteTo(p, tr.Conn.LocalAddr()); err != nil {
+ errChan <- err
+ return
+ }
+ sent.Add(1)
+ }
+ }()
- conn := newUDPConnLocalhost(t)
- var sent atomic.Int64
- errChan := make(chan error, 1)
- stopSending := make(chan struct{})
- go func() {
- defer close(errChan)
- ticker := time.NewTicker(dur / 50)
- timeout := time.NewTimer(scaleDuration(time.Second))
+ // For locally closed connections, CONNECTION_CLOSE packets are sent with an exponential backoff
+ var received int
+ clientConn.SetReadDeadline(time.Now().Add(time.Hour))
for {
- select {
- case <-stopSending:
- return
- case <-timeout.C:
- errChan <- errors.New("timeout")
- return
- case <-ticker.C:
- }
- if _, err := conn.WriteTo(p, tr.Conn.LocalAddr()); err != nil {
- errChan <- err
- return
+ b := make([]byte, 100)
+ n, _, err := clientConn.ReadFrom(b)
+ require.NoError(t, err)
+ // at some point, the connection is cleaned up, and we'll receive a stateless reset
+ if !bytes.Equal(b[:n], []byte("foobar")) {
+ require.GreaterOrEqual(t, n, protocol.MinStatelessResetSize)
+ close(stopSending) // stop sending packets
+ break
}
- sent.Add(1)
+ received++
}
- }()
-
- // For locally closed connections, CONNECTION_CLOSE packets are sent with an exponential backoff
- var received int
- conn.SetReadDeadline(time.Now().Add(scaleDuration(time.Second)))
- for {
- b := make([]byte, 100)
- n, _, err := conn.ReadFrom(b)
- require.NoError(t, err)
- // at some point, the connection is cleaned up, and we'll receive a stateless reset
- if !bytes.Equal(b[:n], []byte("foobar")) {
- require.GreaterOrEqual(t, n, protocol.MinStatelessResetSize)
- close(stopSending) // stop sending packets
- break
- }
- received++
- }
- select {
- case err := <-errChan:
- require.NoError(t, err)
- case <-time.After(time.Second):
- t.Fatal("timeout")
- }
+ select {
+ case err := <-errChan:
+ require.NoError(t, err)
+ case <-time.After(time.Second):
+ t.Fatal("timeout")
+ }
- numSent := sent.Load()
- if !local {
- require.Zero(t, received)
- t.Logf("sent %d packets", numSent)
- return
- }
- t.Logf("sent %d packets, received %d CONNECTION_CLOSE copies", numSent, received)
- // timer resolution on Windows is terrible
- if runtime.GOOS != "windows" {
- require.GreaterOrEqual(t, numSent, int64(8))
- }
- require.GreaterOrEqual(t, received, int(math.Floor(math.Log2(float64(numSent)))))
- require.LessOrEqual(t, received, int(math.Ceil(math.Log2(float64(numSent)))))
+ numSent := sent.Load()
+ if !local {
+ require.Zero(t, received)
+ t.Logf("sent %d packets", numSent)
+ return
+ }
+ t.Logf("sent %d packets, received %d CONNECTION_CLOSE copies", numSent, received)
+ require.Equal(t, int(math.Ceil(math.Log2(float64(numSent)))), received)
+ })
}