Skip to content

Commit

Permalink
Call OnFatalError for workers using Start (#823)
Browse files Browse the repository at this point in the history
Fixes #822
  • Loading branch information
cretz committed Jun 6, 2022
1 parent 17c0144 commit e9dd80f
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 43 deletions.
61 changes: 39 additions & 22 deletions internal/internal_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,9 @@ type (
// WorkerStopChannel is a read only channel listen on worker close. The worker will close the channel before exit.
WorkerStopChannel <-chan struct{}

// WorkerFatalErrorChannel is a channel for fatal errors that should stop
// the worker. This is sent to asynchronously, so it should be buffered.
WorkerFatalErrorChannel chan<- error
// WorkerFatalErrorCallback is a callback for fatal errors that should stop
// the worker.
WorkerFatalErrorCallback func(error)

// SessionResourceID is a unique identifier of the resource the session will consume
SessionResourceID string
Expand Down Expand Up @@ -288,7 +288,7 @@ func newWorkflowTaskWorkerInternal(
identity: params.Identity,
workerType: "WorkflowWorker",
stopTimeout: params.WorkerStopTimeout,
fatalErrCh: params.WorkerFatalErrorChannel},
fatalErrCb: params.WorkerFatalErrorCallback},
params.Logger,
params.MetricsHandler,
nil,
Expand All @@ -312,7 +312,7 @@ func newWorkflowTaskWorkerInternal(
identity: params.Identity,
workerType: "LocalActivityWorker",
stopTimeout: params.WorkerStopTimeout,
fatalErrCh: params.WorkerFatalErrorChannel},
fatalErrCb: params.WorkerFatalErrorCallback},
params.Logger,
params.MetricsHandler,
nil,
Expand Down Expand Up @@ -428,7 +428,7 @@ func newActivityTaskWorker(taskHandler ActivityTaskHandler, service workflowserv
identity: workerParams.Identity,
workerType: "ActivityWorker",
stopTimeout: workerParams.WorkerStopTimeout,
fatalErrCh: workerParams.WorkerFatalErrorChannel,
fatalErrCb: workerParams.WorkerFatalErrorCallback,
userContextCancel: workerParams.UserContextCancel},
workerParams.Logger,
workerParams.MetricsHandler,
Expand Down Expand Up @@ -856,8 +856,8 @@ type AggregatedWorker struct {
logger log.Logger
registry *registry
stopC chan struct{}
fatalErrCh chan error
fatalErrCb func(error)
fatalErr error
fatalErrLock sync.Mutex
}

// RegisterWorkflow registers workflow implementation with the AggregatedWorker
Expand Down Expand Up @@ -1026,14 +1026,11 @@ func (aw *AggregatedWorker) Run(interruptCh <-chan interface{}) error {
case s := <-interruptCh:
aw.logger.Info("Worker has been stopped.", "Signal", s)
aw.Stop()
case err := <-aw.fatalErrCh:
// Fatal error will already have been logged where it is set
if aw.fatalErrCb != nil {
aw.fatalErrCb(err)
}
aw.Stop()
return err
case <-aw.stopC:
aw.fatalErrLock.Lock()
defer aw.fatalErrLock.Unlock()
// This may be nil if this wasn't stopped due to fatal error
return aw.fatalErr
}
return nil
}
Expand Down Expand Up @@ -1311,9 +1308,30 @@ func NewAggregatedWorker(client *WorkflowClient, taskQueue string, options Worke
panic("cannot set MaxConcurrentWorkflowTaskPollers to 1")
}

// We need this buffered since the sender will be sending async and we only
// need the first fatal error
fatalErrCh := make(chan error, 1)
// Need reference to result for fatal error handler
var aw *AggregatedWorker
fatalErrorCallback := func(err error) {
// Set the fatal error if not already set
aw.fatalErrLock.Lock()
alreadySet := aw.fatalErr != nil
if !alreadySet {
aw.fatalErr = err
}
aw.fatalErrLock.Unlock()
// Only do the rest if not already set
if !alreadySet {
// Invoke the callback if present
if options.OnFatalError != nil {
options.OnFatalError(err)
}
// Stop the worker if not already stopped
select {
case <-aw.stopC:
default:
aw.Stop()
}
}
}

cache := NewWorkerCache()
workerParams := workerExecutionParameters{
Expand All @@ -1337,7 +1355,7 @@ func NewAggregatedWorker(client *WorkflowClient, taskQueue string, options Worke
WorkflowPanicPolicy: options.WorkflowPanicPolicy,
DataConverter: client.dataConverter,
WorkerStopTimeout: options.WorkerStopTimeout,
WorkerFatalErrorChannel: fatalErrCh,
WorkerFatalErrorCallback: fatalErrorCallback,
ContextPropagators: client.contextPropagators,
DeadlockDetectionTimeout: options.DeadlockDetectionTimeout,
DefaultHeartbeatThrottleInterval: options.DefaultHeartbeatThrottleInterval,
Expand Down Expand Up @@ -1393,17 +1411,16 @@ func NewAggregatedWorker(client *WorkflowClient, taskQueue string, options Worke
})
}

