]> git.feebdaed.xyz Git - 0xmirror/grpc-go.git/commitdiff
benchmark/client: add context for cancellation (#8614)
authorIvan Mamaev <mamaevivan543@gmail.com>
Tue, 28 Oct 2025 04:55:12 +0000 (07:55 +0300)
committerGitHub <noreply@github.com>
Tue, 28 Oct 2025 04:55:12 +0000 (21:55 -0700)
Fixes: #8596
RELEASE NOTES: None

benchmark/worker/benchmark_client.go
benchmark/worker/main.go

index c28312dd6aab9bd681e092faadc94ade556d8224..af81445ec5bfb2fa2c50b4a1b124bcb83293f9f8 100644 (file)
@@ -73,7 +73,6 @@ func (h *lockingHistogram) mergeInto(merged *stats.Histogram) {
 
 type benchmarkClient struct {
        closeConns        func()
-       stop              chan bool
        lastResetTime     time.Time
        histogramOptions  stats.HistogramOptions
        lockingHistograms []lockingHistogram
@@ -168,7 +167,7 @@ func createConns(config *testpb.ClientConfig) ([]*grpc.ClientConn, func(), error
        }, nil
 }
 
-func performRPCs(config *testpb.ClientConfig, conns []*grpc.ClientConn, bc *benchmarkClient) error {
+func performRPCs(ctx context.Context, config *testpb.ClientConfig, conns []*grpc.ClientConn, bc *benchmarkClient) error {
        // Read payload size and type from config.
        var (
                payloadReqSize, payloadRespSize int
@@ -212,9 +211,9 @@ func performRPCs(config *testpb.ClientConfig, conns []*grpc.ClientConn, bc *benc
 
        switch config.RpcType {
        case testpb.RpcType_UNARY:
-               bc.unaryLoop(conns, rpcCountPerConn, payloadReqSize, payloadRespSize, poissonLambda)
+               bc.unaryLoop(ctx, conns, rpcCountPerConn, payloadReqSize, payloadRespSize, poissonLambda)
        case testpb.RpcType_STREAMING:
-               bc.streamingLoop(conns, rpcCountPerConn, payloadReqSize, payloadRespSize, payloadType, poissonLambda)
+               bc.streamingLoop(ctx, conns, rpcCountPerConn, payloadReqSize, payloadRespSize, payloadType, poissonLambda)
        default:
                return status.Errorf(codes.InvalidArgument, "unknown rpc type: %v", config.RpcType)
        }
@@ -222,7 +221,7 @@ func performRPCs(config *testpb.ClientConfig, conns []*grpc.ClientConn, bc *benc
        return nil
 }
 
-func startBenchmarkClient(config *testpb.ClientConfig) (*benchmarkClient, error) {
+func startBenchmarkClient(ctx context.Context, config *testpb.ClientConfig) (*benchmarkClient, error) {
        printClientConfig(config)
 
        // Set running environment like how many cores to use.
@@ -243,13 +242,12 @@ func startBenchmarkClient(config *testpb.ClientConfig) (*benchmarkClient, error)
                },
                lockingHistograms: make([]lockingHistogram, rpcCountPerConn*len(conns)),
 
-               stop:            make(chan bool),
                lastResetTime:   time.Now(),
                closeConns:      closeConns,
                rusageLastReset: syscall.GetRusage(),
        }
 
-       if err = performRPCs(config, conns, bc); err != nil {
+       if err = performRPCs(ctx, config, conns, bc); err != nil {
                // Close all connections if performRPCs failed.
                closeConns()
                return nil, err
@@ -258,7 +256,7 @@ func startBenchmarkClient(config *testpb.ClientConfig) (*benchmarkClient, error)
        return bc, nil
 }
 
-func (bc *benchmarkClient) unaryLoop(conns []*grpc.ClientConn, rpcCountPerConn int, reqSize int, respSize int, poissonLambda *float64) {
+func (bc *benchmarkClient) unaryLoop(ctx context.Context, conns []*grpc.ClientConn, rpcCountPerConn int, reqSize int, respSize int, poissonLambda *float64) {
        for ic, conn := range conns {
                client := testgrpc.NewBenchmarkServiceClient(conn)
                // For each connection, create rpcCountPerConn goroutines to do rpc.
@@ -274,10 +272,8 @@ func (bc *benchmarkClient) unaryLoop(conns []*grpc.ClientConn, rpcCountPerConn i
                                // before starting benchmark.
                                if poissonLambda == nil { // Closed loop.
                                        for {
-                                               select {
-                                               case <-bc.stop:
-                                                       return
-                                               default:
+                                               if ctx.Err() != nil {
+                                                       break
                                                }
                                                start := time.Now()
                                                if err := benchmark.DoUnaryCall(client, reqSize, respSize); err != nil {
@@ -292,13 +288,12 @@ func (bc *benchmarkClient) unaryLoop(conns []*grpc.ClientConn, rpcCountPerConn i
                                                bc.poissonUnary(client, idx, reqSize, respSize, *poissonLambda)
                                        })
                                }
-
                        }(idx)
                }
        }
 }
 
-func (bc *benchmarkClient) streamingLoop(conns []*grpc.ClientConn, rpcCountPerConn int, reqSize int, respSize int, payloadType string, poissonLambda *float64) {
+func (bc *benchmarkClient) streamingLoop(ctx context.Context, conns []*grpc.ClientConn, rpcCountPerConn int, reqSize int, respSize int, payloadType string, poissonLambda *float64) {
        var doRPC func(testgrpc.BenchmarkService_StreamingCallClient, int, int) error
        if payloadType == "bytebuf" {
                doRPC = benchmark.DoByteBufStreamingRoundTrip
@@ -329,10 +324,8 @@ func (bc *benchmarkClient) streamingLoop(conns []*grpc.ClientConn, rpcCountPerCo
                                                }
                                                elapse := time.Since(start)
                                                bc.lockingHistograms[idx].add(int64(elapse))
-                                               select {
-                                               case <-bc.stop:
+                                               if ctx.Err() != nil {
                                                        return
-                                               default:
                                                }
                                        }
                                }(idx)
@@ -364,6 +357,7 @@ func (bc *benchmarkClient) poissonUnary(client testgrpc.BenchmarkServiceClient,
 func (bc *benchmarkClient) poissonStreaming(stream testgrpc.BenchmarkService_StreamingCallClient, idx int, reqSize int, respSize int, lambda float64, doRPC func(testgrpc.BenchmarkService_StreamingCallClient, int, int) error) {
        go func() {
                start := time.Now()
+
                if err := doRPC(stream, reqSize, respSize); err != nil {
                        return
                }
@@ -430,6 +424,5 @@ func (bc *benchmarkClient) getStats(reset bool) *testpb.ClientStats {
 }
 
 func (bc *benchmarkClient) shutdown() {
-       close(bc.stop)
        bc.closeConns()
 }
index 45893d7b15a2d4d9a6ca5a16ed92cbd92d3f3606..0a6dd965e177d2e066d09fff42d0c0b813f77210 100644 (file)
@@ -139,7 +139,9 @@ func (s *workerServer) RunServer(stream testgrpc.WorkerService_RunServerServer)
 
 func (s *workerServer) RunClient(stream testgrpc.WorkerService_RunClientServer) error {
        var bc *benchmarkClient
+       ctx, cancel := context.WithCancel(stream.Context())
        defer func() {
+               cancel()
                // Shut down benchmark client when stream ends.
                logger.Infof("shutting down benchmark client")
                if bc != nil {
@@ -163,7 +165,7 @@ func (s *workerServer) RunClient(stream testgrpc.WorkerService_RunClientServer)
                                logger.Infof("client setup received when client already exists, shutting down the existing client")
                                bc.shutdown()
                        }
-                       bc, err = startBenchmarkClient(t.Setup)
+                       bc, err = startBenchmarkClient(ctx, t.Setup)
                        if err != nil {
                                return err
                        }