diff --git a/src/Abstractions/TaskOptions.cs b/src/Abstractions/TaskOptions.cs index 7c0d54ee..cb94ac9b 100644 --- a/src/Abstractions/TaskOptions.cs +++ b/src/Abstractions/TaskOptions.cs @@ -39,6 +39,7 @@ public TaskOptions(TaskOptions options) Check.NotNull(options); this.Retry = options.Retry; this.Tags = options.Tags; + this.CancellationToken = options.CancellationToken; } /// @@ -51,6 +52,80 @@ public TaskOptions(TaskOptions options) /// public IDictionary? Tags { get; init; } + /// + /// Gets the cancellation token that can be used to cancel the task. + /// + /// + /// + /// The cancellation token provides cooperative cancellation for activities, sub-orchestrators, and retry logic. + /// Due to the durable orchestrator execution model, cancellation only occurs at specific points when the + /// orchestrator code is executing. + /// + /// + /// Cancellation behavior: + /// + /// + /// 1. Pre-scheduling check: If the token is cancelled before calling + /// CallActivityAsync or CallSubOrchestratorAsync, a is thrown + /// immediately without scheduling the task. + /// + /// + /// 2. Retry handlers: The cancellation token is passed to custom retry handlers via + /// , allowing them to check for cancellation and stop retrying between attempts. + /// + /// + /// Important limitation: Once an activity or sub-orchestrator is scheduled, the orchestrator + /// yields execution and waits for the task to complete. During this yield period, the orchestrator code is not + /// running, so it cannot respond to cancellation requests. Cancelling the token while waiting will not wake up + /// the orchestrator or cancel the waiting task. This is a fundamental limitation of the durable orchestrator + /// execution model. + /// + /// + /// Note: Cancelling a parent orchestrator's token does not terminate sub-orchestrator instances that have + /// already been scheduled. + /// + /// + /// Example of pre-scheduling cancellation: + /// + /// using CancellationTokenSource cts = new CancellationTokenSource(); + /// cts.Cancel(); // Cancel before scheduling + /// + /// TaskOptions options = new TaskOptions { CancellationToken = cts.Token }; + /// + /// try + /// { + /// // This will throw TaskCanceledException without scheduling the activity + /// string result = await context.CallActivityAsync<string>("MyActivity", "input", options); + /// } + /// catch (TaskCanceledException) + /// { + /// // Handle cancellation + /// } + /// + /// + /// + /// Example of using cancellation with retry logic: + /// + /// using CancellationTokenSource cts = new CancellationTokenSource(); + /// TaskOptions options = new TaskOptions + /// { + /// Retry = TaskRetryOptions.FromRetryHandler(retryContext => + /// { + /// if (retryContext.CancellationToken.IsCancellationRequested) + /// { + /// return false; // Stop retrying + /// } + /// return retryContext.LastAttemptNumber < 3; + /// }), + /// CancellationToken = cts.Token + /// }; + /// + /// await context.CallActivityAsync("MyActivity", "input", options); + /// + /// + /// + public CancellationToken CancellationToken { get; init; } + /// /// Returns a new from the provided . /// diff --git a/src/Worker/Core/Shims/TaskOrchestrationContextWrapper.cs b/src/Worker/Core/Shims/TaskOrchestrationContextWrapper.cs index 98ba2721..5434ab01 100644 --- a/src/Worker/Core/Shims/TaskOrchestrationContextWrapper.cs +++ b/src/Worker/Core/Shims/TaskOrchestrationContextWrapper.cs @@ -143,15 +143,25 @@ public override async Task CallActivityAsync( try { IDictionary tags = ImmutableDictionary.Empty; + CancellationToken cancellationToken = default; if (options is TaskOptions callActivityOptions) { if (callActivityOptions.Tags is not null) { tags = callActivityOptions.Tags; } + + cancellationToken = callActivityOptions.CancellationToken; + } + + // If cancellation was requested before starting, throw immediately + // Note: Once the activity is scheduled, the orchestrator yields and cannot respond to + // cancellation until it resumes, so this pre-check is the only cancellation point. + if (cancellationToken.IsCancellationRequested) + { + throw new TaskCanceledException("The task was cancelled before it could be scheduled."); } - // TODO: Cancellation (https://github.com/microsoft/durabletask-dotnet/issues/7) #pragma warning disable 0618 if (options?.Retry?.Policy is RetryPolicy policy) { @@ -176,7 +186,7 @@ public override async Task CallActivityAsync( parameters: input), name.Name, handler, - default); + cancellationToken); } else { @@ -217,6 +227,16 @@ public override async Task CallSubOrchestratorAsync( throw new InvalidOperationException(errorMsg); } + CancellationToken cancellationToken = options?.CancellationToken ?? default; + + // If cancellation was requested before starting, throw immediately + // Note: Once the sub-orchestrator is scheduled, the orchestrator yields and cannot respond to + // cancellation until it resumes, so this pre-check is the only cancellation point. + if (cancellationToken.IsCancellationRequested) + { + throw new TaskCanceledException("The sub-orchestrator was cancelled before it could be scheduled."); + } + try { if (options?.Retry?.Policy is RetryPolicy policy) @@ -226,7 +246,7 @@ public override async Task CallSubOrchestratorAsync( version, instanceId, policy.ToDurableTaskCoreRetryOptions(), - input, + input, options.Tags); } else if (options?.Retry?.Handler is AsyncRetryHandler handler) @@ -236,11 +256,11 @@ public override async Task CallSubOrchestratorAsync( orchestratorName.Name, version, instanceId, - input, + input, options?.Tags), orchestratorName.Name, handler, - default); + cancellationToken); } else { @@ -248,7 +268,7 @@ public override async Task CallSubOrchestratorAsync( orchestratorName.Name, version, instanceId, - input, + input, options?.Tags); } } diff --git a/test/Grpc.IntegrationTests/CancellationTests.cs b/test/Grpc.IntegrationTests/CancellationTests.cs new file mode 100644 index 00000000..c8e71937 --- /dev/null +++ b/test/Grpc.IntegrationTests/CancellationTests.cs @@ -0,0 +1,403 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.DurableTask.Client; +using Microsoft.DurableTask.Tests.Logging; +using Microsoft.DurableTask.Worker; +using Xunit.Abstractions; + +namespace Microsoft.DurableTask.Grpc.Tests; + +/// +/// Integration tests for activity and sub-orchestrator cancellation functionality. +/// +public class CancellationTests(ITestOutputHelper output, GrpcSidecarFixture sidecarFixture) : + IntegrationTestBase(output, sidecarFixture) +{ + /// + /// Tests that an activity can be cancelled using a CancellationToken. + /// + [Fact] + public async Task ActivityCancellation() + { + TaskName orchestratorName = nameof(ActivityCancellation); + TaskName activityName = "SlowActivity"; + + bool activityWasInvoked = false; + + await using HostTestLifetime server = await this.StartWorkerAsync(b => + { + b.AddTasks(tasks => tasks + .AddOrchestratorFunc(orchestratorName, async ctx => + { + using CancellationTokenSource cts = new(); + + // Cancel immediately + cts.Cancel(); + + TaskOptions options = new() { CancellationToken = cts.Token }; + + try + { + await ctx.CallActivityAsync(activityName, options); + return "Should not reach here"; + } + catch (TaskCanceledException) + { + return "Cancelled"; + } + }) + .AddActivityFunc(activityName, (TaskActivityContext activityContext) => + { + activityWasInvoked = true; + return "Activity completed"; + })); + }); + + string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName); + OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync( + instanceId, getInputsAndOutputs: true, this.TimeoutToken); + + Assert.NotNull(metadata); + Assert.Equal(instanceId, metadata.InstanceId); + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + Assert.Equal("\"Cancelled\"", metadata.SerializedOutput); + Assert.False(activityWasInvoked, "Activity should not have been invoked when cancellation happens before scheduling"); + } + + /// + /// Tests that a sub-orchestrator can be cancelled using a CancellationToken. + /// + [Fact] + public async Task SubOrchestratorCancellation() + { + TaskName orchestratorName = nameof(SubOrchestratorCancellation); + TaskName subOrchestratorName = "SubOrchestrator"; + + bool subOrchestratorWasInvoked = false; + + await using HostTestLifetime server = await this.StartWorkerAsync(b => + { + b.AddTasks(tasks => tasks + .AddOrchestratorFunc(orchestratorName, async ctx => + { + using CancellationTokenSource cts = new(); + + // Cancel immediately + cts.Cancel(); + + TaskOptions options = new() { CancellationToken = cts.Token }; + + try + { + await ctx.CallSubOrchestratorAsync(subOrchestratorName, options: options); + return "Should not reach here"; + } + catch (TaskCanceledException) + { + return "Cancelled"; + } + }) + .AddOrchestratorFunc(subOrchestratorName, ctx => + { + subOrchestratorWasInvoked = true; + return Task.FromResult("Sub-orchestrator completed"); + })); + }); + + string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName); + OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync( + instanceId, getInputsAndOutputs: true, this.TimeoutToken); + + Assert.NotNull(metadata); + Assert.Equal(instanceId, metadata.InstanceId); + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + Assert.Equal("\"Cancelled\"", metadata.SerializedOutput); + Assert.False(subOrchestratorWasInvoked, "Sub-orchestrator should not have been invoked when cancellation happens before scheduling"); + } + + /// + /// Tests that cancellation token is passed to retry handler. + /// + [Fact] + public async Task RetryHandlerReceivesCancellationToken() + { + TaskName orchestratorName = nameof(RetryHandlerReceivesCancellationToken); + + int attemptCount = 0; + bool cancellationTokenWasCancelled = false; + + await using HostTestLifetime server = await this.StartWorkerAsync(b => + { + b.AddTasks(tasks => tasks + .AddOrchestratorFunc(orchestratorName, async ctx => + { + using CancellationTokenSource cts = new(); + + TaskRetryOptions retryOptions = TaskOptions.FromRetryHandler(retryContext => + { + attemptCount = retryContext.LastAttemptNumber; + cancellationTokenWasCancelled = retryContext.CancellationToken.IsCancellationRequested; + + // Cancel after first attempt + if (attemptCount == 1) + { + cts.Cancel(); + } + + // Try to retry + return attemptCount < 5; + }).Retry!; + + TaskOptions options = new(retryOptions) + { + CancellationToken = cts.Token + }; + + try + { + await ctx.CallActivityAsync("FailingActivity", options); + return "Should not reach here"; + } + catch (TaskFailedException) + { + return $"Failed after {attemptCount} attempts"; + } + }) + .AddActivityFunc("FailingActivity", (TaskActivityContext activityContext) => + { + throw new InvalidOperationException("Activity always fails"); + })); + }); + + string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName); + OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync( + instanceId, getInputsAndOutputs: true, this.TimeoutToken); + + Assert.NotNull(metadata); + Assert.Equal(instanceId, metadata.InstanceId); + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + Assert.True(attemptCount >= 1, "Retry handler should have been called at least once"); + Assert.True(cancellationTokenWasCancelled, "Cancellation token should have been cancelled in retry handler"); + } + + /// + /// Tests that retry handler can check cancellation token and stop retrying. + /// + [Fact] + public async Task RetryHandlerCanStopOnCancellation() + { + TaskName orchestratorName = nameof(RetryHandlerCanStopOnCancellation); + + int maxAttempts = 0; + + await using HostTestLifetime server = await this.StartWorkerAsync(b => + { + b.AddTasks(tasks => tasks + .AddOrchestratorFunc(orchestratorName, async ctx => + { + using CancellationTokenSource cts = new(); + + TaskRetryOptions retryOptions = TaskOptions.FromRetryHandler(retryContext => + { + maxAttempts = retryContext.LastAttemptNumber; + + // Cancel after second attempt + if (maxAttempts == 2) + { + cts.Cancel(); + } + + // Stop retrying if cancelled + if (retryContext.CancellationToken.IsCancellationRequested) + { + return false; + } + + return maxAttempts < 10; + }).Retry!; + + TaskOptions options = new(retryOptions) + { + CancellationToken = cts.Token + }; + + try + { + await ctx.CallActivityAsync("FailingActivity", options); + return "Should not reach here"; + } + catch (TaskFailedException) + { + return $"Stopped after {maxAttempts} attempts"; + } + }) + .AddActivityFunc("FailingActivity", (TaskActivityContext activityContext) => + { + throw new InvalidOperationException("Activity always fails"); + })); + }); + + string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName); + OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync( + instanceId, getInputsAndOutputs: true, this.TimeoutToken); + + Assert.NotNull(metadata); + Assert.Equal(instanceId, metadata.InstanceId); + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + Assert.Equal(2, maxAttempts); // Should stop after 2 attempts due to cancellation + Assert.Equal("\"Stopped after 2 attempts\"", metadata.SerializedOutput); + } + + /// + /// Tests that when a token is cancelled outside the retry handler (between retry attempts), + /// the handler receives the cancelled token on the next attempt. + /// + [Fact] + public async Task RetryHandlerReceivesCancelledTokenFromOutside() + { + TaskName orchestratorName = nameof(RetryHandlerReceivesCancelledTokenFromOutside); + + int attemptCount = 0; + bool tokenWasCancelledInHandler = false; + CancellationTokenSource? cts = null; + + await using HostTestLifetime server = await this.StartWorkerAsync(b => + { + b.AddTasks(tasks => tasks + .AddOrchestratorFunc(orchestratorName, async ctx => + { + cts = new CancellationTokenSource(); + + TaskRetryOptions retryOptions = TaskOptions.FromRetryHandler(retryContext => + { + attemptCount = retryContext.LastAttemptNumber; + + // Check if token is cancelled + tokenWasCancelledInHandler = retryContext.CancellationToken.IsCancellationRequested; + + // Stop retrying if cancelled + if (retryContext.CancellationToken.IsCancellationRequested) + { + return false; + } + + return attemptCount < 5; + }).Retry!; + + TaskOptions options = new(retryOptions) + { + CancellationToken = cts.Token + }; + + // Cancel the token AFTER creating options but BEFORE first attempt + // This tests that the retry handler receives the cancelled token from outside + cts.Cancel(); + + try + { + await ctx.CallActivityAsync("FailingActivity", options); + return "Should not reach here - activity succeeded"; + } + catch (TaskCanceledException) + { + // Pre-scheduling check caught the cancelled token before even attempting + return $"Cancelled before scheduling, attempts: {attemptCount}"; + } + catch (TaskFailedException) + { + // Activity failed and retry handler stopped retrying + return $"Failed after {attemptCount} attempts, token was cancelled in handler: {tokenWasCancelledInHandler}"; + } + }) + .AddActivityFunc("FailingActivity", (TaskActivityContext activityContext) => + { + throw new InvalidOperationException("Activity always fails"); + })); + }); + + string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName); + OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync( + instanceId, getInputsAndOutputs: true, this.TimeoutToken); + + Assert.NotNull(metadata); + Assert.Equal(instanceId, metadata.InstanceId); + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + + // Since token was cancelled before CallActivityAsync, the pre-scheduling check throws + // TaskCanceledException and retry handler never gets called + Assert.Equal(0, attemptCount); + Assert.Contains("Cancelled before scheduling", metadata.SerializedOutput); + } + + /// + /// Tests that when calling multiple activities in a loop with a cancellation token, + /// the loop exits after cancellation instead of continuing to call remaining activities. + /// This is the main use case for pre-scheduling cancellation checks. + /// + [Fact] + public async Task MultipleActivitiesInLoopWithCancellation() + { + TaskName orchestratorName = nameof(MultipleActivitiesInLoopWithCancellation); + TaskName activityName = "ProcessItem"; + + int activitiesInvoked = 0; + int totalItems = 10; + + await using HostTestLifetime server = await this.StartWorkerAsync(b => + { + b.AddTasks(tasks => tasks + .AddOrchestratorFunc(orchestratorName, async ctx => + { + using CancellationTokenSource cts = new(); + TaskOptions options = new() { CancellationToken = cts.Token }; + + List results = new(); + + for (int i = 0; i < totalItems; i++) + { + // Cancel after processing 3 items + if (i == 3) + { + cts.Cancel(); + } + + try + { + string result = await ctx.CallActivityAsync(activityName, i, options); + results.Add(result); + } + catch (TaskCanceledException) + { + // Pre-scheduling check caught cancellation - exit loop + results.Add($"Cancelled at item {i}"); + break; + } + } + + return $"Processed {results.Count} items: [{string.Join(", ", results)}]"; + }) + .AddActivityFunc(activityName, (TaskActivityContext ctx, int item) => + { + Interlocked.Increment(ref activitiesInvoked); + return $"Item {item}"; + })); + }); + + string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName); + OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync( + instanceId, getInputsAndOutputs: true, this.TimeoutToken); + + Assert.NotNull(metadata); + Assert.Equal(instanceId, metadata.InstanceId); + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + + // Should have processed 3 items (0, 1, 2) before cancellation at item 3 + Assert.Equal(3, activitiesInvoked); + Assert.Contains("Processed 4 items", metadata.SerializedOutput); // 3 successful + 1 cancellation message + Assert.Contains("Cancelled at item 3", metadata.SerializedOutput); + Assert.Contains("Item 0", metadata.SerializedOutput); + Assert.Contains("Item 1", metadata.SerializedOutput); + Assert.Contains("Item 2", metadata.SerializedOutput); + } +}