@@ -46,7 +46,7 @@ async def aencode_queries(self, queries: list[str] | str) -> list[list[float]]:
4646 """等同于aencode"""
4747 return await self .aencode (queries )
4848
49- def batch_encode (self , messages : list [str ], batch_size : int = 40 ) -> list [list [float ]]:
49+ def batch_encode (self , messages : list [str ], batch_size : int = 10 ) -> list [list [float ]]:
5050 # logger.info(f"Batch encoding {len(messages)} messages")
5151 data = []
5252 task_id = None
@@ -67,24 +67,39 @@ def batch_encode(self, messages: list[str], batch_size: int = 40) -> list[list[f
6767
6868 return data
6969
70- async def abatch_encode (self , messages : list [str ], batch_size : int = 40 ) -> list [list [float ]]:
70+ async def abatch_encode (self , messages : list [str ], batch_size : int = 10 ) -> list [list [float ]]:
7171 data = []
7272 task_id = None
7373 if len (messages ) > batch_size :
7474 task_id = hashstr (messages )
7575 self .embed_state [task_id ] = {"status" : "in-progress" , "total" : len (messages ), "progress" : 0 }
7676
77- tasks = []
77+ #保留原有逻辑:
78+ #使用 asyncio.gather 并发执行所有 embedding 批次请求:
79+ # tasks = []
80+ # for i in range(0, len(messages), batch_size):
81+ # group_msg = messages[i : i + batch_size]
82+ # tasks.append(self.aencode(group_msg))
83+
84+ # results = await asyncio.gather(*tasks)
85+ # for res in results:
86+ # data.extend(res)
87+
88+ # if task_id:
89+ # self.embed_state[task_id]["progress"] = len(messages)
90+ # self.embed_state[task_id]["status"] = "completed"
91+
92+ # return data
93+
7894 for i in range (0 , len (messages ), batch_size ):
7995 group_msg = messages [i : i + batch_size ]
80- tasks .append (self .aencode (group_msg ))
81-
82- results = await asyncio .gather (* tasks )
83- for res in results :
96+ logger .info (f"Async encoding [{ i } /{ len (messages )} ] messages (bsz={ batch_size } )" )
97+ res = await self .aencode (group_msg )
8498 data .extend (res )
99+ if task_id :
100+ self .embed_state [task_id ]["progress" ] = i + len (group_msg )
85101
86102 if task_id :
87- self .embed_state [task_id ]["progress" ] = len (messages )
88103 self .embed_state [task_id ]["status" ] = "completed"
89104
90105 return data
0 commit comments