]> git.feebdaed.xyz Git - 0xmirror/quic-go.git/commitdiff
use synctest for the datagram test (#5398)
authorMarten Seemann <martenseemann@gmail.com>
Thu, 23 Oct 2025 17:32:01 +0000 (19:32 +0200)
committerGitHub <noreply@github.com>
Thu, 23 Oct 2025 17:32:01 +0000 (19:32 +0200)
integrationtests/self/datagram_test.go
integrationtests/self/handshake_drop_test.go
integrationtests/self/simnet_helper_test.go

index fdc499f6932f9c718cb557c4cc95117731a89006..fc21e829ac8640a6277e16e9707f4020c1bf8ad7 100644 (file)
@@ -3,6 +3,7 @@ package self_test
 import (
        "bytes"
        "context"
+       "math"
        mrand "math/rand/v2"
        "net"
        "sync/atomic"
@@ -10,8 +11,9 @@ import (
        "time"
 
        "github.com/quic-go/quic-go"
-       quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
+       "github.com/quic-go/quic-go/internal/synctest"
        "github.com/quic-go/quic-go/internal/wire"
+       "github.com/quic-go/quic-go/testutils/simnet"
 
        "github.com/stretchr/testify/assert"
        "github.com/stretchr/testify/require"
@@ -124,117 +126,129 @@ func TestDatagramSizeLimit(t *testing.T) {
 }
 
 func TestDatagramLoss(t *testing.T) {
-       const rtt = 10 * time.Millisecond
-       const numDatagrams = 100
-       const datagramSize = 500
-
-       server, err := quic.Listen(
-               newUDPConnLocalhost(t),
-               getTLSConfig(),
-               getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true, EnableDatagrams: true}),
-       )
-       require.NoError(t, err)
-       defer server.Close()
+       synctest.Test(t, func(t *testing.T) {
+               const rtt = 100 * time.Millisecond
+               const numDatagrams = 100
+               const datagramSize = 500
+
+               clientAddr := &net.UDPAddr{IP: net.ParseIP("1.0.0.1"), Port: 9001}
+               serverAddr := &net.UDPAddr{IP: net.ParseIP("1.0.0.2"), Port: 9002}
+               var droppedToClient, droppedToServer, total atomic.Int32
+               n := &simnet.Simnet{
+                       Router: &directionAwareDroppingRouter{
+                               ClientAddr: clientAddr,
+                               ServerAddr: serverAddr,
+                               Drop: func(d direction, p simnet.Packet) bool {
+                                       if wire.IsLongHeaderPacket(p.Data[0]) { // don't drop Long Header packets
+                                               return false
+                                       }
+                                       if len(p.Data) < datagramSize { // don't drop ACK-only packets
+                                               return false
+                                       }
+                                       total.Add(1)
+                                       // drop about 20% of Short Header packets with DATAGRAM frames
+                                       if mrand.Int()%5 == 0 {
+                                               switch d {
+                                               case directionToClient:
+                                                       droppedToClient.Add(1)
+                                               case directionToServer:
+                                                       droppedToServer.Add(1)
+                                               }
+                                               return true
+                                       }
+                                       return false
+                               },
+                       },
+               }
+               settings := simnet.NodeBiDiLinkSettings{
+                       Downlink: simnet.LinkSettings{BitsPerSecond: math.MaxInt, Latency: rtt / 4},
+                       Uplink:   simnet.LinkSettings{BitsPerSecond: math.MaxInt, Latency: rtt / 4},
+               }
+               clientPacketConn := n.NewEndpoint(clientAddr, settings)
+               defer clientPacketConn.Close()
+               serverPacketConn := n.NewEndpoint(serverAddr, settings)
+               defer serverPacketConn.Close()
+               require.NoError(t, n.Start())
+               defer n.Close()
+
+               server, err := quic.Listen(
+                       serverPacketConn,
+                       getTLSConfig(),
+                       getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true, EnableDatagrams: true}),
+               )
+               require.NoError(t, err)
+               defer server.Close()
+
+               const sendInterval = time.Second // send a datagram every second
+               ctx, cancel := context.WithTimeout(context.Background(), (numDatagrams+10)*sendInterval)
+               defer cancel()
+               clientConn, err := quic.Dial(
+                       ctx,
+                       clientPacketConn,
+                       serverPacketConn.LocalAddr(),
+                       getTLSClientConfig(),
+                       getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true, EnableDatagrams: true}),
+               )
+               require.NoError(t, err)
+               defer clientConn.CloseWithError(0, "")
 