return &AggregatedWorker{
aw = &AggregatedWorker{
client: client,
workflowWorker: workflowWorker,
activityWorker: activityWorker,
sessionWorker: sessionWorker,
logger: workerParams.Logger,
registry: registry,
stopC: make(chan struct{}),
fatalErrCh: fatalErrCh,
fatalErrCb: options.OnFatalError,
}
return aw
}

func processTestTags(wOptions *WorkerOptions, ep *workerExecutionParameters) {
Expand Down
12 changes: 5 additions & 7 deletions internal/internal_worker_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ type (
identity string
workerType string
stopTimeout time.Duration
fatalErrCh chan<- error
fatalErrCb func(error)
userContextCancel context.CancelFunc
}

Expand All @@ -172,7 +172,7 @@ type (

pollerRequestCh chan struct{}
taskQueueCh chan interface{}
fatalErrCh chan<- error
fatalErrCb func(error)
sessionTokenBucket *sessionTokenBucket

lastPollTaskErrMessage string
Expand Down Expand Up @@ -214,7 +214,7 @@ func newBaseWorker(
taskSlotsAvailable: int32(options.maxConcurrentTask),
pollerRequestCh: make(chan struct{}, options.maxConcurrentTask),
taskQueueCh: make(chan interface{}), // no buffer, so poller only able to poll new task after previous is dispatched.
fatalErrCh: options.fatalErrCh,
fatalErrCb: options.fatalErrCb,

limiterContext: ctx,
limiterContextCancel: cancel,
Expand Down Expand Up @@ -317,10 +317,8 @@ func (bw *baseWorker) pollTask() {
if err != nil {
if isNonRetriableError(err) {
bw.logger.Error("Worker received non-retriable error. Shutting down.", tagError, err)
// Set the error and assume it is buffered with room
select {
case bw.fatalErrCh <- err:
default:
if bw.fatalErrCb != nil {
bw.fatalErrCb(err)
}
return
}
Expand Down
49 changes: 35 additions & 14 deletions test/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2157,7 +2157,15 @@ func (ts *IntegrationTestSuite) TestLargeHistoryReplay() {
ts.Contains(err.Error(), "intentional panic")
}

func (ts *IntegrationTestSuite) TestWorkerFatalError() {
func (ts *IntegrationTestSuite) TestWorkerFatalErrorOnRun() {
ts.testWorkerFatalError(true)
}

func (ts *IntegrationTestSuite) TestWorkerFatalErrorOnStart() {
ts.testWorkerFatalError(false)
}

func (ts *IntegrationTestSuite) testWorkerFatalError(useWorkerRun bool) {
// Make a new client that will fail a poll with a namespace not found
c, err := client.Dial(client.Options{
HostPort: ts.config.ServiceAddr,
Expand All @@ -2175,6 +2183,8 @@ func (ts *IntegrationTestSuite) TestWorkerFatalError() {
opts ...grpc.CallOption,
) error {
if method == "/temporal.api.workflowservice.v1.WorkflowService/PollWorkflowTaskQueue" {
// We sleep a bit to let all internal workers start
time.Sleep(1 * time.Second)
return serviceerror.NewNamespaceNotFound(ts.config.Namespace)
}
return invoker(ctx, method, req, reply, cc, opts...)
Expand All @@ -2186,22 +2196,33 @@ func (ts *IntegrationTestSuite) TestWorkerFatalError() {
defer c.Close()

// Create a worker that uses that client
var lastErr error
w := worker.New(c, "ignored-task-queue", worker.Options{OnFatalError: func(err error) { lastErr = err }})
callbackErrCh := make(chan error, 1)
w := worker.New(c, "ignored-task-queue", worker.Options{OnFatalError: func(err error) { callbackErrCh <- err }})

// Do run-based or start-based worker
runErrCh := make(chan error, 1)
if useWorkerRun {
go func() { runErrCh <- w.Run(nil) }()
} else {
ts.NoError(w.Start())
}

// Run it and confirm it fails
go func() { runErrCh <- w.Run(nil) }()
var runErr error
select {
case <-time.After(10 * time.Second):
ts.Fail("timeout")
case runErr = <-runErrCh:
// Wait for done
var callbackErr, runErr error
for callbackErr == nil || (useWorkerRun && runErr == nil) {
select {
case <-time.After(10 * time.Second):
ts.Fail("timeout")
case callbackErr = <-callbackErrCh:
case runErr = <-runErrCh:
}
}

// Check error
ts.IsType(&serviceerror.NamespaceNotFound{}, callbackErr)
if runErr != nil {
ts.Equal(callbackErr, runErr)
}
ts.Error(lastErr)
ts.Error(runErr)
ts.Equal(lastErr, runErr)
ts.IsType(&serviceerror.NamespaceNotFound{}, runErr)
}

func (ts *IntegrationTestSuite) TestNonDeterminismFailureCauseBadStateMachine() {
Expand Down

0 comments on commit e9dd80f

Please sign in to comment.