Skip to content

Commit 0260d43

Browse files
authored
Merge pull request #3706 from piexian/master
2 parents 4e2154f + 2e608cd commit 0260d43

File tree

3 files changed

+263
-0
lines changed

3 files changed

+263
-0
lines changed

astrbot/core/config/default.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1308,6 +1308,19 @@
13081308
"timeout": 20,
13091309
"launch_model_if_not_running": False,
13101310
},
1311+
"阿里云百炼重排序": {
1312+
"id": "bailian_rerank",
1313+
"type": "bailian_rerank",
1314+
"provider": "bailian",
1315+
"provider_type": "rerank",
1316+
"enable": True,
1317+
"rerank_api_key": "",
1318+
"rerank_api_base": "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank",
1319+
"rerank_model": "qwen3-rerank",
1320+
"timeout": 30,
1321+
"return_documents": False,
1322+
"instruct": "",
1323+
},
13111324
"Xinference STT": {
13121325
"id": "xinference_stt",
13131326
"type": "xinference_stt",
@@ -1342,6 +1355,16 @@
13421355
"description": "重排序模型名称",
13431356
"type": "string",
13441357
},
1358+
"return_documents": {
1359+
"description": "是否在排序结果中返回文档原文",
1360+
"type": "bool",
1361+
"hint": "默认值false,以减少网络传输开销。",
1362+
},
1363+
"instruct": {
1364+
"description": "自定义排序任务类型说明",
1365+
"type": "string",
1366+
"hint": "仅在使用 qwen3-rerank 模型时生效。建议使用英文撰写。",
1367+
},
13451368
"launch_model_if_not_running": {
13461369
"description": "模型未运行时自动启动",
13471370
"type": "bool",

astrbot/core/provider/manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,10 @@ async def load_provider(self, provider_config: dict):
331331
from .sources.xinference_rerank_source import (
332332
XinferenceRerankProvider as XinferenceRerankProvider,
333333
)
334+
case "bailian_rerank":
335+
from .sources.bailian_rerank_source import (
336+
BailianRerankProvider as BailianRerankProvider,
337+
)
334338
except (ImportError, ModuleNotFoundError) as e:
335339
logger.critical(
336340
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。",
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
import os
2+
3+
import aiohttp
4+
5+
from astrbot import logger
6+
7+
from ..entities import ProviderType, RerankResult
8+
from ..provider import RerankProvider
9+
from ..register import register_provider_adapter
10+
11+
12+
class BailianRerankError(Exception):
13+
"""百炼重排序服务异常基类"""
14+
15+
pass
16+
17+
18+
class BailianAPIError(BailianRerankError):
19+
"""百炼API返回错误"""
20+
21+
pass
22+
23+
24+
class BailianNetworkError(BailianRerankError):
25+
"""百炼网络请求错误"""
26+
27+
pass
28+
29+
30+
@register_provider_adapter(
31+
"bailian_rerank", "阿里云百炼文本排序适配器", provider_type=ProviderType.RERANK
32+
)
33+
class BailianRerankProvider(RerankProvider):
34+
"""阿里云百炼文本重排序适配器."""
35+
36+
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
37+
super().__init__(provider_config, provider_settings)
38+
self.provider_config = provider_config
39+
self.provider_settings = provider_settings
40+
41+
# API配置
42+
self.api_key = provider_config.get("rerank_api_key") or os.getenv(
43+
"DASHSCOPE_API_KEY", ""
44+
)
45+
if not self.api_key:
46+
raise ValueError("阿里云百炼 API Key 不能为空。")
47+
48+
self.model = provider_config.get("rerank_model", "qwen3-rerank")
49+
self.timeout = provider_config.get("timeout", 30)
50+
self.return_documents = provider_config.get("return_documents", False)
51+
self.instruct = provider_config.get("instruct", "")
52+
53+
self.base_url = provider_config.get(
54+
"rerank_api_base",
55+
"https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank",
56+
)
57+
58+
# 设置HTTP客户端
59+
headers = {
60+
"Authorization": f"Bearer {self.api_key}",
61+
"Content-Type": "application/json",
62+
}
63+
64+
self.client = aiohttp.ClientSession(
65+
headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout)
66+
)
67+
68+
# 设置模型名称
69+
self.set_model(self.model)
70+
71+
logger.info(f"AstrBot 百炼 Rerank 初始化完成。模型: {self.model}")
72+
73+
def _build_payload(
74+
self, query: str, documents: list[str], top_n: int | None
75+
) -> dict:
76+
"""构建请求载荷
77+
78+
Args:
79+
query: 查询文本
80+
documents: 文档列表
81+
top_n: 返回前N个结果,如果为None则返回所有结果
82+
83+
Returns:
84+
请求载荷字典
85+
"""
86+
base = {"model": self.model, "input": {"query": query, "documents": documents}}
87+
88+
params = {
89+
k: v
90+
for k, v in [
91+
("top_n", top_n if top_n is not None and top_n > 0 else None),
92+
("return_documents", True if self.return_documents else None),
93+
(
94+
"instruct",
95+
self.instruct
96+
if self.instruct and self.model == "qwen3-rerank"
97+
else None,
98+
),
99+
]
100+
if v is not None
101+
}
102+
103+
if params:
104+
base["parameters"] = params
105+
106+
return base
107+
108+
def _parse_results(self, data: dict) -> list[RerankResult]:
109+
"""解析API响应结果
110+
111+
Args:
112+
data: API响应数据
113+
114+
Returns:
115+
重排序结果列表
116+
117+
Raises:
118+
BailianAPIError: API返回错误
119+
KeyError: 结果缺少必要字段
120+
"""
121+
# 检查响应状态
122+
if data.get("code", "200") != "200":
123+
raise BailianAPIError(
124+
f"百炼 API 错误: {data.get('code')}{data.get('message', '')}"
125+
)
126+
127+
results = data.get("output", {}).get("results", [])
128+
if not results:
129+
logger.warning(f"百炼 Rerank 返回空结果: {data}")
130+
return []
131+
132+
# 转换为RerankResult对象,使用.get()避免KeyError
133+
rerank_results = []
134+
for idx, result in enumerate(results):
135+
try:
136+
index = result.get("index", idx)
137+
relevance_score = result.get("relevance_score", 0.0)
138+
139+
if relevance_score is None:
140+
logger.warning(f"结果 {idx} 缺少 relevance_score,使用默认值 0.0")
141+
relevance_score = 0.0
142+
143+
rerank_result = RerankResult(
144+
index=index, relevance_score=relevance_score
145+
)
146+
rerank_results.append(rerank_result)
147+
except Exception as e:
148+
logger.warning(f"解析结果 {idx} 时出错: {e}, result={result}")
149+
continue
150+
151+
return rerank_results
152+
153+
def _log_usage(self, data: dict) -> None:
154+
"""记录使用量信息
155+
156+
Args:
157+
data: API响应数据
158+
"""
159+
tokens = data.get("usage", {}).get("total_tokens", 0)
160+
if tokens > 0:
161+
logger.debug(f"百炼 Rerank 消耗 Token: {tokens}")
162+
163+
async def rerank(
164+
self,
165+
query: str,
166+
documents: list[str],
167+
top_n: int | None = None,
168+
) -> list[RerankResult]:
169+
"""
170+
对文档进行重排序
171+
172+
Args:
173+
query: 查询文本
174+
documents: 待排序的文档列表
175+
top_n: 返回前N个结果,如果为None则使用配置中的默认值
176+
177+
Returns:
178+
重排序结果列表
179+
"""
180+
if not documents:
181+
logger.warning("文档列表为空,返回空结果")
182+
return []
183+
184+
if not query.strip():
185+
logger.warning("查询文本为空,返回空结果")
186+
return []
187+
188+
# 检查限制
189+
if len(documents) > 500:
190+
logger.warning(
191+
f"文档数量({len(documents)})超过限制(500),将截断前500个文档"
192+
)
193+
documents = documents[:500]
194+
195+
try:
196+
# 构建请求载荷,如果top_n为None则返回所有重排序结果
197+
payload = self._build_payload(query, documents, top_n)
198+
199+
logger.debug(
200+
f"百炼 Rerank 请求: query='{query[:50]}...', 文档数量={len(documents)}"
201+
)
202+
203+
# 发送请求
204+
async with self.client.post(self.base_url, json=payload) as response:
205+
response.raise_for_status()
206+
response_data = await response.json()
207+
208+
# 解析结果并记录使用量
209+
results = self._parse_results(response_data)
210+
self._log_usage(response_data)
211+
212+
logger.debug(f"百炼 Rerank 成功返回 {len(results)} 个结果")
213+
214+
return results
215+
216+
except aiohttp.ClientError as e:
217+
error_msg = f"网络请求失败: {e}"
218+
logger.error(f"百炼 Rerank 网络请求失败: {e}")
219+
raise BailianNetworkError(error_msg) from e
220+
except BailianRerankError:
221+
raise
222+
except Exception as e:
223+
error_msg = f"重排序失败: {e}"
224+
logger.error(f"百炼 Rerank 处理失败: {e}")
225+
raise BailianRerankError(error_msg) from e
226+
227+
async def terminate(self) -> None:
228+
"""关闭HTTP客户端会话."""
229+
if self.client:
230+
logger.info("关闭 百炼 Rerank 客户端会话")
231+
try:
232+
await self.client.close()
233+
except Exception as e:
234+
logger.error(f"关闭 百炼 Rerank 客户端时出错: {e}")
235+
finally:
236+
self.client = None

0 commit comments

Comments
 (0)