-       var droppedIncoming, droppedOutgoing, total atomic.Int32
-       proxy := &quicproxy.Proxy{
-               Conn:       newUDPConnLocalhost(t),
-               ServerAddr: server.Addr().(*net.UDPAddr),
-               DropPacket: func(dir quicproxy.Direction, _, _ net.Addr, packet []byte) bool {
-                       if wire.IsLongHeaderPacket(packet[0]) { // don't drop Long Header packets
-                               return false
-                       }
-                       if len(packet) < datagramSize { // don't drop ACK-only packets
-                               return false
-                       }
-                       total.Add(1)
-                       // drop about 20% of Short Header packets with DATAGRAM frames
-                       if mrand.Int()%5 == 0 {
-                               switch dir {
-                               case quicproxy.DirectionIncoming:
-                                       droppedIncoming.Add(1)
-                               case quicproxy.DirectionOutgoing:
-                                       droppedOutgoing.Add(1)
+               serverConn, err := server.Accept(ctx)
+               require.NoError(t, err)
+               defer serverConn.CloseWithError(0, "")
+
+               var clientDatagrams, serverDatagrams int
+               clientErrChan := make(chan error, 1)
+               go func() {
+                       defer close(clientErrChan)
+                       for {
+                               if _, err := clientConn.ReceiveDatagram(ctx); err != nil {
+                                       clientErrChan <- err
+                                       return
                                }
-                               return true
+                               clientDatagrams++
                        }
-                       return false
-               },
-               DelayPacket: func(quicproxy.Direction, net.Addr, net.Addr, []byte) time.Duration { return rtt / 2 },
-       }
-       require.NoError(t, proxy.Start())
-       defer proxy.Close()
-
-       // SendDatagram blocks when the queue is full (maxDatagramSendQueueLen),
-       // add some extra margin for the handshake, networking and ACKs.
-       ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(4*numDatagrams*time.Millisecond))
-       defer cancel()
-       clientConn, err := quic.Dial(
-               ctx,
-               newUDPConnLocalhost(t),
-               proxy.LocalAddr(),
-               getTLSClientConfig(),
-               getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true, EnableDatagrams: true}),
-       )
-       require.NoError(t, err)
-       defer clientConn.CloseWithError(0, "")
+               }()
 
