Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/aws_durable_execution_sdk_python/operation/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from typing import TYPE_CHECKING, Any

from aws_durable_execution_sdk_python.config import StepConfig
from aws_durable_execution_sdk_python.exceptions import FatalError
from aws_durable_execution_sdk_python.lambda_service import (
CallbackOptions,
Expand Down Expand Up @@ -97,6 +98,16 @@ def wait_for_callback_handler(
def submitter_step(step_context): # noqa: ARG001
return submitter(callback.callback_id)

context.step(func=submitter_step, name=f"{name_with_space}submitter")
step_config = (
StepConfig(
retry_strategy=config.retry_strategy,
serdes=config.serdes,
)
if config
else None
)
context.step(
func=submitter_step, name=f"{name_with_space}submitter", config=step_config
)

return callback.result()
61 changes: 45 additions & 16 deletions src/aws_durable_execution_sdk_python/retries.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,51 @@

from __future__ import annotations

import math
import random
import re
import sys
from dataclasses import dataclass, field
from enum import StrEnum
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from collections.abc import Callable

Numeric = int | float

# region Jitter


class JitterStrategy(StrEnum):
"""
Jitter strategies are used to introduce noise when attempting to retry
an invoke. We introduce noise to prevent a thundering-herd effect where
a group of accesses (e.g. invokes) happen at once.

Jitter is meant to be used to spread operations across time.

members:
:NONE: No jitter; use the exact calculated delay
:FULL: Full jitter; random delay between 0 and calculated delay
:HALF: Half jitter; random delay between 0.5x and 1.0x of the calculated delay
"""

NONE = "NONE"
FULL = "FULL"
HALF = "HALF"

def compute_jitter(self, delay) -> float:
match self:
case JitterStrategy.NONE:
return 0
case JitterStrategy.HALF:
return random.random() * 0.5 + 0.5 # noqa: S311
case _: # default is FULL
return random.random() * delay # noqa: S311


# endregion Jitter


@dataclass
class RetryDecision:
Expand All @@ -34,11 +68,11 @@ def no_retry(cls) -> RetryDecision:

@dataclass
class RetryStrategyConfig:
max_attempts: int = sys.maxsize # "infinite", practically
max_attempts: int = 3 # "infinite", practically
initial_delay_seconds: int = 5
max_delay_seconds: int = 300 # 5 minutes
backoff_rate: Numeric = 2.0
jitter_seconds: Numeric = 1.0
jitter_strategy: JitterStrategy = field(default=JitterStrategy.FULL)
retryable_errors: list[str | re.Pattern] = field(
default_factory=lambda: [re.compile(r".*")]
)
Expand Down Expand Up @@ -77,10 +111,9 @@ def retry_strategy(error: Exception, attempts_made: int) -> RetryDecision:
config.initial_delay_seconds * (config.backoff_rate ** (attempts_made - 1)),
config.max_delay_seconds,
)

# Add jitter (random not for cryptographic purposes, hence noqa)
jitter = (random.random() * 2 - 1) * config.jitter_seconds # noqa: S311
final_delay = max(1, delay + jitter)
delay_with_jitter = delay + config.jitter_strategy.compute_jitter(delay)
delay_with_jitter = math.ceil(delay_with_jitter)
final_delay = max(1, delay_with_jitter)

return RetryDecision.retry(round(final_delay))

Expand All @@ -93,18 +126,18 @@ class RetryPresets:
@classmethod
def none(cls) -> Callable[[Exception, int], RetryDecision]:
"""No retries."""
return create_retry_strategy(RetryStrategyConfig(max_attempts=0))
return create_retry_strategy(RetryStrategyConfig(max_attempts=1))

@classmethod
def default(cls) -> Callable[[Exception, int], RetryDecision]:
"""Default retries, will be used automatically if retryConfig is missing"""
return create_retry_strategy(
RetryStrategyConfig(
max_attempts=sys.maxsize,
max_attempts=6,
initial_delay_seconds=5,
max_delay_seconds=60,
backoff_rate=2,
jitter_seconds=1,
jitter_strategy=JitterStrategy.FULL,
)
)

Expand All @@ -113,10 +146,7 @@ def transient(cls) -> Callable[[Exception, int], RetryDecision]:
"""Quick retries for transient errors"""
return create_retry_strategy(
RetryStrategyConfig(
max_attempts=3,
initial_delay_seconds=1,
backoff_rate=2,
jitter_seconds=0.5,
max_attempts=3, backoff_rate=2, jitter_strategy=JitterStrategy.HALF
)
)

Expand All @@ -129,7 +159,6 @@ def resource_availability(cls) -> Callable[[Exception, int], RetryDecision]:
initial_delay_seconds=5,
max_delay_seconds=300,
backoff_rate=2,
jitter_seconds=1,
)
)

Expand All @@ -142,6 +171,6 @@ def critical(cls) -> Callable[[Exception, int], RetryDecision]:
initial_delay_seconds=1,
max_delay_seconds=60,
backoff_rate=1.5,
jitter_seconds=0.3,
jitter_strategy=JitterStrategy.NONE,
)
)
69 changes: 57 additions & 12 deletions tests/operation/callback_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@

import pytest

from aws_durable_execution_sdk_python.config import CallbackConfig
from aws_durable_execution_sdk_python.config import (
CallbackConfig,
StepConfig,
WaitForCallbackConfig,
)
from aws_durable_execution_sdk_python.exceptions import FatalError
from aws_durable_execution_sdk_python.identifier import OperationIdentifier
from aws_durable_execution_sdk_python.lambda_service import (
Expand All @@ -22,6 +26,8 @@
create_callback_handler,
wait_for_callback_handler,
)
from aws_durable_execution_sdk_python.retries import RetryDecision
from aws_durable_execution_sdk_python.serdes import SerDes
from aws_durable_execution_sdk_python.state import CheckpointedResult, ExecutionState
from aws_durable_execution_sdk_python.types import DurableContext, StepContext

