Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 37 additions & 23 deletions src/aws_durable_execution_sdk_python/execution.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import contextlib
import json
import logging
from concurrent.futures import ThreadPoolExecutor
Expand Down Expand Up @@ -250,9 +251,12 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
)

# Use ThreadPoolExecutor for concurrent execution of user code and background checkpoint processing
with ThreadPoolExecutor(
max_workers=2, thread_name_prefix="dex-handler"
) as executor:
with (
ThreadPoolExecutor(
max_workers=2, thread_name_prefix="dex-handler"
) as executor,
contextlib.closing(execution_state) as execution_state,
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding here the context so that it always closes automatically.

):
# Thread 1: Run background checkpoint processing
executor.submit(execution_state.checkpoint_batches_forever)

Expand Down Expand Up @@ -296,18 +300,12 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
# Must ensure the result is persisted before returning to Lambda.
# Large results exceed Lambda response limits and must be stored durably
# before the execution completes.
execution_state.create_checkpoint_sync(success_operation)

# Stop background checkpointing thread
execution_state.stop_checkpointing()
execution_state.create_checkpoint(success_operation, is_sync=True)
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason we do create_checkpoint(success_operation, is_sync=True) over create_checkpoint_sync is because that was unwrapping the error. This made it appear like a user error, when in fact it was a checkpoint error.

The change in how we handle user errors made this bug visible.


return DurableExecutionInvocationOutput.create_succeeded(
result=""
).to_dict()

# Stop background checkpointing thread
execution_state.stop_checkpointing()

return DurableExecutionInvocationOutput.create_succeeded(
result=serialized_result
).to_dict()
Expand All @@ -322,33 +320,28 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
)
else:
logger.exception("Checkpoint processing failed")
execution_state.stop_checkpointing()
# Raise the original exception
raise bg_error.source_exception from bg_error

except SuspendExecution:
# User code suspended - stop background checkpointing thread
logger.debug("Suspending execution...")
execution_state.stop_checkpointing()
return DurableExecutionInvocationOutput(
status=InvocationStatus.PENDING
).to_dict()

