diff --git a/src/aws_durable_execution_sdk_python/state.py b/src/aws_durable_execution_sdk_python/state.py index eb704b9..9dffd91 100644 --- a/src/aws_durable_execution_sdk_python/state.py +++ b/src/aws_durable_execution_sdk_python/state.py @@ -64,6 +64,11 @@ def create_from_operation(cls, operation: Operation) -> CheckpointedResult: result = invoke_details.result if invoke_details else None error = invoke_details.error if invoke_details else None + case OperationType.CONTEXT: + context_details = operation.context_details + result = context_details.result if context_details else None + error = context_details.error if context_details else None + return cls( operation=operation, status=operation.status, result=result, error=error ) diff --git a/tests/state_test.py b/tests/state_test.py index a5fdbc3..11ef88e 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -9,6 +9,7 @@ CallbackDetails, CheckpointOutput, CheckpointUpdatedExecutionState, + ContextDetails, ErrorObject, InvokeDetails, LambdaClient, @@ -127,6 +128,74 @@ def test_checkpointed_result_create_from_operation_invoke_with_both_result_and_e assert result.error == error +def test_checkpointed_result_create_from_operation_context(): + """Test CheckpointedResult.create_from_operation with CONTEXT operation.""" + context_details = ContextDetails(result="context_result") + operation = Operation( + operation_id="op1", + operation_type=OperationType.CONTEXT, + status=OperationStatus.SUCCEEDED, + context_details=context_details, + ) + result = CheckpointedResult.create_from_operation(operation) + assert result.operation == operation + assert result.status == OperationStatus.SUCCEEDED + assert result.result == "context_result" + assert result.error is None + + +def test_checkpointed_result_create_from_operation_context_with_error(): + """Test CheckpointedResult.create_from_operation with CONTEXT operation and error.""" + error = ErrorObject( + message="Context error", type="ContextError", data=None, stack_trace=None + ) + context_details = ContextDetails(error=error) + operation = Operation( + operation_id="op1", + operation_type=OperationType.CONTEXT, + status=OperationStatus.FAILED, + context_details=context_details, + ) + result = CheckpointedResult.create_from_operation(operation) + assert result.operation == operation + assert result.status == OperationStatus.FAILED + assert result.result is None + assert result.error == error + + +def test_checkpointed_result_create_from_operation_context_no_details(): + """Test CheckpointedResult.create_from_operation with CONTEXT operation but no context_details.""" + operation = Operation( + operation_id="op1", + operation_type=OperationType.CONTEXT, + status=OperationStatus.STARTED, + ) + result = CheckpointedResult.create_from_operation(operation) + assert result.operation == operation + assert result.status == OperationStatus.STARTED + assert result.result is None + assert result.error is None + + +def test_checkpointed_result_create_from_operation_context_with_both_result_and_error(): + """Test CheckpointedResult.create_from_operation with CONTEXT operation having both result and error.""" + error = ErrorObject( + message="Context error", type="ContextError", data=None, stack_trace=None + ) + context_details = ContextDetails(result="context_result", error=error) + operation = Operation( + operation_id="op1", + operation_type=OperationType.CONTEXT, + status=OperationStatus.FAILED, + context_details=context_details, + ) + result = CheckpointedResult.create_from_operation(operation) + assert result.operation == operation + assert result.status == OperationStatus.FAILED + assert result.result == "context_result" + assert result.error == error + + def test_checkpointed_result_create_from_operation_unknown_type(): """Test CheckpointedResult.create_from_operation with unknown operation type.""" # Create operation with a mock operation type that doesn't match any case