Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions src/aws_durable_execution_sdk_python/concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,5 +824,39 @@ def run_in_child_handler():
),
)

def replay(self, execution_state: ExecutionState, executor_context: DurableContext):
"""
Replay rather than re-run children.

if we are here, then we are in replay_children.
This will pre-generate all the operation ids for the children and collect the checkpointed
results.
"""
items: list[BatchItem[ResultType]] = []
for executable in self.executables:
operation_id = executor_context._create_step_id_for_logical_step( # noqa: SLF001
executable.index
)
checkpoint = execution_state.get_checkpoint_result(operation_id)

result: ResultType | None = None
error = None
status: BatchItemStatus
if checkpoint.is_succeeded():
status = BatchItemStatus.SUCCEEDED
result = self._execute_item_in_child_context(
executor_context, executable
)

elif checkpoint.is_failed():
error = checkpoint.error
status = BatchItemStatus.FAILED
else:
status = BatchItemStatus.STARTED

batch_item = BatchItem(executable.index, status, result=result, error=error)
items.append(batch_item)
return BatchResult.from_items(items, self.completion_config)


# endregion concurrency logic
28 changes: 22 additions & 6 deletions src/aws_durable_execution_sdk_python/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,13 @@ def map(
"""Execute a callable for each item in parallel."""
map_name: str | None = self._resolve_step_name(name, func)

def map_in_child_context(map_context) -> BatchResult[R]:
operation_id = self._create_step_id()
operation_identifier = OperationIdentifier(
operation_id=operation_id, parent_id=self._parent_id, name=map_name
)
map_context = self.create_child_context(parent_id=operation_id)

def map_in_child_context() -> BatchResult[R]:
# map_context is a child_context of the context upon which `.map`
# was called. We are calling it `map_context` to make it explicit
# that any operations happening from hereon are done on the context
Expand All @@ -328,11 +334,13 @@ def map_in_child_context(map_context) -> BatchResult[R]:
config=config,
execution_state=self.state,
map_context=map_context,
operation_identifier=operation_identifier,
)

