diff --git a/.specs/agent-memory-ontology/design.md b/.specs/agent-memory-ontology/design.md new file mode 100644 index 000000000..9192cf666 --- /dev/null +++ b/.specs/agent-memory-ontology/design.md @@ -0,0 +1,137 @@ +# Design Document: Agent Memory Ontology (Graphiti Service) + +## Overview + +Copilot Chat and Codex both ingest conversation history into Graphiti so agents can recall durable context (preferences, terminology, ownership, project state) across sessions and workspaces. Today, each client has to “decide” what schema to use (entity types, relation types) and how to format episodes. This leads to drift and makes it hard to evolve the memory graph consistently. + +This feature adds a **Graphiti-side ontology registry** so clients can select a named schema (starting with `agent_memory_v1`) and get consistent extraction behavior. It also hardens the Graphiti service’s async ingestion so queued jobs run reliably without blocking the request path. + +### Goals + +- Provide a **single source of truth** for the “agent memory” schema (entity types + relation types) on the Graphiti service. +- Allow clients to opt into the schema via an explicit `schema_id`, and default automatically for `` payloads. +- Keep ingestion **non-blocking** (fast `202 Accepted`) and resilient (job failures don’t stop the worker). +- Keep overhead low: schema should not require per-edge/per-node attribute extraction by default. + +### Non-goals + +- Changing Graphiti core extraction prompts or algorithms (`graphiti_core/*`). +- Enforcing authorization / ACLs based on ownership facts (ownership is modeled, not enforced). +- Mandating a client-side message format beyond the existing `` convention. + +## Current Architecture + +### Graphiti service ingest (today) + +- `POST /messages` accepts `group_id` and a list of message DTOs. +- Each message is enqueued into an in-process async worker queue. +- Worker calls `graphiti.add_episode(...)` which runs Graphiti core extraction and writes nodes/edges into Neo4j. + +### Pain points + +- **Schema drift:** no service-level mechanism to select/standardize `entity_types`, `edge_types`, or `edge_type_map`. +- **Async ingest reliability risk:** background jobs must not depend on per-request resources that can be closed once the HTTP request finishes. + +## Proposed Architecture + +### Ontology registry + schema selection + +Introduce a `graph_service.ontologies` module with: + +- `schema_id` strings (start with `agent_memory_v1`) +- `entity_types`, `edge_types`, `edge_type_map`, `excluded_entity_types` for each schema +- a resolver: `resolve_ontology(schema_id: str | None, message_content: str) -> Ontology | None` + +Ingest routing: + +- Extend `AddMessagesRequest` with optional `schema_id`. +- For each message: + - If `request.schema_id` is present, use it. + - Else, if `message.content` contains ``. + +## Components + +- `server/graph_service/ontologies/agent_memory_v1.py` + - Defines `agent_memory_v1` schema: entity types + edge types (docstring-driven, no fields). +- `server/graph_service/ontologies/registry.py` + - Central registry + resolver helpers. +- `server/graph_service/group_ids.py` + - Canonical group id hashing + resolver for `(scope, key)`. +- `server/graph_service/routers/groups.py` + - `POST /groups/resolve` endpoint returning canonical `group_id`. +- `server/graph_service/dto/ingest.py` + - Adds `schema_id` to `AddMessagesRequest`. +- `server/graph_service/routers/ingest.py` + - Selects ontology per message and passes types/maps into `graphiti.add_episode(...)`. + - Worker resiliency. +- `server/graph_service/main.py` / `server/graph_service/zep_graphiti.py` + - App-scoped Graphiti initialization and dependency injection. + +## Data & Control Flow + +``` +Copilot Chat / Codex + └─ POST /messages (group_id, messages[], schema_id?) + └─ enqueue jobs (202 Accepted) + └─ async worker executes sequentially + └─ Graphiti.add_episode(..., entity_types/edge_types/edge_type_map?) + └─ extract nodes + edges (LLM) + └─ write to Neo4j + build embeddings + └─ POST /search (group_ids[], query) + └─ returns relevant edges (“facts”) for recall +``` + +## Integration Points + +- **Clients** can remain unchanged if they already wrap durable/structured memory as ``. +- **Explicit opt-in**: clients may set `schema_id=agent_memory_v1` in `POST /messages` for deterministic behavior. +- **Shared group ids**: clients can call `POST /groups/resolve` once per session/workspace/user identity and cache the returned group ids. +- Docker compose: pass through `OPENAI_BASE_URL`, `MODEL_NAME`, `EMBEDDING_MODEL_NAME` so service behavior matches client/test environments. + +## Migration / Rollout Strategy + +- Backward compatible: `schema_id` is optional. +- Safe default: schema only auto-applies when ``). diff --git a/.specs/agent-memory-ontology/tasks.md b/.specs/agent-memory-ontology/tasks.md new file mode 100644 index 000000000..1987535fb --- /dev/null +++ b/.specs/agent-memory-ontology/tasks.md @@ -0,0 +1,29 @@ +# Implementation Plan + +- [x] 1. Add ontology specs _Requirements: 1, 2, 3, 4_ +- [x] 2. Stabilize async ingest lifecycle _Requirements: 3.1, 3.2, 3.3_ + - [x] 2.1 Keep a single Graphiti instance in `app.state` _Requirements: 3.2_ + - [x] 2.2 Start/stop worker in app lifespan _Requirements: 3.1_ + - [x] 2.3 Ensure worker continues after failures _Requirements: 3.3_ +- [x] 3. Implement `agent_memory_v1` ontology registry _Requirements: 2.1, 2.2, 2.3_ + - [x] 3.1 Add `graph_service/ontologies/agent_memory_v1.py` _Requirements: 2.1_ + - [x] 3.2 Add `graph_service/ontologies/registry.py` resolver _Requirements: 1.2, 1.3, 1.4_ +- [x] 4. Extend ingest API for schema selection _Requirements: 1.1, 1.2, 1.3, 1.4, 1.5_ + - [x] 4.1 Add `schema_id` to `AddMessagesRequest` _Requirements: 1.1_ + - [x] 4.2 Apply schema per message in `POST /messages` _Requirements: 1.2, 1.3, 1.4_ + - [x] 4.3 Validate unknown schema ids _Requirements: 1.5_ +- [x] 5. Add server tests and docs/demo _Requirements: 3.3, 4.1, 4.2_ + - [x] 5.1 Add/adjust pytest discovery so root tests stay isolated _Requirements: 3.3_ + - [x] 5.2 Add demo under `examples/` or `server/README.md` _Requirements: 4.1, 4.2_ +- [x] 6. Verify, commit, push, PR _Requirements: 1, 2, 3, 4_ +- [x] 7. Add canonical group id resolver _Requirements: 5.1, 5.2, 5.3_ + - [x] 7.1 Add group id hashing helper in `graph_service` _Requirements: 5.3_ + - [x] 7.2 Add `POST /groups/resolve` endpoint _Requirements: 5.1, 5.2_ +- [x] 8. Document identity key recommendations _Requirements: 5.4_ + - [x] 8.1 Document GitHub-login based keys in `server/README.md` _Requirements: 5.4_ +- [x] 9. Verify, commit, push, PR (origin) _Requirements: 5_ + +## Current Status Summary + +- Phase: implementation (complete; PR open) +- Next: merge PR and redeploy Graphiti service. diff --git a/docker-compose.yml b/docker-compose.yml index 1b5ba06df..3f850767c 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -21,6 +21,9 @@ services: condition: service_healthy environment: - OPENAI_API_KEY=${OPENAI_API_KEY} + - OPENAI_BASE_URL=${OPENAI_BASE_URL} + - MODEL_NAME=${MODEL_NAME} + - EMBEDDING_MODEL_NAME=${EMBEDDING_MODEL_NAME} - NEO4J_URI=bolt://neo4j:${NEO4J_PORT:-7687} - NEO4J_USER=${NEO4J_USER:-neo4j} - NEO4J_PASSWORD=${NEO4J_PASSWORD:-password} @@ -80,6 +83,9 @@ services: retries: 3 environment: - OPENAI_API_KEY=${OPENAI_API_KEY} + - OPENAI_BASE_URL=${OPENAI_BASE_URL} + - MODEL_NAME=${MODEL_NAME} + - EMBEDDING_MODEL_NAME=${EMBEDDING_MODEL_NAME} - FALKORDB_HOST=falkordb - FALKORDB_PORT=6379 - FALKORDB_DATABASE=default_db diff --git a/examples/agent_memory_ontology/README.md b/examples/agent_memory_ontology/README.md new file mode 100644 index 000000000..27b14889a --- /dev/null +++ b/examples/agent_memory_ontology/README.md @@ -0,0 +1,49 @@ +# Agent Memory Ontology Demo (`agent_memory_v1`) + +This demo shows how to ingest agent memory (preferences/terminology/etc.) into the Graphiti service using the service-owned `agent_memory_v1` schema, and then retrieve facts via `/search`. + +## Prerequisites + +- Graphiti service running (from repo root): + +```bash +docker compose up -d --build graph neo4j +``` + +- Verify health: + +```bash +curl -sS http://localhost:8000/healthcheck +``` + +## Ingest a memory directive + +```bash +curl -sS http://localhost:8000/messages \ + -H 'content-type: application/json' \ + -d '{ + "group_id": "workspace-demo", + "schema_id": "agent_memory_v1", + "messages": [{ + "role_type": "user", + "role": "user", + "content": "preference (workspace): Keep diffs small and focused.", + "source_description": "demo" + }] + }' +``` + +Ingestion is asynchronous; depending on your LLM/embedding backend, processing may take a little while. + +## Retrieve facts + +```bash +curl -sS http://localhost:8000/search \ + -H 'content-type: application/json' \ + -d '{ + "group_ids": ["workspace-demo"], + "query": "What are my preferences for diffs?", + "max_facts": 5 + }' +``` + diff --git a/pytest.ini b/pytest.ini index 7699537e0..2eacb26fe 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,4 +1,5 @@ [pytest] +testpaths = tests markers = integration: marks tests as integration tests asyncio_default_fixture_loop_scope = function diff --git a/server/README.md b/server/README.md index 626d2929f..484c9dea3 100644 --- a/server/README.md +++ b/server/README.md @@ -33,6 +33,11 @@ Only stable releases are built automatically (pre-release versions are skipped). ``` OPENAI_API_KEY=your_openai_api_key + # Optional (useful for Azure OpenAI / local gateways) + OPENAI_BASE_URL=https://api.openai.com/v1 + # Optional (defaults depend on graphiti-core) + MODEL_NAME=gpt-4o-mini + EMBEDDING_MODEL_NAME=text-embedding-3-small NEO4J_USER=your_neo4j_user NEO4J_PASSWORD=your_neo4j_password NEO4J_PORT=your_neo4j_port @@ -75,4 +80,72 @@ Only stable releases are built automatically (pre-release versions are skipped). 6. You may access the swagger docs at `http://localhost:8000/docs`. You may also access redocs at `http://localhost:8000/redoc`. -7. You may also access the neo4j browser at `http://localhost:7474` (the port depends on the neo4j instance you are using). \ No newline at end of file +7. You may also access the neo4j browser at `http://localhost:7474` (the port depends on the neo4j instance you are using). + +## Healthcheck + +- `GET /healthcheck` returns a simple JSON payload (`{"status":"healthy"}`) when the service is up. + +## Schema Selection (Agent Memory) + +`POST /messages` supports an optional `schema_id` to select a service-owned ontology. + +### `agent_memory_v1` + +Designed for agent “memory” extraction across tools like Copilot Chat and Codex (ownership, preferences, terminology, tasks). + +If `schema_id` is omitted, the service auto-selects `agent_memory_v1` when a message contains ``. + +Example: + +```bash +curl -sS http://localhost:8000/messages \\ + -H 'content-type: application/json' \\ + -d '{ + "group_id": "workspace-demo", + "schema_id": "agent_memory_v1", + "messages": [{ + "role_type": "user", + "role": "user", + "content": "terminology (workspace): \"playbook\" means \"runbook docs\"", + "source_description": "demo" + }] + }' +``` + +To retrieve facts: + +```bash +curl -sS http://localhost:8000/search \\ + -H 'content-type: application/json' \\ + -d '{ + "group_ids": ["workspace-demo"], + "query": "What does playbook mean here?", + "max_facts": 5 + }' +``` + +## Shared Memory Across Tools (Canonical Group IDs) + +To share durable memory across multiple clients (e.g. Copilot Chat + Codex), use the same `group_id` for the same scope. + +The Graphiti service can resolve a canonical `group_id` from a `(scope, key)` pair: + +```bash +curl -sS http://localhost:8000/groups/resolve \ + -H 'content-type: application/json' \ + -d '{ + "scope": "user", + "key": "github_login:yulongbai-nov" + }' +``` + +Recommended key for user scope: + +- Use GitHub login (common across tools): `github_login:` + - CLI: parse from `gh auth status` + - VS Code: use the GitHub auth session account label + +## Troubleshooting + +- If `POST /messages` returns `202` but no episodes/facts appear, ensure you are running a build that keeps a single Graphiti client alive for background jobs (app-scoped client + app-scoped worker). Rebuild/redeploy the container (`docker compose up -d --build`). diff --git a/server/graph_service/dto/__init__.py b/server/graph_service/dto/__init__.py index 375c9c432..51e7f07dd 100644 --- a/server/graph_service/dto/__init__.py +++ b/server/graph_service/dto/__init__.py @@ -1,4 +1,5 @@ from .common import Message, Result +from .groups import ResolveGroupIdRequest, ResolveGroupIdResponse from .ingest import AddEntityNodeRequest, AddMessagesRequest from .retrieve import FactResult, GetMemoryRequest, GetMemoryResponse, SearchQuery, SearchResults @@ -12,4 +13,6 @@ 'Result', 'GetMemoryRequest', 'GetMemoryResponse', + 'ResolveGroupIdRequest', + 'ResolveGroupIdResponse', ] diff --git a/server/graph_service/dto/groups.py b/server/graph_service/dto/groups.py new file mode 100644 index 000000000..be715b29c --- /dev/null +++ b/server/graph_service/dto/groups.py @@ -0,0 +1,14 @@ +from typing import Literal + +from pydantic import BaseModel, Field + + +class ResolveGroupIdRequest(BaseModel): + scope: Literal['user', 'workspace', 'session'] = Field( + ..., description='The scope to resolve a group id for' + ) + key: str = Field(..., min_length=1, description='Stable key used to derive the group id') + + +class ResolveGroupIdResponse(BaseModel): + group_id: str diff --git a/server/graph_service/dto/ingest.py b/server/graph_service/dto/ingest.py index 9b0159c85..47ccea31c 100644 --- a/server/graph_service/dto/ingest.py +++ b/server/graph_service/dto/ingest.py @@ -5,6 +5,10 @@ class AddMessagesRequest(BaseModel): group_id: str = Field(..., description='The group id of the messages to add') + schema_id: str | None = Field( + default=None, + description='Optional schema id to apply during ingestion (e.g. agent_memory_v1)', + ) messages: list[Message] = Field(..., description='The messages to add') diff --git a/server/graph_service/group_ids.py b/server/graph_service/group_ids.py new file mode 100644 index 000000000..447e84f3f --- /dev/null +++ b/server/graph_service/group_ids.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +import hashlib +from typing import Literal + +Scope = Literal['user', 'workspace', 'session'] + + +def resolve_group_id(scope: Scope, key: str, prefix: str = 'graphiti') -> str: + if not key: + raise ValueError('key must not be empty') + + digest = hashlib.sha256(key.encode('utf-8')).hexdigest()[:32] + return f'{prefix}_{scope}_{digest}' diff --git a/server/graph_service/main.py b/server/graph_service/main.py index e85638ecc..42e41109a 100644 --- a/server/graph_service/main.py +++ b/server/graph_service/main.py @@ -4,17 +4,21 @@ from fastapi.responses import JSONResponse from graph_service.config import get_settings -from graph_service.routers import ingest, retrieve -from graph_service.zep_graphiti import initialize_graphiti +from graph_service.routers import groups, ingest, retrieve +from graph_service.zep_graphiti import create_graphiti @asynccontextmanager -async def lifespan(_: FastAPI): +async def lifespan(app: FastAPI): settings = get_settings() - await initialize_graphiti(settings) + graphiti = create_graphiti(settings) + await graphiti.build_indices_and_constraints() + app.state.graphiti = graphiti + await ingest.async_worker.start() yield # Shutdown - # No need to close Graphiti here, as it's handled per-request + await ingest.async_worker.stop() + await graphiti.close() app = FastAPI(lifespan=lifespan) @@ -22,6 +26,7 @@ async def lifespan(_: FastAPI): app.include_router(retrieve.router) app.include_router(ingest.router) +app.include_router(groups.router) @app.get('/healthcheck') diff --git a/server/graph_service/ontologies/__init__.py b/server/graph_service/ontologies/__init__.py new file mode 100644 index 000000000..6822f1a84 --- /dev/null +++ b/server/graph_service/ontologies/__init__.py @@ -0,0 +1,3 @@ +from graph_service.ontologies.registry import Ontology, is_known_schema_id, resolve_ontology + +__all__ = ['Ontology', 'is_known_schema_id', 'resolve_ontology'] diff --git a/server/graph_service/ontologies/agent_memory_v1.py b/server/graph_service/ontologies/agent_memory_v1.py new file mode 100644 index 000000000..949b6ee77 --- /dev/null +++ b/server/graph_service/ontologies/agent_memory_v1.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +from collections.abc import Mapping + +from pydantic import BaseModel + +SCHEMA_ID = 'agent_memory_v1' + + +class UserEntity(BaseModel): + """A stable identifier for the human user interacting with an agent. + + Prefer non-PII identifiers (hashed keys or provider ids) such as `github:12345` or `user:hash`. + """ + + +class WorkspaceEntity(BaseModel): + """A stable identifier for the current project/workspace (repo + working directory context). + + Examples: `repo:owner/name`, `workspace:/abs/path`. + """ + + +class SessionEntity(BaseModel): + """A stable identifier for a single interactive session/run. + + Examples: `session:`, `chat:`. + """ + + +ENTITY_TYPES: Mapping[str, type[BaseModel]] = { + 'User': UserEntity, + 'Workspace': WorkspaceEntity, + 'Session': SessionEntity, +} + + +class OwnsRelation(BaseModel): + """Use for explicit ownership relations. + + Examples: + - `User owns Workspace` + - `User owns Session` + - `Workspace owns Asset` (repo, folder, file, branch) + """ + + +class PrefersRelation(BaseModel): + """Use for stable preferences and defaults that should influence future agent behavior. + + Examples: + - Output preferences (concise, bullet-first, include diffs, etc.) + - Tooling preferences (`rg` over `grep`, `just` over `make`, etc.) + - Workflow preferences (spec-first, small commits, run tests before push) + """ + + +class MeansRelation(BaseModel): + """Use for workspace-specific terminology mapping. + + Examples: + - `'playbook' means 'runbook docs'` + - `'MCP' means 'Model Context Protocol'` (if project-specific) + """ + + +class WorkingOnRelation(BaseModel): + """Use for active goals, tasks, or planned work items. + + Examples: + - `Workspace is working on 'graphiti memory integration'` + - `User is working on 'fix flaky tests'` + """ + + +class DecidedRelation(BaseModel): + """Use for decisions/policies agreed for a repo or project. + + Examples: + - `Workspace decided to keep changes behind feature flags` + - `Workspace decided to use Neo4j for memory backend` + """ + + +class LearnedRelation(BaseModel): + """Use for generalized lessons learned or durable technical insights. + + Examples: + - `User learned that /healthcheck is the Graphiti health endpoint` + - `Workspace learned that background jobs must not capture closed resources` + """ + + +class BlockedByRelation(BaseModel): + """Use for explicit dependency/blocker relationships between tasks or work items.""" + + +EDGE_TYPES: Mapping[str, type[BaseModel]] = { + 'OWNS': OwnsRelation, + 'PREFERS': PrefersRelation, + 'MEANS': MeansRelation, + 'WORKING_ON': WorkingOnRelation, + 'DECIDED': DecidedRelation, + 'LEARNED': LearnedRelation, + 'BLOCKED_BY': BlockedByRelation, +} + + +EDGE_TYPE_MAP: Mapping[tuple[str, str], list[str]] = { + ('Entity', 'Entity'): list(EDGE_TYPES.keys()), +} diff --git a/server/graph_service/ontologies/registry.py b/server/graph_service/ontologies/registry.py new file mode 100644 index 000000000..00e94a249 --- /dev/null +++ b/server/graph_service/ontologies/registry.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal + +from pydantic import BaseModel + +from graph_service.ontologies.agent_memory_v1 import EDGE_TYPE_MAP as AGENT_MEMORY_V1_EDGE_TYPE_MAP +from graph_service.ontologies.agent_memory_v1 import EDGE_TYPES as AGENT_MEMORY_V1_EDGE_TYPES +from graph_service.ontologies.agent_memory_v1 import ENTITY_TYPES as AGENT_MEMORY_V1_ENTITY_TYPES +from graph_service.ontologies.agent_memory_v1 import SCHEMA_ID as AGENT_MEMORY_V1_SCHEMA_ID + +SchemaId = Literal['agent_memory_v1'] + + +@dataclass(frozen=True) +class Ontology: + schema_id: SchemaId + entity_types: dict[str, type[BaseModel]] | None = None + excluded_entity_types: list[str] | None = None + edge_types: dict[str, type[BaseModel]] | None = None + edge_type_map: dict[tuple[str, str], list[str]] | None = None + + +_ONTOLOGIES: dict[SchemaId, Ontology] = { + AGENT_MEMORY_V1_SCHEMA_ID: Ontology( + schema_id=AGENT_MEMORY_V1_SCHEMA_ID, + entity_types=dict(AGENT_MEMORY_V1_ENTITY_TYPES), + edge_types=dict(AGENT_MEMORY_V1_EDGE_TYPES), + edge_type_map=dict(AGENT_MEMORY_V1_EDGE_TYPE_MAP), + ), +} + + +def resolve_ontology(schema_id: str | None, message_content: str) -> Ontology | None: + if schema_id is not None: + return _ONTOLOGIES.get(schema_id) # type: ignore[arg-type] + + if ' bool: + return schema_id in _ONTOLOGIES # type: ignore[operator] diff --git a/server/graph_service/routers/groups.py b/server/graph_service/routers/groups.py new file mode 100644 index 000000000..8b01876dc --- /dev/null +++ b/server/graph_service/routers/groups.py @@ -0,0 +1,11 @@ +from fastapi import APIRouter + +from graph_service.dto import ResolveGroupIdRequest, ResolveGroupIdResponse +from graph_service.group_ids import resolve_group_id + +router = APIRouter() + + +@router.post('/groups/resolve') +async def resolve_group(request: ResolveGroupIdRequest) -> ResolveGroupIdResponse: + return ResolveGroupIdResponse(group_id=resolve_group_id(request.scope, request.key)) diff --git a/server/graph_service/routers/ingest.py b/server/graph_service/routers/ingest.py index d03563105..067cb4352 100644 --- a/server/graph_service/routers/ingest.py +++ b/server/graph_service/routers/ingest.py @@ -1,14 +1,18 @@ import asyncio -from contextlib import asynccontextmanager +import logging from functools import partial +from typing import Any, cast -from fastapi import APIRouter, FastAPI, status +from fastapi import APIRouter, HTTPException, status from graphiti_core.nodes import EpisodeType # type: ignore from graphiti_core.utils.maintenance.graph_data_operations import clear_data # type: ignore from graph_service.dto import AddEntityNodeRequest, AddMessagesRequest, Message, Result +from graph_service.ontologies import is_known_schema_id, resolve_ontology from graph_service.zep_graphiti import ZepGraphitiDep +logger = logging.getLogger(__name__) + class AsyncWorker: def __init__(self): @@ -17,12 +21,17 @@ def __init__(self): async def worker(self): while True: + job = None try: - print(f'Got a job: (size of remaining queue: {self.queue.qsize()})') job = await self.queue.get() await job() except asyncio.CancelledError: break + except Exception: + logger.exception('Graphiti background job failed.') + finally: + if job is not None: + self.queue.task_done() async def start(self): self.task = asyncio.create_task(self.worker()) @@ -37,15 +46,7 @@ async def stop(self): async_worker = AsyncWorker() - -@asynccontextmanager -async def lifespan(_: FastAPI): - await async_worker.start() - yield - await async_worker.stop() - - -router = APIRouter(lifespan=lifespan) +router = APIRouter() @router.post('/messages', status_code=status.HTTP_202_ACCEPTED) @@ -53,7 +54,14 @@ async def add_messages( request: AddMessagesRequest, graphiti: ZepGraphitiDep, ): + if request.schema_id is not None and not is_known_schema_id(request.schema_id): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f'Unknown schema_id: {request.schema_id}', + ) + async def add_messages_task(m: Message): + ontology = resolve_ontology(request.schema_id, m.content) await graphiti.add_episode( uuid=m.uuid, group_id=request.group_id, @@ -62,6 +70,9 @@ async def add_messages_task(m: Message): reference_time=m.timestamp, source=EpisodeType.message, source_description=m.source_description, + entity_types=cast(Any, ontology.entity_types) if ontology else None, + edge_types=cast(Any, ontology.edge_types) if ontology else None, + edge_type_map=ontology.edge_type_map if ontology else None, ) for m in request.messages: diff --git a/server/graph_service/zep_graphiti.py b/server/graph_service/zep_graphiti.py index 097c9f391..eb9ff3a87 100644 --- a/server/graph_service/zep_graphiti.py +++ b/server/graph_service/zep_graphiti.py @@ -1,11 +1,17 @@ import logging from typing import Annotated -from fastapi import Depends, HTTPException +from fastapi import Depends, HTTPException, Request from graphiti_core import Graphiti # type: ignore +from graphiti_core.cross_encoder import CrossEncoderClient, OpenAIRerankerClient # type: ignore from graphiti_core.edges import EntityEdge # type: ignore +from graphiti_core.embedder import ( # type: ignore + EmbedderClient, + OpenAIEmbedder, + OpenAIEmbedderConfig, +) from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError, NodeNotFoundError -from graphiti_core.llm_client import LLMClient # type: ignore +from graphiti_core.llm_client import LLMClient, LLMConfig, OpenAIClient # type: ignore from graphiti_core.nodes import EntityNode, EpisodicNode # type: ignore from graph_service.config import ZepEnvDep @@ -15,8 +21,23 @@ class ZepGraphiti(Graphiti): - def __init__(self, uri: str, user: str, password: str, llm_client: LLMClient | None = None): - super().__init__(uri, user, password, llm_client) + def __init__( + self, + uri: str, + user: str, + password: str, + llm_client: LLMClient | None = None, + embedder: EmbedderClient | None = None, + cross_encoder: CrossEncoderClient | None = None, + ): + super().__init__( + uri, + user, + password, + llm_client=llm_client, + embedder=embedder, + cross_encoder=cross_encoder, + ) async def save_entity_node(self, name: str, uuid: str, group_id: str, summary: str = ''): new_node = EntityNode( @@ -71,32 +92,50 @@ async def delete_episodic_node(self, uuid: str): raise HTTPException(status_code=404, detail=e.message) from e -async def get_graphiti(settings: ZepEnvDep): - client = ZepGraphiti( - uri=settings.neo4j_uri, - user=settings.neo4j_user, - password=settings.neo4j_password, +def build_graphiti_clients( + settings: ZepEnvDep, +) -> tuple[OpenAIClient, OpenAIEmbedder, OpenAIRerankerClient]: + llm_config = LLMConfig( + api_key=settings.openai_api_key, + base_url=settings.openai_base_url, + model=settings.model_name, + ) + llm_client = OpenAIClient(config=llm_config) + + embedding_model = settings.embedding_model_name or 'text-embedding-3-small' + embedder_config = OpenAIEmbedderConfig( + api_key=settings.openai_api_key, + base_url=settings.openai_base_url, + embedding_model=embedding_model, + ) + embedder = OpenAIEmbedder(config=embedder_config) + + cross_encoder_config = LLMConfig( + api_key=settings.openai_api_key, + base_url=settings.openai_base_url, ) - if settings.openai_base_url is not None: - client.llm_client.config.base_url = settings.openai_base_url - if settings.openai_api_key is not None: - client.llm_client.config.api_key = settings.openai_api_key - if settings.model_name is not None: - client.llm_client.model = settings.model_name + cross_encoder = OpenAIRerankerClient(config=cross_encoder_config) - try: - yield client - finally: - await client.close() + return llm_client, embedder, cross_encoder -async def initialize_graphiti(settings: ZepEnvDep): - client = ZepGraphiti( +def create_graphiti(settings: ZepEnvDep) -> ZepGraphiti: + llm_client, embedder, cross_encoder = build_graphiti_clients(settings) + return ZepGraphiti( uri=settings.neo4j_uri, user=settings.neo4j_user, password=settings.neo4j_password, + llm_client=llm_client, + embedder=embedder, + cross_encoder=cross_encoder, ) - await client.build_indices_and_constraints() + + +async def get_graphiti(request: Request) -> ZepGraphiti: + graphiti = getattr(request.app.state, 'graphiti', None) + if graphiti is None: + raise HTTPException(status_code=503, detail='Graphiti is not initialized') + return graphiti def get_fact_result_from_edge(edge: EntityEdge): diff --git a/server/tests/conftest.py b/server/tests/conftest.py new file mode 100644 index 000000000..8d607c9b6 --- /dev/null +++ b/server/tests/conftest.py @@ -0,0 +1,5 @@ +import sys +from pathlib import Path + +# Ensure `server/` is on `sys.path` so tests can import `graph_service.*` when running from repo root. +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) diff --git a/server/tests/test_async_worker.py b/server/tests/test_async_worker.py new file mode 100644 index 000000000..3b08763df --- /dev/null +++ b/server/tests/test_async_worker.py @@ -0,0 +1,26 @@ +import asyncio + +import pytest + +from graph_service.routers.ingest import AsyncWorker + + +@pytest.mark.asyncio +async def test_async_worker_continues_after_failure(): + worker = AsyncWorker() + await worker.start() + + ran_second_job = asyncio.Event() + + async def failing_job(): + raise RuntimeError('boom') + + async def ok_job(): + ran_second_job.set() + + await worker.queue.put(failing_job) + await worker.queue.put(ok_job) + + await asyncio.wait_for(ran_second_job.wait(), timeout=2) + + await worker.stop() diff --git a/server/tests/test_group_ids.py b/server/tests/test_group_ids.py new file mode 100644 index 000000000..b117602e7 --- /dev/null +++ b/server/tests/test_group_ids.py @@ -0,0 +1,25 @@ +import re + +import pytest + +from graph_service.group_ids import resolve_group_id + + +def test_resolve_group_id_is_deterministic(): + assert resolve_group_id('user', 'github_login:octocat') == resolve_group_id( + 'user', 'github_login:octocat' + ) + + +def test_resolve_group_id_changes_with_scope(): + assert resolve_group_id('user', 'k') != resolve_group_id('workspace', 'k') + + +def test_resolve_group_id_matches_group_id_charset(): + group_id = resolve_group_id('user', 'github_login:octocat') + assert re.match(r'^[a-zA-Z0-9_-]+$', group_id) + + +def test_resolve_group_id_rejects_empty_key(): + with pytest.raises(ValueError): + resolve_group_id('user', '') diff --git a/server/tests/test_groups_router.py b/server/tests/test_groups_router.py new file mode 100644 index 000000000..ae15f93d0 --- /dev/null +++ b/server/tests/test_groups_router.py @@ -0,0 +1,10 @@ +import pytest + +from graph_service.dto import ResolveGroupIdRequest +from graph_service.routers.groups import resolve_group + + +@pytest.mark.asyncio +async def test_resolve_group_returns_group_id(): + response = await resolve_group(ResolveGroupIdRequest(scope='user', key='github_login:octocat')) + assert response.group_id.startswith('graphiti_user_') diff --git a/server/tests/test_ingest_schema_validation.py b/server/tests/test_ingest_schema_validation.py new file mode 100644 index 000000000..eea457176 --- /dev/null +++ b/server/tests/test_ingest_schema_validation.py @@ -0,0 +1,30 @@ +from datetime import datetime, timezone +from unittest.mock import Mock + +import pytest +from fastapi import HTTPException + +from graph_service.dto import AddMessagesRequest, Message +from graph_service.routers.ingest import add_messages + + +@pytest.mark.asyncio +async def test_add_messages_rejects_unknown_schema_id(): + request = AddMessagesRequest( + group_id='test', + schema_id='unknown_schema', + messages=[ + Message( + content='hello', + role_type='user', + role='user', + timestamp=datetime.now(timezone.utc), + source_description='test', + ) + ], + ) + + with pytest.raises(HTTPException) as exc: + await add_messages(request, graphiti=Mock()) + + assert exc.value.status_code == 400 diff --git a/server/tests/test_ontology_registry.py b/server/tests/test_ontology_registry.py new file mode 100644 index 000000000..ac701e2ef --- /dev/null +++ b/server/tests/test_ontology_registry.py @@ -0,0 +1,18 @@ +from graph_service.ontologies import is_known_schema_id, resolve_ontology + + +def test_is_known_schema_id(): + assert is_known_schema_id('agent_memory_v1') is True + assert is_known_schema_id('does_not_exist') is False + + +def test_resolve_ontology_auto_detects_graphiti_episode(): + ontology = resolve_ontology( + None, '...' + ) + assert ontology is not None + assert ontology.schema_id == 'agent_memory_v1' + + +def test_resolve_ontology_uses_default_when_no_schema(): + assert resolve_ontology(None, 'plain message') is None diff --git a/server/tests/test_zep_graphiti.py b/server/tests/test_zep_graphiti.py new file mode 100644 index 000000000..c9bc3493e --- /dev/null +++ b/server/tests/test_zep_graphiti.py @@ -0,0 +1,21 @@ +from graph_service.config import Settings +from graph_service.zep_graphiti import build_graphiti_clients + + +def test_build_graphiti_clients_applies_openai_base_url_to_clients(): + settings = Settings( + openai_api_key='test_key', + openai_base_url='http://example.test/v1', + model_name='gpt-4o-mini', + embedding_model_name='text-embedding-3-small', + neo4j_uri='bolt://neo4j:7687', + neo4j_user='neo4j', + neo4j_password='password', + ) + + llm_client, embedder, cross_encoder = build_graphiti_clients(settings) + + assert str(llm_client.client.base_url).rstrip('/') == 'http://example.test/v1' + assert str(embedder.client.base_url).rstrip('/') == 'http://example.test/v1' + assert str(cross_encoder.client.base_url).rstrip('/') == 'http://example.test/v1' + assert embedder.config.embedding_model == 'text-embedding-3-small'