diff --git a/cli/exec.go b/cli/exec.go index f9d8f596..bbf82d2d 100644 --- a/cli/exec.go +++ b/cli/exec.go @@ -121,6 +121,7 @@ func ConfigureExecCommand(app *kingpin.Application, a *AwsVault) { cmd.Action(func(c *kingpin.ParseContext) (err error) { input.Config.MfaPromptMethod = a.PromptDriver(hasBackgroundServer(input)) input.Config.NonChainedGetSessionTokenDuration = input.SessionDuration + input.Config.ChainedGetSessionTokenDuration = input.SessionDuration input.Config.AssumeRoleDuration = input.SessionDuration input.Config.SSOUseStdout = input.UseStdout input.ShowHelpMessages = !a.Debug && input.Command == "" && isATerminal() && os.Getenv("AWS_VAULT_DISABLE_HELP_MESSAGE") != "1" diff --git a/cli/export.go b/cli/export.go index 21465b23..7b7c0d20 100644 --- a/cli/export.go +++ b/cli/export.go @@ -67,6 +67,7 @@ func ConfigureExportCommand(app *kingpin.Application, a *AwsVault) { cmd.Action(func(c *kingpin.ParseContext) (err error) { input.Config.MfaPromptMethod = a.PromptDriver(false) input.Config.NonChainedGetSessionTokenDuration = input.SessionDuration + input.Config.ChainedGetSessionTokenDuration = input.SessionDuration input.Config.AssumeRoleDuration = input.SessionDuration input.Config.SSOUseStdout = input.UseStdout diff --git a/cli/login.go b/cli/login.go index 2f330129..5dcc9018 100644 --- a/cli/login.go +++ b/cli/login.go @@ -70,6 +70,7 @@ func ConfigureLoginCommand(app *kingpin.Application, a *AwsVault) { cmd.Action(func(c *kingpin.ParseContext) (err error) { input.Config.MfaPromptMethod = a.PromptDriver(false) input.Config.NonChainedGetSessionTokenDuration = input.SessionDuration + input.Config.ChainedGetSessionTokenDuration = input.SessionDuration input.Config.AssumeRoleDuration = input.SessionDuration input.Config.GetFederationTokenDuration = input.SessionDuration keyring, err := a.Keyring() diff --git a/vault/vault.go b/vault/vault.go index 69dd405d..71099be5 100644 --- a/vault/vault.go +++ b/vault/vault.go @@ -46,6 +46,17 @@ func isMasterCredentialsProvider(credsProvider aws.CredentialsProvider) bool { return ok } +func isTemporaryCredentialsProvider(credsProvider aws.CredentialsProvider) bool { + switch p := credsProvider.(type) { + case *SessionTokenProvider, *AssumeRoleProvider, *SSORoleCredentialsProvider, *AssumeRoleWithWebIdentityProvider: + return true + case *CachedSessionProvider: + return isTemporaryCredentialsProvider(p.SessionProvider) + default: + return false + } +} + // NewMasterCredentialsProvider creates a provider for the master credentials func NewMasterCredentialsProvider(k *CredentialKeyring, credentialsName string) *KeyringProvider { return &KeyringProvider{k, credentialsName} @@ -254,22 +265,38 @@ func (t *TempCredentialsCreator) getSourceCredWithSession(config *ProfileConfig, return nil, err } - if hasStoredCredentials || !config.HasRole() { - if canUseGetSessionToken, reason := t.canUseGetSessionToken(config); !canUseGetSessionToken { - log.Printf("profile %s: skipping GetSessionToken because %s", config.ProfileName, reason) - if !config.HasRole() { - return sourcecredsProvider, nil + isSourceForRoleProfile := config.ChainedFromProfile != nil && config.ChainedFromProfile.HasRole() + + if !config.HasRole() || isSourceForRoleProfile { + if isMasterCredentialsProvider(sourcecredsProvider) || isSourceForRoleProfile { + canUseGetSessionToken, reason := t.canUseGetSessionToken(config) + if !canUseGetSessionToken { + log.Printf("profile %s: skipping GetSessionToken because %s", config.ProfileName, reason) + if !config.HasRole() { + return sourcecredsProvider, nil + } + } else { + t.chainedMfa = config.MfaSerial + log.Printf("profile %s: using GetSessionToken %s", config.ProfileName, mfaDetails(false, config)) + sourcecredsProvider, err = NewSessionTokenProvider(sourcecredsProvider, t.Keyring.Keyring, config, !t.DisableCache) + if !config.HasRole() || err != nil { + return sourcecredsProvider, err + } } } - t.chainedMfa = config.MfaSerial - log.Printf("profile %s: using GetSessionToken %s", config.ProfileName, mfaDetails(false, config)) - sourcecredsProvider, err = NewSessionTokenProvider(sourcecredsProvider, t.Keyring.Keyring, config, !t.DisableCache) - if !config.HasRole() || err != nil { - return sourcecredsProvider, err - } } if config.HasRole() { + if isTemporaryCredentialsProvider(sourcecredsProvider) && config.AssumeRoleDuration > roleChainingMaximumDuration { + log.Printf( + "profile %s: capping AssumeRole duration from %s to AWS maximum %s for role chaining", + config.ProfileName, + config.AssumeRoleDuration, + roleChainingMaximumDuration, + ) + config.AssumeRoleDuration = roleChainingMaximumDuration + } + isMfaChained := config.MfaSerial != "" && config.MfaSerial == t.chainedMfa if isMfaChained { config.MfaSerial = "" @@ -278,16 +305,6 @@ func (t *TempCredentialsCreator) getSourceCredWithSession(config *ProfileConfig, return NewAssumeRoleProvider(sourcecredsProvider, t.Keyring.Keyring, config, !t.DisableCache) } - if isMasterCredentialsProvider(sourcecredsProvider) { - canUseGetSessionToken, reason := t.canUseGetSessionToken(config) - if canUseGetSessionToken { - t.chainedMfa = config.MfaSerial - log.Printf("profile %s: using GetSessionToken %s", config.ProfileName, mfaDetails(false, config)) - return NewSessionTokenProvider(sourcecredsProvider, t.Keyring.Keyring, config, !t.DisableCache) - } - log.Printf("profile %s: skipping GetSessionToken because %s", config.ProfileName, reason) - } - return sourcecredsProvider, nil } @@ -341,9 +358,6 @@ func (t *TempCredentialsCreator) canUseGetSessionToken(c *ProfileConfig) (bool, return false, fmt.Sprintf("MFA serial doesn't match profile '%s'", c.ChainedFromProfile.ProfileName) } - if c.ChainedFromProfile.AssumeRoleDuration > roleChainingMaximumDuration { - return false, fmt.Sprintf("duration %s in profile '%s' is greater than the AWS maximum %s for chaining MFA", c.ChainedFromProfile.AssumeRoleDuration, c.ChainedFromProfile.ProfileName, roleChainingMaximumDuration) - } } return true, "" diff --git a/vault/vault_test.go b/vault/vault_test.go index 75812bab..fd82d9e6 100644 --- a/vault/vault_test.go +++ b/vault/vault_test.go @@ -3,9 +3,14 @@ package vault_test import ( "os" "testing" + "bytes" + "log" + "strings" + "time" "github.com/byteness/keyring" "github.com/byteness/aws-vault/v7/vault" + "github.com/aws/aws-sdk-go-v2/aws" ) func TestUsageWebIdentityExample(t *testing.T) { @@ -123,3 +128,251 @@ sso_registration_scopes=sso:account:access t.Fatalf("Expected AccountID to be 2160xxxx, got %s", ssoProvider.AccountID) } } + +// Ensures direct role login is not treated as chained MFA. +func TestDirectRoleLoginDoesNotUseGetSessionToken(t *testing.T) { + f := newConfigFile(t, []byte(` +[profile target] +role_arn=arn:aws:iam::111111111111:role/target +mfa_serial=arn:aws:iam::111111111111:mfa/user +`)) + defer os.Remove(f) + + configFile, err := vault.LoadConfig(f) + if err != nil { + t.Fatal(err) + } + configLoader := &vault.ConfigLoader{File: configFile, ActiveProfile: "target"} + config, err := configLoader.GetProfileConfig("target") + if err != nil { + t.Fatalf("Should have found a profile: %v", err) + } + config.MfaToken = "123456" + + ckr := &vault.CredentialKeyring{Keyring: keyring.NewArrayKeyring([]keyring.Item{})} + err = ckr.Set("target", aws.Credentials{AccessKeyID: "id", SecretAccessKey: "secret"}) + if err != nil { + t.Fatal(err) + } + + var buf bytes.Buffer + prevWriter := log.Writer() + prevFlags := log.Flags() + prevPrefix := log.Prefix() + log.SetOutput(&buf) + log.SetFlags(0) + log.SetPrefix("") + defer func() { + log.SetOutput(prevWriter) + log.SetFlags(prevFlags) + log.SetPrefix(prevPrefix) + }() + + _, err = vault.NewTempCredentialsProvider(config, ckr, false, true) + if err != nil { + t.Fatal(err) + } + + logs := buf.String() + + if strings.Contains(logs, "profile target: using GetSessionToken") { + t.Fatalf("did not expect GetSessionToken for non-chained role profile, logs:\n%s", logs) + } + if !strings.Contains(logs, "profile target: using AssumeRole") { + t.Fatalf("expected AssumeRole with MFA, logs:\n%s", logs) + } +} + +// Ensures role->role chaining keeps MFA context by priming with GetSessionToken before chained AssumeRole calls. +func TestRoleChainingMfaUsesGetSessionTokenBeforeAssumeRole(t *testing.T) { + f := newConfigFile(t, []byte(` +[profile source] +role_arn=arn:aws:iam::111111111111:role/source +mfa_serial=arn:aws:iam::111111111111:mfa/user + +[profile target] +source_profile=source +role_arn=arn:aws:iam::222222222222:role/target +mfa_serial=arn:aws:iam::111111111111:mfa/user +`)) + defer os.Remove(f) + + configFile, err := vault.LoadConfig(f) + if err != nil { + t.Fatal(err) + } + configLoader := &vault.ConfigLoader{File: configFile, ActiveProfile: "target"} + config, err := configLoader.GetProfileConfig("target") + if err != nil { + t.Fatalf("Should have found a profile: %v", err) + } + config.MfaToken = "123456" + config.SourceProfile.MfaToken = "123456" + + ckr := &vault.CredentialKeyring{Keyring: keyring.NewArrayKeyring([]keyring.Item{})} + err = ckr.Set("source", aws.Credentials{AccessKeyID: "id", SecretAccessKey: "secret"}) + if err != nil { + t.Fatal(err) + } + + var buf bytes.Buffer + prevWriter := log.Writer() + prevFlags := log.Flags() + prevPrefix := log.Prefix() + log.SetOutput(&buf) + log.SetFlags(0) + log.SetPrefix("") + defer func() { + log.SetOutput(prevWriter) + log.SetFlags(prevFlags) + log.SetPrefix(prevPrefix) + }() + + _, err = vault.NewTempCredentialsProvider(config, ckr, false, true) + if err != nil { + t.Fatal(err) + } + + logs := buf.String() + idxSession := strings.Index(logs, "profile source: using GetSessionToken") + idxSourceAssume := strings.Index(logs, "profile source: using AssumeRole") + idxTargetAssume := strings.Index(logs, "profile target: using AssumeRole") + + if idxSession == -1 || idxSourceAssume == -1 || idxTargetAssume == -1 { + t.Fatalf("expected source GetSessionToken then source/target AssumeRole, logs:\n%s", logs) + } + if !(idxSession < idxSourceAssume && idxSourceAssume < idxTargetAssume) { + t.Fatalf("unexpected flow order, logs:\n%s", logs) + } +} + +// Ensures flows that are not real role chaining do not go through the chained MFA path. +func TestNonRoleChainingFlowDoesNotUseChainedMfaPath(t *testing.T) { + f := newConfigFile(t, []byte(` +[profile source] +role_arn=arn:aws:iam::111111111111:role/source +mfa_serial=arn:aws:iam::111111111111:mfa/user + +[profile target] +source_profile=source +`)) + defer os.Remove(f) + + configFile, err := vault.LoadConfig(f) + if err != nil { + t.Fatal(err) + } + configLoader := &vault.ConfigLoader{File: configFile, ActiveProfile: "target"} + config, err := configLoader.GetProfileConfig("target") + if err != nil { + t.Fatalf("Should have found a profile: %v", err) + } + config.MfaPromptMethod = "terminal" + config.SourceProfile.MfaToken = "123456" + + ckr := &vault.CredentialKeyring{Keyring: keyring.NewArrayKeyring([]keyring.Item{})} + err = ckr.Set("source", aws.Credentials{AccessKeyID: "id", SecretAccessKey: "secret"}) + if err != nil { + t.Fatal(err) + } + + var buf bytes.Buffer + prevWriter := log.Writer() + prevFlags := log.Flags() + prevPrefix := log.Prefix() + log.SetOutput(&buf) + log.SetFlags(0) + log.SetPrefix("") + defer func() { + log.SetOutput(prevWriter) + log.SetFlags(prevFlags) + log.SetPrefix(prevPrefix) + }() + + _, err = vault.NewTempCredentialsProvider(config, ckr, false, true) + if err != nil { + t.Fatal(err) + } + + logs := buf.String() + + if strings.Contains(logs, "profile source: using GetSessionToken") { + t.Fatalf("did not expect GetSessionToken for role source chained to non-role target, logs:\n%s", logs) + } + if !strings.Contains(logs, "profile source: using AssumeRole") { + t.Fatalf("expected source to AssumeRole with MFA, logs:\n%s", logs) + } +} + +// Ensures chained AssumeRole duration is capped to 1 hour when a higher duration is requested. +func TestRoleChainingCapsAssumeRoleDurationToOneHour(t *testing.T) { + f := newConfigFile(t, []byte(` +[profile source] +role_arn=arn:aws:iam::111111111111:role/source +mfa_serial=arn:aws:iam::111111111111:mfa/user + +[profile target] +source_profile=source +role_arn=arn:aws:iam::222222222222:role/target +mfa_serial=arn:aws:iam::111111111111:mfa/user +`)) + defer os.Remove(f) + + base := vault.ProfileConfig{AssumeRoleDuration: 12 * time.Hour} + configFile, err := vault.LoadConfig(f) + if err != nil { + t.Fatal(err) + } + configLoader := vault.NewConfigLoader(base, configFile, "target") + config, err := configLoader.GetProfileConfig("target") + if err != nil { + t.Fatalf("Should have found a profile: %v", err) + } + config.MfaToken = "123456" + config.SourceProfile.MfaToken = "123456" + + ckr := &vault.CredentialKeyring{Keyring: keyring.NewArrayKeyring([]keyring.Item{})} + err = ckr.Set("source", aws.Credentials{AccessKeyID: "id", SecretAccessKey: "secret"}) + if err != nil { + t.Fatal(err) + } + + var buf bytes.Buffer + prevWriter := log.Writer() + prevFlags := log.Flags() + prevPrefix := log.Prefix() + log.SetOutput(&buf) + log.SetFlags(0) + log.SetPrefix("") + defer func() { + log.SetOutput(prevWriter) + log.SetFlags(prevFlags) + log.SetPrefix(prevPrefix) + }() + + _, err = vault.NewTempCredentialsProvider(config, ckr, false, true) + if err != nil { + t.Fatal(err) + } + + if config.SourceProfile.AssumeRoleDuration != time.Hour { + t.Fatalf("expected source AssumeRole duration to be capped to 1h, got %s", config.SourceProfile.AssumeRoleDuration) + } + if config.AssumeRoleDuration != time.Hour { + t.Fatalf("expected target AssumeRole duration to be capped to 1h, got %s", config.AssumeRoleDuration) + } + + logs := buf.String() + if !strings.Contains(logs, "profile source: capping AssumeRole duration") { + t.Fatalf("expected source AssumeRole capping log, logs:\n%s", logs) + } + if !strings.Contains(logs, "profile target: capping AssumeRole duration") { + t.Fatalf("expected target AssumeRole capping log, logs:\n%s", logs) + } + if !strings.Contains(logs, "using AssumeRole") { + t.Fatalf("expected chained AssumeRole flow after duration cap, logs:\n%s", logs) + } + if !strings.Contains(logs, "using GetSessionToken") { + t.Fatalf("expected source GetSessionToken flow after duration cap, logs:\n%s", logs) + } +}