Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion v1/brokers/redis/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,9 @@ func (b *Broker) consumeOne(delivery []byte, taskProcessor iface.TaskProcessor)
conn := b.open()
defer conn.Close()

conn.Do("RPUSH", getQueue(b.GetConfig(), taskProcessor), delivery)
// Adjust routing key (this decides which queue the message will be send back to)
b.Broker.AdjustRoutingKey(signature)
conn.Do("RPUSH", signature.RoutingKey, delivery)
return nil
}

Expand Down
39 changes: 18 additions & 21 deletions v1/brokers/sqs/sqs.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ type Broker struct {
stopReceivingChan chan int
sess *session.Session
service sqsiface.SQSAPI
queueUrl *string
}

// New creates new Broker instance
Expand Down Expand Up @@ -65,10 +64,6 @@ func (b *Broker) GetPendingTasks(queue string) ([]*tasks.Signature, error) {
// StartConsuming enters a loop and waits for incoming messages
func (b *Broker) StartConsuming(consumerTag string, concurrency int, taskProcessor iface.TaskProcessor) (bool, error) {
b.Broker.StartConsuming(consumerTag, concurrency, taskProcessor)
qURL := b.getQueueURL(taskProcessor)
//save it so that it can be used later when attempting to delete task
b.queueUrl = qURL

deliveries := make(chan *awssqs.ReceiveMessageOutput, concurrency)
pool := make(chan struct{}, concurrency)

Expand All @@ -82,15 +77,15 @@ func (b *Broker) StartConsuming(consumerTag string, concurrency int, taskProcess
go func() {
defer b.receivingWG.Done()

log.INFO.Printf("[*] Waiting for messages on queue: %s. To exit press CTRL+C\n", *qURL)

log.INFO.Printf("[*] Waiting for messages on queue. To exit press CTRL+C\n")
for {
select {
// A way to stop this goroutine from b.StopConsuming
case <-b.stopReceivingChan:
close(deliveries)
return
case <-pool:
qURL := b.getQueueURL(taskProcessor)
output, err := b.receiveMessage(qURL)
if err == nil && len(output.Messages) > 0 {
deliveries <- output
Expand Down Expand Up @@ -140,7 +135,7 @@ func (b *Broker) Publish(ctx context.Context, signature *tasks.Signature) error

MsgInput := &awssqs.SendMessageInput{
MessageBody: aws.String(string(msg)),
QueueUrl: aws.String(b.GetConfig().Broker + "/" + signature.RoutingKey),
QueueUrl: b.queueToURL(signature.RoutingKey),
}

// if this is a fifo queue, there needs to be some additional parameters.
Expand Down Expand Up @@ -220,7 +215,7 @@ func (b *Broker) consumeOne(delivery *awssqs.ReceiveMessageOutput, taskProcessor
// and leave the message in the queue
if !b.IsTaskRegistered(sig.Name) {
if sig.IgnoreWhenTaskNotRegistered {
b.deleteOne(delivery)
b.deleteOne(delivery, sig)
}
return fmt.Errorf("task %s is not registered", sig.Name)
}
Expand All @@ -234,15 +229,19 @@ func (b *Broker) consumeOne(delivery *awssqs.ReceiveMessageOutput, taskProcessor
return err
}
// Delete message after successfully consuming and processing the message
if err = b.deleteOne(delivery); err != nil {
if err = b.deleteOne(delivery, sig); err != nil {
log.ERROR.Printf("error when deleting the delivery. delivery is %v, Error=%s", delivery, err)
}
return err
}

// deleteOne is a method delete a delivery from AWS SQS
func (b *Broker) deleteOne(delivery *awssqs.ReceiveMessageOutput) error {
func (b *Broker) deleteOne(delivery *awssqs.ReceiveMessageOutput, sig *tasks.Signature) error {
qURL := b.defaultQueueURL()
if sig.RoutingKey != "" {
qURL = b.queueToURL(sig.RoutingKey)
}

_, err := b.service.DeleteMessage(&awssqs.DeleteMessageInput{
QueueUrl: qURL,
ReceiptHandle: delivery.Messages[0].ReceiptHandle,
Expand All @@ -256,12 +255,7 @@ func (b *Broker) deleteOne(delivery *awssqs.ReceiveMessageOutput) error {

// defaultQueueURL is a method returns the default queue url
func (b *Broker) defaultQueueURL() *string {
if b.queueUrl != nil {
return b.queueUrl
} else {
return aws.String(b.GetConfig().Broker + "/" + b.GetConfig().DefaultQueue)
}

return b.queueToURL(b.GetConfig().DefaultQueue)
}

// receiveMessage is a method receives a message from specified queue url
Expand Down Expand Up @@ -360,10 +354,13 @@ func (b *Broker) stopReceiving() {
// getQueueURL is a method returns that returns queueURL first by checking if custom queue was set and usign it
// otherwise using default queueName from config
func (b *Broker) getQueueURL(taskProcessor iface.TaskProcessor) *string {
queueName := b.GetConfig().DefaultQueue
if taskProcessor.CustomQueue() != "" {
queueName = taskProcessor.CustomQueue()
if customQueue := taskProcessor.CustomQueue(); customQueue != "" {
return b.queueToURL(customQueue)
}

return aws.String(b.GetConfig().Broker + "/" + queueName)
return b.defaultQueueURL()
}

func (b *Broker) queueToURL(queue string) *string {
return aws.String(b.GetConfig().Broker + "/" + queue)
}
3 changes: 2 additions & 1 deletion v1/brokers/sqs/sqs_export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/RichardKnop/machinery/v1/brokers/iface"
"github.com/RichardKnop/machinery/v1/common"
"github.com/RichardKnop/machinery/v1/config"
"github.com/RichardKnop/machinery/v1/tasks"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sqs/sqsiface"
Expand Down Expand Up @@ -135,7 +136,7 @@ func (b *Broker) ConsumeOneForTest(delivery *awssqs.ReceiveMessageOutput, taskPr
}

func (b *Broker) DeleteOneForTest(delivery *awssqs.ReceiveMessageOutput) error {
return b.deleteOne(delivery)
return b.deleteOne(delivery, &tasks.Signature{RoutingKey: ""})
}

func (b *Broker) DefaultQueueURLForTest() *string {
Expand Down
46 changes: 46 additions & 0 deletions v1/brokers/sqs/sqs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,49 @@ func TestPrivateFunc_consumeWithConcurrency(t *testing.T) {
t.Fatal("task not processed in 10 seconds")
}
}

type roundRobinQueues struct {
queues []string
currentIndex int
}

func NewRoundRobinQueues(queues []string) *roundRobinQueues {
return &roundRobinQueues{
queues: queues,
currentIndex: -1,
}
}

func (r *roundRobinQueues) Peek() string {
return r.queues[r.currentIndex]
}

func (r *roundRobinQueues) Next() string {
r.currentIndex += 1
if r.currentIndex >= len(r.queues) {
r.currentIndex = 0
}

q := r.queues[r.currentIndex]
return q
}

func TestPrivateFunc_consumeWithRoundRobinQueues(t *testing.T) {
server1, err := machinery.NewServer(cnf)
if err != nil {
t.Fatal(err)
}

w := server1.NewWorker("test-worker", 0)

// Assigning a getQueueHandler to `Next` method of roundRobinQueues
rr := NewRoundRobinQueues([]string{"custom-queue-0", "custom-queue-1", "custom-queue-2", "custom-queue-3"})
w.SetGetQueueHandler(rr.Next)

for i := 0; i < 5; i++ {
// the queue url of the broker should match the current queue url of roundRobin
// and thus queues are being utilized in round-robin fashion
qURL := testAWSSQSBroker.GetQueueURLForTest(w)
assert.Equal(t, qURL, testAWSSQSBroker.GetCustomQueueURL(rr.Peek()))
}
}
11 changes: 11 additions & 0 deletions v1/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type Worker struct {
errorHandler func(err error)
preTaskHandler func(*tasks.Signature)
postTaskHandler func(*tasks.Signature)
getQueueHandler func() string
}

// Launch starts a new worker process. The worker subscribes
Expand Down Expand Up @@ -109,6 +110,11 @@ func (worker *Worker) LaunchAsync(errorsChan chan<- error) {

// CustomQueue returns Custom Queue of the running worker process
func (worker *Worker) CustomQueue() string {
// if the handler is defined, use that to fetch the queue name
if worker.getQueueHandler != nil {
return worker.getQueueHandler()
}

return worker.Queue
}

Expand Down Expand Up @@ -392,6 +398,11 @@ func (worker *Worker) SetPostTaskHandler(handler func(*tasks.Signature)) {
worker.postTaskHandler = handler
}

//SetGetQueueHandler sets a get queue handler to fetch queue name from
func (worker *Worker) SetGetQueueHandler(handler func() string) {
worker.getQueueHandler = handler
}

//GetServer returns server
func (worker *Worker) GetServer() *Server {
return worker.server
Expand Down