diff --git a/src/aws_durable_execution_sdk_python/context.py b/src/aws_durable_execution_sdk_python/context.py index 3e10014..e351b7a 100644 --- a/src/aws_durable_execution_sdk_python/context.py +++ b/src/aws_durable_execution_sdk_python/context.py @@ -53,6 +53,7 @@ from collections.abc import Callable, Sequence from aws_durable_execution_sdk_python.state import CheckpointedResult + from aws_durable_execution_sdk_python.types import LambdaContext P = TypeVar("P") # Payload type R = TypeVar("R") # Result type @@ -149,7 +150,7 @@ class DurableContext(DurableContextProtocol): def __init__( self, state: ExecutionState, - lambda_context: Any | None = None, + lambda_context: LambdaContext | None = None, parent_id: str | None = None, logger: Logger | None = None, ) -> None: @@ -171,7 +172,7 @@ def __init__( @staticmethod def from_lambda_context( state: ExecutionState, - lambda_context: Any, + lambda_context: LambdaContext, ): return DurableContext( state=state, diff --git a/src/aws_durable_execution_sdk_python/execution.py b/src/aws_durable_execution_sdk_python/execution.py index a685233..95c086d 100644 --- a/src/aws_durable_execution_sdk_python/execution.py +++ b/src/aws_durable_execution_sdk_python/execution.py @@ -25,6 +25,8 @@ if TYPE_CHECKING: from collections.abc import Callable, MutableMapping + from aws_durable_execution_sdk_python.types import LambdaContext + logger = logging.getLogger(__name__) @@ -187,10 +189,10 @@ def create_succeeded(cls, result: str) -> DurableExecutionInvocationOutput: def durable_handler( func: Callable[[Any, DurableContext], Any], -) -> Callable[[Any, Any], Any]: +) -> Callable[[Any, LambdaContext], Any]: logger.debug("Starting durable execution handler...") - def wrapper(event: Any, context: Any) -> MutableMapping[str, Any]: + def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: invocation_input: DurableExecutionInvocationInput service_client: DurableServiceClient diff --git a/src/aws_durable_execution_sdk_python/types.py b/src/aws_durable_execution_sdk_python/types.py index 94eb0ef..acc5525 100644 --- a/src/aws_durable_execution_sdk_python/types.py +++ b/src/aws_durable_execution_sdk_python/types.py @@ -135,3 +135,19 @@ def create_callback( ) -> Callback: """Create a callback.""" ... # pragma: no cover + + +class LambdaContext(Protocol): # pragma: no cover + aws_request_id: str + log_group_name: str | None = None + log_stream_name: str | None = None + function_name: str | None = None + memory_limit_in_mb: str | None = None + function_version: str | None = None + invoked_function_arn: str | None = None + tenant_id: str | None = None + client_context: Any | None = None + identity: Any | None = None + + def get_remaining_time_in_millis(self) -> int: ... + def log(self, msg) -> None: ...