"google.golang.org/grpc/codes"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/metadata"
+ "google.golang.org/grpc/stats"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/protoadapt"
ht *serverHandlerTransport
}
-func newHandleStreamTest(t *testing.T) *handleStreamTest {
+type mockStatsHandler struct {
+ rpcStatsCh chan stats.RPCStats
+}
+
+func (h *mockStatsHandler) TagRPC(ctx context.Context, _ *stats.RPCTagInfo) context.Context {
+ return ctx
+}
+
+func (h *mockStatsHandler) HandleRPC(_ context.Context, s stats.RPCStats) {
+ h.rpcStatsCh <- s
+}
+
+func (h *mockStatsHandler) TagConn(ctx context.Context, _ *stats.ConnTagInfo) context.Context {
+ return ctx
+}
+
+func (h *mockStatsHandler) HandleConn(context.Context, stats.ConnStats) {
+}
+
+func newHandleStreamTest(t *testing.T, statsHandlers []stats.Handler) *handleStreamTest {
bodyr, bodyw := io.Pipe()
req := &http.Request{
ProtoMajor: 2,
Body: bodyr,
}
rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
- ht, err := NewServerHandlerTransport(rw, req, nil, mem.DefaultBufferPool())
+ ht, err := NewServerHandlerTransport(rw, req, statsHandlers, mem.DefaultBufferPool())
if err != nil {
t.Fatal(err)
}
}
func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
- st := newHandleStreamTest(t)
+ st := newHandleStreamTest(t, nil)
handleStream := func(s *ServerStream) {
if want := "/service/foo.bar"; s.method != want {
t.Errorf("stream method = %q; want %q", s.method, want)
}
func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) {
- st := newHandleStreamTest(t)
+ st := newHandleStreamTest(t, nil)
handleStream := func(s *ServerStream) {
s.WriteStatus(status.New(statusCode, msg))
}
func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handleStreamTest, s *ServerStream)) {
- st := newHandleStreamTest(t)
+ st := newHandleStreamTest(t, nil)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
t.Cleanup(cancel)
st.ht.HandleStreams(
t.Fatal(err)
}
- hst := newHandleStreamTest(t)
+ hst := newHandleStreamTest(t, nil)
handleStream := func(s *ServerStream) {
s.WriteStatus(st)
}
checkHeaderAndTrailer(t, hst.rw, wantHeader, wantTrailer)
}
+// Tests the use of stats handlers and ensures there are no data races while
+// accessing trailers.
+func (s) TestHandlerTransport_HandleStreams_StatsHandlers(t *testing.T) {
+ errDetails := []protoadapt.MessageV1{
+ &epb.RetryInfo{
+ RetryDelay: &durationpb.Duration{Seconds: 60},
+ },
+ &epb.ResourceInfo{
+ ResourceType: "foo bar",
+ ResourceName: "service.foo.bar",
+ Owner: "User",
+ },
+ }
+
+ statusCode := codes.ResourceExhausted
+ msg := "you are being throttled"
+ st, err := status.New(statusCode, msg).WithDetails(errDetails...)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ stBytes, err := proto.Marshal(st.Proto())
+ if err != nil {
+ t.Fatal(err)
+ }
+ // Add mock stats handlers to exercise the stats handler code path.
+ statsHandler := &mockStatsHandler{
+ rpcStatsCh: make(chan stats.RPCStats, 2),
+ }
+ hst := newHandleStreamTest(t, []stats.Handler{statsHandler})
+ handleStream := func(s *ServerStream) {
+ if err := s.SendHeader(metadata.New(map[string]string{})); err != nil {
+ t.Error(err)
+ }
+ if err := s.SetTrailer(metadata.Pairs("custom-trailer", "Custom trailer value")); err != nil {
+ t.Error(err)
+ }
+ s.WriteStatus(st)
+ }
+ ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
+ defer cancel()
+ hst.ht.HandleStreams(
+ ctx, func(s *ServerStream) { go handleStream(s) },
+ )
+ wantHeader := http.Header{
+ "Date": nil,
+ "Content-Type": {"application/grpc"},
+ "Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
+ }
+ wantTrailer := http.Header{
+ "Grpc-Status": {fmt.Sprint(uint32(statusCode))},
+ "Grpc-Message": {encodeGrpcMessage(msg)},
+ "Grpc-Status-Details-Bin": {encodeBinHeader(stBytes)},
+ "Custom-Trailer": []string{"Custom trailer value"},
+ }
+
+ checkHeaderAndTrailer(t, hst.rw, wantHeader, wantTrailer)
+ wantStatTypes := []stats.RPCStats{&stats.OutHeader{}, &stats.OutTrailer{}}
+ for _, wantType := range wantStatTypes {
+ select {
+ case <-ctx.Done():
+ t.Fatal("Context timed out waiting for statsHandler.HandleRPC() to be called.")
+ case s := <-statsHandler.rpcStatsCh:
+ if reflect.TypeOf(s) != reflect.TypeOf(wantType) {
+ t.Fatalf("Received RPCStats of type %T, want %T", s, wantType)
+ }
+ }
+ }
+}
+
// TestHandlerTransport_Drain verifies that Drain() is not implemented
// by `serverHandlerTransport`.
func (s) TestHandlerTransport_Drain(t *testing.T) {
defer func() { recover() }()
- st := newHandleStreamTest(t)
+ st := newHandleStreamTest(t, nil)
st.ht.Drain("whatever")
t.Errorf("serverHandlerTransport.Drain() should have panicked")
}