From 73e5bf45d559c762e962f823802330ee7407f221 Mon Sep 17 00:00:00 2001 From: Quinn Sinclair Date: Fri, 7 Nov 2025 22:33:47 +0000 Subject: [PATCH] fix(sdk): match reference behaviour for large error payloads Changes: - When payloads are large, we checkpoint the error and return only failed. - When payloads are small, we return back the error fixes: #41 --- .../execution.py | 60 +++++++++++------- src/aws_durable_execution_sdk_python/state.py | 3 + tests/execution_test.py | 62 ++++++++++++++++--- 3 files changed, 94 insertions(+), 31 deletions(-) diff --git a/src/aws_durable_execution_sdk_python/execution.py b/src/aws_durable_execution_sdk_python/execution.py index d2efd80..bc80948 100644 --- a/src/aws_durable_execution_sdk_python/execution.py +++ b/src/aws_durable_execution_sdk_python/execution.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextlib import json import logging from concurrent.futures import ThreadPoolExecutor @@ -250,9 +251,12 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: ) # Use ThreadPoolExecutor for concurrent execution of user code and background checkpoint processing - with ThreadPoolExecutor( - max_workers=2, thread_name_prefix="dex-handler" - ) as executor: + with ( + ThreadPoolExecutor( + max_workers=2, thread_name_prefix="dex-handler" + ) as executor, + contextlib.closing(execution_state) as execution_state, + ): # Thread 1: Run background checkpoint processing executor.submit(execution_state.checkpoint_batches_forever) @@ -296,18 +300,12 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: # Must ensure the result is persisted before returning to Lambda. # Large results exceed Lambda response limits and must be stored durably # before the execution completes. - execution_state.create_checkpoint_sync(success_operation) - - # Stop background checkpointing thread - execution_state.stop_checkpointing() + execution_state.create_checkpoint(success_operation, is_sync=True) return DurableExecutionInvocationOutput.create_succeeded( result="" ).to_dict() - # Stop background checkpointing thread - execution_state.stop_checkpointing() - return DurableExecutionInvocationOutput.create_succeeded( result=serialized_result ).to_dict() @@ -322,33 +320,28 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: ) else: logger.exception("Checkpoint processing failed") - execution_state.stop_checkpointing() # Raise the original exception raise bg_error.source_exception from bg_error except SuspendExecution: # User code suspended - stop background checkpointing thread logger.debug("Suspending execution...") - execution_state.stop_checkpointing() return DurableExecutionInvocationOutput( status=InvocationStatus.PENDING ).to_dict() except CheckpointError as e: # Checkpoint system is broken - stop background thread and exit immediately - execution_state.stop_checkpointing() logger.exception( "Checkpoint system failed", extra=e.build_logger_extras(), ) raise # Terminate Lambda immediately except InvocationError: - execution_state.stop_checkpointing() logger.exception("Invocation error. Must terminate.") # Throw the error to trigger Lambda retry raise except ExecutionError as e: - execution_state.stop_checkpointing() logger.exception("Execution error. Must terminate without retry.") return DurableExecutionInvocationOutput( status=InvocationStatus.FAILED, @@ -357,15 +350,36 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: except Exception as e: # all user-space errors go here logger.exception("Execution failed") - failed_operation = OperationUpdate.create_execution_fail( - error=ErrorObject.from_exception(e) - ) - # TODO: can optimize, if not too large can just return response rather than checkpoint - execution_state.create_checkpoint_sync(failed_operation) - execution_state.stop_checkpointing() - return DurableExecutionInvocationOutput( - status=InvocationStatus.FAILED + result = DurableExecutionInvocationOutput( + status=InvocationStatus.FAILED, error=ErrorObject.from_exception(e) ).to_dict() + serialized_result = json.dumps(result) + + if ( + serialized_result + and len(serialized_result) > LAMBDA_RESPONSE_SIZE_LIMIT + ): + logger.debug( + "Response size (%s bytes) exceeds Lambda limit (%s) bytes). Checkpointing result.", + len(serialized_result), + LAMBDA_RESPONSE_SIZE_LIMIT, + ) + failed_operation = OperationUpdate.create_execution_fail( + error=ErrorObject.from_exception(e) + ) + + # Checkpoint large result with blocking (is_sync=True, default). + # Must ensure the result is persisted before returning to Lambda. + # Large results exceed Lambda response limits and must be stored durably + # before the execution completes. + execution_state.create_checkpoint_sync(failed_operation) + + return DurableExecutionInvocationOutput( + status=InvocationStatus.FAILED + ).to_dict() + + return result + return wrapper diff --git a/src/aws_durable_execution_sdk_python/state.py b/src/aws_durable_execution_sdk_python/state.py index d97d19d..ff251ed 100644 --- a/src/aws_durable_execution_sdk_python/state.py +++ b/src/aws_durable_execution_sdk_python/state.py @@ -731,3 +731,6 @@ def _calculate_operation_size(queued_op: QueuedOperation) -> int: # Use JSON serialization to estimate size serialized = json.dumps(queued_op.operation_update.to_dict()).encode("utf-8") return len(serialized) + + def close(self): + self.stop_checkpointing() diff --git a/tests/execution_test.py b/tests/execution_test.py index df7c8c0..a2b72c1 100644 --- a/tests/execution_test.py +++ b/tests/execution_test.py @@ -592,17 +592,63 @@ def test_handler(event: Any, context: DurableContext) -> dict: result = test_handler(invocation_input, lambda_context) + # small error, should not call checkpoint assert result["Status"] == InvocationStatus.FAILED.value + assert result["Error"] == {"ErrorMessage": "Test error", "ErrorType": "ValueError"} + + assert not mock_client.checkpoint.called + + +def test_durable_execution_with_large_error_payload(): + """Test that large error payloads trigger checkpoint.""" + mock_client = Mock(spec=DurableServiceClient) + mock_output = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState(), + ) + mock_client.checkpoint.return_value = mock_output + + @durable_execution + def test_handler(event: Any, context: DurableContext) -> dict: + raise ValueError(LARGE_RESULT) + + operation = Operation( + operation_id="exec1", + operation_type=OperationType.EXECUTION, + status=OperationStatus.STARTED, + execution_details=ExecutionDetails(input_payload="{}"), + ) + + initial_state = InitialExecutionState(operations=[operation], next_marker="") + + invocation_input = DurableExecutionInvocationInputWithClient( + durable_execution_arn="arn:test:execution", + checkpoint_token="token123", # noqa: S106 + initial_execution_state=initial_state, + is_local_runner=False, + service_client=mock_client, + ) + + lambda_context = Mock() + lambda_context.aws_request_id = "test-request" + lambda_context.client_context = None + lambda_context.identity = None + lambda_context._epoch_deadline_time_in_ms = 1000000 # noqa: SLF001 + lambda_context.invoked_function_arn = None + lambda_context.tenant_id = None + + result = test_handler(invocation_input, lambda_context) + + assert result["Status"] == InvocationStatus.FAILED.value + assert "Error" not in result mock_client.checkpoint.assert_called_once() - # Verify the checkpoint call was for execution failure call_args = mock_client.checkpoint.call_args updates = call_args[1]["updates"] assert len(updates) == 1 assert updates[0].operation_type == OperationType.EXECUTION assert updates[0].action.value == "FAIL" - assert updates[0].error.message == "Test error" - assert updates[0].error.type == "ValueError" + assert updates[0].error.message == LARGE_RESULT def test_durable_execution_fatal_error_handling(): @@ -1404,11 +1450,11 @@ def test_handler(event: Any, context: DurableContext) -> str: # Make the service client checkpoint call fail on error handling mock_client.checkpoint.side_effect = failing_checkpoint - # Verify that the checkpoint error is raised (not the original ValueError) - with pytest.raises( - RuntimeError, match="Background checkpoint failed on error handling" - ): - test_handler(invocation_input, lambda_context) + # Verify that errors are not raised, but returned because response is small + resp = test_handler(invocation_input, lambda_context) + assert resp["Error"]["ErrorMessage"] == "User function error" + assert resp["Error"]["ErrorType"] == "ValueError" + assert resp["Status"] == InvocationStatus.FAILED.value def test_durable_execution_logs_checkpoint_error_extras_from_background_thread():