diff --git a/v1/brokers/redis/redis.go b/v1/brokers/redis/redis.go index b130bb240..bc3e7d9bf 100644 --- a/v1/brokers/redis/redis.go +++ b/v1/brokers/redis/redis.go @@ -82,10 +82,12 @@ func (b *Broker) StartConsuming(consumerTag string, concurrency int, taskProcess // Channel to which we will push tasks ready for processing by worker deliveries := make(chan []byte, concurrency) pool := make(chan struct{}, concurrency) + nextTask := make(chan struct{}, concurrency) // initialize worker pool with maxWorkers workers for i := 0; i < concurrency; i++ { pool <- struct{}{} + nextTask <- struct{}{} } // A receiving goroutine keeps popping messages from the queue by BLPOP @@ -94,7 +96,7 @@ func (b *Broker) StartConsuming(consumerTag string, concurrency int, taskProcess go func() { log.INFO.Print("[*] Waiting for messages. To exit press CTRL+C") - + var gotTask bool for { select { // A way to stop this goroutine from b.StopConsuming @@ -110,10 +112,16 @@ func (b *Broker) StartConsuming(consumerTag string, concurrency int, taskProcess } if taskProcessor.PreConsumeHandler() { + if !gotTask { + <-nextTask + gotTask = true + } + task, _ := b.nextTask(getQueue(b.GetConfig(), taskProcessor)) //TODO: should this error be ignored? if len(task) > 0 { deliveries <- task + gotTask = false } } @@ -153,7 +161,7 @@ func (b *Broker) StartConsuming(consumerTag string, concurrency int, taskProcess } }() - if err := b.consume(deliveries, concurrency, taskProcessor); err != nil { + if err := b.consume(deliveries, nextTask, concurrency, taskProcessor); err != nil { return b.GetRetry(), err } @@ -266,7 +274,7 @@ func (b *Broker) GetDelayedTasks() ([]*tasks.Signature, error) { // consume takes delivered messages from the channel and manages a worker pool // to process tasks concurrently -func (b *Broker) consume(deliveries <-chan []byte, concurrency int, taskProcessor iface.TaskProcessor) error { +func (b *Broker) consume(deliveries <-chan []byte, nextPool chan<- struct{}, concurrency int, taskProcessor iface.TaskProcessor) error { errorsChan := make(chan error, concurrency*2) pool := make(chan struct{}, concurrency) @@ -300,6 +308,11 @@ func (b *Broker) consume(deliveries <-chan []byte, concurrency int, taskProcesso // Consume the task inside a goroutine so multiple tasks // can be processed concurrently go func() { + defer func() { + nextPool <- struct{}{} + }() + + if err := b.consumeOne(d, taskProcessor); err != nil { errorsChan <- err } @@ -475,4 +488,4 @@ func (b *Broker) requeueMessage(delivery []byte, taskProcessor iface.TaskProcess conn := b.open() defer conn.Close() conn.Do("RPUSH", getQueue(b.GetConfig(), taskProcessor), delivery) -} \ No newline at end of file +}