diff --git a/internal/aws/awsutil/conn.go b/internal/aws/awsutil/conn.go index ca8eb9ca0c15..1cad10e98e4f 100644 --- a/internal/aws/awsutil/conn.go +++ b/internal/aws/awsutil/conn.go @@ -8,6 +8,7 @@ import ( "crypto/tls" "crypto/x509" "errors" + "log" "net/http" "net/url" "os" @@ -484,9 +485,11 @@ func newStsCredentials(c client.ConfigProvider, roleARN string, region string) * return credentials.NewCredentials(&stsCredentialProvider{regional: regional, partitional: partitional}) } -var ( - sourceAccount = os.Getenv("AMZ_SOURCE_ACCOUNT") // populates the "x-amz-source-account" header - sourceArn = os.Getenv("AMZ_SOURCE_ARN") // populates the "x-amz-source-arn" header +const ( + SourceArnHeaderKey = "x-amz-source-arn" + SourceAccountHeaderKey = "x-amz-source-account" + AmzSourceAccount = "AMZ_SOURCE_ACCOUNT" + AmzSourceArn = "AMZ_SOURCE_ARN" ) // newStsClient creates a new STS client with the provided config and options. @@ -497,14 +500,18 @@ var ( // // See https://docs.aws.amazon.com/IAM/latest/UserGuide/confused-deputy.html#cross-service-confused-deputy-prevention func newStsClient(p client.ConfigProvider, cfgs ...*aws.Config) *sts.STS { + + sourceAccount := os.Getenv(AmzSourceAccount) + sourceArn := os.Getenv(AmzSourceArn) + client := sts.New(p, cfgs...) if sourceAccount != "" && sourceArn != "" { client.Handlers.Sign.PushFront(func(r *request.Request) { - r.ApplyOptions(request.WithSetRequestHeaders(map[string]string{ - "x-amz-source-arn": sourceArn, - "x-amz-source-account": sourceAccount, - })) + r.HTTPRequest.Header.Set(SourceArnHeaderKey, sourceArn) + r.HTTPRequest.Header.Set(SourceAccountHeaderKey, sourceAccount) }) + + log.Printf("I! Found confused deputy header environment variables: source account: %q, source arn: %q", sourceAccount, sourceArn) } return client diff --git a/internal/aws/awsutil/conn_test.go b/internal/aws/awsutil/conn_test.go index 2da8aeb33407..82539e0e4e52 100644 --- a/internal/aws/awsutil/conn_test.go +++ b/internal/aws/awsutil/conn_test.go @@ -8,9 +8,13 @@ import ( "testing" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/session" + awsmock "github.com/aws/aws-sdk-go/awstesting/mock" + "github.com/aws/aws-sdk-go/service/sts" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" "go.uber.org/zap" ) @@ -178,3 +182,75 @@ func TestLoadEmptyFile(t *testing.T) { assert.Error(t, err) assert.Nil(t, certFromFile) } + +func TestConfusedDeputyHeaders(t *testing.T) { + tests := []struct { + name string + envSourceArn string + envSourceAccount string + expectedHeaderArn string + expectedHeaderAccount string + }{ + { + name: "unpopulated", + envSourceArn: "", + envSourceAccount: "", + expectedHeaderArn: "", + expectedHeaderAccount: "", + }, + { + name: "both populated", + envSourceArn: "arn:aws:ec2:us-east-1:474668408639:instance/i-08293cd9825754f7c", + envSourceAccount: "539247453986", + expectedHeaderArn: "arn:aws:ec2:us-east-1:474668408639:instance/i-08293cd9825754f7c", + expectedHeaderAccount: "539247453986", + }, + { + name: "only source arn populated", + envSourceArn: "arn:aws:ec2:us-east-1:474668408639:instance/i-08293cd9825754f7c", + envSourceAccount: "", + expectedHeaderArn: "", + expectedHeaderAccount: "", + }, + { + name: "only source account populated", + envSourceArn: "", + envSourceAccount: "539247453986", + expectedHeaderArn: "", + expectedHeaderAccount: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + t.Setenv(AmzSourceAccount, tt.envSourceAccount) + t.Setenv(AmzSourceArn, tt.envSourceArn) + + client := newStsClient(awsmock.Session, &aws.Config{ + // These are examples credentials pulled from: + // https://docs.aws.amazon.com/STS/latest/APIReference/API_GetAccessKeyInfo.html + Credentials: credentials.NewStaticCredentials("AKIAIOSFODNN7EXAMPLE", "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", ""), + Region: aws.String("us-east-1"), + }) + + request, _ := client.AssumeRoleRequest(&sts.AssumeRoleInput{ + // We aren't going to actually make the assume role call, we are just going + // to verify the headers are present once signed so the RoleArn and RoleSessionName + // arguments are irrelevant. Fill them out with something so the request is valid. + RoleArn: aws.String("arn:aws:iam::012345678912:role/XXXXXXXX"), + RoleSessionName: aws.String("MockSession"), + }) + + // Headers are generated after the request is signed (but before it's sent) + err := request.Sign() + require.NoError(t, err) + + headerSourceArn := request.HTTPRequest.Header.Get(SourceArnHeaderKey) + assert.Equal(t, tt.expectedHeaderArn, headerSourceArn) + + headerSourceAccount := request.HTTPRequest.Header.Get(SourceAccountHeaderKey) + assert.Equal(t, tt.expectedHeaderAccount, headerSourceAccount) + }) + } + +}