"net/url"
"os"
+ "google.golang.org/grpc/grpclog"
credinternal "google.golang.org/grpc/internal/credentials"
+ "google.golang.org/grpc/internal/envconfig"
)
const alpnFailureHelpMessage = "If you upgraded from a grpc-go version earlier than 1.67, your TLS connections may have stopped working due to ALPN enforcement. For more details, see: https://github.com/grpc/grpc-go/issues/434"
+var logger = grpclog.Component("credentials")
+
// TLSInfo contains the auth information for a TLS authenticated connection.
// It implements the AuthInfo interface.
type TLSInfo struct {
// for using HTTP/2 over TLS. We can terminate the connection immediately.
np := conn.ConnectionState().NegotiatedProtocol
if np == "" {
- conn.Close()
- return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property. %s", alpnFailureHelpMessage)
+ if envconfig.EnforceALPNEnabled {
+ conn.Close()
+ return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property. %s", alpnFailureHelpMessage)
+ }
+ logger.Warningf("Allowing TLS connection to server %q with ALPN disabled. TLS connections to servers with ALPN disabled will be disallowed in future grpc-go releases", cfg.ServerName)
}
tlsInfo := TLSInfo{
State: conn.ConnectionState(),
// support ALPN. In such cases, we can close the connection since ALPN is required
// for using HTTP/2 over TLS.
if cs.NegotiatedProtocol == "" {
- conn.Close()
- return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property. %s", alpnFailureHelpMessage)
+ if envconfig.EnforceALPNEnabled {
+ conn.Close()
+ return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property. %s", alpnFailureHelpMessage)
+ } else if logger.V(2) {
+ logger.Info("Allowing TLS connection from client with ALPN disabled. TLS connections with ALPN disabled will be disallowed in future grpc-go releases")
+ }
}
tlsInfo := TLSInfo{
State: cs,
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
+ "google.golang.org/grpc/internal/envconfig"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/status"
// correctly for a server that doesn't specify the NextProtos field and uses
// GetConfigForClient to provide the TLS config during the handshake.
func (s) TestTLS_ServerConfiguresALPNByDefault(t *testing.T) {
+ initialVal := envconfig.EnforceALPNEnabled
+ defer func() {
+ envconfig.EnforceALPNEnabled = initialVal
+ }()
+ envconfig.EnforceALPNEnabled = true
+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
// TestTLS_DisabledALPNClient tests the behaviour of TransportCredentials when
// connecting to a server that doesn't support ALPN.
func (s) TestTLS_DisabledALPNClient(t *testing.T) {
- listener, err := tls.Listen("tcp", "localhost:0", &tls.Config{
- Certificates: []tls.Certificate{serverCert},
- NextProtos: []string{}, // Empty list indicates ALPN is disabled.
- })
- if err != nil {
- t.Fatalf("Error starting TLS server: %v", err)
- }
-
- errCh := make(chan error, 1)
- go func() {
- conn, err := listener.Accept()
- if err != nil {
- errCh <- fmt.Errorf("listener.Accept returned error: %v", err)
- } else {
- // The first write to the TLS listener initiates the TLS handshake.
- conn.Write([]byte("Hello, World!"))
- conn.Close()
- }
- close(errCh)
+ initialVal := envconfig.EnforceALPNEnabled
+ defer func() {
+ envconfig.EnforceALPNEnabled = initialVal
}()
- serverAddr := listener.Addr().String()
- conn, err := net.Dial("tcp", serverAddr)
- if err != nil {
- t.Fatalf("net.Dial(%s) failed: %v", serverAddr, err)
+ tests := []struct {
+ name string
+ alpnEnforced bool
+ wantErr bool
+ }{
+ {
+ name: "enforced",
+ alpnEnforced: true,
+ wantErr: true,
+ },
+ {
+ name: "not_enforced",
+ },
}
- defer conn.Close()
- ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
- defer cancel()
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ envconfig.EnforceALPNEnabled = tc.alpnEnforced
- clientCfg := tls.Config{
- ServerName: serverName,
- RootCAs: certPool,
- NextProtos: []string{"h2"},
- }
- _, _, err = credentials.NewTLS(&clientCfg).ClientHandshake(ctx, serverName, conn)
+ listener, err := tls.Listen("tcp", "localhost:0", &tls.Config{
+ Certificates: []tls.Certificate{serverCert},
+ NextProtos: []string{}, // Empty list indicates ALPN is disabled.
+ })
+ if err != nil {
+ t.Fatalf("Error starting TLS server: %v", err)
+ }
- if gotErr, wantErr := (err != nil), true; gotErr != wantErr {
- t.Errorf("ClientHandshake returned unexpected error: got=%v, want=%t", err, wantErr)
- }
+ errCh := make(chan error, 1)
+ go func() {
+ conn, err := listener.Accept()
+ if err != nil {
+ errCh <- fmt.Errorf("listener.Accept returned error: %v", err)
+ } else {
+ // The first write to the TLS listener initiates the TLS handshake.
+ conn.Write([]byte("Hello, World!"))
+ conn.Close()
+ }
+ close(errCh)
+ }()
+
+ serverAddr := listener.Addr().String()
+ conn, err := net.Dial("tcp", serverAddr)
+ if err != nil {
+ t.Fatalf("net.Dial(%s) failed: %v", serverAddr, err)
+ }
+ defer conn.Close()
+
+ ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
+ defer cancel()
+
+ clientCfg := tls.Config{
+ ServerName: serverName,
+ RootCAs: certPool,
+ NextProtos: []string{"h2"},
+ }
+ _, _, err = credentials.NewTLS(&clientCfg).ClientHandshake(ctx, serverName, conn)
+
+ if gotErr := (err != nil); gotErr != tc.wantErr {
+ t.Errorf("ClientHandshake returned unexpected error: got=%v, want=%t", err, tc.wantErr)
+ }
- select {
- case err := <-errCh:
- if err != nil {
- t.Fatalf("Unexpected error received from server: %v", err)
- }
- case <-ctx.Done():
- t.Fatalf("Timeout waiting for error from server")
+ select {
+ case err := <-errCh:
+ if err != nil {
+ t.Fatalf("Unexpected error received from server: %v", err)
+ }
+ case <-ctx.Done():
+ t.Fatalf("Timeout waiting for error from server")
+ }
+ })
}
}
// TestTLS_DisabledALPNServer tests the behaviour of TransportCredentials when
// accepting a request from a client that doesn't support ALPN.
func (s) TestTLS_DisabledALPNServer(t *testing.T) {
- listener, err := net.Listen("tcp", "localhost:0")
- if err != nil {
- t.Fatalf("Error starting server: %v", err)
- }
-
- errCh := make(chan error, 1)
- go func() {
- conn, err := listener.Accept()
- if err != nil {
- errCh <- fmt.Errorf("listener.Accept returned error: %v", err)
- return
- }
- defer conn.Close()
- serverCfg := tls.Config{
- Certificates: []tls.Certificate{serverCert},
- NextProtos: []string{"h2"},
- }
- _, _, err = credentials.NewTLS(&serverCfg).ServerHandshake(conn)
- if gotErr, wantErr := (err != nil), true; gotErr != wantErr {
- t.Errorf("ServerHandshake returned unexpected error: got=%v, want=%t", err, wantErr)
- }
- close(errCh)
+ initialVal := envconfig.EnforceALPNEnabled
+ defer func() {
+ envconfig.EnforceALPNEnabled = initialVal
}()
- serverAddr := listener.Addr().String()
- clientCfg := &tls.Config{
- Certificates: []tls.Certificate{serverCert},
- NextProtos: []string{}, // Empty list indicates ALPN is disabled.
- RootCAs: certPool,
- ServerName: serverName,
- }
- conn, err := tls.Dial("tcp", serverAddr, clientCfg)
- if err != nil {
- t.Fatalf("tls.Dial(%s) failed: %v", serverAddr, err)
+ tests := []struct {
+ name string
+ alpnEnforced bool
+ wantErr bool
+ }{
+ {
+ name: "enforced",
+ alpnEnforced: true,
+ wantErr: true,
+ },
+ {
+ name: "not_enforced",
+ },
}
- defer conn.Close()
-
- select {
- case <-time.After(defaultTestTimeout):
- t.Fatal("Timed out waiting for completion")
- case err := <-errCh:
- if err != nil {
- t.Fatalf("Unexpected server error: %v", err)
- }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ envconfig.EnforceALPNEnabled = tc.alpnEnforced
+
+ listener, err := net.Listen("tcp", "localhost:0")
+ if err != nil {
+ t.Fatalf("Error starting server: %v", err)
+ }
+
+ errCh := make(chan error, 1)
+ go func() {
+ conn, err := listener.Accept()
+ if err != nil {
+ errCh <- fmt.Errorf("listener.Accept returned error: %v", err)
+ return
+ }
+ defer conn.Close()
+ serverCfg := tls.Config{
+ Certificates: []tls.Certificate{serverCert},
+ NextProtos: []string{"h2"},
+ }
+ _, _, err = credentials.NewTLS(&serverCfg).ServerHandshake(conn)
+ if gotErr := (err != nil); gotErr != tc.wantErr {
+ t.Errorf("ServerHandshake returned unexpected error: got=%v, want=%t", err, tc.wantErr)
+ }
+ close(errCh)
+ }()
+
+ serverAddr := listener.Addr().String()
+ clientCfg := &tls.Config{
+ Certificates: []tls.Certificate{serverCert},
+ NextProtos: []string{}, // Empty list indicates ALPN is disabled.
+ RootCAs: certPool,
+ ServerName: serverName,
+ }
+ conn, err := tls.Dial("tcp", serverAddr, clientCfg)
+ if err != nil {
+ t.Fatalf("tls.Dial(%s) failed: %v", serverAddr, err)
+ }
+ defer conn.Close()
+
+ select {
+ case <-time.After(defaultTestTimeout):
+ t.Fatal("Timed out waiting for completion")
+ case err := <-errCh:
+ if err != nil {
+ t.Fatalf("Unexpected server error: %v", err)
+ }
+ }
+ })
}
}
"google.golang.org/grpc/codes"
credsstable "google.golang.org/grpc/credentials"
"google.golang.org/grpc/experimental/credentials"
+ "google.golang.org/grpc/internal/envconfig"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/status"
// correctly for a server that doesn't specify the NextProtos field and uses
// GetConfigForClient to provide the TLS config during the handshake.
func (s) TestTLS_ServerConfiguresALPNByDefault(t *testing.T) {
+ initialVal := envconfig.EnforceALPNEnabled
+ defer func() {
+ envconfig.EnforceALPNEnabled = initialVal
+ }()
+ envconfig.EnforceALPNEnabled = true
+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
// TestTLS_DisabledALPNClient tests the behaviour of TransportCredentials when
// connecting to a server that doesn't support ALPN.
func (s) TestTLS_DisabledALPNClient(t *testing.T) {
- listener, err := tls.Listen("tcp", "localhost:0", &tls.Config{
- Certificates: []tls.Certificate{serverCert},
- NextProtos: []string{}, // Empty list indicates ALPN is disabled.
- })
- if err != nil {
- t.Fatalf("Error starting TLS server: %v", err)
- }
-
- errCh := make(chan error, 1)
- go func() {
- conn, err := listener.Accept()
- if err != nil {
- errCh <- fmt.Errorf("listener.Accept returned error: %v", err)
- } else {
- // The first write to the TLS listener initiates the TLS handshake.
- conn.Write([]byte("Hello, World!"))
- conn.Close()
- }
- close(errCh)
+ initialVal := envconfig.EnforceALPNEnabled
+ defer func() {
+ envconfig.EnforceALPNEnabled = initialVal
}()
- serverAddr := listener.Addr().String()
- conn, err := net.Dial("tcp", serverAddr)
- if err != nil {
- t.Fatalf("net.Dial(%s) failed: %v", serverAddr, err)
+ tests := []struct {
+ name string
+ alpnEnforced bool
+ wantErr bool
+ }{
+ {
+ name: "enforced",
+ },
+ {
+ name: "not_enforced",
+ },
}
- defer conn.Close()
- ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
- defer cancel()
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ envconfig.EnforceALPNEnabled = tc.alpnEnforced
- clientCfg := tls.Config{
- ServerName: serverName,
- RootCAs: certPool,
- NextProtos: []string{"h2"},
- }
- _, _, err = credentials.NewTLSWithALPNDisabled(&clientCfg).ClientHandshake(ctx, serverName, conn)
+ listener, err := tls.Listen("tcp", "localhost:0", &tls.Config{
+ Certificates: []tls.Certificate{serverCert},
+ NextProtos: []string{}, // Empty list indicates ALPN is disabled.
+ })
+ if err != nil {
+ t.Fatalf("Error starting TLS server: %v", err)
+ }
- if gotErr, wantErr := (err != nil), false; gotErr != wantErr {
- t.Errorf("ClientHandshake returned unexpected error: got=%v, want=%t", err, wantErr)
- }
+ errCh := make(chan error, 1)
+ go func() {
+ conn, err := listener.Accept()
+ if err != nil {
+ errCh <- fmt.Errorf("listener.Accept returned error: %v", err)
+ } else {
+ // The first write to the TLS listener initiates the TLS handshake.
+ conn.Write([]byte("Hello, World!"))
+ conn.Close()
+ }
+ close(errCh)
+ }()
+
+ serverAddr := listener.Addr().String()
+ conn, err := net.Dial("tcp", serverAddr)
+ if err != nil {
+ t.Fatalf("net.Dial(%s) failed: %v", serverAddr, err)
+ }
+ defer conn.Close()
+
+ ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
+ defer cancel()
+
+ clientCfg := tls.Config{
+ ServerName: serverName,
+ RootCAs: certPool,
+ NextProtos: []string{"h2"},
+ }
+ _, _, err = credentials.NewTLSWithALPNDisabled(&clientCfg).ClientHandshake(ctx, serverName, conn)
+
+ if gotErr := (err != nil); gotErr != tc.wantErr {
+ t.Errorf("ClientHandshake returned unexpected error: got=%v, want=%t", err, tc.wantErr)
+ }
- select {
- case err := <-errCh:
- if err != nil {
- t.Fatalf("Unexpected error received from server: %v", err)
- }
- case <-ctx.Done():
- t.Fatalf("Timeout waiting for error from server")
+ select {
+ case err := <-errCh:
+ if err != nil {
+ t.Fatalf("Unexpected error received from server: %v", err)
+ }
+ case <-ctx.Done():
+ t.Fatalf("Timeout waiting for error from server")
+ }
+ })
}
}
// TestTLS_DisabledALPNServer tests the behaviour of TransportCredentials when
// accepting a request from a client that doesn't support ALPN.
func (s) TestTLS_DisabledALPNServer(t *testing.T) {
- listener, err := net.Listen("tcp", "localhost:0")
- if err != nil {
- t.Fatalf("Error starting server: %v", err)
- }
-
- errCh := make(chan error, 1)
- go func() {
- conn, err := listener.Accept()
- if err != nil {
- errCh <- fmt.Errorf("listener.Accept returned error: %v", err)
- return
- }
- defer conn.Close()
- serverCfg := tls.Config{
- Certificates: []tls.Certificate{serverCert},
- NextProtos: []string{"h2"},
- }
- _, _, err = credentials.NewTLSWithALPNDisabled(&serverCfg).ServerHandshake(conn)
- if gotErr, wantErr := (err != nil), false; gotErr != wantErr {
- t.Errorf("ServerHandshake returned unexpected error: got=%v, want=%t", err, wantErr)
- }
- close(errCh)
+ initialVal := envconfig.EnforceALPNEnabled
+ defer func() {
+ envconfig.EnforceALPNEnabled = initialVal
}()
- serverAddr := listener.Addr().String()
- clientCfg := &tls.Config{
- Certificates: []tls.Certificate{serverCert},
- NextProtos: []string{}, // Empty list indicates ALPN is disabled.
- RootCAs: certPool,
- ServerName: serverName,
- }
- conn, err := tls.Dial("tcp", serverAddr, clientCfg)
- if err != nil {
- t.Fatalf("tls.Dial(%s) failed: %v", serverAddr, err)
+ tests := []struct {
+ name string
+ alpnEnforced bool
+ wantErr bool
+ }{
+ {
+ name: "enforced",
+ },
+ {
+ name: "not_enforced",
+ },
}
- defer conn.Close()
-
- select {
- case <-time.After(defaultTestTimeout):
- t.Fatal("Timed out waiting for completion")
- case err := <-errCh:
- if err != nil {
- t.Fatalf("Unexpected server error: %v", err)
- }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ envconfig.EnforceALPNEnabled = tc.alpnEnforced
+
+ listener, err := net.Listen("tcp", "localhost:0")
+ if err != nil {
+ t.Fatalf("Error starting server: %v", err)
+ }
+
+ errCh := make(chan error, 1)
+ go func() {
+ conn, err := listener.Accept()
+ if err != nil {
+ errCh <- fmt.Errorf("listener.Accept returned error: %v", err)
+ return
+ }
+ defer conn.Close()
+ serverCfg := tls.Config{
+ Certificates: []tls.Certificate{serverCert},
+ NextProtos: []string{"h2"},
+ }
+ _, _, err = credentials.NewTLSWithALPNDisabled(&serverCfg).ServerHandshake(conn)
+ if gotErr := (err != nil); gotErr != tc.wantErr {
+ t.Errorf("ServerHandshake returned unexpected error: got=%v, want=%t", err, tc.wantErr)
+ }
+ close(errCh)
+ }()
+
+ serverAddr := listener.Addr().String()
+ clientCfg := &tls.Config{
+ Certificates: []tls.Certificate{serverCert},
+ NextProtos: []string{}, // Empty list indicates ALPN is disabled.
+ RootCAs: certPool,
+ ServerName: serverName,
+ }
+ conn, err := tls.Dial("tcp", serverAddr, clientCfg)
+ if err != nil {
+ t.Fatalf("tls.Dial(%s) failed: %v", serverAddr, err)
+ }
+ defer conn.Close()
+
+ select {
+ case <-time.After(defaultTestTimeout):
+ t.Fatal("Timed out waiting for completion")
+ case err := <-errCh:
+ if err != nil {
+ t.Fatalf("Unexpected server error: %v", err)
+ }
+ }
+ })
}
}
// handshakes that can be performed.
ALTSMaxConcurrentHandshakes = uint64FromEnv("GRPC_ALTS_MAX_CONCURRENT_HANDSHAKES", 100, 1, 100)
+ // EnforceALPNEnabled is set if TLS connections to servers with ALPN disabled
+ // should be rejected. The HTTP/2 protocol requires ALPN to be enabled, this
+ // option is present for backward compatibility. This option may be overridden
+ // by setting the environment variable "GRPC_ENFORCE_ALPN_ENABLED" to "true"
+ // or "false".
+ EnforceALPNEnabled = boolFromEnv("GRPC_ENFORCE_ALPN_ENABLED", true)
+
// NewPickFirstEnabled is set if the new pickfirst leaf policy is to be used
// instead of the exiting pickfirst implementation. This can be disabled by
// setting the environment variable "GRPC_EXPERIMENTAL_ENABLE_NEW_PICK_FIRST"