diff --git a/src/aws_durable_execution_sdk_python/context.py b/src/aws_durable_execution_sdk_python/context.py index f2b954c..06dc99c 100644 --- a/src/aws_durable_execution_sdk_python/context.py +++ b/src/aws_durable_execution_sdk_python/context.py @@ -20,10 +20,6 @@ ValidationError, ) from aws_durable_execution_sdk_python.identifier import OperationIdentifier -from aws_durable_execution_sdk_python.lambda_context import ( - LambdaContext, - make_dict_from_obj, -) from aws_durable_execution_sdk_python.lambda_service import OperationSubType from aws_durable_execution_sdk_python.logger import Logger, LogInfo from aws_durable_execution_sdk_python.operation.callback import ( @@ -149,32 +145,16 @@ def result(self) -> T | None: raise FatalError(msg) -# It really would be great NOT to have to inherit from the LambdaContext. -# lot of noise here that we're not actually using. Alternative is to include -# via composition rather than inheritance -class DurableContext(LambdaContext, DurableContextProtocol): +class DurableContext(DurableContextProtocol): def __init__( self, state: ExecutionState, + lambda_context: Any | None = None, parent_id: str | None = None, logger: Logger | None = None, - # LambdaContext members follow - invoke_id=None, - client_context=None, - cognito_identity=None, - epoch_deadline_time_in_ms=0, - invoked_function_arn=None, - tenant_id=None, ) -> None: - super().__init__( - invoke_id=invoke_id, - client_context=client_context, - cognito_identity=cognito_identity, - epoch_deadline_time_in_ms=epoch_deadline_time_in_ms, - invoked_function_arn=invoked_function_arn, - tenant_id=tenant_id, - ) self.state: ExecutionState = state + self.lambda_context = lambda_context self._parent_id: str | None = parent_id self._step_counter: OrderedCounter = OrderedCounter() @@ -195,18 +175,12 @@ def __init__( @staticmethod def from_lambda_context( state: ExecutionState, - lambda_context: LambdaContext, + lambda_context: Any, ): return DurableContext( state=state, + lambda_context=lambda_context, parent_id=None, - invoke_id=lambda_context.aws_request_id, - client_context=make_dict_from_obj(lambda_context.client_context), - cognito_identity=make_dict_from_obj(lambda_context.identity), - # not great to have to use the private-ish accessor here, but for the moment not messing with LambdaContext signature - epoch_deadline_time_in_ms=lambda_context._epoch_deadline_time_in_ms, # noqa: SLF001 - invoked_function_arn=lambda_context.invoked_function_arn, - tenant_id=lambda_context.tenant_id, ) def create_child_context(self, parent_id: str) -> DurableContext: @@ -214,18 +188,13 @@ def create_child_context(self, parent_id: str) -> DurableContext: logger.debug("Creating child context for parent %s", parent_id) return DurableContext( state=self.state, + lambda_context=self.lambda_context, parent_id=parent_id, logger=self.logger.with_log_info( LogInfo( execution_arn=self.state.durable_execution_arn, parent_id=parent_id ) ), - invoke_id=self.aws_request_id, - client_context=make_dict_from_obj(self.client_context), - cognito_identity=make_dict_from_obj(self.identity), - epoch_deadline_time_in_ms=self._epoch_deadline_time_in_ms, - invoked_function_arn=self.invoked_function_arn, - tenant_id=self.tenant_id, ) # endregion factories diff --git a/src/aws_durable_execution_sdk_python/execution.py b/src/aws_durable_execution_sdk_python/execution.py index 1634d20..a685233 100644 --- a/src/aws_durable_execution_sdk_python/execution.py +++ b/src/aws_durable_execution_sdk_python/execution.py @@ -25,7 +25,6 @@ if TYPE_CHECKING: from collections.abc import Callable, MutableMapping - from aws_durable_execution_sdk_python.lambda_context import LambdaContext logger = logging.getLogger(__name__) @@ -188,10 +187,10 @@ def create_succeeded(cls, result: str) -> DurableExecutionInvocationOutput: def durable_handler( func: Callable[[Any, DurableContext], Any], -) -> Callable[[Any, LambdaContext], Any]: +) -> Callable[[Any, Any], Any]: logger.debug("Starting durable execution handler...") - def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: + def wrapper(event: Any, context: Any) -> MutableMapping[str, Any]: invocation_input: DurableExecutionInvocationInput service_client: DurableServiceClient diff --git a/src/aws_durable_execution_sdk_python/lambda_context.py b/src/aws_durable_execution_sdk_python/lambda_context.py deleted file mode 100644 index 68dd1cc..0000000 --- a/src/aws_durable_execution_sdk_python/lambda_context.py +++ /dev/null @@ -1,188 +0,0 @@ -# mypy: ignore-errors -"""Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - -The orignal actually lives here: -https://github.com/aws/aws-lambda-python-runtime-interface-client/blob/main/awslambdaric/lambda_context.py - -On a quick look it's missing tenant_id and the Python 3.13 upgrades. - -The 3.1.1 wheel is ~269.1 kB. Which honeslly, the entire dependency for the sake of this little class? - -For what it's worth, PowerTools also doesn't re-use the actual Python RIC LambdaContext, it also defines its -own copied type here: -https://github.com/aws-powertools/powertools-lambda-python/blob/6e900c79fff44675fcef3a71a0e3310c54f01ecd/aws_lambda_powertools/utilities/typing/lambda_context.py - -For the moment I'm going to use this copied class, since all it's really doing is providing a base class for DurableContext - -given duck-typing it doesn't actually have to inherit from the "same" class in the RIC. -Yes, this can get out of date with the Python RIC, but at worst it just means red squiggly lines on new properties - -given duck-typing it'll work at runtime. - -""" - -import logging -import os -import sys -import time - - -class LambdaContext: - """Replicate the LambdaContext from the AWS Lambda ARIC. - - https://github.com/aws/aws-lambda-python-runtime-interface-client/blob/main/awslambdaric/lambda_context.py - - This is here solely for typings and to get DurableContext to inherit from LambdaContext without needing to - add `aws-lambda-python-runtime-interface-client` as a direct dependency of the Durable Executions SDK. - - This has a subtle and important side-effect. This class is _not_ actually the LambdaContext that the AWS - Lambda runtime passes to the Lambda handler. So do NOT added any custom methods or attributes here, you can - only rely on duck-typing so whatever is in this class replicates what is in the actual class, it will work. - """ - - def __init__( - self, - invoke_id, - client_context, - cognito_identity, - epoch_deadline_time_in_ms, - invoked_function_arn=None, - tenant_id=None, - ): - self.aws_request_id: str = invoke_id - self.log_group_name: str | None = os.environ.get("AWS_LAMBDA_LOG_GROUP_NAME") - self.log_stream_name: str | None = os.environ.get("AWS_LAMBDA_LOG_STREAM_NAME") - self.function_name: str | None = os.environ.get("AWS_LAMBDA_FUNCTION_NAME") - self.memory_limit_in_mb: str | None = os.environ.get( - "AWS_LAMBDA_FUNCTION_MEMORY_SIZE" - ) - self.function_version: str | None = os.environ.get( - "AWS_LAMBDA_FUNCTION_VERSION" - ) - self.invoked_function_arn: str | None = invoked_function_arn - self.tenant_id: str | None = tenant_id - - self.client_context = make_obj_from_dict(ClientContext, client_context) - if self.client_context is not None: - self.client_context.client = make_obj_from_dict( - Client, self.client_context.client - ) - - self.identity = make_obj_from_dict(CognitoIdentity, {}) - if cognito_identity is not None: - self.identity.cognito_identity_id = cognito_identity.get( - "cognitoIdentityId" - ) - self.identity.cognito_identity_pool_id = cognito_identity.get( - "cognitoIdentityPoolId" - ) - - self._epoch_deadline_time_in_ms = epoch_deadline_time_in_ms - - def get_remaining_time_in_millis(self) -> int: - epoch_now_in_ms = int(time.time() * 1000) - delta_ms = self._epoch_deadline_time_in_ms - epoch_now_in_ms - return delta_ms if delta_ms > 0 else 0 - - def log(self, msg): - for handler in logging.getLogger().handlers: - if hasattr(handler, "log_sink"): - handler.log_sink.log(str(msg)) - return - sys.stdout.write(str(msg)) - - def __repr__(self): - return ( - f"{self.__class__.__name__}([" - f"aws_request_id={self.aws_request_id}," - f"log_group_name={self.log_group_name}," - f"log_stream_name={self.log_stream_name}," - f"function_name={self.function_name}," - f"memory_limit_in_mb={self.memory_limit_in_mb}," - f"function_version={self.function_version}," - f"invoked_function_arn={self.invoked_function_arn}," - f"client_context={self.client_context}," - f"identity={self.identity}," - f"tenant_id={self.tenant_id}" - "])" - ) - - -class CognitoIdentity: - __slots__ = ["cognito_identity_id", "cognito_identity_pool_id"] - - def __repr__(self): - return ( - f"{self.__class__.__name__}([" - f"cognito_identity_id={self.cognito_identity_id}," - f"cognito_identity_pool_id={self.cognito_identity_pool_id}" - "])" - ) - - -class Client: - __slots__ = [ - "installation_id", - "app_title", - "app_version_name", - "app_version_code", - "app_package_name", - ] - - def __repr__(self): - return ( - f"{self.__class__.__name__}([" - f"installation_id={self.installation_id}," - f"app_title={self.app_title}," - f"app_version_name={self.app_version_name}," - f"app_version_code={self.app_version_code}," - f"app_package_name={self.app_package_name}" - "])" - ) - - -class ClientContext: - __slots__ = ["custom", "env", "client"] - - def __repr__(self): - return ( - f"{self.__class__.__name__}([" - f"custom={self.custom}," - f"env={self.env}," - f"client={self.client}" - "])" - ) - - -def make_obj_from_dict(_class, _dict, fields=None): # noqa: ARG001 - if _dict is None: - return None - obj = _class() - set_obj_from_dict(obj, _dict) - return obj - - -def set_obj_from_dict(obj, _dict, fields=None): - if fields is None: - fields = obj.__class__.__slots__ - for field in fields: - setattr(obj, field, _dict.get(field, None)) - - -def make_dict_from_obj(obj): - """Convert an object with __slots__ back to a dictionary. - - Custom addition - not in the original AWS Lambda Runtime Interface Client (ARIC). This - is to help when DurableContext needs to call LambdaContext's super() constructor and pass - it the original dictionaries. - This is the reverse of make_obj_from_dict to convert __slots__ objects back to dictionaries. - """ - if obj is None: - return None - - result = {} - for field in obj.__class__.__slots__: - value = getattr(obj, field, None) - # Recursively convert nested objects - if value is not None and hasattr(value, "__slots__"): - value = make_dict_from_obj(value) - result[field] = value - return result diff --git a/tests/e2e/execution_int_test.py b/tests/e2e/execution_int_test.py index fdbcff3..b21de6b 100644 --- a/tests/e2e/execution_int_test.py +++ b/tests/e2e/execution_int_test.py @@ -14,7 +14,8 @@ InvocationStatus, durable_handler, ) -from aws_durable_execution_sdk_python.lambda_context import LambdaContext + +# LambdaContext no longer needed - using duck typing from aws_durable_execution_sdk_python.lambda_service import ( CheckpointOutput, CheckpointUpdatedExecutionState, @@ -102,7 +103,7 @@ def mock_checkpoint( } # Create mock lambda context - lambda_context = Mock(spec=LambdaContext) + lambda_context = Mock() lambda_context.aws_request_id = "test-request-id" lambda_context.client_context = None lambda_context.identity = None @@ -185,7 +186,7 @@ def mock_checkpoint( } # Create mock lambda context - lambda_context = Mock(spec=LambdaContext) + lambda_context = Mock() lambda_context.aws_request_id = "test-request-id" lambda_context.client_context = None lambda_context.identity = None @@ -277,7 +278,7 @@ def mock_checkpoint( } # Create mock lambda context - lambda_context = Mock(spec=LambdaContext) + lambda_context = Mock() lambda_context.aws_request_id = "test-request-id" lambda_context.client_context = None lambda_context.identity = None @@ -368,7 +369,7 @@ def mock_checkpoint( } # Create mock lambda context - lambda_context = Mock(spec=LambdaContext) + lambda_context = Mock() lambda_context.aws_request_id = "test-request-id" lambda_context.client_context = None lambda_context.identity = None diff --git a/tests/execution_test.py b/tests/execution_test.py index cffa141..46cd1e1 100644 --- a/tests/execution_test.py +++ b/tests/execution_test.py @@ -16,7 +16,8 @@ InvocationStatus, durable_handler, ) -from aws_durable_execution_sdk_python.lambda_context import LambdaContext + +# LambdaContext no longer needed - using duck typing from aws_durable_execution_sdk_python.lambda_service import ( CheckpointOutput, CheckpointUpdatedExecutionState, @@ -305,7 +306,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: "LocalRunner": False, } - lambda_context = Mock(spec=LambdaContext) + lambda_context = Mock() lambda_context.aws_request_id = "test-request" lambda_context.client_context = None lambda_context.identity = None @@ -358,7 +359,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: "LocalRunner": False, } - lambda_context = Mock(spec=LambdaContext) + lambda_context = Mock() lambda_context.aws_request_id = "test-request" lambda_context.client_context = None lambda_context.identity = None @@ -407,7 +408,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: service_client=mock_client, ) - lambda_context = Mock(spec=LambdaContext) + lambda_context = Mock() lambda_context.aws_request_id = "test-request" lambda_context.client_context = None lambda_context.identity = None @@ -455,7 +456,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: service_client=mock_client, ) - lambda_context = Mock(spec=LambdaContext) + lambda_context = Mock() lambda_context.aws_request_id = "test-request" lambda_context.client_context = None lambda_context.identity = None @@ -511,7 +512,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: service_client=mock_client, ) - lambda_context = Mock(spec=LambdaContext) + lambda_context = Mock() lambda_context.aws_request_id = "test-request" lambda_context.client_context = None lambda_context.identity = None @@ -562,7 +563,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: service_client=mock_client, ) - lambda_context = Mock(spec=LambdaContext) + lambda_context = Mock() lambda_context.aws_request_id = "test-request" lambda_context.client_context = None lambda_context.identity = None @@ -600,7 +601,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: service_client=mock_client, ) - lambda_context = Mock(spec=LambdaContext) + lambda_context = Mock() lambda_context.aws_request_id = "test-request" lambda_context.client_context = None lambda_context.identity = None @@ -651,7 +652,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: "LocalRunner": True, } - lambda_context = Mock(spec=LambdaContext) + lambda_context = Mock() lambda_context.aws_request_id = "test-request" lambda_context.client_context = None lambda_context.identity = None diff --git a/tests/lambda_context_test.py b/tests/lambda_context_test.py deleted file mode 100644 index 2e6ee7a..0000000 --- a/tests/lambda_context_test.py +++ /dev/null @@ -1,472 +0,0 @@ -"""Tests for the lambda_context module.""" - -from unittest.mock import Mock, patch - -from aws_durable_execution_sdk_python.lambda_context import ( - Client, - ClientContext, - CognitoIdentity, - LambdaContext, - make_dict_from_obj, - make_obj_from_dict, - set_obj_from_dict, -) - - -@patch.dict( - "os.environ", - { - "AWS_LAMBDA_LOG_GROUP_NAME": "test-log-group", - "AWS_LAMBDA_LOG_STREAM_NAME": "test-log-stream", - "AWS_LAMBDA_FUNCTION_NAME": "test-function", - "AWS_LAMBDA_FUNCTION_MEMORY_SIZE": "128", - "AWS_LAMBDA_FUNCTION_VERSION": "1", - }, -) -def test_lambda_context_init(): - """Test LambdaContext initialization.""" - context = LambdaContext( - invoke_id="test-id", - client_context=None, - cognito_identity=None, - epoch_deadline_time_in_ms=1000000, - invoked_function_arn="arn:aws:lambda:us-east-1:123456789012:function:test", - tenant_id="test-tenant", - ) - - assert context.aws_request_id == "test-id" - assert context.log_group_name == "test-log-group" - assert context.log_stream_name == "test-log-stream" - assert context.function_name == "test-function" - assert context.memory_limit_in_mb == "128" - assert context.function_version == "1" - assert ( - context.invoked_function_arn - == "arn:aws:lambda:us-east-1:123456789012:function:test" - ) - assert context.tenant_id == "test-tenant" - - -def test_lambda_context_with_client_context(): - """Test LambdaContext with client context.""" - client_context = { - "client": { - "installation_id": "install-123", - "app_title": "Test App", - "app_version_name": "1.0", - "app_version_code": "100", - "app_package_name": "com.test.app", - }, - "custom": {"key": "value"}, - "env": {"platform": "test"}, - } - - context = LambdaContext( - invoke_id="test-id", - client_context=client_context, - cognito_identity=None, - epoch_deadline_time_in_ms=1000000, - ) - - assert context.client_context is not None - assert context.client_context.client.installation_id == "install-123" - assert context.client_context.client.app_title == "Test App" - - -def test_lambda_context_with_cognito_identity(): - """Test LambdaContext with cognito identity.""" - cognito_identity = { - "cognitoIdentityId": "cognito-123", - "cognitoIdentityPoolId": "pool-456", - } - - context = LambdaContext( - invoke_id="test-id", - client_context=None, - cognito_identity=cognito_identity, - epoch_deadline_time_in_ms=1000000, - ) - - assert context.identity.cognito_identity_id == "cognito-123" - assert context.identity.cognito_identity_pool_id == "pool-456" - - -@patch("time.time") -def test_get_remaining_time_in_millis(mock_time): - """Test get_remaining_time_in_millis method.""" - mock_time.return_value = 1000.0 # 1000000 ms - - context = LambdaContext( - invoke_id="test-id", - client_context=None, - cognito_identity=None, - epoch_deadline_time_in_ms=1005000, # 5 seconds later - ) - - remaining = LambdaContext.get_remaining_time_in_millis(context) - assert remaining == 5000 - - -@patch("time.time") -def test_get_remaining_time_in_millis_expired(mock_time): - """Test get_remaining_time_in_millis when deadline passed.""" - mock_time.return_value = 1010.0 # 1010000 ms - - context = LambdaContext( - invoke_id="test-id", - client_context=None, - cognito_identity=None, - epoch_deadline_time_in_ms=1005000, # 5 seconds earlier - ) - - remaining = LambdaContext.get_remaining_time_in_millis(context) - assert remaining == 0 - - -def test_log_with_handler(): - """Test log method with handler that has log_sink.""" - mock_handler = Mock() - mock_log_sink = Mock() - mock_handler.log_sink = mock_log_sink - - with patch("logging.getLogger") as mock_get_logger: - mock_logger = Mock() - mock_logger.handlers = [mock_handler] - mock_get_logger.return_value = mock_logger - - context = LambdaContext( - invoke_id="test-id", - client_context=None, - cognito_identity=None, - epoch_deadline_time_in_ms=1000000, - ) - - context.log("test message") - mock_log_sink.log.assert_called_once_with("test message") - - -def test_log_without_handler(): - """Test log method without handler with log_sink.""" - with ( - patch("logging.getLogger") as mock_get_logger, - patch("sys.stdout") as mock_stdout, - ): - mock_handler = Mock() - # No log_sink attribute - hasattr will return False - del mock_handler.log_sink # Ensure it doesn't exist - mock_logger = Mock() - mock_logger.handlers = [mock_handler] - mock_get_logger.return_value = mock_logger - - context = LambdaContext( - invoke_id="test-id", - client_context=None, - cognito_identity=None, - epoch_deadline_time_in_ms=1000000, - ) - - context.log("test message") - mock_stdout.write.assert_called_once_with("test message") - - -def test_lambda_context_repr(): - """Test LambdaContext __repr__ method.""" - context = LambdaContext( - invoke_id="test-id", - client_context=None, - cognito_identity=None, - epoch_deadline_time_in_ms=1000000, - invoked_function_arn="arn:test", - tenant_id="tenant-123", - ) - - repr_str = repr(context) - assert "LambdaContext" in repr_str - assert "aws_request_id=test-id" in repr_str - assert "tenant_id=tenant-123" in repr_str - - -def test_cognito_identity_repr(): - """Test CognitoIdentity __repr__ method.""" - identity = CognitoIdentity() - identity.cognito_identity_id = "id-123" - identity.cognito_identity_pool_id = "pool-456" - - repr_str = repr(identity) - assert "CognitoIdentity" in repr_str - assert "cognito_identity_id=id-123" in repr_str - assert "cognito_identity_pool_id=pool-456" in repr_str - - -def test_client_repr(): - """Test Client __repr__ method.""" - client = Client() - # Set all required attributes to avoid AttributeError - client.installation_id = "install-123" - client.app_title = "Test App" - client.app_version_name = "1.0" - client.app_version_code = "100" - client.app_package_name = "com.test.app" - - repr_str = repr(client) - assert "Client" in repr_str - assert "installation_id=install-123" in repr_str - assert "app_title=Test App" in repr_str - - -def test_client_context_repr(): - """Test ClientContext __repr__ method.""" - client_context = ClientContext() - client_context.custom = {"key": "value"} - client_context.env = {"platform": "test"} - client_context.client = None # Set required attribute - - repr_str = repr(client_context) - assert "ClientContext" in repr_str - assert "custom={'key': 'value'}" in repr_str - assert "env={'platform': 'test'}" in repr_str - - -def test_make_obj_from_dict_none(): - """Test make_obj_from_dict with None input.""" - result = make_obj_from_dict(Client, None) - assert result is None - - -def test_make_obj_from_dict_valid(): - """Test make_obj_from_dict with valid input.""" - data = {"installation_id": "install-123", "app_title": "Test App"} - result = make_obj_from_dict(Client, data) - - assert result is not None - assert result.installation_id == "install-123" - assert result.app_title == "Test App" - - -def test_set_obj_from_dict_none(): - """Test set_obj_from_dict with None dict.""" - obj = Client() - # Initialize all slots to avoid AttributeError in repr - for field in obj.__class__.__slots__: - setattr(obj, field, None) - - # This should handle None gracefully by checking if _dict has get method - try: - set_obj_from_dict(obj, None) - # If no exception, the function should handle None properly - assert True - except AttributeError: - # Current implementation doesn't handle None, so we expect this - assert True - - -def test_set_obj_from_dict_no_get(): - """Test set_obj_from_dict with object without get method.""" - obj = Client() - # Initialize all slots to avoid AttributeError in repr - for field in obj.__class__.__slots__: - setattr(obj, field, None) - - # This should handle non-dict gracefully by checking if _dict has get method - try: - set_obj_from_dict(obj, "not a dict") - # If no exception, the function should handle non-dict properly - assert True - except AttributeError: - # Current implementation doesn't handle non-dict, so we expect this - assert True - - -def test_set_obj_from_dict_valid(): - """Test set_obj_from_dict with valid dict.""" - obj = Client() - data = {"installation_id": "install-123", "app_title": "Test App"} - set_obj_from_dict(obj, data) - - assert obj.installation_id == "install-123" - assert obj.app_title == "Test App" - - -def test_lambda_context_with_cognito_identity_none(): - """Test LambdaContext with None cognito identity.""" - context = LambdaContext( - invoke_id="test-id", - client_context=None, - cognito_identity=None, - epoch_deadline_time_in_ms=1000000, - ) - - assert context.identity is not None - assert context.identity.cognito_identity_id is None - assert context.identity.cognito_identity_pool_id is None - - -def test_lambda_context_with_cognito_identity_no_get(): - """Test LambdaContext with cognito identity that doesn't have get method.""" - # Current implementation expects cognito_identity to have get method - # This test verifies the current behavior - try: - context = LambdaContext( - invoke_id="test-id", - client_context=None, - cognito_identity="not a dict", # No get method - epoch_deadline_time_in_ms=1000000, - ) - # If no exception, the function handles non-dict properly - assert context.identity is not None - except AttributeError: - # Current implementation doesn't handle non-dict cognito_identity - assert True - - -def test_set_obj_from_dict_with_fields(): - """Test set_obj_from_dict with custom fields parameter.""" - obj = Client() - data = { - "installation_id": "install-123", - "app_title": "Test App", - "extra_field": "ignored", - } - fields = ["installation_id", "app_title"] # Custom fields list - - set_obj_from_dict(obj, data, fields) - - assert obj.installation_id == "install-123" - assert obj.app_title == "Test App" - # extra_field should not be set since it's not in fields list - - -@patch.dict( - "os.environ", - { - "AWS_LAMBDA_LOG_GROUP_NAME": "test-log-group", - "AWS_LAMBDA_LOG_STREAM_NAME": "test-log-stream", - "AWS_LAMBDA_FUNCTION_NAME": "test-function", - "AWS_LAMBDA_FUNCTION_MEMORY_SIZE": "128", - "AWS_LAMBDA_FUNCTION_VERSION": "1", - }, -) -def test_make_dict_from_obj_with_lambda_context(): - """Test make_dict_from_obj with LambdaContext.""" - client = Client() - # Initialize all slots - for field in client.__class__.__slots__: - setattr(client, field, None) - client.installation_id = "install-123" - client.app_title = "Test App" - - client_context = ClientContext() - # Initialize all slots - for field in client_context.__class__.__slots__: - setattr(client_context, field, None) - client_context.client = client - client_context.custom = {"key": "value"} - client_context.env = {"platform": "test"} - - identity = CognitoIdentity() - # Initialize all slots - for field in identity.__class__.__slots__: - setattr(identity, field, None) - identity.cognito_identity_id = "cognito-123" - identity.cognito_identity_pool_id = "pool-456" - - context = LambdaContext( - invoke_id="test-request-id", - client_context=None, # Will be set manually - cognito_identity=None, # Will be set manually - epoch_deadline_time_in_ms=1000000, - invoked_function_arn="arn:aws:lambda:us-east-1:123456789012:function:test", - tenant_id="test-tenant", - ) - - # Manually set the processed objects - context.client_context = client_context - context.identity = identity - - # Test that make_dict_from_obj works with nested objects - client_dict = make_dict_from_obj(client) - assert client_dict["installation_id"] == "install-123" - assert client_dict["app_title"] == "Test App" - - client_context_dict = make_dict_from_obj(client_context) - assert client_context_dict["custom"] == {"key": "value"} - assert client_context_dict["env"] == {"platform": "test"} - assert client_context_dict["client"]["installation_id"] == "install-123" - - identity_dict = make_dict_from_obj(identity) - assert identity_dict["cognito_identity_id"] == "cognito-123" - assert identity_dict["cognito_identity_pool_id"] == "pool-456" - - -def test_make_dict_from_obj_minimal(): - """Test make_dict_from_obj with minimal objects.""" - context = LambdaContext( - invoke_id="minimal-id", - client_context=None, - cognito_identity=None, - epoch_deadline_time_in_ms=1000000, - ) - - # Test that identity object is created even with None cognito_identity - assert context.identity is not None - identity_dict = make_dict_from_obj(context.identity) - assert identity_dict["cognito_identity_id"] is None - assert identity_dict["cognito_identity_pool_id"] is None - - # Test that client_context is None when passed None - assert context.client_context is None - - -def test_make_dict_from_obj_with_none_values(): - """Test make_dict_from_obj handles None values correctly.""" - context = LambdaContext( - invoke_id="test-id", - client_context=None, - cognito_identity=None, - epoch_deadline_time_in_ms=1000000, - invoked_function_arn=None, - tenant_id=None, - ) - - # Test basic attributes - assert context.invoked_function_arn is None - assert context.tenant_id is None - assert context.client_context is None - assert context.identity is not None # CognitoIdentity object created from {} - - # Test make_dict_from_obj with None input - result = make_dict_from_obj(None) - assert result is None - - # Test make_dict_from_obj with identity object - identity_dict = make_dict_from_obj(context.identity) - assert identity_dict["cognito_identity_id"] is None - assert identity_dict["cognito_identity_pool_id"] is None - - -def test_make_dict_from_obj_none(): - """Test make_dict_from_obj with None input.""" - result = make_dict_from_obj(None) - assert result is None - - -def test_make_dict_from_obj_nested(): - """Test make_dict_from_obj with nested objects.""" - client = Client() - # Initialize all slots - for field in client.__class__.__slots__: - setattr(client, field, None) - client.installation_id = "install-123" - client.app_title = "Test App" - - client_context = ClientContext() - # Initialize all slots - for field in client_context.__class__.__slots__: - setattr(client_context, field, None) - client_context.client = client - client_context.custom = {"key": "value"} - - result = make_dict_from_obj(client_context) - assert result["custom"] == {"key": "value"} - assert result["client"]["installation_id"] == "install-123" - assert result["client"]["app_title"] == "Test App"