]> git.feebdaed.xyz Git - 0xmirror/quic-go.git/commitdiff
use synctest for the handshake drop test (#5397)
authorMarten Seemann <martenseemann@gmail.com>
Thu, 23 Oct 2025 14:56:35 +0000 (16:56 +0200)
committerGitHub <noreply@github.com>
Thu, 23 Oct 2025 14:56:35 +0000 (16:56 +0200)
This test used to take 15-30s locally, and even more on CI. It now runs in less than 1ms.

integrationtests/self/handshake_drop_test.go
integrationtests/self/self_go124_test.go [new file with mode: 0644]
integrationtests/self/self_go125_test.go [new file with mode: 0644]
integrationtests/self/simnet_helper_test.go

index 211d635606afb598dc327021abc72f704f746664..63655a0f54d729e372e207f3b6ccd29fb71776ea 100644 (file)
@@ -3,65 +3,35 @@ package self_test
 import (
        "bytes"
        "context"
-       "crypto/rand"
        "crypto/tls"
        "fmt"
        "io"
+       "math"
        mrand "math/rand/v2"
        "net"
+       "runtime"
+       "slices"
+       "strings"
        "sync"
        "sync/atomic"
        "testing"
        "time"
 
        "github.com/quic-go/quic-go"
-       quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
-       "github.com/quic-go/quic-go/internal/wire"
+       "github.com/quic-go/quic-go/internal/synctest"
+       "github.com/quic-go/quic-go/testutils/simnet"
 
        "github.com/stretchr/testify/require"
 )
 
-func startDropTestListenerAndProxy(t *testing.T, rtt, timeout time.Duration, dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool) (_ *quic.Listener, proxyAddr net.Addr) {
-       t.Helper()
-       conf := getQuicConfig(&quic.Config{
-               MaxIdleTimeout:          timeout,
-               HandshakeIdleTimeout:    timeout,
-               DisablePathMTUDiscovery: true,
-       })
-       var tlsConf *tls.Config
-       if longCertChain {
-               tlsConf = getTLSConfigWithLongCertChain()
-       } else {
-               tlsConf = getTLSConfig()
-       }
-       tr := &quic.Transport{
-               Conn:                newUDPConnLocalhost(t),
-               VerifySourceAddress: func(net.Addr) bool { return doRetry },
-       }
-       t.Cleanup(func() { tr.Close() })
-       ln, err := tr.Listen(tlsConf, conf)
-       require.NoError(t, err)
-       t.Cleanup(func() { ln.Close() })
-
-       proxy := quicproxy.Proxy{
-               Conn:        newUDPConnLocalhost(t),
-               ServerAddr:  ln.Addr().(*net.UDPAddr),
-               DropPacket:  dropCallback,
-               DelayPacket: func(quicproxy.Direction, net.Addr, net.Addr, []byte) time.Duration { return rtt / 2 },
-       }
-       require.NoError(t, proxy.Start())
-       t.Cleanup(func() { proxy.Close() })
-       return ln, proxy.LocalAddr()
-}
-
-func dropTestProtocolClientSpeaksFirst(t *testing.T, ln *quic.Listener, addr net.Addr, timeout time.Duration, data []byte) {
+func dropTestProtocolClientSpeaksFirst(t *testing.T, ln *quic.Listener, clientConn net.PacketConn, clientConf *tls.Config, timeout time.Duration, data []byte) *quic.Conn {
        ctx, cancel := context.WithTimeout(context.Background(), timeout)
        defer cancel()
        conn, err := quic.Dial(
                ctx,
-               newUDPConnLocalhost(t),
-               addr,
-               getTLSClientConfig(),
+               clientConn,
+               ln.Addr(),
+               clientConf,
                getQuicConfig(&quic.Config{
                        MaxIdleTimeout:          timeout,
                        HandshakeIdleTimeout:    timeout,
@@ -88,16 +58,18 @@ func dropTestProtocolClientSpeaksFirst(t *testing.T, ln *quic.Listener, addr net
        require.NoError(t, err)
        require.Equal(t, b, data)
        serverConn.CloseWithError(0, "")
+
+       return conn
 }
 
-func dropTestProtocolServerSpeaksFirst(t *testing.T, ln *quic.Listener, addr net.Addr, timeout time.Duration, data []byte) {
+func dropTestProtocolServerSpeaksFirst(t *testing.T, ln *quic.Listener, clientConn net.PacketConn, clientConf *tls.Config, timeout time.Duration, data []byte) *quic.Conn {
        ctx, cancel := context.WithTimeout(context.Background(), timeout)
        defer cancel()
        conn, err := quic.Dial(
                ctx,
-               newUDPConnLocalhost(t),
-               addr,
-               getTLSClientConfig(),
+               clientConn,
+               ln.Addr(),
+               clientConf,
                getQuicConfig(&quic.Config{
                        MaxIdleTimeout:          timeout,
                        HandshakeIdleTimeout:    timeout,
@@ -146,16 +118,18 @@ func dropTestProtocolServerSpeaksFirst(t *testing.T, ln *quic.Listener, addr net
        case <-time.After(timeout):
                t.Fatal("server connection not closed")
        }
+
+       return conn
 }
 
-func dropTestProtocolNobodySpeaks(t *testing.T, ln *quic.Listener, addr net.Addr, timeout time.Duration) {
+func dropTestProtocolNobodySpeaks(t *testing.T, ln *quic.Listener, clientConn net.PacketConn, clientConf *tls.Config, timeout time.Duration, _ []byte) *quic.Conn {
        ctx, cancel := context.WithTimeout(context.Background(), timeout)
        defer cancel()
        conn, err := quic.Dial(
                ctx,
-               newUDPConnLocalhost(t),
-               addr,
-               getTLSClientConfig(),
+               clientConn,
+               ln.Addr(),
+               clientConf,
                getQuicConfig(&quic.Config{
                        MaxIdleTimeout:          timeout,
                        HandshakeIdleTimeout:    timeout,
@@ -168,33 +142,40 @@ func dropTestProtocolNobodySpeaks(t *testing.T, ln *quic.Listener, addr net.Addr
        serverConn, err := ln.Accept(ctx)
        require.NoError(t, err)
        serverConn.CloseWithError(0, "")
+
+       return conn
 }
 
-func dropCallbackDropNthPacket(direction quicproxy.Direction, n int) quicproxy.DropCallback {
+func dropCallbackDropNthPacket(dir direction, ns ...int) func(direction, simnet.Packet) bool {
        var incoming, outgoing atomic.Int32
-       return func(d quicproxy.Direction, _, _ net.Addr, packet []byte) bool {
-               var p int32
+       return func(d direction, p simnet.Packet) bool {
                switch d {
-               case quicproxy.DirectionIncoming:
-                       p = incoming.Add(1)
-               case quicproxy.DirectionOutgoing:
-                       p = outgoing.Add(1)
+               case directionIncoming:
+                       c := incoming.Add(1)
+                       if d == dir || dir == directionBoth {
+                               return slices.Contains(ns, int(c))
+                       }
+               case directionOutgoing:
+                       c := outgoing.Add(1)
+                       if dir == d || dir == directionBoth {
+                               return slices.Contains(ns, int(c))
+                       }
                }
-               return p == int32(n) && d.Is(direction)
+               return false
        }
 }
 
-func dropCallbackDropOneThird(direction quicproxy.Direction) quicproxy.DropCallback {
+func dropCallbackDropOneThird(_ direction) func(direction, simnet.Packet) bool {
        const maxSequentiallyDropped = 10
        var mx sync.Mutex
        var incoming, outgoing int
-       return func(d quicproxy.Direction, _, _ net.Addr, _ []byte) bool {
+       return func(d direction, p simnet.Packet) bool {
                drop := mrand.IntN(3) == 0
 
                mx.Lock()
                defer mx.Unlock()
                // never drop more than 10 consecutive packets
-               if d.Is(quicproxy.DirectionIncoming) {
+               if d == directionIncoming || d == directionBoth {
                        if drop {
                                incoming++
                                if incoming > maxSequentiallyDropped {
@@ -205,7 +186,7 @@ func dropCallbackDropOneThird(direction quicproxy.Direction) quicproxy.DropCallb
                                incoming = 0
                        }
                }
-               if d.Is(quicproxy.DirectionOutgoing) {
+               if d == directionOutgoing || d == directionBoth {
                        if drop {
                                outgoing++
                                if outgoing > maxSequentiallyDropped {
@@ -225,63 +206,127 @@ func TestHandshakeWithPacketLoss(t *testing.T) {
        const timeout = 2 * time.Minute
        const rtt = 20 * time.Millisecond
 
-       type dropPattern struct {
-               name string
-               fn   quicproxy.DropCallback
-       }
+       type dropPattern string
 
-       type serverConfig struct {
+       const (
+               dropPatternDrop1stPacket         dropPattern = "drop 1st packet"
+               dropPatternDropFirst3Packets     dropPattern = "drop first 3 packets"
+               dropPatternDropOneThirdOfPackets dropPattern = "drop 1/3 of packets"
+       )
+
+       type testConfig struct {
+               postQuantum   bool
                longCertChain bool
                doRetry       bool
        }
 
-       for _, direction := range []quicproxy.Direction{quicproxy.DirectionIncoming, quicproxy.DirectionOutgoing, quicproxy.DirectionBoth} {
-               for _, dropPattern := range []dropPattern{
-                       {name: "drop 1st packet", fn: dropCallbackDropNthPacket(direction, 1)},
-                       {name: "drop 2nd packet", fn: dropCallbackDropNthPacket(direction, 2)},
-                       {name: "drop 1/3 of packets", fn: dropCallbackDropOneThird(direction)},
+       for _, dir := range []direction{directionIncoming, directionOutgoing, directionBoth} {
+               for _, pattern := range []dropPattern{
+                       dropPatternDrop1stPacket,
+                       dropPatternDropFirst3Packets,
+                       dropPatternDropOneThirdOfPackets,
                } {
-                       t.Run(fmt.Sprintf("%s in %s direction", dropPattern.name, direction), func(t *testing.T) {
-                               for _, conf := range []serverConfig{
-                                       {longCertChain: false, doRetry: true},
-                                       {longCertChain: false, doRetry: false},
-                                       {longCertChain: true, doRetry: false},
+                       t.Run(fmt.Sprintf("%s in %s direction", pattern, dir), func(t *testing.T) {
+                               for _, conf := range []testConfig{
+                                       {postQuantum: false, longCertChain: false, doRetry: true},
+                                       {postQuantum: false, longCertChain: false, doRetry: false},
+                                       {postQuantum: false, longCertChain: true, doRetry: false},
+                                       {postQuantum: true, longCertChain: false, doRetry: false},
+                                       {postQuantum: true, longCertChain: true, doRetry: false},
                                } {
-                                       t.Run(fmt.Sprintf("retry: %t", conf.doRetry), func(t *testing.T) {
-                                               t.Run("client speaks first", func(t *testing.T) {
-                                                       ln, proxyAddr := startDropTestListenerAndProxy(t, rtt, timeout, dropPattern.fn, conf.doRetry, conf.longCertChain)
-                                                       dropTestProtocolClientSpeaksFirst(t, ln, proxyAddr, timeout, data)
-                                               })
+                                       for _, test := range []struct {
+                                               name string
+                                               fn   func(t *testing.T, ln *quic.Listener, clientConn net.PacketConn, clientConf *tls.Config, timeout time.Duration, data []byte) *quic.Conn
+                                       }{
+                                               {"client speaks first", dropTestProtocolClientSpeaksFirst},
+                                               {"server speaks first", dropTestProtocolServerSpeaksFirst},
+                                               {"nobody speaks", dropTestProtocolNobodySpeaks},
+                                       } {
+                                               t.Run(fmt.Sprintf("retry: %t/%s", conf.doRetry, test.name), func(t *testing.T) {
+                                                       synctest.Test(t, func(t *testing.T) {
+                                                               clientAddr := &net.UDPAddr{IP: net.ParseIP("1.0.0.1"), Port: 9001}
+                                                               serverAddr := &net.UDPAddr{IP: net.ParseIP("1.0.0.2"), Port: 9002}
+                                                               var fn func(direction, simnet.Packet) bool
+                                                               switch pattern {
+                                                               case dropPatternDrop1stPacket:
+                                                                       fn = dropCallbackDropNthPacket(dir, 1)
+                                                               case dropPatternDropFirst3Packets:
+                                                                       fn = dropCallbackDropNthPacket(dir, 1, 2, 3)
+                                                               case dropPatternDropOneThirdOfPackets:
+                                                                       fn = dropCallbackDropOneThird(dir)
+                                                               }
+                                                               var numDropped atomic.Int32
+                                                               n := &simnet.Simnet{
+                                                                       Router: &directionAwareDroppingRouter{
+                                                                               ClientAddr: clientAddr,
+                                                                               ServerAddr: serverAddr,
+                                                                               Drop: func(d direction, p simnet.Packet) bool {
+                                                                                       drop := fn(d, p)
+                                                                                       if drop {
+                                                                                               numDropped.Add(1)
+                                                                                       }
+                                                                                       return drop
+                                                                               },
+                                                                       },
+                                                               }
+                                                               settings := simnet.NodeBiDiLinkSettings{
+                                                                       Downlink: simnet.LinkSettings{BitsPerSecond: math.MaxInt, Latency: rtt / 4},
+                                                                       Uplink:   simnet.LinkSettings{BitsPerSecond: math.MaxInt, Latency: rtt / 4},
+                                                               }
+                                                               clientConn := n.NewEndpoint(clientAddr, settings)
+                                                               defer clientConn.Close()
+                                                               serverConn := n.NewEndpoint(serverAddr, settings)
+                                                               defer serverConn.Close()
+                                                               require.NoError(t, n.Start())
+                                                               defer n.Close()
 
-                                               t.Run("server speaks first", func(t *testing.T) {
-                                                       ln, proxyAddr := startDropTestListenerAndProxy(t, rtt, timeout, dropPattern.fn, conf.doRetry, conf.longCertChain)
-                                                       dropTestProtocolServerSpeaksFirst(t, ln, proxyAddr, timeout, data)
-                                               })
+                                                               var tlsConf *tls.Config
+                                                               if conf.longCertChain {
+                                                                       tlsConf = getTLSConfigWithLongCertChain()
+                                                               } else {
+                                                                       tlsConf = getTLSConfig()
+                                                               }
+                                                               clientConf := getTLSClientConfig()
+                                                               if !conf.postQuantum {
+                                                                       clientConf.CurvePreferences = []tls.CurveID{tls.CurveP384}
+                                                               }
+
+                                                               tr := &quic.Transport{
+                                                                       Conn:                serverConn,
+                                                                       VerifySourceAddress: func(net.Addr) bool { return conf.doRetry },
+                                                               }
+                                                               defer tr.Close()
+
+                                                               ln, err := tr.Listen(
+                                                                       tlsConf,
+                                                                       getQuicConfig(&quic.Config{
+                                                                               MaxIdleTimeout:          timeout,
+                                                                               HandshakeIdleTimeout:    timeout,
+                                                                               DisablePathMTUDiscovery: true,
+                                                                       }),
+                                                               )
+                                                               require.NoError(t, err)
+                                                               defer ln.Close()
 
-                                               t.Run("nobody speaks", func(t *testing.T) {
-                                                       ln, proxyAddr := startDropTestListenerAndProxy(t, rtt, timeout, dropPattern.fn, conf.doRetry, conf.longCertChain)
-                                                       dropTestProtocolNobodySpeaks(t, ln, proxyAddr, timeout)
+                                                               conn := test.fn(t, ln, clientConn, clientConf, timeout, data)
+                                                               if !strings.HasPrefix(runtime.Version(), "go1.24") {
+                                                                       curveID := getCurveID(conn.ConnectionState().TLS)
+                                                                       if conf.postQuantum {
+                                                                               require.Equal(t, tls.X25519MLKEM768, curveID)
+                                                                       } else {
+                                                                               require.Equal(t, tls.CurveP384, curveID)
+                                                                       }
+                                                               }
+
+                                                               if pattern != dropPatternDropOneThirdOfPackets {
+                                                                       require.NotZero(t, numDropped.Load())
+                                                               }
+                                                               t.Logf("dropped %d packets", numDropped.Load())
+                                                       })
                                                })
-                                       })
+                                       }
                                }
                        })
                }
        }
 }
-
-func TestPostQuantumClientHello(t *testing.T) {
-       origAdditionalTransportParametersClient := wire.AdditionalTransportParametersClient
-       t.Cleanup(func() { wire.AdditionalTransportParametersClient = origAdditionalTransportParametersClient })
-
-       b := make([]byte, 2500) // the ClientHello will now span across 3 packets
-       rand.Read(b)
-       wire.AdditionalTransportParametersClient = map[uint64][]byte{
-               // We don't use a greased transport parameter here, since the transport parameter serialization function
-               // will add a greased transport parameter, and therefore there's a risk of a collision.
-               // Instead, we just use pseudorandom constant value.
-               1234567: b,
-       }
-
-       ln, proxyPort := startDropTestListenerAndProxy(t, 10*time.Millisecond, 20*time.Second, dropCallbackDropOneThird(quicproxy.DirectionIncoming), false, false)
-       dropTestProtocolClientSpeaksFirst(t, ln, proxyPort, time.Minute, GeneratePRData(5000))
-}
diff --git a/integrationtests/self/self_go124_test.go b/integrationtests/self/self_go124_test.go
new file mode 100644 (file)
index 0000000..16146e7
--- /dev/null
@@ -0,0 +1,9 @@
+//go:build !go1.25
+
+package self_test
+
+import "crypto/tls"
+
+func getCurveID(connState tls.ConnectionState) tls.CurveID {
+       return 0
+}
diff --git a/integrationtests/self/self_go125_test.go b/integrationtests/self/self_go125_test.go
new file mode 100644 (file)
index 0000000..3474e15
--- /dev/null
@@ -0,0 +1,9 @@
+//go:build go1.25
+
+package self_test
+
+import "crypto/tls"
+
+func getCurveID(connState tls.ConnectionState) tls.CurveID {
+       return connState.CurveID
+}
index 8d56c25567ec53555c29537e47f64ac8e22110b2..35d41f70fda466097ac16130e8d830ddab674dea 100644 (file)
@@ -1,6 +1,10 @@
 package self_test
 
-import "github.com/quic-go/quic-go/testutils/simnet"
+import (
+       "net"
+
+       "github.com/quic-go/quic-go/testutils/simnet"
+)
 
 type droppingRouter struct {
        simnet.PerfectRouter
@@ -15,4 +19,49 @@ func (d *droppingRouter) SendPacket(p simnet.Packet) error {
        return d.PerfectRouter.SendPacket(p)
 }
 
+type direction uint8
+
+const (
+       directionUnknown = iota
+       directionIncoming
+       directionOutgoing
+       directionBoth
+)
+
+func (d direction) String() string {
+       switch d {
+       case directionIncoming:
+               return "incoming"
+       case directionOutgoing:
+               return "outgoing"
+       case directionBoth:
+               return "both"
+       }
+       return "unknown"
+}
+
 var _ simnet.Router = &droppingRouter{}
+
+type directionAwareDroppingRouter struct {
+       simnet.PerfectRouter
+
+       ClientAddr, ServerAddr *net.UDPAddr
+
+       Drop func(direction direction, p simnet.Packet) bool
+}
+
+func (d *directionAwareDroppingRouter) SendPacket(p simnet.Packet) error {
+       var dir direction
+       switch p.To.String() {
+       case d.ClientAddr.String():
+               dir = directionIncoming
+       case d.ServerAddr.String():
+               dir = directionOutgoing
+       default:
+               dir = directionUnknown
+       }
+       if d.Drop(dir, p) {
+               return nil
+       }
+       return d.PerfectRouter.SendPacket(p)
+}