diff --git a/src/aws_durable_execution_sdk_python/context.py b/src/aws_durable_execution_sdk_python/context.py index bd38af3..b7b6d10 100644 --- a/src/aws_durable_execution_sdk_python/context.py +++ b/src/aws_durable_execution_sdk_python/context.py @@ -16,7 +16,7 @@ WaitForConditionConfig, ) from aws_durable_execution_sdk_python.exceptions import ( - FatalError, + CallbackError, SuspendExecution, ValidationError, ) @@ -125,11 +125,17 @@ def result(self) -> T | None: checkpointed_result: CheckpointedResult = self.state.get_checkpoint_result( self.operation_id ) - if checkpointed_result.is_started(): - msg: str = "Calback result not received yet. Suspending execution while waiting for result." - raise SuspendExecution(msg) - if checkpointed_result.is_failed() or checkpointed_result.is_timed_out(): + if not checkpointed_result.is_existent(): + msg = "Callback operation must exist" + raise CallbackError(msg) + + if ( + checkpointed_result.is_failed() + or checkpointed_result.is_cancelled() + or checkpointed_result.is_timed_out() + or checkpointed_result.is_stopped() + ): checkpointed_result.raise_callable_error() if checkpointed_result.is_succeeded(): @@ -143,8 +149,10 @@ def result(self) -> T | None: durable_execution_arn=self.state.durable_execution_arn, ) - msg = "Callback must be started before you can await the result." - raise FatalError(msg) + # operation exists; it has not terminated (successfully or otherwise) + # therefore we should wait + msg = "Callback result not received yet. Suspending execution while waiting for result." + raise SuspendExecution(msg) class DurableContext(DurableContextProtocol): diff --git a/src/aws_durable_execution_sdk_python/exceptions.py b/src/aws_durable_execution_sdk_python/exceptions.py index b612138..caed701 100644 --- a/src/aws_durable_execution_sdk_python/exceptions.py +++ b/src/aws_durable_execution_sdk_python/exceptions.py @@ -7,19 +7,90 @@ from dataclasses import dataclass from datetime import UTC, datetime, timedelta +from enum import Enum + + +class TerminationReason(Enum): + """Reasons why a durable execution terminated.""" + + UNHANDLED_ERROR = "UNHANDLED_ERROR" + INVOCATION_ERROR = "INVOCATION_ERROR" + EXECUTION_ERROR = "EXECUTION_ERROR" + CHECKPOINT_FAILED = "CHECKPOINT_FAILED" + NON_DETERMINISTIC_EXECUTION = "NON_DETERMINISTIC_EXECUTION" + STEP_INTERRUPTED = "STEP_INTERRUPTED" + CALLBACK_ERROR = "CALLBACK_ERROR" + SERIALIZATION_ERROR = "SERIALIZATION_ERROR" class DurableExecutionsError(Exception): """Base class for Durable Executions exceptions""" -class FatalError(DurableExecutionsError): - """Unrecoverable error. Will not retry.""" +class UnrecoverableError(DurableExecutionsError): + """Base class for errors that terminate execution.""" + + def __init__(self, message: str, termination_reason: TerminationReason): + super().__init__(message) + self.termination_reason = termination_reason + + +class ExecutionError(UnrecoverableError): + """Error that returns FAILED status without retry.""" + + def __init__( + self, + message: str, + termination_reason: TerminationReason = TerminationReason.EXECUTION_ERROR, + ): + super().__init__(message, termination_reason) + + +class InvocationError(UnrecoverableError): + """Error that should cause Lambda retry by throwing from handler.""" + + def __init__( + self, + message: str, + termination_reason: TerminationReason = TerminationReason.INVOCATION_ERROR, + ): + super().__init__(message, termination_reason) + + +class CallbackError(ExecutionError): + """Error in callback handling.""" + + def __init__(self, message: str, callback_id: str | None = None): + super().__init__(message, TerminationReason.CALLBACK_ERROR) + self.callback_id = callback_id + + +class CheckpointFailedError(InvocationError): + """Error when checkpoint operation fails.""" + + def __init__(self, message: str, step_id: str | None = None): + super().__init__(message, TerminationReason.CHECKPOINT_FAILED) + self.step_id = step_id + + +class NonDeterministicExecutionError(ExecutionError): + """Error when execution is non-deterministic.""" + def __init__(self, message: str, step_id: str | None = None): + super().__init__(message, TerminationReason.NON_DETERMINISTIC_EXECUTION) + self.step_id = step_id -class CheckpointError(FatalError): + +class CheckpointError(CheckpointFailedError): """Failure to checkpoint. Will terminate the lambda.""" + def __init__(self, message: str): + super().__init__(message) + + @classmethod + def from_exception(cls, exception: Exception) -> CheckpointError: + return cls(message=str(exception)) + class ValidationError(DurableExecutionsError): """Incorrect arguments to a Durable Function operation.""" @@ -50,9 +121,13 @@ def __init__( self.stack_trace = stack_trace -class StepInterruptedError(UserlandError): +class StepInterruptedError(InvocationError): """Raised when a step is interrupted before it checkpointed at the end.""" + def __init__(self, message: str, step_id: str | None = None): + super().__init__(message, TerminationReason.STEP_INTERRUPTED) + self.step_id = step_id + class SuspendExecution(BaseException): """Raise this exception to suspend the current execution by returning PENDING to DAR. diff --git a/src/aws_durable_execution_sdk_python/execution.py b/src/aws_durable_execution_sdk_python/execution.py index 8f3a9c2..e21a828 100644 --- a/src/aws_durable_execution_sdk_python/execution.py +++ b/src/aws_durable_execution_sdk_python/execution.py @@ -12,7 +12,8 @@ from aws_durable_execution_sdk_python.exceptions import ( CheckpointError, DurableExecutionsError, - FatalError, + ExecutionError, + InvocationError, SuspendExecution, ) from aws_durable_execution_sdk_python.lambda_service import ( @@ -291,10 +292,16 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: logger.exception("Failed to checkpoint") # Throw the error to terminate the lambda raise - except FatalError as e: - logger.exception("Fatal error") + + except InvocationError: + logger.exception("Invocation error. Must terminate.") + # Throw the error to trigger Lambda retry + raise + except ExecutionError as e: + logger.exception("Execution error. Must terminate without retry.") return DurableExecutionInvocationOutput( - status=InvocationStatus.PENDING, error=ErrorObject.from_exception(e) + status=InvocationStatus.FAILED, + error=ErrorObject.from_exception(e), ).to_dict() except Exception as e: # all user-space errors go here diff --git a/src/aws_durable_execution_sdk_python/lambda_service.py b/src/aws_durable_execution_sdk_python/lambda_service.py index 17d58a0..46b2c8b 100644 --- a/src/aws_durable_execution_sdk_python/lambda_service.py +++ b/src/aws_durable_execution_sdk_python/lambda_service.py @@ -983,7 +983,7 @@ def checkpoint( return CheckpointOutput.from_dict(result) except Exception as e: logger.exception("Failed to checkpoint.") - raise CheckpointError(e) from e + raise CheckpointError.from_exception(e) from e def get_execution_state( self, diff --git a/src/aws_durable_execution_sdk_python/operation/callback.py b/src/aws_durable_execution_sdk_python/operation/callback.py index 600602e..303e662 100644 --- a/src/aws_durable_execution_sdk_python/operation/callback.py +++ b/src/aws_durable_execution_sdk_python/operation/callback.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any from aws_durable_execution_sdk_python.config import StepConfig -from aws_durable_execution_sdk_python.exceptions import FatalError +from aws_durable_execution_sdk_python.exceptions import CallbackError from aws_durable_execution_sdk_python.lambda_service import ( CallbackOptions, OperationUpdate, @@ -58,8 +58,8 @@ def create_callback_handler( not checkpointed_result.operation or not checkpointed_result.operation.callback_details ): - msg = "Missing callback details" - raise FatalError(msg) + msg = f"Missing callback details for operation: {operation_identifier.operation_id}" + raise CallbackError(msg) return checkpointed_result.operation.callback_details.callback_id @@ -74,8 +74,8 @@ def create_callback_handler( ) if not result.operation or not result.operation.callback_details: - msg = "Missing callback details" - raise FatalError(msg) + msg = f"Missing callback details for operation: {operation_identifier.operation_id}" + raise CallbackError(msg) return result.operation.callback_details.callback_id diff --git a/src/aws_durable_execution_sdk_python/operation/child.py b/src/aws_durable_execution_sdk_python/operation/child.py index e2e8c72..bd649a5 100644 --- a/src/aws_durable_execution_sdk_python/operation/child.py +++ b/src/aws_durable_execution_sdk_python/operation/child.py @@ -6,7 +6,10 @@ from typing import TYPE_CHECKING, TypeVar from aws_durable_execution_sdk_python.config import ChildConfig -from aws_durable_execution_sdk_python.exceptions import FatalError, SuspendExecution +from aws_durable_execution_sdk_python.exceptions import ( + InvocationError, + SuspendExecution, +) from aws_durable_execution_sdk_python.lambda_service import ( ContextOptions, ErrorObject, @@ -138,7 +141,11 @@ def child_handler( ) state.create_checkpoint(operation_update=fail_operation) - # TODO: rethink FatalError - if isinstance(e, FatalError): + # InvocationError and its derivatives can be retried + # When we encounter an invocation error (in all of its forms), we bubble that + # error upwards (with the checkpoint in place) such that we reach the + # execution handler at the very top, which will then induce a retry from the + # dataplane. + if isinstance(e, InvocationError): raise raise error_object.to_callable_runtime_error() from e diff --git a/src/aws_durable_execution_sdk_python/operation/invoke.py b/src/aws_durable_execution_sdk_python/operation/invoke.py index f662476..826f764 100644 --- a/src/aws_durable_execution_sdk_python/operation/invoke.py +++ b/src/aws_durable_execution_sdk_python/operation/invoke.py @@ -6,9 +6,7 @@ from typing import TYPE_CHECKING, TypeVar from aws_durable_execution_sdk_python.config import InvokeConfig -from aws_durable_execution_sdk_python.exceptions import ( - FatalError, -) +from aws_durable_execution_sdk_python.exceptions import ExecutionError from aws_durable_execution_sdk_python.lambda_service import ( ChainedInvokeOptions, OperationUpdate, @@ -107,5 +105,6 @@ def invoke_handler( ) suspend_with_optional_timeout(msg, config.timeout_seconds) # This line should never be reached since suspend_with_optional_timeout always raises + # if it is ever reached, we will crash in a non-retryable manner via ExecutionError msg = "suspend_with_optional_timeout should have raised an exception, but did not." - raise FatalError(msg) from None + raise ExecutionError(msg) from None diff --git a/src/aws_durable_execution_sdk_python/operation/step.py b/src/aws_durable_execution_sdk_python/operation/step.py index 8491f4e..f1df8d6 100644 --- a/src/aws_durable_execution_sdk_python/operation/step.py +++ b/src/aws_durable_execution_sdk_python/operation/step.py @@ -11,7 +11,7 @@ StepSemantics, ) from aws_durable_execution_sdk_python.exceptions import ( - FatalError, + ExecutionError, StepInterruptedError, ) from aws_durable_execution_sdk_python.lambda_service import ( @@ -151,14 +151,14 @@ def step_handler( ) return raw_result # noqa: TRY300 except Exception as e: - if isinstance(e, FatalError): + if isinstance(e, ExecutionError): # no retry on fatal - e.g checkpoint exception logger.debug( "💥 Fatal error for id: %s, name: %s", operation_identifier.operation_id, operation_identifier.name, ) - # this bubbles up to execution.durable_handler, where it will exit with PENDING. TODO: confirm if still correct + # this bubbles up to execution.durable_handler, where it will exit with FAILED raise logger.exception( @@ -168,8 +168,10 @@ def step_handler( ) retry_handler(e, state, operation_identifier, config, checkpointed_result) + # if we've failed to raise an exception from the retry_handler, then we are in a + # weird state, and should crash terminate the execution msg = "retry handler should have raised an exception, but did not." - raise FatalError(msg) from None + raise ExecutionError(msg) from None # TODO: I don't much like this func, needs refactor. Messy grab-bag of args, refine. diff --git a/src/aws_durable_execution_sdk_python/operation/wait_for_condition.py b/src/aws_durable_execution_sdk_python/operation/wait_for_condition.py index 9c566df..bddd4a3 100644 --- a/src/aws_durable_execution_sdk_python/operation/wait_for_condition.py +++ b/src/aws_durable_execution_sdk_python/operation/wait_for_condition.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, TypeVar from aws_durable_execution_sdk_python.exceptions import ( - FatalError, + ExecutionError, ) from aws_durable_execution_sdk_python.lambda_service import ( ErrorObject, @@ -203,4 +203,4 @@ def wait_for_condition_handler( raise msg: str = "wait_for_condition should never reach this point" - raise FatalError(msg) + raise ExecutionError(msg) diff --git a/src/aws_durable_execution_sdk_python/serdes.py b/src/aws_durable_execution_sdk_python/serdes.py index e17cd56..e979a72 100644 --- a/src/aws_durable_execution_sdk_python/serdes.py +++ b/src/aws_durable_execution_sdk_python/serdes.py @@ -34,7 +34,7 @@ from aws_durable_execution_sdk_python.exceptions import ( DurableExecutionsError, - FatalError, + ExecutionError, SerDesError, ) @@ -440,9 +440,12 @@ def serialize( try: return active_serdes.serialize(value, serdes_context) except Exception as e: - logger.exception("⚠️ Serialization failed for id: %s", operation_id) - msg = f"Serialization failed for id: {operation_id}, error: {e}" - raise FatalError(msg) from e + logger.exception( + "⚠️ Serialization failed for id: %s", + operation_id, + ) + msg = f"Serialization failed for id: {operation_id}, error: {e}." + raise ExecutionError(msg) from e def deserialize( @@ -469,4 +472,4 @@ def deserialize( except Exception as e: logger.exception("⚠️ Deserialization failed for id: %s", operation_id) msg = f"Deserialization failed for id: {operation_id}" - raise FatalError(msg) from e + raise ExecutionError(msg) from e diff --git a/src/aws_durable_execution_sdk_python/state.py b/src/aws_durable_execution_sdk_python/state.py index 084f431..d8258ee 100644 --- a/src/aws_durable_execution_sdk_python/state.py +++ b/src/aws_durable_execution_sdk_python/state.py @@ -92,6 +92,11 @@ def is_succeeded(self) -> bool: return op.status is OperationStatus.SUCCEEDED + def is_cancelled(self) -> bool: + if op := self.operation: + return op.status is OperationStatus.CANCELLED + return False + def is_failed(self) -> bool: """Return True if the checkpointed operation is FAILED.""" op = self.operation diff --git a/tests/context_test.py b/tests/context_test.py index 9b48e01..de26153 100644 --- a/tests/context_test.py +++ b/tests/context_test.py @@ -20,7 +20,7 @@ from aws_durable_execution_sdk_python.context import Callback, DurableContext from aws_durable_execution_sdk_python.exceptions import ( CallableRuntimeError, - FatalError, + CallbackError, SuspendExecution, ValidationError, ) @@ -108,7 +108,7 @@ def test_callback_result_started_no_timeout(): callback = Callback("callback3", "op3", mock_state) - with pytest.raises(SuspendExecution, match="Calback result not received yet"): + with pytest.raises(SuspendExecution, match="Callback result not received yet"): callback.result() @@ -126,7 +126,7 @@ def test_callback_result_started_with_timeout(): callback = Callback("callback4", "op4", mock_state) - with pytest.raises(SuspendExecution, match="Calback result not received yet"): + with pytest.raises(SuspendExecution, match="Callback result not received yet"): callback.result() @@ -159,7 +159,7 @@ def test_callback_result_not_started(): callback = Callback("callback6", "op6", mock_state) - with pytest.raises(FatalError, match="Callback must be started"): + with pytest.raises(CallbackError, match="Callback operation must exist"): callback.result() diff --git a/tests/exceptions_test.py b/tests/exceptions_test.py index ed61348..b70a1a5 100644 --- a/tests/exceptions_test.py +++ b/tests/exceptions_test.py @@ -10,11 +10,14 @@ CallableRuntimeErrorSerializableDetails, CheckpointError, DurableExecutionsError, - FatalError, + ExecutionError, + InvocationError, OrderedLockError, StepInterruptedError, SuspendExecution, + TerminationReason, TimedSuspendExecution, + UnrecoverableError, UserlandError, ValidationError, ) @@ -27,18 +30,22 @@ def test_durable_executions_error(): assert isinstance(error, Exception) -def test_fatal_error(): - """Test FatalError exception.""" - error = FatalError("fatal error") - assert str(error) == "fatal error" +def test_invocation_error(): + """Test InvocationError exception.""" + error = InvocationError("invocation error") + assert str(error) == "invocation error" + assert isinstance(error, UnrecoverableError) assert isinstance(error, DurableExecutionsError) + assert error.termination_reason == TerminationReason.INVOCATION_ERROR def test_checkpoint_error(): """Test CheckpointError exception.""" error = CheckpointError("checkpoint failed") assert str(error) == "checkpoint failed" - assert isinstance(error, FatalError) + assert isinstance(error, InvocationError) + assert isinstance(error, UnrecoverableError) + assert error.termination_reason == TerminationReason.CHECKPOINT_FAILED def test_validation_error(): @@ -64,7 +71,6 @@ def test_callable_runtime_error(): assert error.message == "runtime error" assert error.error_type == "ValueError" assert error.data == "error data" - assert error.stack_trace == ["line1", "line2"] assert isinstance(error, UserlandError) @@ -74,14 +80,16 @@ def test_callable_runtime_error_with_none_values(): assert error.message is None assert error.error_type is None assert error.data is None - assert error.stack_trace is None def test_step_interrupted_error(): """Test StepInterruptedError exception.""" - error = StepInterruptedError("step interrupted") + error = StepInterruptedError("step interrupted", "step_123") assert str(error) == "step interrupted" - assert isinstance(error, UserlandError) + assert isinstance(error, InvocationError) + assert isinstance(error, UnrecoverableError) + assert error.termination_reason == TerminationReason.STEP_INTERRUPTED + assert error.step_id == "step_123" def test_suspend_execution(): @@ -225,3 +233,27 @@ def test_timed_suspend_execution_from_delay_calculation_accuracy(): assert expected_min <= error.scheduled_timestamp <= expected_max assert str(error) == message assert isinstance(error, TimedSuspendExecution) + + +def test_unrecoverable_error(): + """Test UnrecoverableError base class.""" + error = UnrecoverableError("unrecoverable error", TerminationReason.EXECUTION_ERROR) + assert str(error) == "unrecoverable error" + assert error.termination_reason == TerminationReason.EXECUTION_ERROR + assert isinstance(error, DurableExecutionsError) + + +def test_execution_error(): + """Test ExecutionError exception.""" + error = ExecutionError("execution error") + assert str(error) == "execution error" + assert isinstance(error, UnrecoverableError) + assert isinstance(error, DurableExecutionsError) + assert error.termination_reason == TerminationReason.EXECUTION_ERROR + + +def test_execution_error_with_custom_termination_reason(): + """Test ExecutionError with custom termination reason.""" + error = ExecutionError("custom error", TerminationReason.SERIALIZATION_ERROR) + assert str(error) == "custom error" + assert error.termination_reason == TerminationReason.SERIALIZATION_ERROR diff --git a/tests/execution_test.py b/tests/execution_test.py index 305be28..3144fa2 100644 --- a/tests/execution_test.py +++ b/tests/execution_test.py @@ -8,7 +8,11 @@ import pytest from aws_durable_execution_sdk_python.context import DurableContext -from aws_durable_execution_sdk_python.exceptions import CheckpointError, FatalError +from aws_durable_execution_sdk_python.exceptions import ( + CheckpointError, + ExecutionError, + InvocationError, +) from aws_durable_execution_sdk_python.execution import ( DurableExecutionInvocationInput, DurableExecutionInvocationInputWithClient, @@ -581,8 +585,47 @@ def test_durable_execution_fatal_error_handling(): @durable_execution def test_handler(event: Any, context: DurableContext) -> dict: - msg = "Fatal error occurred" - raise FatalError(msg) + msg = "Retriable invocation error occurred" + raise InvocationError(msg) + + operation = Operation( + operation_id="exec1", + operation_type=OperationType.EXECUTION, + status=OperationStatus.STARTED, + execution_details=ExecutionDetails(input_payload="{}"), + ) + + initial_state = InitialExecutionState(operations=[operation], next_marker="") + + invocation_input = DurableExecutionInvocationInputWithClient( + durable_execution_arn="arn:test:execution", + checkpoint_token="token123", # noqa: S106 + initial_execution_state=initial_state, + is_local_runner=False, + service_client=mock_client, + ) + + lambda_context = Mock() + lambda_context.aws_request_id = "test-request" + lambda_context.client_context = None + lambda_context.identity = None + lambda_context._epoch_deadline_time_in_ms = 1000000 # noqa: SLF001 + lambda_context.invoked_function_arn = None + lambda_context.tenant_id = None + + # expect raise; backend will retry + with pytest.raises(InvocationError, match="Retriable invocation error occurred"): + test_handler(invocation_input, lambda_context) + + +def test_durable_execution_execution_error_handling(): + """Test durable_execution handles InvocationError correctly.""" + mock_client = Mock(spec=DurableServiceClient) + + @durable_execution + def test_handler(event: Any, context: DurableContext) -> dict: + msg = "Retriable invocation error occurred" + raise ExecutionError(msg) operation = Operation( operation_id="exec1", @@ -609,10 +652,15 @@ def test_handler(event: Any, context: DurableContext) -> dict: lambda_context.invoked_function_arn = None lambda_context.tenant_id = None + # ExecutionError should return FAILED status with ErrorObject in result field result = test_handler(invocation_input, lambda_context) + assert result["Status"] == InvocationStatus.FAILED.value + + # Parse the ErrorObject from the result field + error_data = result["Error"] - assert result["Status"] == InvocationStatus.PENDING.value - assert "Fatal error occurred" in result["Error"]["ErrorMessage"] + assert error_data["ErrorMessage"] == "Retriable invocation error occurred" + assert error_data["ErrorType"] == "ExecutionError" def test_durable_execution_client_selection_local_runner(): diff --git a/tests/operation/callback_test.py b/tests/operation/callback_test.py index e9b7c00..3943f76 100644 --- a/tests/operation/callback_test.py +++ b/tests/operation/callback_test.py @@ -10,7 +10,7 @@ StepConfig, WaitForCallbackConfig, ) -from aws_durable_execution_sdk_python.exceptions import FatalError +from aws_durable_execution_sdk_python.exceptions import CallbackError from aws_durable_execution_sdk_python.identifier import OperationIdentifier from aws_durable_execution_sdk_python.lambda_service import ( CallbackDetails, @@ -171,7 +171,7 @@ def test_create_callback_handler_existing_started_missing_callback_details(): mock_result = CheckpointedResult.create_from_operation(operation) mock_state.get_checkpoint_result.return_value = mock_result - with pytest.raises(FatalError, match="Missing callback details"): + with pytest.raises(CallbackError, match="Missing callback details"): create_callback_handler( state=mock_state, operation_identifier=OperationIdentifier("callback5", None), @@ -193,7 +193,7 @@ def test_create_callback_handler_new_operation_missing_callback_details_after_ch CheckpointedResult.create_from_operation(operation), ] - with pytest.raises(FatalError, match="Missing callback details"): + with pytest.raises(CallbackError, match="Missing callback details"): create_callback_handler( state=mock_state, operation_identifier=OperationIdentifier("callback6", None), @@ -236,7 +236,7 @@ def test_create_callback_handler_existing_timed_out_missing_callback_details(): mock_result = CheckpointedResult.create_from_operation(operation) mock_state.get_checkpoint_result.return_value = mock_result - with pytest.raises(FatalError, match="Missing callback details"): + with pytest.raises(CallbackError, match="Missing callback details"): create_callback_handler( state=mock_state, operation_identifier=OperationIdentifier( @@ -319,7 +319,7 @@ def test_create_callback_handler_with_none_operation_in_result(): mock_result.operation = None mock_state.get_checkpoint_result.return_value = mock_result - with pytest.raises(FatalError, match="Missing callback details"): + with pytest.raises(CallbackError, match="Missing callback details"): create_callback_handler( state=mock_state, operation_identifier=OperationIdentifier("none_operation", None), @@ -473,7 +473,7 @@ def test_create_callback_handler_existing_succeeded_missing_callback_details(): mock_result = CheckpointedResult.create_from_operation(operation) mock_state.get_checkpoint_result.return_value = mock_result - with pytest.raises(FatalError, match="Missing callback details"): + with pytest.raises(CallbackError, match="Missing callback details"): create_callback_handler( state=mock_state, operation_identifier=OperationIdentifier( diff --git a/tests/operation/child_test.py b/tests/operation/child_test.py index 5413172..e888ebb 100644 --- a/tests/operation/child_test.py +++ b/tests/operation/child_test.py @@ -7,7 +7,7 @@ import pytest from aws_durable_execution_sdk_python.config import ChildConfig -from aws_durable_execution_sdk_python.exceptions import CallableRuntimeError, FatalError +from aws_durable_execution_sdk_python.exceptions import CallableRuntimeError from aws_durable_execution_sdk_python.identifier import OperationIdentifier from aws_durable_execution_sdk_python.lambda_service import ( ErrorObject, @@ -241,8 +241,8 @@ def test_child_handler_callable_exception( assert fail_operation.error == ErrorObject.from_exception(ValueError("Test error")) -def test_child_handler_fatal_error_propagated(): - """Test child_handler propagates FatalError without wrapping.""" +def test_child_handler_error_wrapped(): + """Test child_handler wraps regular errors as CallableRuntimeError.""" mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = "test_arn" mock_result = Mock() @@ -250,10 +250,10 @@ def test_child_handler_fatal_error_propagated(): mock_result.is_failed.return_value = False mock_result.is_started.return_value = False mock_state.get_checkpoint_result.return_value = mock_result - fatal_error = FatalError("Fatal test error") - mock_callable = Mock(side_effect=fatal_error) + test_error = RuntimeError("Test error") + mock_callable = Mock(side_effect=test_error) - with pytest.raises(FatalError, match="Fatal test error"): + with pytest.raises(CallableRuntimeError): child_handler( mock_callable, mock_state, diff --git a/tests/operation/invoke_test.py b/tests/operation/invoke_test.py index c0535d8..b9e38bc 100644 --- a/tests/operation/invoke_test.py +++ b/tests/operation/invoke_test.py @@ -10,7 +10,7 @@ from aws_durable_execution_sdk_python.config import InvokeConfig from aws_durable_execution_sdk_python.exceptions import ( CallableRuntimeError, - FatalError, + ExecutionError, SuspendExecution, TimedSuspendExecution, ) @@ -519,7 +519,7 @@ def test_invoke_handler_suspend_does_not_raise(mock_suspend): mock_suspend.return_value = None with pytest.raises( - FatalError, + ExecutionError, match="suspend_with_optional_timeout should have raised an exception, but did not.", ): invoke_handler( diff --git a/tests/operation/step_test.py b/tests/operation/step_test.py index 826efb9..93845e8 100644 --- a/tests/operation/step_test.py +++ b/tests/operation/step_test.py @@ -13,7 +13,7 @@ ) from aws_durable_execution_sdk_python.exceptions import ( CallableRuntimeError, - FatalError, + ExecutionError, StepInterruptedError, SuspendExecution, ) @@ -261,18 +261,18 @@ def test_step_handler_success_at_most_once(): assert success_operation.action is OperationAction.SUCCEED -def test_step_handler_fatal_error(): - """Test step_handler with FatalError exception.""" +def test_step_handler_non_retriable_execution_error(): + """Test step_handler with ExecutionError exception.""" mock_state = Mock(spec=ExecutionState) mock_result = CheckpointedResult.create_not_found() mock_state.get_checkpoint_result.return_value = mock_result mock_state.durable_execution_arn = "test_arn" - mock_callable = Mock(side_effect=FatalError("Fatal error")) + mock_callable = Mock(side_effect=ExecutionError("Do Not Retry")) mock_logger = Mock(spec=Logger) mock_logger.with_log_info.return_value = mock_logger - with pytest.raises(FatalError, match="Fatal error"): + with pytest.raises(ExecutionError, match="Do Not Retry"): step_handler( mock_callable, mock_state, @@ -487,7 +487,8 @@ def test_step_handler_retry_handler_no_exception(mock_retry_handler): mock_logger.with_log_info.return_value = mock_logger with pytest.raises( - FatalError, match="retry handler should have raised an exception, but did not." + ExecutionError, + match="retry handler should have raised an exception, but did not.", ): step_handler( mock_callable, diff --git a/tests/operation/wait_for_condition_test.py b/tests/operation/wait_for_condition_test.py index c6f36a1..945d7f4 100644 --- a/tests/operation/wait_for_condition_test.py +++ b/tests/operation/wait_for_condition_test.py @@ -12,7 +12,7 @@ ) from aws_durable_execution_sdk_python.exceptions import ( CallableRuntimeError, - FatalError, + InvocationError, SuspendExecution, ) from aws_durable_execution_sdk_python.identifier import OperationIdentifier @@ -681,7 +681,7 @@ def test_wait_for_condition_pending(): def check_func(state, context): msg = "Should not be called" - raise FatalError(msg) + raise InvocationError(msg) config = WaitForConditionConfig( initial_state=5, @@ -716,7 +716,7 @@ def test_wait_for_condition_pending_without_next_attempt(): def check_func(state, context): msg = "Should not be called" - raise FatalError(msg) + raise InvocationError(msg) config = WaitForConditionConfig( initial_state=5, diff --git a/tests/serdes_test.py b/tests/serdes_test.py index 9768b27..91baf2c 100644 --- a/tests/serdes_test.py +++ b/tests/serdes_test.py @@ -10,7 +10,7 @@ from aws_durable_execution_sdk_python.exceptions import ( DurableExecutionsError, - FatalError, + ExecutionError, SerDesError, ) from aws_durable_execution_sdk_python.serdes import ( @@ -123,13 +123,13 @@ def test_serialize_invalid_json(): circular_ref = {"a": 1} circular_ref["self"] = circular_ref - with pytest.raises(FatalError) as exc_info: + with pytest.raises(ExecutionError) as exc_info: serialize(None, circular_ref, "test-op", "test-arn") assert "Serialization failed" in str(exc_info.value) def test_deserialize_invalid_json(): - with pytest.raises(FatalError) as exc_info: + with pytest.raises(ExecutionError) as exc_info: deserialize(None, "invalid json", "test-op", "test-arn") assert "Deserialization failed" in str(exc_info.value) @@ -583,7 +583,7 @@ def test_envelope_handles_json_incompatible_types(): } # JsonSerDes should fail - with pytest.raises(FatalError): + with pytest.raises(ExecutionError): serialize(json_serdes, complex_data, "test-op", "test-arn") # EnvelopeSerDes should succeed @@ -598,11 +598,11 @@ def test_envelope_error_handling_with_main_api(): envelope_serdes = ExtendedTypeSerDes() # Test serialization error - with pytest.raises(FatalError, match="Serialization failed"): + with pytest.raises(ExecutionError, match="Serialization failed"): serialize(envelope_serdes, object(), "test-op", "test-arn") # Test deserialization error - with pytest.raises(FatalError, match="Deserialization failed"): + with pytest.raises(ExecutionError, match="Deserialization failed"): deserialize(envelope_serdes, "invalid json", "test-op", "test-arn")