except CheckpointError as e:
# Checkpoint system is broken - stop background thread and exit immediately
execution_state.stop_checkpointing()
logger.exception(
"Checkpoint system failed",
extra=e.build_logger_extras(),
)
raise # Terminate Lambda immediately
except InvocationError:
execution_state.stop_checkpointing()
logger.exception("Invocation error. Must terminate.")
# Throw the error to trigger Lambda retry
raise
except ExecutionError as e:
execution_state.stop_checkpointing()
logger.exception("Execution error. Must terminate without retry.")
return DurableExecutionInvocationOutput(
status=InvocationStatus.FAILED,
Expand All @@ -357,15 +350,36 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
except Exception as e:
# all user-space errors go here
logger.exception("Execution failed")
failed_operation = OperationUpdate.create_execution_fail(
error=ErrorObject.from_exception(e)
)
# TODO: can optimize, if not too large can just return response rather than checkpoint
execution_state.create_checkpoint_sync(failed_operation)

execution_state.stop_checkpointing()
return DurableExecutionInvocationOutput(
status=InvocationStatus.FAILED
result = DurableExecutionInvocationOutput(
status=InvocationStatus.FAILED, error=ErrorObject.from_exception(e)
).to_dict()

serialized_result = json.dumps(result)

if (
serialized_result
and len(serialized_result) > LAMBDA_RESPONSE_SIZE_LIMIT
):
logger.debug(
"Response size (%s bytes) exceeds Lambda limit (%s) bytes). Checkpointing result.",
len(serialized_result),
LAMBDA_RESPONSE_SIZE_LIMIT,
)
failed_operation = OperationUpdate.create_execution_fail(
error=ErrorObject.from_exception(e)
)

# Checkpoint large result with blocking (is_sync=True, default).
# Must ensure the result is persisted before returning to Lambda.
# Large results exceed Lambda response limits and must be stored durably
# before the execution completes.
execution_state.create_checkpoint_sync(failed_operation)

return DurableExecutionInvocationOutput(
status=InvocationStatus.FAILED
).to_dict()

return result

return wrapper
3 changes: 3 additions & 0 deletions src/aws_durable_execution_sdk_python/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,3 +731,6 @@ def _calculate_operation_size(queued_op: QueuedOperation) -> int:
# Use JSON serialization to estimate size
serialized = json.dumps(queued_op.operation_update.to_dict()).encode("utf-8")
return len(serialized)

def close(self):
self.stop_checkpointing()
Comment on lines +734 to +736
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensures we close with the contextlib.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could we add a comment here, and the place that we use the contextlib.closing to point out we are stopping the ckeckpointing

62 changes: 54 additions & 8 deletions tests/execution_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,17 +592,63 @@ def test_handler(event: Any, context: DurableContext) -> dict:

result = test_handler(invocation_input, lambda_context)

# small error, should not call checkpoint
assert result["Status"] == InvocationStatus.FAILED.value
assert result["Error"] == {"ErrorMessage": "Test error", "ErrorType": "ValueError"}

assert not mock_client.checkpoint.called


def test_durable_execution_with_large_error_payload():
"""Test that large error payloads trigger checkpoint."""
mock_client = Mock(spec=DurableServiceClient)
mock_output = CheckpointOutput(
checkpoint_token="new_token", # noqa: S106
new_execution_state=CheckpointUpdatedExecutionState(),
)
mock_client.checkpoint.return_value = mock_output

@durable_execution
def test_handler(event: Any, context: DurableContext) -> dict:
raise ValueError(LARGE_RESULT)

operation = Operation(
operation_id="exec1",
operation_type=OperationType.EXECUTION,
status=OperationStatus.STARTED,
execution_details=ExecutionDetails(input_payload="{}"),
)

initial_state = InitialExecutionState(operations=[operation], next_marker="")

invocation_input = DurableExecutionInvocationInputWithClient(
durable_execution_arn="arn:test:execution",
checkpoint_token="token123", # noqa: S106
initial_execution_state=initial_state,
is_local_runner=False,
service_client=mock_client,
)

lambda_context = Mock()
lambda_context.aws_request_id = "test-request"
lambda_context.client_context = None
lambda_context.identity = None
lambda_context._epoch_deadline_time_in_ms = 1000000 # noqa: SLF001
lambda_context.invoked_function_arn = None
lambda_context.tenant_id = None

result = test_handler(invocation_input, lambda_context)

assert result["Status"] == InvocationStatus.FAILED.value
assert "Error" not in result
mock_client.checkpoint.assert_called_once()

# Verify the checkpoint call was for execution failure
call_args = mock_client.checkpoint.call_args
updates = call_args[1]["updates"]
assert len(updates) == 1
assert updates[0].operation_type == OperationType.EXECUTION
assert updates[0].action.value == "FAIL"
assert updates[0].error.message == "Test error"
assert updates[0].error.type == "ValueError"
assert updates[0].error.message == LARGE_RESULT


def test_durable_execution_fatal_error_handling():
Expand Down Expand Up @@ -1404,11 +1450,11 @@ def test_handler(event: Any, context: DurableContext) -> str:
# Make the service client checkpoint call fail on error handling
mock_client.checkpoint.side_effect = failing_checkpoint

# Verify that the checkpoint error is raised (not the original ValueError)
with pytest.raises(
RuntimeError, match="Background checkpoint failed on error handling"
):
test_handler(invocation_input, lambda_context)
# Verify that errors are not raised, but returned because response is small
resp = test_handler(invocation_input, lambda_context)
assert resp["Error"]["ErrorMessage"] == "User function error"
assert resp["Error"]["ErrorType"] == "ValueError"
assert resp["Status"] == InvocationStatus.FAILED.value


def test_durable_execution_logs_checkpoint_error_extras_from_background_thread():
Expand Down
Loading