Skip to content

Commit cf53784

Browse files
committed
Adjusted implementation per dapr version
linting
1 parent 3525018 commit cf53784

File tree

8 files changed

+28
-29
lines changed

8 files changed

+28
-29
lines changed

durabletask/aio/client.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ def __init__(self, *,
2424
log_handler: Optional[logging.Handler] = None,
2525
log_formatter: Optional[logging.Formatter] = None,
2626
secure_channel: bool = False,
27-
interceptors: Optional[Sequence[AioClientInterceptor]] = None,
28-
default_version: Optional[str] = None):
27+
interceptors: Optional[Sequence[AioClientInterceptor]] = None):
2928

3029
if interceptors is not None:
3130
interceptors = list(interceptors)
@@ -44,7 +43,6 @@ def __init__(self, *,
4443
self._channel = channel
4544
self._stub = stubs.TaskHubSidecarServiceStub(channel)
4645
self._logger = shared.get_logger("client", log_handler, log_formatter)
47-
self.default_version = default_version
4846

4947
async def aclose(self):
5048
await self._channel.close()
@@ -53,9 +51,7 @@ async def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator
5351
input: Optional[TInput] = None,
5452
instance_id: Optional[str] = None,
5553
start_at: Optional[datetime] = None,
56-
reuse_id_policy: Optional[pb.OrchestrationIdReusePolicy] = None,
57-
tags: Optional[dict[str, str]] = None,
58-
version: Optional[str] = None) -> str:
54+
reuse_id_policy: Optional[pb.OrchestrationIdReusePolicy] = None) -> str:
5955

6056
name = orchestrator if isinstance(orchestrator, str) else task.get_name(orchestrator)
6157

@@ -64,9 +60,8 @@ async def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator
6460
instanceId=instance_id if instance_id else uuid.uuid4().hex,
6561
input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input is not None else None,
6662
scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None,
67-
version=helpers.get_string_value(version if version else self.default_version),
63+
version=helpers.get_string_value(None),
6864
orchestrationIdReusePolicy=reuse_id_policy,
69-
tags=tags
7065
)
7166

