diff --git a/graphiti_core/driver/driver.py b/graphiti_core/driver/driver.py index c1a355f31..504599aa3 100644 --- a/graphiti_core/driver/driver.py +++ b/graphiti_core/driver/driver.py @@ -110,6 +110,19 @@ def with_database(self, database: str) -> 'GraphDriver': async def build_indices_and_constraints(self, delete_existing: bool = False): raise NotImplementedError() + async def ensure_edge_type_index(self, edge_type: str) -> None: + """ + Ensure a fulltext index exists for a custom edge type. + + This method should be called when custom edge types are used to enable + BM25 fulltext search on those relationship types. The default implementation + is a no-op; drivers that support fulltext indexes should override this. + + Args: + edge_type: The relationship type name (e.g., 'SANCTION', 'OWNERSHIP') + """ + pass # Default no-op for drivers that don't support fulltext indexes + def clone(self, database: str) -> 'GraphDriver': """Clone the driver with a different database or graph name.""" return self diff --git a/graphiti_core/driver/falkordb_driver.py b/graphiti_core/driver/falkordb_driver.py index de469d53a..7b24d8fb1 100644 --- a/graphiti_core/driver/falkordb_driver.py +++ b/graphiti_core/driver/falkordb_driver.py @@ -249,6 +249,30 @@ async def build_indices_and_constraints(self, delete_existing=False): for query in index_queries: await self.execute_query(query) + async def ensure_edge_type_index(self, edge_type: str) -> None: + """ + Ensure a fulltext index exists for a custom edge type. + + This method creates a fulltext index on (name, fact, group_id) for the + specified edge type if it doesn't already exist. This enables BM25 + fulltext search on custom relationship types. + + Args: + edge_type: The relationship type name (e.g., 'SANCTION', 'OWNERSHIP') + """ + if edge_type == 'RELATES_TO': + # RELATES_TO index is created by build_indices_and_constraints + return + + query = f"""CREATE FULLTEXT INDEX FOR ()-[e:{edge_type}]-() ON (e.name, e.fact, e.group_id)""" + try: + await self.execute_query(query) + logger.info(f'Created fulltext index for edge type: {edge_type}') + except Exception as e: + # Index may already exist + if 'already indexed' not in str(e).lower() and 'already exists' not in str(e).lower(): + logger.warning(f'Failed to create fulltext index for {edge_type}: {e}') + def clone(self, database: str) -> 'GraphDriver': """ Returns a shallow copy of this driver with a different default database. diff --git a/graphiti_core/driver/search_interface/search_interface.py b/graphiti_core/driver/search_interface/search_interface.py index 0abf024d5..2393ff985 100644 --- a/graphiti_core/driver/search_interface/search_interface.py +++ b/graphiti_core/driver/search_interface/search_interface.py @@ -31,6 +31,7 @@ async def edge_fulltext_search( search_filter: Any, group_ids: list[str] | None = None, limit: int = 100, + edge_types: list[str] | None = None, ) -> list[Any]: raise NotImplementedError diff --git a/graphiti_core/graph_queries.py b/graphiti_core/graph_queries.py index 8e4cca4e3..2b5ba7024 100644 --- a/graphiti_core/graph_queries.py +++ b/graphiti_core/graph_queries.py @@ -150,9 +150,12 @@ def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str: return f'vector.similarity.cosine({vec1}, {vec2})' -def get_relationships_query(name: str, limit: int, provider: GraphProvider) -> str: +def get_relationships_query( + name: str, limit: int, provider: GraphProvider, edge_type: str | None = None +) -> str: if provider == GraphProvider.FALKORDB: - label = NEO4J_TO_FALKORDB_MAPPING[name] + # Use provided edge_type or fall back to mapping + label = edge_type or NEO4J_TO_FALKORDB_MAPPING[name] return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)" if provider == GraphProvider.KUZU: diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index af7d0b344..bb8cf03c0 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -744,6 +744,11 @@ async def add_episode_endpoint(episode_data: EpisodeData): else {('Entity', 'Entity'): []} ) + # Ensure fulltext indexes exist for custom edge types + if edge_types is not None: + for edge_type in edge_types.keys(): + await self.driver.ensure_edge_type_index(edge_type) + # Extract and resolve nodes extracted_nodes = await extract_nodes( self.clients, @@ -905,6 +910,11 @@ async def add_episode_bulk( else {('Entity', 'Entity'): []} ) + # Ensure fulltext indexes exist for custom edge types + if edge_types is not None: + for edge_type in edge_types.keys(): + await self.driver.ensure_edge_type_index(edge_type) + episodes = [ await EpisodicNode.get_by_uuid(self.driver, episode.uuid) if episode.uuid is not None diff --git a/graphiti_core/prompts/extract_edges.py b/graphiti_core/prompts/extract_edges.py index 6f653c3c5..b7e42ac63 100644 --- a/graphiti_core/prompts/extract_edges.py +++ b/graphiti_core/prompts/extract_edges.py @@ -102,10 +102,16 @@ def edge(context: dict[str, Any]) -> list[Message]: - are clearly stated or unambiguously implied in the CURRENT MESSAGE, and can be represented as edges in a knowledge graph. - Facts should include entity names rather than pronouns whenever possible. -- The FACT TYPES provide a list of the most important types of facts, make sure to extract facts of these types -- The FACT TYPES are not an exhaustive list, extract all facts from the message even if they do not fit into one - of the FACT TYPES -- The FACT TYPES each contain their fact_type_signature which represents the source and target entity types. +## RELATION TYPE SELECTION (CRITICAL - READ CAREFULLY) +Your ONLY valid options for `relation_type` are: +1. One of the EXACT `fact_type_name` values from the FACT TYPES list above +2. RELATES_TO (as fallback when no fact type matches) + +STRICT RULES: +- Copy the `fact_type_name` exactly as written (e.g., SPOUSE_OF, BORN_IN, DIRECTED, LOCATED_IN) +- If a relationship doesn't match any FACT TYPE, you MUST use RELATES_TO - no exceptions +- ANY invented type (e.g., NAMED_AFTER, FOUNDED, WORKS_AT, MARRIED_TO, ACQUIRED_BY) will be REJECTED +- Do NOT modify names or add prefixes/suffixes (wrong: PERSON_SPOUSE_OF_PERSON, correct: SPOUSE_OF) You may use information from the PREVIOUS MESSAGES only to disambiguate references or support continuity. @@ -117,7 +123,7 @@ def edge(context: dict[str, Any]) -> list[Message]: 1. **Entity ID Validation**: `source_entity_id` and `target_entity_id` must use only the `id` values from the ENTITIES list provided above. - **CRITICAL**: Using IDs not in the list will cause the edge to be rejected 2. Each fact must involve two **distinct** entities. -3. Use a SCREAMING_SNAKE_CASE string as the `relation_type` (e.g., FOUNDED, WORKS_AT). +3. **CRITICAL**: The `relation_type` MUST be one of the FACT TYPES above, or RELATES_TO if no match. 4. Do not emit duplicate or semantically redundant facts. 5. The `fact` should closely paraphrase the original source sentence(s). Do not verbatim quote the original text. 6. Use `REFERENCE_TIME` to resolve vague or relative temporal expressions (e.g., "last week"). diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py index af98f560a..8b87aee2e 100644 --- a/graphiti_core/search/search.py +++ b/graphiti_core/search/search.py @@ -110,13 +110,57 @@ async def search( # if group_ids is empty, set it to None group_ids = group_ids if group_ids and group_ids != [''] else None - ( - (edges, edge_reranker_scores), - (nodes, node_reranker_scores), - (episodes, episode_reranker_scores), - (communities, community_reranker_scores), - ) = await semaphore_gather( - edge_search( + + # Two-phase search when bfs_origin_node_uuids is not provided: + # 1. First, run node search to find relevant nodes + # 2. Then, run edge search using those nodes as BFS origins + # This ensures edge BFS can traverse from nodes found by BM25/cosine similarity + if bfs_origin_node_uuids is None and config.node_config is not None: + # Phase 1: Find nodes first (in parallel with episode/community search) + ( + (nodes, node_reranker_scores), + (episodes, episode_reranker_scores), + (communities, community_reranker_scores), + ) = await semaphore_gather( + node_search( + driver, + cross_encoder, + query, + search_vector, + group_ids, + config.node_config, + search_filter, + center_node_uuid, + None, + config.limit, + config.reranker_min_score, + ), + episode_search( + driver, + cross_encoder, + query, + search_vector, + group_ids, + config.episode_config, + search_filter, + config.limit, + config.reranker_min_score, + ), + community_search( + driver, + cross_encoder, + query, + search_vector, + group_ids, + config.community_config, + config.limit, + config.reranker_min_score, + ), + ) + + # Phase 2: Run edge search with found nodes as BFS origins + node_uuids_for_edge_bfs = [node.uuid for node in nodes] + (edges, edge_reranker_scores) = await edge_search( driver, cross_encoder, query, @@ -125,45 +169,66 @@ async def search( config.edge_config, search_filter, center_node_uuid, - bfs_origin_node_uuids, - config.limit, - config.reranker_min_score, - ), - node_search( - driver, - cross_encoder, - query, - search_vector, - group_ids, - config.node_config, - search_filter, - center_node_uuid, - bfs_origin_node_uuids, - config.limit, - config.reranker_min_score, - ), - episode_search( - driver, - cross_encoder, - query, - search_vector, - group_ids, - config.episode_config, - search_filter, - config.limit, - config.reranker_min_score, - ), - community_search( - driver, - cross_encoder, - query, - search_vector, - group_ids, - config.community_config, + node_uuids_for_edge_bfs, config.limit, config.reranker_min_score, - ), - ) + ) + else: + # Original parallel search when bfs_origin_node_uuids is provided or no node config + ( + (edges, edge_reranker_scores), + (nodes, node_reranker_scores), + (episodes, episode_reranker_scores), + (communities, community_reranker_scores), + ) = await semaphore_gather( + edge_search( + driver, + cross_encoder, + query, + search_vector, + group_ids, + config.edge_config, + search_filter, + center_node_uuid, + bfs_origin_node_uuids, + config.limit, + config.reranker_min_score, + ), + node_search( + driver, + cross_encoder, + query, + search_vector, + group_ids, + config.node_config, + search_filter, + center_node_uuid, + bfs_origin_node_uuids, + config.limit, + config.reranker_min_score, + ), + episode_search( + driver, + cross_encoder, + query, + search_vector, + group_ids, + config.episode_config, + search_filter, + config.limit, + config.reranker_min_score, + ), + community_search( + driver, + cross_encoder, + query, + search_vector, + group_ids, + config.community_config, + config.limit, + config.reranker_min_score, + ), + ) results = SearchResults( edges=edges, @@ -203,7 +268,9 @@ async def edge_search( search_tasks = [] if EdgeSearchMethod.bm25 in config.search_methods: search_tasks.append( - edge_fulltext_search(driver, query, search_filter, group_ids, 2 * limit) + edge_fulltext_search( + driver, query, search_filter, group_ids, 2 * limit, config.edge_types + ) ) if EdgeSearchMethod.cosine_similarity in config.search_methods: search_tasks.append( diff --git a/graphiti_core/search/search_config.py b/graphiti_core/search/search_config.py index 7e5714f5f..c825fd648 100644 --- a/graphiti_core/search/search_config.py +++ b/graphiti_core/search/search_config.py @@ -83,6 +83,11 @@ class EdgeSearchConfig(BaseModel): sim_min_score: float = Field(default=DEFAULT_MIN_SCORE) mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA) bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH) + edge_types: list[str] | None = Field( + default=None, + description='List of edge types to search. If None, defaults to RELATES_TO. ' + 'Custom edge types must have fulltext indexes created.', + ) class NodeSearchConfig(BaseModel): diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 4c0e84fa7..448a373e6 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -173,10 +173,11 @@ async def edge_fulltext_search( search_filter: SearchFilters, group_ids: list[str] | None = None, limit=RELEVANT_SCHEMA_LIMIT, + edge_types: list[str] | None = None, ) -> list[EntityEdge]: if driver.search_interface: return await driver.search_interface.edge_fulltext_search( - driver, query, search_filter, group_ids, limit + driver, query, search_filter, group_ids, limit, edge_types ) # fulltext search over facts @@ -185,15 +186,9 @@ async def edge_fulltext_search( if fuzzy_query == '': return [] - match_query = """ - YIELD relationship AS rel, score - MATCH (n:Entity)-[e:RELATES_TO {uuid: rel.uuid}]->(m:Entity) - """ - if driver.provider == GraphProvider.KUZU: - match_query = """ - YIELD node, score - MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: node.uuid})-[:RELATES_TO]->(m:Entity) - """ + # Default to RELATES_TO if no edge types specified + if edge_types is None: + edge_types = ['RELATES_TO'] filter_queries, filter_params = edge_search_filter_query_constructor( search_filter, driver.provider @@ -219,8 +214,8 @@ async def edge_fulltext_search( """ UNWIND $ids as id MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - WHERE e.group_id IN $group_ids - AND id(e)=id + WHERE e.group_id IN $group_ids + AND id(e)=id """ + filter_query + """ @@ -253,7 +248,68 @@ async def edge_fulltext_search( ) else: return [] + elif driver.provider == GraphProvider.FALKORDB: + # For FalkorDB, query each edge type's fulltext index and combine results + all_records: list[Any] = [] + for edge_type in edge_types: + match_query = f""" + YIELD relationship AS rel, score + MATCH (n:Entity)-[e:{edge_type} {{uuid: rel.uuid}}]->(m:Entity) + """ + + query_str = ( + get_relationships_query( + 'edge_name_and_fact', limit=limit, provider=driver.provider, edge_type=edge_type + ) + + match_query + + filter_query + + """ + WITH e, score, n, m + RETURN + """ + + get_entity_edge_return_query(driver.provider) + + """ + ORDER BY score DESC + LIMIT $limit + """ + ) + + try: + records, _, _ = await driver.execute_query( + query_str, + query=fuzzy_query, + limit=limit, + routing_='r', + **filter_params, + ) + all_records.extend(records) + except Exception as e: + # Index may not exist for this edge type - skip silently + logger.debug(f'Fulltext search skipped for edge type {edge_type}: {e}') + continue + + # Dedupe by uuid and sort by score + seen_uuids: set[str] = set() + unique_records = [] + for record in all_records: + uuid = record.get('uuid') or record[0] + if uuid not in seen_uuids: + seen_uuids.add(uuid) + unique_records.append(record) + + records = unique_records[:limit] else: + # For other providers (Neo4j, Kuzu), use existing behavior + match_query = """ + YIELD relationship AS rel, score + MATCH (n:Entity)-[e:RELATES_TO {uuid: rel.uuid}]->(m:Entity) + """ + if driver.provider == GraphProvider.KUZU: + match_query = """ + YIELD node, score + MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: node.uuid})-[:RELATES_TO]->(m:Entity) + """ + query = ( get_relationships_query('edge_name_and_fact', limit=limit, provider=driver.provider) + match_query @@ -495,13 +551,16 @@ async def edge_bfs_search( records.extend(sub_records) else: if driver.provider == GraphProvider.NEPTUNE: + # Use wildcard traversal to support custom edge types (LOCATED_IN, MEMBER_OF, etc.) + # The Entity-to-Entity match naturally excludes MENTIONS edges (which connect Episodic to Entity) + # NOTE: We traverse in BOTH directions (-[*1..N]-) to find edges regardless of direction query = ( f""" UNWIND $bfs_origin_node_uuids AS origin_uuid - MATCH path = (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS *1..{bfs_max_depth}]->(n:Entity) + MATCH path = (origin {{uuid: origin_uuid}})-[*1..{bfs_max_depth}]-(n:Entity) WHERE origin:Entity OR origin:Episodic UNWIND relationships(path) AS rel - MATCH (n:Entity)-[e:RELATES_TO {{uuid: rel.uuid}}]-(m:Entity) + MATCH (n:Entity)-[e {{uuid: rel.uuid}}]-(m:Entity) """ + filter_query + """ @@ -522,12 +581,15 @@ async def edge_bfs_search( """ ) else: + # Use wildcard traversal to support custom edge types (LOCATED_IN, MEMBER_OF, etc.) + # The Entity-to-Entity match naturally excludes MENTIONS edges (which connect Episodic to Entity) + # NOTE: We traverse in BOTH directions (-[*1..N]-) to find edges regardless of direction query = ( f""" UNWIND $bfs_origin_node_uuids AS origin_uuid - MATCH path = (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(:Entity) + MATCH path = (origin {{uuid: origin_uuid}})-[*1..{bfs_max_depth}]-(:Entity) UNWIND relationships(path) AS rel - MATCH (n:Entity)-[e:RELATES_TO {{uuid: rel.uuid}}]-(m:Entity) + MATCH (n:Entity)-[e {{uuid: rel.uuid}}]-(m:Entity) """ + filter_query + """ @@ -788,10 +850,11 @@ async def node_bfs_search( if filter_queries: filter_query = ' AND ' + (' AND '.join(filter_queries)) + # Use wildcard traversal to support custom edge types (LOCATED_IN, MEMBER_OF, etc.) match_queries = [ f""" UNWIND $bfs_origin_node_uuids AS origin_uuid - MATCH (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity) + MATCH (origin {{uuid: origin_uuid}})-[*1..{bfs_max_depth}]->(n:Entity) WHERE n.group_id = origin.group_id """ ] @@ -800,7 +863,7 @@ async def node_bfs_search( match_queries = [ f""" UNWIND $bfs_origin_node_uuids AS origin_uuid - MATCH (origin {{uuid: origin_uuid}})-[e:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity) + MATCH (origin {{uuid: origin_uuid}})-[*1..{bfs_max_depth}]->(n:Entity) WHERE origin:Entity OR origin.Episode AND n.group_id = origin.group_id """ diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index 4a861b1b1..0ae8497de 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -181,6 +181,14 @@ async def add_nodes_and_edges_bulk_tx( if driver.provider == GraphProvider.KUZU: attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {} entity_data['attributes'] = json.dumps(attributes) + elif driver.provider == GraphProvider.FALKORDB: + # FalkorDB needs complex types JSON-serialized + if node.attributes: + for k, v in node.attributes.items(): + if isinstance(v, (dict, list)): + entity_data[k] = json.dumps(v) + else: + entity_data[k] = v else: entity_data.update(node.attributes or {}) @@ -208,6 +216,14 @@ async def add_nodes_and_edges_bulk_tx( if driver.provider == GraphProvider.KUZU: attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {} edge_data['attributes'] = json.dumps(attributes) + elif driver.provider == GraphProvider.FALKORDB: + # FalkorDB needs complex types JSON-serialized + if edge.attributes: + for k, v in edge.attributes.items(): + if isinstance(v, (dict, list)): + edge_data[k] = json.dumps(v) + else: + edge_data[k] = v else: edge_data.update(edge.attributes or {}) @@ -245,10 +261,35 @@ async def add_nodes_and_edges_bulk_tx( get_episodic_edge_save_bulk_query(driver.provider), episodic_edges=[edge.model_dump() for edge in episodic_edges], ) - await tx.run( - get_entity_edge_save_bulk_query(driver.provider), - entity_edges=edges, - ) + # FalkorDB: group edges by type and run separate queries to support custom edge types + if driver.provider == GraphProvider.FALKORDB: + from collections import defaultdict + + edges_by_type: dict[str, list] = defaultdict(list) + for edge in edges: + edge_type = edge.get('name', 'RELATES_TO') or 'RELATES_TO' + edges_by_type[edge_type].append(edge) + + for edge_type, typed_edges in edges_by_type.items(): + # Sanitize edge type to prevent injection + safe_edge_type = ''.join(c for c in edge_type if c.isalnum() or c == '_') + if not safe_edge_type: + safe_edge_type = 'RELATES_TO' + query = f""" + UNWIND $entity_edges AS edge + MATCH (source:Entity {{uuid: edge.source_node_uuid}}) + MATCH (target:Entity {{uuid: edge.target_node_uuid}}) + MERGE (source)-[r:{safe_edge_type} {{uuid: edge.uuid}}]->(target) + SET r = edge + SET r.fact_embedding = vecf32(edge.fact_embedding) + RETURN edge.uuid AS uuid + """ + await tx.run(query, entity_edges=typed_edges) + else: + await tx.run( + get_entity_edge_save_bulk_query(driver.provider), + entity_edges=edges, + ) async def extract_nodes_and_edges_bulk( diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index 41ce31d1d..140870666 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -339,6 +339,15 @@ async def resolve_extracted_edges( for extracted_edge, extracted_edge_types in zip(extracted_edges, edge_types_lst, strict=True): allowed_type_names = set(extracted_edge_types) is_custom_name = extracted_edge.name in custom_type_names + is_default_name = extracted_edge.name == DEFAULT_EDGE_NAME + + # If custom types are defined, enforce strict type checking + if custom_type_names and not is_custom_name and not is_default_name: + # LLM invented a type not in the schema - convert to RELATES_TO + logger.debug(f'Edge type {extracted_edge.name} not in schema, converting to {DEFAULT_EDGE_NAME}') + extracted_edge.name = DEFAULT_EDGE_NAME + continue + if not allowed_type_names: # No custom types are valid for this node pairing. Keep LLM generated # labels, but flip disallowed custom names back to the default. diff --git a/mcp_server/config.yaml b/mcp_server/config.yaml new file mode 100644 index 000000000..ba71a021f --- /dev/null +++ b/mcp_server/config.yaml @@ -0,0 +1,51 @@ +# Graphiti MCP Server Configuration for FalkorDB + +server: + transport: "stdio" # stdio for Claude Desktop + host: "127.0.0.1" + port: 8000 + +llm: + provider: "openai" + model: "gpt-4o-mini" + max_tokens: 4096 + + providers: + openai: + api_key: ${OPENAI_API_KEY} + api_url: "https://api.openai.com/v1" + +embedder: + provider: "openai" + model: "text-embedding-3-small" + dimensions: 1536 + + providers: + openai: + api_key: ${OPENAI_API_KEY} + api_url: "https://api.openai.com/v1" + +database: + provider: "falkordb" + + providers: + falkordb: + uri: "redis://localhost:6379" + username: "knowledge" + password: "knowledgeG3#" + database: "aviation-mode-hybrid" + +graphiti: + group_id: "\\_" + user_id: "claude_desktop" + entity_types: + - name: "Aircraft" + description: "Aircraft entity with registration and type information" + - name: "Occurrence" + description: "Aviation incident or occurrence" + - name: "Airport" + description: "Airport or aerodrome locations" + - name: "Airline" + description: "Airline operators" + - name: "Manufacturer" + description: "Aircraft manufacturers" diff --git a/mcp_server/src/graphiti_mcp_server.py.bak b/mcp_server/src/graphiti_mcp_server.py.bak new file mode 100644 index 000000000..833bc5d93 --- /dev/null +++ b/mcp_server/src/graphiti_mcp_server.py.bak @@ -0,0 +1,965 @@ +#!/usr/bin/env python3 +""" +Graphiti MCP Server - Exposes Graphiti functionality through the Model Context Protocol (MCP) +""" + +import argparse +import asyncio +import logging +import os +import sys +from pathlib import Path +from typing import Any, Optional + +from dotenv import load_dotenv +from graphiti_core import Graphiti +from graphiti_core.edges import EntityEdge +from graphiti_core.nodes import EpisodeType, EpisodicNode +from graphiti_core.search.search_filters import SearchFilters +from graphiti_core.utils.maintenance.graph_data_operations import clear_data +from mcp.server.fastmcp import FastMCP +from pydantic import BaseModel +from starlette.responses import JSONResponse + +from config.schema import GraphitiConfig, ServerConfig +from models.response_types import ( + EpisodeSearchResponse, + ErrorResponse, + FactSearchResponse, + NodeResult, + NodeSearchResponse, + StatusResponse, + SuccessResponse, +) +from services.factories import DatabaseDriverFactory, EmbedderFactory, LLMClientFactory +from services.queue_service import QueueService +from utils.formatting import format_fact_result + +# Load .env file from mcp_server directory +mcp_server_dir = Path(__file__).parent.parent +env_file = mcp_server_dir / '.env' +if env_file.exists(): + load_dotenv(env_file) +else: + # Try current working directory as fallback + load_dotenv() + + +# Semaphore limit for concurrent Graphiti operations. +# +# This controls how many episodes can be processed simultaneously. Each episode +# processing involves multiple LLM calls (entity extraction, deduplication, etc.), +# so the actual number of concurrent LLM requests will be higher. +# +# TUNING GUIDELINES: +# +# LLM Provider Rate Limits (requests per minute): +# - OpenAI Tier 1 (free): 3 RPM -> SEMAPHORE_LIMIT=1-2 +# - OpenAI Tier 2: 60 RPM -> SEMAPHORE_LIMIT=5-8 +# - OpenAI Tier 3: 500 RPM -> SEMAPHORE_LIMIT=10-15 +# - OpenAI Tier 4: 5,000 RPM -> SEMAPHORE_LIMIT=20-50 +# - Anthropic (default): 50 RPM -> SEMAPHORE_LIMIT=5-8 +# - Anthropic (high tier): 1,000 RPM -> SEMAPHORE_LIMIT=15-30 +# - Azure OpenAI (varies): Consult your quota -> adjust accordingly +# +# SYMPTOMS: +# - Too high: 429 rate limit errors, increased costs from parallel processing +# - Too low: Slow throughput, underutilized API quota +# +# MONITORING: +# - Watch logs for rate limit errors (429) +# - Monitor episode processing times +# - Check LLM provider dashboard for actual request rates +# +# DEFAULT: 10 (suitable for OpenAI Tier 3, mid-tier Anthropic) +SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 10)) + + +# Configure structured logging with timestamps +LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' +DATE_FORMAT = '%Y-%m-%d %H:%M:%S' + +logging.basicConfig( + level=logging.INFO, + format=LOG_FORMAT, + datefmt=DATE_FORMAT, + stream=sys.stderr, +) + +# Configure specific loggers +logging.getLogger('uvicorn').setLevel(logging.INFO) +logging.getLogger('uvicorn.access').setLevel(logging.WARNING) # Reduce access log noise +logging.getLogger('mcp.server.streamable_http_manager').setLevel( + logging.WARNING +) # Reduce MCP noise + + +# Patch uvicorn's logging config to use our format +def configure_uvicorn_logging(): + """Configure uvicorn loggers to match our format after they're created.""" + for logger_name in ['uvicorn', 'uvicorn.error', 'uvicorn.access']: + uvicorn_logger = logging.getLogger(logger_name) + # Remove existing handlers and add our own with proper formatting + uvicorn_logger.handlers.clear() + handler = logging.StreamHandler(sys.stderr) + handler.setFormatter(logging.Formatter(LOG_FORMAT, datefmt=DATE_FORMAT)) + uvicorn_logger.addHandler(handler) + uvicorn_logger.propagate = False + + +logger = logging.getLogger(__name__) + +# Create global config instance - will be properly initialized later +config: GraphitiConfig + +# MCP server instructions +GRAPHITI_MCP_INSTRUCTIONS = """ +Graphiti is a memory service for AI agents built on a knowledge graph. Graphiti performs well +with dynamic data such as user interactions, changing enterprise data, and external information. + +Graphiti transforms information into a richly connected knowledge network, allowing you to +capture relationships between concepts, entities, and information. The system organizes data as episodes +(content snippets), nodes (entities), and facts (relationships between entities), creating a dynamic, +queryable memory store that evolves with new information. Graphiti supports multiple data formats, including +structured JSON data, enabling seamless integration with existing data pipelines and systems. + +Facts contain temporal metadata, allowing you to track the time of creation and whether a fact is invalid +(superseded by new information). + +Key capabilities: +1. Add episodes (text, messages, or JSON) to the knowledge graph with the add_memory tool +2. Search for nodes (entities) in the graph using natural language queries with search_nodes +3. Find relevant facts (relationships between entities) with search_facts +4. Retrieve specific entity edges or episodes by UUID +5. Manage the knowledge graph with tools like delete_episode, delete_entity_edge, and clear_graph + +The server connects to a database for persistent storage and uses language models for certain operations. +Each piece of information is organized by group_id, allowing you to maintain separate knowledge domains. + +When adding information, provide descriptive names and detailed content to improve search quality. +When searching, use specific queries and consider filtering by group_id for more relevant results. + +For optimal performance, ensure the database is properly configured and accessible, and valid +API keys are provided for any language model operations. +""" + +# MCP server instance +mcp = FastMCP( + 'Graphiti Agent Memory', + instructions=GRAPHITI_MCP_INSTRUCTIONS, +) + +# Global services +graphiti_service: Optional['GraphitiService'] = None +queue_service: QueueService | None = None + +# Global client for backward compatibility +graphiti_client: Graphiti | None = None +semaphore: asyncio.Semaphore + + +class GraphitiService: + """Graphiti service using the unified configuration system.""" + + def __init__(self, config: GraphitiConfig, semaphore_limit: int = 10): + self.config = config + self.semaphore_limit = semaphore_limit + self.semaphore = asyncio.Semaphore(semaphore_limit) + self.client: Graphiti | None = None + self.entity_types = None + + async def initialize(self) -> None: + """Initialize the Graphiti client with factory-created components.""" + try: + # Create clients using factories + llm_client = None + embedder_client = None + + # Create LLM client based on configured provider + try: + llm_client = LLMClientFactory.create(self.config.llm) + except Exception as e: + logger.warning(f'Failed to create LLM client: {e}') + + # Create embedder client based on configured provider + try: + embedder_client = EmbedderFactory.create(self.config.embedder) + except Exception as e: + logger.warning(f'Failed to create embedder client: {e}') + + # Get database configuration + db_config = DatabaseDriverFactory.create_config(self.config.database) + + # Build entity types from configuration + custom_types = None + if self.config.graphiti.entity_types: + custom_types = {} + for entity_type in self.config.graphiti.entity_types: + # Create a dynamic Pydantic model for each entity type + # Note: Don't use 'name' as it's a protected Pydantic attribute + entity_model = type( + entity_type.name, + (BaseModel,), + { + '__doc__': entity_type.description, + }, + ) + custom_types[entity_type.name] = entity_model + + # Store entity types for later use + self.entity_types = custom_types + + # Initialize Graphiti client with appropriate driver + try: + if self.config.database.provider.lower() == 'falkordb': + # For FalkorDB, create a FalkorDriver instance directly + from graphiti_core.driver.falkordb_driver import FalkorDriver + + falkor_driver = FalkorDriver( + host=db_config['host'], + port=db_config['port'], + password=db_config['password'], + database=db_config['database'], + ) + + self.client = Graphiti( + graph_driver=falkor_driver, + llm_client=llm_client, + embedder=embedder_client, + max_coroutines=self.semaphore_limit, + ) + else: + # For Neo4j (default), use the original approach + self.client = Graphiti( + uri=db_config['uri'], + user=db_config['user'], + password=db_config['password'], + llm_client=llm_client, + embedder=embedder_client, + max_coroutines=self.semaphore_limit, + ) + except Exception as db_error: + # Check for connection errors + error_msg = str(db_error).lower() + if 'connection refused' in error_msg or 'could not connect' in error_msg: + db_provider = self.config.database.provider + if db_provider.lower() == 'falkordb': + raise RuntimeError( + f'\n{"=" * 70}\n' + f'Database Connection Error: FalkorDB is not running\n' + f'{"=" * 70}\n\n' + f'FalkorDB at {db_config["host"]}:{db_config["port"]} is not accessible.\n\n' + f'To start FalkorDB:\n' + f' - Using Docker Compose: cd mcp_server && docker compose up\n' + f' - Or run FalkorDB manually: docker run -p 6379:6379 falkordb/falkordb\n\n' + f'{"=" * 70}\n' + ) from db_error + elif db_provider.lower() == 'neo4j': + raise RuntimeError( + f'\n{"=" * 70}\n' + f'Database Connection Error: Neo4j is not running\n' + f'{"=" * 70}\n\n' + f'Neo4j at {db_config.get("uri", "unknown")} is not accessible.\n\n' + f'To start Neo4j:\n' + f' - Using Docker Compose: cd mcp_server && docker compose -f docker/docker-compose-neo4j.yml up\n' + f' - Or install Neo4j Desktop from: https://neo4j.com/download/\n' + f' - Or run Neo4j manually: docker run -p 7474:7474 -p 7687:7687 neo4j:latest\n\n' + f'{"=" * 70}\n' + ) from db_error + else: + raise RuntimeError( + f'\n{"=" * 70}\n' + f'Database Connection Error: {db_provider} is not running\n' + f'{"=" * 70}\n\n' + f'{db_provider} at {db_config.get("uri", "unknown")} is not accessible.\n\n' + f'Please ensure {db_provider} is running and accessible.\n\n' + f'{"=" * 70}\n' + ) from db_error + # Re-raise other errors + raise + + # Build indices + await self.client.build_indices_and_constraints() + + logger.info('Successfully initialized Graphiti client') + + # Log configuration details + if llm_client: + logger.info( + f'Using LLM provider: {self.config.llm.provider} / {self.config.llm.model}' + ) + else: + logger.info('No LLM client configured - entity extraction will be limited') + + if embedder_client: + logger.info(f'Using Embedder provider: {self.config.embedder.provider}') + else: + logger.info('No Embedder client configured - search will be limited') + + if self.entity_types: + entity_type_names = list(self.entity_types.keys()) + logger.info(f'Using custom entity types: {", ".join(entity_type_names)}') + else: + logger.info('Using default entity types') + + logger.info(f'Using database: {self.config.database.provider}') + logger.info(f'Using group_id: {self.config.graphiti.group_id}') + + except Exception as e: + logger.error(f'Failed to initialize Graphiti client: {e}') + raise + + async def get_client(self) -> Graphiti: + """Get the Graphiti client, initializing if necessary.""" + if self.client is None: + await self.initialize() + if self.client is None: + raise RuntimeError('Failed to initialize Graphiti client') + return self.client + + +@mcp.tool() +async def add_memory( + name: str, + episode_body: str, + group_id: str | None = None, + source: str = 'text', + source_description: str = '', + uuid: str | None = None, +) -> SuccessResponse | ErrorResponse: + """Add an episode to memory. This is the primary way to add information to the graph. + + This function returns immediately and processes the episode addition in the background. + Episodes for the same group_id are processed sequentially to avoid race conditions. + + Args: + name (str): Name of the episode + episode_body (str): The content of the episode to persist to memory. When source='json', this must be a + properly escaped JSON string, not a raw Python dictionary. The JSON data will be + automatically processed to extract entities and relationships. + group_id (str, optional): A unique ID for this graph. If not provided, uses the default group_id from CLI + or a generated one. + source (str, optional): Source type, must be one of: + - 'text': For plain text content (default) + - 'json': For structured data + - 'message': For conversation-style content + source_description (str, optional): Description of the source + uuid (str, optional): Optional UUID for the episode + + Examples: + # Adding plain text content + add_memory( + name="Company News", + episode_body="Acme Corp announced a new product line today.", + source="text", + source_description="news article", + group_id="some_arbitrary_string" + ) + + # Adding structured JSON data + # NOTE: episode_body should be a JSON string (standard JSON escaping) + add_memory( + name="Customer Profile", + episode_body='{"company": {"name": "Acme Technologies"}, "products": [{"id": "P001", "name": "CloudSync"}, {"id": "P002", "name": "DataMiner"}]}', + source="json", + source_description="CRM data" + ) + """ + global graphiti_service, queue_service + + if graphiti_service is None or queue_service is None: + return ErrorResponse(error='Services not initialized') + + try: + # Use the provided group_id or fall back to the default from config + effective_group_id = group_id or config.graphiti.group_id + + # Try to parse the source as an EpisodeType enum, with fallback to text + episode_type = EpisodeType.text # Default + if source: + try: + episode_type = EpisodeType[source.lower()] + except (KeyError, AttributeError): + # If the source doesn't match any enum value, use text as default + logger.warning(f"Unknown source type '{source}', using 'text' as default") + episode_type = EpisodeType.text + + # Submit to queue service for async processing + await queue_service.add_episode( + group_id=effective_group_id, + name=name, + content=episode_body, + source_description=source_description, + episode_type=episode_type, + entity_types=graphiti_service.entity_types, + uuid=uuid or None, # Ensure None is passed if uuid is None + ) + + return SuccessResponse( + message=f"Episode '{name}' queued for processing in group '{effective_group_id}'" + ) + except Exception as e: + error_msg = str(e) + logger.error(f'Error queuing episode: {error_msg}') + return ErrorResponse(error=f'Error queuing episode: {error_msg}') + + +@mcp.tool() +async def search_nodes( + query: str, + group_ids: list[str] | None = None, + max_nodes: int = 10, + entity_types: list[str] | None = None, +) -> NodeSearchResponse | ErrorResponse: + """Search for nodes in the graph memory. + + Args: + query: The search query + group_ids: Optional list of group IDs to filter results + max_nodes: Maximum number of nodes to return (default: 10) + entity_types: Optional list of entity type names to filter by + """ + global graphiti_service + + if graphiti_service is None: + return ErrorResponse(error='Graphiti service not initialized') + + try: + client = await graphiti_service.get_client() + + # Use the provided group_ids or fall back to the default from config if none provided + effective_group_ids = ( + group_ids + if group_ids is not None + else [config.graphiti.group_id] + if config.graphiti.group_id + else [] + ) + + # Create search filters + search_filters = SearchFilters( + node_labels=entity_types, + ) + + # Use the search_ method with node search config + from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF + + results = await client.search_( + query=query, + config=NODE_HYBRID_SEARCH_RRF, + group_ids=effective_group_ids, + search_filter=search_filters, + ) + + # Extract nodes from results + nodes = results.nodes[:max_nodes] if results.nodes else [] + + if not nodes: + return NodeSearchResponse(message='No relevant nodes found', nodes=[]) + + # Format the results + node_results = [] + for node in nodes: + # Get attributes and ensure no embeddings are included + attrs = node.attributes if hasattr(node, 'attributes') else {} + # Remove any embedding keys that might be in attributes + attrs = {k: v for k, v in attrs.items() if 'embedding' not in k.lower()} + + node_results.append( + NodeResult( + uuid=node.uuid, + name=node.name, + labels=node.labels if node.labels else [], + created_at=node.created_at.isoformat() if node.created_at else None, + summary=node.summary, + group_id=node.group_id, + attributes=attrs, + ) + ) + + return NodeSearchResponse(message='Nodes retrieved successfully', nodes=node_results) + except Exception as e: + error_msg = str(e) + logger.error(f'Error searching nodes: {error_msg}') + return ErrorResponse(error=f'Error searching nodes: {error_msg}') + + +@mcp.tool() +async def search_memory_facts( + query: str, + group_ids: list[str] | None = None, + max_facts: int = 10, + center_node_uuid: str | None = None, +) -> FactSearchResponse | ErrorResponse: + """Search the graph memory for relevant facts. + + Args: + query: The search query + group_ids: Optional list of group IDs to filter results + max_facts: Maximum number of facts to return (default: 10) + center_node_uuid: Optional UUID of a node to center the search around + """ + global graphiti_service + + if graphiti_service is None: + return ErrorResponse(error='Graphiti service not initialized') + + try: + # Validate max_facts parameter + if max_facts <= 0: + return ErrorResponse(error='max_facts must be a positive integer') + + client = await graphiti_service.get_client() + + # Use the provided group_ids or fall back to the default from config if none provided + effective_group_ids = ( + group_ids + if group_ids is not None + else [config.graphiti.group_id] + if config.graphiti.group_id + else [] + ) + + relevant_edges = await client.search( + group_ids=effective_group_ids, + query=query, + num_results=max_facts, + center_node_uuid=center_node_uuid, + ) + + if not relevant_edges: + return FactSearchResponse(message='No relevant facts found', facts=[]) + + facts = [format_fact_result(edge) for edge in relevant_edges] + return FactSearchResponse(message='Facts retrieved successfully', facts=facts) + except Exception as e: + error_msg = str(e) + logger.error(f'Error searching facts: {error_msg}') + return ErrorResponse(error=f'Error searching facts: {error_msg}') + + +@mcp.tool() +async def delete_entity_edge(uuid: str) -> SuccessResponse | ErrorResponse: + """Delete an entity edge from the graph memory. + + Args: + uuid: UUID of the entity edge to delete + """ + global graphiti_service + + if graphiti_service is None: + return ErrorResponse(error='Graphiti service not initialized') + + try: + client = await graphiti_service.get_client() + + # Get the entity edge by UUID + entity_edge = await EntityEdge.get_by_uuid(client.driver, uuid) + # Delete the edge using its delete method + await entity_edge.delete(client.driver) + return SuccessResponse(message=f'Entity edge with UUID {uuid} deleted successfully') + except Exception as e: + error_msg = str(e) + logger.error(f'Error deleting entity edge: {error_msg}') + return ErrorResponse(error=f'Error deleting entity edge: {error_msg}') + + +@mcp.tool() +async def delete_episode(uuid: str) -> SuccessResponse | ErrorResponse: + """Delete an episode from the graph memory. + + Args: + uuid: UUID of the episode to delete + """ + global graphiti_service + + if graphiti_service is None: + return ErrorResponse(error='Graphiti service not initialized') + + try: + client = await graphiti_service.get_client() + + # Get the episodic node by UUID + episodic_node = await EpisodicNode.get_by_uuid(client.driver, uuid) + # Delete the node using its delete method + await episodic_node.delete(client.driver) + return SuccessResponse(message=f'Episode with UUID {uuid} deleted successfully') + except Exception as e: + error_msg = str(e) + logger.error(f'Error deleting episode: {error_msg}') + return ErrorResponse(error=f'Error deleting episode: {error_msg}') + + +@mcp.tool() +async def get_entity_edge(uuid: str) -> dict[str, Any] | ErrorResponse: + """Get an entity edge from the graph memory by its UUID. + + Args: + uuid: UUID of the entity edge to retrieve + """ + global graphiti_service + + if graphiti_service is None: + return ErrorResponse(error='Graphiti service not initialized') + + try: + client = await graphiti_service.get_client() + + # Get the entity edge directly using the EntityEdge class method + entity_edge = await EntityEdge.get_by_uuid(client.driver, uuid) + + # Use the format_fact_result function to serialize the edge + # Return the Python dict directly - MCP will handle serialization + return format_fact_result(entity_edge) + except Exception as e: + error_msg = str(e) + logger.error(f'Error getting entity edge: {error_msg}') + return ErrorResponse(error=f'Error getting entity edge: {error_msg}') + + +@mcp.tool() +async def get_episodes( + group_ids: list[str] | None = None, + max_episodes: int = 10, +) -> EpisodeSearchResponse | ErrorResponse: + """Get episodes from the graph memory. + + Args: + group_ids: Optional list of group IDs to filter results + max_episodes: Maximum number of episodes to return (default: 10) + """ + global graphiti_service + + if graphiti_service is None: + return ErrorResponse(error='Graphiti service not initialized') + + try: + client = await graphiti_service.get_client() + + # Use the provided group_ids or fall back to the default from config if none provided + effective_group_ids = ( + group_ids + if group_ids is not None + else [config.graphiti.group_id] + if config.graphiti.group_id + else [] + ) + + # Get episodes from the driver directly + from graphiti_core.nodes import EpisodicNode + + if effective_group_ids: + episodes = await EpisodicNode.get_by_group_ids( + client.driver, effective_group_ids, limit=max_episodes + ) + else: + # If no group IDs, we need to use a different approach + # For now, return empty list when no group IDs specified + episodes = [] + + if not episodes: + return EpisodeSearchResponse(message='No episodes found', episodes=[]) + + # Format the results + episode_results = [] + for episode in episodes: + episode_dict = { + 'uuid': episode.uuid, + 'name': episode.name, + 'content': episode.content, + 'created_at': episode.created_at.isoformat() if episode.created_at else None, + 'source': episode.source.value + if hasattr(episode.source, 'value') + else str(episode.source), + 'source_description': episode.source_description, + 'group_id': episode.group_id, + } + episode_results.append(episode_dict) + + return EpisodeSearchResponse( + message='Episodes retrieved successfully', episodes=episode_results + ) + except Exception as e: + error_msg = str(e) + logger.error(f'Error getting episodes: {error_msg}') + return ErrorResponse(error=f'Error getting episodes: {error_msg}') + + +@mcp.tool() +async def clear_graph(group_ids: list[str] | None = None) -> SuccessResponse | ErrorResponse: + """Clear all data from the graph for specified group IDs. + + Args: + group_ids: Optional list of group IDs to clear. If not provided, clears the default group. + """ + global graphiti_service + + if graphiti_service is None: + return ErrorResponse(error='Graphiti service not initialized') + + try: + client = await graphiti_service.get_client() + + # Use the provided group_ids or fall back to the default from config if none provided + effective_group_ids = ( + group_ids or [config.graphiti.group_id] if config.graphiti.group_id else [] + ) + + if not effective_group_ids: + return ErrorResponse(error='No group IDs specified for clearing') + + # Clear data for the specified group IDs + await clear_data(client.driver, group_ids=effective_group_ids) + + return SuccessResponse( + message=f'Graph data cleared successfully for group IDs: {", ".join(effective_group_ids)}' + ) + except Exception as e: + error_msg = str(e) + logger.error(f'Error clearing graph: {error_msg}') + return ErrorResponse(error=f'Error clearing graph: {error_msg}') + + +@mcp.tool() +async def get_status() -> StatusResponse: + """Get the status of the Graphiti MCP server and database connection.""" + global graphiti_service + + if graphiti_service is None: + return StatusResponse(status='error', message='Graphiti service not initialized') + + try: + client = await graphiti_service.get_client() + + # Test database connection with a simple query + async with client.driver.session() as session: + result = await session.run('MATCH (n) RETURN count(n) as count') + # Consume the result to verify query execution + if result: + _ = [record async for record in result] + + # Use the provider from the service's config, not the global + provider_name = graphiti_service.config.database.provider + return StatusResponse( + status='ok', + message=f'Graphiti MCP server is running and connected to {provider_name} database', + ) + except Exception as e: + error_msg = str(e) + logger.error(f'Error checking database connection: {error_msg}') + return StatusResponse( + status='error', + message=f'Graphiti MCP server is running but database connection failed: {error_msg}', + ) + + +@mcp.custom_route('/health', methods=['GET']) +async def health_check(request) -> JSONResponse: + """Health check endpoint for Docker and load balancers.""" + return JSONResponse({'status': 'healthy', 'service': 'graphiti-mcp'}) + + +async def initialize_server() -> ServerConfig: + """Parse CLI arguments and initialize the Graphiti server configuration.""" + global config, graphiti_service, queue_service, graphiti_client, semaphore + + parser = argparse.ArgumentParser( + description='Run the Graphiti MCP server with YAML configuration support' + ) + + # Configuration file argument + # Default to config/config.yaml relative to the mcp_server directory + default_config = Path(__file__).parent.parent / 'config' / 'config.yaml' + parser.add_argument( + '--config', + type=Path, + default=default_config, + help='Path to YAML configuration file (default: config/config.yaml)', + ) + + # Transport arguments + parser.add_argument( + '--transport', + choices=['sse', 'stdio', 'http'], + help='Transport to use: http (recommended, default), stdio (standard I/O), or sse (deprecated)', + ) + parser.add_argument( + '--host', + help='Host to bind the MCP server to', + ) + parser.add_argument( + '--port', + type=int, + help='Port to bind the MCP server to', + ) + + # Provider selection arguments + parser.add_argument( + '--llm-provider', + choices=['openai', 'azure_openai', 'anthropic', 'gemini', 'groq'], + help='LLM provider to use', + ) + parser.add_argument( + '--embedder-provider', + choices=['openai', 'azure_openai', 'gemini', 'voyage'], + help='Embedder provider to use', + ) + parser.add_argument( + '--database-provider', + choices=['neo4j', 'falkordb'], + help='Database provider to use', + ) + + # LLM configuration arguments + parser.add_argument('--model', help='Model name to use with the LLM client') + parser.add_argument('--small-model', help='Small model name to use with the LLM client') + parser.add_argument( + '--temperature', type=float, help='Temperature setting for the LLM (0.0-2.0)' + ) + + # Embedder configuration arguments + parser.add_argument('--embedder-model', help='Model name to use with the embedder') + + # Graphiti-specific arguments + parser.add_argument( + '--group-id', + help='Namespace for the graph. If not provided, uses config file or generates random UUID.', + ) + parser.add_argument( + '--user-id', + help='User ID for tracking operations', + ) + parser.add_argument( + '--destroy-graph', + action='store_true', + help='Destroy all Graphiti graphs on startup', + ) + + args = parser.parse_args() + + # Set config path in environment for the settings to pick up + if args.config: + os.environ['CONFIG_PATH'] = str(args.config) + + # Load configuration with environment variables and YAML + config = GraphitiConfig() + + # Apply CLI overrides + config.apply_cli_overrides(args) + + # Also apply legacy CLI args for backward compatibility + if hasattr(args, 'destroy_graph'): + config.destroy_graph = args.destroy_graph + + # Log configuration details + logger.info('Using configuration:') + logger.info(f' - LLM: {config.llm.provider} / {config.llm.model}') + logger.info(f' - Embedder: {config.embedder.provider} / {config.embedder.model}') + logger.info(f' - Database: {config.database.provider}') + logger.info(f' - Group ID: {config.graphiti.group_id}') + logger.info(f' - Transport: {config.server.transport}') + + # Log graphiti-core version + try: + import graphiti_core + + graphiti_version = getattr(graphiti_core, '__version__', 'unknown') + logger.info(f' - Graphiti Core: {graphiti_version}') + except Exception: + # Check for Docker-stored version file + version_file = Path('/app/.graphiti-core-version') + if version_file.exists(): + graphiti_version = version_file.read_text().strip() + logger.info(f' - Graphiti Core: {graphiti_version}') + else: + logger.info(' - Graphiti Core: version unavailable') + + # Handle graph destruction if requested + if hasattr(config, 'destroy_graph') and config.destroy_graph: + logger.warning('Destroying all Graphiti graphs as requested...') + temp_service = GraphitiService(config, SEMAPHORE_LIMIT) + await temp_service.initialize() + client = await temp_service.get_client() + await clear_data(client.driver) + logger.info('All graphs destroyed') + + # Initialize services + graphiti_service = GraphitiService(config, SEMAPHORE_LIMIT) + queue_service = QueueService() + await graphiti_service.initialize() + + # Set global client for backward compatibility + graphiti_client = await graphiti_service.get_client() + semaphore = graphiti_service.semaphore + + # Initialize queue service with the client + await queue_service.initialize(graphiti_client) + + # Set MCP server settings + if config.server.host: + mcp.settings.host = config.server.host + if config.server.port: + mcp.settings.port = config.server.port + + # Return MCP configuration for transport + return config.server + + +async def run_mcp_server(): + """Run the MCP server in the current event loop.""" + # Initialize the server + mcp_config = await initialize_server() + + # Run the server with configured transport + logger.info(f'Starting MCP server with transport: {mcp_config.transport}') + if mcp_config.transport == 'stdio': + await mcp.run_stdio_async() + elif mcp_config.transport == 'sse': + logger.info( + f'Running MCP server with SSE transport on {mcp.settings.host}:{mcp.settings.port}' + ) + logger.info(f'Access the server at: http://{mcp.settings.host}:{mcp.settings.port}/sse') + await mcp.run_sse_async() + elif mcp_config.transport == 'http': + # Use localhost for display if binding to 0.0.0.0 + display_host = 'localhost' if mcp.settings.host == '0.0.0.0' else mcp.settings.host + logger.info( + f'Running MCP server with streamable HTTP transport on {mcp.settings.host}:{mcp.settings.port}' + ) + logger.info('=' * 60) + logger.info('MCP Server Access Information:') + logger.info(f' Base URL: http://{display_host}:{mcp.settings.port}/') + logger.info(f' MCP Endpoint: http://{display_host}:{mcp.settings.port}/mcp/') + logger.info(' Transport: HTTP (streamable)') + + # Show FalkorDB Browser UI access if enabled + if os.environ.get('BROWSER', '1') == '1': + logger.info(f' FalkorDB Browser UI: http://{display_host}:3000/') + + logger.info('=' * 60) + logger.info('For MCP clients, connect to the /mcp/ endpoint above') + + # Configure uvicorn logging to match our format + configure_uvicorn_logging() + + await mcp.run_streamable_http_async() + else: + raise ValueError( + f'Unsupported transport: {mcp_config.transport}. Use "sse", "stdio", or "http"' + ) + + +def main(): + """Main function to run the Graphiti MCP server.""" + try: + # Run everything in a single event loop + asyncio.run(run_mcp_server()) + except KeyboardInterrupt: + logger.info('Server shutting down...') + except Exception as e: + logger.error(f'Error initializing Graphiti MCP server: {str(e)}') + raise + + +if __name__ == '__main__': + main() diff --git a/mcp_server/src/services/factories.py.bak2 b/mcp_server/src/services/factories.py.bak2 new file mode 100644 index 000000000..02f19e35d --- /dev/null +++ b/mcp_server/src/services/factories.py.bak2 @@ -0,0 +1,437 @@ +"""Factory classes for creating LLM, Embedder, and Database clients.""" + +from openai import AsyncAzureOpenAI + +from config.schema import ( + DatabaseConfig, + EmbedderConfig, + LLMConfig, +) + +# Try to import FalkorDriver if available +try: + from graphiti_core.driver.falkordb_driver import FalkorDriver # noqa: F401 + + HAS_FALKOR = True +except ImportError: + HAS_FALKOR = False + +# Kuzu support removed - FalkorDB is now the default +from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder +from graphiti_core.llm_client import LLMClient, OpenAIClient +from graphiti_core.llm_client.config import LLMConfig as GraphitiLLMConfig + +# Try to import additional providers if available +try: + from graphiti_core.embedder.azure_openai import AzureOpenAIEmbedderClient + + HAS_AZURE_EMBEDDER = True +except ImportError: + HAS_AZURE_EMBEDDER = False + +try: + from graphiti_core.embedder.gemini import GeminiEmbedder + + HAS_GEMINI_EMBEDDER = True +except ImportError: + HAS_GEMINI_EMBEDDER = False + +try: + from graphiti_core.embedder.voyage import VoyageAIEmbedder + + HAS_VOYAGE_EMBEDDER = True +except ImportError: + HAS_VOYAGE_EMBEDDER = False + +try: + from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient + + HAS_AZURE_LLM = True +except ImportError: + HAS_AZURE_LLM = False + +try: + from graphiti_core.llm_client.anthropic_client import AnthropicClient + + HAS_ANTHROPIC = True +except ImportError: + HAS_ANTHROPIC = False + +try: + from graphiti_core.llm_client.gemini_client import GeminiClient + + HAS_GEMINI = True +except ImportError: + HAS_GEMINI = False + +try: + from graphiti_core.llm_client.groq_client import GroqClient + + HAS_GROQ = True +except ImportError: + HAS_GROQ = False +from utils.utils import create_azure_credential_token_provider + + +def _validate_api_key(provider_name: str, api_key: str | None, logger) -> str: + """Validate API key is present. + + Args: + provider_name: Name of the provider (e.g., 'OpenAI', 'Anthropic') + api_key: The API key to validate + logger: Logger instance for output + + Returns: + The validated API key + + Raises: + ValueError: If API key is None or empty + """ + if not api_key: + raise ValueError( + f'{provider_name} API key is not configured. Please set the appropriate environment variable.' + ) + + logger.info(f'Creating {provider_name} client') + + return api_key + + +class LLMClientFactory: + """Factory for creating LLM clients based on configuration.""" + + @staticmethod + def create(config: LLMConfig) -> LLMClient: + """Create an LLM client based on the configured provider.""" + import logging + + logger = logging.getLogger(__name__) + + provider = config.provider.lower() + + match provider: + case 'openai': + if not config.providers.openai: + raise ValueError('OpenAI provider configuration not found') + + api_key = config.providers.openai.api_key + _validate_api_key('OpenAI', api_key, logger) + + from graphiti_core.llm_client.config import LLMConfig as CoreLLMConfig + + # Determine appropriate small model based on main model type + is_reasoning_model = ( + config.model.startswith('gpt-5') + or config.model.startswith('o1') + or config.model.startswith('o3') + ) + small_model = ( + 'gpt-5-nano' if is_reasoning_model else 'gpt-4.1-mini' + ) # Use reasoning model for small tasks if main model is reasoning + + llm_config = CoreLLMConfig( + api_key=api_key, + model=config.model, + small_model=small_model, + temperature=config.temperature, + max_tokens=config.max_tokens, + ) + + # Only pass reasoning/verbosity parameters for reasoning models (gpt-5 family) + if is_reasoning_model: + return OpenAIClient(config=llm_config, reasoning='minimal', verbosity='low') + else: + # For non-reasoning models, explicitly pass None to disable these parameters + return OpenAIClient(config=llm_config, reasoning=None, verbosity=None) + + case 'azure_openai': + if not HAS_AZURE_LLM: + raise ValueError( + 'Azure OpenAI LLM client not available in current graphiti-core version' + ) + if not config.providers.azure_openai: + raise ValueError('Azure OpenAI provider configuration not found') + azure_config = config.providers.azure_openai + + if not azure_config.api_url: + raise ValueError('Azure OpenAI API URL is required') + + # Handle Azure AD authentication if enabled + api_key: str | None = None + azure_ad_token_provider = None + if azure_config.use_azure_ad: + logger.info('Creating Azure OpenAI LLM client with Azure AD authentication') + azure_ad_token_provider = create_azure_credential_token_provider() + else: + api_key = azure_config.api_key + _validate_api_key('Azure OpenAI', api_key, logger) + + # Create the Azure OpenAI client first + azure_client = AsyncAzureOpenAI( + api_key=api_key, + azure_endpoint=azure_config.api_url, + api_version=azure_config.api_version, + azure_deployment=azure_config.deployment_name, + azure_ad_token_provider=azure_ad_token_provider, + ) + + # Then create the LLMConfig + from graphiti_core.llm_client.config import LLMConfig as CoreLLMConfig + + llm_config = CoreLLMConfig( + api_key=api_key, + base_url=azure_config.api_url, + model=config.model, + temperature=config.temperature, + max_tokens=config.max_tokens, + ) + + return AzureOpenAILLMClient( + azure_client=azure_client, + config=llm_config, + max_tokens=config.max_tokens, + ) + + case 'anthropic': + if not HAS_ANTHROPIC: + raise ValueError( + 'Anthropic client not available in current graphiti-core version' + ) + if not config.providers.anthropic: + raise ValueError('Anthropic provider configuration not found') + + api_key = config.providers.anthropic.api_key + _validate_api_key('Anthropic', api_key, logger) + + llm_config = GraphitiLLMConfig( + api_key=api_key, + model=config.model, + temperature=config.temperature, + max_tokens=config.max_tokens, + ) + return AnthropicClient(config=llm_config) + + case 'gemini': + if not HAS_GEMINI: + raise ValueError('Gemini client not available in current graphiti-core version') + if not config.providers.gemini: + raise ValueError('Gemini provider configuration not found') + + api_key = config.providers.gemini.api_key + _validate_api_key('Gemini', api_key, logger) + + llm_config = GraphitiLLMConfig( + api_key=api_key, + model=config.model, + temperature=config.temperature, + max_tokens=config.max_tokens, + ) + return GeminiClient(config=llm_config) + + case 'groq': + if not HAS_GROQ: + raise ValueError('Groq client not available in current graphiti-core version') + if not config.providers.groq: + raise ValueError('Groq provider configuration not found') + + api_key = config.providers.groq.api_key + _validate_api_key('Groq', api_key, logger) + + llm_config = GraphitiLLMConfig( + api_key=api_key, + base_url=config.providers.groq.api_url, + model=config.model, + temperature=config.temperature, + max_tokens=config.max_tokens, + ) + return GroqClient(config=llm_config) + + case _: + raise ValueError(f'Unsupported LLM provider: {provider}') + + +class EmbedderFactory: + """Factory for creating Embedder clients based on configuration.""" + + @staticmethod + def create(config: EmbedderConfig) -> EmbedderClient: + """Create an Embedder client based on the configured provider.""" + import logging + + logger = logging.getLogger(__name__) + + provider = config.provider.lower() + + match provider: + case 'openai': + if not config.providers.openai: + raise ValueError('OpenAI provider configuration not found') + + api_key = config.providers.openai.api_key + _validate_api_key('OpenAI Embedder', api_key, logger) + + from graphiti_core.embedder.openai import OpenAIEmbedderConfig + + embedder_config = OpenAIEmbedderConfig( + api_key=api_key, + embedding_model=config.model, + ) + return OpenAIEmbedder(config=embedder_config) + + case 'azure_openai': + if not HAS_AZURE_EMBEDDER: + raise ValueError( + 'Azure OpenAI embedder not available in current graphiti-core version' + ) + if not config.providers.azure_openai: + raise ValueError('Azure OpenAI provider configuration not found') + azure_config = config.providers.azure_openai + + if not azure_config.api_url: + raise ValueError('Azure OpenAI API URL is required') + + # Handle Azure AD authentication if enabled + api_key: str | None = None + azure_ad_token_provider = None + if azure_config.use_azure_ad: + logger.info( + 'Creating Azure OpenAI Embedder client with Azure AD authentication' + ) + azure_ad_token_provider = create_azure_credential_token_provider() + else: + api_key = azure_config.api_key + _validate_api_key('Azure OpenAI Embedder', api_key, logger) + + # Create the Azure OpenAI client first + azure_client = AsyncAzureOpenAI( + api_key=api_key, + azure_endpoint=azure_config.api_url, + api_version=azure_config.api_version, + azure_deployment=azure_config.deployment_name, + azure_ad_token_provider=azure_ad_token_provider, + ) + + return AzureOpenAIEmbedderClient( + azure_client=azure_client, + model=config.model or 'text-embedding-3-small', + ) + + case 'gemini': + if not HAS_GEMINI_EMBEDDER: + raise ValueError( + 'Gemini embedder not available in current graphiti-core version' + ) + if not config.providers.gemini: + raise ValueError('Gemini provider configuration not found') + + api_key = config.providers.gemini.api_key + _validate_api_key('Gemini Embedder', api_key, logger) + + from graphiti_core.embedder.gemini import GeminiEmbedderConfig + + gemini_config = GeminiEmbedderConfig( + api_key=api_key, + embedding_model=config.model or 'models/text-embedding-004', + embedding_dim=config.dimensions or 768, + ) + return GeminiEmbedder(config=gemini_config) + + case 'voyage': + if not HAS_VOYAGE_EMBEDDER: + raise ValueError( + 'Voyage embedder not available in current graphiti-core version' + ) + if not config.providers.voyage: + raise ValueError('Voyage provider configuration not found') + + api_key = config.providers.voyage.api_key + _validate_api_key('Voyage Embedder', api_key, logger) + + from graphiti_core.embedder.voyage import VoyageAIEmbedderConfig + + voyage_config = VoyageAIEmbedderConfig( + api_key=api_key, + embedding_model=config.model or 'voyage-3', + embedding_dim=config.dimensions or 1024, + ) + return VoyageAIEmbedder(config=voyage_config) + + case _: + raise ValueError(f'Unsupported Embedder provider: {provider}') + + +class DatabaseDriverFactory: + """Factory for creating Database drivers based on configuration. + + Note: This returns configuration dictionaries that can be passed to Graphiti(), + not driver instances directly, as the drivers require complex initialization. + """ + + @staticmethod + def create_config(config: DatabaseConfig) -> dict: + """Create database configuration dictionary based on the configured provider.""" + provider = config.provider.lower() + + match provider: + case 'neo4j': + # Use Neo4j config if provided, otherwise use defaults + if config.providers.neo4j: + neo4j_config = config.providers.neo4j + else: + # Create default Neo4j configuration + from config.schema import Neo4jProviderConfig + + neo4j_config = Neo4jProviderConfig() + + # Check for environment variable overrides (for CI/CD compatibility) + import os + + uri = os.environ.get('NEO4J_URI', neo4j_config.uri) + username = os.environ.get('NEO4J_USER', neo4j_config.username) + password = os.environ.get('NEO4J_PASSWORD', neo4j_config.password) + + return { + 'uri': uri, + 'user': username, + 'password': password, + # Note: database and use_parallel_runtime would need to be passed + # to the driver after initialization if supported + } + + case 'falkordb': + if not HAS_FALKOR: + raise ValueError( + 'FalkorDB driver not available in current graphiti-core version' + ) + + # Use FalkorDB config if provided, otherwise use defaults + if config.providers.falkordb: + falkor_config = config.providers.falkordb + else: + # Create default FalkorDB configuration + from config.schema import FalkorDBProviderConfig + + falkor_config = FalkorDBProviderConfig() + + # Check for environment variable overrides (for CI/CD compatibility) + import os + from urllib.parse import urlparse + + uri = os.environ.get('FALKORDB_URI', falkor_config.uri) + password = os.environ.get('FALKORDB_PASSWORD', falkor_config.password) + + # Parse the URI to extract host and port + parsed = urlparse(uri) + host = parsed.hostname or 'localhost' + port = parsed.port or 6379 + + return { + 'driver': 'falkordb', + 'host': host, + 'port': port, + 'password': password, + 'database': falkor_config.database, + } + + case _: + raise ValueError(f'Unsupported Database provider: {provider}') diff --git a/mcp_server/src/services/factories.py.bak3 b/mcp_server/src/services/factories.py.bak3 new file mode 100644 index 000000000..030a8e37b --- /dev/null +++ b/mcp_server/src/services/factories.py.bak3 @@ -0,0 +1,439 @@ +"""Factory classes for creating LLM, Embedder, and Database clients.""" + +from openai import AsyncAzureOpenAI + +from config.schema import ( + DatabaseConfig, + EmbedderConfig, + LLMConfig, +) + +# Try to import FalkorDriver if available +try: + from graphiti_core.driver.falkordb_driver import FalkorDriver # noqa: F401 + + HAS_FALKOR = True +except ImportError: + HAS_FALKOR = False + +# Kuzu support removed - FalkorDB is now the default +from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder +from graphiti_core.llm_client import LLMClient, OpenAIClient +from graphiti_core.llm_client.config import LLMConfig as GraphitiLLMConfig + +# Try to import additional providers if available +try: + from graphiti_core.embedder.azure_openai import AzureOpenAIEmbedderClient + + HAS_AZURE_EMBEDDER = True +except ImportError: + HAS_AZURE_EMBEDDER = False + +try: + from graphiti_core.embedder.gemini import GeminiEmbedder + + HAS_GEMINI_EMBEDDER = True +except ImportError: + HAS_GEMINI_EMBEDDER = False + +try: + from graphiti_core.embedder.voyage import VoyageAIEmbedder + + HAS_VOYAGE_EMBEDDER = True +except ImportError: + HAS_VOYAGE_EMBEDDER = False + +try: + from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient + + HAS_AZURE_LLM = True +except ImportError: + HAS_AZURE_LLM = False + +try: + from graphiti_core.llm_client.anthropic_client import AnthropicClient + + HAS_ANTHROPIC = True +except ImportError: + HAS_ANTHROPIC = False + +try: + from graphiti_core.llm_client.gemini_client import GeminiClient + + HAS_GEMINI = True +except ImportError: + HAS_GEMINI = False + +try: + from graphiti_core.llm_client.groq_client import GroqClient + + HAS_GROQ = True +except ImportError: + HAS_GROQ = False +from utils.utils import create_azure_credential_token_provider + + +def _validate_api_key(provider_name: str, api_key: str | None, logger) -> str: + """Validate API key is present. + + Args: + provider_name: Name of the provider (e.g., 'OpenAI', 'Anthropic') + api_key: The API key to validate + logger: Logger instance for output + + Returns: + The validated API key + + Raises: + ValueError: If API key is None or empty + """ + if not api_key: + raise ValueError( + f'{provider_name} API key is not configured. Please set the appropriate environment variable.' + ) + + logger.info(f'Creating {provider_name} client') + + return api_key + + +class LLMClientFactory: + """Factory for creating LLM clients based on configuration.""" + + @staticmethod + def create(config: LLMConfig) -> LLMClient: + """Create an LLM client based on the configured provider.""" + import logging + + logger = logging.getLogger(__name__) + + provider = config.provider.lower() + + match provider: + case 'openai': + if not config.providers.openai: + raise ValueError('OpenAI provider configuration not found') + + api_key = config.providers.openai.api_key + _validate_api_key('OpenAI', api_key, logger) + + from graphiti_core.llm_client.config import LLMConfig as CoreLLMConfig + + # Determine appropriate small model based on main model type + is_reasoning_model = ( + config.model.startswith('gpt-5') + or config.model.startswith('o1') + or config.model.startswith('o3') + ) + small_model = ( + 'gpt-5-nano' if is_reasoning_model else 'gpt-4.1-mini' + ) # Use reasoning model for small tasks if main model is reasoning + + llm_config = CoreLLMConfig( + api_key=api_key, + model=config.model, + small_model=small_model, + temperature=config.temperature, + max_tokens=config.max_tokens, + ) + + # Only pass reasoning/verbosity parameters for reasoning models (gpt-5 family) + if is_reasoning_model: + return OpenAIClient(config=llm_config, reasoning='minimal', verbosity='low') + else: + # For non-reasoning models, explicitly pass None to disable these parameters + return OpenAIClient(config=llm_config, reasoning=None, verbosity=None) + + case 'azure_openai': + if not HAS_AZURE_LLM: + raise ValueError( + 'Azure OpenAI LLM client not available in current graphiti-core version' + ) + if not config.providers.azure_openai: + raise ValueError('Azure OpenAI provider configuration not found') + azure_config = config.providers.azure_openai + + if not azure_config.api_url: + raise ValueError('Azure OpenAI API URL is required') + + # Handle Azure AD authentication if enabled + api_key: str | None = None + azure_ad_token_provider = None + if azure_config.use_azure_ad: + logger.info('Creating Azure OpenAI LLM client with Azure AD authentication') + azure_ad_token_provider = create_azure_credential_token_provider() + else: + api_key = azure_config.api_key + _validate_api_key('Azure OpenAI', api_key, logger) + + # Create the Azure OpenAI client first + azure_client = AsyncAzureOpenAI( + api_key=api_key, + azure_endpoint=azure_config.api_url, + api_version=azure_config.api_version, + azure_deployment=azure_config.deployment_name, + azure_ad_token_provider=azure_ad_token_provider, + ) + + # Then create the LLMConfig + from graphiti_core.llm_client.config import LLMConfig as CoreLLMConfig + + llm_config = CoreLLMConfig( + api_key=api_key, + base_url=azure_config.api_url, + model=config.model, + temperature=config.temperature, + max_tokens=config.max_tokens, + ) + + return AzureOpenAILLMClient( + azure_client=azure_client, + config=llm_config, + max_tokens=config.max_tokens, + ) + + case 'anthropic': + if not HAS_ANTHROPIC: + raise ValueError( + 'Anthropic client not available in current graphiti-core version' + ) + if not config.providers.anthropic: + raise ValueError('Anthropic provider configuration not found') + + api_key = config.providers.anthropic.api_key + _validate_api_key('Anthropic', api_key, logger) + + llm_config = GraphitiLLMConfig( + api_key=api_key, + model=config.model, + temperature=config.temperature, + max_tokens=config.max_tokens, + ) + return AnthropicClient(config=llm_config) + + case 'gemini': + if not HAS_GEMINI: + raise ValueError('Gemini client not available in current graphiti-core version') + if not config.providers.gemini: + raise ValueError('Gemini provider configuration not found') + + api_key = config.providers.gemini.api_key + _validate_api_key('Gemini', api_key, logger) + + llm_config = GraphitiLLMConfig( + api_key=api_key, + model=config.model, + temperature=config.temperature, + max_tokens=config.max_tokens, + ) + return GeminiClient(config=llm_config) + + case 'groq': + if not HAS_GROQ: + raise ValueError('Groq client not available in current graphiti-core version') + if not config.providers.groq: + raise ValueError('Groq provider configuration not found') + + api_key = config.providers.groq.api_key + _validate_api_key('Groq', api_key, logger) + + llm_config = GraphitiLLMConfig( + api_key=api_key, + base_url=config.providers.groq.api_url, + model=config.model, + temperature=config.temperature, + max_tokens=config.max_tokens, + ) + return GroqClient(config=llm_config) + + case _: + raise ValueError(f'Unsupported LLM provider: {provider}') + + +class EmbedderFactory: + """Factory for creating Embedder clients based on configuration.""" + + @staticmethod + def create(config: EmbedderConfig) -> EmbedderClient: + """Create an Embedder client based on the configured provider.""" + import logging + + logger = logging.getLogger(__name__) + + provider = config.provider.lower() + + match provider: + case 'openai': + if not config.providers.openai: + raise ValueError('OpenAI provider configuration not found') + + api_key = config.providers.openai.api_key + _validate_api_key('OpenAI Embedder', api_key, logger) + + from graphiti_core.embedder.openai import OpenAIEmbedderConfig + + embedder_config = OpenAIEmbedderConfig( + api_key=api_key, + embedding_model=config.model, + ) + return OpenAIEmbedder(config=embedder_config) + + case 'azure_openai': + if not HAS_AZURE_EMBEDDER: + raise ValueError( + 'Azure OpenAI embedder not available in current graphiti-core version' + ) + if not config.providers.azure_openai: + raise ValueError('Azure OpenAI provider configuration not found') + azure_config = config.providers.azure_openai + + if not azure_config.api_url: + raise ValueError('Azure OpenAI API URL is required') + + # Handle Azure AD authentication if enabled + api_key: str | None = None + azure_ad_token_provider = None + if azure_config.use_azure_ad: + logger.info( + 'Creating Azure OpenAI Embedder client with Azure AD authentication' + ) + azure_ad_token_provider = create_azure_credential_token_provider() + else: + api_key = azure_config.api_key + _validate_api_key('Azure OpenAI Embedder', api_key, logger) + + # Create the Azure OpenAI client first + azure_client = AsyncAzureOpenAI( + api_key=api_key, + azure_endpoint=azure_config.api_url, + api_version=azure_config.api_version, + azure_deployment=azure_config.deployment_name, + azure_ad_token_provider=azure_ad_token_provider, + ) + + return AzureOpenAIEmbedderClient( + azure_client=azure_client, + model=config.model or 'text-embedding-3-small', + ) + + case 'gemini': + if not HAS_GEMINI_EMBEDDER: + raise ValueError( + 'Gemini embedder not available in current graphiti-core version' + ) + if not config.providers.gemini: + raise ValueError('Gemini provider configuration not found') + + api_key = config.providers.gemini.api_key + _validate_api_key('Gemini Embedder', api_key, logger) + + from graphiti_core.embedder.gemini import GeminiEmbedderConfig + + gemini_config = GeminiEmbedderConfig( + api_key=api_key, + embedding_model=config.model or 'models/text-embedding-004', + embedding_dim=config.dimensions or 768, + ) + return GeminiEmbedder(config=gemini_config) + + case 'voyage': + if not HAS_VOYAGE_EMBEDDER: + raise ValueError( + 'Voyage embedder not available in current graphiti-core version' + ) + if not config.providers.voyage: + raise ValueError('Voyage provider configuration not found') + + api_key = config.providers.voyage.api_key + _validate_api_key('Voyage Embedder', api_key, logger) + + from graphiti_core.embedder.voyage import VoyageAIEmbedderConfig + + voyage_config = VoyageAIEmbedderConfig( + api_key=api_key, + embedding_model=config.model or 'voyage-3', + embedding_dim=config.dimensions or 1024, + ) + return VoyageAIEmbedder(config=voyage_config) + + case _: + raise ValueError(f'Unsupported Embedder provider: {provider}') + + +class DatabaseDriverFactory: + """Factory for creating Database drivers based on configuration. + + Note: This returns configuration dictionaries that can be passed to Graphiti(), + not driver instances directly, as the drivers require complex initialization. + """ + + @staticmethod + def create_config(config: DatabaseConfig) -> dict: + """Create database configuration dictionary based on the configured provider.""" + provider = config.provider.lower() + + match provider: + case 'neo4j': + # Use Neo4j config if provided, otherwise use defaults + if config.providers.neo4j: + neo4j_config = config.providers.neo4j + else: + # Create default Neo4j configuration + from config.schema import Neo4jProviderConfig + + neo4j_config = Neo4jProviderConfig() + + # Check for environment variable overrides (for CI/CD compatibility) + import os + + uri = os.environ.get('NEO4J_URI', neo4j_config.uri) + username = os.environ.get('NEO4J_USER', neo4j_config.username) + password = os.environ.get('NEO4J_PASSWORD', neo4j_config.password) + username = os.environ.get("FALKORDB_USERNAME", falkor_config.username) + + return { + 'uri': uri, + 'user': username, + 'password': password, + # Note: database and use_parallel_runtime would need to be passed + # to the driver after initialization if supported + } + + case 'falkordb': + if not HAS_FALKOR: + raise ValueError( + 'FalkorDB driver not available in current graphiti-core version' + ) + + # Use FalkorDB config if provided, otherwise use defaults + if config.providers.falkordb: + falkor_config = config.providers.falkordb + else: + # Create default FalkorDB configuration + from config.schema import FalkorDBProviderConfig + + falkor_config = FalkorDBProviderConfig() + + # Check for environment variable overrides (for CI/CD compatibility) + import os + from urllib.parse import urlparse + + uri = os.environ.get('FALKORDB_URI', falkor_config.uri) + password = os.environ.get('FALKORDB_PASSWORD', falkor_config.password) + username = os.environ.get("FALKORDB_USERNAME", falkor_config.username) + + # Parse the URI to extract host and port + parsed = urlparse(uri) + host = parsed.hostname or 'localhost' + port = parsed.port or 6379 + + return { + 'driver': 'falkordb', + 'host': host, + 'port': port, + 'password': password, + 'database': falkor_config.database, + } + + case _: + raise ValueError(f'Unsupported Database provider: {provider}') diff --git a/mcp_server/start_mcp_server.sh b/mcp_server/start_mcp_server.sh new file mode 100755 index 000000000..f88f16b60 --- /dev/null +++ b/mcp_server/start_mcp_server.sh @@ -0,0 +1,9 @@ +#!/bin/bash +# Start Graphiti MCP Server + +# Set OpenAI API key (replace with your actual key) +export OPENAI_API_KEY="your-openai-api-key-here" + +# Start the server +cd ~/graphiti/mcp_server +uv run src/graphiti_mcp_server/server.py diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..d4839a6b1 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# Tests package diff --git a/tests/test_graphiti_mock.py b/tests/test_graphiti_mock.py index a30c93921..d1f192d7f 100644 --- a/tests/test_graphiti_mock.py +++ b/tests/test_graphiti_mock.py @@ -1015,6 +1015,95 @@ async def test_edge_fulltext_search( assert edges[0].name == entity_edge_1.name +@pytest.mark.asyncio +async def test_edge_fulltext_search_custom_edge_types( + graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client +): + """Test edge_fulltext_search with custom edge types parameter.""" + if graph_driver.provider != GraphProvider.FALKORDB: + pytest.skip('Custom edge types only supported for FalkorDB') + + graphiti = Graphiti( + graph_driver=graph_driver, + llm_client=mock_llm_client, + embedder=mock_embedder, + cross_encoder=mock_cross_encoder_client, + ) + await graphiti.build_indices_and_constraints() + + # Create entity nodes (using names that mock_embedder recognizes) + entity_node_1 = EntityNode( + name='test_entity_1', + labels=[], + created_at=datetime.now(), + group_id=group_id, + ) + await entity_node_1.generate_name_embedding(mock_embedder) + entity_node_2 = EntityNode( + name='test_entity_2', + labels=[], + created_at=datetime.now(), + group_id=group_id, + ) + await entity_node_2.generate_name_embedding(mock_embedder) + + # Create entity edge with RELATES_TO (standard type) + # Using predefined fact that mock_embedder recognizes + entity_edge_1 = EntityEdge( + source_node_uuid=entity_node_1.uuid, + target_node_uuid=entity_node_2.uuid, + name='RELATES_TO', + fact='test_entity_1 relates to test_entity_2', + created_at=datetime.now(), + group_id=group_id, + ) + await entity_edge_1.generate_embedding(mock_embedder) + + # Save the graph + await entity_node_1.save(graph_driver) + await entity_node_2.save(graph_driver) + await entity_edge_1.save(graph_driver) + + search_filters = SearchFilters(node_labels=['Entity'], edge_types=['RELATES_TO']) + + # Test with explicit edge_types parameter (should find the edge) + edges = await edge_fulltext_search( + graph_driver, + 'test_entity_1 relates to test_entity_2', + search_filters, + group_ids=[group_id], + edge_types=['RELATES_TO'], + ) + assert len(edges) >= 1 + assert any(e.name == 'RELATES_TO' for e in edges) + + # Test with non-existent edge type (should find nothing) + edges = await edge_fulltext_search( + graph_driver, + 'test_entity_1 relates to test_entity_2', + search_filters, + group_ids=[group_id], + edge_types=['NONEXISTENT_TYPE'], + ) + assert len(edges) == 0 + + +@pytest.mark.asyncio +async def test_ensure_edge_type_index(graph_driver): + """Test ensure_edge_type_index creates fulltext indexes for custom edge types.""" + if graph_driver.provider != GraphProvider.FALKORDB: + pytest.skip('ensure_edge_type_index only implemented for FalkorDB') + + # Test creating index for a custom edge type (should succeed) + await graph_driver.ensure_edge_type_index('CUSTOM_TYPE') + + # Test calling it again is idempotent (should not fail) + await graph_driver.ensure_edge_type_index('CUSTOM_TYPE') + + # Test that RELATES_TO is a no-op (index already exists from build_indices_and_constraints) + await graph_driver.ensure_edge_type_index('RELATES_TO') + + @pytest.mark.asyncio async def test_edge_similarity_search(graph_driver, mock_embedder): if graph_driver.provider == GraphProvider.FALKORDB: diff --git a/tests/utils/maintenance/test_edge_operations.py b/tests/utils/maintenance/test_edge_operations.py index b5d01e541..cb87cb137 100644 --- a/tests/utils/maintenance/test_edge_operations.py +++ b/tests/utils/maintenance/test_edge_operations.py @@ -237,7 +237,11 @@ async def immediate_gather(*aws, max_coroutines=None): @pytest.mark.asyncio -async def test_resolve_extracted_edges_keeps_unknown_names(monkeypatch): +async def test_resolve_extracted_edges_converts_unknown_names_to_default(monkeypatch): + """When custom edge_types are defined, unknown edge names should be converted to RELATES_TO. + + This ensures strict schema enforcement - the LLM cannot invent arbitrary edge types. + """ from graphiti_core.utils.maintenance import edge_operations as edge_ops monkeypatch.setattr(edge_ops, 'create_entity_edge_embeddings', AsyncMock(return_value=None)) @@ -312,7 +316,8 @@ async def immediate_gather(*aws, max_coroutines=None): edge_type_map, ) - assert resolved_edges[0].name == 'INTERACTED_WITH' + # Unknown edge types are converted to RELATES_TO when custom edge_types are defined + assert resolved_edges[0].name == DEFAULT_EDGE_NAME assert invalidated_edges == []