diff --git a/go.mod b/go.mod index bbe78280458..dc99abcc5ee 100644 --- a/go.mod +++ b/go.mod @@ -11,11 +11,12 @@ require ( github.com/alexliesenfeld/health v0.8.1 github.com/appleboy/gin-jwt/v2 v2.10.3 github.com/aws/aws-lambda-go v1.47.0 - github.com/aws/aws-sdk-go v1.52.0 github.com/aws/aws-sdk-go-v2 v1.38.3 github.com/aws/aws-sdk-go-v2/config v1.31.6 + github.com/aws/aws-sdk-go-v2/credentials v1.18.10 // indirect github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.19.4 github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs v1.57.2 + github.com/aws/aws-sdk-go-v2/service/kinesis v1.40.1 github.com/aws/aws-sdk-go-v2/service/s3 v1.87.3 github.com/aws/aws-sdk-go-v2/service/sqs v1.42.3 github.com/beevik/etree v1.4.1 @@ -107,7 +108,6 @@ require ( github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.1 // indirect - github.com/aws/aws-sdk-go-v2/credentials v1.18.10 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.6 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.6 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.6 // indirect @@ -164,7 +164,6 @@ require ( github.com/jackc/pgproto3/v2 v2.3.3 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/pgtype v1.14.0 // indirect - github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect diff --git a/go.sum b/go.sum index 9c00be29cc9..b56bbd573fd 100644 --- a/go.sum +++ b/go.sum @@ -46,8 +46,6 @@ github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3d github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= github.com/aws/aws-lambda-go v1.47.0 h1:0H8s0vumYx/YKs4sE7YM0ktwL2eWse+kfopsRI1sXVI= github.com/aws/aws-lambda-go v1.47.0/go.mod h1:dpMpZgvWx5vuQJfBt0zqBha60q7Dd7RfgJv23DymV8A= -github.com/aws/aws-sdk-go v1.52.0 h1:ptgek/4B2v/ljsjYSEvLQ8LTD+SQyrqhOOWvHc/VGPI= -github.com/aws/aws-sdk-go v1.52.0/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk= github.com/aws/aws-sdk-go-v2 v1.38.3 h1:B6cV4oxnMs45fql4yRH+/Po/YU+597zgWqvDpYMturk= github.com/aws/aws-sdk-go-v2 v1.38.3/go.mod h1:sDioUELIUO9Znk23YVmIk86/9DOpkbyyVb1i/gUNFXY= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.1 h1:i8p8P4diljCr60PpJp6qZXNlgX4m2yQFpYk+9ZT+J4E= @@ -78,6 +76,8 @@ github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.6 h1:LHS1YAIJX github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.6/go.mod h1:c9PCiTEuh0wQID5/KqA32J+HAgZxN9tOGXKCiYJjTZI= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.6 h1:nEXUSAwyUfLTgnc9cxlDWy637qsq4UWwp3sNAfl0Z3Y= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.6/go.mod h1:HGzIULx4Ge3Do2V0FaiYKcyKzOqwrhUZgCI77NisswQ= +github.com/aws/aws-sdk-go-v2/service/kinesis v1.40.1 h1:9QC0AF6gakV1TZuGp3NEUNl/6gXt3rfIifnxd+dWwbw= +github.com/aws/aws-sdk-go-v2/service/kinesis v1.40.1/go.mod h1:UpSQbmXxFiDGDrvqsTgEm3YijDf9cg/Ti+s2W0SeFEU= github.com/aws/aws-sdk-go-v2/service/s3 v1.87.3 h1:ETkfWcXP2KNPLecaDa++5bsQhCRa5M5sLUJa5DWYIIg= github.com/aws/aws-sdk-go-v2/service/s3 v1.87.3/go.mod h1:+/3ZTqoYb3Ur7DObD00tarKMLMuKg8iqz5CHEanqTnw= github.com/aws/aws-sdk-go-v2/service/sqs v1.42.3 h1:0dWg1Tkz3FnEo48DgAh7CT22hYyMShly8WMd3sGx0xI= @@ -362,10 +362,6 @@ github.com/jedib0t/go-pretty/v6 v6.6.7 h1:m+LbHpm0aIAPLzLbMfn8dc3Ht8MW7lsSO4MPIt github.com/jedib0t/go-pretty/v6 v6.6.7/go.mod h1:YwC5CE4fJ1HFUDeivSV1r//AmANFHyqczZk+U6BDALU= github.com/jhump/protoreflect v1.15.1 h1:HUMERORf3I3ZdX05WaQ6MIpd/NJ434hTp5YiKgfCL6c= github.com/jhump/protoreflect v1.15.1/go.mod h1:jD/2GMKKE6OqX8qTjhADU1e6DShO+gavG9e0Q693nKo= -github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= -github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= -github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= -github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= diff --git a/pkg/acquisition/modules/kinesis/creds.go b/pkg/acquisition/modules/kinesis/creds.go new file mode 100644 index 00000000000..a639d239390 --- /dev/null +++ b/pkg/acquisition/modules/kinesis/creds.go @@ -0,0 +1,9 @@ +//go:build !test + +package kinesisacquisition + +import "github.com/aws/aws-sdk-go-v2/aws" + +func defaultCreds() aws.CredentialsProvider { + return nil +} diff --git a/pkg/acquisition/modules/kinesis/creds_test.go b/pkg/acquisition/modules/kinesis/creds_test.go new file mode 100644 index 00000000000..c4ea9d78bb5 --- /dev/null +++ b/pkg/acquisition/modules/kinesis/creds_test.go @@ -0,0 +1,9 @@ +//go:build test + +package kinesisacquisition + +import "github.com/aws/aws-sdk-go-v2/aws" + +func defaultCreds() aws.CredentialsProvider { + return aws.AnonymousCredentials{} +} diff --git a/pkg/acquisition/modules/kinesis/kinesis.go b/pkg/acquisition/modules/kinesis/kinesis.go index 728fabaeed9..9f20443eba3 100644 --- a/pkg/acquisition/modules/kinesis/kinesis.go +++ b/pkg/acquisition/modules/kinesis/kinesis.go @@ -11,10 +11,12 @@ import ( "strings" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/arn" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/kinesis" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/arn" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/kinesis" + kinTypes "github.com/aws/aws-sdk-go-v2/service/kinesis/types" + yaml "github.com/goccy/go-yaml" "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" @@ -45,7 +47,7 @@ type KinesisSource struct { metricsLevel metrics.AcquisitionMetricsLevel Config KinesisConfiguration logger *log.Entry - kClient *kinesis.Kinesis + kClient *kinesis.Client shardReaderTomb *tomb.Tomb } @@ -68,39 +70,37 @@ func (k *KinesisSource) GetUuid() string { return k.Config.UniqueId } -func (k *KinesisSource) newClient() error { - var sess *session.Session - - if k.Config.AwsProfile != nil { - sess = session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - Profile: *k.Config.AwsProfile, - })) - } else { - sess = session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) +func (k *KinesisSource) newClient(ctx context.Context) error { + var loadOpts []func(*config.LoadOptions) error + if k.Config.AwsProfile != nil && *k.Config.AwsProfile != "" { + loadOpts = append(loadOpts, config.WithSharedConfigProfile(*k.Config.AwsProfile)) } - if sess == nil { - return errors.New("failed to create aws session") + region := k.Config.AwsRegion + if region == "" { + region = "us-east-1" } - config := aws.NewConfig() + loadOpts = append(loadOpts, config.WithRegion(region)) - if k.Config.AwsRegion != "" { - config = config.WithRegion(k.Config.AwsRegion) + if c := defaultCreds(); c != nil { + loadOpts = append(loadOpts, config.WithCredentialsProvider(c)) } - if k.Config.AwsEndpoint != "" { - config = config.WithEndpoint(k.Config.AwsEndpoint) + cfg, err := config.LoadDefaultConfig(ctx, loadOpts...) + if err != nil { + return fmt.Errorf("failed to load aws config: %w", err) } - k.kClient = kinesis.New(sess, config) - if k.kClient == nil { - return errors.New("failed to create kinesis client") + var clientOpts []func(*kinesis.Options) + if k.Config.AwsEndpoint != "" { + clientOpts = append(clientOpts, func(o *kinesis.Options) { + o.BaseEndpoint = aws.String(k.Config.AwsEndpoint) + }) } + k.kClient = kinesis.NewFromConfig(cfg, clientOpts...) + return nil } @@ -156,7 +156,7 @@ func (k *KinesisSource) Configure(yamlConfig []byte, logger *log.Entry, metricsL return err } - err = k.newClient() + err = k.newClient(context.TODO()) if err != nil { return fmt.Errorf("cannot create kinesis client: %w", err) } @@ -200,15 +200,15 @@ func (k *KinesisSource) decodeFromSubscription(record []byte) ([]CloudwatchSubsc return subscriptionRecord.LogEvents, nil } -func (k *KinesisSource) WaitForConsumerDeregistration(consumerName string, streamARN string) error { +func (k *KinesisSource) WaitForConsumerDeregistration(ctx context.Context, consumerName string, streamARN string) error { maxTries := k.Config.MaxRetries for i := range maxTries { - _, err := k.kClient.DescribeStreamConsumer(&kinesis.DescribeStreamConsumerInput{ - ConsumerName: aws.String(consumerName), - StreamARN: aws.String(streamARN), - }) + _, err := k.kClient.DescribeStreamConsumer(ctx, &kinesis.DescribeStreamConsumerInput{ + ConsumerName: aws.String(consumerName), + StreamARN: aws.String(streamARN), + }) - var resourceNotFoundErr *kinesis.ResourceNotFoundException + var resourceNotFoundErr *kinTypes.ResourceNotFoundException if errors.As(err, &resourceNotFoundErr) { return nil } @@ -224,14 +224,14 @@ func (k *KinesisSource) WaitForConsumerDeregistration(consumerName string, strea return fmt.Errorf("consumer %s is not deregistered after %d tries", consumerName, maxTries) } -func (k *KinesisSource) DeregisterConsumer() error { +func (k *KinesisSource) DeregisterConsumer(ctx context.Context) error { k.logger.Debugf("Deregistering consumer %s if it exists", k.Config.ConsumerName) - _, err := k.kClient.DeregisterStreamConsumer(&kinesis.DeregisterStreamConsumerInput{ - ConsumerName: aws.String(k.Config.ConsumerName), - StreamARN: aws.String(k.Config.StreamARN), - }) + _, err := k.kClient.DeregisterStreamConsumer(ctx, &kinesis.DeregisterStreamConsumerInput{ + ConsumerName: aws.String(k.Config.ConsumerName), + StreamARN: aws.String(k.Config.StreamARN), + }) - var resourceNotFoundErr *kinesis.ResourceNotFoundException + var resourceNotFoundErr *kinTypes.ResourceNotFoundException if errors.As(err, &resourceNotFoundErr) { return nil } @@ -240,7 +240,7 @@ func (k *KinesisSource) DeregisterConsumer() error { return fmt.Errorf("cannot deregister stream consumer: %w", err) } - err = k.WaitForConsumerDeregistration(k.Config.ConsumerName, k.Config.StreamARN) + err = k.WaitForConsumerDeregistration(ctx, k.Config.ConsumerName, k.Config.StreamARN) if err != nil { return fmt.Errorf("cannot wait for consumer deregistration: %w", err) } @@ -248,17 +248,17 @@ func (k *KinesisSource) DeregisterConsumer() error { return nil } -func (k *KinesisSource) WaitForConsumerRegistration(consumerARN string) error { +func (k *KinesisSource) WaitForConsumerRegistration(ctx context.Context, consumerARN string) error { maxTries := k.Config.MaxRetries for i := range maxTries { - describeOutput, err := k.kClient.DescribeStreamConsumer(&kinesis.DescribeStreamConsumerInput{ - ConsumerARN: aws.String(consumerARN), - }) + describeOutput, err := k.kClient.DescribeStreamConsumer(ctx, &kinesis.DescribeStreamConsumerInput{ + ConsumerARN: aws.String(consumerARN), + }) if err != nil { return fmt.Errorf("cannot describe stream consumer: %w", err) } - if *describeOutput.ConsumerDescription.ConsumerStatus == "ACTIVE" { + if describeOutput.ConsumerDescription.ConsumerStatus == "ACTIVE" { k.logger.Debugf("Consumer %s is active", consumerARN) return nil } @@ -270,18 +270,18 @@ func (k *KinesisSource) WaitForConsumerRegistration(consumerARN string) error { return fmt.Errorf("consumer %s is not active after %d tries", consumerARN, maxTries) } -func (k *KinesisSource) RegisterConsumer() (*kinesis.RegisterStreamConsumerOutput, error) { +func (k *KinesisSource) RegisterConsumer(ctx context.Context) (*kinesis.RegisterStreamConsumerOutput, error) { k.logger.Debugf("Registering consumer %s", k.Config.ConsumerName) - streamConsumer, err := k.kClient.RegisterStreamConsumer(&kinesis.RegisterStreamConsumerInput{ - ConsumerName: aws.String(k.Config.ConsumerName), - StreamARN: aws.String(k.Config.StreamARN), - }) + streamConsumer, err := k.kClient.RegisterStreamConsumer(ctx, &kinesis.RegisterStreamConsumerInput{ + ConsumerName: aws.String(k.Config.ConsumerName), + StreamARN: aws.String(k.Config.StreamARN), + }) if err != nil { return nil, fmt.Errorf("cannot register stream consumer: %w", err) } - err = k.WaitForConsumerRegistration(*streamConsumer.Consumer.ConsumerARN) + err = k.WaitForConsumerRegistration(ctx, *streamConsumer.Consumer.ConsumerARN) if err != nil { return nil, fmt.Errorf("timeout while waiting for consumer to be active: %w", err) } @@ -289,7 +289,7 @@ func (k *KinesisSource) RegisterConsumer() (*kinesis.RegisterStreamConsumerOutpu return streamConsumer, nil } -func (k *KinesisSource) ParseAndPushRecords(records []*kinesis.Record, out chan types.Event, logger *log.Entry, shardID string) { +func (k *KinesisSource) ParseAndPushRecords(records []kinTypes.Record, out chan types.Event, logger *log.Entry, shardID string) { for _, record := range records { if k.Config.StreamARN != "" { if k.metricsLevel != metrics.AcquisitionMetricsLevelNone { @@ -337,6 +337,7 @@ func (k *KinesisSource) ParseAndPushRecords(records []*kinesis.Record, out chan evt := types.MakeEvent(k.Config.UseTimeMachine, types.LOG, true) evt.Line = l + out <- evt } } @@ -365,20 +366,20 @@ func (k *KinesisSource) ReadFromSubscription(reader kinesis.SubscribeToShardEven return nil } - switch event := event.(type) { - case *kinesis.SubscribeToShardEvent: - k.ParseAndPushRecords(event.Records, out, logger, shardID) - case *kinesis.SubscribeToShardEventStreamUnknownEvent: - logger.Infof("got an unknown event, what to do ?") + switch et := event.(type) { + case *kinTypes.SubscribeToShardEventStreamMemberSubscribeToShardEvent: + k.ParseAndPushRecords(et.Value.Records, out, logger, shardID) + default: + logger.Infof("unhandled SubscribeToShard event: %T", et) } } } } -func (k *KinesisSource) SubscribeToShards(arn arn.ARN, streamConsumer *kinesis.RegisterStreamConsumerOutput, out chan types.Event) error { - shards, err := k.kClient.ListShards(&kinesis.ListShardsInput{ - StreamName: aws.String(arn.Resource[7:]), - }) +func (k *KinesisSource) SubscribeToShards(ctx context.Context, arn arn.ARN, streamConsumer *kinesis.RegisterStreamConsumerOutput, out chan types.Event) error { + shards, err := k.kClient.ListShards(ctx, &kinesis.ListShardsInput{ + StreamName: aws.String(arn.Resource[7:]), + }) if err != nil { return fmt.Errorf("cannot list shards for enhanced_read: %w", err) } @@ -386,24 +387,24 @@ func (k *KinesisSource) SubscribeToShards(arn arn.ARN, streamConsumer *kinesis.R for _, shard := range shards.Shards { shardID := *shard.ShardId - r, err := k.kClient.SubscribeToShard(&kinesis.SubscribeToShardInput{ - ShardId: aws.String(shardID), - StartingPosition: &kinesis.StartingPosition{Type: aws.String(kinesis.ShardIteratorTypeLatest)}, - ConsumerARN: streamConsumer.Consumer.ConsumerARN, - }) + r, err := k.kClient.SubscribeToShard(ctx, &kinesis.SubscribeToShardInput{ + ShardId: aws.String(shardID), + StartingPosition: &kinTypes.StartingPosition{Type: kinTypes.ShardIteratorTypeLatest}, + ConsumerARN: streamConsumer.Consumer.ConsumerARN, + }) if err != nil { return fmt.Errorf("cannot subscribe to shard: %w", err) } k.shardReaderTomb.Go(func() error { - return k.ReadFromSubscription(r.GetEventStream().Reader, out, shardID, arn.Resource[7:]) + return k.ReadFromSubscription(r.GetStream().Reader, out, shardID, arn.Resource[7:]) }) } return nil } -func (k *KinesisSource) EnhancedRead(out chan types.Event, t *tomb.Tomb) error { +func (k *KinesisSource) EnhancedRead(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { parsedARN, err := arn.Parse(k.Config.StreamARN) if err != nil { return fmt.Errorf("cannot parse stream ARN: %w", err) @@ -416,12 +417,12 @@ func (k *KinesisSource) EnhancedRead(out chan types.Event, t *tomb.Tomb) error { k.logger = k.logger.WithField("stream", parsedARN.Resource[7:]) k.logger.Info("starting kinesis acquisition with enhanced fan-out") - err = k.DeregisterConsumer() + err = k.DeregisterConsumer(ctx) if err != nil { return fmt.Errorf("cannot deregister consumer: %w", err) } - streamConsumer, err := k.RegisterConsumer() + streamConsumer, err := k.RegisterConsumer(ctx) if err != nil { return fmt.Errorf("cannot register consumer: %w", err) } @@ -429,17 +430,18 @@ func (k *KinesisSource) EnhancedRead(out chan types.Event, t *tomb.Tomb) error { for { k.shardReaderTomb = &tomb.Tomb{} - err = k.SubscribeToShards(parsedARN, streamConsumer, out) + err = k.SubscribeToShards(ctx, parsedARN, streamConsumer, out) if err != nil { return fmt.Errorf("cannot subscribe to shards: %w", err) } + select { case <-t.Dying(): k.logger.Infof("Kinesis source is dying") k.shardReaderTomb.Kill(nil) _ = k.shardReaderTomb.Wait() // we don't care about the error as we kill the tomb ourselves - err = k.DeregisterConsumer() + err = k.DeregisterConsumer(ctx) if err != nil { return fmt.Errorf("cannot deregister consumer: %w", err) } @@ -459,15 +461,16 @@ func (k *KinesisSource) EnhancedRead(out chan types.Event, t *tomb.Tomb) error { } } -func (k *KinesisSource) ReadFromShard(out chan types.Event, shardID string) error { +func (k *KinesisSource) ReadFromShard(ctx context.Context, out chan types.Event, shardID string) error { logger := k.logger.WithField("shard", shardID) logger.Debugf("Starting to read shard") - sharIt, err := k.kClient.GetShardIterator(&kinesis.GetShardIteratorInput{ - ShardId: aws.String(shardID), - StreamName: &k.Config.StreamName, - ShardIteratorType: aws.String(kinesis.ShardIteratorTypeLatest), - }) + sharIt, err := k.kClient.GetShardIterator(ctx, + &kinesis.GetShardIteratorInput{ + ShardId: aws.String(shardID), + StreamName: &k.Config.StreamName, + ShardIteratorType: kinTypes.ShardIteratorTypeLatest, + }) if err != nil { logger.Errorf("Cannot get shard iterator: %s", err) return fmt.Errorf("cannot get shard iterator: %w", err) @@ -480,16 +483,17 @@ func (k *KinesisSource) ReadFromShard(out chan types.Event, shardID string) erro for { select { case <-ticker.C: - records, err := k.kClient.GetRecords(&kinesis.GetRecordsInput{ShardIterator: it}) + records, err := k.kClient.GetRecords(ctx, &kinesis.GetRecordsInput{ShardIterator: it}) + it = records.NextShardIterator - var throughputErr *kinesis.ProvisionedThroughputExceededException + var throughputErr *kinTypes.ProvisionedThroughputExceededException if errors.As(err, &throughputErr) { logger.Warn("Provisioned throughput exceeded") // TODO: implement exponential backoff continue } - var expiredIteratorErr *kinesis.ExpiredIteratorException + var expiredIteratorErr *kinTypes.ExpiredIteratorException if errors.As(err, &expiredIteratorErr) { logger.Warn("Expired iterator") continue @@ -516,14 +520,14 @@ func (k *KinesisSource) ReadFromShard(out chan types.Event, shardID string) erro } } -func (k *KinesisSource) ReadFromStream(out chan types.Event, t *tomb.Tomb) error { +func (k *KinesisSource) ReadFromStream(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { k.logger = k.logger.WithField("stream", k.Config.StreamName) k.logger.Info("starting kinesis acquisition from shards") for { - shards, err := k.kClient.ListShards(&kinesis.ListShardsInput{ - StreamName: aws.String(k.Config.StreamName), - }) + shards, err := k.kClient.ListShards(ctx, &kinesis.ListShardsInput{ + StreamName: aws.String(k.Config.StreamName), + }) if err != nil { return fmt.Errorf("cannot list shards: %w", err) } @@ -535,9 +539,10 @@ func (k *KinesisSource) ReadFromStream(out chan types.Event, t *tomb.Tomb) error k.shardReaderTomb.Go(func() error { defer trace.CatchPanic("crowdsec/acquis/kinesis/streaming/shard") - return k.ReadFromShard(out, shardID) + return k.ReadFromShard(ctx, out, shardID) }) } + select { case <-t.Dying(): k.logger.Info("kinesis source is dying") @@ -559,15 +564,15 @@ func (k *KinesisSource) ReadFromStream(out chan types.Event, t *tomb.Tomb) error } } -func (k *KinesisSource) StreamingAcquisition(_ context.Context, out chan types.Event, t *tomb.Tomb) error { +func (k *KinesisSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { t.Go(func() error { defer trace.CatchPanic("crowdsec/acquis/kinesis/streaming") if k.Config.UseEnhancedFanOut { - return k.EnhancedRead(out, t) + return k.EnhancedRead(ctx, out, t) } - return k.ReadFromStream(out, t) + return k.ReadFromStream(ctx, out, t) }) return nil diff --git a/pkg/acquisition/modules/kinesis/kinesis_test.go b/pkg/acquisition/modules/kinesis/kinesis_test.go index 91615a60326..4d7622d9d23 100644 --- a/pkg/acquisition/modules/kinesis/kinesis_test.go +++ b/pkg/acquisition/modules/kinesis/kinesis_test.go @@ -10,9 +10,10 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/kinesis" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/kinesis" + log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -43,6 +44,7 @@ func GenSubObject(t *testing.T, i int) []byte { require.NoError(t, err) var b bytes.Buffer + gz := gzip.NewWriter(&b) _, err = gz.Write(body) require.NoError(t, err) @@ -53,8 +55,14 @@ func GenSubObject(t *testing.T, i int) []byte { } func WriteToStream(t *testing.T, endpoint string, streamName string, count int, shards int, sub bool) { - sess := session.Must(session.NewSession()) - kinesisClient := kinesis.New(sess, aws.NewConfig().WithEndpoint(endpoint).WithRegion("us-east-1")) + ctx := t.Context() + + cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion("us-east-1"), config.WithCredentialsProvider(aws.AnonymousCredentials{})) + require.NoError(t, err) + + kinesisClient := kinesis.NewFromConfig(cfg, func(o *kinesis.Options) { + o.BaseEndpoint = aws.String(endpoint) + }) for i := range count { partition := "partition" @@ -70,7 +78,7 @@ func WriteToStream(t *testing.T, endpoint string, streamName string, count int, data = []byte(strconv.Itoa(i)) } - _, err := kinesisClient.PutRecord(&kinesis.PutRecordInput{ + _, err := kinesisClient.PutRecord(ctx, &kinesis.PutRecordInput{ Data: data, PartitionKey: aws.String(partition), StreamName: aws.String(streamName), diff --git a/pkg/acquisition/modules/s3/s3.go b/pkg/acquisition/modules/s3/s3.go index 844d93ca707..3a1e6cb699b 100644 --- a/pkg/acquisition/modules/s3/s3.go +++ b/pkg/acquisition/modules/s3/s3.go @@ -143,7 +143,9 @@ func (s *S3Source) newS3Client() error { var clientOpts []func(*s3.Options) if s.Config.AwsEndpoint != "" { - clientOpts = append(clientOpts, func(o *s3.Options) { o.BaseEndpoint = aws.String(s.Config.AwsEndpoint) }) + clientOpts = append(clientOpts, func(o *s3.Options) { + o.BaseEndpoint = aws.String(s.Config.AwsEndpoint) + }) } s.s3Client = s3.NewFromConfig(cfg, clientOpts...) @@ -354,16 +356,19 @@ func (s *S3Source) extractBucketAndPrefix(message *string) (string, string, erro s.Config.SQSFormat = SQSFormatEventBridge return bucket, key, nil } + bucket, key, err = extractBucketAndPrefixFromS3Notif(message) if err == nil { s.Config.SQSFormat = SQSFormatS3Notification return bucket, key, nil } + bucket, key, err = extractBucketAndPrefixFromSNSNotif(message) if err == nil { s.Config.SQSFormat = SQSFormatSNS return bucket, key, nil } + return "", "", errors.New("SQS message format not supported") } } @@ -379,6 +384,7 @@ func (s *S3Source) sqsPoll() error { return nil default: logger.Trace("Polling SQS queue") + out, err := s.sqsClient.ReceiveMessage(s.ctx, &sqs.ReceiveMessageInput{ QueueUrl: aws.String(s.Config.SQSName), MaxNumberOfMessages: 10, @@ -388,12 +394,15 @@ func (s *S3Source) sqsPoll() error { logger.Errorf("Error while polling SQS: %s", err) continue } + logger.Tracef("SQS output: %v", out) logger.Debugf("Received %d messages from SQS", len(out.Messages)) + for _, message := range out.Messages { if s.metricsLevel != metrics.AcquisitionMetricsLevelNone { metrics.S3DataSourceSQSMessagesReceived.WithLabelValues(s.Config.SQSName).Inc() } + bucket, key, err := s.extractBucketAndPrefix(message.Body) if err != nil { logger.Errorf("Error while parsing SQS message: %s", err) @@ -406,10 +415,14 @@ func (s *S3Source) sqsPoll() error { if err != nil { logger.Errorf("Error while deleting SQS message: %s", err) } + continue } + logger.Debugf("Received SQS message for object %s/%s", bucket, key) + s.readerChan <- S3Object{Key: key, Bucket: bucket} + _, err = s.sqsClient.DeleteMessage(s.ctx, &sqs.DeleteMessageInput{ QueueUrl: aws.String(s.Config.SQSName), @@ -418,6 +431,7 @@ func (s *S3Source) sqsPoll() error { if err != nil { logger.Errorf("Error while deleting SQS message: %s", err) } + logger.Debugf("Deleted SQS message for object %s/%s", bucket, key) } } @@ -446,10 +460,12 @@ func (s *S3Source) readFile(bucket string, key string) error { if strings.HasSuffix(key, ".gz") { // This *might* be a gzipped file, but sometimes the SDK will decompress the data for us (it's not clear when it happens, only had the issue with cloudtrail logs) header := make([]byte, 2) + _, err := output.Body.Read(header) if err != nil { return fmt.Errorf("failed to read header of object %s/%s: %w", bucket, key, err) } + if header[0] == 0x1f && header[1] == 0x8b { gz, err := gzip.NewReader(io.MultiReader(bytes.NewReader(header), output.Body)) if err != nil { @@ -462,11 +478,14 @@ func (s *S3Source) readFile(bucket string, key string) error { } else { scanner = bufio.NewScanner(output.Body) } + if s.Config.MaxBufferSize > 0 { s.logger.Infof("Setting max buffer size to %d", s.Config.MaxBufferSize) + buf := make([]byte, 0, bufio.MaxScanTokenSize) scanner.Buffer(buf, s.Config.MaxBufferSize) } + for scanner.Scan() { select { case <-s.t.Dying(): @@ -475,32 +494,40 @@ func (s *S3Source) readFile(bucket string, key string) error { default: text := scanner.Text() logger.Tracef("Read line %s", text) + if s.metricsLevel != metrics.AcquisitionMetricsLevelNone { metrics.S3DataSourceLinesRead.With(prometheus.Labels{"bucket": bucket, "datasource_type": "s3", "acquis_type": s.Config.Labels["type"]}).Inc() } + l := types.Line{} l.Raw = text l.Labels = s.Config.Labels l.Time = time.Now().UTC() l.Process = true l.Module = s.GetName() + switch s.metricsLevel { case metrics.AcquisitionMetricsLevelFull: l.Src = bucket + "/" + key case metrics.AcquisitionMetricsLevelAggregated, metrics.AcquisitionMetricsLevelNone: // Even if metrics are disabled, we want to source in the event l.Src = bucket } + evt := types.MakeEvent(s.Config.UseTimeMachine, types.LOG, true) evt.Line = l + s.out <- evt } } + if err := scanner.Err(); err != nil { return fmt.Errorf("failed to read object %s/%s: %s", bucket, key, err) } + if s.metricsLevel != metrics.AcquisitionMetricsLevelNone { metrics.S3DataSourceObjectsRead.WithLabelValues(bucket).Inc() } + return nil } @@ -518,13 +545,16 @@ func (*S3Source) GetAggregMetrics() []prometheus.Collector { func (s *S3Source) UnmarshalConfig(yamlConfig []byte) error { s.Config = S3Configuration{} + err := yaml.UnmarshalWithOptions(yamlConfig, &s.Config, yaml.Strict()) if err != nil { return fmt.Errorf("cannot parse S3Acquisition configuration: %s", yaml.FormatError(err, false, false)) } + if s.Config.Mode == "" { s.Config.Mode = configuration.TAIL_MODE } + if s.Config.PollingMethod == "" { s.Config.PollingMethod = PollMethodList } @@ -610,8 +640,10 @@ func (s *S3Source) ConfigureByDSN(dsn string, labels map[string]string, logger * "bucket": s.Config.BucketName, "prefix": s.Config.Prefix, }) + dsn = strings.TrimPrefix(dsn, "s3://") args := strings.Split(dsn, "?") + if args[0] == "" { return errors.New("empty s3:// DSN") } @@ -621,25 +653,30 @@ func (s *S3Source) ConfigureByDSN(dsn string, labels map[string]string, logger * if err != nil { return fmt.Errorf("could not parse s3 args: %w", err) } + for key, value := range params { switch key { case "log_level": if len(value) != 1 { return errors.New("expected zero or one value for 'log_level'") } + lvl, err := log.ParseLevel(value[0]) if err != nil { return fmt.Errorf("unknown level %s: %w", value[0], err) } + s.logger.Logger.SetLevel(lvl) case "max_buffer_size": if len(value) != 1 { return errors.New("expected zero or one value for 'max_buffer_size'") } + maxBufferSize, err := strconv.Atoi(value[0]) if err != nil { return fmt.Errorf("invalid value for 'max_buffer_size': %w", err) } + s.logger.Debugf("Setting max buffer size to %d", maxBufferSize) s.Config.MaxBufferSize = maxBufferSize default: @@ -692,6 +729,7 @@ func (s *S3Source) OneShotAcquisition(ctx context.Context, out chan types.Event, s.ctx, s.cancel = context.WithCancel(ctx) s.Config.UseTimeMachine = true s.t = t + if s.Config.Key != "" { err := s.readFile(s.Config.BucketName, s.Config.Key) if err != nil { @@ -703,6 +741,7 @@ func (s *S3Source) OneShotAcquisition(ctx context.Context, out chan types.Event, if err != nil { return err } + for _, object := range objects { err := s.readFile(s.Config.BucketName, *object.Key) if err != nil { @@ -710,7 +749,9 @@ func (s *S3Source) OneShotAcquisition(ctx context.Context, out chan types.Event, } } } + t.Kill(nil) + return nil } @@ -724,12 +765,14 @@ func (s *S3Source) StreamingAcquisition(ctx context.Context, out chan types.Even s.readManager() return nil }) + if s.Config.PollingMethod == PollMethodSQS { t.Go(func() error { err := s.sqsPoll() if err != nil { return err } + return nil }) } else { @@ -738,9 +781,11 @@ func (s *S3Source) StreamingAcquisition(ctx context.Context, out chan types.Even if err != nil { return err } + return nil }) } + return nil }