From da030ae810385006105fc1ed9ea9363ecf7e7dd2 Mon Sep 17 00:00:00 2001 From: Arsalan Khairani Date: Thu, 24 Oct 2019 16:11:55 +0500 Subject: [PATCH] Ability to switch between multiple queues --- v1/brokers/redis/redis.go | 4 ++- v1/brokers/sqs/sqs.go | 39 ++++++++++++-------------- v1/brokers/sqs/sqs_export_test.go | 3 +- v1/brokers/sqs/sqs_test.go | 46 +++++++++++++++++++++++++++++++ v1/worker.go | 11 ++++++++ 5 files changed, 80 insertions(+), 23 deletions(-) diff --git a/v1/brokers/redis/redis.go b/v1/brokers/redis/redis.go index edf240125..10a4ea499 100644 --- a/v1/brokers/redis/redis.go +++ b/v1/brokers/redis/redis.go @@ -290,7 +290,9 @@ func (b *Broker) consumeOne(delivery []byte, taskProcessor iface.TaskProcessor) conn := b.open() defer conn.Close() - conn.Do("RPUSH", getQueue(b.GetConfig(), taskProcessor), delivery) + // Adjust routing key (this decides which queue the message will be send back to) + b.Broker.AdjustRoutingKey(signature) + conn.Do("RPUSH", signature.RoutingKey, delivery) return nil } diff --git a/v1/brokers/sqs/sqs.go b/v1/brokers/sqs/sqs.go index 318481d65..ffb6ddce3 100644 --- a/v1/brokers/sqs/sqs.go +++ b/v1/brokers/sqs/sqs.go @@ -35,7 +35,6 @@ type Broker struct { stopReceivingChan chan int sess *session.Session service sqsiface.SQSAPI - queueUrl *string } // New creates new Broker instance @@ -65,10 +64,6 @@ func (b *Broker) GetPendingTasks(queue string) ([]*tasks.Signature, error) { // StartConsuming enters a loop and waits for incoming messages func (b *Broker) StartConsuming(consumerTag string, concurrency int, taskProcessor iface.TaskProcessor) (bool, error) { b.Broker.StartConsuming(consumerTag, concurrency, taskProcessor) - qURL := b.getQueueURL(taskProcessor) - //save it so that it can be used later when attempting to delete task - b.queueUrl = qURL - deliveries := make(chan *awssqs.ReceiveMessageOutput, concurrency) pool := make(chan struct{}, concurrency) @@ -82,8 +77,7 @@ func (b *Broker) StartConsuming(consumerTag string, concurrency int, taskProcess go func() { defer b.receivingWG.Done() - log.INFO.Printf("[*] Waiting for messages on queue: %s. To exit press CTRL+C\n", *qURL) - + log.INFO.Printf("[*] Waiting for messages on queue. To exit press CTRL+C\n") for { select { // A way to stop this goroutine from b.StopConsuming @@ -91,6 +85,7 @@ func (b *Broker) StartConsuming(consumerTag string, concurrency int, taskProcess close(deliveries) return case <-pool: + qURL := b.getQueueURL(taskProcessor) output, err := b.receiveMessage(qURL) if err == nil && len(output.Messages) > 0 { deliveries <- output @@ -140,7 +135,7 @@ func (b *Broker) Publish(ctx context.Context, signature *tasks.Signature) error MsgInput := &awssqs.SendMessageInput{ MessageBody: aws.String(string(msg)), - QueueUrl: aws.String(b.GetConfig().Broker + "/" + signature.RoutingKey), + QueueUrl: b.queueToURL(signature.RoutingKey), } // if this is a fifo queue, there needs to be some additional parameters. @@ -220,7 +215,7 @@ func (b *Broker) consumeOne(delivery *awssqs.ReceiveMessageOutput, taskProcessor // and leave the message in the queue if !b.IsTaskRegistered(sig.Name) { if sig.IgnoreWhenTaskNotRegistered { - b.deleteOne(delivery) + b.deleteOne(delivery, sig) } return fmt.Errorf("task %s is not registered", sig.Name) } @@ -234,15 +229,19 @@ func (b *Broker) consumeOne(delivery *awssqs.ReceiveMessageOutput, taskProcessor return err } // Delete message after successfully consuming and processing the message - if err = b.deleteOne(delivery); err != nil { + if err = b.deleteOne(delivery, sig); err != nil { log.ERROR.Printf("error when deleting the delivery. delivery is %v, Error=%s", delivery, err) } return err } // deleteOne is a method delete a delivery from AWS SQS -func (b *Broker) deleteOne(delivery *awssqs.ReceiveMessageOutput) error { +func (b *Broker) deleteOne(delivery *awssqs.ReceiveMessageOutput, sig *tasks.Signature) error { qURL := b.defaultQueueURL() + if sig.RoutingKey != "" { + qURL = b.queueToURL(sig.RoutingKey) + } + _, err := b.service.DeleteMessage(&awssqs.DeleteMessageInput{ QueueUrl: qURL, ReceiptHandle: delivery.Messages[0].ReceiptHandle, @@ -256,12 +255,7 @@ func (b *Broker) deleteOne(delivery *awssqs.ReceiveMessageOutput) error { // defaultQueueURL is a method returns the default queue url func (b *Broker) defaultQueueURL() *string { - if b.queueUrl != nil { - return b.queueUrl - } else { - return aws.String(b.GetConfig().Broker + "/" + b.GetConfig().DefaultQueue) - } - + return b.queueToURL(b.GetConfig().DefaultQueue) } // receiveMessage is a method receives a message from specified queue url @@ -360,10 +354,13 @@ func (b *Broker) stopReceiving() { // getQueueURL is a method returns that returns queueURL first by checking if custom queue was set and usign it // otherwise using default queueName from config func (b *Broker) getQueueURL(taskProcessor iface.TaskProcessor) *string { - queueName := b.GetConfig().DefaultQueue - if taskProcessor.CustomQueue() != "" { - queueName = taskProcessor.CustomQueue() + if customQueue := taskProcessor.CustomQueue(); customQueue != "" { + return b.queueToURL(customQueue) } - return aws.String(b.GetConfig().Broker + "/" + queueName) + return b.defaultQueueURL() +} + +func (b *Broker) queueToURL(queue string) *string { + return aws.String(b.GetConfig().Broker + "/" + queue) } diff --git a/v1/brokers/sqs/sqs_export_test.go b/v1/brokers/sqs/sqs_export_test.go index 25b0348b2..2d7208ecd 100644 --- a/v1/brokers/sqs/sqs_export_test.go +++ b/v1/brokers/sqs/sqs_export_test.go @@ -9,6 +9,7 @@ import ( "github.com/RichardKnop/machinery/v1/brokers/iface" "github.com/RichardKnop/machinery/v1/common" "github.com/RichardKnop/machinery/v1/config" + "github.com/RichardKnop/machinery/v1/tasks" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/sqs/sqsiface" @@ -135,7 +136,7 @@ func (b *Broker) ConsumeOneForTest(delivery *awssqs.ReceiveMessageOutput, taskPr } func (b *Broker) DeleteOneForTest(delivery *awssqs.ReceiveMessageOutput) error { - return b.deleteOne(delivery) + return b.deleteOne(delivery, &tasks.Signature{RoutingKey: ""}) } func (b *Broker) DefaultQueueURLForTest() *string { diff --git a/v1/brokers/sqs/sqs_test.go b/v1/brokers/sqs/sqs_test.go index f066ef6d7..3c0c7854a 100644 --- a/v1/brokers/sqs/sqs_test.go +++ b/v1/brokers/sqs/sqs_test.go @@ -302,3 +302,49 @@ func TestPrivateFunc_consumeWithConcurrency(t *testing.T) { t.Fatal("task not processed in 10 seconds") } } + +type roundRobinQueues struct { + queues []string + currentIndex int +} + +func NewRoundRobinQueues(queues []string) *roundRobinQueues { + return &roundRobinQueues{ + queues: queues, + currentIndex: -1, + } +} + +func (r *roundRobinQueues) Peek() string { + return r.queues[r.currentIndex] +} + +func (r *roundRobinQueues) Next() string { + r.currentIndex += 1 + if r.currentIndex >= len(r.queues) { + r.currentIndex = 0 + } + + q := r.queues[r.currentIndex] + return q +} + +func TestPrivateFunc_consumeWithRoundRobinQueues(t *testing.T) { + server1, err := machinery.NewServer(cnf) + if err != nil { + t.Fatal(err) + } + + w := server1.NewWorker("test-worker", 0) + + // Assigning a getQueueHandler to `Next` method of roundRobinQueues + rr := NewRoundRobinQueues([]string{"custom-queue-0", "custom-queue-1", "custom-queue-2", "custom-queue-3"}) + w.SetGetQueueHandler(rr.Next) + + for i := 0; i < 5; i++ { + // the queue url of the broker should match the current queue url of roundRobin + // and thus queues are being utilized in round-robin fashion + qURL := testAWSSQSBroker.GetQueueURLForTest(w) + assert.Equal(t, qURL, testAWSSQSBroker.GetCustomQueueURL(rr.Peek())) + } +} diff --git a/v1/worker.go b/v1/worker.go index 1cfdec497..10b09e571 100644 --- a/v1/worker.go +++ b/v1/worker.go @@ -26,6 +26,7 @@ type Worker struct { errorHandler func(err error) preTaskHandler func(*tasks.Signature) postTaskHandler func(*tasks.Signature) + getQueueHandler func() string } // Launch starts a new worker process. The worker subscribes @@ -109,6 +110,11 @@ func (worker *Worker) LaunchAsync(errorsChan chan<- error) { // CustomQueue returns Custom Queue of the running worker process func (worker *Worker) CustomQueue() string { + // if the handler is defined, use that to fetch the queue name + if worker.getQueueHandler != nil { + return worker.getQueueHandler() + } + return worker.Queue } @@ -392,6 +398,11 @@ func (worker *Worker) SetPostTaskHandler(handler func(*tasks.Signature)) { worker.postTaskHandler = handler } +//SetGetQueueHandler sets a get queue handler to fetch queue name from +func (worker *Worker) SetGetQueueHandler(handler func() string) { + worker.getQueueHandler = handler +} + //GetServer returns server func (worker *Worker) GetServer() *Server { return worker.server