Skip to content

Commit

Permalink
Add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dricross committed Jan 29, 2025
1 parent 1889c8a commit 2a01f03
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 7 deletions.
21 changes: 14 additions & 7 deletions internal/aws/awsutil/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"crypto/tls"
"crypto/x509"
"errors"
"log"
"net/http"
"net/url"
"os"
Expand Down Expand Up @@ -484,9 +485,11 @@ func newStsCredentials(c client.ConfigProvider, roleARN string, region string) *
return credentials.NewCredentials(&stsCredentialProvider{regional: regional, partitional: partitional})
}

var (
sourceAccount = os.Getenv("AMZ_SOURCE_ACCOUNT") // populates the "x-amz-source-account" header
sourceArn = os.Getenv("AMZ_SOURCE_ARN") // populates the "x-amz-source-arn" header
const (
SourceArnHeaderKey = "x-amz-source-arn"
SourceAccountHeaderKey = "x-amz-source-account"
AmzSourceAccount = "AMZ_SOURCE_ACCOUNT"
AmzSourceArn = "AMZ_SOURCE_ARN"
)

// newStsClient creates a new STS client with the provided config and options.
Expand All @@ -497,14 +500,18 @@ var (
//
// See https://docs.aws.amazon.com/IAM/latest/UserGuide/confused-deputy.html#cross-service-confused-deputy-prevention
func newStsClient(p client.ConfigProvider, cfgs ...*aws.Config) *sts.STS {

sourceAccount := os.Getenv(AmzSourceAccount)
sourceArn := os.Getenv(AmzSourceArn)

client := sts.New(p, cfgs...)
if sourceAccount != "" && sourceArn != "" {
client.Handlers.Sign.PushFront(func(r *request.Request) {
r.ApplyOptions(request.WithSetRequestHeaders(map[string]string{
"x-amz-source-arn": sourceArn,
"x-amz-source-account": sourceAccount,
}))
r.HTTPRequest.Header.Set(SourceArnHeaderKey, sourceArn)
r.HTTPRequest.Header.Set(SourceAccountHeaderKey, sourceAccount)
})

log.Printf("I! Found confused deputy header environment variables: source account: %q, source arn: %q", sourceAccount, sourceArn)
}

return client
Expand Down
76 changes: 76 additions & 0 deletions internal/aws/awsutil/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@ import (
"testing"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
awsmock "github.com/aws/aws-sdk-go/awstesting/mock"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
)

Expand Down Expand Up @@ -178,3 +182,75 @@ func TestLoadEmptyFile(t *testing.T) {
assert.Error(t, err)
assert.Nil(t, certFromFile)
}

func TestConfusedDeputyHeaders(t *testing.T) {
tests := []struct {
name string
envSourceArn string
envSourceAccount string
expectedHeaderArn string
expectedHeaderAccount string
}{
{
name: "unpopulated",
envSourceArn: "",
envSourceAccount: "",
expectedHeaderArn: "",
expectedHeaderAccount: "",
},
{
name: "both populated",
envSourceArn: "arn:aws:ec2:us-east-1:474668408639:instance/i-08293cd9825754f7c",
envSourceAccount: "539247453986",
expectedHeaderArn: "arn:aws:ec2:us-east-1:474668408639:instance/i-08293cd9825754f7c",
expectedHeaderAccount: "539247453986",
},
{
name: "only source arn populated",
envSourceArn: "arn:aws:ec2:us-east-1:474668408639:instance/i-08293cd9825754f7c",
envSourceAccount: "",
expectedHeaderArn: "",
expectedHeaderAccount: "",
},
{
name: "only source account populated",
envSourceArn: "",
envSourceAccount: "539247453986",
expectedHeaderArn: "",
expectedHeaderAccount: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

t.Setenv(AmzSourceAccount, tt.envSourceAccount)
t.Setenv(AmzSourceArn, tt.envSourceArn)

client := newStsClient(awsmock.Session, &aws.Config{
// These are examples credentials pulled from:
// https://docs.aws.amazon.com/STS/latest/APIReference/API_GetAccessKeyInfo.html
Credentials: credentials.NewStaticCredentials("AKIAIOSFODNN7EXAMPLE", "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", ""),
Region: aws.String("us-east-1"),
})

request, _ := client.AssumeRoleRequest(&sts.AssumeRoleInput{
// We aren't going to actually make the assume role call, we are just going
// to verify the headers are present once signed so the RoleArn and RoleSessionName
// arguments are irrelevant. Fill them out with something so the request is valid.
RoleArn: aws.String("arn:aws:iam::012345678912:role/XXXXXXXX"),
RoleSessionName: aws.String("MockSession"),
})

// Headers are generated after the request is signed (but before it's sent)
err := request.Sign()
require.NoError(t, err)

headerSourceArn := request.HTTPRequest.Header.Get(SourceArnHeaderKey)
assert.Equal(t, tt.expectedHeaderArn, headerSourceArn)

headerSourceAccount := request.HTTPRequest.Header.Get(SourceAccountHeaderKey)
assert.Equal(t, tt.expectedHeaderAccount, headerSourceAccount)
})
}

}

0 comments on commit 2a01f03

Please sign in to comment.