Skip to content

Commit 6991b07

Browse files
author
Rares Polenciuc
committed
pass retry config to the Step that wraps the submitter
1 parent 87f08b2 commit 6991b07

File tree

2 files changed

+19
-11
lines changed

2 files changed

+19
-11
lines changed

src/aws_durable_execution_sdk_python/operation/callback.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from typing import TYPE_CHECKING, Any
66

7+
from aws_durable_execution_sdk_python.config import StepConfig
78
from aws_durable_execution_sdk_python.exceptions import FatalError
89
from aws_durable_execution_sdk_python.lambda_service import (
910
CallbackOptions,
@@ -16,7 +17,7 @@
1617
from aws_durable_execution_sdk_python.config import (
1718
CallbackConfig,
1819
WaitForCallbackConfig,
19-
)
20+
)
2021
from aws_durable_execution_sdk_python.identifier import OperationIdentifier
2122
from aws_durable_execution_sdk_python.state import (
2223
CheckpointedResult,
@@ -97,6 +98,13 @@ def wait_for_callback_handler(
9798
def submitter_step(step_context): # noqa: ARG001
9899
return submitter(callback.callback_id)
99100

100-
context.step(func=submitter_step, name=f"{name_with_space}submitter")
101+
if config:
102+
step_config = StepConfig(
103+
retry_strategy=config.retry_strategy,
104+
serdes=config.serdes,
105+
)
106+
context.step(func=submitter_step, name=f"{name_with_space}submitter", config=step_config)
107+
else:
108+
context.step(func=submitter_step, name=f"{name_with_space}submitter")
101109

102110
return callback.result()

tests/operation/callback_test.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import pytest
77

8-
from aws_durable_execution_sdk_python.config import CallbackConfig
8+
from aws_durable_execution_sdk_python.config import CallbackConfig, WaitForCallbackConfig
99
from aws_durable_execution_sdk_python.exceptions import FatalError
1010
from aws_durable_execution_sdk_python.identifier import OperationIdentifier
1111
from aws_durable_execution_sdk_python.lambda_service import (
@@ -269,7 +269,7 @@ def test_wait_for_callback_handler_with_name_and_config():
269269
mock_callback.result.return_value = "named_callback_result"
270270
mock_context.create_callback.return_value = mock_callback
271271
mock_submitter = Mock()
272-
config = CallbackConfig()
272+
config = WaitForCallbackConfig()
273273

274274
result = wait_for_callback_handler(
275275
mock_context, mock_submitter, "test_callback", config
@@ -291,7 +291,7 @@ def test_wait_for_callback_handler_submitter_called_with_callback_id():
291291
mock_context.create_callback.return_value = mock_callback
292292
mock_submitter = Mock()
293293

294-
def capture_step_call(func, name):
294+
def capture_step_call(func, name, config=None):
295295
# Execute the step callable to verify submitter is called correctly
296296
step_context = Mock(spec=StepContext)
297297
func(step_context)
@@ -357,7 +357,7 @@ def test_wait_for_callback_handler_with_none_callback_id():
357357
mock_context.create_callback.return_value = mock_callback
358358
mock_submitter = Mock()
359359

360-
def execute_step(func, name):
360+
def execute_step(func, name, config=None):
361361
step_context = Mock(spec=StepContext)
362362
return func(step_context)
363363

@@ -378,7 +378,7 @@ def test_wait_for_callback_handler_with_empty_string_callback_id():
378378
mock_context.create_callback.return_value = mock_callback
379379
mock_submitter = Mock()
380380

381-
def execute_step(func, name):
381+
def execute_step(func, name, config=None):
382382
step_context = Mock(spec=StepContext)
383383
return func(step_context)
384384

@@ -591,7 +591,7 @@ def failing_submitter(callback_id):
591591
msg = "Submitter failed"
592592
raise ValueError(msg)
593593

594-
def step_side_effect(func, name):
594+
def step_side_effect(func, name, config=None):
595595
step_context = Mock(spec=StepContext)
596596
func(step_context)
597597

@@ -675,7 +675,7 @@ def test_wait_for_callback_handler_config_propagation():
675675
mock_context.create_callback.return_value = mock_callback
676676
mock_submitter = Mock()
677677

678-
config = CallbackConfig(timeout_seconds=120, heartbeat_timeout_seconds=30)
678+
config = WaitForCallbackConfig(timeout_seconds=120, heartbeat_timeout_seconds=30)
679679

680680
result = wait_for_callback_handler(
681681
mock_context, mock_submitter, "config_test", config
@@ -729,7 +729,7 @@ def test_callback_lifecycle_complete_flow():
729729
mock_callback.result.return_value = {"status": "completed", "data": "test_data"}
730730
mock_context.create_callback.return_value = mock_callback
731731

732-
config = CallbackConfig(timeout_seconds=300, heartbeat_timeout_seconds=60)
732+
config = WaitForCallbackConfig(timeout_seconds=300, heartbeat_timeout_seconds=60)
733733
callback_id = create_callback_handler(
734734
state=mock_state,
735735
operation_identifier=OperationIdentifier("lifecycle_callback", None),
@@ -742,7 +742,7 @@ def mock_submitter(cb_id):
742742
assert cb_id == "lifecycle_cb123"
743743
return "submitted"
744744

745-
def execute_step(func, name):
745+
def execute_step(func, name, config=None):
746746
step_context = Mock(spec=StepContext)
747747
return func(step_context)
748748

0 commit comments

Comments
 (0)