Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 37 additions & 18 deletions src/aws_durable_execution_sdk_python/concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,20 @@
SuspendExecution,
TimedSuspendExecution,
)
from aws_durable_execution_sdk_python.identifier import OperationIdentifier
from aws_durable_execution_sdk_python.lambda_service import ErrorObject
from aws_durable_execution_sdk_python.operation.child import child_handler
from aws_durable_execution_sdk_python.types import BatchResult as BatchResultProtocol

if TYPE_CHECKING:
from collections.abc import Callable

from aws_durable_execution_sdk_python.config import CompletionConfig
from aws_durable_execution_sdk_python.context import DurableContext
from aws_durable_execution_sdk_python.lambda_service import OperationSubType
from aws_durable_execution_sdk_python.serdes import SerDes
from aws_durable_execution_sdk_python.state import ExecutionState
from aws_durable_execution_sdk_python.types import DurableContext, SummaryGenerator
from aws_durable_execution_sdk_python.types import SummaryGenerator


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -615,12 +618,7 @@ def execute_item(
raise NotImplementedError

def execute(
self,
execution_state: ExecutionState,
run_in_child_context: Callable[
[Callable[[DurableContext], ResultType], str | None, ChildConfig | None],
ResultType,
],
self, execution_state: ExecutionState, executor_context: DurableContext
) -> BatchResult[ResultType]:
"""Execute items concurrently with event-driven state management."""
logger.debug(
Expand Down Expand Up @@ -649,7 +647,7 @@ def submit_task(executable_with_state: ExecutableWithState) -> None:
"""Submit task to the thread executor and mark its state as started."""
future = thread_executor.submit(
self._execute_item_in_child_context,
run_in_child_context,
executor_context,
executable_with_state.executable,
)
executable_with_state.run(future)
Expand Down Expand Up @@ -784,21 +782,42 @@ def _create_result(self) -> BatchResult[ResultType]:

def _execute_item_in_child_context(
self,
run_in_child_context: Callable[
[Callable[[DurableContext], ResultType], str | None, ChildConfig | None],
ResultType,
],
executor_context: DurableContext,
executable: Executable[CallableType],
) -> ResultType:
"""Execute a single item in a child context."""
"""
Execute a single item in a derived child context.

instead of relying on `executor_context.run_in_child_context`
we generate an operation_id for the child, and then call `child_handler`
directly. This avoids the hidden mutation of the context's internal counter.
we can do this because we explicitly control the generation of step_id and do it
using executable.index.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

phrasing.

"This avoids the hidden mutation of the context's internal counter.
we can do this because we explicitly control the generation of step_id and do it
using executable.index."



invariant: `operation_id` for a given executable is deterministic,
and execution order invariant.
"""

operation_id = executor_context._create_step_id_for_logical_step( # noqa: SLF001
executable.index
)
name = f"{self.name_prefix}{executable.index}"
child_context = executor_context.create_child_context(operation_id)
operation_identifier = OperationIdentifier(
operation_id,
executor_context._parent_id, # noqa: SLF001
name,
)

def execute_in_child_context(child_context: DurableContext) -> ResultType:
def run_in_child_handler():
return self.execute_item(child_context, executable)

return run_in_child_context(
execute_in_child_context,
f"{self.name_prefix}{executable.index}",
ChildConfig(
return child_handler(
run_in_child_handler,
child_context.state,
operation_identifier=operation_identifier,
config=ChildConfig(
serdes=self.item_serdes or self.serdes,
sub_type=self.sub_type_iteration,
summary_generator=self.summary_generator,
Expand Down
30 changes: 22 additions & 8 deletions src/aws_durable_execution_sdk_python/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,17 +222,23 @@ def set_logger(self, new_logger: LoggerInterface):
info=self._log_info,
)

def _create_step_id_for_logical_step(self, step: int) -> str:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

step maybe better described as counter

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_create_step_id_with_preset_counter

"""
Generate a step_id based on the given logical step.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on the given counter.

"logical step" doesn't really have a well defined meaning.

This allows us to recover operation ids or even look
forward without changing the internal state of this context.
"""
step_id = f"{self._parent_id}-{step}" if self._parent_id else str(step)
return hashlib.blake2b(step_id.encode()).hexdigest()[:64]

def _create_step_id(self) -> str:
"""Generate a thread-safe step id, incrementing in order of invocation.

This method is an internal implementation detail. Do not rely the exact format of
the id generated by this method. It is subject to change without notice.
"""
new_counter: int = self._step_counter.increment()
step_id = (
f"{self._parent_id}-{new_counter}" if self._parent_id else str(new_counter)
)
return hashlib.blake2b(step_id.encode()).hexdigest()[:64]
return self._create_step_id_for_logical_step(new_counter)

# region Operations

Expand Down Expand Up @@ -311,13 +317,17 @@ def map(
"""Execute a callable for each item in parallel."""
map_name: str | None = self._resolve_step_name(name, func)

def map_in_child_context(child_context) -> BatchResult[R]:
def map_in_child_context(map_context) -> BatchResult[R]:
# map_context is a child_context of the context upon which `.map`
# was called. We are calling it `map_context` to make it explicit
# that any operations happening from hereon are done on the context
# that owns the branches
return map_handler(
items=inputs,
func=func,
config=config,
execution_state=self.state,
run_in_child_context=child_context.run_in_child_context,
map_context=map_context,
)

return self.run_in_child_context(
Expand All @@ -337,12 +347,16 @@ def parallel(
) -> BatchResult[T]:
"""Execute multiple callables in parallel."""

def parallel_in_child_context(child_context) -> BatchResult[T]:
def parallel_in_child_context(parallel_context) -> BatchResult[T]:
# parallel_context is a child_context of the context upon which `.map`
# was called. We are calling it `parallel_context` to make it explicit
# that any operations happening from hereon are done on the context
# that owns the branches
return parallel_handler(
callables=functions,
config=config,
execution_state=self.state,
run_in_child_context=child_context.run_in_child_context,
parallel_context=parallel_context,
)

return self.run_in_child_context(
Expand Down
11 changes: 5 additions & 6 deletions src/aws_durable_execution_sdk_python/operation/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
from aws_durable_execution_sdk_python.lambda_service import OperationSubType

if TYPE_CHECKING:
from aws_durable_execution_sdk_python.config import ChildConfig
from aws_durable_execution_sdk_python.context import DurableContext
from aws_durable_execution_sdk_python.serdes import SerDes
from aws_durable_execution_sdk_python.state import ExecutionState
from aws_durable_execution_sdk_python.types import DurableContext, SummaryGenerator
from aws_durable_execution_sdk_python.types import SummaryGenerator

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -93,9 +93,7 @@ def map_handler(
func: Callable,
config: MapConfig | None,
execution_state: ExecutionState,
run_in_child_context: Callable[
[Callable[[DurableContext], R], str | None, ChildConfig | None], R
],
map_context: DurableContext,
) -> BatchResult[R]:
"""Execute a callable for each item in parallel."""
# Summary Generator Construction (matches TypeScript implementation):
Expand All @@ -109,7 +107,8 @@ def map_handler(
func=func,
config=config or MapConfig(summary_generator=MapSummaryGenerator()),
)
return executor.execute(execution_state, run_in_child_context)
# we are making it explicit that we are now executing within the map_context
return executor.execute(execution_state, executor_context=map_context)


class MapSummaryGenerator:
Expand Down
10 changes: 4 additions & 6 deletions src/aws_durable_execution_sdk_python/operation/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@

if TYPE_CHECKING:
from aws_durable_execution_sdk_python.concurrency import BatchResult
from aws_durable_execution_sdk_python.config import ChildConfig
from aws_durable_execution_sdk_python.context import DurableContext
from aws_durable_execution_sdk_python.serdes import SerDes
from aws_durable_execution_sdk_python.state import ExecutionState
from aws_durable_execution_sdk_python.types import DurableContext, SummaryGenerator
from aws_durable_execution_sdk_python.types import SummaryGenerator

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -81,9 +81,7 @@ def parallel_handler(
callables: Sequence[Callable],
config: ParallelConfig | None,
execution_state: ExecutionState,
run_in_child_context: Callable[
[Callable[[DurableContext], R], str | None, ChildConfig | None], R
],
parallel_context: DurableContext,
) -> BatchResult[R]:
"""Execute multiple operations in parallel."""
# Summary Generator Construction (matches TypeScript implementation):
Expand All @@ -96,7 +94,7 @@ def parallel_handler(
callables,
config or ParallelConfig(summary_generator=ParallelSummaryGenerator()),
)
return executor.execute(execution_state, run_in_child_context)
return executor.execute(execution_state, executor_context=parallel_context)


class ParallelSummaryGenerator:
Expand Down
Loading