]> git.feebdaed.xyz Git - 0xmirror/quic-go.git/commitdiff
ackhandler: avoid storing packet number in packet struct (#5312)
authorMarten Seemann <martenseemann@gmail.com>
Fri, 29 Aug 2025 02:18:45 +0000 (10:18 +0800)
committerGitHub <noreply@github.com>
Fri, 29 Aug 2025 02:18:45 +0000 (04:18 +0200)
* ackhandler: optimize memory layout of packet struct

The packet number can be derived from the position that this packet is
stored at in the packets slice in the sent packet history. There is no
need to store the packet number, saving 8 bytes per packet.

* ackhandler: avoid copying the packet struct

internal/ackhandler/ecn.go
internal/ackhandler/ecn_test.go
internal/ackhandler/mock_ecn_handler_test.go
internal/ackhandler/packet.go
internal/ackhandler/sent_packet_handler.go
internal/ackhandler/sent_packet_handler_test.go
internal/ackhandler/sent_packet_history.go
internal/ackhandler/sent_packet_history_test.go

index 68415ac6c08dab81d6f9a2b9555c9de6e54bb655..1b462a606550aa8f6409c16908a3352ce1d26ac7 100644 (file)
@@ -24,7 +24,7 @@ const numECNTestingPackets = 10
 type ecnHandler interface {
        SentPacket(protocol.PacketNumber, protocol.ECN)
        Mode() protocol.ECN
-       HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64) (congested bool)
+       HandleNewlyAcked(packets []packetWithPacketNumber, ect0, ect1, ecnce int64) (congested bool)
        LostPacket(protocol.PacketNumber)
 }
 
