diff --git a/backend/agents/create_agent_info.py b/backend/agents/create_agent_info.py index bc4031e0a..0e1ed4d96 100644 --- a/backend/agents/create_agent_info.py +++ b/backend/agents/create_agent_info.py @@ -14,6 +14,7 @@ ElasticSearchService, get_vector_db_core, get_embedding_model, + get_rerank_model, ) from services.remote_mcp_service import get_remote_mcp_server_list from services.memory_config_service import build_memory_context @@ -350,11 +351,32 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int tool_config.metadata = langchain_tool break - # special logic for knowledge base search tool + # special logic for search tools that may use reranking models if tool_config.class_name == "KnowledgeBaseSearchTool": - tool_config.metadata = { + rerank = param_dict.get("rerank", False) + rerank_model_name = param_dict.get("rerank_model_name", "") + rerank_model = None + if rerank and rerank_model_name: + rerank_model = get_rerank_model( + tenant_id=tenant_id, model_name=rerank_model_name + ) + + tool_config.metadata = { "vdb_core": get_vector_db_core(), "embedding_model": get_embedding_model(tenant_id=tenant_id), + "rerank_model": rerank_model, + } + elif tool_config.class_name in ["DifySearchTool", "DataMateSearchTool"]: + rerank = param_dict.get("rerank", False) + rerank_model_name = param_dict.get("rerank_model_name", "") + rerank_model = None + if rerank and rerank_model_name: + rerank_model = get_rerank_model( + tenant_id=tenant_id, model_name=rerank_model_name + ) + + tool_config.metadata = { + "rerank_model": rerank_model, } elif tool_config.class_name == "AnalyzeTextFileTool": tool_config.metadata = { diff --git a/backend/services/model_health_service.py b/backend/services/model_health_service.py index 78f6413ee..9214a1ffa 100644 --- a/backend/services/model_health_service.py +++ b/backend/services/model_health_service.py @@ -3,6 +3,7 @@ from nexent.core import MessageObserver from nexent.core.models import OpenAIModel, OpenAIVLModel from nexent.core.models.embedding_model import JinaEmbedding, OpenAICompatibleEmbedding +from nexent.core.models.rerank_model import OpenAICompatibleRerank from services.voice_service import get_voice_service from consts.const import LOCALHOST_IP, LOCALHOST_NAME, DOCKER_INTERNAL_HOST @@ -102,7 +103,13 @@ async def _perform_connectivity_check( ssl_verify=ssl_verify ).check_connectivity() elif model_type == "rerank": - connectivity = False + rerank_model = OpenAICompatibleRerank( + model_name=model_name, + base_url=model_base_url, + api_key=model_api_key, + ssl_verify=ssl_verify, + ) + connectivity = await rerank_model.connectivity_check() elif model_type == "vlm": observer = MessageObserver() connectivity = await OpenAIVLModel( diff --git a/backend/services/model_provider_service.py b/backend/services/model_provider_service.py index 8c397dc70..dbff17082 100644 --- a/backend/services/model_provider_service.py +++ b/backend/services/model_provider_service.py @@ -132,6 +132,11 @@ async def prepare_model_dict(provider: str, model: dict, model_url: str, model_a model_dict["base_url"] = f"{model_url.rstrip('/')}/{MODEL_ENGINE_NORTH_PREFIX}/embeddings" # The embedding dimension might differ from the provided max_tokens. model_dict["max_tokens"] = await embedding_dimension_check(model_dict) + elif model["model_type"] == "rerank": + if provider == ProviderEnum.DASHSCOPE.value: + model_dict["base_url"] = f"{model_url.replace('compatible-mode/v1','api/v1').rstrip('/')}/services/rerank/text-rerank/text-rerank" + else: + model_dict["base_url"] = f"{model_url.rstrip('/')}/rerank" else: # For non-embedding models if provider == ProviderEnum.MODELENGINE.value: diff --git a/backend/services/providers/dashscope_provider.py b/backend/services/providers/dashscope_provider.py index 4ecbcbb1d..b9fb7ab7b 100644 --- a/backend/services/providers/dashscope_provider.py +++ b/backend/services/providers/dashscope_provider.py @@ -58,7 +58,7 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: "chat": [], # Maps to "llm" "vlm": [], # Maps to "vlm" "embedding": [], # Maps to "embedding" / "multi_embedding" - "reranker": [], # Maps to "reranker" + "rerank": [], # Maps to "rerank" "tts": [], # Maps to "tts" "stt": [] # Maps to "stt" } @@ -88,10 +88,10 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: categorized_models['embedding'].append(cleaned_model) continue - # 2. Reranker + # 2. Rerank if 'rerank' in m_id.lower() or '重排序' in desc: - cleaned_model.update({"model_tag": "reranker", "model_type": "reranker"}) - categorized_models['reranker'].append(cleaned_model) + cleaned_model.update({"model_tag": "rerank", "model_type": "rerank"}) + categorized_models['rerank'].append(cleaned_model) continue # 3. STT diff --git a/backend/services/providers/silicon_provider.py b/backend/services/providers/silicon_provider.py index 29de51fce..ea41cc95d 100644 --- a/backend/services/providers/silicon_provider.py +++ b/backend/services/providers/silicon_provider.py @@ -30,6 +30,8 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: silicon_url = f"{SILICON_GET_URL}?sub_type=chat" elif model_type in ("embedding", "multi_embedding"): silicon_url = f"{SILICON_GET_URL}?sub_type=embedding" + elif model_type == "rerank": + silicon_url = f"{SILICON_GET_URL}?sub_type=reranker" else: silicon_url = SILICON_GET_URL @@ -48,6 +50,10 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: for item in model_list: item["model_tag"] = "embedding" item["model_type"] = model_type + elif model_type == "rerank": + for item in model_list: + item["model_tag"] = "rerank" + item["model_type"] = model_type # Return empty list to indicate successful API call but no models if not model_list: diff --git a/backend/services/providers/tokenpony_provider.py b/backend/services/providers/tokenpony_provider.py index 42e5d178c..ab4446c1b 100644 --- a/backend/services/providers/tokenpony_provider.py +++ b/backend/services/providers/tokenpony_provider.py @@ -47,7 +47,7 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: "chat": [], # Maps to "llm" "vlm": [], # Maps to "vlm" "embedding": [], # Maps to "embedding" / "multi_embedding" - "reranker": [], # Maps to "reranker" + "rerank": [], # Maps to "rerank" "tts": [], # Maps to "tts" "stt": [] # Maps to "stt" } @@ -66,10 +66,10 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: "model_type": "", "max_tokens": DEFAULT_LLM_MAX_TOKENS } - # 1. reranker + # 1. rerank if 'rerank' in m_id: - cleaned_model.update({"model_tag": "reranker", "model_type": "reranker"}) - categorized_models['reranker'].append(cleaned_model) + cleaned_model.update({"model_tag": "rerank", "model_type": "rerank"}) + categorized_models['rerank'].append(cleaned_model) #2. embedding elif 'embedding' in m_id or m_id.startswith('bge-'): cleaned_model.update({"model_tag": "embedding", "model_type": "embedding"}) diff --git a/backend/services/tool_configuration_service.py b/backend/services/tool_configuration_service.py index a0f5b2399..9653b2e10 100644 --- a/backend/services/tool_configuration_service.py +++ b/backend/services/tool_configuration_service.py @@ -28,7 +28,7 @@ check_tool_list_initialized, ) from services.file_management_service import get_llm_model -from services.vectordatabase_service import get_embedding_model, get_vector_db_core +from services.vectordatabase_service import get_embedding_model, get_rerank_model, get_vector_db_core from database.client import minio_client from services.image_service import get_vlm_model from utils.tool_utils import get_local_tools_classes, get_local_tools_description_zh @@ -694,10 +694,32 @@ def _validate_local_tool( if tool_name == "knowledge_base_search": embedding_model = get_embedding_model(tenant_id=tenant_id) vdb_core = get_vector_db_core() + + # Get rerank configuration + rerank = instantiation_params.get("rerank", False) + rerank_model_name = instantiation_params.get("rerank_model_name", "") + rerank_model = None + if rerank and rerank_model_name: + rerank_model = get_rerank_model(tenant_id=tenant_id, model_name=rerank_model_name) + params = { **instantiation_params, 'vdb_core': vdb_core, 'embedding_model': embedding_model, + 'rerank_model': rerank_model, + } + tool_instance = tool_class(**params) + elif tool_name in ["dify_search", "datamate_search"]: + # Get rerank configuration for dify and datamate search tools + rerank = instantiation_params.get("rerank", False) + rerank_model_name = instantiation_params.get("rerank_model_name", "") + rerank_model = None + if rerank and rerank_model_name: + rerank_model = get_rerank_model(tenant_id=tenant_id, model_name=rerank_model_name) + + params = { + **instantiation_params, + 'rerank_model': rerank_model, } tool_instance = tool_class(**params) elif tool_name == "analyze_image": diff --git a/backend/services/vectordatabase_service.py b/backend/services/vectordatabase_service.py index de79c812c..cf8f7f98c 100644 --- a/backend/services/vectordatabase_service.py +++ b/backend/services/vectordatabase_service.py @@ -21,6 +21,7 @@ from fastapi import Body, Depends, Path, Query from fastapi.responses import StreamingResponse from nexent.core.models.embedding_model import OpenAICompatibleEmbedding, JinaEmbedding, BaseEmbedding +from nexent.core.models.rerank_model import OpenAICompatibleRerank, BaseRerank from nexent.vector_database.base import VectorDatabaseCore from nexent.vector_database.elasticsearch_core import ElasticSearchCore from nexent.vector_database.datamate_core import DataMateCore @@ -241,6 +242,52 @@ def get_embedding_model(tenant_id: str, model_name: Optional[str] = None): return None +def get_rerank_model(tenant_id: str, model_name: Optional[str] = None): + """ + Get the rerank model for the tenant, optionally using a specific model name. + + Args: + tenant_id: Tenant ID + model_name: Optional specific model name to use (format: "model_repo/model_name" or just "model_name") + If provided, will try to find the model in the tenant's model list. + + Returns: + Rerank model instance or None + """ + # If model_name is provided, try to find it in the tenant's models + if model_name: + try: + models = get_model_records({"model_type": "rerank"}, tenant_id) + for model in models: + model_display_name = model.get("model_repo") + "/" + model["model_name"] if model.get("model_repo") else model["model_name"] + if model_display_name == model_name: + # Found the model, create rerank model instance + return OpenAICompatibleRerank( + model_name=get_model_name_from_config(model) or "", + base_url=model.get("base_url", ""), + api_key=model.get("api_key", ""), + ssl_verify=model.get("ssl_verify", True), + ) + except Exception as e: + logger.warning(f"Failed to get rerank model by name {model_name}: {e}") + + # Fall back to default rerank model + model_config = tenant_config_manager.get_model_config( + key="RERANK_ID", tenant_id=tenant_id) + + model_type = model_config.get("model_type", "") + + if model_type == "rerank": + return OpenAICompatibleRerank( + model_name=get_model_name_from_config(model_config) or "", + base_url=model_config.get("base_url", ""), + api_key=model_config.get("api_key", ""), + ssl_verify=model_config.get("ssl_verify", True), + ) + else: + return None + + class ElasticSearchService: @staticmethod async def full_delete_knowledge_base(index_name: str, vdb_core: VectorDatabaseCore, user_id: str): diff --git a/doc/docs/zh/user-guide/agent-development.md b/doc/docs/zh/user-guide/agent-development.md index cb4b4055d..67d3c8311 100644 --- a/doc/docs/zh/user-guide/agent-development.md +++ b/doc/docs/zh/user-guide/agent-development.md @@ -130,10 +130,11 @@ - 检索的模式 `search_mode`(默认为 `hybrid`) - 目标检索的知识库列表 `index_names`,如 `["医疗", "维生素知识大全"]` - 若不输入 `index_names`,则默认检索知识库页面所选中的全部知识库 + - 是否启用重排模型(默认为 `false`),启用后配置重排模型,实现对检索结果的重排优化 6. 输入完成后点击"执行测试"开始测试,并在下方查看测试结果