func (c *ClientConn) doRequest(req *http.Request, str *RequestStream) (*http.Response, error) {
trace := httptrace.ContextClientTrace(req.Context())
+ var sendingReqFailed bool
if err := str.sendRequestHeader(req); err != nil {
traceWroteRequest(trace, err)
- return nil, err
+ if c.logger != nil {
+ c.logger.Debug("error writing request", "error", err)
+ }
+ sendingReqFailed = true
}
- if req.Body == nil {
- traceWroteRequest(trace, nil)
- str.Close()
- } else {
- // send the request body asynchronously
- go func() {
- contentLength := int64(-1)
- // According to the documentation for http.Request.ContentLength,
- // a value of 0 with a non-nil Body is also treated as unknown content length.
- if req.ContentLength > 0 {
- contentLength = req.ContentLength
- }
- err := c.sendRequestBody(str, req.Body, contentLength)
- traceWroteRequest(trace, err)
- if err != nil {
- if c.logger != nil {
- c.logger.Debug("error writing request", "error", err)
- }
- }
+ if !sendingReqFailed {
+ if req.Body == nil {
+ traceWroteRequest(trace, nil)
str.Close()
- }()
+ } else {
+ // send the request body asynchronously
+ go func() {
+ contentLength := int64(-1)
+ // According to the documentation for http.Request.ContentLength,
+ // a value of 0 with a non-nil Body is also treated as unknown content length.
+ if req.ContentLength > 0 {
+ contentLength = req.ContentLength
+ }
+ err := c.sendRequestBody(str, req.Body, contentLength)
+ traceWroteRequest(trace, err)
+ if err != nil {
+ if c.logger != nil {
+ c.logger.Debug("error writing request", "error", err)
+ }
+ }
+ str.Close()
+ }()
+ }
}
// copy from net/http: support 1xx responses
"context"
"errors"
"io"
+ mrand "math/rand/v2"
"net/http"
"net/http/httptest"
"testing"
return res.rsp
}
+func randomString(length int) string {
+ const alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
+ b := make([]byte, length)
+ for i := range b {
+ n := mrand.IntN(len(alphabet))
+ b[i] = alphabet[n]
+ }
+ return string(b)
+}
+
+func TestClientRequestError(t *testing.T) {
+ clientConn, serverConn := newConnPair(t)
+
+ req, err := http.NewRequest(http.MethodGet, "http://quic-go.net", nil)
+ require.NoError(t, err)
+ for range 1000 {
+ req.Header.Add(randomString(50), randomString(50))
+ }
+
+ type result struct {
+ rsp *http.Response
+ err error
+ }
+ resultChan := make(chan result, 1)
+ go func() {
+ cc := (&Transport{}).NewClientConn(clientConn)
+ rsp, err := cc.RoundTrip(req)
+ resultChan <- result{rsp: rsp, err: err}
+ }()
+
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+ defer cancel()
+ str, err := serverConn.AcceptStream(ctx)
+ require.NoError(t, err)
+ str.CancelRead(quic.StreamErrorCode(ErrCodeExcessiveLoad))
+
+ _, err = str.Write(encodeResponse(t, http.StatusTeapot))
+ require.NoError(t, err)
+
+ var res result
+ select {
+ case res = <-resultChan:
+ require.NoError(t, res.err)
+ require.Equal(t, http.StatusTeapot, res.rsp.StatusCode)
+ case <-time.After(time.Second):
+ t.Fatal("timeout")
+ }
+}
+
func TestClientResponseValidation(t *testing.T) {
t.Run("HEADERS frame too large", func(t *testing.T) {
require.ErrorContains(t,