Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions python/packages/mem0/agent_framework_mem0/_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,12 @@ async def invoking(self, messages: ChatMessage | MutableSequence[ChatMessage], *
if not input_text.strip():
return Context(messages=None)

# Build filters from init parameters
filters = self._build_filters()

search_response: MemorySearchResponse_v1_1 | MemorySearchResponse_v2 = await self.mem0_client.search( # type: ignore[misc]
query=input_text,
user_id=self.user_id,
agent_id=self.agent_id,
run_id=self._per_operation_thread_id if self.scope_to_per_operation_thread_id else self.thread_id,
filters=filters,
)

# Depending on the API version, the response schema varies slightly
Expand Down Expand Up @@ -185,6 +186,29 @@ def _validate_filters(self) -> None:
"At least one of the filters: agent_id, user_id, application_id, or thread_id is required."
)

def _build_filters(self) -> dict[str, Any]:
"""Build search filters from initialization parameters.

Returns:
Filter dictionary for mem0 v2 search API containing initialization parameters.
In the v2 API, filters holds the user_id, agent_id, run_id (thread_id), and app_id
(application_id) which are required for scoping memory search operations.
"""
filters: dict[str, Any] = {}

if self.user_id:
filters["user_id"] = self.user_id
if self.agent_id:
filters["agent_id"] = self.agent_id
if self.scope_to_per_operation_thread_id and self._per_operation_thread_id:
filters["run_id"] = self._per_operation_thread_id
elif self.thread_id:
filters["run_id"] = self.thread_id
if self.application_id:
filters["app_id"] = self.application_id

return filters

def _validate_per_operation_thread_id(self, thread_id: str | None) -> None:
"""Validates that a new thread ID doesn't conflict with an existing one when scoped.

Expand Down
91 changes: 87 additions & 4 deletions python/packages/mem0/tests/test_mem0_context_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ async def test_model_invoking_single_message(self, mock_mem0_client: AsyncMock)
mock_mem0_client.search.assert_called_once()
call_args = mock_mem0_client.search.call_args
assert call_args.kwargs["query"] == "What's the weather?"
assert call_args.kwargs["user_id"] == "user123"
assert call_args.kwargs["filters"] == {"user_id": "user123"}

assert isinstance(context, Context)
expected_instructions = (
Expand Down Expand Up @@ -373,8 +373,7 @@ async def test_model_invoking_with_agent_id(self, mock_mem0_client: AsyncMock) -
await provider.invoking(message)

call_args = mock_mem0_client.search.call_args
assert call_args.kwargs["agent_id"] == "agent123"
assert call_args.kwargs["user_id"] is None
assert call_args.kwargs["filters"] == {"agent_id": "agent123"}

async def test_model_invoking_with_scope_to_per_operation_thread_id(self, mock_mem0_client: AsyncMock) -> None:
"""Test invoking with scope_to_per_operation_thread_id enabled."""
Expand All @@ -392,7 +391,7 @@ async def test_model_invoking_with_scope_to_per_operation_thread_id(self, mock_m
await provider.invoking(message)

call_args = mock_mem0_client.search.call_args
assert call_args.kwargs["run_id"] == "operation_thread"
assert call_args.kwargs["filters"] == {"user_id": "user123", "run_id": "operation_thread"}

async def test_model_invoking_no_memories_returns_none_instructions(self, mock_mem0_client: AsyncMock) -> None:
"""Test that no memories returns context with None instructions."""
Expand Down Expand Up @@ -510,3 +509,87 @@ def test_validate_per_operation_thread_id_disabled_scope(self, mock_mem0_client:

# Should not raise exception even with different thread ID
provider._validate_per_operation_thread_id("different_thread")


class TestMem0ProviderBuildFilters:
"""Test the _build_filters method."""

def test_build_filters_with_user_id_only(self, mock_mem0_client: AsyncMock) -> None:
"""Test building filters with only user_id."""
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)

filters = provider._build_filters()
assert filters == {"user_id": "user123"}

def test_build_filters_with_all_parameters(self, mock_mem0_client: AsyncMock) -> None:
"""Test building filters with all initialization parameters."""
provider = Mem0Provider(
user_id="user123",
agent_id="agent456",
thread_id="thread789",
application_id="app999",
mem0_client=mock_mem0_client,
)

filters = provider._build_filters()
assert filters == {
"user_id": "user123",
"agent_id": "agent456",
"run_id": "thread789",
"app_id": "app999",
}

def test_build_filters_excludes_none_values(self, mock_mem0_client: AsyncMock) -> None:
"""Test that None values are excluded from filters."""
provider = Mem0Provider(
user_id="user123",
agent_id=None,
thread_id=None,
application_id=None,
mem0_client=mock_mem0_client,
)

filters = provider._build_filters()
assert filters == {"user_id": "user123"}
assert "agent_id" not in filters
assert "run_id" not in filters
assert "app_id" not in filters

def test_build_filters_with_per_operation_thread_id(self, mock_mem0_client: AsyncMock) -> None:
"""Test that per-operation thread ID takes precedence over base thread_id."""
provider = Mem0Provider(
user_id="user123",
thread_id="base_thread",
scope_to_per_operation_thread_id=True,
mem0_client=mock_mem0_client,
)
provider._per_operation_thread_id = "operation_thread"

filters = provider._build_filters()
assert filters == {
"user_id": "user123",
"run_id": "operation_thread", # Per-operation thread, not base_thread
}

def test_build_filters_uses_base_thread_when_no_per_operation(self, mock_mem0_client: AsyncMock) -> None:
"""Test that base thread_id is used when per-operation thread is not set."""
provider = Mem0Provider(
user_id="user123",
thread_id="base_thread",
scope_to_per_operation_thread_id=True,
mem0_client=mock_mem0_client,
)
# _per_operation_thread_id is None

filters = provider._build_filters()
assert filters == {
"user_id": "user123",
"run_id": "base_thread", # Falls back to base thread_id
}

def test_build_filters_returns_empty_dict_when_no_parameters(self, mock_mem0_client: AsyncMock) -> None:
"""Test that _build_filters returns an empty dict when no parameters are set."""
provider = Mem0Provider(mem0_client=mock_mem0_client)

filters = provider._build_filters()
assert filters == {}
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ async def main() -> None:
result = await agent.run(query)
print(f"Agent: {result}\n")

# Mem0 processes and indexes memories asynchronously.
# Wait for memories to be indexed before querying in a new thread.
# In production, consider implementing retry logic or using Mem0's
# eventual consistency handling instead of a fixed delay.
print("Waiting for memories to be processed...")
await asyncio.sleep(12) # Empirically determined delay for Mem0 indexing

print("\nRequest within a new thread:")
# Create a new thread for the agent.
# The new thread has no context of the previous conversation.
Expand Down
Loading