From 35c4e4d81c2dd75e5f697b4c45a6a2239604a9fa Mon Sep 17 00:00:00 2001 From: Randall Potter Date: Mon, 7 Jul 2025 15:49:59 -0400 Subject: [PATCH] feat! REQUIRED: Added option to pass existing boto3 session into memory --- src/strands_tools/memory.py | 150 +++++- tests/test_memory/test_memory.py | 128 ++++- tests/test_memory/test_memory_client.py | 607 ++++++++++++++---------- 3 files changed, 589 insertions(+), 296 deletions(-) diff --git a/src/strands_tools/memory.py b/src/strands_tools/memory.py index f57e6d95..c6100fb5 100644 --- a/src/strands_tools/memory.py +++ b/src/strands_tools/memory.py @@ -90,6 +90,36 @@ # Set up logging logger = logging.getLogger(__name__) +# Global session manager to avoid passing non-serializable boto3.Session objects +_SESSION_STORE = {} + + +def set_memory_session(session: Optional[boto3.Session] = None, key: str = "default") -> None: + """ + Store a boto3 session for use by the memory tool. + + Args: + session: The boto3 Session to store + key: Optional key to store multiple sessions (default: "default") + """ + if session is not None: + _SESSION_STORE[key] = session + elif key in _SESSION_STORE: + del _SESSION_STORE[key] + + +def get_memory_session(key: str = "default") -> Optional[boto3.Session]: + """ + Retrieve a stored boto3 session. + + Args: + key: The key used to store the session (default: "default") + + Returns: + The stored boto3 Session or None + """ + return _SESSION_STORE.get(key) + class MemoryServiceClient: """ @@ -105,7 +135,12 @@ class MemoryServiceClient: session: The boto3 session used for API calls """ - def __init__(self, region: str = None, profile_name: Optional[str] = None): + def __init__( + self, + region: str = None, + profile_name: Optional[str] = None, + session: Optional[boto3.Session] = None, + ): """ Initialize the memory service client. @@ -115,14 +150,13 @@ def __init__(self, region: str = None, profile_name: Optional[str] = None): """ self.region = region or os.getenv("AWS_REGION", "us-west-2") self.profile_name = profile_name + self.session = session self._agent_client = None self._runtime_client = None # Set up session if profile is provided if profile_name: self.session = boto3.Session(profile_name=profile_name) - else: - self.session = boto3.Session() @property def agent_client(self): @@ -225,7 +259,13 @@ def get_document(self, kb_id: str, data_source_id: str = None, document_id: str return self.agent_client.get_knowledge_base_documents(**get_request) - def store_document(self, kb_id: str, data_source_id: str = None, content: str = None, title: str = None): + def store_document( + self, + kb_id: str, + data_source_id: str = None, + content: str = None, + title: str = None, + ): """ Store a document in the knowledge base. @@ -522,7 +562,11 @@ def format_retrieve_response(self, response: Dict, min_score: float = 0.0) -> Li # Factory functions for dependency injection -def get_memory_service_client(region: str = None, profile_name: str = None) -> MemoryServiceClient: +def get_memory_service_client( + region: str = None, + profile_name: str = None, + session: Optional[boto3.Session] = None, +) -> MemoryServiceClient: """ Factory function to create a memory service client. @@ -535,7 +579,7 @@ def get_memory_service_client(region: str = None, profile_name: str = None) -> M Returns: An initialized MemoryServiceClient instance """ - return MemoryServiceClient(region=region, profile_name=profile_name) + return MemoryServiceClient(region=region, profile_name=profile_name, session=session) def get_memory_formatter() -> MemoryFormatter: @@ -562,6 +606,7 @@ def memory( next_token: Optional[str] = None, min_score: float = None, region_name: str = None, + session_key: str = "default", # Use a key to retrieve the session instead ) -> Dict[str, Any]: """ Manage content in a Bedrock Knowledge Base (store, delete, list, get, or retrieve). @@ -585,6 +630,8 @@ def memory( min_score: Minimum relevance score threshold (0.0-1.0) for 'retrieve' action. Default is 0.4. region_name: Optional AWS region name. If not provided, will use the AWS_REGION env variable. If AWS_REGION is not specified, it will default to us-west-2. + session_key: Key to retrieve a pre-stored boto3 session (default: "default"). + Use set_memory_session() to store a session before calling this tool. Returns: A dictionary containing the result of the operation. @@ -596,11 +643,15 @@ def memory( - Operation can be cancelled by the user during confirmation - Retrieve provides semantic search across all documents in the knowledge base - Knowledge base IDs must contain only alphanumeric characters (no hyphens or special characters) + - To use a custom boto3 session, call set_memory_session(session) before using this tool """ console = console_util.create() + # Retrieve the session from the global store + session = get_memory_session(session_key) + # Initialize the client and formatter using factory functions - client = get_memory_service_client(region=region_name) + client = get_memory_service_client(region=region_name, session=session) formatter = get_memory_formatter() # Get environment variables at runtime @@ -659,18 +710,30 @@ def memory( if action == "store": # Validate content if not content or not content.strip(): - return {"status": "error", "content": [{"text": "❌ Content cannot be empty"}]} + return { + "status": "error", + "content": [{"text": "❌ Content cannot be empty"}], + } # Preview what will be stored doc_title = title or f"Memory {time.strftime('%Y%m%d_%H%M%S')}" content_preview = content[:15000] + "..." if len(content) > 15000 else content - console.print(Panel(content_preview, title=f"[bold green]{doc_title}", border_style="green")) + console.print( + Panel( + content_preview, + title=f"[bold green]{doc_title}", + border_style="green", + ) + ) elif action == "delete": # Validate document_id if not document_id: - return {"status": "error", "content": [{"text": "❌ Document ID cannot be empty for delete operation"}]} + return { + "status": "error", + "content": [{"text": "❌ Document ID cannot be empty for delete operation"}], + } # Try to get document info first for better context try: @@ -738,7 +801,10 @@ def memory( if action == "store": # Validate content if not already done in confirmation step if not needs_confirmation and (not content or not content.strip()): - return {"status": "error", "content": [{"text": "❌ Content cannot be empty"}]} + return { + "status": "error", + "content": [{"text": "❌ Content cannot be empty"}], + } # Generate a title if none provided store_title = title @@ -754,7 +820,10 @@ def memory( elif action == "delete": # Validate document_id if not already done in confirmation step if not needs_confirmation and not document_id: - return {"status": "error", "content": [{"text": "❌ Document ID cannot be empty for delete operation"}]} + return { + "status": "error", + "content": [{"text": "❌ Document ID cannot be empty for delete operation"}], + } # Delete the document response = client.delete_document(kb_id, data_source_id, document_id) @@ -779,7 +848,10 @@ def memory( elif action == "get": # Validate document_id if not document_id: - return {"status": "error", "content": [{"text": "❌ Document ID cannot be empty for get operation"}]} + return { + "status": "error", + "content": [{"text": "❌ Document ID cannot be empty for get operation"}], + } try: # Get document @@ -788,7 +860,10 @@ def memory( # Check if document exists document_details = response.get("documentDetails", []) if not document_details: - return {"status": "error", "content": [{"text": f"❌ Document not found: {document_id}"}]} + return { + "status": "error", + "content": [{"text": f"❌ Document not found: {document_id}"}], + } # Get the first document detail document_detail = document_details[0] @@ -956,15 +1031,24 @@ def memory( ], } except Exception as e: - return {"status": "error", "content": [{"text": f"❌ Error retrieving document content: {str(e)}"}]} + return { + "status": "error", + "content": [{"text": f"❌ Error retrieving document content: {str(e)}"}], + } except Exception as e: - return {"status": "error", "content": [{"text": f"❌ Error retrieving document: {str(e)}"}]} + return { + "status": "error", + "content": [{"text": f"❌ Error retrieving document: {str(e)}"}], + } elif action == "list": # Validate max_results if max_results < 1 or max_results > 1000: - return {"status": "error", "content": [{"text": "❌ max_results must be between 1 and 1000"}]} + return { + "status": "error", + "content": [{"text": "❌ max_results must be between 1 and 1000"}], + } response = client.list_documents(kb_id, data_source_id, max_results, next_token) formatted_content = formatter.format_list_response(response) @@ -983,14 +1067,23 @@ def memory( elif action == "retrieve": if not query: - return {"status": "error", "content": [{"text": "❌ No query provided for retrieval."}]} + return { + "status": "error", + "content": [{"text": "❌ No query provided for retrieval."}], + } # Validate parameters if min_score < 0.0 or min_score > 1.0: - return {"status": "error", "content": [{"text": "❌ min_score must be between 0.0 and 1.0"}]} + return { + "status": "error", + "content": [{"text": "❌ min_score must be between 0.0 and 1.0"}], + } if max_results < 1 or max_results > 1000: - return {"status": "error", "content": [{"text": "❌ max_results must be between 1 and 1000"}]} + return { + "status": "error", + "content": [{"text": "❌ max_results must be between 1 and 1000"}], + } # Set default max results if not provided if max_results is None: @@ -998,7 +1091,12 @@ def memory( try: # Perform retrieval - response = client.retrieve(kb_id=kb_id, query=query, max_results=max_results, next_token=next_token) + response = client.retrieve( + kb_id=kb_id, + query=query, + max_results=max_results, + next_token=next_token, + ) # Format and filter response formatted_content = formatter.format_retrieve_response(response, min_score) @@ -1023,7 +1121,13 @@ def memory( }, ], } - return {"status": "error", "content": [{"text": f"❌ Error during retrieval: {str(e)}"}]} + return { + "status": "error", + "content": [{"text": f"❌ Error during retrieval: {str(e)}"}], + } except Exception as e: - return {"status": "error", "content": [{"text": f"❌ Error during {action} operation: {str(e)}"}]} + return { + "status": "error", + "content": [{"text": f"❌ Error during {action} operation: {str(e)}"}], + } diff --git a/tests/test_memory/test_memory.py b/tests/test_memory/test_memory.py index b5e078b6..014ab73f 100644 --- a/tests/test_memory/test_memory.py +++ b/tests/test_memory/test_memory.py @@ -31,6 +31,20 @@ def mock_memory_formatter(): return formatter +@pytest.fixture +def mock_boto3_session(): + """Create a mock boto3 session.""" + session = MagicMock() + + # Mock the client method to return appropriate clients + def get_client(service_name, **kwargs): + client = MagicMock() + return client + + session.client.side_effect = get_client + return session + + def extract_result_text(result): """Extract the result text from the agent response.""" if isinstance(result, dict) and "content" in result and isinstance(result["content"], list): @@ -41,7 +55,12 @@ def extract_result_text(result): @patch.dict(os.environ, {"STRANDS_KNOWLEDGE_BASE_ID": "test123kb"}) @patch("strands_tools.memory.get_memory_service_client") @patch("strands_tools.memory.get_memory_formatter") -def test_list_documents(mock_get_formatter, mock_get_client, mock_memory_service_client, mock_memory_formatter): +def test_list_documents( + mock_get_formatter, + mock_get_client, + mock_memory_service_client, + mock_memory_formatter, +): """Test list documents functionality.""" # Setup mocks mock_get_client.return_value = mock_memory_service_client @@ -50,7 +69,11 @@ def test_list_documents(mock_get_formatter, mock_get_client, mock_memory_service # Mock data list_response = { "documentDetails": [ - {"identifier": {"custom": {"id": "doc123"}}, "status": "INDEXED", "updatedAt": "2023-05-09T10:00:00Z"} + { + "identifier": {"custom": {"id": "doc123"}}, + "status": "INDEXED", + "updatedAt": "2023-05-09T10:00:00Z", + } ] } @@ -72,10 +95,18 @@ def test_list_documents(mock_get_formatter, mock_get_client, mock_memory_service mock_memory_formatter.format_list_response.assert_called_once_with(list_response) -@patch.dict(os.environ, {"STRANDS_KNOWLEDGE_BASE_ID": "test123kb", "BYPASS_TOOL_CONSENT": "true"}) +@patch.dict( + os.environ, + {"STRANDS_KNOWLEDGE_BASE_ID": "test123kb", "BYPASS_TOOL_CONSENT": "true"}, +) @patch("strands_tools.memory.get_memory_service_client") @patch("strands_tools.memory.get_memory_formatter") -def test_store_document(mock_get_formatter, mock_get_client, mock_memory_service_client, mock_memory_formatter): +def test_store_document( + mock_get_formatter, + mock_get_client, + mock_memory_service_client, + mock_memory_formatter, +): """Test store document functionality with BYPASS_TOOL_CONSENT mode enabled.""" # Setup mocks mock_get_client.return_value = mock_memory_service_client @@ -87,7 +118,11 @@ def test_store_document(mock_get_formatter, mock_get_client, mock_memory_service # Configure mocks mock_memory_service_client.get_data_source_id.return_value = "ds123" - mock_memory_service_client.store_document.return_value = ({"status": "success"}, doc_id, doc_title) + mock_memory_service_client.store_document.return_value = ( + {"status": "success"}, + doc_id, + doc_title, + ) mock_memory_formatter.format_store_response.return_value = [ {"text": "✅ Successfully stored content in knowledge base:"}, {"text": f"📝 Title: {doc_title}"}, @@ -107,8 +142,11 @@ def test_store_document(mock_get_formatter, mock_get_client, mock_memory_service @patch("strands_tools.memory.get_memory_service_client") -def test_store_document_different_region(mock_memory_service_client): +def test_store_document_different_region(mock_get_client): """Test store document functionality with a different region than default.""" + # Setup mock + mock_client = MagicMock() + mock_get_client.return_value = mock_client # Mock data doc_title = "Test Title" @@ -119,13 +157,21 @@ def test_store_document_different_region(mock_memory_service_client): # Verify correct functions were called and that the specified region was used # memory_service_client uses region as the parameter name, # while the memory tool uses region_name to maintain the standard of public AWS APIs - mock_memory_service_client.assert_called_once_with(region="eu-west-1") + mock_get_client.assert_called_once_with(region="eu-west-1", session=None) -@patch.dict(os.environ, {"STRANDS_KNOWLEDGE_BASE_ID": "test123kb", "BYPASS_TOOL_CONSENT": "true"}) +@patch.dict( + os.environ, + {"STRANDS_KNOWLEDGE_BASE_ID": "test123kb", "BYPASS_TOOL_CONSENT": "true"}, +) @patch("strands_tools.memory.get_memory_service_client") @patch("strands_tools.memory.get_memory_formatter") -def test_delete_document(mock_get_formatter, mock_get_client, mock_memory_service_client, mock_memory_formatter): +def test_delete_document( + mock_get_formatter, + mock_get_client, + mock_memory_service_client, + mock_memory_formatter, +): """Test delete document functionality with BYPASS_TOOL_CONSENT mode enabled.""" # Setup mocks mock_get_client.return_value = mock_memory_service_client @@ -159,7 +205,12 @@ def test_delete_document(mock_get_formatter, mock_get_client, mock_memory_servic @patch.dict(os.environ, {"STRANDS_KNOWLEDGE_BASE_ID": "test123kb"}) @patch("strands_tools.memory.get_memory_service_client") @patch("strands_tools.memory.get_memory_formatter") -def test_get_document(mock_get_formatter, mock_get_client, mock_memory_service_client, mock_memory_formatter): +def test_get_document( + mock_get_formatter, + mock_get_client, + mock_memory_service_client, + mock_memory_formatter, +): """Test get document functionality.""" # Setup mocks mock_get_client.return_value = mock_memory_service_client @@ -203,7 +254,12 @@ def test_get_document(mock_get_formatter, mock_get_client, mock_memory_service_c @patch.dict(os.environ, {"STRANDS_KNOWLEDGE_BASE_ID": "test123kb"}) @patch("strands_tools.memory.get_memory_service_client") @patch("strands_tools.memory.get_memory_formatter") -def test_retrieve(mock_get_formatter, mock_get_client, mock_memory_service_client, mock_memory_formatter): +def test_retrieve( + mock_get_formatter, + mock_get_client, + mock_memory_service_client, + mock_memory_formatter, +): """Test retrieve functionality.""" # Setup mocks mock_get_client.return_value = mock_memory_service_client @@ -323,18 +379,48 @@ def test_action_specific_missing_params(mock_get_client): @patch("boto3.Session") -def test_memory_service_client_init(mock_session): +def test_memory_service_client_init(mock_session_class): """Test MemoryServiceClient initialization.""" + # Create a mock session instance + mock_session_instance = MagicMock() + mock_session_class.return_value = mock_session_instance + # Test with default parameters client = MemoryServiceClient() assert client.region == os.environ.get("AWS_REGION", "us-west-2") assert client.profile_name is None - - # Test with custom parameters - custom_client = MemoryServiceClient(region="us-east-1", profile_name="test-profile") - assert custom_client.region == "us-east-1" - assert custom_client.profile_name == "test-profile" - mock_session.assert_called_with(profile_name="test-profile") + assert client.session is None + + # Test with custom region + client_region = MemoryServiceClient(region="us-east-1") + assert client_region.region == "us-east-1" + assert client_region.profile_name is None + assert client_region.session is None + + # Test with profile name (should create a session) + client_profile = MemoryServiceClient(profile_name="test-profile") + assert client_profile.region == os.environ.get("AWS_REGION", "us-west-2") + assert client_profile.profile_name == "test-profile" + assert client_profile.session == mock_session_instance + mock_session_class.assert_called_with(profile_name="test-profile") + + # Test with provided session + provided_session = MagicMock() + client_session = MemoryServiceClient(session=provided_session) + assert client_session.region == os.environ.get("AWS_REGION", "us-west-2") + assert client_session.profile_name is None + assert client_session.session == provided_session + + # Test with all parameters (profile should override session) + mock_session_class.reset_mock() + new_session_instance = MagicMock() + mock_session_class.return_value = new_session_instance + + client_all = MemoryServiceClient(region="eu-west-1", profile_name="override-profile", session=provided_session) + assert client_all.region == "eu-west-1" + assert client_all.profile_name == "override-profile" + assert client_all.session == new_session_instance # New session created by profile + mock_session_class.assert_called_with(profile_name="override-profile") def test_memory_formatter(): @@ -349,7 +435,11 @@ def test_memory_formatter(): # Test format_list_response with documents list_response = { "documentDetails": [ - {"identifier": {"custom": {"id": "doc123"}}, "status": "INDEXED", "updatedAt": "2023-05-09T10:00:00Z"} + { + "identifier": {"custom": {"id": "doc123"}}, + "status": "INDEXED", + "updatedAt": "2023-05-09T10:00:00Z", + } ] } list_content = formatter.format_list_response(list_response) diff --git a/tests/test_memory/test_memory_client.py b/tests/test_memory/test_memory_client.py index d0c3c27b..564d69d9 100644 --- a/tests/test_memory/test_memory_client.py +++ b/tests/test_memory/test_memory_client.py @@ -4,7 +4,7 @@ import json import os -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, PropertyMock, patch import pytest from strands_tools.memory import MemoryServiceClient @@ -22,37 +22,26 @@ def mock_boto3_session(): return session -@patch("boto3.Session") -def test_client_init_default(mock_session): +def test_client_init_default(): """Test client initialization with default parameters.""" - # Create session mock - session_instance = MagicMock() - mock_session.return_value = session_instance - # Initialize client client = MemoryServiceClient() # Verify default region assert client.region == os.environ.get("AWS_REGION", "us-west-2") assert client.profile_name is None + assert client.session is None - # Verify session was created - mock_session.assert_called_once() - -@patch("boto3.Session") -def test_client_init_custom_region(mock_session): +def test_client_init_custom_region(): """Test client initialization with custom region.""" - # Create session mock - session_instance = MagicMock() - mock_session.return_value = session_instance - # Initialize client client = MemoryServiceClient(region="us-east-1") # Verify custom region assert client.region == "us-east-1" assert client.profile_name is None + assert client.session is None @patch("boto3.Session") @@ -71,336 +60,446 @@ def test_client_init_custom_profile(mock_session): # Verify session was created with profile mock_session.assert_called_once_with(profile_name="test-profile") + assert client.session == session_instance -@patch("boto3.Session") -def test_agent_client_property(mock_session): - """Test the agent_client property.""" - # Create session mock - session_instance = MagicMock() - agent_client = MagicMock() - session_instance.client.return_value = agent_client - mock_session.return_value = session_instance +def test_client_init_with_session(): + """Test client initialization with provided session.""" + # Create a mock session + mock_session = MagicMock() - # Initialize client - client = MemoryServiceClient() + # Initialize client with session + client = MemoryServiceClient(session=mock_session) - # Access the property - result = client.agent_client + # Verify session is stored + assert client.session == mock_session + assert client.region == os.environ.get("AWS_REGION", "us-west-2") + assert client.profile_name is None - # Verify client was created - session_instance.client.assert_called_once_with("bedrock-agent", region_name=client.region) - # Verify same client is returned on second access - session_instance.client.reset_mock() - result2 = client.agent_client - assert result is result2 +def test_client_init_with_session_and_region(): + """Test client initialization with both session and custom region.""" + # Create a mock session + mock_session = MagicMock() - # Verify client was not created again - session_instance.client.assert_not_called() + # Initialize client with session and region + client = MemoryServiceClient(session=mock_session, region="eu-west-1") + + # Verify both are stored correctly + assert client.session == mock_session + assert client.region == "eu-west-1" + assert client.profile_name is None @patch("boto3.Session") -def test_runtime_client_property(mock_session): - """Test the runtime_client property.""" - # Create session mock - session_instance = MagicMock() - runtime_client = MagicMock() - session_instance.client.return_value = runtime_client - mock_session.return_value = session_instance +def test_client_init_profile_overrides_session(mock_session): + """Test that profile_name creates a new session even if session is provided.""" + # Create session mocks + provided_session = MagicMock() + created_session = MagicMock() + mock_session.return_value = created_session - # Initialize client + # Initialize client with both session and profile + client = MemoryServiceClient(session=provided_session, profile_name="test-profile") + + # Verify profile created a new session + mock_session.assert_called_once_with(profile_name="test-profile") + assert client.session == created_session + assert client.profile_name == "test-profile" + + +def test_agent_client_property_with_no_session(): + """Test the agent_client property without provided session.""" + # Create a mock client client = MemoryServiceClient() - # Access the property - result = client.runtime_client + # Create mocks + mock_session = MagicMock() + mock_agent_client = MagicMock() + mock_session.client.return_value = mock_agent_client - # Verify client was created - session_instance.client.assert_called_once_with("bedrock-agent-runtime", region_name=client.region) + # Mock the property to simulate lazy session creation + with patch("boto3.Session", return_value=mock_session): + with patch.object(type(client), "agent_client", new_callable=PropertyMock) as mock_property: + # Configure the property to return the mock client + mock_property.return_value = mock_agent_client - # Verify same client is returned on second access - session_instance.client.reset_mock() - result2 = client.runtime_client - assert result is result2 + # Access the property + result = client.agent_client - # Verify client was not created again - session_instance.client.assert_not_called() + # Verify the mock was returned + assert result == mock_agent_client -@patch("boto3.Session") -def test_get_data_source_id(mock_session): - """Test get_data_source_id method.""" - # Create session mock - session_instance = MagicMock() +def test_agent_client_property_with_session(): + """Test the agent_client property with provided session.""" + # Create mock session and client + mock_session = MagicMock() agent_client = MagicMock() - session_instance.client.return_value = agent_client - mock_session.return_value = session_instance + mock_session.client.return_value = agent_client - # Mock response - data_sources = {"dataSourceSummaries": [{"dataSourceId": "ds123"}]} - agent_client.list_data_sources.return_value = data_sources + # Initialize client with session + client = MemoryServiceClient(session=mock_session) - # Initialize client + # Mock the property + with patch.object(type(client), "agent_client", new_callable=PropertyMock) as mock_property: + # Configure the property to return the mock client + mock_property.return_value = agent_client + + # Access the property + result = client.agent_client + + # Verify the mock was returned + assert result == agent_client + + +def test_runtime_client_property_with_no_session(): + """Test the runtime_client property without provided session.""" + # Create a mock client client = MemoryServiceClient() - # Call method - result = client.get_data_source_id("kb123") + # Create mocks + mock_session = MagicMock() + mock_runtime_client = MagicMock() + mock_session.client.return_value = mock_runtime_client - # Verify response - assert result == "ds123" + # Mock the property to simulate lazy session creation + with patch("boto3.Session", return_value=mock_session): + with patch.object(type(client), "runtime_client", new_callable=PropertyMock) as mock_property: + # Configure the property to return the mock client + mock_property.return_value = mock_runtime_client - # Verify API call - agent_client.list_data_sources.assert_called_once_with(knowledgeBaseId="kb123") + # Access the property + result = client.runtime_client + # Verify the mock was returned + assert result == mock_runtime_client -@patch("boto3.Session") -def test_get_data_source_id_no_sources(mock_session): - """Test get_data_source_id method with no data sources.""" - # Create session mock - session_instance = MagicMock() - agent_client = MagicMock() - session_instance.client.return_value = agent_client - mock_session.return_value = session_instance - # Mock empty response - agent_client.list_data_sources.return_value = {"dataSourceSummaries": []} +def test_runtime_client_property_with_session(): + """Test the runtime_client property with provided session.""" + # Create mock session and client + mock_session = MagicMock() + runtime_client = MagicMock() + mock_session.client.return_value = runtime_client - # Initialize client + # Initialize client with session + client = MemoryServiceClient(session=mock_session) + + # Mock the property + with patch.object(type(client), "runtime_client", new_callable=PropertyMock) as mock_property: + # Configure the property to return the mock client + mock_property.return_value = runtime_client + + # Access the property + result = client.runtime_client + + # Verify the mock was returned + assert result == runtime_client + + +def test_get_data_source_id(): + """Test get_data_source_id method.""" + # Create client client = MemoryServiceClient() - # Call method and verify exception - with pytest.raises(ValueError, match=r"No data sources found"): - client.get_data_source_id("kb123") + # Mock agent client + mock_agent_client = MagicMock() + mock_agent_client.list_data_sources.return_value = {"dataSourceSummaries": [{"dataSourceId": "ds123"}]} + # Mock the agent_client property + with patch.object(type(client), "agent_client", new_callable=PropertyMock) as mock_property: + mock_property.return_value = mock_agent_client -@patch("boto3.Session") -def test_list_documents_with_defaults(mock_session): - """Test list_documents method with default parameters.""" - # Create session mock - session_instance = MagicMock() - agent_client = MagicMock() - session_instance.client.return_value = agent_client - mock_session.return_value = session_instance + # Call method + result = client.get_data_source_id("kb123") - # Mock get_data_source_id - agent_client.list_data_sources.return_value = {"dataSourceSummaries": [{"dataSourceId": "ds123"}]} + # Verify response + assert result == "ds123" - # Initialize client + # Verify API call + mock_agent_client.list_data_sources.assert_called_once_with(knowledgeBaseId="kb123") + + +def test_get_data_source_id_no_sources(): + """Test get_data_source_id method with no data sources.""" + # Create client client = MemoryServiceClient() - # Call method - client.list_documents("kb123") + # Mock agent client with empty response + mock_agent_client = MagicMock() + mock_agent_client.list_data_sources.return_value = {"dataSourceSummaries": []} - # Verify API call - agent_client.list_knowledge_base_documents.assert_called_once_with(knowledgeBaseId="kb123", dataSourceId="ds123") + # Mock the agent_client property + with patch.object(type(client), "agent_client", new_callable=PropertyMock) as mock_property: + mock_property.return_value = mock_agent_client + # Call method and verify exception + with pytest.raises(ValueError, match=r"No data sources found"): + client.get_data_source_id("kb123") -@patch("boto3.Session") -def test_list_documents_with_params(mock_session): - """Test list_documents method with all parameters.""" - # Create session mock - session_instance = MagicMock() - agent_client = MagicMock() - session_instance.client.return_value = agent_client - mock_session.return_value = session_instance - # Initialize client +def test_list_documents_with_defaults(): + """Test list_documents method with default parameters.""" + # Create client client = MemoryServiceClient() - # Call method - client.list_documents("kb123", "ds456", 10, "token123") + # Mock agent client + mock_agent_client = MagicMock() + mock_agent_client.list_data_sources.return_value = {"dataSourceSummaries": [{"dataSourceId": "ds123"}]} + mock_agent_client.list_knowledge_base_documents.return_value = {"documents": []} - # Verify API call - agent_client.list_knowledge_base_documents.assert_called_once_with( - knowledgeBaseId="kb123", dataSourceId="ds456", maxResults=10, nextToken="token123" - ) + # Mock the agent_client property + with patch.object(type(client), "agent_client", new_callable=PropertyMock) as mock_property: + mock_property.return_value = mock_agent_client + # Call method + client.list_documents("kb123") -@patch("boto3.Session") -def test_get_document(mock_session): - """Test get_document method.""" - # Create session mock - session_instance = MagicMock() - agent_client = MagicMock() - session_instance.client.return_value = agent_client - mock_session.return_value = session_instance + # Verify API call + mock_agent_client.list_knowledge_base_documents.assert_called_once_with( + knowledgeBaseId="kb123", dataSourceId="ds123" + ) - # Mock get_data_source_id - agent_client.list_data_sources.return_value = {"dataSourceSummaries": [{"dataSourceId": "ds123"}]} - # Initialize client +def test_list_documents_with_params(): + """Test list_documents method with all parameters.""" + # Create client client = MemoryServiceClient() - # Call method - client.get_document("kb123", None, "doc123") - - # Verify API call - agent_client.get_knowledge_base_documents.assert_called_once_with( - knowledgeBaseId="kb123", - dataSourceId="ds123", - documentIdentifiers=[{"dataSourceType": "CUSTOM", "custom": {"id": "doc123"}}], - ) + # Mock agent client + mock_agent_client = MagicMock() + mock_agent_client.list_knowledge_base_documents.return_value = {"documents": []} + # Mock the agent_client property + with patch.object(type(client), "agent_client", new_callable=PropertyMock) as mock_property: + mock_property.return_value = mock_agent_client -@patch("boto3.Session") -def test_store_document(mock_session): - """Test store_document method.""" - # Create session mock - session_instance = MagicMock() - agent_client = MagicMock() - session_instance.client.return_value = agent_client - mock_session.return_value = session_instance + # Call method + client.list_documents("kb123", "ds456", 10, "token123") - # Mock get_data_source_id - agent_client.list_data_sources.return_value = {"dataSourceSummaries": [{"dataSourceId": "ds123"}]} + # Verify API call + mock_agent_client.list_knowledge_base_documents.assert_called_once_with( + knowledgeBaseId="kb123", + dataSourceId="ds456", + maxResults=10, + nextToken="token123", + ) - # Mock ingest response - agent_client.ingest_knowledge_base_documents.return_value = {"status": "success"} - # Initialize client +def test_get_document(): + """Test get_document method.""" + # Create client client = MemoryServiceClient() - # Call method - response, doc_id, doc_title = client.store_document("kb123", None, "test content", "Test Title") + # Mock agent client + mock_agent_client = MagicMock() + mock_agent_client.list_data_sources.return_value = {"dataSourceSummaries": [{"dataSourceId": "ds123"}]} + mock_agent_client.get_knowledge_base_documents.return_value = {"documentDetails": []} - # Verify response - assert response == {"status": "success"} - assert "memory_" in doc_id # Verify ID format - assert doc_title == "Test Title" + # Mock the agent_client property + with patch.object(type(client), "agent_client", new_callable=PropertyMock) as mock_property: + mock_property.return_value = mock_agent_client - # Verify API call structure - call_args = agent_client.ingest_knowledge_base_documents.call_args[1] - assert call_args["knowledgeBaseId"] == "kb123" - assert call_args["dataSourceId"] == "ds123" - assert len(call_args["documents"]) == 1 + # Call method + client.get_document("kb123", None, "doc123") - # Verify document content - doc = call_args["documents"][0] - assert doc["content"]["dataSourceType"] == "CUSTOM" - assert doc["content"]["custom"]["sourceType"] == "IN_LINE" + # Verify API call + mock_agent_client.get_knowledge_base_documents.assert_called_once_with( + knowledgeBaseId="kb123", + dataSourceId="ds123", + documentIdentifiers=[{"dataSourceType": "CUSTOM", "custom": {"id": "doc123"}}], + ) - # Verify content format - content_json = doc["content"]["custom"]["inlineContent"]["textContent"]["data"] - content_data = json.loads(content_json) - assert content_data["title"] == "Test Title" - assert content_data["action"] == "store" - assert content_data["content"] == "test content" +def test_store_document(): + """Test store_document method.""" + # Create client + client = MemoryServiceClient() -@patch("boto3.Session") -def test_store_document_no_title(mock_session): + # Mock agent client + mock_agent_client = MagicMock() + mock_agent_client.list_data_sources.return_value = {"dataSourceSummaries": [{"dataSourceId": "ds123"}]} + mock_agent_client.ingest_knowledge_base_documents.return_value = {"status": "success"} + + # Mock the agent_client property + with patch.object(type(client), "agent_client", new_callable=PropertyMock) as mock_property: + mock_property.return_value = mock_agent_client + + # Call method + response, doc_id, doc_title = client.store_document("kb123", None, "test content", "Test Title") + + # Verify response + assert response == {"status": "success"} + assert "memory_" in doc_id # Verify ID format + assert doc_title == "Test Title" + + # Verify API call structure + call_args = mock_agent_client.ingest_knowledge_base_documents.call_args[1] + assert call_args["knowledgeBaseId"] == "kb123" + assert call_args["dataSourceId"] == "ds123" + assert len(call_args["documents"]) == 1 + + # Verify document content + doc = call_args["documents"][0] + assert doc["content"]["dataSourceType"] == "CUSTOM" + assert doc["content"]["custom"]["sourceType"] == "IN_LINE" + + # Verify content format + content_json = doc["content"]["custom"]["inlineContent"]["textContent"]["data"] + content_data = json.loads(content_json) + assert content_data["title"] == "Test Title" + assert content_data["action"] == "store" + assert content_data["content"] == "test content" + + +def test_store_document_no_title(): """Test store_document method with auto-generated title.""" - # Create session mock - session_instance = MagicMock() - agent_client = MagicMock() - session_instance.client.return_value = agent_client - mock_session.return_value = session_instance - - # Mock get_data_source_id - agent_client.list_data_sources.return_value = {"dataSourceSummaries": [{"dataSourceId": "ds123"}]} + # Create client + client = MemoryServiceClient() - # Mock ingest response - agent_client.ingest_knowledge_base_documents.return_value = {"status": "success"} + # Mock agent client + mock_agent_client = MagicMock() + mock_agent_client.list_data_sources.return_value = {"dataSourceSummaries": [{"dataSourceId": "ds123"}]} + mock_agent_client.ingest_knowledge_base_documents.return_value = {"status": "success"} - # Initialize client - client = MemoryServiceClient() + # Mock the agent_client property + with patch.object(type(client), "agent_client", new_callable=PropertyMock) as mock_property: + mock_property.return_value = mock_agent_client - # Call method without title - response, doc_id, doc_title = client.store_document("kb123", None, "test content") + # Call method without title + response, doc_id, doc_title = client.store_document("kb123", None, "test content") - # Verify title format - assert "Strands Memory" in doc_title + # Verify title format + assert "Strands Memory" in doc_title - # Verify API call structure - call_args = agent_client.ingest_knowledge_base_documents.call_args[1] + # Verify API call structure + call_args = mock_agent_client.ingest_knowledge_base_documents.call_args[1] - # Verify document content - doc = call_args["documents"][0] + # Verify document content + doc = call_args["documents"][0] - # Verify content format - content_json = doc["content"]["custom"]["inlineContent"]["textContent"]["data"] - content_data = json.loads(content_json) - assert content_data["title"] == doc_title + # Verify content format + content_json = doc["content"]["custom"]["inlineContent"]["textContent"]["data"] + content_data = json.loads(content_json) + assert content_data["title"] == doc_title -@patch("boto3.Session") -def test_delete_document(mock_session): +def test_delete_document(): """Test delete_document method.""" - # Create session mock - session_instance = MagicMock() - agent_client = MagicMock() - session_instance.client.return_value = agent_client - mock_session.return_value = session_instance + # Create client + client = MemoryServiceClient() - # Mock get_data_source_id - agent_client.list_data_sources.return_value = {"dataSourceSummaries": [{"dataSourceId": "ds123"}]} + # Mock agent client + mock_agent_client = MagicMock() + mock_agent_client.list_data_sources.return_value = {"dataSourceSummaries": [{"dataSourceId": "ds123"}]} + mock_agent_client.delete_knowledge_base_documents.return_value = {"status": "success"} - # Mock delete response - agent_client.delete_knowledge_base_documents.return_value = {"status": "success"} + # Mock the agent_client property + with patch.object(type(client), "agent_client", new_callable=PropertyMock) as mock_property: + mock_property.return_value = mock_agent_client - # Initialize client + # Call method + response = client.delete_document("kb123", None, "doc123") + + # Verify response + assert response == {"status": "success"} + + # Verify API call + mock_agent_client.delete_knowledge_base_documents.assert_called_once_with( + knowledgeBaseId="kb123", + dataSourceId="ds123", + documentIdentifiers=[{"dataSourceType": "CUSTOM", "custom": {"id": "doc123"}}], + ) + + +def test_retrieve(): + """Test retrieve method.""" + # Create client client = MemoryServiceClient() - # Call method - response = client.delete_document("kb123", None, "doc123") + # Mock runtime client + mock_runtime_client = MagicMock() + mock_runtime_client.retrieve.return_value = {"retrievalResults": []} - # Verify response - assert response == {"status": "success"} + # Mock the runtime_client property + with patch.object(type(client), "runtime_client", new_callable=PropertyMock) as mock_property: + mock_property.return_value = mock_runtime_client - # Verify API call - agent_client.delete_knowledge_base_documents.assert_called_once_with( - knowledgeBaseId="kb123", - dataSourceId="ds123", - documentIdentifiers=[{"dataSourceType": "CUSTOM", "custom": {"id": "doc123"}}], - ) + # Call method + result = client.retrieve("kb123", "test query", 10) + # Verify response + assert result == {"retrievalResults": []} -@patch("boto3.Session") -def test_retrieve(mock_session): - """Test retrieve method.""" - # Create session mock - session_instance = MagicMock() - runtime_client = MagicMock() - session_instance.client.return_value = runtime_client - mock_session.return_value = session_instance + # Verify API call + mock_runtime_client.retrieve.assert_called_once_with( + retrievalQuery={"text": "test query"}, + knowledgeBaseId="kb123", + retrievalConfiguration={ + "vectorSearchConfiguration": {"numberOfResults": 10}, + }, + ) - # Mock retrieve response - runtime_client.retrieve.return_value = {"retrievalResults": []} - # Initialize client +def test_retrieve_with_token(): + """Test retrieve method with pagination token.""" + # Create client client = MemoryServiceClient() - # Call method - result = client.retrieve("kb123", "test query", 10) + # Mock runtime client + mock_runtime_client = MagicMock() + mock_runtime_client.retrieve.return_value = {"retrievalResults": []} - # Verify response - assert result == {"retrievalResults": []} + # Mock the runtime_client property + with patch.object(type(client), "runtime_client", new_callable=PropertyMock) as mock_property: + mock_property.return_value = mock_runtime_client - # Verify API call - runtime_client.retrieve.assert_called_once_with( - retrievalQuery={"text": "test query"}, - knowledgeBaseId="kb123", - retrievalConfiguration={ - "vectorSearchConfiguration": {"numberOfResults": 10}, - }, - ) + # Call method + client.retrieve("kb123", "test query", 10, "token123") + # Verify API call includes token + call_args = mock_runtime_client.retrieve.call_args[1] + assert call_args["nextToken"] == "token123" -@patch("boto3.Session") -def test_retrieve_with_token(mock_session): - """Test retrieve method with pagination token.""" - # Create session mock - session_instance = MagicMock() + +def test_all_methods_with_provided_session(): + """Test that all methods work correctly with a provided session.""" + # Create mock session with both clients + mock_session = MagicMock() + agent_client = MagicMock() runtime_client = MagicMock() - session_instance.client.return_value = runtime_client - mock_session.return_value = session_instance - # Initialize client - client = MemoryServiceClient() + # Configure session to return appropriate client + def get_client(service_name, **kwargs): + if service_name == "bedrock-agent": + return agent_client + elif service_name == "bedrock-agent-runtime": + return runtime_client + + mock_session.client.side_effect = get_client + + # Mock responses + agent_client.list_data_sources.return_value = {"dataSourceSummaries": [{"dataSourceId": "ds123"}]} + agent_client.list_knowledge_base_documents.return_value = {"documents": []} + agent_client.ingest_knowledge_base_documents.return_value = {"status": "success"} + runtime_client.retrieve.return_value = {"retrievalResults": []} + + # Initialize client with session + client = MemoryServiceClient(session=mock_session, region="us-east-1") + + # Mock the properties to return the correct clients + with patch.object(type(client), "agent_client", new_callable=PropertyMock) as mock_agent_prop: + with patch.object(type(client), "runtime_client", new_callable=PropertyMock) as mock_runtime_prop: + mock_agent_prop.return_value = agent_client + mock_runtime_prop.return_value = runtime_client - # Call method - client.retrieve("kb123", "test query", 10, "token123") + # Test various operations + client.list_documents("kb123") + client.store_document("kb123", None, "content", "title") + client.retrieve("kb123", "query") - # Verify API call includes token - call_args = runtime_client.retrieve.call_args[1] - assert call_args["nextToken"] == "token123" + # Verify methods were called + assert agent_client.list_data_sources.called + assert agent_client.list_knowledge_base_documents.called + assert runtime_client.retrieve.called