Skip to content

Commit 7070cb0

Browse files
Merge pull request dapr#17 from passuied/feature/asyncio-dapr
Add async version of durabletask client
2 parents 06357df + 3d8528d commit 7070cb0

17 files changed

+959
-11
lines changed

.github/workflows/pr-validation.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jobs:
2828
- name: Install dependencies
2929
run: |
3030
python -m pip install --upgrade pip
31-
pip install flake8 pytest
31+
pip install flake8 pytest pytest-cov pytest-asyncio
3232
pip install -r requirements.txt
3333
- name: Lint with flake8
3434
run: |

dev-requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
grpcio-tools==1.62.3 # 1.62.X is the latest version before protobuf 1.26.X is used which has breaking changes for Python
1+
grpcio-tools==1.62.3 # 1.62.X is the latest version before protobuf 1.26.X is used which has breaking changes for Python # supports protobuf 6.x and aligns with generated code

durabletask/aio/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .client import AsyncTaskHubGrpcClient
2+
3+
__all__ = [
4+
"AsyncTaskHubGrpcClient",
5+
]

durabletask/aio/client.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright (c) The Dapr Authors.
2+
# Licensed under the MIT License.
3+
4+
import logging
5+
import uuid
6+
from datetime import datetime
7+
from typing import Any, Optional, Sequence, Union
8+
9+
import grpc
10+
from google.protobuf import wrappers_pb2
11+
12+
import durabletask.internal.helpers as helpers
13+
import durabletask.internal.orchestrator_service_pb2 as pb
14+
import durabletask.internal.orchestrator_service_pb2_grpc as stubs
15+
import durabletask.internal.shared as shared
16+
from durabletask.aio.internal.shared import get_grpc_aio_channel, ClientInterceptor
17+
from durabletask import task
18+
from durabletask.client import OrchestrationState, OrchestrationStatus, new_orchestration_state, TInput, TOutput
19+
from durabletask.aio.internal.grpc_interceptor import DefaultClientInterceptorImpl
20+
21+
22+
class AsyncTaskHubGrpcClient:
23+
24+
def __init__(self, *,
25+
host_address: Optional[str] = None,
26+
metadata: Optional[list[tuple[str, str]]] = None,
27+
log_handler: Optional[logging.Handler] = None,
28+
log_formatter: Optional[logging.Formatter] = None,
29+
secure_channel: bool = False,
30+
interceptors: Optional[Sequence[ClientInterceptor]] = None):
31+
32+
if interceptors is not None:
33+
interceptors = list(interceptors)
34+
if metadata is not None:
35+
interceptors.append(DefaultClientInterceptorImpl(metadata))
36+
elif metadata is not None:
37+
interceptors = [DefaultClientInterceptorImpl(metadata)]
38+
else:
39+
interceptors = None
40+
41+
channel = get_grpc_aio_channel(
42+
host_address=host_address,
43+
secure_channel=secure_channel,
44+
interceptors=interceptors
45+
)
46+
self._channel = channel
47+
self._stub = stubs.TaskHubSidecarServiceStub(channel)
48+
self._logger = shared.get_logger("client", log_handler, log_formatter)
49+
50+
async def aclose(self):
51+
await self._channel.close()
52+
53+
async def __aenter__(self):
54+
return self
55+
56+
async def __aexit__(self, exc_type, exc_val, exc_tb):
57+
await self.aclose()
58+
return False
59+
60+
async def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *,
61+
input: Optional[TInput] = None,
62+
instance_id: Optional[str] = None,
63+
start_at: Optional[datetime] = None,
64+
reuse_id_policy: Optional[pb.OrchestrationIdReusePolicy] = None) -> str:
65+
66+
name = orchestrator if isinstance(orchestrator, str) else task.get_name(orchestrator)
67+
68+
req = pb.CreateInstanceRequest(
69+
name=name,
70+
instanceId=instance_id if instance_id else uuid.uuid4().hex,
71+
input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input is not None else None,
72+
scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None,
73+
version=helpers.get_string_value(None),
74+
orchestrationIdReusePolicy=reuse_id_policy,
75+
)
76+
77+
self._logger.info(f"Starting new '{name}' instance with ID = '{req.instanceId}'.")
78+
res: pb.CreateInstanceResponse = await self._stub.StartInstance(req)
79+
return res.instanceId
80+
81+
async def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = True) -> Optional[OrchestrationState]:
82+
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
83+
res: pb.GetInstanceResponse = await self._stub.GetInstance(req)
84+
return new_orchestration_state(req.instanceId, res)
85+
86+
async def wait_for_orchestration_start(self, instance_id: str, *,
87+
fetch_payloads: bool = False,
88+
timeout: int = 0) -> Optional[OrchestrationState]:
89+
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
90+
try:
91+
grpc_timeout = None if timeout == 0 else timeout
92+
self._logger.info(
93+
f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to start.")
94+
res: pb.GetInstanceResponse = await self._stub.WaitForInstanceStart(req, timeout=grpc_timeout)
95+
return new_orchestration_state(req.instanceId, res)
96+
except grpc.RpcError as rpc_error:
97+
if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore
98+
# Replace gRPC error with the built-in TimeoutError
99+
raise TimeoutError("Timed-out waiting for the orchestration to start")
100+
else:
101+
raise
102+
103+
async def wait_for_orchestration_completion(self, instance_id: str, *,
104+
fetch_payloads: bool = True,
105+
timeout: int = 0) -> Optional[OrchestrationState]:
106+
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
107+
try:
108+
grpc_timeout = None if timeout == 0 else timeout
109+
self._logger.info(
110+
f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to complete.")
111+
res: pb.GetInstanceResponse = await self._stub.WaitForInstanceCompletion(req, timeout=grpc_timeout)
112+
state = new_orchestration_state(req.instanceId, res)
113+
if not state:
114+
return None
115+
116+
if state.runtime_status == OrchestrationStatus.FAILED and state.failure_details is not None:
117+
details = state.failure_details
118+
self._logger.info(f"Instance '{instance_id}' failed: [{details.error_type}] {details.message}")
119+
elif state.runtime_status == OrchestrationStatus.TERMINATED:
120+
self._logger.info(f"Instance '{instance_id}' was terminated.")
121+
elif state.runtime_status == OrchestrationStatus.COMPLETED:
122+
self._logger.info(f"Instance '{instance_id}' completed.")
123+
124+
return state
125+
except grpc.RpcError as rpc_error:
126+
if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore
127+
# Replace gRPC error with the built-in TimeoutError
128+
raise TimeoutError("Timed-out waiting for the orchestration to complete")
129+
else:
130+
raise
131+
132+
async def raise_orchestration_event(
133+
self,
134+
instance_id: str,
135+
event_name: str,
136+
*,
137+
data: Optional[Any] = None):
138+
req = pb.RaiseEventRequest(
139+
instanceId=instance_id,
140+
name=event_name,
141+
input=wrappers_pb2.StringValue(value=shared.to_json(data)) if data else None)
142+
143+
self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.")
144+
await self._stub.RaiseEvent(req)
145+
146+
async def terminate_orchestration(self, instance_id: str, *,
147+
output: Optional[Any] = None,
148+
recursive: bool = True):
149+
req = pb.TerminateRequest(
150+
instanceId=instance_id,
151+
output=wrappers_pb2.StringValue(value=shared.to_json(output)) if output else None,
152+
recursive=recursive)
153+
154+
self._logger.info(f"Terminating instance '{instance_id}'.")
155+
await self._stub.TerminateInstance(req)
156+
157+
async def suspend_orchestration(self, instance_id: str):
158+
req = pb.SuspendRequest(instanceId=instance_id)
159+
self._logger.info(f"Suspending instance '{instance_id}'.")
160+
await self._stub.SuspendInstance(req)
161+
162+
async def resume_orchestration(self, instance_id: str):
163+
req = pb.ResumeRequest(instanceId=instance_id)
164+
self._logger.info(f"Resuming instance '{instance_id}'.")
165+
await self._stub.ResumeInstance(req)
166+
167+
async def purge_orchestration(self, instance_id: str, recursive: bool = True):
168+
req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive)
169+
self._logger.info(f"Purging instance '{instance_id}'.")
170+
await self._stub.PurgeInstances(req)

durabletask/aio/internal/__init__.py

Whitespace-only changes.
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright (c) The Dapr Authors.
2+
# Licensed under the MIT License.
3+
4+
from collections import namedtuple
5+
6+
from grpc import aio as grpc_aio
7+
8+
9+
class _ClientCallDetails(
10+
namedtuple(
11+
'_ClientCallDetails',
12+
['method', 'timeout', 'metadata', 'credentials', 'wait_for_ready', 'compression']),
13+
grpc_aio.ClientCallDetails):
14+
pass
15+
16+
17+
class DefaultClientInterceptorImpl(
18+
grpc_aio.UnaryUnaryClientInterceptor, grpc_aio.UnaryStreamClientInterceptor,
19+
grpc_aio.StreamUnaryClientInterceptor, grpc_aio.StreamStreamClientInterceptor):
20+
"""Async gRPC client interceptor to add metadata to all calls."""
21+
22+
def __init__(self, metadata: list[tuple[str, str]]):
23+
super().__init__()
24+
self._metadata = metadata
25+
26+
def _intercept_call(self, client_call_details: _ClientCallDetails) -> grpc_aio.ClientCallDetails:
27+
if self._metadata is None:
28+
return client_call_details
29+
30+
if client_call_details.metadata is not None:
31+
metadata = list(client_call_details.metadata)
32+
else:
33+
metadata = []
34+
35+
metadata.extend(self._metadata)
36+
return _ClientCallDetails(
37+
client_call_details.method,
38+
client_call_details.timeout,
39+
metadata,
40+
client_call_details.credentials,
41+
client_call_details.wait_for_ready,
42+
client_call_details.compression)
43+
44+
async def intercept_unary_unary(self, continuation, client_call_details, request):
45+
new_client_call_details = self._intercept_call(client_call_details)
46+
return await continuation(new_client_call_details, request)
47+
48+
async def intercept_unary_stream(self, continuation, client_call_details, request):
49+
new_client_call_details = self._intercept_call(client_call_details)
50+
return await continuation(new_client_call_details, request)
51+
52+
async def intercept_stream_unary(self, continuation, client_call_details, request_iterator):
53+
new_client_call_details = self._intercept_call(client_call_details)
54+
return await continuation(new_client_call_details, request_iterator)
55+
56+
async def intercept_stream_stream(self, continuation, client_call_details, request_iterator):
57+
new_client_call_details = self._intercept_call(client_call_details)
58+
return await continuation(new_client_call_details, request_iterator)

durabletask/aio/internal/shared.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright (c) The Dapr Authors.
2+
# Licensed under the MIT License.
3+
4+
from typing import Optional, Sequence, Union
5+
6+
import grpc
7+
from grpc import aio as grpc_aio
8+
9+
from durabletask.internal.shared import (
10+
get_default_host_address,
11+
SECURE_PROTOCOLS,
12+
INSECURE_PROTOCOLS,
13+
)
14+
15+
16+
ClientInterceptor = Union[
17+
grpc_aio.UnaryUnaryClientInterceptor,
18+
grpc_aio.UnaryStreamClientInterceptor,
19+
grpc_aio.StreamUnaryClientInterceptor,
20+
grpc_aio.StreamStreamClientInterceptor
21+
]
22+
23+
24+
def get_grpc_aio_channel(
25+
host_address: Optional[str],
26+
secure_channel: bool = False,
27+
interceptors: Optional[Sequence[ClientInterceptor]] = None) -> grpc_aio.Channel:
28+
29+
if host_address is None:
30+
host_address = get_default_host_address()
31+
32+
for protocol in SECURE_PROTOCOLS:
33+
if host_address.lower().startswith(protocol):
34+
secure_channel = True
35+
host_address = host_address[len(protocol):]
36+
break
37+
38+
for protocol in INSECURE_PROTOCOLS:
39+
if host_address.lower().startswith(protocol):
40+
secure_channel = False
41+
host_address = host_address[len(protocol):]
42+
break
43+
44+
if secure_channel:
45+
channel = grpc_aio.secure_channel(host_address, grpc.ssl_channel_credentials(), interceptors=interceptors)
46+
else:
47+
channel = grpc_aio.insecure_channel(host_address, interceptors=interceptors)
48+
49+
return channel

durabletask/task.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,13 +283,18 @@ def get_tasks(self) -> list[Task]:
283283
def on_child_completed(self, task: Task[T]):
284284
pass
285285

286+
286287
class WhenAllTask(CompositeTask[list[T]]):
287288
"""A task that completes when all of its child tasks complete."""
288289

289290
def __init__(self, tasks: list[Task[T]]):
290291
super().__init__(tasks)
291292
self._completed_tasks = 0
292293
self._failed_tasks = 0
294+
# If there are no child tasks, this composite should complete immediately
295+
if len(self._tasks) == 0:
296+
self._result = [] # type: ignore[assignment]
297+
self._is_complete = True
293298

294299
@property
295300
def pending_tasks(self) -> int:
@@ -387,6 +392,10 @@ class WhenAnyTask(CompositeTask[Task]):
387392

388393
def __init__(self, tasks: list[Task]):
389394
super().__init__(tasks)
395+
# If there are no child tasks, complete immediately with an empty result
396+
if len(self._tasks) == 0:
397+
self._result = [] # type: ignore[assignment]
398+
self._is_complete = True
390399

391400
def on_child_completed(self, task: Task):
392401
# The first task to complete is the result of the WhenAnyTask.

durabletask/worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -880,13 +880,13 @@ class ExecutionResults:
880880
actions: list[pb.OrchestratorAction]
881881
encoded_custom_status: Optional[str]
882882

883-
884883
def __init__(
885884
self, actions: list[pb.OrchestratorAction], encoded_custom_status: Optional[str]
886885
):
887886
self.actions = actions
888887
self.encoded_custom_status = encoded_custom_status
889888

889+
890890
class _OrchestrationExecutor:
891891
_generator: Optional[task.Orchestrator] = None
892892

requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
autopep8
22
grpcio>=1.60.0 # 1.60.0 is the version introducing protobuf 1.25.X support, newer versions are backwards compatible
33
protobuf
4+
asyncio
45
pytest
56
pytest-cov
6-
asyncio
7+
pytest-asyncio
8+
flake8

0 commit comments

Comments
 (0)