Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/aws_durable_execution_sdk_python/operation/child.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def child_handler(
config.sub_type if config.sub_type else OperationSubType.RUN_IN_CHILD_CONTEXT
)

if not checkpointed_result.is_started():
if not checkpointed_result.is_existent():
start_operation = OperationUpdate.create_context_start(
identifier=operation_identifier,
sub_type=sub_type,
Expand Down
31 changes: 18 additions & 13 deletions src/aws_durable_execution_sdk_python/operation/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,26 +87,31 @@ def step_handler(
datetime_timestamp=scheduled_timestamp,
)

if checkpointed_result.is_started():
if (
checkpointed_result.is_started()
and config.step_semantics is StepSemantics.AT_MOST_ONCE_PER_RETRY
):
# step was previously interrupted
if config.step_semantics is StepSemantics.AT_MOST_ONCE_PER_RETRY:
msg = f"Step operation_id={operation_identifier.operation_id} name={operation_identifier.name} was previously interrupted"
retry_handler(
StepInterruptedError(msg),
state,
operation_identifier,
config,
checkpointed_result,
)
msg = f"Step operation_id={operation_identifier.operation_id} name={operation_identifier.name} was previously interrupted"
retry_handler(
StepInterruptedError(msg),
state,
operation_identifier,
config,
checkpointed_result,
)

checkpointed_result.raise_callable_error()

if config.step_semantics is StepSemantics.AT_MOST_ONCE_PER_RETRY:
# At least once needs checkpoint at the start
if not (
checkpointed_result.is_started()
and config.step_semantics is StepSemantics.AT_LEAST_ONCE_PER_RETRY
):
# Do not checkpoint start for started & AT_LEAST_ONCE execution
# Checkpoint start for the other
start_operation: OperationUpdate = OperationUpdate.create_step_start(
identifier=operation_identifier,
)

state.create_checkpoint(operation_update=start_operation)

attempt: int = 0
Expand Down
23 changes: 15 additions & 8 deletions src/aws_durable_execution_sdk_python/operation/wait.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@

if TYPE_CHECKING:
from aws_durable_execution_sdk_python.identifier import OperationIdentifier
from aws_durable_execution_sdk_python.state import ExecutionState
from aws_durable_execution_sdk_python.state import (
CheckpointedResult,
ExecutionState,
)

logger = logging.getLogger(__name__)

Expand All @@ -25,20 +28,24 @@ def wait_handler(
operation_identifier.name,
)

if state.get_checkpoint_result(operation_identifier.operation_id).is_succeeded():
checkpointed_result: CheckpointedResult = state.get_checkpoint_result(
operation_identifier.operation_id
)

if checkpointed_result.is_succeeded():
logger.debug(
"Wait already completed, skipping wait for id: %s, name: %s",
operation_identifier.operation_id,
operation_identifier.name,
)
return

operation = OperationUpdate.create_wait_start(
identifier=operation_identifier,
wait_options=WaitOptions(seconds=seconds),
)

state.create_checkpoint(operation_update=operation)
if not checkpointed_result.is_existent():
operation = OperationUpdate.create_wait_start(
identifier=operation_identifier,
wait_options=WaitOptions(seconds=seconds),
)
state.create_checkpoint(operation_update=operation)

# Calculate when to resume
resume_time = time.time() + seconds
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def wait_for_condition_handler(
current_state = config.initial_state

# Checkpoint START for observability.
if not checkpointed_result.is_existent():
if not checkpointed_result.is_started():
start_operation: OperationUpdate = (
OperationUpdate.create_wait_for_condition_start(
identifier=operation_identifier,
Expand Down
12 changes: 9 additions & 3 deletions tests/e2e/execution_int_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ def mock_checkpoint(
== '["from step 123 str", "from step no args", "from step plain"]'
)

assert len(checkpoint_calls) == 3
# 3 START checkpoint, 3 SUCCEED checkpoint
assert len(checkpoint_calls) == 6

checkpoint = checkpoint_calls[-1][0]
assert checkpoint.operation_type is OperationType.STEP
Expand Down Expand Up @@ -205,11 +206,16 @@ def mock_checkpoint(

assert result["Status"] == InvocationStatus.SUCCEEDED.value

assert len(checkpoint_calls) == 1
# 1 START checkpoint, 1 SUCCEED checkpoint
assert len(checkpoint_calls) == 2

# Check the wait checkpoint
checkpoint = checkpoint_calls[0][0]
assert checkpoint.operation_type == OperationType.STEP
assert checkpoint.action == OperationAction.START
assert checkpoint.operation_id == "1"
# Check the wait checkpoint
checkpoint = checkpoint_calls[1][0]
assert checkpoint.operation_type == OperationType.STEP
assert checkpoint.action == OperationAction.SUCCEED
assert checkpoint.operation_id == "1"

Expand Down
5 changes: 5 additions & 0 deletions tests/operation/child_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def test_child_handler_not_started(
mock_result.is_started.return_value = False
mock_result.is_replay_children.return_value = False
mock_result.is_replay_children.return_value = False
mock_result.is_existent.return_value = False
mock_state.get_checkpoint_result.return_value = mock_result
mock_callable = Mock(return_value="fresh_result")

Expand Down Expand Up @@ -203,6 +204,7 @@ def test_child_handler_callable_exception(
mock_result.is_succeeded.return_value = False
mock_result.is_failed.return_value = False
mock_result.is_started.return_value = False
mock_result.is_existent.return_value = False
mock_state.get_checkpoint_result.return_value = mock_result
mock_callable = Mock(side_effect=ValueError("Test error"))

Expand Down Expand Up @@ -313,6 +315,7 @@ def test_child_handler_custom_serdes_not_start() -> None:
mock_result.is_failed.return_value = False
mock_result.is_started.return_value = False
mock_result.is_replay_children.return_value = False
mock_result.is_existent.return_value = False
mock_state.get_checkpoint_result.return_value = mock_result
complex_result = {"key": "value", "number": 42, "list": [1, 2, 3]}
mock_callable = Mock(return_value=complex_result)
Expand Down Expand Up @@ -372,6 +375,7 @@ def test_child_handler_large_payload_with_summary_generator() -> None:
mock_result.is_failed.return_value = False
mock_result.is_started.return_value = False
mock_result.is_replay_children.return_value = False
mock_result.is_existent.return_value = False
mock_state.get_checkpoint_result.return_value = mock_result
large_result = "large" * 256 * 1024
mock_callable = Mock(return_value=large_result)
Expand Down Expand Up @@ -406,6 +410,7 @@ def test_child_handler_large_payload_without_summary_generator() -> None:
mock_result.is_failed.return_value = False
mock_result.is_started.return_value = False
mock_result.is_replay_children.return_value = False
mock_result.is_existent.return_value = False
mock_state.get_checkpoint_result.return_value = mock_result
large_result = "large" * 256 * 1024
mock_callable = Mock(return_value=large_result)
Expand Down
57 changes: 42 additions & 15 deletions tests/operation/step_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,17 +165,16 @@ def test_step_handler_started_at_least_once():
mock_state.get_checkpoint_result.return_value = mock_result

config = StepConfig(step_semantics=StepSemantics.AT_LEAST_ONCE_PER_RETRY)
mock_callable = Mock()
mock_callable = Mock(return_value="success_result")
mock_logger = Mock(spec=Logger)

with pytest.raises(CallableRuntimeError):
step_handler(
mock_callable,
mock_state,
OperationIdentifier("step5", None, "test_step"),
config,
mock_logger,
)
step_handler(
mock_callable,
mock_state,
OperationIdentifier("step5", None, "test_step"),
config,
mock_logger,
)


def test_step_handler_success_at_least_once():
Expand All @@ -200,10 +199,18 @@ def test_step_handler_success_at_least_once():

assert result == "success_result"

assert mock_state.create_checkpoint.call_count == 1
assert mock_state.create_checkpoint.call_count == 2

# Verify start checkpoint
start_call = mock_state.create_checkpoint.call_args_list[0]
start_operation = start_call[1]["operation_update"]
assert start_operation.operation_id == "step6"
assert start_operation.operation_type is OperationType.STEP
assert start_operation.sub_type is OperationSubType.STEP
assert start_operation.action is OperationAction.START

# Verify only success checkpoint
success_call = mock_state.create_checkpoint.call_args_list[0]
# Verify success checkpoint
success_call = mock_state.create_checkpoint.call_args_list[1]
success_operation = success_call[1]["operation_update"]
assert success_operation.operation_id == "step6"
assert success_operation.payload == json.dumps("success_result")
Expand Down Expand Up @@ -299,8 +306,18 @@ def test_step_handler_retry_success():
mock_logger,
)

assert mock_state.create_checkpoint.call_count == 2

# Verify start checkpoint
start_call = mock_state.create_checkpoint.call_args_list[0]
start_operation = start_call[1]["operation_update"]
assert start_operation.operation_id == "step9"
assert start_operation.operation_type is OperationType.STEP
assert start_operation.sub_type is OperationSubType.STEP
assert start_operation.action is OperationAction.START

# Verify retry checkpoint
retry_call = mock_state.create_checkpoint.call_args_list[0]
retry_call = mock_state.create_checkpoint.call_args_list[1]
retry_operation = retry_call[1]["operation_update"]
assert retry_operation.operation_id == "step9"
assert retry_operation.operation_type is OperationType.STEP
Expand Down Expand Up @@ -332,8 +349,18 @@ def test_step_handler_retry_exhausted():
mock_logger,
)

assert mock_state.create_checkpoint.call_count == 2

# Verify start checkpoint
start_call = mock_state.create_checkpoint.call_args_list[0]
start_operation = start_call[1]["operation_update"]
assert start_operation.operation_id == "step10"
assert start_operation.operation_type is OperationType.STEP
assert start_operation.sub_type is OperationSubType.STEP
assert start_operation.action is OperationAction.START

# Verify fail checkpoint
fail_call = mock_state.create_checkpoint.call_args_list[0]
fail_call = mock_state.create_checkpoint.call_args_list[1]
fail_operation = fail_call[1]["operation_update"]
assert fail_operation.operation_id == "step10"
assert fail_operation.operation_type is OperationType.STEP
Expand Down Expand Up @@ -499,7 +526,7 @@ def test_step_handler_custom_serdes_success():
'{"key": "VALUE", "number": "84", "list": [1, 2, 3]}'
)

success_call = mock_state.create_checkpoint.call_args_list[0]
success_call = mock_state.create_checkpoint.call_args_list[1]
success_operation = success_call[1]["operation_update"]
assert success_operation.payload == expected_checkpoointed_result

Expand Down
21 changes: 21 additions & 0 deletions tests/operation/wait_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def test_wait_handler_not_completed():
mock_state = Mock(spec=ExecutionState)
mock_result = Mock(spec=CheckpointedResult)
mock_result.is_succeeded.return_value = False
mock_result.is_existent.return_value = False
mock_state.get_checkpoint_result.return_value = mock_result

with pytest.raises(SuspendExecution, match="Wait for 30 seconds"):
Expand Down Expand Up @@ -68,6 +69,7 @@ def test_wait_handler_with_none_name():
mock_state = Mock(spec=ExecutionState)
mock_result = Mock(spec=CheckpointedResult)
mock_result.is_succeeded.return_value = False
mock_result.is_existent.return_value = False
mock_state.get_checkpoint_result.return_value = mock_result

with pytest.raises(SuspendExecution, match="Wait for 5 seconds"):
Expand All @@ -88,3 +90,22 @@ def test_wait_handler_with_none_name():
mock_state.create_checkpoint.assert_called_once_with(
operation_update=expected_operation
)


def test_wait_handler_with_existent():
"""Test wait_handler with existent operation."""
mock_state = Mock(spec=ExecutionState)
mock_result = Mock(spec=CheckpointedResult)
mock_result.is_succeeded.return_value = False
mock_result.is_existent.return_value = True
mock_state.get_checkpoint_result.return_value = mock_result

with pytest.raises(SuspendExecution, match="Wait for 5 seconds"):
wait_handler(
seconds=5,
state=mock_state,
operation_identifier=OperationIdentifier("wait4", None),
)

mock_state.get_checkpoint_result.assert_called_once_with("wait4")
mock_state.create_checkpoint.assert_not_called()