]> git.feebdaed.xyz Git - 0xmirror/quic-go.git/commitdiff
implement receiver side behavior for RESET_STREAM_AT (#5235)
authorMarten Seemann <martenseemann@gmail.com>
Thu, 26 Jun 2025 06:42:08 +0000 (14:42 +0800)
committerGitHub <noreply@github.com>
Thu, 26 Jun 2025 06:42:08 +0000 (08:42 +0200)
* implement receiver side behavior for RESET_STREAM_AT

* simplify reliable offset tracking

receive_stream.go
receive_stream_test.go

index 3b2d618ce57c73a7968379b8aa7ea9ccfd152748..61f4cdf4244c44777fc346668b4949354861ca8d 100644 (file)
@@ -42,6 +42,9 @@ type ReceiveStream struct {
        cancelErr           *StreamError
        closeForShutdownErr error
 
+       readPos      protocol.ByteCount
+       reliableSize protocol.ByteCount
+
        readChan chan struct{}
        readOnce chan struct{} // cap: 1, to protect against concurrent use of Read
        deadline time.Time
@@ -128,7 +131,7 @@ func (s *ReceiveStream) readImpl(p []byte) (hasStreamWindowUpdate bool, hasConnW
                s.errorRead = true
                return false, false, 0, io.EOF
        }
-       if s.cancelledRemotely || s.cancelledLocally {
+       if s.cancelledLocally || (s.cancelledRemotely && s.readPos >= s.reliableSize) {
                s.errorRead = true
                return false, false, 0, s.cancelErr
        }
@@ -151,9 +154,9 @@ func (s *ReceiveStream) readImpl(p []byte) (hasStreamWindowUpdate bool, hasConnW
                        if s.closeForShutdownErr != nil {
                                return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, s.closeForShutdownErr
                        }
-                       if s.cancelledRemotely || s.cancelledLocally {
+                       if s.cancelledLocally || (s.cancelledRemotely && s.readPos >= s.reliableSize) {
                                s.errorRead = true
-                               return hasStreamWindowUpdate, hasConnWindowUpdate, 0, s.cancelErr
+                               return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, s.cancelErr
                        }
 
                        deadline := s.deadline
@@ -194,14 +197,11 @@ func (s *ReceiveStream) readImpl(p []byte) (hasStreamWindowUpdate bool, hasConnW
                if s.readPosInFrame > len(s.currentFrame) {
                        return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, len(s.currentFrame))
                }
-
                m := copy(p[bytesRead:], s.currentFrame[s.readPosInFrame:])
-               s.readPosInFrame += m
-               bytesRead += m
 
                // when a RESET_STREAM was received, the flow controller was already
-               // informed about the final byteOffset for this stream
-               if !s.cancelledRemotely {
+               // informed about the final offset for this stream
+               if !s.cancelledRemotely || s.readPos < s.reliableSize {
                        hasStream, hasConn := s.flowController.AddBytesRead(protocol.ByteCount(m))
                        if hasStream {
                                s.queuedMaxStreamData = true
@@ -212,6 +212,14 @@ func (s *ReceiveStream) readImpl(p []byte) (hasStreamWindowUpdate bool, hasConnW
                        }
                }
 
+               s.readPosInFrame += m
+               s.readPos += protocol.ByteCount(m)
+               bytesRead += m
+
+               if s.cancelledRemotely && s.readPos >= s.reliableSize {
+                       s.flowController.Abandon()
+               }
+
                if s.readPosInFrame >= len(s.currentFrame) && s.currentFrameIsLast {
                        s.currentFrame = nil
                        if s.currentFrameDone != nil {
@@ -221,6 +229,10 @@ func (s *ReceiveStream) readImpl(p []byte) (hasStreamWindowUpdate bool, hasConnW
                        return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, io.EOF
                }
        }
+       if s.cancelledRemotely && s.readPos >= s.reliableSize {
+               s.errorRead = true
+               return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, s.cancelErr
+       }
        return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, nil
 }
 
@@ -231,7 +243,7 @@ func (s *ReceiveStream) dequeueNextFrame() {
                s.currentFrameDone()
        }
        offset, s.currentFrame, s.currentFrameDone = s.frameQueue.Pop()
-       s.currentFrameIsLast = offset+protocol.ByteCount(len(s.currentFrame)) >= s.finalOffset
+       s.currentFrameIsLast = offset+protocol.ByteCount(len(s.currentFrame)) >= s.finalOffset && !s.cancelledRemotely
        s.readPosInFrame = 0
 }
 
@@ -323,11 +335,19 @@ func (s *ReceiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame,
        }
        s.finalOffset = frame.FinalSize
 
+       // senders are allowed to reduce the reliable size, but frames might have been reordered
+       if (!s.cancelledRemotely && s.reliableSize == 0) || frame.ReliableSize < s.reliableSize {
+               s.reliableSize = frame.ReliableSize
+       }
+       if s.readPos >= s.reliableSize {
+               // calling Abandon multiple times is a no-op
+               s.flowController.Abandon()
+       }
        // ignore duplicate RESET_STREAM frames for this stream (after checking their final offset)
        if s.cancelledRemotely {
                return nil
        }
-       s.flowController.Abandon()
+
        // don't save the error if the RESET_STREAM frames was received after CancelRead was called
        if s.cancelledLocally {
                return nil
index 2b9f38d2aa0e04e4bd8f85103156d972ee1a518c..3619fa3ca40c305499fde8317fedd84c853b78d3 100644 (file)
@@ -447,7 +447,16 @@ func TestReceiveStreamCancellation(t *testing.T) {
        require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1234, Remote: false})
 }
 
-func TestReceiveStreamCancelReadAfterFINReceived(t *testing.T) {
+func TestReceiveStreamCancelReadAfterFIN(t *testing.T) {
+       t.Run("FIN not read", func(t *testing.T) {
+               testReceiveStreamCancelReadAfterFIN(t, false)
+       })
+       t.Run("FIN read", func(t *testing.T) {
+               testReceiveStreamCancelReadAfterFIN(t, true)
+       })
+}
+
+func testReceiveStreamCancelReadAfterFIN(t *testing.T, finRead bool) {
        mockCtrl := gomock.NewController(t)
        mockFC := mocks.NewMockStreamFlowController(mockCtrl)
        mockSender := NewMockStreamSender(mockCtrl)
@@ -456,46 +465,38 @@ func TestReceiveStreamCancelReadAfterFINReceived(t *testing.T) {
        mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true, gomock.Any())
        mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42))
        require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar"), Fin: true}, time.Now()))
+       if finRead {
+               mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6))
+               n, err := str.Read(make([]byte, 10))
+               require.ErrorIs(t, err, io.EOF)
+               require.Equal(t, 6, n)
+       }
 
        // if the FIN was received, but not read yet, a STOP_SENDING frame is queued
-       mockSender.EXPECT().onHasStreamControlFrame(str.StreamID(), str)
-       mockFC.EXPECT().Abandon()
+       if !finRead {
+               mockFC.EXPECT().Abandon()
+               mockSender.EXPECT().onHasStreamControlFrame(str.StreamID(), str)
+       }
        str.CancelRead(1337)
        f, ok, hasMore := str.getControlFrame(time.Now())
-       require.True(t, ok)
-       require.Equal(t, &wire.StopSendingFrame{StreamID: 42, ErrorCode: 1337}, f.Frame)
-       require.False(t, hasMore)
-
-       // Read returns the error
-       n, err := str.Read([]byte{0})
-       require.Zero(t, n)
-       require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1337, Remote: false})
-}
-
-func TestReceiveStreamCancelReadAfterFINRead(t *testing.T) {
-       mockCtrl := gomock.NewController(t)
-       mockFC := mocks.NewMockStreamFlowController(mockCtrl)
-       mockSender := NewMockStreamSender(mockCtrl)
-       str := newReceiveStream(42, mockSender, mockFC)
-
-       mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true, gomock.Any())
-       mockFC.EXPECT().AddBytesRead(protocol.ByteCount(6))
-       mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42))
-       require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar"), Fin: true}, time.Now()))
-       n, err := str.Read(make([]byte, 10))
-       require.ErrorIs(t, err, io.EOF)
-       require.Equal(t, 6, n)
-
        // if the EOF was already read, no STOP_SENDING frame is queued
