Skip to content

Commit

Permalink
Support for activity retry policies (#83)
Browse files Browse the repository at this point in the history
Adds the ActivityRetryPolicy struct and updates CallActivity to execute a different code path if retries are needed.

Signed-off-by: Fabian Martinez <[email protected]>
  • Loading branch information
famarting authored Oct 24, 2024
1 parent 65c308b commit 0c4afbc
Show file tree
Hide file tree
Showing 9 changed files with 436 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Add API to set custom status ([#81](https://github.com/microsoft/durabletask-go/pull/81)) - by [@famarting](https://github.com/famarting)
- Add missing purge orchestration options ([#82](https://github.com/microsoft/durabletask-go/pull/82)) - by [@famarting](https://github.com/famarting)
- Add support for activity retry policies ([#83](https://github.com/microsoft/durabletask-go/pull/83)) - by [@famarting](https://github.com/famarting)

### Changed

Expand Down
95 changes: 95 additions & 0 deletions samples/retries/retries.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package main

import (
"context"
"encoding/json"
"errors"
"log"
"math/rand"
"time"

"github.com/microsoft/durabletask-go/backend"
"github.com/microsoft/durabletask-go/backend/sqlite"
"github.com/microsoft/durabletask-go/task"
)

func main() {
// Create a new task registry and add the orchestrator and activities
r := task.NewTaskRegistry()
r.AddOrchestrator(RetryActivityOrchestrator)
r.AddActivity(RandomFailActivity)

// Init the client
ctx := context.Background()
client, worker, err := Init(ctx, r)
if err != nil {
log.Fatalf("Failed to initialize the client: %v", err)
}
defer worker.Shutdown(ctx)

// Start a new orchestration
id, err := client.ScheduleNewOrchestration(ctx, RetryActivityOrchestrator)
if err != nil {
log.Fatalf("Failed to schedule new orchestration: %v", err)
}

// Wait for the orchestration to complete
metadata, err := client.WaitForOrchestrationCompletion(ctx, id)
if err != nil {
log.Fatalf("Failed to wait for orchestration to complete: %v", err)
}

// Print the results
metadataEnc, err := json.MarshalIndent(metadata, "", " ")
if err != nil {
log.Fatalf("Failed to encode result to JSON: %v", err)
}
log.Printf("Orchestration completed: %v", string(metadataEnc))
}

// Init creates and initializes an in-memory client and worker pair with default configuration.
func Init(ctx context.Context, r *task.TaskRegistry) (backend.TaskHubClient, backend.TaskHubWorker, error) {
logger := backend.DefaultLogger()

// Create an executor
executor := task.NewTaskExecutor(r)

// Create a new backend
// Use the in-memory sqlite provider by specifying ""
be := sqlite.NewSqliteBackend(sqlite.NewSqliteOptions(""), logger)
orchestrationWorker := backend.NewOrchestrationWorker(be, executor, logger)
activityWorker := backend.NewActivityTaskWorker(be, executor, logger)
taskHubWorker := backend.NewTaskHubWorker(be, orchestrationWorker, activityWorker, logger)

// Start the worker
err := taskHubWorker.Start(ctx)
if err != nil {
return nil, nil, err
}

// Get the client to the backend
taskHubClient := backend.NewTaskHubClient(be)

return taskHubClient, taskHubWorker, nil
}

func RetryActivityOrchestrator(ctx *task.OrchestrationContext) (any, error) {
if err := ctx.CallActivity(RandomFailActivity, task.WithRetryPolicy(&task.ActivityRetryPolicy{
MaxAttempts: 10,
InitialRetryInterval: 100 * time.Millisecond,
BackoffCoefficient: 2,
MaxRetryInterval: 3 * time.Second,
})).Await(nil); err != nil {
return nil, err
}
return nil, nil
}

func RandomFailActivity(ctx task.ActivityContext) (any, error) {
// 70% possibility for activity failure
if rand.Intn(100) <= 70 {
log.Println("random activity failure")
return "", errors.New("random activity failure")
}
return "ok", nil
}
52 changes: 51 additions & 1 deletion task/activity.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package task

import (
"context"
"fmt"
"math"
"time"

"github.com/microsoft/durabletask-go/internal/protos"
"google.golang.org/protobuf/types/known/wrapperspb"
Expand All @@ -10,7 +13,23 @@ import (
type callActivityOption func(*callActivityOptions) error

type callActivityOptions struct {
rawInput *wrapperspb.StringValue
rawInput *wrapperspb.StringValue
retryPolicy *ActivityRetryPolicy
}

type ActivityRetryPolicy struct {
// Max number of attempts to try the activity call, first execution inclusive
MaxAttempts int
// Timespan to wait for the first retry
InitialRetryInterval time.Duration
// Used to determine rate of increase of back-off
BackoffCoefficient float64
// Max timespan to wait for a retry
MaxRetryInterval time.Duration
// Total timeout across all the retries performed
RetryTimeout time.Duration
// Optional function to control if retries should proceed
Handle func(error) bool
}

// WithActivityInput configures an input for an activity invocation.
Expand All @@ -34,6 +53,37 @@ func WithRawActivityInput(input string) callActivityOption {
}
}

func WithRetryPolicy(policy *ActivityRetryPolicy) callActivityOption {
return func(opt *callActivityOptions) error {
if policy == nil {
return nil
}
if policy.InitialRetryInterval <= 0 {
return fmt.Errorf("InitialRetryInterval must be greater than 0")
}
if policy.MaxAttempts <= 0 {
// setting 1 max attempt is equivalent to not retrying
policy.MaxAttempts = 1
}
if policy.BackoffCoefficient <= 0 {
policy.BackoffCoefficient = 1
}
if policy.MaxRetryInterval <= 0 {
policy.MaxRetryInterval = math.MaxInt64
}
if policy.RetryTimeout <= 0 {
policy.RetryTimeout = math.MaxInt64
}
if policy.Handle == nil {
policy.Handle = func(err error) bool {
return true
}
}
opt.retryPolicy = policy
return nil
}
}

// ActivityContext is the context parameter type for activity implementations.
type ActivityContext interface {
GetInput(resultPtr any) error
Expand Down
61 changes: 61 additions & 0 deletions task/orchestrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"container/list"
"encoding/json"
"fmt"
"math"
"strings"
"time"

Expand Down Expand Up @@ -236,6 +237,16 @@ func (ctx *OrchestrationContext) CallActivity(activity interface{}, opts ...call
}
}

if options.retryPolicy != nil {
return ctx.internalCallActivityWithRetries(ctx.CurrentTimeUtc, func() Task {
return ctx.internalScheduleActivity(activity, options)
}, *options.retryPolicy, 0)
}

return ctx.internalScheduleActivity(activity, options)
}

func (ctx *OrchestrationContext) internalScheduleActivity(activity interface{}, options *callActivityOptions) Task {
scheduleTaskAction := helpers.NewScheduleTaskAction(
ctx.getNextSequenceNumber(),
helpers.GetTaskFunctionName(activity),
Expand All @@ -248,6 +259,56 @@ func (ctx *OrchestrationContext) CallActivity(activity interface{}, opts ...call
return task
}

func (ctx *OrchestrationContext) internalCallActivityWithRetries(initialAttempt time.Time, schedule func() Task, policy ActivityRetryPolicy, retryCount int) Task {
return &taskWrapper{
delegate: schedule(),
onAwaitResult: func(v any, err error) error {
if err == nil {
return nil
}

if retryCount+1 >= policy.MaxAttempts {
// next try will exceed the max attempts, dont continue
return err
}

nextDelay := computeNextDelay(ctx.CurrentTimeUtc, policy, retryCount, initialAttempt, err)
if nextDelay == 0 {
return err
}

timerErr := ctx.createTimerInternal(nextDelay).Await(nil)
if timerErr != nil {
// TODO use errors.Join when updating golang
return fmt.Errorf("%v %w", timerErr, err)
}

err = ctx.internalCallActivityWithRetries(initialAttempt, schedule, policy, retryCount+1).Await(v)
if err == nil {
return nil
}
return err
},
}
}

func computeNextDelay(currentTimeUtc time.Time, policy ActivityRetryPolicy, attempt int, firstAttempt time.Time, err error) time.Duration {
if policy.Handle(err) {
isExpired := false
if policy.RetryTimeout != math.MaxInt64 {
isExpired = currentTimeUtc.After(firstAttempt.Add(policy.RetryTimeout))
}
if !isExpired {
nextDelayMs := float64(policy.InitialRetryInterval.Milliseconds()) * math.Pow(policy.BackoffCoefficient, float64(attempt))
if nextDelayMs < float64(policy.MaxRetryInterval.Milliseconds()) {
return time.Duration(int64(nextDelayMs) * int64(time.Millisecond))
}
return policy.MaxRetryInterval
}
}
return 0
}

func (ctx *OrchestrationContext) CallSubOrchestrator(orchestrator interface{}, opts ...subOrchestratorOption) Task {
options := new(callSubOrchestratorOptions)
for _, configure := range opts {
Expand Down
133 changes: 133 additions & 0 deletions task/orchestrator_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package task

import (
"testing"
"time"
)

func Test_computeNextDelay(t *testing.T) {
time1 := time.Now()
time2 := time.Now().Add(1 * time.Minute)
type args struct {
currentTimeUtc time.Time
policy ActivityRetryPolicy
attempt int
firstAttempt time.Time
err error
}
tests := []struct {
name string
args args
want time.Duration
}{
{
name: "first attempt",
args: args{
currentTimeUtc: time2,
policy: ActivityRetryPolicy{
MaxAttempts: 3,
InitialRetryInterval: 2 * time.Second,
BackoffCoefficient: 2,
MaxRetryInterval: 10 * time.Second,
Handle: func(err error) bool { return true },
RetryTimeout: 2 * time.Minute,
},
attempt: 0,
firstAttempt: time1,
},
want: 2 * time.Second,
},
{
name: "second attempt",
args: args{
currentTimeUtc: time2,
policy: ActivityRetryPolicy{
MaxAttempts: 3,
InitialRetryInterval: 2 * time.Second,
BackoffCoefficient: 2,
MaxRetryInterval: 10 * time.Second,
Handle: func(err error) bool { return true },
RetryTimeout: 2 * time.Minute,
},
attempt: 1,
firstAttempt: time1,
},
want: 4 * time.Second,
},
{
name: "third attempt",
args: args{
currentTimeUtc: time2,
policy: ActivityRetryPolicy{
MaxAttempts: 3,
InitialRetryInterval: 2 * time.Second,
BackoffCoefficient: 2,
MaxRetryInterval: 10 * time.Second,
Handle: func(err error) bool { return true },
RetryTimeout: 2 * time.Minute,
},
attempt: 2,
firstAttempt: time1,
},
want: 8 * time.Second,
},
{
name: "fourth attempt",
args: args{
currentTimeUtc: time2,
policy: ActivityRetryPolicy{
MaxAttempts: 3,
InitialRetryInterval: 2 * time.Second,
BackoffCoefficient: 2,
MaxRetryInterval: 10 * time.Second,
Handle: func(err error) bool { return true },
RetryTimeout: 2 * time.Minute,
},
attempt: 3,
firstAttempt: time1,
},
want: 10 * time.Second,
},
{
name: "expired",
args: args{
currentTimeUtc: time2,
policy: ActivityRetryPolicy{
MaxAttempts: 3,
InitialRetryInterval: 2 * time.Second,
BackoffCoefficient: 2,
MaxRetryInterval: 10 * time.Second,
Handle: func(err error) bool { return true },
RetryTimeout: 30 * time.Second,
},
attempt: 3,
firstAttempt: time1,
},
want: 0,
},
{
name: "fourth attempt backoff 1",
args: args{
currentTimeUtc: time2,
policy: ActivityRetryPolicy{
MaxAttempts: 3,
InitialRetryInterval: 2 * time.Second,
BackoffCoefficient: 1,
MaxRetryInterval: 10 * time.Second,
Handle: func(err error) bool { return true },
RetryTimeout: 2 * time.Minute,
},
attempt: 3,
firstAttempt: time1,
},
want: 2 * time.Second,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := computeNextDelay(tt.args.currentTimeUtc, tt.args.policy, tt.args.attempt, tt.args.firstAttempt, tt.args.err); got != tt.want {
t.Errorf("computeNextDelay() = %v, want %v", got, tt.want)
}
})
}
}
Loading

0 comments on commit 0c4afbc

Please sign in to comment.