From 6681d5ecbc82293626cab693625c9cc9fb453381 Mon Sep 17 00:00:00 2001 From: Adrien F Date: Thu, 25 Apr 2024 14:52:56 +0200 Subject: [PATCH] Add AWS MSK IAM authentication to Kafka scaler (#5692) Signed-off-by: Adrien Fillon --- CHANGELOG.md | 1 + go.mod | 1 + go.sum | 2 + .../kafka_scaler_oauth_token_provider.go | 67 +++- pkg/scalers/kafka_scaler.go | 242 ++++++++++---- pkg/scalers/kafka_scaler_test.go | 135 ++++++-- pkg/scaling/scalers_builder.go | 2 +- .../aws/aws-msk-iam-sasl-signer-go/LICENSE | 175 ++++++++++ .../aws/aws-msk-iam-sasl-signer-go/NOTICE | 1 + .../signer/msk_auth_token_provider.go | 305 ++++++++++++++++++ .../signer/version.go | 3 + vendor/modules.txt | 3 + 12 files changed, 843 insertions(+), 94 deletions(-) create mode 100644 vendor/github.com/aws/aws-msk-iam-sasl-signer-go/LICENSE create mode 100644 vendor/github.com/aws/aws-msk-iam-sasl-signer-go/NOTICE create mode 100644 vendor/github.com/aws/aws-msk-iam-sasl-signer-go/signer/msk_auth_token_provider.go create mode 100644 vendor/github.com/aws/aws-msk-iam-sasl-signer-go/signer/version.go diff --git a/CHANGELOG.md b/CHANGELOG.md index b5f384238b1..bd671c7516b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -91,6 +91,7 @@ New deprecation(s): ### New - **General**: Provide capability to filter CloudEvents ([#3533](https://github.com/kedacore/keda/issues/3533)) +- **Kafka**: Support Kafka SASL MSK IAM authentication ([#5540](https://github.com/kedacore/keda/issues/5540)) - **NATS Scaler**: Add TLS authentication ([#2296](https://github.com/kedacore/keda/issues/2296)) - **ScaledObject**: Ability to specify `initialCooldownPeriod` ([#5008](https://github.com/kedacore/keda/issues/5008)) diff --git a/go.mod b/go.mod index 64fa2055d41..69c37dbb8c6 100644 --- a/go.mod +++ b/go.mod @@ -24,6 +24,7 @@ require ( github.com/Huawei/gophercloud v1.0.21 github.com/IBM/sarama v1.43.1 github.com/arangodb/go-driver v1.6.2 + github.com/aws/aws-msk-iam-sasl-signer-go v1.0.0 github.com/aws/aws-sdk-go-v2 v1.26.1 github.com/aws/aws-sdk-go-v2/config v1.27.11 github.com/aws/aws-sdk-go-v2/credentials v1.17.11 diff --git a/go.sum b/go.sum index 5a79728635c..c0d4dc45484 100644 --- a/go.sum +++ b/go.sum @@ -1448,6 +1448,8 @@ github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPd github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so= github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= +github.com/aws/aws-msk-iam-sasl-signer-go v1.0.0 h1:UyjtGmO0Uwl/K+zpzPwLoXzMhcN9xmnR2nrqJoBrg3c= +github.com/aws/aws-msk-iam-sasl-signer-go v1.0.0/go.mod h1:TJAXuFs2HcMib3sN5L0gUC+Q01Qvy3DemvA55WuC+iA= github.com/aws/aws-sdk-go-v2 v1.16.12/go.mod h1:C+Ym0ag2LIghJbXhfXZ0YEEp49rBWowxKzJLUoob0ts= github.com/aws/aws-sdk-go-v2 v1.26.1 h1:5554eUqIYVWpU0YmeeYZ0wU64H2VLBs8TlhRB2L+EkA= github.com/aws/aws-sdk-go-v2 v1.26.1/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM= diff --git a/pkg/scalers/kafka/kafka_scaler_oauth_token_provider.go b/pkg/scalers/kafka/kafka_scaler_oauth_token_provider.go index 084d8c9d15b..6fa36cfa0ba 100644 --- a/pkg/scalers/kafka/kafka_scaler_oauth_token_provider.go +++ b/pkg/scalers/kafka/kafka_scaler_oauth_token_provider.go @@ -18,18 +18,27 @@ package kafka import ( "context" + "sync" + "time" "github.com/IBM/sarama" + "github.com/aws/aws-msk-iam-sasl-signer-go/signer" + "github.com/aws/aws-sdk-go-v2/aws" "golang.org/x/oauth2" "golang.org/x/oauth2/clientcredentials" ) -type TokenProvider struct { +type TokenProvider interface { + sarama.AccessTokenProvider + String() string +} + +type oauthBearerTokenProvider struct { tokenSource oauth2.TokenSource extensions map[string]string } -func OAuthBearerTokenProvider(clientID, clientSecret, tokenURL string, scopes []string, extensions map[string]string) sarama.AccessTokenProvider { +func OAuthBearerTokenProvider(clientID, clientSecret, tokenURL string, scopes []string, extensions map[string]string) TokenProvider { cfg := clientcredentials.Config{ ClientID: clientID, ClientSecret: clientSecret, @@ -37,17 +46,63 @@ func OAuthBearerTokenProvider(clientID, clientSecret, tokenURL string, scopes [] Scopes: scopes, } - return &TokenProvider{ + return &oauthBearerTokenProvider{ tokenSource: cfg.TokenSource(context.Background()), extensions: extensions, } } -func (t *TokenProvider) Token() (*sarama.AccessToken, error) { - token, err := t.tokenSource.Token() +func (o *oauthBearerTokenProvider) Token() (*sarama.AccessToken, error) { + token, err := o.tokenSource.Token() + if err != nil { + return nil, err + } + + return &sarama.AccessToken{Token: token.AccessToken, Extensions: o.extensions}, nil +} + +func (o *oauthBearerTokenProvider) String() string { + return "OAuthBearer" +} + +type mskTokenProvider struct { + sync.Mutex + expireAt *time.Time + token string + region string + credentialsProvider aws.CredentialsProvider +} + +func OAuthMSKTokenProvider(cfg *aws.Config) TokenProvider { + return &mskTokenProvider{ + region: cfg.Region, + credentialsProvider: cfg.Credentials, + } +} + +func (m *mskTokenProvider) Token() (*sarama.AccessToken, error) { + m.Lock() + defer m.Unlock() + + if m.expireAt != nil && time.Now().Before(*m.expireAt) { + return &sarama.AccessToken{Token: m.token}, nil + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + token, expirationMs, err := signer.GenerateAuthTokenFromCredentialsProvider(ctx, m.region, m.credentialsProvider) if err != nil { return nil, err } - return &sarama.AccessToken{Token: token.AccessToken, Extensions: t.extensions}, nil + expirationTime := time.UnixMilli(expirationMs) + m.expireAt = &expirationTime + m.token = token + + return &sarama.AccessToken{Token: token}, err +} + +func (m *mskTokenProvider) String() string { + return "MSK" } diff --git a/pkg/scalers/kafka_scaler.go b/pkg/scalers/kafka_scaler.go index 86acc802d81..ec9c7fc5927 100644 --- a/pkg/scalers/kafka_scaler.go +++ b/pkg/scalers/kafka_scaler.go @@ -33,6 +33,7 @@ import ( v2 "k8s.io/api/autoscaling/v2" "k8s.io/metrics/pkg/apis/external_metrics" + awsutils "github.com/kedacore/keda/v2/pkg/scalers/aws" "github.com/kedacore/keda/v2/pkg/scalers/kafka" "github.com/kedacore/keda/v2/pkg/scalers/scalersconfig" kedautil "github.com/kedacore/keda/v2/pkg/util" @@ -81,10 +82,15 @@ type kafkaMetadata struct { kerberosServiceName string // OAUTHBEARER + tokenProvider kafkaSaslOAuthTokenProvider scopes []string oauthTokenEndpointURI string oauthExtensions map[string]string + // MSK + awsRegion string + awsAuthorization awsutils.AuthorizationMetadata + // TLS enableTLS bool cert string @@ -115,6 +121,14 @@ const ( KafkaSASLTypeGSSAPI kafkaSaslType = "gssapi" ) +type kafkaSaslOAuthTokenProvider string + +// supported SASL OAuth token provider types +const ( + KafkaSASLOAuthTokenProviderBearer kafkaSaslOAuthTokenProvider = "bearer" + KafkaSASLOAuthTokenProviderAWSMSKIAM kafkaSaslOAuthTokenProvider = "aws_msk_iam" +) + const ( lagThresholdMetricName = "lagThreshold" activationLagThresholdMetricName = "activationLagThreshold" @@ -126,7 +140,7 @@ const ( ) // NewKafkaScaler creates a new kafkaScaler -func NewKafkaScaler(config *scalersconfig.ScalerConfig) (Scaler, error) { +func NewKafkaScaler(ctx context.Context, config *scalersconfig.ScalerConfig) (Scaler, error) { metricType, err := GetMetricTargetType(config) if err != nil { return nil, fmt.Errorf("error getting scaler metric type: %w", err) @@ -139,7 +153,7 @@ func NewKafkaScaler(config *scalersconfig.ScalerConfig) (Scaler, error) { return nil, fmt.Errorf("error parsing kafka metadata: %w", err) } - client, admin, err := getKafkaClients(kafkaMetadata) + client, admin, err := getKafkaClients(ctx, kafkaMetadata) if err != nil { return nil, err } @@ -157,6 +171,40 @@ func NewKafkaScaler(config *scalersconfig.ScalerConfig) (Scaler, error) { } func parseKafkaAuthParams(config *scalersconfig.ScalerConfig, meta *kafkaMetadata) error { + meta.enableTLS = false + enableTLS := false + if val, ok := config.TriggerMetadata["tls"]; ok { + switch val { + case stringEnable: + enableTLS = true + case stringDisable: + enableTLS = false + default: + return fmt.Errorf("error incorrect TLS value given, got %s", val) + } + } + + if val, ok := config.AuthParams["tls"]; ok { + val = strings.TrimSpace(val) + if enableTLS { + return errors.New("unable to set `tls` in both ScaledObject and TriggerAuthentication together") + } + switch val { + case stringEnable: + enableTLS = true + case stringDisable: + enableTLS = false + default: + return fmt.Errorf("error incorrect TLS value given, got %s", val) + } + } + + if enableTLS { + if err := parseTLS(config, meta); err != nil { + return err + } + } + meta.saslType = KafkaSASLTypeNone var saslAuthType string switch { @@ -177,11 +225,16 @@ func parseKafkaAuthParams(config *scalersconfig.ScalerConfig, meta *kafkaMetadat mode := kafkaSaslType(saslAuthType) switch { - case mode == KafkaSASLTypePlaintext || mode == KafkaSASLTypeSCRAMSHA256 || mode == KafkaSASLTypeSCRAMSHA512 || mode == KafkaSASLTypeOAuthbearer: + case mode == KafkaSASLTypePlaintext || mode == KafkaSASLTypeSCRAMSHA256 || mode == KafkaSASLTypeSCRAMSHA512: err := parseSaslParams(config, meta, mode) if err != nil { return err } + case mode == KafkaSASLTypeOAuthbearer: + err := parseSaslOAuthParams(config, meta, mode) + if err != nil { + return err + } case mode == KafkaSASLTypeGSSAPI: err := parseKerberosParams(config, meta, mode) if err != nil { @@ -192,38 +245,97 @@ func parseKafkaAuthParams(config *scalersconfig.ScalerConfig, meta *kafkaMetadat } } - meta.enableTLS = false - enableTLS := false - if val, ok := config.TriggerMetadata["tls"]; ok { - switch val { - case stringEnable: - enableTLS = true - case stringDisable: - enableTLS = false - default: - return fmt.Errorf("error incorrect TLS value given, got %s", val) - } + return nil +} + +func parseSaslOAuthParams(config *scalersconfig.ScalerConfig, meta *kafkaMetadata, mode kafkaSaslType) error { + var tokenProviderTypeValue string + if val, ok := config.TriggerMetadata["saslTokenProvider"]; ok { + tokenProviderTypeValue = val } - if val, ok := config.AuthParams["tls"]; ok { - val = strings.TrimSpace(val) - if enableTLS { - return errors.New("unable to set `tls` in both ScaledObject and TriggerAuthentication together") + if val, ok := config.AuthParams["saslTokenProvider"]; ok { + if tokenProviderTypeValue != "" { + return errors.New("unable to set `saslTokenProvider` in both ScaledObject and TriggerAuthentication together") } - switch val { - case stringEnable: - enableTLS = true - case stringDisable: - enableTLS = false - default: - return fmt.Errorf("error incorrect TLS value given, got %s", val) + tokenProviderTypeValue = val + } + + tokenProviderType := KafkaSASLOAuthTokenProviderBearer + if tokenProviderTypeValue != "" { + tokenProviderType = kafkaSaslOAuthTokenProvider(strings.TrimSpace(tokenProviderTypeValue)) + } + + var tokenProviderErr error + switch tokenProviderType { + case KafkaSASLOAuthTokenProviderBearer: + tokenProviderErr = parseSaslOAuthBearerParams(config, meta) + case KafkaSASLOAuthTokenProviderAWSMSKIAM: + tokenProviderErr = parseSaslOAuthAWSMSKIAMParams(config, meta) + default: + return fmt.Errorf("err SASL OAuth token provider %s given", tokenProviderType) + } + + if tokenProviderErr != nil { + return fmt.Errorf("error parsing OAuth token provider configuration: %w", tokenProviderErr) + } + + meta.saslType = mode + meta.tokenProvider = tokenProviderType + + return nil +} + +func parseSaslOAuthBearerParams(config *scalersconfig.ScalerConfig, meta *kafkaMetadata) error { + if config.AuthParams["username"] == "" { + return errors.New("no username given") + } + meta.username = strings.TrimSpace(config.AuthParams["username"]) + + if config.AuthParams["password"] == "" { + return errors.New("no password given") + } + meta.password = strings.TrimSpace(config.AuthParams["password"]) + + meta.scopes = strings.Split(config.AuthParams["scopes"], ",") + + if config.AuthParams["oauthTokenEndpointUri"] == "" { + return errors.New("no oauth token endpoint uri given") + } + meta.oauthTokenEndpointURI = strings.TrimSpace(config.AuthParams["oauthTokenEndpointUri"]) + + meta.oauthExtensions = make(map[string]string) + oauthExtensionsRaw := config.AuthParams["oauthExtensions"] + if oauthExtensionsRaw != "" { + for _, extension := range strings.Split(oauthExtensionsRaw, ",") { + splittedExtension := strings.Split(extension, "=") + if len(splittedExtension) != 2 { + return errors.New("invalid OAuthBearer extension, must be of format key=value") + } + meta.oauthExtensions[splittedExtension[0]] = splittedExtension[1] } } - if enableTLS { - return parseTLS(config, meta) + return nil +} + +func parseSaslOAuthAWSMSKIAMParams(config *scalersconfig.ScalerConfig, meta *kafkaMetadata) error { + if !meta.enableTLS { + return errors.New("TLS is required for AWS MSK authentication") + } + + if config.TriggerMetadata["awsRegion"] == "" { + return errors.New("no awsRegion given") + } + + meta.awsRegion = config.TriggerMetadata["awsRegion"] + + auth, err := awsutils.GetAwsAuthorization(config.TriggerUniqueKey, config.PodIdentity, config.TriggerMetadata, config.AuthParams, config.ResolvedEnv) + if err != nil { + return fmt.Errorf("error getting AWS authorization: %w", err) } + meta.awsAuthorization = auth return nil } @@ -312,26 +424,6 @@ func parseSaslParams(config *scalersconfig.ScalerConfig, meta *kafkaMetadata, mo meta.password = strings.TrimSpace(config.AuthParams["password"]) meta.saslType = mode - if mode == KafkaSASLTypeOAuthbearer { - meta.scopes = strings.Split(config.AuthParams["scopes"], ",") - - if config.AuthParams["oauthTokenEndpointUri"] == "" { - return errors.New("no oauth token endpoint uri given") - } - meta.oauthTokenEndpointURI = strings.TrimSpace(config.AuthParams["oauthTokenEndpointUri"]) - - meta.oauthExtensions = make(map[string]string) - oauthExtensionsRaw := config.AuthParams["oauthExtensions"] - if oauthExtensionsRaw != "" { - for _, extension := range strings.Split(oauthExtensionsRaw, ",") { - splittedExtension := strings.Split(extension, "=") - if len(splittedExtension) != 2 { - return errors.New("invalid OAuthBearer extension, must be of format key=value") - } - meta.oauthExtensions[splittedExtension[0]] = splittedExtension[1] - } - } - } return nil } @@ -505,7 +597,29 @@ func parseKafkaMetadata(config *scalersconfig.ScalerConfig, logger logr.Logger) return meta, nil } -func getKafkaClients(metadata kafkaMetadata) (sarama.Client, sarama.ClusterAdmin, error) { +func getKafkaClients(ctx context.Context, metadata kafkaMetadata) (sarama.Client, sarama.ClusterAdmin, error) { + config, err := getKafkaClientConfig(ctx, metadata) + if err != nil { + return nil, nil, fmt.Errorf("error getting kafka client config: %w", err) + } + + client, err := sarama.NewClient(metadata.bootstrapServers, config) + if err != nil { + return nil, nil, fmt.Errorf("error creating kafka client: %w", err) + } + + admin, err := sarama.NewClusterAdminFromClient(client) + if err != nil { + if !client.Closed() { + client.Close() + } + return nil, nil, fmt.Errorf("error creating kafka admin: %w", err) + } + + return client, admin, nil +} + +func getKafkaClientConfig(ctx context.Context, metadata kafkaMetadata) (*sarama.Config, error) { config := sarama.NewConfig() config.Version = metadata.version @@ -519,7 +633,7 @@ func getKafkaClients(metadata kafkaMetadata) (sarama.Client, sarama.ClusterAdmin config.Net.TLS.Enable = true tlsConfig, err := kedautil.NewTLSConfigWithPassword(metadata.cert, metadata.key, metadata.keyPassword, metadata.ca, metadata.unsafeSsl) if err != nil { - return nil, nil, err + return nil, err } config.Net.TLS.Config = tlsConfig } @@ -540,7 +654,19 @@ func getKafkaClients(metadata kafkaMetadata) (sarama.Client, sarama.ClusterAdmin if metadata.saslType == KafkaSASLTypeOAuthbearer { config.Net.SASL.Mechanism = sarama.SASLTypeOAuth - config.Net.SASL.TokenProvider = kafka.OAuthBearerTokenProvider(metadata.username, metadata.password, metadata.oauthTokenEndpointURI, metadata.scopes, metadata.oauthExtensions) + switch metadata.tokenProvider { + case KafkaSASLOAuthTokenProviderBearer: + config.Net.SASL.TokenProvider = kafka.OAuthBearerTokenProvider(metadata.username, metadata.password, metadata.oauthTokenEndpointURI, metadata.scopes, metadata.oauthExtensions) + case KafkaSASLOAuthTokenProviderAWSMSKIAM: + awsAuth, err := awsutils.GetAwsConfig(ctx, metadata.awsRegion, metadata.awsAuthorization) + if err != nil { + return nil, fmt.Errorf("error getting AWS config: %w", err) + } + + config.Net.SASL.TokenProvider = kafka.OAuthMSKTokenProvider(awsAuth) + default: + return nil, fmt.Errorf("err SASL OAuth token provider %s given but not supported", metadata.tokenProvider) + } } if metadata.saslType == KafkaSASLTypeGSSAPI { @@ -562,21 +688,7 @@ func getKafkaClients(metadata kafkaMetadata) (sarama.Client, sarama.ClusterAdmin config.Net.SASL.GSSAPI.Password = metadata.password } } - - client, err := sarama.NewClient(metadata.bootstrapServers, config) - if err != nil { - return nil, nil, fmt.Errorf("error creating kafka client: %w", err) - } - - admin, err := sarama.NewClusterAdminFromClient(client) - if err != nil { - if !client.Closed() { - client.Close() - } - return nil, nil, fmt.Errorf("error creating kafka admin: %w", err) - } - - return client, admin, nil + return config, nil } func (s *kafkaScaler) getTopicPartitions() (map[string][]int32, error) { diff --git a/pkg/scalers/kafka_scaler_test.go b/pkg/scalers/kafka_scaler_test.go index 81ebc746443..57a3f95eba9 100644 --- a/pkg/scalers/kafka_scaler_test.go +++ b/pkg/scalers/kafka_scaler_test.go @@ -12,6 +12,7 @@ import ( "github.com/IBM/sarama" "github.com/go-logr/logr" + kafka_oauth "github.com/kedacore/keda/v2/pkg/scalers/kafka" "github.com/kedacore/keda/v2/pkg/scalers/scalersconfig" ) @@ -35,6 +36,13 @@ type parseKafkaAuthParamsTestData struct { enableTLS bool } +type parseKafkaOAuthbearerAuthParamsTestData = struct { + metadata map[string]string + authParams map[string]string + isError bool + enableTLS bool +} + // Testing the case where `tls` and `sasl` are specified in ScaledObject type parseAuthParamsTestDataSecondAuthMethod struct { metadata map[string]string @@ -295,23 +303,35 @@ var parseAuthParamsTestDataset = []parseAuthParamsTestDataSecondAuthMethod{ {map[string]string{"bootstrapServers": "foobar:9092", "consumerGroup": "my-group", "topic": "my-topic", "allowIdleConsumers": "true", "version": "1.0.0"}, map[string]string{"sasl": "scram_sha512 ", "username": "admin", "password": "admin"}, false, true}, } -var parseKafkaOAuthbrearerAuthParamsTestDataset = []parseKafkaAuthParamsTestData{ +var parseKafkaOAuthbearerAuthParamsTestDataset = []parseKafkaOAuthbearerAuthParamsTestData{ // success, SASL OAUTHBEARER + TLS - {map[string]string{"sasl": "oauthbearer", "username": "admin", "password": "admin", "scopes": "scope", "oauthTokenEndpointUri": "https://website.com", "tls": "disable"}, false, false}, + {map[string]string{}, map[string]string{"sasl": "oauthbearer", "username": "admin", "password": "admin", "scopes": "scope", "oauthTokenEndpointUri": "https://website.com", "tls": "disable"}, false, false}, + // success, SASL OAUTHBEARER + tokenProvider + TLS + {map[string]string{}, map[string]string{"sasl": "oauthbearer", "saslTokenProvider": "bearer", "username": "admin", "password": "admin", "scopes": "scope", "oauthTokenEndpointUri": "https://website.com", "tls": "disable"}, false, false}, // success, SASL OAUTHBEARER + TLS multiple scopes - {map[string]string{"sasl": "oauthbearer", "username": "admin", "password": "admin", "scopes": "scope1, scope2", "oauthTokenEndpointUri": "https://website.com", "tls": "disable"}, false, false}, + {map[string]string{}, map[string]string{"sasl": "oauthbearer", "username": "admin", "password": "admin", "scopes": "scope1, scope2", "oauthTokenEndpointUri": "https://website.com", "tls": "disable"}, false, false}, // success, SASL OAUTHBEARER + TLS missing scope - {map[string]string{"sasl": "oauthbearer", "username": "admin", "password": "admin", "oauthTokenEndpointUri": "https://website.com", "tls": "disable"}, false, false}, + {map[string]string{}, map[string]string{"sasl": "oauthbearer", "username": "admin", "password": "admin", "oauthTokenEndpointUri": "https://website.com", "tls": "disable"}, false, false}, // failure, SASL OAUTHBEARER + TLS bad sasl type - {map[string]string{"sasl": "foo", "username": "admin", "password": "admin", "scopes": "scope", "oauthTokenEndpointUri": "https://website.com", "tls": "disable"}, true, false}, + {map[string]string{}, map[string]string{"sasl": "foo", "username": "admin", "password": "admin", "scopes": "scope", "oauthTokenEndpointUri": "https://website.com", "tls": "disable"}, true, false}, // failure, SASL OAUTHBEARER + TLS missing oauthTokenEndpointUri - {map[string]string{"sasl": "oauthbearer", "username": "admin", "password": "admin", "scopes": "scope", "oauthTokenEndpointUri": "", "tls": "disable"}, true, false}, + {map[string]string{}, map[string]string{"sasl": "oauthbearer", "username": "admin", "password": "admin", "scopes": "scope", "oauthTokenEndpointUri": "", "tls": "disable"}, true, false}, // success, SASL OAUTHBEARER + extension - {map[string]string{"sasl": "oauthbearer", "username": "admin", "password": "admin", "scopes": "scope", "oauthTokenEndpointUri": "https://website.com", "tls": "disable", "oauthExtensions": "extension_foo=bar"}, false, false}, + {map[string]string{}, map[string]string{"sasl": "oauthbearer", "username": "admin", "password": "admin", "scopes": "scope", "oauthTokenEndpointUri": "https://website.com", "tls": "disable", "oauthExtensions": "extension_foo=bar"}, false, false}, // success, SASL OAUTHBEARER + multiple extensions - {map[string]string{"sasl": "oauthbearer", "username": "admin", "password": "admin", "scopes": "scope", "oauthTokenEndpointUri": "https://website.com", "tls": "disable", "oauthExtensions": "extension_foo=bar,extension_baz=baz"}, false, false}, + {map[string]string{}, map[string]string{"sasl": "oauthbearer", "username": "admin", "password": "admin", "scopes": "scope", "oauthTokenEndpointUri": "https://website.com", "tls": "disable", "oauthExtensions": "extension_foo=bar,extension_baz=baz"}, false, false}, // failure, SASL OAUTHBEARER + bad extension - {map[string]string{"sasl": "oauthbearer", "username": "admin", "password": "admin", "scopes": "scope", "oauthTokenEndpointUri": "https://website.com", "tls": "disable", "oauthExtensions": "extension_foo=bar,extension_bazbaz"}, true, false}, + {map[string]string{}, map[string]string{"sasl": "oauthbearer", "username": "admin", "password": "admin", "scopes": "scope", "oauthTokenEndpointUri": "https://website.com", "tls": "disable", "oauthExtensions": "extension_foo=bar,extension_bazbaz"}, true, false}, + // success, SASL OAUTHBEARER MSK + TLS + Credentials + {map[string]string{"awsRegion": "eu-west-1"}, map[string]string{"sasl": "oauthbearer", "saslTokenProvider": "aws_msk_iam", "tls": "enable", "awsAccessKeyID": "none", "awsSecretAccessKey": "none"}, false, true}, + // success, SASL OAUTHBEARER MSK + TLS + Role + {map[string]string{"awsRegion": "eu-west-1"}, map[string]string{"sasl": "oauthbearer", "saslTokenProvider": "aws_msk_iam", "tls": "enable", "awsRegion": "eu-west-1", "awsRoleArn": "none"}, false, true}, + // failure, SASL OAUTHBEARER MSK + no TLS + {map[string]string{}, map[string]string{"sasl": "oauthbearer", "saslTokenProvider": "aws_msk_iam", "tls": "disable"}, true, false}, + // failure, SASL OAUTHBEARER MSK + TLS + no region + {map[string]string{}, map[string]string{"sasl": "oauthbearer", "saslTokenProvider": "aws_msk_iam", "tls": "enable", "awsRegion": ""}, true, true}, + // failure, SASL OAUTHBEARER MSK + TLS + no credentials + {map[string]string{}, map[string]string{"sasl": "oauthbearer", "saslTokenProvider": "aws_msk_iam", "tls": "enable", "awsRegion": "eu-west-1"}, true, true}, } var kafkaMetricIdentifiers = []kafkaMetricIdentifier{ @@ -384,7 +404,7 @@ func TestKafkaAuthParamsInTriggerAuthentication(t *testing.T) { if testData.isError && err == nil { t.Error("Expected error but got success") } - if meta.enableTLS != testData.enableTLS { + if !testData.isError && meta.enableTLS != testData.enableTLS { t.Errorf("Expected enableTLS to be set to %v but got %v\n", testData.enableTLS, meta.enableTLS) } if meta.enableTLS { @@ -483,29 +503,100 @@ func testFileContents(testData parseKafkaAuthParamsTestData, meta kafkaMetadata, return nil } -func TestKafkaOAuthbrearerAuthParams(t *testing.T) { - for _, testData := range parseKafkaOAuthbrearerAuthParamsTestDataset { - meta, err := parseKafkaMetadata(&scalersconfig.ScalerConfig{TriggerMetadata: validKafkaMetadata, AuthParams: testData.authParams}, logr.Discard()) +func TestKafkaOAuthbearerAuthParams(t *testing.T) { + for _, testData := range parseKafkaOAuthbearerAuthParamsTestDataset { + for k, v := range validKafkaMetadata { + testData.metadata[k] = v + } + + meta, err := parseKafkaMetadata(&scalersconfig.ScalerConfig{TriggerMetadata: testData.metadata, AuthParams: testData.authParams}, logr.Discard()) if err != nil && !testData.isError { - t.Error("Expected success but got error", err) + t.Fatal("Expected success but got error", err) } if testData.isError && err == nil { - t.Error("Expected error but got success") + t.Fatal("Expected error but got success") } - if testData.authParams["scopes"] == "" { - if len(meta.scopes) != strings.Count(testData.authParams["scopes"], ",")+1 { - t.Errorf("Expected scopes to be set to %v but got %v\n", strings.Count(testData.authParams["scopes"], ","), len(meta.scopes)) + + if testData.authParams["saslTokenProvider"] == "" || testData.authParams["saslTokenProvider"] == "bearer" { + if !testData.isError && meta.tokenProvider != KafkaSASLOAuthTokenProviderBearer { + t.Errorf("Expected tokenProvider to be set to %v but got %v\n", KafkaSASLOAuthTokenProviderBearer, meta.tokenProvider) } - } - if err == nil && testData.authParams["oauthExtensions"] != "" { - if len(meta.oauthExtensions) != strings.Count(testData.authParams["oauthExtensions"], ",")+1 { - t.Errorf("Expected number of extensions to be set to %v but got %v\n", strings.Count(testData.authParams["oauthExtensions"], ",")+1, len(meta.oauthExtensions)) + + if testData.authParams["scopes"] == "" { + if len(meta.scopes) != strings.Count(testData.authParams["scopes"], ",")+1 { + t.Errorf("Expected scopes to be set to %v but got %v\n", strings.Count(testData.authParams["scopes"], ","), len(meta.scopes)) + } + } + + if err == nil && testData.authParams["oauthExtensions"] != "" { + if len(meta.oauthExtensions) != strings.Count(testData.authParams["oauthExtensions"], ",")+1 { + t.Errorf("Expected number of extensions to be set to %v but got %v\n", strings.Count(testData.authParams["oauthExtensions"], ",")+1, len(meta.oauthExtensions)) + } + } + } else if testData.authParams["saslTokenProvider"] == "aws_msk_iam" { + if !testData.isError && meta.tokenProvider != KafkaSASLOAuthTokenProviderAWSMSKIAM { + t.Errorf("Expected tokenProvider to be set to %v but got %v\n", KafkaSASLOAuthTokenProviderAWSMSKIAM, meta.tokenProvider) + } + + if testData.metadata["awsRegion"] != "" && meta.awsRegion != testData.metadata["awsRegion"] { + t.Errorf("Expected awsRegion to be set to %v but got %v\n", testData.metadata["awsRegion"], meta.awsRegion) + } + + if testData.authParams["awsAccessKeyID"] != "" { + if meta.awsAuthorization.AwsAccessKeyID != testData.authParams["awsAccessKeyID"] { + t.Errorf("Expected awsAccessKeyID to be set to %v but got %v\n", testData.authParams["awsAccessKeyID"], meta.awsAuthorization.AwsAccessKeyID) + } + + if meta.awsAuthorization.AwsSecretAccessKey != testData.authParams["awsSecretAccessKey"] { + t.Errorf("Expected awsSecretAccessKey to be set to %v but got %v\n", testData.authParams["awsSecretAccessKey"], meta.awsAuthorization.AwsSecretAccessKey) + } + } else if testData.authParams["awsRoleArn"] != "" && meta.awsAuthorization.AwsRoleArn != testData.authParams["awsRoleArn"] { + t.Errorf("Expected awsRoleArn to be set to %v but got %v\n", testData.authParams["awsRoleArn"], meta.awsAuthorization.AwsRoleArn) } } } } +func TestKafkaClientsOAuthTokenProvider(t *testing.T) { + testData := []struct { + name string + metadata map[string]string + authParams map[string]string + expectedTokenProvider string + }{ + {"oauthbearer_bearer", map[string]string{"bootstrapServers": "foobar:9092", "consumerGroup": "my-group", "topic": "my-topic", "partitionLimitation": "1,2"}, map[string]string{"sasl": "oauthbearer", "username": "admin", "password": "admin", "oauthTokenEndpointUri": "https://website.com"}, "OAuthBearer"}, + {"oauthbearer_aws_msk_iam", map[string]string{"bootstrapServers": "foobar:9092", "consumerGroup": "my-group", "topic": "my-topic", "partitionLimitation": "1,2", "tls": "enable", "awsRegion": "eu-west-1"}, map[string]string{"sasl": "oauthbearer", "saslTokenProvider": "aws_msk_iam", "awsRegion": "eu-west-1", "awsAccessKeyID": "none", "awsSecretAccessKey": "none"}, "MSK"}, + } + + for _, tt := range testData { + t.Run(tt.name, func(t *testing.T) { + meta, err := parseKafkaMetadata(&scalersconfig.ScalerConfig{TriggerMetadata: tt.metadata, AuthParams: tt.authParams}, logr.Discard()) + if err != nil { + t.Fatal("Could not parse metadata:", err) + } + + cfg, err := getKafkaClientConfig(context.TODO(), meta) + if err != nil { + t.Error("Expected success but got error", err) + } + + if !cfg.Net.SASL.Enable { + t.Error("Expected SASL to be enabled on client") + } + + tokenProvider, ok := cfg.Net.SASL.TokenProvider.(kafka_oauth.TokenProvider) + if !ok { + t.Error("Expected token provider to be set on client") + } + + if tokenProvider.String() != tt.expectedTokenProvider { + t.Errorf("Expected token provider to be %v but got %v", tt.expectedTokenProvider, tokenProvider.String()) + } + }) + } +} + func TestKafkaGetMetricSpecForScaling(t *testing.T) { for _, testData := range kafkaMetricIdentifiers { meta, err := parseKafkaMetadata(&scalersconfig.ScalerConfig{TriggerMetadata: testData.metadataTestData.metadata, AuthParams: validWithAuthParams, TriggerIndex: testData.triggerIndex}, logr.Discard()) diff --git a/pkg/scaling/scalers_builder.go b/pkg/scaling/scalers_builder.go index afcfc574bb1..3c328930d90 100644 --- a/pkg/scaling/scalers_builder.go +++ b/pkg/scaling/scalers_builder.go @@ -198,7 +198,7 @@ func buildScaler(ctx context.Context, client client.Client, triggerType string, case "influxdb": return scalers.NewInfluxDBScaler(config) case "kafka": - return scalers.NewKafkaScaler(config) + return scalers.NewKafkaScaler(ctx, config) case "kubernetes-workload": return scalers.NewKubernetesWorkloadScaler(client, config) case "liiklus": diff --git a/vendor/github.com/aws/aws-msk-iam-sasl-signer-go/LICENSE b/vendor/github.com/aws/aws-msk-iam-sasl-signer-go/LICENSE new file mode 100644 index 00000000000..67db8588217 --- /dev/null +++ b/vendor/github.com/aws/aws-msk-iam-sasl-signer-go/LICENSE @@ -0,0 +1,175 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. diff --git a/vendor/github.com/aws/aws-msk-iam-sasl-signer-go/NOTICE b/vendor/github.com/aws/aws-msk-iam-sasl-signer-go/NOTICE new file mode 100644 index 00000000000..616fc588945 --- /dev/null +++ b/vendor/github.com/aws/aws-msk-iam-sasl-signer-go/NOTICE @@ -0,0 +1 @@ +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. diff --git a/vendor/github.com/aws/aws-msk-iam-sasl-signer-go/signer/msk_auth_token_provider.go b/vendor/github.com/aws/aws-msk-iam-sasl-signer-go/signer/msk_auth_token_provider.go new file mode 100644 index 00000000000..1c2670db2fa --- /dev/null +++ b/vendor/github.com/aws/aws-msk-iam-sasl-signer-go/signer/msk_auth_token_provider.go @@ -0,0 +1,305 @@ +package signer + +import ( + "context" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "fmt" + + "log" + "net/http" + "net/url" + "runtime" + "strconv" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/sts" +) + +const ( + ActionType = "Action" // ActionType represents the key for the action type in the request. + ActionName = "kafka-cluster:Connect" // ActionName represents the specific action name for connecting to a Kafka cluster. + SigningName = "kafka-cluster" // SigningName represents the signing name for the Kafka cluster. + UserAgentKey = "User-Agent" // UserAgentKey represents the key for the User-Agent parameter in the request. + LibName = "aws-msk-iam-sasl-signer-go" // LibName represents the name of the library. + ExpiresQueryKey = "X-Amz-Expires" // ExpiresQueryKey represents the key for the expiration time in the query parameters. + DefaultSessionName = "MSKSASLDefaultSession" // DefaultSessionName represents the default session name for assuming a role. + DefaultExpirySeconds = 900 // DefaultExpirySeconds represents the default expiration time in seconds. +) + +var ( + endpointURLTemplate = "kafka.%s.amazonaws.com" // endpointURLTemplate represents the template for the Kafka endpoint URL + AwsDebugCreds = false // AwsDebugCreds flag indicates whether credentials should be debugged +) + +// GenerateAuthToken generates base64 encoded signed url as auth token from default credentials. +// Loads the IAM credentials from default credentials provider chain. +func GenerateAuthToken(ctx context.Context, region string) (string, int64, error) { + credentials, err := loadDefaultCredentials(ctx, region) + + if err != nil { + return "", 0, fmt.Errorf("failed to load credentials: %w", err) + } + + return constructAuthToken(ctx, region, credentials) +} + +// GenerateAuthTokenFromProfile generates base64 encoded signed url as auth token by loading IAM credentials from an AWS named profile. +func GenerateAuthTokenFromProfile(ctx context.Context, region string, awsProfile string) (string, int64, error) { + credentials, err := loadCredentialsFromProfile(ctx, region, awsProfile) + + if err != nil { + return "", 0, fmt.Errorf("failed to load credentials: %w", err) + } + + return constructAuthToken(ctx, region, credentials) +} + +// GenerateAuthTokenFromRole generates base64 encoded signed url as auth token by loading IAM credentials from an aws role Arn +func GenerateAuthTokenFromRole( + ctx context.Context, region string, roleArn string, stsSessionName string, +) (string, int64, error) { + if stsSessionName == "" { + stsSessionName = DefaultSessionName + } + credentials, err := loadCredentialsFromRoleArn(ctx, region, roleArn, stsSessionName) + + if err != nil { + return "", 0, fmt.Errorf("failed to load credentials: %w", err) + } + + return constructAuthToken(ctx, region, credentials) +} + +// GenerateAuthTokenFromCredentialsProvider generates base64 encoded signed url as auth token by loading IAM credentials +// from an aws credentials provider +func GenerateAuthTokenFromCredentialsProvider( + ctx context.Context, region string, credentialsProvider aws.CredentialsProvider, +) (string, int64, error) { + credentials, err := loadCredentialsFromCredentialsProvider(ctx, credentialsProvider) + + if err != nil { + return "", 0, fmt.Errorf("failed to load credentials: %w", err) + } + + return constructAuthToken(ctx, region, credentials) +} + +// Loads credentials from the default credential chain. +func loadDefaultCredentials(ctx context.Context, region string) (*aws.Credentials, error) { + cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region)) + + if err != nil { + return nil, fmt.Errorf("unable to load SDK config: %w", err) + } + + return loadCredentialsFromCredentialsProvider(ctx, cfg.Credentials) +} + +// Loads credentials from a named aws profile. +func loadCredentialsFromProfile(ctx context.Context, region string, awsProfile string) (*aws.Credentials, error) { + cfg, err := config.LoadDefaultConfig(ctx, + config.WithRegion(region), + config.WithSharedConfigProfile(awsProfile), + ) + + if err != nil { + return nil, fmt.Errorf("unable to load SDK config: %w", err) + } + + return loadCredentialsFromCredentialsProvider(ctx, cfg.Credentials) +} + +// Loads credentials from a named by assuming the passed role. +// This implementation creates a new sts client for every call to get or refresh token. In order to avoid this, please +// use your own credentials provider. +// If you wish to use regional endpoint, please pass your own credentials provider. +func loadCredentialsFromRoleArn( + ctx context.Context, region string, roleArn string, stsSessionName string, +) (*aws.Credentials, error) { + cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region)) + + if err != nil { + return nil, fmt.Errorf("unable to load SDK config: %w", err) + } + + stsClient := sts.NewFromConfig(cfg) + + assumeRoleInput := &sts.AssumeRoleInput{ + RoleArn: aws.String(roleArn), + RoleSessionName: aws.String(stsSessionName), + } + assumeRoleOutput, err := stsClient.AssumeRole(ctx, assumeRoleInput) + if err != nil { + return nil, fmt.Errorf("unable to assume role, %s: %w", roleArn, err) + } + + //Create new aws.Credentials instance using the credentials from AssumeRoleOutput.Credentials + creds := aws.Credentials{ + AccessKeyID: *assumeRoleOutput.Credentials.AccessKeyId, + SecretAccessKey: *assumeRoleOutput.Credentials.SecretAccessKey, + SessionToken: *assumeRoleOutput.Credentials.SessionToken, + } + + return &creds, nil +} + +// Loads credentials from the credentials provider +func loadCredentialsFromCredentialsProvider( + ctx context.Context, credentialsProvider aws.CredentialsProvider, +) (*aws.Credentials, error) { + creds, err := credentialsProvider.Retrieve(ctx) + return &creds, err +} + +// Constructs Auth Token. +func constructAuthToken(ctx context.Context, region string, credentials *aws.Credentials) (string, int64, error) { + endpointURL := fmt.Sprintf(endpointURLTemplate, region) + + if credentials == nil || credentials.AccessKeyID == "" || credentials.SecretAccessKey == "" { + return "", 0, fmt.Errorf("aws credentials cannot be empty") + } + + if AwsDebugCreds { + logCallerIdentity(ctx, region, *credentials) + } + + req, err := buildRequest(DefaultExpirySeconds, endpointURL) + if err != nil { + return "", 0, fmt.Errorf("failed to build request for signing: %w", err) + } + + signedURL, err := signRequest(ctx, req, region, credentials) + if err != nil { + return "", 0, fmt.Errorf("failed to sign request with aws sig v4: %w", err) + } + + expirationTimeMs, err := getExpirationTimeMs(signedURL) + if err != nil { + return "", 0, fmt.Errorf("failed to extract expiration from signed url: %w", err) + } + + signedURLWithUserAgent, err := addUserAgent(signedURL) + if err != nil { + return "", 0, fmt.Errorf("failed to add user agent to the signed url: %w", err) + } + + return base64Encode(signedURLWithUserAgent), expirationTimeMs, nil +} + +// Build https request with query parameters in order to sign. +func buildRequest(expirySeconds int, endpointURL string) (*http.Request, error) { + query := url.Values{ + ActionType: {ActionName}, + ExpiresQueryKey: {strconv.FormatInt(int64(expirySeconds), 10)}, + } + + authURL := url.URL{ + Host: endpointURL, + Scheme: "https", + Path: "/", + RawQuery: query.Encode(), + } + + return http.NewRequest(http.MethodGet, authURL.String(), nil) +} + +// Sign request with aws sig v4. +func signRequest(ctx context.Context, req *http.Request, region string, credentials *aws.Credentials) (string, error) { + signer := v4.NewSigner() + signedURL, _, err := signer.PresignHTTP(ctx, *credentials, req, + calculateSHA256Hash(""), + SigningName, + region, + time.Now().UTC(), + ) + + return signedURL, err +} + +// Parses the URL and gets the expiration time in millis associated with the signed url +func getExpirationTimeMs(signedURL string) (int64, error) { + parsedURL, err := url.Parse(signedURL) + + if err != nil { + return 0, fmt.Errorf("failed to parse the signed url: %w", err) + } + + params := parsedURL.Query() + date, err := time.Parse("20060102T150405Z", params.Get("X-Amz-Date")) + + if err != nil { + return 0, fmt.Errorf("failed to parse the 'X-Amz-Date' param from signed url: %w", err) + } + + signingTimeMs := date.UnixNano() / int64(time.Millisecond) + expiryDurationSeconds, err := strconv.ParseInt(params.Get("X-Amz-Expires"), 10, 64) + + if err != nil { + return 0, fmt.Errorf("failed to parse the 'X-Amz-Expires' param from signed url: %w", err) + } + + expiryDurationMs := expiryDurationSeconds * 1000 + expiryMs := signingTimeMs + expiryDurationMs + return expiryMs, nil +} + +// Calculate sha256Hash and hex encode it. +func calculateSHA256Hash(input string) string { + hash := sha256.Sum256([]byte(input)) + return hex.EncodeToString(hash[:]) +} + +// Base64 encode with raw url encoding. +func base64Encode(signedURL string) string { + signedURLBytes := []byte(signedURL) + return base64.RawURLEncoding.EncodeToString(signedURLBytes) +} + +// Add user agent to the signed url +func addUserAgent(signedURL string) (string, error) { + parsedSignedURL, err := url.Parse(signedURL) + + if err != nil { + return "", fmt.Errorf("failed to parse signed url: %w", err) + } + + query := parsedSignedURL.Query() + userAgent := strings.Join([]string{LibName, version, runtime.Version()}, "/") + query.Set(UserAgentKey, userAgent) + parsedSignedURL.RawQuery = query.Encode() + + return parsedSignedURL.String(), nil +} + +// Log caller identity to debug which credentials are being picked up +func logCallerIdentity(ctx context.Context, region string, awsCredentials aws.Credentials) { + cfg, err := config.LoadDefaultConfig(ctx, + config.WithRegion(region), + config.WithCredentialsProvider(credentials.StaticCredentialsProvider{ + Value: awsCredentials, + }), + ) + if err != nil { + log.Printf("failed to load AWS configuration: %v", err) + } + + stsClient := sts.NewFromConfig(cfg) + + callerIdentity, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) + + if err != nil { + log.Printf("failed to get caller identity: %v", err) + } + + log.Printf("Credentials Identity: {UserId: %s, Account: %s, Arn: %s}\n", + *callerIdentity.UserId, + *callerIdentity.Account, + *callerIdentity.Arn) +} diff --git a/vendor/github.com/aws/aws-msk-iam-sasl-signer-go/signer/version.go b/vendor/github.com/aws/aws-msk-iam-sasl-signer-go/signer/version.go new file mode 100644 index 00000000000..d723e9085de --- /dev/null +++ b/vendor/github.com/aws/aws-msk-iam-sasl-signer-go/signer/version.go @@ -0,0 +1,3 @@ +package signer + +const version = "1.0.0" diff --git a/vendor/modules.txt b/vendor/modules.txt index 807ec84f06f..fd269297d6f 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -297,6 +297,9 @@ github.com/arangodb/go-velocypack # github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 ## explicit; go 1.13 github.com/asaskevich/govalidator +# github.com/aws/aws-msk-iam-sasl-signer-go v1.0.0 +## explicit; go 1.17 +github.com/aws/aws-msk-iam-sasl-signer-go/signer # github.com/aws/aws-sdk-go-v2 v1.26.1 ## explicit; go 1.20 github.com/aws/aws-sdk-go-v2/aws