Skip to content

Commit 1996c35

Browse files
author
Astraea Quinn S
authored
Merge branch 'main' into exit-paths
2 parents 90a4dcd + 1171521 commit 1996c35

File tree

16 files changed

+1332
-137
lines changed

16 files changed

+1332
-137
lines changed

src/aws_durable_execution_sdk_python/concurrency.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
import heapq
66
import logging
77
import threading
8-
import time
98
from abc import ABC, abstractmethod
109
from collections import Counter
1110
from concurrent.futures import Future, ThreadPoolExecutor
1211
from dataclasses import dataclass
12+
from datetime import UTC, datetime
1313
from enum import Enum
1414
from typing import TYPE_CHECKING, Generic, Self, TypeVar
1515

@@ -258,7 +258,7 @@ def __init__(self, executable: Executable[CallableType]):
258258
self.executable = executable
259259
self._status = BranchStatus.PENDING
260260
self._future: Future | None = None
261-
self._suspend_until: float | None = None
261+
self._suspend_until: datetime | None = None
262262
self._result: ResultType = None # type: ignore[assignment]
263263
self._is_result_set: bool = False
264264
self._error: Exception | None = None
@@ -293,7 +293,7 @@ def error(self) -> Exception:
293293
return self._error
294294

295295
@property
296-
def suspend_until(self) -> float | None:
296+
def suspend_until(self) -> datetime | None:
297297
"""Get suspend timestamp."""
298298
return self._suspend_until
299299

@@ -308,7 +308,7 @@ def can_resume(self) -> bool:
308308
return self._status is BranchStatus.SUSPENDED or (
309309
self._status is BranchStatus.SUSPENDED_WITH_TIMEOUT
310310
and self._suspend_until is not None
311-
and time.time() >= self._suspend_until
311+
and datetime.now(UTC) >= self._suspend_until
312312
)
313313

314314
@property
@@ -333,7 +333,7 @@ def suspend(self) -> None:
333333
self._status = BranchStatus.SUSPENDED
334334
self._suspend_until = None
335335

336-
def suspend_with_timeout(self, timestamp: float) -> None:
336+
def suspend_with_timeout(self, timestamp: datetime) -> None:
337337
"""Transition to SUSPENDED_WITH_TIMEOUT state."""
338338
self._status = BranchStatus.SUSPENDED_WITH_TIMEOUT
339339
self._suspend_until = timestamp
@@ -507,11 +507,11 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
507507
self.shutdown()
508508

509509
def schedule_resume(
510-
self, exe_state: ExecutableWithState, resume_time: float
510+
self, exe_state: ExecutableWithState, resume_time: datetime
511511
) -> None:
512512
"""Schedule a task to resume at the specified time."""
513513
with self._lock:
514-
heapq.heappush(self._pending_resumes, (resume_time, exe_state))
514+
heapq.heappush(self._pending_resumes, (resume_time.timestamp(), exe_state))
515515

516516
def shutdown(self) -> None:
517517
"""Shutdown the timer thread and cancel all pending resumes."""
@@ -534,7 +534,7 @@ def _timer_loop(self) -> None:
534534
self._shutdown.wait(timeout=0.1)
535535
continue
536536

