Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions graphiti_core/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,19 @@ def with_database(self, database: str) -> 'GraphDriver':
async def build_indices_and_constraints(self, delete_existing: bool = False):
raise NotImplementedError()

async def ensure_edge_type_index(self, edge_type: str) -> None:
"""
Ensure a fulltext index exists for a custom edge type.

This method should be called when custom edge types are used to enable
BM25 fulltext search on those relationship types. The default implementation
is a no-op; drivers that support fulltext indexes should override this.

Args:
edge_type: The relationship type name (e.g., 'SANCTION', 'OWNERSHIP')
"""
pass # Default no-op for drivers that don't support fulltext indexes

def clone(self, database: str) -> 'GraphDriver':
"""Clone the driver with a different database or graph name."""
return self
Expand Down
24 changes: 24 additions & 0 deletions graphiti_core/driver/falkordb_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,30 @@ async def build_indices_and_constraints(self, delete_existing=False):
for query in index_queries:
await self.execute_query(query)

async def ensure_edge_type_index(self, edge_type: str) -> None:
"""
Ensure a fulltext index exists for a custom edge type.

This method creates a fulltext index on (name, fact, group_id) for the
specified edge type if it doesn't already exist. This enables BM25
fulltext search on custom relationship types.

Args:
edge_type: The relationship type name (e.g., 'SANCTION', 'OWNERSHIP')
"""
if edge_type == 'RELATES_TO':
# RELATES_TO index is created by build_indices_and_constraints
return

query = f"""CREATE FULLTEXT INDEX FOR ()-[e:{edge_type}]-() ON (e.name, e.fact, e.group_id)"""
try:
await self.execute_query(query)
logger.info(f'Created fulltext index for edge type: {edge_type}')
except Exception as e:
# Index may already exist
if 'already indexed' not in str(e).lower() and 'already exists' not in str(e).lower():
logger.warning(f'Failed to create fulltext index for {edge_type}: {e}')

