"github.com/quic-go/quic-go/internal/monotime"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
+ "github.com/quic-go/quic-go/internal/synctest"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
}
func TestConnectionIdleTimeoutDuringHandshake(t *testing.T) {
- mockCtrl := gomock.NewController(t)
- tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
- tc := newServerTestConnection(t,
- mockCtrl,
- &Config{HandshakeIdleTimeout: scaleDuration(25 * time.Millisecond)},
- false,
- connectionOptTracer(tr),
- )
- tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).AnyTimes()
- tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
- gomock.InOrder(
- tracer.EXPECT().ClosedConnection(&IdleTimeoutError{}),
- tracer.EXPECT().Close(),
- )
- errChan := make(chan error, 1)
- go func() { errChan <- tc.conn.run() }()
- select {
- case err := <-errChan:
- require.ErrorIs(t, err, &IdleTimeoutError{})
- case <-time.After(time.Second):
- t.Fatal("timeout")
- }
+ synctest.Test(t, func(t *testing.T) {
+ const timeout = 7 * time.Second
+ mockCtrl := gomock.NewController(t)
+ tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
+ tc := newServerTestConnection(t,
+ mockCtrl,
+ &Config{HandshakeIdleTimeout: timeout},
+ false,
+ connectionOptTracer(tr),
+ )
+ tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).AnyTimes()
+ tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
+ gomock.InOrder(
+ tracer.EXPECT().ClosedConnection(&IdleTimeoutError{}),
+ tracer.EXPECT().Close(),
+ )
+ start := monotime.Now()
+ errChan := make(chan error, 1)
+ go func() { errChan <- tc.conn.run() }()
+
+ synctest.Wait()
+
+ select {
+ case err := <-errChan:
+ require.ErrorIs(t, err, &IdleTimeoutError{})
+ require.Equal(t, timeout, monotime.Since(start))
+ case <-time.After(timeout + time.Nanosecond):
+ t.Fatal("timeout")
+ }
+ })
}
func TestConnectionHandshakeIdleTimeout(t *testing.T) {
- mockCtrl := gomock.NewController(t)
- tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
- tc := newServerTestConnection(t,
- mockCtrl,
- &Config{HandshakeIdleTimeout: scaleDuration(25 * time.Millisecond)},
- false,
- connectionOptTracer(tr),
- func(c *Conn) { c.creationTime = monotime.Now().Add(-10 * time.Second) },
- )
- tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).AnyTimes()
- tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
- gomock.InOrder(
- tracer.EXPECT().ClosedConnection(&HandshakeTimeoutError{}),
- tracer.EXPECT().Close(),
- )
- errChan := make(chan error, 1)
- go func() { errChan <- tc.conn.run() }()
- select {
- case err := <-errChan:
- require.ErrorIs(t, err, &HandshakeTimeoutError{})
- case <-time.After(time.Second):
- t.Fatal("timeout")
- }
+ synctest.Test(t, func(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+ tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
+ tc := newServerTestConnection(t,
+ mockCtrl,
+ &Config{HandshakeIdleTimeout: 7 * time.Second},
+ false,
+ connectionOptTracer(tr),
+ func(c *Conn) { c.creationTime = monotime.Now().Add(-20 * time.Second) },
+ )
+ tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).AnyTimes()
+ tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
+ gomock.InOrder(
+ tracer.EXPECT().ClosedConnection(&HandshakeTimeoutError{}),
+ tracer.EXPECT().Close(),
+ )
+ errChan := make(chan error, 1)
+ go func() { errChan <- tc.conn.run() }()
+ select {
+ case err := <-errChan:
+ require.ErrorIs(t, err, &HandshakeTimeoutError{})
+ case <-time.After(time.Second):
+ t.Fatal("timeout")
+ }
+ })
}
func TestConnectionTransportParameters(t *testing.T) {
}
func TestConnectionHandleMaxStreamsFrame(t *testing.T) {
- mockCtrl := gomock.NewController(t)
- connFC := flowcontrol.NewConnectionFlowController(0, 0, nil, &utils.RTTStats{}, utils.DefaultLogger)
- tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptConnFlowController(connFC))
- tc.conn.handleTransportParameters(&wire.TransportParameters{})
+ synctest.Test(t, func(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+ connFC := flowcontrol.NewConnectionFlowController(0, 0, nil, &utils.RTTStats{}, utils.DefaultLogger)
+ tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptConnFlowController(connFC))
+ tc.conn.handleTransportParameters(&wire.TransportParameters{})
+
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+ defer cancel()
+ uniStreamChan := make(chan error)
+ go func() {
+ _, err := tc.conn.OpenUniStreamSync(ctx)
+ uniStreamChan <- err
+ }()
+ bidiStreamChan := make(chan error)
+ go func() {
+ _, err := tc.conn.OpenStreamSync(ctx)
+ bidiStreamChan <- err
+ }()
+
+ synctest.Wait()
+ select {
+ case <-uniStreamChan:
+ t.Fatal("uni stream should be blocked")
+ case <-bidiStreamChan:
+ t.Fatal("bidi stream should be blocked")
+ default:
+ }
- ctx, cancel := context.WithTimeout(context.Background(), time.Second)
- defer cancel()
- uniStreamChan := make(chan error)
- go func() {
- _, err := tc.conn.OpenUniStreamSync(ctx)
- uniStreamChan <- err
- }()
- bidiStreamChan := make(chan error)
- go func() {
- _, err := tc.conn.OpenStreamSync(ctx)
- bidiStreamChan <- err
- }()
+ // MAX_STREAMS frame for bidirectional stream
+ _, err := tc.conn.handleFrame(
+ &wire.MaxStreamsFrame{Type: protocol.StreamTypeBidi, MaxStreamNum: 10},
+ protocol.Encryption1RTT,
+ protocol.ConnectionID{},
+ monotime.Now(),
+ )
+ require.NoError(t, err)
- select {
- case <-uniStreamChan:
- t.Fatal("uni stream should be blocked")
- case <-bidiStreamChan:
- t.Fatal("bidi stream should be blocked")
- case <-time.After(scaleDuration(10 * time.Millisecond)):
- }
+ synctest.Wait()
- // MAX_STREAMS frame for bidirectional stream
- _, err := tc.conn.handleFrame(
- &wire.MaxStreamsFrame{Type: protocol.StreamTypeBidi, MaxStreamNum: 10},
- protocol.Encryption1RTT,
- protocol.ConnectionID{},
- monotime.Now(),
- )
- require.NoError(t, err)
+ select {
+ case <-uniStreamChan:
+ t.Fatal("uni stream should be blocked")
+ default:
+ }
+ select {
+ case err := <-bidiStreamChan:
+ require.NoError(t, err)
+ default:
+ t.Fatal("bidi stream should be unblocked")
+ }
- select {
- case <-uniStreamChan:
- t.Fatal("uni stream should be blocked")
- case <-time.After(scaleDuration(10 * time.Millisecond)):
- }
- select {
- case err := <-bidiStreamChan:
+ // MAX_STREAMS frame for bidirectional stream
+ _, err = tc.conn.handleFrame(
+ &wire.MaxStreamsFrame{Type: protocol.StreamTypeUni, MaxStreamNum: 10},
+ protocol.Encryption1RTT,
+ protocol.ConnectionID{},
+ monotime.Now(),
+ )
require.NoError(t, err)
- case <-time.After(time.Second):
- t.Fatal("timeout")
- }
- // MAX_STREAMS frame for bidirectional stream
- _, err = tc.conn.handleFrame(
- &wire.MaxStreamsFrame{Type: protocol.StreamTypeUni, MaxStreamNum: 10},
- protocol.Encryption1RTT,
- protocol.ConnectionID{},
- monotime.Now(),
- )
- require.NoError(t, err)
- select {
- case err := <-uniStreamChan:
- require.NoError(t, err)
- case <-time.After(time.Second):
- t.Fatal("timeout")
- }
+ synctest.Wait()
+ select {
+ case err := <-uniStreamChan:
+ require.NoError(t, err)
+ default:
+ t.Fatal("timeout")
+ }
+ })
}
func TestConnectionTransportParameterValidationFailureServer(t *testing.T) {
}
func TestConnectionPacketPacing(t *testing.T) {
- mockCtrl := gomock.NewController(t)
- sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
- sender := NewMockSender(mockCtrl)
+ synctest.Test(t, func(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+ sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
+ sender := NewMockSender(mockCtrl)
- tc := newServerTestConnection(t,
- mockCtrl,
- nil,
- false,
- connectionOptSentPacketHandler(sph),
- connectionOptSender(sender),
- connectionOptHandshakeConfirmed(),
- // set a fixed RTT, so that the idle timeout doesn't interfere with this test
- connectionOptRTT(10*time.Second),
- )
- sender.EXPECT().Run()
+ tc := newServerTestConnection(t,
+ mockCtrl,
+ nil,
+ false,
+ connectionOptSentPacketHandler(sph),
+ connectionOptSender(sender),
+ connectionOptHandshakeConfirmed(),
+ )
+ sender.EXPECT().Run()
- step := scaleDuration(50 * time.Millisecond)
+ const step = 50 * time.Millisecond
- sph.EXPECT().GetLossDetectionTimeout().Return(monotime.Now().Add(time.Hour)).AnyTimes()
- gomock.InOrder(
- // 1. allow 2 packets to be sent
- sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny),
- sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
- sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny),
- sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
- sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited),
- // 2. become pacing limited for 25ms
- sph.EXPECT().TimeUntilSend().DoAndReturn(func() monotime.Time { return monotime.Now().Add(step) }),
- // 3. send another packet
- sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny),
- sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
- sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited),
- // 4. become pacing limited for 25ms...
- sph.EXPECT().TimeUntilSend().DoAndReturn(func() monotime.Time { return monotime.Now().Add(step) }),
- // ... but this time we're still pacing limited when waking up.
- // In this case, we can only send an ACK.
- sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited),
- // 5. stop the test by becoming pacing limited forever
- sph.EXPECT().TimeUntilSend().Return(monotime.Now().Add(time.Hour)),
- sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
- )
- sph.EXPECT().ECNMode(gomock.Any()).AnyTimes()
- for i := 0; i < 3; i++ {
- tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), Version1).DoAndReturn(
- func(buf *packetBuffer, _ protocol.ByteCount, _ monotime.Time, _ protocol.Version) (shortHeaderPacket, error) {
- buf.Data = append(buf.Data, []byte("packet"+strconv.Itoa(i+1))...)
- return shortHeaderPacket{PacketNumber: protocol.PacketNumber(i + 1)}, nil
+ sph.EXPECT().GetLossDetectionTimeout().Return(monotime.Now().Add(time.Hour)).AnyTimes()
+ gomock.InOrder(
+ // 1. allow 2 packets to be sent
+ sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny),
+ sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
+ sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny),
+ sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
+ sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited),
+ // 2. become pacing limited for 25ms
+ sph.EXPECT().TimeUntilSend().DoAndReturn(func() monotime.Time { return monotime.Now().Add(step) }),
+ // 3. send another packet
+ sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny),
+ sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
+ sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited),
+ // 4. become pacing limited for 25ms...
+ sph.EXPECT().TimeUntilSend().DoAndReturn(func() monotime.Time { return monotime.Now().Add(step) }),
+ // ... but this time we're still pacing limited when waking up.
+ // In this case, we can only send an ACK.
+ sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited),
+ // 5. stop the test by becoming pacing limited forever
+ sph.EXPECT().TimeUntilSend().Return(monotime.Now().Add(time.Hour)),
+ sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
+ )
+ sph.EXPECT().ECNMode(gomock.Any()).AnyTimes()
+ for i := range 3 {
+ tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), Version1).DoAndReturn(
+ func(buf *packetBuffer, _ protocol.ByteCount, _ monotime.Time, _ protocol.Version) (shortHeaderPacket, error) {
+ buf.Data = append(buf.Data, []byte("packet"+strconv.Itoa(i+1))...)
+ return shortHeaderPacket{PacketNumber: protocol.PacketNumber(i + 1)}, nil
+ },
+ )
+ }
+ tc.packer.EXPECT().PackAckOnlyPacket(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
+ func(_ protocol.ByteCount, _ monotime.Time, _ protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
+ buf := getPacketBuffer()
+ buf.Data = []byte("ack")
+ return shortHeaderPacket{PacketNumber: 1}, buf, nil
},
)
- }
- tc.packer.EXPECT().PackAckOnlyPacket(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
- func(_ protocol.ByteCount, _ monotime.Time, _ protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
- buf := getPacketBuffer()
- buf.Data = []byte("ack")
- return shortHeaderPacket{PacketNumber: 1}, buf, nil
- },
- )
- sender.EXPECT().WouldBlock().AnyTimes()
-
- type sentPacket struct {
- time monotime.Time
- data []byte
- }
- sendChan := make(chan sentPacket, 10)
- sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) {
- sendChan <- sentPacket{time: monotime.Now(), data: b.Data}
- }).Times(4)
-
- errChan := make(chan error, 1)
- go func() { errChan <- tc.conn.run() }()
- tc.conn.scheduleSending()
+ sender.EXPECT().WouldBlock().AnyTimes()
- var times []monotime.Time
- for i := 0; i < 3; i++ {
+ type sentPacket struct {
+ time monotime.Time
+ data []byte
+ }
+ sendChan := make(chan sentPacket, 10)
+ sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) {
+ sendChan <- sentPacket{time: monotime.Now(), data: b.Data}
+ }).Times(4)
+
+ errChan := make(chan error, 1)
+ go func() { errChan <- tc.conn.run() }()
+ tc.conn.scheduleSending()
+
+ synctest.Wait()
+
+ var times []monotime.Time
+ for i := range 3 {
+ select {
+ case b := <-sendChan:
+ require.Equal(t, []byte("packet"+strconv.Itoa(i+1)), b.data)
+ times = append(times, b.time)
+ case <-time.After(time.Hour):
+ t.Fatal("should have sent a packet")
+ }
+ }
select {
case b := <-sendChan:
- require.Equal(t, []byte("packet"+strconv.Itoa(i+1)), b.data)
+ require.Equal(t, []byte("ack"), b.data)
times = append(times, b.time)
- case <-time.After(scaleDuration(time.Second)):
+ case <-time.After(time.Second):
t.Fatal("timeout")
}
- }
- select {
- case b := <-sendChan:
- require.Equal(t, []byte("ack"), b.data)
- times = append(times, b.time)
- case <-time.After(scaleDuration(time.Second)):
- t.Fatal("timeout")
- }
- require.InDelta(t, times[0].Sub(times[1]).Seconds(), 0, scaleDuration(10*time.Millisecond).Seconds())
- require.InDelta(t, times[2].Sub(times[1]).Seconds(), step.Seconds(), scaleDuration(20*time.Millisecond).Seconds())
- require.InDelta(t, times[3].Sub(times[2]).Seconds(), step.Seconds(), scaleDuration(20*time.Millisecond).Seconds())
+ require.Equal(t, times[0], times[1])
+ require.Equal(t, times[2], times[1].Add(step))
+ require.Equal(t, times[3], times[2].Add(step))
- time.Sleep(scaleDuration(step)) // make sure that no more packets are sent
- require.True(t, mockCtrl.Satisfied())
+ synctest.Wait() // make sure that no more packets are sent
+ require.True(t, mockCtrl.Satisfied())
- // test teardown
- sender.EXPECT().Close()
- tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
- tc.conn.destroy(nil)
- select {
- case <-sendChan:
- t.Fatal("should not have sent any more packets")
- case err := <-errChan:
- require.NoError(t, err)
- case <-time.After(3 * time.Second):
- t.Fatal("timeout")
- }
+ // test teardown
+ sender.EXPECT().Close()
+ tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
+ tc.conn.destroy(nil)
+
+ synctest.Wait()
+
+ select {
+ case <-sendChan:
+ t.Fatal("should not have sent any more packets")
+ case err := <-errChan:
+ require.NoError(t, err)
+ default:
+ t.Fatal("should have timed out")
+ }
+ })
}
// When the send queue blocks, we need to reset the pacing timer, otherwise the run loop might busy-loop.
// See https://github.com/quic-go/quic-go/pull/4943 for more details.
func TestConnectionPacingAndSendQueue(t *testing.T) {
- mockCtrl := gomock.NewController(t)
- sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
- sender := NewMockSender(mockCtrl)
+ synctest.Test(t, func(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+ sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
+ sender := NewMockSender(mockCtrl)
- tc := newServerTestConnection(t,
- mockCtrl,
- nil,
- false,
- connectionOptSentPacketHandler(sph),
- connectionOptSender(sender),
- connectionOptHandshakeConfirmed(),
- // set a fixed RTT, so that the idle timeout doesn't interfere with this test
- connectionOptRTT(10*time.Second),
- )
- sender.EXPECT().Run()
+ tc := newServerTestConnection(t,
+ mockCtrl,
+ nil,
+ false,
+ connectionOptSentPacketHandler(sph),
+ connectionOptSender(sender),
+ connectionOptHandshakeConfirmed(),
+ )
+ sender.EXPECT().Run()
+
+ sendQueueAvailable := make(chan struct{})
+ pacingDeadline := monotime.Now().Add(-time.Millisecond)
+ var counter int
+ // allow exactly one packet to be sent, then become blocked
+ sender.EXPECT().WouldBlock().Return(false)
+ sender.EXPECT().WouldBlock().DoAndReturn(func() bool { counter++; return true }).AnyTimes()
+ sender.EXPECT().Available().Return(sendQueueAvailable).AnyTimes()
+ sph.EXPECT().GetLossDetectionTimeout().Return(monotime.Now().Add(time.Hour)).AnyTimes()
+ sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited).AnyTimes()
+ sph.EXPECT().TimeUntilSend().Return(pacingDeadline).AnyTimes()
+ sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECNNon).AnyTimes()
+ tc.packer.EXPECT().PackAckOnlyPacket(gomock.Any(), gomock.Any(), gomock.Any()).Return(
+ shortHeaderPacket{}, nil, errNothingToPack,
+ )
- sendQueueAvailable := make(chan struct{})
- pacingDeadline := monotime.Now().Add(-time.Millisecond)
- var counter int
- // allow exactly one packet to be sent, then become blocked
- sender.EXPECT().WouldBlock().Return(false)
- sender.EXPECT().WouldBlock().DoAndReturn(func() bool { counter++; return true }).AnyTimes()
- sender.EXPECT().Available().Return(sendQueueAvailable).AnyTimes()
- sph.EXPECT().GetLossDetectionTimeout().Return(monotime.Now().Add(time.Hour)).AnyTimes()
- sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited).AnyTimes()
- sph.EXPECT().TimeUntilSend().Return(pacingDeadline).AnyTimes()
- sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECNNon).AnyTimes()
- tc.packer.EXPECT().PackAckOnlyPacket(gomock.Any(), gomock.Any(), gomock.Any()).Return(
- shortHeaderPacket{}, nil, errNothingToPack,
- )
+ errChan := make(chan error, 1)
+ go func() { errChan <- tc.conn.run() }()
+ tc.conn.scheduleSending()
- errChan := make(chan error, 1)
- go func() { errChan <- tc.conn.run() }()
- tc.conn.scheduleSending()
+ synctest.Wait()
- time.Sleep(scaleDuration(10 * time.Millisecond))
+ // test teardown
+ tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
+ sender.EXPECT().Close()
+ tc.conn.destroy(nil)
- // test teardown
- tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
- sender.EXPECT().Close()
- tc.conn.destroy(nil)
- select {
- case err := <-errChan:
- require.NoError(t, err)
- case <-time.After(time.Second):
- t.Fatal("timeout")
- }
+ synctest.Wait()
+ select {
+ case err := <-errChan:
+ require.NoError(t, err)
+ default:
+ t.Fatal("should have timed out")
+ }
- // make sure the run loop didn't do too many iterations
- require.Less(t, counter, 3)
+ // make sure the run loop didn't do too many iterations
+ require.Less(t, counter, 3)
+ })
}
func TestConnectionIdleTimeout(t *testing.T) {
- mockCtrl := gomock.NewController(t)
- sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
- tc := newServerTestConnection(t,
- mockCtrl,
- &Config{MaxIdleTimeout: time.Second},
- false,
- connectionOptHandshakeConfirmed(),
- connectionOptSentPacketHandler(sph),
- connectionOptRTT(time.Millisecond),
- )
- // the idle timeout is set when the transport parameters are received
- idleTimeout := scaleDuration(50 * time.Millisecond)
- require.NoError(t, tc.conn.handleTransportParameters(&wire.TransportParameters{
- MaxIdleTimeout: idleTimeout,
- }))
+ synctest.Test(t, func(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+ sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
+ tc := newServerTestConnection(t,
+ mockCtrl,
+ &Config{MaxIdleTimeout: time.Minute},
+ false,
+ connectionOptHandshakeConfirmed(),
+ connectionOptSentPacketHandler(sph),
+ connectionOptRTT(time.Millisecond),
+ )
+ // the idle timeout is set when the transport parameters are received
+ const idleTimeout = 500 * time.Millisecond
+ require.NoError(t, tc.conn.handleTransportParameters(&wire.TransportParameters{
+ MaxIdleTimeout: idleTimeout,
+ }))
+
+ sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
+ sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
+ sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
+ sph.EXPECT().ECNMode(gomock.Any()).AnyTimes()
+ var lastSendTime monotime.Time
+ tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
+ func(buf *packetBuffer, _ protocol.ByteCount, _ monotime.Time, _ protocol.Version) (shortHeaderPacket, error) {
+ buf.Data = append(buf.Data, []byte("foobar")...)
+ lastSendTime = monotime.Now()
+ return shortHeaderPacket{Frames: []ackhandler.Frame{{Frame: &wire.PingFrame{}}}, Length: 6}, nil
+ },
+ )
+ tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack)
+ tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
+ tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
- sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
- sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
- sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
- sph.EXPECT().ECNMode(gomock.Any()).AnyTimes()
- var lastSendTime monotime.Time
- tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
- func(buf *packetBuffer, _ protocol.ByteCount, _ monotime.Time, _ protocol.Version) (shortHeaderPacket, error) {
- buf.Data = append(buf.Data, []byte("foobar")...)
- lastSendTime = monotime.Now()
- return shortHeaderPacket{Frames: []ackhandler.Frame{{Frame: &wire.PingFrame{}}}, Length: 6}, nil
- },
- )
- tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack)
- tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
- tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
+ errChan := make(chan error, 1)
+ go func() { errChan <- tc.conn.run() }()
+ tc.conn.scheduleSending()
- errChan := make(chan error, 1)
- go func() { errChan <- tc.conn.run() }()
- tc.conn.scheduleSending()
+ synctest.Wait()
- select {
- case err := <-errChan:
- require.ErrorIs(t, err, &IdleTimeoutError{})
- require.NotZero(t, lastSendTime)
- require.InDelta(t,
- monotime.Since(lastSendTime).Seconds(),
- idleTimeout.Seconds(),
- scaleDuration(10*time.Millisecond).Seconds(),
- )
- case <-time.After(3 * time.Second):
- t.Fatal("timeout")
- }
+ select {
+ case err := <-errChan:
+ require.ErrorIs(t, err, &IdleTimeoutError{})
+ require.NotZero(t, lastSendTime)
+ require.Equal(t, idleTimeout, monotime.Since(lastSendTime))
+ case <-time.After(time.Hour):
+ t.Fatal("should have timed out")
+ }
+ })
}
func TestConnectionKeepAlive(t *testing.T) {
}
func testConnectionKeepAlive(t *testing.T, enable, expectKeepAlive bool) {
- var keepAlivePeriod time.Duration
- if enable {
- keepAlivePeriod = time.Second
- }
-
- mockCtrl := gomock.NewController(t)
- unpacker := NewMockUnpacker(mockCtrl)
- tc := newServerTestConnection(t,
- mockCtrl,
- &Config{MaxIdleTimeout: time.Second, KeepAlivePeriod: keepAlivePeriod},
- false,
- connectionOptUnpacker(unpacker),
- connectionOptHandshakeConfirmed(),
- connectionOptRTT(time.Millisecond),
- )
- // the idle timeout is set when the transport parameters are received
- idleTimeout := scaleDuration(50 * time.Millisecond)
- require.NoError(t, tc.conn.handleTransportParameters(&wire.TransportParameters{
- MaxIdleTimeout: idleTimeout,
- }))
-
- // Receive a packet. This starts the keep-alive timer.
- buf := getPacketBuffer()
- var err error
- buf.Data, err = wire.AppendShortHeader(buf.Data, tc.srcConnID, 1, protocol.PacketNumberLen1, protocol.KeyPhaseZero)
- require.NoError(t, err)
- buf.Data = append(buf.Data, []byte("packet")...)
+ synctest.Test(t, func(t *testing.T) {
+ var keepAlivePeriod time.Duration
+ if enable {
+ keepAlivePeriod = time.Second
+ }
- errChan := make(chan error, 1)
- go func() { errChan <- tc.conn.run() }()
+ mockCtrl := gomock.NewController(t)
+ unpacker := NewMockUnpacker(mockCtrl)
+ tc := newServerTestConnection(t,
+ mockCtrl,
+ &Config{MaxIdleTimeout: time.Second, KeepAlivePeriod: keepAlivePeriod},
+ false,
+ connectionOptUnpacker(unpacker),
+ connectionOptHandshakeConfirmed(),
+ connectionOptRTT(time.Millisecond),
+ )
+ // the idle timeout is set when the transport parameters are received
+ const idleTimeout = 50 * time.Millisecond
+ require.NoError(t, tc.conn.handleTransportParameters(&wire.TransportParameters{
+ MaxIdleTimeout: idleTimeout,
+ }))
+
+ // Receive a packet. This starts the keep-alive timer.
+ buf := getPacketBuffer()
+ var err error
+ buf.Data, err = wire.AppendShortHeader(buf.Data, tc.srcConnID, 1, protocol.PacketNumberLen1, protocol.KeyPhaseZero)
+ require.NoError(t, err)
+ buf.Data = append(buf.Data, []byte("packet")...)
- var unpackTime, packTime monotime.Time
- done := make(chan struct{})
- unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(
- func(t monotime.Time, bytes []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
- unpackTime = monotime.Now()
- return protocol.PacketNumber(1), protocol.PacketNumberLen1, protocol.KeyPhaseZero, []byte{0} /* PADDING */, nil
- },
- )
- tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack)
+ errChan := make(chan error, 1)
+ go func() { errChan <- tc.conn.run() }()
- switch expectKeepAlive {
- case true:
- // record the time of the keep-alive is sent
- tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
- func(buffer *packetBuffer, count protocol.ByteCount, t monotime.Time, version protocol.Version) (shortHeaderPacket, error) {
- packTime = monotime.Now()
- close(done)
- return shortHeaderPacket{}, errNothingToPack
+ var unpackTime, packTime monotime.Time
+ done := make(chan struct{})
+ unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(
+ func(t monotime.Time, bytes []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
+ unpackTime = monotime.Now()
+ return protocol.PacketNumber(1), protocol.PacketNumberLen1, protocol.KeyPhaseZero, []byte{0} /* PADDING */, nil
},
)
- tc.conn.handlePacket(receivedPacket{data: buf.Data, buffer: buf, rcvTime: monotime.Now(), remoteAddr: tc.remoteAddr})
- select {
- case <-done:
- // the keep-alive packet should be sent after half the idle timeout
- diff := packTime.Sub(unpackTime)
- require.InDelta(t, diff.Seconds(), idleTimeout.Seconds()/2, scaleDuration(10*time.Millisecond).Seconds())
- case <-time.After(idleTimeout):
- t.Fatal("timeout")
+ tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack)
+
+ switch expectKeepAlive {
+ case true:
+ // record the time of the keep-alive is sent
+ tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
+ func(buffer *packetBuffer, count protocol.ByteCount, t monotime.Time, version protocol.Version) (shortHeaderPacket, error) {
+ packTime = monotime.Now()
+ close(done)
+ return shortHeaderPacket{}, errNothingToPack
+ },
+ )
+ tc.conn.handlePacket(receivedPacket{data: buf.Data, buffer: buf, rcvTime: monotime.Now(), remoteAddr: tc.remoteAddr})
+ select {
+ case <-done:
+ // the keep-alive packet should be sent after half the idle timeout
+ require.Equal(t, unpackTime.Add(idleTimeout/2), packTime)
+ case <-time.After(idleTimeout):
+ t.Fatal("timeout")
+ }
+ case false: // if keep-alives are disabled, the connection will run into an idle timeout
+ tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
+ tc.conn.handlePacket(receivedPacket{data: buf.Data, buffer: buf, rcvTime: monotime.Now(), remoteAddr: tc.remoteAddr})
}
- case false: // if keep-alives are disabled, the connection will run into an idle timeout
- tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
- tc.conn.handlePacket(receivedPacket{data: buf.Data, buffer: buf, rcvTime: monotime.Now(), remoteAddr: tc.remoteAddr})
+
+ // test teardown
+ if expectKeepAlive {
+ tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
+ tc.conn.destroy(nil)
+ }
+
+ synctest.Wait()
+
select {
- case <-time.After(3 * time.Second):
+ case err := <-errChan:
+ if expectKeepAlive {
+ require.NoError(t, err)
+ } else {
+ require.ErrorIs(t, err, &IdleTimeoutError{})
+ }
+ case <-time.After(time.Hour):
t.Fatal("timeout")
- case <-time.After(idleTimeout):
}
- }
+ })
+}
- // test teardown
- if expectKeepAlive {
+func TestConnectionACKTimer(t *testing.T) {
+ synctest.Test(t, func(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+ rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
+ sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
+ tc := newServerTestConnection(t,
+ mockCtrl,
+ &Config{MaxIdleTimeout: time.Second},
+ false,
+ connectionOptHandshakeConfirmed(),
+ connectionOptReceivedPacketHandler(rph),
+ connectionOptSentPacketHandler(sph),
+ )
+ const alarmTimeout = 500 * time.Millisecond
+
+ sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
+ sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
+ sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
+ sph.EXPECT().ECNMode(gomock.Any()).AnyTimes()
+ rph.EXPECT().GetAlarmTimeout().Return(monotime.Now().Add(time.Hour))
+ tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
+
+ var times []monotime.Time
+ done := make(chan struct{}, 5)
+ var calls []any
+ for i := range 2 {
+ calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
+ func(buf *packetBuffer, _ protocol.ByteCount, _ monotime.Time, _ protocol.Version) (shortHeaderPacket, error) {
+ buf.Data = append(buf.Data, []byte("foobar")...)
+ times = append(times, monotime.Now())
+ return shortHeaderPacket{Frames: []ackhandler.Frame{{Frame: &wire.PingFrame{}}}, Length: 6}, nil
+ },
+ ))
+ calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
+ func(*packetBuffer, protocol.ByteCount, monotime.Time, protocol.Version) (shortHeaderPacket, error) {
+ done <- struct{}{}
+ return shortHeaderPacket{}, errNothingToPack
+ },
+ ))
+ if i == 0 {
+ calls = append(calls, rph.EXPECT().GetAlarmTimeout().Return(monotime.Now().Add(alarmTimeout)))
+ } else {
+ calls = append(calls, rph.EXPECT().GetAlarmTimeout().Return(monotime.Now().Add(time.Hour)))
+ }
+ }
+ gomock.InOrder(calls...)
+ errChan := make(chan error, 1)
+ go func() { errChan <- tc.conn.run() }()
+ tc.conn.scheduleSending()
+
+ for range 2 {
+ synctest.Wait()
+
+ select {
+ case <-done:
+ case <-time.After(time.Hour):
+ t.Fatal("timeout")
+ }
+ }
+
+ assert.Len(t, times, 2)
+ require.Equal(t, times[0].Add(alarmTimeout), times[1])
+
+ // test teardown
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
tc.conn.destroy(nil)
- }
- select {
- case err := <-errChan:
- if expectKeepAlive {
+
+ synctest.Wait()
+ select {
+ case err := <-errChan:
require.NoError(t, err)
- } else {
- require.ErrorIs(t, err, &IdleTimeoutError{})
+ default:
+ t.Fatal("should have timed out")
}
- case <-time.After(3 * time.Second):
- t.Fatal("timeout")
- }
+ })
}
-func TestConnectionACKTimer(t *testing.T) {
- mockCtrl := gomock.NewController(t)
- rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
- sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
- tc := newServerTestConnection(t,
- mockCtrl,
- &Config{MaxIdleTimeout: time.Second},
- false,
- connectionOptHandshakeConfirmed(),
- connectionOptReceivedPacketHandler(rph),
- connectionOptSentPacketHandler(sph),
- connectionOptRTT(10*time.Second),
- )
- alarmTimeout := scaleDuration(50 * time.Millisecond)
-
- sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
- sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
- sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
- sph.EXPECT().ECNMode(gomock.Any()).AnyTimes()
- rph.EXPECT().GetAlarmTimeout().Return(monotime.Now().Add(time.Hour))
- tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
+// Send a GSO batch, until we have no more data to send.
+func TestConnectionGSOBatch(t *testing.T) {
+ synctest.Test(t, func(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+ sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
+ tc := newServerTestConnection(t,
+ mockCtrl,
+ nil,
+ true,
+ connectionOptHandshakeConfirmed(),
+ connectionOptSentPacketHandler(sph),
+ )
- var times []monotime.Time
- done := make(chan struct{}, 5)
- var calls []any
- for i := 0; i < 2; i++ {
- calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
- func(buf *packetBuffer, _ protocol.ByteCount, _ monotime.Time, _ protocol.Version) (shortHeaderPacket, error) {
- buf.Data = append(buf.Data, []byte("foobar")...)
- times = append(times, monotime.Now())
- return shortHeaderPacket{Frames: []ackhandler.Frame{{Frame: &wire.PingFrame{}}}, Length: 6}, nil
- },
- ))
- calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
- func(buf *packetBuffer, _ protocol.ByteCount, _ monotime.Time, _ protocol.Version) (shortHeaderPacket, error) {
- done <- struct{}{}
- return shortHeaderPacket{}, errNothingToPack
- },
- ))
- if i == 0 {
- calls = append(calls, rph.EXPECT().GetAlarmTimeout().Return(monotime.Now().Add(alarmTimeout)))
- } else {
- calls = append(calls, rph.EXPECT().GetAlarmTimeout().Return(monotime.Now().Add(time.Hour)).MaxTimes(1))
+ // allow packets to be sent
+ sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
+ sph.EXPECT().TimeUntilSend().AnyTimes()
+ sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
+ sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
+ sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECT1).AnyTimes()
+
+ maxPacketSize := tc.conn.maxPacketSize()
+ var expectedData []byte
+ for i := range 4 {
+ data := bytes.Repeat([]byte{byte(i)}, int(maxPacketSize))
+ expectedData = append(expectedData, data...)
+
+ tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
+ func(buffer *packetBuffer, count protocol.ByteCount, t monotime.Time, version protocol.Version) (shortHeaderPacket, error) {
+ buffer.Data = append(buffer.Data, data...)
+ return shortHeaderPacket{PacketNumber: protocol.PacketNumber(i)}, nil
+ },
+ )
}
- }
- gomock.InOrder(calls...)
- errChan := make(chan error, 1)
- go func() { errChan <- tc.conn.run() }()
- tc.conn.scheduleSending()
+ done := make(chan struct{})
+ tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack)
+ tc.sendConn.EXPECT().Write(expectedData, uint16(maxPacketSize), protocol.ECT1).DoAndReturn(
+ func([]byte, uint16, protocol.ECN) error { close(done); return nil },
+ )
+
+ errChan := make(chan error, 1)
+ go func() { errChan <- tc.conn.run() }()
+ tc.conn.scheduleSending()
+
+ synctest.Wait()
- for i := 0; i < 2; i++ {
select {
case <-done:
- case <-time.After(3 * time.Second):
- t.Fatal("timeout")
+ default:
+ t.Fatal("should have sent a packet")
}
- }
- assert.Len(t, times, 2)
- require.InDelta(t, times[1].Sub(times[0]).Seconds(), alarmTimeout.Seconds(), scaleDuration(10*time.Millisecond).Seconds())
+ // test teardown
+ tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
+ tc.conn.destroy(nil)
- // test teardown
- tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
- tc.conn.destroy(nil)
- select {
- case err := <-errChan:
- require.NoError(t, err)
- case <-time.After(3 * time.Second):
- t.Fatal("timeout")
- }
-}
+ synctest.Wait()
-// Send a GSO batch, until we have no more data to send.
-func TestConnectionGSOBatch(t *testing.T) {
- mockCtrl := gomock.NewController(t)
- sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
- tc := newServerTestConnection(t,
- mockCtrl,
- nil,
- true,
- connectionOptHandshakeConfirmed(),
- connectionOptSentPacketHandler(sph),
- )
+ select {
+ case err := <-errChan:
+ require.NoError(t, err)
+ default:
+ t.Fatal("should have timed out")
+ }
+ })
+}
- // allow packets to be sent
- sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
- sph.EXPECT().TimeUntilSend().AnyTimes()
- sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
- sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
- sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECT1).AnyTimes()
+// Send a GSO batch, until a packet smaller than the maximum size is packed
+func TestConnectionGSOBatchPacketSize(t *testing.T) {
+ synctest.Test(t, func(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+ sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
+ tc := newServerTestConnection(t,
+ mockCtrl,
+ nil,
+ true,
+ connectionOptHandshakeConfirmed(),
+ connectionOptSentPacketHandler(sph),
+ )
- maxPacketSize := tc.conn.maxPacketSize()
- var expectedData []byte
- for i := 0; i < 4; i++ {
- data := bytes.Repeat([]byte{byte(i)}, int(maxPacketSize))
- expectedData = append(expectedData, data...)
+ // allow packets to be sent
+ sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
+ sph.EXPECT().TimeUntilSend().AnyTimes()
+ sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
+ sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
+ sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECT1).AnyTimes()
+
+ maxPacketSize := tc.conn.maxPacketSize()
+ var expectedData []byte
+ var calls []any
+ for i := range 4 {
+ var data []byte
+ if i == 3 {
+ data = bytes.Repeat([]byte{byte(i)}, int(maxPacketSize-1))
+ } else {
+ data = bytes.Repeat([]byte{byte(i)}, int(maxPacketSize))
+ }
+ expectedData = append(expectedData, data...)
+
+ calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
+ func(buffer *packetBuffer, count protocol.ByteCount, t monotime.Time, version protocol.Version) (shortHeaderPacket, error) {
+ buffer.Data = append(buffer.Data, data...)
+ return shortHeaderPacket{PacketNumber: protocol.PacketNumber(10 + i)}, nil
+ },
+ ))
+ }
+ // The smaller (fourth) packet concluded this GSO batch, but the send loop will immediately start composing the next batch.
+ // We therefore send a "foobar", so we can check that we're actually generating two GSO batches.
+ calls = append(calls,
+ tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
+ func(buffer *packetBuffer, count protocol.ByteCount, t monotime.Time, version protocol.Version) (shortHeaderPacket, error) {
+ buffer.Data = append(buffer.Data, []byte("foobar")...)
+ return shortHeaderPacket{PacketNumber: protocol.PacketNumber(14)}, nil
+ },
+ ),
+ )
+ calls = append(calls,
+ tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack),
+ )
+ gomock.InOrder(calls...)
- tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
- func(buffer *packetBuffer, count protocol.ByteCount, t monotime.Time, version protocol.Version) (shortHeaderPacket, error) {
- buffer.Data = append(buffer.Data, data...)
- return shortHeaderPacket{PacketNumber: protocol.PacketNumber(i)}, nil
- },
+ done := make(chan struct{})
+ gomock.InOrder(
+ tc.sendConn.EXPECT().Write(expectedData, uint16(maxPacketSize), protocol.ECT1),
+ tc.sendConn.EXPECT().Write([]byte("foobar"), uint16(maxPacketSize), protocol.ECT1).DoAndReturn(
+ func([]byte, uint16, protocol.ECN) error { close(done); return nil },
+ ),
)
- }
- done := make(chan struct{})
- tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack)
- tc.sendConn.EXPECT().Write(expectedData, uint16(maxPacketSize), protocol.ECT1).DoAndReturn(
- func([]byte, uint16, protocol.ECN) error { close(done); return nil },
- )
+ errChan := make(chan error, 1)
+ go func() { errChan <- tc.conn.run() }()
+ tc.conn.scheduleSending()
- errChan := make(chan error, 1)
- go func() { errChan <- tc.conn.run() }()
- tc.conn.scheduleSending()
+ synctest.Wait()
- select {
- case <-done:
- case <-time.After(time.Second):
- t.Fatal("timeout")
- }
+ select {
+ case <-done:
+ default:
+ t.Fatal("should have sent a packet")
+ }
- // test teardown
- tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
- tc.conn.destroy(nil)
- select {
- case err := <-errChan:
- require.NoError(t, err)
- case <-time.After(3 * time.Second):
- t.Fatal("timeout")
- }
-}
+ // test teardown
+ tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
+ tc.conn.destroy(nil)
-// Send a GSO batch, until a packet smaller than the maximum size is packed
-func TestConnectionGSOBatchPacketSize(t *testing.T) {
- mockCtrl := gomock.NewController(t)
- sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
- tc := newServerTestConnection(t,
- mockCtrl,
- nil,
- true,
- connectionOptHandshakeConfirmed(),
- connectionOptSentPacketHandler(sph),
- )
+ synctest.Wait()
- // allow packets to be sent
- sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
- sph.EXPECT().TimeUntilSend().AnyTimes()
- sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
- sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
- sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECT1).AnyTimes()
-
- maxPacketSize := tc.conn.maxPacketSize()
- var expectedData []byte
- var calls []any
- for i := 0; i < 4; i++ {
- var data []byte
- if i == 3 {
- data = bytes.Repeat([]byte{byte(i)}, int(maxPacketSize-1))
- } else {
- data = bytes.Repeat([]byte{byte(i)}, int(maxPacketSize))
+ select {
+ case err := <-errChan:
+ require.NoError(t, err)
+ default:
+ t.Fatal("should have timed out")
}
- expectedData = append(expectedData, data...)
+ })
+}
- calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
- func(buffer *packetBuffer, count protocol.ByteCount, t monotime.Time, version protocol.Version) (shortHeaderPacket, error) {
- buffer.Data = append(buffer.Data, data...)
- return shortHeaderPacket{PacketNumber: protocol.PacketNumber(10 + i)}, nil
- },
- ))
- }
- // The smaller (fourth) packet concluded this GSO batch, but the send loop will immediately start composing the next batch.
- // We therefore send a "foobar", so we can check that we're actually generating two GSO batches.
- calls = append(calls,
- tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
- func(buffer *packetBuffer, count protocol.ByteCount, t monotime.Time, version protocol.Version) (shortHeaderPacket, error) {
- buffer.Data = append(buffer.Data, []byte("foobar")...)
- return shortHeaderPacket{PacketNumber: protocol.PacketNumber(14)}, nil
- },
- ),
- )
- calls = append(calls,
- tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack),
- )
- gomock.InOrder(calls...)
+func TestConnectionGSOBatchECN(t *testing.T) {
+ synctest.Test(t, func(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+ sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
+ tc := newServerTestConnection(t,
+ mockCtrl,
+ nil,
+ true,
+ connectionOptHandshakeConfirmed(),
+ connectionOptSentPacketHandler(sph),
+ )
- done := make(chan struct{})
- gomock.InOrder(
- tc.sendConn.EXPECT().Write(expectedData, uint16(maxPacketSize), protocol.ECT1),
- tc.sendConn.EXPECT().Write([]byte("foobar"), uint16(maxPacketSize), protocol.ECT1).DoAndReturn(
- func([]byte, uint16, protocol.ECN) error { close(done); return nil },
- ),
- )
- errChan := make(chan error, 1)
- go func() { errChan <- tc.conn.run() }()
- tc.conn.scheduleSending()
- select {
- case <-done:
- case <-time.After(time.Second):
- t.Fatal("timeout")
- }
+ // allow packets to be sent
+ ecnMode := protocol.ECT1
+ sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
+ sph.EXPECT().TimeUntilSend().AnyTimes()
+ sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
+ sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
+ sph.EXPECT().ECNMode(gomock.Any()).DoAndReturn(func(bool) protocol.ECN { return ecnMode }).AnyTimes()
+
+ // 3. Send a GSO batch, until the ECN marking changes.
+ var expectedData []byte
+ var calls []any
+ maxPacketSize := tc.conn.maxPacketSize()
+ for i := range 3 {
+ data := bytes.Repeat([]byte{byte(i)}, int(maxPacketSize))
+ expectedData = append(expectedData, data...)
+
+ calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
+ func(buffer *packetBuffer, count protocol.ByteCount, t monotime.Time, version protocol.Version) (shortHeaderPacket, error) {
+ buffer.Data = append(buffer.Data, data...)
+ if i == 2 {
+ ecnMode = protocol.ECNCE
+ }
+ return shortHeaderPacket{PacketNumber: protocol.PacketNumber(20 + i)}, nil
+ },
+ ))
+ }
+ // The smaller (fourth) packet concluded this GSO batch, but the send loop will immediately start composing the next batch.
+ // We therefore send a "foobar", so we can check that we're actually generating two GSO batches.
+ calls = append(calls,
+ tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
+ func(buffer *packetBuffer, count protocol.ByteCount, t monotime.Time, version protocol.Version) (shortHeaderPacket, error) {
+ buffer.Data = append(buffer.Data, []byte("foobar")...)
+ return shortHeaderPacket{PacketNumber: protocol.PacketNumber(24)}, nil
+ },
+ ),
+ )
+ calls = append(calls,
+ tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack),
+ )
+ gomock.InOrder(calls...)
- // test teardown
- tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
- tc.conn.destroy(nil)
- select {
- case err := <-errChan:
- require.NoError(t, err)
- case <-time.After(3 * time.Second):
- t.Fatal("timeout")
- }
-}
+ done3 := make(chan struct{})
+ tc.sendConn.EXPECT().Write(expectedData, uint16(maxPacketSize), protocol.ECT1)
+ tc.sendConn.EXPECT().Write([]byte("foobar"), uint16(maxPacketSize), protocol.ECNCE).DoAndReturn(
+ func([]byte, uint16, protocol.ECN) error { close(done3); return nil },
+ )
-func TestConnectionGSOBatchECN(t *testing.T) {
- mockCtrl := gomock.NewController(t)
- sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
- tc := newServerTestConnection(t,
- mockCtrl,
- nil,
- true,
- connectionOptHandshakeConfirmed(),
- connectionOptSentPacketHandler(sph),
- )
+ errChan := make(chan error, 1)
+ go func() { errChan <- tc.conn.run() }()
+ tc.conn.scheduleSending()
- // allow packets to be sent
- ecnMode := protocol.ECT1
- sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
- sph.EXPECT().TimeUntilSend().AnyTimes()
- sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
- sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
- sph.EXPECT().ECNMode(gomock.Any()).DoAndReturn(func(bool) protocol.ECN { return ecnMode }).AnyTimes()
-
- // 3. Send a GSO batch, until the ECN marking changes.
- var expectedData []byte
- var calls []any
- maxPacketSize := tc.conn.maxPacketSize()
- for i := 0; i < 3; i++ {
- data := bytes.Repeat([]byte{byte(i)}, int(maxPacketSize))
- expectedData = append(expectedData, data...)
-
- calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
- func(buffer *packetBuffer, count protocol.ByteCount, t monotime.Time, version protocol.Version) (shortHeaderPacket, error) {
- buffer.Data = append(buffer.Data, data...)
- if i == 2 {
- ecnMode = protocol.ECNCE
- }
- return shortHeaderPacket{PacketNumber: protocol.PacketNumber(20 + i)}, nil
- },
- ))
- }
- // The smaller (fourth) packet concluded this GSO batch, but the send loop will immediately start composing the next batch.
- // We therefore send a "foobar", so we can check that we're actually generating two GSO batches.
- calls = append(calls,
- tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
- func(buffer *packetBuffer, count protocol.ByteCount, t monotime.Time, version protocol.Version) (shortHeaderPacket, error) {
- buffer.Data = append(buffer.Data, []byte("foobar")...)
- return shortHeaderPacket{PacketNumber: protocol.PacketNumber(24)}, nil
- },
- ),
- )
- calls = append(calls,
- tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack),
- )
- gomock.InOrder(calls...)
+ synctest.Wait()
- done3 := make(chan struct{})
- tc.sendConn.EXPECT().Write(expectedData, uint16(maxPacketSize), protocol.ECT1)
- tc.sendConn.EXPECT().Write([]byte("foobar"), uint16(maxPacketSize), protocol.ECNCE).DoAndReturn(
- func([]byte, uint16, protocol.ECN) error { close(done3); return nil },
- )
+ select {
+ case <-done3:
+ default:
+ t.Fatal("should have sent a packet")
+ }
- errChan := make(chan error, 1)
- go func() { errChan <- tc.conn.run() }()
- tc.conn.scheduleSending()
+ // test teardown
+ tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
+ tc.conn.destroy(nil)
- select {
- case <-done3:
- case <-time.After(time.Second):
- t.Fatal("timeout")
- }
+ synctest.Wait()
- // test teardown
- tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
- tc.conn.destroy(nil)
- select {
- case err := <-errChan:
- require.NoError(t, err)
- case <-time.After(3 * time.Second):
- t.Fatal("timeout")
- }
+ select {
+ case err := <-errChan:
+ require.NoError(t, err)
+ default:
+ t.Fatal("should have timed out")
+ }
+ })
}
func TestConnectionPTOProbePackets(t *testing.T) {
}
func testConnectionPTOProbePackets(t *testing.T, encLevel protocol.EncryptionLevel) {
- mockCtrl := gomock.NewController(t)
- sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
- tc := newServerTestConnection(t,
- mockCtrl,
- nil,
- false,
- connectionOptSentPacketHandler(sph),
- )
+ synctest.Test(t, func(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+ sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
+ tc := newServerTestConnection(t,
+ mockCtrl,
+ nil,
+ false,
+ connectionOptSentPacketHandler(sph),
+ )
- var sendMode ackhandler.SendMode
- switch encLevel {
- case protocol.EncryptionInitial:
- sendMode = ackhandler.SendPTOInitial
- case protocol.EncryptionHandshake:
- sendMode = ackhandler.SendPTOHandshake
- case protocol.Encryption1RTT:
- sendMode = ackhandler.SendPTOAppData
- }
-
- sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
- sph.EXPECT().TimeUntilSend().AnyTimes()
- sph.EXPECT().SendMode(gomock.Any()).Return(sendMode)
- sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone)
- sph.EXPECT().ECNMode(gomock.Any())
- sph.EXPECT().QueueProbePacket(encLevel).Return(false)
- sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
-
- tc.packer.EXPECT().PackPTOProbePacket(encLevel, gomock.Any(), true, gomock.Any(), protocol.Version1).DoAndReturn(
- func(protocol.EncryptionLevel, protocol.ByteCount, bool, monotime.Time, protocol.Version) (*coalescedPacket, error) {
- return &coalescedPacket{
- buffer: getPacketBuffer(),
- shortHdrPacket: &shortHeaderPacket{PacketNumber: 1},
- }, nil
- },
- )
- done := make(chan struct{})
- tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).Do(
- func([]byte, uint16, protocol.ECN) error { close(done); return nil },
- )
+ var sendMode ackhandler.SendMode
+ switch encLevel {
+ case protocol.EncryptionInitial:
+ sendMode = ackhandler.SendPTOInitial
+ case protocol.EncryptionHandshake:
+ sendMode = ackhandler.SendPTOHandshake
+ case protocol.Encryption1RTT:
+ sendMode = ackhandler.SendPTOAppData
+ }
- errChan := make(chan error, 1)
- go func() { errChan <- tc.conn.run() }()
- tc.conn.scheduleSending()
+ sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
+ sph.EXPECT().TimeUntilSend().AnyTimes()
+ sph.EXPECT().SendMode(gomock.Any()).Return(sendMode)
+ sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone)
+ sph.EXPECT().ECNMode(gomock.Any())
+ sph.EXPECT().QueueProbePacket(encLevel).Return(false)
+ sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
+
+ tc.packer.EXPECT().PackPTOProbePacket(encLevel, gomock.Any(), true, gomock.Any(), protocol.Version1).DoAndReturn(
+ func(protocol.EncryptionLevel, protocol.ByteCount, bool, monotime.Time, protocol.Version) (*coalescedPacket, error) {
+ return &coalescedPacket{
+ buffer: getPacketBuffer(),
+ shortHdrPacket: &shortHeaderPacket{PacketNumber: 1},
+ }, nil
+ },
+ )
+ done := make(chan struct{})
+ tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).Do(
+ func([]byte, uint16, protocol.ECN) error { close(done); return nil },
+ )
- select {
- case <-done:
- case <-time.After(time.Second):
- t.Fatal("timeout")
- }
+ errChan := make(chan error, 1)
+ go func() { errChan <- tc.conn.run() }()
+ tc.conn.scheduleSending()
- // test teardown
- tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
- tc.conn.destroy(nil)
- select {
- case err := <-errChan:
- require.NoError(t, err)
- case <-time.After(3 * time.Second):
- t.Fatal("timeout")
- }
+ select {
+ case <-done:
+ case <-time.After(time.Second):
+ t.Fatal("timeout")
+ }
+
+ // test teardown
+ tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
+ tc.conn.destroy(nil)
+
+ synctest.Wait()
+
+ select {
+ case err := <-errChan:
+ require.NoError(t, err)
+ default:
+ t.Fatal("should have timed out")
+ }
+ })
}
func TestConnectionCongestionControl(t *testing.T) {
- mockCtrl := gomock.NewController(t)
- sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
- tc := newServerTestConnection(t,
- mockCtrl,
- nil,
- false,
- connectionOptHandshakeConfirmed(),
- connectionOptSentPacketHandler(sph),
- connectionOptRTT(10*time.Second),
- )
+ synctest.Test(t, func(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+ sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
+ tc := newServerTestConnection(t,
+ mockCtrl,
+ nil,
+ false,
+ connectionOptHandshakeConfirmed(),
+ connectionOptSentPacketHandler(sph),
+ )
- sph.EXPECT().TimeUntilSend().AnyTimes()
- sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
- sph.EXPECT().ECNMode(true).AnyTimes()
- sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(2)
- sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAck).MaxTimes(1)
- sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2)
- // Since we're already sending out packets, we don't expect any calls to PackAckOnlyPacket
- for i := 0; i < 2; i++ {
- tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
- func(buffer *packetBuffer, count protocol.ByteCount, t monotime.Time, version protocol.Version) (shortHeaderPacket, error) {
- buffer.Data = append(buffer.Data, []byte("foobar")...)
- return shortHeaderPacket{PacketNumber: protocol.PacketNumber(i)}, nil
+ sph.EXPECT().TimeUntilSend().AnyTimes()
+ sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
+ sph.EXPECT().ECNMode(true).AnyTimes()
+ sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(2)
+ sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAck).MaxTimes(1)
+ sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2)
+ // Since we're already sending out packets, we don't expect any calls to PackAckOnlyPacket
+ for i := range 2 {
+ tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
+ func(buffer *packetBuffer, count protocol.ByteCount, t monotime.Time, version protocol.Version) (shortHeaderPacket, error) {
+ buffer.Data = append(buffer.Data, []byte("foobar")...)
+ return shortHeaderPacket{PacketNumber: protocol.PacketNumber(i)}, nil
+ },
+ )
+ }
+ tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
+ done1 := make(chan struct{})
+ tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).Do(
+ func([]byte, uint16, protocol.ECN) error { close(done1); return nil },
+ )
+
+ errChan := make(chan error, 1)
+ go func() { errChan <- tc.conn.run() }()
+ tc.conn.scheduleSending()
+
+ synctest.Wait()
+
+ select {
+ case <-done1:
+ default:
+ t.Fatal("should have sent a packet")
+ }
+ require.True(t, mockCtrl.Satisfied())
+
+ // Now that we're congestion limited, we can only send an ack-only packet
+ done2 := make(chan struct{})
+ sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAck)
+ tc.packer.EXPECT().PackAckOnlyPacket(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
+ func(protocol.ByteCount, monotime.Time, protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
+ close(done2)
+ return shortHeaderPacket{}, nil, errNothingToPack
},
)
- }
- tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
- done1 := make(chan struct{})
- tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).Do(
- func([]byte, uint16, protocol.ECN) error { close(done1); return nil },
- )
+ tc.conn.scheduleSending()
- errChan := make(chan error, 1)
- go func() { errChan <- tc.conn.run() }()
- tc.conn.scheduleSending()
- select {
- case <-done1:
- case <-time.After(time.Second):
- t.Fatal("timeout")
- }
- require.True(t, mockCtrl.Satisfied())
+ synctest.Wait()
- // Now that we're congestion limited, we can only send an ack-only packet
- done2 := make(chan struct{})
- sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAck)
- tc.packer.EXPECT().PackAckOnlyPacket(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
- func(protocol.ByteCount, monotime.Time, protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
- close(done2)
- return shortHeaderPacket{}, nil, errNothingToPack
- },
- )
- tc.conn.scheduleSending()
- select {
- case <-done2:
- case <-time.After(time.Second):
- t.Fatal("timeout")
- }
- require.True(t, mockCtrl.Satisfied())
+ select {
+ case <-done2:
+ default:
+ t.Fatal("should have sent an ack-only packet")
+ }
+ require.True(t, mockCtrl.Satisfied())
- // If the send mode is "none", we can't even send an ack-only packet
- sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone)
- tc.conn.scheduleSending()
- time.Sleep(scaleDuration(10 * time.Millisecond)) // make sure there are no calls to the packer
+ // If the send mode is "none", we can't even send an ack-only packet
+ sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone)
+ tc.conn.scheduleSending()
+ synctest.Wait() // make sure there are no calls to the packer
- // test teardown
- tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
- tc.conn.destroy(nil)
- select {
- case err := <-errChan:
- require.NoError(t, err)
- case <-time.After(3 * time.Second):
- t.Fatal("timeout")
- }
+ // test teardown
+ tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
+ tc.conn.destroy(nil)
+ select {
+ case err := <-errChan:
+ require.NoError(t, err)
+ default:
+ t.Fatal("timeout")
+ }
+ })
}
func TestConnectionSendQueue(t *testing.T) {
}
func testConnectionSendQueue(t *testing.T, enableGSO bool) {
- mockCtrl := gomock.NewController(t)
- sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
- sender := NewMockSender(mockCtrl)
- tc := newServerTestConnection(t,
- mockCtrl,
- nil,
- enableGSO,
- connectionOptSender(sender),
- connectionOptHandshakeConfirmed(),
- connectionOptSentPacketHandler(sph),
- )
+ synctest.Test(t, func(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+ sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
+ sender := NewMockSender(mockCtrl)
+ tc := newServerTestConnection(t,
+ mockCtrl,
+ nil,
+ enableGSO,
+ connectionOptSender(sender),
+ connectionOptHandshakeConfirmed(),
+ connectionOptSentPacketHandler(sph),
+ )
- sender.EXPECT().Run().MaxTimes(1)
- sender.EXPECT().WouldBlock()
- sender.EXPECT().WouldBlock().Return(true).Times(2)
- available := make(chan struct{})
- blocked := make(chan struct{})
- sender.EXPECT().Available().DoAndReturn(
- func() <-chan struct{} {
- close(blocked)
- return available
- },
- )
- sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
- sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
- sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
- sph.EXPECT().ECNMode(gomock.Any()).AnyTimes()
- tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(
- shortHeaderPacket{PacketNumber: protocol.PacketNumber(1)}, nil,
- )
- sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any())
+ sender.EXPECT().Run().MaxTimes(1)
+ sender.EXPECT().WouldBlock()
+ sender.EXPECT().WouldBlock().Return(true).Times(2)
+ available := make(chan struct{})
+ blocked := make(chan struct{})
+ sender.EXPECT().Available().DoAndReturn(
+ func() <-chan struct{} {
+ close(blocked)
+ return available
+ },
+ )
+ sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
+ sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
+ sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
+ sph.EXPECT().ECNMode(gomock.Any()).AnyTimes()
+ tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(
+ shortHeaderPacket{PacketNumber: protocol.PacketNumber(1)}, nil,
+ )
+ sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any())
- errChan := make(chan error, 1)
- go func() { errChan <- tc.conn.run() }()
- tc.conn.scheduleSending()
+ errChan := make(chan error, 1)
+ go func() { errChan <- tc.conn.run() }()
+ tc.conn.scheduleSending()
- select {
- case <-blocked:
- case <-time.After(time.Second):
- t.Fatal("timeout")
- }
- require.True(t, mockCtrl.Satisfied())
+ synctest.Wait()
- // now make room in the send queue
- sender.EXPECT().WouldBlock().AnyTimes()
- unblocked := make(chan struct{})
- tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
- func(*packetBuffer, protocol.ByteCount, monotime.Time, protocol.Version) (shortHeaderPacket, error) {
- close(unblocked)
- return shortHeaderPacket{}, errNothingToPack
- },
- )
- available <- struct{}{}
- select {
- case <-unblocked:
- case <-time.After(time.Second):
- t.Fatal("timeout")
- }
+ select {
+ case <-blocked:
+ default:
+ t.Fatal("should have blocked")
+ }
+ require.True(t, mockCtrl.Satisfied())
- // test teardown
- sender.EXPECT().Close()
- tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
- tc.conn.destroy(nil)
- select {
- case err := <-errChan:
- require.NoError(t, err)
- case <-time.After(3 * time.Second):
- t.Fatal("timeout")
- }
+ // now make room in the send queue
+ sender.EXPECT().WouldBlock().AnyTimes()
+ unblocked := make(chan struct{})
+ tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
+ func(*packetBuffer, protocol.ByteCount, monotime.Time, protocol.Version) (shortHeaderPacket, error) {
+ close(unblocked)
+ return shortHeaderPacket{}, errNothingToPack
+ },
+ )
+ available <- struct{}{}
+
+ synctest.Wait()
+
+ select {
+ case <-unblocked:
+ default:
+ t.Fatal("should have unblocked")
+ }
+
+ // test teardown
+ sender.EXPECT().Close()
+ tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
+ tc.conn.destroy(nil)
+
+ synctest.Wait()
+
+ select {
+ case err := <-errChan:
+ require.NoError(t, err)
+ default:
+ t.Fatal("timeout")
+ }
+ })
}
func getVersionNegotiationPacket(src, dest protocol.ConnectionID, versions []protocol.Version) receivedPacket {
}
func TestConnectionVersionNegotiation(t *testing.T) {
- mockCtrl := gomock.NewController(t)
- tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
- tc := newClientTestConnection(t,
- mockCtrl,
- nil,
- false,
- connectionOptTracer(tr),
- )
-
- tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
- var tracerVersions []logging.Version
- gomock.InOrder(
- tracer.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ArbitraryLenConnectionID, versions []logging.Version) {
- tracerVersions = versions
- }),
- tracer.EXPECT().NegotiatedVersion(protocol.Version2, gomock.Any(), gomock.Any()),
- tc.connRunner.EXPECT().Remove(gomock.Any()),
- )
-
- errChan := make(chan error, 1)
- go func() { errChan <- tc.conn.run() }()
- tc.conn.handlePacket(getVersionNegotiationPacket(
- tc.destConnID,
- tc.srcConnID,
- []protocol.Version{1234, protocol.Version2},
- ))
+ synctest.Test(t, func(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+ tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
+ tc := newClientTestConnection(t,
+ mockCtrl,
+ nil,
+ false,
+ connectionOptTracer(tr),
+ )
- select {
- case err := <-errChan:
- var rerr *errCloseForRecreating
- require.ErrorAs(t, err, &rerr)
- require.Equal(t, rerr.nextVersion, protocol.Version2)
- case <-time.After(time.Second):
- t.Fatal("timeout")
- }
- require.Contains(t, tracerVersions, protocol.Version(1234))
- require.Contains(t, tracerVersions, protocol.Version2)
+ tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
+ var tracerVersions []logging.Version
+ gomock.InOrder(
+ tracer.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ArbitraryLenConnectionID, versions []logging.Version) {
+ tracerVersions = versions
+ }),
+ tracer.EXPECT().NegotiatedVersion(protocol.Version2, gomock.Any(), gomock.Any()),
+ tc.connRunner.EXPECT().Remove(gomock.Any()),
+ )
+
+ errChan := make(chan error, 1)
+ go func() { errChan <- tc.conn.run() }()
+ tc.conn.handlePacket(getVersionNegotiationPacket(
+ tc.destConnID,
+ tc.srcConnID,
+ []protocol.Version{1234, protocol.Version2},
+ ))
+
+ synctest.Wait()
+
+ select {
+ case err := <-errChan:
+ var rerr *errCloseForRecreating
+ require.ErrorAs(t, err, &rerr)
+ require.Equal(t, rerr.nextVersion, protocol.Version2)
+ default:
+ t.Fatal("should have received a Version Negotiation packet")
+ }
+ require.Contains(t, tracerVersions, protocol.Version(1234))
+ require.Contains(t, tracerVersions, protocol.Version2)
+ })
}
func TestConnectionVersionNegotiationNoMatch(t *testing.T) {
- mockCtrl := gomock.NewController(t)
- tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
- tc := newClientTestConnection(t,
- mockCtrl,
- &Config{Versions: []protocol.Version{protocol.Version1}},
- false,
- connectionOptTracer(tr),
- )
+ synctest.Test(t, func(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+ tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
+ tc := newClientTestConnection(t,
+ mockCtrl,
+ &Config{Versions: []protocol.Version{protocol.Version1}},
+ false,
+ connectionOptTracer(tr),
+ )
- tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
- var tracerVersions []logging.Version
- tracer.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(
- func(_, _ protocol.ArbitraryLenConnectionID, versions []logging.Version) { tracerVersions = versions },
- )
- tracer.EXPECT().ClosedConnection(gomock.Any())
- tracer.EXPECT().Close()
- tc.connRunner.EXPECT().Remove(gomock.Any())
+ tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
+ var tracerVersions []logging.Version
+ tracer.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(
+ func(_, _ protocol.ArbitraryLenConnectionID, versions []logging.Version) { tracerVersions = versions },
+ )
+ tracer.EXPECT().ClosedConnection(gomock.Any())
+ tracer.EXPECT().Close()
+ tc.connRunner.EXPECT().Remove(gomock.Any())
+
+ errChan := make(chan error, 1)
+ go func() { errChan <- tc.conn.run() }()
+ tc.conn.handlePacket(getVersionNegotiationPacket(
+ tc.destConnID,
+ tc.srcConnID,
+ []protocol.Version{protocol.Version2},
+ ))
- errChan := make(chan error, 1)
- go func() { errChan <- tc.conn.run() }()
- tc.conn.handlePacket(getVersionNegotiationPacket(
- tc.destConnID,
- tc.srcConnID,
- []protocol.Version{protocol.Version2},
- ))
+ synctest.Wait()
- select {
- case err := <-errChan:
- var verr *VersionNegotiationError
- require.ErrorAs(t, err, &verr)
- require.Contains(t, verr.Theirs, protocol.Version2)
- case <-time.After(time.Second):
- t.Fatal("timeout")
- }
- require.Contains(t, tracerVersions, protocol.Version2)
+ select {
+ case err := <-errChan:
+ var verr *VersionNegotiationError
+ require.ErrorAs(t, err, &verr)
+ require.Contains(t, verr.Theirs, protocol.Version2)
+ default:
+ t.Fatal("should have received a Version Negotiation packet")
+ }
+ require.Contains(t, tracerVersions, protocol.Version2)
+ })
}
func TestConnectionVersionNegotiationInvalidPackets(t *testing.T) {
}
func testConnectionConnectionIDChanges(t *testing.T, sendRetry bool) {
- makeInitialPacket := func(t *testing.T, hdr *wire.ExtendedHeader) []byte {
- t.Helper()
- data, err := hdr.Append(nil, protocol.Version1)
- require.NoError(t, err)
- data = append(data, make([]byte, hdr.Length-protocol.ByteCount(hdr.PacketNumberLen))...)
- return data
- }
+ synctest.Test(t, func(t *testing.T) {
+ makeInitialPacket := func(t *testing.T, hdr *wire.ExtendedHeader) []byte {
+ t.Helper()
+ data, err := hdr.Append(nil, protocol.Version1)
+ require.NoError(t, err)
+ data = append(data, make([]byte, hdr.Length-protocol.ByteCount(hdr.PacketNumberLen))...)
+ return data
+ }
- mockCtrl := gomock.NewController(t)
- tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
- unpacker := NewMockUnpacker(mockCtrl)
- tc := newClientTestConnection(t,
- mockCtrl,
- nil,
- false,
- connectionOptTracer(tr),
- connectionOptUnpacker(unpacker),
- )
+ mockCtrl := gomock.NewController(t)
+ tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
+ unpacker := NewMockUnpacker(mockCtrl)
+ tc := newClientTestConnection(t,
+ mockCtrl,
+ nil,
+ false,
+ connectionOptTracer(tr),
+ connectionOptUnpacker(unpacker),
+ )
- dstConnID := tc.destConnID
- b := make([]byte, 3*10)
- rand.Read(b)
- newConnID := protocol.ParseConnectionID(b[:11])
- newConnID2 := protocol.ParseConnectionID(b[11:20])
+ dstConnID := tc.destConnID
+ b := make([]byte, 3*10)
+ rand.Read(b)
+ newConnID := protocol.ParseConnectionID(b[:11])
+ newConnID2 := protocol.ParseConnectionID(b[11:20])
- tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any())
- tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
+ tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any())
+ tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
- errChan := make(chan error, 1)
- go func() { errChan <- tc.conn.run() }()
+ errChan := make(chan error, 1)
+ go func() { errChan <- tc.conn.run() }()
+
+ require.Equal(t, dstConnID, tc.conn.connIDManager.Get())
+
+ var retryConnID protocol.ConnectionID
+ if sendRetry {
+ retryConnID = protocol.ParseConnectionID(b[20:30])
+ hdrChan := make(chan *wire.Header)
+ tracer.EXPECT().ReceivedRetry(gomock.Any()).Do(func(hdr *wire.Header) { hdrChan <- hdr })
+ tc.packer.EXPECT().SetToken([]byte("foobar"))
- require.Equal(t, dstConnID, tc.conn.connIDManager.Get())
+ tc.conn.handlePacket(getRetryPacket(t, retryConnID, tc.srcConnID, tc.destConnID, []byte("foobar")))
- var retryConnID protocol.ConnectionID
- if sendRetry {
- retryConnID = protocol.ParseConnectionID(b[20:30])
- hdrChan := make(chan *wire.Header)
- tracer.EXPECT().ReceivedRetry(gomock.Any()).Do(func(hdr *wire.Header) { hdrChan <- hdr })
- tc.packer.EXPECT().SetToken([]byte("foobar"))
+ synctest.Wait()
+
+ select {
+ case hdr := <-hdrChan:
+ assert.Equal(t, retryConnID, hdr.SrcConnectionID)
+ assert.Equal(t, []byte("foobar"), hdr.Token)
+ require.Equal(t, retryConnID, tc.conn.connIDManager.Get())
+ default:
+ t.Fatal("should have received the retry packet")
+ }
+ }
+
+ // Send the first packet. The server changes the connection ID to newConnID.
+ hdr1 := wire.ExtendedHeader{
+ Header: wire.Header{
+ SrcConnectionID: newConnID,
+ DestConnectionID: tc.srcConnID,
+ Type: protocol.PacketTypeInitial,
+ Length: 200,
+ Version: protocol.Version1,
+ },
+ PacketNumber: 1,
+ PacketNumberLen: protocol.PacketNumberLen2,
+ }
+ hdr2 := hdr1
+ hdr2.SrcConnectionID = newConnID2
+
+ receivedFirst := make(chan struct{})
+ gomock.InOrder(
+ unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(
+ &unpackedPacket{
+ hdr: &hdr1,
+ encryptionLevel: protocol.EncryptionInitial,
+ }, nil,
+ ),
+ tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(
+ func(*wire.ExtendedHeader, protocol.ByteCount, protocol.ECN, []logging.Frame) { close(receivedFirst) },
+ ),
+ )
+
+ tc.conn.handlePacket(receivedPacket{data: makeInitialPacket(t, &hdr1), buffer: getPacketBuffer(), rcvTime: monotime.Now(), remoteAddr: tc.remoteAddr})
+
+ synctest.Wait()
- tc.conn.handlePacket(getRetryPacket(t, retryConnID, tc.srcConnID, tc.destConnID, []byte("foobar")))
select {
- case hdr := <-hdrChan:
- assert.Equal(t, retryConnID, hdr.SrcConnectionID)
- assert.Equal(t, []byte("foobar"), hdr.Token)
- require.Equal(t, retryConnID, tc.conn.connIDManager.Get())
- case <-time.After(time.Second):
- t.Fatal("timeout")
+ case <-receivedFirst:
+ require.Equal(t, newConnID, tc.conn.connIDManager.Get())
+ default:
+ t.Fatal("should have received the first packet")
}
- }
- // Send the first packet. The server changes the connection ID to newConnID.
- hdr1 := wire.ExtendedHeader{
- Header: wire.Header{
- SrcConnectionID: newConnID,
- DestConnectionID: tc.srcConnID,
- Type: protocol.PacketTypeInitial,
- Length: 200,
- Version: protocol.Version1,
- },
- PacketNumber: 1,
- PacketNumberLen: protocol.PacketNumberLen2,
- }
- hdr2 := hdr1
- hdr2.SrcConnectionID = newConnID2
+ // Send the second packet. We refuse to accept it, because the connection ID is changed again.
+ dropped := make(chan struct{})
+ tracer.EXPECT().DroppedPacket(logging.PacketTypeInitial, gomock.Any(), gomock.Any(), logging.PacketDropUnknownConnectionID).Do(
+ func(logging.PacketType, protocol.PacketNumber, protocol.ByteCount, logging.PacketDropReason) {
+ close(dropped)
+ },
+ )
- receivedFirst := make(chan struct{})
- gomock.InOrder(
- unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(
- &unpackedPacket{
- hdr: &hdr1,
- encryptionLevel: protocol.EncryptionInitial,
- }, nil,
- ),
- tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(
- func(*wire.ExtendedHeader, protocol.ByteCount, protocol.ECN, []logging.Frame) { close(receivedFirst) },
- ),
- )
+ tc.conn.handlePacket(receivedPacket{data: makeInitialPacket(t, &hdr2), buffer: getPacketBuffer(), rcvTime: monotime.Now(), remoteAddr: tc.remoteAddr})
- tc.conn.handlePacket(receivedPacket{data: makeInitialPacket(t, &hdr1), buffer: getPacketBuffer(), rcvTime: monotime.Now(), remoteAddr: tc.remoteAddr})
+ synctest.Wait()
- select {
- case <-receivedFirst:
- require.Equal(t, newConnID, tc.conn.connIDManager.Get())
- case <-time.After(time.Second):
- t.Fatal("timeout")
- }
+ select {
+ case <-dropped:
+ // the connection ID should not have changed
+ require.Equal(t, newConnID, tc.conn.connIDManager.Get())
+ default:
+ t.Fatal("should have dropped the packet")
+ }
- // Send the second packet. We refuse to accept it, because the connection ID is changed again.
- dropped := make(chan struct{})
- tracer.EXPECT().DroppedPacket(logging.PacketTypeInitial, gomock.Any(), gomock.Any(), logging.PacketDropUnknownConnectionID).Do(
- func(logging.PacketType, protocol.PacketNumber, protocol.ByteCount, logging.PacketDropReason) {
- close(dropped)
- },
- )
+ // test teardown
+ tracer.EXPECT().ClosedConnection(gomock.Any())
+ tracer.EXPECT().Close()
+ tc.connRunner.EXPECT().Remove(gomock.Any())
+ tc.conn.destroy(nil)
- tc.conn.handlePacket(receivedPacket{data: makeInitialPacket(t, &hdr2), buffer: getPacketBuffer(), rcvTime: monotime.Now(), remoteAddr: tc.remoteAddr})
- select {
- case <-dropped:
- // the connection ID should not have changed
- require.Equal(t, newConnID, tc.conn.connIDManager.Get())
- case <-time.After(time.Second):
- t.Fatal("timeout")
- }
+ synctest.Wait()
- // test teardown
- tracer.EXPECT().ClosedConnection(gomock.Any())
- tracer.EXPECT().Close()
- tc.connRunner.EXPECT().Remove(gomock.Any())
- tc.conn.destroy(nil)
- select {
- case err := <-errChan:
- require.NoError(t, err)
- case <-time.After(time.Second):
- t.Fatal("timeout")
- }
+ select {
+ case err := <-errChan:
+ require.NoError(t, err)
+ default:
+ t.Fatal("should have shut down")
+ }
+ })
}
// When the connection is closed before sending the first packet,
// This can happen if there's something wrong the tls.Config, and
// crypto/tls refuses to start the handshake.
func TestConnectionEarlyClose(t *testing.T) {
- mockCtrl := gomock.NewController(t)
- tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
- cryptoSetup := mocks.NewMockCryptoSetup(mockCtrl)
- tc := newClientTestConnection(t,
- mockCtrl,
- nil,
- false,
- connectionOptTracer(tr),
- connectionOptCryptoSetup(cryptoSetup),
- )
+ synctest.Test(t, func(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+ tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
+ cryptoSetup := mocks.NewMockCryptoSetup(mockCtrl)
+ tc := newClientTestConnection(t,
+ mockCtrl,
+ nil,
+ false,
+ connectionOptTracer(tr),
+ connectionOptCryptoSetup(cryptoSetup),
+ )
- tc.conn.sentFirstPacket = false
- tracer.EXPECT().ClosedConnection(gomock.Any())
- tracer.EXPECT().Close()
- cryptoSetup.EXPECT().StartHandshake(gomock.Any()).Do(func(context.Context) error {
- tc.conn.closeLocal(errors.New("early error"))
- return nil
- })
- cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
- cryptoSetup.EXPECT().Close()
- tc.connRunner.EXPECT().Remove(gomock.Any())
+ tc.conn.sentFirstPacket = false
+ tracer.EXPECT().ClosedConnection(gomock.Any())
+ tracer.EXPECT().Close()
+ cryptoSetup.EXPECT().StartHandshake(gomock.Any()).Do(func(context.Context) error {
+ tc.conn.closeLocal(errors.New("early error"))
+ return nil
+ })
+ cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
+ cryptoSetup.EXPECT().Close()
+ tc.connRunner.EXPECT().Remove(gomock.Any())
- errChan := make(chan error, 1)
- go func() { errChan <- tc.conn.run() }()
+ errChan := make(chan error, 1)
+ go func() { errChan <- tc.conn.run() }()
- select {
- case err := <-errChan:
- require.Error(t, err)
- require.ErrorContains(t, err, "early error")
- case <-time.After(time.Second):
- t.Fatal("timeout")
- }
+ synctest.Wait()
+
+ select {
+ case err := <-errChan:
+ require.Error(t, err)
+ require.ErrorContains(t, err, "early error")
+ default:
+ t.Fatal("should have shut down")
+ }
+ })
}
func TestConnectionPathValidation(t *testing.T) {
}
func testConnectionPathValidation(t *testing.T, isNATRebinding bool) {
- mockCtrl := gomock.NewController(t)
- unpacker := NewMockUnpacker(mockCtrl)
- tc := newServerTestConnection(
- t,
- mockCtrl,
- nil,
- false,
- connectionOptUnpacker(unpacker),
- connectionOptHandshakeConfirmed(),
- connectionOptRTT(time.Second),
- )
- require.NoError(t, tc.conn.handleTransportParameters(&wire.TransportParameters{MaxUDPPayloadSize: 1456}))
-
- newRemoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 1, 1), Port: 1234}
- require.NotEqual(t, tc.remoteAddr, newRemoteAddr)
-
- errChan := make(chan error, 1)
- go func() { errChan <- tc.conn.run() }()
-
- probeSent := make(chan struct{})
- var pathChallenge *wire.PathChallengeFrame
- payload := []byte{0} // PADDING frame
- if isNATRebinding {
- payload = []byte{1} // PING frame
- }
- gomock.InOrder(
- unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(
- protocol.PacketNumber(10), protocol.PacketNumberLen2, protocol.KeyPhaseZero, payload, nil,
- ),
- tc.packer.EXPECT().PackPathProbePacket(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
- func(_ protocol.ConnectionID, frames []ackhandler.Frame, _ protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
- pathChallenge = frames[0].Frame.(*wire.PathChallengeFrame)
- return shortHeaderPacket{IsPathProbePacket: true}, getPacketBuffer(), nil
- },
- ),
- tc.sendConn.EXPECT().WriteTo(gomock.Any(), newRemoteAddr).DoAndReturn(
- func([]byte, net.Addr) error { close(probeSent); return nil },
- ),
- tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(
- shortHeaderPacket{}, errNothingToPack,
- ),
- )
+ synctest.Test(t, func(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+ unpacker := NewMockUnpacker(mockCtrl)
+ tc := newServerTestConnection(
+ t,
+ mockCtrl,
+ nil,
+ false,
+ connectionOptUnpacker(unpacker),
+ connectionOptHandshakeConfirmed(),
+ connectionOptRTT(time.Second),
+ )
+ require.NoError(t, tc.conn.handleTransportParameters(&wire.TransportParameters{MaxUDPPayloadSize: 1456}))
- tc.conn.handlePacket(receivedPacket{
- data: make([]byte, 10),
- buffer: getPacketBuffer(),
- remoteAddr: newRemoteAddr,
- rcvTime: monotime.Now(),
- })
+ newRemoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 1, 1), Port: 1234}
+ require.NotEqual(t, tc.remoteAddr, newRemoteAddr)
- select {
- case <-probeSent:
- case <-time.After(time.Second):
- t.Fatal("timeout")
- }
+ errChan := make(chan error, 1)
+ go func() { errChan <- tc.conn.run() }()
- // Receive a packed containing a PATH_RESPONSE frame.
- // Only if the first packet received on the path was a probing packet
- // (i.e. we're dealing with a NAT rebinding), this makes us switch to the new path.
- migrated := make(chan struct{})
- data, err := (&wire.PathResponseFrame{Data: pathChallenge.Data}).Append(nil, protocol.Version1)
- require.NoError(t, err)
- calls := []any{
- unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(
- protocol.PacketNumber(11), protocol.PacketNumberLen2, protocol.KeyPhaseZero, data, nil,
- ),
- }
- if isNATRebinding {
- calls = append(calls,
- tc.sendConn.EXPECT().ChangeRemoteAddr(newRemoteAddr, gomock.Any()).Do(
- func(net.Addr, packetInfo) { close(migrated) },
+ probeSent := make(chan struct{})
+ var pathChallenge *wire.PathChallengeFrame
+ payload := []byte{0} // PADDING frame
+ if isNATRebinding {
+ payload = []byte{1} // PING frame
+ }
+ gomock.InOrder(
+ unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(
+ protocol.PacketNumber(10), protocol.PacketNumberLen2, protocol.KeyPhaseZero, payload, nil,
+ ),
+ tc.packer.EXPECT().PackPathProbePacket(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
+ func(_ protocol.ConnectionID, frames []ackhandler.Frame, _ protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
+ pathChallenge = frames[0].Frame.(*wire.PathChallengeFrame)
+ return shortHeaderPacket{IsPathProbePacket: true}, getPacketBuffer(), nil
+ },
+ ),
+ tc.sendConn.EXPECT().WriteTo(gomock.Any(), newRemoteAddr).DoAndReturn(
+ func([]byte, net.Addr) error { close(probeSent); return nil },
+ ),
+ tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(
+ shortHeaderPacket{}, errNothingToPack,
),
)
- }
- calls = append(calls,
- tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(
- shortHeaderPacket{}, errNothingToPack,
- ).MaxTimes(1),
- )
- gomock.InOrder(calls...)
- require.Equal(t, tc.remoteAddr, tc.conn.RemoteAddr())
- // the PATH_RESPONSE can be sent on the old path, if the client is just probing the new path
- addr := tc.remoteAddr
- if isNATRebinding {
- addr = newRemoteAddr
- }
- tc.conn.handlePacket(receivedPacket{
- data: make([]byte, 100),
- buffer: getPacketBuffer(),
- remoteAddr: addr,
- rcvTime: monotime.Now(),
- })
- if !isNATRebinding {
- // If the first packet was a probing packet, we only switch to the new path when we
- // receive a non-probing packet on that path.
+ tc.conn.handlePacket(receivedPacket{
+ data: make([]byte, 10),
+ buffer: getPacketBuffer(),
+ remoteAddr: newRemoteAddr,
+ rcvTime: monotime.Now(),
+ })
+
+ synctest.Wait()
+
select {
- case <-migrated:
- t.Fatal("didn't expect a migration yet")
- case <-time.After(scaleDuration(10 * time.Millisecond)):
+ case <-probeSent:
+ case <-time.After(time.Second):
+ t.Fatal("timeout")
}
- payload := []byte{1} // PING frame
- payload, err = (&wire.PathResponseFrame{Data: pathChallenge.Data}).Append(payload, protocol.Version1)
+ // Receive a packed containing a PATH_RESPONSE frame.
+ // Only if the first packet received on the path was a probing packet
+ // (i.e. we're dealing with a NAT rebinding), this makes us switch to the new path.
+ migrated := make(chan struct{})
+ data, err := (&wire.PathResponseFrame{Data: pathChallenge.Data}).Append(nil, protocol.Version1)
require.NoError(t, err)
- gomock.InOrder(
+ calls := []any{
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(
- protocol.PacketNumber(12), protocol.PacketNumberLen2, protocol.KeyPhaseZero, payload, nil,
- ),
- tc.sendConn.EXPECT().ChangeRemoteAddr(newRemoteAddr, gomock.Any()).Do(
- func(net.Addr, packetInfo) { close(migrated) },
+ protocol.PacketNumber(11), protocol.PacketNumberLen2, protocol.KeyPhaseZero, data, nil,
),
+ }
+ if isNATRebinding {
+ calls = append(calls,
+ tc.sendConn.EXPECT().ChangeRemoteAddr(newRemoteAddr, gomock.Any()).Do(
+ func(net.Addr, packetInfo) { close(migrated) },
+ ),
+ )
+ }
+ calls = append(calls,
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(
shortHeaderPacket{}, errNothingToPack,
).MaxTimes(1),
)
+ gomock.InOrder(calls...)
+ require.Equal(t, tc.remoteAddr, tc.conn.RemoteAddr())
+ // the PATH_RESPONSE can be sent on the old path, if the client is just probing the new path
+ addr := tc.remoteAddr
+ if isNATRebinding {
+ addr = newRemoteAddr
+ }
tc.conn.handlePacket(receivedPacket{
data: make([]byte, 100),
buffer: getPacketBuffer(),
- remoteAddr: newRemoteAddr,
+ remoteAddr: addr,
rcvTime: monotime.Now(),
})
- }
- select {
- case <-migrated:
- case <-time.After(time.Second):
- t.Fatal("timeout")
- }
+ synctest.Wait()
- // test teardown
- tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
- tc.conn.destroy(nil)
- select {
- case <-errChan:
- case <-time.After(time.Second):
- t.Fatal("timeout")
- }
+ if !isNATRebinding {
+ // If the first packet was a probing packet, we only switch to the new path when we
+ // receive a non-probing packet on that path.
+ select {
+ case <-migrated:
+ t.Fatal("didn't expect a migration yet")
+ default:
+ }
+
+ payload := []byte{1} // PING frame
+ payload, err = (&wire.PathResponseFrame{Data: pathChallenge.Data}).Append(payload, protocol.Version1)
+ require.NoError(t, err)
+ gomock.InOrder(
+ unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(
+ protocol.PacketNumber(12), protocol.PacketNumberLen2, protocol.KeyPhaseZero, payload, nil,
+ ),
+ tc.sendConn.EXPECT().ChangeRemoteAddr(newRemoteAddr, gomock.Any()).Do(
+ func(net.Addr, packetInfo) { close(migrated) },
+ ),
+ tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(
+ shortHeaderPacket{}, errNothingToPack,
+ ).MaxTimes(1),
+ )
+ tc.conn.handlePacket(receivedPacket{
+ data: make([]byte, 100),
+ buffer: getPacketBuffer(),
+ remoteAddr: newRemoteAddr,
+ rcvTime: monotime.Now(),
+ })
+ }
+
+ synctest.Wait()
+
+ select {
+ case <-migrated:
+ default:
+ t.Fatal("should have migrated")
+ }
+
+ // test teardown
+ tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
+ tc.conn.destroy(nil)
+
+ synctest.Wait()
+
+ select {
+ case err := <-errChan:
+ require.NoError(t, err)
+ default:
+ t.Fatal("should have shut down")
+ }
+ })
}
func TestConnectionMigrationServer(t *testing.T) {