537-
current_time = time.time()
537+
current_time = datetime.now(UTC).timestamp()
538538
if current_time >= next_resume_time:
539539
# Time to resume
540540
with self._lock:
@@ -566,6 +566,7 @@ def __init__(
566566
sub_type_iteration: OperationSubType,
567567
name_prefix: str,
568568
serdes: SerDes | None,
569+
item_serdes: SerDes | None = None,
569570
summary_generator: SummaryGenerator | None = None,
570571
):
571572
"""Initialize ConcurrentExecutor.
@@ -604,6 +605,7 @@ def __init__(
604605
)
605606
self.executables_with_state: list[ExecutableWithState] = []
606607
self.serdes = serdes
608+
self.item_serdes = item_serdes
607609

608610
@abstractmethod
609611
def execute_item(
@@ -673,7 +675,7 @@ def on_done(future: Future) -> None:
673675

674676
def should_execution_suspend(self) -> SuspendResult:
675677
"""Check if execution should suspend."""
676-
earliest_timestamp: float = float("inf")
678+
earliest_timestamp: datetime | None = None
677679
indefinite_suspend_task: (
678680
ExecutableWithState[CallableType, ResultType] | None
679681
) = None
@@ -683,16 +685,16 @@ def should_execution_suspend(self) -> SuspendResult:
683685
# Exit here! Still have tasks that can make progress, don't suspend.
684686
return SuspendResult.do_not_suspend()
685687
if exe_state.status is BranchStatus.SUSPENDED_WITH_TIMEOUT:
686-
if (
687-
exe_state.suspend_until
688-
and exe_state.suspend_until < earliest_timestamp
688+
if exe_state.suspend_until and (
689+
earliest_timestamp is None
690+
or exe_state.suspend_until < earliest_timestamp
689691
):
690692
earliest_timestamp = exe_state.suspend_until
691693
elif exe_state.status is BranchStatus.SUSPENDED:
692694
indefinite_suspend_task = exe_state
693695

694696
# All tasks are in final states and at least one of them is a suspend.
695-
if earliest_timestamp != float("inf"):
697+
if earliest_timestamp is not None:
696698
return SuspendResult.suspend(
697699
TimedSuspendExecution(
698700
"All concurrent work complete or suspended pending retry.",
@@ -797,7 +799,7 @@ def execute_in_child_context(child_context: DurableContext) -> ResultType:
797799
execute_in_child_context,
798800
f"{self.name_prefix}{executable.index}",
799801
ChildConfig(
800-
serdes=self.serdes,
802+
serdes=self.item_serdes or self.serdes,
801803
sub_type=self.sub_type_iteration,
802804
summary_generator=self.summary_generator,
803805
),

src/aws_durable_execution_sdk_python/config.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,21 @@ class ParallelConfig:
125125
Default is CompletionConfig.all_successful() which requires all branches
126126
to succeed. Other options include first_successful() and all_completed().
127127
128-
serdes: Custom serialization/deserialization configuration for parallel results.
129-
If None, uses the default serializer. This allows custom handling of
130-
complex result types or optimization for large result sets.
128+
serdes: Custom serialization/deserialization configuration for BatchResult.
129+
Applied at the handler level to serialize the entire BatchResult object.
130+
If None, uses the default JSON serializer for BatchResult.
131+
132+
Backward Compatibility: If only 'serdes' is provided (no item_serdes),
133+
it will be used for both individual functions AND BatchResult serialization
134+
to maintain existing behavior.
135+
136+
item_serdes: Custom serialization/deserialization configuration for individual functions.
137+
Applied to each function's result as tasks complete in child contexts.
138+
If None, uses the default JSON serializer for individual function results.
139+
140+
When both 'serdes' and 'item_serdes' are provided:
141+
- item_serdes: Used for individual function results in child contexts
142+
- serdes: Used for the entire BatchResult at handler level
131143
132144
summary_generator: Function to generate compact summaries for large results (>256KB).
133145
When the serialized result exceeds CHECKPOINT_SIZE_LIMIT, this generator
@@ -150,6 +162,7 @@ class ParallelConfig:
150162
default_factory=CompletionConfig.all_successful
151163
)
152164
serdes: SerDes | None = None
165+
item_serdes: SerDes | None = None
153166
summary_generator: SummaryGenerator | None = None
154167

155168

@@ -181,9 +194,21 @@ class ChildConfig(Generic[T]):
181194
matching the TypeScript ChildConfig interface behavior.
182195
183196
Args:
184-
serdes: Custom serialization/deserialization configuration for the child context data.
185-
If None, uses the default serializer. This allows different serialization
186-
strategies for child operations vs parent operations.
197+
serdes: Custom serialization/deserialization configuration for BatchResult.
198+
Applied at the handler level to serialize the entire BatchResult object.
199+
If None, uses the default JSON serializer for BatchResult.
200+
201+
Backward Compatibility: If only 'serdes' is provided (no item_serdes),
202+
it will be used for both individual items AND BatchResult serialization
203+
to maintain existing behavior.
204+
205+
item_serdes: Custom serialization/deserialization configuration for individual items.
206+
Applied to each item's result as tasks complete in child contexts.
207+
If None, uses the default JSON serializer for individual items.
208+
209+
When both 'serdes' and 'item_serdes' are provided:
210+
- item_serdes: Used for individual item results in child contexts
211+
- serdes: Used for the entire BatchResult at handler level
187212
188213
sub_type: Operation subtype identifier used for tracking and debugging.
189214
Examples: OperationSubType.MAP_ITERATION, OperationSubType.PARALLEL_BRANCH.
@@ -208,6 +233,7 @@ class ChildConfig(Generic[T]):
208233

209234
# checkpoint_mode: CheckpointMode = CheckpointMode.CHECKPOINT_AT_START_AND_FINISH
210235
serdes: SerDes | None = None
236+
item_serdes: SerDes | None = None
211237
sub_type: OperationSubType | None = None
212238
summary_generator: SummaryGenerator | None = None
213239

@@ -273,9 +299,21 @@ class MapConfig:
273299
Default allows any number of failures. Use CompletionConfig.all_successful()
274300
to require all items to succeed.
275301
276-
serdes: Custom serialization/deserialization configuration for map results.
277-
If None, uses the default serializer. This allows custom handling of
278-
complex item types or optimization for large result collections.
302+
serdes: Custom serialization/deserialization configuration for BatchResult.
303+
Applied at the handler level to serialize the entire BatchResult object.
304+
If None, uses the default JSON serializer for BatchResult.
305+
306+
Backward Compatibility: If only 'serdes' is provided (no item_serdes),
307+
it will be used for both individual items AND BatchResult serialization
308+
to maintain existing behavior.
309+
310+
item_serdes: Custom serialization/deserialization configuration for individual items.
311+
Applied to each item's result as tasks complete in child contexts.
312+
If None, uses the default JSON serializer for individual items.
313+
314+
When both 'serdes' and 'item_serdes' are provided:
315+
- item_serdes: Used for individual item results in child contexts
316+
- serdes: Used for the entire BatchResult at handler level
279317
280318
summary_generator: Function to generate compact summaries for large results (>256KB).
281319
When the serialized result exceeds CHECKPOINT_SIZE_LIMIT, this generator
@@ -298,6 +336,7 @@ class MapConfig:
298336
item_batcher: ItemBatcher = field(default_factory=ItemBatcher)
299337
completion_config: CompletionConfig = field(default_factory=CompletionConfig)
300338
serdes: SerDes | None = None
339+
item_serdes: SerDes | None = None
301340
summary_generator: SummaryGenerator | None = None
302341

303342

src/aws_durable_execution_sdk_python/context.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def map(
311311
"""Execute a callable for each item in parallel."""
312312
map_name: str | None = self._resolve_step_name(name, func)
313313

314-
def map_in_child_context(child_context):
314+
def map_in_child_context(child_context) -> BatchResult[R]:
315315
return map_handler(
316316
items=inputs,
317317
func=func,
@@ -323,7 +323,10 @@ def map_in_child_context(child_context):
323323
return self.run_in_child_context(
324324
func=map_in_child_context,
325325
name=map_name,
326-
config=ChildConfig(sub_type=OperationSubType.MAP),
326+
config=ChildConfig(
327+
sub_type=OperationSubType.MAP,
328+
serdes=config.serdes if config is not None else None,
329+
),
327330
)
328331

329332
def parallel(
@@ -345,7 +348,10 @@ def parallel_in_child_context(child_context) -> BatchResult[T]:
345348
return self.run_in_child_context(
346349
func=parallel_in_child_context,
347350
name=name,
348-
config=ChildConfig(sub_type=OperationSubType.PARALLEL),
351+
config=ChildConfig(
352+
sub_type=OperationSubType.PARALLEL,
353+
serdes=config.serdes if config is not None else None,
354+
),
349355
)
350356

351357
def run_in_child_context(

src/aws_durable_execution_sdk_python/exceptions.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,12 @@
55

66
from __future__ import annotations
77

8-
import time
98
from dataclasses import dataclass
9+
1010
from enum import Enum
1111
from typing import TYPE_CHECKING
12+
from datetime import UTC, datetime, timedelta
1213

13-
if TYPE_CHECKING:
14-
import datetime
1514

1615

1716
class TerminationReason(Enum):
@@ -150,10 +149,10 @@ class TimedSuspendExecution(SuspendExecution):
150149
This is a specialized form of SuspendExecution that includes a scheduled resume time.
151150
152151
Attributes:
153-
scheduled_timestamp (float): Unix timestamp in seconds at which to resume.
152+
scheduled_timestamp (datetime): DateTime at which to resume.
154153
"""
155154

