Skip to content

Commit 139c240

Browse files
committed
[Parity]: Add completion reason inference
- Adds completion reason inference based on the rules outlined in the typescript implementation / spec. fixes: #36
1 parent 8bdb236 commit 139c240

2 files changed

Lines changed: 337 additions & 15 deletions

File tree

src/aws_durable_execution_sdk_python/concurrency.py

Lines changed: 74 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import threading
88
import time
99
from abc import ABC, abstractmethod
10+
from collections import Counter
1011
from concurrent.futures import Future, ThreadPoolExecutor
1112
from dataclasses import dataclass
1213
from enum import Enum
@@ -22,7 +23,7 @@
2223
from aws_durable_execution_sdk_python.types import BatchResult as BatchResultProtocol
2324

2425
if TYPE_CHECKING:
25-
from collections.abc import Callable
26+
from collections.abc import Callable, Iterable
2627

2728
from aws_durable_execution_sdk_python.config import CompletionConfig
2829
from aws_durable_execution_sdk_python.lambda_service import OperationSubType
@@ -67,6 +68,42 @@ def suspend(exception: SuspendExecution) -> SuspendResult:
6768
return SuspendResult(should_suspend=True, exception=exception)
6869

6970

71+
def _get_completion_reason(
72+
items: Iterable[BatchItemStatus], completion_config: CompletionConfig | None = None
73+
) -> CompletionReason:
74+
"""
75+
Infer completion reason based on batch item statuses and completion config.
76+
77+
This follows the same logic as the TypeScript implementation:
78+
- If all items completed: ALL_COMPLETED
79+
- If minSuccessful threshold met and not all completed: MIN_SUCCESSFUL_REACHED
80+
- Otherwise: FAILURE_TOLERANCE_EXCEEDED
81+
"""
82+
83+
counts = Counter(items)
84+
succeeded_count = counts.get(BatchItemStatus.SUCCEEDED, 0)
85+
failed_count = counts.get(BatchItemStatus.FAILED, 0)
86+
started_count = counts.get(BatchItemStatus.STARTED, 0)
87+
88+
completed_count = succeeded_count + failed_count
89+
total_count = started_count + completed_count
90+
91+
# If all items completed (no started items), it's ALL_COMPLETED
92+
if completed_count == total_count:
93+
return CompletionReason.ALL_COMPLETED
94+
95+
# If we have completion config and minSuccessful threshold is met
96+
if (
97+
completion_config
98+
and (min_successful := completion_config.min_successful) is not None
99+
and succeeded_count >= min_successful
100+
):
101+
return CompletionReason.MIN_SUCCESSFUL_REACHED
102+
103+
# Otherwise, assume failure tolerance was exceeded
104+
return CompletionReason.FAILURE_TOLERANCE_EXCEEDED
105+
106+
70107
@dataclass(frozen=True)
71108
class BatchItem(Generic[R]):
72109
index: int
@@ -98,16 +135,44 @@ class BatchResult(Generic[R], BatchResultProtocol[R]): # noqa: PYI059
98135
completion_reason: CompletionReason
99136

100137
@classmethod
101-
def from_dict(cls, data: dict) -> BatchResult[R]:
138+
def from_dict(
139+
cls, data: dict, completion_config: CompletionConfig | None = None
140+
) -> BatchResult[R]:
102141
batch_items: list[BatchItem[R]] = [
103142
BatchItem.from_dict(item) for item in data["all"]
104143
]
105-
# TODO: is this valid? assuming completion reason is ALL_COMPLETED?
106-
completion_reason = CompletionReason(
107-
data.get("completionReason", "ALL_COMPLETED")
108-
)
144+
145+
completion_reason_value = data.get("completionReason")
146+
if completion_reason_value is None:
147+
# Infer completion reason from batch item statuses and completion config
148+
# This aligns with the TypeScript implementation that uses completion config
149+
# to accurately reconstruct the completion reason during replay
150+
result = cls.from_items(batch_items, completion_config)
151+
logger.warning(
152+
"Missing completionReason in BatchResult deserialization, "
153+
"inferred '%s' from batch item statuses. "
154+
"This may indicate incomplete serialization data.",
155+
result.completion_reason.value,
156+
)
157+
return result
158+
159+
completion_reason = CompletionReason(completion_reason_value)
109160
return cls(batch_items, completion_reason)
110161

162+
@classmethod
163+
def from_items(
164+
cls,
165+
items: Iterable[BatchItem[R]],
166+
completion_config: CompletionConfig | None = None,
167+
):
168+
items = list(items)
169+
# Infer completion reason from batch item statuses and completion config
170+
# This aligns with the TypeScript implementation that uses completion config
171+
# to accurately reconstruct the completion reason during replay
172+
statuses = (item.status for item in items)
173+
completion_reason = _get_completion_reason(statuses, completion_config)
174+
return cls(items, completion_reason)
175+
111176
def to_dict(self) -> dict:
112177
return {
113178
"all": [item.to_dict() for item in self.all],
@@ -163,19 +228,15 @@ def get_errors(self) -> list[ErrorObject]:
163228

164229
@property
165230
def success_count(self) -> int:
166-
return len(
167-
[item for item in self.all if item.status is BatchItemStatus.SUCCEEDED]
168-
)
231+
return sum(1 for item in self.all if item.status is BatchItemStatus.SUCCEEDED)
169232

170233
@property
171234
def failure_count(self) -> int:
172-
return len([item for item in self.all if item.status is BatchItemStatus.FAILED])
235+
return sum(1 for item in self.all if item.status is BatchItemStatus.FAILED)
173236

174237
@property
175238
def started_count(self) -> int:
176-
return len(
177-
[item for item in self.all if item.status is BatchItemStatus.STARTED]
178-
)
239+
return sum(1 for item in self.all if item.status is BatchItemStatus.STARTED)
179240

180241
@property
181242
def total_count(self) -> int:

0 commit comments

Comments
 (0)