Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions src/aws_durable_execution_sdk_python/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
WaitForConditionConfig,
)
from aws_durable_execution_sdk_python.exceptions import (
FatalError,
CallbackError,
SuspendExecution,
ValidationError,
)
Expand Down Expand Up @@ -125,11 +125,17 @@ def result(self) -> T | None:
checkpointed_result: CheckpointedResult = self.state.get_checkpoint_result(
self.operation_id
)
if checkpointed_result.is_started():
msg: str = "Calback result not received yet. Suspending execution while waiting for result."
raise SuspendExecution(msg)

if checkpointed_result.is_failed() or checkpointed_result.is_timed_out():
if not checkpointed_result.is_existent():
msg = "Callback operation must exist"
raise CallbackError(msg)

if (
checkpointed_result.is_failed()
or checkpointed_result.is_cancelled()
or checkpointed_result.is_timed_out()
or checkpointed_result.is_stopped()
):
checkpointed_result.raise_callable_error()

if checkpointed_result.is_succeeded():
Expand All @@ -143,8 +149,10 @@ def result(self) -> T | None:
durable_execution_arn=self.state.durable_execution_arn,
)

msg = "Callback must be started before you can await the result."
raise FatalError(msg)
# operation exists; it has not terminated (successfully or otherwise)
# therefore we should wait
msg = "Callback result not received yet. Suspending execution while waiting for result."
raise SuspendExecution(msg)


class DurableContext(DurableContextProtocol):
Expand Down
83 changes: 79 additions & 4 deletions src/aws_durable_execution_sdk_python/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,90 @@

from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from enum import Enum


class TerminationReason(Enum):
"""Reasons why a durable execution terminated."""

UNHANDLED_ERROR = "UNHANDLED_ERROR"
INVOCATION_ERROR = "INVOCATION_ERROR"
EXECUTION_ERROR = "EXECUTION_ERROR"
CHECKPOINT_FAILED = "CHECKPOINT_FAILED"
NON_DETERMINISTIC_EXECUTION = "NON_DETERMINISTIC_EXECUTION"
STEP_INTERRUPTED = "STEP_INTERRUPTED"
CALLBACK_ERROR = "CALLBACK_ERROR"
SERIALIZATION_ERROR = "SERIALIZATION_ERROR"


class DurableExecutionsError(Exception):
"""Base class for Durable Executions exceptions"""


class FatalError(DurableExecutionsError):
"""Unrecoverable error. Will not retry."""
class UnrecoverableError(DurableExecutionsError):
"""Base class for errors that terminate execution."""

def __init__(self, message: str, termination_reason: TerminationReason):
super().__init__(message)
self.termination_reason = termination_reason


class ExecutionError(UnrecoverableError):
"""Error that returns FAILED status without retry."""

def __init__(
self,
message: str,
termination_reason: TerminationReason = TerminationReason.EXECUTION_ERROR,
):
super().__init__(message, termination_reason)


class InvocationError(UnrecoverableError):
"""Error that should cause Lambda retry by throwing from handler."""

def __init__(
self,
message: str,
termination_reason: TerminationReason = TerminationReason.INVOCATION_ERROR,
):
super().__init__(message, termination_reason)


class CallbackError(ExecutionError):
"""Error in callback handling."""

def __init__(self, message: str, callback_id: str | None = None):
super().__init__(message, TerminationReason.CALLBACK_ERROR)
self.callback_id = callback_id


class CheckpointFailedError(InvocationError):
"""Error when checkpoint operation fails."""

def __init__(self, message: str, step_id: str | None = None):
super().__init__(message, TerminationReason.CHECKPOINT_FAILED)
self.step_id = step_id


class NonDeterministicExecutionError(ExecutionError):
"""Error when execution is non-deterministic."""

def __init__(self, message: str, step_id: str | None = None):
super().__init__(message, TerminationReason.NON_DETERMINISTIC_EXECUTION)
self.step_id = step_id

class CheckpointError(FatalError):

class CheckpointError(CheckpointFailedError):
"""Failure to checkpoint. Will terminate the lambda."""

def __init__(self, message: str):
super().__init__(message)

@classmethod
def from_exception(cls, exception: Exception) -> CheckpointError:
return cls(message=str(exception))