def clone(self, database: str) -> 'GraphDriver':
"""
Returns a shallow copy of this driver with a different default database.
Expand Down
1 change: 1 addition & 0 deletions graphiti_core/driver/search_interface/search_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ async def edge_fulltext_search(
search_filter: Any,
group_ids: list[str] | None = None,
limit: int = 100,
edge_types: list[str] | None = None,
) -> list[Any]:
raise NotImplementedError

Expand Down
7 changes: 5 additions & 2 deletions graphiti_core/graph_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,12 @@ def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str:
return f'vector.similarity.cosine({vec1}, {vec2})'


def get_relationships_query(name: str, limit: int, provider: GraphProvider) -> str:
def get_relationships_query(
name: str, limit: int, provider: GraphProvider, edge_type: str | None = None
) -> str:
if provider == GraphProvider.FALKORDB:
label = NEO4J_TO_FALKORDB_MAPPING[name]
# Use provided edge_type or fall back to mapping
label = edge_type or NEO4J_TO_FALKORDB_MAPPING[name]
return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)"

if provider == GraphProvider.KUZU:
Expand Down
10 changes: 10 additions & 0 deletions graphiti_core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,11 @@ async def add_episode_endpoint(episode_data: EpisodeData):
else {('Entity', 'Entity'): []}
)

# Ensure fulltext indexes exist for custom edge types
if edge_types is not None:
for edge_type in edge_types.keys():
await self.driver.ensure_edge_type_index(edge_type)

# Extract and resolve nodes
extracted_nodes = await extract_nodes(
self.clients,
Expand Down Expand Up @@ -905,6 +910,11 @@ async def add_episode_bulk(
else {('Entity', 'Entity'): []}
)

# Ensure fulltext indexes exist for custom edge types
if edge_types is not None:
for edge_type in edge_types.keys():
await self.driver.ensure_edge_type_index(edge_type)

episodes = [
await EpisodicNode.get_by_uuid(self.driver, episode.uuid)
if episode.uuid is not None
Expand Down
16 changes: 11 additions & 5 deletions graphiti_core/prompts/extract_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,16 @@ def edge(context: dict[str, Any]) -> list[Message]:
- are clearly stated or unambiguously implied in the CURRENT MESSAGE,
and can be represented as edges in a knowledge graph.
- Facts should include entity names rather than pronouns whenever possible.
- The FACT TYPES provide a list of the most important types of facts, make sure to extract facts of these types
- The FACT TYPES are not an exhaustive list, extract all facts from the message even if they do not fit into one
of the FACT TYPES
- The FACT TYPES each contain their fact_type_signature which represents the source and target entity types.
## RELATION TYPE SELECTION (CRITICAL - READ CAREFULLY)
Your ONLY valid options for `relation_type` are:
1. One of the EXACT `fact_type_name` values from the FACT TYPES list above
2. RELATES_TO (as fallback when no fact type matches)

STRICT RULES:
- Copy the `fact_type_name` exactly as written (e.g., SPOUSE_OF, BORN_IN, DIRECTED, LOCATED_IN)
- If a relationship doesn't match any FACT TYPE, you MUST use RELATES_TO - no exceptions
- ANY invented type (e.g., NAMED_AFTER, FOUNDED, WORKS_AT, MARRIED_TO, ACQUIRED_BY) will be REJECTED
- Do NOT modify names or add prefixes/suffixes (wrong: PERSON_SPOUSE_OF_PERSON, correct: SPOUSE_OF)

You may use information from the PREVIOUS MESSAGES only to disambiguate references or support continuity.

Expand All @@ -117,7 +123,7 @@ def edge(context: dict[str, Any]) -> list[Message]:
1. **Entity ID Validation**: `source_entity_id` and `target_entity_id` must use only the `id` values from the ENTITIES list provided above.
- **CRITICAL**: Using IDs not in the list will cause the edge to be rejected
2. Each fact must involve two **distinct** entities.
3. Use a SCREAMING_SNAKE_CASE string as the `relation_type` (e.g., FOUNDED, WORKS_AT).
3. **CRITICAL**: The `relation_type` MUST be one of the FACT TYPES above, or RELATES_TO if no match.
4. Do not emit duplicate or semantically redundant facts.
5. The `fact` should closely paraphrase the original source sentence(s). Do not verbatim quote the original text.
6. Use `REFERENCE_TIME` to resolve vague or relative temporal expressions (e.g., "last week").
Expand Down
157 changes: 112 additions & 45 deletions graphiti_core/search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,57 @@ async def search(

# if group_ids is empty, set it to None
group_ids = group_ids if group_ids and group_ids != [''] else None
(
(edges, edge_reranker_scores),
(nodes, node_reranker_scores),
(episodes, episode_reranker_scores),
(communities, community_reranker_scores),
) = await semaphore_gather(
edge_search(

# Two-phase search when bfs_origin_node_uuids is not provided:
# 1. First, run node search to find relevant nodes
# 2. Then, run edge search using those nodes as BFS origins
# This ensures edge BFS can traverse from nodes found by BM25/cosine similarity
if bfs_origin_node_uuids is None and config.node_config is not None:
# Phase 1: Find nodes first (in parallel with episode/community search)
(
(nodes, node_reranker_scores),
(episodes, episode_reranker_scores),
(communities, community_reranker_scores),
) = await semaphore_gather(
node_search(
driver,
cross_encoder,
query,
search_vector,
group_ids,
config.node_config,
search_filter,
center_node_uuid,
None,
config.limit,
config.reranker_min_score,
),
episode_search(
driver,
cross_encoder,
query,
search_vector,
group_ids,
config.episode_config,
search_filter,
config.limit,
config.reranker_min_score,
),
community_search(
driver,
cross_encoder,
query,
search_vector,
group_ids,
config.community_config,
config.limit,
config.reranker_min_score,
),
)

# Phase 2: Run edge search with found nodes as BFS origins
node_uuids_for_edge_bfs = [node.uuid for node in nodes]
(edges, edge_reranker_scores) = await edge_search(
driver,
cross_encoder,
query,
Expand All @@ -125,45 +169,66 @@ async def search(
config.edge_config,
search_filter,
center_node_uuid,
bfs_origin_node_uuids,
config.limit,
config.reranker_min_score,
),
node_search(
driver,
cross_encoder,
query,
search_vector,
group_ids,
config.node_config,
search_filter,
center_node_uuid,
bfs_origin_node_uuids,
config.limit,
config.reranker_min_score,
),
episode_search(
driver,
cross_encoder,
query,
search_vector,
group_ids,
config.episode_config,
search_filter,
config.limit,
config.reranker_min_score,
),
community_search(
driver,
cross_encoder,
query,
search_vector,
group_ids,
config.community_config,
node_uuids_for_edge_bfs,
config.limit,
config.reranker_min_score,
),
)
)
else:
# Original parallel search when bfs_origin_node_uuids is provided or no node config
(
(edges, edge_reranker_scores),
(nodes, node_reranker_scores),
(episodes, episode_reranker_scores),
(communities, community_reranker_scores),
) = await semaphore_gather(
edge_search(
driver,
cross_encoder,
query,
search_vector,
group_ids,
config.edge_config,
search_filter,
center_node_uuid,
bfs_origin_node_uuids,
config.limit,
config.reranker_min_score,
),
node_search(
driver,
cross_encoder,
query,
search_vector,
group_ids,
config.node_config,
search_filter,
center_node_uuid,
bfs_origin_node_uuids,
config.limit,
config.reranker_min_score,
),
episode_search(
driver,
cross_encoder,
query,
search_vector,
group_ids,
config.episode_config,
search_filter,
config.limit,
config.reranker_min_score,
),
community_search(
driver,
cross_encoder,
query,
search_vector,
group_ids,
config.community_config,
config.limit,
config.reranker_min_score,
),
)

results = SearchResults(
edges=edges,
Expand Down Expand Up @@ -203,7 +268,9 @@ async def edge_search(
search_tasks = []
if EdgeSearchMethod.bm25 in config.search_methods:
search_tasks.append(
edge_fulltext_search(driver, query, search_filter, group_ids, 2 * limit)
edge_fulltext_search(
driver, query, search_filter, group_ids, 2 * limit, config.edge_types
)
)
if EdgeSearchMethod.cosine_similarity in config.search_methods:
search_tasks.append(
Expand Down
5 changes: 5 additions & 0 deletions graphiti_core/search/search_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ class EdgeSearchConfig(BaseModel):
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
edge_types: list[str] | None = Field(
default=None,
description='List of edge types to search. If None, defaults to RELATES_TO. '
'Custom edge types must have fulltext indexes created.',
)


class NodeSearchConfig(BaseModel):
Expand Down
Loading