]> git.feebdaed.xyz Git - 0xmirror/grpc-go.git/commitdiff
transport: ensure header mutex is held while copying trailers in handler_server ...
authorArjan Singh Bal <46515553+arjan-bal@users.noreply.github.com>
Thu, 21 Aug 2025 06:50:13 +0000 (12:20 +0530)
committerGitHub <noreply@github.com>
Thu, 21 Aug 2025 06:50:13 +0000 (12:20 +0530)
Fixes: https://github.com/grpc/grpc-go/issues/8514
The mutex that guards the trailers should be held while copying the
trailers. We do lock the mutex in [the regular gRPC server
transport](https://github.com/grpc/grpc-go/blob/9ac0ec87ca2ecc66b3c0c084708aef768637aef6/internal/transport/http2_server.go#L1140-L1142),
but have missed it in the std lib http/2 transport. The only place where
a write happens is `writeStatus()` is when the status contains a proto.

https://github.com/grpc/grpc-go/blob/4375c784450aa7e43ff15b8b2879c896d0917130/internal/transport/handler_server.go#L251-L252

RELEASE NOTES:
* transport: Fix a data race while copying headers for stats handlers in
the std lib http2 server transport.

internal/transport/handler_server.go
internal/transport/handler_server_test.go

index 3dea23573518856507cde136a345871622d40ec9..d954a64c38f40aeccdbd294d4d009627c82eefa0 100644 (file)
@@ -277,11 +277,13 @@ func (ht *serverHandlerTransport) writeStatus(s *ServerStream, st *status.Status
        if err == nil { // transport has not been closed
                // Note: The trailer fields are compressed with hpack after this call returns.
                // No WireLength field is set here.
+               s.hdrMu.Lock()
                for _, sh := range ht.stats {
                        sh.HandleRPC(s.Context(), &stats.OutTrailer{
                                Trailer: s.trailer.Copy(),
                        })
                }
+               s.hdrMu.Unlock()
        }
        ht.Close(errors.New("finished writing status"))
        return err
index 911022834322110c643c8aa613366e2723cc796d..e64af27411da2f057feb83ef1513bb9ff2323489 100644 (file)
@@ -35,6 +35,7 @@ import (
        "google.golang.org/grpc/codes"
        "google.golang.org/grpc/mem"
        "google.golang.org/grpc/metadata"
+       "google.golang.org/grpc/stats"
        "google.golang.org/grpc/status"
        "google.golang.org/protobuf/proto"
        "google.golang.org/protobuf/protoadapt"
@@ -246,7 +247,26 @@ type handleStreamTest struct {
        ht    *serverHandlerTransport
 }
 
-func newHandleStreamTest(t *testing.T) *handleStreamTest {
+type mockStatsHandler struct {
+       rpcStatsCh chan stats.RPCStats
+}
+
+func (h *mockStatsHandler) TagRPC(ctx context.Context, _ *stats.RPCTagInfo) context.Context {
+       return ctx
+}
+
+func (h *mockStatsHandler) HandleRPC(_ context.Context, s stats.RPCStats) {
+       h.rpcStatsCh <- s
+}
+
+func (h *mockStatsHandler) TagConn(ctx context.Context, _ *stats.ConnTagInfo) context.Context {
+       return ctx
+}
+
+func (h *mockStatsHandler) HandleConn(context.Context, stats.ConnStats) {
+}
+
+func newHandleStreamTest(t *testing.T, statsHandlers []stats.Handler) *handleStreamTest {
        bodyr, bodyw := io.Pipe()
        req := &http.Request{
                ProtoMajor: 2,
@@ -260,7 +280,7 @@ func newHandleStreamTest(t *testing.T) *handleStreamTest {
                Body: bodyr,
        }
        rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
-       ht, err := NewServerHandlerTransport(rw, req, nil, mem.DefaultBufferPool())
+       ht, err := NewServerHandlerTransport(rw, req, statsHandlers, mem.DefaultBufferPool())
        if err != nil {
                t.Fatal(err)
        }
@@ -273,7 +293,7 @@ func newHandleStreamTest(t *testing.T) *handleStreamTest {
 }
 
 func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
-       st := newHandleStreamTest(t)
+       st := newHandleStreamTest(t, nil)
        handleStream := func(s *ServerStream) {
                if want := "/service/foo.bar"; s.method != want {
                        t.Errorf("stream method = %q; want %q", s.method, want)
@@ -342,7 +362,7 @@ func (s) TestHandlerTransport_HandleStreams_InvalidArgument(t *testing.T) {
 }
 
 func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) {
-       st := newHandleStreamTest(t)
+       st := newHandleStreamTest(t, nil)
 
        handleStream := func(s *ServerStream) {
                s.WriteStatus(status.New(statusCode, msg))
@@ -451,7 +471,7 @@ func (s) TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) {
 }
 
 func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handleStreamTest, s *ServerStream)) {
-       st := newHandleStreamTest(t)
+       st := newHandleStreamTest(t, nil)
        ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
        t.Cleanup(cancel)
        st.ht.HandleStreams(
@@ -483,7 +503,7 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
                t.Fatal(err)
        }
 
-       hst := newHandleStreamTest(t)
+       hst := newHandleStreamTest(t, nil)
        handleStream := func(s *ServerStream) {
                s.WriteStatus(st)
        }
@@ -506,11 +526,81 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
        checkHeaderAndTrailer(t, hst.rw, wantHeader, wantTrailer)
 }
 
+// Tests the use of stats handlers and ensures there are no data races while
+// accessing trailers.
+func (s) TestHandlerTransport_HandleStreams_StatsHandlers(t *testing.T) {
+       errDetails := []protoadapt.MessageV1{
+               &epb.RetryInfo{
+                       RetryDelay: &durationpb.Duration{Seconds: 60},
+               },
+               &epb.ResourceInfo{
+                       ResourceType: "foo bar",
+                       ResourceName: "service.foo.bar",
+                       Owner:        "User",
+               },
+       }
+
+       statusCode := codes.ResourceExhausted
+       msg := "you are being throttled"
+       st, err := status.New(statusCode, msg).WithDetails(errDetails...)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       stBytes, err := proto.Marshal(st.Proto())
+       if err != nil {
+               t.Fatal(err)
+       }
+       // Add mock stats handlers to exercise the stats handler code path.
+       statsHandler := &mockStatsHandler{
+               rpcStatsCh: make(chan stats.RPCStats, 2),
+       }
+       hst := newHandleStreamTest(t, []stats.Handler{statsHandler})
+       handleStream := func(s *ServerStream) {
+               if err := s.SendHeader(metadata.New(map[string]string{})); err != nil {
+                       t.Error(err)
+               }
+               if err := s.SetTrailer(metadata.Pairs("custom-trailer", "Custom trailer value")); err != nil {
+                       t.Error(err)
+               }
+               s.WriteStatus(st)
+       }
+       ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
+       defer cancel()
+       hst.ht.HandleStreams(
+               ctx, func(s *ServerStream) { go handleStream(s) },
+       )
+       wantHeader := http.Header{
+               "Date":         nil,
+               "Content-Type": {"application/grpc"},
+               "Trailer":      {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
+       }
+       wantTrailer := http.Header{
+               "Grpc-Status":             {fmt.Sprint(uint32(statusCode))},
+               "Grpc-Message":            {encodeGrpcMessage(msg)},
+               "Grpc-Status-Details-Bin": {encodeBinHeader(stBytes)},
+               "Custom-Trailer":          []string{"Custom trailer value"},
+       }
+
+       checkHeaderAndTrailer(t, hst.rw, wantHeader, wantTrailer)
+       wantStatTypes := []stats.RPCStats{&stats.OutHeader{}, &stats.OutTrailer{}}
+       for _, wantType := range wantStatTypes {
+               select {
+               case <-ctx.Done():
+                       t.Fatal("Context timed out waiting for statsHandler.HandleRPC() to be called.")
+               case s := <-statsHandler.rpcStatsCh:
+                       if reflect.TypeOf(s) != reflect.TypeOf(wantType) {
+                               t.Fatalf("Received RPCStats of type %T, want %T", s, wantType)
+                       }
+               }
+       }
+}
+
 // TestHandlerTransport_Drain verifies that Drain() is not implemented
 // by `serverHandlerTransport`.
 func (s) TestHandlerTransport_Drain(t *testing.T) {
        defer func() { recover() }()
-       st := newHandleStreamTest(t)
+       st := newHandleStreamTest(t, nil)
        st.ht.Drain("whatever")
        t.Errorf("serverHandlerTransport.Drain() should have panicked")
 }