Skip to content

Commit e3235e7

Browse files
committed
feat: Implement efficient replay mechanism
Parallel and Map now replays without using threading. Changes: - context.map and context.parallel now call child_handler directly and create their own operation identifier - `map_handler` and `parallel_handler` now checks the state of the operation and when we've already succeeded, we use `executor.replay` We test this by mocking the state, and ensuring we first run through execute and then through replay and get back the correct results.
1 parent f01f6eb commit e3235e7

File tree

8 files changed

+843
-129
lines changed

8 files changed

+843
-129
lines changed

src/aws_durable_execution_sdk_python/concurrency.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -824,5 +824,39 @@ def run_in_child_handler():
824824
),
825825
)
826826

827+
def replay(self, execution_state: ExecutionState, executor_context: DurableContext):
828+
"""
829+
Replay rather than re-run children.
830+
831+
if we are here, then we are in replay_children.
832+
This will pre-generate all the operation ids for the children and collect the checkpointed
833+
results.
834+
"""
835+
items: list[BatchItem[ResultType]] = []
836+
for executable in self.executables:
837+
operation_id = executor_context._create_step_id_for_logical_step( # noqa: SLF001
838+
executable.index
839+
)
840+
checkpoint = execution_state.get_checkpoint_result(operation_id)
841+
842+
result: ResultType | None = None
843+
error = None
844+
status: BatchItemStatus
845+
if checkpoint.is_succeeded():
846+
status = BatchItemStatus.SUCCEEDED
847+
result = self._execute_item_in_child_context(
848+
executor_context, executable
849+
)
850+
851+
elif checkpoint.is_failed():
852+
error = checkpoint.error
853+
status = BatchItemStatus.FAILED
854+
else:
855+
status = BatchItemStatus.STARTED
856+
857+
batch_item = BatchItem(executable.index, status, result=result, error=error)
858+
items.append(batch_item)
859+
return BatchResult.from_items(items, self.completion_config)
860+
827861

828862
# endregion concurrency logic

src/aws_durable_execution_sdk_python/context.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,13 @@ def map(
317317
"""Execute a callable for each item in parallel."""
318318
map_name: str | None = self._resolve_step_name(name, func)
319319

320-
def map_in_child_context(map_context) -> BatchResult[R]:
320+
operation_id = self._create_step_id()
321+
operation_identifier = OperationIdentifier(
322+
operation_id=operation_id, parent_id=self._parent_id, name=map_name
323+
)
324+
map_context = self.create_child_context(parent_id=operation_id)
325+
326+
def map_in_child_context() -> BatchResult[R]:
321327
# map_context is a child_context of the context upon which `.map`
322328
# was called. We are calling it `map_context` to make it explicit
323329
# that any operations happening from hereon are done on the context
@@ -328,11 +334,13 @@ def map_in_child_context(map_context) -> BatchResult[R]:
328334
config=config,
329335
execution_state=self.state,
330336
map_context=map_context,
337+
operation_identifier=operation_identifier,
331338
)
332339

