diff --git a/go.mod b/go.mod index 4a3321926..1569f9499 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( cloud.google.com/go v0.76.0 // indirect cloud.google.com/go/pubsub v1.10.0 github.com/RichardKnop/logging v0.0.0-20190827224416-1a693bdd4fae + github.com/RichardKnop/machinery/v2 v2.0.11 github.com/aws/aws-sdk-go v1.37.16 github.com/bradfitz/gomemcache v0.0.0-20190913173617-a41fca850d0b github.com/cpuguy83/go-md2man/v2 v2.0.0 // indirect diff --git a/go.sum b/go.sum index 1d7f349f9..480012964 100644 --- a/go.sum +++ b/go.sum @@ -31,6 +31,8 @@ cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2k cloud.google.com/go/pubsub v1.1.0/go.mod h1:EwwdRX2sKPjnvnqCa270oGRyludottCI76h+R3AArQw= cloud.google.com/go/pubsub v1.2.0/go.mod h1:jhfEVHT8odbXTkndysNHCcx0awwzvfOlguIAii9o8iA= cloud.google.com/go/pubsub v1.3.1/go.mod h1:i+ucay31+CNRpDW4Lu78I4xXG+O1r/MAHgjpRVR+TSU= +cloud.google.com/go/pubsub v1.5.0 h1:9cH52jizPUVSSrSe+J16RC9wB0QI7i/cfuCm5UUCcIk= +cloud.google.com/go/pubsub v1.5.0/go.mod h1:ZEwJccE3z93Z2HWvstpri00jOg7oO4UZDtKhwDwqF0w= cloud.google.com/go/pubsub v1.10.0 h1:JK22g5uNpscGPthjJE/D0siWtA6UlU4Cb6pLcyJkzyQ= cloud.google.com/go/pubsub v1.10.0/go.mod h1:eNpTrkOy7dCpkNyaSNetMa6udbgecJMd0ZsTJS/cuNo= cloud.google.com/go/storage v1.0.0/go.mod h1:IhtSnM/ZTZV8YYJWCY8RULGVqBDmpoyjwiyrjsg+URw= @@ -43,9 +45,17 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03 github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/RichardKnop/logging v0.0.0-20190827224416-1a693bdd4fae h1:DcFpTQBYQ9Ct2d6sC7ol0/ynxc2pO1cpGUM+f4t5adg= github.com/RichardKnop/logging v0.0.0-20190827224416-1a693bdd4fae/go.mod h1:rJJ84PyA/Wlmw1hO+xTzV2wsSUon6J5ktg0g8BF2PuU= +github.com/RichardKnop/machinery/v2 v2.0.11 h1:BTfLGOmOju3W/OtlZmLX26OjYNZsU4PJo04pQReycdc= +github.com/RichardKnop/machinery/v2 v2.0.11/go.mod h1:b5Q6cT/w7YLlIl4Vi+jpdEoyYiqhTgx+0USoKb1wzqU= +github.com/RichardKnop/redsync v1.2.0 h1:gK35hR3zZkQigHKm8wOGb9MpJ9BsrW6MzxezwjTcHP0= +github.com/RichardKnop/redsync v1.2.0/go.mod h1:9b8nBGAX3bE2uCfJGSnsDvF23mKyHTZzmvmj5FH3Tp0= +github.com/aws/aws-sdk-go v1.33.6 h1:YLoUeMSx05kHwhS+HLDSpdYYpPzJMyp6hn1cWsJ6a+U= +github.com/aws/aws-sdk-go v1.33.6/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0= github.com/aws/aws-sdk-go v1.34.28/go.mod h1:H7NKnBqNVzoTJpGfLrQkkD+ytBA93eiDYi/+8rV9s48= github.com/aws/aws-sdk-go v1.37.16 h1:Q4YOP2s00NpB9wfmTDZArdcLRuG9ijbnoAwTW3ivleI= github.com/aws/aws-sdk-go v1.37.16/go.mod h1:hcU610XS61/+aQV88ixoOzUoG7v3b31pl2zKMmprdro= +github.com/benbjohnson/clock v1.0.3 h1:vkLuvpK4fmtSCuo60+yC63p7y0BmQ8gm5ZXGuBCJyXg= +github.com/benbjohnson/clock v1.0.3/go.mod h1:bGMdMPoPVvcYyt1gHDf4J2KE153Yf9BuiUKYMaxlTDM= github.com/bradfitz/gomemcache v0.0.0-20190913173617-a41fca850d0b h1:L/QXpzIa3pOvUGt1D1lA5KjYhPBAN/3iWdP7xeFS9F0= github.com/bradfitz/gomemcache v0.0.0-20190913173617-a41fca850d0b/go.mod h1:H0wQNHz2YrLsuXOZozoeDmnHXkNCRmMW0gwFWDfEZDA= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= @@ -73,15 +83,17 @@ github.com/envoyproxy/go-control-plane v0.9.7/go.mod h1:cwu0lG7PUMfa9snN8LXBig5y github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= -github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= +github.com/go-redis/redis v6.15.8+incompatible h1:BKZuG6mCnRj5AOaWJXoCgf6rqTYnYJLe4en2hxT7r9o= +github.com/go-redis/redis v6.15.8+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA= github.com/go-redis/redis v6.15.9+incompatible h1:K0pv1D7EQUjfyoMql+r/jZqCLizCGKFlFgcHWWmHQjg= github.com/go-redis/redis v6.15.9+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA= -github.com/go-redis/redis/v7 v7.4.0 h1:7obg6wUoj05T0EpY0o8B59S9w5yeMWql7sw2kwNW1x4= github.com/go-redis/redis/v7 v7.4.0/go.mod h1:JDNMw23GTyLNC4GZu9njt15ctBQVn7xjRfnwdHj/Dcg= +github.com/go-redis/redis/v8 v8.0.0-beta.6 h1:QeXAkG9L5cWJA+eJTBvhkftE7dwpJ0gbMYeBE2NxXS4= +github.com/go-redis/redis/v8 v8.0.0-beta.6/go.mod h1:g79Vpae8JMzg5qjk8BiwU9tK+HmU3iDVyS4UAJLFycI= github.com/go-redis/redis/v8 v8.1.1/go.mod h1:ysgGY09J/QeDYbu3HikWEIPCwaeOkuNoTgKayTEaEOw= github.com/go-redis/redis/v8 v8.6.0 h1:swqbqOrxaPztsj2Hf1p94M3YAgl7hYEpcw21z299hh8= github.com/go-redis/redis/v8 v8.6.0/go.mod h1:DQ9q4Rk2HtwkrwVrdgmphoOQDMfpvcd/nHEwRsicg8s= @@ -260,6 +272,7 @@ github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnIn github.com/streadway/amqp v1.0.0 h1:kuuDrUJFZL1QYL9hUNuCxNObNzB0bV/ZG5jV3RWAQgo= github.com/streadway/amqp v1.0.0/go.mod h1:AZpEONHx3DKn8O/DFsRAY58/XVQiIPMTMB1SddzLXVw= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1 h1:2vfRuCMp5sSVIDSqO8oNnWJq7mPa6KVP3iPIwFBuy8A= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= diff --git a/integration-tests/suite_test.go b/integration-tests/suite_test.go index 0e99d1cf4..37d606475 100644 --- a/integration-tests/suite_test.go +++ b/integration-tests/suite_test.go @@ -3,9 +3,11 @@ package integration_test import ( "context" "errors" + "fmt" "log" "reflect" "sort" + "strings" "testing" "time" @@ -44,6 +46,7 @@ func testAll(server Server, t *testing.T) { testSendGroup(server, t, 0) // with unlimited concurrency testSendGroup(server, t, 2) // with limited concurrency (2 parallel tasks at the most) testSendChord(server, t) + testSendChordWithError(server, t) testSendChain(server, t) testReturnJustError(server, t) testReturnMultipleValues(server, t) @@ -212,6 +215,42 @@ func testSendChord(server Server, t *testing.T) { } } +func testSendChordWithError(server Server, t *testing.T) { + t1, t2, t3, t4, t5 := newAddTask(1, 1), newAddTask(2, 2), newErrorTask("chord error", true), newMultipleTask(), newHandleErrorTask("handle") + + group, err := tasks.NewGroup(t1, t2, t3) + if err != nil { + t.Fatal(err) + } + + chord, err := tasks.NewChordWithError(group, t4, t5) + if err != nil { + t.Fatal(err) + } + + chordAsyncResult, err := server.SendChord(chord, 10) + if err != nil { + t.Error(err) + } + + results, err := chordAsyncResult.Get(time.Duration(time.Millisecond * 5)) + if err != nil { + t.Error(err) + } + + if len(results) != 1 { + t.Errorf("Number of results returned = %d. Wanted %d", len(results), 1) + } + + if results[0].Interface().(string) != "handle=chord error" { + t.Errorf( + "result = %v(%v), want handle=chord error", + results[0].Type().String(), + results[0].Interface(), + ) + } +} + func testReturnJustError(server Server, t *testing.T) { // Fails, returns error as the only value task := newErrorTask("Test error", true) @@ -385,6 +424,9 @@ func registerTestTasks(server Server) { "delay_test": func() (int64, error) { return time.Now().UTC().UnixNano(), nil }, + "handle_error": func(msg string, errors []string) (string, error) { + return fmt.Sprintf("%s=%s", msg, strings.Join(errors, ",")), nil + }, } server.RegisterTasks(tasks) @@ -486,3 +528,15 @@ func newDelayTask(eta time.Time) *tasks.Signature { ETA: &eta, } } + +func newHandleErrorTask(msg string) *tasks.Signature { + return &tasks.Signature{ + Name: "handle_error", + Args: []tasks.Arg{ + { + Type: "string", + Value: msg, + }, + }, + } +} diff --git a/v1/backends/iface/interfaces.go b/v1/backends/iface/interfaces.go index 56dff517b..c16747dbe 100644 --- a/v1/backends/iface/interfaces.go +++ b/v1/backends/iface/interfaces.go @@ -5,6 +5,7 @@ import ( ) // Backend - a common interface for all result backends +//go:generate go run github.com/vektra/mockery/cmd/mockery -name Backend type Backend interface { // Group related functions InitGroup(groupUUID string, taskUUIDs []string) error diff --git a/v1/backends/iface/mocks/Backend.go b/v1/backends/iface/mocks/Backend.go new file mode 100644 index 000000000..0035c3144 --- /dev/null +++ b/v1/backends/iface/mocks/Backend.go @@ -0,0 +1,241 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import ( + tasks "github.com/RichardKnop/machinery/v1/tasks" + mock "github.com/stretchr/testify/mock" +) + +// Backend is an autogenerated mock type for the Backend type +type Backend struct { + mock.Mock +} + +// GetState provides a mock function with given fields: taskUUID +func (_m *Backend) GetState(taskUUID string) (*tasks.TaskState, error) { + ret := _m.Called(taskUUID) + + var r0 *tasks.TaskState + if rf, ok := ret.Get(0).(func(string) *tasks.TaskState); ok { + r0 = rf(taskUUID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*tasks.TaskState) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(taskUUID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GroupCompleted provides a mock function with given fields: groupUUID, groupTaskCount +func (_m *Backend) GroupCompleted(groupUUID string, groupTaskCount int) (bool, error) { + ret := _m.Called(groupUUID, groupTaskCount) + + var r0 bool + if rf, ok := ret.Get(0).(func(string, int) bool); ok { + r0 = rf(groupUUID, groupTaskCount) + } else { + r0 = ret.Get(0).(bool) + } + + var r1 error + if rf, ok := ret.Get(1).(func(string, int) error); ok { + r1 = rf(groupUUID, groupTaskCount) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GroupTaskStates provides a mock function with given fields: groupUUID, groupTaskCount +func (_m *Backend) GroupTaskStates(groupUUID string, groupTaskCount int) ([]*tasks.TaskState, error) { + ret := _m.Called(groupUUID, groupTaskCount) + + var r0 []*tasks.TaskState + if rf, ok := ret.Get(0).(func(string, int) []*tasks.TaskState); ok { + r0 = rf(groupUUID, groupTaskCount) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*tasks.TaskState) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string, int) error); ok { + r1 = rf(groupUUID, groupTaskCount) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// InitGroup provides a mock function with given fields: groupUUID, taskUUIDs +func (_m *Backend) InitGroup(groupUUID string, taskUUIDs []string) error { + ret := _m.Called(groupUUID, taskUUIDs) + + var r0 error + if rf, ok := ret.Get(0).(func(string, []string) error); ok { + r0 = rf(groupUUID, taskUUIDs) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// IsAMQP provides a mock function with given fields: +func (_m *Backend) IsAMQP() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// PurgeGroupMeta provides a mock function with given fields: groupUUID +func (_m *Backend) PurgeGroupMeta(groupUUID string) error { + ret := _m.Called(groupUUID) + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(groupUUID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// PurgeState provides a mock function with given fields: taskUUID +func (_m *Backend) PurgeState(taskUUID string) error { + ret := _m.Called(taskUUID) + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(taskUUID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SetStateFailure provides a mock function with given fields: signature, err +func (_m *Backend) SetStateFailure(signature *tasks.Signature, err string) error { + ret := _m.Called(signature, err) + + var r0 error + if rf, ok := ret.Get(0).(func(*tasks.Signature, string) error); ok { + r0 = rf(signature, err) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SetStatePending provides a mock function with given fields: signature +func (_m *Backend) SetStatePending(signature *tasks.Signature) error { + ret := _m.Called(signature) + + var r0 error + if rf, ok := ret.Get(0).(func(*tasks.Signature) error); ok { + r0 = rf(signature) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SetStateReceived provides a mock function with given fields: signature +func (_m *Backend) SetStateReceived(signature *tasks.Signature) error { + ret := _m.Called(signature) + + var r0 error + if rf, ok := ret.Get(0).(func(*tasks.Signature) error); ok { + r0 = rf(signature) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SetStateRetry provides a mock function with given fields: signature +func (_m *Backend) SetStateRetry(signature *tasks.Signature) error { + ret := _m.Called(signature) + + var r0 error + if rf, ok := ret.Get(0).(func(*tasks.Signature) error); ok { + r0 = rf(signature) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SetStateStarted provides a mock function with given fields: signature +func (_m *Backend) SetStateStarted(signature *tasks.Signature) error { + ret := _m.Called(signature) + + var r0 error + if rf, ok := ret.Get(0).(func(*tasks.Signature) error); ok { + r0 = rf(signature) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SetStateSuccess provides a mock function with given fields: signature, results +func (_m *Backend) SetStateSuccess(signature *tasks.Signature, results []*tasks.TaskResult) error { + ret := _m.Called(signature, results) + + var r0 error + if rf, ok := ret.Get(0).(func(*tasks.Signature, []*tasks.TaskResult) error); ok { + r0 = rf(signature, results) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// TriggerChord provides a mock function with given fields: groupUUID +func (_m *Backend) TriggerChord(groupUUID string) (bool, error) { + ret := _m.Called(groupUUID) + + var r0 bool + if rf, ok := ret.Get(0).(func(string) bool); ok { + r0 = rf(groupUUID) + } else { + r0 = ret.Get(0).(bool) + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(groupUUID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/v1/backends/result/async_result.go b/v1/backends/result/async_result.go index 99891150d..013ca743c 100644 --- a/v1/backends/result/async_result.go +++ b/v1/backends/result/async_result.go @@ -27,6 +27,7 @@ type AsyncResult struct { type ChordAsyncResult struct { groupAsyncResults []*AsyncResult chordAsyncResult *AsyncResult + errorAsyncResult *AsyncResult backend iface.Backend } @@ -46,7 +47,7 @@ func NewAsyncResult(signature *tasks.Signature, backend iface.Backend) *AsyncRes } // NewChordAsyncResult creates ChordAsyncResult instance -func NewChordAsyncResult(groupTasks []*tasks.Signature, chordCallback *tasks.Signature, backend iface.Backend) *ChordAsyncResult { +func NewChordAsyncResult(groupTasks []*tasks.Signature, chordCallback *tasks.Signature, errorCallback *tasks.Signature, backend iface.Backend) *ChordAsyncResult { asyncResults := make([]*AsyncResult, len(groupTasks)) for i, task := range groupTasks { asyncResults[i] = NewAsyncResult(task, backend) @@ -54,6 +55,7 @@ func NewChordAsyncResult(groupTasks []*tasks.Signature, chordCallback *tasks.Sig return &ChordAsyncResult{ groupAsyncResults: asyncResults, chordAsyncResult: NewAsyncResult(chordCallback, backend), + errorAsyncResult: NewAsyncResult(errorCallback, backend), backend: backend, } } @@ -168,14 +170,18 @@ func (chordAsyncResult *ChordAsyncResult) Get(sleepDuration time.Duration) ([]re return nil, ErrBackendNotConfigured } - var err error + errorSeen := false for _, asyncResult := range chordAsyncResult.groupAsyncResults { - _, err = asyncResult.Get(sleepDuration) + _, err := asyncResult.Get(sleepDuration) if err != nil { - return nil, err + errorSeen = true } } + if errorSeen { + return chordAsyncResult.errorAsyncResult.Get(sleepDuration) + } + return chordAsyncResult.chordAsyncResult.Get(sleepDuration) } diff --git a/v1/iface/interfaces.go b/v1/iface/interfaces.go new file mode 100644 index 000000000..a9172774c --- /dev/null +++ b/v1/iface/interfaces.go @@ -0,0 +1,54 @@ +package iface + +import ( + "context" + + backendsiface "github.com/RichardKnop/machinery/v1/backends/iface" + "github.com/RichardKnop/machinery/v1/backends/result" + brokersiface "github.com/RichardKnop/machinery/v1/brokers/iface" + "github.com/RichardKnop/machinery/v1/config" + "github.com/RichardKnop/machinery/v1/tasks" +) + +// Server for sending and processing tasks +//go:generate go run github.com/vektra/mockery/cmd/mockery -name Server +type Server interface { + // NewWorker(consumerTag string, concurrency int) Worker + // NewCustomQueueWorker(consumerTag string, concurrency int, queue string) Worker + GetBroker() brokersiface.Broker + SetBroker(broker brokersiface.Broker) + GetBackend() backendsiface.Backend + SetBackend(backend backendsiface.Backend) + GetConfig() *config.Config + SetConfig(cnf *config.Config) + SetPreTaskHandler(handler func(*tasks.Signature)) + RegisterTasks(namedTaskFuncs map[string]interface{}) error + RegisterTask(name string, taskFunc interface{}) error + IsTaskRegistered(name string) bool + GetRegisteredTask(name string) (interface{}, error) + SendTaskWithContext(ctx context.Context, signature *tasks.Signature) (*result.AsyncResult, error) + SendTask(signature *tasks.Signature) (*result.AsyncResult, error) + SendChainWithContext(ctx context.Context, chain *tasks.Chain) (*result.ChainAsyncResult, error) + SendChain(chain *tasks.Chain) (*result.ChainAsyncResult, error) + SendGroupWithContext(ctx context.Context, group *tasks.Group, sendConcurrency int) ([]*result.AsyncResult, error) + SendGroup(group *tasks.Group, sendConcurrency int) ([]*result.AsyncResult, error) + SendChordWithContext(ctx context.Context, chord *tasks.Chord, sendConcurrency int) (*result.ChordAsyncResult, error) + SendChord(chord *tasks.Chord, sendConcurrency int) (*result.ChordAsyncResult, error) + GetRegisteredTaskNames() []string +} + +// Worker represents a single worker process +//go:generate go run github.com/vektra/mockery/cmd/mockery -name Worker +type Worker interface { + Launch() error + LaunchAsync(errorsChan chan<- error) + CustomQueue() string + Quit() + Process(signature *tasks.Signature) error + SetErrorHandler(handler func(err error)) + SetPreTaskHandler(handler func(*tasks.Signature)) + SetPostTaskHandler(handler func(*tasks.Signature)) + SetPreConsumeHandler(handler func(Worker) bool) + GetServer() *Server + PreConsumeHandler() bool +} diff --git a/v1/iface/mocks/Server.go b/v1/iface/mocks/Server.go new file mode 100644 index 000000000..8efcde00f --- /dev/null +++ b/v1/iface/mocks/Server.go @@ -0,0 +1,356 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import ( + brokersiface "github.com/RichardKnop/machinery/v1/brokers/iface" + config "github.com/RichardKnop/machinery/v1/config" + + context "context" + + iface "github.com/RichardKnop/machinery/v1/backends/iface" + + mock "github.com/stretchr/testify/mock" + + result "github.com/RichardKnop/machinery/v1/backends/result" + + tasks "github.com/RichardKnop/machinery/v1/tasks" +) + +// Server is an autogenerated mock type for the Server type +type Server struct { + mock.Mock +} + +// GetBackend provides a mock function with given fields: +func (_m *Server) GetBackend() iface.Backend { + ret := _m.Called() + + var r0 iface.Backend + if rf, ok := ret.Get(0).(func() iface.Backend); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(iface.Backend) + } + } + + return r0 +} + +// GetBroker provides a mock function with given fields: +func (_m *Server) GetBroker() brokersiface.Broker { + ret := _m.Called() + + var r0 brokersiface.Broker + if rf, ok := ret.Get(0).(func() brokersiface.Broker); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(brokersiface.Broker) + } + } + + return r0 +} + +// GetConfig provides a mock function with given fields: +func (_m *Server) GetConfig() *config.Config { + ret := _m.Called() + + var r0 *config.Config + if rf, ok := ret.Get(0).(func() *config.Config); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*config.Config) + } + } + + return r0 +} + +// GetRegisteredTask provides a mock function with given fields: name +func (_m *Server) GetRegisteredTask(name string) (interface{}, error) { + ret := _m.Called(name) + + var r0 interface{} + if rf, ok := ret.Get(0).(func(string) interface{}); ok { + r0 = rf(name) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(interface{}) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(name) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetRegisteredTaskNames provides a mock function with given fields: +func (_m *Server) GetRegisteredTaskNames() []string { + ret := _m.Called() + + var r0 []string + if rf, ok := ret.Get(0).(func() []string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + return r0 +} + +// IsTaskRegistered provides a mock function with given fields: name +func (_m *Server) IsTaskRegistered(name string) bool { + ret := _m.Called(name) + + var r0 bool + if rf, ok := ret.Get(0).(func(string) bool); ok { + r0 = rf(name) + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// RegisterTask provides a mock function with given fields: name, taskFunc +func (_m *Server) RegisterTask(name string, taskFunc interface{}) error { + ret := _m.Called(name, taskFunc) + + var r0 error + if rf, ok := ret.Get(0).(func(string, interface{}) error); ok { + r0 = rf(name, taskFunc) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// RegisterTasks provides a mock function with given fields: namedTaskFuncs +func (_m *Server) RegisterTasks(namedTaskFuncs map[string]interface{}) error { + ret := _m.Called(namedTaskFuncs) + + var r0 error + if rf, ok := ret.Get(0).(func(map[string]interface{}) error); ok { + r0 = rf(namedTaskFuncs) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SendChain provides a mock function with given fields: chain +func (_m *Server) SendChain(chain *tasks.Chain) (*result.ChainAsyncResult, error) { + ret := _m.Called(chain) + + var r0 *result.ChainAsyncResult + if rf, ok := ret.Get(0).(func(*tasks.Chain) *result.ChainAsyncResult); ok { + r0 = rf(chain) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*result.ChainAsyncResult) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(*tasks.Chain) error); ok { + r1 = rf(chain) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// SendChainWithContext provides a mock function with given fields: ctx, chain +func (_m *Server) SendChainWithContext(ctx context.Context, chain *tasks.Chain) (*result.ChainAsyncResult, error) { + ret := _m.Called(ctx, chain) + + var r0 *result.ChainAsyncResult + if rf, ok := ret.Get(0).(func(context.Context, *tasks.Chain) *result.ChainAsyncResult); ok { + r0 = rf(ctx, chain) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*result.ChainAsyncResult) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *tasks.Chain) error); ok { + r1 = rf(ctx, chain) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// SendChord provides a mock function with given fields: chord, sendConcurrency +func (_m *Server) SendChord(chord *tasks.Chord, sendConcurrency int) (*result.ChordAsyncResult, error) { + ret := _m.Called(chord, sendConcurrency) + + var r0 *result.ChordAsyncResult + if rf, ok := ret.Get(0).(func(*tasks.Chord, int) *result.ChordAsyncResult); ok { + r0 = rf(chord, sendConcurrency) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*result.ChordAsyncResult) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(*tasks.Chord, int) error); ok { + r1 = rf(chord, sendConcurrency) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// SendChordWithContext provides a mock function with given fields: ctx, chord, sendConcurrency +func (_m *Server) SendChordWithContext(ctx context.Context, chord *tasks.Chord, sendConcurrency int) (*result.ChordAsyncResult, error) { + ret := _m.Called(ctx, chord, sendConcurrency) + + var r0 *result.ChordAsyncResult + if rf, ok := ret.Get(0).(func(context.Context, *tasks.Chord, int) *result.ChordAsyncResult); ok { + r0 = rf(ctx, chord, sendConcurrency) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*result.ChordAsyncResult) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *tasks.Chord, int) error); ok { + r1 = rf(ctx, chord, sendConcurrency) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// SendGroup provides a mock function with given fields: group, sendConcurrency +func (_m *Server) SendGroup(group *tasks.Group, sendConcurrency int) ([]*result.AsyncResult, error) { + ret := _m.Called(group, sendConcurrency) + + var r0 []*result.AsyncResult + if rf, ok := ret.Get(0).(func(*tasks.Group, int) []*result.AsyncResult); ok { + r0 = rf(group, sendConcurrency) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*result.AsyncResult) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(*tasks.Group, int) error); ok { + r1 = rf(group, sendConcurrency) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// SendGroupWithContext provides a mock function with given fields: ctx, group, sendConcurrency +func (_m *Server) SendGroupWithContext(ctx context.Context, group *tasks.Group, sendConcurrency int) ([]*result.AsyncResult, error) { + ret := _m.Called(ctx, group, sendConcurrency) + + var r0 []*result.AsyncResult + if rf, ok := ret.Get(0).(func(context.Context, *tasks.Group, int) []*result.AsyncResult); ok { + r0 = rf(ctx, group, sendConcurrency) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*result.AsyncResult) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *tasks.Group, int) error); ok { + r1 = rf(ctx, group, sendConcurrency) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// SendTask provides a mock function with given fields: signature +func (_m *Server) SendTask(signature *tasks.Signature) (*result.AsyncResult, error) { + ret := _m.Called(signature) + + var r0 *result.AsyncResult + if rf, ok := ret.Get(0).(func(*tasks.Signature) *result.AsyncResult); ok { + r0 = rf(signature) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*result.AsyncResult) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(*tasks.Signature) error); ok { + r1 = rf(signature) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// SendTaskWithContext provides a mock function with given fields: ctx, signature +func (_m *Server) SendTaskWithContext(ctx context.Context, signature *tasks.Signature) (*result.AsyncResult, error) { + ret := _m.Called(ctx, signature) + + var r0 *result.AsyncResult + if rf, ok := ret.Get(0).(func(context.Context, *tasks.Signature) *result.AsyncResult); ok { + r0 = rf(ctx, signature) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*result.AsyncResult) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *tasks.Signature) error); ok { + r1 = rf(ctx, signature) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// SetBackend provides a mock function with given fields: backend +func (_m *Server) SetBackend(backend iface.Backend) { + _m.Called(backend) +} + +// SetBroker provides a mock function with given fields: broker +func (_m *Server) SetBroker(broker brokersiface.Broker) { + _m.Called(broker) +} + +// SetConfig provides a mock function with given fields: cnf +func (_m *Server) SetConfig(cnf *config.Config) { + _m.Called(cnf) +} + +// SetPreTaskHandler provides a mock function with given fields: handler +func (_m *Server) SetPreTaskHandler(handler func(*tasks.Signature)) { + _m.Called(handler) +} diff --git a/v1/iface/mocks/Worker.go b/v1/iface/mocks/Worker.go new file mode 100644 index 000000000..fa60393f7 --- /dev/null +++ b/v1/iface/mocks/Worker.go @@ -0,0 +1,117 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import ( + iface "github.com/RichardKnop/machinery/v1/iface" + mock "github.com/stretchr/testify/mock" + + tasks "github.com/RichardKnop/machinery/v1/tasks" +) + +// Worker is an autogenerated mock type for the Worker type +type Worker struct { + mock.Mock +} + +// CustomQueue provides a mock function with given fields: +func (_m *Worker) CustomQueue() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetServer provides a mock function with given fields: +func (_m *Worker) GetServer() *iface.Server { + ret := _m.Called() + + var r0 *iface.Server + if rf, ok := ret.Get(0).(func() *iface.Server); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*iface.Server) + } + } + + return r0 +} + +// Launch provides a mock function with given fields: +func (_m *Worker) Launch() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// LaunchAsync provides a mock function with given fields: errorsChan +func (_m *Worker) LaunchAsync(errorsChan chan<- error) { + _m.Called(errorsChan) +} + +// PreConsumeHandler provides a mock function with given fields: +func (_m *Worker) PreConsumeHandler() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// Process provides a mock function with given fields: signature +func (_m *Worker) Process(signature *tasks.Signature) error { + ret := _m.Called(signature) + + var r0 error + if rf, ok := ret.Get(0).(func(*tasks.Signature) error); ok { + r0 = rf(signature) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Quit provides a mock function with given fields: +func (_m *Worker) Quit() { + _m.Called() +} + +// SetErrorHandler provides a mock function with given fields: handler +func (_m *Worker) SetErrorHandler(handler func(error)) { + _m.Called(handler) +} + +// SetPostTaskHandler provides a mock function with given fields: handler +func (_m *Worker) SetPostTaskHandler(handler func(*tasks.Signature)) { + _m.Called(handler) +} + +// SetPreConsumeHandler provides a mock function with given fields: handler +func (_m *Worker) SetPreConsumeHandler(handler func(iface.Worker) bool) { + _m.Called(handler) +} + +// SetPreTaskHandler provides a mock function with given fields: handler +func (_m *Worker) SetPreTaskHandler(handler func(*tasks.Signature)) { + _m.Called(handler) +} diff --git a/v1/server.go b/v1/server.go index e7a1f84a2..f713d883c 100644 --- a/v1/server.go +++ b/v1/server.go @@ -333,6 +333,7 @@ func (server *Server) SendChordWithContext(ctx context.Context, chord *tasks.Cho return result.NewChordAsyncResult( chord.Group.Tasks, chord.Callback, + chord.ErrorCallback, server.backend, ), nil } diff --git a/v1/tasks/signature.go b/v1/tasks/signature.go index 7a90c4aad..22fdc9901 100644 --- a/v1/tasks/signature.go +++ b/v1/tasks/signature.go @@ -58,7 +58,9 @@ type Signature struct { RetryTimeout int OnSuccess []*Signature OnError []*Signature - ChordCallback *Signature + // Chord members + ChordCallback *Signature + ChordErrorCallback *Signature //MessageGroupId for Broker, e.g. SQS BrokerMessageGroupId string //ReceiptHandle of SQS Message diff --git a/v1/tasks/workflow.go b/v1/tasks/workflow.go index 38a786461..dc9b156a9 100644 --- a/v1/tasks/workflow.go +++ b/v1/tasks/workflow.go @@ -20,8 +20,9 @@ type Group struct { // Chord adds an optional callback to the group to be executed // after all tasks in the group finished type Chord struct { - Group *Group - Callback *Signature + Group *Group + Callback *Signature + ErrorCallback *Signature } // GetUUIDs returns slice of task UUIDS @@ -80,16 +81,29 @@ func NewGroup(signatures ...*Signature) (*Group, error) { // NewChord creates a new chord (a group of tasks with a single callback // to be executed after all tasks in the group has completed) func NewChord(group *Group, callback *Signature) (*Chord, error) { - if callback.UUID == "" { + return NewChordWithError(group, callback, nil) +} + +// NewChordWithError creates a new chord (a group of tasks with a single callback +// to be executed after all tasks in the group has completed) +func NewChordWithError(group *Group, callback *Signature, errorCallback *Signature) (*Chord, error) { + if callback != nil && callback.UUID == "" { // Generate a UUID for the chord callback callbackUUID := uuid.New().String() callback.UUID = fmt.Sprintf("chord_%v", callbackUUID) } + if errorCallback != nil && errorCallback.UUID == "" { + // Generate a UUID for the chord error callback + errorCallbackUUID := uuid.New().String() + errorCallback.UUID = fmt.Sprintf("chord_%v", errorCallbackUUID) + } + // Add a chord callback to all tasks for _, signature := range group.Tasks { signature.ChordCallback = callback + signature.ChordErrorCallback = errorCallback } - return &Chord{Group: group, Callback: callback}, nil + return &Chord{Group: group, Callback: callback, ErrorCallback: errorCallback}, nil } diff --git a/v1/tasks/workflow_test.go b/v1/tasks/workflow_test.go index 5b7eacfe7..361be9ad1 100644 --- a/v1/tasks/workflow_test.go +++ b/v1/tasks/workflow_test.go @@ -4,13 +4,23 @@ import ( "testing" "github.com/RichardKnop/machinery/v1/tasks" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) -func TestNewChain(t *testing.T) { - t.Parallel() +type workflowSuite struct { + suite.Suite + task1 *tasks.Signature + task2 *tasks.Signature + task3 *tasks.Signature + task4 *tasks.Signature +} + +func TestWorkflowSuite(t *testing.T) { + suite.Run(t, new(workflowSuite)) +} - task1 := tasks.Signature{ +func (s *workflowSuite) SetupTest() { + s.task1 = &tasks.Signature{ Name: "foo", Args: []tasks.Arg{ { @@ -24,7 +34,7 @@ func TestNewChain(t *testing.T) { }, } - task2 := tasks.Signature{ + s.task2 = &tasks.Signature{ Name: "bar", Args: []tasks.Arg{ { @@ -38,7 +48,7 @@ func TestNewChain(t *testing.T) { }, } - task3 := tasks.Signature{ + s.task3 = &tasks.Signature{ Name: "qux", Args: []tasks.Arg{ { @@ -48,14 +58,65 @@ func TestNewChain(t *testing.T) { }, } - chain, err := tasks.NewChain(&task1, &task2, &task3) - if err != nil { - t.Fatal(err) + s.task4 = &tasks.Signature{ + Name: "box", + Args: []tasks.Arg{ + { + Type: "float64", + Value: interface{}(7), + }, + }, } +} + +func (s *workflowSuite) TestNewChain() { + chain, err := tasks.NewChain(s.task1, s.task2, s.task3) + s.Nil(err) firstTask := chain.Tasks[0] - assert.Equal(t, "foo", firstTask.Name) - assert.Equal(t, "bar", firstTask.OnSuccess[0].Name) - assert.Equal(t, "qux", firstTask.OnSuccess[0].OnSuccess[0].Name) + s.Equal("foo", firstTask.Name) + s.Equal("bar", firstTask.OnSuccess[0].Name) + s.Equal("qux", firstTask.OnSuccess[0].OnSuccess[0].Name) +} + +func (s *workflowSuite) TestNewGroup() { + group, err := tasks.NewGroup(s.task1, s.task2, s.task3) + s.Nil(err) + s.Equal(len(group.Tasks), 3) + for _, task := range group.Tasks { + s.Equal(task.GroupUUID, group.GroupUUID) + s.Equal(task.GroupTaskCount, len(group.Tasks)) + } +} + +func (s *workflowSuite) TestNewChord() { + group, err := tasks.NewGroup(s.task1, s.task2) + s.Nil(err) + + chord, err := tasks.NewChord(group, s.task3) + s.Nil(err) + s.Equal(chord.Callback, s.task3) + s.Nil(chord.ErrorCallback) + + for _, task := range group.Tasks { + s.Equal(task.ChordCallback, s.task3) + s.Nil(task.ChordErrorCallback) + } + +} + +func (s *workflowSuite) TestNewChordWithError() { + group, err := tasks.NewGroup(s.task1, s.task2) + s.Nil(err) + + errChord, err := tasks.NewChordWithError(group, s.task3, s.task4) + s.Nil(err) + s.Equal(errChord.Callback, s.task3) + s.Equal(errChord.ErrorCallback, s.task4) + + for _, task := range group.Tasks { + s.Equal(task.ChordCallback, s.task3) + s.Equal(task.ChordErrorCallback, s.task4) + } } diff --git a/v1/worker.go b/v1/worker.go index ef6b10bd8..01350a5ef 100644 --- a/v1/worker.go +++ b/v1/worker.go @@ -14,6 +14,7 @@ import ( "github.com/RichardKnop/machinery/v1/backends/amqp" "github.com/RichardKnop/machinery/v1/brokers/errs" + "github.com/RichardKnop/machinery/v1/iface" "github.com/RichardKnop/machinery/v1/log" "github.com/RichardKnop/machinery/v1/retry" "github.com/RichardKnop/machinery/v1/tasks" @@ -281,43 +282,46 @@ func (worker *Worker) taskSucceeded(signature *tasks.Signature, taskResults []*t return nil } - // There is no chord callback, just return - if signature.ChordCallback == nil { + return worker.processChords(signature) +} + +func (worker *Worker) processChords(signature *tasks.Signature) error { + return processChords(worker.GetServer(), signature, worker.hasAMQPBackend()) +} + +func processChords(server iface.Server, signature *tasks.Signature, hasAMQPBackend bool) error { + // this is not group chord execution. short circuit out early + if signature == nil || + (signature.ChordCallback == nil && signature.ChordErrorCallback == nil) { return nil } // Check if all task in the group has completed - groupCompleted, err := worker.server.GetBackend().GroupCompleted( + if groupCompleted, err := server.GetBackend().GroupCompleted( signature.GroupUUID, signature.GroupTaskCount, - ) - if err != nil { + ); err != nil { return fmt.Errorf("Completed check for group %s returned error: %s", signature.GroupUUID, err) - } - - // If the group has not yet completed, just return - if !groupCompleted { + } else if !groupCompleted { + // If the group has not yet completed, just return return nil } // Defer purging of group meta queue if we are using AMQP backend - if worker.hasAMQPBackend() { - defer worker.server.GetBackend().PurgeGroupMeta(signature.GroupUUID) + if hasAMQPBackend { + defer server.GetBackend().PurgeGroupMeta(signature.GroupUUID) } // Trigger chord callback - shouldTrigger, err := worker.server.GetBackend().TriggerChord(signature.GroupUUID) - if err != nil { + if shouldTrigger, err := server.GetBackend().TriggerChord(signature.GroupUUID); err != nil { return fmt.Errorf("Triggering chord for group %s returned error: %s", signature.GroupUUID, err) - } - - // Chord has already been triggered - if !shouldTrigger { + } else if !shouldTrigger { + // Chord has already been triggered return nil } // Get task states - taskStates, err := worker.server.GetBackend().GroupTaskStates( + taskStates, err := server.GetBackend().GroupTaskStates( signature.GroupUUID, signature.GroupTaskCount, ) @@ -331,26 +335,40 @@ func (worker *Worker) taskSucceeded(signature *tasks.Signature, taskResults []*t return nil } - // Append group tasks' return values to chord task if it's not immutable + errors := []string{} + successArgs := []tasks.Arg{} + for _, taskState := range taskStates { if !taskState.IsSuccess() { - return nil + errors = append(errors, taskState.Error) + continue + } + + for _, taskResult := range taskState.Results { + successArgs = append(successArgs, tasks.Arg{ + Type: taskResult.Type, + Value: taskResult.Value, + }) } + } + + // No errors in the group => pass results of the task to the chord callback + if len(errors) == 0 && signature.ChordCallback != nil { if signature.ChordCallback.Immutable == false { - // Pass results of the task to the chord callback - for _, taskResult := range taskState.Results { - signature.ChordCallback.Args = append(signature.ChordCallback.Args, tasks.Arg{ - Type: taskResult.Type, - Value: taskResult.Value, - }) - } + signature.ChordCallback.Args = append(signature.ChordCallback.Args, successArgs...) } + + _, err = server.SendTask(signature.ChordCallback) + return err } - // Send the chord task - _, err = worker.server.SendTask(signature.ChordCallback) - if err != nil { + if len(errors) != 0 && signature.ChordErrorCallback != nil { + signature.ChordErrorCallback.Args = append(signature.ChordErrorCallback.Args, + tasks.Arg{Type: "[]string", Value: errors}) + + // Send the chord task + _, err = server.SendTask(signature.ChordErrorCallback) return err } @@ -370,6 +388,11 @@ func (worker *Worker) taskFailed(signature *tasks.Signature, taskErr error) erro log.ERROR.Printf("Failed processing task %s. Error = %v", signature.UUID, taskErr) } + err := worker.processChords(signature) + if err != nil { + return fmt.Errorf("Error while processing ErrorChord=%v", err) + } + // Trigger error callbacks for _, errorTask := range signature.OnError { // Pass error as a first argument to error callbacks @@ -416,7 +439,7 @@ func (worker *Worker) SetPreConsumeHandler(handler func(*Worker) bool) { } //GetServer returns server -func (worker *Worker) GetServer() *Server { +func (worker *Worker) GetServer() iface.Server { return worker.server } diff --git a/v1/worker_test.go b/v1/worker_test.go index 1a02084ba..0e57ff1ef 100644 --- a/v1/worker_test.go +++ b/v1/worker_test.go @@ -1,11 +1,14 @@ -package machinery_test +package machinery import ( "testing" + backendmocks "github.com/RichardKnop/machinery/v1/backends/iface/mocks" + "github.com/RichardKnop/machinery/v1/iface/mocks" + "github.com/RichardKnop/machinery/v1/tasks" + "github.com/RichardKnop/machinery/v2" "github.com/stretchr/testify/assert" - - "github.com/RichardKnop/machinery/v1" + "github.com/stretchr/testify/suite" ) func TestRedactURL(t *testing.T) { @@ -18,7 +21,7 @@ func TestRedactURL(t *testing.T) { func TestPreConsumeHandler(t *testing.T) { t.Parallel() - + worker := &machinery.Worker{} worker.SetPreConsumeHandler(SamplePreConsumeHandler) @@ -28,3 +31,180 @@ func TestPreConsumeHandler(t *testing.T) { func SamplePreConsumeHandler(w *machinery.Worker) bool { return true } + +const ( + groupUUID = "APPLE" + groupTaskCount = 17 +) + +type processChordSuite struct { + suite.Suite + serverMock *mocks.Server + backendMock *backendmocks.Backend + + chordSuccess *tasks.Signature + chordError *tasks.Signature +} + +func TestProcessChordSuite(t *testing.T) { + suite.Run(t, new(processChordSuite)) +} + +func (s *processChordSuite) SetupTest() { + s.backendMock = &backendmocks.Backend{} + s.serverMock = &mocks.Server{} + s.chordSuccess = &tasks.Signature{Name: "SuccessChord"} + s.chordError = &tasks.Signature{Name: "ErrorChord"} +} + +func (s *processChordSuite) AfterTest(suiteName, testName string) { + s.backendMock.AssertExpectations(s.T()) + s.serverMock.AssertExpectations(s.T()) +} + +func (s *processChordSuite) TestNoChords() { + s.Nil(processChords(s.serverMock, nil, false)) + + sigWithoutChord := &tasks.Signature{Name: "BANANA"} + s.Nil(processChords(s.serverMock, sigWithoutChord, false)) +} + +func (s *processChordSuite) TestErrorChord() { + // this task will fail out + taskStates := []*tasks.TaskState{ + { + State: tasks.StateFailure, + Error: "PEACH", + }, + } + + s.backendMock.On("GroupCompleted", groupUUID, groupTaskCount).Return(true, nil) + s.backendMock.On("TriggerChord", groupUUID).Return(true, nil) + s.backendMock.On("GroupTaskStates", groupUUID, groupTaskCount).Return(taskStates, nil) + + s.serverMock.On("GetBackend").Return(s.backendMock) + s.serverMock.On("SendTask", s.chordError).Return(nil, nil) + + sig := &tasks.Signature{ + Name: "BANANA", + GroupUUID: groupUUID, + GroupTaskCount: groupTaskCount, + ChordCallback: s.chordSuccess, + ChordErrorCallback: s.chordError, + } + + s.Nil(processChords(s.serverMock, sig, false)) +} + +func (s *processChordSuite) TestSuccessChord() { + // this task will fail out + taskStates := []*tasks.TaskState{ + { + State: tasks.StateSuccess, + Results: []*tasks.TaskResult{ + {Type: "string", Value: "GRAPE"}, + }, + }, + } + + s.backendMock.On("GroupCompleted", groupUUID, groupTaskCount).Return(true, nil) + s.backendMock.On("TriggerChord", groupUUID).Return(true, nil) + s.backendMock.On("GroupTaskStates", groupUUID, groupTaskCount).Return(taskStates, nil) + + s.serverMock.On("GetBackend").Return(s.backendMock) + s.serverMock.On("SendTask", s.chordSuccess).Return(nil, nil) + + sig := &tasks.Signature{ + Name: "BANANA", + GroupUUID: groupUUID, + GroupTaskCount: groupTaskCount, + ChordCallback: s.chordSuccess, + ChordErrorCallback: s.chordError, + } + + s.Nil(processChords(s.serverMock, sig, false)) +} + +func (s *processChordSuite) TestSuccessWithoutChord() { + // this task will fail out + taskStates := []*tasks.TaskState{ + { + State: tasks.StateSuccess, + Results: []*tasks.TaskResult{ + {Type: "string", Value: "GRAPE"}, + }, + }, + } + + s.backendMock.On("GroupCompleted", groupUUID, groupTaskCount).Return(true, nil) + s.backendMock.On("TriggerChord", groupUUID).Return(true, nil) + s.backendMock.On("GroupTaskStates", groupUUID, groupTaskCount).Return(taskStates, nil) + + s.serverMock.On("GetBackend").Return(s.backendMock) + // no SendTask should get called on the server as there is no Chord + + sig := &tasks.Signature{ + Name: "BANANA", + GroupUUID: groupUUID, + GroupTaskCount: groupTaskCount, + ChordErrorCallback: s.chordError, + } + + s.Nil(processChords(s.serverMock, sig, false)) +} + +func (s *processChordSuite) TestErrorWithoutChord() { + // this task will fail out + taskStates := []*tasks.TaskState{ + { + State: tasks.StateFailure, + Error: "PEACH", + }, + } + + s.backendMock.On("GroupCompleted", groupUUID, groupTaskCount).Return(true, nil) + s.backendMock.On("TriggerChord", groupUUID).Return(true, nil) + s.backendMock.On("GroupTaskStates", groupUUID, groupTaskCount).Return(taskStates, nil) + + s.serverMock.On("GetBackend").Return(s.backendMock) + // No send task + + sig := &tasks.Signature{ + Name: "BANANA", + GroupUUID: groupUUID, + GroupTaskCount: groupTaskCount, + ChordCallback: s.chordSuccess, + } + + s.Nil(processChords(s.serverMock, sig, false)) +} + +func (s *processChordSuite) TestSuccessChordAMQPBackend() { + // this task will fail out + taskStates := []*tasks.TaskState{ + { + State: tasks.StateSuccess, + Results: []*tasks.TaskResult{ + {Type: "string", Value: "GRAPE"}, + }, + }, + } + + s.backendMock.On("GroupCompleted", groupUUID, groupTaskCount).Return(true, nil) + s.backendMock.On("TriggerChord", groupUUID).Return(true, nil) + s.backendMock.On("GroupTaskStates", groupUUID, groupTaskCount).Return(taskStates, nil) + s.backendMock.On("PurgeGroupMeta", groupUUID).Return(nil) + + s.serverMock.On("GetBackend").Return(s.backendMock) + s.serverMock.On("SendTask", s.chordSuccess).Return(nil, nil) + + sig := &tasks.Signature{ + Name: "BANANA", + GroupUUID: groupUUID, + GroupTaskCount: groupTaskCount, + ChordCallback: s.chordSuccess, + ChordErrorCallback: s.chordError, + } + + s.Nil(processChords(s.serverMock, sig, true)) +} diff --git a/v2/server.go b/v2/server.go index db3cb58f9..f30892573 100644 --- a/v2/server.go +++ b/v2/server.go @@ -302,6 +302,7 @@ func (server *Server) SendChordWithContext(ctx context.Context, chord *tasks.Cho return result.NewChordAsyncResult( chord.Group.Tasks, chord.Callback, + chord.ErrorCallback, server.backend, ), nil }