diff --git a/src/strands_tools/mem0_memory.py b/src/strands_tools/mem0_memory.py index b001d0f6..4dfff380 100644 --- a/src/strands_tools/mem0_memory.py +++ b/src/strands_tools/mem0_memory.py @@ -162,19 +162,6 @@ class Mem0ServiceClient: "max_tokens": int(os.environ.get("MEM0_LLM_MAX_TOKENS", 2000)), }, }, - "vector_store": { - "provider": "opensearch", - "config": { - "port": 443, - "collection_name": os.environ.get("OPENSEARCH_COLLECTION", "mem0"), - "host": os.environ.get("OPENSEARCH_HOST"), - "embedding_model_dims": 1024, - "connection_class": RequestsHttpConnection, - "pool_maxsize": 20, - "use_ssl": True, - "verify_certs": True, - }, - }, } def __init__(self, config: Optional[Dict] = None): @@ -204,19 +191,32 @@ def _initialize_client(self, config: Optional[Dict] = None) -> Any: logger.debug("Using Mem0 Platform backend (MemoryClient)") return MemoryClient() - if os.environ.get("NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER"): - logger.debug("Using Neptune Analytics graph backend (Mem0Memory with Neptune Analytics)") - config = self._configure_neptune_analytics_backend(config) + if os.environ.get("NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER") and os.environ.get("OPENSEARCH_HOST"): + raise RuntimeError("""Conflicting backend configurations: + Only one environment variable of NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER or OPENSEARCH_HOST can be set.""") + # Vector search providers if os.environ.get("OPENSEARCH_HOST"): logger.debug("Using OpenSearch backend (Mem0Memory with OpenSearch)") - return self._initialize_opensearch_client(config) + merged_config = self._initialize_opensearch_client(config) - logger.debug("Using FAISS backend (Mem0Memory with FAISS)") - return self._initialize_faiss_client(config) + elif os.environ.get("NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER"): + logger.debug("Using Neptune Analytics vector backend (Mem0Memory with Neptune Analytics)") + merged_config = self._configure_neptune_analytics_vector_backend(config) - def _configure_neptune_analytics_backend(self, config: Optional[Dict] = None) -> Dict: - """Initialize a Mem0 client with Neptune Analytics graph backend. + else: + logger.debug("Using FAISS backend (Mem0Memory with FAISS)") + merged_config = self._initialize_faiss_client(config) + + # Graph backend providers + if os.environ.get("NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER"): + logger.debug("Using Neptune Analytics graph backend (Mem0Memory with Neptune Analytics)") + merged_config = self._configure_neptune_analytics_graph_backend(merged_config) + + return Mem0Memory.from_config(config_dict=merged_config) + + def _configure_neptune_analytics_vector_backend(self, config: Optional[Dict] = None) -> Dict: + """Initialize a Mem0 client with Neptune Analytics vector backend. Args: config: Optional configuration dictionary to override defaults. @@ -225,13 +225,16 @@ def _configure_neptune_analytics_backend(self, config: Optional[Dict] = None) -> An configuration dict with graph backend. """ config = config or {} - config["graph_store"] = { + config["vector_store"] = { "provider": "neptune", - "config": {"endpoint": f"neptune-graph://{os.environ.get('NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER')}"}, + "config": { + "collection_name": os.environ.get("NEPTUNE_ANALYTICS_VECTOR_COLLECTION", "mem0"), + "endpoint": f"neptune-graph://{os.environ.get('NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER')}", + }, } - return config + return self._merge_config(config) - def _initialize_opensearch_client(self, config: Optional[Dict] = None) -> Mem0Memory: + def _initialize_opensearch_client(self, config: Optional[Dict] = None) -> Dict: """Initialize a Mem0 client with OpenSearch backend. Args: @@ -240,6 +243,22 @@ def _initialize_opensearch_client(self, config: Optional[Dict] = None) -> Mem0Me Returns: An initialized Mem0Memory instance configured for OpenSearch. """ + # Add vector portion of the config + config = config or {} + config["vector_store"] = { + "provider": "opensearch", + "config": { + "port": 443, + "collection_name": os.environ.get("OPENSEARCH_COLLECTION", "mem0"), + "host": os.environ.get("OPENSEARCH_HOST"), + "embedding_model_dims": 1024, + "connection_class": RequestsHttpConnection, + "pool_maxsize": 20, + "use_ssl": True, + "verify_certs": True, + }, + } + # Set up AWS region self.region = os.environ.get("AWS_REGION", "us-west-2") if not os.environ.get("AWS_REGION"): @@ -254,9 +273,9 @@ def _initialize_opensearch_client(self, config: Optional[Dict] = None) -> Mem0Me merged_config = self._merge_config(config) merged_config["vector_store"]["config"].update({"http_auth": auth, "host": os.environ["OPENSEARCH_HOST"]}) - return Mem0Memory.from_config(config_dict=merged_config) + return merged_config - def _initialize_faiss_client(self, config: Optional[Dict] = None) -> Mem0Memory: + def _initialize_faiss_client(self, config: Optional[Dict] = None) -> Dict: """Initialize a Mem0 client with FAISS backend. Args: @@ -284,8 +303,22 @@ def _initialize_faiss_client(self, config: Optional[Dict] = None) -> Mem0Memory: "path": "/tmp/mem0_384_faiss", }, } + return merged_config - return Mem0Memory.from_config(config_dict=merged_config) + def _configure_neptune_analytics_graph_backend(self, config: Dict) -> Dict: + """Initialize a Mem0 client with Neptune Analytics graph backend. + + Args: + config: Configuration dictionary to add graph backend to. + + Returns: + An configuration dict with graph backend. + """ + config["graph_store"] = { + "provider": "neptune", + "config": {"endpoint": f"neptune-graph://{os.environ.get('NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER')}"}, + } + return config def _merge_config(self, config: Optional[Dict] = None) -> Dict: """Merge user-provided configuration with default configuration. @@ -457,13 +490,14 @@ def format_retrieve_response(memories: List[Dict]) -> Panel: def format_retrieve_graph_response(memories: List[Dict]) -> Panel: """Format retrieve response for graph data""" if not memories: - return Panel("No graph memories found matching the query.", - title="[bold yellow]No Matches", border_style="yellow") + return Panel( + "No graph memories found matching the query.", title="[bold yellow]No Matches", border_style="yellow" + ) table = Table(title="Search Results", show_header=True, header_style="bold magenta") - table.add_column("Source", style="cyan") - table.add_column("Relationship", style="yellow", width=50) - table.add_column("Destination", style="green") + table.add_column("Source", style="cyan", width=25) + table.add_column("Relationship", style="yellow", width=45) + table.add_column("Destination", style="green", width=30) for memory in memories: source = memory.get("source", "N/A") @@ -481,9 +515,9 @@ def format_list_graph_response(memories: List[Dict]) -> Panel: return Panel("No graph memories found.", title="[bold yellow]No Memories", border_style="yellow") table = Table(title="Graph Memories", show_header=True, header_style="bold magenta") - table.add_column("Source", style="cyan") - table.add_column("Relationship", style="yellow", width=50) - table.add_column("Target", style="green") + table.add_column("Source", style="cyan", width=25) + table.add_column("Relationship", style="yellow", width=45) + table.add_column("Target", style="green", width=30) for memory in memories: source = memory.get("source", "N/A") @@ -544,6 +578,26 @@ def format_store_response(results: List[Dict]) -> Panel: return Panel(table, title="[bold green]Memory Stored", border_style="green") +def format_store_graph_response(memories: List[Dict]) -> Panel: + """Format store response for graph data""" + if not memories: + return Panel("No graph memories stored.", title="[bold yellow]No Memories Stored", border_style="yellow") + + table = Table(title="Graph Memories Stored", show_header=True, header_style="bold magenta") + table.add_column("Source", style="cyan", width=25) + table.add_column("Relationship", style="yellow", width=45) + table.add_column("Target", style="green", width=30) + + for memory in memories: + source = memory[0].get("source", "N/A") + relationship = memory[0].get("relationship", "N/A") + destination = memory[0].get("target", "N/A") + + table.add_row(source, relationship, destination) + + return Panel(table, title="[bold green]Memories Stored (Graph)", border_style="green") + + def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult: """ Memory management tool for storing, retrieving, and managing memories in Mem0. @@ -655,6 +709,14 @@ def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult: if results_list: panel = format_store_response(results_list) console.print(panel) + + # Process graph relations (If any) + if "relations" in results: + relationships_list = results.get("relations", [])["added_entities"] + results_list.extend(relationships_list) + panel_graph = format_store_graph_response(relationships_list) + console.print(panel_graph) + return ToolResult( toolUseId=tool_use_id, status="success", diff --git a/tests/test_mem0.py b/tests/test_mem0.py index be836788..9b806d1e 100644 --- a/tests/test_mem0.py +++ b/tests/test_mem0.py @@ -424,13 +424,25 @@ def test_mem0_service_client_init(mock_opensearch, mock_mem0_memory, mock_sessio client = Mem0ServiceClient() assert client.region == os.environ.get("AWS_REGION", "us-west-2") - # Test with optional Graph backend + # Test with conflict scenario with patch.dict( os.environ, - {"OPENSEARCH_HOST": "test.opensearch.amazonaws.com", "NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER": "g-5aaaaa1234"}, + { + "OPENSEARCH_HOST": "test.opensearch.amazonaws.com", + "NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER": "g-5aaaaa1234", + }, + ): + with pytest.raises(RuntimeError): + Mem0ServiceClient() + + # Test with Neptune Analytics for both vector and graph + with patch.dict( + os.environ, + { + "NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER": "g-5aaaaa1234", + }, ): client = Mem0ServiceClient() - assert client.region == os.environ.get("AWS_REGION", "us-west-2") assert client.mem0 is not None # Test with custom config (OpenSearch)