Skip to content

Commit b67f72c

Browse files
committed
feat: Implement efficient replay mechanism
Parallel and Map now replays without using threading. Changes: - context.map and context.parallel now call child_handler directly and create their own operation identifier - `map_handler` and `parallel_handler` now checks the state of the operation and when we've already succeeded, we use `executor.replay` We test this by mocking the state, and ensuring we first run through execute and then through replay and get back the correct results.
1 parent f01f6eb commit b67f72c

File tree

8 files changed

+915
-128
lines changed

8 files changed

+915
-128
lines changed

src/aws_durable_execution_sdk_python/concurrency.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from aws_durable_execution_sdk_python.identifier import OperationIdentifier
2323
from aws_durable_execution_sdk_python.lambda_service import ErrorObject
2424
from aws_durable_execution_sdk_python.operation.child import child_handler
25+
from aws_durable_execution_sdk_python.serdes import deserialize
2526
from aws_durable_execution_sdk_python.types import BatchResult as BatchResultProtocol
2627

2728
if TYPE_CHECKING:
@@ -824,5 +825,46 @@ def run_in_child_handler():
824825
),
825826
)
826827

828+
def replay(self, execution_state: ExecutionState, executor_context: DurableContext):
829+
# if we are here, then we are in replay_children
830+
# we will generate again all operation_ids for the children, and collect them
831+
items: list[BatchItem[ResultType]] = []
832+
for executable in self.executables:
833+
operation_id = executor_context._create_step_id_for_logical_step( # noqa: SLF001
834+
executable.index
835+
)
836+
checkpoint = execution_state.get_checkpoint_result(operation_id)
837+
838+
result: ResultType | None = None
839+
error = None
840+
status: BatchItemStatus
841+
if checkpoint.is_succeeded():
842+
status = BatchItemStatus.SUCCEEDED
843+
if checkpoint.is_replay_children():
844+
result = self._execute_item_in_child_context(
845+
executor_context, executable
846+
)
847+
else:
848+
serialized_result = checkpoint.result
849+
if serialized_result is None:
850+
result = None
851+
else:
852+
result = deserialize( # type: ignore[assignment]
853+
serdes=self.item_serdes or self.serdes,
854+
data=serialized_result,
855+
operation_id=operation_id,
856+
durable_execution_arn=execution_state.durable_execution_arn,
857+
)
858+
859+
elif checkpoint.is_failed():
860+
error = checkpoint.error
861+
status = BatchItemStatus.FAILED
862+
else:
863+
status = BatchItemStatus.STARTED
864+
865+
batch_item = BatchItem(executable.index, status, result=result, error=error)
866+
items.append(batch_item)
867+
return BatchResult.from_items(items, self.completion_config)
868+
827869

828870
# endregion concurrency logic

src/aws_durable_execution_sdk_python/context.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,13 @@ def map(
317317
"""Execute a callable for each item in parallel."""
318318
map_name: str | None = self._resolve_step_name(name, func)
319319

320-
def map_in_child_context(map_context) -> BatchResult[R]:
320+
operation_id = self._create_step_id()
321+
operation_identifier = OperationIdentifier(
322+
operation_id=operation_id, parent_id=self._parent_id, name=map_name
323+
)
324+
map_context = self.create_child_context(parent_id=operation_id)
325+
326+
def map_in_child_context() -> BatchResult[R]:
321327
# map_context is a child_context of the context upon which `.map`
322328
# was called. We are calling it `map_context` to make it explicit
323329
# that any operations happening from hereon are done on the context
@@ -328,11 +334,13 @@ def map_in_child_context(map_context) -> BatchResult[R]:
328334
config=config,
329335
execution_state=self.state,
330336
map_context=map_context,
337+
operation_identifier=operation_identifier,
331338
)
332339

