From: Marten Seemann Date: Fri, 24 Oct 2025 10:51:09 +0000 (+0200) Subject: use synctest for the timeout tests (#5403) X-Git-Url: https://git.feebdaed.xyz/?a=commitdiff_plain;h=f2b845a2f50d6957ee7592d6510a9de054318ff6;p=0xmirror%2Fquic-go.git use synctest for the timeout tests (#5403) --- diff --git a/integrationtests/self/simnet_helper_test.go b/integrationtests/self/simnet_helper_test.go index 10f2afd9..8669b8d8 100644 --- a/integrationtests/self/simnet_helper_test.go +++ b/integrationtests/self/simnet_helper_test.go @@ -2,10 +2,40 @@ package self_test import ( "net" + "testing" + "time" "github.com/quic-go/quic-go/testutils/simnet" + + "github.com/stretchr/testify/require" ) +func newSimnetLink(t *testing.T, rtt time.Duration) (client, server *simnet.SimConn, close func(t *testing.T)) { + t.Helper() + + return newSimnetLinkWithRouter(t, rtt, &simnet.PerfectRouter{}) +} + +func newSimnetLinkWithRouter(t *testing.T, rtt time.Duration, router simnet.Router) (client, server *simnet.SimConn, close func(t *testing.T)) { + t.Helper() + + n := &simnet.Simnet{Router: router} + settings := simnet.NodeBiDiLinkSettings{ + Downlink: simnet.LinkSettings{BitsPerSecond: 1e8, Latency: rtt / 4}, + Uplink: simnet.LinkSettings{BitsPerSecond: 1e8, Latency: rtt / 4}, + } + clientPacketConn := n.NewEndpoint(&net.UDPAddr{IP: net.ParseIP("1.0.0.1"), Port: 9001}, settings) + serverPacketConn := n.NewEndpoint(&net.UDPAddr{IP: net.ParseIP("1.0.0.2"), Port: 9002}, settings) + + require.NoError(t, n.Start()) + + return clientPacketConn, serverPacketConn, func(t *testing.T) { + require.NoError(t, clientPacketConn.Close()) + require.NoError(t, serverPacketConn.Close()) + require.NoError(t, n.Close()) + } +} + type droppingRouter struct { simnet.PerfectRouter diff --git a/integrationtests/self/timeout_test.go b/integrationtests/self/timeout_test.go index 996768c6..8ffdf1b0 100644 --- a/integrationtests/self/timeout_test.go +++ b/integrationtests/self/timeout_test.go @@ -3,27 +3,29 @@ package self_test import ( "bytes" "context" + "crypto/tls" "errors" "fmt" "io" mrand "math/rand/v2" "net" - "runtime" "sync/atomic" "testing" "time" "github.com/quic-go/quic-go" - quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy" "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/synctest" "github.com/quic-go/quic-go/qlog" "github.com/quic-go/quic-go/qlogwriter" + "github.com/quic-go/quic-go/testutils/simnet" "github.com/stretchr/testify/require" ) func requireIdleTimeoutError(t *testing.T, err error) { t.Helper() + require.Error(t, err) var idleTimeoutErr *quic.IdleTimeoutError require.ErrorAs(t, err, &idleTimeoutErr) @@ -34,340 +36,325 @@ func requireIdleTimeoutError(t *testing.T, err error) { } func TestHandshakeIdleTimeout(t *testing.T) { - errChan := make(chan error, 1) - go func() { - conn := newUDPConnLocalhost(t) - _, err := quic.Dial( + t.Run("Dial", func(t *testing.T) { + testHandshakeIdleTimeout(t, quic.Dial) + }) + + t.Run("DialEarly", func(t *testing.T) { + testHandshakeIdleTimeout(t, quic.DialEarly) + }) +} + +func testHandshakeIdleTimeout(t *testing.T, dialFn func(context.Context, net.PacketConn, net.Addr, *tls.Config, *quic.Config) (*quic.Conn, error)) { + synctest.Test(t, func(t *testing.T) { + const handshakeIdleTimeout = 3 * time.Second + + clientPacketConn, serverPacketConn, closeFn := newSimnetLink(t, time.Millisecond) + defer closeFn(t) + + errChan := make(chan error, 1) + start := time.Now() + go func() { + _, err := dialFn( + context.Background(), + clientPacketConn, + serverPacketConn.LocalAddr(), + getTLSClientConfig(), + getQuicConfig(&quic.Config{HandshakeIdleTimeout: handshakeIdleTimeout}), + ) + errChan <- err + }() + select { + case err := <-errChan: + requireIdleTimeoutError(t, err) + require.Equal(t, handshakeIdleTimeout, time.Since(start)) + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for dial error") + } + }) +} + +func TestIdleTimeout(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + const idleTimeout = 20 * time.Second + + var drop atomic.Bool + clientPacketConn, serverPacketConn, closeFn := newSimnetLinkWithRouter(t, + time.Millisecond, + &droppingRouter{Drop: func(p simnet.Packet) bool { return drop.Load() }}, + ) + defer closeFn(t) + + server, err := quic.Listen( + serverPacketConn, + getTLSConfig(), + getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}), + ) + require.NoError(t, err) + defer server.Close() + + conn, err := quic.Dial( context.Background(), - newUDPConnLocalhost(t), - conn.LocalAddr(), + clientPacketConn, + serverPacketConn.LocalAddr(), getTLSClientConfig(), - getQuicConfig(&quic.Config{HandshakeIdleTimeout: scaleDuration(50 * time.Millisecond)}), + getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true, MaxIdleTimeout: idleTimeout}), ) - errChan <- err - }() - select { - case err := <-errChan: + require.NoError(t, err) + + serverConn, err := server.Accept(context.Background()) + require.NoError(t, err) + str, err := serverConn.OpenStream() + require.NoError(t, err) + _, err = str.Write([]byte("foobar")) + require.NoError(t, err) + + serverStart := time.Now() + + strIn, err := conn.AcceptStream(context.Background()) + require.NoError(t, err) + strOut, err := conn.OpenStream() + require.NoError(t, err) + _, err = strIn.Read(make([]byte, 6)) + require.NoError(t, err) + + clientStart := time.Now() + + drop.Store(true) + + select { + case <-serverConn.Context().Done(): + took := time.Since(serverStart) + require.GreaterOrEqual(t, took, idleTimeout) + t.Logf("server connection timed out after %s (idle timeout: %s)", took, idleTimeout) + case <-time.After(2 * idleTimeout): + t.Fatal("timeout waiting for idle timeout") + } + + select { + case <-conn.Context().Done(): + took := time.Since(clientStart) + require.GreaterOrEqual(t, took, idleTimeout) + t.Logf("client connection timed out after %s (idle timeout: %s)", took, idleTimeout) + case <-time.After(2 * idleTimeout): + t.Fatal("timeout waiting for idle timeout") + } + + _, err = strIn.Write([]byte("test")) requireIdleTimeoutError(t, err) - case <-time.After(5 * time.Second): - t.Fatal("timeout waiting for dial error") - } + _, err = strIn.Read([]byte{0}) + requireIdleTimeoutError(t, err) + _, err = strOut.Write([]byte("test")) + requireIdleTimeoutError(t, err) + _, err = strOut.Read([]byte{0}) + requireIdleTimeoutError(t, err) + _, err = conn.OpenStream() + requireIdleTimeoutError(t, err) + _, err = conn.OpenUniStream() + requireIdleTimeoutError(t, err) + _, err = conn.AcceptStream(context.Background()) + requireIdleTimeoutError(t, err) + _, err = conn.AcceptUniStream(context.Background()) + requireIdleTimeoutError(t, err) + }) } -func TestHandshakeTimeoutContext(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) - defer cancel() - errChan := make(chan error) - go func() { - conn := newUDPConnLocalhost(t) - _, err := quic.Dial( - ctx, - newUDPConnLocalhost(t), - conn.LocalAddr(), - getTLSClientConfig(), - getQuicConfig(nil), +func TestKeepAlive(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + const idleTimeout = 4 * time.Second + + var drop atomic.Bool + clientPacketConn, serverPacketConn, closeFn := newSimnetLinkWithRouter(t, + time.Millisecond, + &droppingRouter{Drop: func(p simnet.Packet) bool { return drop.Load() }}, ) - errChan <- err - }() - select { - case err := <-errChan: - require.ErrorIs(t, err, context.DeadlineExceeded) - case <-time.After(5 * time.Second): - t.Fatal("timeout waiting for dial error") - } -} + defer closeFn(t) + + server, err := quic.Listen( + serverPacketConn, + getTLSConfig(), + getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}), + ) + require.NoError(t, err) + defer server.Close() -func TestHandshakeTimeout0RTTContext(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) - defer cancel() - errChan := make(chan error) - go func() { - conn := newUDPConnLocalhost(t) - _, err := quic.DialEarly( + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + conn, err := quic.Dial( ctx, - newUDPConnLocalhost(t), - conn.LocalAddr(), + clientPacketConn, + serverPacketConn.LocalAddr(), getTLSClientConfig(), - getQuicConfig(nil), + getQuicConfig(&quic.Config{ + MaxIdleTimeout: idleTimeout, + KeepAlivePeriod: idleTimeout / 2, + DisablePathMTUDiscovery: true, + }), ) - errChan <- err - }() - select { - case err := <-errChan: - require.ErrorIs(t, err, context.DeadlineExceeded) - case <-time.After(5 * time.Second): - t.Fatal("timeout waiting for dial error") - } -} + require.NoError(t, err) -func TestIdleTimeout(t *testing.T) { - idleTimeout := scaleDuration(200 * time.Millisecond) - - server, err := quic.Listen( - newUDPConnLocalhost(t), - getTLSConfig(), - getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}), - ) - require.NoError(t, err) - defer server.Close() - - var drop atomic.Bool - proxy := quicproxy.Proxy{ - Conn: newUDPConnLocalhost(t), - ServerAddr: server.Addr().(*net.UDPAddr), - DropPacket: func(quicproxy.Direction, net.Addr, net.Addr, []byte) bool { return drop.Load() }, - } - require.NoError(t, proxy.Start()) - defer proxy.Close() - - conn, err := quic.Dial( - context.Background(), - newUDPConnLocalhost(t), - proxy.LocalAddr(), - getTLSClientConfig(), - getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true, MaxIdleTimeout: idleTimeout}), - ) - require.NoError(t, err) - - serverConn, err := server.Accept(context.Background()) - require.NoError(t, err) - str, err := serverConn.OpenStream() - require.NoError(t, err) - _, err = str.Write([]byte("foobar")) - require.NoError(t, err) - - strIn, err := conn.AcceptStream(context.Background()) - require.NoError(t, err) - strOut, err := conn.OpenStream() - require.NoError(t, err) - _, err = strIn.Read(make([]byte, 6)) - require.NoError(t, err) - - drop.Store(true) - time.Sleep(2 * idleTimeout) - _, err = strIn.Write([]byte("test")) - requireIdleTimeoutError(t, err) - _, err = strIn.Read([]byte{0}) - requireIdleTimeoutError(t, err) - _, err = strOut.Write([]byte("test")) - requireIdleTimeoutError(t, err) - _, err = strOut.Read([]byte{0}) - requireIdleTimeoutError(t, err) - _, err = conn.OpenStream() - requireIdleTimeoutError(t, err) - _, err = conn.OpenUniStream() - requireIdleTimeoutError(t, err) - _, err = conn.AcceptStream(context.Background()) - requireIdleTimeoutError(t, err) - _, err = conn.AcceptUniStream(context.Background()) - requireIdleTimeoutError(t, err) -} + serverConn, err := server.Accept(ctx) + require.NoError(t, err) -func TestKeepAlive(t *testing.T) { - idleTimeout := scaleDuration(150 * time.Millisecond) - if runtime.GOOS == "windows" { - // increase the duration, since timers on Windows are not very precise - idleTimeout = max(idleTimeout, 600*time.Millisecond) - } + // wait longer than the idle timeout + time.Sleep(3 * idleTimeout) + str, err := conn.OpenUniStream() + require.NoError(t, err) + _, err = str.Write([]byte("foobar")) + require.NoError(t, err) - server, err := quic.Listen( - newUDPConnLocalhost(t), - getTLSConfig(), - getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}), - ) - require.NoError(t, err) - defer server.Close() - - var drop atomic.Bool - proxy := quicproxy.Proxy{ - Conn: newUDPConnLocalhost(t), - ServerAddr: server.Addr().(*net.UDPAddr), - DropPacket: func(quicproxy.Direction, net.Addr, net.Addr, []byte) bool { return drop.Load() }, - } - require.NoError(t, proxy.Start()) - defer proxy.Close() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - conn, err := quic.Dial( - ctx, - newUDPConnLocalhost(t), - proxy.LocalAddr(), - getTLSClientConfig(), - getQuicConfig(&quic.Config{ - MaxIdleTimeout: idleTimeout, - KeepAlivePeriod: idleTimeout / 2, - DisablePathMTUDiscovery: true, - }), - ) - require.NoError(t, err) - - serverConn, err := server.Accept(ctx) - require.NoError(t, err) - - // wait longer than the idle timeout - time.Sleep(3 * idleTimeout) - str, err := conn.OpenUniStream() - require.NoError(t, err) - _, err = str.Write([]byte("foobar")) - require.NoError(t, err) - - // verify connection is still alive - select { - case <-serverConn.Context().Done(): - t.Fatal("server connection closed unexpectedly") - default: - } + // verify connection is still alive + select { + case <-serverConn.Context().Done(): + t.Fatal("server connection closed unexpectedly") + default: + } - // idle timeout will still kick in if PINGs are dropped - drop.Store(true) - time.Sleep(2 * idleTimeout) - _, err = str.Write([]byte("foobar")) - var nerr net.Error - require.True(t, errors.As(err, &nerr)) - require.True(t, nerr.Timeout()) + // idle timeout will still kick in if PINGs are dropped + drop.Store(true) + time.Sleep(2 * idleTimeout) + _, err = str.Write([]byte("foobar")) + requireIdleTimeoutError(t, err) - // can't rely on the server connection closing, since we impose a minimum idle timeout of 5s, - // see https://github.com/quic-go/quic-go/issues/4751 - serverConn.CloseWithError(0, "") + // can't rely on the server connection closing, since we impose a minimum idle timeout of 5s, + // see https://github.com/quic-go/quic-go/issues/4751 + serverConn.CloseWithError(0, "") + }) } func TestTimeoutAfterInactivity(t *testing.T) { - idleTimeout := scaleDuration(150 * time.Millisecond) - if runtime.GOOS == "windows" { - // increase the duration, since timers on Windows are not very precise - idleTimeout = max(idleTimeout, 600*time.Millisecond) - } + synctest.Test(t, func(t *testing.T) { + const idleTimeout = 15 * time.Second + + clientPacketConn, serverPacketConn, closeFn := newSimnetLink(t, time.Millisecond) + defer closeFn(t) + + server, err := quic.Listen( + serverPacketConn, + getTLSConfig(), + getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}), + ) + require.NoError(t, err) + defer server.Close() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + counter, tr := newPacketTracer() + conn, err := quic.Dial( + ctx, + clientPacketConn, + server.Addr(), + getTLSClientConfig(), + getQuicConfig(&quic.Config{ + MaxIdleTimeout: idleTimeout, + Tracer: func(context.Context, bool, quic.ConnectionID) qlogwriter.Trace { return tr }, + DisablePathMTUDiscovery: true, + }), + ) + require.NoError(t, err) + + serverConn, err := server.Accept(ctx) + require.NoError(t, err) + defer serverConn.CloseWithError(0, "") - server, err := quic.Listen( - newUDPConnLocalhost(t), - getTLSConfig(), - getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}), - ) - require.NoError(t, err) - defer server.Close() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - counter, tr := newPacketTracer() - conn, err := quic.Dial( - ctx, - newUDPConnLocalhost(t), - server.Addr(), - getTLSClientConfig(), - getQuicConfig(&quic.Config{ - MaxIdleTimeout: idleTimeout, - Tracer: func(context.Context, bool, quic.ConnectionID) qlogwriter.Trace { return tr }, - DisablePathMTUDiscovery: true, - }), - ) - require.NoError(t, err) - - serverConn, err := server.Accept(ctx) - require.NoError(t, err) - - ctx, cancel = context.WithTimeout(context.Background(), 2*idleTimeout) - defer cancel() - _, err = conn.AcceptStream(ctx) - requireIdleTimeoutError(t, err) - - var lastAckElicitingPacketSentAt time.Time - for _, p := range counter.getSentShortHeaderPackets() { - var hasAckElicitingFrame bool - for _, f := range p.frames { - if _, ok := f.Frame.(qlog.AckFrame); ok { - continue + ctx, cancel = context.WithTimeout(context.Background(), 2*idleTimeout) + defer cancel() + _, err = conn.AcceptStream(ctx) + requireIdleTimeoutError(t, err) + + var lastAckElicitingPacketSentAt time.Time + for _, p := range counter.getSentShortHeaderPackets() { + var hasAckElicitingFrame bool + for _, f := range p.frames { + if _, ok := f.Frame.(qlog.AckFrame); ok { + continue + } + hasAckElicitingFrame = true + break + } + if hasAckElicitingFrame { + lastAckElicitingPacketSentAt = p.time } - hasAckElicitingFrame = true - break - } - if hasAckElicitingFrame { - lastAckElicitingPacketSentAt = p.time } - } - rcvdPackets := counter.getRcvdShortHeaderPackets() - lastPacketRcvdAt := rcvdPackets[len(rcvdPackets)-1].time - // We're ignoring here that only the first ack-eliciting packet sent resets the idle timeout. - // This is ok since we're dealing with a lossless connection here, - // and we'd expect to receive an ACK for additional other ack-eliciting packet sent. - timeSinceLastAckEliciting := time.Since(lastAckElicitingPacketSentAt) - timeSinceLastRcvd := time.Since(lastPacketRcvdAt) - maxDuration := max(timeSinceLastAckEliciting, timeSinceLastRcvd) - require.GreaterOrEqual(t, maxDuration, idleTimeout) - require.Less(t, maxDuration, idleTimeout*6/5) - - select { - case <-serverConn.Context().Done(): - t.Fatal("server connection closed unexpectedly") - default: - } + rcvdPackets := counter.getRcvdShortHeaderPackets() + lastPacketRcvdAt := rcvdPackets[len(rcvdPackets)-1].time + // We're ignoring here that only the first ack-eliciting packet sent resets the idle timeout. + // This is ok since we're dealing with a lossless connection here, + // and we'd expect to receive an ACK for additional other ack-eliciting packet sent. + timeSinceLastAckEliciting := time.Since(lastAckElicitingPacketSentAt) + timeSinceLastRcvd := time.Since(lastPacketRcvdAt) + require.Equal(t, idleTimeout, max(timeSinceLastAckEliciting, timeSinceLastRcvd)) - serverConn.CloseWithError(0, "") + select { + case <-serverConn.Context().Done(): + t.Fatal("server connection closed unexpectedly") + default: + } + }) } func TestTimeoutAfterSendingPacket(t *testing.T) { - idleTimeout := scaleDuration(150 * time.Millisecond) - if runtime.GOOS == "windows" { - // increase the duration, since timers on Windows are not very precise - idleTimeout = max(idleTimeout, 600*time.Millisecond) - } + synctest.Test(t, func(t *testing.T) { + const idleTimeout = 15 * time.Second - server, err := quic.Listen( - newUDPConnLocalhost(t), - getTLSConfig(), - getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}), - ) - require.NoError(t, err) - defer server.Close() - - var drop atomic.Bool - proxy := quicproxy.Proxy{ - Conn: newUDPConnLocalhost(t), - ServerAddr: server.Addr().(*net.UDPAddr), - DropPacket: func(quicproxy.Direction, net.Addr, net.Addr, []byte) bool { return drop.Load() }, - } - require.NoError(t, proxy.Start()) - defer proxy.Close() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - conn, err := quic.Dial( - ctx, - newUDPConnLocalhost(t), - proxy.LocalAddr(), - getTLSClientConfig(), - getQuicConfig(&quic.Config{MaxIdleTimeout: idleTimeout, DisablePathMTUDiscovery: true}), - ) - require.NoError(t, err) - - serverConn, err := server.Accept(ctx) - require.NoError(t, err) - - // wait half the idle timeout, then send a packet - time.Sleep(idleTimeout / 2) - drop.Store(true) - str, err := conn.OpenUniStream() - require.NoError(t, err) - _, err = str.Write([]byte("foobar")) - require.NoError(t, err) - - // now make sure that the idle timeout is based on this packet - startTime := time.Now() - ctx, cancel = context.WithTimeout(context.Background(), 2*idleTimeout) - defer cancel() - _, err = conn.AcceptStream(ctx) - requireIdleTimeoutError(t, err) - dur := time.Since(startTime) - require.GreaterOrEqual(t, dur, idleTimeout) - require.Less(t, dur, idleTimeout*12/10) - - // Verify server connection is still open - select { - case <-serverConn.Context().Done(): - t.Fatal("server connection closed unexpectedly") - default: - } - serverConn.CloseWithError(0, "") + var drop atomic.Bool + clientPacketConn, serverPacketConn, closeFn := newSimnetLinkWithRouter(t, + time.Millisecond, + &droppingRouter{Drop: func(p simnet.Packet) bool { return drop.Load() }}, + ) + defer closeFn(t) + + server, err := quic.Listen( + serverPacketConn, + getTLSConfig(), + getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}), + ) + require.NoError(t, err) + defer server.Close() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + conn, err := quic.Dial( + ctx, + clientPacketConn, + serverPacketConn.LocalAddr(), + getTLSClientConfig(), + getQuicConfig(&quic.Config{MaxIdleTimeout: idleTimeout, DisablePathMTUDiscovery: true}), + ) + require.NoError(t, err) + + serverConn, err := server.Accept(ctx) + require.NoError(t, err) + + serverStart := time.Now() + + // wait half the idle timeout, then send a packet + time.Sleep(idleTimeout / 2) + drop.Store(true) + + clientStart := time.Now() + str, err := conn.OpenUniStream() + require.NoError(t, err) + _, err = str.Write([]byte("foobar")) + require.NoError(t, err) + + select { + case <-serverConn.Context().Done(): + took := time.Since(serverStart) + require.GreaterOrEqual(t, took, idleTimeout) + require.Less(t, took, idleTimeout+time.Second) + case <-time.After(2 * idleTimeout): + t.Fatal("timeout waiting for idle timeout") + } + + select { + case <-conn.Context().Done(): + took := time.Since(clientStart) + require.Equal(t, took, idleTimeout) + case <-time.After(2 * idleTimeout): + t.Fatal("timeout waiting for idle timeout") + } + }) } type faultyConn struct { @@ -405,116 +392,113 @@ func TestFaultyPacketConn(t *testing.T) { } func testFaultyPacketConn(t *testing.T, pers protocol.Perspective) { - handshakeTimeout := scaleDuration(100 * time.Millisecond) + t.Setenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING", "true") - runServer := func(ln *quic.Listener) error { - conn, err := ln.Accept(context.Background()) - if err != nil { - return err - } - str, err := conn.OpenUniStream() - if err != nil { + synctest.Test(t, func(t *testing.T) { + runServer := func(ln *quic.Listener) error { + conn, err := ln.Accept(context.Background()) + if err != nil { + return err + } + str, err := conn.OpenUniStream() + if err != nil { + return err + } + defer str.Close() + _, err = str.Write(PRData) return err } - defer str.Close() - _, err = str.Write(PRData) - return err - } - runClient := func(conn *quic.Conn) error { - str, err := conn.AcceptUniStream(context.Background()) - if err != nil { - return err - } - data, err := io.ReadAll(str) - if err != nil { - return err - } - if !bytes.Equal(data, PRData) { - return fmt.Errorf("wrong data: %q vs %q", data, PRData) + runClient := func(conn *quic.Conn) error { + str, err := conn.AcceptUniStream(context.Background()) + if err != nil { + return err + } + data, err := io.ReadAll(str) + if err != nil { + return err + } + if !bytes.Equal(data, PRData) { + return fmt.Errorf("wrong data: %q vs %q", data, PRData) + } + return conn.CloseWithError(0, "done") } - return conn.CloseWithError(0, "done") - } - var cconn net.PacketConn = newUDPConnLocalhost(t) - var sconn net.PacketConn = newUDPConnLocalhost(t) - maxPackets := mrand.IntN(25) - t.Logf("blocking %s's connection after %d packets", pers, maxPackets) - switch pers { - case protocol.PerspectiveClient: - cconn = &faultyConn{PacketConn: cconn, MaxPackets: maxPackets} - case protocol.PerspectiveServer: - sconn = &faultyConn{PacketConn: sconn, MaxPackets: maxPackets} - } + clientPacketConn, serverPacketConn, closeFn := newSimnetLink(t, 100*time.Millisecond) + defer closeFn(t) - ln, err := quic.Listen( - sconn, - getTLSConfig(), - getQuicConfig(&quic.Config{ - HandshakeIdleTimeout: handshakeTimeout, - MaxIdleTimeout: handshakeTimeout, - KeepAlivePeriod: handshakeTimeout / 2, - DisablePathMTUDiscovery: true, - }), - ) - require.NoError(t, err) - defer ln.Close() - - serverErrChan := make(chan error, 1) - go func() { serverErrChan <- runServer(ln) }() - - clientErrChan := make(chan error, 1) - go func() { - conn, err := quic.Dial( - context.Background(), - cconn, - ln.Addr(), - getTLSClientConfig(), - getQuicConfig(&quic.Config{ - HandshakeIdleTimeout: handshakeTimeout, - MaxIdleTimeout: handshakeTimeout, - KeepAlivePeriod: handshakeTimeout / 2, - DisablePathMTUDiscovery: true, - }), - ) - if err != nil { - clientErrChan <- err - return + var cconn, sconn net.PacketConn = clientPacketConn, serverPacketConn + maxPackets := mrand.IntN(25) + // sanity check: sending PRData should generate at least 25 packets + require.Greater(t, len(PRData)/1500, 25) + + t.Logf("blocking %s's connection after %d packets", pers, maxPackets) + switch pers { + case protocol.PerspectiveClient: + cconn = &faultyConn{PacketConn: cconn, MaxPackets: maxPackets} + case protocol.PerspectiveServer: + sconn = &faultyConn{PacketConn: sconn, MaxPackets: maxPackets} } - clientErrChan <- runClient(conn) - }() - - var clientErr error - select { - case clientErr = <-clientErrChan: - case <-time.After(5 * handshakeTimeout): - t.Fatal("timeout waiting for client error") - } - require.Error(t, clientErr) - if pers == protocol.PerspectiveClient { - require.Contains(t, clientErr.Error(), io.ErrClosedPipe.Error()) - } else { - var nerr net.Error - require.True(t, errors.As(clientErr, &nerr)) - require.True(t, nerr.Timeout()) - } - select { - case serverErr := <-serverErrChan: // The handshake completed on the server side. - require.Error(t, serverErr) - if pers == protocol.PerspectiveServer { - require.Contains(t, serverErr.Error(), io.ErrClosedPipe.Error()) + ln, err := quic.Listen( + sconn, + getTLSConfig(), + getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}), + ) + require.NoError(t, err) + defer ln.Close() + + serverErrChan := make(chan error, 1) + go func() { serverErrChan <- runServer(ln) }() + + clientErrChan := make(chan error, 1) + go func() { + conn, err := quic.Dial( + context.Background(), + cconn, + ln.Addr(), + getTLSClientConfig(), + getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}), + ) + if err != nil { + clientErrChan <- err + return + } + clientErrChan <- runClient(conn) + }() + + var clientErr error + select { + case clientErr = <-clientErrChan: + case <-time.After(time.Hour): + t.Fatal("timeout waiting for client error") + } + require.Error(t, clientErr) + if pers == protocol.PerspectiveClient { + require.Contains(t, clientErr.Error(), io.ErrClosedPipe.Error()) } else { var nerr net.Error - require.True(t, errors.As(serverErr, &nerr)) + require.True(t, errors.As(clientErr, &nerr)) require.True(t, nerr.Timeout()) } - default: // The handshake didn't complete - require.NoError(t, ln.Close()) + select { - case <-serverErrChan: - case <-time.After(time.Second): - t.Fatal("timeout waiting for server to close") + case serverErr := <-serverErrChan: // The handshake completed on the server side. + require.Error(t, serverErr) + if pers == protocol.PerspectiveServer { + require.Contains(t, serverErr.Error(), io.ErrClosedPipe.Error()) + } else { + var nerr net.Error + require.True(t, errors.As(serverErr, &nerr)) + require.True(t, nerr.Timeout()) + } + default: // The handshake didn't complete + require.NoError(t, ln.Close()) + select { + case <-serverErrChan: + case <-time.After(time.Hour): + t.Fatal("timeout waiting for server to close") + } } - } + }) }