156-
def __init__(self, message: str, scheduled_timestamp: float):
155+
def __init__(self, message: str, scheduled_timestamp: datetime):
157156
super().__init__(message)
158157
self.scheduled_timestamp = scheduled_timestamp
159158

@@ -172,23 +171,23 @@ def from_delay(cls, message: str, delay_seconds: int) -> TimedSuspendExecution:
172171
>>> exception = TimedSuspendExecution.from_delay("Waiting for callback", 30)
173172
>>> # Will suspend for 30 seconds from now
174173
"""
175-
resume_time = time.time() + delay_seconds
174+
resume_time = datetime.now(UTC) + timedelta(seconds=delay_seconds)
176175
return cls(message, scheduled_timestamp=resume_time)
177176

178177
@classmethod
179178
def from_datetime(
180-
cls, message: str, datetime_timestamp: datetime.datetime
179+
cls, message: str, datetime_timestamp: datetime
181180
) -> TimedSuspendExecution:
182181
"""Create a timed suspension with the delay calculated from now.
183182
184183
Args:
185184
message: Descriptive message for the suspension
186-
datetime_timestamp: Unix datetime timestamp in seconds at which to resume
185+
datetime_timestamp: DateTime timestamp at which to resume
187186
188187
Returns:
189188
TimedSuspendExecution: Instance with calculated resume time
190189
"""
191-
return cls(message, scheduled_timestamp=datetime_timestamp.timestamp())
190+
return cls(message, scheduled_timestamp=datetime_timestamp)
192191

193192

194193
class OrderedLockError(DurableExecutionsError):
@@ -243,3 +242,7 @@ def __str__(self) -> str:
243242
A string in the format "type: message"
244243
"""
245244
return f"{self.type}: {self.message}"
245+
246+
247+
class SerDesError(DurableExecutionsError):
248+
"""Raised when serialization fails."""

