From 787cd804777279835487ce5e072390a6afbb6ea6 Mon Sep 17 00:00:00 2001 From: Quinn Klassen Date: Mon, 27 Mar 2023 08:07:45 -0700 Subject: [PATCH] Track the worker for each workflow context (#1070) If a different worker is trying to use a workflow context from a different worker discard and create a new context. --- internal/internal_task_handlers.go | 33 +++++++++++--- test/integration_test.go | 72 ++++++++++++++++++++++++++++++ test/workflow_test.go | 38 ++++++++++++++++ 3 files changed, 137 insertions(+), 6 deletions(-) diff --git a/internal/internal_task_handlers.go b/internal/internal_task_handlers.go index bdc1f7529..af499d628 100644 --- a/internal/internal_task_handlers.go +++ b/internal/internal_task_handlers.go @@ -602,20 +602,41 @@ func (wth *workflowTaskHandlerImpl) getOrCreateWorkflowContext( if task.Query == nil || (task.Query != nil && !isFullHistory) { workflowContext = wth.cache.getWorkflowContext(runID) } - + // Verify the cached state is current and for the correct worker if workflowContext != nil { workflowContext.Lock() - if task.Query != nil && !isFullHistory { + if task.Query != nil && !isFullHistory && wth == workflowContext.wth { // query task and we have a valid cached state metricsHandler.Counter(metrics.StickyCacheHit).Inc(1) - } else if history.Events[0].GetEventId() == workflowContext.previousStartedEventID+1 { + } else if history.Events[0].GetEventId() == workflowContext.previousStartedEventID+1 && wth == workflowContext.wth { // non query task and we have a valid cached state metricsHandler.Counter(metrics.StickyCacheHit).Inc(1) } else { - // non query task and cached state is missing events, we need to discard the cached state and rebuild one. - _ = workflowContext.ResetIfStale(task, historyIterator) + // possible another task already destroyed this context. + if !workflowContext.IsDestroyed() { + // non query task and cached state is missing events, we need to discard the cached state and build a new one. + if history.Events[0].GetEventId() != workflowContext.previousStartedEventID+1 { + wth.logger.Debug("Cached state staled, new task has unexpected events", + tagWorkflowID, task.WorkflowExecution.GetWorkflowId(), + tagRunID, task.WorkflowExecution.GetRunId(), + tagAttempt, task.Attempt, + tagCachedPreviousStartedEventID, workflowContext.previousStartedEventID, + tagTaskFirstEventID, task.History.Events[0].GetEventId(), + tagTaskStartedEventID, task.GetStartedEventId(), + tagPreviousStartedEventID, task.GetPreviousStartedEventId(), + ) + } else { + wth.logger.Debug("Cached state started on different worker, creating new context") + } + wth.cache.removeWorkflowContext(runID) + workflowContext.clearState() + } + workflowContext.Unlock(err) + workflowContext = nil } - } else { + } + // If the workflow was not cached or the cache was stale. + if workflowContext == nil { if !isFullHistory { // we are getting partial history task, but cached state was already evicted. // we need to reset history so we get events from beginning to replay/rebuild the state diff --git a/test/integration_test.go b/test/integration_test.go index 9705dfde7..eb6339785 100644 --- a/test/integration_test.go +++ b/test/integration_test.go @@ -2420,6 +2420,78 @@ func (ts *IntegrationTestSuite) TestDeterminismUpsertSearchAttributesConditional ts.testStaleCacheReplayDeterminism(ctx, run, maxTicks) } +func (ts *IntegrationTestSuite) TestLocalActivityWorkerRestart() { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + maxTicks := 3 + options := ts.startWorkflowOptions("test-local-activity-worker-restart-" + uuid.New()) + + run, err := ts.client.ExecuteWorkflow( + ctx, + options, + ts.workflows.LocalActivityStaleCache, + maxTicks, + ) + ts.NoError(err) + + // clean up if test fails + defer func() { _ = ts.client.TerminateWorkflow(ctx, run.GetID(), run.GetRunID(), "", nil) }() + ts.waitForQueryTrue(run, "is-wait-tick-count", 1) + + // Restart worker + ts.workerStopped = true + currentWorker := ts.worker + currentWorker.Stop() + currentWorker = worker.New(ts.client, ts.taskQueueName, worker.Options{}) + ts.registerWorkflowsAndActivities(currentWorker) + ts.NoError(currentWorker.Start()) + defer currentWorker.Stop() + + for i := 0; i < maxTicks-1; i++ { + ts.NoError(ts.client.SignalWorkflow(ctx, run.GetID(), run.GetRunID(), "tick", nil)) + ts.waitForQueryTrue(run, "is-wait-tick-count", 2+i) + } + err = run.Get(ctx, nil) + ts.NoError(err) +} + +func (ts *IntegrationTestSuite) TestLocalActivityStaleCache() { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + maxTicks := 3 + options := ts.startWorkflowOptions("test-local-activity-stale-cache-" + uuid.New()) + + run, err := ts.client.ExecuteWorkflow( + ctx, + options, + ts.workflows.LocalActivityStaleCache, + maxTicks, + ) + ts.NoError(err) + + // clean up if test fails + defer func() { _ = ts.client.TerminateWorkflow(ctx, run.GetID(), run.GetRunID(), "", nil) }() + ts.waitForQueryTrue(run, "is-wait-tick-count", 1) + + ts.workerStopped = true + currentWorker := ts.worker + currentWorker.Stop() + for i := 0; i < maxTicks-1; i++ { + func() { + ts.NoError(ts.client.SignalWorkflow(ctx, run.GetID(), run.GetRunID(), "tick", nil)) + currentWorker = worker.New(ts.client, ts.taskQueueName, worker.Options{}) + defer currentWorker.Stop() + ts.registerWorkflowsAndActivities(currentWorker) + ts.NoError(currentWorker.Start()) + ts.waitForQueryTrue(run, "is-wait-tick-count", 2+i) + }() + } + err = run.Get(ctx, nil) + ts.NoError(err) +} + func (ts *IntegrationTestSuite) TestDeterminismUpsertMemoConditional() { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() diff --git a/test/workflow_test.go b/test/workflow_test.go index 3a3d8cb11..d0f3f7af3 100644 --- a/test/workflow_test.go +++ b/test/workflow_test.go @@ -1984,6 +1984,43 @@ func (w *Workflows) UpsertMemoConditional(ctx workflow.Context, maxTicks int) er } } +func (w *Workflows) LocalActivityStaleCache(ctx workflow.Context, maxTicks int) error { + var waitTickCount int + tickCh := workflow.GetSignalChannel(ctx, "tick") + err := workflow.SetQueryHandler( + ctx, + "is-wait-tick-count", + func(v int) (bool, error) { return waitTickCount == v, nil }, + ) + if err != nil { + return err + } + oneRetry := &temporal.RetryPolicy{InitialInterval: 1 * time.Nanosecond, MaximumAttempts: 1} + + ctx = workflow.WithLocalActivityOptions(ctx, workflow.LocalActivityOptions{ + StartToCloseTimeout: 5 * time.Second, + RetryPolicy: oneRetry, + }) + + // Now just wait for signals over and over + for { + waitTickCount++ + if waitTickCount >= maxTicks { + return nil + } + tickCh.Receive(ctx, nil) + err = workflow.ExecuteLocalActivity(ctx, func(tickCount int) error { + log.Printf("Running local activity on tickCount %d", waitTickCount) + return nil + }, waitTickCount).Get(ctx, nil) + if err != nil { + return err + } + + log.Printf("Signal received (replaying? %v)", workflow.IsReplaying(ctx)) + } +} + func (w *Workflows) MutableSideEffect(ctx workflow.Context, startVal int) (currVal int, err error) { // Make some mutable side effect calls with timers in between sideEffector := func(retVal int) (newVal int, err error) { @@ -2113,6 +2150,7 @@ func (w *Workflows) register(worker worker.Worker) { worker.RegisterWorkflow(w.WorkflowWithLocalActivityStartWhenTimerCancel) worker.RegisterWorkflow(w.WorkflowWithParallelSideEffects) worker.RegisterWorkflow(w.WorkflowWithParallelMutableSideEffects) + worker.RegisterWorkflow(w.LocalActivityStaleCache) worker.RegisterWorkflow(w.SignalWorkflow) worker.RegisterWorkflow(w.CronWorkflow) worker.RegisterWorkflow(w.CancelTimerConcurrentWithOtherCommandWorkflow)