diff --git a/internal/internal_decision_state_machine.go b/internal/internal_decision_state_machine.go index 7982fd96f..4be15f3f0 100644 --- a/internal/internal_decision_state_machine.go +++ b/internal/internal_decision_state_machine.go @@ -874,8 +874,8 @@ func (h *commandsHelper) addCommand(command commandStateMachine) { // might be in the same workflow task. In practice this only seems to happen during unhandled command events. func (h *commandsHelper) removeCancelOfResolvedCommand(commandID commandID) { // Ensure this isn't misused for non-cancel commands - if commandID.commandType != commandTypeCancelTimer && commandID.commandType != commandTypeRequestCancelActivityTask { - panic("removeCancelOfResolvedCommand should only be called for cancel timer / activity") + if commandID.commandType != commandTypeCancelTimer { + panic("removeCancelOfResolvedCommand should only be called for cancel timer") } orderedCmdEl, ok := h.commands[commandID] if ok { @@ -913,10 +913,6 @@ func (h *commandsHelper) requestCancelActivityTask(activityID string) commandSta func (h *commandsHelper) handleActivityTaskClosed(activityID string, scheduledEventID int64) commandStateMachine { command := h.getCommand(makeCommandID(commandTypeActivity, activityID)) - // If, for whatever reason, we were going to send an activity cancel request, don't do that anymore - // since we already know the activity is resolved. - possibleCancelID := makeCommandID(commandTypeRequestCancelActivityTask, activityID) - h.removeCancelOfResolvedCommand(possibleCancelID) command.handleCompletionEvent() delete(h.scheduledEventIDToActivityID, scheduledEventID) return command diff --git a/test/activity_test.go b/test/activity_test.go index 79dccf59a..d011b80e6 100644 --- a/test/activity_test.go +++ b/test/activity_test.go @@ -186,6 +186,17 @@ func (a *Activities) WaitForWorkerStop(ctx context.Context, timeout time.Duratio } } +func (a *Activities) HeartbeatUntilCanceled(ctx context.Context, heartbeatFreq time.Duration) error { + for { + select { + case <-ctx.Done(): + return nil + case <-time.After(heartbeatFreq): + activity.RecordHeartbeat(ctx) + } + } +} + func (a *Activities) Panicked(ctx context.Context) ([]string, error) { panic(fmt.Sprintf("simulated panic on attempt %v", activity.GetInfo(ctx).Attempt)) } diff --git a/test/integration_test.go b/test/integration_test.go index 5039336a9..0b1cf8575 100644 --- a/test/integration_test.go +++ b/test/integration_test.go @@ -1288,6 +1288,65 @@ func (ts *IntegrationTestSuite) TestCancelChildAndExecuteActivityRace() { ts.NoError(err) } +func (ts *IntegrationTestSuite) TestAdvancedPostCancellation() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + assertPostCancellation := func(in *AdvancedPostCancellationInput) { + // Start workflow + run, err := ts.client.ExecuteWorkflow(ctx, ts.startWorkflowOptions("test-advanced-post-cancellation-"+uuid.New()), + ts.workflows.AdvancedPostCancellation, in) + ts.NoError(err) + + // Poll to check if waiting for cancel + var waitingForCancel bool + for i := 0; !waitingForCancel && i < 30; i++ { + time.Sleep(50 * time.Millisecond) + val, err := ts.client.QueryWorkflow(ctx, run.GetID(), run.GetRunID(), "waiting-for-cancel") + // Ignore query failed because it means query may not be registered yet + var queryFailed *serviceerror.QueryFailed + if errors.As(err, &queryFailed) { + continue + } + ts.NoError(err) + ts.NoError(val.Get(&waitingForCancel)) + } + ts.True(waitingForCancel) + + // Now cancel it + ts.NoError(ts.client.CancelWorkflow(ctx, run.GetID(), run.GetRunID())) + + // Confirm no error + ts.NoError(run.Get(ctx, nil)) + } + + // Check just activity and timer + assertPostCancellation(&AdvancedPostCancellationInput{ + PreCancelActivity: true, + PostCancelActivity: true, + }) + assertPostCancellation(&AdvancedPostCancellationInput{ + PreCancelTimer: true, + PostCancelTimer: true, + }) + // Check mixed + assertPostCancellation(&AdvancedPostCancellationInput{ + PreCancelActivity: true, + PostCancelTimer: true, + }) + assertPostCancellation(&AdvancedPostCancellationInput{ + PreCancelTimer: true, + PostCancelActivity: true, + }) + // Check all + assertPostCancellation(&AdvancedPostCancellationInput{ + PreCancelActivity: true, + PreCancelTimer: true, + PostCancelActivity: true, + PostCancelTimer: true, + }) +} + func (ts *IntegrationTestSuite) registerNamespace() { client, err := client.NewNamespaceClient(client.Options{HostPort: ts.config.ServiceAddr}) ts.NoError(err) diff --git a/test/workflow_test.go b/test/workflow_test.go index f90cc9eb4..d057ef415 100644 --- a/test/workflow_test.go +++ b/test/workflow_test.go @@ -1356,6 +1356,67 @@ func (w *Workflows) SleepForDuration(ctx workflow.Context, d time.Duration) erro return workflow.Sleep(ctx, d) } +type AdvancedPostCancellationInput struct { + PreCancelActivity bool + PostCancelActivity bool + PreCancelTimer bool + PostCancelTimer bool +} + +func (w *Workflows) AdvancedPostCancellation(ctx workflow.Context, in *AdvancedPostCancellationInput) error { + // Setup query to tell caller we're waiting for cancel + waitingForCancel := false + err := workflow.SetQueryHandler(ctx, "waiting-for-cancel", func() (bool, error) { + return waitingForCancel, nil + }) + if err != nil { + return err + } + + ctx = workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ + StartToCloseTimeout: 5 * time.Minute, + HeartbeatTimeout: 5 * time.Second, + WaitForCancellation: true, + }) + var a *Activities + + // Start pre-cancel pieces + var actFut, timerFut workflow.Future + if in.PreCancelActivity { + actFut = workflow.ExecuteActivity(ctx, a.HeartbeatUntilCanceled, 1*time.Second) + } + if in.PreCancelTimer { + timerFut = workflow.NewTimer(ctx, 10*time.Minute) + } + + // Set as waiting and wait for futures + waitingForCancel = true + if actFut != nil { + if err := actFut.Get(ctx, nil); err != nil { + return fmt.Errorf("activity did not gracefully cancel: %w", err) + } + } + if timerFut != nil { + if err := timerFut.Get(ctx, nil); !temporal.IsCanceledError(err) { + return fmt.Errorf("timer did not get canceled error, got: %w", err) + } + } + + // Run post-cancel pieces with context not considered cancel + ctx, _ = workflow.NewDisconnectedContext(ctx) + if in.PostCancelActivity { + if err := workflow.ExecuteActivity(ctx, a.Sleep, 1*time.Millisecond).Get(ctx, nil); err != nil { + return fmt.Errorf("failed post-cancel activity: %w", err) + } + } + if in.PostCancelTimer { + if err := workflow.NewTimer(ctx, 1*time.Millisecond).Get(ctx, nil); err != nil { + return fmt.Errorf("failed post-cancel timer: %w", err) + } + } + return nil +} + func (w *Workflows) register(worker worker.Worker) { worker.RegisterWorkflow(w.ActivityCancelRepro) worker.RegisterWorkflow(w.ActivityCompletionUsingID) @@ -1413,6 +1474,7 @@ func (w *Workflows) register(worker worker.Worker) { worker.RegisterWorkflow(w.CancelMultipleCommandsOverMultipleTasks) worker.RegisterWorkflow(w.CancelChildAndExecuteActivityRace) worker.RegisterWorkflow(w.SleepForDuration) + worker.RegisterWorkflow(w.AdvancedPostCancellation) worker.RegisterWorkflow(w.child) worker.RegisterWorkflow(w.childForMemoAndSearchAttr)