333-
return self.run_in_child_context(
340+
return child_handler(
334341
func=map_in_child_context,
335-
name=map_name,
342+
state=self.state,
343+
operation_identifier=operation_identifier,
336344
config=ChildConfig(
337345
sub_type=OperationSubType.MAP,
338346
serdes=config.serdes if config is not None else None,
@@ -346,8 +354,14 @@ def parallel(
346354
config: ParallelConfig | None = None,
347355
) -> BatchResult[T]:
348356
"""Execute multiple callables in parallel."""
357+
# _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id
358+
operation_id = self._create_step_id()
359+
parallel_context = self.create_child_context(parent_id=operation_id)
360+
operation_identifier = OperationIdentifier(
361+
operation_id=operation_id, parent_id=self._parent_id, name=name
362+
)
349363

350-
def parallel_in_child_context(parallel_context) -> BatchResult[T]:
364+
def parallel_in_child_context() -> BatchResult[T]:
351365
# parallel_context is a child_context of the context upon which `.map`
352366
# was called. We are calling it `parallel_context` to make it explicit
353367
# that any operations happening from hereon are done on the context
@@ -357,11 +371,13 @@ def parallel_in_child_context(parallel_context) -> BatchResult[T]:
357371
config=config,
358372
execution_state=self.state,
359373
parallel_context=parallel_context,
374+
operation_identifier=operation_identifier,
360375
)
361376

362-
return self.run_in_child_context(
377+
return child_handler(
363378
func=parallel_in_child_context,
364-
name=name,
379+
state=self.state,
380+
operation_identifier=operation_identifier,
365381
config=ChildConfig(
366382
sub_type=OperationSubType.PARALLEL,
367383
serdes=config.serdes if config is not None else None,

src/aws_durable_execution_sdk_python/operation/map.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,12 @@
1717

1818
if TYPE_CHECKING:
1919
from aws_durable_execution_sdk_python.context import DurableContext
20+
from aws_durable_execution_sdk_python.identifier import OperationIdentifier
2021
from aws_durable_execution_sdk_python.serdes import SerDes
21-
from aws_durable_execution_sdk_python.state import ExecutionState
22+
from aws_durable_execution_sdk_python.state import (
23+
CheckpointedResult,
24+
ExecutionState,
25+
)
2226
from aws_durable_execution_sdk_python.types import SummaryGenerator
2327

2428
logger = logging.getLogger(__name__)
@@ -94,6 +98,7 @@ def map_handler(
9498
config: MapConfig | None,
9599
execution_state: ExecutionState,
96100
map_context: DurableContext,
101+
operation_identifier: OperationIdentifier,
97102
) -> BatchResult[R]:
98103
"""Execute a callable for each item in parallel."""
99104
# Summary Generator Construction (matches TypeScript implementation):
@@ -107,6 +112,13 @@ def map_handler(
107112
func=func,
108113
config=config or MapConfig(summary_generator=MapSummaryGenerator()),
109114
)
115+
116+
checkpoint: CheckpointedResult = execution_state.get_checkpoint_result(
117+
operation_identifier.operation_id
118+
)
119+
if checkpoint.is_succeeded():
120+
# if we've reached this point, then not only is the step succeeded, but it is also `replay_children`.
121+
return executor.replay(execution_state, map_context)
110122
# we are making it explicit that we are now executing within the map_context
111123
return executor.execute(execution_state, executor_context=map_context)
112124

src/aws_durable_execution_sdk_python/operation/parallel.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
if TYPE_CHECKING:
1515
from aws_durable_execution_sdk_python.concurrency import BatchResult
1616
from aws_durable_execution_sdk_python.context import DurableContext
17+
from aws_durable_execution_sdk_python.identifier import OperationIdentifier
1718
from aws_durable_execution_sdk_python.serdes import SerDes
1819
from aws_durable_execution_sdk_python.state import ExecutionState
1920
from aws_durable_execution_sdk_python.types import SummaryGenerator
@@ -82,6 +83,7 @@ def parallel_handler(
8283
config: ParallelConfig | None,
8384
execution_state: ExecutionState,
8485
parallel_context: DurableContext,
86+
operation_identifier: OperationIdentifier,
8587
) -> BatchResult[R]:
8688
"""Execute multiple operations in parallel."""
8789
# Summary Generator Construction (matches TypeScript implementation):
@@ -94,6 +96,12 @@ def parallel_handler(
9496
callables,
9597
config or ParallelConfig(summary_generator=ParallelSummaryGenerator()),
9698
)
99+
100+
checkpoint = execution_state.get_checkpoint_result(
101+
operation_identifier.operation_id
102+
)
103+
if checkpoint.is_succeeded():
104+
return executor.replay(execution_state, parallel_context)
97105
return executor.execute(execution_state, executor_context=parallel_context)
98106

99107

tests/concurrency_test.py

Lines changed: 142 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,15 @@
2222
ExecutionCounters,
2323
TimerScheduler,
2424
)
25-
from aws_durable_execution_sdk_python.config import CompletionConfig
25+
from aws_durable_execution_sdk_python.config import CompletionConfig, MapConfig
2626
from aws_durable_execution_sdk_python.exceptions import (
2727
CallableRuntimeError,
2828
InvalidStateError,
2929
SuspendExecution,
3030
TimedSuspendExecution,
3131
)
3232
from aws_durable_execution_sdk_python.lambda_service import ErrorObject
33+
from aws_durable_execution_sdk_python.operation.map import MapExecutor
3334

3435

3536
def test_batch_item_status_enum():
@@ -2535,3 +2536,143 @@ def create_child_context(operation_id):
25352536
assert all(
25362537
assoc1 == assoc2 for assoc1, assoc2 in combinations(associations_per_run, 2)
25372538
)
2539+
2540+
2541+
def test_concurrent_executor_replay_with_succeeded_operations():
2542+
"""Test ConcurrentExecutor replay method with succeeded operations."""
2543+
2544+
def func1(ctx, item, idx, items):
2545+
return f"result_{item}"
2546+
2547+
items = ["a", "b"]
2548+
config = MapConfig()
2549+
2550+
executor = MapExecutor.from_items(
2551+
items=items,
2552+
func=func1,
2553+
config=config,
2554+
)
2555+
2556+
# Mock execution state with succeeded operations
2557+
mock_execution_state = Mock()
2558+
mock_execution_state.durable_execution_arn = (
2559+
"arn:aws:durable:us-east-1:123456789012:execution/test"
2560+
)
2561+
2562+
def mock_get_checkpoint_result(operation_id):
2563+
mock_result = Mock()
2564+
mock_result.is_succeeded.return_value = True
2565+
mock_result.is_failed.return_value = False
2566+
mock_result.is_replay_children.return_value = False
2567+
mock_result.is_existent.return_value = True
2568+
# Provide properly serialized JSON data
2569+
mock_result.result = f'"cached_result_{operation_id}"' # JSON string
2570+
return mock_result
2571+
2572+
mock_execution_state.get_checkpoint_result = mock_get_checkpoint_result
2573+
2574+
def mock_create_step_id_for_logical_step(step):
2575+
return f"op_{step}"
2576+
2577+
# Mock executor context
2578+
mock_executor_context = Mock()
2579+
mock_executor_context._create_step_id_for_logical_step = (
2580+
mock_create_step_id_for_logical_step # noqa
2581+
)
2582+
2583+
# Mock child context that has the same execution state
2584+
mock_child_context = Mock()
2585+
mock_child_context.state = mock_execution_state
2586+
mock_executor_context.create_child_context = Mock(return_value=mock_child_context)
2587+
mock_executor_context._parent_id = "parent_id" # noqa
2588+
2589+
result = executor.replay(mock_execution_state, mock_executor_context)
2590+
2591+
assert isinstance(result, BatchResult)
2592+
assert len(result.all) == 2
2593+
assert result.all[0].status == BatchItemStatus.SUCCEEDED
2594+
assert result.all[0].result == "cached_result_op_0"
2595+
assert result.all[1].status == BatchItemStatus.SUCCEEDED
2596+
assert result.all[1].result == "cached_result_op_1"
2597+
2598+
2599+
def test_concurrent_executor_replay_with_failed_operations():
2600+
"""Test ConcurrentExecutor replay method with failed operations."""
2601+
2602+
def func1(ctx, item, idx, items):
2603+
return f"result_{item}"
2604+
2605+
items = ["a"]
2606+
config = MapConfig()
2607+
2608+
executor = MapExecutor.from_items(
2609+
items=items,
2610+
func=func1,
2611+
config=config,
2612+
)
2613+
2614+
# Mock execution state with failed operation
2615+
mock_execution_state = Mock()
2616+
2617+
def mock_get_checkpoint_result(operation_id):
2618+
mock_result = Mock()
2619+
mock_result.is_succeeded.return_value = False
2620+
mock_result.is_failed.return_value = True
2621+
mock_result.error = Exception("Test error")
2622+
return mock_result
2623+
2624+
mock_execution_state.get_checkpoint_result = mock_get_checkpoint_result
2625+
2626+
# Mock executor context
2627+
mock_executor_context = Mock()
2628+
mock_executor_context._create_step_id_for_logical_step = Mock(return_value="op_1") # noqa: SLF001
2629+
2630+
result = executor.replay(mock_execution_state, mock_executor_context)
2631+
2632+
assert isinstance(result, BatchResult)
2633+
assert len(result.all) == 1
2634+
assert result.all[0].status == BatchItemStatus.FAILED
2635+
assert result.all[0].error is not None
2636+
2637+
2638+
def test_concurrent_executor_replay_with_replay_children():
2639+
"""Test ConcurrentExecutor replay method when children need re-execution."""
2640+
2641+
def func1(ctx, item, idx, items):
2642+
return f"result_{item}"
2643+
2644+
items = ["a"]
2645+
config = MapConfig()
2646+
2647+
executor = MapExecutor.from_items(
2648+
items=items,
2649+
func=func1,
2650+
config=config,
2651+
)
2652+
2653+
# Mock execution state with succeeded operation that needs replay
2654+
mock_execution_state = Mock()
2655+
2656+
def mock_get_checkpoint_result(operation_id):
2657+
mock_result = Mock()
2658+
mock_result.is_succeeded.return_value = True
2659+
mock_result.is_failed.return_value = False
2660+
mock_result.is_replay_children.return_value = True
2661+
return mock_result
2662+
2663+
mock_execution_state.get_checkpoint_result = mock_get_checkpoint_result
2664+
2665+
# Mock executor context
2666+
mock_executor_context = Mock()
2667+
mock_executor_context._create_step_id_for_logical_step = Mock(return_value="op_1") # noqa: SLF001
2668+
2669+
# Mock _execute_item_in_child_context to return a result
2670+
with patch.object(
2671+
executor, "_execute_item_in_child_context", return_value="re_executed_result"
2672+
):
2673+
result = executor.replay(mock_execution_state, mock_executor_context)
2674+
2675+
assert isinstance(result, BatchResult)
2676+
assert len(result.all) == 1
2677+
assert result.all[0].status == BatchItemStatus.SUCCEEDED
2678+
assert result.all[0].result == "re_executed_result"

0 commit comments

Comments
 (0)