Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions src/aws_durable_execution_sdk_python/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def __init__(
self.logger: Logger = logger or Logger.from_log_info(
logger=logging.getLogger(),
info=log_info,
execution_state=state,
)

# region factories
Expand Down Expand Up @@ -212,6 +213,8 @@ def set_logger(self, new_logger: LoggerInterface):
self.logger = Logger.from_log_info(
logger=new_logger,
info=self._log_info,
execution_state=self.state,
visited_operations=self.logger.visited_operations,
)

def _create_step_id(self) -> str:
Expand Down Expand Up @@ -248,6 +251,9 @@ def create_callback(
if not config:
config = CallbackConfig()
operation_id: str = self._create_step_id()
# Mark operation as visited before execution
self.logger.mark_operation_visited(operation_id)

callback_id: str = create_callback_handler(
state=self.state,
operation_identifier=OperationIdentifier(
Expand Down Expand Up @@ -281,12 +287,16 @@ def invoke(
Returns:
The result of the invoked function
"""
operation_id = self._create_step_id()
# Mark operation as visited before execution
self.logger.mark_operation_visited(operation_id)

return invoke_handler(
function_name=function_name,
payload=payload,
state=self.state,
operation_identifier=OperationIdentifier(
operation_id=self._create_step_id(),
operation_id=operation_id,
parent_id=self._parent_id,
name=name,
),
Expand Down Expand Up @@ -361,6 +371,8 @@ def run_in_child_context(
step_name: str | None = self._resolve_step_name(name, func)
# _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id
operation_id = self._create_step_id()
# Mark operation as visited before execution
self.logger.mark_operation_visited(operation_id)

def callable_with_child_context():
return func(self.create_child_context(parent_id=operation_id))
Expand All @@ -383,12 +395,16 @@ def step(
step_name = self._resolve_step_name(name, func)
logger.debug("Step name: %s", step_name)

operation_id = self._create_step_id()
# Mark operation as visited before execution
self.logger.mark_operation_visited(operation_id)

return step_handler(
func=func,
config=config,
state=self.state,
operation_identifier=OperationIdentifier(
operation_id=self._create_step_id(),
operation_id=operation_id,
parent_id=self._parent_id,
name=step_name,
),
Expand All @@ -405,11 +421,16 @@ def wait(self, seconds: int, name: str | None = None) -> None:
if seconds < 1:
msg = "seconds must be an integer greater than 0"
raise ValidationError(msg)

operation_id = self._create_step_id()
# Mark operation as visited before execution
self.logger.mark_operation_visited(operation_id)

wait_handler(
seconds=seconds,
state=self.state,
operation_identifier=OperationIdentifier(
operation_id=self._create_step_id(),
operation_id=operation_id,
parent_id=self._parent_id,
name=name,
),
Expand Down Expand Up @@ -455,12 +476,16 @@ def wait_for_condition(
msg = "`config` is required for wait_for_condition"
raise ValidationError(msg)

operation_id = self._create_step_id()
# Mark operation as visited before execution
self.logger.mark_operation_visited(operation_id)

return wait_for_condition_handler(
check=check,
config=config,
state=self.state,
operation_identifier=OperationIdentifier(
operation_id=self._create_step_id(),
operation_id=operation_id,
parent_id=self._parent_id,
name=name,
),
Expand Down
93 changes: 90 additions & 3 deletions src/aws_durable_execution_sdk_python/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING

from aws_durable_execution_sdk_python.lambda_service import OperationType
from aws_durable_execution_sdk_python.types import LoggerInterface

if TYPE_CHECKING:
from collections.abc import Mapping, MutableMapping

from aws_durable_execution_sdk_python.identifier import OperationIdentifier
from aws_durable_execution_sdk_python.state import ExecutionState


@dataclass(frozen=True)
Expand Down Expand Up @@ -44,13 +46,25 @@ def with_parent_id(self, parent_id: str) -> LogInfo:

class Logger(LoggerInterface):
def __init__(
self, logger: LoggerInterface, default_extra: Mapping[str, object]
self,
logger: LoggerInterface,
default_extra: Mapping[str, object],
execution_state: ExecutionState | None = None,
visited_operations: set[str] | None = None,
) -> None:
self._logger = logger
self._default_extra = default_extra
self._execution_state = execution_state
self._visited_operations = visited_operations or set()

@classmethod
def from_log_info(cls, logger: LoggerInterface, info: LogInfo) -> Logger:
def from_log_info(
cls,
logger: LoggerInterface,
info: LogInfo,
execution_state: ExecutionState | None = None,
visited_operations: set[str] | None = None,
) -> Logger:
"""Create a new logger with the given LogInfo."""
extra: MutableMapping[str, object] = {"execution_arn": info.execution_arn}
if info.parent_id:
Expand All @@ -59,45 +73,118 @@ def from_log_info(cls, logger: LoggerInterface, info: LogInfo) -> Logger:
extra["name"] = info.name
if info.attempt:
extra["attempt"] = info.attempt
return cls(logger, extra)
return cls(logger, extra, execution_state, visited_operations)

def with_log_info(self, info: LogInfo) -> Logger:
"""Clone the existing logger with new LogInfo."""
return Logger.from_log_info(
logger=self._logger,
info=info,
execution_state=self._execution_state,
visited_operations=self._visited_operations,
)

def get_logger(self) -> LoggerInterface:
"""Get the underlying logger."""
return self._logger

def is_replay(self) -> bool:
"""Check if we are currently in replay mode.

Returns True if there are operations in the execution state that haven't been visited yet.
This indicates we are replaying previously executed operations.
"""
if not self._execution_state:
return False

# If there are no operations, we're not in replay
if not self._execution_state.operations:
return False

# Check if there are any operations in the execution state that we haven't visited
# Only consider operations that are not EXECUTION type (which are system operations)
for operation_id, operation in self._execution_state.operations.items():
# Skip EXECUTION operations as they are system operations, not user operations
if operation.operation_type == OperationType.EXECUTION:
continue
if operation_id not in self._visited_operations:
return True
Comment on lines +104 to +111
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can be in replay and still not visit every operation, e.g. we have a parallel with 5 child executions that has completed due to minSuccessful == 1, then on replay, the optimized replay will only visit the completed thread.

return False

def mark_operation_visited(self, operation_id: str) -> None:
"""Mark an operation as visited."""
self._visited_operations.add(operation_id)

def _should_log(self) -> bool:
"""Determine if logging should occur based on replay state."""
# For the default logger, only log when not in replay
return not self.is_replay()

def debug(
self, msg: object, *args: object, extra: Mapping[str, object] | None = None
) -> None:
if not self._should_log():
return
merged_extra = {**self._default_extra, **(extra or {})}
self._logger.debug(msg, *args, extra=merged_extra)

def info(
self, msg: object, *args: object, extra: Mapping[str, object] | None = None
) -> None:
if not self._should_log():
return
merged_extra = {**self._default_extra, **(extra or {})}
self._logger.info(msg, *args, extra=merged_extra)

def warning(
self, msg: object, *args: object, extra: Mapping[str, object] | None = None
) -> None:
if not self._should_log():
return
merged_extra = {**self._default_extra, **(extra or {})}
self._logger.warning(msg, *args, extra=merged_extra)

def error(
self, msg: object, *args: object, extra: Mapping[str, object] | None = None
) -> None:
if not self._should_log():
return
merged_extra = {**self._default_extra, **(extra or {})}
self._logger.error(msg, *args, extra=merged_extra)

def exception(
self, msg: object, *args: object, extra: Mapping[str, object] | None = None
) -> None:
if not self._should_log():
return
merged_extra = {**self._default_extra, **(extra or {})}
self._logger.exception(msg, *args, extra=merged_extra)

@property
def visited_operations(self):
return self._visited_operations


class ReplayAwareLogger(Logger):
"""A logger that provides custom replay behavior for advanced users.

This logger allows users to customize logging behavior during replay by overriding
the _should_log method. By default, it behaves the same as the base Logger.
"""

def _should_log(self) -> bool:
"""Override this method to customize replay logging behavior.

Returns:
bool: True if logging should occur, False otherwise.

Example:
def _should_log(self) -> bool:
# Always log, even during replay
return True

def _should_log(self) -> bool:
# Only log errors during replay
return not self.is_replay() or self._current_log_level == 'error'
"""
return super()._should_log()
Loading