Skip to content

Commit 2e608cd

Browse files
committed
refactor(bailian_rerank): 修复误删除并优化top_n参数处理
- 移除不合理的知识库配置读取逻辑 - 添加os模块导入(用于读取环境变量) - 抽取辅助函数:_build_payload()、_parse_results()、_log_usage() - 添加自定义异常类:BailianRerankError、BailianAPIError、BailianNetworkError - 使用.get()安全访问API响应字段,避免KeyError - 使用raise ... from e保持异常链
1 parent 234ce93 commit 2e608cd

File tree

1 file changed

+125
-52
lines changed

1 file changed

+125
-52
lines changed

astrbot/core/provider/sources/bailian_rerank_source.py

Lines changed: 125 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
13
import aiohttp
24

35
from astrbot import logger
@@ -7,6 +9,24 @@
79
from ..register import register_provider_adapter
810

911

12+
class BailianRerankError(Exception):
13+
"""百炼重排序服务异常基类"""
14+
15+
pass
16+
17+
18+
class BailianAPIError(BailianRerankError):
19+
"""百炼API返回错误"""
20+
21+
pass
22+
23+
24+
class BailianNetworkError(BailianRerankError):
25+
"""百炼网络请求错误"""
26+
27+
pass
28+
29+
1030
@register_provider_adapter(
1131
"bailian_rerank", "阿里云百炼文本排序适配器", provider_type=ProviderType.RERANK
1232
)
@@ -19,7 +39,9 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None:
1939
self.provider_settings = provider_settings
2040

2141
# API配置
22-
self.api_key = provider_config.get("rerank_api_key", "")
42+
self.api_key = provider_config.get("rerank_api_key") or os.getenv(
43+
"DASHSCOPE_API_KEY", ""
44+
)
2345
if not self.api_key:
2446
raise ValueError("阿里云百炼 API Key 不能为空。")
2547

@@ -48,6 +70,96 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None:
4870

4971
logger.info(f"AstrBot 百炼 Rerank 初始化完成。模型: {self.model}")
5072

