Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion src/aws_durable_execution_sdk_python/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from enum import Enum
from typing import TYPE_CHECKING, Self, TypedDict

BAD_REQUEST_ERROR: int = 400
SERVICE_ERROR: int = 500

if TYPE_CHECKING:
import datetime

Expand All @@ -22,7 +25,7 @@ class AwsErrorObj(TypedDict):
class AwsErrorMetadata(TypedDict):
RequestId: str | None
HostId: str | None
HTTPStatusCode: str | None
HTTPStatusCode: int | None
HTTPHeaders: str | None
RetryAttempts: str | None

Expand Down Expand Up @@ -121,12 +124,18 @@ def __init__(self, message: str, step_id: str | None = None):
self.step_id = step_id


class CheckpointErrorCategory(Enum):
INVOCATION = "INVOCATION"
EXECUTION = "EXECUTION"


class CheckpointError(BotoClientError):
"""Failure to checkpoint. Will terminate the lambda."""

def __init__(
self,
message: str,
error_category: CheckpointErrorCategory,
error: AwsErrorObj | None = None,
response_metadata: AwsErrorMetadata | None = None,
):
Expand All @@ -136,6 +145,40 @@ def __init__(
response_metadata,
termination_reason=TerminationReason.CHECKPOINT_FAILED,
)
self.error_category: CheckpointErrorCategory = error_category

@classmethod
def from_exception(cls, exception: Exception) -> CheckpointError:
base = BotoClientError.from_exception(exception)
metadata: AwsErrorMetadata | None = base.response_metadata
error: AwsErrorObj | None = base.error
error_category: CheckpointErrorCategory = CheckpointErrorCategory.INVOCATION

# InvalidParameterValueException and error message starts with "Invalid Checkpoint Token" is an InvocationError
# all other 4xx errors are Execution Errors and should be retried
# all 5xx errors are Invocation Errors
status_code: int | None = (metadata and metadata.get("HTTPStatusCode")) or None
if (
status_code
# if we are in 4xx range and is not an InvalidParameterValueException with Invalid Checkpoint Token
# then it's an execution error
and status_code < SERVICE_ERROR
and status_code >= BAD_REQUEST_ERROR
and error
and (
# is not InvalidParam => Execution
(error.get("Code", "") or "") != "InvalidParameterValueException"
# is not Invalid Token => Execution
or not (error.get("Message") or "").startswith(
"Invalid Checkpoint Token"
)
)
):
error_category = CheckpointErrorCategory.EXECUTION
return CheckpointError(str(exception), error_category, error, metadata)

def is_retriable(self):
return self.error_category == CheckpointErrorCategory.EXECUTION


class ValidationError(DurableExecutionsError):
Expand Down
29 changes: 22 additions & 7 deletions src/aws_durable_execution_sdk_python/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,6 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
invocation_input.durable_execution_arn,
)
serialized_result = json.dumps(result)

# large response handling here. Remember if checkpointing to complete, NOT to include
# payload in response
if (
Expand All @@ -300,8 +299,12 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
# Must ensure the result is persisted before returning to Lambda.
# Large results exceed Lambda response limits and must be stored durably
# before the execution completes.
execution_state.create_checkpoint(success_operation, is_sync=True)

try:
execution_state.create_checkpoint(
success_operation, is_sync=True
)
except CheckpointError as e:
return handle_checkpoint_error(e).to_dict()
return DurableExecutionInvocationOutput.create_succeeded(
result=""
).to_dict()
Expand All @@ -320,7 +323,9 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
)
else:
logger.exception("Checkpoint processing failed")
# Raise the original exception
# handle the original exception
if isinstance(bg_error.source_exception, CheckpointError):
return handle_checkpoint_error(bg_error.source_exception).to_dict()
raise bg_error.source_exception from bg_error

except SuspendExecution:
Expand All @@ -336,7 +341,7 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
"Checkpoint system failed",
extra=e.build_logger_extras(),
)
raise # Terminate Lambda immediately
return handle_checkpoint_error(e).to_dict()
except InvocationError:
logger.exception("Invocation error. Must terminate.")
# Throw the error to trigger Lambda retry
Expand Down Expand Up @@ -374,12 +379,22 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
# Must ensure the result is persisted before returning to Lambda.
# Large results exceed Lambda response limits and must be stored durably
# before the execution completes.
execution_state.create_checkpoint_sync(failed_operation)

