Skip to content

Commit 1c513c0

Browse files
committed
Add checkpoint error handling when final checkpointing fails
1 parent eaf9d03 commit 1c513c0

File tree

4 files changed

+243
-47
lines changed

4 files changed

+243
-47
lines changed

src/aws_durable_execution_sdk_python/exceptions.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import time
99
from dataclasses import dataclass
1010
from enum import Enum
11-
from typing import TYPE_CHECKING, Literal, Self, TypedDict
11+
from typing import TYPE_CHECKING, Self, TypedDict
1212

1313
BAD_REQUEST_ERROR: int = 400
1414
SERVICE_ERROR: int = 500
@@ -124,7 +124,9 @@ def __init__(self, message: str, step_id: str | None = None):
124124
self.step_id = step_id
125125

126126

127-
CheckpointErrorKind = Literal["Execution", "Invocation"]
127+
class CheckpointErrorCategory(Enum):
128+
INVOCATION = "INVOCATION"
129+
EXECUTION = "EXECUTION"
128130

129131

130132
class CheckpointError(BotoClientError):
@@ -133,7 +135,7 @@ class CheckpointError(BotoClientError):
133135
def __init__(
134136
self,
135137
message: str,
136-
error_kind: CheckpointErrorKind,
138+
error_category: CheckpointErrorCategory,
137139
error: AwsErrorObj | None = None,
138140
response_metadata: AwsErrorMetadata | None = None,
139141
):
@@ -143,14 +145,14 @@ def __init__(
143145
response_metadata,
144146
termination_reason=TerminationReason.CHECKPOINT_FAILED,
145147
)
146-
self.error_kind: CheckpointErrorKind = error_kind
148+
self.error_category: CheckpointErrorCategory = error_category
147149

148150
@classmethod
149151
def from_exception(cls, exception: Exception) -> CheckpointError:
150152
base = BotoClientError.from_exception(exception)
151153
metadata: AwsErrorMetadata | None = base.response_metadata
152154
error: AwsErrorObj | None = base.error
153-
error_kind: CheckpointErrorKind = "Invocation"
155+
error_category: CheckpointErrorCategory = CheckpointErrorCategory.INVOCATION
154156

155157
# InvalidParameterValueException and error message starts with "Invalid Checkpoint Token" is an InvocationError
156158
# all other 4xx errors are Execution Errors and should be retried
@@ -172,11 +174,11 @@ def from_exception(cls, exception: Exception) -> CheckpointError:
172174
)
173175
)
174176
):
175-
error_kind = "Execution"
176-
return CheckpointError(str(exception), error_kind, error, metadata)
177+
error_category = CheckpointErrorCategory.EXECUTION
178+
return CheckpointError(str(exception), error_category, error, metadata)
177179

178-
def should_be_retried(self):
179-
return self.error_kind == "Execution"
180+
def is_retriable(self):
181+
return self.error_category == CheckpointErrorCategory.EXECUTION
180182

181183

182184
class ValidationError(DurableExecutionsError):

src/aws_durable_execution_sdk_python/execution.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,6 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
281281
invocation_input.durable_execution_arn,
282282
)
283283
serialized_result = json.dumps(result)
284-
285284
# large response handling here. Remember if checkpointing to complete, NOT to include
286285
# payload in response
287286
if (
@@ -300,8 +299,12 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
300299
# Must ensure the result is persisted before returning to Lambda.
301300
# Large results exceed Lambda response limits and must be stored durably
302301
# before the execution completes.
303-
execution_state.create_checkpoint(success_operation, is_sync=True)
304-
302+
try:
303+
execution_state.create_checkpoint(
304+
success_operation, is_sync=True
305+
)
306+
except CheckpointError as e:
307+
return handle_checkpoint_error(e).to_dict()
305308
return DurableExecutionInvocationOutput.create_succeeded(
306309
result=""
307310
).to_dict()
@@ -320,17 +323,9 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
320323
)
321324
else:
322325
logger.exception("Checkpoint processing failed")
323-
# Raise the original exception
324-
if (
325-
isinstance(bg_error.source_exception, CheckpointError)
326-
and bg_error.source_exception.should_be_retried()
327-
):
328-
raise bg_error.source_exception from None # Terminate Lambda immediately and have it be retried
326+
# handle the original exception
329327
if isinstance(bg_error.source_exception, CheckpointError):
330-
return DurableExecutionInvocationOutput(
331-
status=InvocationStatus.FAILED,
332-
error=ErrorObject.from_exception(bg_error.source_exception),
333-
).to_dict()
328+
return handle_checkpoint_error(bg_error.source_exception).to_dict()
334329
raise bg_error.source_exception from bg_error
335330

