diff --git a/src/Abstractions/TaskOptions.cs b/src/Abstractions/TaskOptions.cs
index 7c0d54ee..cb94ac9b 100644
--- a/src/Abstractions/TaskOptions.cs
+++ b/src/Abstractions/TaskOptions.cs
@@ -39,6 +39,7 @@ public TaskOptions(TaskOptions options)
Check.NotNull(options);
this.Retry = options.Retry;
this.Tags = options.Tags;
+ this.CancellationToken = options.CancellationToken;
}
///
@@ -51,6 +52,80 @@ public TaskOptions(TaskOptions options)
///
public IDictionary? Tags { get; init; }
+ ///
+ /// Gets the cancellation token that can be used to cancel the task.
+ ///
+ ///
+ ///
+ /// The cancellation token provides cooperative cancellation for activities, sub-orchestrators, and retry logic.
+ /// Due to the durable orchestrator execution model, cancellation only occurs at specific points when the
+ /// orchestrator code is executing.
+ ///
+ ///
+ /// Cancellation behavior:
+ ///
+ ///
+ /// 1. Pre-scheduling check: If the token is cancelled before calling
+ /// CallActivityAsync or CallSubOrchestratorAsync, a is thrown
+ /// immediately without scheduling the task.
+ ///
+ ///
+ /// 2. Retry handlers: The cancellation token is passed to custom retry handlers via
+ /// , allowing them to check for cancellation and stop retrying between attempts.
+ ///
+ ///
+ /// Important limitation: Once an activity or sub-orchestrator is scheduled, the orchestrator
+ /// yields execution and waits for the task to complete. During this yield period, the orchestrator code is not
+ /// running, so it cannot respond to cancellation requests. Cancelling the token while waiting will not wake up
+ /// the orchestrator or cancel the waiting task. This is a fundamental limitation of the durable orchestrator
+ /// execution model.
+ ///
+ ///
+ /// Note: Cancelling a parent orchestrator's token does not terminate sub-orchestrator instances that have
+ /// already been scheduled.
+ ///
+ ///
+ /// Example of pre-scheduling cancellation:
+ ///
+ /// using CancellationTokenSource cts = new CancellationTokenSource();
+ /// cts.Cancel(); // Cancel before scheduling
+ ///
+ /// TaskOptions options = new TaskOptions { CancellationToken = cts.Token };
+ ///
+ /// try
+ /// {
+ /// // This will throw TaskCanceledException without scheduling the activity
+ /// string result = await context.CallActivityAsync<string>("MyActivity", "input", options);
+ /// }
+ /// catch (TaskCanceledException)
+ /// {
+ /// // Handle cancellation
+ /// }
+ ///
+ ///
+ ///
+ /// Example of using cancellation with retry logic:
+ ///
+ /// using CancellationTokenSource cts = new CancellationTokenSource();
+ /// TaskOptions options = new TaskOptions
+ /// {
+ /// Retry = TaskRetryOptions.FromRetryHandler(retryContext =>
+ /// {
+ /// if (retryContext.CancellationToken.IsCancellationRequested)
+ /// {
+ /// return false; // Stop retrying
+ /// }
+ /// return retryContext.LastAttemptNumber < 3;
+ /// }),
+ /// CancellationToken = cts.Token
+ /// };
+ ///
+ /// await context.CallActivityAsync("MyActivity", "input", options);
+ ///
+ ///
+ ///
+ public CancellationToken CancellationToken { get; init; }
+
///
/// Returns a new from the provided .
///
diff --git a/src/Worker/Core/Shims/TaskOrchestrationContextWrapper.cs b/src/Worker/Core/Shims/TaskOrchestrationContextWrapper.cs
index 98ba2721..5434ab01 100644
--- a/src/Worker/Core/Shims/TaskOrchestrationContextWrapper.cs
+++ b/src/Worker/Core/Shims/TaskOrchestrationContextWrapper.cs
@@ -143,15 +143,25 @@ public override async Task CallActivityAsync(
try
{
IDictionary tags = ImmutableDictionary.Empty;
+ CancellationToken cancellationToken = default;
if (options is TaskOptions callActivityOptions)
{
if (callActivityOptions.Tags is not null)
{
tags = callActivityOptions.Tags;
}
+
+ cancellationToken = callActivityOptions.CancellationToken;
+ }
+
+ // If cancellation was requested before starting, throw immediately
+ // Note: Once the activity is scheduled, the orchestrator yields and cannot respond to
+ // cancellation until it resumes, so this pre-check is the only cancellation point.
+ if (cancellationToken.IsCancellationRequested)
+ {
+ throw new TaskCanceledException("The task was cancelled before it could be scheduled.");
}
- // TODO: Cancellation (https://github.com/microsoft/durabletask-dotnet/issues/7)
#pragma warning disable 0618
if (options?.Retry?.Policy is RetryPolicy policy)
{
@@ -176,7 +186,7 @@ public override async Task CallActivityAsync(
parameters: input),
name.Name,
handler,
- default);
+ cancellationToken);
}
else
{
@@ -217,6 +227,16 @@ public override async Task CallSubOrchestratorAsync(
throw new InvalidOperationException(errorMsg);
}
+ CancellationToken cancellationToken = options?.CancellationToken ?? default;
+
+ // If cancellation was requested before starting, throw immediately
+ // Note: Once the sub-orchestrator is scheduled, the orchestrator yields and cannot respond to
+ // cancellation until it resumes, so this pre-check is the only cancellation point.
+ if (cancellationToken.IsCancellationRequested)
+ {
+ throw new TaskCanceledException("The sub-orchestrator was cancelled before it could be scheduled.");
+ }
+
try
{
if (options?.Retry?.Policy is RetryPolicy policy)
@@ -226,7 +246,7 @@ public override async Task CallSubOrchestratorAsync(
version,
instanceId,
policy.ToDurableTaskCoreRetryOptions(),
- input,
+ input,
options.Tags);
}
else if (options?.Retry?.Handler is AsyncRetryHandler handler)
@@ -236,11 +256,11 @@ public override async Task CallSubOrchestratorAsync(
orchestratorName.Name,
version,
instanceId,
- input,
+ input,
options?.Tags),
orchestratorName.Name,
handler,
- default);
+ cancellationToken);
}
else
{
@@ -248,7 +268,7 @@ public override async Task CallSubOrchestratorAsync(
orchestratorName.Name,
version,
instanceId,
- input,
+ input,
options?.Tags);
}
}
diff --git a/test/Grpc.IntegrationTests/CancellationTests.cs b/test/Grpc.IntegrationTests/CancellationTests.cs
new file mode 100644
index 00000000..c8e71937
--- /dev/null
+++ b/test/Grpc.IntegrationTests/CancellationTests.cs
@@ -0,0 +1,403 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+using Microsoft.DurableTask.Client;
+using Microsoft.DurableTask.Tests.Logging;
+using Microsoft.DurableTask.Worker;
+using Xunit.Abstractions;
+
+namespace Microsoft.DurableTask.Grpc.Tests;
+
+///
+/// Integration tests for activity and sub-orchestrator cancellation functionality.
+///
+public class CancellationTests(ITestOutputHelper output, GrpcSidecarFixture sidecarFixture) :
+ IntegrationTestBase(output, sidecarFixture)
+{
+ ///
+ /// Tests that an activity can be cancelled using a CancellationToken.
+ ///
+ [Fact]
+ public async Task ActivityCancellation()
+ {
+ TaskName orchestratorName = nameof(ActivityCancellation);
+ TaskName activityName = "SlowActivity";
+
+ bool activityWasInvoked = false;
+
+ await using HostTestLifetime server = await this.StartWorkerAsync(b =>
+ {
+ b.AddTasks(tasks => tasks
+ .AddOrchestratorFunc(orchestratorName, async ctx =>
+ {
+ using CancellationTokenSource cts = new();
+
+ // Cancel immediately
+ cts.Cancel();
+
+ TaskOptions options = new() { CancellationToken = cts.Token };
+
+ try
+ {
+ await ctx.CallActivityAsync(activityName, options);
+ return "Should not reach here";
+ }
+ catch (TaskCanceledException)
+ {
+ return "Cancelled";
+ }
+ })
+ .AddActivityFunc(activityName, (TaskActivityContext activityContext) =>
+ {
+ activityWasInvoked = true;
+ return "Activity completed";
+ }));
+ });
+
+ string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName);
+ OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync(
+ instanceId, getInputsAndOutputs: true, this.TimeoutToken);
+
+ Assert.NotNull(metadata);
+ Assert.Equal(instanceId, metadata.InstanceId);
+ Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus);
+ Assert.Equal("\"Cancelled\"", metadata.SerializedOutput);
+ Assert.False(activityWasInvoked, "Activity should not have been invoked when cancellation happens before scheduling");
+ }
+
+ ///
+ /// Tests that a sub-orchestrator can be cancelled using a CancellationToken.
+ ///
+ [Fact]
+ public async Task SubOrchestratorCancellation()
+ {
+ TaskName orchestratorName = nameof(SubOrchestratorCancellation);
+ TaskName subOrchestratorName = "SubOrchestrator";
+
+ bool subOrchestratorWasInvoked = false;
+
+ await using HostTestLifetime server = await this.StartWorkerAsync(b =>
+ {
+ b.AddTasks(tasks => tasks
+ .AddOrchestratorFunc(orchestratorName, async ctx =>
+ {
+ using CancellationTokenSource cts = new();
+
+ // Cancel immediately
+ cts.Cancel();
+
+ TaskOptions options = new() { CancellationToken = cts.Token };
+
+ try
+ {
+ await ctx.CallSubOrchestratorAsync(subOrchestratorName, options: options);
+ return "Should not reach here";
+ }
+ catch (TaskCanceledException)
+ {
+ return "Cancelled";
+ }
+ })
+ .AddOrchestratorFunc(subOrchestratorName, ctx =>
+ {
+ subOrchestratorWasInvoked = true;
+ return Task.FromResult("Sub-orchestrator completed");
+ }));
+ });
+
+ string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName);
+ OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync(
+ instanceId, getInputsAndOutputs: true, this.TimeoutToken);
+
+ Assert.NotNull(metadata);
+ Assert.Equal(instanceId, metadata.InstanceId);
+ Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus);
+ Assert.Equal("\"Cancelled\"", metadata.SerializedOutput);
+ Assert.False(subOrchestratorWasInvoked, "Sub-orchestrator should not have been invoked when cancellation happens before scheduling");
+ }
+
+ ///
+ /// Tests that cancellation token is passed to retry handler.
+ ///
+ [Fact]
+ public async Task RetryHandlerReceivesCancellationToken()
+ {
+ TaskName orchestratorName = nameof(RetryHandlerReceivesCancellationToken);
+
+ int attemptCount = 0;
+ bool cancellationTokenWasCancelled = false;
+
+ await using HostTestLifetime server = await this.StartWorkerAsync(b =>
+ {
+ b.AddTasks(tasks => tasks
+ .AddOrchestratorFunc(orchestratorName, async ctx =>
+ {
+ using CancellationTokenSource cts = new();
+
+ TaskRetryOptions retryOptions = TaskOptions.FromRetryHandler(retryContext =>
+ {
+ attemptCount = retryContext.LastAttemptNumber;
+ cancellationTokenWasCancelled = retryContext.CancellationToken.IsCancellationRequested;
+
+ // Cancel after first attempt
+ if (attemptCount == 1)
+ {
+ cts.Cancel();
+ }
+
+ // Try to retry
+ return attemptCount < 5;
+ }).Retry!;
+
+ TaskOptions options = new(retryOptions)
+ {
+ CancellationToken = cts.Token
+ };
+
+ try
+ {
+ await ctx.CallActivityAsync("FailingActivity", options);
+ return "Should not reach here";
+ }
+ catch (TaskFailedException)
+ {
+ return $"Failed after {attemptCount} attempts";
+ }
+ })
+ .AddActivityFunc("FailingActivity", (TaskActivityContext activityContext) =>
+ {
+ throw new InvalidOperationException("Activity always fails");
+ }));
+ });
+
+ string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName);
+ OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync(
+ instanceId, getInputsAndOutputs: true, this.TimeoutToken);
+
+ Assert.NotNull(metadata);
+ Assert.Equal(instanceId, metadata.InstanceId);
+ Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus);
+ Assert.True(attemptCount >= 1, "Retry handler should have been called at least once");
+ Assert.True(cancellationTokenWasCancelled, "Cancellation token should have been cancelled in retry handler");
+ }
+
+ ///
+ /// Tests that retry handler can check cancellation token and stop retrying.
+ ///
+ [Fact]
+ public async Task RetryHandlerCanStopOnCancellation()
+ {
+ TaskName orchestratorName = nameof(RetryHandlerCanStopOnCancellation);
+
+ int maxAttempts = 0;
+
+ await using HostTestLifetime server = await this.StartWorkerAsync(b =>
+ {
+ b.AddTasks(tasks => tasks
+ .AddOrchestratorFunc(orchestratorName, async ctx =>
+ {
+ using CancellationTokenSource cts = new();
+
+ TaskRetryOptions retryOptions = TaskOptions.FromRetryHandler(retryContext =>
+ {
+ maxAttempts = retryContext.LastAttemptNumber;
+
+ // Cancel after second attempt
+ if (maxAttempts == 2)
+ {
+ cts.Cancel();
+ }
+
+ // Stop retrying if cancelled
+ if (retryContext.CancellationToken.IsCancellationRequested)
+ {
+ return false;
+ }
+
+ return maxAttempts < 10;
+ }).Retry!;
+
+ TaskOptions options = new(retryOptions)
+ {
+ CancellationToken = cts.Token
+ };
+
+ try
+ {
+ await ctx.CallActivityAsync("FailingActivity", options);
+ return "Should not reach here";
+ }
+ catch (TaskFailedException)
+ {
+ return $"Stopped after {maxAttempts} attempts";
+ }
+ })
+ .AddActivityFunc("FailingActivity", (TaskActivityContext activityContext) =>
+ {
+ throw new InvalidOperationException("Activity always fails");
+ }));
+ });
+
+ string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName);
+ OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync(
+ instanceId, getInputsAndOutputs: true, this.TimeoutToken);
+
+ Assert.NotNull(metadata);
+ Assert.Equal(instanceId, metadata.InstanceId);
+ Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus);
+ Assert.Equal(2, maxAttempts); // Should stop after 2 attempts due to cancellation
+ Assert.Equal("\"Stopped after 2 attempts\"", metadata.SerializedOutput);
+ }
+
+ ///
+ /// Tests that when a token is cancelled outside the retry handler (between retry attempts),
+ /// the handler receives the cancelled token on the next attempt.
+ ///
+ [Fact]
+ public async Task RetryHandlerReceivesCancelledTokenFromOutside()
+ {
+ TaskName orchestratorName = nameof(RetryHandlerReceivesCancelledTokenFromOutside);
+
+ int attemptCount = 0;
+ bool tokenWasCancelledInHandler = false;
+ CancellationTokenSource? cts = null;
+
+ await using HostTestLifetime server = await this.StartWorkerAsync(b =>
+ {
+ b.AddTasks(tasks => tasks
+ .AddOrchestratorFunc(orchestratorName, async ctx =>
+ {
+ cts = new CancellationTokenSource();
+
+ TaskRetryOptions retryOptions = TaskOptions.FromRetryHandler(retryContext =>
+ {
+ attemptCount = retryContext.LastAttemptNumber;
+
+ // Check if token is cancelled
+ tokenWasCancelledInHandler = retryContext.CancellationToken.IsCancellationRequested;
+
+ // Stop retrying if cancelled
+ if (retryContext.CancellationToken.IsCancellationRequested)
+ {
+ return false;
+ }
+
+ return attemptCount < 5;
+ }).Retry!;
+
+ TaskOptions options = new(retryOptions)
+ {
+ CancellationToken = cts.Token
+ };
+
+ // Cancel the token AFTER creating options but BEFORE first attempt
+ // This tests that the retry handler receives the cancelled token from outside
+ cts.Cancel();
+
+ try
+ {
+ await ctx.CallActivityAsync("FailingActivity", options);
+ return "Should not reach here - activity succeeded";
+ }
+ catch (TaskCanceledException)
+ {
+ // Pre-scheduling check caught the cancelled token before even attempting
+ return $"Cancelled before scheduling, attempts: {attemptCount}";
+ }
+ catch (TaskFailedException)
+ {
+ // Activity failed and retry handler stopped retrying
+ return $"Failed after {attemptCount} attempts, token was cancelled in handler: {tokenWasCancelledInHandler}";
+ }
+ })
+ .AddActivityFunc("FailingActivity", (TaskActivityContext activityContext) =>
+ {
+ throw new InvalidOperationException("Activity always fails");
+ }));
+ });
+
+ string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName);
+ OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync(
+ instanceId, getInputsAndOutputs: true, this.TimeoutToken);
+
+ Assert.NotNull(metadata);
+ Assert.Equal(instanceId, metadata.InstanceId);
+ Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus);
+
+ // Since token was cancelled before CallActivityAsync, the pre-scheduling check throws
+ // TaskCanceledException and retry handler never gets called
+ Assert.Equal(0, attemptCount);
+ Assert.Contains("Cancelled before scheduling", metadata.SerializedOutput);
+ }
+
+ ///
+ /// Tests that when calling multiple activities in a loop with a cancellation token,
+ /// the loop exits after cancellation instead of continuing to call remaining activities.
+ /// This is the main use case for pre-scheduling cancellation checks.
+ ///
+ [Fact]
+ public async Task MultipleActivitiesInLoopWithCancellation()
+ {
+ TaskName orchestratorName = nameof(MultipleActivitiesInLoopWithCancellation);
+ TaskName activityName = "ProcessItem";
+
+ int activitiesInvoked = 0;
+ int totalItems = 10;
+
+ await using HostTestLifetime server = await this.StartWorkerAsync(b =>
+ {
+ b.AddTasks(tasks => tasks
+ .AddOrchestratorFunc(orchestratorName, async ctx =>
+ {
+ using CancellationTokenSource cts = new();
+ TaskOptions options = new() { CancellationToken = cts.Token };
+
+ List results = new();
+
+ for (int i = 0; i < totalItems; i++)
+ {
+ // Cancel after processing 3 items
+ if (i == 3)
+ {
+ cts.Cancel();
+ }
+
+ try
+ {
+ string result = await ctx.CallActivityAsync(activityName, i, options);
+ results.Add(result);
+ }
+ catch (TaskCanceledException)
+ {
+ // Pre-scheduling check caught cancellation - exit loop
+ results.Add($"Cancelled at item {i}");
+ break;
+ }
+ }
+
+ return $"Processed {results.Count} items: [{string.Join(", ", results)}]";
+ })
+ .AddActivityFunc(activityName, (TaskActivityContext ctx, int item) =>
+ {
+ Interlocked.Increment(ref activitiesInvoked);
+ return $"Item {item}";
+ }));
+ });
+
+ string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName);
+ OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync(
+ instanceId, getInputsAndOutputs: true, this.TimeoutToken);
+
+ Assert.NotNull(metadata);
+ Assert.Equal(instanceId, metadata.InstanceId);
+ Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus);
+
+ // Should have processed 3 items (0, 1, 2) before cancellation at item 3
+ Assert.Equal(3, activitiesInvoked);
+ Assert.Contains("Processed 4 items", metadata.SerializedOutput); // 3 successful + 1 cancellation message
+ Assert.Contains("Cancelled at item 3", metadata.SerializedOutput);
+ Assert.Contains("Item 0", metadata.SerializedOutput);
+ Assert.Contains("Item 1", metadata.SerializedOutput);
+ Assert.Contains("Item 2", metadata.SerializedOutput);
+ }
+}