From 806e3eca6744b54783f61dd289068b19357847ef Mon Sep 17 00:00:00 2001 From: Engin Polat <118744+polatengin@users.noreply.github.com> Date: Wed, 17 Dec 2025 10:24:43 -0800 Subject: [PATCH 1/2] feat: implement credential caching and retrieval mechanism in auth --- internal/api/auth.go | 223 ++++++++++++------- internal/api/auth_test.go | 447 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 592 insertions(+), 78 deletions(-) diff --git a/internal/api/auth.go b/internal/api/auth.go index 9808058b0..b5aa54aff 100644 --- a/internal/api/auth.go +++ b/internal/api/auth.go @@ -13,6 +13,7 @@ import ( "net/url" "os" "strings" + "sync" "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" @@ -31,8 +32,30 @@ func (e *TokenExpiredError) Error() string { return e.Message } +type credentialType string + +const ( + credTypeClientSecret credentialType = "client_secret" + credTypeClientCertificate credentialType = "client_certificate" + credTypeCLI credentialType = "cli" + credTypeDevCLI credentialType = "dev_cli" + credTypeAzDOPipelines credentialType = "azdo_pipelines" + credTypeOIDC credentialType = "oidc" + credTypeUserManagedIdentity credentialType = "user_managed_identity" + credTypeSystemManagedIdentity credentialType = "system_managed_identity" +) + +type credentialHolder struct { + credential azcore.TokenCredential + once sync.Once + err error +} + type Auth struct { config *config.ProviderConfig + + credentials map[credentialType]*credentialHolder + mu sync.RWMutex } type OidcCredential struct { @@ -55,32 +78,65 @@ type OidcCredentialOptions struct { func NewAuthBase(configValue *config.ProviderConfig) *Auth { return &Auth{ - config: configValue, + config: configValue, + credentials: make(map[credentialType]*credentialHolder), } } -func (client *Auth) AuthenticateClientCertificate(ctx context.Context, scopes []string) (string, time.Time, error) { - cert, key, err := helpers.ConvertBase64ToCert(client.config.ClientCertificateRaw, client.config.ClientCertificatePassword) - if err != nil { - return "", time.Time{}, err +func (client *Auth) getOrCreateCredential(ctx context.Context, credType credentialType, factory func() (azcore.TokenCredential, error)) (azcore.TokenCredential, error) { + client.mu.RLock() + holder, exists := client.credentials[credType] + client.mu.RUnlock() + + if !exists { + client.mu.Lock() + holder, exists = client.credentials[credType] + if !exists { + holder = &credentialHolder{} + client.credentials[credType] = holder + tflog.Debug(ctx, fmt.Sprintf("Created credential holder for type: %s", credType)) + } + client.mu.Unlock() } - azureCertCredentials, err := azidentity.NewClientCertificateCredential( - client.config.TenantId, - client.config.ClientId, - cert, - key, - &azidentity.ClientCertificateCredentialOptions{ - AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs, - ClientOptions: azcore.ClientOptions{ - Cloud: client.config.Cloud, + holder.once.Do(func() { + tflog.Debug(ctx, fmt.Sprintf("Initializing credential for type: %s", credType)) + holder.credential, holder.err = factory() + if holder.err != nil { + tflog.Error(ctx, fmt.Sprintf("Failed to create credential for type %s: %s", credType, holder.err.Error())) + } else { + tflog.Debug(ctx, fmt.Sprintf("Successfully created credential for type: %s", credType)) + } + }) + + return holder.credential, holder.err +} + +func (client *Auth) AuthenticateClientCertificate(ctx context.Context, scopes []string) (string, time.Time, error) { + cred, err := client.getOrCreateCredential(ctx, credTypeClientCertificate, func() (azcore.TokenCredential, error) { + cert, key, certErr := helpers.ConvertBase64ToCert(client.config.ClientCertificateRaw, client.config.ClientCertificatePassword) + if certErr != nil { + return nil, certErr + } + + return azidentity.NewClientCertificateCredential( + client.config.TenantId, + client.config.ClientId, + cert, + key, + &azidentity.ClientCertificateCredentialOptions{ + AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs, + ClientOptions: azcore.ClientOptions{ + Cloud: client.config.Cloud, + }, }, - }, - ) + ) + }) if err != nil { return "", time.Time{}, err } - accessToken, err := azureCertCredentials.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes)) + + accessToken, err := cred.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes)) if err != nil { return "", time.Time{}, err } @@ -88,15 +144,17 @@ func (client *Auth) AuthenticateClientCertificate(ctx context.Context, scopes [] } func (client *Auth) AuthenticateUsingCli(ctx context.Context, scopes []string) (string, time.Time, error) { - azureCLICredentials, err := azidentity.NewAzureCLICredential(&azidentity.AzureCLICredentialOptions{ - AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs, - TenantID: client.config.TenantId, + cred, err := client.getOrCreateCredential(ctx, credTypeCLI, func() (azcore.TokenCredential, error) { + return azidentity.NewAzureCLICredential(&azidentity.AzureCLICredentialOptions{ + AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs, + TenantID: client.config.TenantId, + }) }) if err != nil { return "", time.Time{}, err } - accessToken, err := azureCLICredentials.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes)) + accessToken, err := cred.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes)) if err != nil { return "", time.Time{}, err } @@ -105,15 +163,17 @@ func (client *Auth) AuthenticateUsingCli(ctx context.Context, scopes []string) ( } func (client *Auth) AuthenticateUsingAzureDeveloperCli(ctx context.Context, scopes []string) (string, time.Time, error) { - azureDeveloperCLICredentials, err := azidentity.NewAzureDeveloperCLICredential(&azidentity.AzureDeveloperCLICredentialOptions{ - AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs, - TenantID: client.config.TenantId, + cred, err := client.getOrCreateCredential(ctx, credTypeDevCLI, func() (azcore.TokenCredential, error) { + return azidentity.NewAzureDeveloperCLICredential(&azidentity.AzureDeveloperCLICredentialOptions{ + AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs, + TenantID: client.config.TenantId, + }) }) if err != nil { return "", time.Time{}, err } - accessToken, err := azureDeveloperCLICredentials.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes)) + accessToken, err := cred.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes)) if err != nil { return "", time.Time{}, err } @@ -122,21 +182,24 @@ func (client *Auth) AuthenticateUsingAzureDeveloperCli(ctx context.Context, scop } func (client *Auth) AuthenticateClientSecret(ctx context.Context, scopes []string) (string, time.Time, error) { - clientSecretCredential, err := azidentity.NewClientSecretCredential( - client.config.TenantId, - client.config.ClientId, - client.config.ClientSecret, &azidentity.ClientSecretCredentialOptions{ - AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs, - ClientOptions: azcore.ClientOptions{ - Cloud: client.config.Cloud, + cred, err := client.getOrCreateCredential(ctx, credTypeClientSecret, func() (azcore.TokenCredential, error) { + return azidentity.NewClientSecretCredential( + client.config.TenantId, + client.config.ClientId, + client.config.ClientSecret, + &azidentity.ClientSecretCredentialOptions{ + AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs, + ClientOptions: azcore.ClientOptions{ + Cloud: client.config.Cloud, + }, }, - }) + ) + }) if err != nil { return "", time.Time{}, err } - accessToken, err := clientSecretCredential.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes)) - + accessToken, err := cred.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes)) if err != nil { return "", time.Time{}, err } @@ -182,32 +245,30 @@ func (w *OidcCredential) GetToken(ctx context.Context, opts policy.TokenRequestO } func (client *Auth) AuthenticateOIDC(ctx context.Context, scopes []string) (string, time.Time, error) { - var creds []azcore.TokenCredential - - oidcCred, err := client.NewOidcCredential(&OidcCredentialOptions{ - ClientOptions: azcore.ClientOptions{ - Cloud: client.config.Cloud, - }, - TenantID: client.config.TenantId, - ClientID: client.config.ClientId, - RequestToken: client.config.OidcRequestToken, - RequestUrl: client.config.OidcRequestUrl, - Token: client.config.OidcToken, - TokenFilePath: client.config.OidcTokenFilePath, - }) + cred, err := client.getOrCreateCredential(ctx, credTypeOIDC, func() (azcore.TokenCredential, error) { + oidcCred, oidcErr := client.NewOidcCredential(&OidcCredentialOptions{ + ClientOptions: azcore.ClientOptions{ + Cloud: client.config.Cloud, + }, + TenantID: client.config.TenantId, + ClientID: client.config.ClientId, + RequestToken: client.config.OidcRequestToken, + RequestUrl: client.config.OidcRequestUrl, + Token: client.config.OidcToken, + TokenFilePath: client.config.OidcTokenFilePath, + }) + if oidcErr != nil { + return nil, oidcErr + } + return azidentity.NewChainedTokenCredential([]azcore.TokenCredential{oidcCred}, nil) + }) if err != nil { tflog.Error(ctx, fmt.Sprintf("newDefaultAzureCredential failed to initialize oidc credential:\n\t%s", err.Error())) return "", time.Time{}, err } - creds = append(creds, oidcCred) - chain, err := azidentity.NewChainedTokenCredential(creds, nil) - if err != nil { - return "", time.Time{}, err - } - - accessToken, err := chain.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes)) + accessToken, err := cred.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes)) if err != nil { return "", time.Time{}, err } @@ -216,17 +277,19 @@ func (client *Auth) AuthenticateOIDC(ctx context.Context, scopes []string) (stri } func (client *Auth) AuthenticateUserManagedIdentity(ctx context.Context, scopes []string) (string, time.Time, error) { - userManagedIdentityCredential, err := azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{ - ID: azidentity.ClientID(client.config.ClientId), - ClientOptions: azcore.ClientOptions{ - Cloud: client.config.Cloud, - }, + cred, err := client.getOrCreateCredential(ctx, credTypeUserManagedIdentity, func() (azcore.TokenCredential, error) { + return azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{ + ID: azidentity.ClientID(client.config.ClientId), + ClientOptions: azcore.ClientOptions{ + Cloud: client.config.Cloud, + }, + }) }) if err != nil { return "", time.Time{}, err } - accessToken, err := userManagedIdentityCredential.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes)) + accessToken, err := cred.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes)) if err != nil { return "", time.Time{}, err } @@ -235,16 +298,18 @@ func (client *Auth) AuthenticateUserManagedIdentity(ctx context.Context, scopes } func (client *Auth) AuthenticateSystemManagedIdentity(ctx context.Context, scopes []string) (string, time.Time, error) { - systemManagedIdentityCredential, err := azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{ - ClientOptions: azcore.ClientOptions{ - Cloud: client.config.Cloud, - }, + cred, err := client.getOrCreateCredential(ctx, credTypeSystemManagedIdentity, func() (azcore.TokenCredential, error) { + return azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{ + ClientOptions: azcore.ClientOptions{ + Cloud: client.config.Cloud, + }, + }) }) if err != nil { return "", time.Time{}, err } - accessToken, err := systemManagedIdentityCredential.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes)) + accessToken, err := cred.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes)) if err != nil { return "", time.Time{}, err } @@ -266,23 +331,25 @@ func (client *Auth) AuthenticateAzDOWorkloadIdentityFederation(ctx context.Conte return "", time.Time{}, errors.New("could not obtain an OIDC request token for Azure DevOps Workload Identity Federation") } - azdoWorkloadIdentityCredential, err := azidentity.NewAzurePipelinesCredential( - client.config.TenantId, - client.config.ClientId, - client.config.AzDOServiceConnectionID, - client.config.OidcRequestToken, - &azidentity.AzurePipelinesCredentialOptions{ - AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs, - ClientOptions: azcore.ClientOptions{ - Cloud: client.config.Cloud, + cred, err := client.getOrCreateCredential(ctx, credTypeAzDOPipelines, func() (azcore.TokenCredential, error) { + return azidentity.NewAzurePipelinesCredential( + client.config.TenantId, + client.config.ClientId, + client.config.AzDOServiceConnectionID, + client.config.OidcRequestToken, + &azidentity.AzurePipelinesCredentialOptions{ + AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs, + ClientOptions: azcore.ClientOptions{ + Cloud: client.config.Cloud, + }, }, - }, - ) + ) + }) if err != nil { return "", time.Time{}, err } - accessToken, err := azdoWorkloadIdentityCredential.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes)) + accessToken, err := cred.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes)) if err != nil { return "", time.Time{}, err } diff --git a/internal/api/auth_test.go b/internal/api/auth_test.go index cfe833a43..f37822fb7 100644 --- a/internal/api/auth_test.go +++ b/internal/api/auth_test.go @@ -5,12 +5,41 @@ package api import ( "context" + "errors" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "sync" + "sync/atomic" "testing" + "time" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/microsoft/terraform-provider-power-platform/internal/config" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +type mockTokenCredential struct { + getTokenCallCount int32 + tokenValue string + expiresOn time.Time + err error +} + +func (m *mockTokenCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { + atomic.AddInt32(&m.getTokenCallCount, 1) + if m.err != nil { + return azcore.AccessToken{}, m.err + } + return azcore.AccessToken{ + Token: m.tokenValue, + ExpiresOn: m.expiresOn, + }, nil +} + func TestUnitCreateTokenRequestOptions(t *testing.T) { scopes := []string{"https://management.azure.com/.default"} ctx := context.Background() @@ -112,3 +141,421 @@ func TestUnitAuthenticateUsingAzureDeveloperCli_ConfigurationCheck(t *testing.T) }) } } + +func TestUnit_GetOrCreateCredential_ReusesCredential(t *testing.T) { + ctx := context.Background() + providerConfig := &config.ProviderConfig{} + authClient := NewAuthBase(providerConfig) + + factoryCallCount := 0 + mockCred := &mockTokenCredential{tokenValue: "test-token"} + + factory := func() (azcore.TokenCredential, error) { + factoryCallCount++ + return mockCred, nil + } + + // First call - should create the credential + cred1, err1 := authClient.getOrCreateCredential(ctx, credTypeClientSecret, factory) + require.NoError(t, err1) + assert.NotNil(t, cred1) + assert.Equal(t, 1, factoryCallCount, "Factory should be called exactly once on first call") + + // Second call - should reuse the credential + cred2, err2 := authClient.getOrCreateCredential(ctx, credTypeClientSecret, factory) + require.NoError(t, err2) + assert.NotNil(t, cred2) + assert.Equal(t, 1, factoryCallCount, "Factory should not be called again on second call") + + // Third call - should still reuse the credential + cred3, err3 := authClient.getOrCreateCredential(ctx, credTypeClientSecret, factory) + require.NoError(t, err3) + assert.NotNil(t, cred3) + assert.Equal(t, 1, factoryCallCount, "Factory should not be called again on third call") + + // Verify same instance is returned + assert.Same(t, cred1, cred2, "Same credential instance should be returned") + assert.Same(t, cred2, cred3, "Same credential instance should be returned") +} + +func TestUnit_GetOrCreateCredential_DifferentTypesAreSeparate(t *testing.T) { + ctx := context.Background() + providerConfig := &config.ProviderConfig{} + authClient := NewAuthBase(providerConfig) + + clientSecretCallCount := 0 + cliCallCount := 0 + + clientSecretCred := &mockTokenCredential{tokenValue: "client-secret-token"} + cliCred := &mockTokenCredential{tokenValue: "cli-token"} + + clientSecretFactory := func() (azcore.TokenCredential, error) { + clientSecretCallCount++ + return clientSecretCred, nil + } + + cliFactory := func() (azcore.TokenCredential, error) { + cliCallCount++ + return cliCred, nil + } + + // Create client secret credential + cred1, err := authClient.getOrCreateCredential(ctx, credTypeClientSecret, clientSecretFactory) + require.NoError(t, err) + assert.Equal(t, 1, clientSecretCallCount) + assert.Equal(t, 0, cliCallCount) + + // Create CLI credential + cred2, err := authClient.getOrCreateCredential(ctx, credTypeCLI, cliFactory) + require.NoError(t, err) + assert.Equal(t, 1, clientSecretCallCount) + assert.Equal(t, 1, cliCallCount) + + // Verify different instances + assert.NotSame(t, cred1, cred2, "Different credential types should return different instances") + assert.Same(t, cred1, clientSecretCred) + assert.Same(t, cred2, cliCred) + + // Re-request both - should reuse + cred1Again, _ := authClient.getOrCreateCredential(ctx, credTypeClientSecret, clientSecretFactory) + cred2Again, _ := authClient.getOrCreateCredential(ctx, credTypeCLI, cliFactory) + + assert.Equal(t, 1, clientSecretCallCount, "Factory should not be called again") + assert.Equal(t, 1, cliCallCount, "Factory should not be called again") + assert.Same(t, cred1, cred1Again) + assert.Same(t, cred2, cred2Again) +} + +func TestUnit_GetOrCreateCredential_CachesErrors(t *testing.T) { + ctx := context.Background() + providerConfig := &config.ProviderConfig{} + authClient := NewAuthBase(providerConfig) + + factoryCallCount := 0 + expectedErr := errors.New("authentication failed") + + factory := func() (azcore.TokenCredential, error) { + factoryCallCount++ + return nil, expectedErr + } + + // First call - should get the error + cred1, err1 := authClient.getOrCreateCredential(ctx, credTypeClientSecret, factory) + assert.Nil(t, cred1) + assert.Equal(t, expectedErr, err1) + assert.Equal(t, 1, factoryCallCount, "Factory should be called exactly once") + + // Second call - should return the cached error without calling factory again + cred2, err2 := authClient.getOrCreateCredential(ctx, credTypeClientSecret, factory) + assert.Nil(t, cred2) + assert.Equal(t, expectedErr, err2) + assert.Equal(t, 1, factoryCallCount, "Factory should not be called again") +} + +func TestUnit_GetOrCreateCredential_ConcurrentAccess(t *testing.T) { + ctx := context.Background() + providerConfig := &config.ProviderConfig{} + authClient := NewAuthBase(providerConfig) + + var factoryCallCount int32 + mockCred := &mockTokenCredential{tokenValue: "concurrent-token"} + + factory := func() (azcore.TokenCredential, error) { + atomic.AddInt32(&factoryCallCount, 1) + time.Sleep(10 * time.Millisecond) + return mockCred, nil + } + + const numGoroutines = 50 + var wg sync.WaitGroup + wg.Add(numGoroutines) + + credentials := make([]azcore.TokenCredential, numGoroutines) + errs := make([]error, numGoroutines) + + for i := range numGoroutines { + go func(index int) { + defer wg.Done() + cred, err := authClient.getOrCreateCredential(ctx, credTypeClientSecret, factory) + credentials[index] = cred + errs[index] = err + }(i) + } + + wg.Wait() + + assert.Equal(t, int32(1), atomic.LoadInt32(&factoryCallCount), + "Factory should be called exactly once even with concurrent access") + + for i, cred := range credentials { + assert.NoError(t, errs[i]) + assert.Same(t, mockCred, cred, "All goroutines should receive the same credential instance") + } +} + +func TestUnit_GetOrCreateCredential_ConcurrentAccessDifferentTypes(t *testing.T) { + ctx := context.Background() + providerConfig := &config.ProviderConfig{} + authClient := NewAuthBase(providerConfig) + + var clientSecretCallCount, cliCallCount, devCliCallCount int32 + clientSecretCred := &mockTokenCredential{tokenValue: "client-secret"} + cliCred := &mockTokenCredential{tokenValue: "cli"} + devCliCred := &mockTokenCredential{tokenValue: "dev-cli"} + + factories := map[credentialType]func() (azcore.TokenCredential, error){ + credTypeClientSecret: func() (azcore.TokenCredential, error) { + atomic.AddInt32(&clientSecretCallCount, 1) + time.Sleep(5 * time.Millisecond) + return clientSecretCred, nil + }, + credTypeCLI: func() (azcore.TokenCredential, error) { + atomic.AddInt32(&cliCallCount, 1) + time.Sleep(5 * time.Millisecond) + return cliCred, nil + }, + credTypeDevCLI: func() (azcore.TokenCredential, error) { + atomic.AddInt32(&devCliCallCount, 1) + time.Sleep(5 * time.Millisecond) + return devCliCred, nil + }, + } + + expectedCreds := map[credentialType]azcore.TokenCredential{ + credTypeClientSecret: clientSecretCred, + credTypeCLI: cliCred, + credTypeDevCLI: devCliCred, + } + + const numGoroutinesPerType = 20 + var wg sync.WaitGroup + + type result struct { + credType credentialType + cred azcore.TokenCredential + err error + } + results := make(chan result, numGoroutinesPerType*3) + + for credType, factory := range factories { + for range numGoroutinesPerType { + wg.Add(1) + go func(ct credentialType, f func() (azcore.TokenCredential, error)) { + defer wg.Done() + cred, err := authClient.getOrCreateCredential(ctx, ct, f) + results <- result{credType: ct, cred: cred, err: err} + }(credType, factory) + } + } + + wg.Wait() + close(results) + + assert.Equal(t, int32(1), atomic.LoadInt32(&clientSecretCallCount)) + assert.Equal(t, int32(1), atomic.LoadInt32(&cliCallCount)) + assert.Equal(t, int32(1), atomic.LoadInt32(&devCliCallCount)) + + for r := range results { + assert.NoError(t, r.err) + assert.Same(t, expectedCreds[r.credType], r.cred, + "Credential for type %s should be the expected instance", r.credType) + } +} + +func TestUnit_NewAuthBase_InitializesCredentialsMap(t *testing.T) { + providerConfig := &config.ProviderConfig{} + authClient := NewAuthBase(providerConfig) + + assert.NotNil(t, authClient.credentials, "Credentials map should be initialized") + assert.Empty(t, authClient.credentials, "Credentials map should be empty initially") + assert.Equal(t, providerConfig, authClient.config, "Config should be set correctly") +} + +func TestUnit_CredentialTypes_AreUnique(t *testing.T) { + types := []credentialType{ + credTypeClientSecret, + credTypeClientCertificate, + credTypeCLI, + credTypeDevCLI, + credTypeAzDOPipelines, + credTypeOIDC, + credTypeUserManagedIdentity, + credTypeSystemManagedIdentity, + } + + seen := make(map[credentialType]bool) + for _, ct := range types { + assert.False(t, seen[ct], "Credential type %s should be unique", ct) + seen[ct] = true + } +} + +func TestUnit_GetTokenForScopes_TestMode(t *testing.T) { + ctx := context.Background() + providerConfig := &config.ProviderConfig{ + TestMode: true, + } + authClient := NewAuthBase(providerConfig) + + scopes := []string{"https://management.azure.com/.default"} + + token, err := authClient.GetTokenForScopes(ctx, scopes) + + assert.NoError(t, err) + assert.NotNil(t, token) + assert.Equal(t, "test_mode_mock_token_value", *token) + + assert.Empty(t, authClient.credentials, "No credentials should be created in test mode") +} + +func TestUnit_GetTokenForScopes_NoCredentials(t *testing.T) { + ctx := context.Background() + providerConfig := &config.ProviderConfig{} + authClient := NewAuthBase(providerConfig) + + scopes := []string{"https://management.azure.com/.default"} + + token, err := authClient.GetTokenForScopes(ctx, scopes) + + assert.Nil(t, token) + assert.Error(t, err) + assert.Equal(t, "no credentials provided", err.Error()) +} + +func TestUnit_TokenExpiredError(t *testing.T) { + err := &TokenExpiredError{Message: "token has expired"} + assert.Equal(t, "token has expired", err.Error()) +} + +type testableAuth struct { + *Auth +} + +func (ta *testableAuth) injectCredential(credType credentialType, cred azcore.TokenCredential) { + ta.mu.Lock() + defer ta.mu.Unlock() + holder := &credentialHolder{credential: cred} + holder.once.Do(func() {}) + ta.credentials[credType] = holder +} + +func (ta *testableAuth) injectCredentialError(credType credentialType, err error) { + ta.mu.Lock() + defer ta.mu.Unlock() + holder := &credentialHolder{err: err} + holder.once.Do(func() {}) + ta.credentials[credType] = holder +} + +func TestUnit_AuthenticateUserManagedIdentity_WithMockedCredential(t *testing.T) { + ctx := context.Background() + expiresOn := time.Now().Add(1 * time.Hour) + mockCred := &mockTokenCredential{ + tokenValue: "user-managed-identity-token", + expiresOn: expiresOn, + } + + providerConfig := &config.ProviderConfig{ + UseMsi: true, + ClientId: "test-client-id", + } + authClient := &testableAuth{Auth: NewAuthBase(providerConfig)} + authClient.injectCredential(credTypeUserManagedIdentity, mockCred) + + scopes := []string{"https://management.azure.com/.default"} + + token, tokenExpiry, err := authClient.AuthenticateUserManagedIdentity(ctx, scopes) + + require.NoError(t, err) + assert.Equal(t, "user-managed-identity-token", token) + assert.Equal(t, expiresOn, tokenExpiry) + assert.Equal(t, int32(1), atomic.LoadInt32(&mockCred.getTokenCallCount)) +} + +func TestUnit_AuthenticateUserManagedIdentity_GetTokenError(t *testing.T) { + ctx := context.Background() + mockCred := &mockTokenCredential{ + err: errors.New("failed to get token for user managed identity"), + } + + providerConfig := &config.ProviderConfig{ + UseMsi: true, + ClientId: "test-client-id", + } + authClient := &testableAuth{Auth: NewAuthBase(providerConfig)} + authClient.injectCredential(credTypeUserManagedIdentity, mockCred) + + scopes := []string{"https://management.azure.com/.default"} + + token, tokenExpiry, err := authClient.AuthenticateUserManagedIdentity(ctx, scopes) + + assert.Error(t, err) + assert.Equal(t, "", token) + assert.Equal(t, time.Time{}, tokenExpiry) + assert.Contains(t, err.Error(), "failed to get token") +} + +func TestUnit_AuthenticateUserManagedIdentity_CredentialCreationError(t *testing.T) { + ctx := context.Background() + + providerConfig := &config.ProviderConfig{ + UseMsi: true, + ClientId: "test-client-id", + } + authClient := &testableAuth{Auth: NewAuthBase(providerConfig)} + authClient.injectCredentialError(credTypeUserManagedIdentity, errors.New("credential creation failed")) + + scopes := []string{"https://management.azure.com/.default"} + + token, tokenExpiry, err := authClient.AuthenticateUserManagedIdentity(ctx, scopes) + + assert.Error(t, err) + assert.Equal(t, "", token) + assert.Equal(t, time.Time{}, tokenExpiry) + assert.Equal(t, "credential creation failed", err.Error()) +} + +func TestUnit_AuthenticateSystemManagedIdentity_WithMockedCredential(t *testing.T) { + ctx := context.Background() + expiresOn := time.Now().Add(1 * time.Hour) + mockCred := &mockTokenCredential{ + tokenValue: "system-managed-identity-token", + expiresOn: expiresOn, + } + + providerConfig := &config.ProviderConfig{ + UseMsi: true, + } + authClient := &testableAuth{Auth: NewAuthBase(providerConfig)} + authClient.injectCredential(credTypeSystemManagedIdentity, mockCred) + + scopes := []string{"https://management.azure.com/.default"} + + token, tokenExpiry, err := authClient.AuthenticateSystemManagedIdentity(ctx, scopes) + + require.NoError(t, err) + assert.Equal(t, "system-managed-identity-token", token) + assert.Equal(t, expiresOn, tokenExpiry) + assert.Equal(t, int32(1), atomic.LoadInt32(&mockCred.getTokenCallCount)) +} + +func TestUnit_AuthenticateSystemManagedIdentity_GetTokenError(t *testing.T) { + ctx := context.Background() + mockCred := &mockTokenCredential{ + err: errors.New("system identity token retrieval failed"), + } + + providerConfig := &config.ProviderConfig{ + UseMsi: true, + } + authClient := &testableAuth{Auth: NewAuthBase(providerConfig)} + authClient.injectCredential(credTypeSystemManagedIdentity, mockCred) + + scopes := []string{"https://management.azure.com/.default"} + + token, tokenExpiry, err := authClient.AuthenticateSystemManagedIdentity(ctx, scopes) + + assert.Error(t, err) + assert.Equal(t, "", token) + assert.Equal(t, time.Time{}, tokenExpiry) +} From 6e4932bdfb3858e8967e0d70a10695e8dba32159 Mon Sep 17 00:00:00 2001 From: Engin Polat <118744+polatengin@users.noreply.github.com> Date: Mon, 22 Dec 2025 07:45:05 -0800 Subject: [PATCH 2/2] Update internal/api/auth_test.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- internal/api/auth_test.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/internal/api/auth_test.go b/internal/api/auth_test.go index f37822fb7..8c02f1e64 100644 --- a/internal/api/auth_test.go +++ b/internal/api/auth_test.go @@ -6,10 +6,6 @@ package api import ( "context" "errors" - "net/http" - "net/http/httptest" - "os" - "path/filepath" "sync" "sync/atomic" "testing"