]> git.feebdaed.xyz Git - 0xmirror/grpc-go.git/commitdiff
credentials: implement file-based JWT Call Credentials (part 1 for A97) (#8431)
authorDimitar Pavlov <dimpavloff@users.noreply.github.com>
Tue, 16 Sep 2025 07:20:07 +0000 (08:20 +0100)
committerGitHub <noreply@github.com>
Tue, 16 Sep 2025 07:20:07 +0000 (12:50 +0530)
Part one for https://github.com/grpc/proposal/pull/492 (A97).
This is done in a new `credentials/jwt` package to provide file-based
PerRPCCallCredentials. It can be used beyond XDS. The package handles
token reloading, caching, and validation as per A97 .

There will be a separate PR which uses it in `xds/bootstrap`.

Whilst implementing the above, I considered `credentials/oauth` and
`credentials/xds` packages instead of creating a new one. The former
package has `NewJWTAccessFromKey` and `jwtAccess` which seem very
relevant at first. However, I think the `jwtAccess` behaviour seems more
tailored towards Google services. Also, the refresh, caching, and error
behaviour for A97 is quite different than what's already there and
therefore a separate implementation would have still made sense.
WRT `credentials/xds`, it could have been extended to both handle
transport and call credentials. However, this is a bit at odds with A97
which says that the implementation should be non-XDS specific and, from
reading between the lines, usable beyond XDS.
I think the current approach makes review easier but because of the
similarities with the other two packages, it is a bit confusing to
navigate. Please let me know whether the structure should change.

Relates to https://github.com/istio/istio/issues/53532

RELEASE NOTES:
- credentials: Add `credentials/jwt` package providing file-based JWT
PerRPCCredentials (A97).

credentials/jwt/doc.go [new file with mode: 0644]
credentials/jwt/file_reader.go [new file with mode: 0644]
credentials/jwt/file_reader_test.go [new file with mode: 0644]
credentials/jwt/token_file_call_creds.go [new file with mode: 0644]
credentials/jwt/token_file_call_creds_test.go [new file with mode: 0644]

diff --git a/credentials/jwt/doc.go b/credentials/jwt/doc.go
new file mode 100644 (file)
index 0000000..bba687f
--- /dev/null
@@ -0,0 +1,50 @@
+/*
+ *
+ * Copyright 2025 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+// Package jwt implements JWT token file-based call credentials.
+//
+// This package provides support for A97 JWT Call Credentials, allowing gRPC
+// clients to authenticate using JWT tokens read from files. While originally
+// designed for xDS environments, these credentials are general-purpose.
+//
+// The credentials can be used directly in gRPC clients or configured via xDS.
+//
+// # Token Requirements
+//
+// JWT tokens must:
+//   - Be valid, well-formed JWT tokens with header, payload, and signature
+//   - Include an "exp" (expiration) claim
+//   - Be readable from the specified file path
+//
+// # Considerations
+//
+// - Tokens are cached until expiration to avoid excessive file I/O
+// - Transport security is required (RequireTransportSecurity returns true)
+// - Errors in reading tokens or parsing JWTs will result in RPC UNAVAILALBE or
+// UNAUTHENTICATED errors. The errors are cached and retried with exponential
+// backoff.
+//
+// This implementation is originally intended for use in service mesh
+// environments like Istio where JWT tokens are provisioned and rotated by the
+// infrastructure.
+//
+// # Experimental
+//
+// Notice: All APIs in this package are experimental and may be removed in a
+// later release.
+package jwt
diff --git a/credentials/jwt/file_reader.go b/credentials/jwt/file_reader.go
new file mode 100644 (file)
index 0000000..8337608
--- /dev/null
@@ -0,0 +1,118 @@
+/*
+ *
+ * Copyright 2025 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+package jwt
+
+import (
+       "encoding/base64"
+       "encoding/json"
+       "errors"
+       "fmt"
+       "os"
+       "strings"
+       "time"
+)
+
+var (
+       errTokenFileAccess = errors.New("token file access error")
+       errJWTValidation   = errors.New("invalid JWT")
+)
+
+// jwtClaims represents the JWT claims structure for extracting expiration time.
+type jwtClaims struct {
+       Exp int64 `json:"exp"`
+}
+
+// jwtFileReader handles reading and parsing JWT tokens from files.
+// It is safe to call methods on this type concurrently as no state is stored.
+type jwtFileReader struct {
+       tokenFilePath string
+}
+
+// readToken reads and parses a JWT token from the configured file.
+// Returns the token string, expiration time, and any error encountered.
+func (r *jwtFileReader) readToken() (string, time.Time, error) {
+       tokenBytes, err := os.ReadFile(r.tokenFilePath)
+       if err != nil {
+               return "", time.Time{}, fmt.Errorf("%v: %w", err, errTokenFileAccess)
+       }
+
+       token := strings.TrimSpace(string(tokenBytes))
+       if token == "" {
+               return "", time.Time{}, fmt.Errorf("token file %q is empty: %w", r.tokenFilePath, errJWTValidation)
+       }
+
+       exp, err := r.extractExpiration(token)
+       if err != nil {
+               return "", time.Time{}, fmt.Errorf("token file %q: %v: %w", r.tokenFilePath, err, errJWTValidation)
+       }
+
+       return token, exp, nil
+}
+
+const tokenDelim = "."
+
+// extractClaimsRaw returns the JWT's claims part as raw string. Even though the
+// header and signature are not used, it still expects that the input string to
+// be well-formed (ie comprised of exactly three parts, separated by a dot
+// character).
+func extractClaimsRaw(s string) (string, bool) {
+       _, s, ok := strings.Cut(s, tokenDelim)
+       if !ok { // no period found
+               return "", false
+       }
+       claims, s, ok := strings.Cut(s, tokenDelim)
+       if !ok { // only one period found
+               return "", false
+       }
+       _, _, ok = strings.Cut(s, tokenDelim)
+       if ok { // three periods found
+               return "", false
+       }
+       return claims, true
+}
+
+// extractExpiration parses the JWT token to extract the expiration time.
+func (r *jwtFileReader) extractExpiration(token string) (time.Time, error) {
+       claimsRaw, ok := extractClaimsRaw(token)
+       if !ok {
+               return time.Time{}, fmt.Errorf("expected 3 parts in token")
+       }
+       payloadBytes, err := base64.RawURLEncoding.DecodeString(claimsRaw)
+       if err != nil {
+               return time.Time{}, fmt.Errorf("decode error: %v", err)
+       }
+
+       var claims jwtClaims
+       if err := json.Unmarshal(payloadBytes, &claims); err != nil {
+               return time.Time{}, fmt.Errorf("unmarshal error: %v", err)
+       }
+
+       if claims.Exp == 0 {
+               return time.Time{}, fmt.Errorf("no expiration claims")
+       }
+
+       expTime := time.Unix(claims.Exp, 0)
+
+       // Check if token is already expired.
+       if expTime.Before(time.Now()) {
+               return time.Time{}, fmt.Errorf("expired token")
+       }
+
+       return expTime, nil
+}
diff --git a/credentials/jwt/file_reader_test.go b/credentials/jwt/file_reader_test.go
new file mode 100644 (file)
index 0000000..481f01f
--- /dev/null
@@ -0,0 +1,172 @@
+/*
+ *
+ * Copyright 2025 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+package jwt
+
+import (
+       "encoding/base64"
+       "encoding/json"
+       "errors"
+       "fmt"
+       "strings"
+       "testing"
+       "time"
+)
+
+func (s) TestJWTFileReader_ReadToken_FileErrors(t *testing.T) {
+       tests := []struct {
+               name     string
+               create   bool
+               contents string
+               wantErr  error
+       }{
+               {
+                       name:     "nonexistent_file",
+                       create:   false,
+                       contents: "",
+                       wantErr:  errTokenFileAccess,
+               },
+               {
+                       name:     "empty_file",
+                       create:   true,
+                       contents: "",
+                       wantErr:  errJWTValidation,
+               },
+               {
+                       name:     "file_with_whitespace_only",
+                       create:   true,
+                       contents: "   \n\t  ",
+                       wantErr:  errJWTValidation,
+               },
+       }
+
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       var tokenFile string
+                       if !tt.create {
+                               tokenFile = "/does-not-exist"
+                       } else {
+                               tokenFile = writeTempFile(t, "token", tt.contents)
+                       }
+
+                       reader := jwtFileReader{tokenFilePath: tokenFile}
+                       if _, _, err := reader.readToken(); err == nil {
+                               t.Fatal("ReadToken() expected error, got nil")
+                       } else if !errors.Is(err, tt.wantErr) {
+                               t.Fatalf("ReadToken() error = %v, want error %v", err, tt.wantErr)
+                       }
+               })
+       }
+}
+
+func (s) TestJWTFileReader_ReadToken_InvalidJWT(t *testing.T) {
+       now := time.Now().Truncate(time.Second)
+       tests := []struct {
+               name         string
+               tokenContent string
+               wantErr      error
+       }{
+               {
+                       name:         "valid_token_without_expiration",
+                       tokenContent: createTestJWT(t, time.Time{}),
+                       wantErr:      errJWTValidation,
+               },
+               {
+                       name:         "expired_token",
+                       tokenContent: createTestJWT(t, now.Add(-time.Hour)),
+                       wantErr:      errJWTValidation,
+               },
+               {
+                       name:         "malformed_JWT_not_enough_parts",
+                       tokenContent: "invalid.jwt",
+                       wantErr:      errJWTValidation,
+               },
+               {
+                       name:         "malformed_JWT_invalid_base64",
+                       tokenContent: "header.invalid_base64!@#.signature",
+                       wantErr:      errJWTValidation,
+               },
+               {
+                       name:         "malformed_JWT_invalid_JSON",
+                       tokenContent: createInvalidJWT(t),
+                       wantErr:      errJWTValidation,
+               },
+       }
+
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       tokenFile := writeTempFile(t, "token", tt.tokenContent)
+
+                       reader := jwtFileReader{tokenFilePath: tokenFile}
+                       if _, _, err := reader.readToken(); err == nil {
+                               t.Fatal("ReadToken() expected error, got nil")
+                       } else if !errors.Is(err, tt.wantErr) {
+                               t.Fatalf("ReadToken() error = %v, want error %v", err, tt.wantErr)
+                       }
+               })
+       }
+}
+
+func (s) TestJWTFileReader_ReadToken_ValidToken(t *testing.T) {
+       now := time.Now().Truncate(time.Second)
+       tokenExp := now.Add(time.Hour)
+       token := createTestJWT(t, tokenExp)
+       tokenFile := writeTempFile(t, "token", token)
+
+       reader := jwtFileReader{tokenFilePath: tokenFile}
+       readToken, expiry, err := reader.readToken()
+       if err != nil {
+               t.Fatalf("ReadToken() unexpected error: %v", err)
+       }
+
+       if readToken != token {
+               t.Errorf("ReadToken() token = %q, want %q", readToken, token)
+       }
+
+       if !expiry.Equal(tokenExp) {
+               t.Errorf("ReadToken() expiry = %v, want %v", expiry, tokenExp)
+       }
+}
+
+// createInvalidJWT creates a JWT with invalid JSON in the payload.
+func createInvalidJWT(t *testing.T) string {
+       t.Helper()
+
+       header := map[string]any{
+               "typ": "JWT",
+               "alg": "HS256",
+       }
+
+       headerBytes, err := json.Marshal(header)
+       if err != nil {
+               t.Fatalf("Failed to marshal header: %v", err)
+       }
+
+       headerB64 := base64.URLEncoding.EncodeToString(headerBytes)
+       headerB64 = strings.TrimRight(headerB64, "=")
+
+       // Create invalid JSON payload
+       invalidJSON := "invalid json content"
+       payloadB64 := base64.URLEncoding.EncodeToString([]byte(invalidJSON))
+       payloadB64 = strings.TrimRight(payloadB64, "=")
+
+       signature := base64.URLEncoding.EncodeToString([]byte("fake_signature"))
+       signature = strings.TrimRight(signature, "=")
+
+       return fmt.Sprintf("%s.%s.%s", headerB64, payloadB64, signature)
+}
diff --git a/credentials/jwt/token_file_call_creds.go b/credentials/jwt/token_file_call_creds.go
new file mode 100644 (file)
index 0000000..2005cf4
--- /dev/null
@@ -0,0 +1,180 @@
+/*
+ *
+ * Copyright 2025 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+package jwt
+
+import (
+       "context"
+       "errors"
+       "fmt"
+       "sync"
+       "time"
+
+       "google.golang.org/grpc/codes"
+       "google.golang.org/grpc/credentials"
+       "google.golang.org/grpc/internal/backoff"
+       "google.golang.org/grpc/status"
+)
+
+const preemptiveRefreshThreshold = time.Minute
+
+// jwtTokenFileCallCreds provides JWT token-based PerRPCCredentials that reads
+// tokens from a file.
+// This implementation follows the A97 JWT Call Credentials specification.
+type jwtTokenFileCallCreds struct {
+       fileReader      *jwtFileReader
+       backoffStrategy backoff.Strategy
+
+       // cached data protected by mu
+       mu               sync.Mutex
+       cachedAuthHeader string    // "Bearer " + token
+       cachedExpiry     time.Time // Slightly less than actual expiration time
+       cachedError      error     // Error from last failed attempt
+       retryAttempt     int       // Current retry attempt number
+       nextRetryTime    time.Time // When next retry is allowed
+       pendingRefresh   bool      // Whether a refresh is currently in progress
+}
+
+// NewTokenFileCallCredentials creates PerRPCCredentials that reads JWT tokens
+// from the specified file path.
+func NewTokenFileCallCredentials(tokenFilePath string) (credentials.PerRPCCredentials, error) {
+       if tokenFilePath == "" {
+               return nil, fmt.Errorf("tokenFilePath cannot be empty")
+       }
+
+       creds := &jwtTokenFileCallCreds{
+               fileReader:      &jwtFileReader{tokenFilePath: tokenFilePath},
+               backoffStrategy: backoff.DefaultExponential,
+       }
+
+       return creds, nil
+}
+
+// GetRequestMetadata gets the current request metadata, refreshing tokens if
+// required. This implementation follows the PerRPCCredentials interface.  The
+// tokens will get automatically refreshed if they are about to expire or if
+// they haven't been loaded successfully yet.
+// If it's not possible to extract a token from the file, UNAVAILABLE is
+// returned.
+// If the token is extracted but invalid, then UNAUTHENTICATED is returned.
+// If errors are encoutered, a backoff is applied before retrying.
+func (c *jwtTokenFileCallCreds) GetRequestMetadata(ctx context.Context, _ ...string) (map[string]string, error) {
+       ri, _ := credentials.RequestInfoFromContext(ctx)
+       if err := credentials.CheckSecurityLevel(ri.AuthInfo, credentials.PrivacyAndIntegrity); err != nil {
+               return nil, fmt.Errorf("cannot send secure credentials on an insecure connection: %v", err)
+       }
+
+       c.mu.Lock()
+       defer c.mu.Unlock()
+
+       if c.isTokenValidLocked() {
+               needsPreemptiveRefresh := time.Until(c.cachedExpiry) < preemptiveRefreshThreshold
+               if needsPreemptiveRefresh && !c.pendingRefresh {
+                       // Start refresh if not pending (handling the prior RPC may have
+                       // just spawned a goroutine).
+                       c.pendingRefresh = true
+                       go c.refreshToken()
+               }
+               return map[string]string{
+                       "authorization": c.cachedAuthHeader,
+               }, nil
+       }
+
+       // If in backoff state, just return the cached error.
+       if c.cachedError != nil && time.Now().Before(c.nextRetryTime) {
+               return nil, c.cachedError
+       }
+
+       // At this point, the token is either invalid or expired and we are no
+       // longer backing off from any encountered errors. So refresh it.
+       // NB: We are holding the lock while reading the token from file. This will
+       // cause other RPCs to block until the read completes (sucecssfully or not)
+       // and the cache is updated. Subsequent RPCs will end up using the cache.
+       // This is per A97.
+       token, expiry, err := c.fileReader.readToken()
+       c.updateCacheLocked(token, expiry, err)
+
+       if c.cachedError != nil {
+               return nil, c.cachedError
+       }
+       return map[string]string{
+               "authorization": c.cachedAuthHeader,
+       }, nil
+}
+
+// RequireTransportSecurity indicates whether the credentials requires
+// transport security.
+func (c *jwtTokenFileCallCreds) RequireTransportSecurity() bool {
+       return true
+}
+
+// isTokenValidLocked checks if the cached token is still valid.
+// Caller must hold c.mu lock.
+func (c *jwtTokenFileCallCreds) isTokenValidLocked() bool {
+       if c.cachedAuthHeader == "" {
+               return false
+       }
+       return c.cachedExpiry.After(time.Now())
+}
+
+// refreshToken reads the token from file and updates the cached data.
+func (c *jwtTokenFileCallCreds) refreshToken() {
+       // Deliberately not locking c.mu here. This way other RPCs can proceed
+       // while we read the token. This is per gRFC A97.
+       token, expiry, err := c.fileReader.readToken()
+
+       c.mu.Lock()
+       defer c.mu.Unlock()
+       c.updateCacheLocked(token, expiry, err)
+       c.pendingRefresh = false
+}
+
+// updateCacheLocked updates the cached token, expiry, and error state.
+// If an error is provided, it determines whether to set it as an UNAVAILABLE
+// or UNAUTHENTICATED error based on the error type.
+// NOTE: This method (and its callers) do not queue up a token refresh/retry if
+// the expiration is soon / an error was encountered. Instead, this is done when
+// handling RPCs. This is as per gRFC A97, which states that it is
+// undesirable to retry loading the token if the channel is idle.
+// Caller must hold c.mu lock.
+func (c *jwtTokenFileCallCreds) updateCacheLocked(token string, expiry time.Time, err error) {
+       if err != nil {
+               // Convert to gRPC status codes
+               if errors.Is(err, errTokenFileAccess) {
+                       c.cachedError = status.Error(codes.Unavailable, err.Error())
+               } else if errors.Is(err, errJWTValidation) {
+                       c.cachedError = status.Error(codes.Unauthenticated, err.Error())
+               } else {
+                       // Should not happen. Treat unknown errors as UNAUTHENTICATED.
+                       c.cachedError = status.Error(codes.Unauthenticated, err.Error())
+               }
+               c.retryAttempt++
+               backoffDelay := c.backoffStrategy.Backoff(c.retryAttempt - 1)
+               c.nextRetryTime = time.Now().Add(backoffDelay)
+               return
+       }
+       // Success - clear any cached error and update token cache
+       c.cachedError = nil
+       c.retryAttempt = 0
+       c.nextRetryTime = time.Time{}
+
+       c.cachedAuthHeader = "Bearer " + token
+       // Per gRFC A97: consider token invalid if it expires within the next 30
+       // seconds to accommodate for clock skew and server processing time.
+       c.cachedExpiry = expiry.Add(-30 * time.Second)
+}
diff --git a/credentials/jwt/token_file_call_creds_test.go b/credentials/jwt/token_file_call_creds_test.go
new file mode 100644 (file)
index 0000000..48c1b4f
--- /dev/null
@@ -0,0 +1,555 @@
+/*
+ *
+ * Copyright 2025 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+package jwt
+
+import (
+       "context"
+       "encoding/base64"
+       "encoding/json"
+       "fmt"
+       "os"
+       "path/filepath"
+       "strings"
+       "testing"
+       "time"
+
+       "github.com/google/go-cmp/cmp"
+       "google.golang.org/grpc/codes"
+       "google.golang.org/grpc/credentials"
+       "google.golang.org/grpc/internal/grpctest"
+       "google.golang.org/grpc/status"
+)
+
+const defaultTestTimeout = 5 * time.Second
+
+type s struct {
+       grpctest.Tester
+}
+
+func Test(t *testing.T) {
+       grpctest.RunSubTests(t, s{})
+}
+
+func (s) TestNewTokenFileCallCredentialsValidFilepath(t *testing.T) {
+       creds, err := NewTokenFileCallCredentials("/path/to/token")
+       if err != nil {
+               t.Fatalf("NewTokenFileCallCredentials() unexpected error: %v", err)
+       }
+       if creds == nil {
+               t.Fatal("NewTokenFileCallCredentials() returned nil credentials")
+       }
+}
+
+func (s) TestNewTokenFileCallCredentialsMissingFilepath(t *testing.T) {
+       if _, err := NewTokenFileCallCredentials(""); err == nil {
+               t.Fatalf("NewTokenFileCallCredentials() expected error, got nil")
+       }
+}
+
+func (s) TestTokenFileCallCreds_RequireTransportSecurity(t *testing.T) {
+       creds, err := NewTokenFileCallCredentials("/path/to/token")
+       if err != nil {
+               t.Fatalf("NewTokenFileCallCredentials() failed: %v", err)
+       }
+
+       if !creds.RequireTransportSecurity() {
+               t.Error("RequireTransportSecurity() = false, want true")
+       }
+}
+
+func (s) TestTokenFileCallCreds_GetRequestMetadata(t *testing.T) {
+       now := time.Now().Truncate(time.Second)
+       tests := []struct {
+               name             string
+               invalidTokenPath bool
+               tokenContent     string
+               authInfo         credentials.AuthInfo
+               wantCode         codes.Code
+               wantMetadata     map[string]string
+       }{
+               {
+                       name:         "valid_token_with_future_expiration",
+                       tokenContent: createTestJWT(t, now.Add(time.Hour)),
+                       authInfo:     &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
+                       wantCode:     codes.OK,
+                       wantMetadata: map[string]string{"authorization": "Bearer " + createTestJWT(t, now.Add(time.Hour))},
+               },
+               {
+                       name:         "insufficient_security_level",
+                       tokenContent: createTestJWT(t, now.Add(time.Hour)),
+                       authInfo:     &testAuthInfo{secLevel: credentials.NoSecurity},
+                       wantCode:     codes.Unknown, // http2Client.getCallAuthData actually transforms such errors into into Unauthenticated
+               },
+               {
+                       name:             "unreachable_token_file",
+                       invalidTokenPath: true,
+                       authInfo:         &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
+                       wantCode:         codes.Unavailable,
+               },
+               {
+                       name:         "empty_file",
+                       tokenContent: "",
+                       authInfo:     &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
+                       wantCode:     codes.Unauthenticated,
+               },
+               {
+                       name:         "malformed_JWT_token",
+                       tokenContent: "bad contents",
+                       authInfo:     &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
+                       wantCode:     codes.Unauthenticated,
+               },
+       }
+
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       var tokenFile string
+                       if tt.invalidTokenPath {
+                               tokenFile = "/does-not-exist"
+                       } else {
+                               tokenFile = writeTempFile(t, "token", tt.tokenContent)
+                       }
+                       creds, err := NewTokenFileCallCredentials(tokenFile)
+                       if err != nil {
+                               t.Fatalf("NewTokenFileCallCredentials() failed: %v", err)
+                       }
+
+                       ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
+                       defer cancel()
+                       ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{
+                               AuthInfo: tt.authInfo,
+                       })
+
+                       metadata, err := creds.GetRequestMetadata(ctx)
+                       if gotCode := status.Code(err); gotCode != tt.wantCode {
+                               t.Fatalf("GetRequestMetadata() = %v, want %v", gotCode, tt.wantCode)
+                       }
+
+                       if diff := cmp.Diff(tt.wantMetadata, metadata); diff != "" {
+                               t.Errorf("GetRequestMetadata() returned unexpected metadata (-want +got):\n%s", diff)
+                       }
+               })
+       }
+}
+
+func (s) TestTokenFileCallCreds_TokenCaching(t *testing.T) {
+       token := createTestJWT(t, time.Now().Add(time.Hour))
+       tokenFile := writeTempFile(t, "token", token)
+
+       creds, err := NewTokenFileCallCredentials(tokenFile)
+       if err != nil {
+               t.Fatalf("NewTokenFileCallCredentials() failed: %v", err)
+       }
+
+       ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
+       defer cancel()
+       ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{
+               AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
+       })
+
+       // First call should read from file.
+       metadata1, err := creds.GetRequestMetadata(ctx)
+       if err != nil {
+               t.Fatalf("First GetRequestMetadata() failed: %v", err)
+       }
+       wantMetadata := map[string]string{"authorization": "Bearer " + token}
+       if diff := cmp.Diff(wantMetadata, metadata1); diff != "" {
+               t.Errorf("First GetRequestMetadata() returned unexpected metadata (-want +got):\n%s", diff)
+       }
+
+       // Update the file with a different token.
+       newToken := createTestJWT(t, time.Now().Add(2*time.Hour))
+       if err := os.WriteFile(tokenFile, []byte(newToken), 0600); err != nil {
+               t.Fatalf("Failed to update token file: %v", err)
+       }
+
+       // Second call should return cached token (not the updated one).
+       metadata2, err := creds.GetRequestMetadata(ctx)
+       if err != nil {
+               t.Fatalf("Second GetRequestMetadata() failed: %v", err)
+       }
+
+       if diff := cmp.Diff(metadata1, metadata2); diff != "" {
+               t.Errorf("Second GetRequestMetadata() returned unexpected metadata (-want +got):\n%s", diff)
+       }
+}
+
+// testAuthInfo implements credentials.AuthInfo for testing.
+type testAuthInfo struct {
+       secLevel credentials.SecurityLevel
+}
+
+func (t *testAuthInfo) AuthType() string {
+       return "test"
+}
+
+func (t *testAuthInfo) GetCommonAuthInfo() credentials.CommonAuthInfo {
+       return credentials.CommonAuthInfo{SecurityLevel: t.secLevel}
+}
+
+// Tests that cached token expiration is set to 30 seconds before actual token
+// expiration.
+// TODO: Refactor the test to avoid inspecting and mutating internal state.
+func (s) TestTokenFileCallCreds_CacheExpirationIsBeforeTokenExpiration(t *testing.T) {
+       // Create token that expires in 2 hours.
+       tokenExp := time.Now().Truncate(time.Second).Add(2 * time.Hour)
+       token := createTestJWT(t, tokenExp)
+       tokenFile := writeTempFile(t, "token", token)
+
+       creds, err := NewTokenFileCallCredentials(tokenFile)
+       if err != nil {
+               t.Fatalf("NewTokenFileCallCredentials() failed: %v", err)
+       }
+
+       ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
+       defer cancel()
+       ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{
+               AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
+       })
+
+       // Get token to trigger caching.
+       if _, err = creds.GetRequestMetadata(ctx); err != nil {
+               t.Fatalf("GetRequestMetadata() failed: %v", err)
+       }
+
+       // Verify cached expiration is 30 seconds before actual token expiration.
+       impl := creds.(*jwtTokenFileCallCreds)
+       impl.mu.Lock()
+       cachedExp := impl.cachedExpiry
+       impl.mu.Unlock()
+
+       wantExp := tokenExp.Add(-30 * time.Second)
+       if !cachedExp.Equal(wantExp) {
+               t.Errorf("Cache expiration = %v, want %v", cachedExp, wantExp)
+       }
+}
+
+// Tests that pre-emptive refresh is triggered within 1 minute of expiration.
+// This is tested as follows:
+// * A token which expires "soon" is created.
+// * On the first call to GetRequestMetadata, the token will get loaded and returned.
+// * Another token is created and overwrites the file.
+// * On the second call we will still return the (valid) first token but also
+// detect that a refresh needs to happen and trigger it.
+// * On the third call we confirm the new token has been loaded and returned.
+func (s) TestTokenFileCallCreds_PreemptiveRefreshIsTriggered(t *testing.T) {
+       // Create token that expires in 80 seconds (=> cache expires in ~50s).
+       // This ensures pre-emptive refresh triggers since 50s < the 1 minute check.
+       tokenExp := time.Now().Add(80 * time.Second)
+       expiringToken := createTestJWT(t, tokenExp)
+       tokenFile := writeTempFile(t, "token", expiringToken)
+
+       creds, err := NewTokenFileCallCredentials(tokenFile)
+       if err != nil {
+               t.Fatalf("NewTokenFileCallCredentials() failed: %v", err)
+       }
+
+       ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
+       defer cancel()
+       ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{
+               AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
+       })
+
+       // First call should read from file synchronously.
+       metadata1, err := creds.GetRequestMetadata(ctx)
+       if err != nil {
+               t.Fatalf("GetRequestMetadata() failed: %v", err)
+       }
+       wantAuth1 := "Bearer " + expiringToken
+       gotAuth1 := metadata1["authorization"]
+       if gotAuth1 != wantAuth1 {
+               t.Fatalf("First call should return original token: got %q, want %q", gotAuth1, wantAuth1)
+       }
+
+       // Verify token was cached and confirm expectation that refresh should be
+       // triggered.
+       impl := creds.(*jwtTokenFileCallCreds)
+       impl.mu.Lock()
+       cacheExp := impl.cachedExpiry
+       tokenCached := impl.cachedAuthHeader != ""
+       shouldTriggerRefresh := time.Until(cacheExp) < preemptiveRefreshThreshold
+       impl.mu.Unlock()
+
+       if !tokenCached {
+               t.Fatal("Token should be cached after successful GetRequestMetadata")
+       }
+
+       if !shouldTriggerRefresh {
+               timeUntilExp := time.Until(cacheExp)
+               t.Fatalf("Cache expires in %v; test precondition requires that this triggers preemptive refresh", timeUntilExp)
+       }
+
+       // Create new token file with different expiration while refresh is
+       // happening.
+       newToken := createTestJWT(t, time.Now().Add(2*time.Hour))
+       if err := os.WriteFile(tokenFile, []byte(newToken), 0600); err != nil {
+               t.Fatalf("Failed to write updated token file: %v", err)
+       }
+
+       // Get token again - this call should trigger a refresh given that the first
+       // one was cached but expiring soon.
+       // However, the function should have returned right away with the current
+       // cached token because it is still valid and the preemptive refresh is
+       // meant to happen without blocking the RPC.
+       metadata2, err := creds.GetRequestMetadata(ctx)
+       if err != nil {
+               t.Fatalf("Second GetRequestMetadata() failed: %v", err)
+       }
+       wantAuth2 := wantAuth1
+       gotAuth2 := metadata2["authorization"]
+       if gotAuth2 != wantAuth2 {
+               t.Fatalf("Second call should return the original token: got %q, want %q", gotAuth2, wantAuth2)
+       }
+
+       // Now should get the new token which was refreshed in the background.
+       wantAuth3 := "Bearer " + newToken
+       ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout)
+       defer cancel()
+       ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{
+               AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
+       })
+       for ; ; <-time.After(time.Millisecond) {
+               if ctx.Err() != nil {
+                       t.Fatal("Context deadline expired before pre-emptive refresh completed")
+               }
+               // If the newly returned metadata is different to the old one, verify
+               // that it matches the token from the updated file. If not, fail the
+               // test.
+               metadata3, err := creds.GetRequestMetadata(ctx)
+               if err != nil {
+                       t.Fatalf("Second GetRequestMetadata() failed: %v", err)
+               }
+               // Pre-emptive refresh not completed yet, try again.
+               gotAuth3 := metadata3["authorization"]
+               if gotAuth3 == gotAuth2 {
+                       continue
+               }
+               if gotAuth3 != wantAuth3 {
+                       t.Fatalf("Third call should return the new token: got %q, want %q", gotAuth3, wantAuth3)
+               }
+               break
+       }
+}
+
+// Tests that backoff behavior handles file read errors correctly.
+// It has the following expectations:
+// First call to GetRequestMetadata() fails with UNAVAILABLE due to a
+// missing file.
+// Second call to GetRequestMetadata() fails with UNAVAILABLE due backoff.
+// Third call to GetRequestMetadata() fails with UNAVAILABLE due to retry.
+// Fourth call to GetRequestMetadata() fails with UNAVAILABLE due to backoff
+// even though file exists.
+// Fifth call to GetRequestMetadata() succeeds after reading the file and
+// backoff has expired.
+// TODO: Refactor the test to avoid inspecting and mutating internal state.
+func (s) TestTokenFileCallCreds_BackoffBehavior(t *testing.T) {
+       tempDir := t.TempDir()
+       nonExistentFile := filepath.Join(tempDir, "nonexistent")
+
+       creds, err := NewTokenFileCallCredentials(nonExistentFile)
+       if err != nil {
+               t.Fatalf("NewTokenFileCallCredentials() failed: %v", err)
+       }
+
+       ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
+       defer cancel()
+       ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{
+               AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
+       })
+
+       // First call should fail with UNAVAILABLE.
+       beforeCallRetryTime := time.Now()
+       _, err = creds.GetRequestMetadata(ctx)
+       if err == nil {
+               t.Fatal("Expected error from nonexistent file")
+       }
+       if status.Code(err) != codes.Unavailable {
+               t.Fatalf("GetRequestMetadata() = %v, want UNAVAILABLE", status.Code(err))
+       }
+
+       // Verify error is cached internally.
+       impl := creds.(*jwtTokenFileCallCreds)
+       impl.mu.Lock()
+       retryAttempt := impl.retryAttempt
+       nextRetryTime := impl.nextRetryTime
+       impl.mu.Unlock()
+
+       if retryAttempt != 1 {
+               t.Errorf("Expected retry attempt to be 1, got %d", retryAttempt)
+       }
+       if !nextRetryTime.After(beforeCallRetryTime) {
+               t.Error("Next retry time should be set to a time after the first call")
+       }
+
+       // Second call should still return cached error and not retry.
+       // Set nextRetryTime far enough in the future to ensure that's the case.
+       impl.mu.Lock()
+       impl.nextRetryTime = time.Now().Add(1 * time.Minute)
+       wantNextRetryTime := impl.nextRetryTime
+       impl.mu.Unlock()
+       _, err = creds.GetRequestMetadata(ctx)
+       if err == nil {
+               t.Fatalf("creds.GetRequestMetadata() = %v, want non-nil", err)
+       }
+       if status.Code(err) != codes.Unavailable {
+               t.Fatalf("GetRequestMetadata() = %v, want cached UNAVAILABLE", status.Code(err))
+       }
+
+       impl.mu.Lock()
+       retryAttempt2 := impl.retryAttempt
+       nextRetryTime2 := impl.nextRetryTime
+       impl.mu.Unlock()
+       if !nextRetryTime2.Equal(wantNextRetryTime) {
+               t.Errorf("nextRetryTime should not change due to backoff. Got: %v, Want: %v", nextRetryTime2, wantNextRetryTime)
+       }
+       if retryAttempt2 != 1 {
+               t.Error("Retry attempt should not change due to backoff")
+       }
+
+       // Third call should retry but still fail with UNAVAILABLE.
+       // Set the backoff retry time in the past to allow next retry attempt.
+       impl.mu.Lock()
+       impl.nextRetryTime = time.Now().Add(-1 * time.Minute)
+       beforeCallRetryTime = impl.nextRetryTime
+       impl.mu.Unlock()
+       _, err = creds.GetRequestMetadata(ctx)
+       if err == nil {
+               t.Fatalf("creds.GetRequestMetadata() = %v, want non-nil", err)
+       }
+       if status.Code(err) != codes.Unavailable {
+               t.Fatalf("GetRequestMetadata() = %v, want cached UNAVAILABLE", status.Code(err))
+       }
+
+       impl.mu.Lock()
+       retryAttempt3 := impl.retryAttempt
+       nextRetryTime3 := impl.nextRetryTime
+       impl.mu.Unlock()
+
+       if !nextRetryTime3.After(beforeCallRetryTime) {
+               t.Error("nextRetryTime3 should have been updated after third call")
+       }
+       if retryAttempt3 != 2 {
+               t.Error("Expected retry attempt to increase after retry")
+       }
+
+       // Create valid token file.
+       validToken := createTestJWT(t, time.Now().Add(time.Hour))
+       if err := os.WriteFile(nonExistentFile, []byte(validToken), 0600); err != nil {
+               t.Fatalf("Failed to create valid token file: %v", err)
+       }
+
+       // Fourth call should still fail even though the file now exists due to backoff.
+       // Set nextRetryTime far enough in the future to ensure that's the case.
+       _, err = creds.GetRequestMetadata(ctx)
+       impl.mu.Lock()
+       impl.nextRetryTime = time.Now().Add(1 * time.Minute)
+       wantNextRetryTime = impl.nextRetryTime
+       impl.mu.Unlock()
+       if err == nil {
+               t.Fatalf("creds.GetRequestMetadata() = %v, want non-nil", err)
+       }
+       if status.Code(err) != codes.Unavailable {
+               t.Fatalf("GetRequestMetadata() = %v, want cached UNAVAILABLE", status.Code(err))
+       }
+
+       impl.mu.Lock()
+       retryAttempt4 := impl.retryAttempt
+       nextRetryTime4 := impl.nextRetryTime
+       impl.mu.Unlock()
+
+       if !nextRetryTime4.Equal(wantNextRetryTime) {
+               t.Errorf("nextRetryTime should not change due to backoff. Got: %v, Want: %v", nextRetryTime4, wantNextRetryTime)
+       }
+       if retryAttempt4 != retryAttempt3 {
+               t.Error("Retry attempt should not change due to backoff")
+       }
+
+       // Fifth call should succeed since the file now exists and the backoff has
+       // expired.
+       // Set the backoff retry time in the past to allow next retry attempt.
+       impl.mu.Lock()
+       impl.nextRetryTime = time.Now().Add(-1 * time.Minute)
+       impl.mu.Unlock()
+       _, err = creds.GetRequestMetadata(ctx)
+       if err != nil {
+               t.Fatalf("After creating valid token file, backoff should expire and trigger a token reload on the next RPC. GetRequestMetadata() should eventually succeed, but got: %v", err)
+       }
+       // If successful, verify error cache and backoff state were cleared.
+       impl.mu.Lock()
+       clearedErr := impl.cachedError
+       retryAttempt = impl.retryAttempt
+       nextRetryTime = impl.nextRetryTime
+       impl.mu.Unlock()
+
+       if clearedErr != nil {
+               t.Errorf("After successful retry, cached error should be cleared, got: %v", clearedErr)
+       }
+       if retryAttempt != 0 {
+               t.Errorf("After successful retry, retry attempt should be reset, got: %d", retryAttempt)
+       }
+       if !nextRetryTime.IsZero() {
+               t.Error("After successful retry, next retry time should be cleared")
+       }
+}
+
+// createTestJWT creates a test JWT token with the specified expiration.
+func createTestJWT(t *testing.T, expiration time.Time) string {
+       t.Helper()
+
+       claims := map[string]any{}
+       if !expiration.IsZero() {
+               claims["exp"] = expiration.Unix()
+       }
+
+       header := map[string]any{
+               "typ": "JWT",
+               "alg": "HS256",
+       }
+       headerBytes, err := json.Marshal(header)
+       if err != nil {
+               t.Fatalf("Failed to marshal header: %v", err)
+       }
+
+       claimsBytes, err := json.Marshal(claims)
+       if err != nil {
+               t.Fatalf("Failed to marshal claims: %v", err)
+       }
+
+       headerB64 := base64.URLEncoding.EncodeToString(headerBytes)
+       claimsB64 := base64.URLEncoding.EncodeToString(claimsBytes)
+
+       // Remove padding for URL-safe base64
+       headerB64 = strings.TrimRight(headerB64, "=")
+       claimsB64 = strings.TrimRight(claimsB64, "=")
+
+       // For testing, we'll use a fake signature
+       signature := base64.URLEncoding.EncodeToString([]byte("fake_signature"))
+       signature = strings.TrimRight(signature, "=")
+
+       return fmt.Sprintf("%s.%s.%s", headerB64, claimsB64, signature)
+}
+
+func writeTempFile(t *testing.T, name, content string) string {
+       t.Helper()
+       tempDir := t.TempDir()
+       filePath := filepath.Join(tempDir, name)
+       if err := os.WriteFile(filePath, []byte(content), 0600); err != nil {
+               t.Fatalf("Failed to write temp file: %v", err)
+       }
+       return filePath
+}