return self.run_in_child_context(
return child_handler(
func=map_in_child_context,
name=map_name,
state=self.state,
operation_identifier=operation_identifier,
config=ChildConfig(
sub_type=OperationSubType.MAP,
serdes=config.serdes if config is not None else None,
Expand All @@ -346,8 +354,14 @@ def parallel(
config: ParallelConfig | None = None,
) -> BatchResult[T]:
"""Execute multiple callables in parallel."""
# _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id
operation_id = self._create_step_id()
parallel_context = self.create_child_context(parent_id=operation_id)
operation_identifier = OperationIdentifier(
operation_id=operation_id, parent_id=self._parent_id, name=name
)

def parallel_in_child_context(parallel_context) -> BatchResult[T]:
def parallel_in_child_context() -> BatchResult[T]:
# parallel_context is a child_context of the context upon which `.map`
# was called. We are calling it `parallel_context` to make it explicit
# that any operations happening from hereon are done on the context
Expand All @@ -357,11 +371,13 @@ def parallel_in_child_context(parallel_context) -> BatchResult[T]:
config=config,
execution_state=self.state,
parallel_context=parallel_context,
operation_identifier=operation_identifier,
)

return self.run_in_child_context(
return child_handler(
func=parallel_in_child_context,
name=name,
state=self.state,
operation_identifier=operation_identifier,
config=ChildConfig(
sub_type=OperationSubType.PARALLEL,
serdes=config.serdes if config is not None else None,
Expand Down
14 changes: 13 additions & 1 deletion src/aws_durable_execution_sdk_python/operation/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@

if TYPE_CHECKING:
from aws_durable_execution_sdk_python.context import DurableContext
from aws_durable_execution_sdk_python.identifier import OperationIdentifier
from aws_durable_execution_sdk_python.serdes import SerDes
from aws_durable_execution_sdk_python.state import ExecutionState
from aws_durable_execution_sdk_python.state import (
CheckpointedResult,
ExecutionState,
)
from aws_durable_execution_sdk_python.types import SummaryGenerator

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -94,6 +98,7 @@ def map_handler(
config: MapConfig | None,
execution_state: ExecutionState,
map_context: DurableContext,
operation_identifier: OperationIdentifier,
) -> BatchResult[R]:
"""Execute a callable for each item in parallel."""
# Summary Generator Construction (matches TypeScript implementation):
Expand All @@ -107,6 +112,13 @@ def map_handler(
func=func,
config=config or MapConfig(summary_generator=MapSummaryGenerator()),
)

checkpoint: CheckpointedResult = execution_state.get_checkpoint_result(
operation_identifier.operation_id
)
if checkpoint.is_succeeded():
# if we've reached this point, then not only is the step succeeded, but it is also `replay_children`.
return executor.replay(execution_state, map_context)
# we are making it explicit that we are now executing within the map_context
return executor.execute(execution_state, executor_context=map_context)

Expand Down
8 changes: 8 additions & 0 deletions src/aws_durable_execution_sdk_python/operation/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
if TYPE_CHECKING:
from aws_durable_execution_sdk_python.concurrency import BatchResult
from aws_durable_execution_sdk_python.context import DurableContext
from aws_durable_execution_sdk_python.identifier import OperationIdentifier
from aws_durable_execution_sdk_python.serdes import SerDes
from aws_durable_execution_sdk_python.state import ExecutionState
from aws_durable_execution_sdk_python.types import SummaryGenerator
Expand Down Expand Up @@ -82,6 +83,7 @@ def parallel_handler(
config: ParallelConfig | None,
execution_state: ExecutionState,
parallel_context: DurableContext,
operation_identifier: OperationIdentifier,
) -> BatchResult[R]:
"""Execute multiple operations in parallel."""
# Summary Generator Construction (matches TypeScript implementation):
Expand All @@ -94,6 +96,12 @@ def parallel_handler(
callables,
config or ParallelConfig(summary_generator=ParallelSummaryGenerator()),
)

checkpoint = execution_state.get_checkpoint_result(
operation_identifier.operation_id
)
if checkpoint.is_succeeded():
return executor.replay(execution_state, parallel_context)
return executor.execute(execution_state, executor_context=parallel_context)


Expand Down
143 changes: 142 additions & 1 deletion tests/concurrency_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@
ExecutionCounters,
TimerScheduler,
)
from aws_durable_execution_sdk_python.config import CompletionConfig
from aws_durable_execution_sdk_python.config import CompletionConfig, MapConfig
from aws_durable_execution_sdk_python.exceptions import (
CallableRuntimeError,
InvalidStateError,
SuspendExecution,
TimedSuspendExecution,
)
from aws_durable_execution_sdk_python.lambda_service import ErrorObject
from aws_durable_execution_sdk_python.operation.map import MapExecutor


def test_batch_item_status_enum():
Expand Down Expand Up @@ -2535,3 +2536,143 @@ def create_child_context(operation_id):
assert all(
assoc1 == assoc2 for assoc1, assoc2 in combinations(associations_per_run, 2)
)


def test_concurrent_executor_replay_with_succeeded_operations():
"""Test ConcurrentExecutor replay method with succeeded operations."""

def func1(ctx, item, idx, items):
return f"result_{item}"

items = ["a", "b"]
config = MapConfig()

executor = MapExecutor.from_items(
items=items,
func=func1,
config=config,
)

# Mock execution state with succeeded operations
mock_execution_state = Mock()
mock_execution_state.durable_execution_arn = (
"arn:aws:durable:us-east-1:123456789012:execution/test"
)

def mock_get_checkpoint_result(operation_id):
mock_result = Mock()
mock_result.is_succeeded.return_value = True
mock_result.is_failed.return_value = False
mock_result.is_replay_children.return_value = False
mock_result.is_existent.return_value = True
# Provide properly serialized JSON data
mock_result.result = f'"cached_result_{operation_id}"' # JSON string
return mock_result

mock_execution_state.get_checkpoint_result = mock_get_checkpoint_result

def mock_create_step_id_for_logical_step(step):
return f"op_{step}"

# Mock executor context
mock_executor_context = Mock()
mock_executor_context._create_step_id_for_logical_step = ( # noqa
mock_create_step_id_for_logical_step
)

# Mock child context that has the same execution state
mock_child_context = Mock()
mock_child_context.state = mock_execution_state
mock_executor_context.create_child_context = Mock(return_value=mock_child_context)
mock_executor_context._parent_id = "parent_id" # noqa

result = executor.replay(mock_execution_state, mock_executor_context)

assert isinstance(result, BatchResult)
assert len(result.all) == 2
assert result.all[0].status == BatchItemStatus.SUCCEEDED
assert result.all[0].result == "cached_result_op_0"
assert result.all[1].status == BatchItemStatus.SUCCEEDED
assert result.all[1].result == "cached_result_op_1"


def test_concurrent_executor_replay_with_failed_operations():
"""Test ConcurrentExecutor replay method with failed operations."""

def func1(ctx, item, idx, items):
return f"result_{item}"

items = ["a"]
config = MapConfig()

executor = MapExecutor.from_items(
items=items,
func=func1,
config=config,
)

# Mock execution state with failed operation
mock_execution_state = Mock()

def mock_get_checkpoint_result(operation_id):
mock_result = Mock()
mock_result.is_succeeded.return_value = False
mock_result.is_failed.return_value = True
mock_result.error = Exception("Test error")
return mock_result

mock_execution_state.get_checkpoint_result = mock_get_checkpoint_result

# Mock executor context
mock_executor_context = Mock()
mock_executor_context._create_step_id_for_logical_step = Mock(return_value="op_1") # noqa: SLF001

result = executor.replay(mock_execution_state, mock_executor_context)

assert isinstance(result, BatchResult)
assert len(result.all) == 1
assert result.all[0].status == BatchItemStatus.FAILED
assert result.all[0].error is not None


def test_concurrent_executor_replay_with_replay_children():
"""Test ConcurrentExecutor replay method when children need re-execution."""

def func1(ctx, item, idx, items):
return f"result_{item}"

items = ["a"]
config = MapConfig()

executor = MapExecutor.from_items(
items=items,
func=func1,
config=config,
)

# Mock execution state with succeeded operation that needs replay
mock_execution_state = Mock()

def mock_get_checkpoint_result(operation_id):
mock_result = Mock()
mock_result.is_succeeded.return_value = True
mock_result.is_failed.return_value = False
mock_result.is_replay_children.return_value = True
return mock_result

mock_execution_state.get_checkpoint_result = mock_get_checkpoint_result

# Mock executor context
mock_executor_context = Mock()
mock_executor_context._create_step_id_for_logical_step = Mock(return_value="op_1") # noqa: SLF001

# Mock _execute_item_in_child_context to return a result
with patch.object(
executor, "_execute_item_in_child_context", return_value="re_executed_result"
):
result = executor.replay(mock_execution_state, mock_executor_context)

assert isinstance(result, BatchResult)
assert len(result.all) == 1
assert result.all[0].status == BatchItemStatus.SUCCEEDED
assert result.all[0].result == "re_executed_result"
Loading
Loading