336331
except SuspendExecution:
@@ -346,11 +341,7 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
346341
"Checkpoint system failed",
347342
extra=e.build_logger_extras(),
348343
)
349-
if e.should_be_retried():
350-
raise # Terminate Lambda immediately and have it be retried
351-
return DurableExecutionInvocationOutput(
352-
status=InvocationStatus.FAILED, error=ErrorObject.from_exception(e)
353-
).to_dict()
344+
return handle_checkpoint_error(e).to_dict()
354345
except InvocationError:
355346
logger.exception("Invocation error. Must terminate.")
356347
# Throw the error to trigger Lambda retry
@@ -388,12 +379,22 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
388379
# Must ensure the result is persisted before returning to Lambda.
389380
# Large results exceed Lambda response limits and must be stored durably
390381
# before the execution completes.
391-
execution_state.create_checkpoint_sync(failed_operation)
392-
382+
try:
383+
execution_state.create_checkpoint_sync(failed_operation)
384+
except CheckpointError as e:
385+
return handle_checkpoint_error(e).to_dict()
393386
return DurableExecutionInvocationOutput(
394387
status=InvocationStatus.FAILED
395388
).to_dict()
396389

397390
return result
398391

399392
return wrapper
393+
394+
395+
def handle_checkpoint_error(error: CheckpointError) -> DurableExecutionInvocationOutput:
396+
if error.is_retriable():
397+
raise error from None # Terminate Lambda immediately and have it be retried
398+
return DurableExecutionInvocationOutput(
399+
status=InvocationStatus.FAILED, error=ErrorObject.from_exception(error)
400+
)

tests/exceptions_test.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
CallableRuntimeError,
1111
CallableRuntimeErrorSerializableDetails,
1212
CheckpointError,
13+
CheckpointErrorCategory,
1314
DurableExecutionsError,
1415
ExecutionError,
1516
InvocationError,
@@ -42,7 +43,9 @@ def test_invocation_error():
4243

4344
def test_checkpoint_error():
4445
"""Test CheckpointError exception."""
45-
error = CheckpointError("checkpoint failed", error_kind="Execution")
46+
error = CheckpointError(
47+
"checkpoint failed", error_category=CheckpointErrorCategory.EXECUTION
48+
)
4649
assert str(error) == "checkpoint failed"
4750
assert isinstance(error, InvocationError)
4851
assert isinstance(error, UnrecoverableError)
@@ -62,8 +65,8 @@ def test_checkpoint_error_classification_invalid_token_invocation():
6265

6366
result = CheckpointError.from_exception(client_error)
6467

65-
assert result.error_kind == "Invocation"
66-
assert not result.should_be_retried()
68+
assert result.error_category == CheckpointErrorCategory.INVOCATION
69+
assert not result.is_retriable()
6770

6871

6972
def test_checkpoint_error_classification_other_4xx_execution():
@@ -76,8 +79,8 @@ def test_checkpoint_error_classification_other_4xx_execution():
7679

7780
result = CheckpointError.from_exception(client_error)
7881

79-
assert result.error_kind == "Execution"
80-
assert result.should_be_retried()
82+
assert result.error_category == CheckpointErrorCategory.EXECUTION
83+
assert result.is_retriable()
8184

8285

8386
def test_checkpoint_error_classification_invalid_param_without_token_execution():
@@ -93,8 +96,8 @@ def test_checkpoint_error_classification_invalid_param_without_token_execution()
9396

9497
result = CheckpointError.from_exception(client_error)
9598

96-
assert result.error_kind == "Execution"
97-
assert result.should_be_retried()
99+
assert result.error_category == CheckpointErrorCategory.EXECUTION
100+
assert result.is_retriable()
98101

99102

100103
def test_checkpoint_error_classification_5xx_invocation():
@@ -107,8 +110,8 @@ def test_checkpoint_error_classification_5xx_invocation():
107110

108111
result = CheckpointError.from_exception(client_error)
109112

110-
assert result.error_kind == "Invocation"
111-
assert not result.should_be_retried()
113+
assert result.error_category == CheckpointErrorCategory.INVOCATION
114+
assert not result.is_retriable()
112115

113116

114117
def test_checkpoint_error_classification_unknown_invocation():
@@ -117,8 +120,8 @@ def test_checkpoint_error_classification_unknown_invocation():
117120

118121
result = CheckpointError.from_exception(unknown_error)
119122

120-
assert result.error_kind == "Invocation"
121-
assert not result.should_be_retried()
123+
assert result.error_category == CheckpointErrorCategory.INVOCATION
124+
assert not result.is_retriable()
122125

123126

124127
def test_validation_error():

0 commit comments

Comments
 (0)