Skip to content

Commit 9439f6d

Browse files
author
Alex Wang
committed
feat: exit early when pending
- suspend execution of step and wait_for_condition when the checkpointed result is pending - test cases
1 parent 1a6262f commit 9439f6d

6 files changed

Lines changed: 209 additions & 14 deletions

File tree

src/aws_durable_execution_sdk_python/operation/step.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,14 @@
1414
from aws_durable_execution_sdk_python.exceptions import (
1515
FatalError,
1616
StepInterruptedError,
17+
SuspendExecution,
1718
TimedSuspendExecution,
1819
)
19-
from aws_durable_execution_sdk_python.lambda_service import ErrorObject, OperationUpdate
20+
from aws_durable_execution_sdk_python.lambda_service import (
21+
ErrorObject,
22+
Operation,
23+
OperationUpdate,
24+
)
2025
from aws_durable_execution_sdk_python.logger import Logger, LogInfo
2126
from aws_durable_execution_sdk_python.retries import RetryPresets
2227
from aws_durable_execution_sdk_python.serdes import deserialize, serialize
@@ -52,7 +57,9 @@ def step_handler(
5257
if not config:
5358
config = StepConfig()
5459

55-
checkpointed_result = state.get_checkpoint_result(operation_identifier.operation_id)
60+
checkpointed_result: CheckpointedResult = state.get_checkpoint_result(
61+
operation_identifier.operation_id
62+
)
5663
if checkpointed_result.is_succeeded():
5764
logger.debug(
5865
"Step already completed, skipping execution for id: %s, name: %s",
@@ -73,6 +80,12 @@ def step_handler(
7380
# have to throw the exact same error on replay as the checkpointed failure
7481
checkpointed_result.raise_callable_error()
7582

83+
if checkpointed_result.is_pending():
84+
_suspend_with_operation(
85+
operation_identifier=operation_identifier,
86+
operation=checkpointed_result.operation,
87+
)
88+
7689
if checkpointed_result.is_started():
7790
# step was previously interrupted
7891
if config.step_semantics is StepSemantics.AT_MOST_ONCE_PER_RETRY:
@@ -193,7 +206,7 @@ def retry_handler(
193206

194207
state.create_checkpoint(operation_update=retry_operation)
195208

196-
_suspend(operation_identifier, retry_decision)
209+
_suspend_with_decision(operation_identifier, retry_decision)
197210

198211
# no retry
199212
fail_operation: OperationUpdate = OperationUpdate.create_step_fail(
@@ -208,10 +221,30 @@ def retry_handler(
208221
raise error_object.to_callable_runtime_error()
209222

210223

211-
def _suspend(operation_identifier: OperationIdentifier, retry_decision: RetryDecision):
224+
def _suspend_with_decision(
225+
operation_identifier: OperationIdentifier, retry_decision: RetryDecision
226+
):
212227
scheduled_timestamp = time.time() + retry_decision.delay_seconds
213228
msg = f"Retry scheduled for {operation_identifier.operation_id} in {retry_decision.delay_seconds} seconds"
214229
raise TimedSuspendExecution(
215230
msg,
216231
scheduled_timestamp=scheduled_timestamp,
217232
)
233+
234+
235+
def _suspend_with_operation(
236+
operation_identifier: OperationIdentifier, operation: Operation | None
237+
) -> None:
238+
if (
239+
operation is None
240+
or operation.step_details is None
241+
or operation.step_details.next_attempt_timestamp is None
242+
):
243+
msg = f"next_attempt_timestamp is None for {operation_identifier.operation_id}"
244+
raise SuspendExecution(msg)
245+
scheduled_timestamp = float(operation.step_details.next_attempt_timestamp)
246+
msg = f"Retry scheduled for {operation_identifier.name or operation_identifier.operation_id} will retry at timestamp {scheduled_timestamp}"
247+
raise TimedSuspendExecution(
248+
msg,
249+
scheduled_timestamp=scheduled_timestamp,
250+
)

src/aws_durable_execution_sdk_python/operation/wait_for_condition.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,14 @@
88

99
from aws_durable_execution_sdk_python.exceptions import (
1010
FatalError,
11+
SuspendExecution,
1112
TimedSuspendExecution,
1213
)
13-
from aws_durable_execution_sdk_python.lambda_service import ErrorObject, OperationUpdate
14+
from aws_durable_execution_sdk_python.lambda_service import (
15+
ErrorObject,
16+
Operation,
17+
OperationUpdate,
18+
)
1419
from aws_durable_execution_sdk_python.logger import LogInfo
1520
from aws_durable_execution_sdk_python.serdes import deserialize, serialize
1621
from aws_durable_execution_sdk_python.types import WaitForConditionCheckContext
@@ -24,7 +29,10 @@
2429
)
2530
from aws_durable_execution_sdk_python.identifier import OperationIdentifier
2631
from aws_durable_execution_sdk_python.logger import Logger
27-
from aws_durable_execution_sdk_python.state import ExecutionState
32+
from aws_durable_execution_sdk_python.state import (
33+
CheckpointedResult,
34+
ExecutionState,
35+
)
2836

2937

3038
T = TypeVar("T")
@@ -49,7 +57,9 @@ def wait_for_condition_handler(
4957
operation_identifier.name,
5058
)
5159

52-
checkpointed_result = state.get_checkpoint_result(operation_identifier.operation_id)
60+
checkpointed_result: CheckpointedResult = state.get_checkpoint_result(
61+
operation_identifier.operation_id
62+
)
5363

5464
# Check if already completed
5565
if checkpointed_result.is_succeeded():
@@ -70,6 +80,12 @@ def wait_for_condition_handler(
7080
if checkpointed_result.is_failed():
7181
checkpointed_result.raise_callable_error()
7282

83+
if checkpointed_result.is_pending():
84+
_suspend_execution_with_operation(
85+
operation_identifier=operation_identifier,
86+
operation=checkpointed_result.operation,
87+
)
88+
7389
attempt: int = 1
7490
if checkpointed_result.is_started_or_ready():
7591
# This is a retry - get state from previous checkpoint
@@ -164,7 +180,7 @@ def wait_for_condition_handler(
164180

165181
state.create_checkpoint(operation_update=retry_operation)
166182

167-
_suspend_execution(operation_identifier, decision)
183+
_suspend_execution_with_decision(operation_identifier, decision)
168184

169185
except Exception as e:
170186
# Mark as failed - waitForCondition doesn't have its own retry logic for errors
@@ -186,7 +202,7 @@ def wait_for_condition_handler(
186202
raise FatalError(msg)
187203

188204

189-
def _suspend_execution(
205+
def _suspend_execution_with_decision(
190206
operation_identifier: OperationIdentifier, decision: WaitForConditionDecision
191207
) -> None:
192208
scheduled_timestamp = time.time() + (decision.delay_seconds or 0)
@@ -195,3 +211,21 @@ def _suspend_execution(
195211
msg,
196212
scheduled_timestamp=scheduled_timestamp,
197213
)
214+
215+
216+
def _suspend_execution_with_operation(
217+
operation_identifier: OperationIdentifier, operation: Operation | None
218+
) -> None:
219+
if (
220+
operation is None
221+
or operation.step_details is None
222+
or operation.step_details.next_attempt_timestamp is None
223+
):
224+
msg = f"next_attempt_timestamp is None for {operation_identifier.operation_id}"
225+
raise SuspendExecution(msg)
226+
scheduled_timestamp = float(operation.step_details.next_attempt_timestamp)
227+
msg = f"wait_for_condition {operation_identifier.name or operation_identifier.operation_id} will retry at timestamp {scheduled_timestamp}"
228+
raise TimedSuspendExecution(
229+
msg,
230+
scheduled_timestamp=scheduled_timestamp,
231+
)

src/aws_durable_execution_sdk_python/state.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,13 @@ def is_started_or_ready(self) -> bool:
106106
return False
107107
return op.status in (OperationStatus.STARTED, OperationStatus.READY)
108108

109+
def is_pending(self) -> bool:
110+
"""Return True if the checkpointed operation is PENDING."""
111+
op = self.operation
112+
if not op:
113+
return False
114+
return op.status is OperationStatus.PENDING
115+
109116
def is_timed_out(self) -> bool:
110117
"""Return True if the checkpointed operation is TIMED_OUT."""
111118
op = self.operation

tests/operation/step_test.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ def test_step_handler_retry_with_existing_attempts():
375375
operation_id="step12",
376376
operation_type=OperationType.STEP,
377377
status=OperationStatus.PENDING,
378-
step_details=StepDetails(attempt=2),
378+
step_details=StepDetails(attempt=2, next_attempt_timestamp="1764547200"),
379379
)
380380
mock_result = CheckpointedResult.create_from_operation(operation)
381381
mock_state.get_checkpoint_result.return_value = mock_result
@@ -398,10 +398,44 @@ def test_step_handler_retry_with_existing_attempts():
398398
mock_logger,
399399
)
400400

401-
# Verify retry strategy was called with correct attempt count (2 + 1 = 3)
402-
mock_retry_strategy.assert_called_once()
403-
call_args = mock_retry_strategy.call_args[0]
404-
assert call_args[1] == 3 # retry_attempt + 1
401+
# Verify retry strategy was not called because we already have attempt timestamp in the checkpointed location
402+
mock_retry_strategy.assert_not_called()
403+
404+
405+
def test_step_handler_pending_without_existing_attempts():
406+
"""Test step_handler retry logic with existing attempt count."""
407+
mock_state = Mock(spec=ExecutionState)
408+
409+
# Simulate a retry operation that was previously checkpointed
410+
operation = Operation(
411+
operation_id="step12",
412+
operation_type=OperationType.STEP,
413+
status=OperationStatus.PENDING,
414+
step_details=StepDetails(attempt=2),
415+
)
416+
mock_result = CheckpointedResult.create_from_operation(operation)
417+
mock_state.get_checkpoint_result.return_value = mock_result
418+
mock_state.durable_execution_arn = "test_arn"
419+
420+
mock_retry_strategy = Mock(
421+
return_value=RetryDecision(should_retry=True, delay_seconds=10)
422+
)
423+
config = StepConfig(retry_strategy=mock_retry_strategy)
424+
mock_callable = Mock(side_effect=RuntimeError("Test error"))
425+
mock_logger = Mock(spec=Logger)
426+
mock_logger.with_log_info.return_value = mock_logger
427+
428+
with pytest.raises(SuspendExecution, match="next_attempt_timestamp is None"):
429+
step_handler(
430+
mock_callable,
431+
mock_state,
432+
OperationIdentifier("step12", None, "test_step"),
433+
config,
434+
mock_logger,
435+
)
436+
437+
# Verify retry strategy was not called because we already have attempt timestamp in the checkpointed location
438+
mock_retry_strategy.assert_not_called()
405439

406440

407441
@patch("aws_durable_execution_sdk_python.operation.step.retry_handler")

tests/operation/wait_for_condition_test.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
)
1212
from aws_durable_execution_sdk_python.exceptions import (
1313
CallableRuntimeError,
14+
FatalError,
1415
SuspendExecution,
1516
)
1617
from aws_durable_execution_sdk_python.identifier import OperationIdentifier
@@ -372,6 +373,7 @@ def test_wait_for_condition_no_operation_in_checkpoint():
372373
mock_result = Mock()
373374
mock_result.is_succeeded.return_value = False
374375
mock_result.is_failed.return_value = False
376+
mock_result.is_pending.return_value = False
375377
mock_result.is_started_or_ready.return_value = True
376378
mock_result.is_existent.return_value = True
377379
mock_result.result = json.dumps(10)
@@ -416,6 +418,7 @@ def test_wait_for_condition_operation_no_step_details():
416418
mock_result = Mock()
417419
mock_result.is_succeeded.return_value = False
418420
mock_result.is_failed.return_value = False
421+
mock_result.is_pending.return_value = False
419422
mock_result.is_started_or_ready.return_value = True
420423
mock_result.is_existent.return_value = True
421424
mock_result.result = json.dumps(10)
@@ -651,3 +654,72 @@ def check_func(state, context):
651654
)
652655

653656
assert result == {"key": "value", "number": 42, "list": [1, 2, 3]}
657+
658+
659+
def test_wait_for_condition_pending():
660+
mock_state = Mock(spec=ExecutionState)
661+
mock_state.durable_execution_arn = "arn:aws:test"
662+
operation = Operation(
663+
operation_id="XXX",
664+
operation_type=OperationType.STEP,
665+
status=OperationStatus.PENDING,
666+
step_details=StepDetails(
667+
result='{"key": "VALUE", "number": "84", "list": [1, 2, 3]}',
668+
next_attempt_timestamp="1764547200",
669+
),
670+
)
671+
mock_result = CheckpointedResult.create_from_operation(operation)
672+
mock_state.get_checkpoint_result.return_value = mock_result
673+
674+
mock_logger = Mock(spec=Logger)
675+
mock_logger.with_log_info.return_value = mock_logger
676+
677+
op_id = OperationIdentifier("op1", None, "test_wait")
678+
679+
def check_func(state, context):
680+
msg = "Should not be called"
681+
raise FatalError(msg)
682+
683+
config = WaitForConditionConfig(
684+
initial_state=5,
685+
wait_strategy=lambda s, a: WaitForConditionDecision.stop_polling(),
686+
serdes=CustomDictSerDes(),
687+
)
688+
689+
with pytest.raises(
690+
SuspendExecution, match="wait_for_condition test_wait will retry at timestamp"
691+
):
692+
wait_for_condition_handler(check_func, config, mock_state, op_id, mock_logger)
693+
694+
695+
def test_wait_for_condition_pending_without_next_attempt():
696+
mock_state = Mock(spec=ExecutionState)
697+
mock_state.durable_execution_arn = "arn:aws:test"
698+
operation = Operation(
699+
operation_id="XXX",
700+
operation_type=OperationType.STEP,
701+
status=OperationStatus.PENDING,
702+
step_details=StepDetails(
703+
result='{"key": "VALUE", "number": "84", "list": [1, 2, 3]}',
704+
),
705+
)
706+
mock_result = CheckpointedResult.create_from_operation(operation)
707+
mock_state.get_checkpoint_result.return_value = mock_result
708+
709+
mock_logger = Mock(spec=Logger)
710+
mock_logger.with_log_info.return_value = mock_logger
711+
712+
op_id = OperationIdentifier("op1", None, "test_wait")
713+
714+
def check_func(state, context):
715+
msg = "Should not be called"
716+
raise FatalError(msg)
717+
718+
config = WaitForConditionConfig(
719+
initial_state=5,
720+
wait_strategy=lambda s, a: WaitForConditionDecision.stop_polling(),
721+
serdes=CustomDictSerDes(),
722+
)
723+
724+
with pytest.raises(SuspendExecution, match="next_attempt_timestamp is None"):
725+
wait_for_condition_handler(check_func, config, mock_state, op_id, mock_logger)

tests/state_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,21 @@ def test_checkpointed_result_is_failed():
214214
assert result_no_op.is_failed() is False
215215

216216

217+
def test_checkpointerd_result_is_pending():
218+
"""Test CheckpointedResult.is_pending method."""
219+
operation = Operation(
220+
operation_id="op1",
221+
operation_type=OperationType.STEP,
222+
status=OperationStatus.PENDING,
223+
)
224+
result = CheckpointedResult.create_from_operation(operation)
225+
assert result.is_pending() is True
226+
227+
# Test with no operation
228+
result_no_op = CheckpointedResult.create_not_found()
229+
assert result_no_op.is_pending() is False
230+
231+
217232
def test_checkpointed_result_is_started():
218233
"""Test CheckpointedResult.is_started method."""
219234
operation = Operation(

0 commit comments

Comments
 (0)