Expand Down Expand Up @@ -269,7 +275,7 @@ def test_wait_for_callback_handler_with_name_and_config():
mock_callback.result.return_value = "named_callback_result"
mock_context.create_callback.return_value = mock_callback
mock_submitter = Mock()
config = CallbackConfig()
config = WaitForCallbackConfig()

result = wait_for_callback_handler(
mock_context, mock_submitter, "test_callback", config
Expand All @@ -291,7 +297,7 @@ def test_wait_for_callback_handler_submitter_called_with_callback_id():
mock_context.create_callback.return_value = mock_callback
mock_submitter = Mock()

def capture_step_call(func, name):
def capture_step_call(func, name, config=None):
# Execute the step callable to verify submitter is called correctly
step_context = Mock(spec=StepContext)
func(step_context)
Expand Down Expand Up @@ -357,7 +363,7 @@ def test_wait_for_callback_handler_with_none_callback_id():
mock_context.create_callback.return_value = mock_callback
mock_submitter = Mock()

def execute_step(func, name):
def execute_step(func, name, config=None):
step_context = Mock(spec=StepContext)
return func(step_context)

Expand All @@ -378,7 +384,7 @@ def test_wait_for_callback_handler_with_empty_string_callback_id():
mock_context.create_callback.return_value = mock_callback
mock_submitter = Mock()

def execute_step(func, name):
def execute_step(func, name, config=None):
step_context = Mock(spec=StepContext)
return func(step_context)

Expand Down Expand Up @@ -426,7 +432,9 @@ def test_wait_for_callback_handler_with_unicode_names():

assert result == f"result_for_{name}"
expected_name = f"{name} submitter"
mock_context.step.assert_called_once_with(func=ANY, name=expected_name)
mock_context.step.assert_called_once_with(
func=ANY, name=expected_name, config=None
)
mock_context.reset_mock()


Expand Down Expand Up @@ -591,7 +599,7 @@ def failing_submitter(callback_id):
msg = "Submitter failed"
raise ValueError(msg)

def step_side_effect(func, name):
def step_side_effect(func, name, config=None):
step_context = Mock(spec=StepContext)
func(step_context)

Expand Down Expand Up @@ -675,7 +683,7 @@ def test_wait_for_callback_handler_config_propagation():
mock_context.create_callback.return_value = mock_callback
mock_submitter = Mock()

config = CallbackConfig(timeout_seconds=120, heartbeat_timeout_seconds=30)
config = WaitForCallbackConfig(timeout_seconds=120, heartbeat_timeout_seconds=30)

result = wait_for_callback_handler(
mock_context, mock_submitter, "config_test", config
Expand All @@ -687,6 +695,41 @@ def test_wait_for_callback_handler_config_propagation():
)


def test_wait_for_callback_handler_step_config_propagation():
"""Test wait_for_callback_handler properly passes retry_strategy and serdes to step config."""

mock_context = Mock(spec=DurableContext)
mock_callback = Mock()
mock_callback.callback_id = "step_config_test"
mock_callback.result.return_value = "step_config_result"
mock_context.create_callback.return_value = mock_callback
mock_submitter = Mock()

def test_retry_strategy(exception, attempt):
return RetryDecision.retry_after_delay(1)

mock_serdes = Mock(spec=SerDes)

config = WaitForCallbackConfig(
retry_strategy=test_retry_strategy, serdes=mock_serdes
)

result = wait_for_callback_handler(
mock_context, mock_submitter, "step_config_test", config
)

assert result == "step_config_result"

# Verify step was called with correct StepConfig
mock_context.step.assert_called_once()
call_args = mock_context.step.call_args
step_config = call_args.kwargs["config"]

assert isinstance(step_config, StepConfig)
assert step_config.retry_strategy == test_retry_strategy
assert step_config.serdes == mock_serdes


def test_wait_for_callback_handler_with_various_result_types():
"""Test wait_for_callback_handler with various result types."""
result_types = [None, True, False, 0, math.pi, "", "string", [], {"key": "value"}]
Expand Down Expand Up @@ -729,7 +772,7 @@ def test_callback_lifecycle_complete_flow():
mock_callback.result.return_value = {"status": "completed", "data": "test_data"}
mock_context.create_callback.return_value = mock_callback

config = CallbackConfig(timeout_seconds=300, heartbeat_timeout_seconds=60)
config = WaitForCallbackConfig(timeout_seconds=300, heartbeat_timeout_seconds=60)
callback_id = create_callback_handler(
state=mock_state,
operation_identifier=OperationIdentifier("lifecycle_callback", None),
Expand All @@ -742,7 +785,7 @@ def mock_submitter(cb_id):
assert cb_id == "lifecycle_cb123"
return "submitted"

def execute_step(func, name):
def execute_step(func, name, config=None):
step_context = Mock(spec=StepContext)
return func(step_context)

Expand Down Expand Up @@ -862,7 +905,7 @@ def complex_submitter(callback_id):
msg = "Invalid callback ID"
raise ValueError(msg)

def execute_step(func, name):
def execute_step(func, name, config):
step_context = Mock(spec=StepContext)
return func(step_context)

Expand Down Expand Up @@ -942,7 +985,9 @@ def test_callback_name_variations():

assert result == f"result_for_{name}"
expected_name = f"{name} submitter" if name else "submitter"
mock_context.step.assert_called_once_with(func=ANY, name=expected_name)
mock_context.step.assert_called_once_with(
func=ANY, name=expected_name, config=None
)
mock_context.reset_mock()


Expand Down
Loading