Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cli/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ func ConfigureExecCommand(app *kingpin.Application, a *AwsVault) {
cmd.Action(func(c *kingpin.ParseContext) (err error) {
input.Config.MfaPromptMethod = a.PromptDriver(hasBackgroundServer(input))
input.Config.NonChainedGetSessionTokenDuration = input.SessionDuration
input.Config.ChainedGetSessionTokenDuration = input.SessionDuration
input.Config.AssumeRoleDuration = input.SessionDuration
input.Config.SSOUseStdout = input.UseStdout
input.ShowHelpMessages = !a.Debug && input.Command == "" && isATerminal() && os.Getenv("AWS_VAULT_DISABLE_HELP_MESSAGE") != "1"
Expand Down
1 change: 1 addition & 0 deletions cli/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ func ConfigureExportCommand(app *kingpin.Application, a *AwsVault) {
cmd.Action(func(c *kingpin.ParseContext) (err error) {
input.Config.MfaPromptMethod = a.PromptDriver(false)
input.Config.NonChainedGetSessionTokenDuration = input.SessionDuration
input.Config.ChainedGetSessionTokenDuration = input.SessionDuration
input.Config.AssumeRoleDuration = input.SessionDuration
input.Config.SSOUseStdout = input.UseStdout

Expand Down
1 change: 1 addition & 0 deletions cli/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ func ConfigureLoginCommand(app *kingpin.Application, a *AwsVault) {
cmd.Action(func(c *kingpin.ParseContext) (err error) {
input.Config.MfaPromptMethod = a.PromptDriver(false)
input.Config.NonChainedGetSessionTokenDuration = input.SessionDuration
input.Config.ChainedGetSessionTokenDuration = input.SessionDuration
input.Config.AssumeRoleDuration = input.SessionDuration
input.Config.GetFederationTokenDuration = input.SessionDuration
keyring, err := a.Keyring()
Expand Down
62 changes: 38 additions & 24 deletions vault/vault.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ func isMasterCredentialsProvider(credsProvider aws.CredentialsProvider) bool {
return ok
}

func isTemporaryCredentialsProvider(credsProvider aws.CredentialsProvider) bool {
switch p := credsProvider.(type) {
case *SessionTokenProvider, *AssumeRoleProvider, *SSORoleCredentialsProvider, *AssumeRoleWithWebIdentityProvider:
return true
case *CachedSessionProvider:
return isTemporaryCredentialsProvider(p.SessionProvider)
default:
return false
}
}

// NewMasterCredentialsProvider creates a provider for the master credentials
func NewMasterCredentialsProvider(k *CredentialKeyring, credentialsName string) *KeyringProvider {
return &KeyringProvider{k, credentialsName}
Expand Down Expand Up @@ -254,22 +265,38 @@ func (t *TempCredentialsCreator) getSourceCredWithSession(config *ProfileConfig,
return nil, err
}

if hasStoredCredentials || !config.HasRole() {
if canUseGetSessionToken, reason := t.canUseGetSessionToken(config); !canUseGetSessionToken {
log.Printf("profile %s: skipping GetSessionToken because %s", config.ProfileName, reason)
if !config.HasRole() {
return sourcecredsProvider, nil
isSourceForRoleProfile := config.ChainedFromProfile != nil && config.ChainedFromProfile.HasRole()

if !config.HasRole() || isSourceForRoleProfile {
if isMasterCredentialsProvider(sourcecredsProvider) || isSourceForRoleProfile {
canUseGetSessionToken, reason := t.canUseGetSessionToken(config)
if !canUseGetSessionToken {
log.Printf("profile %s: skipping GetSessionToken because %s", config.ProfileName, reason)
if !config.HasRole() {
return sourcecredsProvider, nil
}
} else {
t.chainedMfa = config.MfaSerial
log.Printf("profile %s: using GetSessionToken %s", config.ProfileName, mfaDetails(false, config))
sourcecredsProvider, err = NewSessionTokenProvider(sourcecredsProvider, t.Keyring.Keyring, config, !t.DisableCache)
if !config.HasRole() || err != nil {
return sourcecredsProvider, err
}
}
}
t.chainedMfa = config.MfaSerial
log.Printf("profile %s: using GetSessionToken %s", config.ProfileName, mfaDetails(false, config))
sourcecredsProvider, err = NewSessionTokenProvider(sourcecredsProvider, t.Keyring.Keyring, config, !t.DisableCache)
if !config.HasRole() || err != nil {
return sourcecredsProvider, err
}
}

if config.HasRole() {
if isTemporaryCredentialsProvider(sourcecredsProvider) && config.AssumeRoleDuration > roleChainingMaximumDuration {
log.Printf(
"profile %s: capping AssumeRole duration from %s to AWS maximum %s for role chaining",
config.ProfileName,
config.AssumeRoleDuration,
roleChainingMaximumDuration,
)
config.AssumeRoleDuration = roleChainingMaximumDuration
}

isMfaChained := config.MfaSerial != "" && config.MfaSerial == t.chainedMfa
if isMfaChained {
config.MfaSerial = ""
Expand All @@ -278,16 +305,6 @@ func (t *TempCredentialsCreator) getSourceCredWithSession(config *ProfileConfig,
return NewAssumeRoleProvider(sourcecredsProvider, t.Keyring.Keyring, config, !t.DisableCache)
}

if isMasterCredentialsProvider(sourcecredsProvider) {
canUseGetSessionToken, reason := t.canUseGetSessionToken(config)
if canUseGetSessionToken {
t.chainedMfa = config.MfaSerial
log.Printf("profile %s: using GetSessionToken %s", config.ProfileName, mfaDetails(false, config))
return NewSessionTokenProvider(sourcecredsProvider, t.Keyring.Keyring, config, !t.DisableCache)
}
log.Printf("profile %s: skipping GetSessionToken because %s", config.ProfileName, reason)
}

return sourcecredsProvider, nil
}

Expand Down Expand Up @@ -341,9 +358,6 @@ func (t *TempCredentialsCreator) canUseGetSessionToken(c *ProfileConfig) (bool,
return false, fmt.Sprintf("MFA serial doesn't match profile '%s'", c.ChainedFromProfile.ProfileName)
}

if c.ChainedFromProfile.AssumeRoleDuration > roleChainingMaximumDuration {
return false, fmt.Sprintf("duration %s in profile '%s' is greater than the AWS maximum %s for chaining MFA", c.ChainedFromProfile.AssumeRoleDuration, c.ChainedFromProfile.ProfileName, roleChainingMaximumDuration)
}
}

return true, ""
Expand Down
253 changes: 253 additions & 0 deletions vault/vault_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@ package vault_test
import (
"os"
"testing"
"bytes"
"log"
"strings"
"time"

"github.com/byteness/keyring"
"github.com/byteness/aws-vault/v7/vault"
"github.com/aws/aws-sdk-go-v2/aws"
)

func TestUsageWebIdentityExample(t *testing.T) {
Expand Down Expand Up @@ -123,3 +128,251 @@ sso_registration_scopes=sso:account:access
t.Fatalf("Expected AccountID to be 2160xxxx, got %s", ssoProvider.AccountID)
}
}

// Ensures direct role login is not treated as chained MFA.
func TestDirectRoleLoginDoesNotUseGetSessionToken(t *testing.T) {
f := newConfigFile(t, []byte(`
[profile target]
role_arn=arn:aws:iam::111111111111:role/target
mfa_serial=arn:aws:iam::111111111111:mfa/user
`))
defer os.Remove(f)

configFile, err := vault.LoadConfig(f)
if err != nil {
t.Fatal(err)
}
configLoader := &vault.ConfigLoader{File: configFile, ActiveProfile: "target"}
config, err := configLoader.GetProfileConfig("target")
if err != nil {
t.Fatalf("Should have found a profile: %v", err)
}
config.MfaToken = "123456"

ckr := &vault.CredentialKeyring{Keyring: keyring.NewArrayKeyring([]keyring.Item{})}
err = ckr.Set("target", aws.Credentials{AccessKeyID: "id", SecretAccessKey: "secret"})
if err != nil {
t.Fatal(err)
}

var buf bytes.Buffer
prevWriter := log.Writer()
prevFlags := log.Flags()
prevPrefix := log.Prefix()
log.SetOutput(&buf)
log.SetFlags(0)
log.SetPrefix("")
defer func() {
log.SetOutput(prevWriter)
log.SetFlags(prevFlags)
log.SetPrefix(prevPrefix)
}()

_, err = vault.NewTempCredentialsProvider(config, ckr, false, true)
if err != nil {
t.Fatal(err)
}

logs := buf.String()

if strings.Contains(logs, "profile target: using GetSessionToken") {
t.Fatalf("did not expect GetSessionToken for non-chained role profile, logs:\n%s", logs)
}
if !strings.Contains(logs, "profile target: using AssumeRole") {
t.Fatalf("expected AssumeRole with MFA, logs:\n%s", logs)
}
}

// Ensures role->role chaining keeps MFA context by priming with GetSessionToken before chained AssumeRole calls.
func TestRoleChainingMfaUsesGetSessionTokenBeforeAssumeRole(t *testing.T) {
f := newConfigFile(t, []byte(`
[profile source]
role_arn=arn:aws:iam::111111111111:role/source
mfa_serial=arn:aws:iam::111111111111:mfa/user

[profile target]
source_profile=source
role_arn=arn:aws:iam::222222222222:role/target
mfa_serial=arn:aws:iam::111111111111:mfa/user
`))
defer os.Remove(f)

configFile, err := vault.LoadConfig(f)
if err != nil {
t.Fatal(err)
}
configLoader := &vault.ConfigLoader{File: configFile, ActiveProfile: "target"}
config, err := configLoader.GetProfileConfig("target")
if err != nil {
t.Fatalf("Should have found a profile: %v", err)
}
config.MfaToken = "123456"
config.SourceProfile.MfaToken = "123456"

ckr := &vault.CredentialKeyring{Keyring: keyring.NewArrayKeyring([]keyring.Item{})}
err = ckr.Set("source", aws.Credentials{AccessKeyID: "id", SecretAccessKey: "secret"})
if err != nil {
t.Fatal(err)
}

var buf bytes.Buffer
prevWriter := log.Writer()
prevFlags := log.Flags()
prevPrefix := log.Prefix()
log.SetOutput(&buf)
log.SetFlags(0)
log.SetPrefix("")
defer func() {
log.SetOutput(prevWriter)
log.SetFlags(prevFlags)
log.SetPrefix(prevPrefix)
}()

_, err = vault.NewTempCredentialsProvider(config, ckr, false, true)
if err != nil {
t.Fatal(err)
}

logs := buf.String()
idxSession := strings.Index(logs, "profile source: using GetSessionToken")
idxSourceAssume := strings.Index(logs, "profile source: using AssumeRole")
idxTargetAssume := strings.Index(logs, "profile target: using AssumeRole")

if idxSession == -1 || idxSourceAssume == -1 || idxTargetAssume == -1 {
t.Fatalf("expected source GetSessionToken then source/target AssumeRole, logs:\n%s", logs)
}
if !(idxSession < idxSourceAssume && idxSourceAssume < idxTargetAssume) {
t.Fatalf("unexpected flow order, logs:\n%s", logs)
}
}

// Ensures flows that are not real role chaining do not go through the chained MFA path.
func TestNonRoleChainingFlowDoesNotUseChainedMfaPath(t *testing.T) {
f := newConfigFile(t, []byte(`
[profile source]
role_arn=arn:aws:iam::111111111111:role/source
mfa_serial=arn:aws:iam::111111111111:mfa/user

[profile target]
source_profile=source
`))
defer os.Remove(f)

configFile, err := vault.LoadConfig(f)
if err != nil {
t.Fatal(err)
}
configLoader := &vault.ConfigLoader{File: configFile, ActiveProfile: "target"}
config, err := configLoader.GetProfileConfig("target")
if err != nil {
t.Fatalf("Should have found a profile: %v", err)
}
config.MfaPromptMethod = "terminal"
config.SourceProfile.MfaToken = "123456"

ckr := &vault.CredentialKeyring{Keyring: keyring.NewArrayKeyring([]keyring.Item{})}
err = ckr.Set("source", aws.Credentials{AccessKeyID: "id", SecretAccessKey: "secret"})
if err != nil {
t.Fatal(err)
}

var buf bytes.Buffer
prevWriter := log.Writer()
prevFlags := log.Flags()
prevPrefix := log.Prefix()
log.SetOutput(&buf)
log.SetFlags(0)
log.SetPrefix("")
defer func() {
log.SetOutput(prevWriter)
log.SetFlags(prevFlags)
log.SetPrefix(prevPrefix)
}()

_, err = vault.NewTempCredentialsProvider(config, ckr, false, true)
if err != nil {
t.Fatal(err)
}

logs := buf.String()

if strings.Contains(logs, "profile source: using GetSessionToken") {
t.Fatalf("did not expect GetSessionToken for role source chained to non-role target, logs:\n%s", logs)
}
if !strings.Contains(logs, "profile source: using AssumeRole") {
t.Fatalf("expected source to AssumeRole with MFA, logs:\n%s", logs)
}
}

// Ensures chained AssumeRole duration is capped to 1 hour when a higher duration is requested.
func TestRoleChainingCapsAssumeRoleDurationToOneHour(t *testing.T) {
f := newConfigFile(t, []byte(`
[profile source]
role_arn=arn:aws:iam::111111111111:role/source
mfa_serial=arn:aws:iam::111111111111:mfa/user

[profile target]
source_profile=source
role_arn=arn:aws:iam::222222222222:role/target
mfa_serial=arn:aws:iam::111111111111:mfa/user
`))
defer os.Remove(f)

base := vault.ProfileConfig{AssumeRoleDuration: 12 * time.Hour}
configFile, err := vault.LoadConfig(f)
if err != nil {
t.Fatal(err)
}
configLoader := vault.NewConfigLoader(base, configFile, "target")
config, err := configLoader.GetProfileConfig("target")
if err != nil {
t.Fatalf("Should have found a profile: %v", err)
}
config.MfaToken = "123456"
config.SourceProfile.MfaToken = "123456"

ckr := &vault.CredentialKeyring{Keyring: keyring.NewArrayKeyring([]keyring.Item{})}
err = ckr.Set("source", aws.Credentials{AccessKeyID: "id", SecretAccessKey: "secret"})
if err != nil {
t.Fatal(err)
}

var buf bytes.Buffer
prevWriter := log.Writer()
prevFlags := log.Flags()
prevPrefix := log.Prefix()
log.SetOutput(&buf)
log.SetFlags(0)
log.SetPrefix("")
defer func() {
log.SetOutput(prevWriter)
log.SetFlags(prevFlags)
log.SetPrefix(prevPrefix)
}()

_, err = vault.NewTempCredentialsProvider(config, ckr, false, true)
if err != nil {
t.Fatal(err)
}

if config.SourceProfile.AssumeRoleDuration != time.Hour {
t.Fatalf("expected source AssumeRole duration to be capped to 1h, got %s", config.SourceProfile.AssumeRoleDuration)
}
if config.AssumeRoleDuration != time.Hour {
t.Fatalf("expected target AssumeRole duration to be capped to 1h, got %s", config.AssumeRoleDuration)
}

logs := buf.String()
if !strings.Contains(logs, "profile source: capping AssumeRole duration") {
t.Fatalf("expected source AssumeRole capping log, logs:\n%s", logs)
}
if !strings.Contains(logs, "profile target: capping AssumeRole duration") {
t.Fatalf("expected target AssumeRole capping log, logs:\n%s", logs)
}
if !strings.Contains(logs, "using AssumeRole") {
t.Fatalf("expected chained AssumeRole flow after duration cap, logs:\n%s", logs)
}
if !strings.Contains(logs, "using GetSessionToken") {
t.Fatalf("expected source GetSessionToken flow after duration cap, logs:\n%s", logs)
}
}
Loading