diff --git a/backend/agents/create_agent_info.py b/backend/agents/create_agent_info.py index bc4031e0a..0e1ed4d96 100644 --- a/backend/agents/create_agent_info.py +++ b/backend/agents/create_agent_info.py @@ -14,6 +14,7 @@ ElasticSearchService, get_vector_db_core, get_embedding_model, + get_rerank_model, ) from services.remote_mcp_service import get_remote_mcp_server_list from services.memory_config_service import build_memory_context @@ -350,11 +351,32 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int tool_config.metadata = langchain_tool break - # special logic for knowledge base search tool + # special logic for search tools that may use reranking models if tool_config.class_name == "KnowledgeBaseSearchTool": - tool_config.metadata = { + rerank = param_dict.get("rerank", False) + rerank_model_name = param_dict.get("rerank_model_name", "") + rerank_model = None + if rerank and rerank_model_name: + rerank_model = get_rerank_model( + tenant_id=tenant_id, model_name=rerank_model_name + ) + + tool_config.metadata = { "vdb_core": get_vector_db_core(), "embedding_model": get_embedding_model(tenant_id=tenant_id), + "rerank_model": rerank_model, + } + elif tool_config.class_name in ["DifySearchTool", "DataMateSearchTool"]: + rerank = param_dict.get("rerank", False) + rerank_model_name = param_dict.get("rerank_model_name", "") + rerank_model = None + if rerank and rerank_model_name: + rerank_model = get_rerank_model( + tenant_id=tenant_id, model_name=rerank_model_name + ) + + tool_config.metadata = { + "rerank_model": rerank_model, } elif tool_config.class_name == "AnalyzeTextFileTool": tool_config.metadata = { diff --git a/backend/services/model_health_service.py b/backend/services/model_health_service.py index 78f6413ee..9214a1ffa 100644 --- a/backend/services/model_health_service.py +++ b/backend/services/model_health_service.py @@ -3,6 +3,7 @@ from nexent.core import MessageObserver from nexent.core.models import OpenAIModel, OpenAIVLModel from nexent.core.models.embedding_model import JinaEmbedding, OpenAICompatibleEmbedding +from nexent.core.models.rerank_model import OpenAICompatibleRerank from services.voice_service import get_voice_service from consts.const import LOCALHOST_IP, LOCALHOST_NAME, DOCKER_INTERNAL_HOST @@ -102,7 +103,13 @@ async def _perform_connectivity_check( ssl_verify=ssl_verify ).check_connectivity() elif model_type == "rerank": - connectivity = False + rerank_model = OpenAICompatibleRerank( + model_name=model_name, + base_url=model_base_url, + api_key=model_api_key, + ssl_verify=ssl_verify, + ) + connectivity = await rerank_model.connectivity_check() elif model_type == "vlm": observer = MessageObserver() connectivity = await OpenAIVLModel( diff --git a/backend/services/model_provider_service.py b/backend/services/model_provider_service.py index 8c397dc70..dbff17082 100644 --- a/backend/services/model_provider_service.py +++ b/backend/services/model_provider_service.py @@ -132,6 +132,11 @@ async def prepare_model_dict(provider: str, model: dict, model_url: str, model_a model_dict["base_url"] = f"{model_url.rstrip('/')}/{MODEL_ENGINE_NORTH_PREFIX}/embeddings" # The embedding dimension might differ from the provided max_tokens. model_dict["max_tokens"] = await embedding_dimension_check(model_dict) + elif model["model_type"] == "rerank": + if provider == ProviderEnum.DASHSCOPE.value: + model_dict["base_url"] = f"{model_url.replace('compatible-mode/v1','api/v1').rstrip('/')}/services/rerank/text-rerank/text-rerank" + else: + model_dict["base_url"] = f"{model_url.rstrip('/')}/rerank" else: # For non-embedding models if provider == ProviderEnum.MODELENGINE.value: diff --git a/backend/services/providers/dashscope_provider.py b/backend/services/providers/dashscope_provider.py index 4ecbcbb1d..b9fb7ab7b 100644 --- a/backend/services/providers/dashscope_provider.py +++ b/backend/services/providers/dashscope_provider.py @@ -58,7 +58,7 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: "chat": [], # Maps to "llm" "vlm": [], # Maps to "vlm" "embedding": [], # Maps to "embedding" / "multi_embedding" - "reranker": [], # Maps to "reranker" + "rerank": [], # Maps to "rerank" "tts": [], # Maps to "tts" "stt": [] # Maps to "stt" } @@ -88,10 +88,10 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: categorized_models['embedding'].append(cleaned_model) continue - # 2. Reranker + # 2. Rerank if 'rerank' in m_id.lower() or '重排序' in desc: - cleaned_model.update({"model_tag": "reranker", "model_type": "reranker"}) - categorized_models['reranker'].append(cleaned_model) + cleaned_model.update({"model_tag": "rerank", "model_type": "rerank"}) + categorized_models['rerank'].append(cleaned_model) continue # 3. STT diff --git a/backend/services/providers/silicon_provider.py b/backend/services/providers/silicon_provider.py index 29de51fce..ea41cc95d 100644 --- a/backend/services/providers/silicon_provider.py +++ b/backend/services/providers/silicon_provider.py @@ -30,6 +30,8 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: silicon_url = f"{SILICON_GET_URL}?sub_type=chat" elif model_type in ("embedding", "multi_embedding"): silicon_url = f"{SILICON_GET_URL}?sub_type=embedding" + elif model_type == "rerank": + silicon_url = f"{SILICON_GET_URL}?sub_type=reranker" else: silicon_url = SILICON_GET_URL @@ -48,6 +50,10 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: for item in model_list: item["model_tag"] = "embedding" item["model_type"] = model_type + elif model_type == "rerank": + for item in model_list: + item["model_tag"] = "rerank" + item["model_type"] = model_type # Return empty list to indicate successful API call but no models if not model_list: diff --git a/backend/services/providers/tokenpony_provider.py b/backend/services/providers/tokenpony_provider.py index 42e5d178c..ab4446c1b 100644 --- a/backend/services/providers/tokenpony_provider.py +++ b/backend/services/providers/tokenpony_provider.py @@ -47,7 +47,7 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: "chat": [], # Maps to "llm" "vlm": [], # Maps to "vlm" "embedding": [], # Maps to "embedding" / "multi_embedding" - "reranker": [], # Maps to "reranker" + "rerank": [], # Maps to "rerank" "tts": [], # Maps to "tts" "stt": [] # Maps to "stt" } @@ -66,10 +66,10 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: "model_type": "", "max_tokens": DEFAULT_LLM_MAX_TOKENS } - # 1. reranker + # 1. rerank if 'rerank' in m_id: - cleaned_model.update({"model_tag": "reranker", "model_type": "reranker"}) - categorized_models['reranker'].append(cleaned_model) + cleaned_model.update({"model_tag": "rerank", "model_type": "rerank"}) + categorized_models['rerank'].append(cleaned_model) #2. embedding elif 'embedding' in m_id or m_id.startswith('bge-'): cleaned_model.update({"model_tag": "embedding", "model_type": "embedding"}) diff --git a/backend/services/tool_configuration_service.py b/backend/services/tool_configuration_service.py index a0f5b2399..9653b2e10 100644 --- a/backend/services/tool_configuration_service.py +++ b/backend/services/tool_configuration_service.py @@ -28,7 +28,7 @@ check_tool_list_initialized, ) from services.file_management_service import get_llm_model -from services.vectordatabase_service import get_embedding_model, get_vector_db_core +from services.vectordatabase_service import get_embedding_model, get_rerank_model, get_vector_db_core from database.client import minio_client from services.image_service import get_vlm_model from utils.tool_utils import get_local_tools_classes, get_local_tools_description_zh @@ -694,10 +694,32 @@ def _validate_local_tool( if tool_name == "knowledge_base_search": embedding_model = get_embedding_model(tenant_id=tenant_id) vdb_core = get_vector_db_core() + + # Get rerank configuration + rerank = instantiation_params.get("rerank", False) + rerank_model_name = instantiation_params.get("rerank_model_name", "") + rerank_model = None + if rerank and rerank_model_name: + rerank_model = get_rerank_model(tenant_id=tenant_id, model_name=rerank_model_name) + params = { **instantiation_params, 'vdb_core': vdb_core, 'embedding_model': embedding_model, + 'rerank_model': rerank_model, + } + tool_instance = tool_class(**params) + elif tool_name in ["dify_search", "datamate_search"]: + # Get rerank configuration for dify and datamate search tools + rerank = instantiation_params.get("rerank", False) + rerank_model_name = instantiation_params.get("rerank_model_name", "") + rerank_model = None + if rerank and rerank_model_name: + rerank_model = get_rerank_model(tenant_id=tenant_id, model_name=rerank_model_name) + + params = { + **instantiation_params, + 'rerank_model': rerank_model, } tool_instance = tool_class(**params) elif tool_name == "analyze_image": diff --git a/backend/services/vectordatabase_service.py b/backend/services/vectordatabase_service.py index de79c812c..cf8f7f98c 100644 --- a/backend/services/vectordatabase_service.py +++ b/backend/services/vectordatabase_service.py @@ -21,6 +21,7 @@ from fastapi import Body, Depends, Path, Query from fastapi.responses import StreamingResponse from nexent.core.models.embedding_model import OpenAICompatibleEmbedding, JinaEmbedding, BaseEmbedding +from nexent.core.models.rerank_model import OpenAICompatibleRerank, BaseRerank from nexent.vector_database.base import VectorDatabaseCore from nexent.vector_database.elasticsearch_core import ElasticSearchCore from nexent.vector_database.datamate_core import DataMateCore @@ -241,6 +242,52 @@ def get_embedding_model(tenant_id: str, model_name: Optional[str] = None): return None +def get_rerank_model(tenant_id: str, model_name: Optional[str] = None): + """ + Get the rerank model for the tenant, optionally using a specific model name. + + Args: + tenant_id: Tenant ID + model_name: Optional specific model name to use (format: "model_repo/model_name" or just "model_name") + If provided, will try to find the model in the tenant's model list. + + Returns: + Rerank model instance or None + """ + # If model_name is provided, try to find it in the tenant's models + if model_name: + try: + models = get_model_records({"model_type": "rerank"}, tenant_id) + for model in models: + model_display_name = model.get("model_repo") + "/" + model["model_name"] if model.get("model_repo") else model["model_name"] + if model_display_name == model_name: + # Found the model, create rerank model instance + return OpenAICompatibleRerank( + model_name=get_model_name_from_config(model) or "", + base_url=model.get("base_url", ""), + api_key=model.get("api_key", ""), + ssl_verify=model.get("ssl_verify", True), + ) + except Exception as e: + logger.warning(f"Failed to get rerank model by name {model_name}: {e}") + + # Fall back to default rerank model + model_config = tenant_config_manager.get_model_config( + key="RERANK_ID", tenant_id=tenant_id) + + model_type = model_config.get("model_type", "") + + if model_type == "rerank": + return OpenAICompatibleRerank( + model_name=get_model_name_from_config(model_config) or "", + base_url=model_config.get("base_url", ""), + api_key=model_config.get("api_key", ""), + ssl_verify=model_config.get("ssl_verify", True), + ) + else: + return None + + class ElasticSearchService: @staticmethod async def full_delete_knowledge_base(index_name: str, vdb_core: VectorDatabaseCore, user_id: str): diff --git a/doc/docs/zh/user-guide/agent-development.md b/doc/docs/zh/user-guide/agent-development.md index cb4b4055d..67d3c8311 100644 --- a/doc/docs/zh/user-guide/agent-development.md +++ b/doc/docs/zh/user-guide/agent-development.md @@ -130,10 +130,11 @@ - 检索的模式 `search_mode`(默认为 `hybrid`) - 目标检索的知识库列表 `index_names`,如 `["医疗", "维生素知识大全"]` - 若不输入 `index_names`,则默认检索知识库页面所选中的全部知识库 + - 是否启用重排模型(默认为 `false`),启用后配置重排模型,实现对检索结果的重排优化 6. 输入完成后点击"执行测试"开始测试,并在下方查看测试结果
- +
## 📝 描述业务逻辑 diff --git a/doc/docs/zh/user-guide/assets/agent-development/tool-test-run-1.png b/doc/docs/zh/user-guide/assets/agent-development/tool-test-run-1.png new file mode 100644 index 000000000..e0cb534f2 Binary files /dev/null and b/doc/docs/zh/user-guide/assets/agent-development/tool-test-run-1.png differ diff --git a/doc/docs/zh/user-guide/assets/model-management/select-model-4.png b/doc/docs/zh/user-guide/assets/model-management/select-model-4.png new file mode 100644 index 000000000..78ed60633 Binary files /dev/null and b/doc/docs/zh/user-guide/assets/model-management/select-model-4.png differ diff --git a/doc/docs/zh/user-guide/local-tools/search-tools.md b/doc/docs/zh/user-guide/local-tools/search-tools.md index 4b71833c3..9c0ded771 100644 --- a/doc/docs/zh/user-guide/local-tools/search-tools.md +++ b/doc/docs/zh/user-guide/local-tools/search-tools.md @@ -31,6 +31,8 @@ title: 搜索工具 - `query`:检索问题,必填。 - `search_mode`:`hybrid`(默认,混合召回)、`accurate`(文本模糊匹配)、`semantic`(向量语义)。 - `index_names`:指定要搜索的知识库名称列表(可用用户侧名称或内部索引名),可选。 + - `enable_rerank`:是否启用重排序,默认 False。开启后会对检索结果进行二次排序,提升结果相关性。 + - `rerank_model`:重排序使用的模型,默认为系统配置的 rerank 模型。`enable_rerank` 为 True 时生效。 - 返回匹配片段的标题、路径/URL、来源类型、得分等。 - 若未选择知识库,会提示"无可用知识库"。 @@ -44,6 +46,8 @@ title: 搜索工具 - `threshold`:相似度阈值,默认 0.2。 - `index_names`:指定要搜索的知识库名称列表,可选。 - `kb_page` / `kb_page_size`:分页获取 DataMate 知识库列表。 + - `enable_rerank`:是否启用重排序,默认 False。开启后会对检索结果进行二次排序,提升结果相关性。 + - `rerank_model`:重排序使用的模型,默认为系统配置的 rerank 模型。`enable_rerank` 为 True 时生效。 - 返回包含文件名、下载链接、得分等结构化结果。 ### dify_search @@ -58,6 +62,8 @@ title: 搜索工具 - **检索参数**: - `query`:检索问题,必填。 - `search_method`:搜索方法,选项:`keyword_search`、`semantic_search`、`full_text_search`、`hybrid_search`,默认 `semantic_search`。 + - `enable_rerank`:是否启用重排序,默认 False。开启后会对检索结果进行二次排序,提升结果相关性。 + - `rerank_model`:重排序使用的模型,默认为系统配置的 rerank 模型。`enable_rerank` 为 True 时生效。 - 返回匹配片段的标题、内容、得分等。 ### exa_search / tavily_search / linkup_search @@ -79,7 +85,8 @@ title: 搜索工具 1. **选择数据源**:私有资料用 `knowledge_base_search`、`datamate_search` 或 `dify_search`;实时公开信息用 Exa/Tavily/Linkup。 2. **设置检索模式/数量**:知识库可在 `search_mode` 之间切换;公网搜索可调整 `max_results` 与是否启用图片过滤。 3. **限定范围**:需要特定知识库时填写 `index_names`,避免无关结果;DataMate 可通过阈值与 top_k 控制结果精度与数量。 -4. **结果利用**:返回为 JSON,可直接用于回答、摘要或后续引用;包含 cite 索引便于引用管理。 +4. **启用重排序(可选)**:如需提升检索结果相关性,可设置 `enable_rerank: true`,并通过 `rerank_top_n` 和 `rerank_model` 调整重排序效果。 +5. **结果利用**:返回为 JSON,可直接用于回答、摘要或后续引用;包含 cite 索引便于引用管理。 ## 🛡️ 安全与最佳实践 diff --git a/doc/docs/zh/user-guide/model-management.md b/doc/docs/zh/user-guide/model-management.md index b715ebc1a..46c1b25b4 100644 --- a/doc/docs/zh/user-guide/model-management.md +++ b/doc/docs/zh/user-guide/model-management.md @@ -52,7 +52,7 @@ Nexent支持与ModelEngine平台的无缝对接 1. **添加自定义模型** - 点击"添加自定义模型"按钮,进入添加模型弹窗。 2. **选择模型类型** - - 点击模型类型下拉框,选择要添加的模型类型(大语言模型/向量化模型/视觉语言模型)。 + - 点击模型类型下拉框,选择要添加的模型类型(大语言模型/向量化模型/视觉语言模型/重排模型)。 3. **配置模型参数** - **模型名称(必填)**:输入请求体中的模型名称。 - **展示名称**:可为模型设置一个展示名称,默认与模型名称相同。 @@ -82,7 +82,7 @@ Nexent支持与ModelEngine平台的无缝对接 2. **选择模型提供商** - 点击模型提供商下拉框,选择模型提供商。 3. **选择模型类型** - - 点击模型类型下拉框,选择要添加的模型类型(大语言模型/向量化模型/视觉语言模型)。 + - 点击模型类型下拉框,选择要添加的模型类型(大语言模型/向量化模型/视觉语言模型/重排模型)。 4. **输入API Key(必填)** - 输入您的API密钥。 5. **获取模型** @@ -150,6 +150,10 @@ Nexent支持与ModelEngine平台的无缝对接 +#### 重排模型 +重排模型用于初筛后的文档进行语义匹配与评分,确保最相关的核心答案能够排在首位,以提升检索的准确性和效率。配置合适的重排模型,可以显著提升知识库的检索效果。 + +- 点击重排模型下拉框,从已添加的重排模型中选择一个。 #### 多模态模型 @@ -161,6 +165,7 @@ Nexent支持与ModelEngine平台的无缝对接
+
@@ -215,6 +220,8 @@ Nexent 支持任何 **遵循OpenAI API规范** 的大语言模型供应商,包 使用与大语言模型相同的API Key,但模型URL一般会有所差异,一般以`/v1/embeddings`为结尾,同时指定向量模型名称,如硅基流动提供的**BAAI/bge-m3**。 +#### 🔃 重排模型 +使用与大语言模型相同的API Key,但模型URL一般会有所差异,一般以`/v1/rerank`为结尾。 #### 🎤 语音模型 目前仅支持火山引擎语音,且需要在`.env`中进行配置 diff --git a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx index b458e6948..4808bd765 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx @@ -50,6 +50,44 @@ const TOOLS_REQUIRING_KB_SELECTION = [ "idata_search", ]; +const TOOLS_SUPPORTING_RERANK = [ + "knowledge_base_search", + "dify_search", + "datamate_search", +]; + +function withRerankParams(params: ToolParam[], toolName?: string): ToolParam[] { + if (!toolName || !TOOLS_SUPPORTING_RERANK.includes(toolName)) return params; + + const hasRerank = params.some((p) => p.name === "rerank"); + const hasRerankModelName = params.some((p) => p.name === "rerank_model_name"); + if (hasRerank && hasRerankModelName) return params; + + const next = [...params]; + + if (!hasRerank) { + next.push({ + name: "rerank", + type: "boolean", + required: false, + value: false, + description: "Whether to enable reranking for search results", + }); + } + + if (!hasRerankModelName) { + next.push({ + name: "rerank_model_name", + type: "string", + required: false, + value: "", + description: "The name of the rerank model to use", + }); + } + + return next; +} + export default function ToolConfigModal({ isOpen, onCancel, @@ -478,15 +516,16 @@ export default function ToolConfigModal({ // If server_url already has a saved value, use it if (serverUrlParam?.value) { // Initialize form with saved values (including server_url) - setCurrentParams(initialParams); + const paramsWithRerank = withRerankParams(initialParams, tool.name); + setCurrentParams(paramsWithRerank); const formValues: Record = {}; - initialParams.forEach((param, index) => { + paramsWithRerank.forEach((param, index) => { formValues[`param_${index}`] = param.value; }); form.setFieldsValue(formValues); // Parse initial index_names/dataset_ids value for knowledge base selection - const kbParam = initialParams.find( + const kbParam = paramsWithRerank.find( (p) => p.name === "index_names" || p.name === "dataset_ids" ); if (kbParam?.value) { @@ -521,18 +560,20 @@ export default function ToolConfigModal({ return param; }); - setCurrentParams(updatedParams); + const paramsWithRerank = withRerankParams(updatedParams, tool.name); + setCurrentParams(paramsWithRerank); const formValues: Record = {}; - updatedParams.forEach((param, index) => { + paramsWithRerank.forEach((param, index) => { formValues[`param_${index}`] = param.value; }); form.setFieldsValue(formValues); } else { // Either no default available OR user has modified the URL, initialize with initialParams - setCurrentParams(initialParams); + const paramsWithRerank = withRerankParams(initialParams, tool.name); + setCurrentParams(paramsWithRerank); const formValues: Record = {}; - initialParams.forEach((param, index) => { + paramsWithRerank.forEach((param, index) => { formValues[`param_${index}`] = param.value; }); form.setFieldsValue(formValues); @@ -607,10 +648,11 @@ export default function ToolConfigModal({ return param; }); - setCurrentParams(updatedParams); + const paramsWithRerank = withRerankParams(updatedParams, tool.name); + setCurrentParams(paramsWithRerank); const formValues: Record = {}; - updatedParams.forEach((param, index) => { + paramsWithRerank.forEach((param, index) => { formValues[`param_${index}`] = param.value; }); form.setFieldsValue(formValues); @@ -632,9 +674,10 @@ export default function ToolConfigModal({ } // Initialize form values - setCurrentParams(initialParams); + const paramsWithRerank = withRerankParams(initialParams, tool?.name); + setCurrentParams(paramsWithRerank); const formValues: Record = {}; - initialParams.forEach((param, index) => { + paramsWithRerank.forEach((param, index) => { formValues[`param_${index}`] = param.value; }); form.setFieldsValue(formValues); @@ -1176,6 +1219,36 @@ export default function ToolConfigModal({ // Determine if this parameter should be rendered as a select dropdown const isSelectType = options && options.length > 0; + // Special handling for rerank_model_name parameter - show model selector + if (param.name === "rerank_model_name") { + // First try to get the list of available rerank models from config + const rerankConfig = configData?.models?.rerank; + const hasRerankModel = rerankConfig?.modelName; + + if (hasRerankModel) { + // If rerank model is configured, show it as an option + const modelOptions = [{ value: rerankConfig.modelName, label: rerankConfig.displayName || rerankConfig.modelName }]; + return ( +