Skip to content

Commit eaf9d03

Browse files
committed
Differentiate Invocation and Execution Checkpoint errors
Changes: - CheckpointErrors now inspect service metadata and identify whether they should be invocation errors or execution errors - Added tests to ensure we match reference - Added tests to ensure we test all in execution.py
1 parent e2d4527 commit eaf9d03

File tree

4 files changed

+279
-7
lines changed

4 files changed

+279
-7
lines changed

src/aws_durable_execution_sdk_python/exceptions.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
import time
99
from dataclasses import dataclass
1010
from enum import Enum
11-
from typing import TYPE_CHECKING, Self, TypedDict
11+
from typing import TYPE_CHECKING, Literal, Self, TypedDict
12+
13+
BAD_REQUEST_ERROR: int = 400
14+
SERVICE_ERROR: int = 500
1215

1316
if TYPE_CHECKING:
1417
import datetime
@@ -22,7 +25,7 @@ class AwsErrorObj(TypedDict):
2225
class AwsErrorMetadata(TypedDict):
2326
RequestId: str | None
2427
HostId: str | None
25-
HTTPStatusCode: str | None
28+
HTTPStatusCode: int | None
2629
HTTPHeaders: str | None
2730
RetryAttempts: str | None
2831

@@ -121,12 +124,16 @@ def __init__(self, message: str, step_id: str | None = None):
121124
self.step_id = step_id
122125

123126

127+
CheckpointErrorKind = Literal["Execution", "Invocation"]
128+
129+
124130
class CheckpointError(BotoClientError):
125131
"""Failure to checkpoint. Will terminate the lambda."""
126132

127133
def __init__(
128134
self,
129135
message: str,
136+
error_kind: CheckpointErrorKind,
130137
error: AwsErrorObj | None = None,
131138
response_metadata: AwsErrorMetadata | None = None,
132139
):
@@ -136,6 +143,40 @@ def __init__(
136143
response_metadata,
137144
termination_reason=TerminationReason.CHECKPOINT_FAILED,
138145
)
146+
self.error_kind: CheckpointErrorKind = error_kind
147+
148+
@classmethod
149+
def from_exception(cls, exception: Exception) -> CheckpointError:
150+
base = BotoClientError.from_exception(exception)
151+
metadata: AwsErrorMetadata | None = base.response_metadata
152+
error: AwsErrorObj | None = base.error
153+
error_kind: CheckpointErrorKind = "Invocation"
154+
155+
# InvalidParameterValueException and error message starts with "Invalid Checkpoint Token" is an InvocationError
156+
# all other 4xx errors are Execution Errors and should be retried
157+
# all 5xx errors are Invocation Errors
158+
status_code: int | None = (metadata and metadata.get("HTTPStatusCode")) or None
159+
if (
160+
status_code
161+
# if we are in 4xx range and is not an InvalidParameterValueException with Invalid Checkpoint Token
162+
# then it's an execution error
163+
and status_code < SERVICE_ERROR
164+
and status_code >= BAD_REQUEST_ERROR
165+
and error
166+
and (
167+
# is not InvalidParam => Execution
168+
(error.get("Code", "") or "") != "InvalidParameterValueException"
169+
# is not Invalid Token => Execution
170+
or not (error.get("Message") or "").startswith(
171+
"Invalid Checkpoint Token"
172+
)
173+
)
174+
):
175+
error_kind = "Execution"
176+
return CheckpointError(str(exception), error_kind, error, metadata)
177+
178+
def should_be_retried(self):
179+
return self.error_kind == "Execution"
139180

140181

141182
class ValidationError(DurableExecutionsError):

src/aws_durable_execution_sdk_python/execution.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,16 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
321321
else:
322322
logger.exception("Checkpoint processing failed")
323323
# Raise the original exception
324+
if (
325+
isinstance(bg_error.source_exception, CheckpointError)
326+
and bg_error.source_exception.should_be_retried()
327+
):
328+
raise bg_error.source_exception from None # Terminate Lambda immediately and have it be retried
329+
if isinstance(bg_error.source_exception, CheckpointError):
330+
return DurableExecutionInvocationOutput(
331+
status=InvocationStatus.FAILED,
332+
error=ErrorObject.from_exception(bg_error.source_exception),
333+
).to_dict()
324334
raise bg_error.source_exception from bg_error
325335

