1+ import os
2+
13import aiohttp
24
35from astrbot import logger
79from ..register import register_provider_adapter
810
911
12+ class BailianRerankError (Exception ):
13+ """百炼重排序服务异常基类"""
14+
15+ pass
16+
17+
18+ class BailianAPIError (BailianRerankError ):
19+ """百炼API返回错误"""
20+
21+ pass
22+
23+
24+ class BailianNetworkError (BailianRerankError ):
25+ """百炼网络请求错误"""
26+
27+ pass
28+
29+
1030@register_provider_adapter (
1131 "bailian_rerank" , "阿里云百炼文本排序适配器" , provider_type = ProviderType .RERANK
1232)
@@ -19,7 +39,9 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None:
1939 self .provider_settings = provider_settings
2040
2141 # API配置
22- self .api_key = provider_config .get ("rerank_api_key" , "" )
42+ self .api_key = provider_config .get ("rerank_api_key" ) or os .getenv (
43+ "DASHSCOPE_API_KEY" , ""
44+ )
2345 if not self .api_key :
2446 raise ValueError ("阿里云百炼 API Key 不能为空。" )
2547
@@ -48,6 +70,96 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None:
4870
4971 logger .info (f"AstrBot 百炼 Rerank 初始化完成。模型: { self .model } " )
5072
73+ def _build_payload (
74+ self , query : str , documents : list [str ], top_n : int | None
75+ ) -> dict :
76+ """构建请求载荷
77+
78+ Args:
79+ query: 查询文本
80+ documents: 文档列表
81+ top_n: 返回前N个结果,如果为None则返回所有结果
82+
83+ Returns:
84+ 请求载荷字典
85+ """
86+ base = {"model" : self .model , "input" : {"query" : query , "documents" : documents }}
87+
88+ params = {
89+ k : v
90+ for k , v in [
91+ ("top_n" , top_n if top_n is not None and top_n > 0 else None ),
92+ ("return_documents" , True if self .return_documents else None ),
93+ (
94+ "instruct" ,
95+ self .instruct
96+ if self .instruct and self .model == "qwen3-rerank"
97+ else None ,
98+ ),
99+ ]
100+ if v is not None
101+ }
102+
103+ if params :
104+ base ["parameters" ] = params
105+
106+ return base
107+
108+ def _parse_results (self , data : dict ) -> list [RerankResult ]:
109+ """解析API响应结果
110+
111+ Args:
112+ data: API响应数据
113+
114+ Returns:
115+ 重排序结果列表
116+
117+ Raises:
118+ BailianAPIError: API返回错误
119+ KeyError: 结果缺少必要字段
120+ """
121+ # 检查响应状态
122+ if data .get ("code" , "200" ) != "200" :
123+ raise BailianAPIError (
124+ f"百炼 API 错误: { data .get ('code' )} – { data .get ('message' , '' )} "
125+ )
126+
127+ results = data .get ("output" , {}).get ("results" , [])
128+ if not results :
129+ logger .warning (f"百炼 Rerank 返回空结果: { data } " )
130+ return []
131+
132+ # 转换为RerankResult对象,使用.get()避免KeyError
133+ rerank_results = []
134+ for idx , result in enumerate (results ):
135+ try :
136+ index = result .get ("index" , idx )
137+ relevance_score = result .get ("relevance_score" , 0.0 )
138+
139+ if relevance_score is None :
140+ logger .warning (f"结果 { idx } 缺少 relevance_score,使用默认值 0.0" )
141+ relevance_score = 0.0
142+
143+ rerank_result = RerankResult (
144+ index = index , relevance_score = relevance_score
145+ )
146+ rerank_results .append (rerank_result )
147+ except Exception as e :
148+ logger .warning (f"解析结果 { idx } 时出错: { e } , result={ result } " )
149+ continue
150+
151+ return rerank_results
152+
153+ def _log_usage (self , data : dict ) -> None :
154+ """记录使用量信息
155+
156+ Args:
157+ data: API响应数据
158+ """
159+ tokens = data .get ("usage" , {}).get ("total_tokens" , 0 )
160+ if tokens > 0 :
161+ logger .debug (f"百炼 Rerank 消耗 Token: { tokens } " )
162+
51163 async def rerank (
52164 self ,
53165 query : str ,
@@ -60,7 +172,7 @@ async def rerank(
60172 Args:
61173 query: 查询文本
62174 documents: 待排序的文档列表
63- top_n: 返回前N个结果,如果为None则返回所有重排序结果
175+ top_n: 返回前N个结果,如果为None则使用配置中的默认值
64176
65177 Returns:
66178 重排序结果列表
@@ -81,23 +193,8 @@ async def rerank(
81193 documents = documents [:500 ]
82194
83195 try :
84- # 构建请求载荷
85- payload = {
86- "model" : self .model ,
87- "input" : {"query" : query , "documents" : documents },
88- }
89-
90- # 添加可选参数
91- parameters = {}
92- if top_n is not None and top_n > 0 :
93- parameters ["top_n" ] = top_n
94- if self .return_documents :
95- parameters ["return_documents" ] = True
96- if self .instruct and self .model == "qwen3-rerank" :
97- parameters ["instruct" ] = self .instruct
98-
99- if parameters :
100- payload ["parameters" ] = parameters
196+ # 构建请求载荷,如果top_n为None则返回所有重排序结果
197+ payload = self ._build_payload (query , documents , top_n )
101198
102199 logger .debug (
103200 f"百炼 Rerank 请求: query='{ query [:50 ]} ...', 文档数量={ len (documents )} "
@@ -108,48 +205,24 @@ async def rerank(
108205 response .raise_for_status ()
109206 response_data = await response .json ()
110207
111- # 检查响应状态
112- if "code" in response_data and response_data ["code" ] != "200" :
113- error_msg = response_data .get ("message" , "未知错误" )
114- api_error_msg = (
115- f"百炼 API 返回错误: { response_data ['code' ]} - { error_msg } "
116- )
117- raise RuntimeError (api_error_msg )
118-
119- # 解析结果
120- output = response_data .get ("output" , {})
121- results = output .get ("results" , [])
122-
123- if not results :
124- logger .warning (f"百炼 Rerank 返回空结果: { response_data } " )
125- return []
126-
127- # 转换为RerankResult对象
128- rerank_results = []
129- for result in results :
130- rerank_result = RerankResult (
131- index = result ["index" ], relevance_score = result ["relevance_score" ]
132- )
133- rerank_results .append (rerank_result )
134-
135- logger .debug (f"百炼 Rerank 成功返回 { len (rerank_results )} 个结果" )
208+ # 解析结果并记录使用量
209+ results = self ._parse_results (response_data )
210+ self ._log_usage (response_data )
136211
137- # 记录使用量信息
138- usage = response_data .get ("usage" , {})
139- total_tokens = usage .get ("total_tokens" , 0 )
140- if total_tokens > 0 :
141- logger .debug (f"百炼 Rerank 消耗 Token 数量: { total_tokens } " )
212+ logger .debug (f"百炼 Rerank 成功返回 { len (results )} 个结果" )
142213
143- return rerank_results
214+ return results
144215
145216 except aiohttp .ClientError as e :
146217 error_msg = f"网络请求失败: { e } "
147218 logger .error (f"百炼 Rerank 网络请求失败: { e } " )
148- raise RuntimeError (error_msg ) from e
219+ raise BailianNetworkError (error_msg ) from e
220+ except BailianRerankError :
221+ raise
149222 except Exception as e :
150223 error_msg = f"重排序失败: { e } "
151224 logger .error (f"百炼 Rerank 处理失败: { e } " )
152- raise RuntimeError (error_msg ) from e
225+ raise BailianRerankError (error_msg ) from e
153226
154227 async def terminate (self ) -> None :
155228 """关闭HTTP客户端会话."""
0 commit comments