@@ -144,7 +144,7 @@ func (e *ecnTracker) LostPacket(pn protocol.PacketNumber) {
 // HandleNewlyAcked handles the ECN counts on an ACK frame.
 // It must only be called for ACK frames that increase the largest acknowledged packet number,
 // see section 13.4.2.1 of RFC 9000.
-func (e *ecnTracker) HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64) (congested bool) {
+func (e *ecnTracker) HandleNewlyAcked(packets []packetWithPacketNumber, ect0, ect1, ecnce int64) (congested bool) {
        if e.state == ecnStateFailed {
                return false
        }
index c9a04ae1361debda4a1a662008431db9091094f0..5346cfe923ab5be2f581f6b7f3b020afb75f6f88 100644 (file)
@@ -12,10 +12,10 @@ import (
        "go.uber.org/mock/gomock"
 )
 
-func getAckedPackets(pns ...protocol.PacketNumber) []*packet {
-       var packets []*packet
+func getAckedPackets(pns ...protocol.PacketNumber) []packetWithPacketNumber {
+       var packets []packetWithPacketNumber
        for _, p := range pns {
-               packets = append(packets, &packet{PacketNumber: p})
+               packets = append(packets, packetWithPacketNumber{PacketNumber: p})
        }
        return packets
 }
@@ -129,7 +129,12 @@ func TestECNValidationFailures(t *testing.T) {
        })
 }
 
-func testECNValidationFailure(t *testing.T, ackedPackets []*packet, ect0, ect1, ecnce int64, expectedTrigger logging.ECNStateTrigger) {
+func testECNValidationFailure(
+       t *testing.T,
+       ackedPackets []packetWithPacketNumber,
+       ect0, ect1, ecnce int64,
+       expectedTrigger logging.ECNStateTrigger,
+) {
        mockCtrl := gomock.NewController(t)
        tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
        ecnTracker := newECNTracker(utils.DefaultLogger, tr)
index 0a8cc9bda7182cd14a77630ce9325d1f0b9aa142..2e69bfa1de1ffb0dd40f0061539fb865e2058afb 100644 (file)
@@ -41,7 +41,7 @@ func (m *MockECNHandler) EXPECT() *MockECNHandlerMockRecorder {
 }
 
 // HandleNewlyAcked mocks base method.
-func (m *MockECNHandler) HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64) bool {
+func (m *MockECNHandler) HandleNewlyAcked(packets []packetWithPacketNumber, ect0, ect1, ecnce int64) bool {
        m.ctrl.T.Helper()
        ret := m.ctrl.Call(m, "HandleNewlyAcked", packets, ect0, ect1, ecnce)
        ret0, _ := ret[0].(bool)
@@ -67,13 +67,13 @@ func (c *MockECNHandlerHandleNewlyAckedCall) Return(congested bool) *MockECNHand
 }
 
 // Do rewrite *gomock.Call.Do
-func (c *MockECNHandlerHandleNewlyAckedCall) Do(f func([]*packet, int64, int64, int64) bool) *MockECNHandlerHandleNewlyAckedCall {
+func (c *MockECNHandlerHandleNewlyAckedCall) Do(f func([]packetWithPacketNumber, int64, int64, int64) bool) *MockECNHandlerHandleNewlyAckedCall {
        c.Call = c.Call.Do(f)
        return c
 }
 
 // DoAndReturn rewrite *gomock.Call.DoAndReturn
-func (c *MockECNHandlerHandleNewlyAckedCall) DoAndReturn(f func([]*packet, int64, int64, int64) bool) *MockECNHandlerHandleNewlyAckedCall {
+func (c *MockECNHandlerHandleNewlyAckedCall) DoAndReturn(f func([]packetWithPacketNumber, int64, int64, int64) bool) *MockECNHandlerHandleNewlyAckedCall {
        c.Call = c.Call.DoAndReturn(f)
        return c
 }
index bd225a6ff5f46af8f30d42c11dedf91d508c64fa..e2f4e7d2e8b47624d818fb16349eddcca0c24826 100644 (file)
@@ -7,10 +7,14 @@ import (
        "github.com/quic-go/quic-go/internal/protocol"
 )
 
+type packetWithPacketNumber struct {
+       PacketNumber protocol.PacketNumber
+       *packet
+}
+
 // A Packet is a packet
 type packet struct {
        SendTime        time.Time
-       PacketNumber    protocol.PacketNumber
        StreamFrames    []StreamFrame
        Frames          []Frame
        LargestAcked    protocol.PacketNumber // InvalidPacketNumber if the packet doesn't contain an ACK
@@ -33,7 +37,6 @@ var packetPool = sync.Pool{New: func() any { return &packet{} }}
 
 func getPacket() *packet {
        p := packetPool.Get().(*packet)
-       p.PacketNumber = 0
        p.StreamFrames = nil
        p.Frames = nil
        p.LargestAcked = 0
index c01ca91c7f443abbcf4bed4fa405f3313e0b3316..622814a7ebe2721a0a7e9008e7a8462a857dd752 100644 (file)
@@ -84,7 +84,7 @@ type sentPacketHandler struct {
        // Only applies to the application-data packet number space.
        lowestNotConfirmedAcked protocol.PacketNumber
 
-       ackedPackets []*packet // to avoid allocations in detectAndRemoveAckedPackets
+       ackedPackets []packetWithPacketNumber // to avoid allocations in detectAndRemoveAckedPackets
 
        bytesInFlight protocol.ByteCount
 
@@ -181,7 +181,7 @@ func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel, now t
                if pnSpace == nil {
                        return
                }
-               for p := range pnSpace.history.Packets() {
+               for _, p := range pnSpace.history.Packets() {
                        h.removeFromBytesInFlight(p)
                }
        }
@@ -200,12 +200,12 @@ func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel, now t
                // and not when the client drops 0-RTT keys when the handshake completes.
                // When 0-RTT is rejected, all application data sent so far becomes invalid.
                // Delete the packets from the history and remove them from bytes_in_flight.
-               for p := range h.appDataPackets.history.Packets() {
+               for pn, p := range h.appDataPackets.history.Packets() {
                        if p.EncryptionLevel != protocol.Encryption0RTT && !p.skippedPacket {
                                break
                        }
                        h.removeFromBytesInFlight(p)
-                       h.appDataPackets.history.Remove(p.PacketNumber)
+                       h.appDataPackets.history.Remove(pn)
                }
        default:
                panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel))
@@ -275,12 +275,11 @@ func (h *sentPacketHandler) SentPacket(
        if isPathProbePacket {
                p := getPacket()
                p.SendTime = t
-               p.PacketNumber = pn
                p.EncryptionLevel = encLevel
                p.Length = size
                p.Frames = frames
                p.isPathProbePacket = true
-               pnSpace.history.SentPathProbePacket(p)
+               pnSpace.history.SentPathProbePacket(pn, p)
                h.setLossDetectionTimer(t)
                return
        }
@@ -307,7 +306,6 @@ func (h *sentPacketHandler) SentPacket(
 
        p := getPacket()
        p.SendTime = t
-       p.PacketNumber = pn
        p.EncryptionLevel = encLevel
        p.Length = size
        p.LargestAcked = largestAcked
@@ -316,7 +314,7 @@ func (h *sentPacketHandler) SentPacket(
        p.IsPathMTUProbePacket = isPathMTUProbePacket
        p.includedInBytesInFlight = true
 
-       pnSpace.history.SentAckElicitingPacket(p)
+       pnSpace.history.SentAckElicitingPacket(pn, p)
        if h.tracer != nil && h.tracer.UpdatedMetrics != nil {
                h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight())
        }
@@ -399,9 +397,9 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En
                if p.EncryptionLevel == protocol.Encryption1RTT {
                        acked1RTTPacket = true
                }
-               h.removeFromBytesInFlight(p)
+               h.removeFromBytesInFlight(p.packet)
                if !p.isPathProbePacket {
-                       putPacket(p)
+                       putPacket(p.packet)
                }
        }
        // After this point, we must not use ackedPackets any longer!
@@ -432,7 +430,7 @@ func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNu
 }
 
 // Packets are returned in ascending packet number order.
-func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encLevel protocol.EncryptionLevel) ([]*packet, error) {
+func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encLevel protocol.EncryptionLevel) ([]packetWithPacketNumber, error) {
        pnSpace := h.getPacketNumberSpace(encLevel)
        ackRangeIndex := 0
        if len(h.ackedPackets) > 0 {
@@ -440,45 +438,45 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL
        }
        lowestAcked := ack.LowestAcked()
        largestAcked := ack.LargestAcked()
-       for p := range pnSpace.history.Packets() {
+       for pn, p := range pnSpace.history.Packets() {
                // ignore packets below the lowest acked
-               if p.PacketNumber < lowestAcked {
+               if pn < lowestAcked {
                        continue
                }
-               if p.PacketNumber > largestAcked {
+               if pn > largestAcked {
                        break
                }
 
                if ack.HasMissingRanges() {
                        ackRange := ack.AckRanges[len(ack.AckRanges)-1-ackRangeIndex]
 
-                       for p.PacketNumber > ackRange.Largest && ackRangeIndex < len(ack.AckRanges)-1 {
+                       for pn > ackRange.Largest && ackRangeIndex < len(ack.AckRanges)-1 {
                                ackRangeIndex++
                                ackRange = ack.AckRanges[len(ack.AckRanges)-1-ackRangeIndex]
                        }
 
-                       if p.PacketNumber < ackRange.Smallest { // packet not contained in ACK range
+                       if pn < ackRange.Smallest { // packet not contained in ACK range
                                continue
                        }
-                       if p.PacketNumber > ackRange.Largest {
-                               return nil, fmt.Errorf("BUG: ackhandler would have acked wrong packet %d, while evaluating range %d -> %d", p.PacketNumber, ackRange.Smallest, ackRange.Largest)
+                       if pn > ackRange.Largest {
+                               return nil, fmt.Errorf("BUG: ackhandler would have acked wrong packet %d, while evaluating range %d -> %d", pn, ackRange.Smallest, ackRange.Largest)
                        }
                }
                if p.skippedPacket {
                        return nil, &qerr.TransportError{
                                ErrorCode:    qerr.ProtocolViolation,
-                               ErrorMessage: fmt.Sprintf("received an ACK for skipped packet number: %d (%s)", p.PacketNumber, encLevel),
+                               ErrorMessage: fmt.Sprintf("received an ACK for skipped packet number: %d (%s)", pn, encLevel),
                        }
                }
                if p.isPathProbePacket {
-                       probePacket := pnSpace.history.RemovePathProbe(p.PacketNumber)
+                       probePacket := pnSpace.history.RemovePathProbe(pn)
                        // the probe packet might already have been declared lost
                        if probePacket != nil {
-                               h.ackedPackets = append(h.ackedPackets, probePacket)
+                               h.ackedPackets = append(h.ackedPackets, packetWithPacketNumber{PacketNumber: pn, packet: probePacket})
                        }
                        continue
                }
-               h.ackedPackets = append(h.ackedPackets, p)
+               h.ackedPackets = append(h.ackedPackets, packetWithPacketNumber{PacketNumber: pn, packet: p})
        }
        if h.logger.Debug() && len(h.ackedPackets) > 0 {
                pns := make([]protocol.PacketNumber, len(h.ackedPackets))
@@ -623,7 +621,7 @@ func (h *sentPacketHandler) lossDetectionTime(now time.Time) alarmTimer {
 
        var pathProbeLossTime time.Time
        if h.appDataPackets.history.HasOutstandingPathProbes() {
-               if p := h.appDataPackets.history.FirstOutstandingPathProbe(); p != nil {
+               if _, p := h.appDataPackets.history.FirstOutstandingPathProbe(); p != nil {
                        pathProbeLossTime = p.SendTime.Add(pathProbePacketLossTimeout)
                }
        }
@@ -661,10 +659,10 @@ func (h *sentPacketHandler) detectLostPathProbes(now time.Time) {
        }
        lossTime := now.Add(-pathProbePacketLossTimeout)
        // RemovePathProbe cannot be called while iterating.
-       var lostPathProbes []*packet
-       for p := range h.appDataPackets.history.PathProbes() {
+       var lostPathProbes []packetWithPacketNumber
+       for pn, p := range h.appDataPackets.history.PathProbes() {
                if !p.SendTime.After(lossTime) {
-                       lostPathProbes = append(lostPathProbes, p)
+                       lostPathProbes = append(lostPathProbes, packetWithPacketNumber{PacketNumber: pn, packet: p})
                }
        }
        for _, p := range lostPathProbes {
@@ -689,8 +687,8 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E
        lostSendTime := now.Add(-lossDelay)
 
        priorInFlight := h.bytesInFlight
-       for p := range pnSpace.history.Packets() {
-               if p.PacketNumber > pnSpace.largestAcked {
+       for pn, p := range pnSpace.history.Packets() {
+               if pn > pnSpace.largestAcked {
                        break
                }
 
@@ -700,41 +698,41 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E
                        packetLost = true
                        if isRegularPacket {
                                if h.logger.Debug() {
-                                       h.logger.Debugf("\tlost packet %d (time threshold)", p.PacketNumber)
+                                       h.logger.Debugf("\tlost packet %d (time threshold)", pn)
                                }
                                if h.tracer != nil && h.tracer.LostPacket != nil {
-                                       h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossTimeThreshold)
+                                       h.tracer.LostPacket(p.EncryptionLevel, pn, logging.PacketLossTimeThreshold)
                                }
                        }
-               } else if pnSpace.largestAcked >= p.PacketNumber+packetThreshold {
+               } else if pnSpace.largestAcked >= pn+packetThreshold {
                        packetLost = true
                        if isRegularPacket {
                                if h.logger.Debug() {
-                                       h.logger.Debugf("\tlost packet %d (reordering threshold)", p.PacketNumber)
+                                       h.logger.Debugf("\tlost packet %d (reordering threshold)", pn)
                                }
                                if h.tracer != nil && h.tracer.LostPacket != nil {
-                                       h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossReorderingThreshold)
+                                       h.tracer.LostPacket(p.EncryptionLevel, pn, logging.PacketLossReorderingThreshold)
                                }
                        }
                } else if pnSpace.lossTime.IsZero() {
                        // Note: This conditional is only entered once per call
                        lossTime := p.SendTime.Add(lossDelay)
                        if h.logger.Debug() {
-                               h.logger.Debugf("\tsetting loss timer for packet %d (%s) to %s (in %s)", p.PacketNumber, encLevel, lossDelay, lossTime)
+                               h.logger.Debugf("\tsetting loss timer for packet %d (%s) to %s (in %s)", pn, encLevel, lossDelay, lossTime)
                        }
                        pnSpace.lossTime = lossTime
                }
                if packetLost {
-                       pnSpace.history.DeclareLost(p.PacketNumber)
+                       pnSpace.history.DeclareLost(pn)
                        if isRegularPacket {
                                // the bytes in flight need to be reduced no matter if the frames in this packet will be retransmitted
                                h.removeFromBytesInFlight(p)
                                h.queueFramesForRetransmission(p)
                                if !p.IsPathMTUProbePacket {
-                                       h.congestion.OnCongestionEvent(p.PacketNumber, p.Length, priorInFlight)
+                                       h.congestion.OnCongestionEvent(pn, p.Length, priorInFlight)
                                }
                                if encLevel == protocol.Encryption1RTT && h.ecnTracker != nil {
-                                       h.ecnTracker.LostPacket(p.PacketNumber)
+                                       h.ecnTracker.LostPacket(pn)
                                }
                        }
                }
@@ -913,7 +911,7 @@ func (h *sentPacketHandler) isAmplificationLimited() bool {
 
 func (h *sentPacketHandler) QueueProbePacket(encLevel protocol.EncryptionLevel) bool {
        pnSpace := h.getPacketNumberSpace(encLevel)
-       p := pnSpace.history.FirstOutstanding()
+       pn, p := pnSpace.history.FirstOutstanding()
        if p == nil {
                return false
        }
@@ -921,7 +919,7 @@ func (h *sentPacketHandler) QueueProbePacket(encLevel protocol.EncryptionLevel)
        // TODO: don't declare the packet lost here.
        // Keep track of acknowledged frames instead.
        h.removeFromBytesInFlight(p)
-       pnSpace.history.DeclareLost(p.PacketNumber)
+       pnSpace.history.DeclareLost(pn)
        return true
 }
 
@@ -946,7 +944,7 @@ func (h *sentPacketHandler) queueFramesForRetransmission(p *packet) {
 func (h *sentPacketHandler) ResetForRetry(now time.Time) {
        h.bytesInFlight = 0
        var firstPacketSendTime time.Time
-       for p := range h.initialPackets.history.Packets() {
+       for _, p := range h.initialPackets.history.Packets() {
                if firstPacketSendTime.IsZero() {
                        firstPacketSendTime = p.SendTime
                }
@@ -956,7 +954,7 @@ func (h *sentPacketHandler) ResetForRetry(now time.Time) {
        }
        // All application data packets sent at this point are 0-RTT packets.
        // In the case of a Retry, we can assume that the server dropped all of them.
-       for p := range h.appDataPackets.history.Packets() {
+       for _, p := range h.appDataPackets.history.Packets() {
                if !p.declaredLost && !p.skippedPacket {
                        h.queueFramesForRetransmission(p)
                }
@@ -991,15 +989,15 @@ func (h *sentPacketHandler) ResetForRetry(now time.Time) {
 
 func (h *sentPacketHandler) MigratedPath(now time.Time, initialMaxDatagramSize protocol.ByteCount) {
        h.rttStats.ResetForPathMigration()
-       for p := range h.appDataPackets.history.Packets() {
-               h.appDataPackets.history.DeclareLost(p.PacketNumber)
+       for pn, p := range h.appDataPackets.history.Packets() {
+               h.appDataPackets.history.DeclareLost(pn)
                if !p.skippedPacket && !p.isPathProbePacket {
                        h.removeFromBytesInFlight(p)
                        h.queueFramesForRetransmission(p)
                }
        }
-       for p := range h.appDataPackets.history.PathProbes() {
-               h.appDataPackets.history.RemovePathProbe(p.PacketNumber)
+       for pn := range h.appDataPackets.history.PathProbes() {
+               h.appDataPackets.history.RemovePathProbe(pn)
        }
        h.congestion = congestion.NewCubicSender(
                congestion.DefaultClock{},
index 65309ffbce20595ce16b07798e7e50f63e561f5f..13dcdc0b850d35aa2814df916cb662ffd13212f4 100644 (file)
@@ -1058,10 +1058,10 @@ func TestSentPacketHandlerECN(t *testing.T) {
        // Receive an ACK with a short RTT, such that the first packet is lost.
        cong.EXPECT().OnCongestionEvent(gomock.Any(), gomock.Any(), gomock.Any())
        ecnHandler.EXPECT().LostPacket(pns[0])
-       ecnHandler.EXPECT().HandleNewlyAcked(gomock.Any(), int64(10), int64(11), int64(12)).DoAndReturn(func(packets []*packet, _, _, _ int64) bool {
+       ecnHandler.EXPECT().HandleNewlyAcked(gomock.Any(), int64(10), int64(11), int64(12)).DoAndReturn(func(packets []packetWithPacketNumber, _, _, _ int64) bool {
                require.Len(t, packets, 2)
-               require.Equal(t, packets[0].PacketNumber, pns[2])
-               require.Equal(t, packets[1].PacketNumber, pns[3])
+               require.Equal(t, pns[2], packets[0].PacketNumber)
+               require.Equal(t, pns[3], packets[1].PacketNumber)
                return false
        })
        _, err := sph.ReceivedAck(
@@ -1089,7 +1089,7 @@ func TestSentPacketHandlerECN(t *testing.T) {
        pns[0] = sendPacket(now, protocol.ECT1)
        pns[1] = sendPacket(now, protocol.ECT1)
        ecnHandler.EXPECT().HandleNewlyAcked(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
-               func(packets []*packet, _, _, _ int64) bool {
+               func(packets []packetWithPacketNumber, _, _, _ int64) bool {
                        require.Len(t, packets, 1)
                        require.Equal(t, pns[1], packets[0].PacketNumber)
                        return false
index 0aabc6d93f31c435465e1ed67da7468a219e571c..5ea399478d75cfd8bed6b152e529e6012c90f256 100644 (file)
@@ -9,16 +9,18 @@ import (
 
 type sentPacketHistory struct {
        packets          []*packet
-       pathProbePackets []*packet
+       pathProbePackets []packetWithPacketNumber
 
        numOutstanding int
 
+       firstPacketNumber   protocol.PacketNumber
        highestPacketNumber protocol.PacketNumber
 }
 
 func newSentPacketHistory(isAppData bool) *sentPacketHistory {
        h := &sentPacketHistory{
                highestPacketNumber: protocol.InvalidPacketNumber,
+               firstPacketNumber:   protocol.InvalidPacketNumber,
        }
        if isAppData {
                h.packets = make([]*packet, 0, 32)
@@ -35,14 +37,14 @@ func (h *sentPacketHistory) checkSequentialPacketNumberUse(pn protocol.PacketNum
                }
        }
        h.highestPacketNumber = pn
+       if len(h.packets) == 0 {
+               h.firstPacketNumber = pn
+       }
 }
 
 func (h *sentPacketHistory) SkippedPacket(pn protocol.PacketNumber) {
        h.checkSequentialPacketNumberUse(pn)
-       h.packets = append(h.packets, &packet{
-               PacketNumber:  pn,
-               skippedPacket: true,
-       })
+       h.packets = append(h.packets, &packet{skippedPacket: true})
 }
 
 func (h *sentPacketHistory) SentNonAckElicitingPacket(pn protocol.PacketNumber) {
@@ -52,40 +54,40 @@ func (h *sentPacketHistory) SentNonAckElicitingPacket(pn protocol.PacketNumber)
        }
 }
 
-func (h *sentPacketHistory) SentAckElicitingPacket(p *packet) {
-       h.checkSequentialPacketNumberUse(p.PacketNumber)
+func (h *sentPacketHistory) SentAckElicitingPacket(pn protocol.PacketNumber, p *packet) {
+       h.checkSequentialPacketNumberUse(pn)
        h.packets = append(h.packets, p)
        if p.outstanding() {
                h.numOutstanding++
        }
 }
 
-func (h *sentPacketHistory) SentPathProbePacket(p *packet) {
-       h.checkSequentialPacketNumberUse(p.PacketNumber)
-       h.packets = append(h.packets, &packet{
-               PacketNumber:      p.PacketNumber,
-               isPathProbePacket: true,
-       })
-       h.pathProbePackets = append(h.pathProbePackets, p)
+func (h *sentPacketHistory) SentPathProbePacket(pn protocol.PacketNumber, p *packet) {
+       h.checkSequentialPacketNumberUse(pn)
+       h.packets = append(h.packets, &packet{isPathProbePacket: true})
+       h.pathProbePackets = append(h.pathProbePackets, packetWithPacketNumber{PacketNumber: pn, packet: p})
 }
 
-func (h *sentPacketHistory) Packets() iter.Seq[*packet] {
-       return func(yield func(*packet) bool) {
-               for _, p := range h.packets {
+func (h *sentPacketHistory) Packets() iter.Seq2[protocol.PacketNumber, *packet] {
+       return func(yield func(protocol.PacketNumber, *packet) bool) {
+               // h.firstPacketNumber might be updated in the yield function,
+               // so we need to save it here.
+               firstPacketNumber := h.firstPacketNumber
+               for i, p := range h.packets {
                        if p == nil {
                                continue
                        }
-                       if !yield(p) {
+                       if !yield(firstPacketNumber+protocol.PacketNumber(i), p) {
                                return
                        }
                }
        }
 }
 
-func (h *sentPacketHistory) PathProbes() iter.Seq[*packet] {
-       return func(yield func(*packet) bool) {
+func (h *sentPacketHistory) PathProbes() iter.Seq2[protocol.PacketNumber, *packet] {
+       return func(yield func(protocol.PacketNumber, *packet) bool) {
                for _, p := range h.pathProbePackets {
-                       if !yield(p) {
+                       if !yield(p.PacketNumber, p.packet) {
                                return
                        }
                }
@@ -93,24 +95,24 @@ func (h *sentPacketHistory) PathProbes() iter.Seq[*packet] {
 }
 
 // FirstOutstanding returns the first outstanding packet.
-func (h *sentPacketHistory) FirstOutstanding() *packet {
+func (h *sentPacketHistory) FirstOutstanding() (protocol.PacketNumber, *packet) {
        if !h.HasOutstandingPackets() {
-               return nil
+               return protocol.InvalidPacketNumber, nil
        }
-       for _, p := range h.packets {
+       for i, p := range h.packets {
                if p != nil && p.outstanding() {
-                       return p
+                       return h.firstPacketNumber + protocol.PacketNumber(i), p
                }
        }
-       return nil
+       return protocol.InvalidPacketNumber, nil
 }
 
 // FirstOutstandingPathProbe returns the first outstanding path probe packet
-func (h *sentPacketHistory) FirstOutstandingPathProbe() *packet {
+func (h *sentPacketHistory) FirstOutstandingPathProbe() (protocol.PacketNumber, *packet) {
        if len(h.pathProbePackets) == 0 {
-               return nil
+               return protocol.InvalidPacketNumber, nil
        }
-       return h.pathProbePackets[0]
+       return h.pathProbePackets[0].PacketNumber, h.pathProbePackets[0].packet
 }
 
 func (h *sentPacketHistory) Len() int {
@@ -156,7 +158,7 @@ func (h *sentPacketHistory) RemovePathProbe(pn protocol.PacketNumber) *packet {
        idx := -1
        for i, p := range h.pathProbePackets {
                if p.PacketNumber == pn {
-                       packetToDelete = p
+                       packetToDelete = p.packet
                        idx = i
                        break
                }
@@ -174,11 +176,10 @@ func (h *sentPacketHistory) getIndex(p protocol.PacketNumber) (int, bool) {
        if len(h.packets) == 0 {
                return 0, false
        }
-       first := h.packets[0].PacketNumber
-       if p < first {
+       if p < h.firstPacketNumber {
                return 0, false
        }
-       index := int(p - first)
+       index := int(p - h.firstPacketNumber)
        if index > len(h.packets)-1 {
                return 0, false
        }
@@ -198,17 +199,19 @@ func (h *sentPacketHistory) cleanupStart() {
        for i, p := range h.packets {
                if p != nil {
                        h.packets = h.packets[i:]
+                       h.firstPacketNumber += protocol.PacketNumber(i)
                        return
                }
        }
        h.packets = h.packets[:0]
+       h.firstPacketNumber = protocol.InvalidPacketNumber
 }
 
 func (h *sentPacketHistory) LowestPacketNumber() protocol.PacketNumber {
        if len(h.packets) == 0 {
                return protocol.InvalidPacketNumber
        }
-       return h.packets[0].PacketNumber
+       return h.firstPacketNumber
 }
 
 func (h *sentPacketHistory) DeclareLost(pn protocol.PacketNumber) {
index a74155d7e1b5543369744e8dfbc174fb2b94cd8c..ea5157de7ac4394580765cee280d0a5be96d9800 100644 (file)
@@ -11,9 +11,9 @@ import (
 
 func (h *sentPacketHistory) getPacketNumbers() []protocol.PacketNumber {
        pns := make([]protocol.PacketNumber, 0, len(h.packets))
-       for _, p := range h.packets {
+       for pn, p := range h.Packets() {
                if p != nil && !p.skippedPacket {
-                       pns = append(pns, p.PacketNumber)
+                       pns = append(pns, pn)
                }
        }
        return pns
@@ -21,63 +21,87 @@ func (h *sentPacketHistory) getPacketNumbers() []protocol.PacketNumber {
 
 func (h *sentPacketHistory) getSkippedPacketNumbers() []protocol.PacketNumber {
        var pns []protocol.PacketNumber
-       for _, p := range h.packets {
+       for pn, p := range h.Packets() {
                if p != nil && p.skippedPacket {
-                       pns = append(pns, p.PacketNumber)
+                       pns = append(pns, pn)
                }
        }
        return pns
 }
 
 func TestSentPacketHistoryPacketTracking(t *testing.T) {
+       t.Run("first packet ack-eliciting", func(t *testing.T) {
+               testSentPacketHistoryPacketTracking(t, true)
+       })
+       t.Run("first packet non-ack-eliciting", func(t *testing.T) {
+               testSentPacketHistoryPacketTracking(t, false)
+       })
+}
+
+func testSentPacketHistoryPacketTracking(t *testing.T, firstPacketAckEliciting bool) {
        hist := newSentPacketHistory(true)
        now := time.Now()
 
+       var firstPacketNumber []protocol.PacketNumber
        require.False(t, hist.HasOutstandingPackets())
-       hist.SentAckElicitingPacket(&packet{PacketNumber: 0})
-       require.True(t, hist.HasOutstandingPackets())
-       hist.SentAckElicitingPacket(&packet{PacketNumber: 1})
-       hist.SentAckElicitingPacket(&packet{PacketNumber: 2})
-       require.Equal(t, []protocol.PacketNumber{0, 1, 2}, hist.getPacketNumbers())
+       if firstPacketAckEliciting {
+               hist.SentAckElicitingPacket(0, &packet{})
+               require.True(t, hist.HasOutstandingPackets())
+               firstPacketNumber = append(firstPacketNumber, 0)
+       } else {
+               hist.SentNonAckElicitingPacket(0)
+               require.False(t, hist.HasOutstandingPackets())
+       }
+       hist.SentAckElicitingPacket(1, &packet{})
+       hist.SentAckElicitingPacket(2, &packet{})
+       require.Equal(t, append(firstPacketNumber, 1, 2), hist.getPacketNumbers())
        require.Empty(t, hist.getSkippedPacketNumbers())
-       require.Equal(t, 3, hist.Len())
+       if firstPacketAckEliciting {
+               require.Equal(t, 3, hist.Len())
+       } else {
+               require.Equal(t, 2, hist.Len())
+       }
 
        // non-ack-eliciting packets are not saved
        hist.SentNonAckElicitingPacket(3)
-       hist.SentAckElicitingPacket(&packet{PacketNumber: 4, SendTime: now})
+       hist.SentAckElicitingPacket(4, &packet{SendTime: now})
        hist.SentNonAckElicitingPacket(5)
-       hist.SentAckElicitingPacket(&packet{PacketNumber: 6, SendTime: now})
-       require.Equal(t, []protocol.PacketNumber{0, 1, 2, 4, 6}, hist.getPacketNumbers())
+       hist.SentAckElicitingPacket(6, &packet{SendTime: now})
+       require.Equal(t, append(firstPacketNumber, 1, 2, 4, 6), hist.getPacketNumbers())
 
        // handle skipped packet numbers
        hist.SkippedPacket(7)
-       hist.SentAckElicitingPacket(&packet{PacketNumber: 8})
+       hist.SentAckElicitingPacket(8, &packet{SendTime: now})
        hist.SentNonAckElicitingPacket(9)
        hist.SkippedPacket(10)
-       hist.SentAckElicitingPacket(&packet{PacketNumber: 11})
-       require.Equal(t, []protocol.PacketNumber{0, 1, 2, 4, 6, 8, 11}, hist.getPacketNumbers())
+       hist.SentAckElicitingPacket(11, &packet{SendTime: now})
+       require.Equal(t, append(firstPacketNumber, 1, 2, 4, 6, 8, 11), hist.getPacketNumbers())
        require.Equal(t, []protocol.PacketNumber{7, 10}, hist.getSkippedPacketNumbers())
-       require.Equal(t, 12, hist.Len())
+       if firstPacketAckEliciting {
+               require.Equal(t, 12, hist.Len())
+       } else {
+               require.Equal(t, 11, hist.Len())
+       }
 }
 
 func TestSentPacketHistoryNonSequentialPacketNumberUse(t *testing.T) {
        hist := newSentPacketHistory(true)
-       hist.SentAckElicitingPacket(&packet{PacketNumber: 100})
+       hist.SentAckElicitingPacket(100, &packet{})
        require.Panics(t, func() {
-               hist.SentAckElicitingPacket(&packet{PacketNumber: 102})
+               hist.SentAckElicitingPacket(102, &packet{})
        })
 }
 
 func TestSentPacketHistoryRemovePackets(t *testing.T) {
        hist := newSentPacketHistory(true)
 
-       hist.SentAckElicitingPacket(&packet{PacketNumber: 0})
-       hist.SentAckElicitingPacket(&packet{PacketNumber: 1})
+       hist.SentAckElicitingPacket(0, &packet{})
+       hist.SentAckElicitingPacket(1, &packet{})
        hist.SkippedPacket(2)
        hist.SkippedPacket(3)
-       hist.SentAckElicitingPacket(&packet{PacketNumber: 4})
+       hist.SentAckElicitingPacket(4, &packet{})
        hist.SkippedPacket(5)
-       hist.SentAckElicitingPacket(&packet{PacketNumber: 6})
+       hist.SentAckElicitingPacket(6, &packet{})
        require.Equal(t, []protocol.PacketNumber{0, 1, 4, 6}, hist.getPacketNumbers())
        require.Equal(t, []protocol.PacketNumber{2, 3, 5}, hist.getSkippedPacketNumbers())
 
@@ -87,12 +111,12 @@ func TestSentPacketHistoryRemovePackets(t *testing.T) {
        require.Equal(t, []protocol.PacketNumber{2, 3, 5}, hist.getSkippedPacketNumbers())
 
        // add one more packet
-       hist.SentAckElicitingPacket(&packet{PacketNumber: 7})
+       hist.SentAckElicitingPacket(7, &packet{})
        require.Equal(t, []protocol.PacketNumber{4, 6, 7}, hist.getPacketNumbers())
 
        // remove last packet and add another
        require.NoError(t, hist.Remove(7))
-       hist.SentAckElicitingPacket(&packet{PacketNumber: 8})
+       hist.SentAckElicitingPacket(8, &packet{})
        require.Equal(t, []protocol.PacketNumber{4, 6, 8}, hist.getPacketNumbers())
 
        // try to remove non-existent packet
@@ -112,48 +136,50 @@ func TestSentPacketHistoryRemovePackets(t *testing.T) {
 func TestSentPacketHistoryFirstOutstandingPacket(t *testing.T) {
        hist := newSentPacketHistory(true)
 
-       require.Nil(t, hist.FirstOutstanding())
+       pn, p := hist.FirstOutstanding()
+       require.Equal(t, protocol.InvalidPacketNumber, pn)
+       require.Nil(t, p)
 
-       hist.SentAckElicitingPacket(&packet{PacketNumber: 2})
-       hist.SentAckElicitingPacket(&packet{PacketNumber: 3})
-       front := hist.FirstOutstanding()
-       require.NotNil(t, front)
-       require.Equal(t, protocol.PacketNumber(2), front.PacketNumber)
+       hist.SentAckElicitingPacket(2, &packet{})
+       hist.SentAckElicitingPacket(3, &packet{})
+       pn, p = hist.FirstOutstanding()
+       require.Equal(t, protocol.PacketNumber(2), pn)
+       require.NotNil(t, p)
 
        // remove the first packet
        hist.Remove(2)
-       front = hist.FirstOutstanding()
-       require.NotNil(t, front)
-       require.Equal(t, protocol.PacketNumber(3), front.PacketNumber)
+       pn, p = hist.FirstOutstanding()
+       require.Equal(t, protocol.PacketNumber(3), pn)
+       require.NotNil(t, p)
 
        // Path MTU packets are not regarded as outstanding
        hist = newSentPacketHistory(true)
-       hist.SentAckElicitingPacket(&packet{PacketNumber: 2})
+       hist.SentAckElicitingPacket(2, &packet{})
        hist.SkippedPacket(3)
-       hist.SentAckElicitingPacket(&packet{PacketNumber: 4, IsPathMTUProbePacket: true})
-       front = hist.FirstOutstanding()
-       require.NotNil(t, front)
-       require.Equal(t, protocol.PacketNumber(2), front.PacketNumber)
+       hist.SentAckElicitingPacket(4, &packet{IsPathMTUProbePacket: true})
+       pn, p = hist.FirstOutstanding()
+       require.NotNil(t, p)
+       require.Equal(t, protocol.PacketNumber(2), pn)
 }
 
 func TestSentPacketHistoryIterating(t *testing.T) {
        hist := newSentPacketHistory(true)
        hist.SkippedPacket(0)
-       hist.SentAckElicitingPacket(&packet{PacketNumber: 1})
-       hist.SentAckElicitingPacket(&packet{PacketNumber: 2})
-       hist.SentAckElicitingPacket(&packet{PacketNumber: 3})
+       hist.SentAckElicitingPacket(1, &packet{})
+       hist.SentAckElicitingPacket(2, &packet{})
+       hist.SentAckElicitingPacket(3, &packet{})
        hist.SkippedPacket(4)
        hist.SkippedPacket(5)
-       hist.SentAckElicitingPacket(&packet{PacketNumber: 6})
+       hist.SentAckElicitingPacket(6, &packet{})
        require.NoError(t, hist.Remove(3))
        require.NoError(t, hist.Remove(4))
 
        var packets, skippedPackets []protocol.PacketNumber
-       for p := range hist.Packets() {
+       for pn, p := range hist.Packets() {
                if p.skippedPacket {
-                       skippedPackets = append(skippedPackets, p.PacketNumber)
+                       skippedPackets = append(skippedPackets, pn)
                } else {
-                       packets = append(packets, p.PacketNumber)
+                       packets = append(packets, pn)
                }
        }
 
@@ -163,17 +189,17 @@ func TestSentPacketHistoryIterating(t *testing.T) {
 
 func TestSentPacketHistoryDeleteWhileIterating(t *testing.T) {
        hist := newSentPacketHistory(true)
-       hist.SentAckElicitingPacket(&packet{PacketNumber: 0})
-       hist.SentAckElicitingPacket(&packet{PacketNumber: 1})
+       hist.SentAckElicitingPacket(0, &packet{})
+       hist.SentAckElicitingPacket(1, &packet{})
        hist.SkippedPacket(2)
-       hist.SentAckElicitingPacket(&packet{PacketNumber: 3})
+       hist.SentAckElicitingPacket(3, &packet{})
        hist.SkippedPacket(4)
-       hist.SentAckElicitingPacket(&packet{PacketNumber: 5})
+       hist.SentAckElicitingPacket(5, &packet{})
 
        var iterations []protocol.PacketNumber
-       for p := range hist.Packets() {
-               iterations = append(iterations, p.PacketNumber)
-               switch p.PacketNumber {
+       for pn := range hist.Packets() {
+               iterations = append(iterations, pn)
+               switch pn {
                case 0:
                        require.NoError(t, hist.Remove(0))
                case 4:
@@ -188,19 +214,19 @@ func TestSentPacketHistoryDeleteWhileIterating(t *testing.T) {
 
 func TestSentPacketHistoryPathProbes(t *testing.T) {
        hist := newSentPacketHistory(true)
-       hist.SentAckElicitingPacket(&packet{PacketNumber: 0})
-       hist.SentAckElicitingPacket(&packet{PacketNumber: 1})
-       hist.SentPathProbePacket(&packet{PacketNumber: 2})
-       hist.SentAckElicitingPacket(&packet{PacketNumber: 3})
-       hist.SentAckElicitingPacket(&packet{PacketNumber: 4})
-       hist.SentPathProbePacket(&packet{PacketNumber: 5})
+       hist.SentAckElicitingPacket(0, &packet{})
+       hist.SentAckElicitingPacket(1, &packet{})
+       hist.SentPathProbePacket(2, &packet{})
+       hist.SentAckElicitingPacket(3, &packet{})
+       hist.SentAckElicitingPacket(4, &packet{})
+       hist.SentPathProbePacket(5, &packet{})
 
        getPacketsInHistory := func(t *testing.T) []protocol.PacketNumber {
                t.Helper()
                var pns []protocol.PacketNumber
-               for p := range hist.Packets() {
-                       pns = append(pns, p.PacketNumber)
-                       switch p.PacketNumber {
+               for pn, p := range hist.Packets() {
+                       pns = append(pns, pn)
+                       switch pn {
                        case 2, 5:
                                require.True(t, p.isPathProbePacket)
                        default:
@@ -213,8 +239,8 @@ func TestSentPacketHistoryPathProbes(t *testing.T) {
        getPacketsInPathProbeHistory := func(t *testing.T) []protocol.PacketNumber {
                t.Helper()
                var pns []protocol.PacketNumber
-               for p := range hist.PathProbes() {
-                       pns = append(pns, p.PacketNumber)
+               for pn := range hist.PathProbes() {
+                       pns = append(pns, pn)
                }
                return pns
        }
@@ -233,45 +259,53 @@ func TestSentPacketHistoryPathProbes(t *testing.T) {
        require.Equal(t, []protocol.PacketNumber{2, 5}, getPacketsInPathProbeHistory(t))
        require.True(t, hist.HasOutstandingPackets())
        require.True(t, hist.HasOutstandingPathProbes())
-       firstOutstanding := hist.FirstOutstanding()
-       require.NotNil(t, firstOutstanding)
-       require.Equal(t, protocol.PacketNumber(4), firstOutstanding.PacketNumber)
-       firstOutStandingPathProbe := hist.FirstOutstandingPathProbe()
-       require.NotNil(t, firstOutStandingPathProbe)
-       require.Equal(t, protocol.PacketNumber(2), firstOutStandingPathProbe.PacketNumber)
+       pn, p := hist.FirstOutstanding()
+       require.Equal(t, protocol.PacketNumber(4), pn)
+       require.NotNil(t, p)
+       pn, p = hist.FirstOutstandingPathProbe()
+       require.NotNil(t, p)
+       require.Equal(t, protocol.PacketNumber(2), pn)
 
        hist.RemovePathProbe(2)
        require.Equal(t, []protocol.PacketNumber{4, 5}, getPacketsInHistory(t))
        require.Equal(t, []protocol.PacketNumber{5}, getPacketsInPathProbeHistory(t))
        require.True(t, hist.HasOutstandingPathProbes())
-       firstOutStandingPathProbe = hist.FirstOutstandingPathProbe()
-       require.NotNil(t, firstOutStandingPathProbe)
-       require.Equal(t, protocol.PacketNumber(5), firstOutStandingPathProbe.PacketNumber)
+       pn, p = hist.FirstOutstandingPathProbe()
+       require.NotNil(t, p)
+       require.Equal(t, protocol.PacketNumber(5), pn)
 
        hist.RemovePathProbe(5)
        require.Equal(t, []protocol.PacketNumber{4, 5}, getPacketsInHistory(t))
        require.Empty(t, getPacketsInPathProbeHistory(t))
        require.True(t, hist.HasOutstandingPackets())
        require.False(t, hist.HasOutstandingPathProbes())
-       require.Nil(t, hist.FirstOutstandingPathProbe())
+       pn, p = hist.FirstOutstandingPathProbe()
+       require.Equal(t, protocol.InvalidPacketNumber, pn)
+       require.Nil(t, p)
 
        require.NoError(t, hist.Remove(4))
        require.NoError(t, hist.Remove(5))
        require.Empty(t, getPacketsInHistory(t))
        require.False(t, hist.HasOutstandingPackets())
-       require.Nil(t, hist.FirstOutstanding())
+       pn, p = hist.FirstOutstanding()
+       require.Equal(t, protocol.InvalidPacketNumber, pn)
+       require.Nil(t, p)
 
        // path probe packets are considered outstanding
-       hist.SentPathProbePacket(&packet{PacketNumber: 6})
+       hist.SentPathProbePacket(6, &packet{})
        require.False(t, hist.HasOutstandingPackets())
        require.True(t, hist.HasOutstandingPathProbes())
-       firstOutStandingPathProbe = hist.FirstOutstandingPathProbe()
-       require.NotNil(t, firstOutStandingPathProbe)
-       require.Equal(t, protocol.PacketNumber(6), firstOutStandingPathProbe.PacketNumber)
+       pn, p = hist.FirstOutstandingPathProbe()
+       require.NotNil(t, p)
+       require.Equal(t, protocol.PacketNumber(6), pn)
 
        hist.RemovePathProbe(6)
        require.False(t, hist.HasOutstandingPackets())
-       require.Nil(t, hist.FirstOutstanding())
+       pn, p = hist.FirstOutstanding()
+       require.Equal(t, protocol.InvalidPacketNumber, pn)
+       require.Nil(t, p)
        require.False(t, hist.HasOutstandingPathProbes())
-       require.Nil(t, hist.FirstOutstandingPathProbe())
+       pn, p = hist.FirstOutstandingPathProbe()
+       require.Equal(t, protocol.InvalidPacketNumber, pn)
+       require.Nil(t, p)
 }