]> git.feebdaed.xyz Git - 0xmirror/grpc-go.git/commitdiff
credentials/tls: Revert removal of ALPN flag from #8660 (#8664)
authorArjan Singh Bal <46515553+arjan-bal@users.noreply.github.com>
Tue, 21 Oct 2025 18:47:01 +0000 (00:17 +0530)
committerGitHub <noreply@github.com>
Tue, 21 Oct 2025 18:47:01 +0000 (00:17 +0530)
Original PR: https://github.com/grpc/grpc-go/pull/8660

This reverts commit 0037c61d300991605f745bba4d145f406c6e392d.

## Why

There are internal users of this flag that need to be updated. Internal
issue to track removal: b/454048967.

RELEASE NOTES: N/A

credentials/tls.go
credentials/tls_ext_test.go
experimental/credentials/tls_ext_test.go
internal/envconfig/envconfig.go

index 8901d0ec1755bc49b111f411c39f92f933cc6a6c..8277be7d6f855d8aa906c1727a9d5d7f73eed756 100644 (file)
@@ -28,11 +28,15 @@ import (
        "net/url"
        "os"
 
+       "google.golang.org/grpc/grpclog"
        credinternal "google.golang.org/grpc/internal/credentials"
+       "google.golang.org/grpc/internal/envconfig"
 )
 
 const alpnFailureHelpMessage = "If you upgraded from a grpc-go version earlier than 1.67, your TLS connections may have stopped working due to ALPN enforcement. For more details, see: https://github.com/grpc/grpc-go/issues/434"
 