-       serverConn, err := server.Accept(ctx)
-       require.NoError(t, err)
-       defer serverConn.CloseWithError(0, "")
-
-       var clientDatagrams, serverDatagrams int
-       clientErrChan := make(chan error, 1)
-       go func() {
-               defer close(clientErrChan)
-               for {
-                       if _, err := clientConn.ReceiveDatagram(ctx); err != nil {
-                               clientErrChan <- err
-                               return
-                       }
-                       clientDatagrams++
+               for i := range numDatagrams {
+                       payload := bytes.Repeat([]byte{uint8(i)}, datagramSize)
+                       require.NoError(t, clientConn.SendDatagram(payload))
+                       require.NoError(t, serverConn.SendDatagram(payload))
+                       time.Sleep(sendInterval)
                }
-       }()
 
-       for i := range numDatagrams {
-               payload := bytes.Repeat([]byte{uint8(i)}, datagramSize)
-               require.NoError(t, clientConn.SendDatagram(payload))
-               require.NoError(t, serverConn.SendDatagram(payload))
-               time.Sleep(scaleDuration(time.Millisecond / 2))
-       }
-
-       serverErrChan := make(chan error, 1)
-       go func() {
-               defer close(serverErrChan)
-               for {
-                       if _, err := serverConn.ReceiveDatagram(ctx); err != nil {
-                               serverErrChan <- err
-                               return
+               serverErrChan := make(chan error, 1)
+               go func() {
+                       defer close(serverErrChan)
+                       for {
+                               if _, err := serverConn.ReceiveDatagram(ctx); err != nil {
+                                       serverErrChan <- err
+                                       return
+                               }
+                               serverDatagrams++
                        }
-                       serverDatagrams++
-               }
-       }()
+               }()
 
-       select {
-       case err := <-clientErrChan:
-               require.ErrorIs(t, err, context.DeadlineExceeded)
-       case <-time.After(scaleDuration(5 * numDatagrams * time.Millisecond)):
-               t.Fatal("timeout")
-       }
-       select {
-       case err := <-serverErrChan:
-               require.ErrorIs(t, err, context.DeadlineExceeded)
-       case <-time.After(scaleDuration(5 * numDatagrams * time.Millisecond)):
-               t.Fatal("timeout")
-       }
+               select {
+               case err := <-clientErrChan:
+                       require.ErrorIs(t, err, context.DeadlineExceeded)
+               case <-time.After(5 * numDatagrams * sendInterval):
+                       t.Fatal("timeout")
+               }
+               select {
+               case err := <-serverErrChan:
+                       require.ErrorIs(t, err, context.DeadlineExceeded)
+               case <-time.After(5 * numDatagrams * sendInterval):
+                       t.Fatal("timeout")
+               }
 
-       numDroppedIncoming := droppedIncoming.Load()
-       numDroppedOutgoing := droppedOutgoing.Load()
-       t.Logf("dropped %d incoming and %d outgoing out of %d packets", numDroppedIncoming, numDroppedOutgoing, total.Load())
-       assert.NotZero(t, numDroppedIncoming)
-       assert.NotZero(t, numDroppedOutgoing)
-       t.Logf("server received %d out of %d sent datagrams", serverDatagrams, numDatagrams)
-       assert.EqualValues(t, numDatagrams-numDroppedIncoming, serverDatagrams, "datagrams received by the server")
-       t.Logf("client received %d out of %d sent datagrams", clientDatagrams, numDatagrams)
-       assert.EqualValues(t, numDatagrams-numDroppedOutgoing, clientDatagrams, "datagrams received by the client")
+               numDroppedToClient := droppedToClient.Load()
+               numDroppedToServer := droppedToServer.Load()
+               t.Logf("dropped %d to client and %d to server out of %d packets", numDroppedToClient, numDroppedToServer, total.Load())
+               assert.NotZero(t, numDroppedToClient)
+               assert.NotZero(t, numDroppedToServer)
+               t.Logf("server received %d out of %d sent datagrams", serverDatagrams, numDatagrams)
+               assert.EqualValues(t, numDatagrams-numDroppedToServer, serverDatagrams, "datagrams received by the server")
+               t.Logf("client received %d out of %d sent datagrams", clientDatagrams, numDatagrams)
+               assert.EqualValues(t, numDatagrams-numDroppedToClient, clientDatagrams, "datagrams received by the client")
+       })
 }
