Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions src/aws_durable_execution_sdk_python/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
import time
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, Self, TypedDict
from typing import TYPE_CHECKING, Literal, Self, TypedDict

BAD_REQUEST_ERROR: int = 400
SERVICE_ERROR: int = 500

if TYPE_CHECKING:
import datetime
Expand All @@ -22,7 +25,7 @@ class AwsErrorObj(TypedDict):
class AwsErrorMetadata(TypedDict):
RequestId: str | None
HostId: str | None
HTTPStatusCode: str | None
HTTPStatusCode: int | None
HTTPHeaders: str | None
RetryAttempts: str | None

Expand Down Expand Up @@ -121,12 +124,16 @@ def __init__(self, message: str, step_id: str | None = None):
self.step_id = step_id


CheckpointErrorKind = Literal["Execution", "Invocation"]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not enum?



class CheckpointError(BotoClientError):
"""Failure to checkpoint. Will terminate the lambda."""

def __init__(
self,
message: str,
error_kind: CheckpointErrorKind,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.category = ErrorCategory.INVOCATION

error: AwsErrorObj | None = None,
response_metadata: AwsErrorMetadata | None = None,
):
Expand All @@ -136,6 +143,40 @@ def __init__(
response_metadata,
termination_reason=TerminationReason.CHECKPOINT_FAILED,
)
self.error_kind: CheckpointErrorKind = error_kind

@classmethod
def from_exception(cls, exception: Exception) -> CheckpointError:
base = BotoClientError.from_exception(exception)
metadata: AwsErrorMetadata | None = base.response_metadata
error: AwsErrorObj | None = base.error
error_kind: CheckpointErrorKind = "Invocation"

# InvalidParameterValueException and error message starts with "Invalid Checkpoint Token" is an InvocationError
# all other 4xx errors are Execution Errors and should be retried
# all 5xx errors are Invocation Errors
status_code: int | None = (metadata and metadata.get("HTTPStatusCode")) or None
if (
status_code
# if we are in 4xx range and is not an InvalidParameterValueException with Invalid Checkpoint Token
# then it's an execution error
and status_code < SERVICE_ERROR
and status_code >= BAD_REQUEST_ERROR
and error
and (
# is not InvalidParam => Execution
(error.get("Code", "") or "") != "InvalidParameterValueException"
# is not Invalid Token => Execution
or not (error.get("Message") or "").startswith(
"Invalid Checkpoint Token"
)
)
):
error_kind = "Execution"
return CheckpointError(str(exception), error_kind, error, metadata)

def should_be_retried(self):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is_retriable

return self.error_kind == "Execution"


class ValidationError(DurableExecutionsError):
Expand Down
16 changes: 15 additions & 1 deletion src/aws_durable_execution_sdk_python/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,16 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
else:
logger.exception("Checkpoint processing failed")
# Raise the original exception
if (
isinstance(bg_error.source_exception, CheckpointError)
and bg_error.source_exception.should_be_retried()
):
raise bg_error.source_exception from None # Terminate Lambda immediately and have it be retried
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use from None to indicate that the error didn't originate from handling the exception in the first place, but instead is a continuation.

if isinstance(bg_error.source_exception, CheckpointError):
return DurableExecutionInvocationOutput(
status=InvocationStatus.FAILED,
error=ErrorObject.from_exception(bg_error.source_exception),
).to_dict()
raise bg_error.source_exception from bg_error

except SuspendExecution:
Expand All @@ -336,7 +346,11 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
"Checkpoint system failed",
extra=e.build_logger_extras(),
)
raise # Terminate Lambda immediately
if e.should_be_retried():
raise # Terminate Lambda immediately and have it be retried
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we don't do from since this is an unwrapped CheckpointError.

return DurableExecutionInvocationOutput(
status=InvocationStatus.FAILED, error=ErrorObject.from_exception(e)
).to_dict()
except InvocationError:
logger.exception("Invocation error. Must terminate.")
# Throw the error to trigger Lambda retry
Expand Down
75 changes: 74 additions & 1 deletion tests/exceptions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from unittest.mock import patch

import pytest
from botocore.exceptions import ClientError # type: ignore[import-untyped]