+var logger = grpclog.Component("credentials")
+
 // TLSInfo contains the auth information for a TLS authenticated connection.
 // It implements the AuthInfo interface.
 type TLSInfo struct {
@@ -140,8 +144,11 @@ func (c *tlsCreds) ClientHandshake(ctx context.Context, authority string, rawCon
        //    for using HTTP/2 over TLS. We can terminate the connection immediately.
        np := conn.ConnectionState().NegotiatedProtocol
        if np == "" {
-               conn.Close()
-               return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property. %s", alpnFailureHelpMessage)
+               if envconfig.EnforceALPNEnabled {
+                       conn.Close()
+                       return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property. %s", alpnFailureHelpMessage)
+               }
+               logger.Warningf("Allowing TLS connection to server %q with ALPN disabled. TLS connections to servers with ALPN disabled will be disallowed in future grpc-go releases", cfg.ServerName)
        }
        tlsInfo := TLSInfo{
                State: conn.ConnectionState(),
@@ -167,8 +174,12 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
        // support ALPN. In such cases, we can close the connection since ALPN is required
        // for using HTTP/2 over TLS.
        if cs.NegotiatedProtocol == "" {
-               conn.Close()
-               return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property. %s", alpnFailureHelpMessage)
+               if envconfig.EnforceALPNEnabled {
+                       conn.Close()
+                       return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property. %s", alpnFailureHelpMessage)
+               } else if logger.V(2) {
+                       logger.Info("Allowing TLS connection from client with ALPN disabled. TLS connections with ALPN disabled will be disallowed in future grpc-go releases")
+               }
        }
        tlsInfo := TLSInfo{
                State: cs,
index d8c559d80019dc972939869821359630e0cd7d0e..ceb810a4a477fa6bfff621f8cf5bff3b8bf4a61e 100644 (file)
@@ -32,6 +32,7 @@ import (
        "google.golang.org/grpc"
        "google.golang.org/grpc/codes"
        "google.golang.org/grpc/credentials"
+       "google.golang.org/grpc/internal/envconfig"
        "google.golang.org/grpc/internal/grpctest"
        "google.golang.org/grpc/internal/stubserver"
        "google.golang.org/grpc/status"
@@ -410,6 +411,12 @@ func (s) TestTLS_CipherSuitesOverridable(t *testing.T) {
 // correctly for a server that doesn't specify the NextProtos field and uses
 // GetConfigForClient to provide the TLS config during the handshake.
 func (s) TestTLS_ServerConfiguresALPNByDefault(t *testing.T) {
+       initialVal := envconfig.EnforceALPNEnabled
+       defer func() {
+               envconfig.EnforceALPNEnabled = initialVal
+       }()
+       envconfig.EnforceALPNEnabled = true
+
        ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
        defer cancel()
 
@@ -446,104 +453,156 @@ func (s) TestTLS_ServerConfiguresALPNByDefault(t *testing.T) {
 // TestTLS_DisabledALPNClient tests the behaviour of TransportCredentials when
 // connecting to a server that doesn't support ALPN.
 func (s) TestTLS_DisabledALPNClient(t *testing.T) {
-       listener, err := tls.Listen("tcp", "localhost:0", &tls.Config{
-               Certificates: []tls.Certificate{serverCert},
-               NextProtos:   []string{}, // Empty list indicates ALPN is disabled.
-       })
-       if err != nil {
-               t.Fatalf("Error starting TLS server: %v", err)
-       }
-
-       errCh := make(chan error, 1)
-       go func() {
-               conn, err := listener.Accept()
-               if err != nil {
-                       errCh <- fmt.Errorf("listener.Accept returned error: %v", err)
-               } else {
-                       // The first write to the TLS listener initiates the TLS handshake.
-                       conn.Write([]byte("Hello, World!"))
-                       conn.Close()
-               }
-               close(errCh)
+       initialVal := envconfig.EnforceALPNEnabled
+       defer func() {
+               envconfig.EnforceALPNEnabled = initialVal
        }()
 
-       serverAddr := listener.Addr().String()
-       conn, err := net.Dial("tcp", serverAddr)
-       if err != nil {
-               t.Fatalf("net.Dial(%s) failed: %v", serverAddr, err)
+       tests := []struct {
+               name         string
+               alpnEnforced bool
+               wantErr      bool
+       }{
+               {
+                       name:         "enforced",
+                       alpnEnforced: true,
+                       wantErr:      true,
+               },
+               {
+                       name: "not_enforced",
+               },
        }
-       defer conn.Close()
 
-       ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
-       defer cancel()
+       for _, tc := range tests {
+               t.Run(tc.name, func(t *testing.T) {
+                       envconfig.EnforceALPNEnabled = tc.alpnEnforced
 
-       clientCfg := tls.Config{
-               ServerName: serverName,
-               RootCAs:    certPool,
-               NextProtos: []string{"h2"},
-       }
-       _, _, err = credentials.NewTLS(&clientCfg).ClientHandshake(ctx, serverName, conn)
+                       listener, err := tls.Listen("tcp", "localhost:0", &tls.Config{
+                               Certificates: []tls.Certificate{serverCert},
+                               NextProtos:   []string{}, // Empty list indicates ALPN is disabled.
+                       })
+                       if err != nil {
+                               t.Fatalf("Error starting TLS server: %v", err)
+                       }
 
-       if gotErr, wantErr := (err != nil), true; gotErr != wantErr {
-               t.Errorf("ClientHandshake returned unexpected error: got=%v, want=%t", err, wantErr)
-       }
+                       errCh := make(chan error, 1)
+                       go func() {
+                               conn, err := listener.Accept()
+                               if err != nil {
+                                       errCh <- fmt.Errorf("listener.Accept returned error: %v", err)
+                               } else {
+                                       // The first write to the TLS listener initiates the TLS handshake.
+                                       conn.Write([]byte("Hello, World!"))
+                                       conn.Close()
+                               }
+                               close(errCh)
+                       }()
+
+                       serverAddr := listener.Addr().String()
+                       conn, err := net.Dial("tcp", serverAddr)
+                       if err != nil {
+                               t.Fatalf("net.Dial(%s) failed: %v", serverAddr, err)
+                       }
+                       defer conn.Close()
+
+                       ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
+                       defer cancel()
+
+                       clientCfg := tls.Config{
+                               ServerName: serverName,
+                               RootCAs:    certPool,
+                               NextProtos: []string{"h2"},
+                       }
+                       _, _, err = credentials.NewTLS(&clientCfg).ClientHandshake(ctx, serverName, conn)
+
+                       if gotErr := (err != nil); gotErr != tc.wantErr {
+                               t.Errorf("ClientHandshake returned unexpected error: got=%v, want=%t", err, tc.wantErr)
+                       }
 
-       select {
-       case err := <-errCh:
-               if err != nil {
-                       t.Fatalf("Unexpected error received from server: %v", err)
-               }
-       case <-ctx.Done():
-               t.Fatalf("Timeout waiting for error from server")
+                       select {
+                       case err := <-errCh:
+                               if err != nil {
+                                       t.Fatalf("Unexpected error received from server: %v", err)
+                               }
+                       case <-ctx.Done():
+                               t.Fatalf("Timeout waiting for error from server")
+                       }
+               })
        }
 }
 
 // TestTLS_DisabledALPNServer tests the behaviour of TransportCredentials when
 // accepting a request from a client that doesn't support ALPN.
 func (s) TestTLS_DisabledALPNServer(t *testing.T) {
-       listener, err := net.Listen("tcp", "localhost:0")
-       if err != nil {
-               t.Fatalf("Error starting server: %v", err)
-       }
-
-       errCh := make(chan error, 1)
-       go func() {
-               conn, err := listener.Accept()
-               if err != nil {
-                       errCh <- fmt.Errorf("listener.Accept returned error: %v", err)
-                       return
-               }
-               defer conn.Close()
-               serverCfg := tls.Config{
-                       Certificates: []tls.Certificate{serverCert},
-                       NextProtos:   []string{"h2"},
-               }
-               _, _, err = credentials.NewTLS(&serverCfg).ServerHandshake(conn)
-               if gotErr, wantErr := (err != nil), true; gotErr != wantErr {
-                       t.Errorf("ServerHandshake returned unexpected error: got=%v, want=%t", err, wantErr)
-               }
-               close(errCh)
+       initialVal := envconfig.EnforceALPNEnabled
+       defer func() {
+               envconfig.EnforceALPNEnabled = initialVal
        }()
 
-       serverAddr := listener.Addr().String()
-       clientCfg := &tls.Config{
-               Certificates: []tls.Certificate{serverCert},
-               NextProtos:   []string{}, // Empty list indicates ALPN is disabled.
-               RootCAs:      certPool,
-               ServerName:   serverName,
-       }
-       conn, err := tls.Dial("tcp", serverAddr, clientCfg)
-       if err != nil {
-               t.Fatalf("tls.Dial(%s) failed: %v", serverAddr, err)
+       tests := []struct {
+               name         string
+               alpnEnforced bool
+               wantErr      bool
+       }{
+               {
+                       name:         "enforced",
+                       alpnEnforced: true,
+                       wantErr:      true,
+               },
+               {
+                       name: "not_enforced",
+               },
        }
-       defer conn.Close()
-
-       select {
-       case <-time.After(defaultTestTimeout):
-               t.Fatal("Timed out waiting for completion")
-       case err := <-errCh:
-               if err != nil {
-                       t.Fatalf("Unexpected server error: %v", err)
-               }
+
+       for _, tc := range tests {
+               t.Run(tc.name, func(t *testing.T) {
+                       envconfig.EnforceALPNEnabled = tc.alpnEnforced
+
+                       listener, err := net.Listen("tcp", "localhost:0")
+                       if err != nil {
+                               t.Fatalf("Error starting server: %v", err)
+                       }
+
+                       errCh := make(chan error, 1)
+                       go func() {
+                               conn, err := listener.Accept()
+                               if err != nil {
+                                       errCh <- fmt.Errorf("listener.Accept returned error: %v", err)
+                                       return
+                               }
+                               defer conn.Close()
+                               serverCfg := tls.Config{
+                                       Certificates: []tls.Certificate{serverCert},
+                                       NextProtos:   []string{"h2"},
+                               }
+                               _, _, err = credentials.NewTLS(&serverCfg).ServerHandshake(conn)
+                               if gotErr := (err != nil); gotErr != tc.wantErr {
+                                       t.Errorf("ServerHandshake returned unexpected error: got=%v, want=%t", err, tc.wantErr)
+                               }
+                               close(errCh)
+                       }()
+
+                       serverAddr := listener.Addr().String()
+                       clientCfg := &tls.Config{
+                               Certificates: []tls.Certificate{serverCert},
+                               NextProtos:   []string{}, // Empty list indicates ALPN is disabled.
+                               RootCAs:      certPool,
+                               ServerName:   serverName,
+                       }
+                       conn, err := tls.Dial("tcp", serverAddr, clientCfg)
+                       if err != nil {
+                               t.Fatalf("tls.Dial(%s) failed: %v", serverAddr, err)
+                       }
+                       defer conn.Close()
+
+                       select {
+                       case <-time.After(defaultTestTimeout):
+                               t.Fatal("Timed out waiting for completion")
+                       case err := <-errCh:
+                               if err != nil {
+                                       t.Fatalf("Unexpected server error: %v", err)
+                               }
+                       }
+               })
        }
 }
index cd4cefe69114384df774df29202a9c1ac8b5ed17..3d4e473ff8aa79fa769f783b267ead4fe59ba2e4 100644 (file)
@@ -33,6 +33,7 @@ import (
        "google.golang.org/grpc/codes"
        credsstable "google.golang.org/grpc/credentials"
        "google.golang.org/grpc/experimental/credentials"
+       "google.golang.org/grpc/internal/envconfig"
        "google.golang.org/grpc/internal/grpctest"
        "google.golang.org/grpc/internal/stubserver"
        "google.golang.org/grpc/status"
@@ -410,6 +411,12 @@ func (s) TestTLS_CipherSuitesOverridable(t *testing.T) {
 // correctly for a server that doesn't specify the NextProtos field and uses
 // GetConfigForClient to provide the TLS config during the handshake.
 func (s) TestTLS_ServerConfiguresALPNByDefault(t *testing.T) {
+       initialVal := envconfig.EnforceALPNEnabled
+       defer func() {
+               envconfig.EnforceALPNEnabled = initialVal
+       }()
+       envconfig.EnforceALPNEnabled = true
+
        ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
        defer cancel()
 
@@ -446,104 +453,152 @@ func (s) TestTLS_ServerConfiguresALPNByDefault(t *testing.T) {
 // TestTLS_DisabledALPNClient tests the behaviour of TransportCredentials when
 // connecting to a server that doesn't support ALPN.
 func (s) TestTLS_DisabledALPNClient(t *testing.T) {
-       listener, err := tls.Listen("tcp", "localhost:0", &tls.Config{
-               Certificates: []tls.Certificate{serverCert},
-               NextProtos:   []string{}, // Empty list indicates ALPN is disabled.
-       })
-       if err != nil {
-               t.Fatalf("Error starting TLS server: %v", err)
-       }
-
-       errCh := make(chan error, 1)
-       go func() {
-               conn, err := listener.Accept()
-               if err != nil {
-                       errCh <- fmt.Errorf("listener.Accept returned error: %v", err)
-               } else {
-                       // The first write to the TLS listener initiates the TLS handshake.
-                       conn.Write([]byte("Hello, World!"))
-                       conn.Close()
-               }
-               close(errCh)
+       initialVal := envconfig.EnforceALPNEnabled
+       defer func() {
+               envconfig.EnforceALPNEnabled = initialVal
        }()
 
-       serverAddr := listener.Addr().String()
-       conn, err := net.Dial("tcp", serverAddr)
-       if err != nil {
-               t.Fatalf("net.Dial(%s) failed: %v", serverAddr, err)
+       tests := []struct {
+               name         string
+               alpnEnforced bool
+               wantErr      bool
+       }{
+               {
+                       name: "enforced",
+               },
+               {
+                       name: "not_enforced",
+               },
        }
-       defer conn.Close()
 
-       ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
-       defer cancel()
+       for _, tc := range tests {
+               t.Run(tc.name, func(t *testing.T) {
+                       envconfig.EnforceALPNEnabled = tc.alpnEnforced
 
-       clientCfg := tls.Config{
-               ServerName: serverName,
-               RootCAs:    certPool,
-               NextProtos: []string{"h2"},
-       }
-       _, _, err = credentials.NewTLSWithALPNDisabled(&clientCfg).ClientHandshake(ctx, serverName, conn)
+                       listener, err := tls.Listen("tcp", "localhost:0", &tls.Config{
+                               Certificates: []tls.Certificate{serverCert},
+                               NextProtos:   []string{}, // Empty list indicates ALPN is disabled.
+                       })
+                       if err != nil {
+                               t.Fatalf("Error starting TLS server: %v", err)
+                       }
 
-       if gotErr, wantErr := (err != nil), false; gotErr != wantErr {
-               t.Errorf("ClientHandshake returned unexpected error: got=%v, want=%t", err, wantErr)
-       }
+                       errCh := make(chan error, 1)
+                       go func() {
+                               conn, err := listener.Accept()
+                               if err != nil {
+                                       errCh <- fmt.Errorf("listener.Accept returned error: %v", err)
+                               } else {
+                                       // The first write to the TLS listener initiates the TLS handshake.
+                                       conn.Write([]byte("Hello, World!"))
+                                       conn.Close()
+                               }
+                               close(errCh)
+                       }()
+
+                       serverAddr := listener.Addr().String()
+                       conn, err := net.Dial("tcp", serverAddr)
+                       if err != nil {
+                               t.Fatalf("net.Dial(%s) failed: %v", serverAddr, err)
+                       }
+                       defer conn.Close()
+
+                       ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
+                       defer cancel()
+
+                       clientCfg := tls.Config{
+                               ServerName: serverName,
+                               RootCAs:    certPool,
+                               NextProtos: []string{"h2"},
+                       }
+                       _, _, err = credentials.NewTLSWithALPNDisabled(&clientCfg).ClientHandshake(ctx, serverName, conn)
+
+                       if gotErr := (err != nil); gotErr != tc.wantErr {
+                               t.Errorf("ClientHandshake returned unexpected error: got=%v, want=%t", err, tc.wantErr)
+                       }
 
-       select {
-       case err := <-errCh:
-               if err != nil {
-                       t.Fatalf("Unexpected error received from server: %v", err)
-               }
-       case <-ctx.Done():
-               t.Fatalf("Timeout waiting for error from server")
+                       select {
+                       case err := <-errCh:
+                               if err != nil {
+                                       t.Fatalf("Unexpected error received from server: %v", err)
+                               }
+                       case <-ctx.Done():
+                               t.Fatalf("Timeout waiting for error from server")
+                       }
+               })
        }
 }
 
 // TestTLS_DisabledALPNServer tests the behaviour of TransportCredentials when
 // accepting a request from a client that doesn't support ALPN.
 func (s) TestTLS_DisabledALPNServer(t *testing.T) {
-       listener, err := net.Listen("tcp", "localhost:0")
-       if err != nil {
-               t.Fatalf("Error starting server: %v", err)
-       }
-
-       errCh := make(chan error, 1)
-       go func() {
-               conn, err := listener.Accept()
-               if err != nil {
-                       errCh <- fmt.Errorf("listener.Accept returned error: %v", err)
-                       return
-               }
-               defer conn.Close()
-               serverCfg := tls.Config{
-                       Certificates: []tls.Certificate{serverCert},
-                       NextProtos:   []string{"h2"},
-               }
-               _, _, err = credentials.NewTLSWithALPNDisabled(&serverCfg).ServerHandshake(conn)
-               if gotErr, wantErr := (err != nil), false; gotErr != wantErr {
-                       t.Errorf("ServerHandshake returned unexpected error: got=%v, want=%t", err, wantErr)
-               }
-               close(errCh)
+       initialVal := envconfig.EnforceALPNEnabled
+       defer func() {
+               envconfig.EnforceALPNEnabled = initialVal
        }()
 
-       serverAddr := listener.Addr().String()
-       clientCfg := &tls.Config{
-               Certificates: []tls.Certificate{serverCert},
-               NextProtos:   []string{}, // Empty list indicates ALPN is disabled.
-               RootCAs:      certPool,
-               ServerName:   serverName,
-       }
-       conn, err := tls.Dial("tcp", serverAddr, clientCfg)
-       if err != nil {
-               t.Fatalf("tls.Dial(%s) failed: %v", serverAddr, err)
+       tests := []struct {
+               name         string
+               alpnEnforced bool
+               wantErr      bool
+       }{
+               {
+                       name: "enforced",
+               },
+               {
+                       name: "not_enforced",
+               },
        }
-       defer conn.Close()
-
-       select {
-       case <-time.After(defaultTestTimeout):
-               t.Fatal("Timed out waiting for completion")
-       case err := <-errCh:
-               if err != nil {
-                       t.Fatalf("Unexpected server error: %v", err)
-               }
+
+       for _, tc := range tests {
+               t.Run(tc.name, func(t *testing.T) {
+                       envconfig.EnforceALPNEnabled = tc.alpnEnforced
+
+                       listener, err := net.Listen("tcp", "localhost:0")
+                       if err != nil {
+                               t.Fatalf("Error starting server: %v", err)
+                       }
+
+                       errCh := make(chan error, 1)
+                       go func() {
+                               conn, err := listener.Accept()
+                               if err != nil {
+                                       errCh <- fmt.Errorf("listener.Accept returned error: %v", err)
+                                       return
+                               }
+                               defer conn.Close()
+                               serverCfg := tls.Config{
+                                       Certificates: []tls.Certificate{serverCert},
+                                       NextProtos:   []string{"h2"},
+                               }
+                               _, _, err = credentials.NewTLSWithALPNDisabled(&serverCfg).ServerHandshake(conn)
+                               if gotErr := (err != nil); gotErr != tc.wantErr {
+                                       t.Errorf("ServerHandshake returned unexpected error: got=%v, want=%t", err, tc.wantErr)
+                               }
+                               close(errCh)
+                       }()
+
+                       serverAddr := listener.Addr().String()
+                       clientCfg := &tls.Config{
+                               Certificates: []tls.Certificate{serverCert},
+                               NextProtos:   []string{}, // Empty list indicates ALPN is disabled.
+                               RootCAs:      certPool,
+                               ServerName:   serverName,
+                       }
+                       conn, err := tls.Dial("tcp", serverAddr, clientCfg)
+                       if err != nil {
+                               t.Fatalf("tls.Dial(%s) failed: %v", serverAddr, err)
+                       }
+                       defer conn.Close()
+
+                       select {
+                       case <-time.After(defaultTestTimeout):
+                               t.Fatal("Timed out waiting for completion")
+                       case err := <-errCh:
+                               if err != nil {
+                                       t.Fatalf("Unexpected server error: %v", err)
+                               }
+                       }
+               })
        }
 }
index 069bd89ab805203b9b39b697fa7146eee2f1765a..293a9a40b241b2f2614412f13efc8450e9fb09b7 100644 (file)
@@ -45,6 +45,13 @@ var (
        // handshakes that can be performed.
        ALTSMaxConcurrentHandshakes = uint64FromEnv("GRPC_ALTS_MAX_CONCURRENT_HANDSHAKES", 100, 1, 100)
 
+       // EnforceALPNEnabled is set if TLS connections to servers with ALPN disabled
+       // should be rejected. The HTTP/2 protocol requires ALPN to be enabled, this
+       // option is present for backward compatibility. This option may be overridden
+       // by setting the environment variable "GRPC_ENFORCE_ALPN_ENABLED" to "true"
+       // or "false".
+       EnforceALPNEnabled = boolFromEnv("GRPC_ENFORCE_ALPN_ENABLED", true)
+
        // NewPickFirstEnabled is set if the new pickfirst leaf policy is to be used
        // instead of the exiting pickfirst implementation. This can be disabled by
        // setting the environment variable "GRPC_EXPERIMENTAL_ENABLE_NEW_PICK_FIRST"