Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 62 additions & 12 deletions src/aws_durable_execution_sdk_python/concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import threading
import time
from abc import ABC, abstractmethod
from collections import Counter
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
from enum import Enum
Expand Down Expand Up @@ -98,16 +99,69 @@ class BatchResult(Generic[R], BatchResultProtocol[R]): # noqa: PYI059
completion_reason: CompletionReason

@classmethod
def from_dict(cls, data: dict) -> BatchResult[R]:
def from_dict(
cls, data: dict, completion_config: CompletionConfig | None = None
) -> BatchResult[R]:
batch_items: list[BatchItem[R]] = [
BatchItem.from_dict(item) for item in data["all"]
]
# TODO: is this valid? assuming completion reason is ALL_COMPLETED?
completion_reason = CompletionReason(
data.get("completionReason", "ALL_COMPLETED")
)

completion_reason_value = data.get("completionReason")
if completion_reason_value is None:
# Infer completion reason from batch item statuses and completion config
# This aligns with the TypeScript implementation that uses completion config
# to accurately reconstruct the completion reason during replay
result = cls.from_items(batch_items, completion_config)
logger.warning(
"Missing completionReason in BatchResult deserialization, "
"inferred '%s' from batch item statuses. "
"This may indicate incomplete serialization data.",
result.completion_reason.value,
)
return result

completion_reason = CompletionReason(completion_reason_value)
return cls(batch_items, completion_reason)

@classmethod
def from_items(
cls,
items: list[BatchItem[R]],
completion_config: CompletionConfig | None = None,
):
"""
Infer completion reason based on batch item statuses and completion config.

This follows the same logic as the TypeScript implementation:
- If all items completed: ALL_COMPLETED
- If minSuccessful threshold met and not all completed: MIN_SUCCESSFUL_REACHED
- Otherwise: FAILURE_TOLERANCE_EXCEEDED
"""

statuses = (item.status for item in items)
counts = Counter(statuses)
succeeded_count = counts.get(BatchItemStatus.SUCCEEDED, 0)
failed_count = counts.get(BatchItemStatus.FAILED, 0)
started_count = counts.get(BatchItemStatus.STARTED, 0)

completed_count = succeeded_count + failed_count
total_count = started_count + completed_count

# If all items completed (no started items), it's ALL_COMPLETED
if completed_count == total_count:
completion_reason = CompletionReason.ALL_COMPLETED
elif ( # If we have completion config and minSuccessful threshold is met
completion_config
and (min_successful := completion_config.min_successful) is not None
and succeeded_count >= min_successful
):
completion_reason = CompletionReason.MIN_SUCCESSFUL_REACHED
else:
# Otherwise, assume failure tolerance was exceeded
completion_reason = CompletionReason.FAILURE_TOLERANCE_EXCEEDED

return cls(items, completion_reason)

def to_dict(self) -> dict:
return {
"all": [item.to_dict() for item in self.all],
Expand Down Expand Up @@ -163,19 +217,15 @@ def get_errors(self) -> list[ErrorObject]:

@property
def success_count(self) -> int:
return len(
[item for item in self.all if item.status is BatchItemStatus.SUCCEEDED]
)
return sum(1 for item in self.all if item.status is BatchItemStatus.SUCCEEDED)

@property
def failure_count(self) -> int:
return len([item for item in self.all if item.status is BatchItemStatus.FAILED])
return sum(1 for item in self.all if item.status is BatchItemStatus.FAILED)

@property
def started_count(self) -> int:
return len(
[item for item in self.all if item.status is BatchItemStatus.STARTED]
)
return sum(1 for item in self.all if item.status is BatchItemStatus.STARTED)

@property
def total_count(self) -> int:
Expand Down
265 changes: 263 additions & 2 deletions tests/concurrency_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,178 @@ def test_batch_result_from_dict_default_completion_reason():
# No completionReason provided
}

result = BatchResult.from_dict(data)
assert result.completion_reason == CompletionReason.ALL_COMPLETED
with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger:
result = BatchResult.from_dict(data)
assert result.completion_reason == CompletionReason.ALL_COMPLETED
# Verify warning was logged
mock_logger.warning.assert_called_once()
assert "Missing completionReason" in mock_logger.warning.call_args[0][0]


