]> git.feebdaed.xyz Git - 0xmirror/quic-go.git/commitdiff
convert SendStream to a struct (#5172)
authorMarten Seemann <martenseemann@gmail.com>
Sun, 1 Jun 2025 03:53:53 +0000 (11:53 +0800)
committerGitHub <noreply@github.com>
Sun, 1 Jun 2025 03:53:53 +0000 (05:53 +0200)
18 files changed:
connection.go
connection_test.go
framer.go
framer_test.go
http3/conn.go
http3/http3_helper_test.go
integrationtests/self/cancelation_test.go
interface.go
mock_quic_conn_test.go
mock_send_stream_internal_test.go [deleted file]
mock_stream_frame_getter_test.go [new file with mode: 0644]
mock_stream_manager_test.go
mock_stream_sender_test.go
mockgen.go
send_stream.go
send_stream_test.go
stream.go
streams_map.go

index 78482a7226fbfcf45aa0553de3d214aa0cfdbe9c..18bb6e031aa01612ba5a019b3c0b4b7f9582ee52 100644 (file)
@@ -31,12 +31,12 @@ type unpacker interface {
 }
 
 type streamManager interface {
-       GetOrOpenSendStream(protocol.StreamID) (sendStreamI, error)
+       GetOrOpenSendStream(protocol.StreamID) (*SendStream, error)
        GetOrOpenReceiveStream(protocol.StreamID) (receiveStreamI, error)
        OpenStream() (*Stream, error)
-       OpenUniStream() (SendStream, error)
+       OpenUniStream() (*SendStream, error)
        OpenStreamSync(context.Context) (*Stream, error)
-       OpenUniStreamSync(context.Context) (SendStream, error)
+       OpenUniStreamSync(context.Context) (*SendStream, error)
        AcceptStream(context.Context) (*Stream, error)
        AcceptUniStream(context.Context) (ReceiveStream, error)
        DeleteStream(protocol.StreamID) error
@@ -2509,11 +2509,11 @@ func (s *connection) OpenStreamSync(ctx context.Context) (*Stream, error) {
        return s.streamsMap.OpenStreamSync(ctx)
 }
 
-func (s *connection) OpenUniStream() (SendStream, error) {
+func (s *connection) OpenUniStream() (*SendStream, error) {
        return s.streamsMap.OpenUniStream()
 }
 
-func (s *connection) OpenUniStreamSync(ctx context.Context) (SendStream, error) {
+func (s *connection) OpenUniStreamSync(ctx context.Context) (*SendStream, error) {
        return s.streamsMap.OpenUniStreamSync(ctx)
 }
 
@@ -2572,7 +2572,7 @@ func (s *connection) queueControlFrame(f wire.Frame) {
 
 func (s *connection) onHasConnectionData() { s.scheduleSending() }
 
-func (s *connection) onHasStreamData(id protocol.StreamID, str sendStreamI) {
+func (s *connection) onHasStreamData(id protocol.StreamID, str *SendStream) {
        s.framer.AddActiveStream(id, str)
        s.scheduleSending()
 }
index d8e3b3da18f7331b34d01c378d4574f07be471b7..1dcfa8a42557e0d8f8f69861450922facd2479cc 100644 (file)
@@ -284,15 +284,17 @@ func TestConnectionHandleSendStreamFrames(t *testing.T) {
                mockCtrl := gomock.NewController(t)
                streamsMap := NewMockStreamManager(mockCtrl)
                tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
-               str := NewMockSendStreamI(mockCtrl)
+               mockSender := NewMockStreamSender(mockCtrl)
+               mockSender.EXPECT().onHasStreamControlFrame(streamID, gomock.Any()).AnyTimes()
+               mockFC := mocks.NewMockStreamFlowController(mockCtrl)
+               str := newSendStream(context.Background(), streamID, mockSender, mockFC)
                // STOP_SENDING frame
                streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(str, nil)
-               str.EXPECT().handleStopSendingFrame(ss)
                _, err := tc.conn.handleFrame(ss, protocol.Encryption1RTT, connID, now)
                require.NoError(t, err)
                // MAX_STREAM_DATA frame
+               mockFC.EXPECT().UpdateSendWindow(protocol.ByteCount(1337))
                streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(str, nil)
-               str.EXPECT().updateSendWindow(msd.MaximumStreamData)
                _, err = tc.conn.handleFrame(msd, protocol.Encryption1RTT, connID, now)
                require.NoError(t, err)
        })
@@ -363,29 +365,30 @@ func TestConnectionOpenStreams(t *testing.T) {
        tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
 
        // using OpenStream
-       mstr := &Stream{}
-       streamsMap.EXPECT().OpenStream().Return(mstr, nil)
+       str1 := &Stream{}
+       streamsMap.EXPECT().OpenStream().Return(str1, nil)
        str, err := tc.conn.OpenStream()
        require.NoError(t, err)
-       require.Equal(t, mstr, str)
+       require.Equal(t, str1, str)
 
        // using OpenStreamSync
-       streamsMap.EXPECT().OpenStreamSync(context.Background()).Return(mstr, nil)
+       streamsMap.EXPECT().OpenStreamSync(context.Background()).Return(str1, nil)
        str, err = tc.conn.OpenStreamSync(context.Background())
        require.NoError(t, err)
-       require.Equal(t, mstr, str)
+       require.Equal(t, str1, str)
 
        // using OpenUniStream
-       streamsMap.EXPECT().OpenUniStream().Return(mstr, nil)
+       str2 := &SendStream{}
+       streamsMap.EXPECT().OpenUniStream().Return(str2, nil)
        ustr, err := tc.conn.OpenUniStream()
        require.NoError(t, err)
-       require.Equal(t, mstr, ustr)
+       require.Equal(t, str2, ustr)
 
        // using OpenUniStreamSync
-       streamsMap.EXPECT().OpenUniStreamSync(context.Background()).Return(mstr, nil)
+       streamsMap.EXPECT().OpenUniStreamSync(context.Background()).Return(str2, nil)
        ustr, err = tc.conn.OpenUniStreamSync(context.Background())
        require.NoError(t, err)
-       require.Equal(t, mstr, ustr)
+       require.Equal(t, str2, ustr)
 }
 
 func TestConnectionAcceptStreams(t *testing.T) {
index fee3163155809c9a071dd4cb29c38e3e44291f7d..a331202d02ee12bb1c683943f7ccde795780b5c8 100644 (file)
--- a/framer.go
+++ b/framer.go
@@ -22,6 +22,10 @@ const (
 // (which is the RESET_STREAM frame).
 const maxStreamControlFrameSize = 25
 
+type streamFrameGetter interface {
+       popStreamFrame(protocol.ByteCount, protocol.Version) (ackhandler.StreamFrame, *wire.StreamDataBlockedFrame, bool)
+}
+
 type streamControlFrameGetter interface {
        getControlFrame(time.Time) (_ ackhandler.Frame, ok, hasMore bool)
 }
@@ -29,7 +33,7 @@ type streamControlFrameGetter interface {
 type framer struct {
        mutex sync.Mutex
 
-       activeStreams            map[protocol.StreamID]sendStreamI
+       activeStreams            map[protocol.StreamID]streamFrameGetter
        streamQueue              ringbuffer.RingBuffer[protocol.StreamID]
        streamsWithControlFrames map[protocol.StreamID]streamControlFrameGetter
 
@@ -42,7 +46,7 @@ type framer struct {
 
 func newFramer(connFlowController flowcontrol.ConnectionFlowController) *framer {
        return &framer{
-               activeStreams:            make(map[protocol.StreamID]sendStreamI),
+               activeStreams:            make(map[protocol.StreamID]streamFrameGetter),
                streamsWithControlFrames: make(map[protocol.StreamID]streamControlFrameGetter),
                connFlowController:       connFlowController,
        }
@@ -214,7 +218,7 @@ func (f *framer) QueuedTooManyControlFrames() bool {
        return f.queuedTooManyControlFrames
 }
 
-func (f *framer) AddActiveStream(id protocol.StreamID, str sendStreamI) {
+func (f *framer) AddActiveStream(id protocol.StreamID, str streamFrameGetter) {
        f.mutex.Lock()
        if _, ok := f.activeStreams[id]; !ok {
                f.streamQueue.PushBack(id)
index 9e4bfe580c31647bec97d2f5f1fc8e6a6ed4b521..660d0f91c28414545221831c364ad5b0f7d9b45a 100644 (file)
@@ -114,7 +114,7 @@ func TestFramerStreamDataBlocked(t *testing.T) {
 // in the next packet.
 func testFramerStreamDataBlocked(t *testing.T, fits bool) {
        const streamID = 5
-       str := NewMockSendStreamI(gomock.NewController(t))
+       str := NewMockStreamFrameGetter(gomock.NewController(t))
        framer := newFramer(flowcontrol.NewConnectionFlowController(0, 0, nil, nil, nil))
        framer.AddActiveStream(streamID, str)
        str.EXPECT().popStreamFrame(gomock.Any(), gomock.Any()).DoAndReturn(
@@ -175,7 +175,7 @@ func testFramerDataBlocked(t *testing.T, fits bool) {
        fc.UpdateSendWindow(offset)
        fc.AddBytesSent(offset)
 
-       str := NewMockSendStreamI(gomock.NewController(t))
+       str := NewMockStreamFrameGetter(gomock.NewController(t))
        framer := newFramer(fc)
        framer.AddActiveStream(streamID, str)
 
@@ -292,9 +292,9 @@ func TestFramerAppendStreamFrames(t *testing.T) {
 
        // add two streams
        mockCtrl := gomock.NewController(t)
-       str1 := NewMockSendStreamI(mockCtrl)
+       str1 := NewMockStreamFrameGetter(mockCtrl)
        str1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f1}, nil, true)
-       str2 := NewMockSendStreamI(mockCtrl)
+       str2 := NewMockStreamFrameGetter(mockCtrl)
        str2.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).Return(ackhandler.StreamFrame{Frame: f2}, nil, false)
        framer.AddActiveStream(str1ID, str1)
        framer.AddActiveStream(str1ID, str1) // duplicate calls are ok (they're no-ops)
@@ -332,7 +332,7 @@ func TestFramerRemoveActiveStream(t *testing.T) {
        const id = protocol.StreamID(42)
        framer := newFramer(flowcontrol.NewConnectionFlowController(0, 0, nil, nil, nil))
        require.False(t, framer.HasData())
-       framer.AddActiveStream(id, NewMockSendStreamI(gomock.NewController(t)))
+       framer.AddActiveStream(id, NewMockStreamFrameGetter(gomock.NewController(t)))
        require.True(t, framer.HasData())
        framer.RemoveActiveStream(id) // no calls will be issued to the mock stream
        // we can't assert on framer.HasData here, since it's not removed from the ringbuffer
@@ -344,7 +344,7 @@ func TestFramerRemoveActiveStream(t *testing.T) {
 func TestFramerMinStreamFrameSize(t *testing.T) {
        const id = protocol.StreamID(42)
        framer := newFramer(flowcontrol.NewConnectionFlowController(0, 0, nil, nil, nil))
-       str := NewMockSendStreamI(gomock.NewController(t))
+       str := NewMockStreamFrameGetter(gomock.NewController(t))
        framer.AddActiveStream(id, str)
 
        require.True(t, framer.HasData())
@@ -369,7 +369,7 @@ func TestFramerMinStreamFrameSize(t *testing.T) {
 func TestFramerMinStreamFrameSizeMultipleStreamFrames(t *testing.T) {
        const id = protocol.StreamID(42)
        framer := newFramer(flowcontrol.NewConnectionFlowController(0, 0, nil, nil, nil))
-       str := NewMockSendStreamI(gomock.NewController(t))
+       str := NewMockStreamFrameGetter(gomock.NewController(t))
        framer.AddActiveStream(id, str)
 
        // pop a frame such that the remaining size is one byte less than the minimum STREAM frame size
@@ -388,7 +388,7 @@ func TestFramerMinStreamFrameSizeMultipleStreamFrames(t *testing.T) {
 
 func TestFramerFillPacketOneStream(t *testing.T) {
        const id = protocol.StreamID(42)
-       str := NewMockSendStreamI(gomock.NewController(t))
+       str := NewMockStreamFrameGetter(gomock.NewController(t))
        framer := newFramer(flowcontrol.NewConnectionFlowController(0, 0, nil, nil, nil))
 
        for i := protocol.MinStreamFrameSize; i < 2000; i++ {
@@ -418,8 +418,8 @@ func TestFramerFillPacketMultipleStreams(t *testing.T) {
                id2 = protocol.StreamID(11)
        )
        mockCtrl := gomock.NewController(t)
-       stream1 := NewMockSendStreamI(mockCtrl)
-       stream2 := NewMockSendStreamI(mockCtrl)
+       stream1 := NewMockStreamFrameGetter(mockCtrl)
+       stream2 := NewMockStreamFrameGetter(mockCtrl)
        framer := newFramer(flowcontrol.NewConnectionFlowController(0, 0, nil, nil, nil))
 
        for i := 2 * protocol.MinStreamFrameSize; i < 2000; i++ {
@@ -464,7 +464,7 @@ func TestFramer0RTTRejection(t *testing.T) {
        framer.QueueControlFrame(&wire.StreamsBlockedFrame{StreamLimit: 13})
        framer.QueueControlFrame(pc)
 
-       framer.AddActiveStream(10, NewMockSendStreamI(gomock.NewController(t)))
+       framer.AddActiveStream(10, NewMockStreamFrameGetter(gomock.NewController(t)))
 
        framer.Handle0RTTRejection()
        controlFrames, streamFrames, _ := framer.Append(nil, nil, protocol.MaxByteCount, time.Now(), protocol.Version1)
index 3b94c5503cc83ef764d633d6e77bdbd92a0d75c9..4ef53e879825db948be750be18f190c46ae42b06 100644 (file)
@@ -30,8 +30,8 @@ var errGoAway = errors.New("connection in graceful shutdown")
 type Connection interface {
        OpenStream() (*quic.Stream, error)
        OpenStreamSync(context.Context) (*quic.Stream, error)
-       OpenUniStream() (quic.SendStream, error)
-       OpenUniStreamSync(context.Context) (quic.SendStream, error)
+       OpenUniStream() (*quic.SendStream, error)
+       OpenUniStreamSync(context.Context) (*quic.SendStream, error)
        LocalAddr() net.Addr
        RemoteAddr() net.Addr
        CloseWithError(quic.ApplicationErrorCode, string) error
index eff6196c49c3c993d688b9fd01df9d48fa789766..b5f24b7e745b343ce30fa93a140da3cb6d1127ee 100644 (file)
@@ -147,7 +147,12 @@ func expectStreamReadReset(t *testing.T, str quic.ReceiveStream, errCode quic.St
        require.Equal(t, errCode, strErr.ErrorCode)
 }
 
-func expectStreamWriteReset(t *testing.T, str quic.SendStream, errCode quic.StreamErrorCode) {
+type quicSendStream interface {
+       io.Writer
+       Context() context.Context
+}
+
+func expectStreamWriteReset(t *testing.T, str quicSendStream, errCode quic.StreamErrorCode) {
        t.Helper()
 
        select {
index 987666459ce14a1ffa78acdb58840cc73678d023..4839d7c1a24af3d33f2c152fe8c3b6bda51fd1d8 100644 (file)
@@ -79,7 +79,7 @@ func TestStreamReadCancellation(t *testing.T) {
 
 func TestStreamWriteCancellation(t *testing.T) {
        t.Run("immediate", func(t *testing.T) {
-               testStreamCancellation(t, nil, func(str quic.SendStream) error {
+               testStreamCancellation(t, nil, func(str *quic.SendStream) error {
                        str.CancelWrite(quic.StreamErrorCode(str.StreamID()))
                        _, err := str.Write([]byte{0})
                        return err
@@ -87,7 +87,7 @@ func TestStreamWriteCancellation(t *testing.T) {
        })
 
        t.Run("after writing some data", func(t *testing.T) {
-               testStreamCancellation(t, nil, func(str quic.SendStream) error {
+               testStreamCancellation(t, nil, func(str *quic.SendStream) error {
                        length := rand.IntN(len(PRData) - 1)
                        if _, err := str.Write(PRData[:length]); err != nil {
                                return fmt.Errorf("writing stream data failed: %w", err)
@@ -101,7 +101,7 @@ func TestStreamWriteCancellation(t *testing.T) {
        // This test is especially valuable when run with race detector,
        // see https://github.com/quic-go/quic-go/issues/3239.
        t.Run("concurrent", func(t *testing.T) {
-               testStreamCancellation(t, nil, func(str quic.SendStream) error {
+               testStreamCancellation(t, nil, func(str *quic.SendStream) error {
                        errChan := make(chan error, 1)
                        go func() {
                                var offset int
@@ -146,7 +146,7 @@ func TestStreamReadWriteCancellation(t *testing.T) {
                                _, err := str.Read([]byte{0})
                                return err
                        },
-                       func(str quic.SendStream) error {
+                       func(str *quic.SendStream) error {
                                str.CancelWrite(quic.StreamErrorCode(str.StreamID()))
                                _, err := str.Write([]byte{0})
                                return err
@@ -165,7 +165,7 @@ func TestStreamReadWriteCancellation(t *testing.T) {
                                _, err := str.Read([]byte{0})
                                return err
                        },
-                       func(str quic.SendStream) error {
+                       func(str *quic.SendStream) error {
                                length := rand.IntN(len(PRData) - 1)
                                if _, err := str.Write(PRData[:length]); err != nil {
                                        return fmt.Errorf("writing stream data failed: %w", err)
@@ -183,7 +183,7 @@ func TestStreamReadWriteCancellation(t *testing.T) {
 func testStreamCancellation(
        t *testing.T,
        readFunc func(str quic.ReceiveStream) error,
-       writeFunc func(str quic.SendStream) error,
+       writeFunc func(str *quic.SendStream) error,
 ) {
        const numStreams = 80
 
@@ -457,7 +457,7 @@ func TestCancelOpenStreamSync(t *testing.T) {
                                continue
                        }
                        numOpened++
-                       go func(str quic.SendStream) {
+                       go func(str *quic.SendStream) {
                                defer str.Close()
                                if _, err := str.Write(PRData); err != nil {
                                        serverErrChan <- err
@@ -467,7 +467,7 @@ func TestCancelOpenStreamSync(t *testing.T) {
        }()
 
        clientErrChan := make(chan error, numStreams)
-       for i := 0; i < numStreams; i++ {
+       for range numStreams {
                <-msg
                str, err := conn.AcceptUniStream(context.Background())
                require.NoError(t, err)
index c149d068cfd2a9a5a2c7da34012ac0b2e0e616e7..9fb81124a85692d31c82b470836289c33ddc2eb2 100644 (file)
@@ -89,40 +89,6 @@ type ReceiveStream interface {
        SetReadDeadline(t time.Time) error
 }
 
-// A SendStream is a unidirectional Send Stream.
-type SendStream interface {
-       // StreamID returns the stream ID.
-       StreamID() StreamID
-       // Write writes data to the stream.
-       // Write can be made to time out using SetDeadline and SetWriteDeadline.
-       // If the stream was canceled, the error is a StreamError.
-       io.Writer
-       // Close closes the write-direction of the stream.
-       // Future calls to Write are not permitted after calling Close.
-       // It must not be called concurrently with Write.
-       // It must not be called after calling CancelWrite.
-       io.Closer
-       // CancelWrite aborts sending on this stream.
-       // Data already written, but not yet delivered to the peer is not guaranteed to be delivered reliably.
-       // Write will unblock immediately, and future calls to Write will fail.
-       // When called multiple times it is a no-op.
-       // When called after Close, it aborts delivery. Note that there is no guarantee if
-       // the peer will receive the FIN or the reset first.
-       CancelWrite(StreamErrorCode)
-       // The Context is canceled as soon as the write-side of the stream is closed.
-       // This happens when Close() or CancelWrite() is called, or when the peer
-       // cancels the read-side of their stream.
-       // The cancellation cause is set to the error that caused the stream to
-       // close, or `context.Canceled` in case the stream is closed without error.
-       Context() context.Context
-       // SetWriteDeadline sets the deadline for future Write calls
-       // and any currently-blocked Write call.
-       // Even if write times out, it may return n > 0, indicating that
-       // some data was successfully written.
-       // A zero value for t means Write will not time out.
-       SetWriteDeadline(t time.Time) error
-}
-
 // A Connection is a QUIC connection between two peers.
 // Calls to the connection (and to streams) can return the following types of errors:
 //   - [ApplicationError]: for errors triggered by the application running on top of QUIC
@@ -155,13 +121,13 @@ type Connection interface {
        // or the stream has been reset or closed.
        // When reaching the peer's stream limit, it is not possible to open a new stream until the
        // peer raises the stream limit. In that case, a StreamLimitReachedError is returned.
-       OpenUniStream() (SendStream, error)
+       OpenUniStream() (*SendStream, error)
        // OpenUniStreamSync opens a new outgoing unidirectional QUIC stream.
        // It blocks until a new stream can be opened.
        // There is no signaling to the peer about new streams:
        // The peer can only accept the stream after data has been sent on the stream,
        // or the stream has been reset or closed.
-       OpenUniStreamSync(context.Context) (SendStream, error)
+       OpenUniStreamSync(context.Context) (*SendStream, error)
        // LocalAddr returns the local address.
        LocalAddr() net.Addr
        // RemoteAddr returns the address of the peer.
index 2eb5590da78be12a75a4164da6b94d895a5d6f1d..c24ace7bfeb0016b54dc65b51e21d9bc1ef63c12 100644 (file)
@@ -466,10 +466,10 @@ func (c *MockQUICConnOpenStreamSyncCall) DoAndReturn(f func(context.Context) (*S
 }
 
 // OpenUniStream mocks base method.
-func (m *MockQUICConn) OpenUniStream() (SendStream, error) {
+func (m *MockQUICConn) OpenUniStream() (*SendStream, error) {
        m.ctrl.T.Helper()
        ret := m.ctrl.Call(m, "OpenUniStream")
-       ret0, _ := ret[0].(SendStream)
+       ret0, _ := ret[0].(*SendStream)
        ret1, _ := ret[1].(error)
        return ret0, ret1
 }
@@ -487,28 +487,28 @@ type MockQUICConnOpenUniStreamCall struct {
 }
 
 // Return rewrite *gomock.Call.Return
-func (c *MockQUICConnOpenUniStreamCall) Return(arg0 SendStream, arg1 error) *MockQUICConnOpenUniStreamCall {
+func (c *MockQUICConnOpenUniStreamCall) Return(arg0 *SendStream, arg1 error) *MockQUICConnOpenUniStreamCall {
        c.Call = c.Call.Return(arg0, arg1)
        return c
 }
 
 // Do rewrite *gomock.Call.Do
-func (c *MockQUICConnOpenUniStreamCall) Do(f func() (SendStream, error)) *MockQUICConnOpenUniStreamCall {
+func (c *MockQUICConnOpenUniStreamCall) Do(f func() (*SendStream, error)) *MockQUICConnOpenUniStreamCall {
        c.Call = c.Call.Do(f)
        return c
 }
 
 // DoAndReturn rewrite *gomock.Call.DoAndReturn
-func (c *MockQUICConnOpenUniStreamCall) DoAndReturn(f func() (SendStream, error)) *MockQUICConnOpenUniStreamCall {
+func (c *MockQUICConnOpenUniStreamCall) DoAndReturn(f func() (*SendStream, error)) *MockQUICConnOpenUniStreamCall {
        c.Call = c.Call.DoAndReturn(f)
        return c
 }
 
 // OpenUniStreamSync mocks base method.
-func (m *MockQUICConn) OpenUniStreamSync(arg0 context.Context) (SendStream, error) {
+func (m *MockQUICConn) OpenUniStreamSync(arg0 context.Context) (*SendStream, error) {
        m.ctrl.T.Helper()
        ret := m.ctrl.Call(m, "OpenUniStreamSync", arg0)
-       ret0, _ := ret[0].(SendStream)
+       ret0, _ := ret[0].(*SendStream)
        ret1, _ := ret[1].(error)
        return ret0, ret1
 }
@@ -526,19 +526,19 @@ type MockQUICConnOpenUniStreamSyncCall struct {
 }
 
 // Return rewrite *gomock.Call.Return
-func (c *MockQUICConnOpenUniStreamSyncCall) Return(arg0 SendStream, arg1 error) *MockQUICConnOpenUniStreamSyncCall {
+func (c *MockQUICConnOpenUniStreamSyncCall) Return(arg0 *SendStream, arg1 error) *MockQUICConnOpenUniStreamSyncCall {
        c.Call = c.Call.Return(arg0, arg1)
        return c
 }
 
 // Do rewrite *gomock.Call.Do
-func (c *MockQUICConnOpenUniStreamSyncCall) Do(f func(context.Context) (SendStream, error)) *MockQUICConnOpenUniStreamSyncCall {
+func (c *MockQUICConnOpenUniStreamSyncCall) Do(f func(context.Context) (*SendStream, error)) *MockQUICConnOpenUniStreamSyncCall {
        c.Call = c.Call.Do(f)
        return c
 }
 
 // DoAndReturn rewrite *gomock.Call.DoAndReturn
-func (c *MockQUICConnOpenUniStreamSyncCall) DoAndReturn(f func(context.Context) (SendStream, error)) *MockQUICConnOpenUniStreamSyncCall {
+func (c *MockQUICConnOpenUniStreamSyncCall) DoAndReturn(f func(context.Context) (*SendStream, error)) *MockQUICConnOpenUniStreamSyncCall {
        c.Call = c.Call.DoAndReturn(f)
        return c
 }
diff --git a/mock_send_stream_internal_test.go b/mock_send_stream_internal_test.go
deleted file mode 100644 (file)
index d7a0e1e..0000000
+++ /dev/null
@@ -1,458 +0,0 @@
-// Code generated by MockGen. DO NOT EDIT.
-// Source: github.com/quic-go/quic-go (interfaces: SendStreamI)
-//
-// Generated by this command:
-//
-//     mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_send_stream_internal_test.go github.com/quic-go/quic-go SendStreamI
-//
-
-// Package quic is a generated GoMock package.
-package quic
-
-import (
-       context "context"
-       reflect "reflect"
-       time "time"
-
-       ackhandler "github.com/quic-go/quic-go/internal/ackhandler"
-       protocol "github.com/quic-go/quic-go/internal/protocol"
-       wire "github.com/quic-go/quic-go/internal/wire"
-       gomock "go.uber.org/mock/gomock"
-)
-
-// MockSendStreamI is a mock of SendStreamI interface.
-type MockSendStreamI struct {
-       ctrl     *gomock.Controller
-       recorder *MockSendStreamIMockRecorder
-       isgomock struct{}
-}
-
-// MockSendStreamIMockRecorder is the mock recorder for MockSendStreamI.
-type MockSendStreamIMockRecorder struct {
-       mock *MockSendStreamI
-}
-
-// NewMockSendStreamI creates a new mock instance.
-func NewMockSendStreamI(ctrl *gomock.Controller) *MockSendStreamI {
-       mock := &MockSendStreamI{ctrl: ctrl}
-       mock.recorder = &MockSendStreamIMockRecorder{mock}
-       return mock
-}
-
-// EXPECT returns an object that allows the caller to indicate expected use.
-func (m *MockSendStreamI) EXPECT() *MockSendStreamIMockRecorder {
-       return m.recorder
-}
-
-// CancelWrite mocks base method.
-func (m *MockSendStreamI) CancelWrite(arg0 StreamErrorCode) {
-       m.ctrl.T.Helper()
-       m.ctrl.Call(m, "CancelWrite", arg0)
-}
-
-// CancelWrite indicates an expected call of CancelWrite.
-func (mr *MockSendStreamIMockRecorder) CancelWrite(arg0 any) *MockSendStreamICancelWriteCall {
-       mr.mock.ctrl.T.Helper()
-       call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelWrite", reflect.TypeOf((*MockSendStreamI)(nil).CancelWrite), arg0)
-       return &MockSendStreamICancelWriteCall{Call: call}
-}
-
-// MockSendStreamICancelWriteCall wrap *gomock.Call
-type MockSendStreamICancelWriteCall struct {
-       *gomock.Call
-}
-
-// Return rewrite *gomock.Call.Return
-func (c *MockSendStreamICancelWriteCall) Return() *MockSendStreamICancelWriteCall {
-       c.Call = c.Call.Return()
-       return c
-}
-
-// Do rewrite *gomock.Call.Do
-func (c *MockSendStreamICancelWriteCall) Do(f func(StreamErrorCode)) *MockSendStreamICancelWriteCall {
-       c.Call = c.Call.Do(f)
-       return c
-}
-
-// DoAndReturn rewrite *gomock.Call.DoAndReturn
-func (c *MockSendStreamICancelWriteCall) DoAndReturn(f func(StreamErrorCode)) *MockSendStreamICancelWriteCall {
-       c.Call = c.Call.DoAndReturn(f)
-       return c
-}
-
-// Close mocks base method.
-func (m *MockSendStreamI) Close() error {
-       m.ctrl.T.Helper()
-       ret := m.ctrl.Call(m, "Close")
-       ret0, _ := ret[0].(error)
-       return ret0
-}
-
-// Close indicates an expected call of Close.
-func (mr *MockSendStreamIMockRecorder) Close() *MockSendStreamICloseCall {
-       mr.mock.ctrl.T.Helper()
-       call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSendStreamI)(nil).Close))
-       return &MockSendStreamICloseCall{Call: call}
-}
-
-// MockSendStreamICloseCall wrap *gomock.Call
-type MockSendStreamICloseCall struct {
-       *gomock.Call
-}
-
-// Return rewrite *gomock.Call.Return
-func (c *MockSendStreamICloseCall) Return(arg0 error) *MockSendStreamICloseCall {
-       c.Call = c.Call.Return(arg0)
-       return c
-}
-
-// Do rewrite *gomock.Call.Do
-func (c *MockSendStreamICloseCall) Do(f func() error) *MockSendStreamICloseCall {
-       c.Call = c.Call.Do(f)
-       return c
-}
-
-// DoAndReturn rewrite *gomock.Call.DoAndReturn
-func (c *MockSendStreamICloseCall) DoAndReturn(f func() error) *MockSendStreamICloseCall {
-       c.Call = c.Call.DoAndReturn(f)
-       return c
-}
-
-// Context mocks base method.
-func (m *MockSendStreamI) Context() context.Context {
-       m.ctrl.T.Helper()
-       ret := m.ctrl.Call(m, "Context")
-       ret0, _ := ret[0].(context.Context)
-       return ret0
-}
-
-// Context indicates an expected call of Context.
-func (mr *MockSendStreamIMockRecorder) Context() *MockSendStreamIContextCall {
-       mr.mock.ctrl.T.Helper()
-       call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockSendStreamI)(nil).Context))
-       return &MockSendStreamIContextCall{Call: call}
-}
-
-// MockSendStreamIContextCall wrap *gomock.Call
-type MockSendStreamIContextCall struct {
-       *gomock.Call
-}
-
-// Return rewrite *gomock.Call.Return
-func (c *MockSendStreamIContextCall) Return(arg0 context.Context) *MockSendStreamIContextCall {
-       c.Call = c.Call.Return(arg0)
-       return c
-}
-
-// Do rewrite *gomock.Call.Do
-func (c *MockSendStreamIContextCall) Do(f func() context.Context) *MockSendStreamIContextCall {
-       c.Call = c.Call.Do(f)
-       return c
-}
-
-// DoAndReturn rewrite *gomock.Call.DoAndReturn
-func (c *MockSendStreamIContextCall) DoAndReturn(f func() context.Context) *MockSendStreamIContextCall {
-       c.Call = c.Call.DoAndReturn(f)
-       return c
-}
-
-// SetWriteDeadline mocks base method.
-func (m *MockSendStreamI) SetWriteDeadline(t time.Time) error {
-       m.ctrl.T.Helper()
-       ret := m.ctrl.Call(m, "SetWriteDeadline", t)
-       ret0, _ := ret[0].(error)
-       return ret0
-}
-
-// SetWriteDeadline indicates an expected call of SetWriteDeadline.
-func (mr *MockSendStreamIMockRecorder) SetWriteDeadline(t any) *MockSendStreamISetWriteDeadlineCall {
-       mr.mock.ctrl.T.Helper()
-       call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockSendStreamI)(nil).SetWriteDeadline), t)
-       return &MockSendStreamISetWriteDeadlineCall{Call: call}
-}
-
-// MockSendStreamISetWriteDeadlineCall wrap *gomock.Call
-type MockSendStreamISetWriteDeadlineCall struct {
-       *gomock.Call
-}
-
-// Return rewrite *gomock.Call.Return
-func (c *MockSendStreamISetWriteDeadlineCall) Return(arg0 error) *MockSendStreamISetWriteDeadlineCall {
-       c.Call = c.Call.Return(arg0)
-       return c
-}
-
-// Do rewrite *gomock.Call.Do
-func (c *MockSendStreamISetWriteDeadlineCall) Do(f func(time.Time) error) *MockSendStreamISetWriteDeadlineCall {
-       c.Call = c.Call.Do(f)
-       return c
-}
-
-// DoAndReturn rewrite *gomock.Call.DoAndReturn
-func (c *MockSendStreamISetWriteDeadlineCall) DoAndReturn(f func(time.Time) error) *MockSendStreamISetWriteDeadlineCall {
-       c.Call = c.Call.DoAndReturn(f)
-       return c
-}
-
-// StreamID mocks base method.
-func (m *MockSendStreamI) StreamID() StreamID {
-       m.ctrl.T.Helper()
-       ret := m.ctrl.Call(m, "StreamID")
-       ret0, _ := ret[0].(StreamID)
-       return ret0
-}
-
-// StreamID indicates an expected call of StreamID.
-func (mr *MockSendStreamIMockRecorder) StreamID() *MockSendStreamIStreamIDCall {
-       mr.mock.ctrl.T.Helper()
-       call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockSendStreamI)(nil).StreamID))
-       return &MockSendStreamIStreamIDCall{Call: call}
-}
-
-// MockSendStreamIStreamIDCall wrap *gomock.Call
-type MockSendStreamIStreamIDCall struct {
-       *gomock.Call
-}
-
-// Return rewrite *gomock.Call.Return
-func (c *MockSendStreamIStreamIDCall) Return(arg0 StreamID) *MockSendStreamIStreamIDCall {
-       c.Call = c.Call.Return(arg0)
-       return c
-}
-
-// Do rewrite *gomock.Call.Do
-func (c *MockSendStreamIStreamIDCall) Do(f func() StreamID) *MockSendStreamIStreamIDCall {
-       c.Call = c.Call.Do(f)
-       return c
-}
-
-// DoAndReturn rewrite *gomock.Call.DoAndReturn
-func (c *MockSendStreamIStreamIDCall) DoAndReturn(f func() StreamID) *MockSendStreamIStreamIDCall {
-       c.Call = c.Call.DoAndReturn(f)
-       return c
-}
-
-// Write mocks base method.
-func (m *MockSendStreamI) Write(p []byte) (int, error) {
-       m.ctrl.T.Helper()
-       ret := m.ctrl.Call(m, "Write", p)
-       ret0, _ := ret[0].(int)
-       ret1, _ := ret[1].(error)
-       return ret0, ret1
-}
-
-// Write indicates an expected call of Write.
-func (mr *MockSendStreamIMockRecorder) Write(p any) *MockSendStreamIWriteCall {
-       mr.mock.ctrl.T.Helper()
-       call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockSendStreamI)(nil).Write), p)
-       return &MockSendStreamIWriteCall{Call: call}
-}
-
-// MockSendStreamIWriteCall wrap *gomock.Call
-type MockSendStreamIWriteCall struct {
-       *gomock.Call
-}
-
-// Return rewrite *gomock.Call.Return
-func (c *MockSendStreamIWriteCall) Return(n int, err error) *MockSendStreamIWriteCall {
-       c.Call = c.Call.Return(n, err)
-       return c
-}
-
-// Do rewrite *gomock.Call.Do
-func (c *MockSendStreamIWriteCall) Do(f func([]byte) (int, error)) *MockSendStreamIWriteCall {
-       c.Call = c.Call.Do(f)
-       return c
-}
-
-// DoAndReturn rewrite *gomock.Call.DoAndReturn
-func (c *MockSendStreamIWriteCall) DoAndReturn(f func([]byte) (int, error)) *MockSendStreamIWriteCall {
-       c.Call = c.Call.DoAndReturn(f)
-       return c
-}
-
-// closeForShutdown mocks base method.
-func (m *MockSendStreamI) closeForShutdown(arg0 error) {
-       m.ctrl.T.Helper()
-       m.ctrl.Call(m, "closeForShutdown", arg0)
-}
-
-// closeForShutdown indicates an expected call of closeForShutdown.
-func (mr *MockSendStreamIMockRecorder) closeForShutdown(arg0 any) *MockSendStreamIcloseForShutdownCall {
-       mr.mock.ctrl.T.Helper()
-       call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockSendStreamI)(nil).closeForShutdown), arg0)
-       return &MockSendStreamIcloseForShutdownCall{Call: call}
-}
-
-// MockSendStreamIcloseForShutdownCall wrap *gomock.Call
-type MockSendStreamIcloseForShutdownCall struct {
-       *gomock.Call
-}
-
-// Return rewrite *gomock.Call.Return
-func (c *MockSendStreamIcloseForShutdownCall) Return() *MockSendStreamIcloseForShutdownCall {
-       c.Call = c.Call.Return()
-       return c
-}
-
-// Do rewrite *gomock.Call.Do
-func (c *MockSendStreamIcloseForShutdownCall) Do(f func(error)) *MockSendStreamIcloseForShutdownCall {
-       c.Call = c.Call.Do(f)
-       return c
-}
-
-// DoAndReturn rewrite *gomock.Call.DoAndReturn
-func (c *MockSendStreamIcloseForShutdownCall) DoAndReturn(f func(error)) *MockSendStreamIcloseForShutdownCall {
-       c.Call = c.Call.DoAndReturn(f)
-       return c
-}
-
-// handleStopSendingFrame mocks base method.
-func (m *MockSendStreamI) handleStopSendingFrame(arg0 *wire.StopSendingFrame) {
-       m.ctrl.T.Helper()
-       m.ctrl.Call(m, "handleStopSendingFrame", arg0)
-}
-
-// handleStopSendingFrame indicates an expected call of handleStopSendingFrame.
-func (mr *MockSendStreamIMockRecorder) handleStopSendingFrame(arg0 any) *MockSendStreamIhandleStopSendingFrameCall {
-       mr.mock.ctrl.T.Helper()
-       call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStopSendingFrame", reflect.TypeOf((*MockSendStreamI)(nil).handleStopSendingFrame), arg0)
-       return &MockSendStreamIhandleStopSendingFrameCall{Call: call}
-}
-
-// MockSendStreamIhandleStopSendingFrameCall wrap *gomock.Call
-type MockSendStreamIhandleStopSendingFrameCall struct {
-       *gomock.Call
-}
-
-// Return rewrite *gomock.Call.Return
-func (c *MockSendStreamIhandleStopSendingFrameCall) Return() *MockSendStreamIhandleStopSendingFrameCall {
-       c.Call = c.Call.Return()
-       return c
-}
-
-// Do rewrite *gomock.Call.Do
-func (c *MockSendStreamIhandleStopSendingFrameCall) Do(f func(*wire.StopSendingFrame)) *MockSendStreamIhandleStopSendingFrameCall {
-       c.Call = c.Call.Do(f)
-       return c
-}
-
-// DoAndReturn rewrite *gomock.Call.DoAndReturn
-func (c *MockSendStreamIhandleStopSendingFrameCall) DoAndReturn(f func(*wire.StopSendingFrame)) *MockSendStreamIhandleStopSendingFrameCall {
-       c.Call = c.Call.DoAndReturn(f)
-       return c
-}
-
-// hasData mocks base method.
-func (m *MockSendStreamI) hasData() bool {
-       m.ctrl.T.Helper()
-       ret := m.ctrl.Call(m, "hasData")
-       ret0, _ := ret[0].(bool)
-       return ret0
-}
-
-// hasData indicates an expected call of hasData.
-func (mr *MockSendStreamIMockRecorder) hasData() *MockSendStreamIhasDataCall {
-       mr.mock.ctrl.T.Helper()
-       call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "hasData", reflect.TypeOf((*MockSendStreamI)(nil).hasData))
-       return &MockSendStreamIhasDataCall{Call: call}
-}
-
-// MockSendStreamIhasDataCall wrap *gomock.Call
-type MockSendStreamIhasDataCall struct {
-       *gomock.Call
-}
-
-// Return rewrite *gomock.Call.Return
-func (c *MockSendStreamIhasDataCall) Return(arg0 bool) *MockSendStreamIhasDataCall {
-       c.Call = c.Call.Return(arg0)
-       return c
-}
-
-// Do rewrite *gomock.Call.Do
-func (c *MockSendStreamIhasDataCall) Do(f func() bool) *MockSendStreamIhasDataCall {
-       c.Call = c.Call.Do(f)
-       return c
-}
-
-// DoAndReturn rewrite *gomock.Call.DoAndReturn
-func (c *MockSendStreamIhasDataCall) DoAndReturn(f func() bool) *MockSendStreamIhasDataCall {
-       c.Call = c.Call.DoAndReturn(f)
-       return c
-}
-
-// popStreamFrame mocks base method.
-func (m *MockSendStreamI) popStreamFrame(arg0 protocol.ByteCount, arg1 protocol.Version) (ackhandler.StreamFrame, *wire.StreamDataBlockedFrame, bool) {
-       m.ctrl.T.Helper()
-       ret := m.ctrl.Call(m, "popStreamFrame", arg0, arg1)
-       ret0, _ := ret[0].(ackhandler.StreamFrame)
-       ret1, _ := ret[1].(*wire.StreamDataBlockedFrame)
-       ret2, _ := ret[2].(bool)
-       return ret0, ret1, ret2
-}
-
-// popStreamFrame indicates an expected call of popStreamFrame.
-func (mr *MockSendStreamIMockRecorder) popStreamFrame(arg0, arg1 any) *MockSendStreamIpopStreamFrameCall {
-       mr.mock.ctrl.T.Helper()
-       call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "popStreamFrame", reflect.TypeOf((*MockSendStreamI)(nil).popStreamFrame), arg0, arg1)
-       return &MockSendStreamIpopStreamFrameCall{Call: call}
-}
-
-// MockSendStreamIpopStreamFrameCall wrap *gomock.Call
-type MockSendStreamIpopStreamFrameCall struct {
-       *gomock.Call
-}
-
-// Return rewrite *gomock.Call.Return
-func (c *MockSendStreamIpopStreamFrameCall) Return(arg0 ackhandler.StreamFrame, arg1 *wire.StreamDataBlockedFrame, hasMore bool) *MockSendStreamIpopStreamFrameCall {
-       c.Call = c.Call.Return(arg0, arg1, hasMore)
-       return c
-}
-
-// Do rewrite *gomock.Call.Do
-func (c *MockSendStreamIpopStreamFrameCall) Do(f func(protocol.ByteCount, protocol.Version) (ackhandler.StreamFrame, *wire.StreamDataBlockedFrame, bool)) *MockSendStreamIpopStreamFrameCall {
-       c.Call = c.Call.Do(f)
-       return c
-}
-
-// DoAndReturn rewrite *gomock.Call.DoAndReturn
-func (c *MockSendStreamIpopStreamFrameCall) DoAndReturn(f func(protocol.ByteCount, protocol.Version) (ackhandler.StreamFrame, *wire.StreamDataBlockedFrame, bool)) *MockSendStreamIpopStreamFrameCall {
-       c.Call = c.Call.DoAndReturn(f)
-       return c
-}
-
-// updateSendWindow mocks base method.
-func (m *MockSendStreamI) updateSendWindow(arg0 protocol.ByteCount) {
-       m.ctrl.T.Helper()
-       m.ctrl.Call(m, "updateSendWindow", arg0)
-}
-
-// updateSendWindow indicates an expected call of updateSendWindow.
-func (mr *MockSendStreamIMockRecorder) updateSendWindow(arg0 any) *MockSendStreamIupdateSendWindowCall {
-       mr.mock.ctrl.T.Helper()
-       call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "updateSendWindow", reflect.TypeOf((*MockSendStreamI)(nil).updateSendWindow), arg0)
-       return &MockSendStreamIupdateSendWindowCall{Call: call}
-}
-
-// MockSendStreamIupdateSendWindowCall wrap *gomock.Call
-type MockSendStreamIupdateSendWindowCall struct {
-       *gomock.Call
-}
-
-// Return rewrite *gomock.Call.Return
-func (c *MockSendStreamIupdateSendWindowCall) Return() *MockSendStreamIupdateSendWindowCall {
-       c.Call = c.Call.Return()
-       return c
-}
-
-// Do rewrite *gomock.Call.Do
-func (c *MockSendStreamIupdateSendWindowCall) Do(f func(protocol.ByteCount)) *MockSendStreamIupdateSendWindowCall {
-       c.Call = c.Call.Do(f)
-       return c
-}
-
-// DoAndReturn rewrite *gomock.Call.DoAndReturn
-func (c *MockSendStreamIupdateSendWindowCall) DoAndReturn(f func(protocol.ByteCount)) *MockSendStreamIupdateSendWindowCall {
-       c.Call = c.Call.DoAndReturn(f)
-       return c
-}
diff --git a/mock_stream_frame_getter_test.go b/mock_stream_frame_getter_test.go
new file mode 100644 (file)
index 0000000..2d64103
--- /dev/null
@@ -0,0 +1,83 @@
+// Code generated by MockGen. DO NOT EDIT.
+// Source: github.com/quic-go/quic-go (interfaces: StreamFrameGetter)
+//
+// Generated by this command:
+//
+//     mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_frame_getter_test.go github.com/quic-go/quic-go StreamFrameGetter
+//
+
+// Package quic is a generated GoMock package.
+package quic
+
+import (
+       reflect "reflect"
+
+       ackhandler "github.com/quic-go/quic-go/internal/ackhandler"
+       protocol "github.com/quic-go/quic-go/internal/protocol"
+       wire "github.com/quic-go/quic-go/internal/wire"
+       gomock "go.uber.org/mock/gomock"
+)
+
+// MockStreamFrameGetter is a mock of StreamFrameGetter interface.
+type MockStreamFrameGetter struct {
+       ctrl     *gomock.Controller
+       recorder *MockStreamFrameGetterMockRecorder
+       isgomock struct{}
+}
+
+// MockStreamFrameGetterMockRecorder is the mock recorder for MockStreamFrameGetter.
+type MockStreamFrameGetterMockRecorder struct {
+       mock *MockStreamFrameGetter
+}
+
+// NewMockStreamFrameGetter creates a new mock instance.
+func NewMockStreamFrameGetter(ctrl *gomock.Controller) *MockStreamFrameGetter {
+       mock := &MockStreamFrameGetter{ctrl: ctrl}
+       mock.recorder = &MockStreamFrameGetterMockRecorder{mock}
+       return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockStreamFrameGetter) EXPECT() *MockStreamFrameGetterMockRecorder {
+       return m.recorder
+}
+
+// popStreamFrame mocks base method.
+func (m *MockStreamFrameGetter) popStreamFrame(arg0 protocol.ByteCount, arg1 protocol.Version) (ackhandler.StreamFrame, *wire.StreamDataBlockedFrame, bool) {
+       m.ctrl.T.Helper()
+       ret := m.ctrl.Call(m, "popStreamFrame", arg0, arg1)
+       ret0, _ := ret[0].(ackhandler.StreamFrame)
+       ret1, _ := ret[1].(*wire.StreamDataBlockedFrame)
+       ret2, _ := ret[2].(bool)
+       return ret0, ret1, ret2
+}
+
+// popStreamFrame indicates an expected call of popStreamFrame.
+func (mr *MockStreamFrameGetterMockRecorder) popStreamFrame(arg0, arg1 any) *MockStreamFrameGetterpopStreamFrameCall {
+       mr.mock.ctrl.T.Helper()
+       call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "popStreamFrame", reflect.TypeOf((*MockStreamFrameGetter)(nil).popStreamFrame), arg0, arg1)
+       return &MockStreamFrameGetterpopStreamFrameCall{Call: call}
+}
+
+// MockStreamFrameGetterpopStreamFrameCall wrap *gomock.Call
+type MockStreamFrameGetterpopStreamFrameCall struct {
+       *gomock.Call
+}
+
+// Return rewrite *gomock.Call.Return
+func (c *MockStreamFrameGetterpopStreamFrameCall) Return(arg0 ackhandler.StreamFrame, arg1 *wire.StreamDataBlockedFrame, arg2 bool) *MockStreamFrameGetterpopStreamFrameCall {
+       c.Call = c.Call.Return(arg0, arg1, arg2)
+       return c
+}
+
+// Do rewrite *gomock.Call.Do
+func (c *MockStreamFrameGetterpopStreamFrameCall) Do(f func(protocol.ByteCount, protocol.Version) (ackhandler.StreamFrame, *wire.StreamDataBlockedFrame, bool)) *MockStreamFrameGetterpopStreamFrameCall {
+       c.Call = c.Call.Do(f)
+       return c
+}
+
+// DoAndReturn rewrite *gomock.Call.DoAndReturn
+func (c *MockStreamFrameGetterpopStreamFrameCall) DoAndReturn(f func(protocol.ByteCount, protocol.Version) (ackhandler.StreamFrame, *wire.StreamDataBlockedFrame, bool)) *MockStreamFrameGetterpopStreamFrameCall {
+       c.Call = c.Call.DoAndReturn(f)
+       return c
+}
index 8d5d92708410e1719a5128dcb6220207f0888bd5..b87955eab17239f0ae3deeb1722f4e68bf4cf86d 100644 (file)
@@ -234,10 +234,10 @@ func (c *MockStreamManagerGetOrOpenReceiveStreamCall) DoAndReturn(f func(protoco
 }
 
 // GetOrOpenSendStream mocks base method.
-func (m *MockStreamManager) GetOrOpenSendStream(arg0 protocol.StreamID) (sendStreamI, error) {
+func (m *MockStreamManager) GetOrOpenSendStream(arg0 protocol.StreamID) (*SendStream, error) {
        m.ctrl.T.Helper()
        ret := m.ctrl.Call(m, "GetOrOpenSendStream", arg0)
-       ret0, _ := ret[0].(sendStreamI)
+       ret0, _ := ret[0].(*SendStream)
        ret1, _ := ret[1].(error)
        return ret0, ret1
 }
@@ -255,19 +255,19 @@ type MockStreamManagerGetOrOpenSendStreamCall struct {
 }
 
 // Return rewrite *gomock.Call.Return
-func (c *MockStreamManagerGetOrOpenSendStreamCall) Return(arg0 sendStreamI, arg1 error) *MockStreamManagerGetOrOpenSendStreamCall {
+func (c *MockStreamManagerGetOrOpenSendStreamCall) Return(arg0 *SendStream, arg1 error) *MockStreamManagerGetOrOpenSendStreamCall {
        c.Call = c.Call.Return(arg0, arg1)
        return c
 }
 
 // Do rewrite *gomock.Call.Do
-func (c *MockStreamManagerGetOrOpenSendStreamCall) Do(f func(protocol.StreamID) (sendStreamI, error)) *MockStreamManagerGetOrOpenSendStreamCall {
+func (c *MockStreamManagerGetOrOpenSendStreamCall) Do(f func(protocol.StreamID) (*SendStream, error)) *MockStreamManagerGetOrOpenSendStreamCall {
        c.Call = c.Call.Do(f)
        return c
 }
 
 // DoAndReturn rewrite *gomock.Call.DoAndReturn
-func (c *MockStreamManagerGetOrOpenSendStreamCall) DoAndReturn(f func(protocol.StreamID) (sendStreamI, error)) *MockStreamManagerGetOrOpenSendStreamCall {
+func (c *MockStreamManagerGetOrOpenSendStreamCall) DoAndReturn(f func(protocol.StreamID) (*SendStream, error)) *MockStreamManagerGetOrOpenSendStreamCall {
        c.Call = c.Call.DoAndReturn(f)
        return c
 }
@@ -387,10 +387,10 @@ func (c *MockStreamManagerOpenStreamSyncCall) DoAndReturn(f func(context.Context
 }
 
 // OpenUniStream mocks base method.
-func (m *MockStreamManager) OpenUniStream() (SendStream, error) {
+func (m *MockStreamManager) OpenUniStream() (*SendStream, error) {
        m.ctrl.T.Helper()
        ret := m.ctrl.Call(m, "OpenUniStream")
-       ret0, _ := ret[0].(SendStream)
+       ret0, _ := ret[0].(*SendStream)
        ret1, _ := ret[1].(error)
        return ret0, ret1
 }
@@ -408,28 +408,28 @@ type MockStreamManagerOpenUniStreamCall struct {
 }
 
 // Return rewrite *gomock.Call.Return
-func (c *MockStreamManagerOpenUniStreamCall) Return(arg0 SendStream, arg1 error) *MockStreamManagerOpenUniStreamCall {
+func (c *MockStreamManagerOpenUniStreamCall) Return(arg0 *SendStream, arg1 error) *MockStreamManagerOpenUniStreamCall {
        c.Call = c.Call.Return(arg0, arg1)
        return c
 }
 
 // Do rewrite *gomock.Call.Do
-func (c *MockStreamManagerOpenUniStreamCall) Do(f func() (SendStream, error)) *MockStreamManagerOpenUniStreamCall {
+func (c *MockStreamManagerOpenUniStreamCall) Do(f func() (*SendStream, error)) *MockStreamManagerOpenUniStreamCall {
        c.Call = c.Call.Do(f)
        return c
 }
 
 // DoAndReturn rewrite *gomock.Call.DoAndReturn
-func (c *MockStreamManagerOpenUniStreamCall) DoAndReturn(f func() (SendStream, error)) *MockStreamManagerOpenUniStreamCall {
+func (c *MockStreamManagerOpenUniStreamCall) DoAndReturn(f func() (*SendStream, error)) *MockStreamManagerOpenUniStreamCall {
        c.Call = c.Call.DoAndReturn(f)
        return c
 }
 
 // OpenUniStreamSync mocks base method.
-func (m *MockStreamManager) OpenUniStreamSync(arg0 context.Context) (SendStream, error) {
+func (m *MockStreamManager) OpenUniStreamSync(arg0 context.Context) (*SendStream, error) {
        m.ctrl.T.Helper()
        ret := m.ctrl.Call(m, "OpenUniStreamSync", arg0)
-       ret0, _ := ret[0].(SendStream)
+       ret0, _ := ret[0].(*SendStream)
        ret1, _ := ret[1].(error)
        return ret0, ret1
 }
@@ -447,19 +447,19 @@ type MockStreamManagerOpenUniStreamSyncCall struct {
 }
 
 // Return rewrite *gomock.Call.Return
-func (c *MockStreamManagerOpenUniStreamSyncCall) Return(arg0 SendStream, arg1 error) *MockStreamManagerOpenUniStreamSyncCall {
+func (c *MockStreamManagerOpenUniStreamSyncCall) Return(arg0 *SendStream, arg1 error) *MockStreamManagerOpenUniStreamSyncCall {
        c.Call = c.Call.Return(arg0, arg1)
        return c
 }
 
 // Do rewrite *gomock.Call.Do
-func (c *MockStreamManagerOpenUniStreamSyncCall) Do(f func(context.Context) (SendStream, error)) *MockStreamManagerOpenUniStreamSyncCall {
+func (c *MockStreamManagerOpenUniStreamSyncCall) Do(f func(context.Context) (*SendStream, error)) *MockStreamManagerOpenUniStreamSyncCall {
        c.Call = c.Call.Do(f)
        return c
 }
 
 // DoAndReturn rewrite *gomock.Call.DoAndReturn
-func (c *MockStreamManagerOpenUniStreamSyncCall) DoAndReturn(f func(context.Context) (SendStream, error)) *MockStreamManagerOpenUniStreamSyncCall {
+func (c *MockStreamManagerOpenUniStreamSyncCall) DoAndReturn(f func(context.Context) (*SendStream, error)) *MockStreamManagerOpenUniStreamSyncCall {
        c.Call = c.Call.DoAndReturn(f)
        return c
 }
index 7d3a76e57b6746cdabeabb38f538ea76f16144a9..e8a77f2b3401939d5bec645262fa39f89f8c98d1 100644 (file)
@@ -113,7 +113,7 @@ func (c *MockStreamSenderonHasStreamControlFrameCall) DoAndReturn(f func(protoco
 }
 
 // onHasStreamData mocks base method.
-func (m *MockStreamSender) onHasStreamData(arg0 protocol.StreamID, arg1 sendStreamI) {
+func (m *MockStreamSender) onHasStreamData(arg0 protocol.StreamID, arg1 *SendStream) {
        m.ctrl.T.Helper()
        m.ctrl.Call(m, "onHasStreamData", arg0, arg1)
 }
@@ -137,13 +137,13 @@ func (c *MockStreamSenderonHasStreamDataCall) Return() *MockStreamSenderonHasStr
 }
 
 // Do rewrite *gomock.Call.Do
-func (c *MockStreamSenderonHasStreamDataCall) Do(f func(protocol.StreamID, sendStreamI)) *MockStreamSenderonHasStreamDataCall {
+func (c *MockStreamSenderonHasStreamDataCall) Do(f func(protocol.StreamID, *SendStream)) *MockStreamSenderonHasStreamDataCall {
        c.Call = c.Call.Do(f)
        return c
 }
 
 // DoAndReturn rewrite *gomock.Call.DoAndReturn
-func (c *MockStreamSenderonHasStreamDataCall) DoAndReturn(f func(protocol.StreamID, sendStreamI)) *MockStreamSenderonHasStreamDataCall {
+func (c *MockStreamSenderonHasStreamDataCall) DoAndReturn(f func(protocol.StreamID, *SendStream)) *MockStreamSenderonHasStreamDataCall {
        c.Call = c.Call.DoAndReturn(f)
        return c
 }
index 8ba31825afda9ac82916de9df70cb8a8d6aaf43a..aeb6094a59be33a816a026ce26b9f8995fdef651 100644 (file)
@@ -14,15 +14,15 @@ type Sender = sender
 //go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_receive_stream_internal_test.go github.com/quic-go/quic-go ReceiveStreamI"
 type ReceiveStreamI = receiveStreamI
 
-//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_send_stream_internal_test.go github.com/quic-go/quic-go SendStreamI"
-type SendStreamI = sendStreamI
-
 //go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_sender_test.go github.com/quic-go/quic-go StreamSender"
 type StreamSender = streamSender
 
 //go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_control_frame_getter_test.go github.com/quic-go/quic-go StreamControlFrameGetter"
 type StreamControlFrameGetter = streamControlFrameGetter
 
+//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_frame_getter_test.go github.com/quic-go/quic-go StreamFrameGetter"
+type StreamFrameGetter = streamFrameGetter
+
 //go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_frame_source_test.go github.com/quic-go/quic-go FrameSource"
 type FrameSource = frameSource
 
index a588cc8ad636a7a9c352c529f85121b27fdaf078..b7424cbaf2555aff1f28e8fa0796ea16d302e423 100644 (file)
@@ -14,16 +14,8 @@ import (
        "github.com/quic-go/quic-go/internal/wire"
 )
 
-type sendStreamI interface {
-       SendStream
-       handleStopSendingFrame(*wire.StopSendingFrame)
-       hasData() bool
-       popStreamFrame(protocol.ByteCount, protocol.Version) (_ ackhandler.StreamFrame, _ *wire.StreamDataBlockedFrame, hasMore bool)
-       closeForShutdown(error)
-       updateSendWindow(protocol.ByteCount)
-}
-
-type sendStream struct {
+// A SendStream is a unidirectional Send Stream.
+type SendStream struct {
        mutex sync.Mutex
 
        numOutstandingFrames int64 // outstanding STREAM and RESET_STREAM frames
@@ -62,19 +54,15 @@ type sendStream struct {
        flowController flowcontrol.StreamFlowController
 }
 
-var (
-       _ SendStream               = &sendStream{}
-       _ sendStreamI              = &sendStream{}
-       _ streamControlFrameGetter = &sendStream{}
-)
+var _ streamControlFrameGetter = &SendStream{}
 
 func newSendStream(
        ctx context.Context,
        streamID protocol.StreamID,
        sender streamSender,
        flowController flowcontrol.StreamFlowController,
-) *sendStream {
-       s := &sendStream{
+) *SendStream {
+       s := &SendStream{
                streamID:       streamID,
                sender:         sender,
                flowController: flowController,
@@ -85,11 +73,15 @@ func newSendStream(
        return s
 }
 
-func (s *sendStream) StreamID() protocol.StreamID {
+// StreamID returns the stream ID.
+func (s *SendStream) StreamID() StreamID {
        return s.streamID // same for receiveStream and sendStream
 }
 
-func (s *sendStream) Write(p []byte) (int, error) {
+// Write writes data to the stream.
+// Write can be made to time out using SetDeadline and SetWriteDeadline.
+// If the stream was canceled, the error is a StreamError.
+func (s *SendStream) Write(p []byte) (int, error) {
        // Concurrent use of Write is not permitted (and doesn't make any sense),
        // but sometimes people do it anyway.
        // Make sure that we only execute one call at any given time to avoid hard to debug failures.
@@ -103,7 +95,7 @@ func (s *sendStream) Write(p []byte) (int, error) {
        return n, err
 }
 
-func (s *sendStream) write(p []byte) (bool /* is newly completed */, int, error) {
+func (s *SendStream) write(p []byte) (bool /* is newly completed */, int, error) {
        s.mutex.Lock()
        defer s.mutex.Unlock()
 
@@ -207,7 +199,7 @@ func (s *sendStream) write(p []byte) (bool /* is newly completed */, int, error)
        return false, bytesWritten, nil
 }
 
-func (s *sendStream) canBufferStreamFrame() bool {
+func (s *SendStream) canBufferStreamFrame() bool {
        var l protocol.ByteCount
        if s.nextFrame != nil {
                l = s.nextFrame.DataLen()
@@ -217,7 +209,7 @@ func (s *sendStream) canBufferStreamFrame() bool {
 
 // popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream
 // maxBytes is the maximum length this frame (including frame header) will have.
-func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (_ ackhandler.StreamFrame, _ *wire.StreamDataBlockedFrame, hasMore bool) {
+func (s *SendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (_ ackhandler.StreamFrame, _ *wire.StreamDataBlockedFrame, hasMore bool) {
        s.mutex.Lock()
        f, blocked, hasMoreData := s.popNewOrRetransmittedStreamFrame(maxBytes, v)
        if f != nil {
@@ -234,7 +226,7 @@ func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Vers
        }, blocked, hasMoreData
 }
 
-func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (_ *wire.StreamFrame, _ *wire.StreamDataBlockedFrame, hasMoreData bool) {
+func (s *SendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (_ *wire.StreamFrame, _ *wire.StreamDataBlockedFrame, hasMoreData bool) {
        if s.finalError != nil {
                return nil, nil, false
        }
@@ -290,7 +282,7 @@ func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCoun
        return f, blocked, hasMoreData
 }
 
-func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool) {
+func (s *SendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool) {
        if s.nextFrame != nil {
                maxDataLen := min(sendWindow, s.nextFrame.MaxDataLen(maxBytes, v))
                if maxDataLen == 0 {
@@ -327,7 +319,7 @@ func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount,
        return f, hasMoreData
 }
 
-func (s *sendStream) popNewStreamFrameWithoutBuffer(f *wire.StreamFrame, maxBytes, sendWindow protocol.ByteCount, v protocol.Version) bool {
+func (s *SendStream) popNewStreamFrameWithoutBuffer(f *wire.StreamFrame, maxBytes, sendWindow protocol.ByteCount, v protocol.Version) bool {
        maxDataLen := f.MaxDataLen(maxBytes, v)
        if maxDataLen == 0 { // a STREAM frame must have at least one byte of data
                return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
@@ -337,7 +329,7 @@ func (s *sendStream) popNewStreamFrameWithoutBuffer(f *wire.StreamFrame, maxByte
        return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
 }
 
-func (s *sendStream) maybeGetRetransmission(maxBytes protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool /* has more retransmissions */) {
+func (s *SendStream) maybeGetRetransmission(maxBytes protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool /* has more retransmissions */) {
        f := s.retransmissionQueue[0]
        newFrame, needsSplit := f.MaybeSplitOffFrame(maxBytes, v)
        if needsSplit {
@@ -347,14 +339,7 @@ func (s *sendStream) maybeGetRetransmission(maxBytes protocol.ByteCount, v proto
        return f, len(s.retransmissionQueue) > 0
 }
 
-func (s *sendStream) hasData() bool {
-       s.mutex.Lock()
-       hasData := len(s.dataForWriting) > 0
-       s.mutex.Unlock()
-       return hasData
-}
-
-func (s *sendStream) getDataForWriting(f *wire.StreamFrame, maxBytes protocol.ByteCount) {
+func (s *SendStream) getDataForWriting(f *wire.StreamFrame, maxBytes protocol.ByteCount) {
        if protocol.ByteCount(len(s.dataForWriting)) <= maxBytes {
                f.Data = f.Data[:len(s.dataForWriting)]
                copy(f.Data, s.dataForWriting)
@@ -370,7 +355,7 @@ func (s *sendStream) getDataForWriting(f *wire.StreamFrame, maxBytes protocol.By
        }
 }
 
-func (s *sendStream) isNewlyCompleted() bool {
+func (s *SendStream) isNewlyCompleted() bool {
        if s.completed {
                return false
        }
@@ -395,7 +380,11 @@ func (s *sendStream) isNewlyCompleted() bool {
        return false
 }
 
-func (s *sendStream) Close() error {
+// Close closes the write-direction of the stream.
+// Future calls to Write are not permitted after calling Close.
+// It must not be called concurrently with Write.
+// It must not be called after calling CancelWrite.
+func (s *SendStream) Close() error {
        s.mutex.Lock()
        if s.closedForShutdown || s.finishedWriting {
                s.mutex.Unlock()
@@ -421,14 +410,20 @@ func (s *sendStream) Close() error {
        return nil
 }
 
-func (s *sendStream) CancelWrite(errorCode StreamErrorCode) {
+// CancelWrite aborts sending on this stream.
+// Data already written, but not yet delivered to the peer is not guaranteed to be delivered reliably.
+// Write will unblock immediately, and future calls to Write will fail.
+// When called multiple times it is a no-op.
+// When called after Close, it aborts delivery. Note that there is no guarantee if
+// the peer will receive the FIN or the reset first.
+func (s *SendStream) CancelWrite(errorCode StreamErrorCode) {
        s.cancelWrite(errorCode, false)
 }
 
 // cancelWrite cancels the stream
 // It is possible to cancel a stream after it has been closed, both locally and remotely.
 // This is useful to prevent the retransmission of outstanding stream data.
-func (s *sendStream) cancelWrite(errorCode qerr.StreamErrorCode, remote bool) {
+func (s *SendStream) cancelWrite(errorCode qerr.StreamErrorCode, remote bool) {
        s.mutex.Lock()
        if s.closedForShutdown {
                s.mutex.Unlock()
@@ -468,7 +463,7 @@ func (s *sendStream) cancelWrite(errorCode qerr.StreamErrorCode, remote bool) {
        s.sender.onHasStreamControlFrame(s.streamID, s)
 }
 
-func (s *sendStream) updateSendWindow(limit protocol.ByteCount) {
+func (s *SendStream) updateSendWindow(limit protocol.ByteCount) {
        updated := s.flowController.UpdateSendWindow(limit)
        if !updated { // duplicate or reordered MAX_STREAM_DATA frame
                return
@@ -481,11 +476,11 @@ func (s *sendStream) updateSendWindow(limit protocol.ByteCount) {
        }
 }
 
-func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) {
+func (s *SendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) {
        s.cancelWrite(frame.ErrorCode, true)
 }
 
-func (s *sendStream) getControlFrame(time.Time) (_ ackhandler.Frame, ok, hasMore bool) {
+func (s *SendStream) getControlFrame(time.Time) (_ ackhandler.Frame, ok, hasMore bool) {
        s.mutex.Lock()
        defer s.mutex.Unlock()
 
@@ -501,11 +496,21 @@ func (s *sendStream) getControlFrame(time.Time) (_ ackhandler.Frame, ok, hasMore
        return f, true, false
 }
 
-func (s *sendStream) Context() context.Context {
+// The Context is canceled as soon as the write-side of the stream is closed.
+// This happens when Close() or CancelWrite() is called, or when the peer
+// cancels the read-side of their stream.
+// The cancellation cause is set to the error that caused the stream to
+// close, or `context.Canceled` in case the stream is closed without error.
+func (s *SendStream) Context() context.Context {
        return s.ctx
 }
 
-func (s *sendStream) SetWriteDeadline(t time.Time) error {
+// SetWriteDeadline sets the deadline for future Write calls
+// and any currently-blocked Write call.
+// Even if write times out, it may return n > 0, indicating that
+// some data was successfully written.
+// A zero value for t means Write will not time out.
+func (s *SendStream) SetWriteDeadline(t time.Time) error {
        s.mutex.Lock()
        s.deadline = t
        s.mutex.Unlock()
@@ -516,7 +521,7 @@ func (s *sendStream) SetWriteDeadline(t time.Time) error {
 // CloseForShutdown closes a stream abruptly.
 // It makes Write unblock (and return the error) immediately.
 // The peer will NOT be informed about this: the stream is closed without sending a FIN or RST.
-func (s *sendStream) closeForShutdown(err error) {
+func (s *SendStream) closeForShutdown(err error) {
        s.mutex.Lock()
        s.closedForShutdown = true
        if s.finalError == nil && !s.finishedWriting {
@@ -527,14 +532,14 @@ func (s *sendStream) closeForShutdown(err error) {
 }
 
 // signalWrite performs a non-blocking send on the writeChan
-func (s *sendStream) signalWrite() {
+func (s *SendStream) signalWrite() {
        select {
        case s.writeChan <- struct{}{}:
        default:
        }
 }
 
-type sendStreamAckHandler sendStream
+type sendStreamAckHandler SendStream
 
 var _ ackhandler.FrameHandler = &sendStreamAckHandler{}
 
@@ -550,7 +555,7 @@ func (s *sendStreamAckHandler) OnAcked(f wire.Frame) {
        if s.numOutstandingFrames < 0 {
                panic("numOutStandingFrames negative")
        }
-       completed := (*sendStream)(s).isNewlyCompleted()
+       completed := (*SendStream)(s).isNewlyCompleted()
        s.mutex.Unlock()
 
        if completed {
@@ -573,10 +578,10 @@ func (s *sendStreamAckHandler) OnLost(f wire.Frame) {
        }
        s.mutex.Unlock()
 
-       s.sender.onHasStreamData(s.streamID, (*sendStream)(s))
+       s.sender.onHasStreamData(s.streamID, (*SendStream)(s))
 }
 
-type sendStreamResetStreamHandler sendStream
+type sendStreamResetStreamHandler SendStream
 
 var _ ackhandler.FrameHandler = &sendStreamResetStreamHandler{}
 
@@ -586,7 +591,7 @@ func (s *sendStreamResetStreamHandler) OnAcked(wire.Frame) {
        if s.numOutstandingFrames < 0 {
                panic("numOutStandingFrames negative")
        }
-       completed := (*sendStream)(s).isNewlyCompleted()
+       completed := (*SendStream)(s).isNewlyCompleted()
        s.mutex.Unlock()
 
        if completed {
@@ -599,5 +604,5 @@ func (s *sendStreamResetStreamHandler) OnLost(f wire.Frame) {
        s.queuedResetStreamFrame = f.(*wire.ResetStreamFrame)
        s.numOutstandingFrames--
        s.mutex.Unlock()
-       s.sender.onHasStreamControlFrame(s.streamID, (*sendStream)(s))
+       s.sender.onHasStreamControlFrame(s.streamID, (*SendStream)(s))
 }
index 3d0af2eb0a23a2a5a9eb78d1fc400f698e04e4d8..dcd593cd776b58fd146a47beaffe4bc30447b8cc 100644 (file)
@@ -573,7 +573,7 @@ func TestSendStreamCancellation(t *testing.T) {
        require.True(t, mockCtrl.Satisfied())
 
        wrote := make(chan struct{})
-       mockSender.EXPECT().onHasStreamData(streamID, str).Do(func(protocol.StreamID, sendStreamI) { close(wrote) })
+       mockSender.EXPECT().onHasStreamData(streamID, str).Do(func(protocol.StreamID, *SendStream) { close(wrote) })
        errChan := make(chan error, 1)
        go func() {
                _, err := strWithTimeout.Write(make([]byte, 2000))
index 3f979d17bed1e95cae8cf4b885299f12cc650727..85b70bb28c25f426a1f249bcae4ca112c560bc81 100644 (file)
--- a/stream.go
+++ b/stream.go
@@ -24,7 +24,7 @@ var errDeadline net.Error = &deadlineError{}
 // The streamSender is notified by the stream about various events.
 type streamSender interface {
        onHasConnectionData()
-       onHasStreamData(protocol.StreamID, sendStreamI)
+       onHasStreamData(protocol.StreamID, *SendStream)
        onHasStreamControlFrame(protocol.StreamID, streamControlFrameGetter)
        // must be called without holding the mutex that is acquired by closeForShutdown
        onStreamCompleted(protocol.StreamID)
@@ -38,7 +38,7 @@ type uniStreamSender struct {
        onHasStreamControlFrameImpl func(protocol.StreamID, streamControlFrameGetter)
 }
 
-func (s *uniStreamSender) onHasStreamData(id protocol.StreamID, str sendStreamI) {
+func (s *uniStreamSender) onHasStreamData(id protocol.StreamID, str *SendStream) {
        s.streamSender.onHasStreamData(id, str)
 }
 func (s *uniStreamSender) onStreamCompleted(protocol.StreamID) { s.onStreamCompletedImpl() }
@@ -50,7 +50,7 @@ var _ streamSender = &uniStreamSender{}
 
 type Stream struct {
        receiveStream
-       sendStream
+       *SendStream
 
        completedMutex         sync.Mutex
        sender                 streamSender
@@ -80,7 +80,7 @@ func newStream(
                        sender.onHasStreamControlFrame(streamID, s)
                },
        }
-       s.sendStream = *newSendStream(ctx, streamID, senderForSendStream, flowController)
+       s.SendStream = newSendStream(ctx, streamID, senderForSendStream, flowController)
        senderForReceiveStream := &uniStreamSender{
                streamSender: sender,
                onStreamCompletedImpl: func() {
@@ -100,15 +100,15 @@ func newStream(
 // need to define StreamID() here, since both receiveStream and readStream have a StreamID()
 func (s *Stream) StreamID() protocol.StreamID {
        // the result is same for receiveStream and sendStream
-       return s.sendStream.StreamID()
+       return s.SendStream.StreamID()
 }
 
 func (s *Stream) Close() error {
-       return s.sendStream.Close()
+       return s.SendStream.Close()
 }
 
 func (s *Stream) getControlFrame(now time.Time) (_ ackhandler.Frame, ok, hasMore bool) {
-       f, ok, _ := s.sendStream.getControlFrame(now)
+       f, ok, _ := s.SendStream.getControlFrame(now)
        if ok {
                return f, true, true
        }
@@ -125,7 +125,7 @@ func (s *Stream) SetDeadline(t time.Time) error {
 // It makes Read and Write unblock (and return the error) immediately.
 // The peer will NOT be informed about this: the stream is closed without sending a FIN or RST.
 func (s *Stream) closeForShutdown(err error) {
-       s.sendStream.closeForShutdown(err)
+       s.SendStream.closeForShutdown(err)
        s.receiveStream.closeForShutdown(err)
 }
 
index 186d4e6b8327d826a266f3a6d37fc9993c1c9599..0f2f01ff47c746a78c6f7d5efbf0338425647da3 100644 (file)
@@ -52,7 +52,7 @@ type streamsMap struct {
 
        mutex               sync.Mutex
        outgoingBidiStreams *outgoingStreamsMap[*Stream]
-       outgoingUniStreams  *outgoingStreamsMap[sendStreamI]
+       outgoingUniStreams  *outgoingStreamsMap[*SendStream]
        incomingBidiStreams *incomingStreamsMap[*Stream]
        incomingUniStreams  *incomingStreamsMap[receiveStreamI]
        reset               bool
@@ -102,7 +102,7 @@ func (m *streamsMap) initMaps() {
        )
        m.outgoingUniStreams = newOutgoingStreamsMap(
                protocol.StreamTypeUni,
-               func(num protocol.StreamNum) sendStreamI {
+               func(num protocol.StreamNum) *SendStream {
                        id := num.StreamID(protocol.StreamTypeUni, m.perspective)
                        return newSendStream(m.ctx, id, m.sender, m.newFlowController(id))
                },
@@ -143,7 +143,7 @@ func (m *streamsMap) OpenStreamSync(ctx context.Context) (*Stream, error) {
        return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
 }
 
-func (m *streamsMap) OpenUniStream() (SendStream, error) {
+func (m *streamsMap) OpenUniStream() (*SendStream, error) {
        m.mutex.Lock()
        reset := m.reset
        mm := m.outgoingUniStreams
@@ -155,7 +155,7 @@ func (m *streamsMap) OpenUniStream() (SendStream, error) {
        return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
 }
 
-func (m *streamsMap) OpenUniStreamSync(ctx context.Context) (SendStream, error) {
+func (m *streamsMap) OpenUniStreamSync(ctx context.Context) (*SendStream, error) {
        m.mutex.Lock()
        reset := m.reset
        mm := m.outgoingUniStreams
@@ -247,7 +247,7 @@ func (m *streamsMap) getOrOpenReceiveStream(id protocol.StreamID) (receiveStream
        panic("")
 }
 
-func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
+func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (*SendStream, error) {
        str, err := m.getOrOpenSendStream(id)
        if err != nil {
                return nil, &qerr.TransportError{
@@ -258,12 +258,15 @@ func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, err
        return str, nil
 }
 
-func (m *streamsMap) getOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
+func (m *streamsMap) getOrOpenSendStream(id protocol.StreamID) (*SendStream, error) {
        num := id.StreamNum()
        switch id.Type() {
        case protocol.StreamTypeUni:
                if id.InitiatedBy() == m.perspective {
                        str, err := m.outgoingUniStreams.GetStream(num)
+                       if str == nil && err == nil {
+                               return nil, nil
+                       }
                        return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
                }
                // an incoming unidirectional stream is a receive stream, not a send stream
@@ -274,13 +277,19 @@ func (m *streamsMap) getOrOpenSendStream(id protocol.StreamID) (sendStreamI, err
                        if str == nil && err == nil {
                                return nil, nil
                        }
-                       return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy())
+                       if err != nil {
+                               return nil, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy())
+                       }
+                       return str.SendStream, nil
                } else {
                        str, err := m.incomingBidiStreams.GetOrOpenStream(num)
                        if str == nil && err == nil {
                                return nil, nil
                        }
-                       return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy())
+                       if err != nil {
+                               return nil, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy())
+                       }
+                       return str.SendStream, nil
                }
        }
        panic("")