diff --git a/src/aws_durable_execution_sdk_python/context.py b/src/aws_durable_execution_sdk_python/context.py index e351b7a..877fda0 100644 --- a/src/aws_durable_execution_sdk_python/context.py +++ b/src/aws_durable_execution_sdk_python/context.py @@ -1,5 +1,6 @@ from __future__ import annotations +import hashlib import logging from typing import TYPE_CHECKING, Any, Concatenate, Generic, ParamSpec, TypeVar @@ -220,9 +221,10 @@ def _create_step_id(self) -> str: the id generated by this method. It is subject to change without notice. """ new_counter: int = self._step_counter.increment() - return ( + step_id = ( f"{self._parent_id}-{new_counter}" if self._parent_id else str(new_counter) ) + return hashlib.blake2b(step_id.encode()).hexdigest()[:64] # region Operations diff --git a/tests/context_test.py b/tests/context_test.py index e48237a..9b48e01 100644 --- a/tests/context_test.py +++ b/tests/context_test.py @@ -1,6 +1,8 @@ """Unit tests for context.""" import json +import random +from itertools import islice from unittest.mock import ANY, Mock, patch import pytest @@ -32,6 +34,7 @@ ) from aws_durable_execution_sdk_python.state import CheckpointedResult, ExecutionState from tests.serdes_test import CustomDictSerDes +from tests.test_helpers import operation_id_sequence def test_durable_context(): @@ -221,17 +224,19 @@ def test_create_callback_basic(mock_handler): ) context = DurableContext(state=mock_state) + operation_ids = operation_id_sequence() + expected_operation_id = next(operation_ids) callback = context.create_callback() assert isinstance(callback, Callback) assert callback.callback_id == "callback123" - assert callback.operation_id == "1" + assert callback.operation_id == expected_operation_id assert callback.state is mock_state mock_handler.assert_called_once_with( state=mock_state, - operation_identifier=OperationIdentifier("1", None, None), + operation_identifier=OperationIdentifier(expected_operation_id, None, None), config=CallbackConfig(), ) @@ -247,16 +252,19 @@ def test_create_callback_with_name_and_config(mock_handler): config = CallbackConfig() context = DurableContext(state=mock_state) + operation_ids = operation_id_sequence() + [next(operation_ids) for _ in range(5)] # Skip 5 IDs + expected_operation_id = next(operation_ids) # Get the 6th ID [context._create_step_id() for _ in range(5)] # Set counter to 5 # noqa: SLF001 callback = context.create_callback(config=config) assert callback.callback_id == "callback456" - assert callback.operation_id == "6" + assert callback.operation_id == expected_operation_id mock_handler.assert_called_once_with( state=mock_state, - operation_identifier=OperationIdentifier("6", None, None), + operation_identifier=OperationIdentifier(expected_operation_id, None, None), config=config, ) @@ -264,6 +272,7 @@ def test_create_callback_with_name_and_config(mock_handler): @patch("aws_durable_execution_sdk_python.context.create_callback_handler") def test_create_callback_with_parent_id(mock_handler): """Test create_callback with parent_id.""" + mock_handler.return_value = "callback789" mock_state = Mock(spec=ExecutionState) mock_state.durable_execution_arn = ( @@ -271,15 +280,18 @@ def test_create_callback_with_parent_id(mock_handler): ) context = DurableContext(state=mock_state, parent_id="parent123") + operation_ids = operation_id_sequence("parent123") + [next(operation_ids) for _ in range(2)] # Skip 2 IDs + expected_operation_id = next(operation_ids) # Get the 3rd ID [context._create_step_id() for _ in range(2)] # Set counter to 2 # noqa: SLF001 callback = context.create_callback() - assert callback.operation_id == "parent123-3" + assert callback.operation_id == expected_operation_id mock_handler.assert_called_once_with( state=mock_state, - operation_identifier=OperationIdentifier("parent123-3", "parent123"), + operation_identifier=OperationIdentifier(expected_operation_id, "parent123"), config=CallbackConfig(), ) @@ -299,8 +311,14 @@ def test_create_callback_increments_counter(mock_handler): callback1 = context.create_callback() callback2 = context.create_callback() - assert callback1.operation_id == "11" - assert callback2.operation_id == "12" + # Use operation_id_sequence to get expected IDs + seq = operation_id_sequence() + [next(seq) for _ in range(10)] # Skip first 10 + expected_id1 = next(seq) # 11th + expected_id2 = next(seq) # 12th + + assert callback1.operation_id == expected_id1 + assert callback2.operation_id == expected_id2 assert context._step_counter.get_current() == 12 # noqa: SLF001 @@ -322,6 +340,8 @@ def test_step_basic(mock_handler): ) # Ensure _original_name doesn't exist context = DurableContext(state=mock_state) + operation_ids = operation_id_sequence() + expected_operation_id = next(operation_ids) result = context.step(mock_callable) @@ -330,7 +350,7 @@ def test_step_basic(mock_handler): func=mock_callable, config=None, state=mock_state, - operation_identifier=OperationIdentifier("1", None, None), + operation_identifier=OperationIdentifier(expected_operation_id, None, None), context_logger=ANY, ) @@ -354,12 +374,17 @@ def test_step_with_name_and_config(mock_handler): result = context.step(mock_callable, config=config) + # Get expected ID + seq = operation_id_sequence() + [next(seq) for _ in range(5)] # Skip first 5 + expected_id = next(seq) # 6th + assert result == "configured_result" mock_handler.assert_called_once_with( func=mock_callable, config=config, state=mock_state, - operation_identifier=OperationIdentifier("6", None, None), + operation_identifier=OperationIdentifier(expected_id, None, None), context_logger=ANY, ) @@ -382,11 +407,16 @@ def test_step_with_parent_id(mock_handler): context.step(mock_callable) + # Get expected ID with parent + seq = operation_id_sequence("parent123") + [next(seq) for _ in range(2)] # Skip first 2 + expected_id = next(seq) # 3rd + mock_handler.assert_called_once_with( func=mock_callable, config=None, state=mock_state, - operation_identifier=OperationIdentifier("parent123-3", "parent123"), + operation_identifier=OperationIdentifier(expected_id, "parent123"), context_logger=ANY, ) @@ -410,13 +440,19 @@ def test_step_increments_counter(mock_handler): context.step(mock_callable) context.step(mock_callable) + # Get expected IDs + seq = operation_id_sequence() + [next(seq) for _ in range(10)] # Skip first 10 + expected_id1 = next(seq) # 11th + expected_id2 = next(seq) # 12th + assert context._step_counter.get_current() == 12 # noqa: SLF001 assert mock_handler.call_args_list[0][1][ "operation_identifier" - ] == OperationIdentifier("11", None, None) + ] == OperationIdentifier(expected_id1, None, None) assert mock_handler.call_args_list[1][1][ "operation_identifier" - ] == OperationIdentifier("12", None, None) + ] == OperationIdentifier(expected_id2, None, None) @patch("aws_durable_execution_sdk_python.context.step_handler") @@ -434,11 +470,15 @@ def test_step_with_original_name(mock_handler): context.step(mock_callable, name="override_name") + # Get expected ID + seq = operation_id_sequence() + expected_id = next(seq) # 1st + mock_handler.assert_called_once_with( func=mock_callable, config=None, state=mock_state, - operation_identifier=OperationIdentifier("1", None, "override_name"), + operation_identifier=OperationIdentifier(expected_id, None, "override_name"), context_logger=ANY, ) @@ -457,6 +497,8 @@ def test_invoke_basic(mock_handler): ) context = DurableContext(state=mock_state) + operation_ids = operation_id_sequence() + expected_operation_id = next(operation_ids) result = context.invoke("test_function", "test_payload") @@ -466,7 +508,7 @@ def test_invoke_basic(mock_handler): function_name="test_function", payload="test_payload", state=mock_state, - operation_identifier=OperationIdentifier("1", None, None), + operation_identifier=OperationIdentifier(expected_operation_id, None, None), config=None, ) @@ -488,12 +530,17 @@ def test_invoke_with_name_and_config(mock_handler): "test_function", {"key": "value"}, name="named_invoke", config=config ) + # Get expected ID + seq = operation_id_sequence() + [next(seq) for _ in range(5)] # Skip first 5 + expected_id = next(seq) # 6th + assert result == "configured_result" mock_handler.assert_called_once_with( function_name="test_function", payload={"key": "value"}, state=mock_state, - operation_identifier=OperationIdentifier("6", None, "named_invoke"), + operation_identifier=OperationIdentifier(expected_id, None, "named_invoke"), config=config, ) @@ -512,11 +559,15 @@ def test_invoke_with_parent_id(mock_handler): context.invoke("test_function", None) + seq = operation_id_sequence("parent123") + [next(seq) for _ in range(2)] + expected_id = next(seq) + mock_handler.assert_called_once_with( function_name="test_function", payload=None, state=mock_state, - operation_identifier=OperationIdentifier("parent123-3", "parent123", None), + operation_identifier=OperationIdentifier(expected_id, "parent123", None), config=None, ) @@ -536,13 +587,18 @@ def test_invoke_increments_counter(mock_handler): context.invoke("function1", "payload1") context.invoke("function2", "payload2") + seq = operation_id_sequence() + [next(seq) for _ in range(10)] + expected_id1 = next(seq) + expected_id2 = next(seq) + assert context._step_counter.get_current() == 12 # noqa: SLF001 assert mock_handler.call_args_list[0][1][ "operation_identifier" - ] == OperationIdentifier("11", None, None) + ] == OperationIdentifier(expected_id1, None, None) assert mock_handler.call_args_list[1][1][ "operation_identifier" - ] == OperationIdentifier("12", None, None) + ] == OperationIdentifier(expected_id2, None, None) @patch("aws_durable_execution_sdk_python.context.invoke_handler") @@ -558,13 +614,16 @@ def test_invoke_with_none_payload(mock_handler): result = context.invoke("test_function", None) + seq = operation_id_sequence() + expected_id = next(seq) + assert result is None mock_handler.assert_called_once_with( function_name="test_function", payload=None, state=mock_state, - operation_identifier=OperationIdentifier("1", None, None), + operation_identifier=OperationIdentifier(expected_id, None, None), config=None, ) @@ -593,12 +652,17 @@ def test_invoke_with_custom_serdes(mock_handler): config=config, ) + seq = operation_id_sequence() + expected_id = next(seq) + assert result == {"transformed": "data"} mock_handler.assert_called_once_with( function_name="test_function", payload={"original": "data"}, state=mock_state, - operation_identifier=OperationIdentifier("1", None, "custom_serdes_invoke"), + operation_identifier=OperationIdentifier( + expected_id, None, "custom_serdes_invoke" + ), config=config, ) @@ -616,13 +680,15 @@ def test_wait_basic(mock_handler): ) context = DurableContext(state=mock_state) + operation_ids = operation_id_sequence() + expected_operation_id = next(operation_ids) context.wait(30) mock_handler.assert_called_once_with( seconds=30, state=mock_state, - operation_identifier=OperationIdentifier("1", None, None), + operation_identifier=OperationIdentifier(expected_operation_id, None, None), ) @@ -639,10 +705,14 @@ def test_wait_with_name(mock_handler): context.wait(60, name="test_wait") + seq = operation_id_sequence() + [next(seq) for _ in range(5)] + expected_id = next(seq) + mock_handler.assert_called_once_with( seconds=60, state=mock_state, - operation_identifier=OperationIdentifier("6", None, "test_wait"), + operation_identifier=OperationIdentifier(expected_id, None, "test_wait"), ) @@ -659,10 +729,14 @@ def test_wait_with_parent_id(mock_handler): context.wait(45) + seq = operation_id_sequence("parent123") + [next(seq) for _ in range(2)] + expected_id = next(seq) + mock_handler.assert_called_once_with( seconds=45, state=mock_state, - operation_identifier=OperationIdentifier("parent123-3", "parent123"), + operation_identifier=OperationIdentifier(expected_id, "parent123"), ) @@ -680,13 +754,18 @@ def test_wait_increments_counter(mock_handler): context.wait(15) context.wait(25) + seq = operation_id_sequence() + [next(seq) for _ in range(10)] + expected_id1 = next(seq) + expected_id2 = next(seq) + assert context._step_counter.get_current() == 12 # noqa: SLF001 assert mock_handler.call_args_list[0][1][ "operation_identifier" - ] == OperationIdentifier("11", None, None) + ] == OperationIdentifier(expected_id1, None, None) assert mock_handler.call_args_list[1][1][ "operation_identifier" - ] == OperationIdentifier("12", None, None) + ] == OperationIdentifier(expected_id2, None, None) @patch("aws_durable_execution_sdk_python.context.wait_handler") @@ -736,6 +815,8 @@ def test_run_in_child_context_basic(mock_handler): ) # Ensure _original_name doesn't exist context = DurableContext(state=mock_state) + operation_ids = operation_id_sequence() + expected_operation_id = next(operation_ids) result = context.run_in_child_context(mock_callable) @@ -745,7 +826,9 @@ def test_run_in_child_context_basic(mock_handler): # Verify the callable was wrapped with child context call_args = mock_handler.call_args assert call_args[1]["state"] is mock_state - assert call_args[1]["operation_identifier"] == OperationIdentifier("1", None, None) + assert call_args[1]["operation_identifier"] == OperationIdentifier( + expected_operation_id, None, None + ) assert call_args[1]["config"] is None @@ -767,10 +850,14 @@ def test_run_in_child_context_with_name_and_config(mock_handler): result = context.run_in_child_context(mock_callable, config=config) + seq = operation_id_sequence() + [next(seq) for _ in range(3)] + expected_id = next(seq) + assert result == "configured_child_result" call_args = mock_handler.call_args assert call_args[1]["operation_identifier"] == OperationIdentifier( - "4", None, "original_function" + expected_id, None, "original_function" ) assert call_args[1]["config"] is config @@ -793,9 +880,13 @@ def test_run_in_child_context_with_parent_id(mock_handler): context.run_in_child_context(mock_callable) + seq = operation_id_sequence("parent456") + [next(seq) for _ in range(1)] + expected_id = next(seq) + call_args = mock_handler.call_args assert call_args[1]["operation_identifier"] == OperationIdentifier( - "parent456-2", "parent456", None + expected_id, "parent456", None ) @@ -807,11 +898,14 @@ def test_run_in_child_context_creates_child_context(mock_handler): "arn:aws:durable:us-east-1:123456789012:execution/test" ) + seq = operation_id_sequence() + expected_parent_id = next(seq) + def capture_child_context(child_context): # Verify child context properties assert isinstance(child_context, DurableContext) assert child_context.state is mock_state - assert child_context._parent_id == "1" # noqa: SLF001 + assert child_context._parent_id == expected_parent_id # noqa: SLF001 return "child_executed" mock_callable = Mock(side_effect=capture_child_context) @@ -844,13 +938,18 @@ def test_run_in_child_context_increments_counter(mock_handler): context.run_in_child_context(mock_callable) context.run_in_child_context(mock_callable) + seq = operation_id_sequence() + [next(seq) for _ in range(5)] + expected_id1 = next(seq) + expected_id2 = next(seq) + assert context._step_counter.get_current() == 7 # noqa: SLF001 assert mock_handler.call_args_list[0][1][ "operation_identifier" - ] == OperationIdentifier("6", None, None) + ] == OperationIdentifier(expected_id1, None, None) assert mock_handler.call_args_list[1][1][ "operation_identifier" - ] == OperationIdentifier("7", None, None) + ] == OperationIdentifier(expected_id2, None, None) @patch("aws_durable_execution_sdk_python.context.child_handler") @@ -1511,3 +1610,66 @@ def test_wait_strategy(state, attempt): # Verify wait_for_condition_handler was called (line 425) mock_handler.assert_called_once() assert result == "final_state" + + +# region operation_id generation +def test_operation_id_conditional_on_parent(): + """ + - ensure that for all unique parents we produce unique sequences for the children + """ + all_sequences = set() + + for i in range(10): + parent = f"parent_{i}" + seq = operation_id_sequence(parent) + sequence = tuple(islice(seq, 10)) + all_sequences.add(sequence) + + assert len(all_sequences) == 10 + + +def test_operation_id_generation_conditional_on_name_and_parent(): + """ + ensure that for all given (name, parent), None included, we observe unique sequences + """ + + parents = [f"parent_{i}" for i in range(9)] + [None] + random.shuffle(parents) + all_sequences = set() + + for parent in parents: + seq = operation_id_sequence(parent) + sequence = tuple(islice(seq, 5)) + all_sequences.add(sequence) + + assert len(all_sequences) == 10 + + +def test_operation_id_generation_deterministic(): + """ + ensure that any sequence with any seed name and parent is deterministic + """ + + random.seed(43) + parents = [f"parent_{i}" for i in range(9)] + [None] + random.shuffle(parents) + + for parent in parents: + seq1 = operation_id_sequence(parent) + sequence1 = tuple(islice(seq1, 10)) + + seq2 = operation_id_sequence(parent) + sequence2 = tuple(islice(seq2, 10)) + + assert sequence1 == sequence2 + + +def test_operation_id_generation_unique(): + """ + ensure that for any sequence, any two adjacent operation ids are unique + """ + seq = operation_id_sequence() + ids = [next(seq) for _ in range(100)] + + for i in range(len(ids) - 1): + assert ids[i] != ids[i + 1] diff --git a/tests/e2e/execution_int_test.py b/tests/e2e/execution_int_test.py index b21de6b..7549580 100644 --- a/tests/e2e/execution_int_test.py +++ b/tests/e2e/execution_int_test.py @@ -23,6 +23,7 @@ OperationType, ) from aws_durable_execution_sdk_python.logger import LoggerInterface +from tests.test_helpers import operation_id_sequence if TYPE_CHECKING: from aws_durable_execution_sdk_python.types import StepContext @@ -208,16 +209,17 @@ def mock_checkpoint( # 1 START checkpoint, 1 SUCCEED checkpoint assert len(checkpoint_calls) == 2 + operation_id = next(operation_id_sequence()) checkpoint = checkpoint_calls[0][0] assert checkpoint.operation_type == OperationType.STEP assert checkpoint.action == OperationAction.START - assert checkpoint.operation_id == "1" + assert checkpoint.operation_id == operation_id # Check the wait checkpoint checkpoint = checkpoint_calls[1][0] assert checkpoint.operation_type == OperationType.STEP assert checkpoint.action == OperationAction.SUCCEED - assert checkpoint.operation_id == "1" + assert checkpoint.operation_id == operation_id def test_wait_inside_run_in_childcontext(): @@ -295,19 +297,24 @@ def mock_checkpoint( # Assert that checkpoints were created assert len(checkpoint_calls) == 2 # One for child context start, one for wait + expected_parent_id = next(operation_id_sequence()) + expected_child_id = next(operation_id_sequence(expected_parent_id)) + # Check first checkpoint (child context start) first_checkpoint = checkpoint_calls[0][0] assert first_checkpoint.operation_type is OperationType.CONTEXT assert first_checkpoint.action is OperationAction.START - assert first_checkpoint.operation_id == "1" + assert first_checkpoint.operation_id == expected_parent_id # Check second checkpoint (wait operation) second_checkpoint = checkpoint_calls[1][0] assert second_checkpoint.operation_type is OperationType.WAIT assert second_checkpoint.action is OperationAction.START - assert second_checkpoint.operation_id == "1-1" + assert second_checkpoint.operation_id == expected_child_id assert second_checkpoint.wait_options.wait_seconds == 1 + assert second_checkpoint.operation_id != first_checkpoint.operation_id + mock_inside_child.assert_called_once_with(10, 20) @@ -379,6 +386,7 @@ def mock_checkpoint( # Execute the handler result = my_handler(event, lambda_context) + operation_ids = operation_id_sequence() # Assert the execution returns PENDING status assert result["Status"] == InvocationStatus.PENDING.value @@ -390,5 +398,5 @@ def mock_checkpoint( checkpoint = checkpoint_calls[0][0] assert checkpoint.operation_type is OperationType.WAIT assert checkpoint.action is OperationAction.START - assert checkpoint.operation_id == "1" + assert checkpoint.operation_id == next(operation_ids) assert checkpoint.wait_options.wait_seconds == 1 diff --git a/tests/test_helpers.py b/tests/test_helpers.py new file mode 100644 index 0000000..dca15a0 --- /dev/null +++ b/tests/test_helpers.py @@ -0,0 +1,17 @@ +"""Test helpers for generating expected step IDs.""" + +from unittest.mock import Mock + +from aws_durable_execution_sdk_python.context import DurableContext +from aws_durable_execution_sdk_python.execution import ExecutionState + + +def operation_id_sequence(parent_id: str | None = None): + """Generator that yields step IDs in sequence using DurableContext.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test-arn" + + context = DurableContext(state=mock_state, parent_id=parent_id) + + while True: + yield context._create_step_id() # noqa: SLF001