Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 98 additions & 36 deletions src/strands_tools/mem0_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,19 +162,6 @@ class Mem0ServiceClient:
"max_tokens": int(os.environ.get("MEM0_LLM_MAX_TOKENS", 2000)),
},
},
"vector_store": {
"provider": "opensearch",
"config": {
"port": 443,
"collection_name": os.environ.get("OPENSEARCH_COLLECTION", "mem0"),
"host": os.environ.get("OPENSEARCH_HOST"),
"embedding_model_dims": 1024,
"connection_class": RequestsHttpConnection,
"pool_maxsize": 20,
"use_ssl": True,
"verify_certs": True,
},
},
}

def __init__(self, config: Optional[Dict] = None):
Expand Down Expand Up @@ -204,19 +191,32 @@ def _initialize_client(self, config: Optional[Dict] = None) -> Any:
logger.debug("Using Mem0 Platform backend (MemoryClient)")
return MemoryClient()

if os.environ.get("NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER"):
logger.debug("Using Neptune Analytics graph backend (Mem0Memory with Neptune Analytics)")
config = self._configure_neptune_analytics_backend(config)
if os.environ.get("NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER") and os.environ.get("OPENSEARCH_HOST"):
raise RuntimeError("""Conflicting backend configurations:
Only one environment variable of NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER or OPENSEARCH_HOST can be set.""")

# Vector search providers
if os.environ.get("OPENSEARCH_HOST"):
logger.debug("Using OpenSearch backend (Mem0Memory with OpenSearch)")
return self._initialize_opensearch_client(config)
merged_config = self._initialize_opensearch_client(config)

logger.debug("Using FAISS backend (Mem0Memory with FAISS)")
return self._initialize_faiss_client(config)
elif os.environ.get("NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER"):
logger.debug("Using Neptune Analytics vector backend (Mem0Memory with Neptune Analytics)")
merged_config = self._configure_neptune_analytics_vector_backend(config)

def _configure_neptune_analytics_backend(self, config: Optional[Dict] = None) -> Dict:
"""Initialize a Mem0 client with Neptune Analytics graph backend.
else:
logger.debug("Using FAISS backend (Mem0Memory with FAISS)")
merged_config = self._initialize_faiss_client(config)

# Graph backend providers
if os.environ.get("NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER"):
logger.debug("Using Neptune Analytics graph backend (Mem0Memory with Neptune Analytics)")
merged_config = self._configure_neptune_analytics_graph_backend(merged_config)

return Mem0Memory.from_config(config_dict=merged_config)

def _configure_neptune_analytics_vector_backend(self, config: Optional[Dict] = None) -> Dict:
"""Initialize a Mem0 client with Neptune Analytics vector backend.

Args:
config: Optional configuration dictionary to override defaults.
Expand All @@ -225,13 +225,16 @@ def _configure_neptune_analytics_backend(self, config: Optional[Dict] = None) ->
An configuration dict with graph backend.
"""
config = config or {}
config["graph_store"] = {
config["vector_store"] = {
"provider": "neptune",
"config": {"endpoint": f"neptune-graph://{os.environ.get('NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER')}"},
"config": {
"collection_name": os.environ.get("NEPTUNE_ANALYTICS_VECTOR_COLLECTION", "mem0"),
"endpoint": f"neptune-graph://{os.environ.get('NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER')}",
},
}
return config
return self._merge_config(config)

def _initialize_opensearch_client(self, config: Optional[Dict] = None) -> Mem0Memory:
def _initialize_opensearch_client(self, config: Optional[Dict] = None) -> Dict:
"""Initialize a Mem0 client with OpenSearch backend.

Args:
Expand All @@ -240,6 +243,22 @@ def _initialize_opensearch_client(self, config: Optional[Dict] = None) -> Mem0Me
Returns:
An initialized Mem0Memory instance configured for OpenSearch.
"""
# Add vector portion of the config
config = config or {}
config["vector_store"] = {
"provider": "opensearch",
"config": {
"port": 443,
"collection_name": os.environ.get("OPENSEARCH_COLLECTION", "mem0"),
"host": os.environ.get("OPENSEARCH_HOST"),
"embedding_model_dims": 1024,
"connection_class": RequestsHttpConnection,
"pool_maxsize": 20,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit; this parameter might be os.getenv too, but definitely not a blocker

"use_ssl": True,
"verify_certs": True,
},
}

# Set up AWS region
self.region = os.environ.get("AWS_REGION", "us-west-2")
if not os.environ.get("AWS_REGION"):
Expand All @@ -254,9 +273,9 @@ def _initialize_opensearch_client(self, config: Optional[Dict] = None) -> Mem0Me
merged_config = self._merge_config(config)
merged_config["vector_store"]["config"].update({"http_auth": auth, "host": os.environ["OPENSEARCH_HOST"]})

return Mem0Memory.from_config(config_dict=merged_config)
return merged_config

def _initialize_faiss_client(self, config: Optional[Dict] = None) -> Mem0Memory:
def _initialize_faiss_client(self, config: Optional[Dict] = None) -> Dict:
"""Initialize a Mem0 client with FAISS backend.

