Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 145 additions & 78 deletions internal/api/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"net/url"
"os"
"strings"
"sync"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
Expand All @@ -31,8 +32,30 @@ func (e *TokenExpiredError) Error() string {
return e.Message
}

type credentialType string

const (
credTypeClientSecret credentialType = "client_secret"
credTypeClientCertificate credentialType = "client_certificate"
credTypeCLI credentialType = "cli"
credTypeDevCLI credentialType = "dev_cli"
credTypeAzDOPipelines credentialType = "azdo_pipelines"
credTypeOIDC credentialType = "oidc"
credTypeUserManagedIdentity credentialType = "user_managed_identity"
credTypeSystemManagedIdentity credentialType = "system_managed_identity"
)

type credentialHolder struct {
credential azcore.TokenCredential
once sync.Once
err error
}

type Auth struct {
config *config.ProviderConfig

credentials map[credentialType]*credentialHolder
mu sync.RWMutex
}

type OidcCredential struct {
Expand All @@ -55,48 +78,83 @@ type OidcCredentialOptions struct {

func NewAuthBase(configValue *config.ProviderConfig) *Auth {
return &Auth{
config: configValue,
config: configValue,
credentials: make(map[credentialType]*credentialHolder),
}
}

func (client *Auth) AuthenticateClientCertificate(ctx context.Context, scopes []string) (string, time.Time, error) {
cert, key, err := helpers.ConvertBase64ToCert(client.config.ClientCertificateRaw, client.config.ClientCertificatePassword)
if err != nil {
return "", time.Time{}, err
func (client *Auth) getOrCreateCredential(ctx context.Context, credType credentialType, factory func() (azcore.TokenCredential, error)) (azcore.TokenCredential, error) {
client.mu.RLock()
holder, exists := client.credentials[credType]
client.mu.RUnlock()

if !exists {
client.mu.Lock()
holder, exists = client.credentials[credType]
if !exists {
holder = &credentialHolder{}
client.credentials[credType] = holder
tflog.Debug(ctx, fmt.Sprintf("Created credential holder for type: %s", credType))
}
client.mu.Unlock()
}

azureCertCredentials, err := azidentity.NewClientCertificateCredential(
client.config.TenantId,
client.config.ClientId,
cert,
key,
&azidentity.ClientCertificateCredentialOptions{
AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs,
ClientOptions: azcore.ClientOptions{
Cloud: client.config.Cloud,
holder.once.Do(func() {
tflog.Debug(ctx, fmt.Sprintf("Initializing credential for type: %s", credType))
holder.credential, holder.err = factory()
if holder.err != nil {
tflog.Error(ctx, fmt.Sprintf("Failed to create credential for type %s: %s", credType, holder.err.Error()))
} else {
tflog.Debug(ctx, fmt.Sprintf("Successfully created credential for type: %s", credType))
}
})

return holder.credential, holder.err
Comment on lines +102 to +112
Copy link

Copilot AI Dec 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The credential caching mechanism caches errors permanently using sync.Once. This means if credential creation fails due to a transient error (e.g., temporary network issue, rate limiting), all subsequent attempts will return the same cached error without retrying credential creation. Consider whether this behavior is acceptable for your use case, or if you need a mechanism to retry credential creation after transient failures. For long-running provider instances, this could cause persistent authentication failures even after the underlying issue is resolved.

Suggested change
holder.once.Do(func() {
tflog.Debug(ctx, fmt.Sprintf("Initializing credential for type: %s", credType))
holder.credential, holder.err = factory()
if holder.err != nil {
tflog.Error(ctx, fmt.Sprintf("Failed to create credential for type %s: %s", credType, holder.err.Error()))
} else {
tflog.Debug(ctx, fmt.Sprintf("Successfully created credential for type: %s", credType))
}
})
return holder.credential, holder.err
tflog.Debug(ctx, fmt.Sprintf("Initializing credential for type: %s", credType))
credential, err := factory()
if err != nil {
holder.err = err
tflog.Error(ctx, fmt.Sprintf("Failed to create credential for type %s: %s", credType, err.Error()))
return nil, err
}
holder.credential = credential
holder.err = nil
tflog.Debug(ctx, fmt.Sprintf("Successfully created credential for type: %s", credType))
return holder.credential, nil

Copilot uses AI. Check for mistakes.
}

func (client *Auth) AuthenticateClientCertificate(ctx context.Context, scopes []string) (string, time.Time, error) {
cred, err := client.getOrCreateCredential(ctx, credTypeClientCertificate, func() (azcore.TokenCredential, error) {
cert, key, certErr := helpers.ConvertBase64ToCert(client.config.ClientCertificateRaw, client.config.ClientCertificatePassword)
if certErr != nil {
return nil, certErr
}

return azidentity.NewClientCertificateCredential(
client.config.TenantId,
client.config.ClientId,
cert,
key,
&azidentity.ClientCertificateCredentialOptions{
AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs,
ClientOptions: azcore.ClientOptions{
Cloud: client.config.Cloud,
},
},
},
)
)
})
if err != nil {
return "", time.Time{}, err
}
accessToken, err := azureCertCredentials.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))

accessToken, err := cred.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
if err != nil {
return "", time.Time{}, err
}
return accessToken.Token, accessToken.ExpiresOn, nil
}

func (client *Auth) AuthenticateUsingCli(ctx context.Context, scopes []string) (string, time.Time, error) {
azureCLICredentials, err := azidentity.NewAzureCLICredential(&azidentity.AzureCLICredentialOptions{
AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs,
TenantID: client.config.TenantId,
cred, err := client.getOrCreateCredential(ctx, credTypeCLI, func() (azcore.TokenCredential, error) {
return azidentity.NewAzureCLICredential(&azidentity.AzureCLICredentialOptions{
AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs,
TenantID: client.config.TenantId,
})
})
if err != nil {
return "", time.Time{}, err
}

accessToken, err := azureCLICredentials.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
accessToken, err := cred.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
if err != nil {
return "", time.Time{}, err
}
Expand All @@ -105,15 +163,17 @@ func (client *Auth) AuthenticateUsingCli(ctx context.Context, scopes []string) (
}

func (client *Auth) AuthenticateUsingAzureDeveloperCli(ctx context.Context, scopes []string) (string, time.Time, error) {
azureDeveloperCLICredentials, err := azidentity.NewAzureDeveloperCLICredential(&azidentity.AzureDeveloperCLICredentialOptions{
AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs,
TenantID: client.config.TenantId,
cred, err := client.getOrCreateCredential(ctx, credTypeDevCLI, func() (azcore.TokenCredential, error) {
return azidentity.NewAzureDeveloperCLICredential(&azidentity.AzureDeveloperCLICredentialOptions{
AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs,
TenantID: client.config.TenantId,
})
})
if err != nil {
return "", time.Time{}, err
}

accessToken, err := azureDeveloperCLICredentials.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
accessToken, err := cred.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
if err != nil {
return "", time.Time{}, err
}
Expand All @@ -122,21 +182,24 @@ func (client *Auth) AuthenticateUsingAzureDeveloperCli(ctx context.Context, scop
}

func (client *Auth) AuthenticateClientSecret(ctx context.Context, scopes []string) (string, time.Time, error) {
clientSecretCredential, err := azidentity.NewClientSecretCredential(
client.config.TenantId,
client.config.ClientId,
client.config.ClientSecret, &azidentity.ClientSecretCredentialOptions{
AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs,
ClientOptions: azcore.ClientOptions{
Cloud: client.config.Cloud,
cred, err := client.getOrCreateCredential(ctx, credTypeClientSecret, func() (azcore.TokenCredential, error) {
return azidentity.NewClientSecretCredential(
client.config.TenantId,
client.config.ClientId,
client.config.ClientSecret,
&azidentity.ClientSecretCredentialOptions{
AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs,
ClientOptions: azcore.ClientOptions{
Cloud: client.config.Cloud,
},
},
})
)
})
if err != nil {
return "", time.Time{}, err
}

accessToken, err := clientSecretCredential.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))

accessToken, err := cred.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
if err != nil {
return "", time.Time{}, err
}
Expand Down Expand Up @@ -182,32 +245,30 @@ func (w *OidcCredential) GetToken(ctx context.Context, opts policy.TokenRequestO
}

func (client *Auth) AuthenticateOIDC(ctx context.Context, scopes []string) (string, time.Time, error) {
var creds []azcore.TokenCredential

oidcCred, err := client.NewOidcCredential(&OidcCredentialOptions{
ClientOptions: azcore.ClientOptions{
Cloud: client.config.Cloud,
},
TenantID: client.config.TenantId,
ClientID: client.config.ClientId,
RequestToken: client.config.OidcRequestToken,
RequestUrl: client.config.OidcRequestUrl,
Token: client.config.OidcToken,
TokenFilePath: client.config.OidcTokenFilePath,
})
cred, err := client.getOrCreateCredential(ctx, credTypeOIDC, func() (azcore.TokenCredential, error) {
oidcCred, oidcErr := client.NewOidcCredential(&OidcCredentialOptions{
ClientOptions: azcore.ClientOptions{
Cloud: client.config.Cloud,
},
TenantID: client.config.TenantId,
ClientID: client.config.ClientId,
RequestToken: client.config.OidcRequestToken,
RequestUrl: client.config.OidcRequestUrl,
Token: client.config.OidcToken,
TokenFilePath: client.config.OidcTokenFilePath,
})
if oidcErr != nil {
return nil, oidcErr
}

return azidentity.NewChainedTokenCredential([]azcore.TokenCredential{oidcCred}, nil)
})
if err != nil {
tflog.Error(ctx, fmt.Sprintf("newDefaultAzureCredential failed to initialize oidc credential:\n\t%s", err.Error()))
return "", time.Time{}, err
}
creds = append(creds, oidcCred)

chain, err := azidentity.NewChainedTokenCredential(creds, nil)
if err != nil {
return "", time.Time{}, err
}

accessToken, err := chain.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
accessToken, err := cred.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
if err != nil {
return "", time.Time{}, err
}
Expand All @@ -216,17 +277,19 @@ func (client *Auth) AuthenticateOIDC(ctx context.Context, scopes []string) (stri
}

func (client *Auth) AuthenticateUserManagedIdentity(ctx context.Context, scopes []string) (string, time.Time, error) {
userManagedIdentityCredential, err := azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{
ID: azidentity.ClientID(client.config.ClientId),
ClientOptions: azcore.ClientOptions{
Cloud: client.config.Cloud,
},
cred, err := client.getOrCreateCredential(ctx, credTypeUserManagedIdentity, func() (azcore.TokenCredential, error) {
return azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{
ID: azidentity.ClientID(client.config.ClientId),
ClientOptions: azcore.ClientOptions{
Cloud: client.config.Cloud,
},
})
})
if err != nil {
return "", time.Time{}, err
}

accessToken, err := userManagedIdentityCredential.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
accessToken, err := cred.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
if err != nil {
return "", time.Time{}, err
}
Expand All @@ -235,16 +298,18 @@ func (client *Auth) AuthenticateUserManagedIdentity(ctx context.Context, scopes
}

func (client *Auth) AuthenticateSystemManagedIdentity(ctx context.Context, scopes []string) (string, time.Time, error) {
systemManagedIdentityCredential, err := azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{
ClientOptions: azcore.ClientOptions{
Cloud: client.config.Cloud,
},
cred, err := client.getOrCreateCredential(ctx, credTypeSystemManagedIdentity, func() (azcore.TokenCredential, error) {
return azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{
ClientOptions: azcore.ClientOptions{
Cloud: client.config.Cloud,
},
})
})
if err != nil {
return "", time.Time{}, err
}

accessToken, err := systemManagedIdentityCredential.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
accessToken, err := cred.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
if err != nil {
return "", time.Time{}, err
}
Expand All @@ -266,23 +331,25 @@ func (client *Auth) AuthenticateAzDOWorkloadIdentityFederation(ctx context.Conte
return "", time.Time{}, errors.New("could not obtain an OIDC request token for Azure DevOps Workload Identity Federation")
}

azdoWorkloadIdentityCredential, err := azidentity.NewAzurePipelinesCredential(
client.config.TenantId,
client.config.ClientId,
client.config.AzDOServiceConnectionID,
client.config.OidcRequestToken,
&azidentity.AzurePipelinesCredentialOptions{
AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs,
ClientOptions: azcore.ClientOptions{
Cloud: client.config.Cloud,
cred, err := client.getOrCreateCredential(ctx, credTypeAzDOPipelines, func() (azcore.TokenCredential, error) {
return azidentity.NewAzurePipelinesCredential(
client.config.TenantId,
client.config.ClientId,
client.config.AzDOServiceConnectionID,
client.config.OidcRequestToken,
&azidentity.AzurePipelinesCredentialOptions{
AdditionallyAllowedTenants: client.config.AuxiliaryTenantIDs,
ClientOptions: azcore.ClientOptions{
Cloud: client.config.Cloud,
},
},
},
)
)
})
if err != nil {
return "", time.Time{}, err
}

accessToken, err := azdoWorkloadIdentityCredential.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
accessToken, err := cred.GetToken(ctx, client.createTokenRequestOptions(ctx, scopes))
if err != nil {
return "", time.Time{}, err
}
Expand Down
Loading
Loading