import (
"bytes"
"context"
+ "math"
mrand "math/rand/v2"
"net"
"sync/atomic"
"time"
"github.com/quic-go/quic-go"
- quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
+ "github.com/quic-go/quic-go/internal/synctest"
"github.com/quic-go/quic-go/internal/wire"
+ "github.com/quic-go/quic-go/testutils/simnet"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
}
func TestDatagramLoss(t *testing.T) {
- const rtt = 10 * time.Millisecond
- const numDatagrams = 100
- const datagramSize = 500
-
- server, err := quic.Listen(
- newUDPConnLocalhost(t),
- getTLSConfig(),
- getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true, EnableDatagrams: true}),
- )
- require.NoError(t, err)
- defer server.Close()
+ synctest.Test(t, func(t *testing.T) {
+ const rtt = 100 * time.Millisecond
+ const numDatagrams = 100
+ const datagramSize = 500
+
+ clientAddr := &net.UDPAddr{IP: net.ParseIP("1.0.0.1"), Port: 9001}
+ serverAddr := &net.UDPAddr{IP: net.ParseIP("1.0.0.2"), Port: 9002}
+ var droppedToClient, droppedToServer, total atomic.Int32
+ n := &simnet.Simnet{
+ Router: &directionAwareDroppingRouter{
+ ClientAddr: clientAddr,
+ ServerAddr: serverAddr,
+ Drop: func(d direction, p simnet.Packet) bool {
+ if wire.IsLongHeaderPacket(p.Data[0]) { // don't drop Long Header packets
+ return false
+ }
+ if len(p.Data) < datagramSize { // don't drop ACK-only packets
+ return false
+ }
+ total.Add(1)
+ // drop about 20% of Short Header packets with DATAGRAM frames
+ if mrand.Int()%5 == 0 {
+ switch d {
+ case directionToClient:
+ droppedToClient.Add(1)
+ case directionToServer:
+ droppedToServer.Add(1)
+ }
+ return true
+ }
+ return false
+ },
+ },
+ }
+ settings := simnet.NodeBiDiLinkSettings{
+ Downlink: simnet.LinkSettings{BitsPerSecond: math.MaxInt, Latency: rtt / 4},
+ Uplink: simnet.LinkSettings{BitsPerSecond: math.MaxInt, Latency: rtt / 4},
+ }
+ clientPacketConn := n.NewEndpoint(clientAddr, settings)
+ defer clientPacketConn.Close()
+ serverPacketConn := n.NewEndpoint(serverAddr, settings)
+ defer serverPacketConn.Close()
+ require.NoError(t, n.Start())
+ defer n.Close()
+
+ server, err := quic.Listen(
+ serverPacketConn,
+ getTLSConfig(),
+ getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true, EnableDatagrams: true}),
+ )
+ require.NoError(t, err)
+ defer server.Close()
+
+ const sendInterval = time.Second // send a datagram every second
+ ctx, cancel := context.WithTimeout(context.Background(), (numDatagrams+10)*sendInterval)
+ defer cancel()
+ clientConn, err := quic.Dial(
+ ctx,
+ clientPacketConn,
+ serverPacketConn.LocalAddr(),
+ getTLSClientConfig(),
+ getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true, EnableDatagrams: true}),
+ )
+ require.NoError(t, err)
+ defer clientConn.CloseWithError(0, "")
- var droppedIncoming, droppedOutgoing, total atomic.Int32
- proxy := &quicproxy.Proxy{
- Conn: newUDPConnLocalhost(t),
- ServerAddr: server.Addr().(*net.UDPAddr),
- DropPacket: func(dir quicproxy.Direction, _, _ net.Addr, packet []byte) bool {
- if wire.IsLongHeaderPacket(packet[0]) { // don't drop Long Header packets
- return false
- }
- if len(packet) < datagramSize { // don't drop ACK-only packets
- return false
- }
- total.Add(1)
- // drop about 20% of Short Header packets with DATAGRAM frames
- if mrand.Int()%5 == 0 {
- switch dir {
- case quicproxy.DirectionIncoming:
- droppedIncoming.Add(1)
- case quicproxy.DirectionOutgoing:
- droppedOutgoing.Add(1)
+ serverConn, err := server.Accept(ctx)
+ require.NoError(t, err)
+ defer serverConn.CloseWithError(0, "")
+
+ var clientDatagrams, serverDatagrams int
+ clientErrChan := make(chan error, 1)
+ go func() {
+ defer close(clientErrChan)
+ for {
+ if _, err := clientConn.ReceiveDatagram(ctx); err != nil {
+ clientErrChan <- err
+ return
}
- return true
+ clientDatagrams++
}
- return false
- },
- DelayPacket: func(quicproxy.Direction, net.Addr, net.Addr, []byte) time.Duration { return rtt / 2 },
- }
- require.NoError(t, proxy.Start())
- defer proxy.Close()
-
- // SendDatagram blocks when the queue is full (maxDatagramSendQueueLen),
- // add some extra margin for the handshake, networking and ACKs.
- ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(4*numDatagrams*time.Millisecond))
- defer cancel()
- clientConn, err := quic.Dial(
- ctx,
- newUDPConnLocalhost(t),
- proxy.LocalAddr(),
- getTLSClientConfig(),
- getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true, EnableDatagrams: true}),
- )
- require.NoError(t, err)
- defer clientConn.CloseWithError(0, "")
+ }()
- serverConn, err := server.Accept(ctx)
- require.NoError(t, err)
- defer serverConn.CloseWithError(0, "")
-
- var clientDatagrams, serverDatagrams int
- clientErrChan := make(chan error, 1)
- go func() {
- defer close(clientErrChan)
- for {
- if _, err := clientConn.ReceiveDatagram(ctx); err != nil {
- clientErrChan <- err
- return
- }
- clientDatagrams++
+ for i := range numDatagrams {
+ payload := bytes.Repeat([]byte{uint8(i)}, datagramSize)
+ require.NoError(t, clientConn.SendDatagram(payload))
+ require.NoError(t, serverConn.SendDatagram(payload))
+ time.Sleep(sendInterval)
}
- }()
- for i := range numDatagrams {
- payload := bytes.Repeat([]byte{uint8(i)}, datagramSize)
- require.NoError(t, clientConn.SendDatagram(payload))
- require.NoError(t, serverConn.SendDatagram(payload))
- time.Sleep(scaleDuration(time.Millisecond / 2))
- }
-
- serverErrChan := make(chan error, 1)
- go func() {
- defer close(serverErrChan)
- for {
- if _, err := serverConn.ReceiveDatagram(ctx); err != nil {
- serverErrChan <- err
- return
+ serverErrChan := make(chan error, 1)
+ go func() {
+ defer close(serverErrChan)
+ for {
+ if _, err := serverConn.ReceiveDatagram(ctx); err != nil {
+ serverErrChan <- err
+ return
+ }
+ serverDatagrams++
}
- serverDatagrams++
- }
- }()
+ }()
- select {
- case err := <-clientErrChan:
- require.ErrorIs(t, err, context.DeadlineExceeded)
- case <-time.After(scaleDuration(5 * numDatagrams * time.Millisecond)):
- t.Fatal("timeout")
- }
- select {
- case err := <-serverErrChan:
- require.ErrorIs(t, err, context.DeadlineExceeded)
- case <-time.After(scaleDuration(5 * numDatagrams * time.Millisecond)):
- t.Fatal("timeout")
- }
+ select {
+ case err := <-clientErrChan:
+ require.ErrorIs(t, err, context.DeadlineExceeded)
+ case <-time.After(5 * numDatagrams * sendInterval):
+ t.Fatal("timeout")
+ }
+ select {
+ case err := <-serverErrChan:
+ require.ErrorIs(t, err, context.DeadlineExceeded)
+ case <-time.After(5 * numDatagrams * sendInterval):
+ t.Fatal("timeout")
+ }
- numDroppedIncoming := droppedIncoming.Load()
- numDroppedOutgoing := droppedOutgoing.Load()
- t.Logf("dropped %d incoming and %d outgoing out of %d packets", numDroppedIncoming, numDroppedOutgoing, total.Load())
- assert.NotZero(t, numDroppedIncoming)
- assert.NotZero(t, numDroppedOutgoing)
- t.Logf("server received %d out of %d sent datagrams", serverDatagrams, numDatagrams)
- assert.EqualValues(t, numDatagrams-numDroppedIncoming, serverDatagrams, "datagrams received by the server")
- t.Logf("client received %d out of %d sent datagrams", clientDatagrams, numDatagrams)
- assert.EqualValues(t, numDatagrams-numDroppedOutgoing, clientDatagrams, "datagrams received by the client")
+ numDroppedToClient := droppedToClient.Load()
+ numDroppedToServer := droppedToServer.Load()
+ t.Logf("dropped %d to client and %d to server out of %d packets", numDroppedToClient, numDroppedToServer, total.Load())
+ assert.NotZero(t, numDroppedToClient)
+ assert.NotZero(t, numDroppedToServer)
+ t.Logf("server received %d out of %d sent datagrams", serverDatagrams, numDatagrams)
+ assert.EqualValues(t, numDatagrams-numDroppedToServer, serverDatagrams, "datagrams received by the server")
+ t.Logf("client received %d out of %d sent datagrams", clientDatagrams, numDatagrams)
+ assert.EqualValues(t, numDatagrams-numDroppedToClient, clientDatagrams, "datagrams received by the client")
+ })
}
}
func dropCallbackDropNthPacket(dir direction, ns ...int) func(direction, simnet.Packet) bool {
- var incoming, outgoing atomic.Int32
+ var toClient, toServer atomic.Int32
return func(d direction, p simnet.Packet) bool {
switch d {
- case directionIncoming:
- c := incoming.Add(1)
+ case directionToClient:
+ c := toClient.Add(1)
if d == dir || dir == directionBoth {
return slices.Contains(ns, int(c))
}
- case directionOutgoing:
- c := outgoing.Add(1)
+ case directionToServer:
+ c := toServer.Add(1)
if dir == d || dir == directionBoth {
return slices.Contains(ns, int(c))
}
func dropCallbackDropOneThird(_ direction) func(direction, simnet.Packet) bool {
const maxSequentiallyDropped = 10
var mx sync.Mutex
- var incoming, outgoing int
+ var toClient, toServer int
return func(d direction, p simnet.Packet) bool {
drop := mrand.IntN(3) == 0
mx.Lock()
defer mx.Unlock()
// never drop more than 10 consecutive packets
- if d == directionIncoming || d == directionBoth {
+ if d == directionToClient || d == directionBoth {
if drop {
- incoming++
- if incoming > maxSequentiallyDropped {
+ toClient++
+ if toClient > maxSequentiallyDropped {
drop = false
}
}
if !drop {
- incoming = 0
+ toClient = 0
}
}
- if d == directionOutgoing || d == directionBoth {
+ if d == directionToServer || d == directionBoth {
if drop {
- outgoing++
- if outgoing > maxSequentiallyDropped {
+ toServer++
+ if toServer > maxSequentiallyDropped {
drop = false
}
}
if !drop {
- outgoing = 0
+ toServer = 0
}
}
return drop
doRetry bool
}
- for _, dir := range []direction{directionIncoming, directionOutgoing, directionBoth} {
+ for _, dir := range []direction{directionToClient, directionToServer, directionBoth} {
for _, pattern := range []dropPattern{
dropPatternDrop1stPacket,
dropPatternDropFirst3Packets,
dropPatternDropOneThirdOfPackets,
} {
- t.Run(fmt.Sprintf("%s in %s direction", pattern, dir), func(t *testing.T) {
+ t.Run(fmt.Sprintf("%s in direction %s", pattern, dir), func(t *testing.T) {
for _, conf := range []testConfig{
{postQuantum: false, longCertChain: false, doRetry: true},
{postQuantum: false, longCertChain: false, doRetry: false},