Args:
Expand Down Expand Up @@ -284,8 +303,22 @@ def _initialize_faiss_client(self, config: Optional[Dict] = None) -> Mem0Memory:
"path": "/tmp/mem0_384_faiss",
},
}
return merged_config

return Mem0Memory.from_config(config_dict=merged_config)
def _configure_neptune_analytics_graph_backend(self, config: Dict) -> Dict:
"""Initialize a Mem0 client with Neptune Analytics graph backend.

Args:
config: Configuration dictionary to add graph backend to.

Returns:
An configuration dict with graph backend.
"""
config["graph_store"] = {
"provider": "neptune",
"config": {"endpoint": f"neptune-graph://{os.environ.get('NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER')}"},
}
return config

def _merge_config(self, config: Optional[Dict] = None) -> Dict:
"""Merge user-provided configuration with default configuration.
Expand Down Expand Up @@ -457,13 +490,14 @@ def format_retrieve_response(memories: List[Dict]) -> Panel:
def format_retrieve_graph_response(memories: List[Dict]) -> Panel:
"""Format retrieve response for graph data"""
if not memories:
return Panel("No graph memories found matching the query.",
title="[bold yellow]No Matches", border_style="yellow")
return Panel(
"No graph memories found matching the query.", title="[bold yellow]No Matches", border_style="yellow"
)

table = Table(title="Search Results", show_header=True, header_style="bold magenta")
table.add_column("Source", style="cyan")
table.add_column("Relationship", style="yellow", width=50)
table.add_column("Destination", style="green")
table.add_column("Source", style="cyan", width=25)
table.add_column("Relationship", style="yellow", width=45)
table.add_column("Destination", style="green", width=30)

for memory in memories:
source = memory.get("source", "N/A")
Expand All @@ -481,9 +515,9 @@ def format_list_graph_response(memories: List[Dict]) -> Panel:
return Panel("No graph memories found.", title="[bold yellow]No Memories", border_style="yellow")

table = Table(title="Graph Memories", show_header=True, header_style="bold magenta")
table.add_column("Source", style="cyan")
table.add_column("Relationship", style="yellow", width=50)
table.add_column("Target", style="green")
table.add_column("Source", style="cyan", width=25)
table.add_column("Relationship", style="yellow", width=45)
table.add_column("Target", style="green", width=30)

for memory in memories:
source = memory.get("source", "N/A")
Expand Down Expand Up @@ -544,6 +578,26 @@ def format_store_response(results: List[Dict]) -> Panel:
return Panel(table, title="[bold green]Memory Stored", border_style="green")


def format_store_graph_response(memories: List[Dict]) -> Panel:
"""Format store response for graph data"""
if not memories:
return Panel("No graph memories stored.", title="[bold yellow]No Memories Stored", border_style="yellow")

table = Table(title="Graph Memories Stored", show_header=True, header_style="bold magenta")
table.add_column("Source", style="cyan", width=25)
table.add_column("Relationship", style="yellow", width=45)
table.add_column("Target", style="green", width=30)

for memory in memories:
source = memory[0].get("source", "N/A")
relationship = memory[0].get("relationship", "N/A")
destination = memory[0].get("target", "N/A")

table.add_row(source, relationship, destination)

return Panel(table, title="[bold green]Memories Stored (Graph)", border_style="green")


def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult:
"""
Memory management tool for storing, retrieving, and managing memories in Mem0.
Expand Down Expand Up @@ -655,6 +709,14 @@ def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult:
if results_list:
panel = format_store_response(results_list)
console.print(panel)

# Process graph relations (If any)
if "relations" in results:
relationships_list = results.get("relations", [])["added_entities"]
results_list.extend(relationships_list)
panel_graph = format_store_graph_response(relationships_list)
console.print(panel_graph)

return ToolResult(
toolUseId=tool_use_id,
status="success",
Expand Down
18 changes: 15 additions & 3 deletions tests/test_mem0.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,13 +424,25 @@ def test_mem0_service_client_init(mock_opensearch, mock_mem0_memory, mock_sessio
client = Mem0ServiceClient()
assert client.region == os.environ.get("AWS_REGION", "us-west-2")

# Test with optional Graph backend
# Test with conflict scenario
with patch.dict(
os.environ,
{"OPENSEARCH_HOST": "test.opensearch.amazonaws.com", "NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER": "g-5aaaaa1234"},
{
"OPENSEARCH_HOST": "test.opensearch.amazonaws.com",
"NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER": "g-5aaaaa1234",
},
):
with pytest.raises(RuntimeError):
Mem0ServiceClient()

# Test with Neptune Analytics for both vector and graph
with patch.dict(
os.environ,
{
"NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER": "g-5aaaaa1234",
},
):
client = Mem0ServiceClient()
assert client.region == os.environ.get("AWS_REGION", "us-west-2")
assert client.mem0 is not None

# Test with custom config (OpenSearch)
Expand Down