Skip to content

Commit f6547fe

Browse files
committed
fix: 处理 embed_info 可能为字典或 EmbedModelInfo 对象的情况,确保获取模型名称和其他属性的兼容性
1 parent 59727d6 commit f6547fe

7 files changed

Lines changed: 404 additions & 19 deletions

File tree

docs/changelog/roadmap.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
## Bugs
99

10-
-
10+
- [x] 修复本地知识库的 metadata 和 向量数据库中不一致的情况。
1111

1212
## Next
1313

src/knowledge/base.py

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import json
22
import os
3+
import tempfile
4+
import shutil
35
from abc import ABC, abstractmethod
46
from typing import Any
57

@@ -524,6 +526,7 @@ async def retriever(query_text):
524526
def _load_metadata(self):
525527
"""加载元数据"""
526528
meta_file = os.path.join(self.work_dir, f"metadata_{self.kb_type}.json")
529+
527530
if os.path.exists(meta_file):
528531
try:
529532
with open(meta_file, encoding="utf-8") as f:
@@ -533,19 +536,74 @@ def _load_metadata(self):
533536
logger.info(f"Loaded {self.kb_type} metadata for {len(self.databases_meta)} databases")
534537
except Exception as e:
535538
logger.error(f"Failed to load {self.kb_type} metadata: {e}")
539+
# 尝试从备份恢复
540+
backup_file = f"{meta_file}.backup"
541+
if os.path.exists(backup_file):
542+
try:
543+
with open(backup_file, encoding="utf-8") as f:
544+
data = json.load(f)
545+
self.databases_meta = data.get("databases", {})
546+
self.files_meta = data.get("files", {})
547+
logger.info(f"Loaded {self.kb_type} metadata from backup")
548+
# 恢复备份文件
549+
shutil.copy2(backup_file, meta_file)
550+
return
551+
except Exception as backup_e:
552+
logger.error(f"Failed to load backup: {backup_e}")
553+
554+
# 如果加载失败,初始化为空状态
555+
logger.warning(f"Initializing empty {self.kb_type} metadata")
556+
self.databases_meta = {}
557+
self.files_meta = {}
558+
559+
def _serialize_metadata(self, obj):
560+
"""递归序列化元数据中的 Pydantic 模型"""
561+
if hasattr(obj, 'dict'):
562+
return obj.dict()
563+
elif isinstance(obj, dict):
564+
return {k: self._serialize_metadata(v) for k, v in obj.items()}
565+
elif isinstance(obj, list):
566+
return [self._serialize_metadata(item) for item in obj]
567+
else:
568+
return obj
536569

537570
def _save_metadata(self):
538571
"""保存元数据"""
539572
self._normalize_metadata_state()
540573
meta_file = os.path.join(self.work_dir, f"metadata_{self.kb_type}.json")
574+
backup_file = f"{meta_file}.backup"
575+
541576
try:
577+
# 创建简单备份
578+
if os.path.exists(meta_file):
579+
shutil.copy2(meta_file, backup_file)
580+
581+
# 准备数据并序列化 Pydantic 模型
542582
data = {
543-
"databases": self.databases_meta,
544-
"files": self.files_meta,
583+
"databases": self._serialize_metadata(self.databases_meta),
584+
"files": self._serialize_metadata(self.files_meta),
545585
"kb_type": self.kb_type,
546586
"updated_at": utc_isoformat(),
547587
}
548-
with open(meta_file, "w", encoding="utf-8") as f:
549-
json.dump(data, f, ensure_ascii=False, indent=2)
588+
589+
# 原子性写入(使用临时文件)
590+
with tempfile.NamedTemporaryFile(
591+
mode='w', dir=os.path.dirname(meta_file),
592+
prefix='.tmp_', suffix='.json', delete=False
593+
) as tmp_file:
594+
json.dump(data, tmp_file, ensure_ascii=False, indent=2)
595+
temp_path = tmp_file.name
596+
597+
os.replace(temp_path, meta_file)
598+
logger.debug(f"Saved {self.kb_type} metadata")
599+
550600
except Exception as e:
551601
logger.error(f"Failed to save {self.kb_type} metadata: {e}")
602+
# 尝试恢复备份
603+
if os.path.exists(backup_file):
604+
try:
605+
shutil.copy2(backup_file, meta_file)
606+
logger.info("Restored metadata from backup")
607+
except Exception as restore_e:
608+
logger.error(f"Failed to restore backup: {restore_e}")
609+
raise e