try:
execution_state.create_checkpoint_sync(failed_operation)
except CheckpointError as e:
return handle_checkpoint_error(e).to_dict()
return DurableExecutionInvocationOutput(
status=InvocationStatus.FAILED
).to_dict()

return result

return wrapper


def handle_checkpoint_error(error: CheckpointError) -> DurableExecutionInvocationOutput:
if error.is_retriable():
raise error from None # Terminate Lambda immediately and have it be retried
return DurableExecutionInvocationOutput(
status=InvocationStatus.FAILED, error=ErrorObject.from_exception(error)
)
78 changes: 77 additions & 1 deletion tests/exceptions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
from unittest.mock import patch

import pytest
from botocore.exceptions import ClientError # type: ignore[import-untyped]

from aws_durable_execution_sdk_python.exceptions import (
CallableRuntimeError,
CallableRuntimeErrorSerializableDetails,
CheckpointError,
CheckpointErrorCategory,
DurableExecutionsError,
ExecutionError,
InvocationError,
Expand Down Expand Up @@ -41,13 +43,87 @@ def test_invocation_error():

def test_checkpoint_error():
"""Test CheckpointError exception."""
error = CheckpointError("checkpoint failed")
error = CheckpointError(
"checkpoint failed", error_category=CheckpointErrorCategory.EXECUTION
)
assert str(error) == "checkpoint failed"
assert isinstance(error, InvocationError)
assert isinstance(error, UnrecoverableError)
assert error.termination_reason == TerminationReason.CHECKPOINT_FAILED


def test_checkpoint_error_classification_invalid_token_invocation():
"""Test 4xx InvalidParameterValueException with Invalid Checkpoint Token is invocation error."""
error_response = {
"Error": {
"Code": "InvalidParameterValueException",
"Message": "Invalid Checkpoint Token: token expired",
},
"ResponseMetadata": {"HTTPStatusCode": 400},
}
client_error = ClientError(error_response, "Checkpoint")

result = CheckpointError.from_exception(client_error)

assert result.error_category == CheckpointErrorCategory.INVOCATION
assert not result.is_retriable()


def test_checkpoint_error_classification_other_4xx_execution():
"""Test other 4xx errors are execution errors."""
error_response = {
"Error": {"Code": "ValidationException", "Message": "Invalid parameter value"},
"ResponseMetadata": {"HTTPStatusCode": 400},
}
client_error = ClientError(error_response, "Checkpoint")

result = CheckpointError.from_exception(client_error)

assert result.error_category == CheckpointErrorCategory.EXECUTION
assert result.is_retriable()


def test_checkpoint_error_classification_invalid_param_without_token_execution():
"""Test 4xx InvalidParameterValueException without Invalid Checkpoint Token is execution error."""
error_response = {
"Error": {
"Code": "InvalidParameterValueException",
"Message": "Some other invalid parameter",
},
"ResponseMetadata": {"HTTPStatusCode": 400},
}
client_error = ClientError(error_response, "Checkpoint")

result = CheckpointError.from_exception(client_error)

assert result.error_category == CheckpointErrorCategory.EXECUTION
assert result.is_retriable()


def test_checkpoint_error_classification_5xx_invocation():
"""Test 5xx errors are invocation errors."""
error_response = {
"Error": {"Code": "InternalServerError", "Message": "Service unavailable"},
"ResponseMetadata": {"HTTPStatusCode": 500},
}
client_error = ClientError(error_response, "Checkpoint")

result = CheckpointError.from_exception(client_error)

assert result.error_category == CheckpointErrorCategory.INVOCATION
assert not result.is_retriable()


def test_checkpoint_error_classification_unknown_invocation():
"""Test unknown errors are invocation errors."""
unknown_error = Exception("Network timeout")

result = CheckpointError.from_exception(unknown_error)

assert result.error_category == CheckpointErrorCategory.INVOCATION
assert not result.is_retriable()


def test_validation_error():
"""Test ValidationError exception."""
error = ValidationError("validation failed")
Expand Down
Loading