index 63655a0f54d729e372e207f3b6ccd29fb71776ea..09592c0aacc3e5027fefe88bfc303318ff9fd079 100644 (file)
@@ -147,16 +147,16 @@ func dropTestProtocolNobodySpeaks(t *testing.T, ln *quic.Listener, clientConn ne
 }
 
 func dropCallbackDropNthPacket(dir direction, ns ...int) func(direction, simnet.Packet) bool {
-       var incoming, outgoing atomic.Int32
+       var toClient, toServer atomic.Int32
        return func(d direction, p simnet.Packet) bool {
                switch d {
-               case directionIncoming:
-                       c := incoming.Add(1)
+               case directionToClient:
+                       c := toClient.Add(1)
                        if d == dir || dir == directionBoth {
                                return slices.Contains(ns, int(c))
                        }
-               case directionOutgoing:
-                       c := outgoing.Add(1)
+               case directionToServer:
+                       c := toServer.Add(1)
                        if dir == d || dir == directionBoth {
                                return slices.Contains(ns, int(c))
                        }
@@ -168,33 +168,33 @@ func dropCallbackDropNthPacket(dir direction, ns ...int) func(direction, simnet.
 func dropCallbackDropOneThird(_ direction) func(direction, simnet.Packet) bool {
        const maxSequentiallyDropped = 10
        var mx sync.Mutex
-       var incoming, outgoing int
+       var toClient, toServer int
        return func(d direction, p simnet.Packet) bool {
                drop := mrand.IntN(3) == 0
 
                mx.Lock()
                defer mx.Unlock()
                // never drop more than 10 consecutive packets
-               if d == directionIncoming || d == directionBoth {
+               if d == directionToClient || d == directionBoth {
                        if drop {
-                               incoming++
-                               if incoming > maxSequentiallyDropped {
+                               toClient++
+                               if toClient > maxSequentiallyDropped {
                                        drop = false
                                }
                        }
                        if !drop {
-                               incoming = 0
+                               toClient = 0
                        }
                }
-               if d == directionOutgoing || d == directionBoth {
+               if d == directionToServer || d == directionBoth {
                        if drop {
-                               outgoing++
-                               if outgoing > maxSequentiallyDropped {
+                               toServer++
+                               if toServer > maxSequentiallyDropped {
                                        drop = false
                                }
                        }
                        if !drop {
-                               outgoing = 0
+                               toServer = 0
                        }
                }
                return drop
@@ -220,13 +220,13 @@ func TestHandshakeWithPacketLoss(t *testing.T) {
                doRetry       bool
        }
 
-       for _, dir := range []direction{directionIncoming, directionOutgoing, directionBoth} {
+       for _, dir := range []direction{directionToClient, directionToServer, directionBoth} {
                for _, pattern := range []dropPattern{
                        dropPatternDrop1stPacket,
                        dropPatternDropFirst3Packets,
                        dropPatternDropOneThirdOfPackets,
                } {
-                       t.Run(fmt.Sprintf("%s in %s direction", pattern, dir), func(t *testing.T) {
+                       t.Run(fmt.Sprintf("%s in direction %s", pattern, dir), func(t *testing.T) {
                                for _, conf := range []testConfig{
                                        {postQuantum: false, longCertChain: false, doRetry: true},
                                        {postQuantum: false, longCertChain: false, doRetry: false},
index 35d41f70fda466097ac16130e8d830ddab674dea..10f2afd98553f4e727d87a7e8cb720977084fd7e 100644 (file)
@@ -23,17 +23,17 @@ type direction uint8
 
 const (
        directionUnknown = iota
-       directionIncoming
-       directionOutgoing
+       directionToClient
+       directionToServer
        directionBoth
 )
 
 func (d direction) String() string {
        switch d {
-       case directionIncoming:
-               return "incoming"
-       case directionOutgoing:
-               return "outgoing"
+       case directionToClient:
+               return "to client"
+       case directionToServer:
+               return "to server"
        case directionBoth:
                return "both"
        }
@@ -54,9 +54,9 @@ func (d *directionAwareDroppingRouter) SendPacket(p simnet.Packet) error {
        var dir direction
        switch p.To.String() {
        case d.ClientAddr.String():
-               dir = directionIncoming
+               dir = directionToClient
        case d.ServerAddr.String():
-               dir = directionOutgoing
+               dir = directionToServer
        default:
                dir = directionUnknown
        }