333-
return self.run_in_child_context(
340+
return child_handler(
334341
func=map_in_child_context,
335-
name=map_name,
342+
state=self.state,
343+
operation_identifier=operation_identifier,
336344
config=ChildConfig(
337345
sub_type=OperationSubType.MAP,
338346
serdes=config.serdes if config is not None else None,
@@ -346,8 +354,12 @@ def parallel(
346354
config: ParallelConfig | None = None,
347355
) -> BatchResult[T]:
348356
"""Execute multiple callables in parallel."""
357+
# _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id
358+
operation_id = self._create_step_id()
359+
parallel_context = self.create_child_context(parent_id=operation_id)
360+
operation_identifier: OperationIdentifier
349361

350-
def parallel_in_child_context(parallel_context) -> BatchResult[T]:
362+
def parallel_in_child_context() -> BatchResult[T]:
351363
# parallel_context is a child_context of the context upon which `.map`
352364
# was called. We are calling it `parallel_context` to make it explicit
353365
# that any operations happening from hereon are done on the context
@@ -357,11 +369,18 @@ def parallel_in_child_context(parallel_context) -> BatchResult[T]:
357369
config=config,
358370
execution_state=self.state,
359371
parallel_context=parallel_context,
372+
operation_identifier=operation_identifier,
360373
)
361374

362-
return self.run_in_child_context(
375+
step_name: str | None = self._resolve_step_name(name, parallel_in_child_context)
376+
operation_identifier = OperationIdentifier(
377+
operation_id=operation_id, parent_id=self._parent_id, name=step_name
378+
)
379+
380+
return child_handler(
363381
func=parallel_in_child_context,
364-
name=name,
382+
state=self.state,
383+
operation_identifier=operation_identifier,
365384
config=ChildConfig(
366385
sub_type=OperationSubType.PARALLEL,
367386
serdes=config.serdes if config is not None else None,

src/aws_durable_execution_sdk_python/operation/map.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
if TYPE_CHECKING:
1919
from aws_durable_execution_sdk_python.context import DurableContext
20+
from aws_durable_execution_sdk_python.identifier import OperationIdentifier
2021
from aws_durable_execution_sdk_python.serdes import SerDes
2122
from aws_durable_execution_sdk_python.state import ExecutionState
2223
from aws_durable_execution_sdk_python.types import SummaryGenerator
@@ -94,6 +95,7 @@ def map_handler(
9495
config: MapConfig | None,
9596
execution_state: ExecutionState,
9697
map_context: DurableContext,
98+
operation_identifier: OperationIdentifier,
9799
) -> BatchResult[R]:
98100
"""Execute a callable for each item in parallel."""
99101
# Summary Generator Construction (matches TypeScript implementation):
@@ -107,6 +109,14 @@ def map_handler(
107109
func=func,
108110
config=config or MapConfig(summary_generator=MapSummaryGenerator()),
109111
)
112+
113+
checkpoint = execution_state.get_checkpoint_result(
114+
operation_identifier.operation_id
115+
)
116+
if checkpoint.is_succeeded():
117+
# if we've reached this point, then not only is the step succeeded, but it is also `replay_children`.
118+
# `map_handler` only executes on replay children and gives back
119+
return executor.replay(execution_state, map_context)
110120
# we are making it explicit that we are now executing within the map_context
111121
return executor.execute(execution_state, executor_context=map_context)
112122

src/aws_durable_execution_sdk_python/operation/parallel.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
if TYPE_CHECKING:
1515
from aws_durable_execution_sdk_python.concurrency import BatchResult
1616
from aws_durable_execution_sdk_python.context import DurableContext
17+
from aws_durable_execution_sdk_python.identifier import OperationIdentifier
1718
from aws_durable_execution_sdk_python.serdes import SerDes
1819
from aws_durable_execution_sdk_python.state import ExecutionState
1920
from aws_durable_execution_sdk_python.types import SummaryGenerator
@@ -82,6 +83,7 @@ def parallel_handler(
8283
config: ParallelConfig | None,
8384
execution_state: ExecutionState,
8485
parallel_context: DurableContext,
86+
operation_identifier: OperationIdentifier,
8587
) -> BatchResult[R]:
8688
"""Execute multiple operations in parallel."""
8789
# Summary Generator Construction (matches TypeScript implementation):
@@ -94,6 +96,12 @@ def parallel_handler(
9496
callables,
9597
config or ParallelConfig(summary_generator=ParallelSummaryGenerator()),
9698
)
99+
100+
checkpoint = execution_state.get_checkpoint_result(
101+
operation_identifier.operation_id
102+
)
103+
if checkpoint.is_succeeded():
104+
return executor.replay(execution_state, parallel_context)
97105
return executor.execute(execution_state, executor_context=parallel_context)
98106

99107

0 commit comments

Comments
 (0)