Skip to content

Commit 1249344

Browse files
authored
Merge pull request #523 from holenzzz/fix/embedding-400
Fix embedding 400 errors
2 parents ee50666 + 9cd6470 commit 1249344

4 files changed

Lines changed: 27 additions & 12 deletions

File tree

src/knowledge/implementations/milvus.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,13 +191,13 @@ def _get_async_embedding(self, embed_info: dict):
191191
def _get_async_embedding_function(self, embed_info: dict):
192192
"""获取 embedding 函数"""
193193
embedding_model = self._get_async_embedding(embed_info)
194-
return partial(embedding_model.abatch_encode, batch_size=40)
194+
return partial(embedding_model.abatch_encode, batch_size=10)
195195

196196
def _get_embedding_function(self, embed_info: dict):
197197
"""获取 embedding 函数"""
198198
embedding_model = self._get_async_embedding(embed_info)
199199

200-
return partial(embedding_model.batch_encode, batch_size=40)
200+
return partial(embedding_model.batch_encode, batch_size=10)
201201

202202
async def _get_milvus_collection(self, db_id: str):
203203
"""获取或创建 Milvus 集合"""

src/knowledge/services/upload_graph_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@ def load_graph_info(self):
488488
logger.error(f"加载图数据库信息失败:{e}")
489489
return False
490490

491-
async def aget_embedding(self, text, batch_size=40):
491+
async def aget_embedding(self, text, batch_size=10):
492492
if isinstance(text, list):
493493
outputs = await self.embed_model.abatch_encode(text, batch_size=batch_size)
494494
return outputs

src/models/embed.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ async def aencode_queries(self, queries: list[str] | str) -> list[list[float]]:
4646
"""等同于aencode"""
4747
return await self.aencode(queries)
4848

49-
def batch_encode(self, messages: list[str], batch_size: int = 40) -> list[list[float]]:
49+
def batch_encode(self, messages: list[str], batch_size: int = 10) -> list[list[float]]:
5050
# logger.info(f"Batch encoding {len(messages)} messages")
5151
data = []
5252
task_id = None
@@ -67,24 +67,39 @@ def batch_encode(self, messages: list[str], batch_size: int = 40) -> list[list[f
6767

6868
return data
6969

70-
async def abatch_encode(self, messages: list[str], batch_size: int = 40) -> list[list[float]]:
70+
async def abatch_encode(self, messages: list[str], batch_size: int = 10) -> list[list[float]]:
7171
data = []
7272
task_id = None
7373
if len(messages) > batch_size:
7474
task_id = hashstr(messages)
7575
self.embed_state[task_id] = {"status": "in-progress", "total": len(messages), "progress": 0}
7676

77-
tasks = []
77+
#保留原有逻辑:
78+
#使用 asyncio.gather 并发执行所有 embedding 批次请求:
79+
# tasks = []
80+
# for i in range(0, len(messages), batch_size):
81+
# group_msg = messages[i : i + batch_size]
82+
# tasks.append(self.aencode(group_msg))
83+
84+
# results = await asyncio.gather(*tasks)
85+
# for res in results:
86+
# data.extend(res)
87+
88+
# if task_id:
89+
# self.embed_state[task_id]["progress"] = len(messages)
90+
# self.embed_state[task_id]["status"] = "completed"
91+
92+
# return data
93+
7894
for i in range(0, len(messages), batch_size):
7995
group_msg = messages[i : i + batch_size]
80-
tasks.append(self.aencode(group_msg))
81-
82-
results = await asyncio.gather(*tasks)
83-
for res in results:
96+
logger.info(f"Async encoding [{i}/{len(messages)}] messages (bsz={batch_size})")
97+
res = await self.aencode(group_msg)
8498
data.extend(res)
99+
if task_id:
100+
self.embed_state[task_id]["progress"] = i + len(group_msg)
85101

86102
if task_id:
87-
self.embed_state[task_id]["progress"] = len(messages)
88103
self.embed_state[task_id]["status"] = "completed"
89104

90105
return data

src/services/evaluation_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ async def _generate_benchmark_task(self, context: TaskContext):
336336
# Currently, we re-calculate embeddings for ALL chunks in the KB for every benchmark generation.
337337
# This is inefficient for large KBs (O(N) embedding calls).
338338
# Optimization: Reuse existing embeddings from Vector DB if embedding_model_id matches the KB's embedding model.
339-
embeddings = await embed_model.abatch_encode(contents, batch_size=40)
339+
embeddings = await embed_model.abatch_encode(contents, batch_size=10)
340340
norms = [math.sqrt(sum(x * x for x in vec)) or 1.0 for vec in embeddings]
341341

342342
def cosine(a, b, na, nb):

0 commit comments

Comments
 (0)