from aws_durable_execution_sdk_python.exceptions import (
CallableRuntimeError,
Expand Down Expand Up @@ -41,13 +42,85 @@ def test_invocation_error():

def test_checkpoint_error():
"""Test CheckpointError exception."""
error = CheckpointError("checkpoint failed")
error = CheckpointError("checkpoint failed", error_kind="Execution")
assert str(error) == "checkpoint failed"
assert isinstance(error, InvocationError)
assert isinstance(error, UnrecoverableError)
assert error.termination_reason == TerminationReason.CHECKPOINT_FAILED


def test_checkpoint_error_classification_invalid_token_invocation():
"""Test 4xx InvalidParameterValueException with Invalid Checkpoint Token is invocation error."""
error_response = {
"Error": {
"Code": "InvalidParameterValueException",
"Message": "Invalid Checkpoint Token: token expired",
},
"ResponseMetadata": {"HTTPStatusCode": 400},
}
client_error = ClientError(error_response, "Checkpoint")

result = CheckpointError.from_exception(client_error)

assert result.error_kind == "Invocation"
assert not result.should_be_retried()


def test_checkpoint_error_classification_other_4xx_execution():
"""Test other 4xx errors are execution errors."""
error_response = {
"Error": {"Code": "ValidationException", "Message": "Invalid parameter value"},
"ResponseMetadata": {"HTTPStatusCode": 400},
}
client_error = ClientError(error_response, "Checkpoint")

result = CheckpointError.from_exception(client_error)

assert result.error_kind == "Execution"
assert result.should_be_retried()


def test_checkpoint_error_classification_invalid_param_without_token_execution():
"""Test 4xx InvalidParameterValueException without Invalid Checkpoint Token is execution error."""
error_response = {
"Error": {
"Code": "InvalidParameterValueException",
"Message": "Some other invalid parameter",
},
"ResponseMetadata": {"HTTPStatusCode": 400},
}
client_error = ClientError(error_response, "Checkpoint")

result = CheckpointError.from_exception(client_error)

assert result.error_kind == "Execution"
assert result.should_be_retried()


def test_checkpoint_error_classification_5xx_invocation():
"""Test 5xx errors are invocation errors."""
error_response = {
"Error": {"Code": "InternalServerError", "Message": "Service unavailable"},
"ResponseMetadata": {"HTTPStatusCode": 500},
}
client_error = ClientError(error_response, "Checkpoint")

result = CheckpointError.from_exception(client_error)

assert result.error_kind == "Invocation"
assert not result.should_be_retried()


def test_checkpoint_error_classification_unknown_invocation():
"""Test unknown errors are invocation errors."""
unknown_error = Exception("Network timeout")

result = CheckpointError.from_exception(unknown_error)

assert result.error_kind == "Invocation"
assert not result.should_be_retried()


def test_validation_error():
"""Test ValidationError exception."""
error = ValidationError("validation failed")
Expand Down
150 changes: 147 additions & 3 deletions tests/execution_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,7 +1045,7 @@ def test_durable_execution_checkpoint_error_in_background_thread():
# Make the background checkpoint thread fail immediately
def failing_checkpoint(*args, **kwargs):
msg = "Background checkpoint failed"
raise CheckpointError(msg)
raise CheckpointError(msg, error_kind="Execution")

@durable_execution
def test_handler(event: Any, context: DurableContext) -> dict:
Expand Down Expand Up @@ -1088,7 +1088,7 @@ def test_handler(event: Any, context: DurableContext) -> dict:
# endregion durable_execution


def test_durable_execution_checkpoint_error_stops_background():
def test_durable_execution_checkpoint_execution_error_stops_background():
"""Test that CheckpointError handler stops background checkpointing.

When user code raises CheckpointError, the handler should stop the background
Expand All @@ -1100,7 +1100,7 @@ def test_durable_execution_checkpoint_error_stops_background():
def test_handler(event: Any, context: DurableContext) -> dict:
# Directly raise CheckpointError to simulate checkpoint failure
msg = "Checkpoint system failed"
raise CheckpointError(msg)
raise CheckpointError(msg, "Execution")

operation = Operation(
operation_id="exec1",
Expand Down Expand Up @@ -1140,6 +1140,148 @@ def slow_background():
test_handler(invocation_input, lambda_context)


def test_durable_execution_checkpoint_invocation_error_stops_background():
"""Test that CheckpointError handler stops background checkpointing.

When user code raises CheckpointError, the handler should stop the background
thread before re-raising to terminate the Lambda.
"""
mock_client = Mock(spec=DurableServiceClient)

@durable_execution
def test_handler(event: Any, context: DurableContext) -> dict:
# Directly raise CheckpointError to simulate checkpoint failure
msg = "Checkpoint system failed"
raise CheckpointError(msg, "Invocation")

operation = Operation(
operation_id="exec1",
operation_type=OperationType.EXECUTION,
status=OperationStatus.STARTED,
execution_details=ExecutionDetails(input_payload="{}"),
)

initial_state = InitialExecutionState(operations=[operation], next_marker="")

invocation_input = DurableExecutionInvocationInputWithClient(
durable_execution_arn="arn:test:execution",
checkpoint_token="token123", # noqa: S106
initial_execution_state=initial_state,
is_local_runner=False,
service_client=mock_client,
)

lambda_context = Mock()
lambda_context.aws_request_id = "test-request"
lambda_context.client_context = None
lambda_context.identity = None
lambda_context._epoch_deadline_time_in_ms = 1000000 # noqa: SLF001
lambda_context.invoked_function_arn = None
lambda_context.tenant_id = None

# Make background thread sleep so user code completes first
def slow_background():
time.sleep(1)

# Mock checkpoint_batches_forever to sleep (simulates background thread running)
with patch(
"aws_durable_execution_sdk_python.state.ExecutionState.checkpoint_batches_forever",
side_effect=slow_background,
):
response = test_handler(invocation_input, lambda_context)
assert response["Status"] == InvocationStatus.FAILED.value
assert response["Error"]["ErrorType"] == "CheckpointError"


def test_durable_execution_background_thread_execution_error_retries():
"""Test that background thread Execution errors are retried (re-raised)."""
mock_client = Mock(spec=DurableServiceClient)

def failing_checkpoint(*args, **kwargs):
msg = "Background checkpoint failed"
raise CheckpointError(msg, error_kind="Execution")

@durable_execution
def test_handler(event: Any, context: DurableContext) -> dict:
context.step(lambda ctx: "step_result")
return {"result": "success"}

operation = Operation(
operation_id="exec1",
operation_type=OperationType.EXECUTION,
status=OperationStatus.STARTED,
execution_details=ExecutionDetails(input_payload="{}"),
)

initial_state = InitialExecutionState(operations=[operation], next_marker="")

invocation_input = DurableExecutionInvocationInputWithClient(
durable_execution_arn="arn:test:execution",
checkpoint_token="token123", # noqa: S106
initial_execution_state=initial_state,
is_local_runner=False,
service_client=mock_client,
)

lambda_context = Mock()
lambda_context.aws_request_id = "test-request"
lambda_context.client_context = None
lambda_context.identity = None
lambda_context._epoch_deadline_time_in_ms = 1000000 # noqa: SLF001
lambda_context.invoked_function_arn = None
lambda_context.tenant_id = None

mock_client.checkpoint.side_effect = failing_checkpoint

with pytest.raises(CheckpointError, match="Background checkpoint failed"):
test_handler(invocation_input, lambda_context)


def test_durable_execution_background_thread_invocation_error_returns_failed():
"""Test that background thread Invocation errors return FAILED status."""
mock_client = Mock(spec=DurableServiceClient)

def failing_checkpoint(*args, **kwargs):
msg = "Background checkpoint failed"
raise CheckpointError(msg, error_kind="Invocation")

@durable_execution
def test_handler(event: Any, context: DurableContext) -> dict:
context.step(lambda ctx: "step_result")
return {"result": "success"}

operation = Operation(
operation_id="exec1",
operation_type=OperationType.EXECUTION,
status=OperationStatus.STARTED,
execution_details=ExecutionDetails(input_payload="{}"),
)

initial_state = InitialExecutionState(operations=[operation], next_marker="")

invocation_input = DurableExecutionInvocationInputWithClient(
durable_execution_arn="arn:test:execution",
checkpoint_token="token123", # noqa: S106
initial_execution_state=initial_state,
is_local_runner=False,
service_client=mock_client,
)

lambda_context = Mock()
lambda_context.aws_request_id = "test-request"
lambda_context.client_context = None
lambda_context.identity = None
lambda_context._epoch_deadline_time_in_ms = 1000000 # noqa: SLF001
lambda_context.invoked_function_arn = None
lambda_context.tenant_id = None

mock_client.checkpoint.side_effect = failing_checkpoint

response = test_handler(invocation_input, lambda_context)
assert response["Status"] == InvocationStatus.FAILED.value
assert response["Error"]["ErrorType"] == "CheckpointError"


def test_durable_handler_background_thread_failure_on_succeed_checkpoint():
"""Test durable_handler handles background thread failure on SUCCEED checkpoint.

Expand Down Expand Up @@ -1468,6 +1610,7 @@ def test_durable_execution_logs_checkpoint_error_extras_from_background_thread()
def failing_checkpoint(*args, **kwargs):
raise CheckpointError( # noqa TRY003
"Checkpoint failed", # noqa EM101
error_kind="Execution",
error=error_obj,
response_metadata=metadata_obj, # EM101
)
Expand Down Expand Up @@ -1589,6 +1732,7 @@ def test_durable_execution_logs_checkpoint_error_extras_from_user_code():
def test_handler(event: Any, context: DurableContext) -> dict:
raise CheckpointError( # noqa TRY003
"User checkpoint error", # noqa EM101
error_kind="Execution",
error=error_obj,
response_metadata=metadata_obj, # EM101
)
Expand Down
Loading