]> git.feebdaed.xyz Git - 0xmirror/quic-go.git/commitdiff
ackhandler: store skipped packet numbers separately (#5314)
authorMarten Seemann <martenseemann@gmail.com>
Fri, 29 Aug 2025 03:09:46 +0000 (11:09 +0800)
committerGitHub <noreply@github.com>
Fri, 29 Aug 2025 03:09:46 +0000 (05:09 +0200)
internal/ackhandler/packet.go
internal/ackhandler/sent_packet_handler.go
internal/ackhandler/sent_packet_history.go
internal/ackhandler/sent_packet_history_test.go

index e2f4e7d2e8b47624d818fb16349eddcca0c24826..1677d4fbc94b22688fe8441f6126d18b8d311afa 100644 (file)
@@ -25,12 +25,11 @@ type packet struct {
 
        includedInBytesInFlight bool
        declaredLost            bool
-       skippedPacket           bool
        isPathProbePacket       bool
 }
 
 func (p *packet) outstanding() bool {
-       return !p.declaredLost && !p.skippedPacket && !p.IsPathMTUProbePacket && !p.isPathProbePacket
+       return !p.declaredLost && !p.IsPathMTUProbePacket && !p.isPathProbePacket
 }
 
 var packetPool = sync.Pool{New: func() any { return &packet{} }}
@@ -46,7 +45,6 @@ func getPacket() *packet {
        p.IsPathMTUProbePacket = false
        p.includedInBytesInFlight = false
        p.declaredLost = false
-       p.skippedPacket = false
        p.isPathProbePacket = false
        return p
 }
index 622814a7ebe2721a0a7e9008e7a8462a857dd752..5e4623a87c94cdddb9eaff5fc5b3cbf23bba9102 100644 (file)
@@ -201,7 +201,7 @@ func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel, now t
                // When 0-RTT is rejected, all application data sent so far becomes invalid.
                // Delete the packets from the history and remove them from bytes_in_flight.
                for pn, p := range h.appDataPackets.history.Packets() {
-                       if p.EncryptionLevel != protocol.Encryption0RTT && !p.skippedPacket {
+                       if p.EncryptionLevel != protocol.Encryption0RTT {
                                break
                        }
                        h.removeFromBytesInFlight(p)
@@ -431,11 +431,24 @@ func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNu
 
 // Packets are returned in ascending packet number order.
 func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encLevel protocol.EncryptionLevel) ([]packetWithPacketNumber, error) {
-       pnSpace := h.getPacketNumberSpace(encLevel)
-       ackRangeIndex := 0
        if len(h.ackedPackets) > 0 {
                return nil, errors.New("ackhandler BUG: ackedPackets slice not empty")
        }
+
+       pnSpace := h.getPacketNumberSpace(encLevel)
+
+       if encLevel == protocol.Encryption1RTT {
+               for p := range pnSpace.history.SkippedPackets() {
+                       if ack.AcksPacket(p) {
+                               return nil, &qerr.TransportError{
+                                       ErrorCode:    qerr.ProtocolViolation,
+                                       ErrorMessage: fmt.Sprintf("received an ACK for skipped packet number: %d (%s)", p, encLevel),
+                               }
+                       }
+               }
+       }
+
+       var ackRangeIndex int
        lowestAcked := ack.LowestAcked()
        largestAcked := ack.LargestAcked()
        for pn, p := range pnSpace.history.Packets() {
@@ -462,12 +475,6 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL
                                return nil, fmt.Errorf("BUG: ackhandler would have acked wrong packet %d, while evaluating range %d -> %d", pn, ackRange.Smallest, ackRange.Largest)
                        }
                }
-               if p.skippedPacket {
-                       return nil, &qerr.TransportError{
-                               ErrorCode:    qerr.ProtocolViolation,
-                               ErrorMessage: fmt.Sprintf("received an ACK for skipped packet number: %d (%s)", pn, encLevel),
-                       }
-               }
                if p.isPathProbePacket {
                        probePacket := pnSpace.history.RemovePathProbe(pn)
                        // the probe packet might already have been declared lost
@@ -692,11 +699,10 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E
                        break
                }
 
-               isRegularPacket := !p.skippedPacket && !p.isPathProbePacket
                var packetLost bool
                if !p.SendTime.After(lostSendTime) {
                        packetLost = true
-                       if isRegularPacket {
+                       if !p.isPathProbePacket {
                                if h.logger.Debug() {
                                        h.logger.Debugf("\tlost packet %d (time threshold)", pn)
                                }
@@ -706,7 +712,7 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E
                        }
                } else if pnSpace.largestAcked >= pn+packetThreshold {
                        packetLost = true
-                       if isRegularPacket {
+                       if !p.isPathProbePacket {
                                if h.logger.Debug() {
                                        h.logger.Debugf("\tlost packet %d (reordering threshold)", pn)
                                }
@@ -724,7 +730,7 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E
                }
                if packetLost {
                        pnSpace.history.DeclareLost(pn)
-                       if isRegularPacket {
+                       if !p.isPathProbePacket {
                                // the bytes in flight need to be reduced no matter if the frames in this packet will be retransmitted
                                h.removeFromBytesInFlight(p)
                                h.queueFramesForRetransmission(p)
@@ -948,14 +954,14 @@ func (h *sentPacketHandler) ResetForRetry(now time.Time) {
                if firstPacketSendTime.IsZero() {
                        firstPacketSendTime = p.SendTime
                }
-               if !p.declaredLost && !p.skippedPacket {
+               if !p.declaredLost {
                        h.queueFramesForRetransmission(p)
                }
        }
        // All application data packets sent at this point are 0-RTT packets.
        // In the case of a Retry, we can assume that the server dropped all of them.
        for _, p := range h.appDataPackets.history.Packets() {
-               if !p.declaredLost && !p.skippedPacket {
+               if !p.declaredLost {
                        h.queueFramesForRetransmission(p)
                }
        }
@@ -991,7 +997,7 @@ func (h *sentPacketHandler) MigratedPath(now time.Time, initialMaxDatagramSize p
        h.rttStats.ResetForPathMigration()
        for pn, p := range h.appDataPackets.history.Packets() {
                h.appDataPackets.history.DeclareLost(pn)
-               if !p.skippedPacket && !p.isPathProbePacket {
+               if !p.isPathProbePacket {
                        h.removeFromBytesInFlight(p)
                        h.queueFramesForRetransmission(p)
                }
index 5ea399478d75cfd8bed6b152e529e6012c90f256..c92c642c99bdb0950b29ab34f33f2b5c85444e49 100644 (file)
@@ -10,6 +10,7 @@ import (
 type sentPacketHistory struct {
        packets          []*packet
        pathProbePackets []packetWithPacketNumber
+       skippedPackets   []protocol.PacketNumber
 
        numOutstanding int
 
@@ -44,7 +45,10 @@ func (h *sentPacketHistory) checkSequentialPacketNumberUse(pn protocol.PacketNum
 
 func (h *sentPacketHistory) SkippedPacket(pn protocol.PacketNumber) {
        h.checkSequentialPacketNumberUse(pn)
-       h.packets = append(h.packets, &packet{skippedPacket: true})
+       if len(h.packets) > 0 {
+               h.packets = append(h.packets, nil)
+       }
+       h.skippedPackets = append(h.skippedPackets, pn)
 }
 
 func (h *sentPacketHistory) SentNonAckElicitingPacket(pn protocol.PacketNumber) {
@@ -115,10 +119,22 @@ func (h *sentPacketHistory) FirstOutstandingPathProbe() (protocol.PacketNumber,
        return h.pathProbePackets[0].PacketNumber, h.pathProbePackets[0].packet
 }
 
+func (h *sentPacketHistory) SkippedPackets() iter.Seq[protocol.PacketNumber] {
+       return func(yield func(protocol.PacketNumber) bool) {
+               for _, p := range h.skippedPackets {
+                       if !yield(p) {
+                               return
+                       }
+               }
+       }
+}
+
 func (h *sentPacketHistory) Len() int {
        return len(h.packets)
 }
 
+// Remove removes a packet from the sent packet history.
+// It must not be used for skipped packet numbers.
 func (h *sentPacketHistory) Remove(pn protocol.PacketNumber) error {
        idx, ok := h.getIndex(pn)
        if !ok {
@@ -133,19 +149,26 @@ func (h *sentPacketHistory) Remove(pn protocol.PacketNumber) error {
        }
        h.packets[idx] = nil
        // clean up all skipped packets directly before this packet number
+       var hasPacketBefore bool
        for idx > 0 {
                idx--
-               p := h.packets[idx]
-               if p == nil || !p.skippedPacket {
+               if h.packets[idx] != nil {
+                       hasPacketBefore = true
                        break
                }
-               h.packets[idx] = nil
        }
-       if idx == 0 {
+       if !hasPacketBefore {
                h.cleanupStart()
        }
        if len(h.packets) > 0 && h.packets[0] == nil {
-               panic("remove failed")
+               panic("cleanup failed")
+       }
+       if len(h.packets) > 0 && len(h.skippedPackets) > 0 {
+               for _, p := range h.skippedPackets {
+                       if p < h.firstPacketNumber {
+                               h.skippedPackets = h.skippedPackets[1:]
+                       }
+               }
        }
        return nil
 }
index ea5157de7ac4394580765cee280d0a5be96d9800..8c08923adc7ac677c81bdf226fd03674454741a1 100644 (file)
@@ -1,6 +1,7 @@
 package ackhandler
 
 import (
+       "slices"
        "testing"
        "time"
 
@@ -11,20 +12,8 @@ import (
 
 func (h *sentPacketHistory) getPacketNumbers() []protocol.PacketNumber {
        pns := make([]protocol.PacketNumber, 0, len(h.packets))
-       for pn, p := range h.Packets() {
-               if p != nil && !p.skippedPacket {
-                       pns = append(pns, pn)
-               }
-       }
-       return pns
-}
-
-func (h *sentPacketHistory) getSkippedPacketNumbers() []protocol.PacketNumber {
-       var pns []protocol.PacketNumber
-       for pn, p := range h.Packets() {
-               if p != nil && p.skippedPacket {
-                       pns = append(pns, pn)
-               }
+       for pn := range h.Packets() {
+               pns = append(pns, pn)
        }
        return pns
 }
@@ -55,7 +44,7 @@ func testSentPacketHistoryPacketTracking(t *testing.T, firstPacketAckEliciting b
        hist.SentAckElicitingPacket(1, &packet{})
        hist.SentAckElicitingPacket(2, &packet{})
        require.Equal(t, append(firstPacketNumber, 1, 2), hist.getPacketNumbers())
-       require.Empty(t, hist.getSkippedPacketNumbers())
+       require.Empty(t, slices.Collect(hist.SkippedPackets()))
        if firstPacketAckEliciting {
                require.Equal(t, 3, hist.Len())
        } else {
@@ -76,7 +65,7 @@ func testSentPacketHistoryPacketTracking(t *testing.T, firstPacketAckEliciting b
        hist.SkippedPacket(10)
        hist.SentAckElicitingPacket(11, &packet{SendTime: now})
        require.Equal(t, append(firstPacketNumber, 1, 2, 4, 6, 8, 11), hist.getPacketNumbers())
-       require.Equal(t, []protocol.PacketNumber{7, 10}, hist.getSkippedPacketNumbers())
+       require.Equal(t, []protocol.PacketNumber{7, 10}, slices.Collect(hist.SkippedPackets()))
        if firstPacketAckEliciting {
                require.Equal(t, 12, hist.Len())
        } else {
@@ -103,12 +92,13 @@ func TestSentPacketHistoryRemovePackets(t *testing.T) {
        hist.SkippedPacket(5)
        hist.SentAckElicitingPacket(6, &packet{})
        require.Equal(t, []protocol.PacketNumber{0, 1, 4, 6}, hist.getPacketNumbers())
-       require.Equal(t, []protocol.PacketNumber{2, 3, 5}, hist.getSkippedPacketNumbers())
+       require.Equal(t, []protocol.PacketNumber{2, 3, 5}, slices.Collect(hist.SkippedPackets()))
 
        require.NoError(t, hist.Remove(0))
+       require.Equal(t, []protocol.PacketNumber{2, 3, 5}, slices.Collect(hist.SkippedPackets()))
        require.NoError(t, hist.Remove(1))
        require.Equal(t, []protocol.PacketNumber{4, 6}, hist.getPacketNumbers())
-       require.Equal(t, []protocol.PacketNumber{2, 3, 5}, hist.getSkippedPacketNumbers())
+       require.Equal(t, []protocol.PacketNumber{5}, slices.Collect(hist.SkippedPackets()))
 
        // add one more packet
        hist.SentAckElicitingPacket(7, &packet{})
@@ -129,7 +119,7 @@ func TestSentPacketHistoryRemovePackets(t *testing.T) {
        require.NoError(t, hist.Remove(6))
        require.NoError(t, hist.Remove(8))
        require.Empty(t, hist.getPacketNumbers())
-       require.Empty(t, hist.getSkippedPacketNumbers())
+       require.Empty(t, slices.Collect(hist.SkippedPackets()))
        require.False(t, hist.HasOutstandingPackets())
 }
 
@@ -171,20 +161,17 @@ func TestSentPacketHistoryIterating(t *testing.T) {
        hist.SkippedPacket(4)
        hist.SkippedPacket(5)
        hist.SentAckElicitingPacket(6, &packet{})
+       require.Equal(t, []protocol.PacketNumber{0, 4, 5}, slices.Collect(hist.SkippedPackets()))
        require.NoError(t, hist.Remove(3))
-       require.NoError(t, hist.Remove(4))
 
-       var packets, skippedPackets []protocol.PacketNumber
+       var packets []protocol.PacketNumber
        for pn, p := range hist.Packets() {
-               if p.skippedPacket {
-                       skippedPackets = append(skippedPackets, pn)
-               } else {
-                       packets = append(packets, pn)
-               }
+               require.NotNil(t, p)
+               packets = append(packets, pn)
        }
 
        require.Equal(t, []protocol.PacketNumber{1, 2, 6}, packets)
-       require.Equal(t, []protocol.PacketNumber{0, 5}, skippedPackets)
+       require.Equal(t, []protocol.PacketNumber{4, 5}, slices.Collect(hist.SkippedPackets()))
 }
 
 func TestSentPacketHistoryDeleteWhileIterating(t *testing.T) {
@@ -202,14 +189,14 @@ func TestSentPacketHistoryDeleteWhileIterating(t *testing.T) {
                switch pn {
                case 0:
                        require.NoError(t, hist.Remove(0))
-               case 4:
-                       require.NoError(t, hist.Remove(4))
+               case 3:
+                       require.NoError(t, hist.Remove(3))
                }
        }
 
-       require.Equal(t, []protocol.PacketNumber{0, 1, 2, 3, 4, 5}, iterations)
-       require.Equal(t, []protocol.PacketNumber{1, 3, 5}, hist.getPacketNumbers())
-       require.Equal(t, []protocol.PacketNumber{2}, hist.getSkippedPacketNumbers())
+       require.Equal(t, []protocol.PacketNumber{0, 1, 3, 5}, iterations)
+       require.Equal(t, []protocol.PacketNumber{1, 5}, hist.getPacketNumbers())
+       require.Equal(t, []protocol.PacketNumber{2, 4}, slices.Collect(hist.SkippedPackets()))
 }
 
 func TestSentPacketHistoryPathProbes(t *testing.T) {