73+
def _build_payload(
74+
self, query: str, documents: list[str], top_n: int | None
75+
) -> dict:
76+
"""构建请求载荷
77+
78+
Args:
79+
query: 查询文本
80+
documents: 文档列表
81+
top_n: 返回前N个结果,如果为None则返回所有结果
82+
83+
Returns:
84+
请求载荷字典
85+
"""
86+
base = {"model": self.model, "input": {"query": query, "documents": documents}}
87+
88+
params = {
89+
k: v
90+
for k, v in [
91+
("top_n", top_n if top_n is not None and top_n > 0 else None),
92+
("return_documents", True if self.return_documents else None),
93+
(
94+
"instruct",
95+
self.instruct
96+
if self.instruct and self.model == "qwen3-rerank"
97+
else None,
98+
),
99+
]
100+
if v is not None
101+
}
102+
103+
if params:
104+
base["parameters"] = params
105+
106+
return base
107+
108+
def _parse_results(self, data: dict) -> list[RerankResult]:
109+
"""解析API响应结果
110+
111+
Args:
112+
data: API响应数据
113+
114+
Returns:
115+
重排序结果列表
116+
117+
Raises:
118+
BailianAPIError: API返回错误
119+
KeyError: 结果缺少必要字段
120+
"""
121+
# 检查响应状态
122+
if data.get("code", "200") != "200":
123+
raise BailianAPIError(
124+
f"百炼 API 错误: {data.get('code')}{data.get('message', '')}"
125+
)
126+
127+
results = data.get("output", {}).get("results", [])
128+
if not results:
129+
logger.warning(f"百炼 Rerank 返回空结果: {data}")
130+
return []
131+
132+
# 转换为RerankResult对象,使用.get()避免KeyError
133+
rerank_results = []
134+
for idx, result in enumerate(results):
135+
try:
136+
index = result.get("index", idx)
137+
relevance_score = result.get("relevance_score", 0.0)
138+
139+
if relevance_score is None:
140+
logger.warning(f"结果 {idx} 缺少 relevance_score,使用默认值 0.0")
141+
relevance_score = 0.0
142+
143+
rerank_result = RerankResult(
144+
index=index, relevance_score=relevance_score
145+
)
146+
rerank_results.append(rerank_result)
147+
except Exception as e:
148+
logger.warning(f"解析结果 {idx} 时出错: {e}, result={result}")
149+
continue
150+
151+
return rerank_results
152+
153+
def _log_usage(self, data: dict) -> None:
154+
"""记录使用量信息
155+
156+
Args:
157+
data: API响应数据
158+
"""
159+
tokens = data.get("usage", {}).get("total_tokens", 0)
160+
if tokens > 0:
161+
logger.debug(f"百炼 Rerank 消耗 Token: {tokens}")
162+
51163
async def rerank(
52164
self,
53165
query: str,
@@ -60,7 +172,7 @@ async def rerank(
60172
Args:
61173
query: 查询文本
62174
documents: 待排序的文档列表
63-
top_n: 返回前N个结果,如果为None则返回所有重排序结果
175+
top_n: 返回前N个结果,如果为None则使用配置中的默认值
64176
65177
Returns:
66178
重排序结果列表
@@ -81,23 +193,8 @@ async def rerank(
81193
documents = documents[:500]
82194

83195
try:
84-
# 构建请求载荷
85-
payload = {
86-
"model": self.model,
87-
"input": {"query": query, "documents": documents},
88-
}
89-
90-
# 添加可选参数
91-
parameters = {}
92-
if top_n is not None and top_n > 0:
93-
parameters["top_n"] = top_n
94-
if self.return_documents:
95-
parameters["return_documents"] = True
96-
if self.instruct and self.model == "qwen3-rerank":
97-
parameters["instruct"] = self.instruct
98-
99-
if parameters:
100-
payload["parameters"] = parameters
196+
# 构建请求载荷,如果top_n为None则返回所有重排序结果
197+
payload = self._build_payload(query, documents, top_n)
101198

102199
logger.debug(
103200
f"百炼 Rerank 请求: query='{query[:50]}...', 文档数量={len(documents)}"
@@ -108,48 +205,24 @@ async def rerank(
108205
response.raise_for_status()
109206
response_data = await response.json()
110207

111-
# 检查响应状态
112-
if "code" in response_data and response_data["code"] != "200":
113-
error_msg = response_data.get("message", "未知错误")
114-
api_error_msg = (
115-
f"百炼 API 返回错误: {response_data['code']} - {error_msg}"
116-
)
117-
raise RuntimeError(api_error_msg)
118-
119-
# 解析结果
120-
output = response_data.get("output", {})
121-
results = output.get("results", [])
122-
123-
if not results:
124-
logger.warning(f"百炼 Rerank 返回空结果: {response_data}")
125-
return []
126-
127-
# 转换为RerankResult对象
128-
rerank_results = []
129-
for result in results:
130-
rerank_result = RerankResult(
131-
index=result["index"], relevance_score=result["relevance_score"]
132-
)
133-
rerank_results.append(rerank_result)
134-
135-
logger.debug(f"百炼 Rerank 成功返回 {len(rerank_results)} 个结果")
208+
# 解析结果并记录使用量
209+
results = self._parse_results(response_data)
210+
self._log_usage(response_data)
136211

137-
# 记录使用量信息
138-
usage = response_data.get("usage", {})
139-
total_tokens = usage.get("total_tokens", 0)
140-
if total_tokens > 0:
141-
logger.debug(f"百炼 Rerank 消耗 Token 数量: {total_tokens}")
212+
logger.debug(f"百炼 Rerank 成功返回 {len(results)} 个结果")
142213

143-
return rerank_results
214+
return results
144215

145216
except aiohttp.ClientError as e:
146217
error_msg = f"网络请求失败: {e}"
147218
logger.error(f"百炼 Rerank 网络请求失败: {e}")
148-
raise RuntimeError(error_msg) from e
219+
raise BailianNetworkError(error_msg) from e
220+
except BailianRerankError:
221+
raise
149222
except Exception as e:
150223
error_msg = f"重排序失败: {e}"
151224
logger.error(f"百炼 Rerank 处理失败: {e}")
152-
raise RuntimeError(error_msg) from e
225+
raise BailianRerankError(error_msg) from e
153226

154227
async def terminate(self) -> None:
155228
"""关闭HTTP客户端会话."""

0 commit comments

Comments
 (0)