def test_batch_result_from_dict_infer_all_completed_all_succeeded():
"""Test BatchResult from_dict infers ALL_COMPLETED when all items succeeded."""
data = {
"all": [
{"index": 0, "status": "SUCCEEDED", "result": "result1", "error": None},
{"index": 1, "status": "SUCCEEDED", "result": "result2", "error": None},
],
# No completionReason provided
}

with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger:
result = BatchResult.from_dict(data)
assert result.completion_reason == CompletionReason.ALL_COMPLETED
mock_logger.warning.assert_called_once()


def test_batch_result_from_dict_infer_failure_tolerance_exceeded_all_failed():
"""Test BatchResult from_dict infers FAILURE_TOLERANCE_EXCEEDED when all items failed."""
error_data = {
"message": "Test error",
"type": "TestError",
"data": None,
"stackTrace": None,
}
data = {
"all": [
{"index": 0, "status": "FAILED", "result": None, "error": error_data},
{"index": 1, "status": "FAILED", "result": None, "error": error_data},
],
# No completionReason provided
}

# even if everything has failed, if we've completed all items, then we've finished as ALL_COMPLETED
# https://github.com/aws/aws-durable-execution-sdk-js/blob/f20396f24afa9d6539d8e5056ee851ac7ef62301/packages/aws-durable-execution-sdk-js/src/handlers/concurrent-execution-handler/concurrent-execution-handler.ts#L324-L335
with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger:
result = BatchResult.from_dict(data)
assert result.completion_reason == CompletionReason.ALL_COMPLETED
mock_logger.warning.assert_called_once()


def test_batch_result_from_dict_infer_all_completed_mixed_success_failure():
"""Test BatchResult from_dict infers ALL_COMPLETED when mix of success/failure but no started items."""
error_data = {
"message": "Test error",
"type": "TestError",
"data": None,
"stackTrace": None,
}
data = {
"all": [
{"index": 0, "status": "SUCCEEDED", "result": "result1", "error": None},
{"index": 1, "status": "FAILED", "result": None, "error": error_data},
{"index": 2, "status": "SUCCEEDED", "result": "result2", "error": None},
],
# No completionReason provided
}

# the logic is that when \every item i: hasCompleted(i) then terminate due to all_completed
with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger:
result = BatchResult.from_dict(data)
assert result.completion_reason == CompletionReason.ALL_COMPLETED
mock_logger.warning.assert_called_once()


def test_batch_result_from_dict_infer_min_successful_reached_has_started():
"""Test BatchResult from_dict infers MIN_SUCCESSFUL_REACHED when items are still started."""
data = {
"all": [
{"index": 0, "status": "SUCCEEDED", "result": "result1", "error": None},
{"index": 1, "status": "STARTED", "result": None, "error": None},
{"index": 2, "status": "SUCCEEDED", "result": "result2", "error": None},
],
# No completionReason provided
}

with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger:
result = BatchResult.from_dict(data, CompletionConfig(1))
assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED
mock_logger.warning.assert_called_once()


def test_batch_result_from_dict_infer_empty_items():
"""Test BatchResult from_dict infers ALL_COMPLETED for empty items."""
data = {
"all": [],
# No completionReason provided
}

with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger:
result = BatchResult.from_dict(data)
assert result.completion_reason == CompletionReason.ALL_COMPLETED
mock_logger.warning.assert_called_once()


def test_batch_result_from_dict_with_explicit_completion_reason():
"""Test BatchResult from_dict uses explicit completionReason when provided."""
data = {
"all": [
{"index": 0, "status": "SUCCEEDED", "result": "result1", "error": None}
],
"completionReason": "MIN_SUCCESSFUL_REACHED",
}

with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger:
result = BatchResult.from_dict(data)
assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED
# No warning should be logged when completionReason is provided
mock_logger.warning.assert_not_called()


def test_batch_result_infer_completion_reason_edge_cases():
"""Test _infer_completion_reason method with various edge cases."""
# Test with only started items
started_items = [
BatchItem(0, BatchItemStatus.STARTED).to_dict(),
BatchItem(1, BatchItemStatus.STARTED).to_dict(),
]
items = {"all": started_items}
batch = BatchResult.from_dict(items, CompletionConfig(0)) # SLF001
# this state is not possible with CompletionConfig(0)
assert batch.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED

