Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions astrbot/core/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,6 +1308,19 @@
"timeout": 20,
"launch_model_if_not_running": False,
},
"阿里云百炼重排序": {
"id": "bailian_rerank",
"type": "bailian_rerank",
"provider": "bailian",
"provider_type": "rerank",
"enable": True,
"rerank_api_key": "",
"rerank_api_base": "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank",
"rerank_model": "qwen3-rerank",
"timeout": 30,
"return_documents": False,
"instruct": "",
},
"Xinference STT": {
"id": "xinference_stt",
"type": "xinference_stt",
Expand Down Expand Up @@ -1342,6 +1355,16 @@
"description": "重排序模型名称",
"type": "string",
},
"return_documents": {
"description": "是否在排序结果中返回文档原文",
"type": "bool",
"hint": "默认值false,以减少网络传输开销。",
},
"instruct": {
"description": "自定义排序任务类型说明",
"type": "string",
"hint": "仅在使用 qwen3-rerank 模型时生效。建议使用英文撰写。",
},
"launch_model_if_not_running": {
"description": "模型未运行时自动启动",
"type": "bool",
Expand Down
4 changes: 4 additions & 0 deletions astrbot/core/provider/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,10 @@ async def load_provider(self, provider_config: dict):
from .sources.xinference_rerank_source import (
XinferenceRerankProvider as XinferenceRerankProvider,
)
case "bailian_rerank":
from .sources.bailian_rerank_source import (
BailianRerankProvider as BailianRerankProvider,
)
except (ImportError, ModuleNotFoundError) as e:
logger.critical(
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。",
Expand Down
174 changes: 174 additions & 0 deletions astrbot/core/provider/sources/bailian_rerank_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import os

import aiohttp

from astrbot import logger

from ..entities import ProviderType, RerankResult
from ..provider import RerankProvider
from ..register import register_provider_adapter


@register_provider_adapter(
"bailian_rerank", "阿里云百炼文本排序适配器", provider_type=ProviderType.RERANK
)
class BailianRerankProvider(RerankProvider):
"""阿里云百炼文本重排序适配器"""

def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config, provider_settings)
self.provider_config = provider_config
self.provider_settings = provider_settings

# API配置
self.api_key = provider_config.get("rerank_api_key", "")
if not self.api_key:
self.api_key = os.getenv("DASHSCOPE_API_KEY", "")

if not self.api_key:
raise ValueError(
"阿里云百炼 API Key 不能为空,请在配置中设置 rerank_api_key 或设置环境变量 DASHSCOPE_API_KEY"
)

self.model = provider_config.get("rerank_model", "qwen3-rerank")
self.timeout = provider_config.get("timeout", 30)
# 自动读取知识库配置的 kb_final_top_k,如果没有则使用配置中的 top_n
self.default_top_n = provider_settings.get(
"kb_final_top_k"
) or provider_config.get("top_n", 5)
self.return_documents = provider_config.get("return_documents", False)
self.instruct = provider_config.get("instruct", "")

self.base_url = provider_config.get(
"rerank_api_base",
"https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank",
)

# 设置HTTP客户端
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}

self.client = aiohttp.ClientSession(
headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout)
)

# 设置模型名称
self.set_model(self.model)

logger.info(f"AstrBot 百炼 Rerank 初始化完成。模型: {self.model}")

async def rerank(
self,
query: str,
documents: list[str],
top_n: int | None = None,
) -> list[RerankResult]:
"""
对文档进行重排序

Args:
query: 查询文本
documents: 待排序的文档列表
top_n: 返回前N个结果,如果为None则使用配置中的默认值

Returns:
重排序结果列表
"""
if not documents:
logger.warning("文档列表为空,返回空结果")
return []

if not query.strip():
logger.warning("查询文本为空,返回空结果")
return []

# 检查限制
if len(documents) > 500:
logger.warning(
f"文档数量({len(documents)})超过限制(500),将截断前500个文档"
)
documents = documents[:500]

# 优先使用传入的top_n参数(来自知识库配置),如果没有才使用默认配置
final_top_n = top_n if top_n is not None else self.default_top_n

try:
# 构建请求载荷
payload = {
"model": self.model,
"input": {"query": query, "documents": documents},
}

# 添加可选参数
parameters = {}
if final_top_n is not None:
parameters["top_n"] = final_top_n
if self.return_documents:
parameters["return_documents"] = True
if self.instruct and self.model == "qwen3-rerank":
parameters["instruct"] = self.instruct

if parameters:
payload["parameters"] = parameters

logger.debug(
f"百炼 Rerank 请求: query='{query[:50]}...', 文档数量={len(documents)}"
)

# 发送请求
async with self.client.post(self.base_url, json=payload) as response:
response.raise_for_status()
response_data = await response.json()

# 检查响应状态
if "code" in response_data and response_data["code"] != "200":
error_msg = response_data.get("message", "未知错误")
raise Exception(
f"百炼 API 返回错误: {response_data['code']} - {error_msg}"
)

# 解析结果
output = response_data.get("output", {})
results = output.get("results", [])

if not results:
logger.warning(f"百炼 Rerank 返回空结果: {response_data}")
return []

# 转换为RerankResult对象
rerank_results = []
for result in results:
rerank_result = RerankResult(
index=result["index"], relevance_score=result["relevance_score"]
)
rerank_results.append(rerank_result)

logger.debug(f"百炼 Rerank 成功返回 {len(rerank_results)} 个结果")

# 记录使用量信息
usage = response_data.get("usage", {})
total_tokens = usage.get("total_tokens", 0)
if total_tokens > 0:
logger.debug(f"百炼 Rerank 消耗 Token 数量: {total_tokens}")

return rerank_results

except aiohttp.ClientError as e:
logger.error(f"百炼 Rerank 网络请求失败: {e}")
raise Exception(f"网络请求失败: {e}")
except Exception as e:
logger.error(f"百炼 Rerank 处理失败: {e}")
raise Exception(f"重排序失败: {e}")

async def terminate(self) -> None:
"""关闭HTTP客户端会话"""
if self.client:
logger.info("关闭 百炼 Rerank 客户端会话")
try:
await self.client.close()
except Exception as e:
logger.error(f"关闭 百炼 Rerank 客户端时出错: {e}")
finally:
self.client = None