diff --git a/components/payments/cmd/connectors/internal/task/errors.go b/components/payments/cmd/connectors/internal/task/errors.go new file mode 100644 index 0000000000..a471a841ff --- /dev/null +++ b/components/payments/cmd/connectors/internal/task/errors.go @@ -0,0 +1,13 @@ +package task + +import "github.com/pkg/errors" + +var ( + // ErrRetryableError will be sent by the task if we can retry the task, + // e.g. if the task failed because of a temporary network issue. + ErrRetryableError = errors.New("retryable error") + + // ErrNonRetryableError will be sent by the task if we can't retry the task, + // e.g. if the task failed because of a validation error. + ErrNonRetryableError = errors.New("non-retryable error") +) diff --git a/components/payments/cmd/connectors/internal/task/scheduler.go b/components/payments/cmd/connectors/internal/task/scheduler.go index 69b384b544..4769b5d211 100644 --- a/components/payments/cmd/connectors/internal/task/scheduler.go +++ b/components/payments/cmd/connectors/internal/task/scheduler.go @@ -360,6 +360,9 @@ func (s *DefaultTaskScheduler) startTask(ctx context.Context, descriptor models. err = container.Provide(func() models.TaskID { return models.TaskID(task.ID) }) + if err != nil { + panic(err) + } err = container.Provide(func() StopChan { s.mu.Lock() @@ -431,159 +434,233 @@ func (s *DefaultTaskScheduler) startTask(ctx context.Context, descriptor models. } fallthrough case models.OPTIONS_RUN_IN_DURATION: - go func() { - if options.Duration > 0 { - logger.Infof("Waiting %s before starting task...", options.Duration) - // todo(gfyrag): need to listen on stopChan if the application is stopped - time.Sleep(options.Duration) - } + go s.runTaskOnce( + ctx, + logger, + holder, + descriptor, + options, + taskResolver, + container, + sendError, + errChan, + 1, + ) + case models.OPTIONS_RUN_PERIODICALLY: + go s.runTaskPeriodically( + ctx, + logger, + holder, + descriptor, + options, + taskResolver, + container, + ) + } - logger.Infof("Starting task...") + if !sendError { + close(errChan) + } - defer func() { - defer s.deleteTask(ctx, holder) + return errChan +} - if sendError { - defer close(errChan) - } +func (s *DefaultTaskScheduler) runTaskOnce( + ctx context.Context, + logger logging.Logger, + holder *taskHolder, + descriptor models.TaskDescriptor, + options models.TaskSchedulerOptions, + taskResolver Task, + container *dig.Container, + sendError bool, + errChan chan error, + attempt int, +) { + // If attempt is > 1, it means that the task is being retried, so no need + // to wait again + if options.Duration > 0 && attempt == 1 { + logger.Infof("Waiting %s before starting task...", options.Duration) + select { + case <-ctx.Done(): + return + case ch := <-holder.stopChan: + logger.Infof("Stopping task...") + close(ch) + return + case <-time.After(options.Duration): + } + } - if e := recover(); e != nil { - switch v := e.(type) { - case error: - if errors.Is(v, pond.ErrSubmitOnStoppedPool) { - // Pool is stopped and task is marked as active, - // nothing to do as they will be restarted on - // next startup - return - } - } - - s.registerTaskError(ctx, holder, e) - debug.PrintStack() - - if sendError { - switch v := e.(type) { - case error: - errChan <- v - default: - errChan <- fmt.Errorf("%s", v) - } - } - } - }() + logger.Infof("Starting task...") - done := make(chan struct{}) - s.workerPool.Submit(func() { - defer close(done) - err = container.Invoke(taskResolver) - }) - select { - case <-done: - case <-ctx.Done(): - return - } - if err != nil { - s.registerTaskError(ctx, holder, err) + defer func() { + defer s.deleteTask(ctx, holder) - if sendError { - errChan <- err + if sendError { + defer close(errChan) + } + + if e := recover(); e != nil { + switch v := e.(type) { + case error: + if errors.Is(v, pond.ErrSubmitOnStoppedPool) { + // Pool is stopped and task is marked as active, + // nothing to do as they will be restarted on + // next startup return } - - return } - logger.Infof("Task terminated with success") + s.registerTaskError(ctx, holder, e) + debug.PrintStack() - err = s.store.UpdateTaskStatus(ctx, s.connectorID, descriptor, models.TaskStatusTerminated, "") - if err != nil { - logger.Errorf("Error updating task status: %s", err) - if sendError { - errChan <- err + if sendError { + switch v := e.(type) { + case error: + errChan <- v + default: + errChan <- fmt.Errorf("%s", v) } } - }() - case models.OPTIONS_RUN_PERIODICALLY: - go func() { - defer func() { - defer s.deleteTask(ctx, holder) + } + }() - if e := recover(); e != nil { - s.registerTaskError(ctx, holder, e) - debug.PrintStack() + runF := func() error { + var err error - return - } - }() - - processFunc := func() (bool, error) { - done := make(chan struct{}) - s.workerPool.Submit(func() { - defer close(done) - err = container.Invoke(taskResolver) - }) - select { - case <-done: - case <-ctx.Done(): - return true, nil - case ch := <-holder.stopChan: - logger.Infof("Stopping task...") - close(ch) - return true, nil - } - if err != nil { - s.registerTaskError(ctx, holder, err) - return false, err - } + done := make(chan struct{}) + s.workerPool.Submit(func() { + defer close(done) + err = container.Invoke(taskResolver) + }) + select { + case <-done: + case <-ctx.Done(): + return ctx.Err() + } - return false, err - } + return err + } - // launch it once before starting the ticker - stopped, err := processFunc() - if err != nil { - // error is already registered +loop: + for { + err := runF() + switch { + case err == nil: + break loop + case errors.Is(err, ErrRetryableError): + continue + case errors.Is(err, ErrNonRetryableError): + fallthrough + default: + if err == context.Canceled { + // Context was canceled, which means the scheduler was stopped + // either by the application being stopped or by the connector + // being removed. In this case, we don't want to update the + // task status, as it will be restarted on next startup. return } - if stopped { - // Task is stopped or context is done - return - } + // All other errors + s.registerTaskError(ctx, holder, err) - logger.Infof("Starting task...") - ticker := time.NewTicker(options.Duration) - for { - select { - case ch := <-holder.stopChan: - logger.Infof("Stopping task...") - close(ch) - return - case <-ctx.Done(): - return - case <-ticker.C: - logger.Infof("Polling trigger, running task...") - stop, err := processFunc() - if err != nil { - // error is already registered - return - } - - if stop { - // Task is stopped or context is done - return - } - } + if sendError { + errChan <- err } - }() + return + } } - if !sendError { - close(errChan) + logger.Infof("Task terminated with success") + + err := s.store.UpdateTaskStatus(ctx, s.connectorID, descriptor, models.TaskStatusTerminated, "") + if err != nil { + logger.Errorf("Error updating task status: %s", err) + if sendError { + errChan <- err + } } +} - return errChan +func (s *DefaultTaskScheduler) runTaskPeriodically( + ctx context.Context, + logger logging.Logger, + holder *taskHolder, + descriptor models.TaskDescriptor, + options models.TaskSchedulerOptions, + taskResolver Task, + container *dig.Container, +) { + defer func() { + defer s.deleteTask(ctx, holder) + + if e := recover(); e != nil { + s.registerTaskError(ctx, holder, e) + debug.PrintStack() + + return + } + }() + + processFunc := func() (bool, error) { + var err error + done := make(chan struct{}) + s.workerPool.Submit(func() { + defer close(done) + err = container.Invoke(taskResolver) + }) + select { + case <-done: + case <-ctx.Done(): + return true, nil + case ch := <-holder.stopChan: + logger.Infof("Stopping task...") + close(ch) + return true, nil + } + if err != nil { + return false, err + } + + return false, err + } + + logger.Infof("Starting task...") + ticker := time.NewTicker(options.Duration) + for { + stopped, err := processFunc() + switch { + case err == nil: + // Doing nothing, waiting for the next tick + case errors.Is(err, ErrRetryableError): + ticker.Reset(options.Duration) + continue + case errors.Is(err, ErrNonRetryableError): + fallthrough + default: + // All other errors + s.registerTaskError(ctx, holder, err) + return + } + + if stopped { + // Task is stopped or context is done + return + } + + select { + case ch := <-holder.stopChan: + logger.Infof("Stopping task...") + close(ch) + return + case <-ctx.Done(): + return + case <-ticker.C: + logger.Infof("Polling trigger, running task...") + } + } } func (s *DefaultTaskScheduler) logger(ctx context.Context) logging.Logger { diff --git a/components/payments/cmd/connectors/internal/task/scheduler_test.go b/components/payments/cmd/connectors/internal/task/scheduler_test.go index 6cd8f07d4d..6aa4109857 100644 --- a/components/payments/cmd/connectors/internal/task/scheduler_test.go +++ b/components/payments/cmd/connectors/internal/task/scheduler_test.go @@ -2,13 +2,13 @@ package task import ( "context" - "errors" "testing" "time" "github.com/formancehq/payments/cmd/connectors/internal/metrics" "github.com/formancehq/payments/internal/models" "github.com/google/uuid" + "github.com/pkg/errors" "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" "go.uber.org/dig" @@ -272,4 +272,148 @@ func TestTaskScheduler(t *testing.T) { require.Eventually(t, TaskActive(store, connectorID, mainDescriptor), time.Second, 100*time.Millisecond) require.Eventually(t, TaskActive(store, connectorID, workerDescriptor), time.Second, 100*time.Millisecond) }) + + t.Run("errors and retryable errors", func(t *testing.T) { + t.Parallel() + + connectorID := models.ConnectorID{ + Reference: uuid.New(), + Provider: models.ConnectorProviderDummyPay, + } + store := NewInMemoryStore() + nonRetryableDescriptor := newDescriptor() + retryableDescriptor := newDescriptor() + otherErrorDescriptor := newDescriptor() + noErrorDescriptor := newDescriptor() + + scheduler := NewDefaultScheduler(connectorID, store, DefaultContainerFactory, + ResolverFn(func(descriptor models.TaskDescriptor) Task { + switch string(descriptor) { + case string(nonRetryableDescriptor): + return func(ctx context.Context, scheduler Scheduler) error { + return ErrNonRetryableError + } + case string(retryableDescriptor): + return func(ctx context.Context, scheduler Scheduler) error { + return ErrRetryableError + } + case string(otherErrorDescriptor): + return func(ctx context.Context, scheduler Scheduler) error { + return errors.New("test") + } + case string(noErrorDescriptor): + return func(ctx context.Context, scheduler Scheduler) error { + return nil + } + default: + return func() { + + } + } + }), metrics.NewNoOpMetricsRegistry(), 1) + + require.NoError(t, scheduler.Schedule(context.TODO(), nonRetryableDescriptor, models.TaskSchedulerOptions{ + ScheduleOption: models.OPTIONS_RUN_NOW, + RestartOption: models.OPTIONS_RESTART_NEVER, + })) + require.Eventually(t, TaskFailed(store, connectorID, nonRetryableDescriptor, "non-retryable error"), time.Second, 100*time.Millisecond) + + require.NoError(t, scheduler.Schedule(context.TODO(), otherErrorDescriptor, models.TaskSchedulerOptions{ + ScheduleOption: models.OPTIONS_RUN_NOW, + RestartOption: models.OPTIONS_RESTART_NEVER, + })) + require.Eventually(t, TaskFailed(store, connectorID, otherErrorDescriptor, "test"), time.Second, 100*time.Millisecond) + + require.NoError(t, scheduler.Schedule(context.TODO(), noErrorDescriptor, models.TaskSchedulerOptions{ + ScheduleOption: models.OPTIONS_RUN_NOW, + RestartOption: models.OPTIONS_RESTART_NEVER, + })) + require.Eventually(t, TaskTerminated(store, connectorID, noErrorDescriptor), time.Second, 100*time.Millisecond) + + require.NoError(t, scheduler.Schedule(context.TODO(), retryableDescriptor, models.TaskSchedulerOptions{ + ScheduleOption: models.OPTIONS_RUN_NOW, + RestartOption: models.OPTIONS_RESTART_NEVER, + })) + require.Eventually(t, TaskActive(store, connectorID, retryableDescriptor), time.Second, 100*time.Millisecond) + require.NoError(t, scheduler.Shutdown(context.TODO())) + + require.Eventually(t, TaskFailed(store, connectorID, nonRetryableDescriptor, "non-retryable error"), time.Second, 100*time.Millisecond) + require.Eventually(t, TaskFailed(store, connectorID, otherErrorDescriptor, "test"), time.Second, 100*time.Millisecond) + require.Eventually(t, TaskTerminated(store, connectorID, noErrorDescriptor), time.Second, 100*time.Millisecond) + require.Eventually(t, TaskActive(store, connectorID, retryableDescriptor), time.Second, 100*time.Millisecond) + }) + + t.Run("errors and retryable errors", func(t *testing.T) { + t.Parallel() + + connectorID := models.ConnectorID{ + Reference: uuid.New(), + Provider: models.ConnectorProviderDummyPay, + } + store := NewInMemoryStore() + nonRetryableDescriptor := newDescriptor() + retryableDescriptor := newDescriptor() + otherErrorDescriptor := newDescriptor() + noErrorDescriptor := newDescriptor() + + scheduler := NewDefaultScheduler(connectorID, store, DefaultContainerFactory, + ResolverFn(func(descriptor models.TaskDescriptor) Task { + switch string(descriptor) { + case string(nonRetryableDescriptor): + return func(ctx context.Context, scheduler Scheduler) error { + return ErrNonRetryableError + } + case string(retryableDescriptor): + return func(ctx context.Context, scheduler Scheduler) error { + return ErrRetryableError + } + case string(otherErrorDescriptor): + return func(ctx context.Context, scheduler Scheduler) error { + return errors.New("test") + } + case string(noErrorDescriptor): + return func(ctx context.Context, scheduler Scheduler) error { + return nil + } + default: + return func() { + + } + } + }), metrics.NewNoOpMetricsRegistry(), 1) + + require.NoError(t, scheduler.Schedule(context.TODO(), nonRetryableDescriptor, models.TaskSchedulerOptions{ + ScheduleOption: models.OPTIONS_RUN_PERIODICALLY, + Duration: 1 * time.Second, + RestartOption: models.OPTIONS_RESTART_NEVER, + })) + require.Eventually(t, TaskFailed(store, connectorID, nonRetryableDescriptor, "non-retryable error"), time.Second, 100*time.Millisecond) + + require.NoError(t, scheduler.Schedule(context.TODO(), otherErrorDescriptor, models.TaskSchedulerOptions{ + ScheduleOption: models.OPTIONS_RUN_PERIODICALLY, + Duration: 1 * time.Second, + RestartOption: models.OPTIONS_RESTART_NEVER, + })) + require.Eventually(t, TaskFailed(store, connectorID, otherErrorDescriptor, "test"), time.Second, 100*time.Millisecond) + + require.NoError(t, scheduler.Schedule(context.TODO(), noErrorDescriptor, models.TaskSchedulerOptions{ + ScheduleOption: models.OPTIONS_RUN_PERIODICALLY, + Duration: 1 * time.Second, + RestartOption: models.OPTIONS_RESTART_NEVER, + })) + require.Eventually(t, TaskActive(store, connectorID, noErrorDescriptor), time.Second, 100*time.Millisecond) + + require.NoError(t, scheduler.Schedule(context.TODO(), retryableDescriptor, models.TaskSchedulerOptions{ + ScheduleOption: models.OPTIONS_RUN_PERIODICALLY, + Duration: 1 * time.Second, + RestartOption: models.OPTIONS_RESTART_NEVER, + })) + require.Eventually(t, TaskActive(store, connectorID, retryableDescriptor), time.Second, 100*time.Millisecond) + require.NoError(t, scheduler.Shutdown(context.TODO())) + + require.Eventually(t, TaskFailed(store, connectorID, nonRetryableDescriptor, "non-retryable error"), time.Second, 100*time.Millisecond) + require.Eventually(t, TaskFailed(store, connectorID, otherErrorDescriptor, "test"), time.Second, 100*time.Millisecond) + require.Eventually(t, TaskActive(store, connectorID, noErrorDescriptor), time.Second, 100*time.Millisecond) + require.Eventually(t, TaskActive(store, connectorID, retryableDescriptor), time.Second, 100*time.Millisecond) + }) } diff --git a/components/payments/cmd/connectors/internal/task/storememory.go b/components/payments/cmd/connectors/internal/task/storememory.go index 66a101a9c5..ecb05020e1 100644 --- a/components/payments/cmd/connectors/internal/task/storememory.go +++ b/components/payments/cmd/connectors/internal/task/storememory.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "strings" + "sync" "time" "github.com/formancehq/payments/cmd/connectors/internal/storage" @@ -15,6 +16,7 @@ import ( ) type InMemoryStore struct { + mu sync.RWMutex tasks map[uuid.UUID]models.Task statuses map[string]models.TaskStatus created map[string]time.Time @@ -22,6 +24,9 @@ type InMemoryStore struct { } func (s *InMemoryStore) GetTask(ctx context.Context, id uuid.UUID) (*models.Task, error) { + s.mu.RLock() + defer s.mu.RUnlock() + task, ok := s.tasks[id] if !ok { return nil, storage.ErrNotFound @@ -35,6 +40,9 @@ func (s *InMemoryStore) GetTaskByDescriptor( connectorID models.ConnectorID, descriptor models.TaskDescriptor, ) (*models.Task, error) { + s.mu.RLock() + defer s.mu.RUnlock() + id, err := descriptor.EncodeToString() if err != nil { return nil, err @@ -58,6 +66,9 @@ func (s *InMemoryStore) ListTasks(ctx context.Context, connectorID models.ConnectorID, q storage.ListTasksQuery, ) (*api.Cursor[models.Task], error) { + s.mu.RLock() + defer s.mu.RUnlock() + ret := make([]models.Task, 0) for id, status := range s.statuses { @@ -89,6 +100,9 @@ func (s *InMemoryStore) ReadOldestPendingTask( ctx context.Context, connectorID models.ConnectorID, ) (*models.Task, error) { + s.mu.RLock() + defer s.mu.RUnlock() + var ( oldestDate time.Time oldestID string @@ -182,6 +196,9 @@ func (s *InMemoryStore) UpdateTaskStatus( status models.TaskStatus, taskError string, ) error { + s.mu.Lock() + defer s.mu.Unlock() + taskID, err := descriptor.EncodeToString() if err != nil { return err @@ -203,6 +220,9 @@ func (s *InMemoryStore) Result( connectorID models.ConnectorID, descriptor models.TaskDescriptor, ) (models.TaskStatus, string, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + taskID, err := descriptor.EncodeToString() if err != nil { panic(err)