diff --git a/src/aws_durable_execution_sdk_python/context.py b/src/aws_durable_execution_sdk_python/context.py index e790665..f3f45a3 100644 --- a/src/aws_durable_execution_sdk_python/context.py +++ b/src/aws_durable_execution_sdk_python/context.py @@ -167,6 +167,7 @@ def __init__( self.logger: Logger = logger or Logger.from_log_info( logger=logging.getLogger(), info=log_info, + execution_state=state, ) # region factories @@ -212,6 +213,8 @@ def set_logger(self, new_logger: LoggerInterface): self.logger = Logger.from_log_info( logger=new_logger, info=self._log_info, + execution_state=self.state, + visited_operations=self.logger.visited_operations, ) def _create_step_id(self) -> str: @@ -248,6 +251,9 @@ def create_callback( if not config: config = CallbackConfig() operation_id: str = self._create_step_id() + # Mark operation as visited before execution + self.logger.mark_operation_visited(operation_id) + callback_id: str = create_callback_handler( state=self.state, operation_identifier=OperationIdentifier( @@ -281,12 +287,16 @@ def invoke( Returns: The result of the invoked function """ + operation_id = self._create_step_id() + # Mark operation as visited before execution + self.logger.mark_operation_visited(operation_id) + return invoke_handler( function_name=function_name, payload=payload, state=self.state, operation_identifier=OperationIdentifier( - operation_id=self._create_step_id(), + operation_id=operation_id, parent_id=self._parent_id, name=name, ), @@ -361,6 +371,8 @@ def run_in_child_context( step_name: str | None = self._resolve_step_name(name, func) # _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id operation_id = self._create_step_id() + # Mark operation as visited before execution + self.logger.mark_operation_visited(operation_id) def callable_with_child_context(): return func(self.create_child_context(parent_id=operation_id)) @@ -383,12 +395,16 @@ def step( step_name = self._resolve_step_name(name, func) logger.debug("Step name: %s", step_name) + operation_id = self._create_step_id() + # Mark operation as visited before execution + self.logger.mark_operation_visited(operation_id) + return step_handler( func=func, config=config, state=self.state, operation_identifier=OperationIdentifier( - operation_id=self._create_step_id(), + operation_id=operation_id, parent_id=self._parent_id, name=step_name, ), @@ -405,11 +421,16 @@ def wait(self, seconds: int, name: str | None = None) -> None: if seconds < 1: msg = "seconds must be an integer greater than 0" raise ValidationError(msg) + + operation_id = self._create_step_id() + # Mark operation as visited before execution + self.logger.mark_operation_visited(operation_id) + wait_handler( seconds=seconds, state=self.state, operation_identifier=OperationIdentifier( - operation_id=self._create_step_id(), + operation_id=operation_id, parent_id=self._parent_id, name=name, ), @@ -455,12 +476,16 @@ def wait_for_condition( msg = "`config` is required for wait_for_condition" raise ValidationError(msg) + operation_id = self._create_step_id() + # Mark operation as visited before execution + self.logger.mark_operation_visited(operation_id) + return wait_for_condition_handler( check=check, config=config, state=self.state, operation_identifier=OperationIdentifier( - operation_id=self._create_step_id(), + operation_id=operation_id, parent_id=self._parent_id, name=name, ), diff --git a/src/aws_durable_execution_sdk_python/logger.py b/src/aws_durable_execution_sdk_python/logger.py index f68b9b8..7ef1b09 100644 --- a/src/aws_durable_execution_sdk_python/logger.py +++ b/src/aws_durable_execution_sdk_python/logger.py @@ -5,12 +5,14 @@ from dataclasses import dataclass from typing import TYPE_CHECKING +from aws_durable_execution_sdk_python.lambda_service import OperationType from aws_durable_execution_sdk_python.types import LoggerInterface if TYPE_CHECKING: from collections.abc import Mapping, MutableMapping from aws_durable_execution_sdk_python.identifier import OperationIdentifier + from aws_durable_execution_sdk_python.state import ExecutionState @dataclass(frozen=True) @@ -44,13 +46,25 @@ def with_parent_id(self, parent_id: str) -> LogInfo: class Logger(LoggerInterface): def __init__( - self, logger: LoggerInterface, default_extra: Mapping[str, object] + self, + logger: LoggerInterface, + default_extra: Mapping[str, object], + execution_state: ExecutionState | None = None, + visited_operations: set[str] | None = None, ) -> None: self._logger = logger self._default_extra = default_extra + self._execution_state = execution_state + self._visited_operations = visited_operations or set() @classmethod - def from_log_info(cls, logger: LoggerInterface, info: LogInfo) -> Logger: + def from_log_info( + cls, + logger: LoggerInterface, + info: LogInfo, + execution_state: ExecutionState | None = None, + visited_operations: set[str] | None = None, + ) -> Logger: """Create a new logger with the given LogInfo.""" extra: MutableMapping[str, object] = {"execution_arn": info.execution_arn} if info.parent_id: @@ -59,45 +73,118 @@ def from_log_info(cls, logger: LoggerInterface, info: LogInfo) -> Logger: extra["name"] = info.name if info.attempt: extra["attempt"] = info.attempt - return cls(logger, extra) + return cls(logger, extra, execution_state, visited_operations) def with_log_info(self, info: LogInfo) -> Logger: """Clone the existing logger with new LogInfo.""" return Logger.from_log_info( logger=self._logger, info=info, + execution_state=self._execution_state, + visited_operations=self._visited_operations, ) def get_logger(self) -> LoggerInterface: """Get the underlying logger.""" return self._logger + def is_replay(self) -> bool: + """Check if we are currently in replay mode. + + Returns True if there are operations in the execution state that haven't been visited yet. + This indicates we are replaying previously executed operations. + """ + if not self._execution_state: + return False + + # If there are no operations, we're not in replay + if not self._execution_state.operations: + return False + + # Check if there are any operations in the execution state that we haven't visited + # Only consider operations that are not EXECUTION type (which are system operations) + for operation_id, operation in self._execution_state.operations.items(): + # Skip EXECUTION operations as they are system operations, not user operations + if operation.operation_type == OperationType.EXECUTION: + continue + if operation_id not in self._visited_operations: + return True + return False + + def mark_operation_visited(self, operation_id: str) -> None: + """Mark an operation as visited.""" + self._visited_operations.add(operation_id) + + def _should_log(self) -> bool: + """Determine if logging should occur based on replay state.""" + # For the default logger, only log when not in replay + return not self.is_replay() + def debug( self, msg: object, *args: object, extra: Mapping[str, object] | None = None ) -> None: + if not self._should_log(): + return merged_extra = {**self._default_extra, **(extra or {})} self._logger.debug(msg, *args, extra=merged_extra) def info( self, msg: object, *args: object, extra: Mapping[str, object] | None = None ) -> None: + if not self._should_log(): + return merged_extra = {**self._default_extra, **(extra or {})} self._logger.info(msg, *args, extra=merged_extra) def warning( self, msg: object, *args: object, extra: Mapping[str, object] | None = None ) -> None: + if not self._should_log(): + return merged_extra = {**self._default_extra, **(extra or {})} self._logger.warning(msg, *args, extra=merged_extra) def error( self, msg: object, *args: object, extra: Mapping[str, object] | None = None ) -> None: + if not self._should_log(): + return merged_extra = {**self._default_extra, **(extra or {})} self._logger.error(msg, *args, extra=merged_extra) def exception( self, msg: object, *args: object, extra: Mapping[str, object] | None = None ) -> None: + if not self._should_log(): + return merged_extra = {**self._default_extra, **(extra or {})} self._logger.exception(msg, *args, extra=merged_extra) + + @property + def visited_operations(self): + return self._visited_operations + + +class ReplayAwareLogger(Logger): + """A logger that provides custom replay behavior for advanced users. + + This logger allows users to customize logging behavior during replay by overriding + the _should_log method. By default, it behaves the same as the base Logger. + """ + + def _should_log(self) -> bool: + """Override this method to customize replay logging behavior. + + Returns: + bool: True if logging should occur, False otherwise. + + Example: + def _should_log(self) -> bool: + # Always log, even during replay + return True + + def _should_log(self) -> bool: + # Only log errors during replay + return not self.is_replay() or self._current_log_level == 'error' + """ + return super()._should_log() diff --git a/tests/logger_test.py b/tests/logger_test.py index d3b76aa..dc2f3ae 100644 --- a/tests/logger_test.py +++ b/tests/logger_test.py @@ -4,7 +4,14 @@ from unittest.mock import Mock from aws_durable_execution_sdk_python.identifier import OperationIdentifier -from aws_durable_execution_sdk_python.logger import Logger, LoggerInterface, LogInfo +from aws_durable_execution_sdk_python.lambda_service import Operation, OperationType +from aws_durable_execution_sdk_python.logger import ( + Logger, + LoggerInterface, + LogInfo, + ReplayAwareLogger, +) +from aws_durable_execution_sdk_python.state import ExecutionState class PowertoolsLoggerStub: @@ -325,3 +332,203 @@ def test_logger_extra_override(): "new_field": "value", } mock_logger.info.assert_called_once_with("test", extra=expected_extra) + + +def test_replay_detection_with_operations(): + """Test replay detection when operations exist but not visited.""" + mock_state = Mock(spec=ExecutionState) + mock_state.operations = { + "step1": Mock(spec=Operation, operation_type=OperationType.STEP), + "step2": Mock(spec=Operation, operation_type=OperationType.STEP), + } + + log_info = LogInfo("arn:aws:test") + logger = Logger.from_log_info(Mock(), log_info, execution_state=mock_state) + + # Should be in replay mode when operations exist but not visited + assert logger.is_replay() is True + + +def test_replay_detection_partial_visited(): + """Test replay detection when some operations are visited.""" + mock_state = Mock(spec=ExecutionState) + mock_state.operations = { + "step1": Mock(spec=Operation, operation_type=OperationType.STEP), + "step2": Mock(spec=Operation, operation_type=OperationType.STEP), + } + + log_info = LogInfo("arn:aws:test") + logger = Logger.from_log_info(Mock(), log_info, execution_state=mock_state) + + # Mark one operation as visited + logger.mark_operation_visited("step1") + + # Should still be in replay mode because step2 not visited + assert logger.is_replay() is True + + +def test_replay_detection_all_visited(): + """Test replay detection when all operations are visited.""" + mock_state = Mock(spec=ExecutionState) + mock_state.operations = { + "step1": Mock(spec=Operation, operation_type=OperationType.STEP), + "step2": Mock(spec=Operation, operation_type=OperationType.STEP), + } + + log_info = LogInfo("arn:aws:test") + logger = Logger.from_log_info(Mock(), log_info, execution_state=mock_state) + + # Mark all operations as visited + logger.mark_operation_visited("step1") + logger.mark_operation_visited("step2") + + # Should not be in replay mode + assert logger.is_replay() is False + + +def test_replay_detection_ignores_execution_operations(): + """Test that EXECUTION operations don't trigger replay mode.""" + mock_state = Mock(spec=ExecutionState) + mock_state.operations = { + "execution-1": Mock(spec=Operation, operation_type=OperationType.EXECUTION), + } + + log_info = LogInfo("arn:aws:test") + logger = Logger.from_log_info(Mock(), log_info, execution_state=mock_state) + + # Should NOT be in replay mode because EXECUTION operations are ignored + assert logger.is_replay() is False + + +def test_replay_detection_no_execution_state(): + """Test replay detection when no execution state is provided.""" + log_info = LogInfo("arn:aws:test") + logger = Logger.from_log_info(Mock(), log_info) + + # Should not be in replay mode when no execution state + assert logger.is_replay() is False + + +def test_replay_detection_empty_operations(): + """Test replay detection when execution state has no operations.""" + mock_state = Mock(spec=ExecutionState) + mock_state.operations = {} + + log_info = LogInfo("arn:aws:test") + logger = Logger.from_log_info(Mock(), log_info, execution_state=mock_state) + + # Should not be in replay mode when no operations + assert logger.is_replay() is False + + +def test_logging_suppressed_during_replay(): + """Test that logging is suppressed during replay.""" + mock_state = Mock(spec=ExecutionState) + mock_state.operations = { + "step1": Mock(spec=Operation, operation_type=OperationType.STEP), + } + + mock_underlying_logger = Mock() + log_info = LogInfo("arn:aws:test") + logger = Logger.from_log_info( + mock_underlying_logger, log_info, execution_state=mock_state + ) + + # During replay, logging should be suppressed + logger.info("This should not be logged during replay") + logger.debug("This should not be logged during replay") + logger.warning("This should not be logged during replay") + logger.error("This should not be logged during replay") + logger.exception("This should not be logged during replay") + + mock_underlying_logger.info.assert_not_called() + mock_underlying_logger.debug.assert_not_called() + mock_underlying_logger.warning.assert_not_called() + mock_underlying_logger.error.assert_not_called() + mock_underlying_logger.exception.assert_not_called() + + +def test_logging_works_after_replay(): + """Test that logging works after replay ends.""" + mock_state = Mock(spec=ExecutionState) + mock_state.operations = { + "step1": Mock(spec=Operation, operation_type=OperationType.STEP), + } + + mock_underlying_logger = Mock() + log_info = LogInfo("arn:aws:test") + logger = Logger.from_log_info( + mock_underlying_logger, log_info, execution_state=mock_state + ) + + # Mark operation as visited (end replay) + logger.mark_operation_visited("step1") + + # Now logging should work + logger.info("This should be logged after replay") + mock_underlying_logger.info.assert_called_once() + + +def test_mark_operation_visited(): + """Test marking operations as visited.""" + log_info = LogInfo("arn:aws:test") + logger = Logger.from_log_info(Mock(), log_info) + + # Initially no operations visited + assert "op1" not in logger.visited_operations + + # Mark operation as visited + logger.mark_operation_visited("op1") + assert "op1" in logger.visited_operations + + # Mark another operation + logger.mark_operation_visited("op2") + assert "op2" in logger.visited_operations + assert "op1" in logger.visited_operations + + +def test_replay_aware_logger(): + """Test ReplayAwareLogger custom behavior.""" + + class AlwaysLogLogger(ReplayAwareLogger): + def _should_log(self) -> bool: + return True # Always log, even during replay + + mock_state = Mock(spec=ExecutionState) + mock_state.operations = { + "step1": Mock(spec=Operation, operation_type=OperationType.STEP), + } + + mock_underlying_logger = Mock() + log_info = LogInfo("arn:aws:test") + + custom_logger = AlwaysLogLogger.from_log_info( + mock_underlying_logger, log_info, execution_state=mock_state + ) + + # Should be in replay mode + assert custom_logger.is_replay() is True + + # But should still log because of custom _should_log implementation + custom_logger.info("This should log even during replay") + mock_underlying_logger.info.assert_called_once() + + +def test_with_log_info_preserves_execution_state(): + """Test that with_log_info preserves execution state and visited operations.""" + mock_state = Mock(spec=ExecutionState) + mock_state.operations = { + "step1": Mock(spec=Operation, operation_type=OperationType.STEP), + } + + original_info = LogInfo("arn:aws:test", "parent1") + logger = Logger.from_log_info(Mock(), original_info, execution_state=mock_state) + logger.mark_operation_visited("step1") + + new_info = LogInfo("arn:aws:new", "parent2", "new_name") + new_logger = logger.with_log_info(new_info) + + # Should preserve execution state and visited operations + assert new_logger._execution_state is mock_state # noqa: SLF001 + assert "step1" in new_logger.visited_operations + assert new_logger.is_replay() is False # Because step1 is visited