-       str.CancelRead(1234)
-       _, ok, hasMore := str.getControlFrame(time.Now())
-       require.False(t, ok)
-       require.False(t, hasMore)
+       if finRead {
+               require.False(t, ok)
+               require.False(t, hasMore)
+       } else {
+               require.True(t, ok)
+               require.Equal(t, &wire.StopSendingFrame{StreamID: 42, ErrorCode: 1337}, f.Frame)
+               require.False(t, hasMore)
+       }
 
        // Read returns the error
-       n, err = str.Read([]byte{0})
+       n, err := str.Read([]byte{0})
        require.Zero(t, n)
-       require.ErrorIs(t, err, io.EOF)
+       if finRead {
+               require.ErrorIs(t, err, io.EOF)
+       } else {
+               require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1337, Remote: false})
+       }
 }
 
 func TestReceiveStreamReset(t *testing.T) {
@@ -520,7 +521,7 @@ func TestReceiveStreamReset(t *testing.T) {
        mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42))
        gomock.InOrder(
                mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true, gomock.Any()),
-               mockFC.EXPECT().Abandon(),
+               mockFC.EXPECT().Abandon().MinTimes(1),
        )
        require.NoError(t, str.handleResetStreamFrame(
                &wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1234, FinalSize: 42},
@@ -616,14 +617,14 @@ func TestReceiveStreamConcurrentReads(t *testing.T) {
 
        const num = 3
        errChan := make(chan error, num)
-       for i := 0; i < num; i++ {
+       for range num {
                go func() {
                        _, err := str.Read(make([]byte, 8))
                        errChan <- err
                }()
        }
        require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar"), Fin: true}, time.Now()))
-       for i := 0; i < num; i++ {
+       for range num {
                select {
                case err := <-errChan:
                        require.ErrorIs(t, err, io.EOF)
@@ -634,3 +635,141 @@ func TestReceiveStreamConcurrentReads(t *testing.T) {
        require.Equal(t, protocol.ByteCount(6), bytesRead)
        require.Equal(t, int32(1), numCompleted.Load())
 }
+
+func TestReceiveStreamResetStreamAtBeforeReadOffset(t *testing.T) {
+       mockCtrl := gomock.NewController(t)
+       mockFC := mocks.NewMockStreamFlowController(mockCtrl)
+       mockSender := NewMockStreamSender(mockCtrl)
+       str := newReceiveStream(42, mockSender, mockFC)
+
+       mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false, gomock.Any())
+       require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar")}, time.Now()))
+       mockFC.EXPECT().AddBytesRead(protocol.ByteCount(3))
+       b := make([]byte, 3)
+       n, err := str.Read(b)
+       require.NoError(t, err)
+       require.Equal(t, 3, n)
+       require.Equal(t, []byte("foo"), b)
+
+       mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true, gomock.Any())
+       mockFC.EXPECT().Abandon()
+       str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1337, FinalSize: 10, ReliableSize: 3}, time.Now())
+       require.True(t, mockCtrl.Satisfied())
+
+       // Read returns the error
+       mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42))
+       n, err = str.Read([]byte{0})
+       require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1337, Remote: true})
+       require.Zero(t, n)
+}
+
+func TestReceiveStreamResetStreamAtAfterReadOffset(t *testing.T) {
+       mockCtrl := gomock.NewController(t)
+       mockFC := mocks.NewMockStreamFlowController(mockCtrl)
+       mockSender := NewMockStreamSender(mockCtrl)
+       str := newReceiveStream(42, mockSender, mockFC)
+
+       mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false, gomock.Any())
+       require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar")}, time.Now()))
+       mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2))
+       b := make([]byte, 2)
+       n, err := str.Read(b)
+       require.NoError(t, err)
+       require.Equal(t, 2, n)
+       require.Equal(t, []byte("fo"), b)
+
+       mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true, gomock.Any())
+       str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1337, FinalSize: 10, ReliableSize: 6}, time.Now())
+       require.True(t, mockCtrl.Satisfied())
+
+       // Read returns the error
+       mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2))
+       n, err = str.Read(b)
+       require.NoError(t, err)
+       require.Equal(t, 2, n)
+       require.Equal(t, []byte("ob"), b)
+       require.True(t, mockCtrl.Satisfied())
+
+       gomock.InOrder(
+               mockFC.EXPECT().AddBytesRead(protocol.ByteCount(2)),
+               mockFC.EXPECT().Abandon(),
+       )
+       mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42))
+       n, err = str.Read(b)
+       require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1337, Remote: true})
+       require.Equal(t, 2, n)
+       require.Equal(t, []byte("ar"), b)
+}
+
+func TestReceiveStreamMultipleResetStreamAt(t *testing.T) {
+       mockCtrl := gomock.NewController(t)
+       mockFC := mocks.NewMockStreamFlowController(mockCtrl)
+       mockSender := NewMockStreamSender(mockCtrl)
+       str := newReceiveStream(42, mockSender, mockFC)
+
+       mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false, gomock.Any())
+       require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar")}, time.Now()))
+
+       mockFC.EXPECT().AddBytesRead(protocol.ByteCount(3))
+       b := make([]byte, 3)
+       n, err := str.Read(b)
+       require.NoError(t, err)
+       require.Equal(t, 3, n)
+       require.Equal(t, []byte("foo"), b)
+       require.True(t, mockCtrl.Satisfied())
+
+       mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true, gomock.Any())
+       str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1337, FinalSize: 10, ReliableSize: 6}, time.Now())
+       require.True(t, mockCtrl.Satisfied())
+
+       // receiving a reordered RESET_STREAM_AT frame has no effect
+       mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true, gomock.Any())
+       str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1337, FinalSize: 10, ReliableSize: 8}, time.Now())
+       require.True(t, mockCtrl.Satisfied())
+
+       // receiving a RESET_STREAM_AT frame with a smaller reliable size is valid
+       mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true, gomock.Any())
+       mockFC.EXPECT().Abandon()
+       str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1337, FinalSize: 10, ReliableSize: 3}, time.Now())
+
+       // Read returns the error
+       mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42))
+       n, err = str.Read(b)
+       require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1337, Remote: true})
+       require.Zero(t, n)
+}
+
+func TestReceiveStreamResetStreamAtAfterResetStream(t *testing.T) {
+       mockCtrl := gomock.NewController(t)
+       mockFC := mocks.NewMockStreamFlowController(mockCtrl)
+       mockSender := NewMockStreamSender(mockCtrl)
+       str := newReceiveStream(42, mockSender, mockFC)
+
+       mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false, gomock.Any())
+       require.NoError(t, str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar")}, time.Now()))
+
+       mockFC.EXPECT().AddBytesRead(protocol.ByteCount(3))
+       b := make([]byte, 3)
+       n, err := str.Read(b)
+       require.NoError(t, err)
+       require.Equal(t, 3, n)
+       require.Equal(t, []byte("foo"), b)
+       require.True(t, mockCtrl.Satisfied())
+
+       mockFC.EXPECT().Abandon()
+       mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true, gomock.Any())
+       str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1337, FinalSize: 10}, time.Now())
+       require.True(t, mockCtrl.Satisfied())
+
+       // receiving a reordered RESET_STREAM_AT frame has no effect
+       mockFC.EXPECT().Abandon()
+       mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true, gomock.Any())
+       str.handleResetStreamFrame(&wire.ResetStreamFrame{StreamID: 42, ErrorCode: 1337, FinalSize: 10, ReliableSize: 8}, time.Now())
+       require.True(t, mockCtrl.Satisfied())
+
+       // Read returns the error
+       mockSender.EXPECT().onStreamCompleted(protocol.StreamID(42))
+       n, err = str.Read(b)
+       require.ErrorIs(t, err, &StreamError{StreamID: 42, ErrorCode: 1337, Remote: true})
+       require.Zero(t, n)
+}