# Test with only started items
started_items = [
BatchItem(0, BatchItemStatus.STARTED).to_dict(),
BatchItem(1, BatchItemStatus.STARTED).to_dict(),
]
items = {"all": started_items}
batch = BatchResult.from_dict(items) # SLF001
# this state is not possible with CompletionConfig(0)
assert batch.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED

# Test with only failed items
failed_items = [
BatchItem(
0, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None)
).to_dict(),
BatchItem(
1, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None)
).to_dict(),
]
failed_items = {"all": failed_items}
batch = BatchResult.from_dict(failed_items) # SLF001
assert batch.completion_reason == CompletionReason.ALL_COMPLETED

# Test with only succeeded items
succeeded_items = [
BatchItem(0, BatchItemStatus.SUCCEEDED, "result1").to_dict(),
BatchItem(1, BatchItemStatus.SUCCEEDED, "result2").to_dict(),
]
succeeded_items = {"all": succeeded_items}
batch = BatchResult.from_dict(succeeded_items) # SLF001
assert batch.completion_reason == CompletionReason.ALL_COMPLETED

# Test with mixed but no started (all completed)
mixed_items = [
BatchItem(0, BatchItemStatus.SUCCEEDED, "result1"),
BatchItem(
1, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None)
),
]

batch = BatchResult.from_items(mixed_items) # SLF001
assert batch.completion_reason == CompletionReason.ALL_COMPLETED


def test_batch_result_get_results_empty():
Expand Down Expand Up @@ -1692,3 +1862,94 @@ def test_timer_scheduler_future_time_condition_false():

# Callback should not be called since time is in future
callback.assert_not_called()


def test_batch_result_from_dict_with_completion_config():
"""Test BatchResult from_dict with completion config parameter."""
data = {
"all": [
{"index": 0, "status": "SUCCEEDED", "result": "result1", "error": None},
{"index": 1, "status": "STARTED", "result": None, "error": None},
],
# No completionReason provided
}

# With started items, should infer MIN_SUCCESSFUL_REACHED
completion_config = CompletionConfig(min_successful=1)

with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger:
result = BatchResult.from_dict(data, completion_config)
assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED
mock_logger.warning.assert_called_once()


def test_batch_result_from_dict_all_completed():
"""Test BatchResult from_dict infers ALL_COMPLETED when all items are completed."""
data = {
"all": [
{"index": 0, "status": "SUCCEEDED", "result": "result1", "error": None},
{
"index": 1,
"status": "FAILED",
"result": None,
"error": {
"message": "error",
"type": "Error",
"data": None,
"stackTrace": None,
},
},
],
# No completionReason provided
}

with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger:
result = BatchResult.from_dict(data)
assert result.completion_reason == CompletionReason.ALL_COMPLETED
mock_logger.warning.assert_called_once()


def test_batch_result_from_dict_backward_compatibility():
"""Test BatchResult from_dict maintains backward compatibility when no completion_config provided."""
data = {
"all": [
{"index": 0, "status": "SUCCEEDED", "result": "result1", "error": None}
],
"completionReason": "MIN_SUCCESSFUL_REACHED",
}

# Should work without completion_config parameter
result = BatchResult.from_dict(data)
assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED

# Should also work with None completion_config
result2 = BatchResult.from_dict(data, None)
assert result2.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED


def test_batch_result_infer_completion_reason_basic_cases():
"""Test _infer_completion_reason method with basic scenarios."""
# Test with started items - should be MIN_SUCCESSFUL_REACHED
items = {
"all": [
BatchItem(0, BatchItemStatus.SUCCEEDED, "result1").to_dict(),
BatchItem(1, BatchItemStatus.STARTED).to_dict(),
]
}
batch = BatchResult.from_dict(items, CompletionConfig(1))
assert batch.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED

# Test with all completed items - should be ALL_COMPLETED
completed_items = [
BatchItem(0, BatchItemStatus.SUCCEEDED, "result1").to_dict(),
BatchItem(
1, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None)
).to_dict(),
]
completed_items = {"all": completed_items}
batch = BatchResult.from_dict(completed_items, CompletionConfig(1))
assert batch.completion_reason == CompletionReason.ALL_COMPLETED

# Test empty items - should be ALL_COMPLETED
batch = BatchResult.from_dict({"all": []}, CompletionConfig(1))
assert batch.completion_reason == CompletionReason.ALL_COMPLETED
Loading