src/knowledge/graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ def _batch_set_embeddings(tx, entity_embedding_pairs):
299299
logger.info(f"Adding entity to {kgdb_name}")
300300
session.execute_write(_create_graph, triples)
301301
logger.info(f"Creating vector index for {kgdb_name} with {config.embed_model}")
302-
session.execute_write(_create_vector_index, cur_embed_info["dimension"])
302+
session.execute_write(_create_vector_index, getattr(cur_embed_info, 'dimension', 1024))
303303

304304
# 收集所有需要处理的实体名称,去重
305305
all_entities = []

src/knowledge/implementations/chroma.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,12 @@ async def _create_kb_instance(self, db_id: str, kb_config: dict) -> Any:
7272
logger.info(f"Retrieved existing collection: {collection_name}")
7373

7474
# 检查现有集合的配置是否匹配当前的 embed_info
75-
expected_model = embed_info.get("name") if embed_info else "default"
75+
expected_model = getattr(embed_info, 'name', None) if embed_info else None
76+
if expected_model is None and hasattr(embed_info, 'get'):
77+
expected_model = embed_info.get('name')
78+
elif embed_info and isinstance(embed_info, dict):
79+
expected_model = embed_info.get('name')
80+
expected_model = expected_model or "default"
7681
collection_metadata = collection.metadata or {}
7782
current_model = collection_metadata.get("embedding_model", "unknown")
7883

@@ -88,11 +93,18 @@ async def _create_kb_instance(self, db_id: str, kb_config: dict) -> Any:
8893

8994
except Exception:
9095
# 创建新集合
91-
logger.info(f"Creating new collection with embedding model: {embed_info.get('name', 'default')}")
96+
model_name = getattr(embed_info, 'name', None) if embed_info else None
97+
if model_name is None and hasattr(embed_info, 'get'):
98+
model_name = embed_info.get('name')
99+
elif embed_info and isinstance(embed_info, dict):
100+
model_name = embed_info.get('name')
101+
102+
model_name = model_name or 'default'
103+
logger.info(f"Creating new collection with embedding model: {model_name}")
92104
collection_metadata = {
93105
"db_id": db_id,
94106
"created_at": utc_isoformat(),
95-
"embedding_model": embed_info.get("name") if embed_info else "default",
107+
"embedding_model": model_name,
96108
}
97109
collection = self.chroma_client.create_collection(
98110
name=collection_name, embedding_function=embedding_function, metadata=collection_metadata

src/knowledge/implementations/milvus.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ async def _create_kb_instance(self, db_id: str, kb_config: dict) -> Any:
103103

104104
# 检查嵌入模型是否匹配
105105
description = collection.description
106-
expected_model = embed_info.get("name") if embed_info else "default"
106+
expected_model = getattr(embed_info, 'name', 'default') if embed_info else "default"
107107

108108
if expected_model not in description:
109109
logger.warning(f"Collection {collection_name} model mismatch, recreating...")
@@ -116,8 +116,8 @@ async def _create_kb_instance(self, db_id: str, kb_config: dict) -> Any:
116116

117117
except Exception:
118118
# 创建新集合
119-
embedding_dim = embed_info.get("dimension", 1024) if embed_info else 1024
120-
model_name = embed_info.get("name", "default") if embed_info else "default"
119+
embedding_dim = getattr(embed_info, 'dimension', 1024) if embed_info else 1024
120+
model_name = getattr(embed_info, 'name', 'default') if embed_info else "default"
121121

122122
# 定义集合Schema
123123
fields = [

0 commit comments

Comments
 (0)