326336
except SuspendExecution:
@@ -336,7 +346,11 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
336346
"Checkpoint system failed",
337347
extra=e.build_logger_extras(),
338348
)
339-
raise # Terminate Lambda immediately
349+
if e.should_be_retried():
350+
raise # Terminate Lambda immediately and have it be retried
351+
return DurableExecutionInvocationOutput(
352+
status=InvocationStatus.FAILED, error=ErrorObject.from_exception(e)
353+
).to_dict()
340354
except InvocationError:
341355
logger.exception("Invocation error. Must terminate.")
342356
# Throw the error to trigger Lambda retry

tests/exceptions_test.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from unittest.mock import patch
55

66
import pytest
7+
from botocore.exceptions import ClientError # type: ignore[import-untyped]
78

89
from aws_durable_execution_sdk_python.exceptions import (
910
CallableRuntimeError,
@@ -41,13 +42,85 @@ def test_invocation_error():
4142

4243
def test_checkpoint_error():
4344
"""Test CheckpointError exception."""
44-
error = CheckpointError("checkpoint failed")
45+
error = CheckpointError("checkpoint failed", error_kind="Execution")
4546
assert str(error) == "checkpoint failed"
4647
assert isinstance(error, InvocationError)
4748
assert isinstance(error, UnrecoverableError)
4849
assert error.termination_reason == TerminationReason.CHECKPOINT_FAILED
4950

5051

52+
def test_checkpoint_error_classification_invalid_token_invocation():
53+
"""Test 4xx InvalidParameterValueException with Invalid Checkpoint Token is invocation error."""
54+
error_response = {
55+
"Error": {
56+
"Code": "InvalidParameterValueException",
57+
"Message": "Invalid Checkpoint Token: token expired",
58+
},
59+
"ResponseMetadata": {"HTTPStatusCode": 400},
60+
}
61+
client_error = ClientError(error_response, "Checkpoint")
62+
63+
result = CheckpointError.from_exception(client_error)
64+
65+
assert result.error_kind == "Invocation"
66+
assert not result.should_be_retried()
67+
68+
69+
def test_checkpoint_error_classification_other_4xx_execution():
70+
"""Test other 4xx errors are execution errors."""
71+
error_response = {
72+
"Error": {"Code": "ValidationException", "Message": "Invalid parameter value"},
73+
"ResponseMetadata": {"HTTPStatusCode": 400},
74+
}
75+
client_error = ClientError(error_response, "Checkpoint")
76+
77+
result = CheckpointError.from_exception(client_error)
78+
79+
assert result.error_kind == "Execution"
80+
assert result.should_be_retried()
81+
82+
83+
def test_checkpoint_error_classification_invalid_param_without_token_execution():
84+
"""Test 4xx InvalidParameterValueException without Invalid Checkpoint Token is execution error."""
85+
error_response = {
86+
"Error": {
87+
"Code": "InvalidParameterValueException",
88+
"Message": "Some other invalid parameter",
89+
},
90+
"ResponseMetadata": {"HTTPStatusCode": 400},
91+
}
92+
client_error = ClientError(error_response, "Checkpoint")
93+
94+
result = CheckpointError.from_exception(client_error)
95+
96+
assert result.error_kind == "Execution"
97+
assert result.should_be_retried()
98+
99+
100+
def test_checkpoint_error_classification_5xx_invocation():
101+
"""Test 5xx errors are invocation errors."""
102+
error_response = {
103+
"Error": {"Code": "InternalServerError", "Message": "Service unavailable"},
104+
"ResponseMetadata": {"HTTPStatusCode": 500},
105+
}
106+
client_error = ClientError(error_response, "Checkpoint")
107+
108+
result = CheckpointError.from_exception(client_error)
109+
110+
assert result.error_kind == "Invocation"
111+
assert not result.should_be_retried()
112+
113+
114+
def test_checkpoint_error_classification_unknown_invocation():
115+
"""Test unknown errors are invocation errors."""
116+
unknown_error = Exception("Network timeout")
117+
118+
result = CheckpointError.from_exception(unknown_error)
119+
120+
assert result.error_kind == "Invocation"
121+
assert not result.should_be_retried()
122+
123+
51124
def test_validation_error():
52125
"""Test ValidationError exception."""
53126
error = ValidationError("validation failed")

tests/execution_test.py

Lines changed: 147 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,7 +1045,7 @@ def test_durable_execution_checkpoint_error_in_background_thread():
10451045
# Make the background checkpoint thread fail immediately
10461046
def failing_checkpoint(*args, **kwargs):
10471047
msg = "Background checkpoint failed"
1048-
raise CheckpointError(msg)
1048+
raise CheckpointError(msg, error_kind="Execution")
10491049

10501050
@durable_execution
10511051
def test_handler(event: Any, context: DurableContext) -> dict:
@@ -1088,7 +1088,7 @@ def test_handler(event: Any, context: DurableContext) -> dict:
10881088
# endregion durable_execution
10891089

10901090

1091-
def test_durable_execution_checkpoint_error_stops_background():
1091+
def test_durable_execution_checkpoint_execution_error_stops_background():
10921092
"""Test that CheckpointError handler stops background checkpointing.
10931093
10941094
When user code raises CheckpointError, the handler should stop the background
@@ -1100,7 +1100,7 @@ def test_durable_execution_checkpoint_error_stops_background():
11001100
def test_handler(event: Any, context: DurableContext) -> dict:
11011101
# Directly raise CheckpointError to simulate checkpoint failure
11021102
msg = "Checkpoint system failed"
1103-
raise CheckpointError(msg)
1103+
raise CheckpointError(msg, "Execution")
11041104

11051105
operation = Operation(
11061106
operation_id="exec1",
@@ -1140,6 +1140,148 @@ def slow_background():
11401140
test_handler(invocation_input, lambda_context)
11411141

11421142

1143+
def test_durable_execution_checkpoint_invocation_error_stops_background():
1144+
"""Test that CheckpointError handler stops background checkpointing.
1145+
1146+
When user code raises CheckpointError, the handler should stop the background
1147+
thread before re-raising to terminate the Lambda.
1148+
"""
1149+
mock_client = Mock(spec=DurableServiceClient)
1150+
1151+
@durable_execution
1152+
def test_handler(event: Any, context: DurableContext) -> dict:
1153+
# Directly raise CheckpointError to simulate checkpoint failure
1154+
msg = "Checkpoint system failed"
1155+
raise CheckpointError(msg, "Invocation")
1156+
1157+
operation = Operation(
1158+
operation_id="exec1",
1159+
operation_type=OperationType.EXECUTION,
1160+
status=OperationStatus.STARTED,
1161+
execution_details=ExecutionDetails(input_payload="{}"),
1162+
)
1163+
1164+
initial_state = InitialExecutionState(operations=[operation], next_marker="")
1165+
1166+
invocation_input = DurableExecutionInvocationInputWithClient(
1167+
durable_execution_arn="arn:test:execution",
1168+
checkpoint_token="token123", # noqa: S106
1169+
initial_execution_state=initial_state,
1170+
is_local_runner=False,
1171+
service_client=mock_client,
1172+
)
1173+
1174+
lambda_context = Mock()
1175+
lambda_context.aws_request_id = "test-request"
1176+
lambda_context.client_context = None
1177+
lambda_context.identity = None
1178+
lambda_context._epoch_deadline_time_in_ms = 1000000 # noqa: SLF001
1179+
lambda_context.invoked_function_arn = None
1180+
lambda_context.tenant_id = None
1181+
1182+
# Make background thread sleep so user code completes first
1183+
def slow_background():
1184+
time.sleep(1)
1185+
1186+
# Mock checkpoint_batches_forever to sleep (simulates background thread running)
1187+
with patch(
1188+
"aws_durable_execution_sdk_python.state.ExecutionState.checkpoint_batches_forever",
1189+
side_effect=slow_background,
1190+
):
1191+
response = test_handler(invocation_input, lambda_context)
1192+
assert response["Status"] == InvocationStatus.FAILED.value
1193+
assert response["Error"]["ErrorType"] == "CheckpointError"
1194+
1195+
1196+
def test_durable_execution_background_thread_execution_error_retries():
1197+
"""Test that background thread Execution errors are retried (re-raised)."""
1198+
mock_client = Mock(spec=DurableServiceClient)
1199+
1200+
def failing_checkpoint(*args, **kwargs):
1201+
msg = "Background checkpoint failed"
1202+
raise CheckpointError(msg, error_kind="Execution")
1203+
1204+
@durable_execution
1205+
def test_handler(event: Any, context: DurableContext) -> dict:
1206+
context.step(lambda ctx: "step_result")
1207+
return {"result": "success"}
1208+
1209+
operation = Operation(
1210+
operation_id="exec1",
1211+
operation_type=OperationType.EXECUTION,
1212+
status=OperationStatus.STARTED,
1213+
execution_details=ExecutionDetails(input_payload="{}"),
1214+
)
1215+
1216+
initial_state = InitialExecutionState(operations=[operation], next_marker="")
1217+
1218+
invocation_input = DurableExecutionInvocationInputWithClient(
1219+
durable_execution_arn="arn:test:execution",
1220+
checkpoint_token="token123", # noqa: S106
1221+
initial_execution_state=initial_state,
1222+
is_local_runner=False,
1223+
service_client=mock_client,
1224+
)
1225+
1226+
lambda_context = Mock()
1227+
lambda_context.aws_request_id = "test-request"
1228+
lambda_context.client_context = None
1229+
lambda_context.identity = None
1230+
lambda_context._epoch_deadline_time_in_ms = 1000000 # noqa: SLF001
1231+
lambda_context.invoked_function_arn = None
1232+
lambda_context.tenant_id = None
1233+
1234+
mock_client.checkpoint.side_effect = failing_checkpoint
1235+
1236+
with pytest.raises(CheckpointError, match="Background checkpoint failed"):
1237+
test_handler(invocation_input, lambda_context)
1238+
1239+
1240+
def test_durable_execution_background_thread_invocation_error_returns_failed():
1241+
"""Test that background thread Invocation errors return FAILED status."""
1242+
mock_client = Mock(spec=DurableServiceClient)
1243+
1244+
def failing_checkpoint(*args, **kwargs):
1245+
msg = "Background checkpoint failed"
1246+
raise CheckpointError(msg, error_kind="Invocation")
1247+
1248+
@durable_execution
1249+
def test_handler(event: Any, context: DurableContext) -> dict:
1250+
context.step(lambda ctx: "step_result")
1251+
return {"result": "success"}
1252+
1253+
operation = Operation(
1254+
operation_id="exec1",
1255+
operation_type=OperationType.EXECUTION,
1256+
status=OperationStatus.STARTED,
1257+
execution_details=ExecutionDetails(input_payload="{}"),
1258+
)
1259+
1260+
initial_state = InitialExecutionState(operations=[operation], next_marker="")
1261+
1262+
invocation_input = DurableExecutionInvocationInputWithClient(
1263+
durable_execution_arn="arn:test:execution",
1264+
checkpoint_token="token123", # noqa: S106
1265+
initial_execution_state=initial_state,
1266+
is_local_runner=False,
1267+
service_client=mock_client,
1268+
)
1269+
1270+
lambda_context = Mock()
1271+
lambda_context.aws_request_id = "test-request"
1272+
lambda_context.client_context = None
1273+
lambda_context.identity = None
1274+
lambda_context._epoch_deadline_time_in_ms = 1000000 # noqa: SLF001
1275+
lambda_context.invoked_function_arn = None
1276+
lambda_context.tenant_id = None
1277+
1278+
mock_client.checkpoint.side_effect = failing_checkpoint
1279+
1280+
response = test_handler(invocation_input, lambda_context)
1281+
assert response["Status"] == InvocationStatus.FAILED.value
1282+
assert response["Error"]["ErrorType"] == "CheckpointError"
1283+
1284+
11431285
def test_durable_handler_background_thread_failure_on_succeed_checkpoint():
11441286
"""Test durable_handler handles background thread failure on SUCCEED checkpoint.
11451287
@@ -1468,6 +1610,7 @@ def test_durable_execution_logs_checkpoint_error_extras_from_background_thread()
14681610
def failing_checkpoint(*args, **kwargs):
14691611
raise CheckpointError( # noqa TRY003
14701612
"Checkpoint failed", # noqa EM101
1613+
error_kind="Execution",
14711614
error=error_obj,
14721615
response_metadata=metadata_obj, # EM101
14731616
)
@@ -1589,6 +1732,7 @@ def test_durable_execution_logs_checkpoint_error_extras_from_user_code():
15891732
def test_handler(event: Any, context: DurableContext) -> dict:
15901733
raise CheckpointError( # noqa TRY003
15911734
"User checkpoint error", # noqa EM101
1735+
error_kind="Execution",
15921736
error=error_obj,
15931737
response_metadata=metadata_obj, # EM101
15941738
)

0 commit comments

Comments
 (0)