From 9c2cd38a2115d210f4499cd472c09f66951968d3 Mon Sep 17 00:00:00 2001 From: Pranjali-2501 <87357388+Pranjali-2501@users.noreply.github.com> Date: Mon, 10 Nov 2025 16:01:28 +0530 Subject: [PATCH] credentials/xds: fix goroutine leak in testServer (#8699) Fixes #8694 This PR fixes a goroutine leak in `credentials/xds/xds_client_test.go`. Previously, the `testServer` used standard `Send()` calls . If a test timed out or failed before reading the expected value, the `testServer` goroutine would block indefinitely on the channel, causing a leak. Replaced blocking `Send` calls with `SendContext` in `handleConn`. This ensures that if the test ends (canceling the context), the `testServer` stops trying to send and exits its goroutine gracefully. RELEASE NOTES: None --- credentials/xds/xds_client_test.go | 36 ++++++++++++++++-------------- credentials/xds/xds_server_test.go | 30 ++++++++++++------------- 2 files changed, 34 insertions(+), 32 deletions(-) diff --git a/credentials/xds/xds_client_test.go b/credentials/xds/xds_client_test.go index 023fbd4d..e013184f 100644 --- a/credentials/xds/xds_client_test.go +++ b/credentials/xds/xds_client_test.go @@ -95,32 +95,32 @@ type testHandshakeFunc func(net.Conn) handshakeResult // newTestServerWithHandshakeFunc starts a new testServer which listens for // connections on a local TCP port, and uses the provided custom handshake // function to perform TLS handshake. -func newTestServerWithHandshakeFunc(f testHandshakeFunc) *testServer { +func newTestServerWithHandshakeFunc(ctx context.Context, f testHandshakeFunc) *testServer { ts := &testServer{ handshakeFunc: f, hsResult: testutils.NewChannel(), } - ts.start() + ts.start(ctx) return ts } // starts actually starts listening on a local TCP port, and spawns a goroutine // to handle new connections. -func (ts *testServer) start() error { +func (ts *testServer) start(ctx context.Context) error { lis, err := net.Listen("tcp", "localhost:0") if err != nil { return err } ts.lis = lis ts.address = lis.Addr().String() - go ts.handleConn() + go ts.handleConn(ctx) return nil } // handleConn accepts a new raw connection, and invokes the test provided // handshake function to perform TLS handshake, and returns the result on the // `hsResult` channel. -func (ts *testServer) handleConn() { +func (ts *testServer) handleConn(ctx context.Context) { for { rawConn, err := ts.lis.Accept() if err != nil { @@ -128,7 +128,7 @@ func (ts *testServer) handleConn() { return } hsr := ts.handshakeFunc(rawConn) - ts.hsResult.Send(hsr) + ts.hsResult.SendContext(ctx, hsr) } } @@ -388,7 +388,9 @@ func (s) TestClientCredsSuccess(t *testing.T) { for _, test := range tests { t.Run(test.desc, func(t *testing.T) { - ts := newTestServerWithHandshakeFunc(test.handshakeFunc) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + ts := newTestServerWithHandshakeFunc(ctx, test.handshakeFunc) defer ts.stop() opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)} @@ -403,8 +405,6 @@ func (s) TestClientCredsSuccess(t *testing.T) { } defer conn.Close() - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() _, ai, err := creds.ClientHandshake(test.handshakeInfoCtx(ctx), authority, conn) if err != nil { t.Fatalf("ClientHandshake() returned failed: %q", err) @@ -418,11 +418,13 @@ func (s) TestClientCredsSuccess(t *testing.T) { func (s) TestClientCredsHandshakeTimeout(t *testing.T) { clientDone := make(chan struct{}) + ctx, sCancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer sCancel() // A handshake function which simulates a handshake timeout from the // server-side by simply blocking on the client-side handshake to timeout // and not writing any handshake data. hErr := errors.New("server handshake error") - ts := newTestServerWithHandshakeFunc(func(net.Conn) handshakeResult { + ts := newTestServerWithHandshakeFunc(ctx, func(net.Conn) handshakeResult { <-clientDone return handshakeResult{err: hErr} }) @@ -442,7 +444,7 @@ func (s) TestClientCredsHandshakeTimeout(t *testing.T) { sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) defer sCancel() - ctx := newTestContextWithHandshakeInfo(sCtx, makeRootProvider(t, "x509/server_ca_cert.pem"), nil, defaultTestCertSAN) + ctx = newTestContextWithHandshakeInfo(sCtx, makeRootProvider(t, "x509/server_ca_cert.pem"), nil, defaultTestCertSAN) if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil { t.Fatal("ClientHandshake() succeeded when expected to timeout") } @@ -489,7 +491,9 @@ func (s) TestClientCredsHandshakeFailure(t *testing.T) { for _, test := range tests { t.Run(test.desc, func(t *testing.T) { - ts := newTestServerWithHandshakeFunc(test.handshakeFunc) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + ts := newTestServerWithHandshakeFunc(ctx, test.handshakeFunc) defer ts.stop() opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)} @@ -504,8 +508,6 @@ func (s) TestClientCredsHandshakeFailure(t *testing.T) { } defer conn.Close() - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() ctx = newTestContextWithHandshakeInfo(ctx, test.rootProvider, nil, test.san) if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil || !strings.Contains(err.Error(), test.wantErr) { t.Fatalf("ClientHandshake() returned %q, wantErr %q", err, test.wantErr) @@ -520,7 +522,9 @@ func (s) TestClientCredsHandshakeFailure(t *testing.T) { // approximation of the flow of events when the control plane specifies new // security config which results in new certificate providers being used. func (s) TestClientCredsProviderSwitch(t *testing.T) { - ts := newTestServerWithHandshakeFunc(testServerTLSHandshake) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + ts := newTestServerWithHandshakeFunc(ctx, testServerTLSHandshake) defer ts.stop() opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)} @@ -535,8 +539,6 @@ func (s) TestClientCredsProviderSwitch(t *testing.T) { } defer conn.Close() - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() // Create a root provider which will fail the handshake because it does not // use the correct trust roots. root1 := makeRootProvider(t, "x509/client_ca_cert.pem") diff --git a/credentials/xds/xds_server_test.go b/credentials/xds/xds_server_test.go index 4547cf33..e83b710d 100644 --- a/credentials/xds/xds_server_test.go +++ b/credentials/xds/xds_server_test.go @@ -178,10 +178,12 @@ func (s) TestServerCredsHandshake_XDSHandshakeInfoError(t *testing.T) { if err != nil { t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err) } + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() // Create a test server which uses the xDS server credentials created above // to perform TLS handshake on incoming connections. - ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult { + ts := newTestServerWithHandshakeFunc(ctx, func(rawConn net.Conn) handshakeResult { // Create a wrapped conn which returns a nil HandshakeInfo and a non-nil error. conn := newWrappedConn(rawConn, nil, time.Now().Add(defaultTestTimeout)) hiErr := errors.New("xdsHandshakeInfo error") @@ -208,8 +210,6 @@ func (s) TestServerCredsHandshake_XDSHandshakeInfoError(t *testing.T) { // Read handshake result from the testServer which will return an error if // the handshake succeeded. - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() val, err := ts.hsResult.Receive(ctx) if err != nil { t.Fatalf("testServer failed to return handshake result: %v", err) @@ -229,10 +229,12 @@ func (s) TestServerCredsHandshakeTimeout(t *testing.T) { if err != nil { t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err) } + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() // Create a test server which uses the xDS server credentials created above // to perform TLS handshake on incoming connections. - ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult { + ts := newTestServerWithHandshakeFunc(ctx, func(rawConn net.Conn) handshakeResult { hi := xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/client_ca_cert.pem"), makeIdentityProvider(t, "x509/server2_cert.pem", "x509/server2_key.pem"), nil, true) // Create a wrapped conn which can return the HandshakeInfo created @@ -258,8 +260,6 @@ func (s) TestServerCredsHandshakeTimeout(t *testing.T) { defer rawConn.Close() // Read handshake result from the testServer and expect a failure result. - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() val, err := ts.hsResult.Receive(ctx) if err != nil { t.Fatalf("testServer failed to return handshake result: %v", err) @@ -279,10 +279,12 @@ func (s) TestServerCredsHandshakeFailure(t *testing.T) { if err != nil { t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err) } + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() // Create a test server which uses the xDS server credentials created above // to perform TLS handshake on incoming connections. - ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult { + ts := newTestServerWithHandshakeFunc(ctx, func(rawConn net.Conn) handshakeResult { // Create a HandshakeInfo which has a root provider which does not match // the certificate sent by the client. hi := xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/client2_cert.pem", "x509/client2_key.pem"), nil, true) @@ -314,8 +316,6 @@ func (s) TestServerCredsHandshakeFailure(t *testing.T) { // Read handshake result from the testServer which will return an error if // the handshake succeeded. - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() val, err := ts.hsResult.Receive(ctx) if err != nil { t.Fatalf("testServer failed to return handshake result: %v", err) @@ -361,10 +361,12 @@ func (s) TestServerCredsHandshakeSuccess(t *testing.T) { if err != nil { t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err) } + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() // Create a test server which uses the xDS server credentials // created above to perform TLS handshake on incoming connections. - ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult { + ts := newTestServerWithHandshakeFunc(ctx, func(rawConn net.Conn) handshakeResult { // Create a HandshakeInfo with information from the test table. hi := xdsinternal.NewHandshakeInfo(test.rootProvider, test.identityProvider, nil, test.requireClientCert) @@ -406,8 +408,6 @@ func (s) TestServerCredsHandshakeSuccess(t *testing.T) { // Read the handshake result from the testServer which contains the // TLS connection state on the server-side and compare it with the // one received on the client-side. - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() val, err := ts.hsResult.Receive(ctx) if err != nil { t.Fatalf("testServer failed to return handshake result: %v", err) @@ -433,6 +433,8 @@ func (s) TestServerCredsProviderSwitch(t *testing.T) { if err != nil { t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err) } + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() // The first time the handshake function is invoked, it returns a // HandshakeInfo which is expected to fail. Further invocations return a @@ -440,7 +442,7 @@ func (s) TestServerCredsProviderSwitch(t *testing.T) { cnt := 0 // Create a test server which uses the xDS server credentials created above // to perform TLS handshake on incoming connections. - ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult { + ts := newTestServerWithHandshakeFunc(ctx, func(rawConn net.Conn) handshakeResult { cnt++ var hi *xdsinternal.HandshakeInfo if cnt == 1 { @@ -501,8 +503,6 @@ func (s) TestServerCredsProviderSwitch(t *testing.T) { // Read the handshake result from the testServer which contains the // TLS connection state on the server-side and compare it with the // one received on the client-side. - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() val, err := ts.hsResult.Receive(ctx) if err != nil { t.Fatalf("testServer failed to return handshake result: %v", err) -- 2.43.0