Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.

Large diffs are not rendered by default.

477 changes: 477 additions & 0 deletions src/aws_durable_execution_sdk_python/concurrency/models.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions src/aws_durable_execution_sdk_python/operation/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Generic, TypeVar

from aws_durable_execution_sdk_python.concurrency import (
from aws_durable_execution_sdk_python.concurrency.impl import ConcurrentExecutor
from aws_durable_execution_sdk_python.concurrency.models import (
BatchResult,
ConcurrentExecutor,
Executable,
)
from aws_durable_execution_sdk_python.config import MapConfig
Expand Down
5 changes: 3 additions & 2 deletions src/aws_durable_execution_sdk_python/operation/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, TypeVar

from aws_durable_execution_sdk_python.concurrency import ConcurrentExecutor, Executable
from aws_durable_execution_sdk_python.concurrency.impl import ConcurrentExecutor
from aws_durable_execution_sdk_python.concurrency.models import Executable
from aws_durable_execution_sdk_python.config import ParallelConfig
from aws_durable_execution_sdk_python.lambda_service import OperationSubType

if TYPE_CHECKING:
from aws_durable_execution_sdk_python.concurrency import BatchResult
from aws_durable_execution_sdk_python.concurrency.models import BatchResult
from aws_durable_execution_sdk_python.context import DurableContext
from aws_durable_execution_sdk_python.identifier import OperationIdentifier
from aws_durable_execution_sdk_python.serdes import SerDes
Expand Down
19 changes: 18 additions & 1 deletion src/aws_durable_execution_sdk_python/serdes.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from enum import StrEnum
from typing import Any, Generic, Protocol, TypeVar

from aws_durable_execution_sdk_python.concurrency.models import BatchResult
from aws_durable_execution_sdk_python.exceptions import (
DurableExecutionsError,
ExecutionError,
Expand Down Expand Up @@ -62,6 +63,7 @@ class TypeTag(StrEnum):
TUPLE = "t"
LIST = "l"
DICT = "m"
BATCH_RESULT = "br"


@dataclass(frozen=True)
Expand Down Expand Up @@ -206,7 +208,14 @@ def dispatcher(self):

def encode(self, obj: Any) -> EncodedValue:
"""Encode container using dispatcher for recursive elements."""

match obj:
case BatchResult():
# Encode BatchResult as dict with special tag
return EncodedValue(
TypeTag.BATCH_RESULT,
self._wrap(obj.to_dict(), self.dispatcher).value,
)
case list():
return EncodedValue(
TypeTag.LIST, [self._wrap(v, self.dispatcher) for v in obj]
Expand All @@ -230,7 +239,13 @@ def encode(self, obj: Any) -> EncodedValue:

def decode(self, tag: TypeTag, value: Any) -> Any:
"""Decode container using dispatcher for recursive elements."""

match tag:
case TypeTag.BATCH_RESULT:
# Decode BatchResult from dict - value is already the dict structure
# First decode it as a dict to unwrap all nested EncodedValues
decoded_dict = self.decode(TypeTag.DICT, value)
return BatchResult.from_dict(decoded_dict)
case TypeTag.LIST:
if not isinstance(value, list):
msg = f"Expected list, got {type(value)}"
Expand Down Expand Up @@ -295,6 +310,8 @@ def encode(self, obj: Any) -> EncodedValue:
case list() | tuple() | dict():
return self.container_codec.encode(obj)
case _:
if isinstance(obj, BatchResult):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why here too? (see line 248)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oooopsies, remnant from cyclical dependency

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah will fix.

return self.container_codec.encode(obj)
msg = f"Unsupported type: {type(obj)}"
raise SerDesError(msg)

Expand All @@ -316,7 +333,7 @@ def decode(self, tag: TypeTag, value: Any) -> Any:
return self.decimal_codec.decode(tag, value)
case TypeTag.DATETIME | TypeTag.DATE:
return self.datetime_codec.decode(tag, value)
case TypeTag.LIST | TypeTag.TUPLE | TypeTag.DICT:
case TypeTag.LIST | TypeTag.TUPLE | TypeTag.DICT | TypeTag.BATCH_RESULT:
return self.container_codec.decode(tag, value)
case _:
msg = f"Unknown type tag: {tag}"
Expand Down
145 changes: 110 additions & 35 deletions tests/concurrency_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests for the concurrency module."""

import json
import random
import threading
import time
Expand All @@ -10,17 +11,19 @@

import pytest

from aws_durable_execution_sdk_python.concurrency import (
from aws_durable_execution_sdk_python.concurrency.impl import (
ConcurrentExecutor,
TimerScheduler,
)
from aws_durable_execution_sdk_python.concurrency.models import (
BatchItem,
BatchItemStatus,
BatchResult,
BranchStatus,
CompletionReason,
ConcurrentExecutor,
Executable,
ExecutableWithState,
ExecutionCounters,
TimerScheduler,
)
from aws_durable_execution_sdk_python.config import CompletionConfig, MapConfig
from aws_durable_execution_sdk_python.exceptions import (
Expand Down Expand Up @@ -102,28 +105,6 @@ def test_batch_item_from_dict():
assert item.error is None


def test_batch_item_from_dict_with_error():
"""Test BatchItem from_dict with error object."""
error_data = {
"message": "Test error",
"type": "TestError",
"data": None,
"stackTrace": None,
}
data = {
"index": 1,
"status": "FAILED",
"result": None,
"error": error_data,
}

item = BatchItem.from_dict(data)
assert item.index == 1
assert item.status == BatchItemStatus.FAILED
assert item.result is None
assert item.error is not None


def test_batch_result_creation():
"""Test BatchResult creation."""
items = [
Expand Down Expand Up @@ -323,7 +304,9 @@ def test_batch_result_from_dict_default_completion_reason():
# No completionReason provided
}

with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger:
with patch(
"aws_durable_execution_sdk_python.concurrency.models.logger"
) as mock_logger:
result = BatchResult.from_dict(data)
assert result.completion_reason == CompletionReason.ALL_COMPLETED
# Verify warning was logged
Expand All @@ -341,7 +324,9 @@ def test_batch_result_from_dict_infer_all_completed_all_succeeded():
# No completionReason provided
}

with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger:
with patch(
"aws_durable_execution_sdk_python.concurrency.models.logger"
) as mock_logger:
result = BatchResult.from_dict(data)
assert result.completion_reason == CompletionReason.ALL_COMPLETED
mock_logger.warning.assert_called_once()
Expand All @@ -365,7 +350,9 @@ def test_batch_result_from_dict_infer_failure_tolerance_exceeded_all_failed():

# even if everything has failed, if we've completed all items, then we've finished as ALL_COMPLETED
# https://github.com/aws/aws-durable-execution-sdk-js/blob/f20396f24afa9d6539d8e5056ee851ac7ef62301/packages/aws-durable-execution-sdk-js/src/handlers/concurrent-execution-handler/concurrent-execution-handler.ts#L324-L335
with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger:
with patch(
"aws_durable_execution_sdk_python.concurrency.models.logger"
) as mock_logger:
result = BatchResult.from_dict(data)
assert result.completion_reason == CompletionReason.ALL_COMPLETED
mock_logger.warning.assert_called_once()
Expand All @@ -389,7 +376,9 @@ def test_batch_result_from_dict_infer_all_completed_mixed_success_failure():
}

# the logic is that when \every item i: hasCompleted(i) then terminate due to all_completed
with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger:
with patch(
"aws_durable_execution_sdk_python.concurrency.models.logger"
) as mock_logger:
result = BatchResult.from_dict(data)
assert result.completion_reason == CompletionReason.ALL_COMPLETED
mock_logger.warning.assert_called_once()
Expand All @@ -406,7 +395,9 @@ def test_batch_result_from_dict_infer_min_successful_reached_has_started():
# No completionReason provided
}

with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger:
with patch(
"aws_durable_execution_sdk_python.concurrency.models.logger"
) as mock_logger:
result = BatchResult.from_dict(data, CompletionConfig(1))
assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED
mock_logger.warning.assert_called_once()
Expand All @@ -419,7 +410,9 @@ def test_batch_result_from_dict_infer_empty_items():
# No completionReason provided
}

with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger:
with patch(
"aws_durable_execution_sdk_python.concurrency.models.logger"
) as mock_logger:
result = BatchResult.from_dict(data)
assert result.completion_reason == CompletionReason.ALL_COMPLETED
mock_logger.warning.assert_called_once()
Expand All @@ -434,7 +427,9 @@ def test_batch_result_from_dict_with_explicit_completion_reason():
"completionReason": "MIN_SUCCESSFUL_REACHED",
}

with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger:
with patch(
"aws_durable_execution_sdk_python.concurrency.models.logger"
) as mock_logger:
result = BatchResult.from_dict(data)
assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED
# No warning should be logged when completionReason is provided
Expand Down Expand Up @@ -2373,7 +2368,9 @@ def test_batch_result_from_dict_with_completion_config():
# With started items, should infer MIN_SUCCESSFUL_REACHED
completion_config = CompletionConfig(min_successful=1)

with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger:
with patch(
"aws_durable_execution_sdk_python.concurrency.models.logger"
) as mock_logger:
result = BatchResult.from_dict(data, completion_config)
assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED
mock_logger.warning.assert_called_once()
Expand All @@ -2399,7 +2396,9 @@ def test_batch_result_from_dict_all_completed():
# No completionReason provided
}

with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger:
with patch(
"aws_durable_execution_sdk_python.concurrency.models.logger"
) as mock_logger:
result = BatchResult.from_dict(data)
assert result.completion_reason == CompletionReason.ALL_COMPLETED
mock_logger.warning.assert_called_once()
Expand Down Expand Up @@ -2520,7 +2519,7 @@ def create_child_context(operation_id):
executor_context.create_child_context = create_child_context

with patch(
"aws_durable_execution_sdk_python.concurrency.child_handler",
"aws_durable_execution_sdk_python.concurrency.impl.child_handler",
patched_child_handler,
):
executor.execute(execution_state, executor_context)
Expand Down Expand Up @@ -2676,3 +2675,79 @@ def mock_get_checkpoint_result(operation_id):
assert len(result.all) == 1
assert result.all[0].status == BatchItemStatus.SUCCEEDED
assert result.all[0].result == "re_executed_result"


def test_batch_item_from_dict_with_error():
"""Test BatchItem.from_dict() with error."""
data = {
"index": 3,
"status": "FAILED",
"result": None,
"error": {
"ErrorType": "ValueError",
"ErrorMessage": "bad value",
"StackTrace": [],
},
}

item = BatchItem.from_dict(data)

assert item.index == 3
assert item.status == BatchItemStatus.FAILED
assert item.error.type == "ValueError"
assert item.error.message == "bad value"


def test_batch_result_with_mixed_statuses():
"""Test BatchResult serialization with mixed item statuses."""
result = BatchResult(
all=[
BatchItem(0, BatchItemStatus.SUCCEEDED, result="success"),
BatchItem(
1,
BatchItemStatus.FAILED,
error=ErrorObject(message="msg", type="E", data=None, stack_trace=[]),
),
BatchItem(2, BatchItemStatus.STARTED),
],
completion_reason=CompletionReason.FAILURE_TOLERANCE_EXCEEDED,
)

serialized = json.dumps(result.to_dict())
deserialized = BatchResult.from_dict(json.loads(serialized))

assert len(deserialized.all) == 3
assert deserialized.all[0].status == BatchItemStatus.SUCCEEDED
assert deserialized.all[1].status == BatchItemStatus.FAILED
assert deserialized.all[2].status == BatchItemStatus.STARTED
assert deserialized.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED


def test_batch_result_empty_list():
"""Test BatchResult serialization with empty items list."""
result = BatchResult(all=[], completion_reason=CompletionReason.ALL_COMPLETED)

serialized = json.dumps(result.to_dict())
deserialized = BatchResult.from_dict(json.loads(serialized))

assert len(deserialized.all) == 0
assert deserialized.completion_reason == CompletionReason.ALL_COMPLETED


def test_batch_result_complex_nested_data():
"""Test BatchResult with complex nested data structures."""
complex_result = {
"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}],
"metadata": {"count": 2, "timestamp": "2025-10-31"},
}

result = BatchResult(
all=[BatchItem(0, BatchItemStatus.SUCCEEDED, result=complex_result)],
completion_reason=CompletionReason.ALL_COMPLETED,
)

serialized = json.dumps(result.to_dict())
deserialized = BatchResult.from_dict(json.loads(serialized))

assert deserialized.all[0].result == complex_result
assert deserialized.all[0].result["users"][0]["name"] == "Alice"
Loading
Loading