From a104f99e3b81b9b06006b31e4fd8af5010d34b35 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Tue, 13 Jan 2026 23:58:06 +0100 Subject: [PATCH] refactor: adding proper type annotations --- pyproject.toml | 2 +- .../execution.py | 4 +- .../lambda_service.py | 52 +++++++++++-------- 3 files changed, 33 insertions(+), 25 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d80d37c..08ef201 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ dependencies = ["coverage[toml]", "pytest", "pytest-cov"] cov = "pytest --cov-report=term-missing --cov-config=pyproject.toml --cov=src/aws_durable_execution_sdk_python --cov-fail-under=98" [tool.hatch.envs.types] -extra-dependencies = ["mypy>=1.0.0", "pytest"] +extra-dependencies = ["mypy>=1.0.0", "pytest", "boto3-stubs[lambda]"] [tool.hatch.envs.types.scripts] check = "mypy --install-types --non-interactive {args:src/aws_durable_execution_sdk_python tests}" diff --git a/src/aws_durable_execution_sdk_python/execution.py b/src/aws_durable_execution_sdk_python/execution.py index 0ce1ac3..fc3058f 100644 --- a/src/aws_durable_execution_sdk_python/execution.py +++ b/src/aws_durable_execution_sdk_python/execution.py @@ -32,7 +32,7 @@ if TYPE_CHECKING: from collections.abc import Callable, MutableMapping - import boto3 # type: ignore + from mypy_boto3_lambda import LambdaClient as Boto3LambdaClient from aws_durable_execution_sdk_python.types import LambdaContext @@ -237,7 +237,7 @@ def create_succeeded(cls, result: str) -> DurableExecutionInvocationOutput: def durable_execution( func: Callable[[Any, DurableContext], Any] | None = None, *, - boto3_client: boto3.client | None = None, + boto3_client: Boto3LambdaClient | None = None, ) -> Callable[[Any, LambdaContext], Any]: # Decorator called with parameters if func is None: diff --git a/src/aws_durable_execution_sdk_python/lambda_service.py b/src/aws_durable_execution_sdk_python/lambda_service.py index 5b079fa..b11f950 100644 --- a/src/aws_durable_execution_sdk_python/lambda_service.py +++ b/src/aws_durable_execution_sdk_python/lambda_service.py @@ -3,12 +3,13 @@ import copy import datetime import logging +from collections.abc import MutableMapping from dataclasses import dataclass, field from enum import Enum -from typing import TYPE_CHECKING, Any, Protocol, TypeAlias +from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, cast -import boto3 # type: ignore -from botocore.config import Config # type: ignore +import boto3 +from botocore.config import Config from aws_durable_execution_sdk_python.exceptions import ( CallableRuntimeError, @@ -17,7 +18,11 @@ ) if TYPE_CHECKING: - from collections.abc import MutableMapping + from mypy_boto3_lambda import LambdaClient as Boto3LambdaClient + from mypy_boto3_lambda.type_defs import ( + CheckpointDurableExecutionResponseTypeDef, + GetDurableExecutionStateResponseTypeDef, + ) from aws_durable_execution_sdk_python.identifier import OperationIdentifier @@ -1031,9 +1036,9 @@ def get_execution_state( class LambdaClient(DurableServiceClient): """Persist durable operations to the Lambda Durable Function APIs.""" - _cached_boto_client: Any = None + _cached_boto_client: Boto3LambdaClient | None = None - def __init__(self, client: Any) -> None: + def __init__(self, client: Boto3LambdaClient) -> None: self.client = client @classmethod @@ -1066,19 +1071,20 @@ def checkpoint( client_token: str | None, ) -> CheckpointOutput: try: - params = { - "DurableExecutionArn": durable_execution_arn, - "CheckpointToken": checkpoint_token, - "Updates": [o.to_dict() for o in updates], - } + optional_params: dict[str, str] = {} if client_token is not None: - params["ClientToken"] = client_token - - result: MutableMapping[str, Any] = self.client.checkpoint_durable_execution( - **params + optional_params["ClientToken"] = client_token + + result: CheckpointDurableExecutionResponseTypeDef = ( + self.client.checkpoint_durable_execution( + DurableExecutionArn=durable_execution_arn, + CheckpointToken=checkpoint_token, + Updates=cast(Any, [o.to_dict() for o in updates]), + **optional_params, # type: ignore[arg-type] + ) ) - return CheckpointOutput.from_dict(result) + return CheckpointOutput.from_dict(cast(MutableMapping[str, Any], result)) except Exception as e: checkpoint_error = CheckpointError.from_exception(e) logger.exception( @@ -1094,13 +1100,15 @@ def get_execution_state( max_items: int = 1000, ) -> StateOutput: try: - result: MutableMapping[str, Any] = self.client.get_durable_execution_state( - DurableExecutionArn=durable_execution_arn, - CheckpointToken=checkpoint_token, - Marker=next_marker, - MaxItems=max_items, + result: GetDurableExecutionStateResponseTypeDef = ( + self.client.get_durable_execution_state( + DurableExecutionArn=durable_execution_arn, + CheckpointToken=checkpoint_token, + Marker=next_marker, + MaxItems=max_items, + ) ) - return StateOutput.from_dict(result) + return StateOutput.from_dict(cast(MutableMapping[str, Any], result)) except Exception as e: error = GetExecutionStateError.from_exception(e) logger.exception(