src/aws_durable_execution_sdk_python/execution.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
import logging
55
from dataclasses import dataclass
66
from enum import Enum
7+
from functools import wraps
78
from typing import TYPE_CHECKING, Any
9+
from warnings import deprecated
810

911
from aws_durable_execution_sdk_python.context import DurableContext, ExecutionState
1012
from aws_durable_execution_sdk_python.exceptions import (
@@ -188,7 +190,7 @@ def create_succeeded(cls, result: str) -> DurableExecutionInvocationOutput:
188190
# endregion Invocation models
189191

190192

191-
def durable_handler(
193+
def durable_execution(
192194
func: Callable[[Any, DurableContext], Any],
193195
) -> Callable[[Any, LambdaContext], Any]:
194196
logger.debug("Starting durable execution handler...")
@@ -315,3 +317,9 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
315317
).to_dict()
316318

317319
return wrapper
320+
321+
322+
@deprecated("Use `durable_execution` instead.")
323+
@wraps(durable_execution)
324+
def durable_handler(*args, **kwargs):
325+
return durable_execution(*args, **kwargs)

src/aws_durable_execution_sdk_python/operation/map.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from aws_durable_execution_sdk_python.state import ExecutionState
2222
from aws_durable_execution_sdk_python.types import DurableContext, SummaryGenerator
2323

24-
2524
logger = logging.getLogger(__name__)
2625

2726
# Input item type
@@ -42,6 +41,7 @@ def __init__(
4241
name_prefix: str,
4342
serdes: SerDes | None,
4443
summary_generator: SummaryGenerator | None = None,
44+
item_serdes: SerDes | None = None,
4545
):
4646
super().__init__(
4747
executables=executables,
@@ -52,6 +52,7 @@ def __init__(
5252
name_prefix=name_prefix,
5353
serdes=serdes,
5454
summary_generator=summary_generator,
55+
item_serdes=item_serdes,
5556
)
5657
self.items = items
5758

0 commit comments

Comments
 (0)