Skip to content

Commit f36d1b6

Browse files
committed
fix(sdk): match reference behaviour for large error payloads
Changes: - When payloads are large, we checkpoint the error and return only failed. - When payloads are small, we return back the error fixes: #41
1 parent 8367bcc commit f36d1b6

File tree

3 files changed

+98
-32
lines changed

3 files changed

+98
-32
lines changed

src/aws_durable_execution_sdk_python/execution.py

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from __future__ import annotations
22

3+
import contextlib
34
import json
45
import logging
56
from concurrent.futures import ThreadPoolExecutor
67
from dataclasses import dataclass
78
from enum import Enum
8-
from typing import TYPE_CHECKING, Any
9+
from typing import TYPE_CHECKING, Any, Tuple
910

1011
from aws_durable_execution_sdk_python.context import DurableContext, ExecutionState
1112
from aws_durable_execution_sdk_python.exceptions import (
@@ -191,6 +192,7 @@ def create_succeeded(cls, result: str) -> DurableExecutionInvocationOutput:
191192
# endregion Invocation models
192193

193194

195+
194196
def durable_execution(
195197
func: Callable[[Any, DurableContext], Any],
196198
) -> Callable[[Any, LambdaContext], Any]:
@@ -250,9 +252,12 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
250252
)
251253

252254
# Use ThreadPoolExecutor for concurrent execution of user code and background checkpoint processing
253-
with ThreadPoolExecutor(
254-
max_workers=2, thread_name_prefix="dex-handler"
255-
) as executor:
255+
with (
256+
ThreadPoolExecutor(
257+
max_workers=2, thread_name_prefix="dex-handler"
258+
) as executor,
259+
contextlib.closing(execution_state) as execution_state,
260+
):
256261
# Thread 1: Run background checkpoint processing
257262
executor.submit(execution_state.checkpoint_batches_forever)
258263

@@ -296,18 +301,12 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
296301
# Must ensure the result is persisted before returning to Lambda.
297302
# Large results exceed Lambda response limits and must be stored durably
298303
# before the execution completes.
299-
execution_state.create_checkpoint_sync(success_operation)
300-
301-
# Stop background checkpointing thread
302-
execution_state.stop_checkpointing()
304+
execution_state.create_checkpoint(success_operation, is_sync=True)
303305

304306
return DurableExecutionInvocationOutput.create_succeeded(
305307
result=""
306308
).to_dict()
307309

308-
# Stop background checkpointing thread
309-
execution_state.stop_checkpointing()
310-
311310
return DurableExecutionInvocationOutput.create_succeeded(
312311
result=serialized_result
313312
).to_dict()
@@ -322,33 +321,28 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
322321
)
323322
else:
324323
logger.exception("Checkpoint processing failed")
325-
execution_state.stop_checkpointing()
326324
# Raise the original exception
327325
raise bg_error.source_exception from bg_error
328326

329327
except SuspendExecution:
330328
# User code suspended - stop background checkpointing thread
331329
logger.debug("Suspending execution...")
332-
execution_state.stop_checkpointing()
333330
return DurableExecutionInvocationOutput(
334331
status=InvocationStatus.PENDING
335332
).to_dict()
336333

