|
7 | 7 | import threading |
8 | 8 | import time |
9 | 9 | from abc import ABC, abstractmethod |
| 10 | +from collections import Counter |
10 | 11 | from concurrent.futures import Future, ThreadPoolExecutor |
11 | 12 | from dataclasses import dataclass |
12 | 13 | from enum import Enum |
|
22 | 23 | from aws_durable_execution_sdk_python.types import BatchResult as BatchResultProtocol |
23 | 24 |
|
24 | 25 | if TYPE_CHECKING: |
25 | | - from collections.abc import Callable |
| 26 | + from collections.abc import Callable, Iterable |
26 | 27 |
|
27 | 28 | from aws_durable_execution_sdk_python.config import CompletionConfig |
28 | 29 | from aws_durable_execution_sdk_python.lambda_service import OperationSubType |
@@ -67,6 +68,42 @@ def suspend(exception: SuspendExecution) -> SuspendResult: |
67 | 68 | return SuspendResult(should_suspend=True, exception=exception) |
68 | 69 |
|
69 | 70 |
|
| 71 | +def _get_completion_reason( |
| 72 | + items: Iterable[BatchItemStatus], completion_config: CompletionConfig | None = None |
| 73 | +) -> CompletionReason: |
| 74 | + """ |
| 75 | + Infer completion reason based on batch item statuses and completion config. |
| 76 | +
|
| 77 | + This follows the same logic as the TypeScript implementation: |
| 78 | + - If all items completed: ALL_COMPLETED |
| 79 | + - If minSuccessful threshold met and not all completed: MIN_SUCCESSFUL_REACHED |
| 80 | + - Otherwise: FAILURE_TOLERANCE_EXCEEDED |
| 81 | + """ |
| 82 | + |
| 83 | + counts = Counter(items) |
| 84 | + succeeded_count = counts.get(BatchItemStatus.SUCCEEDED, 0) |
| 85 | + failed_count = counts.get(BatchItemStatus.FAILED, 0) |
| 86 | + started_count = counts.get(BatchItemStatus.STARTED, 0) |
| 87 | + |
| 88 | + completed_count = succeeded_count + failed_count |
| 89 | + total_count = started_count + completed_count |
| 90 | + |
| 91 | + # If all items completed (no started items), it's ALL_COMPLETED |
| 92 | + if completed_count == total_count: |
| 93 | + return CompletionReason.ALL_COMPLETED |
| 94 | + |
| 95 | + # If we have completion config and minSuccessful threshold is met |
| 96 | + if ( |
| 97 | + completion_config |
| 98 | + and (min_successful := completion_config.min_successful) is not None |
| 99 | + and succeeded_count >= min_successful |
| 100 | + ): |
| 101 | + return CompletionReason.MIN_SUCCESSFUL_REACHED |
| 102 | + |
| 103 | + # Otherwise, assume failure tolerance was exceeded |
| 104 | + return CompletionReason.FAILURE_TOLERANCE_EXCEEDED |
| 105 | + |
| 106 | + |
70 | 107 | @dataclass(frozen=True) |
71 | 108 | class BatchItem(Generic[R]): |
72 | 109 | index: int |
@@ -98,16 +135,44 @@ class BatchResult(Generic[R], BatchResultProtocol[R]): # noqa: PYI059 |
98 | 135 | completion_reason: CompletionReason |
99 | 136 |
|
100 | 137 | @classmethod |
101 | | - def from_dict(cls, data: dict) -> BatchResult[R]: |
| 138 | + def from_dict( |
| 139 | + cls, data: dict, completion_config: CompletionConfig | None = None |
| 140 | + ) -> BatchResult[R]: |
102 | 141 | batch_items: list[BatchItem[R]] = [ |
103 | 142 | BatchItem.from_dict(item) for item in data["all"] |
104 | 143 | ] |
105 | | - # TODO: is this valid? assuming completion reason is ALL_COMPLETED? |
106 | | - completion_reason = CompletionReason( |
107 | | - data.get("completionReason", "ALL_COMPLETED") |
108 | | - ) |
| 144 | + |
| 145 | + completion_reason_value = data.get("completionReason") |
| 146 | + if completion_reason_value is None: |
| 147 | + # Infer completion reason from batch item statuses and completion config |
| 148 | + # This aligns with the TypeScript implementation that uses completion config |
| 149 | + # to accurately reconstruct the completion reason during replay |
| 150 | + result = cls.from_items(batch_items, completion_config) |
| 151 | + logger.warning( |
| 152 | + "Missing completionReason in BatchResult deserialization, " |
| 153 | + "inferred '%s' from batch item statuses. " |
| 154 | + "This may indicate incomplete serialization data.", |
| 155 | + result.completion_reason.value, |
| 156 | + ) |
| 157 | + return result |
| 158 | + |
| 159 | + completion_reason = CompletionReason(completion_reason_value) |
109 | 160 | return cls(batch_items, completion_reason) |
110 | 161 |
|
| 162 | + @classmethod |
| 163 | + def from_items( |
| 164 | + cls, |
| 165 | + items: Iterable[BatchItem[R]], |
| 166 | + completion_config: CompletionConfig | None = None, |
| 167 | + ): |
| 168 | + items = list(items) |
| 169 | + # Infer completion reason from batch item statuses and completion config |
| 170 | + # This aligns with the TypeScript implementation that uses completion config |
| 171 | + # to accurately reconstruct the completion reason during replay |
| 172 | + statuses = (item.status for item in items) |
| 173 | + completion_reason = _get_completion_reason(statuses, completion_config) |
| 174 | + return cls(items, completion_reason) |
| 175 | + |
111 | 176 | def to_dict(self) -> dict: |
112 | 177 | return { |
113 | 178 | "all": [item.to_dict() for item in self.all], |
@@ -163,19 +228,15 @@ def get_errors(self) -> list[ErrorObject]: |
163 | 228 |
|
164 | 229 | @property |
165 | 230 | def success_count(self) -> int: |
166 | | - return len( |
167 | | - [item for item in self.all if item.status is BatchItemStatus.SUCCEEDED] |
168 | | - ) |
| 231 | + return sum(1 for item in self.all if item.status is BatchItemStatus.SUCCEEDED) |
169 | 232 |
|
170 | 233 | @property |
171 | 234 | def failure_count(self) -> int: |
172 | | - return len([item for item in self.all if item.status is BatchItemStatus.FAILED]) |
| 235 | + return sum(1 for item in self.all if item.status is BatchItemStatus.FAILED) |
173 | 236 |
|
174 | 237 | @property |
175 | 238 | def started_count(self) -> int: |
176 | | - return len( |
177 | | - [item for item in self.all if item.status is BatchItemStatus.STARTED] |
178 | | - ) |
| 239 | + return sum(1 for item in self.all if item.status is BatchItemStatus.STARTED) |
179 | 240 |
|
180 | 241 | @property |
181 | 242 | def total_count(self) -> int: |
|
0 commit comments