--- /dev/null
+/*
+ *
+ * Copyright 2025 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+package jwt
+
+import (
+ "encoding/base64"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "strings"
+ "testing"
+ "time"
+)
+
+func (s) TestJWTFileReader_ReadToken_FileErrors(t *testing.T) {
+ tests := []struct {
+ name string
+ create bool
+ contents string
+ wantErr error
+ }{
+ {
+ name: "nonexistent_file",
+ create: false,
+ contents: "",
+ wantErr: errTokenFileAccess,
+ },
+ {
+ name: "empty_file",
+ create: true,
+ contents: "",
+ wantErr: errJWTValidation,
+ },
+ {
+ name: "file_with_whitespace_only",
+ create: true,
+ contents: " \n\t ",
+ wantErr: errJWTValidation,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var tokenFile string
+ if !tt.create {
+ tokenFile = "/does-not-exist"
+ } else {
+ tokenFile = writeTempFile(t, "token", tt.contents)
+ }
+
+ reader := jwtFileReader{tokenFilePath: tokenFile}
+ if _, _, err := reader.readToken(); err == nil {
+ t.Fatal("ReadToken() expected error, got nil")
+ } else if !errors.Is(err, tt.wantErr) {
+ t.Fatalf("ReadToken() error = %v, want error %v", err, tt.wantErr)
+ }
+ })
+ }
+}
+
+func (s) TestJWTFileReader_ReadToken_InvalidJWT(t *testing.T) {
+ now := time.Now().Truncate(time.Second)
+ tests := []struct {
+ name string
+ tokenContent string
+ wantErr error
+ }{
+ {
+ name: "valid_token_without_expiration",
+ tokenContent: createTestJWT(t, time.Time{}),
+ wantErr: errJWTValidation,
+ },
+ {
+ name: "expired_token",
+ tokenContent: createTestJWT(t, now.Add(-time.Hour)),
+ wantErr: errJWTValidation,
+ },
+ {
+ name: "malformed_JWT_not_enough_parts",
+ tokenContent: "invalid.jwt",
+ wantErr: errJWTValidation,
+ },
+ {
+ name: "malformed_JWT_invalid_base64",
+ tokenContent: "header.invalid_base64!@#.signature",
+ wantErr: errJWTValidation,
+ },
+ {
+ name: "malformed_JWT_invalid_JSON",
+ tokenContent: createInvalidJWT(t),
+ wantErr: errJWTValidation,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ tokenFile := writeTempFile(t, "token", tt.tokenContent)
+
+ reader := jwtFileReader{tokenFilePath: tokenFile}
+ if _, _, err := reader.readToken(); err == nil {
+ t.Fatal("ReadToken() expected error, got nil")
+ } else if !errors.Is(err, tt.wantErr) {
+ t.Fatalf("ReadToken() error = %v, want error %v", err, tt.wantErr)
+ }
+ })
+ }
+}
+
+func (s) TestJWTFileReader_ReadToken_ValidToken(t *testing.T) {
+ now := time.Now().Truncate(time.Second)
+ tokenExp := now.Add(time.Hour)
+ token := createTestJWT(t, tokenExp)
+ tokenFile := writeTempFile(t, "token", token)
+
+ reader := jwtFileReader{tokenFilePath: tokenFile}
+ readToken, expiry, err := reader.readToken()
+ if err != nil {
+ t.Fatalf("ReadToken() unexpected error: %v", err)
+ }
+
+ if readToken != token {
+ t.Errorf("ReadToken() token = %q, want %q", readToken, token)
+ }
+
+ if !expiry.Equal(tokenExp) {
+ t.Errorf("ReadToken() expiry = %v, want %v", expiry, tokenExp)
+ }
+}
+
+// createInvalidJWT creates a JWT with invalid JSON in the payload.
+func createInvalidJWT(t *testing.T) string {
+ t.Helper()
+
+ header := map[string]any{
+ "typ": "JWT",
+ "alg": "HS256",
+ }
+
+ headerBytes, err := json.Marshal(header)
+ if err != nil {
+ t.Fatalf("Failed to marshal header: %v", err)
+ }
+
+ headerB64 := base64.URLEncoding.EncodeToString(headerBytes)
+ headerB64 = strings.TrimRight(headerB64, "=")
+
+ // Create invalid JSON payload
+ invalidJSON := "invalid json content"
+ payloadB64 := base64.URLEncoding.EncodeToString([]byte(invalidJSON))
+ payloadB64 = strings.TrimRight(payloadB64, "=")
+
+ signature := base64.URLEncoding.EncodeToString([]byte("fake_signature"))
+ signature = strings.TrimRight(signature, "=")
+
+ return fmt.Sprintf("%s.%s.%s", headerB64, payloadB64, signature)
+}
--- /dev/null
+/*
+ *
+ * Copyright 2025 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+package jwt
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "sync"
+ "time"
+
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/credentials"
+ "google.golang.org/grpc/internal/backoff"
+ "google.golang.org/grpc/status"
+)
+
+const preemptiveRefreshThreshold = time.Minute
+
+// jwtTokenFileCallCreds provides JWT token-based PerRPCCredentials that reads
+// tokens from a file.
+// This implementation follows the A97 JWT Call Credentials specification.
+type jwtTokenFileCallCreds struct {
+ fileReader *jwtFileReader
+ backoffStrategy backoff.Strategy
+
+ // cached data protected by mu
+ mu sync.Mutex
+ cachedAuthHeader string // "Bearer " + token
+ cachedExpiry time.Time // Slightly less than actual expiration time
+ cachedError error // Error from last failed attempt
+ retryAttempt int // Current retry attempt number
+ nextRetryTime time.Time // When next retry is allowed
+ pendingRefresh bool // Whether a refresh is currently in progress
+}
+
+// NewTokenFileCallCredentials creates PerRPCCredentials that reads JWT tokens
+// from the specified file path.
+func NewTokenFileCallCredentials(tokenFilePath string) (credentials.PerRPCCredentials, error) {
+ if tokenFilePath == "" {
+ return nil, fmt.Errorf("tokenFilePath cannot be empty")
+ }
+
+ creds := &jwtTokenFileCallCreds{
+ fileReader: &jwtFileReader{tokenFilePath: tokenFilePath},
+ backoffStrategy: backoff.DefaultExponential,
+ }
+
+ return creds, nil
+}
+
+// GetRequestMetadata gets the current request metadata, refreshing tokens if
+// required. This implementation follows the PerRPCCredentials interface. The
+// tokens will get automatically refreshed if they are about to expire or if
+// they haven't been loaded successfully yet.
+// If it's not possible to extract a token from the file, UNAVAILABLE is
+// returned.
+// If the token is extracted but invalid, then UNAUTHENTICATED is returned.
+// If errors are encoutered, a backoff is applied before retrying.
+func (c *jwtTokenFileCallCreds) GetRequestMetadata(ctx context.Context, _ ...string) (map[string]string, error) {
+ ri, _ := credentials.RequestInfoFromContext(ctx)
+ if err := credentials.CheckSecurityLevel(ri.AuthInfo, credentials.PrivacyAndIntegrity); err != nil {
+ return nil, fmt.Errorf("cannot send secure credentials on an insecure connection: %v", err)
+ }
+
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if c.isTokenValidLocked() {
+ needsPreemptiveRefresh := time.Until(c.cachedExpiry) < preemptiveRefreshThreshold
+ if needsPreemptiveRefresh && !c.pendingRefresh {
+ // Start refresh if not pending (handling the prior RPC may have
+ // just spawned a goroutine).
+ c.pendingRefresh = true
+ go c.refreshToken()
+ }
+ return map[string]string{
+ "authorization": c.cachedAuthHeader,
+ }, nil
+ }
+
+ // If in backoff state, just return the cached error.
+ if c.cachedError != nil && time.Now().Before(c.nextRetryTime) {
+ return nil, c.cachedError
+ }
+
+ // At this point, the token is either invalid or expired and we are no
+ // longer backing off from any encountered errors. So refresh it.
+ // NB: We are holding the lock while reading the token from file. This will
+ // cause other RPCs to block until the read completes (sucecssfully or not)
+ // and the cache is updated. Subsequent RPCs will end up using the cache.
+ // This is per A97.
+ token, expiry, err := c.fileReader.readToken()
+ c.updateCacheLocked(token, expiry, err)
+
+ if c.cachedError != nil {
+ return nil, c.cachedError
+ }
+ return map[string]string{
+ "authorization": c.cachedAuthHeader,
+ }, nil
+}
+
+// RequireTransportSecurity indicates whether the credentials requires
+// transport security.
+func (c *jwtTokenFileCallCreds) RequireTransportSecurity() bool {
+ return true
+}
+
+// isTokenValidLocked checks if the cached token is still valid.
+// Caller must hold c.mu lock.
+func (c *jwtTokenFileCallCreds) isTokenValidLocked() bool {
+ if c.cachedAuthHeader == "" {
+ return false
+ }
+ return c.cachedExpiry.After(time.Now())
+}
+
+// refreshToken reads the token from file and updates the cached data.
+func (c *jwtTokenFileCallCreds) refreshToken() {
+ // Deliberately not locking c.mu here. This way other RPCs can proceed
+ // while we read the token. This is per gRFC A97.
+ token, expiry, err := c.fileReader.readToken()
+
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.updateCacheLocked(token, expiry, err)
+ c.pendingRefresh = false
+}
+
+// updateCacheLocked updates the cached token, expiry, and error state.
+// If an error is provided, it determines whether to set it as an UNAVAILABLE
+// or UNAUTHENTICATED error based on the error type.
+// NOTE: This method (and its callers) do not queue up a token refresh/retry if
+// the expiration is soon / an error was encountered. Instead, this is done when
+// handling RPCs. This is as per gRFC A97, which states that it is
+// undesirable to retry loading the token if the channel is idle.
+// Caller must hold c.mu lock.
+func (c *jwtTokenFileCallCreds) updateCacheLocked(token string, expiry time.Time, err error) {
+ if err != nil {
+ // Convert to gRPC status codes
+ if errors.Is(err, errTokenFileAccess) {
+ c.cachedError = status.Error(codes.Unavailable, err.Error())
+ } else if errors.Is(err, errJWTValidation) {
+ c.cachedError = status.Error(codes.Unauthenticated, err.Error())
+ } else {
+ // Should not happen. Treat unknown errors as UNAUTHENTICATED.
+ c.cachedError = status.Error(codes.Unauthenticated, err.Error())
+ }
+ c.retryAttempt++
+ backoffDelay := c.backoffStrategy.Backoff(c.retryAttempt - 1)
+ c.nextRetryTime = time.Now().Add(backoffDelay)
+ return
+ }
+ // Success - clear any cached error and update token cache
+ c.cachedError = nil
+ c.retryAttempt = 0
+ c.nextRetryTime = time.Time{}
+
+ c.cachedAuthHeader = "Bearer " + token
+ // Per gRFC A97: consider token invalid if it expires within the next 30
+ // seconds to accommodate for clock skew and server processing time.
+ c.cachedExpiry = expiry.Add(-30 * time.Second)
+}
--- /dev/null
+/*
+ *
+ * Copyright 2025 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+package jwt
+
+import (
+ "context"
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/google/go-cmp/cmp"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/credentials"
+ "google.golang.org/grpc/internal/grpctest"
+ "google.golang.org/grpc/status"
+)
+
+const defaultTestTimeout = 5 * time.Second
+
+type s struct {
+ grpctest.Tester
+}
+
+func Test(t *testing.T) {
+ grpctest.RunSubTests(t, s{})
+}
+
+func (s) TestNewTokenFileCallCredentialsValidFilepath(t *testing.T) {
+ creds, err := NewTokenFileCallCredentials("/path/to/token")
+ if err != nil {
+ t.Fatalf("NewTokenFileCallCredentials() unexpected error: %v", err)
+ }
+ if creds == nil {
+ t.Fatal("NewTokenFileCallCredentials() returned nil credentials")
+ }
+}
+
+func (s) TestNewTokenFileCallCredentialsMissingFilepath(t *testing.T) {
+ if _, err := NewTokenFileCallCredentials(""); err == nil {
+ t.Fatalf("NewTokenFileCallCredentials() expected error, got nil")
+ }
+}
+
+func (s) TestTokenFileCallCreds_RequireTransportSecurity(t *testing.T) {
+ creds, err := NewTokenFileCallCredentials("/path/to/token")
+ if err != nil {
+ t.Fatalf("NewTokenFileCallCredentials() failed: %v", err)
+ }
+
+ if !creds.RequireTransportSecurity() {
+ t.Error("RequireTransportSecurity() = false, want true")
+ }
+}
+
+func (s) TestTokenFileCallCreds_GetRequestMetadata(t *testing.T) {
+ now := time.Now().Truncate(time.Second)
+ tests := []struct {
+ name string
+ invalidTokenPath bool
+ tokenContent string
+ authInfo credentials.AuthInfo
+ wantCode codes.Code
+ wantMetadata map[string]string
+ }{
+ {
+ name: "valid_token_with_future_expiration",
+ tokenContent: createTestJWT(t, now.Add(time.Hour)),
+ authInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
+ wantCode: codes.OK,
+ wantMetadata: map[string]string{"authorization": "Bearer " + createTestJWT(t, now.Add(time.Hour))},
+ },
+ {
+ name: "insufficient_security_level",
+ tokenContent: createTestJWT(t, now.Add(time.Hour)),
+ authInfo: &testAuthInfo{secLevel: credentials.NoSecurity},
+ wantCode: codes.Unknown, // http2Client.getCallAuthData actually transforms such errors into into Unauthenticated
+ },
+ {
+ name: "unreachable_token_file",
+ invalidTokenPath: true,
+ authInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
+ wantCode: codes.Unavailable,
+ },
+ {
+ name: "empty_file",
+ tokenContent: "",
+ authInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
+ wantCode: codes.Unauthenticated,
+ },
+ {
+ name: "malformed_JWT_token",
+ tokenContent: "bad contents",
+ authInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
+ wantCode: codes.Unauthenticated,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var tokenFile string
+ if tt.invalidTokenPath {
+ tokenFile = "/does-not-exist"
+ } else {
+ tokenFile = writeTempFile(t, "token", tt.tokenContent)
+ }
+ creds, err := NewTokenFileCallCredentials(tokenFile)
+ if err != nil {
+ t.Fatalf("NewTokenFileCallCredentials() failed: %v", err)
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
+ defer cancel()
+ ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{
+ AuthInfo: tt.authInfo,
+ })
+
+ metadata, err := creds.GetRequestMetadata(ctx)
+ if gotCode := status.Code(err); gotCode != tt.wantCode {
+ t.Fatalf("GetRequestMetadata() = %v, want %v", gotCode, tt.wantCode)
+ }
+
+ if diff := cmp.Diff(tt.wantMetadata, metadata); diff != "" {
+ t.Errorf("GetRequestMetadata() returned unexpected metadata (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func (s) TestTokenFileCallCreds_TokenCaching(t *testing.T) {
+ token := createTestJWT(t, time.Now().Add(time.Hour))
+ tokenFile := writeTempFile(t, "token", token)
+
+ creds, err := NewTokenFileCallCredentials(tokenFile)
+ if err != nil {
+ t.Fatalf("NewTokenFileCallCredentials() failed: %v", err)
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
+ defer cancel()
+ ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{
+ AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
+ })
+
+ // First call should read from file.
+ metadata1, err := creds.GetRequestMetadata(ctx)
+ if err != nil {
+ t.Fatalf("First GetRequestMetadata() failed: %v", err)
+ }
+ wantMetadata := map[string]string{"authorization": "Bearer " + token}
+ if diff := cmp.Diff(wantMetadata, metadata1); diff != "" {
+ t.Errorf("First GetRequestMetadata() returned unexpected metadata (-want +got):\n%s", diff)
+ }
+
+ // Update the file with a different token.
+ newToken := createTestJWT(t, time.Now().Add(2*time.Hour))
+ if err := os.WriteFile(tokenFile, []byte(newToken), 0600); err != nil {
+ t.Fatalf("Failed to update token file: %v", err)
+ }
+
+ // Second call should return cached token (not the updated one).
+ metadata2, err := creds.GetRequestMetadata(ctx)
+ if err != nil {
+ t.Fatalf("Second GetRequestMetadata() failed: %v", err)
+ }
+
+ if diff := cmp.Diff(metadata1, metadata2); diff != "" {
+ t.Errorf("Second GetRequestMetadata() returned unexpected metadata (-want +got):\n%s", diff)
+ }
+}
+
+// testAuthInfo implements credentials.AuthInfo for testing.
+type testAuthInfo struct {
+ secLevel credentials.SecurityLevel
+}
+
+func (t *testAuthInfo) AuthType() string {
+ return "test"
+}
+
+func (t *testAuthInfo) GetCommonAuthInfo() credentials.CommonAuthInfo {
+ return credentials.CommonAuthInfo{SecurityLevel: t.secLevel}
+}
+
+// Tests that cached token expiration is set to 30 seconds before actual token
+// expiration.
+// TODO: Refactor the test to avoid inspecting and mutating internal state.
+func (s) TestTokenFileCallCreds_CacheExpirationIsBeforeTokenExpiration(t *testing.T) {
+ // Create token that expires in 2 hours.
+ tokenExp := time.Now().Truncate(time.Second).Add(2 * time.Hour)
+ token := createTestJWT(t, tokenExp)
+ tokenFile := writeTempFile(t, "token", token)
+
+ creds, err := NewTokenFileCallCredentials(tokenFile)
+ if err != nil {
+ t.Fatalf("NewTokenFileCallCredentials() failed: %v", err)
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
+ defer cancel()
+ ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{
+ AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
+ })
+
+ // Get token to trigger caching.
+ if _, err = creds.GetRequestMetadata(ctx); err != nil {
+ t.Fatalf("GetRequestMetadata() failed: %v", err)
+ }
+
+ // Verify cached expiration is 30 seconds before actual token expiration.
+ impl := creds.(*jwtTokenFileCallCreds)
+ impl.mu.Lock()
+ cachedExp := impl.cachedExpiry
+ impl.mu.Unlock()
+
+ wantExp := tokenExp.Add(-30 * time.Second)
+ if !cachedExp.Equal(wantExp) {
+ t.Errorf("Cache expiration = %v, want %v", cachedExp, wantExp)
+ }
+}
+
+// Tests that pre-emptive refresh is triggered within 1 minute of expiration.
+// This is tested as follows:
+// * A token which expires "soon" is created.
+// * On the first call to GetRequestMetadata, the token will get loaded and returned.
+// * Another token is created and overwrites the file.
+// * On the second call we will still return the (valid) first token but also
+// detect that a refresh needs to happen and trigger it.
+// * On the third call we confirm the new token has been loaded and returned.
+func (s) TestTokenFileCallCreds_PreemptiveRefreshIsTriggered(t *testing.T) {
+ // Create token that expires in 80 seconds (=> cache expires in ~50s).
+ // This ensures pre-emptive refresh triggers since 50s < the 1 minute check.
+ tokenExp := time.Now().Add(80 * time.Second)
+ expiringToken := createTestJWT(t, tokenExp)
+ tokenFile := writeTempFile(t, "token", expiringToken)
+
+ creds, err := NewTokenFileCallCredentials(tokenFile)
+ if err != nil {
+ t.Fatalf("NewTokenFileCallCredentials() failed: %v", err)
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
+ defer cancel()
+ ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{
+ AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
+ })
+
+ // First call should read from file synchronously.
+ metadata1, err := creds.GetRequestMetadata(ctx)
+ if err != nil {
+ t.Fatalf("GetRequestMetadata() failed: %v", err)
+ }
+ wantAuth1 := "Bearer " + expiringToken
+ gotAuth1 := metadata1["authorization"]
+ if gotAuth1 != wantAuth1 {
+ t.Fatalf("First call should return original token: got %q, want %q", gotAuth1, wantAuth1)
+ }
+
+ // Verify token was cached and confirm expectation that refresh should be
+ // triggered.
+ impl := creds.(*jwtTokenFileCallCreds)
+ impl.mu.Lock()
+ cacheExp := impl.cachedExpiry
+ tokenCached := impl.cachedAuthHeader != ""
+ shouldTriggerRefresh := time.Until(cacheExp) < preemptiveRefreshThreshold
+ impl.mu.Unlock()
+
+ if !tokenCached {
+ t.Fatal("Token should be cached after successful GetRequestMetadata")
+ }
+
+ if !shouldTriggerRefresh {
+ timeUntilExp := time.Until(cacheExp)
+ t.Fatalf("Cache expires in %v; test precondition requires that this triggers preemptive refresh", timeUntilExp)
+ }
+
+ // Create new token file with different expiration while refresh is
+ // happening.
+ newToken := createTestJWT(t, time.Now().Add(2*time.Hour))
+ if err := os.WriteFile(tokenFile, []byte(newToken), 0600); err != nil {
+ t.Fatalf("Failed to write updated token file: %v", err)
+ }
+
+ // Get token again - this call should trigger a refresh given that the first
+ // one was cached but expiring soon.
+ // However, the function should have returned right away with the current
+ // cached token because it is still valid and the preemptive refresh is
+ // meant to happen without blocking the RPC.
+ metadata2, err := creds.GetRequestMetadata(ctx)
+ if err != nil {
+ t.Fatalf("Second GetRequestMetadata() failed: %v", err)
+ }
+ wantAuth2 := wantAuth1
+ gotAuth2 := metadata2["authorization"]
+ if gotAuth2 != wantAuth2 {
+ t.Fatalf("Second call should return the original token: got %q, want %q", gotAuth2, wantAuth2)
+ }
+
+ // Now should get the new token which was refreshed in the background.
+ wantAuth3 := "Bearer " + newToken
+ ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout)
+ defer cancel()
+ ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{
+ AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
+ })
+ for ; ; <-time.After(time.Millisecond) {
+ if ctx.Err() != nil {
+ t.Fatal("Context deadline expired before pre-emptive refresh completed")
+ }
+ // If the newly returned metadata is different to the old one, verify
+ // that it matches the token from the updated file. If not, fail the
+ // test.
+ metadata3, err := creds.GetRequestMetadata(ctx)
+ if err != nil {
+ t.Fatalf("Second GetRequestMetadata() failed: %v", err)
+ }
+ // Pre-emptive refresh not completed yet, try again.
+ gotAuth3 := metadata3["authorization"]
+ if gotAuth3 == gotAuth2 {
+ continue
+ }
+ if gotAuth3 != wantAuth3 {
+ t.Fatalf("Third call should return the new token: got %q, want %q", gotAuth3, wantAuth3)
+ }
+ break
+ }
+}
+
+// Tests that backoff behavior handles file read errors correctly.
+// It has the following expectations:
+// First call to GetRequestMetadata() fails with UNAVAILABLE due to a
+// missing file.
+// Second call to GetRequestMetadata() fails with UNAVAILABLE due backoff.
+// Third call to GetRequestMetadata() fails with UNAVAILABLE due to retry.
+// Fourth call to GetRequestMetadata() fails with UNAVAILABLE due to backoff
+// even though file exists.
+// Fifth call to GetRequestMetadata() succeeds after reading the file and
+// backoff has expired.
+// TODO: Refactor the test to avoid inspecting and mutating internal state.
+func (s) TestTokenFileCallCreds_BackoffBehavior(t *testing.T) {
+ tempDir := t.TempDir()
+ nonExistentFile := filepath.Join(tempDir, "nonexistent")
+
+ creds, err := NewTokenFileCallCredentials(nonExistentFile)
+ if err != nil {
+ t.Fatalf("NewTokenFileCallCredentials() failed: %v", err)
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
+ defer cancel()
+ ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{
+ AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
+ })
+
+ // First call should fail with UNAVAILABLE.
+ beforeCallRetryTime := time.Now()
+ _, err = creds.GetRequestMetadata(ctx)
+ if err == nil {
+ t.Fatal("Expected error from nonexistent file")
+ }
+ if status.Code(err) != codes.Unavailable {
+ t.Fatalf("GetRequestMetadata() = %v, want UNAVAILABLE", status.Code(err))
+ }
+
+ // Verify error is cached internally.
+ impl := creds.(*jwtTokenFileCallCreds)
+ impl.mu.Lock()
+ retryAttempt := impl.retryAttempt
+ nextRetryTime := impl.nextRetryTime
+ impl.mu.Unlock()
+
+ if retryAttempt != 1 {
+ t.Errorf("Expected retry attempt to be 1, got %d", retryAttempt)
+ }
+ if !nextRetryTime.After(beforeCallRetryTime) {
+ t.Error("Next retry time should be set to a time after the first call")
+ }
+
+ // Second call should still return cached error and not retry.
+ // Set nextRetryTime far enough in the future to ensure that's the case.
+ impl.mu.Lock()
+ impl.nextRetryTime = time.Now().Add(1 * time.Minute)
+ wantNextRetryTime := impl.nextRetryTime
+ impl.mu.Unlock()
+ _, err = creds.GetRequestMetadata(ctx)
+ if err == nil {
+ t.Fatalf("creds.GetRequestMetadata() = %v, want non-nil", err)
+ }
+ if status.Code(err) != codes.Unavailable {
+ t.Fatalf("GetRequestMetadata() = %v, want cached UNAVAILABLE", status.Code(err))
+ }
+
+ impl.mu.Lock()
+ retryAttempt2 := impl.retryAttempt
+ nextRetryTime2 := impl.nextRetryTime
+ impl.mu.Unlock()
+ if !nextRetryTime2.Equal(wantNextRetryTime) {
+ t.Errorf("nextRetryTime should not change due to backoff. Got: %v, Want: %v", nextRetryTime2, wantNextRetryTime)
+ }
+ if retryAttempt2 != 1 {
+ t.Error("Retry attempt should not change due to backoff")
+ }
+
+ // Third call should retry but still fail with UNAVAILABLE.
+ // Set the backoff retry time in the past to allow next retry attempt.
+ impl.mu.Lock()
+ impl.nextRetryTime = time.Now().Add(-1 * time.Minute)
+ beforeCallRetryTime = impl.nextRetryTime
+ impl.mu.Unlock()
+ _, err = creds.GetRequestMetadata(ctx)
+ if err == nil {
+ t.Fatalf("creds.GetRequestMetadata() = %v, want non-nil", err)
+ }
+ if status.Code(err) != codes.Unavailable {
+ t.Fatalf("GetRequestMetadata() = %v, want cached UNAVAILABLE", status.Code(err))
+ }
+
+ impl.mu.Lock()
+ retryAttempt3 := impl.retryAttempt
+ nextRetryTime3 := impl.nextRetryTime
+ impl.mu.Unlock()
+
+ if !nextRetryTime3.After(beforeCallRetryTime) {
+ t.Error("nextRetryTime3 should have been updated after third call")
+ }
+ if retryAttempt3 != 2 {
+ t.Error("Expected retry attempt to increase after retry")
+ }
+
+ // Create valid token file.
+ validToken := createTestJWT(t, time.Now().Add(time.Hour))
+ if err := os.WriteFile(nonExistentFile, []byte(validToken), 0600); err != nil {
+ t.Fatalf("Failed to create valid token file: %v", err)
+ }
+
+ // Fourth call should still fail even though the file now exists due to backoff.
+ // Set nextRetryTime far enough in the future to ensure that's the case.
+ _, err = creds.GetRequestMetadata(ctx)
+ impl.mu.Lock()
+ impl.nextRetryTime = time.Now().Add(1 * time.Minute)
+ wantNextRetryTime = impl.nextRetryTime
+ impl.mu.Unlock()
+ if err == nil {
+ t.Fatalf("creds.GetRequestMetadata() = %v, want non-nil", err)
+ }
+ if status.Code(err) != codes.Unavailable {
+ t.Fatalf("GetRequestMetadata() = %v, want cached UNAVAILABLE", status.Code(err))
+ }
+
+ impl.mu.Lock()
+ retryAttempt4 := impl.retryAttempt
+ nextRetryTime4 := impl.nextRetryTime
+ impl.mu.Unlock()
+
+ if !nextRetryTime4.Equal(wantNextRetryTime) {
+ t.Errorf("nextRetryTime should not change due to backoff. Got: %v, Want: %v", nextRetryTime4, wantNextRetryTime)
+ }
+ if retryAttempt4 != retryAttempt3 {
+ t.Error("Retry attempt should not change due to backoff")
+ }
+
+ // Fifth call should succeed since the file now exists and the backoff has
+ // expired.
+ // Set the backoff retry time in the past to allow next retry attempt.
+ impl.mu.Lock()
+ impl.nextRetryTime = time.Now().Add(-1 * time.Minute)
+ impl.mu.Unlock()
+ _, err = creds.GetRequestMetadata(ctx)
+ if err != nil {
+ t.Fatalf("After creating valid token file, backoff should expire and trigger a token reload on the next RPC. GetRequestMetadata() should eventually succeed, but got: %v", err)
+ }
+ // If successful, verify error cache and backoff state were cleared.
+ impl.mu.Lock()
+ clearedErr := impl.cachedError
+ retryAttempt = impl.retryAttempt
+ nextRetryTime = impl.nextRetryTime
+ impl.mu.Unlock()
+
+ if clearedErr != nil {
+ t.Errorf("After successful retry, cached error should be cleared, got: %v", clearedErr)
+ }
+ if retryAttempt != 0 {
+ t.Errorf("After successful retry, retry attempt should be reset, got: %d", retryAttempt)
+ }
+ if !nextRetryTime.IsZero() {
+ t.Error("After successful retry, next retry time should be cleared")
+ }
+}
+
+// createTestJWT creates a test JWT token with the specified expiration.
+func createTestJWT(t *testing.T, expiration time.Time) string {
+ t.Helper()
+
+ claims := map[string]any{}
+ if !expiration.IsZero() {
+ claims["exp"] = expiration.Unix()
+ }
+
+ header := map[string]any{
+ "typ": "JWT",
+ "alg": "HS256",
+ }
+ headerBytes, err := json.Marshal(header)
+ if err != nil {
+ t.Fatalf("Failed to marshal header: %v", err)
+ }
+
+ claimsBytes, err := json.Marshal(claims)
+ if err != nil {
+ t.Fatalf("Failed to marshal claims: %v", err)
+ }
+
+ headerB64 := base64.URLEncoding.EncodeToString(headerBytes)
+ claimsB64 := base64.URLEncoding.EncodeToString(claimsBytes)
+
+ // Remove padding for URL-safe base64
+ headerB64 = strings.TrimRight(headerB64, "=")
+ claimsB64 = strings.TrimRight(claimsB64, "=")
+
+ // For testing, we'll use a fake signature
+ signature := base64.URLEncoding.EncodeToString([]byte("fake_signature"))
+ signature = strings.TrimRight(signature, "=")
+
+ return fmt.Sprintf("%s.%s.%s", headerB64, claimsB64, signature)
+}
+
+func writeTempFile(t *testing.T, name, content string) string {
+ t.Helper()
+ tempDir := t.TempDir()
+ filePath := filepath.Join(tempDir, name)
+ if err := os.WriteFile(filePath, []byte(content), 0600); err != nil {
+ t.Fatalf("Failed to write temp file: %v", err)
+ }
+ return filePath
+}