From 0ce6f4f2566af6197075c4a395896c715df9f014 Mon Sep 17 00:00:00 2001 From: wyxkerry <1012700194@qq.com> Date: Tue, 24 Mar 2026 23:59:08 +0800 Subject: [PATCH 01/11] =?UTF-8?q?=E2=9C=A8add=5Fimage=5Fretrieval?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/agents/create_agent_info.py | 3 +- backend/apps/model_managment_app.py | 3 +- backend/apps/vectordatabase_app.py | 17 +- backend/consts/const.py | 4 + backend/data_process/ray_actors.py | 62 ++- backend/database/attachment_db.py | 60 ++- backend/database/db_models.py | 1 + backend/database/knowledge_db.py | 9 + backend/database/model_management_db.py | 11 +- backend/services/config_sync_service.py | 2 +- backend/services/data_process_service.py | 13 +- backend/services/datamate_service.py | 3 +- backend/services/model_health_service.py | 4 +- .../services/tool_configuration_service.py | 7 +- backend/services/vectordatabase_service.py | 78 +++- docker/deploy.sh | 284 +++++++++++- docker/init.sql | 2 + ...dd_is_multimodal_to_knowledge_record_t.sql | 5 + .../components/agentConfig/ToolManagement.tsx | 13 +- .../agentConfig/tool/ToolConfigModal.tsx | 110 ++++- .../knowledges/KnowledgeBaseConfiguration.tsx | 68 ++- .../components/document/DocumentChunk.tsx | 92 ++-- .../components/document/DocumentList.tsx | 31 ++ .../knowledge/KnowledgeBaseList.tsx | 26 +- .../contexts/KnowledgeBaseContext.tsx | 36 +- .../components/model/ModelAddDialog.tsx | 5 +- .../components/model/ModelEditDialog.tsx | 3 +- .../models/components/modelConfig.tsx | 3 +- .../components/resources/ModelList.tsx | 7 +- .../KnowledgeBaseSelectorModal.tsx | 206 +++++---- frontend/const/agentConfig.ts | 4 +- frontend/const/knowledgeBaseLayout.ts | 2 + frontend/hooks/useConfig.ts | 4 + frontend/public/locales/en/common.json | 4 +- frontend/public/locales/zh/common.json | 4 +- frontend/services/api.ts | 4 +- frontend/services/knowledgeBaseService.ts | 20 + frontend/services/modelService.ts | 5 +- frontend/types/knowledgeBase.ts | 3 + sdk/nexent/core/models/embedding_model.py | 8 +- .../core/tools/knowledge_base_search_tool.py | 120 ++++- sdk/nexent/core/utils/favicon_extractor.py | 50 +-- sdk/nexent/data_process/core.py | 36 +- sdk/nexent/data_process/extract_image.py | 413 ++++++++++++++++++ .../vector_database/elasticsearch_core.py | 159 +++++-- test/backend/agents/test_create_agent_info.py | 57 ++- test/backend/app/test_model_managment_app.py | 8 +- test/backend/app/test_vectordatabase_app.py | 66 ++- test/backend/data_process/test_ray_actors.py | 80 ++++ test/backend/database/test_attachment_db.py | 48 +- test/backend/database/test_knowledge_db.py | 106 ++++- .../database/test_model_managment_db.py | 28 ++ .../services/test_config_sync_service.py | 32 +- .../services/test_data_process_service.py | 15 + .../backend/services/test_datamate_service.py | 1 + .../services/test_model_health_service.py | 85 +--- .../test_tool_configuration_service.py | 65 ++- .../services/test_vectordatabase_service.py | 179 +++++++- test/sdk/core/models/test_embedding_model.py | 16 + .../tools/test_knowledge_base_search_tool.py | 29 ++ test/sdk/core/utils/test_favicon_extractor.py | 38 ++ test/sdk/data_process/test_core.py | 106 ++++- test/sdk/data_process/test_extract_image.py | 118 +++++ .../test_elasticsearch_core.py | 107 ++++- 64 files changed, 2804 insertions(+), 384 deletions(-) create mode 100644 docker/sql/v1.8.1_0306_add_is_multimodal_to_knowledge_record_t.sql create mode 100644 sdk/nexent/data_process/extract_image.py create mode 100644 test/sdk/core/utils/test_favicon_extractor.py create mode 100644 test/sdk/data_process/test_extract_image.py diff --git a/backend/agents/create_agent_info.py b/backend/agents/create_agent_info.py index faed9ce79..8fb3ec51c 100644 --- a/backend/agents/create_agent_info.py +++ b/backend/agents/create_agent_info.py @@ -239,9 +239,10 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int # special logic for knowledge base search tool if tool_config.class_name == "KnowledgeBaseSearchTool": + is_multimodal = tool_config.params.pop("multimodal", False) tool_config.metadata = { "vdb_core": get_vector_db_core(), - "embedding_model": get_embedding_model(tenant_id=tenant_id), + "embedding_model": get_embedding_model(tenant_id=tenant_id, is_multimodal=is_multimodal), } elif tool_config.class_name == "AnalyzeTextFileTool": tool_config.metadata = { diff --git a/backend/apps/model_managment_app.py b/backend/apps/model_managment_app.py index 0a5a04139..9dfc951fc 100644 --- a/backend/apps/model_managment_app.py +++ b/backend/apps/model_managment_app.py @@ -298,6 +298,7 @@ async def get_llm_model_list(authorization: Optional[str] = Header(None)): @router.post("/healthcheck") async def check_model_health( display_name: str = Query(..., description="Display name to check"), + modelType: str = Query(..., description="..."), authorization: Optional[str] = Header(None) ): """Check and update model connectivity, returning the latest status. @@ -308,7 +309,7 @@ async def check_model_health( """ try: _, tenant_id = get_current_user_id(authorization) - result = await check_model_connectivity(display_name, tenant_id) + result = await check_model_connectivity(display_name, tenant_id, modelType) return JSONResponse(status_code=HTTPStatus.OK, content={ "message": "Successfully checked model connectivity", "data": result diff --git a/backend/apps/vectordatabase_app.py b/backend/apps/vectordatabase_app.py index 04ea9820f..177b90272 100644 --- a/backend/apps/vectordatabase_app.py +++ b/backend/apps/vectordatabase_app.py @@ -18,7 +18,7 @@ from services.redis_service import get_redis_service from utils.auth_utils import get_current_user_id from utils.file_management_utils import get_all_files_status -from database.knowledge_db import get_index_name_by_knowledge_name +from database.knowledge_db import get_index_name_by_knowledge_name, get_knowledge_record router = APIRouter(prefix="/indices") service = ElasticSearchService() @@ -65,9 +65,15 @@ def create_new_index( # Extract optional fields from request body ingroup_permission = None group_ids = None + is_multimodal = False + embedding_model_name: Optional[str] = None if request: ingroup_permission = request.get("ingroup_permission") group_ids = request.get("group_ids") + is_multimodal = request.get("is_multimodal") + embedding_model_name = request.get("embeddingModel") or request.get("embedding_model") + if isinstance(embedding_model_name, str): + embedding_model_name = embedding_model_name.strip() or None # Treat path parameter as user-facing knowledge base name for new creations return ElasticSearchService.create_knowledge_base( @@ -78,6 +84,7 @@ def create_new_index( tenant_id=tenant_id, ingroup_permission=ingroup_permission, group_ids=group_ids, + is_multimodal=is_multimodal, ) except Exception as e: raise HTTPException( @@ -121,6 +128,7 @@ async def update_index( knowledge_name = request.get("knowledge_name") ingroup_permission = request.get("ingroup_permission") group_ids = request.get("group_ids") + is_multimodal = request.get("is_multimodal") # Call service layer to update knowledge base result = ElasticSearchService.update_knowledge_base( @@ -128,6 +136,7 @@ async def update_index( knowledge_name=knowledge_name, ingroup_permission=ingroup_permission, group_ids=group_ids, + is_multimodal=is_multimodal, tenant_id=tenant_id, user_id=user_id, ) @@ -195,7 +204,11 @@ def create_index_documents( """ try: user_id, tenant_id = get_current_user_id(authorization) - embedding_model = get_embedding_model(tenant_id) + + knowledge_record = get_knowledge_record({'index_name': index_name}) + is_multimodal = True if knowledge_record.get( + 'is_multimodal') == 'Y' else False + embedding_model = get_embedding_model(tenant_id, is_multimodal) return ElasticSearchService.index_documents( embedding_model=embedding_model, index_name=index_name, diff --git a/backend/consts/const.py b/backend/consts/const.py index 2a48da8ce..fc26606f0 100644 --- a/backend/consts/const.py +++ b/backend/consts/const.py @@ -27,6 +27,10 @@ class VectorDatabaseType(str, Enum): # Data Processing Service Configuration DATA_PROCESS_SERVICE = os.getenv("DATA_PROCESS_SERVICE") CLIP_MODEL_PATH = os.getenv("CLIP_MODEL_PATH") +TABLE_TRANSFORMER_MODEL_PATH = os.getenv("TABLE_TRANSFORMER_MODEL_PATH") +UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH = os.getenv( + "UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH" +) # Upload Configuration diff --git a/backend/data_process/ray_actors.py b/backend/data_process/ray_actors.py index 2fa590bec..c64c173fd 100644 --- a/backend/data_process/ray_actors.py +++ b/backend/data_process/ray_actors.py @@ -1,11 +1,19 @@ +from io import BytesIO import logging import json from typing import Any, Dict, List, Optional import ray -from consts.const import RAY_ACTOR_NUM_CPUS, REDIS_BACKEND_URL, DEFAULT_EXPECTED_CHUNK_SIZE, DEFAULT_MAXIMUM_CHUNK_SIZE -from database.attachment_db import get_file_stream +from consts.const import ( + RAY_ACTOR_NUM_CPUS, + REDIS_BACKEND_URL, + DEFAULT_EXPECTED_CHUNK_SIZE, + DEFAULT_MAXIMUM_CHUNK_SIZE, + TABLE_TRANSFORMER_MODEL_PATH, + UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH, +) +from database.attachment_db import build_s3_url, get_file_stream, upload_fileobj from database.model_management_db import get_model_by_model_id from nexent.data_process import DataProcessCore @@ -58,6 +66,11 @@ def process_file( if task_id: params['task_id'] = task_id + params["table_transformer_model_path"] = TABLE_TRANSFORMER_MODEL_PATH + params[ + "unstructured_default_model_initialize_params_json_path" + ] = UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH + # Get chunk size parameters from embedding model if model_id is provided if model_id and tenant_id: try: @@ -95,12 +108,55 @@ def process_file( logger.error(f"Failed to fetch file from {source}: {e}") raise - chunks = self._processor.file_process( + result = self._processor.file_process( file_data=file_data, filename=source, chunking_strategy=chunking_strategy, **params ) + if isinstance(result, tuple) and len(result) == 2: + chunks, images_info = result + else: + chunks = result + images_info = [] + + if len(images_info) > 0: + folder = "images_in_attachments" + for index, image_data in enumerate(images_info): + if not isinstance(image_data, dict): + logger.warning( + f"[RayActor] Skipping image entry at index {index}: unexpected type {type(image_data)}" + ) + continue + if "image_bytes" not in image_data: + logger.warning( + f"[RayActor] Skipping image entry at index {index}: missing image_bytes" + ) + continue + + img_obj = BytesIO(image_data["image_bytes"]) + result = upload_fileobj( + file_obj=img_obj, + file_name=f"{index}.{image_data['image_format']}", + prefix=folder) + + image_data["source_file"] = source + image_data["image_url"] = build_s3_url(result.get("object_name", "")) + + + chunks.append({ + "content": json.dumps({ + "source_file": source, + "position": image_data["position"], + "image_url": build_s3_url(result.get("object_name", "")) + }), + "filename": source, + "metadata": { + "chunk_index": len(chunks) + index, + "process_source": "UniversalImageExtractor", + "image_url": build_s3_url(result.get("object_name", "")) + } + }) if chunks is None: logger.warning( diff --git a/backend/database/attachment_db.py b/backend/database/attachment_db.py index 2e6249468..88d0e9160 100644 --- a/backend/database/attachment_db.py +++ b/backend/database/attachment_db.py @@ -2,11 +2,63 @@ import os import uuid from datetime import datetime -from typing import Any, BinaryIO, Dict, List, Optional +from typing import Any, BinaryIO, Dict, List, Optional, Tuple from .client import minio_client +def _normalize_object_and_bucket(object_name: str, bucket: Optional[str] = None) -> Tuple[str, Optional[str]]: + """ + Normalize object_name + bucket from supported URL styles. + + Supports: + - s3://bucket/key + - /bucket/key + - key (uses provided bucket or default bucket) + """ + if not object_name: + return object_name, bucket + + if object_name.startswith("s3://"): + s3_path = object_name[len("s3://") :] + parts = s3_path.split("/", 1) + parsed_bucket = parts[0] if parts[0] else None + parsed_key = parts[1] if len(parts) > 1 else "" + return parsed_key, parsed_bucket or bucket + + if object_name.startswith("/"): + path = object_name.lstrip("/") + parts = path.split("/", 1) + parsed_bucket = parts[0] if parts[0] else None + parsed_key = parts[1] if len(parts) > 1 else "" + return parsed_key, parsed_bucket or bucket + + return object_name, bucket + + +def build_s3_url(object_name: str, bucket: Optional[str] = None) -> str: + """ + Build an s3://bucket/key style URL from an object name (or passthrough if already s3://). + """ + if not object_name: + return "" + + if object_name.startswith("s3://"): + return object_name + + if object_name.startswith("/"): + path = object_name.lstrip("/") + parts = path.split("/", 1) + if len(parts) == 2: + return f"s3://{parts[0]}/{parts[1]}" + return f"s3://{parts[0]}/" + + resolved_bucket = bucket or minio_client.storage_config.default_bucket + if resolved_bucket: + return f"s3://{resolved_bucket}/{object_name}" + return f"s3://{object_name}" + + def generate_object_name(file_name: str, prefix: str = "attachments") -> str: """ Generate a unique object name @@ -165,6 +217,7 @@ def get_file_size_from_minio(object_name: str, bucket: Optional[str] = None) -> """ Get file size by object name """ + object_name, bucket = _normalize_object_and_bucket(object_name, bucket) bucket = bucket or minio_client.storage_config.default_bucket return minio_client.get_file_size(object_name, bucket) @@ -181,6 +234,7 @@ def file_exists(object_name: str, bucket: Optional[str] = None) -> bool: bool: True if file exists, False otherwise """ try: + object_name, bucket = _normalize_object_and_bucket(object_name, bucket) return minio_client.file_exists(object_name, bucket) except Exception: return False @@ -198,6 +252,8 @@ def copy_file(source_object: str, dest_object: str, bucket: Optional[str] = None Returns: Dict[str, Any]: Result containing success flag and error message (if any) """ + source_object, bucket = _normalize_object_and_bucket(source_object, bucket) + dest_object, bucket = _normalize_object_and_bucket(dest_object, bucket) success, result = minio_client.copy_file(source_object, dest_object, bucket) if success: return {"success": True, "object_name": result} @@ -242,6 +298,7 @@ def delete_file(object_name: str, bucket: Optional[str] = None) -> Dict[str, Any Returns: Dict[str, Any]: Delete result, containing success flag and error message (if any) """ + object_name, bucket = _normalize_object_and_bucket(object_name, bucket) if not bucket: bucket = minio_client.storage_config.default_bucket success, result = minio_client.delete_file(object_name, bucket) @@ -265,6 +322,7 @@ def get_file_stream(object_name: str, bucket: Optional[str] = None) -> Optional[ Returns: Optional[BinaryIO]: Standard BinaryIO stream object, or None if failed """ + object_name, bucket = _normalize_object_and_bucket(object_name, bucket) success, result = minio_client.get_file_stream(object_name, bucket) if not success: return None diff --git a/backend/database/db_models.py b/backend/database/db_models.py index 80dcc87eb..24dfd26df 100644 --- a/backend/database/db_models.py +++ b/backend/database/db_models.py @@ -271,6 +271,7 @@ class KnowledgeRecord(TableBase): group_ids = Column(String, doc="Knowledge base group IDs list") ingroup_permission = Column( String(30), doc="In-group permission: EDIT, READ_ONLY, PRIVATE") + is_multimodal = Column(String(1), default="N", doc="Whether it is multimodal. Optional values: Y/N") class TenantConfig(TableBase): diff --git a/backend/database/knowledge_db.py b/backend/database/knowledge_db.py index df42e1888..40f4ca718 100644 --- a/backend/database/knowledge_db.py +++ b/backend/database/knowledge_db.py @@ -52,6 +52,7 @@ def create_knowledge_record(query: Dict[str, Any]) -> Dict[str, Any]: "knowledge_name": knowledge_name, "group_ids": convert_list_to_string(group_ids) if isinstance(group_ids, list) else group_ids, "ingroup_permission": query.get("ingroup_permission"), + "is_multimodal": 'Y' if query.get("is_multimodal") else 'N' } # For backward compatibility: if caller explicitly provides index_name, @@ -178,6 +179,9 @@ def update_knowledge_record(query: Dict[str, Any]) -> bool: if query.get("group_ids") is not None: record.group_ids = query["group_ids"] + if query.get("is_multimodal"): + record.is_multimodal = 'Y' if query["is_multimodal"] else 'N' + # Update timestamp and user if query.get("user_id"): record.updated_by = query["user_id"] @@ -254,6 +258,11 @@ def get_knowledge_record(query: Optional[Dict[str, Any]] = None) -> Dict[str, An db_query = db_query.filter( KnowledgeRecord.tenant_id == query['tenant_id']) + if 'is_multimodal' in query: + db_query = db_query.filter( + KnowledgeRecord.is_multimodal == query['is_multimodal'] + ) + result = db_query.first() if result: diff --git a/backend/database/model_management_db.py b/backend/database/model_management_db.py index cb1c6c69f..61753f52f 100644 --- a/backend/database/model_management_db.py +++ b/backend/database/model_management_db.py @@ -170,7 +170,7 @@ def get_model_records(filters: Optional[Dict[str, Any]], tenant_id: str) -> List return result_list -def get_model_by_display_name(display_name: str, tenant_id: str) -> Optional[Dict[str, Any]]: +def get_model_by_display_name(display_name: str, tenant_id: str, model_type: str = None) -> Optional[Dict[str, Any]]: """ Get a model record by display name @@ -179,6 +179,11 @@ def get_model_by_display_name(display_name: str, tenant_id: str) -> Optional[Dic tenant_id: """ filters = {'display_name': display_name} + + if model_type in ["multiEmbedding", "multi_embedding"]: + filters['model_type'] = "multi_embedding" + elif model_type == "embedding": + filters['model_type'] = "embedding" records = get_model_records(filters, tenant_id) if not records: @@ -203,7 +208,7 @@ def get_models_by_display_name(display_name: str, tenant_id: str) -> List[Dict[s return get_model_records(filters, tenant_id) -def get_model_id_by_display_name(display_name: str, tenant_id: str) -> Optional[int]: +def get_model_id_by_display_name(display_name: str, tenant_id: str, model_type: str = None) -> Optional[int]: """ Get a model ID by display name @@ -214,7 +219,7 @@ def get_model_id_by_display_name(display_name: str, tenant_id: str) -> Optional[ Returns: Optional[int]: Model ID """ - model = get_model_by_display_name(display_name, tenant_id) + model = get_model_by_display_name(display_name, tenant_id, model_type) return model["model_id"] if model else None diff --git a/backend/services/config_sync_service.py b/backend/services/config_sync_service.py index 9fe50813a..c484ca23f 100644 --- a/backend/services/config_sync_service.py +++ b/backend/services/config_sync_service.py @@ -99,7 +99,7 @@ async def save_config_impl(config, tenant_id, user_id): config_key = get_env_key(model_type) + "_ID" model_id = get_model_id_by_display_name( - model_display_name, tenant_id) + model_display_name, tenant_id, model_type=model_type) handle_model_config(tenant_id, user_id, config_key, model_id, tenant_config_dict) diff --git a/backend/services/data_process_service.py b/backend/services/data_process_service.py index 8c44c15e6..9eae72407 100644 --- a/backend/services/data_process_service.py +++ b/backend/services/data_process_service.py @@ -255,6 +255,17 @@ async def load_image(self, image_url: str) -> Optional[Image.Image]: async def _load_image(self, session: aiohttp.ClientSession, path: str) -> Optional[Image.Image]: """Internal method to load an image from various sources""" try: + if path.startswith('s3://'): + # Fetch from MinIO using s3://bucket/key + file_stream = get_file_stream(object_name=path) + if file_stream is None: + raise FileNotFoundError( + f"Unable to fetch file from URL: {path}") + file_data = file_stream.read() + image_based64_str = base64.b64encode( + file_data).decode('utf-8') + path = f"data:image/jpeg;base64,{image_based64_str}" + # Check if input is base64 encoded if path.startswith('data:image'): # Extract the base64 data after the comma @@ -559,7 +570,7 @@ async def process_uploaded_text_file(self, file_content: bytes, filename: str, c } async def convert_office_to_pdf_impl(self, object_name: str, pdf_object_name: str) -> None: - """Full conversion pipeline: download → convert → upload → validate → cleanup. + """Full conversion pipeline: download -> convert -> upload -> validate -> cleanup. All five steps run inside data-process so that LibreOffice only needs to be installed in this container. diff --git a/backend/services/datamate_service.py b/backend/services/datamate_service.py index 776e0eb1d..26e777eba 100644 --- a/backend/services/datamate_service.py +++ b/backend/services/datamate_service.py @@ -51,7 +51,8 @@ async def _create_datamate_knowledge_records(knowledge_base_ids: List[str], "tenant_id": tenant_id, "user_id": user_id, # Use datamate as embedding model name - "embedding_model_name": embedding_model_names[i] + "embedding_model_name": embedding_model_names[i], + "is_multimodal": False, } # Run synchronous database operation in executor to avoid blocking diff --git a/backend/services/model_health_service.py b/backend/services/model_health_service.py index 78f6413ee..950bd3936 100644 --- a/backend/services/model_health_service.py +++ b/backend/services/model_health_service.py @@ -121,10 +121,10 @@ async def _perform_connectivity_check( return connectivity -async def check_model_connectivity(display_name: str, tenant_id: str) -> dict: +async def check_model_connectivity(display_name: str, tenant_id: str, model_type: str = None) -> dict: try: # Query the database using display_name and tenant context from app layer - model = get_model_by_display_name(display_name, tenant_id=tenant_id) + model = get_model_by_display_name(display_name, tenant_id=tenant_id, model_type=model_type) if not model: raise LookupError(f"Model configuration not found for {display_name}") diff --git a/backend/services/tool_configuration_service.py b/backend/services/tool_configuration_service.py index 3e7b22d11..cc0cb28a1 100644 --- a/backend/services/tool_configuration_service.py +++ b/backend/services/tool_configuration_service.py @@ -120,6 +120,10 @@ def get_local_tools() -> List[ToolInfo]: else: param_info["default"] = param.default.default param_info["optional"] = True + if getattr(param.default, "json_schema_extra", None): + optional_override = param.default.json_schema_extra.get("optional") + if optional_override is not None: + param_info["optional"] = optional_override init_params_list.append(param_info) @@ -632,7 +636,8 @@ def _validate_local_tool( instantiation_params[param_name] = param.default if tool_name == "knowledge_base_search": - embedding_model = get_embedding_model(tenant_id=tenant_id) + is_multimodal = instantiation_params.pop("multimodal", False) + embedding_model = get_embedding_model(tenant_id=tenant_id, is_multimodal=is_multimodal) vdb_core = get_vector_db_core() params = { **instantiation_params, diff --git a/backend/services/vectordatabase_service.py b/backend/services/vectordatabase_service.py index e32f005a3..d798dd4a8 100644 --- a/backend/services/vectordatabase_service.py +++ b/backend/services/vectordatabase_service.py @@ -27,7 +27,7 @@ from consts.const import DATAMATE_URL, ES_API_KEY, ES_HOST, LANGUAGE, VectorDatabaseType, IS_SPEED_MODE, PERMISSION_EDIT, PERMISSION_READ from consts.model import ChunkCreateRequest, ChunkUpdateRequest -from database.attachment_db import delete_file +from database.attachment_db import delete_file, get_file_stream from database.knowledge_db import ( create_knowledge_record, delete_knowledge_record, @@ -175,7 +175,7 @@ def check_knowledge_base_exist_impl(knowledge_name: str, vdb_core: VectorDatabas return {"status": "available"} -def get_embedding_model(tenant_id: str, model_name: Optional[str] = None): +def get_embedding_model(tenant_id: str, is_multimodal: bool = False, model_name: Optional[str] = None): """ Get the embedding model for the tenant, optionally using a specific model name. @@ -190,7 +190,8 @@ def get_embedding_model(tenant_id: str, model_name: Optional[str] = None): # If model_name is provided, try to find it in the tenant's models if model_name: try: - models = get_model_records({"model_type": "embedding"}, tenant_id) + model_type = "multi_embedding" if is_multimodal else "embedding" + models = get_model_records({"model_type": model_type}, tenant_id) for model in models: model_display_name = model.get("model_repo") + "/" + model["model_name"] if model.get("model_repo") else model["model_name"] if model_display_name == model_name: @@ -204,19 +205,29 @@ def get_embedding_model(tenant_id: str, model_name: Optional[str] = None): "max_tokens": model.get("max_tokens", 1024), "ssl_verify": model.get("ssl_verify", True), } - return OpenAICompatibleEmbedding( - api_key=model_config.get("api_key", ""), - base_url=model_config.get("base_url", ""), - model_name=get_model_name_from_config(model_config) or "", - embedding_dim=model_config.get("max_tokens", 1024), - ssl_verify=model_config.get("ssl_verify", True), - ) + if not is_multimodal: + return OpenAICompatibleEmbedding( + api_key=model_config.get("api_key", ""), + base_url=model_config.get("base_url", ""), + model_name=get_model_name_from_config(model_config) or "", + embedding_dim=model_config.get("max_tokens", 1024), + ssl_verify=model_config.get("ssl_verify", True), + ) + else: + return JinaEmbedding( + api_key=model_config.get("api_key", ""), + base_url=model_config.get("base_url", ""), + model_name=get_model_name_from_config( + model_config) or "", + embedding_dim=model_config.get("max_tokens", 1024), + ssl_verify=model_config.get("ssl_verify", True), + ) except Exception as e: logger.warning(f"Failed to get embedding model by name {model_name}: {e}") # Fall back to default embedding model (current behavior) model_config = tenant_config_manager.get_model_config( - key="EMBEDDING_ID", tenant_id=tenant_id) + key="MULTI_EMBEDDING_ID" if is_multimodal else "EMBEDDING_ID", tenant_id=tenant_id) model_type = model_config.get("model_type", "") @@ -359,6 +370,7 @@ async def full_delete_knowledge_base(index_name: str, vdb_core: VectorDatabaseCo @staticmethod def create_index( + embedding_model: BaseEmbedding, index_name: str = Path(..., description="Name of the index to create"), embedding_dim: Optional[int] = Query( @@ -372,15 +384,22 @@ def create_index( try: if vdb_core.check_index_exists(index_name): raise Exception(f"Index {index_name} already exists") - embedding_model = get_embedding_model(tenant_id) + success = vdb_core.create_index(index_name, embedding_dim=embedding_dim or ( embedding_model.embedding_dim if embedding_model else 1024)) if not success: raise Exception(f"Failed to create index {index_name}") - knowledge_data = {"index_name": index_name, - "created_by": user_id, - "tenant_id": tenant_id, - "embedding_model_name": embedding_model.model} + is_multimodal = ( + True + if embedding_model and getattr(embedding_model, "model_type", None) == "multimodal" + else False + ) + knowledge_data = { + "index_name": index_name, + "created_by": user_id, + "tenant_id": tenant_id, + "is_multimodal": is_multimodal, + } create_knowledge_record(knowledge_data) return {"status": "success", "message": f"Index {index_name} created successfully"} except Exception as e: @@ -395,6 +414,7 @@ def create_knowledge_base( tenant_id: Optional[str], ingroup_permission: Optional[str] = None, group_ids: Optional[List[int]] = None, + is_multimodal: bool = False, ): """ Create a new knowledge base with a user-facing name and an internal Elasticsearch index name. @@ -408,7 +428,7 @@ def create_knowledge_base( with an explicit index_name. """ try: - embedding_model = get_embedding_model(tenant_id) + embedding_model = get_embedding_model(tenant_id, is_multimodal,) # Create knowledge record first to obtain knowledge_id and generated index_name knowledge_data = { @@ -417,6 +437,7 @@ def create_knowledge_base( "user_id": user_id, "tenant_id": tenant_id, "embedding_model_name": embedding_model.model if embedding_model else None, + "is_multimodal": is_multimodal, } # Add group permission and group IDs if provided @@ -453,6 +474,7 @@ def update_knowledge_base( knowledge_name: Optional[str] = None, ingroup_permission: Optional[str] = None, group_ids: Optional[List[int]] = None, + is_multimodal: bool = False, tenant_id: Optional[str] = None, user_id: Optional[str] = None, ) -> bool: @@ -483,6 +505,7 @@ def update_knowledge_base( update_data = { "index_name": index_name, "updated_by": user_id, + "is_multimodal": is_multimodal, } if knowledge_name is not None: @@ -715,6 +738,7 @@ def list_indices( "display_name": record.get("knowledge_name", index_name), "permission": record["permission"], "group_ids": record["group_ids"], + "is_multimodal": True if record.get("is_multimodal") == "Y" else False, # knowledge source and ingroup permission from DB record "knowledge_sources": record["knowledge_sources"], "ingroup_permission": record["ingroup_permission"], @@ -816,15 +840,30 @@ def index_documents( "author": author, "date": date, "content": text, - "process_source": "Unstructured", + "process_source": metadata.get("process_source", "Unstructured"), "file_size": file_size, "create_time": create_time, "languages": metadata.get("languages", []), "embedding_model_name": embedding_model_name } + image_url = metadata.get("image_url", "") + if len(image_url) > 0: + # Fetch image bytes from MinIO (supports s3://bucket/key or /bucket/key) + try: + file_stream = get_file_stream( + object_name=image_url) + if file_stream is None: + raise FileNotFoundError( + f"Unable to fetch file from URL: {image_url}") + document["image_bytes"] = file_stream.read() + except Exception as e: + logger.error( + f"Failed to fetch file from {image_url}: {e}") + raise documents.append(document) + total_submitted = len(documents) if total_submitted == 0: return { @@ -842,8 +881,9 @@ def index_documents( 'tenant_id') if knowledge_record else None if tenant_id: + model_type = "EMBEDDING_ID" if embedding_model.model_type == "text" else "MULTI_EMBEDDING_ID" model_config = tenant_config_manager.get_model_config( - key="EMBEDDING_ID", tenant_id=tenant_id) + key=model_type, tenant_id=tenant_id) embedding_batch_size = model_config.get("chunk_batch", 10) if embedding_batch_size is None: embedding_batch_size = 10 diff --git a/docker/deploy.sh b/docker/deploy.sh index e30e6e75a..f65dd240b 100755 --- a/docker/deploy.sh +++ b/docker/deploy.sh @@ -17,6 +17,7 @@ DEPLOY_OPTIONS_FILE="$SCRIPT_DIR/deploy.options" MODE_CHOICE_SAVED="" VERSION_CHOICE_SAVED="" IS_MAINLAND_SAVED="" +DOWNLOAD_MODELS="N" ENABLE_TERMINAL_SAVED="N" TERMINAL_MOUNT_DIR_SAVED="${TERMINAL_MOUNT_DIR:-}" APP_VERSION="" @@ -79,6 +80,56 @@ is_windows_env() { return 1 } +detect_os_type() { + # Return: windows | mac | linux | unknown + local os_name + os_name=$(uname -s 2>/dev/null | tr '[:upper:]' '[:lower:]') + case "$os_name" in + mingw*|msys*|cygwin*) + echo "windows" + ;; + darwin*) + echo "mac" + ;; + linux*) + echo "linux" + ;; + *) + echo "unknown" + ;; + esac +} + +format_path_for_env() { + # Convert path to OS-specific format for .env values + local input_path="$1" + local os_type + os_type=$(detect_os_type) + + if [ "$os_type" = "windows" ]; then + if command -v cygpath >/dev/null 2>&1; then + cygpath -w "$input_path" + return 0 + fi + + if [[ "$input_path" =~ ^/([a-zA-Z])/(.*)$ ]]; then + local drive="${BASH_REMATCH[1]}" + local rest="${BASH_REMATCH[2]}" + rest="${rest//\//\\}" + printf "%s:\\%s" "$(echo "$drive" | tr '[:lower:]' '[:upper:]')" "$rest" + return 0 + fi + fi + + printf "%s" "$input_path" +} + +escape_backslashes() { + # Escape backslashes for safe writing into .env or JSON + local input_path="$1" + printf "%s" "$input_path" | sed 's/\\/\\\\/g' +} + is_port_in_use() { # Check if a TCP port is already in use (Linux/macOS/Windows Git Bash) local port="$1" @@ -266,6 +317,7 @@ persist_deploy_options() { echo "MODE_CHOICE=\"${MODE_CHOICE_SAVED}\"" echo "VERSION_CHOICE=\"${VERSION_CHOICE_SAVED}\"" echo "IS_MAINLAND=\"${IS_MAINLAND_SAVED}\"" + echo "DOWNLOAD_MODELS=\"${DOWNLOAD_MODELS}\"" echo "ENABLE_TERMINAL=\"${ENABLE_TERMINAL_SAVED}\"" echo "TERMINAL_MOUNT_DIR=\"${TERMINAL_MOUNT_DIR_SAVED}\"" } > "$DEPLOY_OPTIONS_FILE" @@ -528,6 +580,227 @@ select_deployment_mode() { echo "" } + +# Model download selection +select_model_download() { + echo "" + + local input_choice="" + read -r -p "Do you want to download AI model files (table-transformer and yolox)? [Y/N] (default: N): " input_choice + echo "" + + if [[ $input_choice =~ ^[Yy]$ ]]; then + DOWNLOAD_MODELS="Y" + echo "INFO: Model download will be performed." + else + DOWNLOAD_MODELS="N" + echo "INFO: Skipping model download." + fi + echo "----------------------------------------" + echo "" +} + +# kerry + +download_and_config_models() { + if [ "$DOWNLOAD_MODELS" != "Y" ]; then + echo "INFO: Model download skipped by user choice." + return 0 + fi + + echo "INFO: Downloading AI model files (this may take a while)..." + + local ENV_FILE_DIR="$SCRIPT_DIR" + local ENV_FILE_PATH="$ENV_FILE_DIR/.env" + local ORIGINAL_DIR="$(pwd)" + + MODEL_ROOT="$ROOT_DIR/model" + mkdir -p "$MODEL_ROOT" + echo "INFO: Model directory: $MODEL_ROOT" + + export HF_ENDPOINT="https://hf-mirror.com" + + command -v git >/dev/null || { echo "ERROR: git is required but not found."; return 1; } + + # ========================================== + # 1. Table Transformer (table-structure recognition) + echo "INFO: Downloading table-transformer-structure-recognition..." + + TT_MODEL_DIR_NAME="table-transformer-structure-recognition" + TT_MODEL_DIR_PATH="$MODEL_ROOT/$TT_MODEL_DIR_NAME" + TT_MODEL_FILE_CHECK="$TT_MODEL_DIR_PATH/model.safetensors" + + cd "$MODEL_ROOT" || return 1 + + if [ -d "$TT_MODEL_DIR_PATH" ] && [ -f "$TT_MODEL_FILE_CHECK" ]; then + FILE_SIZE=$(stat -c%s "$TT_MODEL_FILE_CHECK" 2>/dev/null || stat -f%z "$TT_MODEL_FILE_CHECK" 2>/dev/null) + if [ "$FILE_SIZE" -gt 1000000 ]; then + echo "INFO: Table Transformer already exists." + else + echo "WARN: Existing model file looks incomplete, re-downloading..." + rm -rf "$TT_MODEL_DIR_NAME" + fi + fi + + if [ ! -f "$TT_MODEL_FILE_CHECK" ]; then + if [ -d "$TT_MODEL_DIR_NAME" ]; then + echo "WARN: Removing existing directory before re-download..." + rm -rf "$TT_MODEL_DIR_NAME" + fi + + echo "INFO: Step 1/2: Clone repo (skip LFS files)..." + if ! GIT_LFS_SKIP_SMUDGE=1 git clone "$HF_ENDPOINT/microsoft/$TT_MODEL_DIR_NAME" "$TT_MODEL_DIR_NAME"; then + echo "ERROR: Failed to clone repository." + cd "$ORIGINAL_DIR" + return 1 + fi + + cd "$TT_MODEL_DIR_NAME" || return 1 + + echo "INFO: Step 2/2: Download model.safetensors..." + LARGE_FILE_URL="$HF_ENDPOINT/microsoft/$TT_MODEL_DIR_NAME/resolve/main/model.safetensors" + + if command -v curl &> /dev/null; then + curl -L -o "model.safetensors" "$LARGE_FILE_URL" --progress-bar + elif command -v wget &> /dev/null; then + wget "$LARGE_FILE_URL" -O "model.safetensors" + else + echo "ERROR: curl or wget is required to download model files." + cd "$MODEL_ROOT"; rm -rf "$TT_MODEL_DIR_NAME"; cd "$ORIGINAL_DIR"; return 1 + fi + + if [ ! -f "model.safetensors" ]; then + echo "ERROR: model.safetensors download failed." + cd "$MODEL_ROOT"; rm -rf "$TT_MODEL_DIR_NAME"; cd "$ORIGINAL_DIR"; return 1 + fi + + FILE_SIZE=$(stat -c%s "model.safetensors" 2>/dev/null || stat -f%z "model.safetensors" 2>/dev/null) + if [ "$FILE_SIZE" -lt 1000000 ]; then + echo "ERROR: model.safetensors seems too small (size: $FILE_SIZE bytes)." + cd "$MODEL_ROOT"; rm -rf "$TT_MODEL_DIR_NAME"; cd "$ORIGINAL_DIR"; return 1 + fi + + echo "INFO: model.safetensors downloaded (size: $(du -h model.safetensors | cut -f1))" + cd "$MODEL_ROOT" + fi + + echo "INFO: Table Transformer OK" + + # ========================================== + # 2. YOLOX (layout detection model) + echo "INFO: Downloading yolox_l0.05.onnx" + + YOLOX_MODEL_FILE="$MODEL_ROOT/yolox_l0.05.onnx" + MIN_YOLOX_SIZE=50000000 + + NEED_DOWNLOAD=false + + if [ -f "$YOLOX_MODEL_FILE" ]; then + CURRENT_SIZE=$(stat -c%s "$YOLOX_MODEL_FILE" 2>/dev/null || stat -f%z "$YOLOX_MODEL_FILE" 2>/dev/null) + if [ "$CURRENT_SIZE" -lt "$MIN_YOLOX_SIZE" ]; then + echo "WARN: Existing YOLOX file looks incomplete (size: $(numfmt --to=iec-i --suffix=B $CURRENT_SIZE 2>/dev/null || echo $CURRENT_SIZE)). Re-downloading..." + NEED_DOWNLOAD=true + else + echo "INFO: YOLOX already exists." + fi + else + NEED_DOWNLOAD=true + fi + + if [ "$NEED_DOWNLOAD" = true ]; then + ONNX_URL="$HF_ENDPOINT/unstructuredio/yolo_x_layout/resolve/main/yolox_l0.05.onnx" + + if command -v curl &> /dev/null; then + echo "INFO: Downloading with curl (supports resume -C -)..." + if curl -L -C - -o "$YOLOX_MODEL_FILE" "$ONNX_URL" --progress-bar; then + echo "INFO: curl download completed" + else + echo "ERROR: curl download failed." + cd "$ORIGINAL_DIR" + return 1 + fi + elif command -v wget &> /dev/null; then + echo "INFO: Downloading with wget (supports resume -c)..." + wget -c "$ONNX_URL" -O "$YOLOX_MODEL_FILE" + else + echo "ERROR: curl or wget is required to download model files." + cd "$ORIGINAL_DIR" + return 1 + fi + + if [ -f "$YOLOX_MODEL_FILE" ]; then + FINAL_SIZE=$(stat -c%s "$YOLOX_MODEL_FILE" 2>/dev/null || stat -f%z "$YOLOX_MODEL_FILE" 2>/dev/null) + if [ "$FINAL_SIZE" -lt "$MIN_YOLOX_SIZE" ]; then + echo "ERROR: YOLOX file seems too small (size: $FINAL_SIZE bytes)." + cd "$ORIGINAL_DIR" + return 1 + else + echo "INFO: YOLOX downloaded (size: $(numfmt --to=iec-i --suffix=B $FINAL_SIZE 2>/dev/null || echo $FINAL_SIZE))" + fi + else + echo "ERROR: YOLOX download failed: file not found." + cd "$ORIGINAL_DIR" + return 1 + fi + fi + + echo "INFO: YOLOX OK" + + # ========================================== + # 3. config.json + CONFIG_FILE="$MODEL_ROOT/config.json" + YOLOX_ABS_PATH=$(cd "$(dirname "$YOLOX_MODEL_FILE")" && pwd)/$(basename "$YOLOX_MODEL_FILE") + YOLOX_OS_PATH=$(format_path_for_env "$YOLOX_ABS_PATH") + YOLOX_CONFIG_PATH=$(escape_backslashes "$YOLOX_OS_PATH") + + cat > "$CONFIG_FILE" < /dev/null; then + update_env_var "TABLE_TRANSFORMER_MODEL_PATH" "$TT_MODEL_DIR_ENV_PATH" + update_env_var "UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH" "$CONFIG_FILE_ENV_PATH" + else + sed -i.bak "/^TABLE_TRANSFORMER_MODEL_PATH=/d" "$ENV_FILE_PATH" 2>/dev/null || true + echo "TABLE_TRANSFORMER_MODEL_PATH="$TT_MODEL_DIR_ENV_PATH"" >> "$ENV_FILE_PATH" + + sed -i.bak "/^UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH=/d" "$ENV_FILE_PATH" 2>/dev/null || true + echo "UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH="$CONFIG_FILE_ENV_PATH"" >> "$ENV_FILE_PATH" + rm -f "$ENV_FILE_PATH.bak" 2>/dev/null + fi + + echo "INFO: Environment file updated" + cd "$ORIGINAL_DIR" +} + clean() { export MINIO_ACCESS_KEY= export MINIO_SECRET_KEY= @@ -600,6 +873,13 @@ prepare_directory_and_data() { create_dir_with_permission "$ROOT_DIR/minio" 775 create_dir_with_permission "$ROOT_DIR/redis" 775 + echo "📦 Check the status of model configuration..." + download_and_config_models || { + echo "⚠️ A warning occurred during the model configuration step, but subsequent deployment will proceed..." + # Do not exit here; the user may choose N or prefer to continue after a download failure. + } + echo "" + cp -rn volumes $ROOT_DIR chmod -R 775 $ROOT_DIR/volumes echo " 📁 Directory $ROOT_DIR/volumes has been created and permissions set to 775." @@ -1057,6 +1337,8 @@ main_deploy() { select_terminal_tool || { echo "❌ Terminal tool container configuration failed"; exit 1; } choose_image_env || { echo "❌ Image environment setup failed"; exit 1; } + select_model_download || { echo "❌ Model download failed"; exit 1;} + # Set NEXENT_MCP_DOCKER_IMAGE in .env file if [ -n "${NEXENT_MCP_DOCKER_IMAGE:-}" ]; then update_env_var "NEXENT_MCP_DOCKER_IMAGE" "${NEXENT_MCP_DOCKER_IMAGE}" @@ -1142,7 +1424,7 @@ docker_compose_command="" case $version_type in "v1") echo "Detected Docker Compose V1, version: $version_number" - # The version ​​v1.28.0​​ is the minimum requirement in Docker Compose v1 that explicitly supports interpolation syntax with default values like ${VAR:-default} + # The version 1.28.0 is the minimum requirement in Docker Compose v1 for default interpolation syntax. if [[ $version_number < "1.28.0" ]]; then echo "Warning: V1 version is too old, consider upgrading to V2" exit 1 diff --git a/docker/init.sql b/docker/init.sql index 02e99632c..5a07941fe 100644 --- a/docker/init.sql +++ b/docker/init.sql @@ -213,6 +213,7 @@ CREATE TABLE IF NOT EXISTS "knowledge_record_t" ( "embedding_model_name" varchar(200) COLLATE "pg_catalog"."default", "group_ids" varchar, "ingroup_permission" varchar(30), + "is_multimodal" varchar(1) COLLATE "pg_catalog"."default" DEFAULT 'N'::character varying, "create_time" timestamp(0) DEFAULT CURRENT_TIMESTAMP, "update_time" timestamp(0) DEFAULT CURRENT_TIMESTAMP, "delete_flag" varchar(1) COLLATE "pg_catalog"."default" DEFAULT 'N'::character varying, @@ -230,6 +231,7 @@ COMMENT ON COLUMN "knowledge_record_t"."knowledge_sources" IS 'Knowledge base so COMMENT ON COLUMN "knowledge_record_t"."embedding_model_name" IS 'Embedding model name, used to record the embedding model used by the knowledge base'; COMMENT ON COLUMN "knowledge_record_t"."group_ids" IS 'Knowledge base group IDs list'; COMMENT ON COLUMN "knowledge_record_t"."ingroup_permission" IS 'In-group permission: EDIT, READ_ONLY, PRIVATE'; +COMMENT ON COLUMN "knowledge_record_t"."is_multimodal" IS 'whether it is multimodal'; COMMENT ON COLUMN "knowledge_record_t"."create_time" IS 'Creation time, audit field'; COMMENT ON COLUMN "knowledge_record_t"."update_time" IS 'Update time, audit field'; COMMENT ON COLUMN "knowledge_record_t"."delete_flag" IS 'When deleted by user frontend, delete flag will be set to true, achieving soft delete effect. Optional values Y/N'; diff --git a/docker/sql/v1.8.1_0306_add_is_multimodal_to_knowledge_record_t.sql b/docker/sql/v1.8.1_0306_add_is_multimodal_to_knowledge_record_t.sql new file mode 100644 index 000000000..d5b14bfbb --- /dev/null +++ b/docker/sql/v1.8.1_0306_add_is_multimodal_to_knowledge_record_t.sql @@ -0,0 +1,5 @@ +-- Add is_multimodal column to knowledge_record_t table +ALTER TABLE nexent.knowledge_record_t +ADD COLUMN IF NOT EXISTS is_multimodal varchar(1) DEFAULT 'N'; + +COMMENT ON COLUMN nexent.knowledge_record_t.is_multimodal IS 'whether it is multimodal'; diff --git a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx index c16ab969e..ed0e73d80 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx @@ -101,7 +101,8 @@ export default function ToolManagement({ // Use tool list hook for data management const { availableTools } = useToolList(); - const { isVlmAvailable, isEmbeddingAvailable } = useConfig(); + const { isVlmAvailable, isEmbeddingAvailable, isMultiEmbeddingAvailable } = useConfig(); + const isEmbeddingOrMultiAvailable = isEmbeddingAvailable || isMultiEmbeddingAvailable; // Prefetch knowledge bases for KB tools const { prefetchKnowledgeBases } = usePrefetchKnowledgeBases(); @@ -364,7 +365,10 @@ export default function ToolManagement({ tool.id ); const isDisabledDueToVlm = isToolDisabledDueToVlm(tool.name, isVlmAvailable); - const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding(tool.name, isEmbeddingAvailable); + const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding( + tool.name, + isEmbeddingOrMultiAvailable + ); const isDisabled = isDisabledDueToVlm || isDisabledDueToEmbedding || isReadOnly; // Tooltip priority: permission > VLM > Embedding const tooltipTitle = isReadOnly @@ -469,7 +473,10 @@ export default function ToolManagement({ {group.tools.map((tool) => { const isSelected = originalSelectedToolIdsSet.has(tool.id); const isDisabledDueToVlm = isToolDisabledDueToVlm(tool.name, isVlmAvailable); - const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding(tool.name, isEmbeddingAvailable); + const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding( + tool.name, + isEmbeddingOrMultiAvailable + ); const isDisabled = isDisabledDueToVlm || isDisabledDueToEmbedding || isReadOnly; // Tooltip priority: permission > VLM > Embedding const tooltipTitle = isReadOnly diff --git a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx index fc927d51d..5a2ce4ce3 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx @@ -420,6 +420,68 @@ export default function ToolConfigModal({ } }, [configData]); + const currentMultiEmbeddingModel = useMemo(() => { + try { + const modelConfig = configData?.models; + return ( + modelConfig?.multiEmbedding?.modelName || + modelConfig?.multiEmbedding?.displayName || + null + ); + } catch { + return null; + } + }, [configData]); + + const hasEmbeddingModel = Boolean(currentEmbeddingModel); + const hasMultiEmbeddingModel = Boolean(currentMultiEmbeddingModel); + const canToggleMultimodalParam = hasEmbeddingModel && hasMultiEmbeddingModel; + const forcedMultimodalValue = + !hasEmbeddingModel && hasMultiEmbeddingModel + ? true + : hasEmbeddingModel && !hasMultiEmbeddingModel + ? false + : null; + + const toolMultimodal = useMemo(() => { + const multimodalParam = currentParams.find( + (param) => param.name === "multimodal" + ); + const value = multimodalParam?.value; + if (typeof value === "boolean") { + return value; + } + if (typeof value === "number") { + return value === 1; + } + if (typeof value === "string") { + const normalized = value.trim().toLowerCase(); + if (normalized === "true") return true; + if (normalized === "false") return false; + } + return null; + }, [currentParams]); + + useEffect(() => { + if (tool?.name !== "knowledge_base_search") return; + if (forcedMultimodalValue === null) return; + + const index = currentParams.findIndex( + (param) => param.name === "multimodal" + ); + if (index < 0) return; + + const param = currentParams[index]; + if (param.value === forcedMultimodalValue) return; + + const updatedParams = [...currentParams]; + updatedParams[index] = { ...param, value: forcedMultimodalValue }; + setCurrentParams(updatedParams); + + const fieldName = `param_${index}`; + form.setFieldValue(fieldName, forcedMultimodalValue); + }, [tool?.name, forcedMultimodalValue, currentParams, form]); + // Check if a knowledge base can be selected const canSelectKnowledgeBase = useCallback( (kb: KnowledgeBase): boolean => { @@ -431,19 +493,40 @@ export default function ToolConfigModal({ } // For nexent source, check model matching - if (kb.source === "nexent" && currentEmbeddingModel) { - if ( - kb.embeddingModel && - kb.embeddingModel !== "unknown" && - kb.embeddingModel !== currentEmbeddingModel - ) { + if (kb.source === "nexent") { + const hasMultimodalConstraintMismatch = + toolMultimodal !== null && + ((toolMultimodal && !kb.is_multimodal) || + (!toolMultimodal && kb.is_multimodal)); + if (hasMultimodalConstraintMismatch) { return false; } + + if (kb.is_multimodal) { + if (!currentMultiEmbeddingModel) { + return false; + } + if ( + kb.embeddingModel && + kb.embeddingModel !== "unknown" && + kb.embeddingModel !== currentMultiEmbeddingModel + ) { + return false; + } + } else if (currentEmbeddingModel) { + if ( + kb.embeddingModel && + kb.embeddingModel !== "unknown" && + kb.embeddingModel !== currentEmbeddingModel + ) { + return false; + } + } } return true; }, - [currentEmbeddingModel] + [currentEmbeddingModel, currentMultiEmbeddingModel, toolMultimodal] ); // Track whether this is the first time opening the modal (reset when modal closes) @@ -1227,7 +1310,7 @@ export default function ToolConfigModal({ })} options={options.map((option) => ({ value: option, - label: option, + label: String(option), }))} /> ); @@ -1243,8 +1326,13 @@ export default function ToolConfigModal({ /> ); - case TOOL_PARAM_TYPES.BOOLEAN: - return ; + case TOOL_PARAM_TYPES.BOOLEAN: { + const isMultimodalParam = + tool.name === "knowledge_base_search" && param.name === "multimodal"; + return ( + + ); + } case TOOL_PARAM_TYPES.STRING: case TOOL_PARAM_TYPES.ARRAY: @@ -1587,6 +1675,8 @@ export default function ToolConfigModal({ syncLoading={kbLoading} isSelectable={canSelectKnowledgeBase} currentEmbeddingModel={currentEmbeddingModel} + currentMultiEmbeddingModel={currentMultiEmbeddingModel} + toolMultimodal={toolMultimodal} difyConfig={ toolKbType === "dify_search" ? difyConfig diff --git a/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx b/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx index e1c6e2c2c..c9e79f149 100644 --- a/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx +++ b/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx @@ -180,6 +180,7 @@ function DataConfig({ isActive }: DataConfigProps) { const [newKbName, setNewKbName] = useState(""); const [newKbIngroupPermission, setNewKbIngroupPermission] = useState("READ_ONLY"); const [newKbGroupIds, setNewKbGroupIds] = useState([]); + const [isMultimodal, setIsMultimodal] = useState(false); const [uploadFiles, setUploadFiles] = useState([]); const [hasClickedUpload, setHasClickedUpload] = useState(false); const [showEmbeddingWarning, setShowEmbeddingWarning] = useState(false); @@ -192,11 +193,28 @@ function DataConfig({ isActive }: DataConfigProps) { const [modelFilter, setModelFilter] = useState([]); const contentRef = useRef(null); + const hasEmbeddingModel = Boolean( + modelConfig?.embedding?.modelName?.trim() + ); + const hasMultiEmbeddingModel = Boolean( + modelConfig?.multiEmbedding?.modelName?.trim() + ); + const canToggleMultimodal = hasEmbeddingModel && hasMultiEmbeddingModel; + // Open warning modal when single Embedding model is not configured (ignore multi-embedding) useEffect(() => { - const singleEmbeddingModelName = modelConfig?.embedding?.modelName; - setShowEmbeddingWarning(!singleEmbeddingModelName); - }, [modelConfig?.embedding?.modelName]); + setShowEmbeddingWarning(!hasEmbeddingModel && !hasMultiEmbeddingModel); + }, [hasEmbeddingModel, hasMultiEmbeddingModel]); + + useEffect(() => { + if (hasMultiEmbeddingModel && !hasEmbeddingModel) { + setIsMultimodal(true); + return; + } + if (hasEmbeddingModel && !hasMultiEmbeddingModel) { + setIsMultimodal(false); + } + }, [hasEmbeddingModel, hasMultiEmbeddingModel]); // Add event listener for selecting new knowledge base useEffect(() => { @@ -616,6 +634,8 @@ function DataConfig({ isActive }: DataConfigProps) { setIsCreatingMode(true); setHasClickedUpload(false); // Reset upload button click state setUploadFiles([]); // Reset upload files array, clear all pending upload files + + setIsMultimodal(hasMultiEmbeddingModel && !hasEmbeddingModel) }; // Handle document deletion @@ -675,7 +695,8 @@ function DataConfig({ isActive }: DataConfigProps) { t("knowledgeBase.description.default"), "elasticsearch", newKbIngroupPermission, - newKbGroupIds + newKbGroupIds, + isMultimodal ); if (!newKB) { @@ -816,6 +837,11 @@ function DataConfig({ isActive }: DataConfigProps) { setNewKbName(name); }; + const handleSetMultimodal = (is_multimodal: boolean) => { + if (!canToggleMultimodal) return; + setIsMultimodal(is_multimodal) + } + // If Embedding model is not configured, show warning container instead of content if (showEmbeddingWarning) { return ( @@ -876,6 +902,9 @@ function DataConfig({ isActive }: DataConfigProps) { knowledgeBases={kbState.knowledgeBases} activeKnowledgeBase={kbState.activeKnowledgeBase} currentEmbeddingModel={kbState.currentEmbeddingModel} + currentMultiEmbeddingModel={ + modelConfig?.multiEmbedding?.modelName?.trim() || null + } isLoading={kbState.isLoading} syncLoading={kbState.syncLoading} onClick={handleKnowledgeBaseClick} @@ -936,6 +965,9 @@ function DataConfig({ isActive }: DataConfigProps) { onIngroupPermissionChange={setNewKbIngroupPermission} selectedGroupIds={newKbGroupIds} onSelectedGroupIdsChange={setNewKbGroupIds} + isMultimodal={isMultimodal} + onMultimodalChange={handleSetMultimodal} + canToggleMultimodal={canToggleMultimodal} // Upload related props isDragging={uiState.isDragging} onDragOver={handleDragOver} @@ -956,15 +988,31 @@ function DataConfig({ isActive }: DataConfigProps) { modelMismatch={hasKnowledgeBaseModelMismatch( kbState.activeKnowledgeBase )} - currentModel={kbState.currentEmbeddingModel || ""} + currentModel={ + kbState.activeKnowledgeBase.is_multimodal + ? modelConfig?.multiEmbedding?.modelName?.trim() || "" + : kbState.currentEmbeddingModel || "" + } knowledgeBaseModel={kbState.activeKnowledgeBase.embeddingModel} embeddingModelInfo={ hasKnowledgeBaseModelMismatch(kbState.activeKnowledgeBase) - ? t("document.modelMismatch.withModels", { - currentModel: kbState.currentEmbeddingModel || "", - knowledgeBaseModel: - kbState.activeKnowledgeBase.embeddingModel, - }) + ? (() => { + const currentModelName = kbState.activeKnowledgeBase.is_multimodal + ? modelConfig?.multiEmbedding?.modelName?.trim() || "" + : kbState.currentEmbeddingModel || ""; + const knowledgeBaseModel = + kbState.activeKnowledgeBase.embeddingModel; + if (!currentModelName) { + return t("document.chunk.tooltip.disabledDueToModelMismatch", { + currentModel: t("embedding.model.notConfigured"), + knowledgeBaseModel, + }); + } + return t("document.modelMismatch.withModels", { + currentModel: currentModelName, + knowledgeBaseModel, + }); + })() : undefined } containerHeight={SETUP_PAGE_CONTAINER.MAIN_CONTENT_HEIGHT} diff --git a/frontend/app/[locale]/knowledges/components/document/DocumentChunk.tsx b/frontend/app/[locale]/knowledges/components/document/DocumentChunk.tsx index 5e963c545..0e1c189af 100644 --- a/frontend/app/[locale]/knowledges/components/document/DocumentChunk.tsx +++ b/frontend/app/[locale]/knowledges/components/document/DocumentChunk.tsx @@ -128,16 +128,25 @@ const DocumentChunk: React.FC = ({ setTooltipResetKey((prev) => prev + 1); }, []); + const hasKnowledgeBaseModel = + Boolean(knowledgeBaseEmbeddingModel) && + knowledgeBaseEmbeddingModel !== "unknown"; + const hasCurrentModel = Boolean(currentEmbeddingModel); + const currentModelLabel = + currentEmbeddingModel || t("embedding.model.notConfigured"); + // Determine if embedding models mismatch (specific condition for tooltip) const isEmbeddingModelMismatch = React.useMemo(() => { - if (!currentEmbeddingModel || !knowledgeBaseEmbeddingModel) { - return false; - } - if (knowledgeBaseEmbeddingModel === "unknown") { + if (!hasKnowledgeBaseModel) { return false; } - return currentEmbeddingModel !== knowledgeBaseEmbeddingModel; - }, [currentEmbeddingModel, knowledgeBaseEmbeddingModel]); + return !hasCurrentModel || currentEmbeddingModel !== knowledgeBaseEmbeddingModel; + }, [ + currentEmbeddingModel, + hasCurrentModel, + hasKnowledgeBaseModel, + knowledgeBaseEmbeddingModel, + ]); // Determine if in read-only mode (embedding model mismatch OR user has READ_ONLY permission) // Note: isReadOnlyMode is broader, includes model mismatch and other conditions @@ -146,37 +155,48 @@ const DocumentChunk: React.FC = ({ if (permission === "READ_ONLY") { return true; } - if (!currentEmbeddingModel || !knowledgeBaseEmbeddingModel) { + if (!hasKnowledgeBaseModel) { return false; } - if (knowledgeBaseEmbeddingModel === "unknown") { - return false; - } - return currentEmbeddingModel !== knowledgeBaseEmbeddingModel; - }, [currentEmbeddingModel, knowledgeBaseEmbeddingModel, permission]); + return !hasCurrentModel || currentEmbeddingModel !== knowledgeBaseEmbeddingModel; + }, [ + currentEmbeddingModel, + hasCurrentModel, + hasKnowledgeBaseModel, + knowledgeBaseEmbeddingModel, + permission, + ]); // Determine if search should be disabled (only when embedding model mismatch, NOT for READ_ONLY permission) // This allows READ_ONLY users to still perform search const isSearchDisabled = React.useMemo(() => { - if (!currentEmbeddingModel || !knowledgeBaseEmbeddingModel) { - return false; - } - if (knowledgeBaseEmbeddingModel === "unknown") { + if (!hasKnowledgeBaseModel) { return false; } - return currentEmbeddingModel !== knowledgeBaseEmbeddingModel; - }, [currentEmbeddingModel, knowledgeBaseEmbeddingModel]); + return !hasCurrentModel || currentEmbeddingModel !== knowledgeBaseEmbeddingModel; + }, [ + currentEmbeddingModel, + hasCurrentModel, + hasKnowledgeBaseModel, + knowledgeBaseEmbeddingModel, + ]); // Disabled tooltip message when embedding model mismatch const disabledTooltipMessage = React.useMemo(() => { - if (isEmbeddingModelMismatch && currentEmbeddingModel && knowledgeBaseEmbeddingModel && knowledgeBaseEmbeddingModel !== "unknown") { + if (isEmbeddingModelMismatch && hasKnowledgeBaseModel) { return t("document.chunk.tooltip.disabledDueToModelMismatch", { - currentModel: currentEmbeddingModel, + currentModel: currentModelLabel, knowledgeBaseModel: knowledgeBaseEmbeddingModel }); } return ""; - }, [isEmbeddingModelMismatch, currentEmbeddingModel, knowledgeBaseEmbeddingModel, t]); + }, [ + currentModelLabel, + hasKnowledgeBaseModel, + isEmbeddingModelMismatch, + knowledgeBaseEmbeddingModel, + t, + ]); // Set active document when documents change useEffect(() => { @@ -322,11 +342,13 @@ const DocumentChunk: React.FC = ({ } // Check embedding model consistency before searching - if (isEmbeddingModelMismatch && currentEmbeddingModel && knowledgeBaseEmbeddingModel && knowledgeBaseEmbeddingModel !== "unknown") { - message.error(t("document.chunk.error.searchFailed", { - currentModel: currentEmbeddingModel, - knowledgeBaseModel: knowledgeBaseEmbeddingModel - })); + if (isEmbeddingModelMismatch && hasKnowledgeBaseModel) { + message.error( + t("document.chunk.tooltip.disabledDueToModelMismatch", { + currentModel: currentModelLabel, + knowledgeBaseModel: knowledgeBaseEmbeddingModel, + }) + ); return; } @@ -380,8 +402,9 @@ const DocumentChunk: React.FC = ({ resetChunkSearch, searchValue, t, + currentModelLabel, + hasKnowledgeBaseModel, isEmbeddingModelMismatch, - currentEmbeddingModel, knowledgeBaseEmbeddingModel, ]); @@ -465,14 +488,13 @@ const DocumentChunk: React.FC = ({ // Check embedding model consistency before creating chunk if (chunkModalMode === "create") { - if (knowledgeBaseEmbeddingModel && - knowledgeBaseEmbeddingModel !== "unknown" && - currentEmbeddingModel && - currentEmbeddingModel !== knowledgeBaseEmbeddingModel) { - message.error(t("document.chunk.error.createFailed", { - currentModel: currentEmbeddingModel, - knowledgeBaseModel: knowledgeBaseEmbeddingModel - })); + if (isEmbeddingModelMismatch && hasKnowledgeBaseModel) { + message.error( + t("document.chunk.tooltip.disabledDueToModelMismatch", { + currentModel: currentModelLabel, + knowledgeBaseModel: knowledgeBaseEmbeddingModel, + }) + ); return; } } diff --git a/frontend/app/[locale]/knowledges/components/document/DocumentList.tsx b/frontend/app/[locale]/knowledges/components/document/DocumentList.tsx index 3ce8ac803..bf71c3552 100644 --- a/frontend/app/[locale]/knowledges/components/document/DocumentList.tsx +++ b/frontend/app/[locale]/knowledges/components/document/DocumentList.tsx @@ -75,6 +75,10 @@ interface DocumentListProps { onSelectedGroupIdsChange?: (values: number[]) => void; permission?: string; // User's permission for this knowledge base (READ_ONLY, EDIT, etc.) + isMultimodal?: boolean; + onMultimodalChange?: (is_nultimodal: boolean) => void; + canToggleMultimodal?: boolean; + // Upload related props isDragging?: boolean; onDragOver?: (e: React.DragEvent) => void; @@ -114,6 +118,10 @@ const DocumentListContainer = forwardRef( onSelectedGroupIdsChange, permission, + isMultimodal = false, + onMultimodalChange, + canToggleMultimodal = true, + // Upload related props isDragging = false, onDragOver, @@ -502,6 +510,29 @@ const DocumentListContainer = forwardRef( options={permissionOptions} /> + + onMultimodalChange(!isMultimodal)} + onClick={() => + canToggleMultimodal && + onMultimodalChange && + onMultimodalChange(!isMultimodal) + } + style={{ + width: 80, // Keep width aligned with adjacent controls + display: 'inline-block', + textAlign: 'center', + cursor: canToggleMultimodal ? 'pointer' : 'default', + color: isMultimodal ? '#52c41a' : '#000000', // Success green when enabled, black when disabled + fontWeight: isMultimodal ? 500 : 400, // Slightly bolder when enabled + userSelect: 'none', // Prevent text selection on double-click + lineHeight: '32px' // Align with Select height + }} + title={isMultimodal ? "Multimodal: Enabled" : "Multimodal: Disabled"} + > + Multimodal + + ) : ( diff --git a/frontend/app/[locale]/knowledges/components/knowledge/KnowledgeBaseList.tsx b/frontend/app/[locale]/knowledges/components/knowledge/KnowledgeBaseList.tsx index d5ec5cdb7..20bb76161 100644 --- a/frontend/app/[locale]/knowledges/components/knowledge/KnowledgeBaseList.tsx +++ b/frontend/app/[locale]/knowledges/components/knowledge/KnowledgeBaseList.tsx @@ -32,6 +32,7 @@ interface KnowledgeBaseListProps { knowledgeBases: KnowledgeBase[]; activeKnowledgeBase: KnowledgeBase | null; currentEmbeddingModel: string | null; + currentMultiEmbeddingModel: string | null; isLoading?: boolean; syncLoading?: boolean; onClick: (kb: KnowledgeBase) => void; @@ -57,6 +58,7 @@ const KnowledgeBaseList: React.FC = ({ knowledgeBases, activeKnowledgeBase, currentEmbeddingModel, + currentMultiEmbeddingModel, isLoading = false, syncLoading = false, onClick, @@ -127,6 +129,19 @@ const KnowledgeBaseList: React.FC = ({ return `knowledgeBase.ingroup.permission.${permission || "DEFAULT"}`; }; + const isModelMismatch = (kb: KnowledgeBase) => { + if (kb.embeddingModel === "unknown") return false; + if (kb.source === "datamate") return false; + + if (kb.is_multimodal) { + if (!currentMultiEmbeddingModel) return true; + return kb.embeddingModel !== currentMultiEmbeddingModel; + } + + if (!currentEmbeddingModel) return false; + return kb.embeddingModel !== currentEmbeddingModel; + }; + // Search and filter states const [searchKeyword, setSearchKeyword] = useState(""); const [selectedSources, setSelectedSources] = useState([]); @@ -579,9 +594,7 @@ const KnowledgeBaseList: React.FC = ({ })} )} - {kb.embeddingModel !== "unknown" && - kb.embeddingModel !== currentEmbeddingModel && - kb.source !== "datamate" && ( + {isModelMismatch(kb) && ( @@ -601,6 +614,13 @@ const KnowledgeBaseList: React.FC = ({ ))} + {kb.is_multimodal && ( + + multimodal + + )} )} diff --git a/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx b/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx index 0a07774be..1c6adf1da 100644 --- a/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx +++ b/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx @@ -109,7 +109,8 @@ export const KnowledgeBaseContext = createContext<{ description: string, source?: string, ingroup_permission?: string, - group_ids?: number[] + group_ids?: number[], + is_multiimodal?: boolean, ) => Promise; deleteKnowledgeBase: (id: string) => Promise; selectKnowledgeBase: (id: string) => void; @@ -124,6 +125,7 @@ export const KnowledgeBaseContext = createContext<{ selectedIds: [], activeKnowledgeBase: null, currentEmbeddingModel: null, + currentMultiEmbeddingModel: null, isLoading: false, syncLoading: false, error: null, @@ -158,6 +160,7 @@ export const KnowledgeBaseProvider: React.FC = ({ selectedIds: [], activeKnowledgeBase: null, currentEmbeddingModel: null, + currentMultiEmbeddingModel: null, isLoading: false, syncLoading: false, error: null, @@ -198,16 +201,29 @@ export const KnowledgeBaseProvider: React.FC = ({ // Check if knowledge base has model mismatch (for display purposes) const hasKnowledgeBaseModelMismatch = useCallback( (kb: KnowledgeBase): boolean => { - if (!state.currentEmbeddingModel || kb.embeddingModel === "unknown") { + if (kb.embeddingModel === "unknown") { return false; } // DataMate knowledge bases don't report model mismatch (they are always selectable) if (kb.source === "datamate") { return false; } + + if (kb.is_multimodal) { + const multiEmbeddingModel = + modelConfig?.multiEmbedding?.modelName?.trim() || ""; + if (!multiEmbeddingModel) { + return true; + } + return kb.embeddingModel !== multiEmbeddingModel; + } + + if (!state.currentEmbeddingModel) { + return false; + } return kb.embeddingModel !== state.currentEmbeddingModel; }, - [state.currentEmbeddingModel] + [modelConfig?.multiEmbedding?.modelName, state.currentEmbeddingModel] ); // Load knowledge base data (supports force fetch from server and load selected status) - optimized with useCallback @@ -315,17 +331,23 @@ export const KnowledgeBaseProvider: React.FC = ({ description: string, source: string = "elasticsearch", ingroup_permission?: string, - group_ids?: number[] + group_ids?: number[], + is_multimodal?: boolean ) => { try { + const multiEmbeddingModel = + modelConfig?.multiEmbedding?.modelName?.trim() || ""; + const embeddingModel = is_multimodal + ? multiEmbeddingModel + : state.currentEmbeddingModel || "text-embedding-3-small"; const newKB = await knowledgeBaseService.createKnowledgeBase({ name, description, source, - embeddingModel: - state.currentEmbeddingModel || "text-embedding-3-small", + embeddingModel, ingroup_permission, group_ids, + is_multimodal, }); return newKB; } catch (error) { @@ -337,7 +359,7 @@ export const KnowledgeBaseProvider: React.FC = ({ return null; } }, - [state.currentEmbeddingModel, t] + [modelConfig?.multiEmbedding?.modelName, state.currentEmbeddingModel, t] ); // Delete knowledge base - memoized with useCallback diff --git a/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx b/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx index cd258abc8..cb48cfd95 100644 --- a/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx +++ b/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx @@ -277,7 +277,7 @@ export const ModelAddDialog = ({ let getModelList; let getProviderSelectedModalList; -// 2. 根据条件赋值 + // 2. Select provider-specific hooks if (form.provider === "silicon") { ({ getModelList, getProviderSelectedModalList } = siliconHook); } else if (form.provider === "dashscope") { @@ -451,7 +451,8 @@ export const ModelAddDialog = ({ // Call backend healthcheck API for tenant management const result = await modelService.checkManageTenantModelConnectivity( tenantId, - form.displayName || form.name + form.displayName || form.name, + modelType, ); // Set connectivity status diff --git a/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx b/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx index ebd7097d8..903934a7d 100644 --- a/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx +++ b/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx @@ -112,7 +112,8 @@ export const ModelEditDialog = ({ // Call backend healthcheck API for tenant management const result = await modelService.checkManageTenantModelConnectivity( tenantId, - form.displayName || form.name + form.displayName || form.name, + modelType, ); // Set connectivity status diff --git a/frontend/app/[locale]/models/components/modelConfig.tsx b/frontend/app/[locale]/models/components/modelConfig.tsx index 9a686352d..e4aceb551 100644 --- a/frontend/app/[locale]/models/components/modelConfig.tsx +++ b/frontend/app/[locale]/models/components/modelConfig.tsx @@ -527,6 +527,7 @@ export const ModelConfigSection = forwardRef< try { const isConnected = await modelService.verifyCustomModel( modelName, + modelType, signal ); @@ -603,7 +604,7 @@ export const ModelConfigSection = forwardRef< throttleTimerRef.current = setTimeout(async () => { try { // Use modelService to verify model - const isConnected = await modelService.verifyCustomModel(displayName); + const isConnected = await modelService.verifyCustomModel(displayName, modelType); // Update model status updateModelStatus( diff --git a/frontend/app/[locale]/tenant-resources/components/resources/ModelList.tsx b/frontend/app/[locale]/tenant-resources/components/resources/ModelList.tsx index 6de719127..e06ea0f58 100644 --- a/frontend/app/[locale]/tenant-resources/components/resources/ModelList.tsx +++ b/frontend/app/[locale]/tenant-resources/components/resources/ModelList.tsx @@ -92,7 +92,7 @@ export default function ModelList({ tenantId }: { tenantId: string | null }) { }; // Handle checking model connectivity - const handleCheckConnectivity = async (displayName: string) => { + const handleCheckConnectivity = async (displayName: string, modelType: string) => { if (!tenantId) { message.error(t("tenantResources.tenants.tenantIdRequired")); return; @@ -102,7 +102,8 @@ export default function ModelList({ tenantId }: { tenantId: string | null }) { try { const isConnected = await modelService.checkManageTenantModelConnectivity( tenantId, - displayName + displayName, + modelType ); if (isConnected) { message.success(t("tenantResources.models.connectivitySuccess")); @@ -197,7 +198,7 @@ export default function ModelList({ tenantId }: { tenantId: string | null }) { diff --git a/frontend/components/tool-config/KnowledgeBaseSelectorModal.tsx b/frontend/components/tool-config/KnowledgeBaseSelectorModal.tsx index 6c78c279b..cdde3b2b2 100644 --- a/frontend/components/tool-config/KnowledgeBaseSelectorModal.tsx +++ b/frontend/components/tool-config/KnowledgeBaseSelectorModal.tsx @@ -263,37 +263,39 @@ export default function KnowledgeBaseSelectorModal({ ] ); - // Check if a knowledge base has model mismatch (for display purposes) - const checkModelMismatch = (kb: KnowledgeBase): boolean => { - if (kb.source !== "nexent") { - return false; - } + const getModelMismatch = useCallback( + (kb: KnowledgeBase): boolean => { + if (kb.source !== "nexent") { + return false; + } - const hasMultimodalConstraintMismatch = - toolMultimodal !== null && - ((toolMultimodal && !kb.is_multimodal) || - (!toolMultimodal && kb.is_multimodal)); - if (hasMultimodalConstraintMismatch) { - return true; - } + const hasMultimodalConstraintMismatch = + toolMultimodal !== null && + ((toolMultimodal && !kb.is_multimodal) || + (!toolMultimodal && kb.is_multimodal)); + if (hasMultimodalConstraintMismatch) { + return true; + } - const embeddingModel = kb.embeddingModel; - if (!embeddingModel || embeddingModel === "unknown") { - return false; - } + const embeddingModel = kb.embeddingModel; + if (!embeddingModel || embeddingModel === "unknown") { + return false; + } - if (kb.is_multimodal) { - if (!currentMultiEmbeddingModel) { - return true; + if (kb.is_multimodal) { + if (!currentMultiEmbeddingModel) { + return true; + } + return embeddingModel !== currentMultiEmbeddingModel; } - return embeddingModel !== currentMultiEmbeddingModel; - } - if (!currentEmbeddingModel) { - return false; - } - return embeddingModel !== currentEmbeddingModel; - }; + if (!currentEmbeddingModel) { + return false; + } + return embeddingModel !== currentEmbeddingModel; + }, + [currentEmbeddingModel, currentMultiEmbeddingModel, toolMultimodal] + ); // Filter knowledge bases based on tool type, search, and filters const filteredKnowledgeBases = useMemo(() => { @@ -714,6 +716,7 @@ export default function KnowledgeBaseSelectorModal({ String(selectedId).trim() === String(kb.id).trim() ); const canSelect = checkCanSelect(kb); + const hasModelMismatch = getModelMismatch(kb); return (
list: # kerry diff --git a/sdk/nexent/data_process/extract_image.py b/sdk/nexent/data_process/extract_image.py index 0134a2733..8c40e9346 100644 --- a/sdk/nexent/data_process/extract_image.py +++ b/sdk/nexent/data_process/extract_image.py @@ -4,7 +4,7 @@ import hashlib import tempfile import subprocess -from typing import List, Dict, Any +from typing import List, Dict, Any, Optional import zipfile from xml.etree import ElementTree @@ -171,6 +171,76 @@ def _extract_pdf(self, pdf_path: str, **params) -> List[Dict]: return results + def _excel_sheet_files(self, z: zipfile.ZipFile) -> List[str]: + return [f for f in z.namelist() if f.startswith("xl/worksheets/sheet")] + + + def _excel_drawing_file(self, z: zipfile.ZipFile, sheet_file: str) -> Optional[str]: + sheet_xml = ElementTree.fromstring(z.read(sheet_file)) + drawing = sheet_xml.find( + ".//{http://schemas.openxmlformats.org/spreadsheetml/2006/main}drawing") + if drawing is None: + return None + + rel_id = drawing.get( + "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id") + rel_path = sheet_file.replace("worksheets", "worksheets/_rels") + ".rels" + if rel_path not in z.namelist(): + return None + + rel_xml = ElementTree.fromstring(z.read(rel_path)) + for rel in rel_xml: + if rel.get("Id") == rel_id: + return "xl/" + rel.get("Target").replace("../", "") + + return None + + + def _excel_rel_map(self, z: zipfile.ZipFile, drawing_file: str) -> Optional[Dict[str, str]]: + rel_file = drawing_file.replace("drawings/", "drawings/_rels/") + ".rels" + if rel_file not in z.namelist(): + return None + + rel_root = ElementTree.fromstring(z.read(rel_file)) + return { + rel.get("Id"): "xl/" + rel.get("Target").replace("../", "") + for rel in rel_root + } + + + def _excel_anchors(self, z: zipfile.ZipFile, drawing_file: str, ns: Dict[str, str]) -> List[Any]: + drawing_root = ElementTree.fromstring(z.read(drawing_file)) + return drawing_root.findall(".//xdr:twoCellAnchor", ns) + \ + drawing_root.findall(".//xdr:oneCellAnchor", ns) + + + def _excel_anchor_coords(self, anchor: Any, ns: Dict[str, str]) -> Optional[Dict[str, int]]: + from_node = anchor.find("xdr:from", ns) + if from_node is None: + return None + + row1 = int(from_node.find("xdr:row", ns).text) + 1 + col1 = int(from_node.find("xdr:col", ns).text) + 1 + + to_node = anchor.find("xdr:to", ns) + if to_node is not None: + row2 = int(to_node.find("xdr:row", ns).text) + 1 + col2 = int(to_node.find("xdr:col", ns).text) + 1 + else: + row2, col2 = row1, col1 + + return {"row1": row1, "col1": col1, "row2": row2, "col2": col2} + + + def _excel_anchor_embed_id(self, anchor: Any, ns: Dict[str, str]) -> Optional[str]: + blip = anchor.find(".//a:blip", ns) + if blip is None: + return None + + return blip.get( + "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed") + + def _extract_excel(self, xlsx_path): results = [] seen = set() @@ -182,86 +252,35 @@ def _extract_excel(self, xlsx_path): "r": "http://schemas.openxmlformats.org/officeDocument/2006/relationships", } - workbook = ElementTree.fromstring(z.read("xl/workbook.xml")) - sheets = {} - for s in workbook.findall(".//{http://schemas.openxmlformats.org/spreadsheetml/2006/main}sheet"): - sheets[s.get("r:id")] = s.get("name") - - sheet_files = [f for f in z.namelist( - ) if f.startswith("xl/worksheets/sheet")] + sheet_files = self._excel_sheet_files(z) for sheet_file in sheet_files: - sheet_xml = ElementTree.fromstring(z.read(sheet_file)) - drawing = sheet_xml.find( - ".//{http://schemas.openxmlformats.org/spreadsheetml/2006/main}drawing") - - if drawing is None: - continue - - rel_id = drawing.get( - "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id") - rel_path = sheet_file.replace( - "worksheets", "worksheets/_rels") + ".rels" - - if rel_path not in z.namelist(): - continue - - rel_xml = ElementTree.fromstring(z.read(rel_path)) - drawing_file = None - - for r in rel_xml: - if r.get("Id") == rel_id: - drawing_file = "xl/" + \ - r.get("Target").replace("../", "") - break - + drawing_file = self._excel_drawing_file(z, sheet_file) if drawing_file is None: continue - sheet_name = os.path.basename(sheet_file) - drawing_root = ElementTree.fromstring(z.read(drawing_file)) - - rel_file = drawing_file.replace( - "drawings/", "drawings/_rels/") + ".rels" - if rel_file not in z.namelist(): + rel_map = self._excel_rel_map(z, drawing_file) + if not rel_map: continue - rel_root = ElementTree.fromstring(z.read(rel_file)) - rel_map = { - r.get("Id"): "xl/" + r.get("Target").replace("../", "") - for r in rel_root - } - - anchors = drawing_root.findall(".//xdr:twoCellAnchor", ns) + \ - drawing_root.findall(".//xdr:oneCellAnchor", ns) + anchors = self._excel_anchors(z, drawing_file, ns) + sheet_name = os.path.basename(sheet_file) for anchor in anchors: - from_node = anchor.find("xdr:from", ns) - if from_node is None: + coords = self._excel_anchor_coords(anchor, ns) + if coords is None: continue - row1 = int(from_node.find("xdr:row", ns).text) + 1 - col1 = int(from_node.find("xdr:col", ns).text) + 1 - - to_node = anchor.find("xdr:to", ns) - if to_node is not None: - row2 = int(to_node.find("xdr:row", ns).text) + 1 - col2 = int(to_node.find("xdr:col", ns).text) + 1 - else: - row2, col2 = row1, col1 - - blip = anchor.find(".//a:blip", ns) - if blip is None: + embed_rel_id = self._excel_anchor_embed_id(anchor, ns) + if not embed_rel_id: continue - embed_rel_id = blip.get( - "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed") - if embed_rel_id not in rel_map: + target = rel_map.get(embed_rel_id) + if not target: continue - img_bytes = z.read(rel_map[embed_rel_id]) + img_bytes = z.read(target) h = self._hash(img_bytes) - if h in seen: continue seen.add(h) @@ -270,10 +289,10 @@ def _extract_excel(self, xlsx_path): "position": { "sheet_name": sheet_name, "coordinates": { - "x1": col1, - "x2": col2, - "y1": row1, - "y2": row2 + "x1": coords["col1"], + "x2": coords["col2"], + "y1": coords["row1"], + "y2": coords["row2"] } }, "image_format": self.detect_image_format(img_bytes), @@ -377,4 +396,4 @@ def process_file(self, file_bytes: bytes, chunking_strategy: str, filename: str, try: os.remove(f_path) except Exception: - pass \ No newline at end of file + pass diff --git a/sdk/nexent/vector_database/elasticsearch_core.py b/sdk/nexent/vector_database/elasticsearch_core.py index 190a98554..4f230786c 100644 --- a/sdk/nexent/vector_database/elasticsearch_core.py +++ b/sdk/nexent/vector_database/elasticsearch_core.py @@ -1081,10 +1081,6 @@ def semantic_search( "_source": {"excludes": ["multi_embedding"]}, } raw_results = self.exec_query(index_pattern, search_text_query) + self.exec_query(index_pattern, search_image_query) - - # raw_results = raw_results + raw_results2 - print("raw_results: ", raw_results) - else: search_query = { "knn": { diff --git a/test/backend/services/test_tool_configuration_service.py b/test/backend/services/test_tool_configuration_service.py index 4207c9e62..8f4bad20c 100644 --- a/test/backend/services/test_tool_configuration_service.py +++ b/test/backend/services/test_tool_configuration_service.py @@ -26,6 +26,7 @@ memory_pkg.__path__ = [] memory_service_stub = types.ModuleType("nexent.memory.memory_service") async def _clear_memory_stub(*_args, **_kwargs): + await asyncio.sleep(0) return None memory_service_stub.clear_memory = _clear_memory_stub sys.modules["nexent.memory.memory_service"] = memory_service_stub diff --git a/test/backend/services/test_vectordatabase_service.py b/test/backend/services/test_vectordatabase_service.py index cbb4904c3..6726a1701 100644 --- a/test/backend/services/test_vectordatabase_service.py +++ b/test/backend/services/test_vectordatabase_service.py @@ -29,6 +29,7 @@ memory_pkg.__path__ = [] memory_service_stub = ModuleType("nexent.memory.memory_service") async def _clear_memory_stub(*_args, **_kwargs): + await asyncio.sleep(0) return None memory_service_stub.clear_memory = _clear_memory_stub sys.modules["nexent.memory.memory_service"] = memory_service_stub diff --git a/test/sdk/data_process/test_core.py b/test/sdk/data_process/test_core.py index 359204c29..af325b52f 100644 --- a/test/sdk/data_process/test_core.py +++ b/test/sdk/data_process/test_core.py @@ -346,9 +346,6 @@ def test_file_process_returns_images_when_extractor_available(self, core, mocker {"image_bytes": b"img", "image_format": "png", "position": {"page_number": 1}} ] core.processors["Unstructured"] = mock_processor - core.processors["UniversalImageExtractor"] = Mock( - process_file=Mock(return_value=[]) - ) core.processors["UniversalImageExtractor"] = mock_extractor result = core.file_process( From 6b427c0a62ed218c60934fb6993cab3bd0bedd0f Mon Sep 17 00:00:00 2001 From: wyxkerry <1012700194@qq.com> Date: Tue, 31 Mar 2026 11:53:05 +0800 Subject: [PATCH 04/11] =?UTF-8?q?=E2=9C=A8add=5Fimage=5Fretrieval?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/services/vectordatabase_service.py | 42 ++++---- sdk/nexent/data_process/extract_image.py | 113 +++++++++++++-------- 2 files changed, 93 insertions(+), 62 deletions(-) diff --git a/backend/services/vectordatabase_service.py b/backend/services/vectordatabase_service.py index 38502afe3..c7e9b800a 100644 --- a/backend/services/vectordatabase_service.py +++ b/backend/services/vectordatabase_service.py @@ -284,6 +284,20 @@ def get_embedding_model(tenant_id: str, is_multimodal: bool = False, model_name: return _build_embedding_from_config(model_config) +def _resolve_embedding_model( + tenant_id: str, + is_multimodal: bool, + embedding_model_name: Optional[str], +) -> Optional[BaseEmbedding]: + if embedding_model_name: + return get_embedding_model( + tenant_id, + is_multimodal=is_multimodal, + model_name=embedding_model_name, + ) + return get_embedding_model(tenant_id, is_multimodal=is_multimodal) + + class ElasticSearchService: @staticmethod async def full_delete_knowledge_base(index_name: str, vdb_core: VectorDatabaseCore, user_id: str): @@ -472,20 +486,14 @@ def create_knowledge_base( with an explicit index_name. """ try: - if embedding_model_name is None: - if is_multimodal: - embedding_model = get_embedding_model(tenant_id, is_multimodal=True) - else: - embedding_model = get_embedding_model(tenant_id, None) - else: - if is_multimodal: - embedding_model = get_embedding_model( - tenant_id, - is_multimodal=True, - model_name=embedding_model_name, - ) - else: - embedding_model = get_embedding_model(tenant_id, embedding_model_name) + embedding_model = _resolve_embedding_model( + tenant_id=tenant_id, + is_multimodal=is_multimodal, + embedding_model_name=embedding_model_name, + ) + resolved_embedding_model_name = embedding_model_name + if resolved_embedding_model_name is None and embedding_model: + resolved_embedding_model_name = getattr(embedding_model, "model", None) # Create knowledge record first to obtain knowledge_id and generated index_name knowledge_data = { @@ -493,11 +501,7 @@ def create_knowledge_base( "knowledge_describe": "", "user_id": user_id, "tenant_id": tenant_id, - "embedding_model_name": ( - embedding_model_name - if embedding_model_name is not None - else (embedding_model.model if embedding_model else None) - ), + "embedding_model_name": resolved_embedding_model_name, "is_multimodal": is_multimodal, } diff --git a/sdk/nexent/data_process/extract_image.py b/sdk/nexent/data_process/extract_image.py index 8c40e9346..0b34987e8 100644 --- a/sdk/nexent/data_process/extract_image.py +++ b/sdk/nexent/data_process/extract_image.py @@ -241,6 +241,73 @@ def _excel_anchor_embed_id(self, anchor: Any, ns: Dict[str, str]) -> Optional[st "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed") + def _extract_excel_anchors( + self, + z: zipfile.ZipFile, + anchors: List[Any], + rel_map: Dict[str, str], + sheet_name: str, + ns: Dict[str, str], + seen: set, + ) -> List[Dict[str, Any]]: + results = [] + for anchor in anchors: + coords = self._excel_anchor_coords(anchor, ns) + if coords is None: + continue + + embed_rel_id = self._excel_anchor_embed_id(anchor, ns) + if not embed_rel_id: + continue + + target = rel_map.get(embed_rel_id) + if not target: + continue + + img_bytes = z.read(target) + h = self._hash(img_bytes) + if h in seen: + continue + seen.add(h) + + results.append({ + "position": { + "sheet_name": sheet_name, + "coordinates": { + "x1": coords["col1"], + "x2": coords["col2"], + "y1": coords["row1"], + "y2": coords["row2"] + } + }, + "image_format": self.detect_image_format(img_bytes), + "image_bytes": img_bytes + }) + + return results + + + def _extract_excel_sheet( + self, + z: zipfile.ZipFile, + sheet_file: str, + ns: Dict[str, str], + seen: set, + ) -> List[Dict[str, Any]]: + drawing_file = self._excel_drawing_file(z, sheet_file) + if drawing_file is None: + return [] + + rel_map = self._excel_rel_map(z, drawing_file) + if not rel_map: + return [] + + anchors = self._excel_anchors(z, drawing_file, ns) + sheet_name = os.path.basename(sheet_file) + + return self._extract_excel_anchors(z, anchors, rel_map, sheet_name, ns, seen) + + def _extract_excel(self, xlsx_path): results = [] seen = set() @@ -255,49 +322,9 @@ def _extract_excel(self, xlsx_path): sheet_files = self._excel_sheet_files(z) for sheet_file in sheet_files: - drawing_file = self._excel_drawing_file(z, sheet_file) - if drawing_file is None: - continue - - rel_map = self._excel_rel_map(z, drawing_file) - if not rel_map: - continue - - anchors = self._excel_anchors(z, drawing_file, ns) - sheet_name = os.path.basename(sheet_file) - - for anchor in anchors: - coords = self._excel_anchor_coords(anchor, ns) - if coords is None: - continue - - embed_rel_id = self._excel_anchor_embed_id(anchor, ns) - if not embed_rel_id: - continue - - target = rel_map.get(embed_rel_id) - if not target: - continue - - img_bytes = z.read(target) - h = self._hash(img_bytes) - if h in seen: - continue - seen.add(h) - - results.append({ - "position": { - "sheet_name": sheet_name, - "coordinates": { - "x1": coords["col1"], - "x2": coords["col2"], - "y1": coords["row1"], - "y2": coords["row2"] - } - }, - "image_format": self.detect_image_format(img_bytes), - "image_bytes": img_bytes - }) + results.extend( + self._extract_excel_sheet(z, sheet_file, ns, seen) + ) return results From 2deeb8ae415963e99037af8d6c953bcae5b957f0 Mon Sep 17 00:00:00 2001 From: wyxkerry <1012700194@qq.com> Date: Tue, 31 Mar 2026 14:30:41 +0800 Subject: [PATCH 05/11] =?UTF-8?q?=E2=9C=A8add=5Fimage=5Fretrieval?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../agentConfig/tool/ToolConfigModal.tsx | 42 ++----- .../KnowledgeBaseSelectorModal.tsx | 46 ++----- frontend/lib/knowledgeBaseCompatibility.ts | 46 +++++++ frontend/services/knowledgeBaseService.ts | 11 +- .../core/tools/knowledge_base_search_tool.py | 119 ++++++++---------- sdk/nexent/data_process/extract_image.py | 52 +++++--- test/backend/data_process/test_ray_actors.py | 92 +++++++------- test/backend/database/test_knowledge_db.py | 63 +++------- .../tools/test_knowledge_base_search_tool.py | 6 +- 9 files changed, 228 insertions(+), 249 deletions(-) create mode 100644 frontend/lib/knowledgeBaseCompatibility.ts diff --git a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx index 2dc55592c..5f0d969cf 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx @@ -30,6 +30,10 @@ import { API_ENDPOINTS } from "@/services/api"; import knowledgeBaseService from "@/services/knowledgeBaseService"; import log from "@/lib/logger"; import { isZhLocale, getLocalizedDescription } from "@/lib/utils"; +import { + isEmbeddingModelCompatible as isEmbeddingModelCompatibleBase, + isMultimodalConstraintMismatch as isMultimodalConstraintMismatchBase, +} from "@/lib/knowledgeBaseCompatibility"; export interface ToolConfigModalProps { isOpen: boolean; @@ -488,44 +492,18 @@ export default function ToolConfigModal({ const isMultimodalConstraintMismatch = useCallback( (kb: KnowledgeBase) => { - return ( - toolMultimodal !== null && - ((toolMultimodal && !kb.is_multimodal) || - (!toolMultimodal && kb.is_multimodal)) - ); + return isMultimodalConstraintMismatchBase(kb, toolMultimodal); }, [toolMultimodal] ); const isEmbeddingModelCompatible = useCallback( (kb: KnowledgeBase) => { - if (kb.is_multimodal) { - if (!currentMultiEmbeddingModel) { - return false; - } - if ( - kb.embeddingModel && - kb.embeddingModel !== "unknown" && - kb.embeddingModel !== currentMultiEmbeddingModel - ) { - return false; - } - return true; - } - - if (!currentEmbeddingModel) { - return true; - } - - if ( - kb.embeddingModel && - kb.embeddingModel !== "unknown" && - kb.embeddingModel !== currentEmbeddingModel - ) { - return false; - } - - return true; + return isEmbeddingModelCompatibleBase( + kb, + currentEmbeddingModel, + currentMultiEmbeddingModel + ); }, [currentEmbeddingModel, currentMultiEmbeddingModel] ); diff --git a/frontend/components/tool-config/KnowledgeBaseSelectorModal.tsx b/frontend/components/tool-config/KnowledgeBaseSelectorModal.tsx index cdde3b2b2..a3c72e804 100644 --- a/frontend/components/tool-config/KnowledgeBaseSelectorModal.tsx +++ b/frontend/components/tool-config/KnowledgeBaseSelectorModal.tsx @@ -20,6 +20,10 @@ import { import { KnowledgeBase } from "@/types/knowledgeBase"; import { KB_LAYOUT, KB_TAG_VARIANTS } from "@/const/knowledgeBaseLayout"; +import { + isEmbeddingModelCompatible as isEmbeddingModelCompatibleBase, + isMultimodalConstraintMismatch as isMultimodalConstraintMismatchBase, +} from "@/lib/knowledgeBaseCompatibility"; interface KnowledgeBaseSelectorProps { isOpen: boolean; @@ -188,44 +192,18 @@ export default function KnowledgeBaseSelectorModal({ const isMultimodalConstraintMismatch = useCallback( (kb: KnowledgeBase) => { - return ( - toolMultimodal !== null && - ((toolMultimodal && !kb.is_multimodal) || - (!toolMultimodal && kb.is_multimodal)) - ); + return isMultimodalConstraintMismatchBase(kb, toolMultimodal); }, [toolMultimodal] ); const isEmbeddingModelCompatible = useCallback( (kb: KnowledgeBase) => { - if (kb.is_multimodal) { - if (!currentMultiEmbeddingModel) { - return false; - } - if ( - kb.embeddingModel && - kb.embeddingModel !== "unknown" && - kb.embeddingModel !== currentMultiEmbeddingModel - ) { - return false; - } - return true; - } - - if (!currentEmbeddingModel) { - return true; - } - - if ( - kb.embeddingModel && - kb.embeddingModel !== "unknown" && - kb.embeddingModel !== currentEmbeddingModel - ) { - return false; - } - - return true; + return isEmbeddingModelCompatibleBase( + kb, + currentEmbeddingModel, + currentMultiEmbeddingModel + ); }, [currentEmbeddingModel, currentMultiEmbeddingModel] ); @@ -270,9 +248,7 @@ export default function KnowledgeBaseSelectorModal({ } const hasMultimodalConstraintMismatch = - toolMultimodal !== null && - ((toolMultimodal && !kb.is_multimodal) || - (!toolMultimodal && kb.is_multimodal)); + isMultimodalConstraintMismatchBase(kb, toolMultimodal); if (hasMultimodalConstraintMismatch) { return true; } diff --git a/frontend/lib/knowledgeBaseCompatibility.ts b/frontend/lib/knowledgeBaseCompatibility.ts new file mode 100644 index 000000000..37381c048 --- /dev/null +++ b/frontend/lib/knowledgeBaseCompatibility.ts @@ -0,0 +1,46 @@ +import { KnowledgeBase } from "@/types/knowledgeBase"; + +export const isMultimodalConstraintMismatch = ( + kb: KnowledgeBase, + toolMultimodal: boolean | null +): boolean => { + return ( + toolMultimodal !== null && + ((toolMultimodal && !kb.is_multimodal) || + (!toolMultimodal && kb.is_multimodal)) + ); +}; + +export const isEmbeddingModelCompatible = ( + kb: KnowledgeBase, + currentEmbeddingModel: string | null, + currentMultiEmbeddingModel: string | null +): boolean => { + if (kb.is_multimodal) { + if (!currentMultiEmbeddingModel) { + return false; + } + if ( + kb.embeddingModel && + kb.embeddingModel !== "unknown" && + kb.embeddingModel !== currentMultiEmbeddingModel + ) { + return false; + } + return true; + } + + if (!currentEmbeddingModel) { + return true; + } + + if ( + kb.embeddingModel && + kb.embeddingModel !== "unknown" && + kb.embeddingModel !== currentEmbeddingModel + ) { + return false; + } + + return true; +}; diff --git a/frontend/services/knowledgeBaseService.ts b/frontend/services/knowledgeBaseService.ts index 0c8262e9a..657160fc7 100644 --- a/frontend/services/knowledgeBaseService.ts +++ b/frontend/services/knowledgeBaseService.ts @@ -30,6 +30,9 @@ const normalizeIsMultimodal = (value: unknown): boolean => { return false; }; +const resolveIsMultimodal = (indexInfo: any, stats: any): boolean => + normalizeIsMultimodal(indexInfo.is_multimodal ?? stats.is_multimodal); + // Knowledge base service class class KnowledgeBaseService { // Check Elasticsearch health (force refresh, no caching for setup page) @@ -499,9 +502,7 @@ class KnowledgeBaseService { stats.creation_date || null, embeddingModel: stats.embedding_model || "unknown", - is_multimodal: normalizeIsMultimodal( - indexInfo.is_multimodal ?? stats.is_multimodal - ), + is_multimodal: resolveIsMultimodal(indexInfo, stats), knowledge_sources: indexInfo.knowledge_sources || "elasticsearch", ingroup_permission: indexInfo.ingroup_permission || "", @@ -569,9 +570,7 @@ class KnowledgeBaseService { createdAt: stats.creation_date || null, updatedAt: stats.update_date || stats.creation_date || null, embeddingModel: stats.embedding_model || "unknown", - is_multimodal: normalizeIsMultimodal( - indexInfo.is_multimodal ?? stats.is_multimodal - ), + is_multimodal: resolveIsMultimodal(indexInfo, stats), knowledge_sources: indexInfo.knowledge_sources || "datamate", ingroup_permission: indexInfo.ingroup_permission || "", diff --git a/sdk/nexent/core/tools/knowledge_base_search_tool.py b/sdk/nexent/core/tools/knowledge_base_search_tool.py index d47aad063..530d51765 100644 --- a/sdk/nexent/core/tools/knowledge_base_search_tool.py +++ b/sdk/nexent/core/tools/knowledge_base_search_tool.py @@ -39,6 +39,7 @@ class KnowledgeBaseSearchTool(Tool): "index_names": { "type": "array", "description": "The list of index names to search", + "nullable": True, "description_zh": "要索引的知识库" }, } @@ -106,9 +107,14 @@ def __init__( self.running_prompt_en = "Searching the knowledge base..." - def forward(self, query: str, index_names: List[str]) -> str: - # Parse index_names from string (always required) - search_index_names = index_names + def forward(self, query: str, index_names: List[str] | str | None = None) -> str: + # Parse index_names from string (optional) + if index_names is None: + search_index_names = self.index_names + elif isinstance(index_names, str): + search_index_names = [name.strip() for name in index_names.split(",") if name.strip()] + else: + search_index_names = index_names # Use the instance search_mode search_mode = self.search_mode @@ -260,82 +266,59 @@ def _record_search_results( self.observer.add_message( "", ProcessType.PICTURE_WEB, search_images_list_json) - def search_hybrid(self, query, index_names): + @staticmethod + def _format_vdb_results(results): + formatted_results = [] + for result in results: + doc = result["document"] + doc["score"] = result["score"] + # Include source index in results + doc["index"] = result["index"] + if "content" in result: + doc["content"] = result["content"] + if "process_source" in result: + doc["process_source"] = result["process_source"] + formatted_results.append(doc) + return formatted_results + + def _search_with(self, search_fn, query, index_names, label, **kwargs): try: - results = self.vdb_core.hybrid_search( - index_names=index_names, query_text=query, embedding_model=self.embedding_model, top_k=self.top_k + results = search_fn( + index_names=index_names, query_text=query, top_k=self.top_k, **kwargs ) - - # Format results - formatted_results = [] - for result in results: - doc = result["document"] - doc["score"] = result["score"] - # Include source index in results - doc["index"] = result["index"] - if "content" in result: - doc["content"] = result["content"] - if "process_source" in result: - doc["process_source"] = result["process_source"] - formatted_results.append(doc) - + formatted_results = self._format_vdb_results(results) return { "results": formatted_results, "total": len(formatted_results), } except Exception as e: - raise Exception(detail=f"Error during hybrid search: {str(e)}") + raise Exception(f"Error during {label} search: {str(e)}") - def search_accurate(self, query, index_names): - try: - results = self.vdb_core.accurate_search( - index_names=index_names, query_text=query, top_k=self.top_k) - - # Format results - formatted_results = [] - for result in results: - doc = result["document"] - doc["score"] = result["score"] - # Include source index in results - doc["index"] = result["index"] - if "content" in result: - doc["content"] = result["content"] - if "process_source" in result: - doc["process_source"] = result["process_source"] - formatted_results.append(doc) + def search_hybrid(self, query, index_names): + return self._search_with( + self.vdb_core.hybrid_search, + query, + index_names, + "hybrid", + embedding_model=self.embedding_model, + ) - return { - "results": formatted_results, - "total": len(formatted_results), - } - except Exception as e: - raise Exception(detail=f"Error during accurate search: {str(e)}") + def search_accurate(self, query, index_names): + return self._search_with( + self.vdb_core.accurate_search, + query, + index_names, + "accurate", + ) def search_semantic(self, query, index_names): - try: - results = self.vdb_core.semantic_search( - index_names=index_names, query_text=query, embedding_model=self.embedding_model, top_k=self.top_k - ) - - # Format results - formatted_results = [] - for result in results: - doc = result["document"] - doc["score"] = result["score"] - # Include source index in results - doc["index"] = result["index"] - if "content" in result: - doc["content"] = result["content"] - if "process_source" in result: - doc["process_source"] = result["process_source"] - formatted_results.append(doc) - - return { - "results": formatted_results, - "total": len(formatted_results), - } - except Exception as e: - raise Exception(detail=f"Error during semantic search: {str(e)}") + return self._search_with( + self.vdb_core.semantic_search, + query, + index_names, + "semantic", + embedding_model=self.embedding_model, + ) def _filter_images(self, images_list_url, query) -> list: # kerry diff --git a/sdk/nexent/data_process/extract_image.py b/sdk/nexent/data_process/extract_image.py index 0b34987e8..3fbe377f1 100644 --- a/sdk/nexent/data_process/extract_image.py +++ b/sdk/nexent/data_process/extract_image.py @@ -43,7 +43,24 @@ class UniversalImageExtractor(FileProcessor): @staticmethod def _hash(data: bytes) -> str: - return hashlib.md5(data).hexdigest() + # Use a modern hash for safe, collision-resistant de-duplication. + return hashlib.sha256(data).hexdigest() + + @staticmethod + def _openxml_namespace_maps() -> List[Dict[str, str]]: + # Prefer https URIs, but retain http for compatibility with existing files. + return [ + { + "xdr": "https://schemas.openxmlformats.org/drawingml/2006/spreadsheetDrawing", + "a": "https://schemas.openxmlformats.org/drawingml/2006/main", + "r": "https://schemas.openxmlformats.org/officeDocument/2006/relationships", + }, + { + "xdr": "http://schemas.openxmlformats.org/drawingml/2006/spreadsheetDrawing", + "a": "http://schemas.openxmlformats.org/drawingml/2006/main", + "r": "http://schemas.openxmlformats.org/officeDocument/2006/relationships", + }, + ] def _write_temp_file(self, data: bytes, suffix: str) -> str: @@ -178,12 +195,18 @@ def _excel_sheet_files(self, z: zipfile.ZipFile) -> List[str]: def _excel_drawing_file(self, z: zipfile.ZipFile, sheet_file: str) -> Optional[str]: sheet_xml = ElementTree.fromstring(z.read(sheet_file)) drawing = sheet_xml.find( - ".//{http://schemas.openxmlformats.org/spreadsheetml/2006/main}drawing") + ".//{https://schemas.openxmlformats.org/spreadsheetml/2006/main}drawing") + if drawing is None: + drawing = sheet_xml.find( + ".//{http://schemas.openxmlformats.org/spreadsheetml/2006/main}drawing") if drawing is None: return None rel_id = drawing.get( - "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id") + "{https://schemas.openxmlformats.org/officeDocument/2006/relationships}id") + if rel_id is None: + rel_id = drawing.get( + "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id") rel_path = sheet_file.replace("worksheets", "worksheets/_rels") + ".rels" if rel_path not in z.namelist(): return None @@ -237,8 +260,12 @@ def _excel_anchor_embed_id(self, anchor: Any, ns: Dict[str, str]) -> Optional[st if blip is None: return None - return blip.get( - "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed") + embed_id = blip.get( + "{https://schemas.openxmlformats.org/officeDocument/2006/relationships}embed") + if embed_id is None: + embed_id = blip.get( + "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed") + return embed_id def _extract_excel_anchors( @@ -313,18 +340,15 @@ def _extract_excel(self, xlsx_path): seen = set() with zipfile.ZipFile(xlsx_path) as z: - ns = { - "xdr": "http://schemas.openxmlformats.org/drawingml/2006/spreadsheetDrawing", - "a": "http://schemas.openxmlformats.org/drawingml/2006/main", - "r": "http://schemas.openxmlformats.org/officeDocument/2006/relationships", - } - sheet_files = self._excel_sheet_files(z) for sheet_file in sheet_files: - results.extend( - self._extract_excel_sheet(z, sheet_file, ns, seen) - ) + extracted = [] + for ns in self._openxml_namespace_maps(): + extracted = self._extract_excel_sheet(z, sheet_file, ns, seen) + if extracted: + break + results.extend(extracted) return results diff --git a/test/backend/data_process/test_ray_actors.py b/test/backend/data_process/test_ray_actors.py index 41cf69c56..07fe6676d 100644 --- a/test/backend/data_process/test_ray_actors.py +++ b/test/backend/data_process/test_ray_actors.py @@ -53,6 +53,27 @@ def expire(self, key, seconds): self.expirations[key] = seconds +def make_temp_file(tmp_path, name: str, content: bytes = b"file-bytes") -> str: + path = tmp_path / name + path.write_bytes(content) + return str(path) + + +def stub_consts(monkeypatch): + fake_consts_pkg = types.ModuleType("consts") + fake_consts_const = types.ModuleType("consts.const") + fake_consts_const.RAY_ACTOR_NUM_CPUS = 1 + fake_consts_const.REDIS_BACKEND_URL = "" + # New defaults required by ray_actors import + fake_consts_const.DEFAULT_EXPECTED_CHUNK_SIZE = 1024 + fake_consts_const.DEFAULT_MAXIMUM_CHUNK_SIZE = 1536 + fake_consts_const.TABLE_TRANSFORMER_MODEL_PATH = "/models/table" + fake_consts_const.UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH = "/models/unstructured.json" + monkeypatch.setitem(sys.modules, "consts", fake_consts_pkg) + monkeypatch.setitem(sys.modules, "consts.const", fake_consts_const) + return fake_consts_const + + @pytest.fixture(autouse=True) def stub_ray_before_import(monkeypatch): # Ensure that when module under test imports ray, it gets our stub @@ -138,17 +159,7 @@ class _Redis: monkeypatch.setitem(sys.modules, "backend.data_process.tasks", fake_dp_tasks) # Stub consts.const needed by ray_actors imports - fake_consts_pkg = types.ModuleType("consts") - fake_consts_const = types.ModuleType("consts.const") - fake_consts_const.RAY_ACTOR_NUM_CPUS = 1 - fake_consts_const.REDIS_BACKEND_URL = "" - # New defaults required by ray_actors import - fake_consts_const.DEFAULT_EXPECTED_CHUNK_SIZE = 1024 - fake_consts_const.DEFAULT_MAXIMUM_CHUNK_SIZE = 1536 - fake_consts_const.TABLE_TRANSFORMER_MODEL_PATH = "/models/table" - fake_consts_const.UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH = "/models/unstructured.json" - monkeypatch.setitem(sys.modules, "consts", fake_consts_pkg) - monkeypatch.setitem(sys.modules, "consts.const", fake_consts_const) + stub_consts(monkeypatch) # Ensure model_management_db is stubbed to avoid importing real DB layer if "database.model_management_db" not in sys.modules: @@ -184,12 +195,13 @@ class _Redis: return ray_actors -def test_process_file_happy_path(monkeypatch): +def test_process_file_happy_path(monkeypatch, tmp_path): ray_actors = import_module(monkeypatch) actor = ray_actors.DataProcessorRayActor() + source_path = make_temp_file(tmp_path, "a.txt") chunks = actor.process_file( - source="/tmp/a.txt", + source=source_path, chunking_strategy="basic", destination="local", task_id="tid-1", @@ -201,7 +213,7 @@ def test_process_file_happy_path(monkeypatch): assert chunks[0]["content"] == "hello world" -def test_process_file_applies_chunk_sizes_from_model(monkeypatch): +def test_process_file_applies_chunk_sizes_from_model(monkeypatch, tmp_path): ray_actors = import_module(monkeypatch) # Recorder core to capture params @@ -229,8 +241,9 @@ def file_process(self, file_data, filename, chunking_strategy, **params): ) actor = ray_actors.DataProcessorRayActor() + source_path = make_temp_file(tmp_path, "a.txt") actor.process_file( - source="/tmp/a.txt", + source=source_path, chunking_strategy="basic", destination="local", model_id=9, @@ -246,7 +259,7 @@ def file_process(self, file_data, filename, chunking_strategy, **params): ) == "/models/unstructured.json" -def test_process_file_no_model_omits_chunk_params(monkeypatch): +def test_process_file_no_model_omits_chunk_params(monkeypatch, tmp_path): ray_actors = import_module(monkeypatch) class RecorderCore: @@ -268,8 +281,9 @@ def file_process(self, file_data, filename, chunking_strategy, **params): ) actor = ray_actors.DataProcessorRayActor() + source_path = make_temp_file(tmp_path, "b.txt") actor.process_file( - source="/tmp/b.txt", + source=source_path, chunking_strategy="basic", destination="local", model_id=10, @@ -285,7 +299,7 @@ def file_process(self, file_data, filename, chunking_strategy, **params): ) == "/models/unstructured.json" -def test_process_file_model_lookup_exception_uses_defaults(monkeypatch): +def test_process_file_model_lookup_exception_uses_defaults(monkeypatch, tmp_path): ray_actors = import_module(monkeypatch) class RecorderCore: @@ -308,8 +322,9 @@ def file_process(self, file_data, filename, chunking_strategy, **params): ) actor = ray_actors.DataProcessorRayActor() + source_path = make_temp_file(tmp_path, "c.txt") actor.process_file( - source="/tmp/c.txt", + source=source_path, chunking_strategy="basic", destination="local", model_id=11, @@ -392,17 +407,7 @@ class _Redis: fake_dp_tasks.process_sync = lambda *a, **k: None monkeypatch.setitem(sys.modules, "backend.data_process.tasks", fake_dp_tasks) # Stub consts.const again for reload path - fake_consts_pkg = types.ModuleType("consts") - fake_consts_const = types.ModuleType("consts.const") - fake_consts_const.RAY_ACTOR_NUM_CPUS = 1 - fake_consts_const.REDIS_BACKEND_URL = "" - # Provide defaults required by backend.data_process.ray_actors import - fake_consts_const.DEFAULT_EXPECTED_CHUNK_SIZE = 1024 - fake_consts_const.DEFAULT_MAXIMUM_CHUNK_SIZE = 1536 - fake_consts_const.TABLE_TRANSFORMER_MODEL_PATH = "/models/table" - fake_consts_const.UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH = "/models/unstructured.json" - monkeypatch.setitem(sys.modules, "consts", fake_consts_pkg) - monkeypatch.setitem(sys.modules, "consts.const", fake_consts_const) + stub_consts(monkeypatch) # Stub database.model_management_db and link to parent to avoid real DB import if "database.model_management_db" not in sys.modules: @@ -433,7 +438,7 @@ class _Redis: actor.process_file("url://missing", "basic", destination="minio") -def test_process_file_core_returns_none_list_variants(monkeypatch): +def test_process_file_core_returns_none_list_variants(monkeypatch, tmp_path): class CoreNone(FakeDataProcessCore): def file_process(self, *a, **k): return None @@ -505,17 +510,7 @@ class _Redis: fake_dp_tasks.process_sync = lambda *a, **k: None monkeypatch.setitem(sys.modules, "backend.data_process.tasks", fake_dp_tasks) # Stub consts.const for ray_actors imports - fake_consts_pkg = types.ModuleType("consts") - fake_consts_const = types.ModuleType("consts.const") - fake_consts_const.RAY_ACTOR_NUM_CPUS = 1 - fake_consts_const.REDIS_BACKEND_URL = "" - # Provide defaults required by backend.data_process.ray_actors import - fake_consts_const.DEFAULT_EXPECTED_CHUNK_SIZE = 1024 - fake_consts_const.DEFAULT_MAXIMUM_CHUNK_SIZE = 1536 - fake_consts_const.TABLE_TRANSFORMER_MODEL_PATH = "/models/table" - fake_consts_const.UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH = "/models/unstructured.json" - monkeypatch.setitem(sys.modules, "consts", fake_consts_pkg) - monkeypatch.setitem(sys.modules, "consts.const", fake_consts_const) + stub_consts(monkeypatch) # Ensure model_management_db is stubbed to avoid importing real DB layer if "database.model_management_db" not in sys.modules: @@ -530,7 +525,8 @@ class _Redis: import backend.data_process.ray_actors as ray_actors reload(ray_actors) actor = ray_actors.DataProcessorRayActor() - chunks = actor.process_file("/tmp/a.txt", "basic", destination="local") + source_path = make_temp_file(tmp_path, f"a_{core_cls.__name__}.txt") + chunks = actor.process_file(source_path, "basic", destination="local") assert chunks == [] @@ -575,7 +571,7 @@ def test_store_chunks_in_redis_no_url_returns_false(monkeypatch): assert actor.store_chunks_in_redis("k", [{"content": "x"}]) is False -def test_process_file_appends_image_chunks(monkeypatch): +def test_process_file_appends_image_chunks(monkeypatch, tmp_path): ray_actors = import_module(monkeypatch) class CoreWithImages: @@ -604,14 +600,15 @@ def file_process(self, *a, **k): ) actor = ray_actors.DataProcessorRayActor() - chunks = actor.process_file("/tmp/a.pdf", "basic", destination="local") + source_path = make_temp_file(tmp_path, "a.pdf", content=b"%PDF-1.4") + chunks = actor.process_file(source_path, "basic", destination="local") assert len(chunks) == 2 assert chunks[1]["metadata"]["process_source"] == "UniversalImageExtractor" assert "image_url" in chunks[1]["metadata"] -def test_process_file_skips_invalid_image_entries(monkeypatch): +def test_process_file_skips_invalid_image_entries(monkeypatch, tmp_path): ray_actors = import_module(monkeypatch) class CoreWithBadImages: @@ -623,7 +620,8 @@ def file_process(self, *a, **k): monkeypatch.setattr(ray_actors, "DataProcessCore", CoreWithBadImages) actor = ray_actors.DataProcessorRayActor() - chunks = actor.process_file("/tmp/a.pdf", "basic", destination="local") + source_path = make_temp_file(tmp_path, "a.pdf", content=b"%PDF-1.4") + chunks = actor.process_file(source_path, "basic", destination="local") assert chunks == [{"content": "text", "metadata": {}}] diff --git a/test/backend/database/test_knowledge_db.py b/test/backend/database/test_knowledge_db.py index dc5147503..219372994 100644 --- a/test/backend/database/test_knowledge_db.py +++ b/test/backend/database/test_knowledge_db.py @@ -198,27 +198,32 @@ def mock_session(): return mock_session, mock_query -def test_create_knowledge_record_success(monkeypatch, mock_session): - """Test successful creation of knowledge record""" - session, _ = mock_session - - # Create mock knowledge record - mock_record = MockKnowledgeRecord(knowledge_name="test_knowledge") - mock_record.knowledge_id = 123 - mock_record.index_name = "test_knowledge" - - # Mock database session context +def setup_mock_db_session(monkeypatch, session): mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session - # Mock the context manager to call rollback on exception, like the real get_db_session does def mock_exit(exc_type, exc_val, exc_tb): if exc_type is not None: session.rollback() return None # Don't suppress the exception + mock_ctx.__exit__.side_effect = mock_exit monkeypatch.setattr( "backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + return mock_ctx + + +def test_create_knowledge_record_success(monkeypatch, mock_session): + """Test successful creation of knowledge record""" + session, _ = mock_session + + # Create mock knowledge record + mock_record = MockKnowledgeRecord(knowledge_name="test_knowledge") + mock_record.knowledge_id = 123 + mock_record.index_name = "test_knowledge" + + # Mock database session context + setup_mock_db_session(monkeypatch, session) # Prepare test data test_query = { @@ -305,17 +310,7 @@ def test_create_knowledge_record_sets_multimodal_flag(monkeypatch, mock_session) mock_record.knowledge_id = 123 mock_record.index_name = "test_knowledge" - mock_ctx = MagicMock() - mock_ctx.__enter__.return_value = session - - def mock_exit(exc_type, exc_val, exc_tb): - if exc_type is not None: - session.rollback() - return None - - mock_ctx.__exit__.side_effect = mock_exit - monkeypatch.setattr( - "backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + setup_mock_db_session(monkeypatch, session) test_query = { "index_name": "test_knowledge", @@ -339,17 +334,7 @@ def test_create_knowledge_record_exception(monkeypatch, mock_session): session, _ = mock_session session.add.side_effect = MockSQLAlchemyError("Database error") - mock_ctx = MagicMock() - mock_ctx.__enter__.return_value = session - # Mock the context manager to call rollback on exception, like the real get_db_session does - - def mock_exit(exc_type, exc_val, exc_tb): - if exc_type is not None: - session.rollback() - return None # Don't suppress the exception - mock_ctx.__exit__.side_effect = mock_exit - monkeypatch.setattr( - "backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + setup_mock_db_session(monkeypatch, session) test_query = { "index_name": "test_knowledge", @@ -374,17 +359,7 @@ def test_create_knowledge_record_generates_index_name(monkeypatch, mock_session) mock_record = MockKnowledgeRecord(knowledge_name="kb1") mock_record.knowledge_id = 7 - mock_ctx = MagicMock() - mock_ctx.__enter__.return_value = session - # Mock the context manager to call rollback on exception, like the real get_db_session does - - def mock_exit(exc_type, exc_val, exc_tb): - if exc_type is not None: - session.rollback() - return None # Don't suppress the exception - mock_ctx.__exit__.side_effect = mock_exit - monkeypatch.setattr( - "backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + setup_mock_db_session(monkeypatch, session) # Deterministic index name monkeypatch.setattr( diff --git a/test/sdk/core/tools/test_knowledge_base_search_tool.py b/test/sdk/core/tools/test_knowledge_base_search_tool.py index 8f377aa74..478b2687e 100644 --- a/test/sdk/core/tools/test_knowledge_base_search_tool.py +++ b/test/sdk/core/tools/test_knowledge_base_search_tool.py @@ -185,7 +185,7 @@ def test_search_hybrid_error(self, knowledge_base_search_tool): with pytest.raises(Exception) as excinfo: knowledge_base_search_tool.search_hybrid("test query", ["test_index1"]) - assert "Error during semantic search" in str(excinfo.value) + assert "Error during hybrid search" in str(excinfo.value) def test_forward_accurate_mode_success(self, knowledge_base_search_tool): """Test forward method with accurate search mode""" @@ -305,8 +305,8 @@ def test_forward_title_fallback(self, knowledge_base_search_tool): def test_forward_adds_picture_web_for_images(self, knowledge_base_search_tool, monkeypatch): """Forward should add picture messages when image results are present.""" - monkeypatch.setenv("DATA_PROCESS_SERVICE", "http://data-process") - knowledge_base_search_tool.data_process_service = "http://data-process" + monkeypatch.setenv("DATA_PROCESS_SERVICE", "https://data-process") + knowledge_base_search_tool.data_process_service = "https://data-process" mock_results = [ { From a959469753dab81dcf344138f77af1ff51b58b9b Mon Sep 17 00:00:00 2001 From: wyxkerry <1012700194@qq.com> Date: Tue, 31 Mar 2026 14:55:00 +0800 Subject: [PATCH 06/11] =?UTF-8?q?=E2=9C=A8add=5Fimage=5Fretrieval?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sdk/nexent/data_process/extract_image.py | 28 +++++++----------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/sdk/nexent/data_process/extract_image.py b/sdk/nexent/data_process/extract_image.py index 3fbe377f1..2fd87dc5c 100644 --- a/sdk/nexent/data_process/extract_image.py +++ b/sdk/nexent/data_process/extract_image.py @@ -47,20 +47,12 @@ def _hash(data: bytes) -> str: return hashlib.sha256(data).hexdigest() @staticmethod - def _openxml_namespace_maps() -> List[Dict[str, str]]: - # Prefer https URIs, but retain http for compatibility with existing files. - return [ - { - "xdr": "https://schemas.openxmlformats.org/drawingml/2006/spreadsheetDrawing", - "a": "https://schemas.openxmlformats.org/drawingml/2006/main", - "r": "https://schemas.openxmlformats.org/officeDocument/2006/relationships", - }, - { - "xdr": "http://schemas.openxmlformats.org/drawingml/2006/spreadsheetDrawing", - "a": "http://schemas.openxmlformats.org/drawingml/2006/main", - "r": "http://schemas.openxmlformats.org/officeDocument/2006/relationships", - }, - ] + def _openxml_namespace_maps() -> Dict[str, str]: + return { + "xdr": "https://schemas.openxmlformats.org/drawingml/2006/spreadsheetDrawing", + "a": "https://schemas.openxmlformats.org/drawingml/2006/main", + "r": "https://schemas.openxmlformats.org/officeDocument/2006/relationships", + } def _write_temp_file(self, data: bytes, suffix: str) -> str: @@ -342,13 +334,9 @@ def _extract_excel(self, xlsx_path): with zipfile.ZipFile(xlsx_path) as z: sheet_files = self._excel_sheet_files(z) + ns = self._openxml_namespace_maps() for sheet_file in sheet_files: - extracted = [] - for ns in self._openxml_namespace_maps(): - extracted = self._extract_excel_sheet(z, sheet_file, ns, seen) - if extracted: - break - results.extend(extracted) + results.extend(self._extract_excel_sheet(z, sheet_file, ns, seen)) return results From 2abcdb3826be7a20887dcf128ec78398c7e0fed7 Mon Sep 17 00:00:00 2001 From: wyxkerry <1012700194@qq.com> Date: Tue, 31 Mar 2026 22:58:27 +0800 Subject: [PATCH 07/11] =?UTF-8?q?=E2=9C=A8add=5Fimage=5Fretrieval?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../[locale]/knowledges/KnowledgeBaseConfiguration.tsx | 10 ---------- sdk/nexent/data_process/extract_image.py | 6 +++--- 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx b/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx index fc86ada16..c9e79f149 100644 --- a/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx +++ b/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx @@ -26,7 +26,6 @@ import knowledgeBaseService from "@/services/knowledgeBaseService"; import knowledgeBasePollingService from "@/services/knowledgeBasePollingService"; import { KnowledgeBase } from "@/types/knowledgeBase"; import { useConfig } from "@/hooks/useConfig"; -import { useModelList } from "@/hooks/model/useModelList"; import { SETUP_PAGE_CONTAINER, TWO_COLUMN_LAYOUT, @@ -128,9 +127,6 @@ function DataConfig({ isActive }: DataConfigProps) { const { modelConfig, data: configData, invalidateConfig, config, updateConfig, saveConfig } = useConfig(); const { token } = theme.useToken(); - // Get available embedding models for knowledge base creation - const { availableEmbeddingModels } = useModelList({ enabled: true }); - // Clear cache when component initializes useEffect(() => { localStorage.removeItem("preloaded_kb_data"); @@ -635,12 +631,6 @@ function DataConfig({ isActive }: DataConfigProps) { setNewKbName(defaultName); setNewKbIngroupPermission("READ_ONLY"); setNewKbGroupIds([]); - // Set default embedding model - prioritize config's default model, fall back to first available model - const configModel = modelConfig?.embedding?.modelName; - const defaultModel = configModel || (availableEmbeddingModels.length > 0 - ? availableEmbeddingModels[0].displayName - : ""); - setNewKbEmbeddingModel(defaultModel); setIsCreatingMode(true); setHasClickedUpload(false); // Reset upload button click state setUploadFiles([]); // Reset upload files array, clear all pending upload files diff --git a/sdk/nexent/data_process/extract_image.py b/sdk/nexent/data_process/extract_image.py index 2fd87dc5c..6d5051132 100644 --- a/sdk/nexent/data_process/extract_image.py +++ b/sdk/nexent/data_process/extract_image.py @@ -49,9 +49,9 @@ def _hash(data: bytes) -> str: @staticmethod def _openxml_namespace_maps() -> Dict[str, str]: return { - "xdr": "https://schemas.openxmlformats.org/drawingml/2006/spreadsheetDrawing", - "a": "https://schemas.openxmlformats.org/drawingml/2006/main", - "r": "https://schemas.openxmlformats.org/officeDocument/2006/relationships", + "xdr": "http://schemas.openxmlformats.org/drawingml/2006/spreadsheetDrawing", + "a": "http://schemas.openxmlformats.org/drawingml/2006/main", + "r": "http://schemas.openxmlformats.org/officeDocument/2006/relationships", } From 5ed7e8adc8aacdc35dcc180a26074e7cffa53032 Mon Sep 17 00:00:00 2001 From: wyxkerry <1012700194@qq.com> Date: Wed, 1 Apr 2026 00:19:13 +0800 Subject: [PATCH 08/11] =?UTF-8?q?=E2=9C=A8add=5Fimage=5Fretrieval?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sdk/nexent/data_process/extract_image.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sdk/nexent/data_process/extract_image.py b/sdk/nexent/data_process/extract_image.py index 6d5051132..38b452d6d 100644 --- a/sdk/nexent/data_process/extract_image.py +++ b/sdk/nexent/data_process/extract_image.py @@ -49,9 +49,9 @@ def _hash(data: bytes) -> str: @staticmethod def _openxml_namespace_maps() -> Dict[str, str]: return { - "xdr": "http://schemas.openxmlformats.org/drawingml/2006/spreadsheetDrawing", - "a": "http://schemas.openxmlformats.org/drawingml/2006/main", - "r": "http://schemas.openxmlformats.org/officeDocument/2006/relationships", + "xdr": "http://schemas.openxmlformats.org/drawingml/2006/spreadsheetDrawing", # NOSONAR + "a": "http://schemas.openxmlformats.org/drawingml/2006/main", # NOSONAR + "r": "http://schemas.openxmlformats.org/officeDocument/2006/relationships", # NOSONAR } From 39744e86315c58b21d0f5b55e2f67312e1b5f243 Mon Sep 17 00:00:00 2001 From: wyxkerry <1012700194@qq.com> Date: Sun, 5 Apr 2026 11:58:07 +0800 Subject: [PATCH 09/11] =?UTF-8?q?=E2=9C=A8add=5Fimage=5Fretrieval?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/apps/file_management_app.py | 6 ++++-- backend/consts/model.py | 1 + backend/data_process/ray_actors.py | 3 +++ backend/services/data_process_service.py | 6 +++++- backend/utils/file_management_utils.py | 6 +++++- .../knowledges/KnowledgeBaseConfiguration.tsx | 8 ++++++-- .../knowledges/contexts/DocumentContext.tsx | 8 ++++---- .../contexts/KnowledgeBaseContext.tsx | 2 +- frontend/services/knowledgeBaseService.ts | 6 ++++-- frontend/tsconfig.json | 2 +- sdk/nexent/data_process/core.py | 13 ++++++++----- .../vector_database/elasticsearch_core.py | 18 +++++++++--------- 12 files changed, 51 insertions(+), 28 deletions(-) diff --git a/backend/apps/file_management_app.py b/backend/apps/file_management_app.py index 5b7c7bc3c..77af77650 100644 --- a/backend/apps/file_management_app.py +++ b/backend/apps/file_management_app.py @@ -120,6 +120,7 @@ async def process_files( chunking_strategy: Optional[str] = Body("basic"), index_name: str = Body(...), destination: str = Body(...), + is_multimodal: Optional[bool] = Body(False), authorization: Optional[str] = Header(None) ): """ @@ -133,7 +134,8 @@ async def process_files( chunking_strategy=chunking_strategy, source_type=destination, index_name=index_name, - authorization=authorization + authorization=authorization, + is_multimodal=is_multimodal ) process_result = await trigger_data_process(files, process_params) @@ -639,4 +641,4 @@ async def preview_file( raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=f"Failed to preview file: {str(e)}" - ) \ No newline at end of file + ) diff --git a/backend/consts/model.py b/backend/consts/model.py index 2728d95ca..128fe81d4 100644 --- a/backend/consts/model.py +++ b/backend/consts/model.py @@ -234,6 +234,7 @@ class ProcessParams(BaseModel): source_type: str index_name: str authorization: Optional[str] = None + is_multimodal: Optional[bool] = False class OpinionRequest(BaseModel): diff --git a/backend/data_process/ray_actors.py b/backend/data_process/ray_actors.py index 934be1720..b9fd982ae 100644 --- a/backend/data_process/ray_actors.py +++ b/backend/data_process/ray_actors.py @@ -118,9 +118,12 @@ def _apply_model_chunk_sizes( maximum_chunk_size = model_record.get( 'maximum_chunk_size', DEFAULT_MAXIMUM_CHUNK_SIZE) model_name = model_record.get('display_name') + model_type = model_record.get('model_type') params['max_characters'] = maximum_chunk_size params['new_after_n_chars'] = expected_chunk_size + if model_type: + params['model_type'] = model_type logger.info( f"[RayActor] Using chunk sizes from embedding model '{model_name}' (ID: {model_id}): " diff --git a/backend/services/data_process_service.py b/backend/services/data_process_service.py index 9eae72407..dcd80424d 100644 --- a/backend/services/data_process_service.py +++ b/backend/services/data_process_service.py @@ -474,6 +474,8 @@ async def create_batch_tasks_impl(self, authorization: Optional[str], request: B chunking_strategy = source_config.get('chunking_strategy') index_name = source_config.get('index_name') original_filename = source_config.get('original_filename') + embedding_model_id = source_config.get('embedding_model_id') + tenant_id = source_config.get('tenant_id') # Validate required fields if not source: @@ -492,7 +494,9 @@ async def create_batch_tasks_impl(self, authorization: Optional[str], request: B source_type=source_type, chunking_strategy=chunking_strategy, index_name=index_name, - original_filename=original_filename + original_filename=original_filename, + embedding_model_id=embedding_model_id, + tenant_id=tenant_id ).set(queue='process_q'), forward.s( index_name=index_name, diff --git a/backend/utils/file_management_utils.py b/backend/utils/file_management_utils.py index 57025e350..4103770f3 100644 --- a/backend/utils/file_management_utils.py +++ b/backend/utils/file_management_utils.py @@ -40,11 +40,13 @@ async def trigger_data_process(files: List[dict], process_params: ProcessParams) # Get chunking size according to the embedding model embedding_model_id = None tenant_id = None + is_multimodal = process_params.is_multimodal try: _, tenant_id = get_current_user_id(process_params.authorization) # Get embedding model ID from tenant config tenant_config = tenant_config_manager.load_config(tenant_id) - embedding_model_id_str = tenant_config.get("EMBEDDING_ID") if tenant_config else None + embedding_id_key = "MULTI_EMBEDDING_ID" if is_multimodal else "EMBEDDING_ID" + embedding_model_id_str = tenant_config.get(embedding_id_key) if tenant_config else None if embedding_model_id_str: embedding_model_id = int(embedding_model_id_str) except Exception as e: @@ -66,6 +68,7 @@ async def trigger_data_process(files: List[dict], process_params: ProcessParams) "index_name": process_params.index_name, "original_filename": file_details.get("filename"), "embedding_model_id": embedding_model_id, + "is_multimodal": is_multimodal, "tenant_id": tenant_id } @@ -97,6 +100,7 @@ async def trigger_data_process(files: List[dict], process_params: ProcessParams) "index_name": process_params.index_name, "original_filename": file_details.get("filename"), "embedding_model_id": embedding_model_id, + "is_multimodal": is_multimodal, "tenant_id": tenant_id } sources.append(source) diff --git a/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx b/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx index c9e79f149..954c3b82e 100644 --- a/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx +++ b/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx @@ -711,7 +711,7 @@ function DataConfig({ isActive }: DataConfigProps) { setHasClickedUpload(false); setNewlyCreatedKbId(newKB.id); // Mark this KB as newly created - await uploadDocuments(newKB.id, filesToUpload); + await uploadDocuments(newKB.id, filesToUpload, isMultimodal); setUploadFiles([]); knowledgeBasePollingService @@ -747,7 +747,11 @@ function DataConfig({ isActive }: DataConfigProps) { } try { - await uploadDocuments(kbId, filesToUpload); + await uploadDocuments( + kbId, + filesToUpload, + kbState.activeKnowledgeBase?.is_multimodal + ); setUploadFiles([]); knowledgeBasePollingService.triggerKnowledgeBaseListUpdate(true); diff --git a/frontend/app/[locale]/knowledges/contexts/DocumentContext.tsx b/frontend/app/[locale]/knowledges/contexts/DocumentContext.tsx index b956dd919..7a2dcfb2e 100644 --- a/frontend/app/[locale]/knowledges/contexts/DocumentContext.tsx +++ b/frontend/app/[locale]/knowledges/contexts/DocumentContext.tsx @@ -112,7 +112,7 @@ export const DocumentContext = createContext<{ state: DocumentState; dispatch: React.Dispatch; fetchDocuments: (kbId: string, forceRefresh?: boolean, kbSource?: string) => Promise; - uploadDocuments: (kbId: string, files: File[]) => Promise; + uploadDocuments: (kbId: string, files: File[], isMultimodal?: boolean) => Promise; deleteDocument: (kbId: string, docId: string) => Promise; }>({ state: { @@ -202,11 +202,11 @@ export const DocumentProvider: React.FC = ({ children }) }, [state.loadingKbIds, state.documentsMap, t]); // Upload documents to a knowledge base - const uploadDocuments = useCallback(async (kbId: string, files: File[]) => { + const uploadDocuments = useCallback(async (kbId: string, files: File[], isMultimodal?: boolean) => { dispatch({ type: DOCUMENT_ACTION_TYPES.SET_UPLOADING, payload: true }); try { - await knowledgeBaseService.uploadDocuments(kbId, files); + await knowledgeBaseService.uploadDocuments(kbId, files, undefined, isMultimodal); // Set loading state before fetching latest documents dispatch({ type: DOCUMENT_ACTION_TYPES.SET_LOADING_DOCUMENTS, payload: true }); @@ -265,4 +265,4 @@ export const DocumentProvider: React.FC = ({ children }) {children} ); -}; \ No newline at end of file +}; diff --git a/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx b/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx index 1a087a6a5..0aa397863 100644 --- a/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx +++ b/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx @@ -110,7 +110,7 @@ export const KnowledgeBaseContext = createContext<{ source?: string, ingroup_permission?: string, group_ids?: number[], - is_multiimodal?: boolean, + is_multimodal?: boolean, ) => Promise; deleteKnowledgeBase: (id: string) => Promise; selectKnowledgeBase: (id: string) => void; diff --git a/frontend/services/knowledgeBaseService.ts b/frontend/services/knowledgeBaseService.ts index 657160fc7..b2cc4fce3 100644 --- a/frontend/services/knowledgeBaseService.ts +++ b/frontend/services/knowledgeBaseService.ts @@ -700,7 +700,7 @@ class KnowledgeBaseService { } = { name: params.name, description: params.description || "", - embeddingModel: params.embeddingModel || "", + embedding_model_name: params.embeddingModel || "", is_multimodal: params.is_multimodal || false }; @@ -846,7 +846,8 @@ class KnowledgeBaseService { async uploadDocuments( kbId: string, files: File[], - chunkingStrategy?: string + chunkingStrategy?: string, + isMultimodal?: boolean ): Promise { try { // Create FormData object @@ -908,6 +909,7 @@ class KnowledgeBaseService { files: filesToProcess, chunking_strategy: chunkingStrategy, destination: "minio", + is_multimodal: isMultimodal ?? false, }), }); diff --git a/frontend/tsconfig.json b/frontend/tsconfig.json index d61634fac..75f792957 100644 --- a/frontend/tsconfig.json +++ b/frontend/tsconfig.json @@ -8,7 +8,7 @@ "noEmit": true, "esModuleInterop": true, "module": "esnext", - "moduleResolution": "node", + "moduleResolution": "bundler", "resolveJsonModule": true, "isolatedModules": true, "jsx": "preserve", diff --git a/sdk/nexent/data_process/core.py b/sdk/nexent/data_process/core.py index 84bff7c5a..b58e6fe03 100644 --- a/sdk/nexent/data_process/core.py +++ b/sdk/nexent/data_process/core.py @@ -1,6 +1,6 @@ import logging import os -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from .extract_image import UniversalImageExtractor @@ -86,10 +86,10 @@ def file_process( # Select appropriate processor if processor: processor_name = processor - _, extractor = self._select_processor_by_filename(filename) + _, extractor = self._select_processor_by_filename(filename, params) else: processor_name, extractor = self._select_processor_by_filename( - filename) + filename, params) processor_instance = self.processors.get(processor_name) extract_image_processor_instance = ( @@ -131,13 +131,16 @@ def _validate_parameters(self, chunking_strategy: str, processor: Optional[str]) logger.debug( f"Parameter validation passed: chunking_strategy={chunking_strategy}, processor={processor}") - def _select_processor_by_filename(self, filename: str) -> Tuple[str, Optional[str]]: + def _select_processor_by_filename( + self, filename: str, params: Optional[Dict[str, Any]] = None + ) -> Tuple[str, Optional[str]]: """Selects a processor based on the file extension.""" _, file_extension = os.path.splitext(filename) file_extension = file_extension.lower() extract_image = None - if file_extension in self.EXTRACT_IMAGE_EXTENSIONS: + model_type = params.get("model_type") + if model_type == "multi_embedding" and file_extension in self.EXTRACT_IMAGE_EXTENSIONS: extract_image = "UniversalImageExtractor" if file_extension in self.EXCEL_EXTENSIONS: return "OpenPyxl", extract_image diff --git a/sdk/nexent/vector_database/elasticsearch_core.py b/sdk/nexent/vector_database/elasticsearch_core.py index 4f230786c..6d5953f0c 100644 --- a/sdk/nexent/vector_database/elasticsearch_core.py +++ b/sdk/nexent/vector_database/elasticsearch_core.py @@ -461,15 +461,15 @@ def _prepare_small_batch_embeddings( inputs.append({"text": doc[content_field]}) embeddings = embedding_model.get_multimodal_embeddings(inputs) return processed_docs, embeddings - - filtered_docs = [ - doc - for doc in processed_docs - if doc.get("process_source") != "UniversalImageExtractor" - ] - inputs = [doc[content_field] for doc in filtered_docs] - embeddings = embedding_model.get_embeddings(inputs) - return filtered_docs, embeddings + else: + filtered_docs = [ + doc + for doc in processed_docs + if doc.get("process_source") != "UniversalImageExtractor" + ] + inputs = [doc[content_field] for doc in filtered_docs] + embeddings = embedding_model.get_embeddings(inputs) + return filtered_docs, embeddings @staticmethod def _build_bulk_operations( From 845a369542d57f60a5cf978618a77e57e8efd8cb Mon Sep 17 00:00:00 2001 From: wyxkerry <1012700194@qq.com> Date: Sun, 5 Apr 2026 13:30:55 +0800 Subject: [PATCH 10/11] =?UTF-8?q?=E2=9C=A8add=5Fimage=5Fretrieval?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/apps/file_management_app.py | 16 ++++++++-------- test/backend/database/test_attachment_db.py | 2 ++ .../services/test_data_process_service.py | 8 ++++++-- .../services/test_vectordatabase_service.py | 6 ++++-- test/sdk/data_process/test_core.py | 11 ++++++++--- 5 files changed, 28 insertions(+), 15 deletions(-) diff --git a/backend/apps/file_management_app.py b/backend/apps/file_management_app.py index 77af77650..3c9a95fbe 100644 --- a/backend/apps/file_management_app.py +++ b/backend/apps/file_management_app.py @@ -2,7 +2,7 @@ import re import base64 from http import HTTPStatus -from typing import List, Optional +from typing import Annotated, List, Optional from urllib.parse import urlparse, urlunparse, unquote, quote import httpx @@ -115,13 +115,13 @@ async def upload_files( @file_management_config_router.post("/process") async def process_files( - files: List[dict] = Body( - ..., description="List of file details to process, including path_or_url and filename"), - chunking_strategy: Optional[str] = Body("basic"), - index_name: str = Body(...), - destination: str = Body(...), - is_multimodal: Optional[bool] = Body(False), - authorization: Optional[str] = Header(None) + files: Annotated[List[dict], Body( + ..., description="List of file details to process, including path_or_url and filename")], + index_name: Annotated[str, Body(...)], + destination: Annotated[str, Body(...)], + chunking_strategy: Annotated[Optional[str], Body("basic")], + is_multimodal: Annotated[Optional[bool], Body(False)], + authorization: Annotated[Optional[str], Header(None)] ): """ Trigger data processing for a list of uploaded files. diff --git a/test/backend/database/test_attachment_db.py b/test/backend/database/test_attachment_db.py index afb080682..771b90b27 100644 --- a/test/backend/database/test_attachment_db.py +++ b/test/backend/database/test_attachment_db.py @@ -17,6 +17,8 @@ # Mock consts module consts_mock = MagicMock() consts_mock.const = MagicMock() +# Ensure constants are real strings to avoid startswith TypeError +consts_mock.const.S3_URL_PREFIX = "s3://" # Environment variables are now configured in conftest.py sys.modules['consts'] = consts_mock diff --git a/test/backend/services/test_data_process_service.py b/test/backend/services/test_data_process_service.py index 393bea339..f306c54a5 100644 --- a/test/backend/services/test_data_process_service.py +++ b/test/backend/services/test_data_process_service.py @@ -1667,14 +1667,18 @@ async def async_test_create_batch_tasks_impl_success(self, mock_process, mock_fo 'source_type': 'url', 'chunking_strategy': 'semantic', 'index_name': 'test_index_1', - 'original_filename': 'doc1.pdf' + 'original_filename': 'doc1.pdf', + 'embedding_model_id': None, + 'tenant_id': None }, { 'source': 'http://example.com/doc2.pdf', 'source_type': 'url', 'chunking_strategy': 'fixed', 'index_name': 'test_index_2', - 'original_filename': 'doc2.pdf' + 'original_filename': 'doc2.pdf', + 'embedding_model_id': None, + 'tenant_id': None } ] actual_process_calls = [kwargs for args, diff --git a/test/backend/services/test_vectordatabase_service.py b/test/backend/services/test_vectordatabase_service.py index 66c7e8a7a..b583565bb 100644 --- a/test/backend/services/test_vectordatabase_service.py +++ b/test/backend/services/test_vectordatabase_service.py @@ -512,7 +512,9 @@ def test_create_knowledge_base_with_embedding_model_name(self, mock_get_embeddin self.assertEqual(result["knowledge_id"], 10) # Verify get_embedding_model was called with the model name - mock_get_embedding.assert_called_once_with("tenant-1", "text-embedding-3-small") + mock_get_embedding.assert_called_once_with( + "tenant-1", is_multimodal=False, model_name="text-embedding-3-small" + ) # Verify knowledge record was created with the embedding model name mock_create_knowledge.assert_called_once() @@ -559,7 +561,7 @@ def test_create_knowledge_base_without_embedding_model_name_uses_default(self, m self.assertEqual(result["status"], "success") # Verify get_embedding_model was called with None (no specific model) - mock_get_embedding.assert_called_once_with("tenant-1", None) + mock_get_embedding.assert_called_once_with("tenant-1", is_multimodal=False) # Verify knowledge record was created with the model's display name mock_create_knowledge.assert_called_once() diff --git a/test/sdk/data_process/test_core.py b/test/sdk/data_process/test_core.py index af325b52f..5dfff546f 100644 --- a/test/sdk/data_process/test_core.py +++ b/test/sdk/data_process/test_core.py @@ -207,7 +207,8 @@ def test_validate_parameters_invalid_processor(self, core): ) def test_select_processor_by_filename(self, core, filename, expected_processor, expected_extractor): """Test processor selection based on filename""" - processor_name, extractor = core._select_processor_by_filename(filename) + params = {"model_type": "multi_embedding"} if expected_extractor else {} + processor_name, extractor = core._select_processor_by_filename(filename, params) assert processor_name == expected_processor assert extractor == expected_extractor @@ -349,7 +350,7 @@ def test_file_process_returns_images_when_extractor_available(self, core, mocker core.processors["UniversalImageExtractor"] = mock_extractor result = core.file_process( - b"data", "sample.pdf", chunking_strategy="basic" + b"data", "sample.pdf", chunking_strategy="basic", model_type="multi_embedding" ) chunks = _unpack_chunks(result) @@ -366,7 +367,11 @@ def test_file_process_with_explicit_processor_still_extracts_images(self, core): ) result = core.file_process( - b"data", "report.pdf", chunking_strategy="basic", processor="Unstructured" + b"data", + "report.pdf", + chunking_strategy="basic", + processor="Unstructured", + model_type="multi_embedding", ) chunks = _unpack_chunks(result) From 46e71223d1bc5f524abf7064f294d24785a63fc0 Mon Sep 17 00:00:00 2001 From: wyxkerry <1012700194@qq.com> Date: Sun, 5 Apr 2026 17:53:46 +0800 Subject: [PATCH 11/11] =?UTF-8?q?=E2=9C=A8add=5Fimage=5Fretrieval?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/apps/file_management_app.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/backend/apps/file_management_app.py b/backend/apps/file_management_app.py index 3c9a95fbe..e0321237d 100644 --- a/backend/apps/file_management_app.py +++ b/backend/apps/file_management_app.py @@ -119,9 +119,9 @@ async def process_files( ..., description="List of file details to process, including path_or_url and filename")], index_name: Annotated[str, Body(...)], destination: Annotated[str, Body(...)], - chunking_strategy: Annotated[Optional[str], Body("basic")], - is_multimodal: Annotated[Optional[bool], Body(False)], - authorization: Annotated[Optional[str], Header(None)] + chunking_strategy: Annotated[Optional[str], Body(...)] = "basic", + is_multimodal: Annotated[Optional[bool], Body(...)] = False, + authorization: Annotated[Optional[str], Header()] = None ): """ Trigger data processing for a list of uploaded files.