Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions python/packages/mem0/agent_framework_mem0/_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ def __init__(
user_id: The user ID for scoping memories or None.
scope_to_per_operation_thread_id: Whether to scope memories to per-operation thread ID.
context_prompt: The prompt to prepend to retrieved memories.

Note:
Currently, filters are set at initialization time via user_id, agent_id, thread_id,
and application_id. Run-level filtering support is planned for a future release.
"""
should_close_client = False
if mem0_client is None:
Expand Down Expand Up @@ -150,11 +154,12 @@ async def invoking(self, messages: ChatMessage | MutableSequence[ChatMessage], *
if not input_text.strip():
return Context(messages=None)

# Build filters from init parameters
filters = self._build_filters()

search_response: MemorySearchResponse_v1_1 | MemorySearchResponse_v2 = await self.mem0_client.search( # type: ignore[misc]
query=input_text,
user_id=self.user_id,
agent_id=self.agent_id,
run_id=self._per_operation_thread_id if self.scope_to_per_operation_thread_id else self.thread_id,
filters=filters,
)

# Depending on the API version, the response schema varies slightly
Expand Down Expand Up @@ -185,6 +190,29 @@ def _validate_filters(self) -> None:
"At least one of the filters: agent_id, user_id, application_id, or thread_id is required."
)

def _build_filters(self) -> dict[str, Any]:
"""Build search filters from initialization parameters.

Returns:
Filter dictionary for mem0 v2 search API containing initialization parameters.
In the v2 API, filters holds the user_id, agent_id, run_id (thread_id), and app_id
(application_id) which are required for scoping memory search operations.
"""
filters: dict[str, Any] = {}

if self.user_id:
filters["user_id"] = self.user_id
if self.agent_id:
filters["agent_id"] = self.agent_id
if self.scope_to_per_operation_thread_id and self._per_operation_thread_id:
filters["run_id"] = self._per_operation_thread_id
elif self.thread_id:
filters["run_id"] = self.thread_id
if self.application_id:
filters["app_id"] = self.application_id

return filters

def _validate_per_operation_thread_id(self, thread_id: str | None) -> None:
"""Validates that a new thread ID doesn't conflict with an existing one when scoped.

Expand Down
29 changes: 25 additions & 4 deletions python/packages/mem0/tests/test_mem0_context_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ async def test_model_invoking_single_message(self, mock_mem0_client: AsyncMock)
mock_mem0_client.search.assert_called_once()
call_args = mock_mem0_client.search.call_args
assert call_args.kwargs["query"] == "What's the weather?"
assert call_args.kwargs["user_id"] == "user123"
assert call_args.kwargs["filters"] == {"user_id": "user123"}

assert isinstance(context, Context)
expected_instructions = (
Expand Down Expand Up @@ -373,8 +373,7 @@ async def test_model_invoking_with_agent_id(self, mock_mem0_client: AsyncMock) -
await provider.invoking(message)

call_args = mock_mem0_client.search.call_args
assert call_args.kwargs["agent_id"] == "agent123"
assert call_args.kwargs["user_id"] is None
assert call_args.kwargs["filters"] == {"agent_id": "agent123"}

async def test_model_invoking_with_scope_to_per_operation_thread_id(self, mock_mem0_client: AsyncMock) -> None:
"""Test invoking with scope_to_per_operation_thread_id enabled."""
Expand All @@ -392,7 +391,7 @@ async def test_model_invoking_with_scope_to_per_operation_thread_id(self, mock_m
await provider.invoking(message)

call_args = mock_mem0_client.search.call_args
assert call_args.kwargs["run_id"] == "operation_thread"
assert call_args.kwargs["filters"] == {"user_id": "user123", "run_id": "operation_thread"}

async def test_model_invoking_no_memories_returns_none_instructions(self, mock_mem0_client: AsyncMock) -> None:
"""Test that no memories returns context with None instructions."""
Expand Down Expand Up @@ -510,3 +509,25 @@ def test_validate_per_operation_thread_id_disabled_scope(self, mock_mem0_client:

# Should not raise exception even with different thread ID
provider._validate_per_operation_thread_id("different_thread")


class TestMem0ProviderBuildFilters:
"""Test the _build_filters method."""

def test_build_filters_with_all_parameters(self, mock_mem0_client: AsyncMock) -> None:
"""Test building filters with all initialization parameters."""
provider = Mem0Provider(
user_id="user123",
agent_id="agent456",
thread_id="thread789",
application_id="app999",
mem0_client=mock_mem0_client,
)

filters = provider._build_filters()
assert filters == {
"user_id": "user123",
"agent_id": "agent456",
"run_id": "thread789",
"app_id": "app999",
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ async def main() -> None:
result = await agent.run(query)
print(f"Agent: {result}\n")

# Mem0 processes and indexes memories asynchronously.
# Wait for memories to be indexed before querying in a new thread.
# In production, consider implementing retry logic or using Mem0's
# eventual consistency handling instead of a fixed delay.
print("Waiting for memories to be processed...")
await asyncio.sleep(12) # Empirically determined delay for Mem0 indexing

print("\nRequest within a new thread:")
# Create a new thread for the agent.
# The new thread has no context of the previous conversation.
Expand Down
Loading