diff --git a/backend/agents/create_agent_info.py b/backend/agents/create_agent_info.py index 0e1ed4d96..d6d79dcf1 100644 --- a/backend/agents/create_agent_info.py +++ b/backend/agents/create_agent_info.py @@ -353,6 +353,7 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int # special logic for search tools that may use reranking models if tool_config.class_name == "KnowledgeBaseSearchTool": + is_multimodal = tool_config.params.pop("multimodal", False) rerank = param_dict.get("rerank", False) rerank_model_name = param_dict.get("rerank_model_name", "") rerank_model = None @@ -363,7 +364,7 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int tool_config.metadata = { "vdb_core": get_vector_db_core(), - "embedding_model": get_embedding_model(tenant_id=tenant_id), + "embedding_model": get_embedding_model(tenant_id=tenant_id, is_multimodal=is_multimodal), "rerank_model": rerank_model, } elif tool_config.class_name in ["DifySearchTool", "DataMateSearchTool"]: diff --git a/backend/apps/file_management_app.py b/backend/apps/file_management_app.py index 5b7c7bc3c..e0321237d 100644 --- a/backend/apps/file_management_app.py +++ b/backend/apps/file_management_app.py @@ -2,7 +2,7 @@ import re import base64 from http import HTTPStatus -from typing import List, Optional +from typing import Annotated, List, Optional from urllib.parse import urlparse, urlunparse, unquote, quote import httpx @@ -115,12 +115,13 @@ async def upload_files( @file_management_config_router.post("/process") async def process_files( - files: List[dict] = Body( - ..., description="List of file details to process, including path_or_url and filename"), - chunking_strategy: Optional[str] = Body("basic"), - index_name: str = Body(...), - destination: str = Body(...), - authorization: Optional[str] = Header(None) + files: Annotated[List[dict], Body( + ..., description="List of file details to process, including path_or_url and filename")], + index_name: Annotated[str, Body(...)], + destination: Annotated[str, Body(...)], + chunking_strategy: Annotated[Optional[str], Body(...)] = "basic", + is_multimodal: Annotated[Optional[bool], Body(...)] = False, + authorization: Annotated[Optional[str], Header()] = None ): """ Trigger data processing for a list of uploaded files. @@ -133,7 +134,8 @@ async def process_files( chunking_strategy=chunking_strategy, source_type=destination, index_name=index_name, - authorization=authorization + authorization=authorization, + is_multimodal=is_multimodal ) process_result = await trigger_data_process(files, process_params) @@ -639,4 +641,4 @@ async def preview_file( raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=f"Failed to preview file: {str(e)}" - ) \ No newline at end of file + ) diff --git a/backend/apps/model_managment_app.py b/backend/apps/model_managment_app.py index 0a5a04139..c04c577f5 100644 --- a/backend/apps/model_managment_app.py +++ b/backend/apps/model_managment_app.py @@ -33,7 +33,7 @@ from fastapi.responses import JSONResponse from fastapi.encoders import jsonable_encoder from http import HTTPStatus -from typing import List, Optional +from typing import Annotated, List, Optional from services.model_health_service import ( check_model_connectivity, verify_model_config_connectivity, @@ -297,7 +297,8 @@ async def get_llm_model_list(authorization: Optional[str] = Header(None)): @router.post("/healthcheck") async def check_model_health( - display_name: str = Query(..., description="Display name to check"), + display_name: Annotated[str, Query(..., description="Display name to check")], + model_type: Annotated[str, Query(..., description="...")], authorization: Optional[str] = Header(None) ): """Check and update model connectivity, returning the latest status. @@ -308,7 +309,7 @@ async def check_model_health( """ try: _, tenant_id = get_current_user_id(authorization) - result = await check_model_connectivity(display_name, tenant_id) + result = await check_model_connectivity(display_name, tenant_id, model_type) return JSONResponse(status_code=HTTPStatus.OK, content={ "message": "Successfully checked model connectivity", "data": result diff --git a/backend/apps/vectordatabase_app.py b/backend/apps/vectordatabase_app.py index 872b5387b..2dba4559b 100644 --- a/backend/apps/vectordatabase_app.py +++ b/backend/apps/vectordatabase_app.py @@ -65,11 +65,15 @@ def create_new_index( # Extract optional fields from request body ingroup_permission = None group_ids = None - embedding_model_name = None + is_multimodal = False + embedding_model_name: Optional[str] = None if request: ingroup_permission = request.get("ingroup_permission") group_ids = request.get("group_ids") - embedding_model_name = request.get("embedding_model_name") + is_multimodal = request.get("is_multimodal") + embedding_model_name = request.get("embeddingModel") or request.get("embedding_model") + if isinstance(embedding_model_name, str): + embedding_model_name = embedding_model_name.strip() or None # Treat path parameter as user-facing knowledge base name for new creations return ElasticSearchService.create_knowledge_base( @@ -80,7 +84,7 @@ def create_new_index( tenant_id=tenant_id, ingroup_permission=ingroup_permission, group_ids=group_ids, - embedding_model_name=embedding_model_name, + is_multimodal=is_multimodal, ) except Exception as e: raise HTTPException( @@ -124,6 +128,7 @@ async def update_index( knowledge_name = request.get("knowledge_name") ingroup_permission = request.get("ingroup_permission") group_ids = request.get("group_ids") + is_multimodal = request.get("is_multimodal") # Call service layer to update knowledge base result = ElasticSearchService.update_knowledge_base( @@ -131,6 +136,7 @@ async def update_index( knowledge_name=knowledge_name, ingroup_permission=ingroup_permission, group_ids=group_ids, + is_multimodal=is_multimodal, tenant_id=tenant_id, user_id=user_id, ) @@ -199,15 +205,10 @@ def create_index_documents( try: user_id, tenant_id = get_current_user_id(authorization) - # Get the knowledge base record to retrieve the saved embedding model knowledge_record = get_knowledge_record({'index_name': index_name}) - saved_embedding_model_name = None - if knowledge_record: - saved_embedding_model_name = knowledge_record.get('embedding_model_name') - - # Use the saved model from knowledge base, fallback to tenant default if not set - embedding_model = get_embedding_model(tenant_id, saved_embedding_model_name) - + is_multimodal = True if knowledge_record.get( + 'is_multimodal') == 'Y' else False + embedding_model = get_embedding_model(tenant_id, is_multimodal) return ElasticSearchService.index_documents( embedding_model=embedding_model, index_name=index_name, diff --git a/backend/consts/const.py b/backend/consts/const.py index 5bfd012ff..8c75d7afd 100644 --- a/backend/consts/const.py +++ b/backend/consts/const.py @@ -27,6 +27,10 @@ class VectorDatabaseType(str, Enum): # Data Processing Service Configuration DATA_PROCESS_SERVICE = os.getenv("DATA_PROCESS_SERVICE") CLIP_MODEL_PATH = os.getenv("CLIP_MODEL_PATH") +TABLE_TRANSFORMER_MODEL_PATH = os.getenv("TABLE_TRANSFORMER_MODEL_PATH") +UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH = os.getenv( + "UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH" +) # Upload Configuration @@ -109,6 +113,7 @@ class VectorDatabaseType(str, Enum): MINIO_SECRET_KEY = os.getenv("MINIO_SECRET_KEY") MINIO_REGION = os.getenv("MINIO_REGION") MINIO_DEFAULT_BUCKET = os.getenv("MINIO_DEFAULT_BUCKET") +S3_URL_PREFIX = "s3://" # Postgres Configuration diff --git a/backend/consts/model.py b/backend/consts/model.py index 2728d95ca..128fe81d4 100644 --- a/backend/consts/model.py +++ b/backend/consts/model.py @@ -234,6 +234,7 @@ class ProcessParams(BaseModel): source_type: str index_name: str authorization: Optional[str] = None + is_multimodal: Optional[bool] = False class OpinionRequest(BaseModel): diff --git a/backend/data_process/ray_actors.py b/backend/data_process/ray_actors.py index 2fa590bec..b9fd982ae 100644 --- a/backend/data_process/ray_actors.py +++ b/backend/data_process/ray_actors.py @@ -1,11 +1,19 @@ +from io import BytesIO import logging import json from typing import Any, Dict, List, Optional import ray -from consts.const import RAY_ACTOR_NUM_CPUS, REDIS_BACKEND_URL, DEFAULT_EXPECTED_CHUNK_SIZE, DEFAULT_MAXIMUM_CHUNK_SIZE -from database.attachment_db import get_file_stream +from consts.const import ( + RAY_ACTOR_NUM_CPUS, + REDIS_BACKEND_URL, + DEFAULT_EXPECTED_CHUNK_SIZE, + DEFAULT_MAXIMUM_CHUNK_SIZE, + TABLE_TRANSFORMER_MODEL_PATH, + UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH, +) +from database.attachment_db import build_s3_url, get_file_stream, upload_fileobj from database.model_management_db import get_model_by_model_id from nexent.data_process import DataProcessCore @@ -58,50 +66,137 @@ def process_file( if task_id: params['task_id'] = task_id - # Get chunk size parameters from embedding model if model_id is provided - if model_id and tenant_id: - try: - # Get embedding model details directly by model_id - model_record = get_model_by_model_id( - model_id=model_id, tenant_id=tenant_id) - if model_record: - expected_chunk_size = model_record.get( - 'expected_chunk_size', DEFAULT_EXPECTED_CHUNK_SIZE) - maximum_chunk_size = model_record.get( - 'maximum_chunk_size', DEFAULT_MAXIMUM_CHUNK_SIZE) - model_name = model_record.get('display_name') - - # Pass chunk sizes to processing parameters - params['max_characters'] = maximum_chunk_size - params['new_after_n_chars'] = expected_chunk_size - - logger.info( - f"[RayActor] Using chunk sizes from embedding model '{model_name}' (ID: {model_id}): " - f"max_characters={maximum_chunk_size}, new_after_n_chars={expected_chunk_size}") - else: - logger.warning( - f"[RayActor] Embedding model with ID {model_id} not found for tenant '{tenant_id}', using default chunk sizes") - except Exception as e: + self._apply_model_chunk_sizes( + model_id=model_id, tenant_id=tenant_id, params=params) + self._apply_model_paths(params) + file_data = self._read_file_bytes(source) + + result = self._processor.file_process( + file_data=file_data, + filename=source, + chunking_strategy=chunking_strategy, + **params + ) + chunks, images_info = self._normalize_processor_result(result) + if images_info: + self._append_image_chunks( + source=source, chunks=chunks, images_info=images_info) + + chunks = self._validate_chunks(chunks, source) + if not chunks: + return [] + + logger.info( + f"[RayActor] Processing done: produced {len(chunks)} chunks for source='{source}'") + return chunks + + def _apply_model_paths(self, params: Dict[str, Any]) -> None: + params["table_transformer_model_path"] = TABLE_TRANSFORMER_MODEL_PATH + params[ + "unstructured_default_model_initialize_params_json_path" + ] = UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH + + def _apply_model_chunk_sizes( + self, + model_id: Optional[int], + tenant_id: Optional[str], + params: Dict[str, Any], + ) -> None: + if not (model_id and tenant_id): + return + + try: + model_record = get_model_by_model_id( + model_id=model_id, tenant_id=tenant_id) + if not model_record: logger.warning( - f"[RayActor] Failed to retrieve chunk sizes from embedding model ID {model_id}: {e}. Using default chunk sizes") + f"[RayActor] Embedding model with ID {model_id} not found for tenant '{tenant_id}', using default chunk sizes") + return + + expected_chunk_size = model_record.get( + 'expected_chunk_size', DEFAULT_EXPECTED_CHUNK_SIZE) + maximum_chunk_size = model_record.get( + 'maximum_chunk_size', DEFAULT_MAXIMUM_CHUNK_SIZE) + model_name = model_record.get('display_name') + model_type = model_record.get('model_type') + params['max_characters'] = maximum_chunk_size + params['new_after_n_chars'] = expected_chunk_size + if model_type: + params['model_type'] = model_type + + logger.info( + f"[RayActor] Using chunk sizes from embedding model '{model_name}' (ID: {model_id}): " + f"max_characters={maximum_chunk_size}, new_after_n_chars={expected_chunk_size}") + except Exception as e: + logger.warning( + f"[RayActor] Failed to retrieve chunk sizes from embedding model ID {model_id}: {e}. Using default chunk sizes") + + def _read_file_bytes(self, source: str) -> bytes: try: file_stream = get_file_stream(source) if file_stream is None: raise FileNotFoundError( f"Unable to fetch file from URL: {source}") - file_data = file_stream.read() + return file_stream.read() except Exception as e: logger.error(f"Failed to fetch file from {source}: {e}") raise - chunks = self._processor.file_process( - file_data=file_data, - filename=source, - chunking_strategy=chunking_strategy, - **params - ) + def _normalize_processor_result( + self, result: Any + ) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + if isinstance(result, tuple) and len(result) == 2: + chunks, images_info = result + return chunks or [], images_info or [] + return result or [], [] + + def _append_image_chunks( + self, + source: str, + chunks: List[Dict[str, Any]], + images_info: List[Dict[str, Any]], + ) -> None: + folder = "images_in_attachments" + for index, image_data in enumerate(images_info): + if not isinstance(image_data, dict): + logger.warning( + f"[RayActor] Skipping image entry at index {index}: unexpected type {type(image_data)}" + ) + continue + if "image_bytes" not in image_data: + logger.warning( + f"[RayActor] Skipping image entry at index {index}: missing image_bytes" + ) + continue + + img_obj = BytesIO(image_data["image_bytes"]) + result = upload_fileobj( + file_obj=img_obj, + file_name=f"{index}.{image_data['image_format']}", + prefix=folder) + image_url = build_s3_url(result.get("object_name", "")) + + image_data["source_file"] = source + image_data["image_url"] = image_url + chunks.append({ + "content": json.dumps({ + "source_file": source, + "position": image_data["position"], + "image_url": image_url, + }), + "filename": source, + "metadata": { + "chunk_index": len(chunks) + index, + "process_source": "UniversalImageExtractor", + "image_url": image_url, + } + }) + + def _validate_chunks( + self, chunks: Any, source: str + ) -> List[Dict[str, Any]]: if chunks is None: logger.warning( f"[RayActor] file_process returned None for source='{source}'") @@ -114,9 +209,6 @@ def process_file( logger.warning( f"[RayActor] file_process returned empty list for source='{source}'") return [] - - logger.info( - f"[RayActor] Processing done: produced {len(chunks)} chunks for source='{source}'") return chunks def store_chunks_in_redis(self, redis_key: str, chunks: List[Dict[str, Any]]) -> bool: diff --git a/backend/database/attachment_db.py b/backend/database/attachment_db.py index 2e6249468..a9ecf00f6 100644 --- a/backend/database/attachment_db.py +++ b/backend/database/attachment_db.py @@ -2,9 +2,62 @@ import os import uuid from datetime import datetime -from typing import Any, BinaryIO, Dict, List, Optional +from typing import Any, BinaryIO, Dict, List, Optional, Tuple from .client import minio_client +from consts.const import S3_URL_PREFIX + + +def _normalize_object_and_bucket(object_name: str, bucket: Optional[str] = None) -> Tuple[str, Optional[str]]: + """ + Normalize object_name + bucket from supported URL styles. + + Supports: + - s3://bucket/key + - /bucket/key + - key (uses provided bucket or default bucket) + """ + if not object_name: + return object_name, bucket + + if object_name.startswith(S3_URL_PREFIX): + s3_path = object_name[len(S3_URL_PREFIX) :] + parts = s3_path.split("/", 1) + parsed_bucket = parts[0] if parts[0] else None + parsed_key = parts[1] if len(parts) > 1 else "" + return parsed_key, parsed_bucket or bucket + + if object_name.startswith("/"): + path = object_name.lstrip("/") + parts = path.split("/", 1) + parsed_bucket = parts[0] if parts[0] else None + parsed_key = parts[1] if len(parts) > 1 else "" + return parsed_key, parsed_bucket or bucket + + return object_name, bucket + + +def build_s3_url(object_name: str, bucket: Optional[str] = None) -> str: + """ + Build an s3://bucket/key style URL from an object name (or passthrough if already s3://). + """ + if not object_name: + return "" + + if object_name.startswith(S3_URL_PREFIX): + return object_name + + if object_name.startswith("/"): + path = object_name.lstrip("/") + parts = path.split("/", 1) + if len(parts) == 2: + return f"{S3_URL_PREFIX}{parts[0]}/{parts[1]}" + return f"{S3_URL_PREFIX}{parts[0]}/" + + resolved_bucket = bucket or minio_client.storage_config.default_bucket + if resolved_bucket: + return f"{S3_URL_PREFIX}{resolved_bucket}/{object_name}" + return f"{S3_URL_PREFIX}{object_name}" def generate_object_name(file_name: str, prefix: str = "attachments") -> str: @@ -165,6 +218,7 @@ def get_file_size_from_minio(object_name: str, bucket: Optional[str] = None) -> """ Get file size by object name """ + object_name, bucket = _normalize_object_and_bucket(object_name, bucket) bucket = bucket or minio_client.storage_config.default_bucket return minio_client.get_file_size(object_name, bucket) @@ -181,6 +235,7 @@ def file_exists(object_name: str, bucket: Optional[str] = None) -> bool: bool: True if file exists, False otherwise """ try: + object_name, bucket = _normalize_object_and_bucket(object_name, bucket) return minio_client.file_exists(object_name, bucket) except Exception: return False @@ -198,6 +253,8 @@ def copy_file(source_object: str, dest_object: str, bucket: Optional[str] = None Returns: Dict[str, Any]: Result containing success flag and error message (if any) """ + source_object, bucket = _normalize_object_and_bucket(source_object, bucket) + dest_object, bucket = _normalize_object_and_bucket(dest_object, bucket) success, result = minio_client.copy_file(source_object, dest_object, bucket) if success: return {"success": True, "object_name": result} @@ -242,6 +299,7 @@ def delete_file(object_name: str, bucket: Optional[str] = None) -> Dict[str, Any Returns: Dict[str, Any]: Delete result, containing success flag and error message (if any) """ + object_name, bucket = _normalize_object_and_bucket(object_name, bucket) if not bucket: bucket = minio_client.storage_config.default_bucket success, result = minio_client.delete_file(object_name, bucket) @@ -265,6 +323,7 @@ def get_file_stream(object_name: str, bucket: Optional[str] = None) -> Optional[ Returns: Optional[BinaryIO]: Standard BinaryIO stream object, or None if failed """ + object_name, bucket = _normalize_object_and_bucket(object_name, bucket) success, result = minio_client.get_file_stream(object_name, bucket) if not success: return None diff --git a/backend/database/db_models.py b/backend/database/db_models.py index a1b28334c..faf58fce4 100644 --- a/backend/database/db_models.py +++ b/backend/database/db_models.py @@ -271,6 +271,7 @@ class KnowledgeRecord(TableBase): group_ids = Column(String, doc="Knowledge base group IDs list") ingroup_permission = Column( String(30), doc="In-group permission: EDIT, READ_ONLY, PRIVATE") + is_multimodal = Column(String(1), default="N", doc="Whether it is multimodal. Optional values: Y/N") class TenantConfig(TableBase): diff --git a/backend/database/knowledge_db.py b/backend/database/knowledge_db.py index df42e1888..40f4ca718 100644 --- a/backend/database/knowledge_db.py +++ b/backend/database/knowledge_db.py @@ -52,6 +52,7 @@ def create_knowledge_record(query: Dict[str, Any]) -> Dict[str, Any]: "knowledge_name": knowledge_name, "group_ids": convert_list_to_string(group_ids) if isinstance(group_ids, list) else group_ids, "ingroup_permission": query.get("ingroup_permission"), + "is_multimodal": 'Y' if query.get("is_multimodal") else 'N' } # For backward compatibility: if caller explicitly provides index_name, @@ -178,6 +179,9 @@ def update_knowledge_record(query: Dict[str, Any]) -> bool: if query.get("group_ids") is not None: record.group_ids = query["group_ids"] + if query.get("is_multimodal"): + record.is_multimodal = 'Y' if query["is_multimodal"] else 'N' + # Update timestamp and user if query.get("user_id"): record.updated_by = query["user_id"] @@ -254,6 +258,11 @@ def get_knowledge_record(query: Optional[Dict[str, Any]] = None) -> Dict[str, An db_query = db_query.filter( KnowledgeRecord.tenant_id == query['tenant_id']) + if 'is_multimodal' in query: + db_query = db_query.filter( + KnowledgeRecord.is_multimodal == query['is_multimodal'] + ) + result = db_query.first() if result: diff --git a/backend/database/model_management_db.py b/backend/database/model_management_db.py index cb1c6c69f..61753f52f 100644 --- a/backend/database/model_management_db.py +++ b/backend/database/model_management_db.py @@ -170,7 +170,7 @@ def get_model_records(filters: Optional[Dict[str, Any]], tenant_id: str) -> List return result_list -def get_model_by_display_name(display_name: str, tenant_id: str) -> Optional[Dict[str, Any]]: +def get_model_by_display_name(display_name: str, tenant_id: str, model_type: str = None) -> Optional[Dict[str, Any]]: """ Get a model record by display name @@ -179,6 +179,11 @@ def get_model_by_display_name(display_name: str, tenant_id: str) -> Optional[Dic tenant_id: """ filters = {'display_name': display_name} + + if model_type in ["multiEmbedding", "multi_embedding"]: + filters['model_type'] = "multi_embedding" + elif model_type == "embedding": + filters['model_type'] = "embedding" records = get_model_records(filters, tenant_id) if not records: @@ -203,7 +208,7 @@ def get_models_by_display_name(display_name: str, tenant_id: str) -> List[Dict[s return get_model_records(filters, tenant_id) -def get_model_id_by_display_name(display_name: str, tenant_id: str) -> Optional[int]: +def get_model_id_by_display_name(display_name: str, tenant_id: str, model_type: str = None) -> Optional[int]: """ Get a model ID by display name @@ -214,7 +219,7 @@ def get_model_id_by_display_name(display_name: str, tenant_id: str) -> Optional[ Returns: Optional[int]: Model ID """ - model = get_model_by_display_name(display_name, tenant_id) + model = get_model_by_display_name(display_name, tenant_id, model_type) return model["model_id"] if model else None diff --git a/backend/services/config_sync_service.py b/backend/services/config_sync_service.py index 9fe50813a..c484ca23f 100644 --- a/backend/services/config_sync_service.py +++ b/backend/services/config_sync_service.py @@ -99,7 +99,7 @@ async def save_config_impl(config, tenant_id, user_id): config_key = get_env_key(model_type) + "_ID" model_id = get_model_id_by_display_name( - model_display_name, tenant_id) + model_display_name, tenant_id, model_type=model_type) handle_model_config(tenant_id, user_id, config_key, model_id, tenant_config_dict) diff --git a/backend/services/data_process_service.py b/backend/services/data_process_service.py index 8c44c15e6..dcd80424d 100644 --- a/backend/services/data_process_service.py +++ b/backend/services/data_process_service.py @@ -255,6 +255,17 @@ async def load_image(self, image_url: str) -> Optional[Image.Image]: async def _load_image(self, session: aiohttp.ClientSession, path: str) -> Optional[Image.Image]: """Internal method to load an image from various sources""" try: + if path.startswith('s3://'): + # Fetch from MinIO using s3://bucket/key + file_stream = get_file_stream(object_name=path) + if file_stream is None: + raise FileNotFoundError( + f"Unable to fetch file from URL: {path}") + file_data = file_stream.read() + image_based64_str = base64.b64encode( + file_data).decode('utf-8') + path = f"data:image/jpeg;base64,{image_based64_str}" + # Check if input is base64 encoded if path.startswith('data:image'): # Extract the base64 data after the comma @@ -463,6 +474,8 @@ async def create_batch_tasks_impl(self, authorization: Optional[str], request: B chunking_strategy = source_config.get('chunking_strategy') index_name = source_config.get('index_name') original_filename = source_config.get('original_filename') + embedding_model_id = source_config.get('embedding_model_id') + tenant_id = source_config.get('tenant_id') # Validate required fields if not source: @@ -481,7 +494,9 @@ async def create_batch_tasks_impl(self, authorization: Optional[str], request: B source_type=source_type, chunking_strategy=chunking_strategy, index_name=index_name, - original_filename=original_filename + original_filename=original_filename, + embedding_model_id=embedding_model_id, + tenant_id=tenant_id ).set(queue='process_q'), forward.s( index_name=index_name, @@ -559,7 +574,7 @@ async def process_uploaded_text_file(self, file_content: bytes, filename: str, c } async def convert_office_to_pdf_impl(self, object_name: str, pdf_object_name: str) -> None: - """Full conversion pipeline: download → convert → upload → validate → cleanup. + """Full conversion pipeline: download -> convert -> upload -> validate -> cleanup. All five steps run inside data-process so that LibreOffice only needs to be installed in this container. diff --git a/backend/services/datamate_service.py b/backend/services/datamate_service.py index 776e0eb1d..26e777eba 100644 --- a/backend/services/datamate_service.py +++ b/backend/services/datamate_service.py @@ -51,7 +51,8 @@ async def _create_datamate_knowledge_records(knowledge_base_ids: List[str], "tenant_id": tenant_id, "user_id": user_id, # Use datamate as embedding model name - "embedding_model_name": embedding_model_names[i] + "embedding_model_name": embedding_model_names[i], + "is_multimodal": False, } # Run synchronous database operation in executor to avoid blocking diff --git a/backend/services/model_health_service.py b/backend/services/model_health_service.py index 9214a1ffa..5b8e27f07 100644 --- a/backend/services/model_health_service.py +++ b/backend/services/model_health_service.py @@ -128,10 +128,10 @@ async def _perform_connectivity_check( return connectivity -async def check_model_connectivity(display_name: str, tenant_id: str) -> dict: +async def check_model_connectivity(display_name: str, tenant_id: str, model_type: str = None) -> dict: try: # Query the database using display_name and tenant context from app layer - model = get_model_by_display_name(display_name, tenant_id=tenant_id) + model = get_model_by_display_name(display_name, tenant_id=tenant_id, model_type=model_type) if not model: raise LookupError(f"Model configuration not found for {display_name}") diff --git a/backend/services/tool_configuration_service.py b/backend/services/tool_configuration_service.py index 9653b2e10..22e9dd1b7 100644 --- a/backend/services/tool_configuration_service.py +++ b/backend/services/tool_configuration_service.py @@ -140,6 +140,10 @@ def get_local_tools() -> List[ToolInfo]: else: param_info["default"] = param.default.default param_info["optional"] = True + if getattr(param.default, "json_schema_extra", None): + optional_override = param.default.json_schema_extra.get("optional") + if optional_override is not None: + param_info["optional"] = optional_override init_params_list.append(param_info) @@ -692,7 +696,8 @@ def _validate_local_tool( instantiation_params[param_name] = param.default if tool_name == "knowledge_base_search": - embedding_model = get_embedding_model(tenant_id=tenant_id) + is_multimodal = instantiation_params.pop("multimodal", False) + embedding_model = get_embedding_model(tenant_id=tenant_id, is_multimodal=is_multimodal) vdb_core = get_vector_db_core() # Get rerank configuration diff --git a/backend/services/vectordatabase_service.py b/backend/services/vectordatabase_service.py index cf8f7f98c..2e1a881b7 100644 --- a/backend/services/vectordatabase_service.py +++ b/backend/services/vectordatabase_service.py @@ -28,7 +28,7 @@ from consts.const import DATAMATE_URL, ES_API_KEY, ES_HOST, LANGUAGE, VectorDatabaseType, IS_SPEED_MODE, PERMISSION_EDIT, PERMISSION_READ from consts.model import ChunkCreateRequest, ChunkUpdateRequest -from database.attachment_db import delete_file +from database.attachment_db import delete_file, get_file_stream from database.knowledge_db import ( create_knowledge_record, delete_knowledge_record, @@ -176,7 +176,76 @@ def check_knowledge_base_exist_impl(knowledge_name: str, vdb_core: VectorDatabas return {"status": "available"} -def get_embedding_model(tenant_id: str, model_name: Optional[str] = None): +def _build_embedding_from_config(model_config: Dict[str, Any]) -> Optional[BaseEmbedding]: + model_type = model_config.get("model_type", "") + if model_type == "embedding": + return OpenAICompatibleEmbedding( + api_key=model_config.get("api_key", ""), + base_url=model_config.get("base_url", ""), + model_name=get_model_name_from_config(model_config) or "", + embedding_dim=model_config.get("max_tokens", 1024), + ssl_verify=model_config.get("ssl_verify", True), + ) + if model_type == "multi_embedding": + return JinaEmbedding( + api_key=model_config.get("api_key", ""), + base_url=model_config.get("base_url", ""), + model_name=get_model_name_from_config(model_config) or "", + embedding_dim=model_config.get("max_tokens", 1024), + ssl_verify=model_config.get("ssl_verify", True), + ) + return None + + +def _find_model_record( + tenant_id: str, + is_multimodal: bool, + model_name: str, +) -> Optional[Dict[str, Any]]: + model_type = "multi_embedding" if is_multimodal else "embedding" + models = get_model_records({"model_type": model_type}, tenant_id) + for model in models: + model_display_name = ( + f"{model.get('model_repo')}/{model['model_name']}" + if model.get("model_repo") + else model["model_name"] + ) + if model_display_name == model_name: + return model + return None + + +def _build_embedding_from_record( + model_record: Dict[str, Any], + is_multimodal: bool, +) -> BaseEmbedding: + model_config = { + "model_repo": model_record.get("model_repo", ""), + "model_name": model_record["model_name"], + "api_key": model_record.get("api_key", ""), + "base_url": model_record.get("base_url", ""), + "model_type": "embedding", + "max_tokens": model_record.get("max_tokens", 1024), + "ssl_verify": model_record.get("ssl_verify", True), + } + if not is_multimodal: + return OpenAICompatibleEmbedding( + api_key=model_config.get("api_key", ""), + base_url=model_config.get("base_url", ""), + model_name=get_model_name_from_config(model_config) or "", + embedding_dim=model_config.get("max_tokens", 1024), + ssl_verify=model_config.get("ssl_verify", True), + ) + return JinaEmbedding( + api_key=model_config.get("api_key", ""), + base_url=model_config.get("base_url", ""), + model_name=get_model_name_from_config(model_config) or "", + embedding_dim=model_config.get("max_tokens", 1024), + ssl_verify=model_config.get("ssl_verify", True), + ) + + +def get_embedding_model(tenant_id: str, is_multimodal: bool = False, model_name: Optional[str] = None): """ Get the embedding model for the tenant, optionally using a specific model name. @@ -188,58 +257,46 @@ def get_embedding_model(tenant_id: str, model_name: Optional[str] = None): Returns: Embedding model instance or None """ + if model_name is None and (isinstance(is_multimodal, str) or is_multimodal is None): + model_name = is_multimodal + is_multimodal = False # If model_name is provided, try to find it in the tenant's models if model_name: try: - models = get_model_records({"model_type": "embedding"}, tenant_id) - for model in models: - model_display_name = model.get("model_repo") + "/" + model["model_name"] if model.get("model_repo") else model["model_name"] - if model_display_name == model_name: - # Found the model, create embedding instance - model_config = { - "model_repo": model.get("model_repo", ""), - "model_name": model["model_name"], - "api_key": model.get("api_key", ""), - "base_url": model.get("base_url", ""), - "model_type": "embedding", - "max_tokens": model.get("max_tokens", 1024), - "ssl_verify": model.get("ssl_verify", True), - } - return OpenAICompatibleEmbedding( - api_key=model_config.get("api_key", ""), - base_url=model_config.get("base_url", ""), - model_name=get_model_name_from_config(model_config) or "", - embedding_dim=model_config.get("max_tokens", 1024), - ssl_verify=model_config.get("ssl_verify", True), - ) + model_record = _find_model_record( + tenant_id=tenant_id, + is_multimodal=is_multimodal, + model_name=model_name, + ) + if model_record: + return _build_embedding_from_record( + model_record=model_record, + is_multimodal=is_multimodal, + ) except Exception as e: - logger.warning(f"Failed to get embedding model by name {model_name}: {e}") + logger.warning( + f"Failed to get embedding model by name {model_name}: {e}") # Fall back to default embedding model (current behavior) model_config = tenant_config_manager.get_model_config( - key="EMBEDDING_ID", tenant_id=tenant_id) - - model_type = model_config.get("model_type", "") - - if model_type == "embedding": - # Get the es core - return OpenAICompatibleEmbedding( - api_key=model_config.get("api_key", ""), - base_url=model_config.get("base_url", ""), - model_name=get_model_name_from_config(model_config) or "", - embedding_dim=model_config.get("max_tokens", 1024), - ssl_verify=model_config.get("ssl_verify", True), - ) - elif model_type == "multi_embedding": - return JinaEmbedding( - api_key=model_config.get("api_key", ""), - base_url=model_config.get("base_url", ""), - model_name=get_model_name_from_config(model_config) or "", - embedding_dim=model_config.get("max_tokens", 1024), - ssl_verify=model_config.get("ssl_verify", True), + key="MULTI_EMBEDDING_ID" if is_multimodal else "EMBEDDING_ID", + tenant_id=tenant_id, + ) + return _build_embedding_from_config(model_config) + + +def _resolve_embedding_model( + tenant_id: str, + is_multimodal: bool, + embedding_model_name: Optional[str], +) -> Optional[BaseEmbedding]: + if embedding_model_name: + return get_embedding_model( + tenant_id, + is_multimodal=is_multimodal, + model_name=embedding_model_name, ) - else: - return None + return get_embedding_model(tenant_id, is_multimodal=is_multimodal) def get_rerank_model(tenant_id: str, model_name: Optional[str] = None): @@ -406,6 +463,7 @@ async def full_delete_knowledge_base(index_name: str, vdb_core: VectorDatabaseCo @staticmethod def create_index( + embedding_model: BaseEmbedding, index_name: str = Path(..., description="Name of the index to create"), embedding_dim: Optional[int] = Query( @@ -419,15 +477,22 @@ def create_index( try: if vdb_core.check_index_exists(index_name): raise Exception(f"Index {index_name} already exists") - embedding_model = get_embedding_model(tenant_id) + success = vdb_core.create_index(index_name, embedding_dim=embedding_dim or ( embedding_model.embedding_dim if embedding_model else 1024)) if not success: raise Exception(f"Failed to create index {index_name}") - knowledge_data = {"index_name": index_name, - "created_by": user_id, - "tenant_id": tenant_id, - "embedding_model_name": embedding_model.model} + is_multimodal = ( + True + if embedding_model and getattr(embedding_model, "model_type", None) == "multimodal" + else False + ) + knowledge_data = { + "index_name": index_name, + "created_by": user_id, + "tenant_id": tenant_id, + "is_multimodal": is_multimodal, + } create_knowledge_record(knowledge_data) return {"status": "success", "message": f"Index {index_name} created successfully"} except Exception as e: @@ -443,6 +508,7 @@ def create_knowledge_base( ingroup_permission: Optional[str] = None, group_ids: Optional[List[int]] = None, embedding_model_name: Optional[str] = None, + is_multimodal: bool = False, ): """ Create a new knowledge base with a user-facing name and an internal Elasticsearch index name. @@ -467,14 +533,14 @@ def create_knowledge_base( with an explicit index_name. """ try: - # Get embedding model - use user-selected model if provided, otherwise use tenant default - embedding_model = get_embedding_model(tenant_id, embedding_model_name) - - # Determine the embedding model name to save: use user-provided name if available, - # otherwise use the model's display name - saved_embedding_model_name = embedding_model_name - if not saved_embedding_model_name and embedding_model: - saved_embedding_model_name = embedding_model.model + embedding_model = _resolve_embedding_model( + tenant_id=tenant_id, + is_multimodal=is_multimodal, + embedding_model_name=embedding_model_name, + ) + resolved_embedding_model_name = embedding_model_name + if resolved_embedding_model_name is None and embedding_model: + resolved_embedding_model_name = getattr(embedding_model, "model", None) # Create knowledge record first to obtain knowledge_id and generated index_name knowledge_data = { @@ -482,7 +548,8 @@ def create_knowledge_base( "knowledge_describe": "", "user_id": user_id, "tenant_id": tenant_id, - "embedding_model_name": saved_embedding_model_name, + "embedding_model_name": resolved_embedding_model_name, + "is_multimodal": is_multimodal, } # Add group permission and group IDs if provided @@ -519,6 +586,7 @@ def update_knowledge_base( knowledge_name: Optional[str] = None, ingroup_permission: Optional[str] = None, group_ids: Optional[List[int]] = None, + is_multimodal: bool = False, tenant_id: Optional[str] = None, user_id: Optional[str] = None, ) -> bool: @@ -549,6 +617,7 @@ def update_knowledge_base( update_data = { "index_name": index_name, "updated_by": user_id, + "is_multimodal": is_multimodal, } if knowledge_name is not None: @@ -781,6 +850,7 @@ def list_indices( "display_name": record.get("knowledge_name", index_name), "permission": record["permission"], "group_ids": record["group_ids"], + "is_multimodal": True if record.get("is_multimodal") == "Y" else False, # knowledge source and ingroup permission from DB record "knowledge_sources": record["knowledge_sources"], "ingroup_permission": record["ingroup_permission"], @@ -882,15 +952,30 @@ def index_documents( "author": author, "date": date, "content": text, - "process_source": "Unstructured", + "process_source": metadata.get("process_source", "Unstructured"), "file_size": file_size, "create_time": create_time, "languages": metadata.get("languages", []), "embedding_model_name": embedding_model_name } + image_url = metadata.get("image_url", "") + if len(image_url) > 0: + # Fetch image bytes from MinIO (supports s3://bucket/key or /bucket/key) + try: + file_stream = get_file_stream( + object_name=image_url) + if file_stream is None: + raise FileNotFoundError( + f"Unable to fetch file from URL: {image_url}") + document["image_bytes"] = file_stream.read() + except Exception as e: + logger.error( + f"Failed to fetch file from {image_url}: {e}") + raise documents.append(document) + total_submitted = len(documents) if total_submitted == 0: return { @@ -908,8 +993,9 @@ def index_documents( 'tenant_id') if knowledge_record else None if tenant_id: + model_type = "EMBEDDING_ID" if embedding_model.model_type == "text" else "MULTI_EMBEDDING_ID" model_config = tenant_config_manager.get_model_config( - key="EMBEDDING_ID", tenant_id=tenant_id) + key=model_type, tenant_id=tenant_id) embedding_batch_size = model_config.get("chunk_batch", 10) if embedding_batch_size is None: embedding_batch_size = 10 diff --git a/backend/utils/file_management_utils.py b/backend/utils/file_management_utils.py index 57025e350..4103770f3 100644 --- a/backend/utils/file_management_utils.py +++ b/backend/utils/file_management_utils.py @@ -40,11 +40,13 @@ async def trigger_data_process(files: List[dict], process_params: ProcessParams) # Get chunking size according to the embedding model embedding_model_id = None tenant_id = None + is_multimodal = process_params.is_multimodal try: _, tenant_id = get_current_user_id(process_params.authorization) # Get embedding model ID from tenant config tenant_config = tenant_config_manager.load_config(tenant_id) - embedding_model_id_str = tenant_config.get("EMBEDDING_ID") if tenant_config else None + embedding_id_key = "MULTI_EMBEDDING_ID" if is_multimodal else "EMBEDDING_ID" + embedding_model_id_str = tenant_config.get(embedding_id_key) if tenant_config else None if embedding_model_id_str: embedding_model_id = int(embedding_model_id_str) except Exception as e: @@ -66,6 +68,7 @@ async def trigger_data_process(files: List[dict], process_params: ProcessParams) "index_name": process_params.index_name, "original_filename": file_details.get("filename"), "embedding_model_id": embedding_model_id, + "is_multimodal": is_multimodal, "tenant_id": tenant_id } @@ -97,6 +100,7 @@ async def trigger_data_process(files: List[dict], process_params: ProcessParams) "index_name": process_params.index_name, "original_filename": file_details.get("filename"), "embedding_model_id": embedding_model_id, + "is_multimodal": is_multimodal, "tenant_id": tenant_id } sources.append(source) diff --git a/docker/deploy.sh b/docker/deploy.sh index e30e6e75a..233c14604 100755 --- a/docker/deploy.sh +++ b/docker/deploy.sh @@ -17,6 +17,7 @@ DEPLOY_OPTIONS_FILE="$SCRIPT_DIR/deploy.options" MODE_CHOICE_SAVED="" VERSION_CHOICE_SAVED="" IS_MAINLAND_SAVED="" +DOWNLOAD_MODELS="N" ENABLE_TERMINAL_SAVED="N" TERMINAL_MOUNT_DIR_SAVED="${TERMINAL_MOUNT_DIR:-}" APP_VERSION="" @@ -79,6 +80,58 @@ is_windows_env() { return 1 } +detect_os_type() { + # Return: windows | mac | linux | unknown + local os_name + os_name=$(uname -s 2>/dev/null | tr '[:upper:]' '[:lower:]') + case "$os_name" in + mingw*|msys*|cygwin*) + echo "windows" + ;; + darwin*) + echo "mac" + ;; + linux*) + echo "linux" + ;; + *) + echo "unknown" + ;; + esac + return 0 +} + +format_path_for_env() { + # Convert path to OS-specific format for .env values + local input_path="$1" + local os_type + os_type=$(detect_os_type) + + if [[ "$os_type" = "windows" ]]; then + if command -v cygpath >/dev/null 2>&1; then + cygpath -w "$input_path" + return 0 + fi + + if [[ "$input_path" =~ ^/([a-zA-Z])/(.*)$ ]]; then + local drive="${BASH_REMATCH[1]}" + local rest="${BASH_REMATCH[2]}" + rest="${rest//\//\\}" + printf "%s:\\%s" "$(echo "$drive" | tr '[:lower:]' '[:upper:]')" "$rest" + return 0 + fi + fi + + printf "%s" "$input_path" +} + +escape_backslashes() { + # Escape backslashes for safe writing into .env or JSON + local input_path="$1" + printf "%s" "$input_path" | sed 's/\\/\\\\/g' + return 0 +} + is_port_in_use() { # Check if a TCP port is already in use (Linux/macOS/Windows Git Bash) local port="$1" @@ -266,6 +319,7 @@ persist_deploy_options() { echo "MODE_CHOICE=\"${MODE_CHOICE_SAVED}\"" echo "VERSION_CHOICE=\"${VERSION_CHOICE_SAVED}\"" echo "IS_MAINLAND=\"${IS_MAINLAND_SAVED}\"" + echo "DOWNLOAD_MODELS=\"${DOWNLOAD_MODELS}\"" echo "ENABLE_TERMINAL=\"${ENABLE_TERMINAL_SAVED}\"" echo "TERMINAL_MOUNT_DIR=\"${TERMINAL_MOUNT_DIR_SAVED}\"" } > "$DEPLOY_OPTIONS_FILE" @@ -528,6 +582,229 @@ select_deployment_mode() { echo "" } + +# Model download selection +select_model_download() { + echo "" + + local input_choice="" + read -r -p "Do you want to download AI model files (table-transformer and yolox)? [Y/N] (default: N): " input_choice + echo "" + + if [[ $input_choice =~ ^[Yy]$ ]]; then + DOWNLOAD_MODELS="Y" + echo "INFO: Model download will be performed." + else + DOWNLOAD_MODELS="N" + echo "INFO: Skipping model download." + fi + echo "----------------------------------------" + echo "" + return 0 +} + +# kerry + +download_and_config_models() { + if [[ "$DOWNLOAD_MODELS" != "Y" ]]; then + echo "INFO: Model download skipped by user choice." + return 0 + fi + + echo "INFO: Downloading AI model files (this may take a while)..." + + local env_file_dir="$SCRIPT_DIR" + local env_file_path="$env_file_dir/.env" + local original_dir="$(pwd)" + + MODEL_ROOT="$ROOT_DIR/model" + mkdir -p "$MODEL_ROOT" + echo "INFO: Model directory: $MODEL_ROOT" + + export HF_ENDPOINT="https://hf-mirror.com" + + command -v git >/dev/null || { echo "ERROR: git is required but not found." >&2; return 1; } + + # ========================================== + # 1. Table Transformer (table-structure recognition) + echo "INFO: Downloading table-transformer-structure-recognition..." + + TT_MODEL_DIR_NAME="table-transformer-structure-recognition" + TT_MODEL_DIR_PATH="$MODEL_ROOT/$TT_MODEL_DIR_NAME" + MODEL_SAFETENSORS_FILE="model.safetensors" + TT_MODEL_FILE_CHECK="$TT_MODEL_DIR_PATH/$MODEL_SAFETENSORS_FILE" + + cd "$MODEL_ROOT" || return 1 + + if [[ -d "$TT_MODEL_DIR_PATH" ]] && [[ -f "$TT_MODEL_FILE_CHECK" ]]; then + FILE_SIZE=$(stat -c%s "$TT_MODEL_FILE_CHECK" 2>/dev/null || stat -f%z "$TT_MODEL_FILE_CHECK" 2>/dev/null) + if [[ "$FILE_SIZE" -gt 1000000 ]]; then + echo "INFO: Table Transformer already exists." + else + echo "WARN: Existing model file looks incomplete, re-downloading..." + rm -rf "$TT_MODEL_DIR_NAME" + fi + fi + + if [[ ! -f "$TT_MODEL_FILE_CHECK" ]]; then + if [[ -d "$TT_MODEL_DIR_NAME" ]]; then + echo "WARN: Removing existing directory before re-download..." + rm -rf "$TT_MODEL_DIR_NAME" + fi + + echo "INFO: Step 1/2: Clone repo (skip LFS files)..." + if ! GIT_LFS_SKIP_SMUDGE=1 git clone "$HF_ENDPOINT/microsoft/$TT_MODEL_DIR_NAME" "$TT_MODEL_DIR_NAME"; then + echo "ERROR: Failed to clone repository." >&2 + cd "$original_dir" + return 1 + fi + + cd "$TT_MODEL_DIR_NAME" || return 1 + + echo "INFO: Step 2/2: Download model.safetensors..." + LARGE_FILE_URL="$HF_ENDPOINT/microsoft/$TT_MODEL_DIR_NAME/resolve/main/$MODEL_SAFETENSORS_FILE" + + if command -v curl &> /dev/null; then + curl -L -o "$MODEL_SAFETENSORS_FILE" "$LARGE_FILE_URL" --progress-bar + elif command -v wget &> /dev/null; then + wget "$LARGE_FILE_URL" -O "$MODEL_SAFETENSORS_FILE" + else + echo "ERROR: curl or wget is required to download model files." >&2 + cd "$MODEL_ROOT"; rm -rf "$TT_MODEL_DIR_NAME"; cd "$original_dir"; return 1 + fi + + if [[ ! -f "$MODEL_SAFETENSORS_FILE" ]]; then + echo "ERROR: $MODEL_SAFETENSORS_FILE download failed." >&2 + cd "$MODEL_ROOT"; rm -rf "$TT_MODEL_DIR_NAME"; cd "$original_dir"; return 1 + fi + + FILE_SIZE=$(stat -c%s "$MODEL_SAFETENSORS_FILE" 2>/dev/null || stat -f%z "$MODEL_SAFETENSORS_FILE" 2>/dev/null) + if [[ "$FILE_SIZE" -lt 1000000 ]]; then + echo "ERROR: $MODEL_SAFETENSORS_FILE seems too small (size: $FILE_SIZE bytes)." >&2 + cd "$MODEL_ROOT"; rm -rf "$TT_MODEL_DIR_NAME"; cd "$original_dir"; return 1 + fi + + echo "INFO: $MODEL_SAFETENSORS_FILE downloaded (size: $(du -h "$MODEL_SAFETENSORS_FILE" | cut -f1))" + cd "$MODEL_ROOT" + fi + + echo "INFO: Table Transformer OK" + + # ========================================== + # 2. YOLOX (layout detection model) + echo "INFO: Downloading yolox_l0.05.onnx" + + YOLOX_MODEL_FILE="$MODEL_ROOT/yolox_l0.05.onnx" + MIN_YOLOX_SIZE=50000000 + + NEED_DOWNLOAD=false + + if [[ -f "$YOLOX_MODEL_FILE" ]]; then + CURRENT_SIZE=$(stat -c%s "$YOLOX_MODEL_FILE" 2>/dev/null || stat -f%z "$YOLOX_MODEL_FILE" 2>/dev/null) + if [[ "$CURRENT_SIZE" -lt "$MIN_YOLOX_SIZE" ]]; then + echo "WARN: Existing YOLOX file looks incomplete (size: $(numfmt --to=iec-i --suffix=B $CURRENT_SIZE 2>/dev/null || echo $CURRENT_SIZE)). Re-downloading..." + NEED_DOWNLOAD=true + else + echo "INFO: YOLOX already exists." + fi + else + NEED_DOWNLOAD=true + fi + + if [[ "$NEED_DOWNLOAD" = true ]]; then + ONNX_URL="$HF_ENDPOINT/unstructuredio/yolo_x_layout/resolve/main/yolox_l0.05.onnx" + + if command -v curl &> /dev/null; then + echo "INFO: Downloading with curl (supports resume -C -)..." + if curl -L -C - -o "$YOLOX_MODEL_FILE" "$ONNX_URL" --progress-bar; then + echo "INFO: curl download completed" + else + echo "ERROR: curl download failed." >&2 + cd "$original_dir" + return 1 + fi + elif command -v wget &> /dev/null; then + echo "INFO: Downloading with wget (supports resume -c)..." + wget -c "$ONNX_URL" -O "$YOLOX_MODEL_FILE" + else + echo "ERROR: curl or wget is required to download model files." >&2 + cd "$original_dir" + return 1 + fi + + if [[ -f "$YOLOX_MODEL_FILE" ]]; then + FINAL_SIZE=$(stat -c%s "$YOLOX_MODEL_FILE" 2>/dev/null || stat -f%z "$YOLOX_MODEL_FILE" 2>/dev/null) + if [[ "$FINAL_SIZE" -lt "$MIN_YOLOX_SIZE" ]]; then + echo "ERROR: YOLOX file seems too small (size: $FINAL_SIZE bytes)." >&2 + cd "$original_dir" + return 1 + else + echo "INFO: YOLOX downloaded (size: $(numfmt --to=iec-i --suffix=B $FINAL_SIZE 2>/dev/null || echo $FINAL_SIZE))" + fi + else + echo "ERROR: YOLOX download failed: file not found." >&2 + cd "$original_dir" + return 1 + fi + fi + + echo "INFO: YOLOX OK" + + # ========================================== + # 3. config.json + CONFIG_FILE="$MODEL_ROOT/config.json" + YOLOX_ABS_PATH=$(cd "$(dirname "$YOLOX_MODEL_FILE")" && pwd)/$(basename "$YOLOX_MODEL_FILE") + YOLOX_OS_PATH=$(format_path_for_env "$YOLOX_ABS_PATH") + YOLOX_CONFIG_PATH=$(escape_backslashes "$YOLOX_OS_PATH") + + cat > "$CONFIG_FILE" < /dev/null; then + update_env_var "TABLE_TRANSFORMER_MODEL_PATH" "$TT_MODEL_DIR_ENV_PATH" + update_env_var "UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH" "$CONFIG_FILE_ENV_PATH" + else + sed -i.bak "/^TABLE_TRANSFORMER_MODEL_PATH=/d" "$env_file_path" 2>/dev/null || true + echo "TABLE_TRANSFORMER_MODEL_PATH="$TT_MODEL_DIR_ENV_PATH"" >> "$env_file_path" + + sed -i.bak "/^UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH=/d" "$env_file_path" 2>/dev/null || true + echo "UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH="$CONFIG_FILE_ENV_PATH"" >> "$env_file_path" + rm -f "$env_file_path.bak" 2>/dev/null + fi + + echo "INFO: Environment file updated" + cd "$original_dir" +} + clean() { export MINIO_ACCESS_KEY= export MINIO_SECRET_KEY= @@ -600,6 +877,13 @@ prepare_directory_and_data() { create_dir_with_permission "$ROOT_DIR/minio" 775 create_dir_with_permission "$ROOT_DIR/redis" 775 + echo "📦 Check the status of model configuration..." + download_and_config_models || { + echo "⚠️ A warning occurred during the model configuration step, but subsequent deployment will proceed..." + # Do not exit here; the user may choose N or prefer to continue after a download failure. + } + echo "" + cp -rn volumes $ROOT_DIR chmod -R 775 $ROOT_DIR/volumes echo " 📁 Directory $ROOT_DIR/volumes has been created and permissions set to 775." @@ -1057,6 +1341,8 @@ main_deploy() { select_terminal_tool || { echo "❌ Terminal tool container configuration failed"; exit 1; } choose_image_env || { echo "❌ Image environment setup failed"; exit 1; } + select_model_download || { echo "❌ Model download failed"; exit 1;} + # Set NEXENT_MCP_DOCKER_IMAGE in .env file if [ -n "${NEXENT_MCP_DOCKER_IMAGE:-}" ]; then update_env_var "NEXENT_MCP_DOCKER_IMAGE" "${NEXENT_MCP_DOCKER_IMAGE}" @@ -1142,7 +1428,7 @@ docker_compose_command="" case $version_type in "v1") echo "Detected Docker Compose V1, version: $version_number" - # The version ​​v1.28.0​​ is the minimum requirement in Docker Compose v1 that explicitly supports interpolation syntax with default values like ${VAR:-default} + # The version 1.28.0 is the minimum requirement in Docker Compose v1 for default interpolation syntax. if [[ $version_number < "1.28.0" ]]; then echo "Warning: V1 version is too old, consider upgrading to V2" exit 1 diff --git a/docker/init.sql b/docker/init.sql index 75e9a818f..f808d9c06 100644 --- a/docker/init.sql +++ b/docker/init.sql @@ -213,6 +213,7 @@ CREATE TABLE IF NOT EXISTS "knowledge_record_t" ( "embedding_model_name" varchar(200) COLLATE "pg_catalog"."default", "group_ids" varchar, "ingroup_permission" varchar(30), + "is_multimodal" varchar(1) COLLATE "pg_catalog"."default" DEFAULT 'N'::character varying, "create_time" timestamp(0) DEFAULT CURRENT_TIMESTAMP, "update_time" timestamp(0) DEFAULT CURRENT_TIMESTAMP, "delete_flag" varchar(1) COLLATE "pg_catalog"."default" DEFAULT 'N'::character varying, @@ -230,6 +231,7 @@ COMMENT ON COLUMN "knowledge_record_t"."knowledge_sources" IS 'Knowledge base so COMMENT ON COLUMN "knowledge_record_t"."embedding_model_name" IS 'Embedding model name, used to record the embedding model used by the knowledge base'; COMMENT ON COLUMN "knowledge_record_t"."group_ids" IS 'Knowledge base group IDs list'; COMMENT ON COLUMN "knowledge_record_t"."ingroup_permission" IS 'In-group permission: EDIT, READ_ONLY, PRIVATE'; +COMMENT ON COLUMN "knowledge_record_t"."is_multimodal" IS 'whether it is multimodal'; COMMENT ON COLUMN "knowledge_record_t"."create_time" IS 'Creation time, audit field'; COMMENT ON COLUMN "knowledge_record_t"."update_time" IS 'Update time, audit field'; COMMENT ON COLUMN "knowledge_record_t"."delete_flag" IS 'When deleted by user frontend, delete flag will be set to true, achieving soft delete effect. Optional values Y/N'; diff --git a/docker/sql/v1.8.1_0306_add_is_multimodal_to_knowledge_record_t.sql b/docker/sql/v1.8.1_0306_add_is_multimodal_to_knowledge_record_t.sql new file mode 100644 index 000000000..d5b14bfbb --- /dev/null +++ b/docker/sql/v1.8.1_0306_add_is_multimodal_to_knowledge_record_t.sql @@ -0,0 +1,5 @@ +-- Add is_multimodal column to knowledge_record_t table +ALTER TABLE nexent.knowledge_record_t +ADD COLUMN IF NOT EXISTS is_multimodal varchar(1) DEFAULT 'N'; + +COMMENT ON COLUMN nexent.knowledge_record_t.is_multimodal IS 'whether it is multimodal'; diff --git a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx index c97536b92..481d71920 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx @@ -100,7 +100,8 @@ export default function ToolManagement({ // Use tool list hook for data management const { availableTools } = useToolList(); - const { isVlmAvailable, isEmbeddingAvailable } = useConfig(); + const { isVlmAvailable, isEmbeddingAvailable, isMultiEmbeddingAvailable } = useConfig(); + const isEmbeddingOrMultiAvailable = isEmbeddingAvailable || isMultiEmbeddingAvailable; // Prefetch knowledge bases for KB tools const { prefetchKnowledgeBases } = usePrefetchKnowledgeBases(); @@ -363,7 +364,10 @@ export default function ToolManagement({ tool.id ); const isDisabledDueToVlm = isToolDisabledDueToVlm(tool.name, isVlmAvailable); - const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding(tool.name, isEmbeddingAvailable); + const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding( + tool.name, + isEmbeddingOrMultiAvailable + ); const isDisabled = isDisabledDueToVlm || isDisabledDueToEmbedding || isReadOnly; // Tooltip priority: permission > VLM > Embedding const tooltipTitle = isReadOnly @@ -468,7 +472,10 @@ export default function ToolManagement({ {group.tools.map((tool) => { const isSelected = originalSelectedToolIdsSet.has(tool.id); const isDisabledDueToVlm = isToolDisabledDueToVlm(tool.name, isVlmAvailable); - const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding(tool.name, isEmbeddingAvailable); + const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding( + tool.name, + isEmbeddingOrMultiAvailable + ); const isDisabled = isDisabledDueToVlm || isDisabledDueToEmbedding || isReadOnly; // Tooltip priority: permission > VLM > Embedding const tooltipTitle = isReadOnly diff --git a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx index 4808bd765..d67a5a901 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx @@ -30,6 +30,10 @@ import { API_ENDPOINTS } from "@/services/api"; import knowledgeBaseService from "@/services/knowledgeBaseService"; import log from "@/lib/logger"; import { isZhLocale, getLocalizedDescription } from "@/lib/utils"; +import { + isEmbeddingModelCompatible as isEmbeddingModelCompatibleBase, + isMultimodalConstraintMismatch as isMultimodalConstraintMismatchBase, +} from "@/lib/knowledgeBaseCompatibility"; export interface ToolConfigModalProps { isOpen: boolean; @@ -459,6 +463,89 @@ export default function ToolConfigModal({ } }, [configData]); + const currentMultiEmbeddingModel = useMemo(() => { + try { + const modelConfig = configData?.models; + return ( + modelConfig?.multiEmbedding?.modelName || + modelConfig?.multiEmbedding?.displayName || + null + ); + } catch { + return null; + } + }, [configData]); + + const hasEmbeddingModel = Boolean(currentEmbeddingModel); + const hasMultiEmbeddingModel = Boolean(currentMultiEmbeddingModel); + const canToggleMultimodalParam = hasEmbeddingModel && hasMultiEmbeddingModel; + const forcedMultimodalValue = useMemo(() => { + if (!hasEmbeddingModel && hasMultiEmbeddingModel) { + return true; + } + if (hasEmbeddingModel && !hasMultiEmbeddingModel) { + return false; + } + return null; + }, [hasEmbeddingModel, hasMultiEmbeddingModel]); + + const toolMultimodal = useMemo(() => { + const multimodalParam = currentParams.find( + (param) => param.name === "multimodal" + ); + const value = multimodalParam?.value; + if (typeof value === "boolean") { + return value; + } + if (typeof value === "number") { + return value === 1; + } + if (typeof value === "string") { + const normalized = value.trim().toLowerCase(); + if (normalized === "true") return true; + if (normalized === "false") return false; + } + return null; + }, [currentParams]); + + useEffect(() => { + if (tool?.name !== "knowledge_base_search") return; + if (forcedMultimodalValue === null) return; + + const index = currentParams.findIndex( + (param) => param.name === "multimodal" + ); + if (index < 0) return; + + const param = currentParams[index]; + if (param.value === forcedMultimodalValue) return; + + const updatedParams = [...currentParams]; + updatedParams[index] = { ...param, value: forcedMultimodalValue }; + setCurrentParams(updatedParams); + + const fieldName = `param_${index}`; + form.setFieldValue(fieldName, forcedMultimodalValue); + }, [tool?.name, forcedMultimodalValue, currentParams, form]); + + const isMultimodalConstraintMismatch = useCallback( + (kb: KnowledgeBase) => { + return isMultimodalConstraintMismatchBase(kb, toolMultimodal); + }, + [toolMultimodal] + ); + + const isEmbeddingModelCompatible = useCallback( + (kb: KnowledgeBase) => { + return isEmbeddingModelCompatibleBase( + kb, + currentEmbeddingModel, + currentMultiEmbeddingModel + ); + }, + [currentEmbeddingModel, currentMultiEmbeddingModel] + ); + // Check if a knowledge base can be selected const canSelectKnowledgeBase = useCallback( (kb: KnowledgeBase): boolean => { @@ -469,9 +556,17 @@ export default function ToolConfigModal({ return false; } + // For nexent source, check model matching + if (kb.source === "nexent") { + if (isMultimodalConstraintMismatch(kb)) { + return false; + } + return isEmbeddingModelCompatible(kb); + } + return true; }, - [currentEmbeddingModel] + [isEmbeddingModelCompatible, isMultimodalConstraintMismatch] ); // Track whether this is the first time opening the modal (reset when modal closes) @@ -1290,7 +1385,7 @@ export default function ToolConfigModal({ })} options={options.map((option) => ({ value: option, - label: option, + label: String(option), }))} /> ); @@ -1306,8 +1401,13 @@ export default function ToolConfigModal({ /> ); - case TOOL_PARAM_TYPES.BOOLEAN: - return ; + case TOOL_PARAM_TYPES.BOOLEAN: { + const isMultimodalParam = + tool.name === "knowledge_base_search" && param.name === "multimodal"; + return ( + + ); + } case TOOL_PARAM_TYPES.STRING: case TOOL_PARAM_TYPES.ARRAY: @@ -1662,6 +1762,8 @@ export default function ToolConfigModal({ syncLoading={kbLoading} isSelectable={canSelectKnowledgeBase} currentEmbeddingModel={currentEmbeddingModel} + currentMultiEmbeddingModel={currentMultiEmbeddingModel} + toolMultimodal={toolMultimodal} difyConfig={ toolKbType === "dify_search" ? difyConfig diff --git a/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx b/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx index a5e7d52d1..954c3b82e 100644 --- a/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx +++ b/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx @@ -26,7 +26,6 @@ import knowledgeBaseService from "@/services/knowledgeBaseService"; import knowledgeBasePollingService from "@/services/knowledgeBasePollingService"; import { KnowledgeBase } from "@/types/knowledgeBase"; import { useConfig } from "@/hooks/useConfig"; -import { useModelList } from "@/hooks/model/useModelList"; import { SETUP_PAGE_CONTAINER, TWO_COLUMN_LAYOUT, @@ -128,9 +127,6 @@ function DataConfig({ isActive }: DataConfigProps) { const { modelConfig, data: configData, invalidateConfig, config, updateConfig, saveConfig } = useConfig(); const { token } = theme.useToken(); - // Get available embedding models for knowledge base creation - const { availableEmbeddingModels } = useModelList({ enabled: true }); - // Clear cache when component initializes useEffect(() => { localStorage.removeItem("preloaded_kb_data"); @@ -184,7 +180,7 @@ function DataConfig({ isActive }: DataConfigProps) { const [newKbName, setNewKbName] = useState(""); const [newKbIngroupPermission, setNewKbIngroupPermission] = useState("READ_ONLY"); const [newKbGroupIds, setNewKbGroupIds] = useState([]); - const [newKbEmbeddingModel, setNewKbEmbeddingModel] = useState(""); // Selected embedding model for new KB + const [isMultimodal, setIsMultimodal] = useState(false); const [uploadFiles, setUploadFiles] = useState([]); const [hasClickedUpload, setHasClickedUpload] = useState(false); const [showEmbeddingWarning, setShowEmbeddingWarning] = useState(false); @@ -197,11 +193,28 @@ function DataConfig({ isActive }: DataConfigProps) { const [modelFilter, setModelFilter] = useState([]); const contentRef = useRef(null); + const hasEmbeddingModel = Boolean( + modelConfig?.embedding?.modelName?.trim() + ); + const hasMultiEmbeddingModel = Boolean( + modelConfig?.multiEmbedding?.modelName?.trim() + ); + const canToggleMultimodal = hasEmbeddingModel && hasMultiEmbeddingModel; + // Open warning modal when single Embedding model is not configured (ignore multi-embedding) useEffect(() => { - const singleEmbeddingModelName = modelConfig?.embedding?.modelName; - setShowEmbeddingWarning(!singleEmbeddingModelName); - }, [modelConfig?.embedding?.modelName]); + setShowEmbeddingWarning(!hasEmbeddingModel && !hasMultiEmbeddingModel); + }, [hasEmbeddingModel, hasMultiEmbeddingModel]); + + useEffect(() => { + if (hasMultiEmbeddingModel && !hasEmbeddingModel) { + setIsMultimodal(true); + return; + } + if (hasEmbeddingModel && !hasMultiEmbeddingModel) { + setIsMultimodal(false); + } + }, [hasEmbeddingModel, hasMultiEmbeddingModel]); // Add event listener for selecting new knowledge base useEffect(() => { @@ -618,15 +631,11 @@ function DataConfig({ isActive }: DataConfigProps) { setNewKbName(defaultName); setNewKbIngroupPermission("READ_ONLY"); setNewKbGroupIds([]); - // Set default embedding model - prioritize config's default model, fall back to first available model - const configModel = modelConfig?.embedding?.modelName; - const defaultModel = configModel || (availableEmbeddingModels.length > 0 - ? availableEmbeddingModels[0].displayName - : ""); - setNewKbEmbeddingModel(defaultModel); setIsCreatingMode(true); setHasClickedUpload(false); // Reset upload button click state setUploadFiles([]); // Reset upload files array, clear all pending upload files + + setIsMultimodal(hasMultiEmbeddingModel && !hasEmbeddingModel) }; // Handle document deletion @@ -687,7 +696,7 @@ function DataConfig({ isActive }: DataConfigProps) { "elasticsearch", newKbIngroupPermission, newKbGroupIds, - newKbEmbeddingModel + isMultimodal ); if (!newKB) { @@ -702,7 +711,7 @@ function DataConfig({ isActive }: DataConfigProps) { setHasClickedUpload(false); setNewlyCreatedKbId(newKB.id); // Mark this KB as newly created - await uploadDocuments(newKB.id, filesToUpload); + await uploadDocuments(newKB.id, filesToUpload, isMultimodal); setUploadFiles([]); knowledgeBasePollingService @@ -738,7 +747,11 @@ function DataConfig({ isActive }: DataConfigProps) { } try { - await uploadDocuments(kbId, filesToUpload); + await uploadDocuments( + kbId, + filesToUpload, + kbState.activeKnowledgeBase?.is_multimodal + ); setUploadFiles([]); knowledgeBasePollingService.triggerKnowledgeBaseListUpdate(true); @@ -828,6 +841,11 @@ function DataConfig({ isActive }: DataConfigProps) { setNewKbName(name); }; + const handleSetMultimodal = (is_multimodal: boolean) => { + if (!canToggleMultimodal) return; + setIsMultimodal(is_multimodal) + } + // If Embedding model is not configured, show warning container instead of content if (showEmbeddingWarning) { return ( @@ -888,6 +906,9 @@ function DataConfig({ isActive }: DataConfigProps) { knowledgeBases={kbState.knowledgeBases} activeKnowledgeBase={kbState.activeKnowledgeBase} currentEmbeddingModel={kbState.currentEmbeddingModel} + currentMultiEmbeddingModel={ + modelConfig?.multiEmbedding?.modelName?.trim() || null + } isLoading={kbState.isLoading} syncLoading={kbState.syncLoading} onClick={handleKnowledgeBaseClick} @@ -948,10 +969,9 @@ function DataConfig({ isActive }: DataConfigProps) { onIngroupPermissionChange={setNewKbIngroupPermission} selectedGroupIds={newKbGroupIds} onSelectedGroupIdsChange={setNewKbGroupIds} - // Embedding model for create mode - availableEmbeddingModels={availableEmbeddingModels} - selectedEmbeddingModel={newKbEmbeddingModel} - onEmbeddingModelChange={setNewKbEmbeddingModel} + isMultimodal={isMultimodal} + onMultimodalChange={handleSetMultimodal} + canToggleMultimodal={canToggleMultimodal} // Upload related props isDragging={uiState.isDragging} onDragOver={handleDragOver} @@ -972,15 +992,31 @@ function DataConfig({ isActive }: DataConfigProps) { modelMismatch={hasKnowledgeBaseModelMismatch( kbState.activeKnowledgeBase )} - currentModel={kbState.currentEmbeddingModel || ""} + currentModel={ + kbState.activeKnowledgeBase.is_multimodal + ? modelConfig?.multiEmbedding?.modelName?.trim() || "" + : kbState.currentEmbeddingModel || "" + } knowledgeBaseModel={kbState.activeKnowledgeBase.embeddingModel} embeddingModelInfo={ hasKnowledgeBaseModelMismatch(kbState.activeKnowledgeBase) - ? t("document.modelMismatch.withModels", { - currentModel: kbState.currentEmbeddingModel || "", - knowledgeBaseModel: - kbState.activeKnowledgeBase.embeddingModel, - }) + ? (() => { + const currentModelName = kbState.activeKnowledgeBase.is_multimodal + ? modelConfig?.multiEmbedding?.modelName?.trim() || "" + : kbState.currentEmbeddingModel || ""; + const knowledgeBaseModel = + kbState.activeKnowledgeBase.embeddingModel; + if (!currentModelName) { + return t("document.chunk.tooltip.disabledDueToModelMismatch", { + currentModel: t("embedding.model.notConfigured"), + knowledgeBaseModel, + }); + } + return t("document.modelMismatch.withModels", { + currentModel: currentModelName, + knowledgeBaseModel, + }); + })() : undefined } containerHeight={SETUP_PAGE_CONTAINER.MAIN_CONTENT_HEIGHT} diff --git a/frontend/app/[locale]/knowledges/components/document/DocumentChunk.tsx b/frontend/app/[locale]/knowledges/components/document/DocumentChunk.tsx index 5e963c545..2c8bcd4af 100644 --- a/frontend/app/[locale]/knowledges/components/document/DocumentChunk.tsx +++ b/frontend/app/[locale]/knowledges/components/document/DocumentChunk.tsx @@ -128,16 +128,25 @@ const DocumentChunk: React.FC = ({ setTooltipResetKey((prev) => prev + 1); }, []); + const hasKnowledgeBaseModel = + Boolean(knowledgeBaseEmbeddingModel) && + knowledgeBaseEmbeddingModel !== "unknown"; + const hasCurrentModel = Boolean(currentEmbeddingModel); + const currentModelLabel = + currentEmbeddingModel ?? t("embedding.model.notConfigured"); + // Determine if embedding models mismatch (specific condition for tooltip) const isEmbeddingModelMismatch = React.useMemo(() => { - if (!currentEmbeddingModel || !knowledgeBaseEmbeddingModel) { - return false; - } - if (knowledgeBaseEmbeddingModel === "unknown") { + if (!hasKnowledgeBaseModel) { return false; } - return currentEmbeddingModel !== knowledgeBaseEmbeddingModel; - }, [currentEmbeddingModel, knowledgeBaseEmbeddingModel]); + return !hasCurrentModel || currentEmbeddingModel !== knowledgeBaseEmbeddingModel; + }, [ + currentEmbeddingModel, + hasCurrentModel, + hasKnowledgeBaseModel, + knowledgeBaseEmbeddingModel, + ]); // Determine if in read-only mode (embedding model mismatch OR user has READ_ONLY permission) // Note: isReadOnlyMode is broader, includes model mismatch and other conditions @@ -146,37 +155,48 @@ const DocumentChunk: React.FC = ({ if (permission === "READ_ONLY") { return true; } - if (!currentEmbeddingModel || !knowledgeBaseEmbeddingModel) { + if (!hasKnowledgeBaseModel) { return false; } - if (knowledgeBaseEmbeddingModel === "unknown") { - return false; - } - return currentEmbeddingModel !== knowledgeBaseEmbeddingModel; - }, [currentEmbeddingModel, knowledgeBaseEmbeddingModel, permission]); + return !hasCurrentModel || currentEmbeddingModel !== knowledgeBaseEmbeddingModel; + }, [ + currentEmbeddingModel, + hasCurrentModel, + hasKnowledgeBaseModel, + knowledgeBaseEmbeddingModel, + permission, + ]); // Determine if search should be disabled (only when embedding model mismatch, NOT for READ_ONLY permission) // This allows READ_ONLY users to still perform search const isSearchDisabled = React.useMemo(() => { - if (!currentEmbeddingModel || !knowledgeBaseEmbeddingModel) { - return false; - } - if (knowledgeBaseEmbeddingModel === "unknown") { + if (!hasKnowledgeBaseModel) { return false; } - return currentEmbeddingModel !== knowledgeBaseEmbeddingModel; - }, [currentEmbeddingModel, knowledgeBaseEmbeddingModel]); + return !hasCurrentModel || currentEmbeddingModel !== knowledgeBaseEmbeddingModel; + }, [ + currentEmbeddingModel, + hasCurrentModel, + hasKnowledgeBaseModel, + knowledgeBaseEmbeddingModel, + ]); // Disabled tooltip message when embedding model mismatch const disabledTooltipMessage = React.useMemo(() => { - if (isEmbeddingModelMismatch && currentEmbeddingModel && knowledgeBaseEmbeddingModel && knowledgeBaseEmbeddingModel !== "unknown") { + if (isEmbeddingModelMismatch && hasKnowledgeBaseModel) { return t("document.chunk.tooltip.disabledDueToModelMismatch", { - currentModel: currentEmbeddingModel, + currentModel: currentModelLabel, knowledgeBaseModel: knowledgeBaseEmbeddingModel }); } return ""; - }, [isEmbeddingModelMismatch, currentEmbeddingModel, knowledgeBaseEmbeddingModel, t]); + }, [ + currentModelLabel, + hasKnowledgeBaseModel, + isEmbeddingModelMismatch, + knowledgeBaseEmbeddingModel, + t, + ]); // Set active document when documents change useEffect(() => { @@ -322,11 +342,13 @@ const DocumentChunk: React.FC = ({ } // Check embedding model consistency before searching - if (isEmbeddingModelMismatch && currentEmbeddingModel && knowledgeBaseEmbeddingModel && knowledgeBaseEmbeddingModel !== "unknown") { - message.error(t("document.chunk.error.searchFailed", { - currentModel: currentEmbeddingModel, - knowledgeBaseModel: knowledgeBaseEmbeddingModel - })); + if (isEmbeddingModelMismatch && hasKnowledgeBaseModel) { + message.error( + t("document.chunk.tooltip.disabledDueToModelMismatch", { + currentModel: currentModelLabel, + knowledgeBaseModel: knowledgeBaseEmbeddingModel, + }) + ); return; } @@ -380,8 +402,9 @@ const DocumentChunk: React.FC = ({ resetChunkSearch, searchValue, t, + currentModelLabel, + hasKnowledgeBaseModel, isEmbeddingModelMismatch, - currentEmbeddingModel, knowledgeBaseEmbeddingModel, ]); @@ -465,14 +488,13 @@ const DocumentChunk: React.FC = ({ // Check embedding model consistency before creating chunk if (chunkModalMode === "create") { - if (knowledgeBaseEmbeddingModel && - knowledgeBaseEmbeddingModel !== "unknown" && - currentEmbeddingModel && - currentEmbeddingModel !== knowledgeBaseEmbeddingModel) { - message.error(t("document.chunk.error.createFailed", { - currentModel: currentEmbeddingModel, - knowledgeBaseModel: knowledgeBaseEmbeddingModel - })); + if (isEmbeddingModelMismatch && hasKnowledgeBaseModel) { + message.error( + t("document.chunk.tooltip.disabledDueToModelMismatch", { + currentModel: currentModelLabel, + knowledgeBaseModel: knowledgeBaseEmbeddingModel, + }) + ); return; } } diff --git a/frontend/app/[locale]/knowledges/components/document/DocumentList.tsx b/frontend/app/[locale]/knowledges/components/document/DocumentList.tsx index bf0925369..256917d77 100644 --- a/frontend/app/[locale]/knowledges/components/document/DocumentList.tsx +++ b/frontend/app/[locale]/knowledges/components/document/DocumentList.tsx @@ -79,6 +79,10 @@ interface DocumentListProps { onEmbeddingModelChange?: (value: string) => void; permission?: string; // User's permission for this knowledge base (READ_ONLY, EDIT, etc.) + isMultimodal?: boolean; + onMultimodalChange?: (is_nultimodal: boolean) => void; + canToggleMultimodal?: boolean; + // Upload related props isDragging?: boolean; onDragOver?: (e: React.DragEvent) => void; @@ -122,6 +126,10 @@ const DocumentListContainer = forwardRef( onEmbeddingModelChange, permission, + isMultimodal = false, + onMultimodalChange, + canToggleMultimodal = true, + // Upload related props isDragging = false, onDragOver, @@ -524,6 +532,39 @@ const DocumentListContainer = forwardRef( options={permissionOptions} /> + + + ) : ( diff --git a/frontend/app/[locale]/knowledges/components/knowledge/KnowledgeBaseList.tsx b/frontend/app/[locale]/knowledges/components/knowledge/KnowledgeBaseList.tsx index cbff0297b..20bb76161 100644 --- a/frontend/app/[locale]/knowledges/components/knowledge/KnowledgeBaseList.tsx +++ b/frontend/app/[locale]/knowledges/components/knowledge/KnowledgeBaseList.tsx @@ -32,6 +32,7 @@ interface KnowledgeBaseListProps { knowledgeBases: KnowledgeBase[]; activeKnowledgeBase: KnowledgeBase | null; currentEmbeddingModel: string | null; + currentMultiEmbeddingModel: string | null; isLoading?: boolean; syncLoading?: boolean; onClick: (kb: KnowledgeBase) => void; @@ -57,6 +58,7 @@ const KnowledgeBaseList: React.FC = ({ knowledgeBases, activeKnowledgeBase, currentEmbeddingModel, + currentMultiEmbeddingModel, isLoading = false, syncLoading = false, onClick, @@ -127,6 +129,19 @@ const KnowledgeBaseList: React.FC = ({ return `knowledgeBase.ingroup.permission.${permission || "DEFAULT"}`; }; + const isModelMismatch = (kb: KnowledgeBase) => { + if (kb.embeddingModel === "unknown") return false; + if (kb.source === "datamate") return false; + + if (kb.is_multimodal) { + if (!currentMultiEmbeddingModel) return true; + return kb.embeddingModel !== currentMultiEmbeddingModel; + } + + if (!currentEmbeddingModel) return false; + return kb.embeddingModel !== currentEmbeddingModel; + }; + // Search and filter states const [searchKeyword, setSearchKeyword] = useState(""); const [selectedSources, setSelectedSources] = useState([]); @@ -579,6 +594,13 @@ const KnowledgeBaseList: React.FC = ({ })} )} + {isModelMismatch(kb) && ( + + {t("knowledgeBase.tag.modelMismatch")} + + )} {/* User group tags - only show when not PRIVATE */} @@ -592,6 +614,13 @@ const KnowledgeBaseList: React.FC = ({ ))} + {kb.is_multimodal && ( + + multimodal + + )} )} diff --git a/frontend/app/[locale]/knowledges/contexts/DocumentContext.tsx b/frontend/app/[locale]/knowledges/contexts/DocumentContext.tsx index b956dd919..7a2dcfb2e 100644 --- a/frontend/app/[locale]/knowledges/contexts/DocumentContext.tsx +++ b/frontend/app/[locale]/knowledges/contexts/DocumentContext.tsx @@ -112,7 +112,7 @@ export const DocumentContext = createContext<{ state: DocumentState; dispatch: React.Dispatch; fetchDocuments: (kbId: string, forceRefresh?: boolean, kbSource?: string) => Promise; - uploadDocuments: (kbId: string, files: File[]) => Promise; + uploadDocuments: (kbId: string, files: File[], isMultimodal?: boolean) => Promise; deleteDocument: (kbId: string, docId: string) => Promise; }>({ state: { @@ -202,11 +202,11 @@ export const DocumentProvider: React.FC = ({ children }) }, [state.loadingKbIds, state.documentsMap, t]); // Upload documents to a knowledge base - const uploadDocuments = useCallback(async (kbId: string, files: File[]) => { + const uploadDocuments = useCallback(async (kbId: string, files: File[], isMultimodal?: boolean) => { dispatch({ type: DOCUMENT_ACTION_TYPES.SET_UPLOADING, payload: true }); try { - await knowledgeBaseService.uploadDocuments(kbId, files); + await knowledgeBaseService.uploadDocuments(kbId, files, undefined, isMultimodal); // Set loading state before fetching latest documents dispatch({ type: DOCUMENT_ACTION_TYPES.SET_LOADING_DOCUMENTS, payload: true }); @@ -265,4 +265,4 @@ export const DocumentProvider: React.FC = ({ children }) {children} ); -}; \ No newline at end of file +}; diff --git a/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx b/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx index 5985c4b08..0aa397863 100644 --- a/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx +++ b/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx @@ -110,7 +110,7 @@ export const KnowledgeBaseContext = createContext<{ source?: string, ingroup_permission?: string, group_ids?: number[], - embeddingModel?: string + is_multimodal?: boolean, ) => Promise; deleteKnowledgeBase: (id: string) => Promise; selectKnowledgeBase: (id: string) => void; @@ -125,6 +125,7 @@ export const KnowledgeBaseContext = createContext<{ selectedIds: [], activeKnowledgeBase: null, currentEmbeddingModel: null, + currentMultiEmbeddingModel: null, isLoading: false, syncLoading: false, error: null, @@ -159,6 +160,7 @@ export const KnowledgeBaseProvider: React.FC = ({ selectedIds: [], activeKnowledgeBase: null, currentEmbeddingModel: null, + currentMultiEmbeddingModel: null, isLoading: false, syncLoading: false, error: null, @@ -200,9 +202,29 @@ export const KnowledgeBaseProvider: React.FC = ({ // Note: Always return false to remove model mismatch restrictions const hasKnowledgeBaseModelMismatch = useCallback( (kb: KnowledgeBase): boolean => { - return false; + if (kb.embeddingModel === "unknown") { + return false; + } + // DataMate knowledge bases don't report model mismatch (they are always selectable) + if (kb.source === "datamate") { + return false; + } + + if (kb.is_multimodal) { + const multiEmbeddingModel = + modelConfig?.multiEmbedding?.modelName?.trim() || ""; + if (!multiEmbeddingModel) { + return true; + } + return kb.embeddingModel !== multiEmbeddingModel; + } + + if (!state.currentEmbeddingModel) { + return false; + } + return kb.embeddingModel !== state.currentEmbeddingModel; }, - [] + [modelConfig?.multiEmbedding?.modelName, state.currentEmbeddingModel] ); // Load knowledge base data (supports force fetch from server and load selected status) - optimized with useCallback @@ -311,17 +333,22 @@ export const KnowledgeBaseProvider: React.FC = ({ source: string = "elasticsearch", ingroup_permission?: string, group_ids?: number[], - embeddingModel?: string + is_multimodal?: boolean ) => { try { + const multiEmbeddingModel = + modelConfig?.multiEmbedding?.modelName?.trim() || ""; + const embeddingModel = is_multimodal + ? multiEmbeddingModel + : state.currentEmbeddingModel || "text-embedding-3-small"; const newKB = await knowledgeBaseService.createKnowledgeBase({ name, description, source, - // Use provided embeddingModel if available, otherwise fall back to current model or default - embeddingModel: embeddingModel || state.currentEmbeddingModel || "", + embeddingModel, ingroup_permission, group_ids, + is_multimodal, }); return newKB; } catch (error) { @@ -333,7 +360,7 @@ export const KnowledgeBaseProvider: React.FC = ({ return null; } }, - [state.currentEmbeddingModel, t] + [modelConfig?.multiEmbedding?.modelName, state.currentEmbeddingModel, t] ); // Delete knowledge base - memoized with useCallback diff --git a/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx b/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx index 4c26da19d..8545e7658 100644 --- a/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx +++ b/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx @@ -277,7 +277,7 @@ export const ModelAddDialog = ({ let getModelList; let getProviderSelectedModalList; -// 2. 根据条件赋值 + // 2. Select provider-specific hooks if (form.provider === "silicon") { ({ getModelList, getProviderSelectedModalList } = siliconHook); } else if (form.provider === "dashscope") { @@ -454,7 +454,8 @@ export const ModelAddDialog = ({ // Call backend healthcheck API for tenant management const result = await modelService.checkManageTenantModelConnectivity( tenantId, - form.displayName || form.name + form.displayName || form.name, + modelType, ); // Set connectivity status diff --git a/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx b/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx index eeddec04f..75b8e2809 100644 --- a/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx +++ b/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx @@ -113,7 +113,8 @@ export const ModelEditDialog = ({ // Call backend healthcheck API for tenant management const result = await modelService.checkManageTenantModelConnectivity( tenantId, - form.displayName || form.name + form.displayName || form.name, + modelType, ); // Set connectivity status diff --git a/frontend/app/[locale]/models/components/modelConfig.tsx b/frontend/app/[locale]/models/components/modelConfig.tsx index e20e74876..5e91f71f1 100644 --- a/frontend/app/[locale]/models/components/modelConfig.tsx +++ b/frontend/app/[locale]/models/components/modelConfig.tsx @@ -527,6 +527,7 @@ export const ModelConfigSection = forwardRef< try { const isConnected = await modelService.verifyCustomModel( modelName, + modelType, signal ); @@ -603,7 +604,7 @@ export const ModelConfigSection = forwardRef< throttleTimerRef.current = setTimeout(async () => { try { // Use modelService to verify model - const isConnected = await modelService.verifyCustomModel(displayName); + const isConnected = await modelService.verifyCustomModel(displayName, modelType); // Update model status updateModelStatus( diff --git a/frontend/app/[locale]/tenant-resources/components/resources/ModelList.tsx b/frontend/app/[locale]/tenant-resources/components/resources/ModelList.tsx index 6de719127..e06ea0f58 100644 --- a/frontend/app/[locale]/tenant-resources/components/resources/ModelList.tsx +++ b/frontend/app/[locale]/tenant-resources/components/resources/ModelList.tsx @@ -92,7 +92,7 @@ export default function ModelList({ tenantId }: { tenantId: string | null }) { }; // Handle checking model connectivity - const handleCheckConnectivity = async (displayName: string) => { + const handleCheckConnectivity = async (displayName: string, modelType: string) => { if (!tenantId) { message.error(t("tenantResources.tenants.tenantIdRequired")); return; @@ -102,7 +102,8 @@ export default function ModelList({ tenantId }: { tenantId: string | null }) { try { const isConnected = await modelService.checkManageTenantModelConnectivity( tenantId, - displayName + displayName, + modelType ); if (isConnected) { message.success(t("tenantResources.models.connectivitySuccess")); @@ -197,7 +198,7 @@ export default function ModelList({ tenantId }: { tenantId: string | null }) {