import (
"context"
- "fmt"
+ "errors"
"io"
+ "math/rand/v2"
"net"
"testing"
"time"
"github.com/quic-go/quic-go"
- quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
+ "github.com/quic-go/quic-go/internal/protocol"
+ "github.com/quic-go/quic-go/internal/synctest"
+ "github.com/quic-go/quic-go/qlogwriter"
+ "github.com/quic-go/quic-go/testutils/simnet"
+
+ "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
-func runServerForRTTTest(t *testing.T) (net.Addr, <-chan error) {
- ln, err := quic.Listen(
- newUDPConnLocalhost(t),
- getTLSConfig(),
- getQuicConfig(nil),
- )
+func runServerForRTTTest(t *testing.T, serverConn net.PacketConn) <-chan *quic.Conn {
+ t.Helper()
+
+ ln, err := quic.Listen(serverConn, getTLSConfig(), getQuicConfig(nil))
require.NoError(t, err)
t.Cleanup(func() { ln.Close() })
- errChan := make(chan error, 1)
+ connChan := make(chan *quic.Conn, 1)
go func() {
- defer close(errChan)
for {
conn, err := ln.Accept(context.Background())
if err != nil {
- errChan <- fmt.Errorf("accept error: %w", err)
+ if !errors.Is(err, net.ErrClosed) {
+ t.Logf("accept error: %v", err)
+ }
return
}
+ connChan <- conn
str, err := conn.OpenStream()
if err != nil {
- errChan <- fmt.Errorf("open stream error: %w", err)
+ t.Logf("open stream error: %v", err)
return
}
- _, err = str.Write(PRData)
- if err != nil {
- errChan <- fmt.Errorf("write error: %w", err)
+ if _, err := str.Write(PRData); err != nil {
+ t.Logf("write error: %v", err)
return
}
str.Close()
}
}()
- return ln.Addr(), errChan
+ return connChan
}
func TestDownloadWithFixedRTT(t *testing.T) {
- addr, errChan := runServerForRTTTest(t)
-
for _, rtt := range []time.Duration{
10 * time.Millisecond,
100 * time.Millisecond,
- 250 * time.Millisecond,
+ 500 * time.Millisecond,
} {
- t.Run(fmt.Sprintf("RTT %s", rtt), func(t *testing.T) {
- t.Cleanup(func() {
- select {
- case err := <-errChan:
- t.Errorf("server error: %v", err)
- default:
+ t.Run(rtt.String(), func(t *testing.T) {
+ synctest.Test(t, func(t *testing.T) {
+ clientConn, serverConn, closeFn := newSimnetLink(t, rtt)
+ defer closeFn(t)
+
+ sconnChan := runServerForRTTTest(t, serverConn)
+
+ ctx, cancel := context.WithTimeout(context.Background(), time.Hour)
+ defer cancel()
+ conn, err := quic.Dial(
+ ctx,
+ clientConn,
+ serverConn.LocalAddr(),
+ getTLSClientConfig(),
+ getQuicConfig(nil),
+ )
+ require.NoError(t, err)
+ defer conn.CloseWithError(0, "")
+
+ str, err := conn.AcceptStream(ctx)
+ require.NoError(t, err)
+ data, err := io.ReadAll(str)
+ require.NoError(t, err)
+ require.Equal(t, PRData, data)
+
+ sconn := <-sconnChan
+ defer sconn.CloseWithError(0, "")
+
+ checkRTTs := func(stats quic.ConnectionStats) {
+ t.Helper()
+ require.GreaterOrEqual(t, stats.SmoothedRTT, rtt)
+ require.GreaterOrEqual(t, stats.MinRTT, rtt)
+ require.LessOrEqual(t, stats.SmoothedRTT, rtt+time.Millisecond)
+ require.LessOrEqual(t, stats.MinRTT, rtt+time.Millisecond)
}
+ checkRTTs(conn.ConnectionStats())
+ checkRTTs(sconn.ConnectionStats())
})
-
- proxy := quicproxy.Proxy{
- Conn: newUDPConnLocalhost(t),
- ServerAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: addr.(*net.UDPAddr).Port},
- DelayPacket: func(quicproxy.Direction, net.Addr, net.Addr, []byte) time.Duration { return rtt / 2 },
- }
- require.NoError(t, proxy.Start())
- t.Cleanup(func() { proxy.Close() })
-
- ctx, cancel := context.WithTimeout(context.Background(), time.Second)
- defer cancel()
- conn, err := quic.Dial(
- ctx,
- newUDPConnLocalhost(t),
- proxy.LocalAddr(),
- getTLSClientConfig(),
- getQuicConfig(nil),
- )
- require.NoError(t, err)
- defer conn.CloseWithError(0, "")
-
- str, err := conn.AcceptStream(ctx)
- require.NoError(t, err)
- data, err := io.ReadAll(str)
- require.NoError(t, err)
- require.Equal(t, PRData, data)
})
}
}
func TestDownloadWithReordering(t *testing.T) {
- addr, errChan := runServerForRTTTest(t)
-
for _, rtt := range []time.Duration{
5 * time.Millisecond,
- 30 * time.Millisecond,
+ 50 * time.Millisecond,
+ 250 * time.Millisecond,
} {
- t.Run(fmt.Sprintf("RTT %s", rtt), func(t *testing.T) {
- t.Cleanup(func() {
- select {
- case err := <-errChan:
- t.Errorf("server error: %v", err)
- default:
+ t.Run(rtt.String(), func(t *testing.T) {
+ synctest.Test(t, func(t *testing.T) {
+ counter, tracer := newPacketTracer()
+
+ n := &simnet.Simnet{Router: &simnet.PerfectRouter{}}
+ settings := simnet.NodeBiDiLinkSettings{
+ LatencyFunc: func(simnet.Packet) time.Duration {
+ min := rtt * 9 / 10
+ max := rtt * 11 / 10
+ return (min + time.Duration(rand.IntN(int(max-min)))) / 2
+ },
}
- })
+ clientConn := n.NewEndpoint(&net.UDPAddr{IP: net.ParseIP("1.0.0.1"), Port: 9001}, settings)
+ serverConn := n.NewEndpoint(&net.UDPAddr{IP: net.ParseIP("1.0.0.2"), Port: 9002}, settings)
+ require.NoError(t, n.Start())
+ defer func() {
+ require.NoError(t, clientConn.Close())
+ require.NoError(t, serverConn.Close())
+ require.NoError(t, n.Close())
+ }()
- proxy := quicproxy.Proxy{
- Conn: newUDPConnLocalhost(t),
- ServerAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: addr.(*net.UDPAddr).Port},
- DelayPacket: func(quicproxy.Direction, net.Addr, net.Addr, []byte) time.Duration {
- return randomDuration(rtt/2, rtt*3/2) / 2
- },
- }
- require.NoError(t, proxy.Start())
- t.Cleanup(func() { proxy.Close() })
-
- ctx, cancel := context.WithTimeout(context.Background(), time.Second)
- defer cancel()
- conn, err := quic.Dial(
- ctx,
- newUDPConnLocalhost(t),
- proxy.LocalAddr(),
- getTLSClientConfig(),
- getQuicConfig(nil),
- )
- require.NoError(t, err)
- defer conn.CloseWithError(0, "")
-
- str, err := conn.AcceptStream(ctx)
- require.NoError(t, err)
- data, err := io.ReadAll(str)
- require.NoError(t, err)
- require.Equal(t, PRData, data)
+ sconnChan := runServerForRTTTest(t, serverConn)
+
+ ctx, cancel := context.WithTimeout(context.Background(), time.Hour)
+ defer cancel()
+ conn, err := quic.Dial(
+ ctx,
+ clientConn,
+ serverConn.LocalAddr(),
+ getTLSClientConfig(),
+ getQuicConfig(&quic.Config{
+ Tracer: func(context.Context, bool, quic.ConnectionID) qlogwriter.Trace { return tracer },
+ }),
+ )
+ require.NoError(t, err)
+ defer conn.CloseWithError(0, "")
+
+ str, err := conn.AcceptStream(ctx)
+ require.NoError(t, err)
+ data, err := io.ReadAll(str)
+ require.NoError(t, err)
+ require.Equal(t, PRData, data)
+
+ sconn := <-sconnChan
+ defer sconn.CloseWithError(0, "")
+
+ packets := counter.getRcvdShortHeaderPackets()
+ var packetNumbers []protocol.PacketNumber
+ for _, p := range packets {
+ packetNumbers = append(packetNumbers, p.hdr.PacketNumber)
+ }
+
+ // Count reorderings: a reordering occurs when we receive a packet
+ // with a number lower than the highest number we've seen so far
+ var reorderings int
+ var highestSeen protocol.PacketNumber
+ for _, pn := range packetNumbers {
+ if pn < highestSeen {
+ reorderings++
+ }
+ if pn > highestSeen {
+ highestSeen = pn
+ }
+ }
+
+ t.Logf("Smoothed RTT: %s", conn.ConnectionStats().SmoothedRTT)
+ assert.GreaterOrEqual(t, conn.ConnectionStats().SmoothedRTT, rtt*9/10)
+ assert.LessOrEqual(t, conn.ConnectionStats().SmoothedRTT, rtt*11/10)
+ t.Logf("received %d short header packets, detected %d reorderings", len(packetNumbers), reorderings)
+ assert.GreaterOrEqual(t, reorderings, 20, "expected at least 20 reorderings")
+ })
})
}
}