Skip to content

Commit 3221a8e

Browse files
committed
context
1 parent bb57fd4 commit 3221a8e

File tree

1 file changed

+27
-42
lines changed

1 file changed

+27
-42
lines changed

pkg/acquisition/modules/kinesis/kinesis.go

Lines changed: 27 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -195,12 +195,10 @@ func (k *KinesisSource) decodeFromSubscription(record []byte) ([]CloudwatchSubsc
195195
return subscriptionRecord.LogEvents, nil
196196
}
197197

198-
func (k *KinesisSource) WaitForConsumerDeregistration(consumerName string, streamARN string) error {
198+
func (k *KinesisSource) WaitForConsumerDeregistration(ctx context.Context, consumerName string, streamARN string) error {
199199
maxTries := k.Config.MaxRetries
200200
for i := range maxTries {
201-
_, err := k.kClient.DescribeStreamConsumer(
202-
context.TODO(),
203-
&kinesis.DescribeStreamConsumerInput{
201+
_, err := k.kClient.DescribeStreamConsumer(ctx, &kinesis.DescribeStreamConsumerInput{
204202
ConsumerName: aws.String(consumerName),
205203
StreamARN: aws.String(streamARN),
206204
})
@@ -221,11 +219,9 @@ func (k *KinesisSource) WaitForConsumerDeregistration(consumerName string, strea
221219
return fmt.Errorf("consumer %s is not deregistered after %d tries", consumerName, maxTries)
222220
}
223221

224-
func (k *KinesisSource) DeregisterConsumer() error {
222+
func (k *KinesisSource) DeregisterConsumer(ctx context.Context) error {
225223
k.logger.Debugf("Deregistering consumer %s if it exists", k.Config.ConsumerName)
226-
_, err := k.kClient.DeregisterStreamConsumer(
227-
context.TODO(),
228-
&kinesis.DeregisterStreamConsumerInput{
224+
_, err := k.kClient.DeregisterStreamConsumer(ctx, &kinesis.DeregisterStreamConsumerInput{
229225
ConsumerName: aws.String(k.Config.ConsumerName),
230226
StreamARN: aws.String(k.Config.StreamARN),
231227
})
@@ -239,20 +235,18 @@ func (k *KinesisSource) DeregisterConsumer() error {
239235
return fmt.Errorf("cannot deregister stream consumer: %w", err)
240236
}
241237

242-
err = k.WaitForConsumerDeregistration(k.Config.ConsumerName, k.Config.StreamARN)
238+
err = k.WaitForConsumerDeregistration(ctx, k.Config.ConsumerName, k.Config.StreamARN)
243239
if err != nil {
244240
return fmt.Errorf("cannot wait for consumer deregistration: %w", err)
245241
}
246242

247243
return nil
248244
}
249245

250-
func (k *KinesisSource) WaitForConsumerRegistration(consumerARN string) error {
246+
func (k *KinesisSource) WaitForConsumerRegistration(ctx context.Context, consumerARN string) error {
251247
maxTries := k.Config.MaxRetries
252248
for i := range maxTries {
253-
describeOutput, err := k.kClient.DescribeStreamConsumer(
254-
context.TODO(),
255-
&kinesis.DescribeStreamConsumerInput{
249+
describeOutput, err := k.kClient.DescribeStreamConsumer(ctx, &kinesis.DescribeStreamConsumerInput{
256250
ConsumerARN: aws.String(consumerARN),
257251
})
258252
if err != nil {
@@ -271,20 +265,18 @@ func (k *KinesisSource) WaitForConsumerRegistration(consumerARN string) error {
271265
return fmt.Errorf("consumer %s is not active after %d tries", consumerARN, maxTries)
272266
}
273267

274-
func (k *KinesisSource) RegisterConsumer() (*kinesis.RegisterStreamConsumerOutput, error) {
268+
func (k *KinesisSource) RegisterConsumer(ctx context.Context) (*kinesis.RegisterStreamConsumerOutput, error) {
275269
k.logger.Debugf("Registering consumer %s", k.Config.ConsumerName)
276270

277-
streamConsumer, err := k.kClient.RegisterStreamConsumer(
278-
context.TODO(),
279-
&kinesis.RegisterStreamConsumerInput{
271+
streamConsumer, err := k.kClient.RegisterStreamConsumer(ctx, &kinesis.RegisterStreamConsumerInput{
280272
ConsumerName: aws.String(k.Config.ConsumerName),
281273
StreamARN: aws.String(k.Config.StreamARN),
282274
})
283275
if err != nil {
284276
return nil, fmt.Errorf("cannot register stream consumer: %w", err)
285277
}
286278

287-
err = k.WaitForConsumerRegistration(*streamConsumer.Consumer.ConsumerARN)
279+
err = k.WaitForConsumerRegistration(ctx, *streamConsumer.Consumer.ConsumerARN)
288280
if err != nil {
289281
return nil, fmt.Errorf("timeout while waiting for consumer to be active: %w", err)
290282
}
@@ -378,10 +370,8 @@ func (k *KinesisSource) ReadFromSubscription(reader kinesis.SubscribeToShardEven
378370
}
379371
}
380372

381-
func (k *KinesisSource) SubscribeToShards(arn arn.ARN, streamConsumer *kinesis.RegisterStreamConsumerOutput, out chan types.Event) error {
382-
shards, err := k.kClient.ListShards(
383-
context.TODO(),
384-
&kinesis.ListShardsInput{
373+
func (k *KinesisSource) SubscribeToShards(ctx context.Context, arn arn.ARN, streamConsumer *kinesis.RegisterStreamConsumerOutput, out chan types.Event) error {
374+
shards, err := k.kClient.ListShards(ctx, &kinesis.ListShardsInput{
385375
StreamName: aws.String(arn.Resource[7:]),
386376
})
387377
if err != nil {
@@ -391,9 +381,7 @@ func (k *KinesisSource) SubscribeToShards(arn arn.ARN, streamConsumer *kinesis.R
391381
for _, shard := range shards.Shards {
392382
shardID := *shard.ShardId
393383

394-
r, err := k.kClient.SubscribeToShard(
395-
context.TODO(),
396-
&kinesis.SubscribeToShardInput{
384+
r, err := k.kClient.SubscribeToShard(ctx, &kinesis.SubscribeToShardInput{
397385
ShardId: aws.String(shardID),
398386
StartingPosition: &kinTypes.StartingPosition{Type: kinTypes.ShardIteratorTypeLatest},
399387
ConsumerARN: streamConsumer.Consumer.ConsumerARN,
@@ -410,7 +398,7 @@ func (k *KinesisSource) SubscribeToShards(arn arn.ARN, streamConsumer *kinesis.R
410398
return nil
411399
}
412400

413-
func (k *KinesisSource) EnhancedRead(out chan types.Event, t *tomb.Tomb) error {
401+
func (k *KinesisSource) EnhancedRead(ctx context.Context, out chan types.Event, t *tomb.Tomb) error {
414402
parsedARN, err := arn.Parse(k.Config.StreamARN)
415403
if err != nil {
416404
return fmt.Errorf("cannot parse stream ARN: %w", err)
@@ -423,20 +411,20 @@ func (k *KinesisSource) EnhancedRead(out chan types.Event, t *tomb.Tomb) error {
423411
k.logger = k.logger.WithField("stream", parsedARN.Resource[7:])
424412
k.logger.Info("starting kinesis acquisition with enhanced fan-out")
425413

426-
err = k.DeregisterConsumer()
414+
err = k.DeregisterConsumer(ctx)
427415
if err != nil {
428416
return fmt.Errorf("cannot deregister consumer: %w", err)
429417
}
430418

431-
streamConsumer, err := k.RegisterConsumer()
419+
streamConsumer, err := k.RegisterConsumer(ctx)
432420
if err != nil {
433421
return fmt.Errorf("cannot register consumer: %w", err)
434422
}
435423

436424
for {
437425
k.shardReaderTomb = &tomb.Tomb{}
438426

439-
err = k.SubscribeToShards(parsedARN, streamConsumer, out)
427+
err = k.SubscribeToShards(ctx, parsedARN, streamConsumer, out)
440428
if err != nil {
441429
return fmt.Errorf("cannot subscribe to shards: %w", err)
442430
}
@@ -446,7 +434,7 @@ func (k *KinesisSource) EnhancedRead(out chan types.Event, t *tomb.Tomb) error {
446434
k.shardReaderTomb.Kill(nil)
447435
_ = k.shardReaderTomb.Wait() // we don't care about the error as we kill the tomb ourselves
448436

449-
err = k.DeregisterConsumer()
437+
err = k.DeregisterConsumer(ctx)
450438
if err != nil {
451439
return fmt.Errorf("cannot deregister consumer: %w", err)
452440
}
@@ -466,12 +454,11 @@ func (k *KinesisSource) EnhancedRead(out chan types.Event, t *tomb.Tomb) error {
466454
}
467455
}
468456

469-
func (k *KinesisSource) ReadFromShard(out chan types.Event, shardID string) error {
457+
func (k *KinesisSource) ReadFromShard(ctx context.Context, out chan types.Event, shardID string) error {
470458
logger := k.logger.WithField("shard", shardID)
471459
logger.Debugf("Starting to read shard")
472460

473-
sharIt, err := k.kClient.GetShardIterator(
474-
context.TODO(),
461+
sharIt, err := k.kClient.GetShardIterator(ctx,
475462
&kinesis.GetShardIteratorInput{
476463
ShardId: aws.String(shardID),
477464
StreamName: &k.Config.StreamName,
@@ -489,7 +476,7 @@ func (k *KinesisSource) ReadFromShard(out chan types.Event, shardID string) erro
489476
for {
490477
select {
491478
case <-ticker.C:
492-
records, err := k.kClient.GetRecords(context.TODO(), &kinesis.GetRecordsInput{ShardIterator: it})
479+
records, err := k.kClient.GetRecords(ctx, &kinesis.GetRecordsInput{ShardIterator: it})
493480
it = records.NextShardIterator
494481

495482
var throughputErr *kinTypes.ProvisionedThroughputExceededException
@@ -526,14 +513,12 @@ func (k *KinesisSource) ReadFromShard(out chan types.Event, shardID string) erro
526513
}
527514
}
528515

529-
func (k *KinesisSource) ReadFromStream(out chan types.Event, t *tomb.Tomb) error {
516+
func (k *KinesisSource) ReadFromStream(ctx context.Context, out chan types.Event, t *tomb.Tomb) error {
530517
k.logger = k.logger.WithField("stream", k.Config.StreamName)
531518
k.logger.Info("starting kinesis acquisition from shards")
532519

533520
for {
534-
shards, err := k.kClient.ListShards(
535-
context.TODO(),
536-
&kinesis.ListShardsInput{
521+
shards, err := k.kClient.ListShards(ctx, &kinesis.ListShardsInput{
537522
StreamName: aws.String(k.Config.StreamName),
538523
})
539524
if err != nil {
@@ -547,7 +532,7 @@ func (k *KinesisSource) ReadFromStream(out chan types.Event, t *tomb.Tomb) error
547532

548533
k.shardReaderTomb.Go(func() error {
549534
defer trace.CatchPanic("crowdsec/acquis/kinesis/streaming/shard")
550-
return k.ReadFromShard(out, shardID)
535+
return k.ReadFromShard(ctx, out, shardID)
551536
})
552537
}
553538
select {
@@ -571,15 +556,15 @@ func (k *KinesisSource) ReadFromStream(out chan types.Event, t *tomb.Tomb) error
571556
}
572557
}
573558

574-
func (k *KinesisSource) StreamingAcquisition(_ context.Context, out chan types.Event, t *tomb.Tomb) error {
559+
func (k *KinesisSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error {
575560
t.Go(func() error {
576561
defer trace.CatchPanic("crowdsec/acquis/kinesis/streaming")
577562

578563
if k.Config.UseEnhancedFanOut {
579-
return k.EnhancedRead(out, t)
564+
return k.EnhancedRead(ctx, out, t)
580565
}
581566

582-
return k.ReadFromStream(out, t)
567+
return k.ReadFromStream(ctx, out, t)
583568
})
584569

585570
return nil

0 commit comments

Comments
 (0)