Skip to content

Commit be50716

Browse files
FullyTypedAstraea Quinn S
authored andcommitted
parity: Add completion reason inference
- Adds completion reason inference based on the rules outlined in the typescript implementation / spec. fixes: #36
1 parent 2e44b69 commit be50716

2 files changed

Lines changed: 325 additions & 14 deletions

File tree

src/aws_durable_execution_sdk_python/concurrency.py

Lines changed: 62 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import threading
88
import time
99
from abc import ABC, abstractmethod
10+
from collections import Counter
1011
from concurrent.futures import Future, ThreadPoolExecutor
1112
from dataclasses import dataclass
1213
from enum import Enum
@@ -98,16 +99,69 @@ class BatchResult(Generic[R], BatchResultProtocol[R]): # noqa: PYI059
9899
completion_reason: CompletionReason
99100

100101
@classmethod
101-
def from_dict(cls, data: dict) -> BatchResult[R]:
102+
def from_dict(
103+
cls, data: dict, completion_config: CompletionConfig | None = None
104+
) -> BatchResult[R]:
102105
batch_items: list[BatchItem[R]] = [
103106
BatchItem.from_dict(item) for item in data["all"]
104107
]
105-
# TODO: is this valid? assuming completion reason is ALL_COMPLETED?
106-
completion_reason = CompletionReason(
107-
data.get("completionReason", "ALL_COMPLETED")
108-
)
108+
109+
completion_reason_value = data.get("completionReason")
110+
if completion_reason_value is None:
111+
# Infer completion reason from batch item statuses and completion config
112+
# This aligns with the TypeScript implementation that uses completion config
113+
# to accurately reconstruct the completion reason during replay
114+
result = cls.from_items(batch_items, completion_config)
115+
logger.warning(
116+
"Missing completionReason in BatchResult deserialization, "
117+
"inferred '%s' from batch item statuses. "
118+
"This may indicate incomplete serialization data.",
119+
result.completion_reason.value,
120+
)
121+
return result
122+
123+
completion_reason = CompletionReason(completion_reason_value)
109124
return cls(batch_items, completion_reason)
110125

126+
@classmethod
127+
def from_items(
128+
cls,
129+
items: list[BatchItem[R]],
130+
completion_config: CompletionConfig | None = None,
131+
):
132+
"""
133+
Infer completion reason based on batch item statuses and completion config.
134+
135+
This follows the same logic as the TypeScript implementation:
136+
- If all items completed: ALL_COMPLETED
137+
- If minSuccessful threshold met and not all completed: MIN_SUCCESSFUL_REACHED
138+
- Otherwise: FAILURE_TOLERANCE_EXCEEDED
139+
"""
140+
141+
statuses = (item.status for item in items)
142+
counts = Counter(statuses)
143+
succeeded_count = counts.get(BatchItemStatus.SUCCEEDED, 0)
144+
failed_count = counts.get(BatchItemStatus.FAILED, 0)
145+
started_count = counts.get(BatchItemStatus.STARTED, 0)
146+
147+
completed_count = succeeded_count + failed_count
148+
total_count = started_count + completed_count
149+
150+
# If all items completed (no started items), it's ALL_COMPLETED
151+
if completed_count == total_count:
152+
completion_reason = CompletionReason.ALL_COMPLETED
153+
elif ( # If we have completion config and minSuccessful threshold is met
154+
completion_config
155+
and (min_successful := completion_config.min_successful) is not None
156+
and succeeded_count >= min_successful
157+
):
158+
completion_reason = CompletionReason.MIN_SUCCESSFUL_REACHED
159+
else:
160+
# Otherwise, assume failure tolerance was exceeded
161+
completion_reason = CompletionReason.FAILURE_TOLERANCE_EXCEEDED
162+
163+
return cls(items, completion_reason)
164+
111165
def to_dict(self) -> dict:
112166
return {
113167
"all": [item.to_dict() for item in self.all],
@@ -163,19 +217,15 @@ def get_errors(self) -> list[ErrorObject]:
163217

164218
@property
165219
def success_count(self) -> int:
166-
return len(
167-
[item for item in self.all if item.status is BatchItemStatus.SUCCEEDED]
168-
)
220+
return sum(1 for item in self.all if item.status is BatchItemStatus.SUCCEEDED)
169221

170222
@property
171223
def failure_count(self) -> int:
172-
return len([item for item in self.all if item.status is BatchItemStatus.FAILED])
224+
return sum(1 for item in self.all if item.status is BatchItemStatus.FAILED)
173225

174226
@property
175227
def started_count(self) -> int:
176-
return len(
177-
[item for item in self.all if item.status is BatchItemStatus.STARTED]
178-
)
228+
return sum(1 for item in self.all if item.status is BatchItemStatus.STARTED)
179229

180230
@property
181231
def total_count(self) -> int:

tests/concurrency_test.py

Lines changed: 263 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,178 @@ def test_batch_result_from_dict_default_completion_reason():
319319
# No completionReason provided
320320
}
321321

