diff --git a/src/aws_durable_execution_sdk_python/operation/callback.py b/src/aws_durable_execution_sdk_python/operation/callback.py index 98bdbb8..600602e 100644 --- a/src/aws_durable_execution_sdk_python/operation/callback.py +++ b/src/aws_durable_execution_sdk_python/operation/callback.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any +from aws_durable_execution_sdk_python.config import StepConfig from aws_durable_execution_sdk_python.exceptions import FatalError from aws_durable_execution_sdk_python.lambda_service import ( CallbackOptions, @@ -97,6 +98,16 @@ def wait_for_callback_handler( def submitter_step(step_context): # noqa: ARG001 return submitter(callback.callback_id) - context.step(func=submitter_step, name=f"{name_with_space}submitter") + step_config = ( + StepConfig( + retry_strategy=config.retry_strategy, + serdes=config.serdes, + ) + if config + else None + ) + context.step( + func=submitter_step, name=f"{name_with_space}submitter", config=step_config + ) return callback.result() diff --git a/src/aws_durable_execution_sdk_python/retries.py b/src/aws_durable_execution_sdk_python/retries.py index 1637eac..5675950 100644 --- a/src/aws_durable_execution_sdk_python/retries.py +++ b/src/aws_durable_execution_sdk_python/retries.py @@ -2,10 +2,11 @@ from __future__ import annotations +import math import random import re -import sys from dataclasses import dataclass, field +from enum import StrEnum from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -13,6 +14,39 @@ Numeric = int | float +# region Jitter + + +class JitterStrategy(StrEnum): + """ + Jitter strategies are used to introduce noise when attempting to retry + an invoke. We introduce noise to prevent a thundering-herd effect where + a group of accesses (e.g. invokes) happen at once. + + Jitter is meant to be used to spread operations across time. + + members: + :NONE: No jitter; use the exact calculated delay + :FULL: Full jitter; random delay between 0 and calculated delay + :HALF: Half jitter; random delay between 0.5x and 1.0x of the calculated delay + """ + + NONE = "NONE" + FULL = "FULL" + HALF = "HALF" + + def compute_jitter(self, delay) -> float: + match self: + case JitterStrategy.NONE: + return 0 + case JitterStrategy.HALF: + return random.random() * 0.5 + 0.5 # noqa: S311 + case _: # default is FULL + return random.random() * delay # noqa: S311 + + +# endregion Jitter + @dataclass class RetryDecision: @@ -34,11 +68,11 @@ def no_retry(cls) -> RetryDecision: @dataclass class RetryStrategyConfig: - max_attempts: int = sys.maxsize # "infinite", practically + max_attempts: int = 3 # "infinite", practically initial_delay_seconds: int = 5 max_delay_seconds: int = 300 # 5 minutes backoff_rate: Numeric = 2.0 - jitter_seconds: Numeric = 1.0 + jitter_strategy: JitterStrategy = field(default=JitterStrategy.FULL) retryable_errors: list[str | re.Pattern] = field( default_factory=lambda: [re.compile(r".*")] ) @@ -77,10 +111,9 @@ def retry_strategy(error: Exception, attempts_made: int) -> RetryDecision: config.initial_delay_seconds * (config.backoff_rate ** (attempts_made - 1)), config.max_delay_seconds, ) - - # Add jitter (random not for cryptographic purposes, hence noqa) - jitter = (random.random() * 2 - 1) * config.jitter_seconds # noqa: S311 - final_delay = max(1, delay + jitter) + delay_with_jitter = delay + config.jitter_strategy.compute_jitter(delay) + delay_with_jitter = math.ceil(delay_with_jitter) + final_delay = max(1, delay_with_jitter) return RetryDecision.retry(round(final_delay)) @@ -93,18 +126,18 @@ class RetryPresets: @classmethod def none(cls) -> Callable[[Exception, int], RetryDecision]: """No retries.""" - return create_retry_strategy(RetryStrategyConfig(max_attempts=0)) + return create_retry_strategy(RetryStrategyConfig(max_attempts=1)) @classmethod def default(cls) -> Callable[[Exception, int], RetryDecision]: """Default retries, will be used automatically if retryConfig is missing""" return create_retry_strategy( RetryStrategyConfig( - max_attempts=sys.maxsize, + max_attempts=6, initial_delay_seconds=5, max_delay_seconds=60, backoff_rate=2, - jitter_seconds=1, + jitter_strategy=JitterStrategy.FULL, ) ) @@ -113,10 +146,7 @@ def transient(cls) -> Callable[[Exception, int], RetryDecision]: """Quick retries for transient errors""" return create_retry_strategy( RetryStrategyConfig( - max_attempts=3, - initial_delay_seconds=1, - backoff_rate=2, - jitter_seconds=0.5, + max_attempts=3, backoff_rate=2, jitter_strategy=JitterStrategy.HALF ) ) @@ -129,7 +159,6 @@ def resource_availability(cls) -> Callable[[Exception, int], RetryDecision]: initial_delay_seconds=5, max_delay_seconds=300, backoff_rate=2, - jitter_seconds=1, ) ) @@ -142,6 +171,6 @@ def critical(cls) -> Callable[[Exception, int], RetryDecision]: initial_delay_seconds=1, max_delay_seconds=60, backoff_rate=1.5, - jitter_seconds=0.3, + jitter_strategy=JitterStrategy.NONE, ) ) diff --git a/tests/operation/callback_test.py b/tests/operation/callback_test.py index 8fea68f..e9b7c00 100644 --- a/tests/operation/callback_test.py +++ b/tests/operation/callback_test.py @@ -5,7 +5,11 @@ import pytest -from aws_durable_execution_sdk_python.config import CallbackConfig +from aws_durable_execution_sdk_python.config import ( + CallbackConfig, + StepConfig, + WaitForCallbackConfig, +) from aws_durable_execution_sdk_python.exceptions import FatalError from aws_durable_execution_sdk_python.identifier import OperationIdentifier from aws_durable_execution_sdk_python.lambda_service import ( @@ -22,6 +26,8 @@ create_callback_handler, wait_for_callback_handler, ) +from aws_durable_execution_sdk_python.retries import RetryDecision +from aws_durable_execution_sdk_python.serdes import SerDes from aws_durable_execution_sdk_python.state import CheckpointedResult, ExecutionState from aws_durable_execution_sdk_python.types import DurableContext, StepContext @@ -269,7 +275,7 @@ def test_wait_for_callback_handler_with_name_and_config(): mock_callback.result.return_value = "named_callback_result" mock_context.create_callback.return_value = mock_callback mock_submitter = Mock() - config = CallbackConfig() + config = WaitForCallbackConfig() result = wait_for_callback_handler( mock_context, mock_submitter, "test_callback", config @@ -291,7 +297,7 @@ def test_wait_for_callback_handler_submitter_called_with_callback_id(): mock_context.create_callback.return_value = mock_callback mock_submitter = Mock() - def capture_step_call(func, name): + def capture_step_call(func, name, config=None): # Execute the step callable to verify submitter is called correctly step_context = Mock(spec=StepContext) func(step_context) @@ -357,7 +363,7 @@ def test_wait_for_callback_handler_with_none_callback_id(): mock_context.create_callback.return_value = mock_callback mock_submitter = Mock() - def execute_step(func, name): + def execute_step(func, name, config=None): step_context = Mock(spec=StepContext) return func(step_context) @@ -378,7 +384,7 @@ def test_wait_for_callback_handler_with_empty_string_callback_id(): mock_context.create_callback.return_value = mock_callback mock_submitter = Mock() - def execute_step(func, name): + def execute_step(func, name, config=None): step_context = Mock(spec=StepContext) return func(step_context) @@ -426,7 +432,9 @@ def test_wait_for_callback_handler_with_unicode_names(): assert result == f"result_for_{name}" expected_name = f"{name} submitter" - mock_context.step.assert_called_once_with(func=ANY, name=expected_name) + mock_context.step.assert_called_once_with( + func=ANY, name=expected_name, config=None + ) mock_context.reset_mock() @@ -591,7 +599,7 @@ def failing_submitter(callback_id): msg = "Submitter failed" raise ValueError(msg) - def step_side_effect(func, name): + def step_side_effect(func, name, config=None): step_context = Mock(spec=StepContext) func(step_context) @@ -675,7 +683,7 @@ def test_wait_for_callback_handler_config_propagation(): mock_context.create_callback.return_value = mock_callback mock_submitter = Mock() - config = CallbackConfig(timeout_seconds=120, heartbeat_timeout_seconds=30) + config = WaitForCallbackConfig(timeout_seconds=120, heartbeat_timeout_seconds=30) result = wait_for_callback_handler( mock_context, mock_submitter, "config_test", config @@ -687,6 +695,41 @@ def test_wait_for_callback_handler_config_propagation(): ) +def test_wait_for_callback_handler_step_config_propagation(): + """Test wait_for_callback_handler properly passes retry_strategy and serdes to step config.""" + + mock_context = Mock(spec=DurableContext) + mock_callback = Mock() + mock_callback.callback_id = "step_config_test" + mock_callback.result.return_value = "step_config_result" + mock_context.create_callback.return_value = mock_callback + mock_submitter = Mock() + + def test_retry_strategy(exception, attempt): + return RetryDecision.retry_after_delay(1) + + mock_serdes = Mock(spec=SerDes) + + config = WaitForCallbackConfig( + retry_strategy=test_retry_strategy, serdes=mock_serdes + ) + + result = wait_for_callback_handler( + mock_context, mock_submitter, "step_config_test", config + ) + + assert result == "step_config_result" + + # Verify step was called with correct StepConfig + mock_context.step.assert_called_once() + call_args = mock_context.step.call_args + step_config = call_args.kwargs["config"] + + assert isinstance(step_config, StepConfig) + assert step_config.retry_strategy == test_retry_strategy + assert step_config.serdes == mock_serdes + + def test_wait_for_callback_handler_with_various_result_types(): """Test wait_for_callback_handler with various result types.""" result_types = [None, True, False, 0, math.pi, "", "string", [], {"key": "value"}] @@ -729,7 +772,7 @@ def test_callback_lifecycle_complete_flow(): mock_callback.result.return_value = {"status": "completed", "data": "test_data"} mock_context.create_callback.return_value = mock_callback - config = CallbackConfig(timeout_seconds=300, heartbeat_timeout_seconds=60) + config = WaitForCallbackConfig(timeout_seconds=300, heartbeat_timeout_seconds=60) callback_id = create_callback_handler( state=mock_state, operation_identifier=OperationIdentifier("lifecycle_callback", None), @@ -742,7 +785,7 @@ def mock_submitter(cb_id): assert cb_id == "lifecycle_cb123" return "submitted" - def execute_step(func, name): + def execute_step(func, name, config=None): step_context = Mock(spec=StepContext) return func(step_context) @@ -862,7 +905,7 @@ def complex_submitter(callback_id): msg = "Invalid callback ID" raise ValueError(msg) - def execute_step(func, name): + def execute_step(func, name, config): step_context = Mock(spec=StepContext) return func(step_context) @@ -942,7 +985,9 @@ def test_callback_name_variations(): assert result == f"result_for_{name}" expected_name = f"{name} submitter" if name else "submitter" - mock_context.step.assert_called_once_with(func=ANY, name=expected_name) + mock_context.step.assert_called_once_with( + func=ANY, name=expected_name, config=None + ) mock_context.reset_mock() diff --git a/tests/retries_test.py b/tests/retries_test.py new file mode 100644 index 0000000..3bc1edd --- /dev/null +++ b/tests/retries_test.py @@ -0,0 +1,414 @@ +"""Tests for retry strategies and jitter implementations.""" + +import re +from unittest.mock import patch + +import pytest + +from aws_durable_execution_sdk_python.retries import ( + JitterStrategy, + RetryDecision, + RetryPresets, + RetryStrategyConfig, + create_retry_strategy, +) + + +class TestJitterStrategy: + """Test jitter strategy implementations.""" + + def test_none_jitter_returns_zero(self): + """Test NONE jitter always returns 0.""" + strategy = JitterStrategy.NONE + assert strategy.compute_jitter(10) == 0 + assert strategy.compute_jitter(100) == 0 + + @patch("aws_durable_execution_sdk_python.retries.random.random") + def test_full_jitter_range(self, mock_random): + """Test FULL jitter returns value between 0 and delay.""" + mock_random.return_value = 0.5 + strategy = JitterStrategy.FULL + delay = 10 + result = strategy.compute_jitter(delay) + assert result == 5.0 # 0.5 * 10 + + @patch("aws_durable_execution_sdk_python.retries.random.random") + def test_half_jitter_range(self, mock_random): + """Test HALF jitter returns value between 0.5 and 1.0 (multiplier).""" + mock_random.return_value = 0.5 + strategy = JitterStrategy.HALF + result = strategy.compute_jitter(10) + assert result == 0.75 # 0.5 * 0.5 + 0.5 + + @patch("aws_durable_execution_sdk_python.retries.random.random") + def test_half_jitter_boundary_values(self, mock_random): + """Test HALF jitter boundary values.""" + strategy = JitterStrategy.HALF + + # Minimum value (random = 0) + mock_random.return_value = 0.0 + result = strategy.compute_jitter(100) + assert result == 0.5 + + # Maximum value (random = 1) + mock_random.return_value = 1.0 + result = strategy.compute_jitter(100) + assert result == 1.0 + + def test_invalid_jitter_strategy(self): + """Test behavior with invalid jitter strategy.""" + # Create an invalid enum value by bypassing normal construction + invalid_strategy = "INVALID" + + # This should raise an exception or return None + with pytest.raises((ValueError, AttributeError)): + JitterStrategy(invalid_strategy).compute_jitter(10) + + +class TestRetryDecision: + """Test RetryDecision factory methods.""" + + def test_retry_factory(self): + """Test retry factory method.""" + decision = RetryDecision.retry(30) + assert decision.should_retry is True + assert decision.delay_seconds == 30 + + def test_no_retry_factory(self): + """Test no_retry factory method.""" + decision = RetryDecision.no_retry() + assert decision.should_retry is False + assert decision.delay_seconds == 0 + + +class TestRetryStrategyConfig: + """Test RetryStrategyConfig defaults and behavior.""" + + def test_default_config(self): + """Test default configuration values.""" + config = RetryStrategyConfig() + assert config.max_attempts == 3 + assert config.initial_delay_seconds == 5 + assert config.max_delay_seconds == 300 + assert config.backoff_rate == 2.0 + assert config.jitter_strategy == JitterStrategy.FULL + assert len(config.retryable_errors) == 1 + assert config.retryable_error_types == [] + + +class TestCreateRetryStrategy: + """Test retry strategy creation and behavior.""" + + def test_max_attempts_exceeded(self): + """Test strategy returns no_retry when max attempts exceeded.""" + config = RetryStrategyConfig(max_attempts=2) + strategy = create_retry_strategy(config) + + error = Exception("test error") + decision = strategy(error, 2) + assert decision.should_retry is False + + def test_retryable_error_message_string(self): + """Test retry based on error message string match.""" + config = RetryStrategyConfig(retryable_errors=["timeout"]) + strategy = create_retry_strategy(config) + + error = Exception("connection timeout") + decision = strategy(error, 1) + assert decision.should_retry is True + + def test_retryable_error_message_regex(self): + """Test retry based on error message regex match.""" + config = RetryStrategyConfig(retryable_errors=[re.compile(r"timeout|error")]) + strategy = create_retry_strategy(config) + + error = Exception("network timeout occurred") + decision = strategy(error, 1) + assert decision.should_retry is True + + def test_retryable_error_type(self): + """Test retry based on error type.""" + config = RetryStrategyConfig(retryable_error_types=[ValueError]) + strategy = create_retry_strategy(config) + + error = ValueError("invalid value") + decision = strategy(error, 1) + assert decision.should_retry is True + + def test_non_retryable_error(self): + """Test no retry for non-retryable error.""" + config = RetryStrategyConfig(retryable_errors=["timeout"]) + strategy = create_retry_strategy(config) + + error = Exception("permission denied") + decision = strategy(error, 1) + assert decision.should_retry is False + + @patch("aws_durable_execution_sdk_python.retries.random.random") + def test_exponential_backoff_calculation(self, mock_random): + """Test exponential backoff delay calculation.""" + mock_random.return_value = 0.5 + config = RetryStrategyConfig( + initial_delay_seconds=2, + backoff_rate=2.0, + jitter_strategy=JitterStrategy.FULL, + ) + strategy = create_retry_strategy(config) + + error = Exception("test error") + + # First attempt: 2 * (2^0) = 2, jitter adds 1, total = 3 + decision = strategy(error, 1) + assert decision.delay_seconds == 3 + + # Second attempt: 2 * (2^1) = 4, jitter adds 2, total = 6 + decision = strategy(error, 2) + assert decision.delay_seconds == 6 + + def test_max_delay_cap(self): + """Test delay is capped at max_delay_seconds.""" + config = RetryStrategyConfig( + initial_delay_seconds=100, + max_delay_seconds=50, + backoff_rate=2.0, + jitter_strategy=JitterStrategy.NONE, + ) + strategy = create_retry_strategy(config) + + error = Exception("test error") + decision = strategy(error, 2) # Would be 200 without cap + assert decision.delay_seconds == 50 + + def test_minimum_delay_one_second(self): + """Test delay is at least 1 second.""" + config = RetryStrategyConfig( + initial_delay_seconds=0, jitter_strategy=JitterStrategy.NONE + ) + strategy = create_retry_strategy(config) + + error = Exception("test error") + decision = strategy(error, 1) + assert decision.delay_seconds == 1 + + def test_delay_ceiling_applied(self): + """Test delay is rounded up using math.ceil.""" + with patch( + "aws_durable_execution_sdk_python.retries.random.random", return_value=0.3 + ): + config = RetryStrategyConfig( + initial_delay_seconds=3, jitter_strategy=JitterStrategy.FULL + ) + strategy = create_retry_strategy(config) + + error = Exception("test error") + decision = strategy(error, 1) + # 3 + (0.3 * 3) = 3.9, ceil(3.9) = 4 + assert decision.delay_seconds == 4 + + +class TestRetryPresets: + """Test predefined retry presets.""" + + def test_none_preset(self): + """Test none preset allows no retries.""" + strategy = RetryPresets.none() + error = Exception("test error") + + decision = strategy(error, 1) + assert decision.should_retry is False + + def test_default_preset_config(self): + """Test default preset configuration.""" + strategy = RetryPresets.default() + error = Exception("test error") + + # Should retry within max attempts + decision = strategy(error, 1) + assert decision.should_retry is True + + # Should not retry after max attempts + decision = strategy(error, 6) + assert decision.should_retry is False + + def test_transient_preset_config(self): + """Test transient preset configuration.""" + strategy = RetryPresets.transient() + error = Exception("test error") + + # Should retry within max attempts + decision = strategy(error, 1) + assert decision.should_retry is True + + # Should not retry after max attempts + decision = strategy(error, 3) + assert decision.should_retry is False + + def test_resource_availability_preset(self): + """Test resource availability preset allows longer retries.""" + strategy = RetryPresets.resource_availability() + error = Exception("test error") + + # Should retry within max attempts + decision = strategy(error, 1) + assert decision.should_retry is True + + # Should not retry after max attempts + decision = strategy(error, 5) + assert decision.should_retry is False + + def test_critical_preset_config(self): + """Test critical preset allows many retries.""" + strategy = RetryPresets.critical() + error = Exception("test error") + + # Should retry within max attempts + decision = strategy(error, 5) + assert decision.should_retry is True + + # Should not retry after max attempts + decision = strategy(error, 10) + assert decision.should_retry is False + + @patch("aws_durable_execution_sdk_python.retries.random.random") + def test_critical_preset_no_jitter(self, mock_random): + """Test critical preset uses no jitter.""" + mock_random.return_value = 0.5 # Should be ignored + strategy = RetryPresets.critical() + error = Exception("test error") + + decision = strategy(error, 1) + # With no jitter: 1 * (1.5^0) = 1 + assert decision.delay_seconds == 1 + + +class TestJitterIntegration: + """Test jitter integration with retry strategies.""" + + @patch("aws_durable_execution_sdk_python.retries.random.random") + def test_full_jitter_integration(self, mock_random): + """Test full jitter integration in retry strategy.""" + mock_random.return_value = 0.8 + config = RetryStrategyConfig( + initial_delay_seconds=10, jitter_strategy=JitterStrategy.FULL + ) + strategy = create_retry_strategy(config) + + error = Exception("test error") + decision = strategy(error, 1) + # 10 + (0.8 * 10) = 18 + assert decision.delay_seconds == 18 + + @patch("aws_durable_execution_sdk_python.retries.random.random") + def test_half_jitter_integration(self, mock_random): + """Test half jitter integration in retry strategy.""" + mock_random.return_value = 0.6 + config = RetryStrategyConfig( + initial_delay_seconds=10, jitter_strategy=JitterStrategy.HALF + ) + strategy = create_retry_strategy(config) + + error = Exception("test error") + decision = strategy(error, 1) + # 10 + (0.6 * 0.5 + 0.5) = 10.8, ceil(10.8) = 11 + assert decision.delay_seconds == 11 + + @patch("aws_durable_execution_sdk_python.retries.random.random") + def test_half_jitter_integration_corrected(self, mock_random): + """Test half jitter with corrected understanding of implementation.""" + mock_random.return_value = 0.0 # Minimum jitter + config = RetryStrategyConfig( + initial_delay_seconds=10, jitter_strategy=JitterStrategy.HALF + ) + strategy = create_retry_strategy(config) + + error = Exception("test error") + decision = strategy(error, 1) + # 10 + 0.5 = 10.5, ceil(10.5) = 11 + assert decision.delay_seconds == 11 + + def test_none_jitter_integration(self): + """Test no jitter integration in retry strategy.""" + config = RetryStrategyConfig( + initial_delay_seconds=10, jitter_strategy=JitterStrategy.NONE + ) + strategy = create_retry_strategy(config) + + error = Exception("test error") + decision = strategy(error, 1) + assert decision.delay_seconds == 10 + + +class TestEdgeCases: + """Test edge cases and error conditions.""" + + def test_none_config(self): + """Test behavior when config is None.""" + strategy = create_retry_strategy(None) + error = Exception("test error") + decision = strategy(error, 1) + assert decision.should_retry is True + assert decision.delay_seconds >= 1 + + def test_zero_backoff_rate(self): + """Test behavior with zero backoff rate.""" + config = RetryStrategyConfig( + initial_delay_seconds=5, backoff_rate=0, jitter_strategy=JitterStrategy.NONE + ) + strategy = create_retry_strategy(config) + + error = Exception("test error") + decision = strategy(error, 1) + # 5 * (0^0) = 5 * 1 = 5 + assert decision.delay_seconds == 5 + + def test_fractional_backoff_rate(self): + """Test behavior with fractional backoff rate.""" + config = RetryStrategyConfig( + initial_delay_seconds=8, + backoff_rate=0.5, + jitter_strategy=JitterStrategy.NONE, + ) + strategy = create_retry_strategy(config) + + error = Exception("test error") + decision = strategy(error, 2) + # 8 * (0.5^1) = 4 + assert decision.delay_seconds == 4 + + def test_empty_retryable_errors_list(self): + """Test behavior with empty retryable errors list.""" + config = RetryStrategyConfig(retryable_errors=[]) + strategy = create_retry_strategy(config) + + error = Exception("test error") + decision = strategy(error, 1) + assert decision.should_retry is False + + def test_multiple_error_patterns(self): + """Test multiple error patterns matching.""" + config = RetryStrategyConfig( + retryable_errors=["timeout", re.compile(r"network.*error")] + ) + strategy = create_retry_strategy(config) + + # Test string match + error1 = Exception("connection timeout") + decision1 = strategy(error1, 1) + assert decision1.should_retry is True + + # Test regex match + error2 = Exception("network connection error") + decision2 = strategy(error2, 1) + assert decision2.should_retry is True + + def test_mixed_error_types_and_patterns(self): + """Test combination of error types and patterns.""" + config = RetryStrategyConfig( + retryable_errors=["timeout"], retryable_error_types=[ValueError] + ) + strategy = create_retry_strategy(config) + + # Should retry on ValueError even without message match + error = ValueError("some value error") + decision = strategy(error, 1) + assert decision.should_retry is True