7267
self._logger.info(f"Starting new '{name}' instance with ID = '{req.instanceId}'.")
@@ -80,25 +75,30 @@ async def get_orchestration_state(self, instance_id: str, *, fetch_payloads: boo
8075

8176
async def wait_for_orchestration_start(self, instance_id: str, *,
8277
fetch_payloads: bool = False,
83-
timeout: int = 60) -> Optional[OrchestrationState]:
78+
timeout: int = 0) -> Optional[OrchestrationState]:
8479
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
8580
try:
86-
self._logger.info(f"Waiting up to {timeout}s for instance '{instance_id}' to start.")
87-
res: pb.GetInstanceResponse = await self._stub.WaitForInstanceStart(req, timeout=timeout)
81+
grpc_timeout = None if timeout == 0 else timeout
82+
self._logger.info(
83+
f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to start.")
84+
res: pb.GetInstanceResponse = await self._stub.WaitForInstanceStart(req, timeout=grpc_timeout)
8885
return new_orchestration_state(req.instanceId, res)
8986
except grpc.RpcError as rpc_error:
9087
if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore
88+
# Replace gRPC error with the built-in TimeoutError
9189
raise TimeoutError("Timed-out waiting for the orchestration to start")
9290
else:
9391
raise
9492

9593
async def wait_for_orchestration_completion(self, instance_id: str, *,
9694
fetch_payloads: bool = True,
97-
timeout: int = 60) -> Optional[OrchestrationState]:
95+
timeout: int = 0) -> Optional[OrchestrationState]:
9896
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
9997
try:
100-
self._logger.info(f"Waiting {timeout}s for instance '{instance_id}' to complete.")
101-
res: pb.GetInstanceResponse = await self._stub.WaitForInstanceCompletion(req, timeout=timeout)
98+
grpc_timeout = None if timeout == 0 else timeout
99+
self._logger.info(
100+
f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to complete.")
101+
res: pb.GetInstanceResponse = await self._stub.WaitForInstanceCompletion(req, timeout=grpc_timeout)
102102
state = new_orchestration_state(req.instanceId, res)
103103
if not state:
104104
return None
@@ -114,6 +114,7 @@ async def wait_for_orchestration_completion(self, instance_id: str, *,
114114
return state
115115
except grpc.RpcError as rpc_error:
116116
if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore
117+
# Replace gRPC error with the built-in TimeoutError
117118
raise TimeoutError("Timed-out waiting for the orchestration to complete")
118119
else:
119120
raise

durabletask/task.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ def get_tasks(self) -> list[Task]:
283283
def on_child_completed(self, task: Task[T]):
284284
pass
285285

286+
286287
class WhenAllTask(CompositeTask[list[T]]):
287288
"""A task that completes when all of its child tasks complete."""
288289

durabletask/worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -880,13 +880,13 @@ class ExecutionResults:
880880
actions: list[pb.OrchestratorAction]
881881
encoded_custom_status: Optional[str]
882882

883-
884883
def __init__(
885884
self, actions: list[pb.OrchestratorAction], encoded_custom_status: Optional[str]
886885
):
887886
self.actions = actions
888887
self.encoded_custom_status = encoded_custom_status
889888

889+
890890
class _OrchestrationExecutor:
891891
_generator: Optional[task.Orchestrator] = None
892892

tests/durabletask/test_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def test_get_grpc_channel_secure():
2121
get_grpc_channel(HOST_ADDRESS, True, interceptors=INTERCEPTORS)
2222
mock_channel.assert_called_once_with(HOST_ADDRESS, mock_credentials.return_value)
2323

24+
2425
def test_get_grpc_channel_default_host_address():
2526
with patch('grpc.insecure_channel') as mock_channel:
2627
get_grpc_channel(None, False, interceptors=INTERCEPTORS)

tests/durabletask/test_orchestration_e2e.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,6 @@ def parent_orchestrator(ctx: task.OrchestrationContext, count: int):
316316
output = "Recursive termination = {recurse}"
317317
task_hub_client.terminate_orchestration(instance_id, output=output, recursive=recurse)
318318

319-
320319
metadata = task_hub_client.wait_for_orchestration_completion(instance_id, timeout=30)
321320

322321
assert metadata is not None

tests/durabletask/test_orchestration_e2e_async.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import asyncio
22
import json
33
import threading
4-
import time
54
from datetime import timedelta
65

76
import pytest
@@ -16,13 +15,14 @@
1615
# durabletask-go --port 4001
1716
pytestmark = [pytest.mark.e2e, pytest.mark.anyio]
1817

18+
1919
@pytest.fixture
2020
def anyio_backend():
2121
return 'asyncio'
2222

2323

2424
async def test_empty_orchestration():
25-
25+
2626
invoked = False
2727

2828
def empty_orchestrator(ctx: task.OrchestrationContext, _):
@@ -35,7 +35,7 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _):
3535
w.start()
3636

3737
c = AsyncTaskHubGrpcClient()
38-
id = await c.schedule_new_orchestration(empty_orchestrator, tags={'Tagged': 'true'})
38+
id = await c.schedule_new_orchestration(empty_orchestrator)
3939
state = await c.wait_for_orchestration_completion(id, timeout=30)
4040
await c.aclose()
4141

@@ -59,7 +59,7 @@ def sequence(ctx: task.OrchestrationContext, start_val: int):
5959
numbers = [start_val]
6060
current = start_val
6161
for _ in range(10):
62-
current = yield ctx.call_activity(plus_one, input=current, tags={'Activity': 'PlusOne'})
62+
current = yield ctx.call_activity(plus_one, input=current)
6363
numbers.append(current)
6464
return numbers
6565

@@ -70,7 +70,7 @@ def sequence(ctx: task.OrchestrationContext, start_val: int):
7070
w.start()
7171

7272
client = AsyncTaskHubGrpcClient()
73-
id = await client.schedule_new_orchestration(sequence, input=1, tags={'Orchestration': 'Sequence'})
73+
id = await client.schedule_new_orchestration(sequence, input=1)
7474
state = await client.wait_for_orchestration_completion(id, timeout=30)
7575
await client.aclose()
7676

@@ -485,5 +485,3 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _):
485485
assert state.serialized_input is None
486486
assert state.serialized_output is None
487487
assert state.serialized_custom_status == "\"foobaz\""
488-
489-

tests/durabletask/test_orchestration_executor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -634,7 +634,7 @@ def orchestrator(ctx: task.OrchestrationContext, _):
634634
yield ctx.call_sub_orchestrator(suborchestrator, input=None)
635635

636636
registry = worker._Registry()
637-
suborchestrator_name = registry.add_orchestrator(suborchestrator)
637+
registry.add_orchestrator(suborchestrator)
638638
orchestrator_name = registry.add_orchestrator(orchestrator)
639639

640640
exec_evt = helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID, encoded_input=None)
@@ -666,7 +666,7 @@ def orchestrator(ctx: task.OrchestrationContext, _):
666666
yield ctx.call_sub_orchestrator(suborchestrator, input=None, app_id="target-app")
667667

668668
registry = worker._Registry()
669-
suborchestrator_name = registry.add_orchestrator(suborchestrator)
669+
registry.add_orchestrator(suborchestrator)
670670
orchestrator_name = registry.add_orchestrator(orchestrator)
671671

672672
exec_evt = helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID, encoded_input=None)

tests/durabletask/test_orchestration_wait.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
from unittest.mock import patch, ANY, Mock
1+
from unittest.mock import Mock
22

33
from durabletask.client import TaskHubGrpcClient
4-
from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl
5-
from durabletask.internal.shared import (get_default_host_address,
6-
get_grpc_channel)
74
import pytest
85

6+
97
@pytest.mark.parametrize("timeout", [None, 0, 5])
108
def test_wait_for_orchestration_start_timeout(timeout):
119
instance_id = "test-instance"
@@ -34,6 +32,7 @@ def test_wait_for_orchestration_start_timeout(timeout):
3432
else:
3533
assert kwargs.get('timeout') == timeout
3634

35+
3736
@pytest.mark.parametrize("timeout", [None, 0, 5])
3837
def test_wait_for_orchestration_completion_timeout(timeout):
3938
instance_id = "test-instance"

0 commit comments

Comments
 (0)