Skip to content

Commit 2ef4f91

Browse files
authored
Merge pull request #117 from xerrors/112-kg-retrieve-error
fix:检索前检查是否存在 entityEmbeddings 这个索引
2 parents f07a8fe + b432c1b commit 2ef4f91

3 files changed

Lines changed: 23 additions & 3 deletions

File tree

docker/docker-compose.dev.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ services:
6060
- NEO4J_AUTH=neo4j/0123456789
6161
- NEO4J_server_bolt_listen__address=0.0.0.0:7687
6262
- NEO4J_server_http_listen__address=0.0.0.0:7474
63+
- ENTITY_EMBEDDING=true
6364
networks:
6465
- app-network
6566

src/core/graphbase.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,19 @@ def query_node(self, entity_name, hops=2, **kwargs):
218218

219219
def query_by_vector(self, entity_name, threshold=0.9, kgdb_name='neo4j', hops=2, num_of_res=5):
220220
self.use_database(kgdb_name)
221+
def _index_exists(tx, index_name):
222+
"""检查索引是否存在"""
223+
result = tx.run("SHOW INDEXES")
224+
for record in result:
225+
if record["name"] == index_name:
226+
return True
227+
return False
228+
221229
def query(tx, text):
230+
# 首先检查索引是否存在
231+
if not _index_exists(tx, "entityEmbeddings"):
232+
raise Exception("向量索引不存在,请先创建索引")
233+
222234
embedding = self.get_embedding(text)
223235
result = tx.run("""
224236
CALL db.index.vector.queryNodes('entityEmbeddings', 10, $embedding)
@@ -227,9 +239,14 @@ def query(tx, text):
227239
""", embedding=embedding)
228240
return result.values()
229241

230-
with self.driver.session() as session:
231-
results = session.execute_read(query, entity_name)
232-
242+
try:
243+
with self.driver.session() as session:
244+
results = session.execute_read(query, entity_name)
245+
except Exception as e:
246+
if "向量索引不存在" in str(e):
247+
logger.error(f"向量索引不存在,请先创建索引: {e}, {traceback.format_exc()}")
248+
return []
249+
raise e
233250

234251
# 筛选出分数高于阈值的实体
235252
qualified_entities = [result[0] for result in results[:num_of_res] if result[1] > threshold]

src/core/retriever.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ def query_graph(self, query, history, refs):
7676
results = []
7777
if refs["meta"].get("use_graph") and config.enable_knowledge_base:
7878
for entity in refs["entities"]:
79+
if entity == "":
80+
continue
7981
result = graph_base.query_by_vector(entity)
8082
if result != []:
8183
results.extend(result)

0 commit comments

Comments
 (0)