337334
except CheckpointError as e:
338335
# Checkpoint system is broken - stop background thread and exit immediately
339-
execution_state.stop_checkpointing()
340336
logger.exception(
341337
"Checkpoint system failed",
342338
extra=e.build_logger_extras(),
343339
)
344340
raise # Terminate Lambda immediately
345341
except InvocationError:
346-
execution_state.stop_checkpointing()
347342
logger.exception("Invocation error. Must terminate.")
348343
# Throw the error to trigger Lambda retry
349344
raise
350345
except ExecutionError as e:
351-
execution_state.stop_checkpointing()
352346
logger.exception("Execution error. Must terminate without retry.")
353347
return DurableExecutionInvocationOutput(
354348
status=InvocationStatus.FAILED,
@@ -357,15 +351,37 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
357351
except Exception as e:
358352
# all user-space errors go here
359353
logger.exception("Execution failed")
360-
failed_operation = OperationUpdate.create_execution_fail(
361-
error=ErrorObject.from_exception(e)
362-
)
363-
# TODO: can optimize, if not too large can just return response rather than checkpoint
364-
execution_state.create_checkpoint_sync(failed_operation)
365354

366-
execution_state.stop_checkpointing()
367-
return DurableExecutionInvocationOutput(
368-
status=InvocationStatus.FAILED
355+
result = DurableExecutionInvocationOutput(
356+
status=InvocationStatus.FAILED,
357+
error=ErrorObject.from_exception(e)
369358
).to_dict()
370359

360+
serialized_result = json.dumps(result)
361+
362+
if (
363+
serialized_result
364+
and len(serialized_result) > LAMBDA_RESPONSE_SIZE_LIMIT
365+
):
366+
logger.debug(
367+
"Response size (%s bytes) exceeds Lambda limit (%s) bytes). Checkpointing result.",
368+
len(serialized_result),
369+
LAMBDA_RESPONSE_SIZE_LIMIT,
370+
)
371+
failed_operation = OperationUpdate.create_execution_fail(
372+
error=ErrorObject.from_exception(e)
373+
)
374+
375+
# Checkpoint large result with blocking (is_sync=True, default).
376+
# Must ensure the result is persisted before returning to Lambda.
377+
# Large results exceed Lambda response limits and must be stored durably
378+
# before the execution completes.
379+
execution_state.create_checkpoint_sync(failed_operation)
380+
381+
return DurableExecutionInvocationOutput(
382+
status=InvocationStatus.FAILED
383+
).to_dict()
384+
385+
return result
386+
371387
return wrapper

src/aws_durable_execution_sdk_python/state.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,3 +731,6 @@ def _calculate_operation_size(queued_op: QueuedOperation) -> int:
731731
# Use JSON serialization to estimate size
732732
serialized = json.dumps(queued_op.operation_update.to_dict()).encode("utf-8")
733733
return len(serialized)
734+
735+
def close(self):
736+
self.stop_checkpointing()

tests/execution_test.py

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -592,17 +592,63 @@ def test_handler(event: Any, context: DurableContext) -> dict:
592592

593593
result = test_handler(invocation_input, lambda_context)
594594

595+
# small error, should not call checkpoint
595596
assert result["Status"] == InvocationStatus.FAILED.value
597+
assert result["Error"] == {'ErrorMessage': 'Test error', 'ErrorType': 'ValueError'}
598+
599+
assert not mock_client.checkpoint.called
600+
601+
602+
def test_durable_execution_with_large_error_payload():
603+
"""Test that large error payloads trigger checkpoint."""
604+
mock_client = Mock(spec=DurableServiceClient)
605+
mock_output = CheckpointOutput(
606+
checkpoint_token="new_token", # noqa: S106
607+
new_execution_state=CheckpointUpdatedExecutionState(),
608+
)
609+
mock_client.checkpoint.return_value = mock_output
610+
611+
@durable_execution
612+
def test_handler(event: Any, context: DurableContext) -> dict:
613+
raise ValueError(LARGE_RESULT)
614+
615+
operation = Operation(
616+
operation_id="exec1",
617+
operation_type=OperationType.EXECUTION,
618+
status=OperationStatus.STARTED,
619+
execution_details=ExecutionDetails(input_payload="{}"),
620+
)
621+
622+
initial_state = InitialExecutionState(operations=[operation], next_marker="")
623+
624+
invocation_input = DurableExecutionInvocationInputWithClient(
625+
durable_execution_arn="arn:test:execution",
626+
checkpoint_token="token123", # noqa: S106
627+
initial_execution_state=initial_state,
628+
is_local_runner=False,
629+
service_client=mock_client,
630+
)
631+
632+
lambda_context = Mock()
633+
lambda_context.aws_request_id = "test-request"
634+
lambda_context.client_context = None
635+
lambda_context.identity = None
636+
lambda_context._epoch_deadline_time_in_ms = 1000000 # noqa: SLF001
637+
lambda_context.invoked_function_arn = None
638+
lambda_context.tenant_id = None
639+
640+
result = test_handler(invocation_input, lambda_context)
641+
642+
assert result["Status"] == InvocationStatus.FAILED.value
643+
assert "Error" not in result
596644
mock_client.checkpoint.assert_called_once()
597645

598-
# Verify the checkpoint call was for execution failure
599646
call_args = mock_client.checkpoint.call_args
600647
updates = call_args[1]["updates"]
601648
assert len(updates) == 1
602649
assert updates[0].operation_type == OperationType.EXECUTION
603650
assert updates[0].action.value == "FAIL"
604-
assert updates[0].error.message == "Test error"
605-
assert updates[0].error.type == "ValueError"
651+
assert updates[0].error.message == LARGE_RESULT
606652

607653

608654
def test_durable_execution_fatal_error_handling():
@@ -1404,11 +1450,12 @@ def test_handler(event: Any, context: DurableContext) -> str:
14041450
# Make the service client checkpoint call fail on error handling
14051451
mock_client.checkpoint.side_effect = failing_checkpoint
14061452

1407-
# Verify that the checkpoint error is raised (not the original ValueError)
1408-
with pytest.raises(
1409-
RuntimeError, match="Background checkpoint failed on error handling"
1410-
):
1411-
test_handler(invocation_input, lambda_context)
1453+
# Verify that errors are not raised, but returned because response is small
1454+
resp = test_handler(invocation_input, lambda_context)
1455+
assert resp['Error']['ErrorMessage'] == "User function error"
1456+
assert resp['Error']['ErrorType'] == "ValueError"
1457+
assert resp['Status'] == InvocationStatus.FAILED.value
1458+
14121459

14131460

14141461
def test_durable_execution_logs_checkpoint_error_extras_from_background_thread():

0 commit comments

Comments
 (0)