322-
result = BatchResult.from_dict(data)
323-
assert result.completion_reason == CompletionReason.ALL_COMPLETED
322+
with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger:
323+
result = BatchResult.from_dict(data)
324+
assert result.completion_reason == CompletionReason.ALL_COMPLETED
325+
# Verify warning was logged
326+
mock_logger.warning.assert_called_once()
327+
assert "Missing completionReason" in mock_logger.warning.call_args[0][0]
328+
329+
330+
def test_batch_result_from_dict_infer_all_completed_all_succeeded():
331+
"""Test BatchResult from_dict infers ALL_COMPLETED when all items succeeded."""
332+
data = {
333+
"all": [
334+
{"index": 0, "status": "SUCCEEDED", "result": "result1", "error": None},
335+
{"index": 1, "status": "SUCCEEDED", "result": "result2", "error": None},
336+
],
337+
# No completionReason provided
338+
}
339+
340+
with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger:
341+
result = BatchResult.from_dict(data)
342+
assert result.completion_reason == CompletionReason.ALL_COMPLETED
343+
mock_logger.warning.assert_called_once()
344+
345+
346+
def test_batch_result_from_dict_infer_failure_tolerance_exceeded_all_failed():
347+
"""Test BatchResult from_dict infers FAILURE_TOLERANCE_EXCEEDED when all items failed."""
348+
error_data = {
349+
"message": "Test error",
350+
"type": "TestError",
351+
"data": None,
352+
"stackTrace": None,
353+
}
354+
data = {
355+
"all": [
356+
{"index": 0, "status": "FAILED", "result": None, "error": error_data},
357+
{"index": 1, "status": "FAILED", "result": None, "error": error_data},
358+
],
359+
# No completionReason provided
360+
}
361+
362+
# even if everything has failed, if we've completed all items, then we've finished as ALL_COMPLETED
363+
# https://github.com/aws/aws-durable-execution-sdk-js/blob/f20396f24afa9d6539d8e5056ee851ac7ef62301/packages/aws-durable-execution-sdk-js/src/handlers/concurrent-execution-handler/concurrent-execution-handler.ts#L324-L335
364+
with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger:
365+
result = BatchResult.from_dict(data)
366+
assert result.completion_reason == CompletionReason.ALL_COMPLETED
367+
mock_logger.warning.assert_called_once()
368+
369+
370+
def test_batch_result_from_dict_infer_all_completed_mixed_success_failure():
371+
"""Test BatchResult from_dict infers ALL_COMPLETED when mix of success/failure but no started items."""
372+
error_data = {
373+
"message": "Test error",
374+
"type": "TestError",
375+
"data": None,
376+
"stackTrace": None,
377+
}
378+
data = {
379+
"all": [
380+
{"index": 0, "status": "SUCCEEDED", "result": "result1", "error": None},
381+
{"index": 1, "status": "FAILED", "result": None, "error": error_data},
382+
{"index": 2, "status": "SUCCEEDED", "result": "result2", "error": None},
383+
],
384+
# No completionReason provided
385+
}
386+
387+
# the logic is that when \every item i: hasCompleted(i) then terminate due to all_completed
388+
with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger:
389+
result = BatchResult.from_dict(data)
390+
assert result.completion_reason == CompletionReason.ALL_COMPLETED
391+
mock_logger.warning.assert_called_once()
392+
393+
394+
def test_batch_result_from_dict_infer_min_successful_reached_has_started():
395+
"""Test BatchResult from_dict infers MIN_SUCCESSFUL_REACHED when items are still started."""
396+
data = {
397+
"all": [
398+
{"index": 0, "status": "SUCCEEDED", "result": "result1", "error": None},
399+
{"index": 1, "status": "STARTED", "result": None, "error": None},
400+
{"index": 2, "status": "SUCCEEDED", "result": "result2", "error": None},
401+
],
402+
# No completionReason provided
403+
}
404+
405+
with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger:
406+
result = BatchResult.from_dict(data, CompletionConfig(1))
407+
assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED
408+
mock_logger.warning.assert_called_once()
409+
410+
411+
def test_batch_result_from_dict_infer_empty_items():
412+
"""Test BatchResult from_dict infers ALL_COMPLETED for empty items."""
413+
data = {
414+
"all": [],
415+
# No completionReason provided
416+
}
417+
418+
with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger:
419+
result = BatchResult.from_dict(data)
420+
assert result.completion_reason == CompletionReason.ALL_COMPLETED
421+
mock_logger.warning.assert_called_once()
422+
423+
424+
def test_batch_result_from_dict_with_explicit_completion_reason():
425+
"""Test BatchResult from_dict uses explicit completionReason when provided."""
426+
data = {
427+
"all": [
428+
{"index": 0, "status": "SUCCEEDED", "result": "result1", "error": None}
429+
],
430+
"completionReason": "MIN_SUCCESSFUL_REACHED",
431+
}
432+
433+
with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger:
434+
result = BatchResult.from_dict(data)
435+
assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED
436+
# No warning should be logged when completionReason is provided
437+
mock_logger.warning.assert_not_called()
438+
439+
440+
def test_batch_result_infer_completion_reason_edge_cases():
441+
"""Test _infer_completion_reason method with various edge cases."""
442+
# Test with only started items
443+
started_items = [
444+
BatchItem(0, BatchItemStatus.STARTED).to_dict(),
445+
BatchItem(1, BatchItemStatus.STARTED).to_dict(),
446+
]
447+
items = {"all": started_items}
448+
batch = BatchResult.from_dict(items, CompletionConfig(0)) # SLF001
449+
# this state is not possible with CompletionConfig(0)
450+
assert batch.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED
451+
452+
# Test with only started items
453+
started_items = [
454+
BatchItem(0, BatchItemStatus.STARTED).to_dict(),
455+
BatchItem(1, BatchItemStatus.STARTED).to_dict(),
456+
]
457+
items = {"all": started_items}
458+
batch = BatchResult.from_dict(items) # SLF001
459+
# this state is not possible with CompletionConfig(0)
460+
assert batch.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED
461+
462+
# Test with only failed items
463+
failed_items = [
464+
BatchItem(
465+
0, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None)
466+
).to_dict(),
467+
BatchItem(
468+
1, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None)
469+
).to_dict(),
470+
]
471+
failed_items = {"all": failed_items}
472+
batch = BatchResult.from_dict(failed_items) # SLF001
473+
assert batch.completion_reason == CompletionReason.ALL_COMPLETED
474+
475+
# Test with only succeeded items
476+
succeeded_items = [
477+
BatchItem(0, BatchItemStatus.SUCCEEDED, "result1").to_dict(),
478+
BatchItem(1, BatchItemStatus.SUCCEEDED, "result2").to_dict(),
479+
]
480+
succeeded_items = {"all": succeeded_items}
481+
batch = BatchResult.from_dict(succeeded_items) # SLF001
482+
assert batch.completion_reason == CompletionReason.ALL_COMPLETED
483+
484+
# Test with mixed but no started (all completed)
485+
mixed_items = [
486+
BatchItem(0, BatchItemStatus.SUCCEEDED, "result1"),
487+
BatchItem(
488+
1, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None)
489+
),
490+
]
491+
492+
batch = BatchResult.from_items(mixed_items) # SLF001
493+
assert batch.completion_reason == CompletionReason.ALL_COMPLETED
324494