class ValidationError(DurableExecutionsError):
"""Incorrect arguments to a Durable Function operation."""
Expand Down Expand Up @@ -50,9 +121,13 @@ def __init__(
self.stack_trace = stack_trace


class StepInterruptedError(UserlandError):
class StepInterruptedError(InvocationError):
"""Raised when a step is interrupted before it checkpointed at the end."""

def __init__(self, message: str, step_id: str | None = None):
super().__init__(message, TerminationReason.STEP_INTERRUPTED)
self.step_id = step_id


class SuspendExecution(BaseException):
"""Raise this exception to suspend the current execution by returning PENDING to DAR.
Expand Down
15 changes: 11 additions & 4 deletions src/aws_durable_execution_sdk_python/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from aws_durable_execution_sdk_python.exceptions import (
CheckpointError,
DurableExecutionsError,
FatalError,
ExecutionError,
InvocationError,
SuspendExecution,
)
from aws_durable_execution_sdk_python.lambda_service import (
Expand Down Expand Up @@ -291,10 +292,16 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
logger.exception("Failed to checkpoint")
# Throw the error to terminate the lambda
raise
except FatalError as e:
logger.exception("Fatal error")

except InvocationError:
logger.exception("Invocation error. Must terminate.")
# Throw the error to trigger Lambda retry
raise
except ExecutionError as e:
logger.exception("Execution error. Must terminate without retry.")
return DurableExecutionInvocationOutput(
status=InvocationStatus.PENDING, error=ErrorObject.from_exception(e)
status=InvocationStatus.FAILED,
error=ErrorObject.from_exception(e),
).to_dict()
except Exception as e:
# all user-space errors go here
Expand Down
2 changes: 1 addition & 1 deletion src/aws_durable_execution_sdk_python/lambda_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,7 +983,7 @@ def checkpoint(
return CheckpointOutput.from_dict(result)
except Exception as e:
logger.exception("Failed to checkpoint.")
raise CheckpointError(e) from e
raise CheckpointError.from_exception(e) from e

def get_execution_state(
self,
Expand Down
10 changes: 5 additions & 5 deletions src/aws_durable_execution_sdk_python/operation/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import TYPE_CHECKING, Any

from aws_durable_execution_sdk_python.config import StepConfig
from aws_durable_execution_sdk_python.exceptions import FatalError
from aws_durable_execution_sdk_python.exceptions import CallbackError
from aws_durable_execution_sdk_python.lambda_service import (
CallbackOptions,
OperationUpdate,
Expand Down Expand Up @@ -58,8 +58,8 @@ def create_callback_handler(
not checkpointed_result.operation
or not checkpointed_result.operation.callback_details
):
msg = "Missing callback details"
raise FatalError(msg)
msg = f"Missing callback details for operation: {operation_identifier.operation_id}"
raise CallbackError(msg)

return checkpointed_result.operation.callback_details.callback_id

Expand All @@ -74,8 +74,8 @@ def create_callback_handler(
)

if not result.operation or not result.operation.callback_details:
msg = "Missing callback details"
raise FatalError(msg)
msg = f"Missing callback details for operation: {operation_identifier.operation_id}"
raise CallbackError(msg)

return result.operation.callback_details.callback_id

Expand Down
13 changes: 10 additions & 3 deletions src/aws_durable_execution_sdk_python/operation/child.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from typing import TYPE_CHECKING, TypeVar

from aws_durable_execution_sdk_python.config import ChildConfig
from aws_durable_execution_sdk_python.exceptions import FatalError, SuspendExecution
from aws_durable_execution_sdk_python.exceptions import (
InvocationError,
SuspendExecution,
)
from aws_durable_execution_sdk_python.lambda_service import (
ContextOptions,
ErrorObject,
Expand Down Expand Up @@ -138,7 +141,11 @@ def child_handler(
)
state.create_checkpoint(operation_update=fail_operation)

# TODO: rethink FatalError
if isinstance(e, FatalError):
# InvocationError and its derivatives can be retried
# When we encounter an invocation error (in all of its forms), we bubble that
# error upwards (with the checkpoint in place) such that we reach the
# execution handler at the very top, which will then induce a retry from the
# dataplane.
if isinstance(e, InvocationError):
raise
raise error_object.to_callable_runtime_error() from e
7 changes: 3 additions & 4 deletions src/aws_durable_execution_sdk_python/operation/invoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
from typing import TYPE_CHECKING, TypeVar

from aws_durable_execution_sdk_python.config import InvokeConfig
from aws_durable_execution_sdk_python.exceptions import (
FatalError,
)
from aws_durable_execution_sdk_python.exceptions import ExecutionError
from aws_durable_execution_sdk_python.lambda_service import (
ChainedInvokeOptions,
OperationUpdate,
Expand Down Expand Up @@ -107,5 +105,6 @@ def invoke_handler(
)
suspend_with_optional_timeout(msg, config.timeout_seconds)
# This line should never be reached since suspend_with_optional_timeout always raises
# if it is ever reached, we will crash in a non-retryable manner via ExecutionError
msg = "suspend_with_optional_timeout should have raised an exception, but did not."
raise FatalError(msg) from None
raise ExecutionError(msg) from None
10 changes: 6 additions & 4 deletions src/aws_durable_execution_sdk_python/operation/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
StepSemantics,
)
from aws_durable_execution_sdk_python.exceptions import (
FatalError,
ExecutionError,
StepInterruptedError,
)
from aws_durable_execution_sdk_python.lambda_service import (
Expand Down Expand Up @@ -151,14 +151,14 @@ def step_handler(
)
return raw_result # noqa: TRY300
except Exception as e:
if isinstance(e, FatalError):
if isinstance(e, ExecutionError):
# no retry on fatal - e.g checkpoint exception
logger.debug(
"💥 Fatal error for id: %s, name: %s",
operation_identifier.operation_id,
operation_identifier.name,
)
# this bubbles up to execution.durable_handler, where it will exit with PENDING. TODO: confirm if still correct
# this bubbles up to execution.durable_handler, where it will exit with FAILED
raise

logger.exception(
Expand All @@ -168,8 +168,10 @@ def step_handler(
)

retry_handler(e, state, operation_identifier, config, checkpointed_result)
# if we've failed to raise an exception from the retry_handler, then we are in a
# weird state, and should crash terminate the execution
msg = "retry handler should have raised an exception, but did not."
raise FatalError(msg) from None
raise ExecutionError(msg) from None


# TODO: I don't much like this func, needs refactor. Messy grab-bag of args, refine.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import TYPE_CHECKING, TypeVar

from aws_durable_execution_sdk_python.exceptions import (
FatalError,
ExecutionError,
)
from aws_durable_execution_sdk_python.lambda_service import (
ErrorObject,
Expand Down Expand Up @@ -203,4 +203,4 @@ def wait_for_condition_handler(
raise

msg: str = "wait_for_condition should never reach this point"
raise FatalError(msg)
raise ExecutionError(msg)
13 changes: 8 additions & 5 deletions src/aws_durable_execution_sdk_python/serdes.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

from aws_durable_execution_sdk_python.exceptions import (
DurableExecutionsError,
FatalError,
ExecutionError,
SerDesError,
)

Expand Down Expand Up @@ -440,9 +440,12 @@ def serialize(
try:
return active_serdes.serialize(value, serdes_context)
except Exception as e:
logger.exception("⚠️ Serialization failed for id: %s", operation_id)
msg = f"Serialization failed for id: {operation_id}, error: {e}"
raise FatalError(msg) from e
logger.exception(
"⚠️ Serialization failed for id: %s",
operation_id,
)
msg = f"Serialization failed for id: {operation_id}, error: {e}."
raise ExecutionError(msg) from e


def deserialize(
Expand All @@ -469,4 +472,4 @@ def deserialize(
except Exception as e:
logger.exception("⚠️ Deserialization failed for id: %s", operation_id)
msg = f"Deserialization failed for id: {operation_id}"
raise FatalError(msg) from e
raise ExecutionError(msg) from e
5 changes: 5 additions & 0 deletions src/aws_durable_execution_sdk_python/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ def is_succeeded(self) -> bool:

return op.status is OperationStatus.SUCCEEDED

def is_cancelled(self) -> bool:
if op := self.operation:
return op.status is OperationStatus.CANCELLED
return False

def is_failed(self) -> bool:
"""Return True if the checkpointed operation is FAILED."""
op = self.operation
Expand Down
8 changes: 4 additions & 4 deletions tests/context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from aws_durable_execution_sdk_python.context import Callback, DurableContext
from aws_durable_execution_sdk_python.exceptions import (
CallableRuntimeError,
FatalError,
CallbackError,
SuspendExecution,
ValidationError,
)
Expand Down Expand Up @@ -108,7 +108,7 @@ def test_callback_result_started_no_timeout():

callback = Callback("callback3", "op3", mock_state)

with pytest.raises(SuspendExecution, match="Calback result not received yet"):
with pytest.raises(SuspendExecution, match="Callback result not received yet"):
callback.result()


Expand All @@ -126,7 +126,7 @@ def test_callback_result_started_with_timeout():

callback = Callback("callback4", "op4", mock_state)

with pytest.raises(SuspendExecution, match="Calback result not received yet"):
with pytest.raises(SuspendExecution, match="Callback result not received yet"):
callback.result()


Expand Down Expand Up @@ -159,7 +159,7 @@ def test_callback_result_not_started():

callback = Callback("callback6", "op6", mock_state)

with pytest.raises(FatalError, match="Callback must be started"):
with pytest.raises(CallbackError, match="Callback operation must exist"):
callback.result()


Expand Down
Loading