diff --git a/src/aws_durable_execution_sdk_python/lambda_service.py b/src/aws_durable_execution_sdk_python/lambda_service.py index 200aede..5b079fa 100644 --- a/src/aws_durable_execution_sdk_python/lambda_service.py +++ b/src/aws_durable_execution_sdk_python/lambda_service.py @@ -1031,19 +1031,32 @@ def get_execution_state( class LambdaClient(DurableServiceClient): """Persist durable operations to the Lambda Durable Function APIs.""" + _cached_boto_client: Any = None + def __init__(self, client: Any) -> None: self.client = client - @staticmethod - def initialize_client() -> LambdaClient: - client = boto3.client( - "lambda", - config=Config( - connect_timeout=5, - read_timeout=50, - ), - ) - return LambdaClient(client=client) + @classmethod + def initialize_client(cls) -> LambdaClient: + """Initialize or return cached Lambda client. + + Implements lazy initialization with class-level caching to optimize + Lambda warm starts. The boto3 client is created once and reused across + invocations, avoiding repeated credential resolution and connection + pool setup. + + Returns: + LambdaClient: A new LambdaClient instance wrapping the cached boto3 client. + """ + if cls._cached_boto_client is None: + cls._cached_boto_client = boto3.client( + "lambda", + config=Config( + connect_timeout=5, + read_timeout=50, + ), + ) + return cls(client=cls._cached_boto_client) def checkpoint( self, diff --git a/tests/lambda_service_test.py b/tests/lambda_service_test.py index c812069..e8757cc 100644 --- a/tests/lambda_service_test.py +++ b/tests/lambda_service_test.py @@ -39,6 +39,19 @@ WaitOptions, ) +# ============================================================================= +# Fixtures +# ============================================================================= + + +@pytest.fixture +def reset_lambda_client_cache(): + """Reset the class-level boto3 client cache before and after each test.""" + LambdaClient._cached_boto_client = None # noqa: SLF001 + yield + LambdaClient._cached_boto_client = None # noqa: SLF001 + + # ============================================================================= # Tests for Data Classes (ExecutionDetails, ContextDetails, ErrorObject, etc.) # ============================================================================= @@ -1910,7 +1923,9 @@ def test_lambda_client_constructor(): @patch.dict("os.environ", {}, clear=True) @patch("boto3.client") -def test_lambda_client_initialize_client_default(mock_boto_client): +def test_lambda_client_initialize_client_default( + mock_boto_client, reset_lambda_client_cache +): """Test LambdaClient.initialize_client with default endpoint.""" mock_client = Mock() mock_boto_client.return_value = mock_client @@ -1930,7 +1945,9 @@ def test_lambda_client_initialize_client_default(mock_boto_client): @patch.dict("os.environ", {"AWS_ENDPOINT_URL_LAMBDA": "http://localhost:3000"}) @patch("boto3.client") -def test_lambda_client_initialize_client_with_endpoint(mock_boto_client): +def test_lambda_client_initialize_client_with_endpoint( + mock_boto_client, reset_lambda_client_cache +): """Test LambdaClient.initialize_client with custom endpoint (boto3 handles it automatically).""" mock_client = Mock() mock_boto_client.return_value = mock_client @@ -2008,7 +2025,9 @@ def test_checkpoint_error_handling(): @patch.dict("os.environ", {}, clear=True) @patch("boto3.client") -def test_lambda_client_initialize_client_no_endpoint(mock_boto_client): +def test_lambda_client_initialize_client_no_endpoint( + mock_boto_client, reset_lambda_client_cache +): """Test LambdaClient.initialize_client without AWS_ENDPOINT_URL_LAMBDA.""" mock_client = Mock() mock_boto_client.return_value = mock_client @@ -2047,6 +2066,60 @@ def test_lambda_client_checkpoint_with_non_none_client_token(): # ============================================================================= +# Tests for LambdaClient caching behavior +# ============================================================================= + + +@patch("boto3.client") +def test_lambda_client_cache_reuses_client(mock_boto_client, reset_lambda_client_cache): + """Test that initialize_client reuses the same boto3 client on subsequent calls.""" + mock_client = Mock() + mock_boto_client.return_value = mock_client + + # First call should create the boto3 client + client1 = LambdaClient.initialize_client() + + # Second call should reuse the same boto3 client + client2 = LambdaClient.initialize_client() + + # boto3.client should only be called once + mock_boto_client.assert_called_once() + + # Both LambdaClient instances should wrap the same boto3 client + assert client1.client is client2.client + + +@patch("boto3.client") +def test_lambda_client_cache_creates_client_only_once( + mock_boto_client, reset_lambda_client_cache +): + """Test that boto3.client is called only once even with multiple initialize_client calls.""" + mock_client = Mock() + mock_boto_client.return_value = mock_client + + # Call initialize_client multiple times + for _ in range(5): + LambdaClient.initialize_client() + + # boto3.client should only be called once + assert mock_boto_client.call_count == 1 + + +@patch("boto3.client") +def test_lambda_client_cache_is_class_level( + mock_boto_client, reset_lambda_client_cache +): + """Test that the boto3 client cache is stored at class level.""" + mock_client = Mock() + mock_boto_client.return_value = mock_client + + # Create client + LambdaClient.initialize_client() + + # Verify the boto3 client is cached at class level + assert LambdaClient._cached_boto_client is mock_client # noqa: SLF001 + + # Tests for Operation JSON Serialization Methods # =============================================================================