]> git.feebdaed.xyz Git - 0xmirror/quic-go.git/commitdiff
wire: optimize parsing logic for STREAM, DATAGRAM and ACK frames (#5227)
authorJannis Seemann <5215310+jannis-seemann@users.noreply.github.com>
Sun, 20 Jul 2025 11:14:38 +0000 (14:14 +0300)
committerGitHub <noreply@github.com>
Sun, 20 Jul 2025 11:14:38 +0000 (13:14 +0200)
ParseOtherFrames-16       148ns ± 4%     150ns ± 3%     ~     (p=0.223 n=8+8)
ParseAckFrame-16          302ns ± 2%     298ns ± 3%     ~     (p=0.246 n=8+8)
ParseStreamFrame-16       262ns ± 3%     213ns ± 2%  -18.61%  (p=0.000 n=8+8)
ParseDatagramFrame-16     561ns ± 5%     547ns ± 4%     ~     (p=0.105 n=8+8)

49 files changed:
connection.go
connection_test.go
fuzzing/frames/fuzz.go
internal/ackhandler/ack_eliciting.go
internal/ackhandler/ack_eliciting_test.go
internal/wire/ack_frame.go
internal/wire/ack_frame_test.go
internal/wire/connection_close_frame.go
internal/wire/connection_close_frame_test.go
internal/wire/crypto_frame.go
internal/wire/crypto_frame_test.go
internal/wire/data_blocked_frame.go
internal/wire/data_blocked_frame_test.go
internal/wire/datagram_frame.go
internal/wire/frame.go
internal/wire/frame_parser.go
internal/wire/frame_parser_test.go
internal/wire/frame_test.go
internal/wire/frame_type.go [new file with mode: 0644]
internal/wire/frame_type_test.go [new file with mode: 0644]
internal/wire/handshake_done_frame.go
internal/wire/handshake_done_frame_test.go
internal/wire/max_data_frame.go
internal/wire/max_data_frame_test.go
internal/wire/max_stream_data_frame.go
internal/wire/max_stream_data_frame_test.go
internal/wire/max_streams_frame.go
internal/wire/max_streams_frame_test.go
internal/wire/new_connection_id_frame.go
internal/wire/new_connection_id_frame_test.go
internal/wire/new_token_frame.go
internal/wire/new_token_frame_test.go
internal/wire/path_challenge_frame.go
internal/wire/path_challenge_frame_test.go
internal/wire/path_response_frame.go
internal/wire/path_response_frame_test.go
internal/wire/ping_frame.go
internal/wire/reset_stream_frame.go
internal/wire/reset_stream_frame_test.go
internal/wire/retire_connection_id_frame.go
internal/wire/retire_connection_id_frame_test.go
internal/wire/stop_sending_frame.go
internal/wire/stop_sending_frame_test.go
internal/wire/stream_data_blocked_frame_test.go
internal/wire/stream_frame.go
internal/wire/stream_frame_test.go
internal/wire/streams_blocked_frame.go
internal/wire/streams_blocked_frame_test.go
packet_packer_test.go

index ce565aa014f96eedf971f9528fb9b0a07271c874..2251af08d5c412ce99b21ee0ccf7378bd0495060 100644 (file)
@@ -1437,39 +1437,100 @@ func (c *Conn) handleFrames(
        }
        handshakeWasComplete := c.handshakeComplete
        var handleErr error
+       var skipHandling bool
+
        for len(data) > 0 {
-               l, frame, err := c.frameParser.ParseNext(data, encLevel, c.version)
+               frameType, l, err := c.frameParser.ParseType(data, encLevel)
                if err != nil {
+                       // The frame parser skips over PADDING frames, and returns an io.EOF if the PADDING
+                       // frames were the last frames in this packet.
+                       if err == io.EOF {
+                               break
+                       }
                        return false, false, nil, err
                }
                data = data[l:]
-               if frame == nil {
-                       break
-               }
-               if ackhandler.IsFrameAckEliciting(frame) {
+
+               if ackhandler.IsFrameTypeAckEliciting(frameType) {
                        isAckEliciting = true
                }
-               if !wire.IsProbingFrame(frame) {
+               if !wire.IsProbingFrameType(frameType) {
                        isNonProbing = true
                }
-               if log != nil {
-                       frames = append(frames, toLoggingFrame(frame))
-               }
-               // An error occurred handling a previous frame.
-               // Don't handle the current frame.
-               if handleErr != nil {
-                       continue
-               }
-               pc, err := c.handleFrame(frame, encLevel, destConnID, rcvTime)
-               if err != nil {
-                       if log == nil {
+
+               // We're inlining common cases, to avoid using interfaces
+               // Fast path: STREAM, DATAGRAM and ACK
+               if frameType.IsStreamFrameType() {
+                       streamFrame, l, err := c.frameParser.ParseStreamFrame(frameType, data, c.version)
+                       if err != nil {
+                               return false, false, nil, err
+                       }
+                       data = data[l:]
+
+                       if log != nil {
+                               frames = append(frames, toLoggingFrame(streamFrame))
+                       }
+                       // an error occurred handling a previous frame, don't handle the current frame
+                       if skipHandling {
+                               continue
+                       }
+                       handleErr = c.streamsMap.HandleStreamFrame(streamFrame, rcvTime)
+               } else if frameType.IsAckFrameType() {
+                       ackFrame, l, err := c.frameParser.ParseAckFrame(frameType, data, encLevel, c.version)
+                       if err != nil {
+                               return false, false, nil, err
+                       }
+                       data = data[l:]
+                       if log != nil {
+                               frames = append(frames, toLoggingFrame(ackFrame))
+                       }
+                       // an error occurred handling a previous frame, don't handle the current frame
+                       if skipHandling {
+                               continue
+                       }
+                       handleErr = c.handleAckFrame(ackFrame, encLevel, rcvTime)
+               } else if frameType.IsDatagramFrameType() {
+                       datagramFrame, l, err := c.frameParser.ParseDatagramFrame(frameType, data, c.version)
+                       if err != nil {
+                               return false, false, nil, err
+                       }
+                       data = data[l:]
+
+                       if log != nil {
+                               frames = append(frames, toLoggingFrame(datagramFrame))
+                       }
+                       // an error occurred handling a previous frame, don't handle the current frame
+                       if skipHandling {
+                               continue
+                       }
+                       handleErr = c.handleDatagramFrame(datagramFrame)
+               } else {
+                       frame, l, err := c.frameParser.ParseLessCommonFrame(frameType, data, c.version)
+                       if err != nil {
                                return false, false, nil, err
                        }
-                       // If we're logging, we need to keep parsing (but not handling) all frames.
+                       data = data[l:]
+
+                       if log != nil {
+                               frames = append(frames, toLoggingFrame(frame))
+                       }
+                       // an error occurred handling a previous frame, don't handle the current frame
+                       if skipHandling {
+                               continue
+                       }
+                       pc, err := c.handleFrame(frame, encLevel, destConnID, rcvTime)
+                       if pc != nil {
+                               pathChallenge = pc
+                       }
                        handleErr = err
                }
-               if pc != nil {
-                       pathChallenge = pc
+
+               if handleErr != nil {
+                       // if we're logging, we need to keep parsing (but not handling) all frames
+                       skipHandling = true
+                       if log == nil {
+                               return false, false, nil, handleErr
+                       }
                }
        }
 
@@ -1503,10 +1564,6 @@ func (c *Conn) handleFrame(
        switch frame := f.(type) {
        case *wire.CryptoFrame:
                err = c.handleCryptoFrame(frame, encLevel, rcvTime)
-       case *wire.StreamFrame:
-               err = c.streamsMap.HandleStreamFrame(frame, rcvTime)
-       case *wire.AckFrame:
-               err = c.handleAckFrame(frame, encLevel, rcvTime)
        case *wire.ConnectionCloseFrame:
                err = c.handleConnectionCloseFrame(frame)
        case *wire.ResetStreamFrame:
@@ -1537,8 +1594,6 @@ func (c *Conn) handleFrame(
                err = c.connIDGenerator.Retire(frame.SequenceNumber, destConnID, rcvTime.Add(3*c.rttStats.PTO(false)))
        case *wire.HandshakeDoneFrame:
                err = c.handleHandshakeDoneFrame(rcvTime)
-       case *wire.DatagramFrame:
-               err = c.handleDatagramFrame(frame)
        default:
                err = fmt.Errorf("unexpected frame type: %s", reflect.ValueOf(&frame).Elem().Type().Name())
        }
index cf976eb69ca59b0570d149f4bf04907e94cbb787..30d7dd8c15c42bf2560f27be4c2b573ed79189cd 100644 (file)
@@ -215,17 +215,19 @@ func TestConnectionHandleStreamRelatedFrames(t *testing.T) {
                name  string
                frame wire.Frame
        }{
-               {name: "STREAM", frame: &wire.StreamFrame{StreamID: id, Data: []byte("foobar")}},
                {name: "RESET_STREAM", frame: &wire.ResetStreamFrame{StreamID: id, ErrorCode: 42, FinalSize: 1337}},
                {name: "STOP_SENDING", frame: &wire.StopSendingFrame{StreamID: id, ErrorCode: 42}},
                {name: "MAX_STREAM_DATA", frame: &wire.MaxStreamDataFrame{StreamID: id, MaximumStreamData: 1337}},
                {name: "STREAM_DATA_BLOCKED", frame: &wire.StreamDataBlockedFrame{StreamID: id, MaximumStreamData: 42}},
+               {name: "STREAM_FRAME", frame: &wire.StreamFrame{StreamID: id, Data: []byte{1, 2, 3, 4, 5, 6, 7, 8}, Offset: 1337}},
        }
 
        for _, test := range tests {
                t.Run(test.name, func(t *testing.T) {
                        tc := newServerTestConnection(t, gomock.NewController(t), nil, false)
-                       _, err := tc.conn.handleFrame(test.frame, protocol.Encryption1RTT, connID, time.Now())
+                       data, err := test.frame.Append(nil, protocol.Version1)
+                       require.NoError(t, err)
+                       _, _, _, err = tc.conn.handleFrames(data, connID, protocol.Encryption1RTT, nil, time.Now())
                        require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.StreamStateError})
                })
        }
@@ -2996,3 +2998,37 @@ func testConnectionMigration(t *testing.T, enabled bool) {
                t.Fatal("timeout")
        }
 }
+
+func TestConnectionDatagrams(t *testing.T) {
+       t.Run("disabled", func(t *testing.T) {
+               testConnectionDatagrams(t, false)
+       })
+       t.Run("enabled", func(t *testing.T) {
+               testConnectionDatagrams(t, true)
+       })
+}
+
+func testConnectionDatagrams(t *testing.T, enabled bool) {
+       tc := newServerTestConnection(t, nil, &Config{EnableDatagrams: enabled}, false)
+
+       data, err := (&wire.DatagramFrame{Data: []byte("foo"), DataLenPresent: true}).Append(nil, protocol.Version1)
+       require.NoError(t, err)
+       data, err = (&wire.DatagramFrame{Data: []byte("bar")}).Append(data, protocol.Version1)
+       require.NoError(t, err)
+       _, _, _, err = tc.conn.handleFrames(data, protocol.ConnectionID{}, protocol.Encryption1RTT, nil, time.Now())
+
+       if !enabled {
+               require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.FrameEncodingError, FrameType: uint64(wire.FrameTypeDatagramWithLength)})
+               return
+       }
+
+       require.NoError(t, err)
+       ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+       defer cancel()
+       d, err := tc.conn.ReceiveDatagram(ctx)
+       require.NoError(t, err)
+       require.Equal(t, []byte("foo"), d)
+       d, err = tc.conn.ReceiveDatagram(ctx)
+       require.NoError(t, err)
+       require.Equal(t, []byte("bar"), d)
+}
index e7b0247aa44f6df2afc759bec4a031afa2b783f2..c82002656809f58ecf5bfb5cdc3a31fee8aa3dce 100644 (file)
@@ -2,6 +2,7 @@ package frames
 
 import (
        "fmt"
+       "io"
 
        "github.com/quic-go/quic-go/internal/ackhandler"
        "github.com/quic-go/quic-go/internal/protocol"
@@ -41,15 +42,32 @@ func Fuzz(data []byte) int {
        var b []byte
        for len(data) > 0 {
                initialLen := len(data)
-               l, f, err := parser.ParseNext(data, encLevel, version)
+               frameType, l, err := parser.ParseType(data, encLevel)
                if err != nil {
+                       if err == io.EOF { // the last frame was a PADDING frame
+                               continue
+                       }
                        break
                }
+
                data = data[l:]
                numFrames++
-               if f == nil { // PADDING frame
-                       continue
+
+               var f wire.Frame
+               switch {
+               case frameType.IsStreamFrameType():
+                       f, l, err = parser.ParseStreamFrame(frameType, data, version)
+               case frameType == wire.FrameTypeAck || frameType == wire.FrameTypeAckECN:
+                       f, l, err = parser.ParseAckFrame(frameType, data, encLevel, version)
+               case frameType == wire.FrameTypeDatagramNoLength || frameType == wire.FrameTypeDatagramWithLength:
+                       f, l, err = parser.ParseDatagramFrame(frameType, data, version)
+               default:
+                       f, l, err = parser.ParseLessCommonFrame(frameType, data, version)
                }
+               if err != nil {
+                       break
+               }
+               data = data[l:]
                wire.IsProbingFrame(f)
                ackhandler.IsFrameAckEliciting(f)
                // We accept empty STREAM frames, but we don't write them.
index 34506b12e011baf3051a1f550614febdf41c6bb0..8d8436123e5529d4b00c326311e3813d6ac4c5f6 100644 (file)
@@ -2,6 +2,19 @@ package ackhandler
 
 import "github.com/quic-go/quic-go/internal/wire"
 
+// IsFrameTypeAckEliciting returns true if the frame is ack-eliciting.
+func IsFrameTypeAckEliciting(t wire.FrameType) bool {
+       //nolint:exhaustive // The default case catches the rest.
+       switch t {
+       case wire.FrameTypeAck, wire.FrameTypeAckECN:
+               return false
+       case wire.FrameTypeConnectionClose, wire.FrameTypeApplicationClose:
+               return false
+       default:
+               return true
+       }
+}
+
 // IsFrameAckEliciting returns true if the frame is ack-eliciting.
 func IsFrameAckEliciting(f wire.Frame) bool {
        _, isAck := f.(*wire.AckFrame)
index 1c363e9304e51bdbf6a7ec2ac2ae9027deb0539c..65cc627e65c845490a2b70de01a5b24833599b33 100644 (file)
@@ -7,6 +7,48 @@ import (
        "github.com/stretchr/testify/require"
 )
 
+func TestIsFrameTypeAckEliciting(t *testing.T) {
+       testCases := map[wire.FrameType]bool{
+               wire.FrameTypePing:               true,
+               wire.FrameTypeAck:                false,
+               wire.FrameTypeAckECN:             false,
+               wire.FrameTypeResetStream:        true,
+               wire.FrameTypeStopSending:        true,
+               wire.FrameTypeCrypto:             true,
+               wire.FrameTypeNewToken:           true,
+               wire.FrameType(0x08):             true,
+               wire.FrameType(0x09):             true,
+               wire.FrameType(0x0a):             true,
+               wire.FrameType(0x0b):             true,
+               wire.FrameType(0x0c):             true,
+               wire.FrameType(0x0d):             true,
+               wire.FrameType(0x0e):             true,
+               wire.FrameType(0x0f):             true,
+               wire.FrameTypeMaxData:            true,
+               wire.FrameTypeMaxStreamData:      true,
+               wire.FrameTypeBidiMaxStreams:     true,
+               wire.FrameTypeUniMaxStreams:      true,
+               wire.FrameTypeDataBlocked:        true,
+               wire.FrameTypeStreamDataBlocked:  true,
+               wire.FrameTypeBidiStreamBlocked:  true,
+               wire.FrameTypeUniStreamBlocked:   true,
+               wire.FrameTypeNewConnectionID:    true,
+               wire.FrameTypeRetireConnectionID: true,
+               wire.FrameTypePathChallenge:      true,
+               wire.FrameTypePathResponse:       true,
+               wire.FrameTypeConnectionClose:    false,
+               wire.FrameTypeApplicationClose:   false,
+               wire.FrameTypeHandshakeDone:      true,
+               wire.FrameTypeResetStreamAt:      true,
+               wire.FrameTypeDatagramNoLength:   true,
+               wire.FrameTypeDatagramWithLength: true,
+       }
+
+       for ft, expected := range testCases {
+               require.Equal(t, expected, IsFrameTypeAckEliciting(ft), "unexpected result for frame type 0x%x", ft)
+       }
+}
+
 func TestAckElicitingFrames(t *testing.T) {
        testCases := map[wire.Frame]bool{
                &wire.AckFrame{}:             false,
index 8befef4f2def998d7f5c250075e2fb802e865706..68bebfa7917242a64a15d7900a427d47e7e7d699 100644 (file)
@@ -21,9 +21,9 @@ type AckFrame struct {
 }
 
 // parseAckFrame reads an ACK frame
-func parseAckFrame(frame *AckFrame, b []byte, typ uint64, ackDelayExponent uint8, _ protocol.Version) (int, error) {
+func parseAckFrame(frame *AckFrame, b []byte, typ FrameType, ackDelayExponent uint8, _ protocol.Version) (int, error) {
        startLen := len(b)
-       ecn := typ == ackECNFrameType
+       ecn := typ == FrameTypeAckECN
 
        la, l, err := quicvarint.Parse(b)
        if err != nil {
@@ -122,9 +122,9 @@ func parseAckFrame(frame *AckFrame, b []byte, typ uint64, ackDelayExponent uint8
 func (f *AckFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
        hasECN := f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0
        if hasECN {
-               b = append(b, ackECNFrameType)
+               b = append(b, byte(FrameTypeAckECN))
        } else {
-               b = append(b, ackFrameType)
+               b = append(b, byte(FrameTypeAck))
        }
        b = quicvarint.Append(b, uint64(f.LargestAcked()))
        b = quicvarint.Append(b, encodeAckDelay(f.DelayTime))
index f6390bfcd68ff5cb946622d0f82092f78ebc7f22..e7f21a0b8ff0bd39b8530aa48b8efcdb9af61be0 100644 (file)
@@ -17,7 +17,7 @@ func TestParseACKWithoutRanges(t *testing.T) {
        data = append(data, encodeVarInt(0)...)  // num blocks
        data = append(data, encodeVarInt(10)...) // first ack block
        var frame AckFrame
-       n, err := parseAckFrame(&frame, data, ackFrameType, protocol.AckDelayExponent, protocol.Version1)
+       n, err := parseAckFrame(&frame, data, FrameTypeAck, protocol.AckDelayExponent, protocol.Version1)
        require.NoError(t, err)
        require.Equal(t, len(data), n)
        require.Equal(t, protocol.PacketNumber(100), frame.LargestAcked())
@@ -31,7 +31,7 @@ func TestParseACKSinglePacket(t *testing.T) {
        data = append(data, encodeVarInt(0)...) // num blocks
        data = append(data, encodeVarInt(0)...) // first ack block
        var frame AckFrame
-       n, err := parseAckFrame(&frame, data, ackFrameType, protocol.AckDelayExponent, protocol.Version1)
+       n, err := parseAckFrame(&frame, data, FrameTypeAck, protocol.AckDelayExponent, protocol.Version1)
        require.NoError(t, err)
        require.Equal(t, len(data), n)
        require.Equal(t, protocol.PacketNumber(55), frame.LargestAcked())
@@ -45,7 +45,7 @@ func TestParseACKAllPacketsFrom0ToLargest(t *testing.T) {
        data = append(data, encodeVarInt(0)...)  // num blocks
        data = append(data, encodeVarInt(20)...) // first ack block
        var frame AckFrame
-       n, err := parseAckFrame(&frame, data, ackFrameType, protocol.AckDelayExponent, protocol.Version1)
+       n, err := parseAckFrame(&frame, data, FrameTypeAck, protocol.AckDelayExponent, protocol.Version1)
        require.NoError(t, err)
        require.Equal(t, len(data), n)
        require.Equal(t, protocol.PacketNumber(20), frame.LargestAcked())
@@ -59,7 +59,7 @@ func TestParseACKRejectFirstBlockLargerThanLargestAcked(t *testing.T) {
        data = append(data, encodeVarInt(0)...)  // num blocks
        data = append(data, encodeVarInt(21)...) // first ack block
        var frame AckFrame
-       _, err := parseAckFrame(&frame, data, ackFrameType, protocol.AckDelayExponent, protocol.Version1)
+       _, err := parseAckFrame(&frame, data, FrameTypeAck, protocol.AckDelayExponent, protocol.Version1)
        require.EqualError(t, err, "invalid first ACK range")
 }
 
@@ -71,7 +71,7 @@ func TestParseACKWithSingleBlock(t *testing.T) {
        data = append(data, encodeVarInt(98)...)  // gap
        data = append(data, encodeVarInt(50)...)  // ack block
        var frame AckFrame
-       n, err := parseAckFrame(&frame, data, ackFrameType, protocol.AckDelayExponent, protocol.Version1)
+       n, err := parseAckFrame(&frame, data, FrameTypeAck, protocol.AckDelayExponent, protocol.Version1)
        require.NoError(t, err)
        require.Equal(t, len(data), n)
        require.Equal(t, protocol.PacketNumber(1000), frame.LargestAcked())
@@ -93,7 +93,7 @@ func TestParseACKWithMultipleBlocks(t *testing.T) {
        data = append(data, encodeVarInt(1)...) // gap
        data = append(data, encodeVarInt(1)...) // ack block
        var frame AckFrame
-       n, err := parseAckFrame(&frame, data, ackFrameType, protocol.AckDelayExponent, protocol.Version1)
+       n, err := parseAckFrame(&frame, data, FrameTypeAck, protocol.AckDelayExponent, protocol.Version1)
        require.NoError(t, err)
        require.Equal(t, len(data), n)
        require.Equal(t, protocol.PacketNumber(100), frame.LargestAcked())
@@ -118,7 +118,7 @@ func TestParseACKUseAckDelayExponent(t *testing.T) {
                typ, l, err := quicvarint.Parse(b)
                require.NoError(t, err)
                var frame AckFrame
-               n, err := parseAckFrame(&frame, b[l:], typ, protocol.AckDelayExponent+i, protocol.Version1)
+               n, err := parseAckFrame(&frame, b[l:], FrameType(typ), protocol.AckDelayExponent+i, protocol.Version1)
                require.NoError(t, err)
                require.Equal(t, len(b[l:]), n)
                require.Equal(t, delayTime*(1<<i), frame.DelayTime)
@@ -131,7 +131,7 @@ func TestParseACKHandleDelayTimeOverflow(t *testing.T) {
        data = append(data, encodeVarInt(0)...)                // num blocks
        data = append(data, encodeVarInt(0)...)                // first ack block
        var frame AckFrame
-       _, err := parseAckFrame(&frame, data, ackFrameType, protocol.AckDelayExponent, protocol.Version1)
+       _, err := parseAckFrame(&frame, data, FrameTypeAck, protocol.AckDelayExponent, protocol.Version1)
        require.NoError(t, err)
        require.Greater(t, frame.DelayTime, time.Duration(0))
        // The maximum encodable duration is ~292 years.
@@ -146,11 +146,11 @@ func TestParseACKErrorOnEOF(t *testing.T) {
        data = append(data, encodeVarInt(98)...)  // gap
        data = append(data, encodeVarInt(50)...)  // ack block
        var frame AckFrame
-       _, err := parseAckFrame(&frame, data, ackFrameType, protocol.AckDelayExponent, protocol.Version1)
+       _, err := parseAckFrame(&frame, data, FrameTypeAck, protocol.AckDelayExponent, protocol.Version1)
        require.NoError(t, err)
        for i := range data {
                var frame AckFrame
-               _, err := parseAckFrame(&frame, data[:i], ackFrameType, protocol.AckDelayExponent, protocol.Version1)
+               _, err := parseAckFrame(&frame, data[:i], FrameTypeAck, protocol.AckDelayExponent, protocol.Version1)
                require.Equal(t, io.EOF, err)
        }
 }
@@ -164,7 +164,7 @@ func TestParseACKECN(t *testing.T) {
        data = append(data, encodeVarInt(0x12345)...)    // ECT(1)
        data = append(data, encodeVarInt(0x12345678)...) // ECN-CE
        var frame AckFrame
-       n, err := parseAckFrame(&frame, data, ackECNFrameType, protocol.AckDelayExponent, protocol.Version1)
+       n, err := parseAckFrame(&frame, data, FrameTypeAckECN, protocol.AckDelayExponent, protocol.Version1)
        require.NoError(t, err)
        require.Equal(t, len(data), n)
        require.Equal(t, protocol.PacketNumber(100), frame.LargestAcked())
@@ -186,12 +186,12 @@ func TestParseACKECNErrorOnEOF(t *testing.T) {
        data = append(data, encodeVarInt(0x12345)...)    // ECT(1)
        data = append(data, encodeVarInt(0x12345678)...) // ECN-CE
        var frame AckFrame
-       n, err := parseAckFrame(&frame, data, ackECNFrameType, protocol.AckDelayExponent, protocol.Version1)
+       n, err := parseAckFrame(&frame, data, FrameTypeAckECN, protocol.AckDelayExponent, protocol.Version1)
        require.NoError(t, err)
        require.Equal(t, len(data), n)
        for i := range data {
                var frame AckFrame
-               _, err := parseAckFrame(&frame, data[:i], ackECNFrameType, protocol.AckDelayExponent, protocol.Version1)
+               _, err := parseAckFrame(&frame, data[:i], FrameTypeAckECN, protocol.AckDelayExponent, protocol.Version1)
                require.Equal(t, io.EOF, err)
        }
 }
@@ -202,7 +202,7 @@ func TestWriteACKSimpleFrame(t *testing.T) {
        }
        b, err := f.Append(nil, protocol.Version1)
        require.NoError(t, err)
-       expected := []byte{ackFrameType}
+       expected := []byte{byte(FrameTypeAck)}
        expected = append(expected, encodeVarInt(1337)...) // largest acked
        expected = append(expected, 0)                     // delay
        expected = append(expected, encodeVarInt(0)...)    // num ranges
@@ -220,7 +220,7 @@ func TestWriteACKECNFrame(t *testing.T) {
        b, err := f.Append(nil, protocol.Version1)
        require.NoError(t, err)
        require.Len(t, b, int(f.Length(protocol.Version1)))
-       expected := []byte{ackECNFrameType}
+       expected := []byte{byte(FrameTypeAckECN)}
        expected = append(expected, encodeVarInt(2000)...) // largest acked
        expected = append(expected, 0)                     // delay
        expected = append(expected, encodeVarInt(0)...)    // num ranges
@@ -243,7 +243,7 @@ func TestWriteACKSinglePacket(t *testing.T) {
        require.NoError(t, err)
        b = b[l:]
        var frame AckFrame
-       n, err := parseAckFrame(&frame, b, typ, protocol.AckDelayExponent, protocol.Version1)
+       n, err := parseAckFrame(&frame, b, FrameType(typ), protocol.AckDelayExponent, protocol.Version1)
        require.NoError(t, err)
        require.Equal(t, len(b), n)
        require.Equal(t, f, &frame)
@@ -262,7 +262,7 @@ func TestWriteACKManyPackets(t *testing.T) {
        require.NoError(t, err)
        b = b[l:]
        var frame AckFrame
-       n, err := parseAckFrame(&frame, b, typ, protocol.AckDelayExponent, protocol.Version1)
+       n, err := parseAckFrame(&frame, b, FrameType(typ), protocol.AckDelayExponent, protocol.Version1)
        require.NoError(t, err)
        require.Equal(t, len(b), n)
        require.Equal(t, f, &frame)
@@ -284,7 +284,7 @@ func TestWriteACKSingleGap(t *testing.T) {
        require.NoError(t, err)
        b = b[l:]
        var frame AckFrame
-       n, err := parseAckFrame(&frame, b, typ, protocol.AckDelayExponent, protocol.Version1)
+       n, err := parseAckFrame(&frame, b, FrameType(typ), protocol.AckDelayExponent, protocol.Version1)
        require.NoError(t, err)
        require.Equal(t, len(b), n)
        require.Equal(t, f, &frame)
@@ -308,7 +308,7 @@ func TestWriteACKMultipleRanges(t *testing.T) {
        require.NoError(t, err)
        b = b[l:]
        var frame AckFrame
-       n, err := parseAckFrame(&frame, b, typ, protocol.AckDelayExponent, protocol.Version1)
+       n, err := parseAckFrame(&frame, b, FrameType(typ), protocol.AckDelayExponent, protocol.Version1)
        require.NoError(t, err)
        require.Equal(t, len(b), n)
        require.Equal(t, f, &frame)
@@ -333,7 +333,7 @@ func TestWriteACKLimitMaxSize(t *testing.T) {
        require.NoError(t, err)
        b = b[l:]
        var frame AckFrame
-       n, err := parseAckFrame(&frame, b, typ, protocol.AckDelayExponent, protocol.Version1)
+       n, err := parseAckFrame(&frame, b, FrameType(typ), protocol.AckDelayExponent, protocol.Version1)
        require.NoError(t, err)
        require.Equal(t, len(b), n)
        require.True(t, frame.HasMissingRanges())
index be11a1b2ecbb458bce9baf807bd07375521a75aa..6c71aab6b4452f865437db98a606202344886b6d 100644 (file)
@@ -15,9 +15,9 @@ type ConnectionCloseFrame struct {
        ReasonPhrase       string
 }
 
-func parseConnectionCloseFrame(b []byte, typ uint64, _ protocol.Version) (*ConnectionCloseFrame, int, error) {
+func parseConnectionCloseFrame(b []byte, typ FrameType, _ protocol.Version) (*ConnectionCloseFrame, int, error) {
        startLen := len(b)
-       f := &ConnectionCloseFrame{IsApplicationError: typ == applicationCloseFrameType}
+       f := &ConnectionCloseFrame{IsApplicationError: typ == FrameTypeApplicationClose}
        ec, l, err := quicvarint.Parse(b)
        if err != nil {
                return nil, 0, replaceUnexpectedEOF(err)
@@ -60,9 +60,9 @@ func (f *ConnectionCloseFrame) Length(protocol.Version) protocol.ByteCount {
 
 func (f *ConnectionCloseFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
        if f.IsApplicationError {
-               b = append(b, applicationCloseFrameType)
+               b = append(b, byte(FrameTypeApplicationClose))
        } else {
-               b = append(b, connectionCloseFrameType)
+               b = append(b, byte(FrameTypeConnectionClose))
        }
 
        b = quicvarint.Append(b, f.ErrorCode)
index df76b24a48a8f08b458de45d99a9750c997e4943..ae54113bdbc72a496c8fcfb85a28998e1130ee33 100644 (file)
@@ -15,7 +15,7 @@ func TestParseConnectionCloseTransportError(t *testing.T) {
        data = append(data, encodeVarInt(0x1337)...)              // frame type
        data = append(data, encodeVarInt(uint64(len(reason)))...) // reason phrase length
        data = append(data, []byte(reason)...)
-       frame, l, err := parseConnectionCloseFrame(data, connectionCloseFrameType, protocol.Version1)
+       frame, l, err := parseConnectionCloseFrame(data, FrameTypeConnectionClose, protocol.Version1)
        require.NoError(t, err)
        require.False(t, frame.IsApplicationError)
        require.EqualValues(t, 0x19, frame.ErrorCode)
@@ -29,7 +29,7 @@ func TestParseConnectionCloseWithApplicationError(t *testing.T) {
        data := encodeVarInt(0xcafe)
        data = append(data, encodeVarInt(uint64(len(reason)))...) // reason phrase length
        data = append(data, reason...)
-       frame, l, err := parseConnectionCloseFrame(data, applicationCloseFrameType, protocol.Version1)
+       frame, l, err := parseConnectionCloseFrame(data, FrameTypeApplicationClose, protocol.Version1)
        require.NoError(t, err)
        require.True(t, frame.IsApplicationError)
        require.EqualValues(t, 0xcafe, frame.ErrorCode)
@@ -41,7 +41,7 @@ func TestParseConnectionCloseLongReasonPhrase(t *testing.T) {
        data := encodeVarInt(0xcafe)
        data = append(data, encodeVarInt(0x42)...)   // frame type
        data = append(data, encodeVarInt(0xffff)...) // reason phrase length
-       _, _, err := parseConnectionCloseFrame(data, connectionCloseFrameType, protocol.Version1)
+       _, _, err := parseConnectionCloseFrame(data, FrameTypeConnectionClose, protocol.Version1)
        require.Equal(t, io.EOF, err)
 }
 
@@ -51,11 +51,11 @@ func TestParseConnectionCloseErrorsOnEOFs(t *testing.T) {
        data = append(data, encodeVarInt(0x1337)...)              // frame type
        data = append(data, encodeVarInt(uint64(len(reason)))...) // reason phrase length
        data = append(data, []byte(reason)...)
-       _, l, err := parseConnectionCloseFrame(data, connectionCloseFrameType, protocol.Version1)
+       _, l, err := parseConnectionCloseFrame(data, FrameTypeConnectionClose, protocol.Version1)
        require.Equal(t, len(data), l)
        require.NoError(t, err)
        for i := range data {
-               _, _, err = parseConnectionCloseFrame(data[:i], connectionCloseFrameType, protocol.Version1)
+               _, _, err = parseConnectionCloseFrame(data[:i], FrameTypeConnectionClose, protocol.Version1)
                require.Equal(t, io.EOF, err)
        }
 }
@@ -64,7 +64,7 @@ func TestParseConnectionCloseNoReasonPhrase(t *testing.T) {
        data := encodeVarInt(0xcafe)
        data = append(data, encodeVarInt(0x42)...) // frame type
        data = append(data, encodeVarInt(0)...)
-       frame, l, err := parseConnectionCloseFrame(data, connectionCloseFrameType, protocol.Version1)
+       frame, l, err := parseConnectionCloseFrame(data, FrameTypeConnectionClose, protocol.Version1)
        require.NoError(t, err)
        require.Empty(t, frame.ReasonPhrase)
        require.Equal(t, len(data), l)
@@ -77,7 +77,7 @@ func TestWriteConnectionCloseNoReasonPhrase(t *testing.T) {
        }
        b, err := frame.Append(nil, protocol.Version1)
        require.NoError(t, err)
-       expected := []byte{connectionCloseFrameType}
+       expected := []byte{byte(FrameTypeConnectionClose)}
        expected = append(expected, encodeVarInt(0xbeef)...)
        expected = append(expected, encodeVarInt(0x12345)...) // frame type
        expected = append(expected, encodeVarInt(0)...)       // reason phrase length
@@ -91,7 +91,7 @@ func TestWriteConnectionCloseWithReasonPhrase(t *testing.T) {
        }
        b, err := frame.Append(nil, protocol.Version1)
        require.NoError(t, err)
-       expected := []byte{connectionCloseFrameType}
+       expected := []byte{byte(FrameTypeConnectionClose)}
        expected = append(expected, encodeVarInt(0xdead)...)
        expected = append(expected, encodeVarInt(0)...) // frame type
        expected = append(expected, encodeVarInt(6)...) // reason phrase length
@@ -107,7 +107,7 @@ func TestWriteConnectionCloseWithApplicationError(t *testing.T) {
        }
        b, err := frame.Append(nil, protocol.Version1)
        require.NoError(t, err)
-       expected := []byte{applicationCloseFrameType}
+       expected := []byte{byte(FrameTypeApplicationClose)}
        expected = append(expected, encodeVarInt(0xdead)...)
        expected = append(expected, encodeVarInt(6)...) // reason phrase length
        expected = append(expected, []byte("foobar")...)
index 0aa7fe7bc76ac33077cfbc4373d586a37a0cdf87..60a713f7413cc587db67f341611d5edb3893219a 100644 (file)
@@ -38,7 +38,7 @@ func parseCryptoFrame(b []byte, _ protocol.Version) (*CryptoFrame, int, error) {
 }
 
 func (f *CryptoFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
-       b = append(b, cryptoFrameType)
+       b = append(b, byte(FrameTypeCrypto))
        b = quicvarint.Append(b, uint64(f.Offset))
        b = quicvarint.Append(b, uint64(len(f.Data)))
        b = append(b, f.Data...)
index 89965bf5a33dbf8dc95a1f03742a48d5740321ad..674b03acb101d3e02dfa40daed1a04e4928cb183 100644 (file)
@@ -40,7 +40,7 @@ func TestWriteCryptoFrame(t *testing.T) {
        }
        b, err := f.Append(nil, protocol.Version1)
        require.NoError(t, err)
-       expected := []byte{cryptoFrameType}
+       expected := []byte{byte(FrameTypeCrypto)}
        expected = append(expected, encodeVarInt(0x123456)...) // offset
        expected = append(expected, encodeVarInt(6)...)        // length
        expected = append(expected, []byte("foobar")...)
index c97d4c62948fa9950041b3798840a3a461fca784..11c72ea32d23339ed1924db27de8f66d56667f0c 100644 (file)
@@ -19,7 +19,7 @@ func parseDataBlockedFrame(b []byte, _ protocol.Version) (*DataBlockedFrame, int
 }
 
 func (f *DataBlockedFrame) Append(b []byte, version protocol.Version) ([]byte, error) {
-       b = append(b, dataBlockedFrameType)
+       b = append(b, byte(FrameTypeDataBlocked))
        return quicvarint.Append(b, uint64(f.MaximumData)), nil
 }
 
index 12b1535effb5041416e59f12412c73d63414dd03..ba19e84602f93274c4e669d2777f3694d01b52b8 100644 (file)
@@ -33,7 +33,7 @@ func TestWriteDataBlocked(t *testing.T) {
        frame := DataBlockedFrame{MaximumData: 0xdeadbeef}
        b, err := frame.Append(nil, protocol.Version1)
        require.NoError(t, err)
-       expected := []byte{dataBlockedFrameType}
+       expected := []byte{byte(FrameTypeDataBlocked)}
        expected = append(expected, encodeVarInt(0xdeadbeef)...)
        require.Equal(t, expected, b)
        require.Equal(t, protocol.ByteCount(1+quicvarint.Len(uint64(frame.MaximumData))), frame.Length(protocol.Version1))
index 071fda9a05b69539b72eba9e6da6d0d61a1191ee..a6034867effb748aa6d80348359059b3f5a9ef83 100644 (file)
@@ -19,10 +19,10 @@ type DatagramFrame struct {
        Data           []byte
 }
 
-func parseDatagramFrame(b []byte, typ uint64, _ protocol.Version) (*DatagramFrame, int, error) {
+func parseDatagramFrame(b []byte, typ FrameType, _ protocol.Version) (*DatagramFrame, int, error) {
        startLen := len(b)
        f := &DatagramFrame{}
-       f.DataLenPresent = typ&0x1 > 0
+       f.DataLenPresent = uint64(typ)&0x1 > 0
 
        var length uint64
        if f.DataLenPresent {
index 10d4eebc31cfaff6e5e2d8b34fc2b98e9f1fb0e6..09ea92f754180a2e453c19cdae78406a6603edde 100644 (file)
@@ -19,3 +19,15 @@ func IsProbingFrame(f Frame) bool {
        }
        return false
 }
+
+// IsProbingFrameType returns true if the FrameType is a probing frame.
+// See section 9.1 of RFC 9000.
+func IsProbingFrameType(f FrameType) bool {
+       //nolint:exhaustive // PATH_CHALLENGE, PATH_RESPONSE and NEW_CONNECTION_ID are the only probing frames
+       switch f {
+       case FrameTypePathChallenge, FrameTypePathResponse, FrameTypeNewConnectionID:
+               return true
+       default:
+               return false
+       }
+}
index 794c70ccf10494543ec204e6a50ac197df7d3ddf..e92e29c8bd8e271d7deb6bfa50d0b7c82d36616b 100644 (file)
@@ -4,39 +4,12 @@ import (
        "errors"
        "fmt"
        "io"
-       "reflect"
 
        "github.com/quic-go/quic-go/internal/protocol"
        "github.com/quic-go/quic-go/internal/qerr"
        "github.com/quic-go/quic-go/quicvarint"
 )
 
-const (
-       pingFrameType               = 0x1
-       ackFrameType                = 0x2
-       ackECNFrameType             = 0x3
-       resetStreamFrameType        = 0x4
-       stopSendingFrameType        = 0x5
-       cryptoFrameType             = 0x6
-       newTokenFrameType           = 0x7
-       maxDataFrameType            = 0x10
-       maxStreamDataFrameType      = 0x11
-       bidiMaxStreamsFrameType     = 0x12
-       uniMaxStreamsFrameType      = 0x13
-       dataBlockedFrameType        = 0x14
-       streamDataBlockedFrameType  = 0x15
-       bidiStreamBlockedFrameType  = 0x16
-       uniStreamBlockedFrameType   = 0x17
-       newConnectionIDFrameType    = 0x18
-       retireConnectionIDFrameType = 0x19
-       pathChallengeFrameType      = 0x1a
-       pathResponseFrameType       = 0x1b
-       connectionCloseFrameType    = 0x1c
-       applicationCloseFrameType   = 0x1d
-       handshakeDoneFrameType      = 0x1e
-       resetStreamAtFrameType      = 0x24 // https://datatracker.ietf.org/doc/draft-ietf-quic-reliable-stream-reset/06/
-)
-
 var errUnknownFrameType = errors.New("unknown frame type")
 
 // The FrameParser parses QUIC frames, one by one.
@@ -59,20 +32,15 @@ func NewFrameParser(supportsDatagrams, supportsResetStreamAt bool) *FrameParser
        }
 }
 
-// ParseNext parses the next frame.
-// It skips PADDING frames.
-func (p *FrameParser) ParseNext(data []byte, encLevel protocol.EncryptionLevel, v protocol.Version) (int, Frame, error) {
-       frame, l, err := p.parseNext(data, encLevel, v)
-       return l, frame, err
-}
-
-func (p *FrameParser) parseNext(b []byte, encLevel protocol.EncryptionLevel, v protocol.Version) (Frame, int, error) {
+// ParseType parses the frame type of the next frame.
+// It skips over PADDING frames.
+func (p *FrameParser) ParseType(b []byte, encLevel protocol.EncryptionLevel) (FrameType, int, error) {
        var parsed int
        for len(b) != 0 {
                typ, l, err := quicvarint.Parse(b)
                parsed += l
                if err != nil {
-                       return nil, parsed, &qerr.TransportError{
+                       return 0, parsed, &qerr.TransportError{
                                ErrorCode:    qerr.FrameEncodingError,
                                ErrorMessage: err.Error(),
                        }
@@ -81,115 +49,126 @@ func (p *FrameParser) parseNext(b []byte, encLevel protocol.EncryptionLevel, v p
                if typ == 0x0 { // skip PADDING frames
                        continue
                }
-
-               f, l, err := p.parseFrame(b, typ, encLevel, v)
-               parsed += l
-               if err != nil {
-                       return nil, parsed, &qerr.TransportError{
+               ft := FrameType(typ)
+               valid := ft.isValidRFC9000() ||
+                       (p.supportsDatagrams && ft.IsDatagramFrameType()) ||
+                       (p.supportsResetStreamAt && ft == FrameTypeResetStreamAt)
+               if !valid {
+                       return 0, parsed, &qerr.TransportError{
+                               ErrorCode:    qerr.FrameEncodingError,
                                FrameType:    typ,
+                               ErrorMessage: errUnknownFrameType.Error(),
+                       }
+               }
+               if !ft.isAllowedAtEncLevel(encLevel) {
+                       return 0, parsed, &qerr.TransportError{
                                ErrorCode:    qerr.FrameEncodingError,
-                               ErrorMessage: err.Error(),
+                               FrameType:    typ,
+                               ErrorMessage: fmt.Sprintf("%d not allowed at encryption level %s", ft, encLevel),
                        }
                }
-               return f, parsed, nil
+               return ft, parsed, nil
        }
-       return nil, parsed, nil
+       return 0, parsed, io.EOF
 }
 
-func (p *FrameParser) parseFrame(b []byte, typ uint64, encLevel protocol.EncryptionLevel, v protocol.Version) (Frame, int, error) {
-       var frame Frame
-       var err error
-       var l int
-       if typ&0xf8 == 0x8 {
-               frame, l, err = parseStreamFrame(b, typ, v)
-       } else {
-               switch typ {
-               case pingFrameType:
-                       frame = &PingFrame{}
-               case ackFrameType, ackECNFrameType:
-                       ackDelayExponent := p.ackDelayExponent
-                       if encLevel != protocol.Encryption1RTT {
-                               ackDelayExponent = protocol.DefaultAckDelayExponent
-                       }
-                       p.ackFrame.Reset()
-                       l, err = parseAckFrame(p.ackFrame, b, typ, ackDelayExponent, v)
-                       frame = p.ackFrame
-               case resetStreamFrameType:
-                       frame, l, err = parseResetStreamFrame(b, false, v)
-               case stopSendingFrameType:
-                       frame, l, err = parseStopSendingFrame(b, v)
-               case cryptoFrameType:
-                       frame, l, err = parseCryptoFrame(b, v)
-               case newTokenFrameType:
-                       frame, l, err = parseNewTokenFrame(b, v)
-               case maxDataFrameType:
-                       frame, l, err = parseMaxDataFrame(b, v)
-               case maxStreamDataFrameType:
-                       frame, l, err = parseMaxStreamDataFrame(b, v)
-               case bidiMaxStreamsFrameType, uniMaxStreamsFrameType:
-                       frame, l, err = parseMaxStreamsFrame(b, typ, v)
-               case dataBlockedFrameType:
-                       frame, l, err = parseDataBlockedFrame(b, v)
-               case streamDataBlockedFrameType:
-                       frame, l, err = parseStreamDataBlockedFrame(b, v)
-               case bidiStreamBlockedFrameType, uniStreamBlockedFrameType:
-                       frame, l, err = parseStreamsBlockedFrame(b, typ, v)
-               case newConnectionIDFrameType:
-                       frame, l, err = parseNewConnectionIDFrame(b, v)
-               case retireConnectionIDFrameType:
-                       frame, l, err = parseRetireConnectionIDFrame(b, v)
-               case pathChallengeFrameType:
-                       frame, l, err = parsePathChallengeFrame(b, v)
-               case pathResponseFrameType:
-                       frame, l, err = parsePathResponseFrame(b, v)
-               case connectionCloseFrameType, applicationCloseFrameType:
-                       frame, l, err = parseConnectionCloseFrame(b, typ, v)
-               case handshakeDoneFrameType:
-                       frame = &HandshakeDoneFrame{}
-               case 0x30, 0x31:
-                       if !p.supportsDatagrams {
-                               return nil, 0, errUnknownFrameType
-                       }
-                       frame, l, err = parseDatagramFrame(b, typ, v)
-               case resetStreamAtFrameType:
-                       if !p.supportsResetStreamAt {
-                               return nil, 0, errUnknownFrameType
-                       }
-                       frame, l, err = parseResetStreamFrame(b, true, v)
-               default:
-                       err = errUnknownFrameType
+func (p *FrameParser) ParseStreamFrame(frameType FrameType, data []byte, v protocol.Version) (*StreamFrame, int, error) {
+       frame, n, err := ParseStreamFrame(data, frameType, v)
+       if err != nil {
+               return nil, n, &qerr.TransportError{
+                       ErrorCode:    qerr.FrameEncodingError,
+                       FrameType:    uint64(frameType),
+                       ErrorMessage: err.Error(),
                }
        }
-       if err != nil {
-               return nil, 0, err
+       return frame, n, nil
+}
+
+func (p *FrameParser) ParseAckFrame(frameType FrameType, data []byte, encLevel protocol.EncryptionLevel, v protocol.Version) (*AckFrame, int, error) {
+       ackDelayExponent := p.ackDelayExponent
+       if encLevel != protocol.Encryption1RTT {
+               ackDelayExponent = protocol.DefaultAckDelayExponent
        }
-       if !p.isAllowedAtEncLevel(frame, encLevel) {
-               return nil, l, fmt.Errorf("%s not allowed at encryption level %s", reflect.TypeOf(frame).Elem().Name(), encLevel)
+       p.ackFrame.Reset()
+       l, err := parseAckFrame(p.ackFrame, data, frameType, ackDelayExponent, v)
+       if err != nil {
+               return nil, l, &qerr.TransportError{
+                       ErrorCode:    qerr.FrameEncodingError,
+                       FrameType:    uint64(frameType),
+                       ErrorMessage: err.Error(),
+               }
        }
-       return frame, l, nil
+
+       return p.ackFrame, l, nil
 }
 
-func (p *FrameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionLevel) bool {
-       switch encLevel {
-       case protocol.EncryptionInitial, protocol.EncryptionHandshake:
-               switch f.(type) {
-               case *CryptoFrame, *AckFrame, *ConnectionCloseFrame, *PingFrame:
-                       return true
-               default:
-                       return false
-               }
-       case protocol.Encryption0RTT:
-               switch f.(type) {
-               case *CryptoFrame, *AckFrame, *ConnectionCloseFrame, *NewTokenFrame, *PathResponseFrame, *RetireConnectionIDFrame:
-                       return false
-               default:
-                       return true
+func (p *FrameParser) ParseDatagramFrame(frameType FrameType, data []byte, v protocol.Version) (*DatagramFrame, int, error) {
+       f, l, err := parseDatagramFrame(data, frameType, v)
+       if err != nil {
+               return nil, 0, &qerr.TransportError{
+                       ErrorCode:    qerr.FrameEncodingError,
+                       FrameType:    uint64(frameType),
+                       ErrorMessage: err.Error(),
                }
-       case protocol.Encryption1RTT:
-               return true
+       }
+       return f, l, nil
+}
+
+// ParseLessCommonFrame parses everything except STREAM, ACK or DATAGRAM.
+// These cases should be handled separately for performance reasons.
+func (p *FrameParser) ParseLessCommonFrame(frameType FrameType, data []byte, v protocol.Version) (Frame, int, error) {
+       var frame Frame
+       var l int
+       var err error
+       //nolint:exhaustive // Common frames should already be handled.
+       switch frameType {
+       case FrameTypePing:
+               frame = &PingFrame{}
+       case FrameTypeResetStream:
+               frame, l, err = parseResetStreamFrame(data, false, v)
+       case FrameTypeStopSending:
+               frame, l, err = parseStopSendingFrame(data, v)
+       case FrameTypeCrypto:
+               frame, l, err = parseCryptoFrame(data, v)
+       case FrameTypeNewToken:
+               frame, l, err = parseNewTokenFrame(data, v)
+       case FrameTypeMaxData:
+               frame, l, err = parseMaxDataFrame(data, v)
+       case FrameTypeMaxStreamData:
+               frame, l, err = parseMaxStreamDataFrame(data, v)
+       case FrameTypeBidiMaxStreams, FrameTypeUniMaxStreams:
+               frame, l, err = parseMaxStreamsFrame(data, frameType, v)
+       case FrameTypeDataBlocked:
+               frame, l, err = parseDataBlockedFrame(data, v)
+       case FrameTypeStreamDataBlocked:
+               frame, l, err = parseStreamDataBlockedFrame(data, v)
+       case FrameTypeBidiStreamBlocked, FrameTypeUniStreamBlocked:
+               frame, l, err = parseStreamsBlockedFrame(data, frameType, v)
+       case FrameTypeNewConnectionID:
+               frame, l, err = parseNewConnectionIDFrame(data, v)
+       case FrameTypeRetireConnectionID:
+               frame, l, err = parseRetireConnectionIDFrame(data, v)
+       case FrameTypePathChallenge:
+               frame, l, err = parsePathChallengeFrame(data, v)
+       case FrameTypePathResponse:
+               frame, l, err = parsePathResponseFrame(data, v)
+       case FrameTypeConnectionClose, FrameTypeApplicationClose:
+               frame, l, err = parseConnectionCloseFrame(data, frameType, v)
+       case FrameTypeHandshakeDone:
+               frame = &HandshakeDoneFrame{}
+       case FrameTypeResetStreamAt:
+               frame, l, err = parseResetStreamFrame(data, true, v)
        default:
-               panic("unknown encryption level")
+               err = errUnknownFrameType
+       }
+       if err != nil {
+               return frame, l, &qerr.TransportError{
+                       ErrorCode:    qerr.FrameEncodingError,
+                       FrameType:    uint64(frameType),
+                       ErrorMessage: err.Error(),
+               }
        }
+       return frame, l, err
 }
 
 // SetAckDelayExponent sets the acknowledgment delay exponent (sent in the transport parameters).
index fefb5c11ee0bcec762c0521b42f0b8f70221e827..821bc5d65268353df9b488766fc92be13e38a9bd 100644 (file)
@@ -3,22 +3,31 @@ package wire
 import (
        "bytes"
        "crypto/rand"
+       "fmt"
+       "io"
        "slices"
        "testing"
        "time"
 
        "github.com/quic-go/quic-go/internal/protocol"
        "github.com/quic-go/quic-go/internal/qerr"
-
        "github.com/stretchr/testify/require"
 )
 
-func TestFrameParsingReturnsNilWhenNothingToRead(t *testing.T) {
+func TestFrameTypeParsingReturnsNilWhenNothingToRead(t *testing.T) {
        parser := NewFrameParser(true, true)
-       l, f, err := parser.ParseNext(nil, protocol.Encryption1RTT, protocol.Version1)
-       require.NoError(t, err)
+       frameType, l, err := parser.ParseType(nil, protocol.Encryption1RTT)
+       require.Equal(t, io.EOF, err)
+       require.Zero(t, frameType)
+       require.Zero(t, l)
+}
+
+func TestParseLessCommonFrameReturnsEOFWhenNothingToRead(t *testing.T) {
+       parser := NewFrameParser(true, true)
+       l, f, err := parser.ParseLessCommonFrame(FrameTypeMaxStreamData, nil, protocol.Version1)
+       require.IsType(t, &qerr.TransportError{}, err)
        require.Zero(t, l)
-       require.Nil(t, f)
+       require.Zero(t, f)
 }
 
 func TestFrameParsingSkipsPaddingFrames(t *testing.T) {
@@ -26,17 +35,24 @@ func TestFrameParsingSkipsPaddingFrames(t *testing.T) {
        b := []byte{0, 0} // 2 PADDING frames
        b, err := (&PingFrame{}).Append(b, protocol.Version1)
        require.NoError(t, err)
-       l, f, err := parser.ParseNext(b, protocol.Encryption1RTT, protocol.Version1)
+
+       frameType, l, err := parser.ParseType(b, protocol.Encryption1RTT)
        require.NoError(t, err)
-       require.Equal(t, &PingFrame{}, f)
-       require.Equal(t, 2+1, l)
+       require.Equal(t, 3, l)
+       require.Equal(t, FrameTypePing, frameType)
+
+       frame, l, err := parser.ParseLessCommonFrame(frameType, b[1:], protocol.Version1)
+       require.NoError(t, err)
+       require.Zero(t, l)
+       require.IsType(t, &PingFrame{}, frame)
 }
 
 func TestFrameParsingHandlesPaddingAtEnd(t *testing.T) {
        parser := NewFrameParser(true, true)
-       l, f, err := parser.ParseNext([]byte{0, 0, 0}, protocol.Encryption1RTT, protocol.Version1)
-       require.NoError(t, err)
-       require.Nil(t, f)
+       b := []byte{0, 0, 0}
+
+       _, l, err := parser.ParseType(b, protocol.Encryption1RTT)
+       require.Equal(t, io.EOF, err)
        require.Equal(t, 3, l)
 }
 
@@ -48,10 +64,15 @@ func TestFrameParsingParsesSingleFrame(t *testing.T) {
                b, err = (&PingFrame{}).Append(b, protocol.Version1)
                require.NoError(t, err)
        }
-       l, f, err := parser.ParseNext(b, protocol.Encryption1RTT, protocol.Version1)
+       frameType, l, err := parser.ParseType(b, protocol.Encryption1RTT)
        require.NoError(t, err)
-       require.IsType(t, &PingFrame{}, f)
+       require.Equal(t, FrameTypePing, frameType)
        require.Equal(t, 1, l)
+
+       frame, l, err := parser.ParseLessCommonFrame(frameType, b, protocol.Version1)
+       require.NoError(t, err)
+       require.Zero(t, l)
+       require.IsType(t, &PingFrame{}, frame)
 }
 
 func TestFrameParserACK(t *testing.T) {
@@ -59,12 +80,16 @@ func TestFrameParserACK(t *testing.T) {
        f := &AckFrame{AckRanges: []AckRange{{Smallest: 1, Largest: 0x13}}}
        b, err := f.Append(nil, protocol.Version1)
        require.NoError(t, err)
-       l, frame, err := parser.ParseNext(b, protocol.Encryption1RTT, protocol.Version1)
+       frameType, l, err := parser.ParseType(b, protocol.Encryption1RTT)
+       require.NoError(t, err)
+       require.Equal(t, FrameTypeAck, frameType)
+       require.Equal(t, 1, l)
+
+       frame, l, err := parser.ParseAckFrame(frameType, b[l:], protocol.Encryption1RTT, protocol.Version1)
        require.NoError(t, err)
        require.NotNil(t, frame)
-       require.IsType(t, f, frame)
-       require.Equal(t, protocol.PacketNumber(0x13), frame.(*AckFrame).LargestAcked())
-       require.Equal(t, len(b), l)
+       require.Equal(t, protocol.PacketNumber(0x13), frame.LargestAcked())
+       require.Equal(t, len(b)-1, l)
 }
 
 func TestFrameParserAckDelay(t *testing.T) {
@@ -85,15 +110,31 @@ func testFrameParserAckDelay(t *testing.T, encLevel protocol.EncryptionLevel) {
        }
        b, err := f.Append(nil, protocol.Version1)
        require.NoError(t, err)
-       _, frame, err := parser.ParseNext(b, encLevel, protocol.Version1)
+       frameType, l, err := parser.ParseType(b, encLevel)
+       require.NoError(t, err)
+       require.Equal(t, FrameTypeAck, frameType)
+       require.Equal(t, 1, l)
+
+       frame, l, err := parser.ParseAckFrame(frameType, b[l:], encLevel, protocol.Version1)
        require.NoError(t, err)
+       require.Equal(t, len(b)-1, l)
        if encLevel == protocol.Encryption1RTT {
-               require.Equal(t, 4*time.Second, frame.(*AckFrame).DelayTime)
+               require.Equal(t, 4*time.Second, frame.DelayTime)
        } else {
-               require.Equal(t, time.Second, frame.(*AckFrame).DelayTime)
+               require.Equal(t, time.Second, frame.DelayTime)
        }
 }
 
+func checkFrameUnsupported(t *testing.T, err error, expectedFrameType uint64) {
+       t.Helper()
+       require.ErrorContains(t, err, errUnknownFrameType.Error())
+       var transportErr *qerr.TransportError
+       require.ErrorAs(t, err, &transportErr)
+       require.Equal(t, qerr.FrameEncodingError, transportErr.ErrorCode)
+       require.Equal(t, expectedFrameType, transportErr.FrameType)
+       require.Equal(t, "unknown frame type", transportErr.ErrorMessage)
+}
+
 func TestFrameParserStreamFrames(t *testing.T) {
        parser := NewFrameParser(true, true)
        f := &StreamFrame{
@@ -104,28 +145,95 @@ func TestFrameParserStreamFrames(t *testing.T) {
        }
        b, err := f.Append(nil, protocol.Version1)
        require.NoError(t, err)
-       l, frame, err := parser.ParseNext(b, protocol.Encryption1RTT, protocol.Version1)
+       frameType, l, err := parser.ParseType(b, protocol.Encryption1RTT)
        require.NoError(t, err)
-       require.NotNil(t, frame)
-       require.Equal(t, f, frame)
-       require.Equal(t, len(b), l)
+       require.Equal(t, FrameType(0xd), frameType)
+       require.True(t, frameType.IsStreamFrameType())
+       require.Equal(t, 1, l)
+
+       // ParseLessCommonFrame should not handle Stream Frames
+       frame, l, err := parser.ParseLessCommonFrame(frameType, b[l:], protocol.Version1)
+       checkFrameUnsupported(t, err, 0xd)
+       require.Nil(t, frame)
+       require.Zero(t, l)
+}
+
+func TestParseStreamFrameWrapsError(t *testing.T) {
+       parser := NewFrameParser(true, true)
+       f := &StreamFrame{
+               StreamID:       0x1234,
+               Offset:         0x1000,
+               Data:           []byte("hello world"),
+               DataLenPresent: true,
+       }
+       b, err := f.Append(nil, protocol.Version1)
+       require.NoError(t, err)
+
+       // Corrupt the buffer to trigger a parse error
+       b = b[:len(b)-2] // Remove last 2 bytes to cause an EOF
+
+       frameType, l, err := parser.ParseType(b, protocol.Encryption1RTT)
+       require.NoError(t, err)
+
+       frame, n, err := parser.ParseStreamFrame(frameType, b[l:], protocol.Version1)
+       require.Nil(t, frame)
+       require.Zero(t, n)
+
+       var transportErr *qerr.TransportError
+       require.ErrorAs(t, err, &transportErr)
+       require.Equal(t, qerr.FrameEncodingError, transportErr.ErrorCode)
+       require.Equal(t, uint64(frameType), transportErr.FrameType)
+       require.Contains(t, transportErr.Error(), "EOF")
+}
+
+func TestParseStreamFrameSuccess(t *testing.T) {
+       parser := NewFrameParser(true, true)
+       original := &StreamFrame{
+               StreamID:       0x1234,
+               Offset:         0x1000,
+               Fin:            true,
+               Data:           []byte("hello world"),
+               DataLenPresent: true,
+       }
+       b, err := original.Append(nil, protocol.Version1)
+       require.NoError(t, err)
+
+       frameType, l, err := parser.ParseType(b, protocol.Encryption1RTT)
+       require.NoError(t, err)
+       require.True(t, frameType.IsStreamFrameType())
+       require.Equal(t, FrameType(0x0f), frameType) // STREAM | OFF | LEN | FIN
+
+       parsed, n, err := parser.ParseStreamFrame(frameType, b[l:], protocol.Version1)
+       require.NoError(t, err)
+       require.NotNil(t, parsed)
+       require.Equal(t, len(b)-l, n)
+
+       require.Equal(t, original.StreamID, parsed.StreamID)
+       require.Equal(t, original.Offset, parsed.Offset)
+       require.Equal(t, original.Fin, parsed.Fin)
+       require.Equal(t, original.DataLenPresent, parsed.DataLenPresent)
+       require.Equal(t, original.Data, parsed.Data)
 }
 
 func TestFrameParserFrames(t *testing.T) {
        tests := []struct {
-               name  string
-               frame Frame
+               name      string
+               frameType FrameType
+               frame     Frame
        }{
                {
-                       name:  "MAX_DATA",
-                       frame: &MaxDataFrame{MaximumData: 0xcafe},
+                       name:      "MAX_DATA",
+                       frameType: FrameTypeMaxData,
+                       frame:     &MaxDataFrame{MaximumData: 0xcafe},
                },
                {
-                       name:  "MAX_STREAM_DATA",
-                       frame: &MaxStreamDataFrame{StreamID: 0xdeadbeef, MaximumStreamData: 0xdecafbad},
+                       name:      "MAX_STREAM_DATA",
+                       frameType: FrameTypeMaxStreamData,
+                       frame:     &MaxStreamDataFrame{StreamID: 0xdeadbeef, MaximumStreamData: 0xdecafbad},
                },
                {
-                       name: "RESET_STREAM",
+                       name:      "RESET_STREAM",
+                       frameType: FrameTypeResetStream,
                        frame: &ResetStreamFrame{
                                StreamID:  0xdeadbeef,
                                FinalSize: 0xdecafbad1234,
@@ -133,35 +241,43 @@ func TestFrameParserFrames(t *testing.T) {
                        },
                },
                {
-                       name:  "STOP_SENDING",
-                       frame: &StopSendingFrame{StreamID: 0x42},
+                       name:      "STOP_SENDING",
+                       frameType: FrameTypeStopSending,
+                       frame:     &StopSendingFrame{StreamID: 0x42},
                },
                {
-                       name:  "CRYPTO",
-                       frame: &CryptoFrame{Offset: 0x1337, Data: []byte("lorem ipsum")},
+                       name:      "CRYPTO",
+                       frameType: FrameTypeCrypto,
+                       frame:     &CryptoFrame{Offset: 0x1337, Data: []byte("lorem ipsum")},
                },
                {
-                       name:  "NEW_TOKEN",
-                       frame: &NewTokenFrame{Token: []byte("foobar")},
+                       name:      "NEW_TOKEN",
+                       frameType: FrameTypeNewToken,
+                       frame:     &NewTokenFrame{Token: []byte("foobar")},
                },
                {
-                       name:  "MAX_STREAMS",
-                       frame: &MaxStreamsFrame{Type: protocol.StreamTypeBidi, MaxStreamNum: 0x1337},
+                       name:      "MAX_STREAMS",
+                       frameType: FrameTypeBidiMaxStreams,
+                       frame:     &MaxStreamsFrame{Type: protocol.StreamTypeBidi, MaxStreamNum: 0x1337},
                },
                {
-                       name:  "DATA_BLOCKED",
-                       frame: &DataBlockedFrame{MaximumData: 0x1234},
+                       name:      "DATA_BLOCKED",
+                       frameType: FrameTypeDataBlocked,
+                       frame:     &DataBlockedFrame{MaximumData: 0x1234},
                },
                {
-                       name:  "STREAM_DATA_BLOCKED",
-                       frame: &StreamDataBlockedFrame{StreamID: 0xdeadbeef, MaximumStreamData: 0xdead},
+                       name:      "STREAM_DATA_BLOCKED",
+                       frameType: FrameTypeStreamDataBlocked,
+                       frame:     &StreamDataBlockedFrame{StreamID: 0xdeadbeef, MaximumStreamData: 0xdead},
                },
                {
-                       name:  "STREAMS_BLOCKED",
-                       frame: &StreamsBlockedFrame{Type: protocol.StreamTypeBidi, StreamLimit: 0x1234567},
+                       name:      "STREAMS_BLOCKED",
+                       frameType: FrameTypeBidiStreamBlocked,
+                       frame:     &StreamsBlockedFrame{Type: protocol.StreamTypeBidi, StreamLimit: 0x1234567},
                },
                {
-                       name: "NEW_CONNECTION_ID",
+                       name:      "NEW_CONNECTION_ID",
+                       frameType: FrameTypeNewConnectionID,
                        frame: &NewConnectionIDFrame{
                                SequenceNumber:      0x1337,
                                ConnectionID:        protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}),
@@ -169,32 +285,39 @@ func TestFrameParserFrames(t *testing.T) {
                        },
                },
                {
-                       name:  "RETIRE_CONNECTION_ID",
-                       frame: &RetireConnectionIDFrame{SequenceNumber: 0x1337},
+                       name:      "RETIRE_CONNECTION_ID",
+                       frameType: FrameTypeRetireConnectionID,
+                       frame:     &RetireConnectionIDFrame{SequenceNumber: 0x1337},
                },
                {
-                       name:  "PATH_CHALLENGE",
-                       frame: &PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}},
+                       name:      "PATH_CHALLENGE",
+                       frameType: FrameTypePathChallenge,
+                       frame:     &PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}},
                },
                {
-                       name:  "PATH_RESPONSE",
-                       frame: &PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}},
+                       name:      "PATH_RESPONSE",
+                       frameType: FrameTypePathResponse,
+                       frame:     &PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}},
                },
                {
-                       name:  "CONNECTION_CLOSE",
-                       frame: &ConnectionCloseFrame{IsApplicationError: true, ReasonPhrase: "foobar"},
+                       name:      "CONNECTION_CLOSE",
+                       frameType: FrameTypeConnectionClose,
+                       frame:     &ConnectionCloseFrame{IsApplicationError: false, ReasonPhrase: "foobar"},
                },
                {
-                       name:  "HANDSHAKE_DONE",
-                       frame: &HandshakeDoneFrame{},
+                       name:      "APPLICATION_CLOSE",
+                       frameType: FrameTypeApplicationClose,
+                       frame:     &ConnectionCloseFrame{IsApplicationError: true, ReasonPhrase: "foobar"},
                },
                {
-                       name:  "DATAGRAM",
-                       frame: &DatagramFrame{Data: []byte("foobar")},
+                       name:      "HANDSHAKE_DONE",
+                       frameType: FrameTypeHandshakeDone,
+                       frame:     &HandshakeDoneFrame{},
                },
                {
-                       name:  "RESET_STREAM_AT",
-                       frame: &ResetStreamFrame{StreamID: 0x1337, ReliableSize: 0x42, FinalSize: 0xdeadbeef},
+                       name:      "RESET_STREAM_AT",
+                       frameType: FrameTypeResetStreamAt,
+                       frame:     &ResetStreamFrame{StreamID: 0x1337, ReliableSize: 0x42, FinalSize: 0xdeadbeef},
                },
        }
 
@@ -203,22 +326,173 @@ func TestFrameParserFrames(t *testing.T) {
                        parser := NewFrameParser(true, true)
                        b, err := test.frame.Append(nil, protocol.Version1)
                        require.NoError(t, err)
-                       l, frame, err := parser.ParseNext(b, protocol.Encryption1RTT, protocol.Version1)
+
+                       frameType, l, err := parser.ParseType(b, protocol.Encryption1RTT)
+                       require.NoError(t, err)
+                       require.Equal(t, test.frameType, frameType)
+                       require.Equal(t, 1, l)
+
+                       frame, l, err := parser.ParseLessCommonFrame(frameType, b[l:], protocol.Version1)
                        require.NoError(t, err)
                        require.Equal(t, test.frame, frame)
-                       require.Equal(t, len(b), l)
+                       require.Equal(t, len(b)-1, l)
                })
        }
 }
 
-func checkFrameUnsupported(t *testing.T, err error, expectedFrameType uint64) {
-       t.Helper()
-       require.ErrorContains(t, err, errUnknownFrameType.Error())
-       var transportErr *qerr.TransportError
-       require.ErrorAs(t, err, &transportErr)
-       require.Equal(t, qerr.FrameEncodingError, transportErr.ErrorCode)
-       require.Equal(t, expectedFrameType, transportErr.FrameType)
-       require.Equal(t, "unknown frame type", transportErr.ErrorMessage)
+func TestFrameAllowedAtEncLevel(t *testing.T) {
+       type testCase struct {
+               name             string
+               frameType        FrameType
+               frame            Frame
+               allowedInitial   bool
+               allowedHandshake bool
+               allowedZeroRTT   bool
+               allowedOneRTT    bool
+       }
+
+       for _, tc := range []testCase{
+               {
+                       name:             "CRYPTO_FRAME",
+                       frameType:        FrameTypeCrypto,
+                       frame:            &CryptoFrame{Offset: 0, Data: []byte("foo")},
+                       allowedInitial:   true,
+                       allowedHandshake: true,
+                       allowedZeroRTT:   false,
+                       allowedOneRTT:    true,
+               },
+               {
+                       name:             "ACK_FRAME",
+                       frameType:        FrameTypeAck,
+                       frame:            &AckFrame{AckRanges: []AckRange{{Smallest: 1, Largest: 1}}},
+                       allowedInitial:   true,
+                       allowedHandshake: true,
+                       allowedZeroRTT:   false,
+                       allowedOneRTT:    true,
+               },
+               {
+                       name:             "CONNECTION_CLOSE_FRAME",
+                       frameType:        FrameTypeConnectionClose,
+                       frame:            &ConnectionCloseFrame{IsApplicationError: false, ReasonPhrase: "err"},
+                       allowedInitial:   true,
+                       allowedHandshake: true,
+                       allowedZeroRTT:   false,
+                       allowedOneRTT:    true,
+               },
+               {
+                       name:             "PING_FRAME",
+                       frameType:        FrameTypePing,
+                       frame:            &PingFrame{},
+                       allowedInitial:   true,
+                       allowedHandshake: true,
+                       allowedZeroRTT:   true,
+                       allowedOneRTT:    true,
+               },
+               {
+                       name:             "NEW_TOKEN_FRAME",
+                       frameType:        FrameTypeNewToken,
+                       frame:            &NewTokenFrame{Token: []byte("tok")},
+                       allowedInitial:   false,
+                       allowedHandshake: false,
+                       allowedZeroRTT:   false,
+                       allowedOneRTT:    true,
+               },
+               {
+                       name:             "PATH_RESPONSE_FRAME",
+                       frameType:        FrameTypePathResponse,
+                       frame:            &PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}},
+                       allowedInitial:   false,
+                       allowedHandshake: false,
+                       allowedZeroRTT:   false,
+                       allowedOneRTT:    true,
+               },
+               {
+                       name:             "RETIRE_CONNECTION_ID_FRAME",
+                       frameType:        FrameTypeRetireConnectionID,
+                       frame:            &RetireConnectionIDFrame{SequenceNumber: 1},
+                       allowedInitial:   false,
+                       allowedHandshake: false,
+                       allowedZeroRTT:   false,
+                       allowedOneRTT:    true,
+               },
+               {
+                       name:             "MAX_DATA_FRAME",
+                       frameType:        FrameTypeMaxData,
+                       frame:            &MaxDataFrame{MaximumData: 1},
+                       allowedInitial:   false,
+                       allowedHandshake: false,
+                       allowedZeroRTT:   true,
+                       allowedOneRTT:    true,
+               },
+               {
+                       name:             "STREAM_FRAME",
+                       frameType:        FrameType(0x8),
+                       frame:            &StreamFrame{StreamID: 1, Data: []byte("foobar")},
+                       allowedInitial:   false,
+                       allowedHandshake: false,
+                       allowedZeroRTT:   true,
+                       allowedOneRTT:    true,
+               },
+       } {
+               for _, encLevel := range []protocol.EncryptionLevel{
+                       protocol.EncryptionInitial,
+                       protocol.EncryptionHandshake,
+                       protocol.Encryption0RTT,
+                       protocol.Encryption1RTT,
+               } {
+                       t.Run(fmt.Sprintf("%s/%v", tc.name, encLevel), func(t *testing.T) {
+                               var allowed bool
+                               switch encLevel {
+                               case protocol.EncryptionInitial:
+                                       allowed = tc.allowedInitial
+                               case protocol.EncryptionHandshake:
+                                       allowed = tc.allowedHandshake
+                               case protocol.Encryption0RTT:
+                                       allowed = tc.allowedZeroRTT
+                               case protocol.Encryption1RTT:
+                                       allowed = tc.allowedOneRTT
+                               }
+
+                               parser := NewFrameParser(true, true)
+                               b, err := tc.frame.Append(nil, protocol.Version1)
+                               require.NoError(t, err)
+                               frameType, _, err := parser.ParseType(b, encLevel)
+                               if allowed {
+                                       require.NoError(t, err)
+                                       require.Equal(t, tc.frameType, frameType)
+                               } else {
+                                       require.Error(t, err)
+                                       var transportErr *qerr.TransportError
+                                       require.ErrorAs(t, err, &transportErr)
+                                       require.Equal(t, qerr.FrameEncodingError, transportErr.ErrorCode)
+                               }
+                       })
+               }
+       }
+}
+
+func TestFrameParserDatagramFrame(t *testing.T) {
+       parser := NewFrameParser(true, true)
+       f := &DatagramFrame{
+               Data: []byte("foobar"),
+       }
+       b, err := f.Append(nil, protocol.Version1)
+       require.NoError(t, err)
+       frameType, l, err := parser.ParseType(b, protocol.Encryption1RTT)
+       require.NoError(t, err)
+       require.Equal(t, FrameTypeDatagramNoLength, frameType)
+       require.Equal(t, 1, l)
+
+       // ParseLessCommonFrame should not be used to handle DATAGRAM frames
+       _, _, err = parser.ParseLessCommonFrame(frameType, b[l:], protocol.Version1)
+       require.Error(t, err)
+
+       // parseDatagramFrame should be used for this type
+       datagramFrame, l, err := parser.ParseDatagramFrame(frameType, b[l:], protocol.Version1)
+       require.NoError(t, err)
+       require.IsType(t, &DatagramFrame{}, datagramFrame)
+       require.Equal(t, 6, l)
+       require.Equal(t, f.Data, datagramFrame.Data)
 }
 
 func TestFrameParserDatagramUnsupported(t *testing.T) {
@@ -226,7 +500,8 @@ func TestFrameParserDatagramUnsupported(t *testing.T) {
        f := &DatagramFrame{Data: []byte("foobar")}
        b, err := f.Append(nil, protocol.Version1)
        require.NoError(t, err)
-       _, _, err = parser.ParseNext(b, protocol.Encryption1RTT, protocol.Version1)
+
+       _, _, err = parser.ParseType(b, protocol.Encryption1RTT)
        checkFrameUnsupported(t, err, 0x30)
 }
 
@@ -235,14 +510,22 @@ func TestFrameParserResetStreamAtUnsupported(t *testing.T) {
        f := &ResetStreamFrame{StreamID: 0x1337, ReliableSize: 0x42, FinalSize: 0xdeadbeef}
        b, err := f.Append(nil, protocol.Version1)
        require.NoError(t, err)
-       _, _, err = parser.ParseNext(b, protocol.Encryption1RTT, protocol.Version1)
+
+       _, _, err = parser.ParseType(b, protocol.Encryption1RTT)
        checkFrameUnsupported(t, err, 0x24)
 }
 
 func TestFrameParserInvalidFrameType(t *testing.T) {
        parser := NewFrameParser(true, true)
-       _, _, err := parser.ParseNext(encodeVarInt(0x42), protocol.Encryption1RTT, protocol.Version1)
-       checkFrameUnsupported(t, err, 0x42)
+
+       _, l, err := parser.ParseType(encodeVarInt(0x42), protocol.Encryption1RTT)
+
+       require.Equal(t, 2, l)
+
+       require.Error(t, err)
+       var transportErr *qerr.TransportError
+       require.ErrorAs(t, err, &transportErr)
+       require.Equal(t, qerr.FrameEncodingError, transportErr.ErrorCode)
 }
 
 func TestFrameParsingErrorsOnInvalidFrames(t *testing.T) {
@@ -253,7 +536,13 @@ func TestFrameParsingErrorsOnInvalidFrames(t *testing.T) {
        }
        b, err := f.Append(nil, protocol.Version1)
        require.NoError(t, err)
-       _, _, err = parser.ParseNext(b[:len(b)-2], protocol.Encryption1RTT, protocol.Version1)
+
+       frameType, l, err := parser.ParseType(b[:len(b)-2], protocol.Encryption1RTT)
+       require.NoError(t, err)
+       require.Equal(t, FrameTypeMaxStreamData, frameType)
+       require.Equal(t, 1, l)
+
+       _, _, err = parser.ParseLessCommonFrame(frameType, b[1:len(b)-2], protocol.Version1)
        require.Error(t, err)
        var transportErr *qerr.TransportError
        require.ErrorAs(t, err, &transportErr)
@@ -274,102 +563,160 @@ func writeFrames(tb testing.TB, frames ...Frame) []byte {
 // We can therefore not use the require framework, as it allocates.
 func parseFrames(tb testing.TB, parser *FrameParser, data []byte, frames ...Frame) {
        for _, expectedFrame := range frames {
-               l, frame, err := parser.ParseNext(data, protocol.Encryption1RTT, protocol.Version1)
+               frameType, l, err := parser.ParseType(data, protocol.Encryption1RTT)
                if err != nil {
                        tb.Fatal(err)
                }
                data = data[l:]
-               if frame == nil {
-                       break
-               }
 
-               // Use type switch approach (like master branch)
-               switch f := frame.(type) {
-               case *StreamFrame:
+               if frameType.IsStreamFrameType() {
                        sf := expectedFrame.(*StreamFrame)
-                       if sf.StreamID != f.StreamID || sf.Offset != f.Offset || !bytes.Equal(sf.Data, f.Data) {
-                               tb.Fatalf("STREAM frame does not match: %v vs %v", sf, f)
+                       frame, l, err := ParseStreamFrame(data, frameType, protocol.Version1)
+                       if err != nil {
+                               tb.Fatal(err)
+                       }
+                       if sf.StreamID != frame.StreamID || sf.Offset != frame.Offset {
+                               tb.Fatalf("STREAM frame does not match: %v vs %v", sf, frame)
                        }
-                       f.PutBack()
-               case *AckFrame:
+                       frame.PutBack()
+                       data = data[l:]
+                       continue
+               }
+
+               if frameType.IsAckFrameType() {
                        af, ok := expectedFrame.(*AckFrame)
                        if !ok {
                                tb.Fatalf("expected ACK, but got %v", expectedFrame)
                        }
+
+                       f, l, err := parser.ParseAckFrame(frameType, data, protocol.Encryption1RTT, protocol.Version1)
                        if f.DelayTime != af.DelayTime || f.ECNCE != af.ECNCE || f.ECT0 != af.ECT0 || f.ECT1 != af.ECT1 {
+                               tb.Fatal(err)
+                       }
+                       if f.DelayTime != af.DelayTime {
                                tb.Fatalf("ACK frame does not match: %v vs %v", af, f)
                        }
                        if !slices.Equal(f.AckRanges, af.AckRanges) {
                                tb.Fatalf("ACK frame ACK ranges don't match: %v vs %v", af, f)
                        }
-               case *DatagramFrame:
+                       data = data[l:]
+                       continue
+               }
+
+               if frameType.IsDatagramFrameType() {
                        df, ok := expectedFrame.(*DatagramFrame)
                        if !ok {
                                tb.Fatalf("expected DATAGRAM, but got %v", expectedFrame)
                        }
+
+                       f, l, err := parser.ParseDatagramFrame(frameType, data, protocol.Version1)
+                       if err != nil {
+                               tb.Fatal(err)
+                       }
                        if df.DataLenPresent != f.DataLenPresent || !bytes.Equal(df.Data, f.Data) {
                                tb.Fatalf("DATAGRAM frame does not match: %v vs %v", df, f)
                        }
-               case *MaxDataFrame:
+                       data = data[l:]
+                       continue
+               }
+
+               f, l, err := parser.ParseLessCommonFrame(frameType, data, protocol.Version1)
+               if err != nil {
+                       tb.Fatal(err)
+               }
+               data = data[l:]
+
+               switch frameType {
+               case FrameTypeMaxData:
                        mdf, ok := expectedFrame.(*MaxDataFrame)
                        if !ok {
                                tb.Fatalf("expected MAX_DATA, but got %v", expectedFrame)
                        }
-                       if *f != *mdf {
+                       if *f.(*MaxDataFrame) != *mdf {
                                tb.Fatalf("MAX_DATA frame does not match: %v vs %v", f, mdf)
                        }
-               case *MaxStreamsFrame:
+               case FrameTypeUniMaxStreams:
                        msf, ok := expectedFrame.(*MaxStreamsFrame)
                        if !ok {
                                tb.Fatalf("expected MAX_STREAMS, but got %v", expectedFrame)
                        }
-                       if *f != *msf {
+                       if *f.(*MaxStreamsFrame) != *msf {
                                tb.Fatalf("MAX_STREAMS frame does not match: %v vs %v", f, msf)
                        }
-               case *MaxStreamDataFrame:
+               case FrameTypeMaxStreamData:
                        mdf, ok := expectedFrame.(*MaxStreamDataFrame)
                        if !ok {
                                tb.Fatalf("expected MAX_STREAM_DATA, but got %v", expectedFrame)
                        }
-                       if *f != *mdf {
+                       if *f.(*MaxStreamDataFrame) != *mdf {
                                tb.Fatalf("MAX_STREAM_DATA frame does not match: %v vs %v", f, mdf)
                        }
-               case *CryptoFrame:
+               case FrameTypeCrypto:
                        cf, ok := expectedFrame.(*CryptoFrame)
                        if !ok {
                                tb.Fatalf("expected CRYPTO, but got %v", expectedFrame)
                        }
-                       if f.Offset != cf.Offset || !bytes.Equal(f.Data, cf.Data) {
+                       frame := f.(*CryptoFrame)
+                       if frame.Offset != cf.Offset || !bytes.Equal(frame.Data, cf.Data) {
                                tb.Fatalf("CRYPTO frame does not match: %v vs %v", f, cf)
                        }
-               case *PingFrame:
-                       _ = f
-               case *ResetStreamFrame:
+               case FrameTypePing:
+                       _ = f.(*PingFrame)
+               case FrameTypeResetStream:
                        rsf, ok := expectedFrame.(*ResetStreamFrame)
                        if !ok {
                                tb.Fatalf("expected RESET_STREAM, but got %v", expectedFrame)
                        }
-                       if *f != *rsf {
+                       if *f.(*ResetStreamFrame) != *rsf {
                                tb.Fatalf("RESET_STREAM frame does not match: %v vs %v", f, rsf)
                        }
+                       continue
                default:
-                       tb.Fatalf("Frame type not supported in benchmark: %T", f)
+                       tb.Fatalf("Frame type not supported in benchmark or should not occur: %v", frameType)
                }
        }
 }
 
-func benchmarkFrames(b *testing.B, frames ...Frame) {
-       buf := writeFrames(b, frames...)
+func TestFrameParserAllocs(t *testing.T) {
+       t.Run("STREAM", func(t *testing.T) {
+               var frames []Frame
+               for i := range 10 {
+                       frames = append(frames, &StreamFrame{
+                               StreamID:       protocol.StreamID(1337 + i),
+                               Offset:         protocol.ByteCount(1e7 + i),
+                               Data:           make([]byte, 200+i),
+                               DataLenPresent: true,
+                       })
+               }
+               require.Zero(t, testFrameParserAllocs(t, frames))
+       })
 
+       t.Run("ACK", func(t *testing.T) {
+               var frames []Frame
+               for i := range 10 {
+                       frames = append(frames, &AckFrame{
+                               AckRanges: []AckRange{
+                                       {Smallest: protocol.PacketNumber(5000 + i), Largest: protocol.PacketNumber(5200 + i)},
+                                       {Smallest: protocol.PacketNumber(1 + i), Largest: protocol.PacketNumber(4200 + i)},
+                               },
+                               DelayTime: time.Duration(int64(time.Millisecond) * int64(i)),
+                               ECT0:      uint64(5000 + i),
+                               ECT1:      uint64(i),
+                               ECNCE:     uint64(10 + i),
+                       })
+               }
+               require.Zero(t, testFrameParserAllocs(t, frames))
+       })
+}
+
+func testFrameParserAllocs(t *testing.T, frames []Frame) float64 {
+       buf := writeFrames(t, frames...)
        parser := NewFrameParser(true, true)
        parser.SetAckDelayExponent(3)
 
-       b.ResetTimer()
-       b.ReportAllocs()
-
-       for range b.N {
-               parseFrames(b, parser, buf, frames...)
-       }
+       return testing.AllocsPerRun(100, func() {
+               parseFrames(t, parser, buf, frames...)
+       })
 }
 
 func BenchmarkParseOtherFrames(b *testing.B) {
@@ -428,3 +775,17 @@ func BenchmarkParseDatagramFrame(b *testing.B) {
        }
        benchmarkFrames(b, frames...)
 }
+
+func benchmarkFrames(b *testing.B, frames ...Frame) {
+       buf := writeFrames(b, frames...)
+
+       parser := NewFrameParser(true, true)
+       parser.SetAckDelayExponent(3)
+
+       b.ResetTimer()
+       b.ReportAllocs()
+
+       for range b.N {
+               parseFrames(b, parser, buf, frames...)
+       }
+}
index 3012b558d7bc9ba333515bc6983fd89cf379b12f..5dbfba8963b9a880e9ddf2939a1a56822c5997a7 100644 (file)
@@ -27,3 +27,16 @@ func TestProbingFrames(t *testing.T) {
                require.Equal(t, expected, IsProbingFrame(f))
        }
 }
+
+func TestIsProbingFrameType(t *testing.T) {
+       tests := map[FrameType]bool{
+               FrameTypePathChallenge:   true,
+               FrameTypePathResponse:    true,
+               FrameTypeNewConnectionID: true,
+               FrameType(0x01):          false,
+               FrameType(0xFF):          false,
+       }
+       for ft, expected := range tests {
+               require.Equal(t, expected, IsProbingFrameType(ft))
+       }
+}
diff --git a/internal/wire/frame_type.go b/internal/wire/frame_type.go
new file mode 100644 (file)
index 0000000..0576657
--- /dev/null
@@ -0,0 +1,77 @@
+package wire
+
+import "github.com/quic-go/quic-go/internal/protocol"
+
+type FrameType uint64
+
+// These constants correspond to those defined in RFC 9000.
+// Stream frame types are not listed explicitly here; use FrameType.IsStreamFrameType() to identify them.
+const (
+       FrameTypePing        FrameType = 0x1
+       FrameTypeAck         FrameType = 0x2
+       FrameTypeAckECN      FrameType = 0x3
+       FrameTypeResetStream FrameType = 0x4
+       FrameTypeStopSending FrameType = 0x5
+       FrameTypeCrypto      FrameType = 0x6
+       FrameTypeNewToken    FrameType = 0x7
+
+       FrameTypeMaxData            FrameType = 0x10
+       FrameTypeMaxStreamData      FrameType = 0x11
+       FrameTypeBidiMaxStreams     FrameType = 0x12
+       FrameTypeUniMaxStreams      FrameType = 0x13
+       FrameTypeDataBlocked        FrameType = 0x14
+       FrameTypeStreamDataBlocked  FrameType = 0x15
+       FrameTypeBidiStreamBlocked  FrameType = 0x16
+       FrameTypeUniStreamBlocked   FrameType = 0x17
+       FrameTypeNewConnectionID    FrameType = 0x18
+       FrameTypeRetireConnectionID FrameType = 0x19
+       FrameTypePathChallenge      FrameType = 0x1a
+       FrameTypePathResponse       FrameType = 0x1b
+       FrameTypeConnectionClose    FrameType = 0x1c
+       FrameTypeApplicationClose   FrameType = 0x1d
+       FrameTypeHandshakeDone      FrameType = 0x1e
+       FrameTypeResetStreamAt      FrameType = 0x24 // https://datatracker.ietf.org/doc/draft-ietf-quic-reliable-stream-reset/06/
+
+       FrameTypeDatagramNoLength   FrameType = 0x30
+       FrameTypeDatagramWithLength FrameType = 0x31
+)
+
+func (t FrameType) IsStreamFrameType() bool {
+       return t >= 0x8 && t <= 0xf
+}
+
+func (t FrameType) isValidRFC9000() bool {
+       return t <= 0x1e
+}
+
+func (t FrameType) IsAckFrameType() bool {
+       return t == FrameTypeAck || t == FrameTypeAckECN
+}
+
+func (t FrameType) IsDatagramFrameType() bool {
+       return t == FrameTypeDatagramNoLength || t == FrameTypeDatagramWithLength
+}
+
+func (t FrameType) isAllowedAtEncLevel(encLevel protocol.EncryptionLevel) bool {
+       //nolint:exhaustive
+       switch encLevel {
+       case protocol.EncryptionInitial, protocol.EncryptionHandshake:
+               switch t {
+               case FrameTypeCrypto, FrameTypeAck, FrameTypeAckECN, FrameTypeConnectionClose, FrameTypePing:
+                       return true
+               default:
+                       return false
+               }
+       case protocol.Encryption0RTT:
+               switch t {
+               case FrameTypeCrypto, FrameTypeAck, FrameTypeAckECN, FrameTypeConnectionClose, FrameTypeNewToken, FrameTypePathResponse, FrameTypeRetireConnectionID:
+                       return false
+               default:
+                       return true
+               }
+       case protocol.Encryption1RTT:
+               return true
+       default:
+               panic("unknown encryption level")
+       }
+}
diff --git a/internal/wire/frame_type_test.go b/internal/wire/frame_type_test.go
new file mode 100644 (file)
index 0000000..2702af3
--- /dev/null
@@ -0,0 +1,29 @@
+package wire
+
+import (
+       "testing"
+
+       "github.com/stretchr/testify/require"
+)
+
+func TestIsStreamFrameType(t *testing.T) {
+       for i := 0x08; i <= 0x0f; i++ {
+               require.Truef(t, FrameType(i).IsStreamFrameType(), "FrameType(0x%x).IsStreamFrameType() = false, want true", i)
+       }
+
+       require.False(t, FrameType(0x1).IsStreamFrameType())
+}
+
+func TestIsAckFrameType(t *testing.T) {
+       require.True(t, FrameTypeAck.IsAckFrameType(), "AckFrameType should be recognized as ACK")
+       require.True(t, FrameTypeAckECN.IsAckFrameType(), "AckECNFrameType should be recognized as ACK")
+       require.False(t, FrameTypePing.IsAckFrameType(), "PingFrameType should not be recognized as ACK")
+       require.False(t, FrameType(0x10).IsAckFrameType(), "MaxDataFrameType should not be recognized as ACK")
+}
+
+func TestIsDatagramFrameType(t *testing.T) {
+       require.True(t, FrameTypeDatagramNoLength.IsDatagramFrameType(), "DatagramNoLengthFrameType should be recognized as DATAGRAM")
+       require.True(t, FrameTypeDatagramWithLength.IsDatagramFrameType(), "DatagramWithLengthFrameType should be recognized as DATAGRAM")
+       require.False(t, FrameTypePing.IsDatagramFrameType(), "PingFrameType should not be recognized as DATAGRAM")
+       require.False(t, FrameType(0x1e).IsDatagramFrameType(), "HandshakeDoneFrameType should not be recognized as DATAGRAM")
+}
index 85dd64745596088bfd1882ceb56c7c3fe303c436..bf95f525b8cf21f840e8999c79c090b52c47cf71 100644 (file)
@@ -8,7 +8,7 @@ import (
 type HandshakeDoneFrame struct{}
 
 func (f *HandshakeDoneFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
-       return append(b, handshakeDoneFrameType), nil
+       return append(b, byte(FrameTypeHandshakeDone)), nil
 }
 
 // Length of a written frame
index 51381df455dad009f80a806beff6270779e70164..bec44ec97338da8972ae85d5f2bdfde1369922f6 100644 (file)
@@ -11,6 +11,6 @@ func TestWriteHandshakeDoneSampleFrame(t *testing.T) {
        frame := HandshakeDoneFrame{}
        b, err := frame.Append(nil, protocol.Version1)
        require.NoError(t, err)
-       require.Equal(t, []byte{handshakeDoneFrameType}, b)
+       require.Equal(t, []byte{byte(FrameTypeHandshakeDone)}, b)
        require.Equal(t, protocol.ByteCount(1), frame.Length(protocol.Version1))
 }
index 5819c0273933064534c0cb8be03c3f43352ec56f..bfbdcba6666c546a554047a7274204dd4ced52ea 100644 (file)
@@ -22,7 +22,7 @@ func parseMaxDataFrame(b []byte, _ protocol.Version) (*MaxDataFrame, int, error)
 }
 
 func (f *MaxDataFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
-       b = append(b, maxDataFrameType)
+       b = append(b, byte(FrameTypeMaxData))
        b = quicvarint.Append(b, uint64(f.MaximumData))
        return b, nil
 }
index 2c8060894d36a2b2dd11811222493c1d8573c9c7..a6cb3d255679f1462122ce80c6c48e87089603ec 100644 (file)
@@ -32,7 +32,7 @@ func TestWriteMaxDataFrame(t *testing.T) {
        f := &MaxDataFrame{MaximumData: 0xdeadbeefcafe}
        b, err := f.Append(nil, protocol.Version1)
        require.NoError(t, err)
-       expected := []byte{maxDataFrameType}
+       expected := []byte{byte(FrameTypeMaxData)}
        expected = append(expected, encodeVarInt(0xdeadbeefcafe)...)
        require.Equal(t, expected, b)
        require.Len(t, b, int(f.Length(protocol.Version1)))
index db9091af8e11280de80139c80831fd546dd7e18c..0966ea46954e275c01cedaacfa5f404bfd85daf0 100644 (file)
@@ -31,7 +31,7 @@ func parseMaxStreamDataFrame(b []byte, _ protocol.Version) (*MaxStreamDataFrame,
 }
 
 func (f *MaxStreamDataFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
-       b = append(b, maxStreamDataFrameType)
+       b = append(b, byte(FrameTypeMaxStreamData))
        b = quicvarint.Append(b, uint64(f.StreamID))
        b = quicvarint.Append(b, uint64(f.MaximumStreamData))
        return b, nil
index 05559ebfa1743ca63b8c59ba1393c05666836ddf..f4757da5c019faab40cdf15897630ef7b3869a01 100644 (file)
@@ -36,7 +36,7 @@ func TestWriteMaxStreamDataFrame(t *testing.T) {
                StreamID:          0xdecafbad,
                MaximumStreamData: 0xdeadbeefcafe42,
        }
-       expected := []byte{maxStreamDataFrameType}
+       expected := []byte{byte(FrameTypeMaxStreamData)}
        expected = append(expected, encodeVarInt(0xdecafbad)...)
        expected = append(expected, encodeVarInt(0xdeadbeefcafe42)...)
        b, err := f.Append(nil, protocol.Version1)
index a8745bd124d63d6939fe7667579dc2bc2cb69de5..30612e23bc49f634feaa99eb71cff289436e6551 100644 (file)
@@ -13,12 +13,13 @@ type MaxStreamsFrame struct {
        MaxStreamNum protocol.StreamNum
 }
 
-func parseMaxStreamsFrame(b []byte, typ uint64, _ protocol.Version) (*MaxStreamsFrame, int, error) {
+func parseMaxStreamsFrame(b []byte, typ FrameType, _ protocol.Version) (*MaxStreamsFrame, int, error) {
        f := &MaxStreamsFrame{}
+       //nolint:exhaustive // Function will only be called with BidiMaxStreamsFrameType or UniMaxStreamsFrameType
        switch typ {
-       case bidiMaxStreamsFrameType:
+       case FrameTypeBidiMaxStreams:
                f.Type = protocol.StreamTypeBidi
-       case uniMaxStreamsFrameType:
+       case FrameTypeUniMaxStreams:
                f.Type = protocol.StreamTypeUni
        }
        streamID, l, err := quicvarint.Parse(b)
@@ -35,9 +36,9 @@ func parseMaxStreamsFrame(b []byte, typ uint64, _ protocol.Version) (*MaxStreams
 func (f *MaxStreamsFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
        switch f.Type {
        case protocol.StreamTypeBidi:
-               b = append(b, bidiMaxStreamsFrameType)
+               b = append(b, byte(FrameTypeBidiMaxStreams))
        case protocol.StreamTypeUni:
-               b = append(b, uniMaxStreamsFrameType)
+               b = append(b, byte(FrameTypeUniMaxStreams))
        }
        b = quicvarint.Append(b, uint64(f.MaxStreamNum))
        return b, nil
index f5be03e1e22730d4807bfde46dbcbfbd8ae1f48d..7afbce2ff8bf6a6f82fe5348d2d54c4beb251799 100644 (file)
@@ -13,7 +13,7 @@ import (
 
 func TestParseMaxStreamsFrameBidirectional(t *testing.T) {
        data := encodeVarInt(0xdecaf)
-       f, l, err := parseMaxStreamsFrame(data, bidiMaxStreamsFrameType, protocol.Version1)
+       f, l, err := parseMaxStreamsFrame(data, FrameTypeBidiMaxStreams, protocol.Version1)
        require.NoError(t, err)
        require.Equal(t, protocol.StreamTypeBidi, f.Type)
        require.EqualValues(t, 0xdecaf, f.MaxStreamNum)
@@ -22,7 +22,7 @@ func TestParseMaxStreamsFrameBidirectional(t *testing.T) {
 
 func TestParseMaxStreamsFrameUnidirectional(t *testing.T) {
        data := encodeVarInt(0xdecaf)
-       f, l, err := parseMaxStreamsFrame(data, uniMaxStreamsFrameType, protocol.Version1)
+       f, l, err := parseMaxStreamsFrame(data, FrameTypeUniMaxStreams, protocol.Version1)
        require.NoError(t, err)
        require.Equal(t, protocol.StreamTypeUni, f.Type)
        require.EqualValues(t, 0xdecaf, f.MaxStreamNum)
@@ -59,7 +59,7 @@ func TestParseMaxStreamsMaxValue(t *testing.T) {
                        typ, l, err := quicvarint.Parse(b)
                        require.NoError(t, err)
                        b = b[l:]
-                       frame, _, err := parseMaxStreamsFrame(b, typ, protocol.Version1)
+                       frame, _, err := parseMaxStreamsFrame(b, FrameType(typ), protocol.Version1)
                        require.NoError(t, err)
                        require.Equal(t, f, frame)
                })
@@ -84,7 +84,7 @@ func TestParseMaxStreamsErrorsOnTooLargeStreamCount(t *testing.T) {
                        typ, l, err := quicvarint.Parse(b)
                        require.NoError(t, err)
                        b = b[l:]
-                       _, _, err = parseMaxStreamsFrame(b, typ, protocol.Version1)
+                       _, _, err = parseMaxStreamsFrame(b, FrameType(typ), protocol.Version1)
                        require.EqualError(t, err, fmt.Sprintf("%d exceeds the maximum stream count", protocol.MaxStreamCount+1))
                })
        }
@@ -97,7 +97,7 @@ func TestWriteMaxStreamsBidirectional(t *testing.T) {
        }
        b, err := f.Append(nil, protocol.Version1)
        require.NoError(t, err)
-       expected := []byte{bidiMaxStreamsFrameType}
+       expected := []byte{byte(FrameTypeBidiMaxStreams)}
        expected = append(expected, encodeVarInt(0xdeadbeef)...)
        require.Equal(t, expected, b)
        require.Len(t, b, int(f.Length(protocol.Version1)))
@@ -110,7 +110,7 @@ func TestWriteMaxStreamsUnidirectional(t *testing.T) {
        }
        b, err := f.Append(nil, protocol.Version1)
        require.NoError(t, err)
-       expected := []byte{uniMaxStreamsFrameType}
+       expected := []byte{byte(FrameTypeUniMaxStreams)}
        expected = append(expected, encodeVarInt(0xdecafbad)...)
        require.Equal(t, expected, b)
        require.Len(t, b, int(f.Length(protocol.Version1)))
index 6f2287f44b770650d0e0106d0821db7991aa8063..058319266f5a12d86bdf8e743466b57c44ecbad8 100644 (file)
@@ -61,7 +61,7 @@ func parseNewConnectionIDFrame(b []byte, _ protocol.Version) (*NewConnectionIDFr
 }
 
 func (f *NewConnectionIDFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
-       b = append(b, newConnectionIDFrameType)
+       b = append(b, byte(FrameTypeNewConnectionID))
        b = quicvarint.Append(b, f.SequenceNumber)
        b = quicvarint.Append(b, f.RetirePriorTo)
        connIDLen := f.ConnectionID.Len()
index 739cf2cfda920365160f40b5d50f46036dba1c85..3292e0f3c91efa54f73590115c7777773368715b 100644 (file)
@@ -77,7 +77,7 @@ func TestWriteNewConnectionIDFrame(t *testing.T) {
        }
        b, err := frame.Append(nil, protocol.Version1)
        require.NoError(t, err)
-       expected := []byte{newConnectionIDFrameType}
+       expected := []byte{byte(FrameTypeNewConnectionID)}
        expected = append(expected, encodeVarInt(0x1337)...)
        expected = append(expected, encodeVarInt(0x42)...)
        expected = append(expected, 6)
index f1d4d00fe6613761e5e96f77ad9f4a45628b23d3..73d356b1ad181d13bbb8ced56fca28967c5d8d43 100644 (file)
@@ -31,7 +31,7 @@ func parseNewTokenFrame(b []byte, _ protocol.Version) (*NewTokenFrame, int, erro
 }
 
 func (f *NewTokenFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
-       b = append(b, newTokenFrameType)
+       b = append(b, byte(FrameTypeNewToken))
        b = quicvarint.Append(b, uint64(len(f.Token)))
        b = append(b, f.Token...)
        return b, nil
index 77da62e8f8b70cf2d522ae5e710f425d5b42e285..cd2ae3dafce7b60720cdd2a37c2595a8d16f7ae0 100644 (file)
@@ -43,7 +43,7 @@ func TestWriteNewTokenFrame(t *testing.T) {
        f := &NewTokenFrame{Token: []byte(token)}
        b, err := f.Append(nil, protocol.Version1)
        require.NoError(t, err)
-       expected := []byte{newTokenFrameType}
+       expected := []byte{byte(FrameTypeNewToken)}
        expected = append(expected, encodeVarInt(uint64(len(token)))...)
        expected = append(expected, token...)
        require.Equal(t, expected, b)
index 2aca989fa6bd34e360a62fc0454f306974a57244..7a4a767e5180792cf3001c2686f1436720b30c9a 100644 (file)
@@ -21,7 +21,7 @@ func parsePathChallengeFrame(b []byte, _ protocol.Version) (*PathChallengeFrame,
 }
 
 func (f *PathChallengeFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
-       b = append(b, pathChallengeFrameType)
+       b = append(b, byte(FrameTypePathChallenge))
        b = append(b, f.Data[:]...)
        return b, nil
 }
index 3a755e89177350015ba77cb593f1e17b6f7fdb08..f5e521aa59170b1634c50d5d8d1fe505a3613213 100644 (file)
@@ -32,6 +32,6 @@ func TestWritePathChallenge(t *testing.T) {
        frame := PathChallengeFrame{Data: [8]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}}
        b, err := frame.Append(nil, protocol.Version1)
        require.NoError(t, err)
-       require.Equal(t, []byte{pathChallengeFrameType, 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, b)
+       require.Equal(t, []byte{byte(FrameTypePathChallenge), 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, b)
        require.Len(t, b, int(frame.Length(protocol.Version1)))
 }
index 76532c8527b8ca4c6c5f2700d09ff632adf4453c..e76d037b15100a95538a5ba393c06295a129fec1 100644 (file)
@@ -21,7 +21,7 @@ func parsePathResponseFrame(b []byte, _ protocol.Version) (*PathResponseFrame, i
 }
 
 func (f *PathResponseFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
-       b = append(b, pathResponseFrameType)
+       b = append(b, byte(FrameTypePathResponse))
        b = append(b, f.Data[:]...)
        return b, nil
 }
index d939b0a1c4223e22139971a4885ce68dfb7a8391..884c407df766b158954c163ccc8accc2302ffd63 100644 (file)
@@ -32,6 +32,6 @@ func TestWritePathResponse(t *testing.T) {
        frame := PathResponseFrame{Data: [8]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}}
        b, err := frame.Append(nil, protocol.Version1)
        require.NoError(t, err)
-       require.Equal(t, []byte{pathResponseFrameType, 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, b)
+       require.Equal(t, []byte{byte(FrameTypePathResponse), 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, b)
        require.Len(t, b, int(frame.Length(protocol.Version1)))
 }
index 71f8d16c38fd5aa64244b9436b0846481829b521..5d344d447f8c6b330d9cbd38ab09054ddc0b47f5 100644 (file)
@@ -8,7 +8,7 @@ import (
 type PingFrame struct{}
 
 func (f *PingFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
-       return append(b, pingFrameType), nil
+       return append(b, byte(FrameTypePing)), nil
 }
 
 // Length of a written frame
index cb678bf453d49d53dab6d77583025483126698ca..4101b76b26fb3442575b84c654f0eeddd8b931a0 100644 (file)
@@ -56,9 +56,9 @@ func parseResetStreamFrame(b []byte, isResetStreamAt bool, _ protocol.Version) (
 
 func (f *ResetStreamFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
        if f.ReliableSize == 0 {
-               b = quicvarint.Append(b, resetStreamFrameType)
+               b = quicvarint.Append(b, uint64(FrameTypeResetStream))
        } else {
-               b = quicvarint.Append(b, resetStreamAtFrameType)
+               b = quicvarint.Append(b, uint64(FrameTypeResetStreamAt))
        }
        b = quicvarint.Append(b, uint64(f.StreamID))
        b = quicvarint.Append(b, uint64(f.ErrorCode))
index 3f85a985394abb8fb33e1cd0759a899ad38c9132..d85d616ba33506ec56830e086a35d8fc2e471f8b 100644 (file)
@@ -77,7 +77,7 @@ func TestWriteResetStream(t *testing.T) {
        }
        b, err := frame.Append(nil, protocol.Version1)
        require.NoError(t, err)
-       expected := []byte{resetStreamFrameType}
+       expected := []byte{byte(FrameTypeResetStream)}
        expected = append(expected, encodeVarInt(0x1337)...)
        expected = append(expected, encodeVarInt(0xcafe)...)
        expected = append(expected, encodeVarInt(0x11223344decafbad)...)
@@ -94,7 +94,7 @@ func TestWriteResetStreamAt(t *testing.T) {
        }
        b, err := frame.Append(nil, protocol.Version1)
        require.NoError(t, err)
-       expected := []byte{resetStreamAtFrameType}
+       expected := []byte{byte(FrameTypeResetStreamAt)}
        expected = append(expected, encodeVarInt(1337)...)
        expected = append(expected, encodeVarInt(0xcafe)...)
        expected = append(expected, encodeVarInt(42)...)
index 27aeff8428bb9fb1c9572265a3037dda04476129..1927f9dc07b4818683540ccc135e07ce9ff62f0d 100644 (file)
@@ -19,7 +19,7 @@ func parseRetireConnectionIDFrame(b []byte, _ protocol.Version) (*RetireConnecti
 }
 
 func (f *RetireConnectionIDFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
-       b = append(b, retireConnectionIDFrameType)
+       b = append(b, byte(FrameTypeRetireConnectionID))
        b = quicvarint.Append(b, f.SequenceNumber)
        return b, nil
 }
index 9c76151c9619d46203fc1ea624bc2d967d0277e2..e1b64ccfdf424e63005ab28bcd3a391d1f8fcca1 100644 (file)
@@ -32,7 +32,7 @@ func TestWriteRetireConnectionID(t *testing.T) {
        frame := &RetireConnectionIDFrame{SequenceNumber: 0x1337}
        b, err := frame.Append(nil, protocol.Version1)
        require.NoError(t, err)
-       expected := []byte{retireConnectionIDFrameType}
+       expected := []byte{byte(FrameTypeRetireConnectionID)}
        expected = append(expected, encodeVarInt(0x1337)...)
        require.Equal(t, expected, b)
        require.Len(t, b, int(frame.Length(protocol.Version1)))
index a2326f8ec429314d6a6ccbe496b3688d4329aa48..2b15c7109f6a2977853b73013dfc8b76b7af021f 100644 (file)
@@ -38,7 +38,7 @@ func (f *StopSendingFrame) Length(_ protocol.Version) protocol.ByteCount {
 }
 
 func (f *StopSendingFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
-       b = append(b, stopSendingFrameType)
+       b = append(b, byte(FrameTypeStopSending))
        b = quicvarint.Append(b, uint64(f.StreamID))
        b = quicvarint.Append(b, uint64(f.ErrorCode))
        return b, nil
index b670c047a7a7a3692f28b967aa486d04a77cc28b..90bc8a32e6683061c9cbad93239b8b85f46b82c2 100644 (file)
@@ -39,7 +39,7 @@ func TestWriteStopSendingFrame(t *testing.T) {
        }
        b, err := frame.Append(nil, protocol.Version1)
        require.NoError(t, err)
-       expected := []byte{stopSendingFrameType}
+       expected := []byte{byte(FrameTypeStopSending)}
        expected = append(expected, encodeVarInt(0xdeadbeefcafe)...)
        expected = append(expected, encodeVarInt(0xdecafbad)...)
        require.Equal(t, expected, b)
index b73e45815c7d09bdaf515bced8a49a8b22c0b17d..325cb58fd7f8763d7eb226daca9aa13d7cbb1c74 100644 (file)
@@ -38,7 +38,7 @@ func TestWriteStreamDataBlocked(t *testing.T) {
        }
        b, err := f.Append(nil, protocol.Version1)
        require.NoError(t, err)
-       expected := []byte{streamDataBlockedFrameType}
+       expected := []byte{byte(FrameTypeStreamDataBlocked)}
        expected = append(expected, encodeVarInt(uint64(f.StreamID))...)
        expected = append(expected, encodeVarInt(uint64(f.MaximumStreamData))...)
        require.Equal(t, expected, b)
index cdc32722fbc38150a53ef236210201ad94a1f56d..e53962b193c6fd4741274c373ae7868402891eec 100644 (file)
@@ -19,7 +19,7 @@ type StreamFrame struct {
        fromPool bool
 }
 
-func parseStreamFrame(b []byte, typ uint64, _ protocol.Version) (*StreamFrame, int, error) {
+func ParseStreamFrame(b []byte, typ FrameType, _ protocol.Version) (*StreamFrame, int, error) {
        startLen := len(b)
        hasOffset := typ&0b100 > 0
        fin := typ&0b1 > 0
index b61775712534be5cbecebfbae4ed66d426da0ba1..3c658100135372c66bec6906421940cc98f5a057 100644 (file)
@@ -14,7 +14,7 @@ func TestParseStreamFrameWithOffBit(t *testing.T) {
        data := encodeVarInt(0x12345)                    // stream ID
        data = append(data, encodeVarInt(0xdecafbad)...) // offset
        data = append(data, []byte("foobar")...)
-       frame, l, err := parseStreamFrame(data, 0x8^0x4, protocol.Version1)
+       frame, l, err := ParseStreamFrame(data, 0x8^0x4, protocol.Version1)
        require.NoError(t, err)
        require.Equal(t, protocol.StreamID(0x12345), frame.StreamID)
        require.Equal(t, []byte("foobar"), frame.Data)
@@ -27,7 +27,7 @@ func TestParseStreamFrameRespectsLEN(t *testing.T) {
        data := encodeVarInt(0x12345)           // stream ID
        data = append(data, encodeVarInt(4)...) // data length
        data = append(data, []byte("foobar")...)
-       frame, l, err := parseStreamFrame(data, 0x8^0x2, protocol.Version1)
+       frame, l, err := ParseStreamFrame(data, 0x8^0x2, protocol.Version1)
        require.NoError(t, err)
        require.Equal(t, protocol.StreamID(0x12345), frame.StreamID)
        require.Equal(t, []byte("foob"), frame.Data)
@@ -39,7 +39,7 @@ func TestParseStreamFrameRespectsLEN(t *testing.T) {
 func TestParseStreamFrameWithFINBit(t *testing.T) {
        data := encodeVarInt(9) // stream ID
        data = append(data, []byte("foobar")...)
-       frame, l, err := parseStreamFrame(data, 0x8^0x1, protocol.Version1)
+       frame, l, err := ParseStreamFrame(data, 0x8^0x1, protocol.Version1)
        require.NoError(t, err)
        require.Equal(t, protocol.StreamID(9), frame.StreamID)
        require.Equal(t, []byte("foobar"), frame.Data)
@@ -51,7 +51,7 @@ func TestParseStreamFrameWithFINBit(t *testing.T) {
 func TestParseStreamFrameAllowsEmpty(t *testing.T) {
        data := encodeVarInt(0x1337)                  // stream ID
        data = append(data, encodeVarInt(0x12345)...) // offset
-       f, l, err := parseStreamFrame(data, 0x8^0x4, protocol.Version1)
+       f, l, err := ParseStreamFrame(data, 0x8^0x4, protocol.Version1)
        require.NoError(t, err)
        require.Equal(t, protocol.StreamID(0x1337), f.StreamID)
        require.Equal(t, protocol.ByteCount(0x12345), f.Offset)
@@ -64,7 +64,7 @@ func TestParseStreamFrameRejectsOverflow(t *testing.T) {
        data := encodeVarInt(0x12345)                                         // stream ID
        data = append(data, encodeVarInt(uint64(protocol.MaxByteCount-5))...) // offset
        data = append(data, []byte("foobar")...)
-       _, _, err := parseStreamFrame(data, 0x8^0x4, protocol.Version1)
+       _, _, err := ParseStreamFrame(data, 0x8^0x4, protocol.Version1)
        require.EqualError(t, err, "stream data overflows maximum offset")
 }
 
@@ -72,7 +72,7 @@ func TestParseStreamFrameRejectsLongFrames(t *testing.T) {
        data := encodeVarInt(0x12345)                                                // stream ID
        data = append(data, encodeVarInt(uint64(protocol.MaxPacketBufferSize)+1)...) // data length
        data = append(data, make([]byte, protocol.MaxPacketBufferSize+1)...)
-       _, _, err := parseStreamFrame(data, 0x8^0x2, protocol.Version1)
+       _, _, err := ParseStreamFrame(data, 0x8^0x2, protocol.Version1)
        require.Equal(t, io.EOF, err)
 }
 
@@ -80,7 +80,7 @@ func TestParseStreamFrameRejectsFramesExceedingRemainingSize(t *testing.T) {
        data := encodeVarInt(0x12345)           // stream ID
        data = append(data, encodeVarInt(7)...) // data length
        data = append(data, []byte("foobar")...)
-       _, _, err := parseStreamFrame(data, 0x8^0x2, protocol.Version1)
+       _, _, err := ParseStreamFrame(data, 0x8^0x2, protocol.Version1)
        require.Equal(t, io.EOF, err)
 }
 
@@ -90,10 +90,10 @@ func TestParseStreamFrameErrorsOnEOFs(t *testing.T) {
        data = append(data, encodeVarInt(0xdecafbad)...) // offset
        data = append(data, encodeVarInt(6)...)          // data length
        data = append(data, []byte("foobar")...)
-       _, _, err := parseStreamFrame(data, typ, protocol.Version1)
+       _, _, err := ParseStreamFrame(data, FrameType(typ), protocol.Version1)
        require.NoError(t, err)
        for i := range data {
-               _, _, err = parseStreamFrame(data[:i], typ, protocol.Version1)
+               _, _, err = ParseStreamFrame(data[:i], FrameType(typ), protocol.Version1)
                require.Error(t, err)
        }
 }
@@ -101,7 +101,7 @@ func TestParseStreamFrameErrorsOnEOFs(t *testing.T) {
 func TestParseStreamUsesBufferForLongFrames(t *testing.T) {
        data := encodeVarInt(0x12345) // stream ID
        data = append(data, bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize)...)
-       frame, l, err := parseStreamFrame(data, 0x8, protocol.Version1)
+       frame, l, err := ParseStreamFrame(data, 0x8, protocol.Version1)
        require.NoError(t, err)
        require.Equal(t, protocol.StreamID(0x12345), frame.StreamID)
        require.Equal(t, bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize), frame.Data)
@@ -115,7 +115,7 @@ func TestParseStreamUsesBufferForLongFrames(t *testing.T) {
 func TestParseStreamDoesNotUseBufferForShortFrames(t *testing.T) {
        data := encodeVarInt(0x12345) // stream ID
        data = append(data, bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize-1)...)
-       frame, l, err := parseStreamFrame(data, 0x8, protocol.Version1)
+       frame, l, err := ParseStreamFrame(data, 0x8, protocol.Version1)
        require.NoError(t, err)
        require.Equal(t, protocol.StreamID(0x12345), frame.StreamID)
        require.Equal(t, bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize-1), frame.Data)
index c946fec31bfa4ef415876bbd8465845e852d7d56..d98fde46c010f708552e3bee62932a0b5602ed18 100644 (file)
@@ -13,12 +13,13 @@ type StreamsBlockedFrame struct {
        StreamLimit protocol.StreamNum
 }
 
-func parseStreamsBlockedFrame(b []byte, typ uint64, _ protocol.Version) (*StreamsBlockedFrame, int, error) {
+func parseStreamsBlockedFrame(b []byte, typ FrameType, _ protocol.Version) (*StreamsBlockedFrame, int, error) {
        f := &StreamsBlockedFrame{}
+       //nolint:exhaustive // This will only be called with a BidiStreamBlockedFrameType or a UniStreamBlockedFrameType.
        switch typ {
-       case bidiStreamBlockedFrameType:
+       case FrameTypeBidiStreamBlocked:
                f.Type = protocol.StreamTypeBidi
-       case uniStreamBlockedFrameType:
+       case FrameTypeUniStreamBlocked:
                f.Type = protocol.StreamTypeUni
        }
        streamLimit, l, err := quicvarint.Parse(b)
@@ -35,9 +36,9 @@ func parseStreamsBlockedFrame(b []byte, typ uint64, _ protocol.Version) (*Stream
 func (f *StreamsBlockedFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
        switch f.Type {
        case protocol.StreamTypeBidi:
-               b = append(b, bidiStreamBlockedFrameType)
+               b = append(b, byte(FrameTypeBidiStreamBlocked))
        case protocol.StreamTypeUni:
-               b = append(b, uniStreamBlockedFrameType)
+               b = append(b, byte(FrameTypeUniStreamBlocked))
        }
        b = quicvarint.Append(b, uint64(f.StreamLimit))
        return b, nil
index ae8913ac97076e812d041e047b4cb22027357d63..e49a124d0ca2174a26e42f9d6a46dbc236a15cf9 100644 (file)
@@ -13,7 +13,7 @@ import (
 
 func TestParseStreamsBlockedFrameBidirectional(t *testing.T) {
        data := encodeVarInt(0x1337)
-       f, l, err := parseStreamsBlockedFrame(data, bidiStreamBlockedFrameType, protocol.Version1)
+       f, l, err := parseStreamsBlockedFrame(data, FrameTypeBidiStreamBlocked, protocol.Version1)
        require.NoError(t, err)
        require.Equal(t, protocol.StreamTypeBidi, f.Type)
        require.EqualValues(t, 0x1337, f.StreamLimit)
@@ -22,7 +22,7 @@ func TestParseStreamsBlockedFrameBidirectional(t *testing.T) {
 
 func TestParseStreamsBlockedFrameUnidirectional(t *testing.T) {
        data := encodeVarInt(0x7331)
-       f, l, err := parseStreamsBlockedFrame(data, uniStreamBlockedFrameType, protocol.Version1)
+       f, l, err := parseStreamsBlockedFrame(data, FrameTypeUniStreamBlocked, protocol.Version1)
        require.NoError(t, err)
        require.Equal(t, protocol.StreamTypeUni, f.Type)
        require.EqualValues(t, 0x7331, f.StreamLimit)
@@ -31,11 +31,11 @@ func TestParseStreamsBlockedFrameUnidirectional(t *testing.T) {
 
 func TestParseStreamsBlockedFrameErrorsOnEOFs(t *testing.T) {
        data := encodeVarInt(0x12345678)
-       _, l, err := parseStreamsBlockedFrame(data, bidiStreamBlockedFrameType, protocol.Version1)
+       _, l, err := parseStreamsBlockedFrame(data, FrameTypeBidiStreamBlocked, protocol.Version1)
        require.NoError(t, err)
        require.Equal(t, len(data), l)
        for i := range data {
-               _, _, err := parseStreamsBlockedFrame(data[:i], bidiStreamBlockedFrameType, protocol.Version1)
+               _, _, err := parseStreamsBlockedFrame(data[:i], FrameTypeBidiStreamBlocked, protocol.Version1)
                require.Equal(t, io.EOF, err)
        }
 }
@@ -58,7 +58,7 @@ func TestParseStreamsBlockedFrameMaxStreamCount(t *testing.T) {
                        typ, l, err := quicvarint.Parse(b)
                        require.NoError(t, err)
                        b = b[l:]
-                       frame, l, err := parseStreamsBlockedFrame(b, typ, protocol.Version1)
+                       frame, l, err := parseStreamsBlockedFrame(b, FrameType(typ), protocol.Version1)
                        require.NoError(t, err)
                        require.Equal(t, f, frame)
                        require.Equal(t, len(b), l)
@@ -84,7 +84,7 @@ func TestParseStreamsBlockedFrameErrorOnTooLargeStreamCount(t *testing.T) {
                        typ, l, err := quicvarint.Parse(b)
                        require.NoError(t, err)
                        b = b[l:]
-                       _, _, err = parseStreamsBlockedFrame(b, typ, protocol.Version1)
+                       _, _, err = parseStreamsBlockedFrame(b, FrameType(typ), protocol.Version1)
                        require.EqualError(t, err, fmt.Sprintf("%d exceeds the maximum stream count", protocol.MaxStreamCount+1))
                })
        }
@@ -97,7 +97,7 @@ func TestWriteStreamsBlockedFrameBidirectional(t *testing.T) {
        }
        b, err := f.Append(nil, protocol.Version1)
        require.NoError(t, err)
-       expected := []byte{bidiStreamBlockedFrameType}
+       expected := []byte{byte(FrameTypeBidiStreamBlocked)}
        expected = append(expected, encodeVarInt(0xdeadbeefcafe)...)
        require.Equal(t, expected, b)
        require.Equal(t, int(f.Length(protocol.Version1)), len(b))
@@ -110,7 +110,7 @@ func TestWriteStreamsBlockedFrameUnidirectional(t *testing.T) {
        }
        b, err := f.Append(nil, protocol.Version1)
        require.NoError(t, err)
-       expected := []byte{uniStreamBlockedFrameType}
+       expected := []byte{byte(FrameTypeUniStreamBlocked)}
        expected = append(expected, encodeVarInt(0xdeadbeefcafe)...)
        require.Equal(t, expected, b)
        require.Equal(t, int(f.Length(protocol.Version1)), len(b))
index 5f2ec0c14c92fc3f78bbca7281df4b35676dffb7..748c061e9805caa10217a44fa2d1608223c860e4 100644 (file)
@@ -736,10 +736,15 @@ func TestPackLongHeaderPadToAtLeast4Bytes(t *testing.T) {
        require.Equal(t, []byte{0, 0}, data[:2])
        // ...followed by the PING frame
        frameParser := wire.NewFrameParser(false, false)
-       l, frame, err := frameParser.ParseNext(data[2:], protocol.EncryptionHandshake, protocol.Version1)
+
+       frameType, lt, err := frameParser.ParseType(data[2:], protocol.EncryptionHandshake)
+       require.NoError(t, err)
+       require.Equal(t, 1, lt)
+       frame, l, err := frameParser.ParseLessCommonFrame(frameType, data[2+lt:], protocol.Version1)
        require.NoError(t, err)
        require.IsType(t, &wire.PingFrame{}, frame)
-       require.Equal(t, sealer.Overhead(), len(data)-2-l)
+       require.Zero(t, l)
+       require.Equal(t, sealer.Overhead(), len(data)-2-lt)
 }
 
 func TestPackShortHeaderPadToAtLeast4Bytes(t *testing.T) {
@@ -774,10 +779,15 @@ func TestPackShortHeaderPadToAtLeast4Bytes(t *testing.T) {
 
        // ... followed by the STREAM frame
        frameParser := wire.NewFrameParser(false, false)
-       frameLen, frame, err := frameParser.ParseNext(payload[1:], protocol.Encryption1RTT, protocol.Version1)
+       frameType, l, err := frameParser.ParseType(payload[1:], protocol.Encryption1RTT)
+       require.NoError(t, err)
+       require.Equal(t, 1, l)
+       require.True(t, frameType.IsStreamFrameType())
+
+       frame, frameLen, err := wire.ParseStreamFrame(payload[1+l:], frameType, protocol.Version1)
        require.NoError(t, err)
        require.Equal(t, f, frame)
-       require.Equal(t, len(payload)-1, frameLen)
+       require.Equal(t, len(payload)-2, frameLen)
 }
 
 func TestPackInitialProbePacket(t *testing.T) {