325495

326496
def test_batch_result_get_results_empty():
@@ -1692,3 +1862,94 @@ def test_timer_scheduler_future_time_condition_false():
16921862

16931863
# Callback should not be called since time is in future
16941864
callback.assert_not_called()
1865+
1866+
1867+
def test_batch_result_from_dict_with_completion_config():
1868+
"""Test BatchResult from_dict with completion config parameter."""
1869+
data = {
1870+
"all": [
1871+
{"index": 0, "status": "SUCCEEDED", "result": "result1", "error": None},
1872+
{"index": 1, "status": "STARTED", "result": None, "error": None},
1873+
],
1874+
# No completionReason provided
1875+
}
1876+
1877+
# With started items, should infer MIN_SUCCESSFUL_REACHED
1878+
completion_config = CompletionConfig(min_successful=1)
1879+
1880+
with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger:
1881+
result = BatchResult.from_dict(data, completion_config)
1882+
assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED
1883+
mock_logger.warning.assert_called_once()
1884+
1885+
1886+
def test_batch_result_from_dict_all_completed():
1887+
"""Test BatchResult from_dict infers ALL_COMPLETED when all items are completed."""
1888+
data = {
1889+
"all": [
1890+
{"index": 0, "status": "SUCCEEDED", "result": "result1", "error": None},
1891+
{
1892+
"index": 1,
1893+
"status": "FAILED",
1894+
"result": None,
1895+
"error": {
1896+
"message": "error",
1897+
"type": "Error",
1898+
"data": None,
1899+
"stackTrace": None,
1900+
},
1901+
},
1902+
],
1903+
# No completionReason provided
1904+
}
1905+
1906+
with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger:
1907+
result = BatchResult.from_dict(data)
1908+
assert result.completion_reason == CompletionReason.ALL_COMPLETED
1909+
mock_logger.warning.assert_called_once()
1910+
1911+
1912+
def test_batch_result_from_dict_backward_compatibility():
1913+
"""Test BatchResult from_dict maintains backward compatibility when no completion_config provided."""
1914+
data = {
1915+
"all": [
1916+
{"index": 0, "status": "SUCCEEDED", "result": "result1", "error": None}
1917+
],
1918+
"completionReason": "MIN_SUCCESSFUL_REACHED",
1919+
}
1920+
1921+
# Should work without completion_config parameter
1922+
result = BatchResult.from_dict(data)
1923+
assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED
1924+
1925+
# Should also work with None completion_config
1926+
result2 = BatchResult.from_dict(data, None)
1927+
assert result2.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED
1928+
1929+
1930+
def test_batch_result_infer_completion_reason_basic_cases():
1931+
"""Test _infer_completion_reason method with basic scenarios."""
1932+
# Test with started items - should be MIN_SUCCESSFUL_REACHED
1933+
items = {
1934+
"all": [
1935+
BatchItem(0, BatchItemStatus.SUCCEEDED, "result1").to_dict(),
1936+
BatchItem(1, BatchItemStatus.STARTED).to_dict(),
1937+
]
1938+
}
1939+
batch = BatchResult.from_dict(items, CompletionConfig(1))
1940+
assert batch.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED
1941+
1942+
# Test with all completed items - should be ALL_COMPLETED
1943+
completed_items = [
1944+
BatchItem(0, BatchItemStatus.SUCCEEDED, "result1").to_dict(),
1945+
BatchItem(
1946+
1, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None)
1947+
).to_dict(),
1948+
]
1949+
completed_items = {"all": completed_items}
1950+
batch = BatchResult.from_dict(completed_items, CompletionConfig(1))
1951+
assert batch.completion_reason == CompletionReason.ALL_COMPLETED
1952+
1953+
# Test empty items - should be ALL_COMPLETED
1954+
batch = BatchResult.from_dict({"all": []}, CompletionConfig(1))
1955+
assert batch.completion_reason == CompletionReason.ALL_COMPLETED

0 commit comments

Comments
 (0)