diff --git a/.github/workflows/deploy-docs.yml b/.github/workflows/deploy-docs.yml index f559357f8..8e84693fb 100644 --- a/.github/workflows/deploy-docs.yml +++ b/.github/workflows/deploy-docs.yml @@ -38,18 +38,11 @@ jobs: with: fetch-depth: 0 - - name: Setup Node - uses: actions/setup-node@v4 - with: - node-version: 20 - cache: npm - cache-dependency-path: ./doc/package-lock.json - - name: Setup Pages uses: actions/configure-pages@v4 - name: Install dependencies - run: npm ci + run: npm install - name: Build with VitePress run: npm run docs:build diff --git a/.gitignore b/.gitignore index 2c212f5e2..8e066b585 100644 --- a/.gitignore +++ b/.gitignore @@ -20,6 +20,7 @@ docker/openssh-server docker/volumes/db/data docker/.env docker/.run +docker/deploy.options frontend_standalone/ .pnpm-store/ diff --git a/backend/agents/create_agent_info.py b/backend/agents/create_agent_info.py index 6e8d17740..fd0c7db2e 100644 --- a/backend/agents/create_agent_info.py +++ b/backend/agents/create_agent_info.py @@ -15,7 +15,7 @@ get_vector_db_core, get_embedding_model, ) -from services.tenant_config_service import get_selected_knowledge_list +from services.tenant_config_service import get_selected_knowledge_list, build_knowledge_name_mapping from services.remote_mcp_service import get_remote_mcp_server_list from services.memory_config_service import build_memory_context from services.image_service import get_vlm_model @@ -241,6 +241,7 @@ async def create_tool_config_list(agent_id, tenant_id, user_id): "index_names": index_names, "vdb_core": get_vector_db_core(), "embedding_model": get_embedding_model(tenant_id=tenant_id), + "name_resolver": build_knowledge_name_mapping(tenant_id=tenant_id, user_id=user_id), } elif tool_config.class_name == "AnalyzeTextFileTool": tool_config.metadata = { diff --git a/backend/apps/agent_app.py b/backend/apps/agent_app.py index a9ce8f6c9..b2ae60c1a 100644 --- a/backend/apps/agent_app.py +++ b/backend/apps/agent_app.py @@ -4,7 +4,7 @@ from fastapi import APIRouter, Body, Header, HTTPException, Request -from consts.model import AgentRequest, AgentInfoRequest, AgentIDRequest, ConversationResponse, AgentImportRequest +from consts.model import AgentRequest, AgentInfoRequest, AgentIDRequest, ConversationResponse, AgentImportRequest, AgentNameBatchCheckRequest, AgentNameBatchRegenerateRequest from services.agent_service import ( get_agent_info_impl, get_creating_sub_agent_info_impl, @@ -12,6 +12,8 @@ delete_agent_impl, export_agent_impl, import_agent_impl, + check_agent_name_conflict_batch_impl, + regenerate_agent_name_batch_impl, list_all_agent_info_impl, run_agent_stream, stop_agent_tasks, @@ -146,6 +148,36 @@ async def import_agent_api(request: AgentImportRequest, authorization: Optional[ status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Agent import error.") +@agent_config_router.post("/check_name") +async def check_agent_name_batch_api(request: AgentNameBatchCheckRequest, authorization: Optional[str] = Header(None)): + """ + Batch check whether agent name/display_name conflicts exist in the tenant. + """ + try: + return await check_agent_name_conflict_batch_impl(request, authorization) + except ValueError as e: + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e)) + except Exception as e: + logger.error(f"Agent name batch check error: {str(e)}") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Agent name batch check error.") + + +@agent_config_router.post("/regenerate_name") +async def regenerate_agent_name_batch_api(request: AgentNameBatchRegenerateRequest, authorization: Optional[str] = Header(None)): + """ + Batch regenerate agent name/display_name using LLM or suffix fallback. + """ + try: + return await regenerate_agent_name_batch_impl(request, authorization) + except ValueError as e: + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e)) + except Exception as e: + logger.error(f"Agent name batch regenerate error: {str(e)}") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Agent name batch regenerate error.") + + @agent_config_router.get("/list") async def list_all_agent_info_api(authorization: Optional[str] = Header(None), request: Request = None): """ diff --git a/backend/apps/file_management_app.py b/backend/apps/file_management_app.py index 4869ce440..9ed87cfae 100644 --- a/backend/apps/file_management_app.py +++ b/backend/apps/file_management_app.py @@ -1,5 +1,6 @@ import logging import re +import base64 from http import HTTPStatus from typing import List, Optional from urllib.parse import urlparse, urlunparse, unquote, quote @@ -149,7 +150,16 @@ async def process_files( @file_management_config_router.get("/download/{object_name:path}") async def get_storage_file( object_name: str = PathParam(..., description="File object name"), - download: str = Query("ignore", description="How to get the file"), + download: str = Query( + "ignore", + description=( + "How to get the file: " + "'ignore' (default, return file info), " + "'stream' (return file stream), " + "'redirect' (redirect to download URL), " + "'base64' (return base64-encoded content for images)." + ), + ), expires: int = Query(3600, description="URL validity period (seconds)"), filename: Optional[str] = Query(None, description="Original filename for download (optional)") ): @@ -192,6 +202,28 @@ async def get_storage_file( "ETag": f'"{object_name}"', } ) + elif download == "base64": + # Return base64 encoded file content (primarily for images) + file_stream, content_type = await get_file_stream_impl(object_name=object_name) + try: + data = file_stream.read() + except Exception as exc: + logger.error("Failed to read file stream for base64: %s", str(exc)) + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail="Failed to read file content for base64 encoding", + ) + + base64_content = base64.b64encode(data).decode("utf-8") + return JSONResponse( + status_code=HTTPStatus.OK, + content={ + "success": True, + "base64": base64_content, + "content_type": content_type, + "object_name": object_name, + }, + ) else: # return file metadata return await get_file_url_impl(object_name=object_name, expires=expires) diff --git a/backend/apps/vectordatabase_app.py b/backend/apps/vectordatabase_app.py index 4eec301dd..39b94fbd0 100644 --- a/backend/apps/vectordatabase_app.py +++ b/backend/apps/vectordatabase_app.py @@ -1,9 +1,11 @@ import logging +import json from http import HTTPStatus from typing import Any, Dict, List, Optional from fastapi import APIRouter, Body, Depends, Header, HTTPException, Path, Query from fastapi.responses import JSONResponse +import re from consts.model import ChunkCreateRequest, ChunkUpdateRequest, HybridSearchRequest, IndexingResponse from nexent.vector_database.base import VectorDatabaseCore @@ -15,6 +17,8 @@ ) from services.redis_service import get_redis_service from utils.auth_utils import get_current_user_id +from utils.file_management_utils import get_all_files_status +from database.knowledge_db import get_index_name_by_knowledge_name router = APIRouter(prefix="/indices") service = ElasticSearchService() @@ -49,7 +53,8 @@ def create_new_index( """Create a new vector index and store it in the knowledge table""" try: user_id, tenant_id = get_current_user_id(authorization) - return ElasticSearchService.create_index(index_name, embedding_dim, vdb_core, user_id, tenant_id) + # Treat path parameter as user-facing knowledge base name for new creations + return ElasticSearchService.create_knowledge_base(index_name, embedding_dim, vdb_core, user_id, tenant_id) except Exception as e: raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=f"Error creating index: {str(e)}") @@ -99,7 +104,9 @@ def create_index_documents( data: List[Dict[str, Any] ] = Body(..., description="Document List to process"), vdb_core: VectorDatabaseCore = Depends(get_vector_db_core), - authorization: Optional[str] = Header(None) + authorization: Optional[str] = Header(None), + task_id: Optional[str] = Header( + None, alias="X-Task-Id", description="Task ID for progress tracking"), ): """ Index documents with embeddings, creating the index if it doesn't exist. @@ -108,12 +115,21 @@ def create_index_documents( try: user_id, tenant_id = get_current_user_id(authorization) embedding_model = get_embedding_model(tenant_id) - return ElasticSearchService.index_documents(embedding_model, index_name, data, vdb_core) + return ElasticSearchService.index_documents( + embedding_model=embedding_model, + index_name=index_name, + data=data, + vdb_core=vdb_core, + task_id=task_id, + ) except Exception as e: error_msg = str(e) logger.error(f"Error indexing documents: {error_msg}") + raise HTTPException( - status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=f"Error indexing documents: {error_msg}") + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail=f"Error indexing documents: {error_msg}" + ) @router.get("/{index_name}/files") @@ -187,6 +203,66 @@ def delete_documents( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=f"Error delete indexing documents: {e}") +@router.get("/{index_name}/documents/{path_or_url:path}/error-info") +async def get_document_error_info( + index_name: str = Path(..., description="Name of the index"), + path_or_url: str = Path(..., + description="Path or URL of the document"), + authorization: Optional[str] = Header(None) +): + """Get error information for a document""" + try: + celery_task_files = await get_all_files_status(index_name) + file_status = celery_task_files.get(path_or_url) + + if not file_status: + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail=f"Document {path_or_url} not found in index {index_name}" + ) + + task_id = file_status.get('latest_task_id', '') + if not task_id: + return { + "status": "success", + "error_code": None, + } + + redis_service = get_redis_service() + raw_error = redis_service.get_error_info(task_id) + error_code = None + + if raw_error: + # Try to parse JSON (new format with error_code only) + try: + parsed = json.loads(raw_error) + if isinstance(parsed, dict) and "error_code" in parsed: + error_code = parsed.get("error_code") + except Exception: + # Fallback: regex extraction if JSON parsing fails + try: + match = re.search( + r'["\']error_code["\']\s*:\s*["\']([^"\']+)["\']', raw_error) + if match: + error_code = match.group(1) + except Exception: + pass + + return { + "status": "success", + "error_code": error_code, + } + except HTTPException: + raise + except Exception as e: + logger.error( + f"Error getting error info for document {path_or_url}: {str(e)}") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail=f"Error getting error info: {str(e)}" + ) + + # Health check @router.get("/health") def health_check(vdb_core: VectorDatabaseCore = Depends(get_vector_db_core)): @@ -201,25 +277,35 @@ def health_check(vdb_core: VectorDatabaseCore = Depends(get_vector_db_core)): @router.post("/{index_name}/chunks") def get_index_chunks( index_name: str = Path(..., - description="Name of the index to get chunks from"), + description="Name of the index (or knowledge_name) to get chunks from"), page: int = Query( None, description="Page number (1-based) for pagination"), page_size: int = Query( None, description="Number of records per page for pagination"), path_or_url: Optional[str] = Query( None, description="Filter chunks by document path_or_url"), - vdb_core: VectorDatabaseCore = Depends(get_vector_db_core) + vdb_core: VectorDatabaseCore = Depends(get_vector_db_core), + authorization: Optional[str] = Header(None) ): """Get chunks from the specified index, with optional pagination support""" try: + _, tenant_id = get_current_user_id(authorization) + actual_index_name = get_index_name_by_knowledge_name( + index_name, tenant_id) + result = ElasticSearchService.get_index_chunks( - index_name=index_name, + index_name=actual_index_name, page=page, page_size=page_size, path_or_url=path_or_url, vdb_core=vdb_core, ) return JSONResponse(status_code=HTTPStatus.OK, content=result) + except ValueError as e: + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail=str(e) + ) except Exception as e: error_msg = str(e) logger.error( @@ -230,21 +316,29 @@ def get_index_chunks( @router.post("/{index_name}/chunk") def create_chunk( - index_name: str = Path(..., description="Name of the index"), + index_name: str = Path(..., + description="Name of the index (or knowledge_name)"), payload: ChunkCreateRequest = Body(..., description="Chunk data"), vdb_core: VectorDatabaseCore = Depends(get_vector_db_core), authorization: Optional[str] = Header(None), ): """Create a manual chunk.""" try: - user_id, _ = get_current_user_id(authorization) + user_id, tenant_id = get_current_user_id(authorization) + actual_index_name = get_index_name_by_knowledge_name( + index_name, tenant_id) result = ElasticSearchService.create_chunk( - index_name=index_name, + index_name=actual_index_name, chunk_request=payload, vdb_core=vdb_core, user_id=user_id, ) return JSONResponse(status_code=HTTPStatus.OK, content=result) + except ValueError as e: + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail=str(e) + ) except Exception as exc: logger.error( "Error creating chunk for index %s: %s", index_name, exc, exc_info=True @@ -256,7 +350,8 @@ def create_chunk( @router.put("/{index_name}/chunk/{chunk_id}") def update_chunk( - index_name: str = Path(..., description="Name of the index"), + index_name: str = Path(..., + description="Name of the index (or knowledge_name)"), chunk_id: str = Path(..., description="Chunk identifier"), payload: ChunkUpdateRequest = Body(..., description="Chunk update payload"), @@ -265,18 +360,22 @@ def update_chunk( ): """Update an existing chunk.""" try: - user_id, _ = get_current_user_id(authorization) + user_id, tenant_id = get_current_user_id(authorization) + actual_index_name = get_index_name_by_knowledge_name( + index_name, tenant_id) result = ElasticSearchService.update_chunk( - index_name=index_name, + index_name=actual_index_name, chunk_id=chunk_id, chunk_request=payload, vdb_core=vdb_core, user_id=user_id, ) return JSONResponse(status_code=HTTPStatus.OK, content=result) - except ValueError as exc: + except ValueError as e: raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, detail=str(exc)) + status_code=HTTPStatus.NOT_FOUND, + detail=str(e) + ) except Exception as exc: logger.error( "Error updating chunk %s for index %s: %s", @@ -292,22 +391,28 @@ def update_chunk( @router.delete("/{index_name}/chunk/{chunk_id}") def delete_chunk( - index_name: str = Path(..., description="Name of the index"), + index_name: str = Path(..., + description="Name of the index (or knowledge_name)"), chunk_id: str = Path(..., description="Chunk identifier"), vdb_core: VectorDatabaseCore = Depends(get_vector_db_core), authorization: Optional[str] = Header(None), ): """Delete a chunk.""" try: - get_current_user_id(authorization) + _, tenant_id = get_current_user_id(authorization) + actual_index_name = get_index_name_by_knowledge_name( + index_name, tenant_id) result = ElasticSearchService.delete_chunk( - index_name=index_name, + index_name=actual_index_name, chunk_id=chunk_id, vdb_core=vdb_core, ) return JSONResponse(status_code=HTTPStatus.OK, content=result) - except ValueError as exc: - raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=str(exc)) + except ValueError as e: + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail=str(e) + ) except Exception as exc: logger.error( "Error deleting chunk %s for index %s: %s", diff --git a/backend/consts/const.py b/backend/consts/const.py index 8e99ca84d..0b5f4bcef 100644 --- a/backend/consts/const.py +++ b/backend/consts/const.py @@ -38,6 +38,7 @@ class VectorDatabaseType(str, Enum): MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB MAX_CONCURRENT_UPLOADS = 5 UPLOAD_FOLDER = os.getenv('UPLOAD_FOLDER', 'uploads') +ROOT_DIR = os.getenv("ROOT_DIR") # Supabase Configuration @@ -278,8 +279,10 @@ class VectorDatabaseType(str, Enum): LLM_SLOW_TOKEN_RATE_THRESHOLD = float( os.getenv("LLM_SLOW_TOKEN_RATE_THRESHOLD", "10.0")) # tokens per second -# APP Version -APP_VERSION = "v1.7.7.1" DEFAULT_ZH_TITLE = "新对话" DEFAULT_EN_TITLE = "New Conversation" + + +# APP Version +APP_VERSION = "v1.7.8" diff --git a/backend/consts/model.py b/backend/consts/model.py index bb0aefb7f..cf22afbf2 100644 --- a/backend/consts/model.py +++ b/backend/consts/model.py @@ -59,6 +59,7 @@ class ModelRequest(BaseModel): connect_status: Optional[str] = '' expected_chunk_size: Optional[int] = None maximum_chunk_size: Optional[int] = None + chunk_batch: Optional[int] = None class ProviderModelRequest(BaseModel): @@ -246,6 +247,7 @@ class AgentInfoRequest(BaseModel): display_name: Optional[str] = None description: Optional[str] = None business_description: Optional[str] = None + author: Optional[str] = None model_name: Optional[str] = None model_id: Optional[int] = None max_steps: Optional[int] = None @@ -311,6 +313,7 @@ class ExportAndImportAgentInfo(BaseModel): display_name: Optional[str] = None description: str business_description: str + author: Optional[str] = None max_steps: int provide_run_summary: bool duty_prompt: Optional[str] = None @@ -344,6 +347,27 @@ class AgentImportRequest(BaseModel): force_import: bool = False +class AgentNameBatchRegenerateItem(BaseModel): + name: str + display_name: Optional[str] = None + task_description: Optional[str] = "" + agent_id: Optional[int] = None + + +class AgentNameBatchRegenerateRequest(BaseModel): + items: List[AgentNameBatchRegenerateItem] + + +class AgentNameBatchCheckItem(BaseModel): + name: str + display_name: Optional[str] = None + agent_id: Optional[int] = None + + +class AgentNameBatchCheckRequest(BaseModel): + items: List[AgentNameBatchCheckItem] + + class ConvertStateRequest(BaseModel): """Request schema for /tasks/convert_state endpoint""" process_state: str = "" diff --git a/backend/data_process/tasks.py b/backend/data_process/tasks.py index 50ddad166..50414b711 100644 --- a/backend/data_process/tasks.py +++ b/backend/data_process/tasks.py @@ -10,12 +10,14 @@ from typing import Any, Dict, Optional import aiohttp +import re import ray from celery import Task, chain, states from celery.exceptions import Retry from consts.const import ELASTICSEARCH_SERVICE from utils.file_management_utils import get_file_size +from services.redis_service import get_redis_service from .app import app from .ray_actors import DataProcessorRayActor from consts.const import ( @@ -23,6 +25,7 @@ FORWARD_REDIS_RETRY_DELAY_S, FORWARD_REDIS_RETRY_MAX, DISABLE_RAY_DASHBOARD, + ROOT_DIR, ) @@ -31,6 +34,74 @@ # Thread lock for initializing Ray to prevent race conditions ray_init_lock = threading.Lock() +ROOT_DIR_DISPLAY = ROOT_DIR or "{ROOT_DIR}" + + +def extract_error_code(reason: str, parsed_error: Optional[Dict] = None) -> Optional[str]: + """ + Extract error code from error message or parsed error dict. + Returns error code if matched, None otherwise. + """ + # 1) parsed_error dict + if parsed_error and isinstance(parsed_error, dict): + code = parsed_error.get("error_code") + if code: + return code + + # 2) try parse reason as JSON + try: + parsed = json.loads(reason) + if isinstance(parsed, dict): + code = parsed.get("error_code") + if code: + return code + detail = parsed.get("detail") + if isinstance(detail, dict) and detail.get("error_code"): + return detail.get("error_code") + except Exception: + pass + + # 3) regex from raw string (supports single/double quotes) + try: + match = re.search( + r'["\']error_code["\']\s*:\s*["\']([^"\']+)["\']', reason) + if match: + return match.group(1) + except Exception: + pass + + return "unknown_error" + + +def save_error_to_redis(task_id: str, error_reason: str, start_time: float): + """ + Save error information to Redis + + Args: + task_id: Celery task ID + error_reason: Short error reason summary + start_time: Task start timestamp (unused, kept for compatibility) + """ + if not task_id: + logger.warning("Cannot save error info: task_id is empty") + return + if not error_reason: + logger.warning( + f"Cannot save error info for task {task_id}: error_reason is empty") + return + try: + redis_service = get_redis_service() + success = redis_service.save_error_info(task_id, error_reason) + if success: + logger.info( + f"Successfully saved error info for task {task_id}: {error_reason[:100]}...") + else: + logger.warning( + f"Failed to save error info for task {task_id}: save_error_info returned False") + except Exception as e: + logger.error( + f"Failed to save error info to Redis for task {task_id}: {str(e)}", exc_info=True) + def init_ray_in_worker(): """ @@ -58,7 +129,7 @@ def init_ray_in_worker(): logger.info("Ray initialized successfully for Celery worker.") except Exception as e: logger.error(f"Failed to initialize Ray for Celery worker: {e}") - raise + raise RuntimeError("Failed to initialize Ray for Celery worker") from e def run_async(coro): @@ -282,6 +353,17 @@ def process( raise NotImplementedError( f"Source type '{source_type}' not yet supported") + chunk_count = len(chunks) if chunks else 0 + if chunk_count == 0: + raise Exception(json.dumps({ + "message": "Ray processing completed but produced 0 chunks", + "index_name": index_name, + "task_name": "process", + "source": source, + "original_filename": original_filename, + "error_code": "no_valid_chunks" + }, ensure_ascii=False)) + # Update task state to SUCCESS after Ray processing completes # This transitions from STARTED (PROCESSING) to SUCCESS (WAIT_FOR_FORWARDING) self.update_state( @@ -316,31 +398,114 @@ def process( except Exception as e: logger.error(f"Error processing file {source}: {str(e)}") + # task_id is already defined at the start of the function try: + # Try to parse the exception as JSON (it might be our custom JSON error) + error_message = str(e) + parsed_error = None + + try: + parsed_error = json.loads(error_message) + if isinstance(parsed_error, dict): + error_message = parsed_error.get("message", error_message) + logger.debug( + f"Parsed JSON error for task {task_id}" + ) + except (json.JSONDecodeError, TypeError): + # Not a JSON string, use as-is + logger.debug( + f"Exception is not JSON format for task {task_id}, using raw message" + ) + + # Build error_info for re-raising error_info = { - "message": str(e), + "message": error_message, "index_name": index_name, "task_name": "process", "source": source, - "original_filename": original_filename + "original_filename": original_filename, } + + # Extract error code from parsed error or error message + error_code = extract_error_code(error_message, parsed_error) + if error_code: + error_info["error_code"] = error_code + + # Store only error code (if available) or raw error message + if error_code: + reason_to_store = json.dumps({ + "error_code": error_code + }, ensure_ascii=False) + else: + # Fallback: store raw error message (truncated if too long) + reason_to_store = error_message + if len(reason_to_store) > 200: + reason_to_store = reason_to_store[:200] + "..." + + # Save error info to Redis BEFORE re-raising + logger.info( + f"Attempting to save error info for task {task_id} with reason: {reason_to_store[:100]}..." + ) + save_error_to_redis(task_id, reason_to_store, start_time) + self.update_state( meta={ - 'source': error_info.get('source', ''), - 'index_name': error_info.get('index_name', ''), - 'task_name': error_info.get('task_name', ''), - 'original_filename': error_info.get('original_filename', ''), - 'custom_error': error_info.get('message', str(e)), - 'stage': 'text_extraction_failed' + "source": error_info.get("source", ""), + "index_name": error_info.get("index_name", ""), + "task_name": error_info.get("task_name", ""), + "original_filename": error_info.get( + "original_filename", "" + ), + "custom_error": error_info.get("message", str(e)), + "stage": "text_extraction_failed", } ) raise Exception(json.dumps(error_info, ensure_ascii=False)) except Exception as ex: logger.error(f"Error serializing process exception: {str(ex)}") + # Try to save error even if serialization fails + try: + error_message = str(e) + parsed_error = None + + try: + parsed_error = json.loads(error_message) + if isinstance(parsed_error, dict): + error_message = parsed_error.get( + "message", error_message + ) + logger.debug( + "Fallback serialization: parsed JSON error for task " + f"{task_id}" + ) + except (json.JSONDecodeError, TypeError): + logger.debug( + "Fallback serialization: exception is not JSON format " + f"for task {task_id}, using raw message" + ) + parsed_error = None + + # Extract error code from parsed error or error message + error_code = extract_error_code(error_message, parsed_error) + + # Store only error code (if available) or raw error message + if error_code: + reason_to_store = json.dumps({ + "error_code": error_code + }, ensure_ascii=False) + else: + # Fallback: store raw error message (truncated if too long) + reason_to_store = error_message + if len(reason_to_store) > 200: + reason_to_store = reason_to_store[:200] + "..." + + save_error_to_redis(task_id, reason_to_store, start_time) + except Exception: + pass self.update_state( meta={ - 'custom_error': str(e), - 'stage': 'text_extraction_failed' + "custom_error": str(e), + "stage": "text_extraction_failed", } ) raise @@ -377,6 +542,38 @@ def forward( filename = original_filename try: + # Before doing any heavy work, check whether this task has been + # explicitly cancelled (for example, because the user deleted the + # document from the knowledge base configuration page). + try: + redis_service = get_redis_service() + if redis_service.is_task_cancelled(task_id): + logger.info( + f"[{self.request.id}] FORWARD TASK: Detected cancellation flag for task {task_id}; " + f"skipping chunk forwarding for source '{source}' in index '{index_name}'." + ) + # Treat this as a graceful early exit. We still return a + # structured payload so callers can consider the task done. + return { + 'task_id': task_id, + 'source': source, + 'index_name': index_name, + 'original_filename': original_filename, + 'chunks_stored': 0, + 'storage_time': 0, + 'es_result': { + "success": False, + "message": "Indexing cancelled because document was deleted.", + "total_indexed": 0, + "total_submitted": 0, + }, + } + except Exception as cancel_check_exc: + logger.warning( + f"[{self.request.id}] FORWARD TASK: Failed to check cancellation flag for task {task_id}: " + f"{cancel_check_exc}" + ) + chunks = processed_data.get('chunks') # If chunks are not in payload, try loading from Redis via the redis_key if (not chunks) and processed_data.get('redis_key'): @@ -441,18 +638,8 @@ def forward( logger.info( f"[{self.request.id}] FORWARD TASK: Received data for source '{original_source}' with {len(chunks) if chunks else 'None'} chunks") - # Update task state to FORWARDING - self.update_state( - state=states.STARTED, - meta={ - 'source': original_source, - 'index_name': original_index_name, - 'original_filename': filename, - 'task_name': 'forward', - 'start_time': start_time, - 'stage': 'vectorizing_and_storing' - } - ) + # Calculate total chunks for progress tracking + total_chunks = len(chunks) if chunks else 0 if chunks is None: raise Exception(json.dumps({ @@ -500,7 +687,8 @@ def forward( "index_name": original_index_name, "task_name": "forward", "source": original_source, - "original_filename": original_filename + "original_filename": original_filename, + "error_code": "no_valid_chunks" }, ensure_ascii=False)) async def index_documents(): @@ -518,84 +706,115 @@ async def index_documents(): headers = {"Content-Type": "application/json"} if authorization: headers["Authorization"] = authorization + # Add task_id header for progress tracking + headers["X-Task-Id"] = task_id - max_retries = 5 - retry_delay = 5 - for retry in range(max_retries): - try: - connector = aiohttp.TCPConnector(verify_ssl=False) - # Increased timeout for large documents - timeout = aiohttp.ClientTimeout(total=120) - - async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session: - async with session.post( - full_url, - headers=headers, - json=formatted_chunks, - raise_for_status=True - ) as response: - result = await response.json() - return result - - except aiohttp.ClientResponseError as e: - if e.status == 503 and retry < max_retries - 1: - wait_time = retry_delay * (retry + 1) - await asyncio.sleep(wait_time) - else: - raise Exception(json.dumps({ - "message": f"ElasticSearch service unavailable: {str(e)}", - "index_name": original_index_name, - "task_name": "forward", - "source": original_source, - "original_filename": original_filename - }, ensure_ascii=False)) - except aiohttp.ClientConnectorError as e: - logger.error( - f"[{self.request.id}] FORWARD TASK: Connection error to {full_url}: {str(e)}") - if retry < max_retries - 1: - wait_time = retry_delay * (retry + 1) - logger.warning( - f"[{self.request.id}] FORWARD TASK: Connection error when indexing documents: {str(e)}. Retrying in {wait_time}s...") - await asyncio.sleep(wait_time) - else: - raise Exception(json.dumps({ - "message": f"Failed to connect to API: {str(e)}", - "index_name": original_index_name, - "task_name": "forward", - "source": original_source, - "original_filename": original_filename - }, ensure_ascii=False)) - except asyncio.TimeoutError as e: - if retry < max_retries - 1: - wait_time = retry_delay * (retry + 1) - logger.warning( - f"[{self.request.id}] FORWARD TASK: Timeout when indexing documents: {str(e)}. Retrying in {wait_time}s...") - await asyncio.sleep(wait_time) - else: - raise Exception(json.dumps({ - "message": f"Timeout after {max_retries} attempts: {str(e)}", - "index_name": original_index_name, - "task_name": "forward", - "source": original_source, - "original_filename": original_filename - }, ensure_ascii=False)) - except Exception as e: - if retry < max_retries - 1: - wait_time = retry_delay * (retry + 1) - logger.warning( - f"[{self.request.id}] FORWARD TASK: Unexpected error when indexing documents: {str(e)}. Retrying in {wait_time}s...") - await asyncio.sleep(wait_time) - else: - raise Exception(json.dumps({ - "message": f"Unexpected error when indexing documents: {str(e)}", - "index_name": original_index_name, - "task_name": "forward", - "source": original_source, - "original_filename": original_filename - }, ensure_ascii=False)) + try: + connector = aiohttp.TCPConnector(verify_ssl=False) + timeout = aiohttp.ClientTimeout(total=600) + + async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session: + async with session.post( + full_url, + headers=headers, + json=formatted_chunks, + raise_for_status=False + ) as response: + text = await response.text() + status = response.status + # Try parse JSON body for structured error_code/message + parsed_body = None + try: + parsed_body = json.loads(text) + except Exception: + parsed_body = None + + if status >= 400: + error_code = None + if isinstance(parsed_body, dict): + error_code = parsed_body.get("error_code") + detail = parsed_body.get("detail") + if isinstance(detail, dict) and detail.get("error_code"): + error_code = detail.get("error_code") + elif isinstance(detail, str): + try: + parsed_detail = json.loads(detail) + if isinstance(parsed_detail, dict): + error_code = parsed_detail.get( + "error_code", error_code) + except Exception: + pass + + if not error_code: + try: + match = re.search( + r'["\']error_code["\']\s*:\s*["\']([^"\']+)["\']', text) + if match: + error_code = match.group(1) + except Exception: + pass + + if error_code: + # Raise flat payload to avoid nested JSON and preserve error_code + raise Exception(json.dumps({ + "error_code": error_code + }, ensure_ascii=False)) + + raise Exception( + f"ElasticSearch service returned HTTP {status}") + + result = parsed_body if isinstance(parsed_body, dict) else await response.json() + return result + + except aiohttp.ClientConnectorError as e: + logger.error( + f"[{self.request.id}] FORWARD TASK: Connection error to {full_url}: {str(e)}") + raise Exception(json.dumps({ + "message": f"Failed to connect to API: {str(e)}", + "index_name": original_index_name, + "task_name": "forward", + "source": original_source, + "original_filename": original_filename + }, ensure_ascii=False)) + except asyncio.TimeoutError as e: + logger.warning( + f"[{self.request.id}] FORWARD TASK: Timeout when indexing documents: {str(e)}.") + raise Exception(json.dumps({ + "message": f"Timeout when indexing documents: {str(e)}", + "index_name": original_index_name, + "task_name": "forward", + "source": original_source, + "original_filename": original_filename + }, ensure_ascii=False)) + except Exception as e: + logger.error( + f"[{self.request.id}] FORWARD TASK: Unexpected error when indexing documents: {str(e)}.") + raise Exception(json.dumps({ + "message": f"Unexpected error when indexing documents: {str(e)}", + "index_name": original_index_name, + "task_name": "forward", + "source": original_source, + "original_filename": original_filename + }, ensure_ascii=False)) logger.info( f"[{self.request.id}] FORWARD TASK: Starting ES indexing for {len(formatted_chunks)} chunks to index '{original_index_name}'...") + + # Update task state with total chunks before starting vectorization + self.update_state( + state=states.STARTED, + meta={ + 'source': original_source, + 'index_name': original_index_name, + 'original_filename': filename, + 'task_name': 'forward', + 'start_time': start_time, + 'stage': 'vectorizing_and_storing', + 'total_chunks': total_chunks, + 'processed_chunks': 0 # Will be updated during vectorization via Redis + } + ) + es_result = run_async(index_documents()) logger.debug( f"[{self.request.id}] FORWARD TASK: API response from main_server for source '{original_source}': {es_result}") @@ -617,7 +836,8 @@ async def index_documents(): "index_name": original_index_name, "task_name": "forward", "source": original_source, - "original_filename": original_filename + "original_filename": original_filename, + "error_code": "es_bulk_failed" }, ensure_ascii=False)) elif isinstance(es_result, dict) and not es_result.get("success"): error_message = es_result.get( @@ -638,6 +858,12 @@ async def index_documents(): "original_filename": original_filename }, ensure_ascii=False)) end_time = time.time() + + # Get final indexed count from result + final_processed = 0 + if isinstance(es_result, dict) and es_result.get("success"): + final_processed = es_result.get("total_indexed", len(chunks)) + logger.info( f"[{self.request.id}] FORWARD TASK: Updating task state to SUCCESS after ES indexing completion") self.update_state( @@ -650,7 +876,9 @@ async def index_documents(): 'original_filename': original_filename, 'task_name': 'forward', 'es_result': es_result, - 'stage': 'completed' + 'stage': 'completed', + 'total_chunks': total_chunks, + 'processed_chunks': final_processed } ) @@ -667,22 +895,68 @@ async def index_documents(): } except Exception as e: # If it's an Exception, all go here (including our custom JSON message) + # Important: if this is a Celery Retry, re-raise immediately without recording error_code + if isinstance(e, Retry): + raise + + task_id = self.request.id try: error_info = json.loads(str(e)) + error_message = error_info.get('message', str(e)) logger.error( - f"Error forwarding chunks for index '{error_info.get('index_name', '')}': {error_info.get('message', str(e))}") + f"Error forwarding chunks for index '{error_info.get('index_name', '')}': {error_message}") + + # Extract error code from parsed error or error message + error_code = extract_error_code(error_message, error_info) + + # Store only error code (if available) or raw error message + if error_code: + reason_to_store = json.dumps({ + "error_code": error_code + }, ensure_ascii=False) + else: + # Fallback: store raw error message (truncated if too long) + reason_to_store = error_message + if len(reason_to_store) > 200: + reason_to_store = reason_to_store[:200] + "..." + + # Save error info to Redis BEFORE re-raising + logger.info( + f"Attempting to save error info for task {task_id} with reason: {reason_to_store[:100]}...") + save_error_to_redis(task_id, reason_to_store, start_time) + self.update_state( meta={ 'source': error_info.get('source', ''), 'index_name': error_info.get('index_name', ''), 'task_name': error_info.get('task_name', ''), 'original_filename': error_info.get('original_filename', ''), - 'custom_error': error_info.get('message', str(e)), + 'custom_error': error_message, 'stage': 'forward_task_failed' } ) - except Exception as e: + except Exception: logger.error(f"Error forwarding chunks: {str(e)}") + # Try to save error even if parsing fails + try: + error_message = str(e) + # Extract error code from error message + error_code = extract_error_code(error_message, None) + + # Store only error code (if available) or raw error message + if error_code: + reason_to_store = json.dumps({ + "error_code": error_code + }, ensure_ascii=False) + else: + # Fallback: store raw error message (truncated if too long) + reason_to_store = error_message + if len(reason_to_store) > 200: + reason_to_store = reason_to_store[:200] + "..." + + save_error_to_redis(task_id, reason_to_store, start_time) + except Exception: + pass self.update_state( meta={ 'custom_error': str(e), diff --git a/backend/data_process/utils.py b/backend/data_process/utils.py index af0126f29..f4ed5631c 100644 --- a/backend/data_process/utils.py +++ b/backend/data_process/utils.py @@ -123,6 +123,25 @@ def sync_get(): if 'original_filename' in metadata: status_info['original_filename'] = metadata['original_filename'] + + # Get progress info from metadata + if 'total_chunks' in metadata: + status_info['total_chunks'] = metadata['total_chunks'] + if 'processed_chunks' in metadata: + status_info['processed_chunks'] = metadata['processed_chunks'] + + # Always try to get latest progress from Redis (real-time updates during vectorization) + # Redis progress takes precedence over metadata for active tasks + try: + from services.redis_service import get_redis_service + redis_service = get_redis_service() + progress_info = redis_service.get_progress_info(task_id) + if progress_info: + # Use Redis progress as primary source (updated in real-time) + status_info['processed_chunks'] = progress_info.get('processed_chunks', status_info.get('processed_chunks')) + status_info['total_chunks'] = progress_info.get('total_chunks', status_info.get('total_chunks')) + except Exception as e: + logger.debug(f"Failed to get progress from Redis for task {task_id}: {str(e)}") # Add error information for failed tasks if result.failed(): try: diff --git a/backend/database/db_models.py b/backend/database/db_models.py index a4201abad..49d8722ba 100644 --- a/backend/database/db_models.py +++ b/backend/database/db_models.py @@ -162,6 +162,8 @@ class ModelRecord(TableBase): Integer, doc="Maximum chunk size for embedding models, used during document chunking") ssl_verify = Column( Boolean, default=True, doc="Whether to verify SSL certificates when connecting to this model API. Default is true. Set to false for local services without SSL support.") + chunk_batch = Column( + Integer, doc="Batch size for concurrent embedding requests during document chunking") class ToolInfo(TableBase): @@ -199,6 +201,7 @@ class AgentInfo(TableBase): name = Column(String(100), doc="Agent name") display_name = Column(String(100), doc="Agent display name") description = Column(Text, doc="Description") + author = Column(String(100), doc="Agent author") model_name = Column(String(100), doc="[DEPRECATED] Name of the model used, use model_id instead") model_id = Column(Integer, doc="Model ID, foreign key reference to model_record_t.model_id") max_steps = Column(Integer, doc="Maximum number of steps") @@ -242,7 +245,8 @@ class KnowledgeRecord(TableBase): knowledge_id = Column(Integer, Sequence("knowledge_record_t_knowledge_id_seq", schema="nexent"), primary_key=True, nullable=False, doc="Knowledge base ID, unique primary key") - index_name = Column(String(100), doc="Knowledge base name") + index_name = Column(String(100), doc="Internal Elasticsearch index name") + knowledge_name = Column(String(100), doc="User-facing knowledge base name") knowledge_describe = Column(String(3000), doc="Knowledge base description") knowledge_sources = Column(String(300), doc="Knowledge base sources") embedding_model_name = Column(String(200), doc="Embedding model name, used to record the embedding model used by the knowledge base") diff --git a/backend/database/knowledge_db.py b/backend/database/knowledge_db.py index 392efb926..6faccdafa 100644 --- a/backend/database/knowledge_db.py +++ b/backend/database/knowledge_db.py @@ -1,5 +1,6 @@ from typing import Any, Dict, List, Optional +import uuid from sqlalchemy import func from sqlalchemy.exc import SQLAlchemyError @@ -7,7 +8,16 @@ from database.db_models import KnowledgeRecord -def create_knowledge_record(query: Dict[str, Any]) -> int: +def _generate_index_name(knowledge_id: int) -> str: + """ + Generate a new internal index_name based on knowledge_id and a UUID suffix. + The suffix contains only digits and lowercase letters. + """ + suffix = uuid.uuid4().hex + return f"{knowledge_id}-{suffix}" + + +def create_knowledge_record(query: Dict[str, Any]) -> Dict[str, Any]: """ Create a knowledge base record @@ -21,27 +31,49 @@ def create_knowledge_record(query: Dict[str, Any]) -> int: - embedding_model_name: embedding model name for the knowledge base Returns: - int: Newly created knowledge base ID + Dict[str, Any]: Dictionary with at least 'knowledge_id' and 'index_name' """ try: with get_db_session() as session: + # Determine user-facing knowledge base name + knowledge_name = query.get( + "knowledge_name") or query.get("index_name") + # Prepare data dictionary - data = { - "index_name": query["index_name"], + data: Dict[str, Any] = { "knowledge_describe": query.get("knowledge_describe", ""), "created_by": query.get("user_id"), "updated_by": query.get("user_id"), "knowledge_sources": query.get("knowledge_sources", "elasticsearch"), "tenant_id": query.get("tenant_id"), - "embedding_model_name": query.get("embedding_model_name") + "embedding_model_name": query.get("embedding_model_name"), + "knowledge_name": knowledge_name, } + # For backward compatibility: if caller explicitly provides index_name, + # respect it and do not regenerate; otherwise generate after flush. + explicit_index_name = query.get("index_name") + if explicit_index_name: + data["index_name"] = explicit_index_name + # Create new record new_record = KnowledgeRecord(**data) session.add(new_record) session.flush() + + # Generate internal index_name for new records when not explicitly provided + if not explicit_index_name: + generated_index_name = _generate_index_name( + new_record.knowledge_id) + new_record.index_name = generated_index_name + session.flush() + session.commit() - return new_record.knowledge_id + return { + "knowledge_id": new_record.knowledge_id, + "index_name": new_record.index_name, + "knowledge_name": new_record.knowledge_name, + } except SQLAlchemyError as e: session.rollback() raise e @@ -165,6 +197,7 @@ def get_knowledge_info_by_knowledge_ids(knowledge_ids: List[str]) -> List[Dict[s knowledge_info.append({ "knowledge_id": item.knowledge_id, "index_name": item.index_name, + "knowledge_name": item.knowledge_name, "knowledge_sources": item.knowledge_sources, "embedding_model_name": item.embedding_model_name }) @@ -208,4 +241,35 @@ def update_model_name_by_index_name(index_name: str, embedding_model_name: str, session.commit() return True except SQLAlchemyError as e: - raise e \ No newline at end of file + raise e + + +def get_index_name_by_knowledge_name(knowledge_name: str, tenant_id: str) -> str: + """ + Get the internal index_name from user-facing knowledge_name. + + Args: + knowledge_name: User-facing knowledge base name + tenant_id: Tenant ID to filter by + + Returns: + str: The internal index_name if found + + Raises: + ValueError: If knowledge base with the given name is not found for the tenant + """ + try: + with get_db_session() as session: + result = session.query(KnowledgeRecord).filter( + KnowledgeRecord.knowledge_name == knowledge_name, + KnowledgeRecord.tenant_id == tenant_id, + KnowledgeRecord.delete_flag != 'Y' + ).first() + + if result: + return result.index_name + raise ValueError( + f"Knowledge base '{knowledge_name}' not found for the current tenant" + ) + except SQLAlchemyError as e: + raise e diff --git a/backend/services/agent_service.py b/backend/services/agent_service.py index 5184b0e25..aafa38ba6 100644 --- a/backend/services/agent_service.py +++ b/backend/services/agent_service.py @@ -20,6 +20,8 @@ from consts.model import ( AgentInfoRequest, AgentRequest, + AgentNameBatchCheckRequest, + AgentNameBatchRegenerateRequest, ExportAndImportAgentInfo, ExportAndImportDataFormat, MCPInfo, @@ -410,6 +412,150 @@ def _regenerate_agent_display_name_with_llm( +async def check_agent_name_conflict_batch_impl( + request: AgentNameBatchCheckRequest, + authorization: str +) -> list[dict]: + """ + Batch check name/display_name duplication for multiple agents. + """ + _, tenant_id, _ = get_current_user_info(authorization) + agents_cache = query_all_agent_info_by_tenant_id(tenant_id) + + results: list[dict] = [] + for item in request.items: + if not item.name: + results.append({ + "name_conflict": False, + "display_name_conflict": False, + "conflict_agents": [] + }) + continue + + conflicts: list[dict] = [] + name_conflict = False + display_name_conflict = False + for agent in agents_cache: + if item.agent_id and agent.get("agent_id") == item.agent_id: + continue + matches_name = item.name and agent.get("name") == item.name + matches_display = item.display_name and agent.get( + "display_name") == item.display_name + if matches_name: + name_conflict = True + if matches_display: + display_name_conflict = True + if matches_name or matches_display: + conflicts.append({ + "name": agent.get("name"), + "display_name": agent.get("display_name"), + }) + + results.append({ + "name_conflict": name_conflict, + "display_name_conflict": display_name_conflict, + "conflict_agents": conflicts + }) + return results + + +async def regenerate_agent_name_batch_impl( + request: AgentNameBatchRegenerateRequest, + authorization: str +) -> list[dict]: + """ + Batch regenerate agent name/display_name with LLM (or suffix fallback). + """ + _, tenant_id, _ = get_current_user_info(authorization) + agents_cache = query_all_agent_info_by_tenant_id(tenant_id) + + existing_names = [agent.get("name") for agent in agents_cache if agent.get("name")] + existing_display_names = [agent.get("display_name") for agent in agents_cache if agent.get("display_name")] + + # Always use tenant quick-config LLM model + quick_config_model = tenant_config_manager.get_model_config( + key=MODEL_CONFIG_MAPPING["llm"], + tenant_id=tenant_id + ) + resolved_model_id = quick_config_model.get("model_id") if quick_config_model else None + if not resolved_model_id: + raise ValueError("No available model for regeneration. Please configure an LLM model first.") + + results: list[dict] = [] + # Use local mutable caches to avoid regenerated duplicates in the same batch + name_set = set(existing_names) + display_name_set = set(existing_display_names) + + for item in request.items: + agent_name = item.name or "" + agent_display_name = item.display_name or "" + task_description = item.task_description or "" + exclude_agent_id = item.agent_id + + # Regenerate name if duplicate and non-empty + if agent_name and _check_agent_name_duplicate( + agent_name, tenant_id, agents_cache=agents_cache, exclude_agent_id=exclude_agent_id + ): + try: + agent_name = await asyncio.to_thread( + _regenerate_agent_name_with_llm, + original_name=agent_name, + existing_names=list(name_set), + task_description=task_description, + model_id=resolved_model_id, + tenant_id=tenant_id, + language=LANGUAGE["ZH"], + agents_cache=agents_cache, + exclude_agent_id=exclude_agent_id + ) + except Exception as e: + logger.error(f"Failed to regenerate agent name with LLM: {str(e)}, using fallback") + agent_name = _generate_unique_agent_name_with_suffix( + agent_name, + tenant_id=tenant_id, + agents_cache=agents_cache, + exclude_agent_id=exclude_agent_id + ) + + # Regenerate display_name if duplicate and non-empty + if agent_display_name and _check_agent_display_name_duplicate( + agent_display_name, tenant_id, agents_cache=agents_cache, exclude_agent_id=exclude_agent_id + ): + try: + agent_display_name = await asyncio.to_thread( + _regenerate_agent_display_name_with_llm, + original_display_name=agent_display_name, + existing_display_names=list(display_name_set), + task_description=task_description, + model_id=resolved_model_id, + tenant_id=tenant_id, + language=LANGUAGE["ZH"], + agents_cache=agents_cache, + exclude_agent_id=exclude_agent_id + ) + except Exception as e: + logger.error(f"Failed to regenerate agent display_name with LLM: {str(e)}, using fallback") + agent_display_name = _generate_unique_display_name_with_suffix( + agent_display_name, + tenant_id=tenant_id, + agents_cache=agents_cache, + exclude_agent_id=exclude_agent_id + ) + + # Track regenerated names to avoid duplicates within batch + if agent_name: + name_set.add(agent_name) + if agent_display_name: + display_name_set.add(agent_display_name) + + results.append({ + "name": agent_name, + "display_name": agent_display_name + }) + + return results + + async def _stream_agent_chunks( agent_request: "AgentRequest", user_id: str, @@ -635,6 +781,7 @@ async def update_agent_info_impl(request: AgentInfoRequest, authorization: str = "display_name": request.display_name, "description": request.description, "business_description": request.business_description, + "author": request.author, "model_id": request.model_id, "model_name": request.model_name, "business_logic_model_id": request.business_logic_model_id, @@ -873,6 +1020,7 @@ async def export_agent_by_agent_id(agent_id: int, tenant_id: str, user_id: str) display_name=agent_info["display_name"], description=agent_info["description"], business_description=agent_info["business_description"], + author=agent_info.get("author"), max_steps=agent_info["max_steps"], provide_run_summary=agent_info["provide_run_summary"], duty_prompt=agent_info.get( @@ -1002,89 +1150,15 @@ async def import_agent_by_agent_id( tenant_id=tenant_id ) - # Check for duplicate names and regenerate if needed (unless forced import) agent_name = import_agent_info.name agent_display_name = import_agent_info.display_name - # Get all existing agent names and display names for duplicate checking - all_agents = query_all_agent_info_by_tenant_id(tenant_id) - existing_names = [agent.get("name") for agent in all_agents if agent.get("name")] - existing_display_names = [agent.get("display_name") for agent in all_agents if agent.get("display_name")] - - if not skip_duplicate_regeneration: - # Check and regenerate name if duplicate - if _check_agent_name_duplicate(agent_name, tenant_id, agents_cache=all_agents): - logger.info(f"Agent name '{agent_name}' already exists, regenerating with LLM") - # Get model for regeneration (use business_logic_model_id if available, otherwise use model_id) - regeneration_model_id = business_logic_model_id or model_id - if regeneration_model_id: - try: - # Offload blocking LLM regeneration to a thread to avoid blocking the event loop - agent_name = await asyncio.to_thread( - _regenerate_agent_name_with_llm, - original_name=agent_name, - existing_names=existing_names, - task_description=import_agent_info.business_description or import_agent_info.description or "", - model_id=regeneration_model_id, - tenant_id=tenant_id, - language=LANGUAGE["ZH"], # Default to Chinese, can be enhanced later - agents_cache=all_agents, - ) - logger.info(f"Regenerated agent name: '{agent_name}'") - except Exception as e: - logger.error(f"Failed to regenerate agent name with LLM: {str(e)}, using fallback") - agent_name = _generate_unique_agent_name_with_suffix( - agent_name, - tenant_id=tenant_id, - agents_cache=all_agents - ) - else: - logger.warning("No model available for regeneration, using fallback") - agent_name = _generate_unique_agent_name_with_suffix( - agent_name, - tenant_id=tenant_id, - agents_cache=all_agents - ) - - # Check and regenerate display_name if duplicate - if _check_agent_display_name_duplicate(agent_display_name, tenant_id, agents_cache=all_agents): - logger.info(f"Agent display_name '{agent_display_name}' already exists, regenerating with LLM") - # Get model for regeneration (use business_logic_model_id if available, otherwise use model_id) - regeneration_model_id = business_logic_model_id or model_id - if regeneration_model_id: - try: - # Offload blocking LLM regeneration to a thread to avoid blocking the event loop - agent_display_name = await asyncio.to_thread( - _regenerate_agent_display_name_with_llm, - original_display_name=agent_display_name, - existing_display_names=existing_display_names, - task_description=import_agent_info.business_description or import_agent_info.description or "", - model_id=regeneration_model_id, - tenant_id=tenant_id, - language=LANGUAGE["ZH"], # Default to Chinese, can be enhanced later - agents_cache=all_agents, - ) - logger.info(f"Regenerated agent display_name: '{agent_display_name}'") - except Exception as e: - logger.error(f"Failed to regenerate agent display_name with LLM: {str(e)}, using fallback") - agent_display_name = _generate_unique_display_name_with_suffix( - agent_display_name, - tenant_id=tenant_id, - agents_cache=all_agents - ) - else: - logger.warning("No model available for regeneration, using fallback") - agent_display_name = _generate_unique_display_name_with_suffix( - agent_display_name, - tenant_id=tenant_id, - agents_cache=all_agents - ) - # create a new agent new_agent = create_agent(agent_info={"name": agent_name, "display_name": agent_display_name, "description": import_agent_info.description, "business_description": import_agent_info.business_description, + "author": import_agent_info.author, "model_id": model_id, "model_name": import_agent_info.model_name, "business_logic_model_id": business_logic_model_id, @@ -1172,6 +1246,7 @@ async def list_all_agent_info_impl(tenant_id: str) -> list[dict]: "name": agent["name"] if agent["name"] else agent["display_name"], "display_name": agent["display_name"] if agent["display_name"] else agent["name"], "description": agent["description"], + "author": agent.get("author"), "is_available": len(unavailable_reasons) == 0, "unavailable_reasons": unavailable_reasons }) diff --git a/backend/services/conversation_management_service.py b/backend/services/conversation_management_service.py index b14835d90..857369f3c 100644 --- a/backend/services/conversation_management_service.py +++ b/backend/services/conversation_management_service.py @@ -5,7 +5,6 @@ from typing import Any, Dict, List, Optional from jinja2 import StrictUndefined, Template -from smolagents import OpenAIServerModel from consts.const import LANGUAGE, MODEL_CONFIG_MAPPING, MESSAGE_ROLE, DEFAULT_EN_TITLE, DEFAULT_ZH_TITLE from consts.model import AgentRequest, ConversationResponse, MessageRequest, MessageUnit @@ -27,7 +26,8 @@ rename_conversation, update_message_opinion ) -from nexent.core.utils.observer import ProcessType +from nexent.core.utils.observer import MessageObserver, ProcessType +from nexent.core.models import OpenAIModel from utils.config_utils import get_model_name_from_config, tenant_config_manager from utils.prompt_template_utils import get_generate_title_prompt_template from utils.str_utils import remove_think_blocks @@ -262,8 +262,8 @@ def call_llm_for_title(content: str, tenant_id: str, language: str = LANGUAGE["Z model_config = tenant_config_manager.get_model_config( key=MODEL_CONFIG_MAPPING["llm"], tenant_id=tenant_id) - # Create OpenAIServerModel instance - llm = OpenAIServerModel( + # Create OpenAIModel instance + llm = OpenAIModel( model_id=get_model_name_from_config(model_config) if model_config.get("model_name") else "", api_base=model_config.get("base_url", ""), api_key=model_config.get("api_key", ""), diff --git a/backend/services/model_management_service.py b/backend/services/model_management_service.py index 7fe1a86b6..7e7c59a5a 100644 --- a/backend/services/model_management_service.py +++ b/backend/services/model_management_service.py @@ -75,6 +75,9 @@ async def create_model_for_tenant(user_id: str, tenant_id: str, model_data: Dict # If embedding or multi_embedding, set max_tokens via embedding dimension check if model_data.get("model_type") in ("embedding", "multi_embedding"): model_data["max_tokens"] = await embedding_dimension_check(model_data) + # Set default chunk_batch if not provided + if model_data.get("chunk_batch") is None: + model_data["chunk_batch"] = 10 is_multimodal = model_data.get("model_type") == "multi_embedding" @@ -248,7 +251,8 @@ async def update_single_model_for_tenant( f"Model {current_display_name} (embedding + multi_embedding) updated successfully") else: # Single model update - current_model_id = existing_models[0]["model_id"] + current_model = existing_models[0] + current_model_id = current_model["model_id"] update_data = {k: v for k, v in model_data.items() if k != "model_id"} update_model_record(current_model_id, update_data, user_id) logging.debug(f"Model {current_display_name} updated successfully") diff --git a/backend/services/model_provider_service.py b/backend/services/model_provider_service.py index 271ad7f99..55d999ea4 100644 --- a/backend/services/model_provider_service.py +++ b/backend/services/model_provider_service.py @@ -165,11 +165,13 @@ async def prepare_model_dict(provider: str, model: dict, model_url: str, model_a # Initialize chunk size variables for all model types; only embeddings use them expected_chunk_size = None maximum_chunk_size = None + chunk_batch = None # For embedding models, apply default values when chunk sizes are null if model["model_type"] in ["embedding", "multi_embedding"]: expected_chunk_size = model.get("expected_chunk_size", DEFAULT_EXPECTED_CHUNK_SIZE) maximum_chunk_size = model.get("maximum_chunk_size", DEFAULT_MAXIMUM_CHUNK_SIZE) + chunk_batch = model.get("chunk_batch", 10) # For ModelEngine provider, extract the host from model's base_url # We'll append the correct path later @@ -185,15 +187,23 @@ async def prepare_model_dict(provider: str, model: dict, model_url: str, model_a # Build the canonical representation using the existing Pydantic schema for # consistency of validation and default handling. + # For embedding/multi_embedding models, max_tokens will be set via connectivity check later, + # so use 0 as placeholder if not provided + model_type = model["model_type"] + is_embedding_type = model_type in ["embedding", "multi_embedding"] + max_tokens_value = model.get( + "max_tokens", 0) if not is_embedding_type else 0 + model_obj = ModelRequest( model_factory=provider, model_name=model_name, - model_type=model["model_type"], + model_type=model_type, api_key=model_api_key, - max_tokens=model["max_tokens"], + max_tokens=max_tokens_value, display_name=model_display_name, expected_chunk_size=expected_chunk_size, - maximum_chunk_size=maximum_chunk_size + maximum_chunk_size=maximum_chunk_size, + chunk_batch=chunk_batch ) model_dict = model_obj.model_dump() diff --git a/backend/services/redis_service.py b/backend/services/redis_service.py index f852b75cf..efd2c0a7b 100644 --- a/backend/services/redis_service.py +++ b/backend/services/redis_service.py @@ -1,6 +1,6 @@ import json import logging -from typing import Dict, Any +from typing import Dict, Any, Optional import redis @@ -22,7 +22,12 @@ def client(self) -> redis.Redis: if self._client is None: if not REDIS_URL: raise ValueError("REDIS_URL environment variable is not set") - self._client = redis.from_url(REDIS_URL, socket_timeout=5, socket_connect_timeout=5) + self._client = redis.from_url( + REDIS_URL, + socket_timeout=5, + socket_connect_timeout=5, + decode_responses=True + ) return self._client @property @@ -35,9 +40,95 @@ def backend_client(self) -> redis.Redis: self._backend_client = redis.from_url(redis_backend_url, socket_timeout=5, socket_connect_timeout=5) return self._backend_client + # ------------------------------------------------------------------ + # Cancellation helpers + # ------------------------------------------------------------------ + + def mark_task_cancelled(self, task_id: str, ttl_hours: int = 24) -> bool: + """ + Mark a Celery task as cancelled in Redis so that long-running + consumers (for example, chunk indexing) can detect the flag + and stop further processing. + """ + if not task_id: + logger.warning("Cannot mark task as cancelled: empty task_id") + return False + try: + cancel_key = f"cancel:{task_id}" + ttl_seconds = ttl_hours * 3600 + self.client.setex(cancel_key, ttl_seconds, "1") + logger.info(f"Marked task {task_id} as cancelled in Redis (key={cancel_key})") + return True + except Exception as exc: + logger.error(f"Failed to mark task {task_id} as cancelled: {exc}") + return False + + def is_task_cancelled(self, task_id: str) -> bool: + """ + Check whether a Celery task has been marked as cancelled. + """ + if not task_id: + return False + try: + cancel_key = f"cancel:{task_id}" + value = self.client.get(cancel_key) + return bool(value) + except Exception as exc: + logger.warning(f"Failed to check cancellation flag for task {task_id}: {exc}") + return False + + # ------------------------------------------------------------------ + # High-level cleanup helpers + # ------------------------------------------------------------------ + + def _cleanup_single_task_related_keys(self, task_id: str) -> int: + """ + Delete all known Redis keys that are related to a specific task. + + This includes: + - Progress info + - Error info + - Cancellation flag + - Chunk cache used by the forward task (dp:{task_id}:chunks) + """ + if not task_id: + return 0 + + deleted_count = 0 + try: + # Keys stored in the main Redis client + progress_key = f"progress:{task_id}" + error_key = f"error:reason:{task_id}" + cancel_key = f"cancel:{task_id}" + + for key in (progress_key, error_key, cancel_key): + try: + deleted = self.client.delete(key) + deleted_count += deleted + if deleted: + logger.debug(f"Deleted task-related key: {key}") + except Exception as exc: + logger.warning(f"Error deleting key {key}: {exc}") + + # Chunk payload is stored in the backend Redis used by Celery + chunk_key = f"dp:{task_id}:chunks" + try: + deleted = self.backend_client.delete(chunk_key) + deleted_count += deleted + if deleted: + logger.debug(f"Deleted chunk cache key: {chunk_key}") + except Exception as exc: + logger.warning(f"Error deleting chunk cache key {chunk_key}: {exc}") + + except Exception as exc: + logger.error(f"Error cleaning up task-related keys for task {task_id}: {exc}") + + return deleted_count + def delete_knowledgebase_records(self, index_name: str) -> Dict[str, Any]: """ - Delete all Redis records related to a specific knowledge base + Delete all Redis records related to a specific knowledge base. + Also marks all related tasks as cancelled to stop ongoing processing. Args: index_name: Name of the knowledge base (index) to clean up @@ -51,14 +142,18 @@ def delete_knowledgebase_records(self, index_name: str) -> Dict[str, Any]: "index_name": index_name, "celery_tasks_deleted": 0, "cache_keys_deleted": 0, + "tasks_cancelled": 0, "total_deleted": 0, "errors": [] } try: # 1. Clean up Celery task results related to this knowledge base + # This also marks tasks as cancelled and cleans up all related keys celery_deleted = self._cleanup_celery_tasks(index_name) result["celery_tasks_deleted"] = celery_deleted + # Count cancelled tasks (approximate, based on processed tasks) + result["tasks_cancelled"] = celery_deleted # Each deleted task was also cancelled # 2. Clean up any cache keys related to this knowledge base cache_deleted = self._cleanup_cache_keys(index_name) @@ -67,7 +162,8 @@ def delete_knowledgebase_records(self, index_name: str) -> Dict[str, Any]: result["total_deleted"] = celery_deleted + cache_deleted logger.info(f"Redis cleanup completed for {index_name}: " - f"Celery tasks: {celery_deleted}, Cache keys: {cache_deleted}") + f"Celery tasks: {celery_deleted}, Cache keys: {cache_deleted}, " + f"Tasks marked as cancelled: {result['tasks_cancelled']}") except Exception as e: error_msg = f"Error during Redis cleanup for {index_name}: {str(e)}" @@ -174,6 +270,7 @@ def _recursively_delete_task_and_parents(self, task_id: str) -> tuple[int, set]: def _cleanup_celery_tasks(self, index_name: str) -> int: """ Clean up Celery task results related to the knowledge base and their parents. + Also marks all related tasks as cancelled before deletion to stop ongoing processing. Args: index_name: Name of the knowledge base @@ -183,11 +280,13 @@ def _cleanup_celery_tasks(self, index_name: str) -> int: """ total_deleted_count = 0 processed_tasks = set() # Track tasks that have been processed to avoid redundant work + task_ids_to_cancel = set() # Collect all task IDs to mark as cancelled try: # Get all Celery task result keys task_keys = self.backend_client.keys('celery-task-meta-*') + # First pass: Collect all task IDs related to this knowledge base for key in task_keys: try: # Get task data @@ -225,9 +324,72 @@ def _cleanup_celery_tasks(self, index_name: str) -> int: key_str = key.decode('utf-8') if isinstance(key, bytes) else key task_id = key_str.replace('celery-task-meta-', '') if task_id not in processed_tasks: + # Collect task ID and its parent chain + # We need to get the parent chain before deleting + task_ids_to_cancel.add(task_id) + # Also get parent chain by reading task data + try: + parent_id = task_info.get('parent_id') + if parent_id: + task_ids_to_cancel.add(parent_id) + except Exception: + pass + + except Exception as e: + logger.warning(f"Error processing task key {key} for cleanup: {str(e)}") + continue + + # Mark all collected task IDs as cancelled BEFORE deleting them + # This ensures ongoing processing tasks will detect cancellation and stop + for task_id in task_ids_to_cancel: + try: + self.mark_task_cancelled(task_id) + logger.info(f"Marked task {task_id} as cancelled for knowledge base {index_name}") + except Exception as e: + logger.warning(f"Failed to mark task {task_id} as cancelled: {str(e)}") + + # Second pass: Delete task records and clean up related keys + for key in task_keys: + try: + task_data = self.backend_client.get(key) + if task_data: + import json + task_info = json.loads(task_data) + result = task_info.get('result', {}) + task_index_name = None + + if isinstance(result, dict): + task_index_name = ( + result.get('index_name') or + task_info.get('index_name') or + result.get('kwargs', {}).get('index_name') + ) + + if task_index_name is None and 'exc_message' in result: + try: + exc_str = str(result['exc_message']) + if '{' in exc_str and '}' in exc_str: + json_part = exc_str[exc_str.find('{'):exc_str.rfind('}')+1] + cleaned_json_part = json_part.replace('\\"', '"') + error_data = json.loads(cleaned_json_part) + task_index_name = error_data.get('index_name') + except (json.JSONDecodeError, TypeError, IndexError): + pass + + if task_index_name == index_name: + key_str = key.decode('utf-8') if isinstance(key, bytes) else key + task_id = key_str.replace('celery-task-meta-', '') + if task_id not in processed_tasks: + # Delete task record and its parent chain deleted, processed_chain = self._recursively_delete_task_and_parents(task_id) total_deleted_count += deleted processed_tasks.update(processed_chain) + # Clean up all related keys (progress, error, chunks) for each task + for tid in processed_chain: + try: + self._cleanup_single_task_related_keys(tid) + except Exception as e: + logger.warning(f"Failed to clean up keys for task {tid}: {str(e)}") except Exception as e: logger.warning(f"Error processing task key {key} for cleanup: {str(e)}") @@ -350,10 +512,23 @@ def _cleanup_document_celery_tasks(self, index_name: str, path_or_url: str) -> i if task_index_name == index_name and task_source == path_or_url: # Recursively delete this task and its parents if task_id not in processed_tasks: + # Mark this task as cancelled so any in-flight + # processing can observe the flag and stop. + try: + self.mark_task_cancelled(task_id) + except Exception as cancel_exc: + logger.warning( + f"Failed to mark task {task_id} as cancelled during document cleanup: {cancel_exc}" + ) + deleted, processed_chain = self._recursively_delete_task_and_parents(task_id) total_deleted_count += deleted processed_tasks.update(processed_chain) + # Clean up all known keys for each task in the chain + for processed_task_id in processed_chain: + self._cleanup_single_task_related_keys(processed_task_id) + except Exception as e: logger.warning(f"Error processing task key {key} for document cleanup: {str(e)}") continue @@ -472,6 +647,128 @@ def ping(self) -> bool: logger.error(f"Redis ping failed: {str(e)}") return False + def save_error_info(self, task_id: str, error_reason: str, ttl_days: int = 30) -> bool: + """ + Save error information to Redis for a specific task + + Args: + task_id: Celery task ID + error_reason: Short error reason summary + ttl_days: Time to live in days (default 30 days) + + Returns: + True if saved successfully, False otherwise + """ + try: + if not task_id: + logger.error("Cannot save error info: task_id is empty") + return False + if not error_reason: + logger.error(f"Cannot save error info for task {task_id}: error_reason is empty") + return False + + ttl_seconds = ttl_days * 24 * 60 * 60 + reason_key = f"error:reason:{task_id}" + + # Save error reason + result = self.client.setex(reason_key, ttl_seconds, error_reason) + + if result: + logger.info(f"Successfully saved error info to Redis for task {task_id}, key: {reason_key}") + # Verify the save by reading it back + verify = self.client.get(reason_key) + if verify: + logger.debug(f"Verified error info saved for task {task_id}: {verify[:100]}...") + else: + logger.warning(f"Failed to verify error info save for task {task_id}") + return True + else: + logger.error(f"Redis setex returned False for task {task_id}") + return False + except Exception as e: + logger.error( + f"Failed to save error info for task {task_id}: {str(e)}", exc_info=True) + return False + + def save_progress_info(self, task_id: str, processed_chunks: int, total_chunks: int, ttl_hours: int = 24) -> bool: + """ + Save progress information to Redis for a specific task + + Args: + task_id: Celery task ID + processed_chunks: Number of chunks processed so far + total_chunks: Total number of chunks to process + ttl_hours: Time to live in hours (default 24 hours) + + Returns: + True if saved successfully, False otherwise + """ + try: + if not task_id: + logger.error("Cannot save progress info: task_id is empty") + return False + + progress_key = f"progress:{task_id}" + progress_data = { + 'processed_chunks': processed_chunks, + 'total_chunks': total_chunks + } + + ttl_seconds = ttl_hours * 3600 + progress_json = json.dumps(progress_data) + self.client.setex( + progress_key, + ttl_seconds, + progress_json + ) + # Use info level for better visibility during debugging + logger.info(f"[REDIS PROGRESS] Saved progress for task {task_id}: {processed_chunks}/{total_chunks} (key: {progress_key}, TTL: {ttl_hours}h)") + return True + except Exception as e: + logger.error(f"Failed to save progress info for task {task_id}: {str(e)}") + return False + + def get_progress_info(self, task_id: str) -> Optional[Dict[str, int]]: + """ + Get progress information for a specific task + + Args: + task_id: Celery task ID + + Returns: + Dict with 'processed_chunks' and 'total_chunks' or None if not found + """ + try: + progress_key = f"progress:{task_id}" + progress_data = self.client.get(progress_key) + if progress_data: + if isinstance(progress_data, bytes): + progress_data = progress_data.decode('utf-8') + return json.loads(progress_data) + return None + except Exception as e: + logger.warning(f"Failed to get progress info for task {task_id}: {str(e)}") + return None + + def get_error_info(self, task_id: str) -> Optional[str]: + """ + Get error reason for a specific task + + Args: + task_id: Celery task ID + + Returns: + Error reason string or None if not found + """ + try: + reason_key = f"error:reason:{task_id}" + reason = self.client.get(reason_key) + # With decode_responses=True, reason is already a string + return reason if reason else None + except Exception as e: + logger.error( + f"Failed to get error info for task {task_id}: {str(e)}") + return None # Global Redis service instance _redis_service = None diff --git a/backend/services/tenant_config_service.py b/backend/services/tenant_config_service.py index e39018d98..30524677c 100644 --- a/backend/services/tenant_config_service.py +++ b/backend/services/tenant_config_service.py @@ -66,3 +66,19 @@ def delete_selected_knowledge_by_index_name(tenant_id: str, user_id: str, index_ return False return True + + +def build_knowledge_name_mapping(tenant_id: str, user_id: str): + """ + Build mapping from user-facing knowledge_name to internal index_name for the selected knowledge bases. + Falls back to using index_name as key when knowledge_name is missing for backward compatibility. + """ + knowledge_info_list = get_selected_knowledge_list( + tenant_id=tenant_id, user_id=user_id) + mapping = {} + for info in knowledge_info_list: + key = info.get("knowledge_name") or info.get("index_name") + value = info.get("index_name") + if key and value: + mapping[key] = value + return mapping diff --git a/backend/services/tool_configuration_service.py b/backend/services/tool_configuration_service.py index ab971da1a..66298c8c5 100644 --- a/backend/services/tool_configuration_service.py +++ b/backend/services/tool_configuration_service.py @@ -25,7 +25,8 @@ from database.user_tenant_db import get_all_tenant_ids from services.file_management_service import get_llm_model from services.vectordatabase_service import get_embedding_model, get_vector_db_core -from services.tenant_config_service import get_selected_knowledge_list +from services.tenant_config_service import get_selected_knowledge_list, build_knowledge_name_mapping +from database.knowledge_db import get_index_name_by_knowledge_name from database.client import minio_client from services.image_service import get_vlm_model @@ -419,7 +420,7 @@ async def initialize_tools_on_startup(): failed_tenants.append(f"{tenant_id} (error: {str(e)})") # Log final results - logger.info(f"Tool initialization completed!") + logger.info("Tool initialization completed!") logger.info(f"Total tools available across all tenants: {total_tools}") logger.info(f"Successfully processed: {successful_tenants}/{len(tenant_ids)} tenants") @@ -607,11 +608,32 @@ def _validate_local_tool( tenant_id=tenant_id, user_id=user_id) index_names = [knowledge_info.get("index_name") for knowledge_info in knowledge_info_list] + name_resolver = build_knowledge_name_mapping( + tenant_id=tenant_id, user_id=user_id) + + # Fallback: if user provided index_names in inputs, try to resolve them even when no selection stored + if (not index_names) and inputs and inputs.get("index_names"): + raw_names = inputs.get("index_names") + if isinstance(raw_names, str): + raw_names = [raw_names] + resolved_indices = [] + for raw in raw_names: + try: + resolved = get_index_name_by_knowledge_name( + raw, tenant_id=tenant_id) + name_resolver[raw] = resolved + resolved_indices.append(resolved) + except Exception: + # If not found as knowledge_name, assume it's already an index_name + resolved_indices.append(raw) + index_names = resolved_indices + embedding_model = get_embedding_model(tenant_id=tenant_id) vdb_core = get_vector_db_core() params = { **instantiation_params, 'index_names': index_names, + 'name_resolver': name_resolver, 'vdb_core': vdb_core, 'embedding_model': embedding_model, } diff --git a/backend/services/vectordatabase_service.py b/backend/services/vectordatabase_service.py index e72b3f9f3..00bb7b5ec 100644 --- a/backend/services/vectordatabase_service.py +++ b/backend/services/vectordatabase_service.py @@ -10,6 +10,7 @@ 4. Health check interface """ import asyncio +import json import logging import os import time @@ -30,12 +31,43 @@ create_knowledge_record, delete_knowledge_record, get_knowledge_record, - update_knowledge_record, get_knowledge_info_by_tenant_id, update_model_name_by_index_name, + update_knowledge_record, + get_knowledge_info_by_tenant_id, + update_model_name_by_index_name, ) from services.redis_service import get_redis_service from utils.config_utils import tenant_config_manager, get_model_name_from_config from utils.file_management_utils import get_all_files_status, get_file_size + +def _update_progress(task_id: str, processed: int, total: int): + """Helper function to update progress in Redis""" + try: + redis_service = get_redis_service() + + # If this task has been marked as cancelled, stop updating progress + # and raise an exception so the caller can abort long-running work. + if redis_service.is_task_cancelled(task_id): + logger.debug( + f"[PROGRESS CALLBACK] Task {task_id} is marked as cancelled; " + f"stopping further indexing work at {processed}/{total}." + ) + raise RuntimeError( + "Indexing cancelled because the task was marked as cancelled.") + + success = redis_service.save_progress_info(task_id, processed, total) + if success: + percentage = processed * 100 // total if total > 0 else 0 + logger.debug( + f"[PROGRESS CALLBACK] Updated progress for task {task_id}: {processed}/{total} ({percentage}%)") + else: + logger.warning( + f"[PROGRESS CALLBACK] Failed to save progress for task {task_id}: {processed}/{total}") + except Exception as e: + logger.warning( + f"[PROGRESS CALLBACK] Exception updating progress for task {task_id}: {str(e)}") + + ALLOWED_CHUNK_FIELDS = { "id", "title", @@ -78,6 +110,23 @@ def get_vector_db_core( raise ValueError(f"Unsupported vector database type: {db_type}") +def _rethrow_or_plain(exc: Exception) -> None: + """ + If the exception message is a JSON dict with error_code, re-raise that JSON as-is. + Otherwise, re-raise the original string (no additional nesting/context). + """ + msg = str(exc) + try: + parsed = json.loads(msg) + except Exception: + raise Exception(msg) + + if isinstance(parsed, dict) and parsed.get("error_code"): + raise Exception(json.dumps(parsed, ensure_ascii=False)) + + raise Exception(msg) + + def check_knowledge_base_exist_impl(index_name: str, vdb_core: VectorDatabaseCore, user_id: str, tenant_id: str) -> dict: """ Check knowledge base existence and handle orphan cases @@ -226,14 +275,10 @@ async def full_delete_knowledge_base(index_name: str, vdb_core: VectorDatabaseCo logger.debug( f"Step 2/4: No files found in index '{index_name}', skipping MinIO deletion.") - # 3. Delete Elasticsearch index and its DB record - logger.debug( - f"Step 3/4: Deleting Elasticsearch index '{index_name}' and its database record.") - delete_index_result = await ElasticSearchService.delete_index(index_name, vdb_core, user_id) - - # 4. Clean up Redis records related to this knowledge base + # 3. Mark all related tasks as cancelled and clean up Redis records BEFORE deleting ES index + # This ensures ongoing indexing tasks will detect cancellation and stop immediately logger.debug( - f"Step 4/4: Cleaning up Redis records for index '{index_name}'.") + f"Step 3/5: Marking all tasks as cancelled and cleaning up Redis records for index '{index_name}'.") redis_cleanup_result = {} try: from services.redis_service import get_redis_service @@ -241,12 +286,18 @@ async def full_delete_knowledge_base(index_name: str, vdb_core: VectorDatabaseCo redis_cleanup_result = redis_service.delete_knowledgebase_records( index_name) logger.debug(f"Redis cleanup for index '{index_name}' completed. " - f"Deleted {redis_cleanup_result['total_deleted']} records.") + f"Deleted {redis_cleanup_result['total_deleted']} records, " + f"marked {redis_cleanup_result.get('tasks_cancelled', 0)} tasks as cancelled.") except Exception as redis_error: logger.error( f"Redis cleanup failed for index '{index_name}': {str(redis_error)}") redis_cleanup_result = {"error": str(redis_error)} + # 4. Delete Elasticsearch index and its DB record + logger.debug( + f"Step 4/5: Deleting Elasticsearch index '{index_name}' and its database record.") + delete_index_result = await ElasticSearchService.delete_index(index_name, vdb_core, user_id) + # Construct final result result = { "status": "success", @@ -305,6 +356,58 @@ def create_index( except Exception as e: raise Exception(f"Error creating index: {str(e)}") + @staticmethod + def create_knowledge_base( + knowledge_name: str, + embedding_dim: Optional[int], + vdb_core: VectorDatabaseCore, + user_id: Optional[str], + tenant_id: Optional[str], + ): + """ + Create a new knowledge base with a user-facing name and an internal Elasticsearch index name. + + For new data: + - Store the user-facing name in knowledge_name column. + - Generate index_name as ``knowledge_id + '-' + uuid`` (digits and lowercase letters only). + - Use generated index_name as the Elasticsearch index name. + + For backward compatibility, legacy callers can still use create_index() directly + with an explicit index_name. + """ + try: + embedding_model = get_embedding_model(tenant_id) + + # Create knowledge record first to obtain knowledge_id and generated index_name + knowledge_data = { + "knowledge_name": knowledge_name, + "knowledge_describe": "", + "user_id": user_id, + "tenant_id": tenant_id, + "embedding_model_name": embedding_model.model if embedding_model else None, + } + record_info = create_knowledge_record(knowledge_data) + index_name = record_info["index_name"] + + # Create Elasticsearch index with generated internal index_name + success = vdb_core.create_index( + index_name, + embedding_dim=embedding_dim + or (embedding_model.embedding_dim if embedding_model else 1024), + ) + if not success: + raise Exception(f"Failed to create index {index_name}") + + return { + "status": "success", + "message": f"Index {index_name} created successfully", + "id": index_name, + "knowledge_id": record_info["knowledge_id"], + "name": record_info.get("knowledge_name", knowledge_name), + } + except Exception as e: + raise Exception(f"Error creating knowledge base: {str(e)}") + @staticmethod async def delete_index( index_name: str = Path(..., @@ -382,6 +485,13 @@ def list_indices( db_record = get_knowledge_info_by_tenant_id(tenant_id=tenant_id) + # Build mapping from index_name to user-facing knowledge_name (fallback to index_name) + index_to_display_name = { + record["index_name"]: record.get( + "knowledge_name") or record["index_name"] + for record in db_record + } + filtered_indices_list = [] model_name_is_none_list = [] for record in db_record: @@ -399,7 +509,7 @@ def list_indices( response = { "indices": indices, - "count": len(indices) + "count": len(indices), } if include_stats: @@ -409,8 +519,11 @@ def list_indices( for index_name in filtered_indices_list: index_stats = indice_stats.get(index_name, {}) stats_info.append({ + # Internal index name (used as ID) "name": index_name, - "stats": index_stats + # User-facing knowledge base name from PostgreSQL (fallback to index_name) + "display_name": index_to_display_name.get(index_name, index_name), + "stats": index_stats, }) if index_name in model_name_is_none_list: update_model_name_by_index_name(index_name, @@ -427,7 +540,8 @@ def index_documents( index_name: str = Path(..., description="Name of the index"), data: List[Dict[str, Any] ] = Body(..., description="Document List to process"), - vdb_core: VectorDatabaseCore = Depends(get_vector_db_core) + vdb_core: VectorDatabaseCore = Depends(get_vector_db_core), + task_id: Optional[str] = None, ): """ Index documents and create vector embeddings, create index if it doesn't exist @@ -517,13 +631,64 @@ def index_documents( } # Index documents (use default batch_size and content_field) + # Get chunk_batch from model config + # First, get tenant_id from knowledge record + knowledge_record = get_knowledge_record({'index_name': index_name}) + tenant_id = knowledge_record.get( + 'tenant_id') if knowledge_record else None + + if tenant_id: + model_config = tenant_config_manager.get_model_config( + key="EMBEDDING_ID", tenant_id=tenant_id) + embedding_batch_size = model_config.get("chunk_batch", 10) + if embedding_batch_size is None: + embedding_batch_size = 10 + else: + # Fallback to default if tenant_id not found + embedding_batch_size = 10 + + # Initialize progress tracking if task_id is provided + if task_id: + try: + redis_service = get_redis_service() + success = redis_service.save_progress_info( + task_id, 0, total_submitted) + if success: + logger.info( + f"[REDIS PROGRESS] Initialized progress tracking for task {task_id}: 0/{total_submitted}") + else: + logger.warning( + f"Failed to initialize progress tracking for task {task_id}") + except Exception as e: + logger.warning( + f"Failed to initialize progress tracking for task {task_id}: {str(e)}") + try: total_indexed = vdb_core.vectorize_documents( index_name=index_name, embedding_model=embedding_model, documents=documents, + embedding_batch_size=embedding_batch_size, + progress_callback=lambda processed, total: _update_progress( + task_id, processed, total) if task_id else None ) + # Update final progress + if task_id: + try: + redis_service = get_redis_service() + success = redis_service.save_progress_info( + task_id, total_indexed, total_submitted) + if success: + logger.info( + f"[REDIS PROGRESS] Updated final progress for task {task_id}: {total_indexed}/{total_submitted}") + else: + logger.warning( + f"[REDIS PROGRESS] Failed to update final progress for task {task_id}") + except Exception as e: + logger.warning( + f"[REDIS PROGRESS] Exception updating final progress for task {task_id}: {str(e)}") + return { "success": True, "message": f"Successfully indexed {total_indexed} documents", @@ -531,14 +696,12 @@ def index_documents( "total_submitted": total_submitted } except Exception as e: - error_msg = str(e) - logger.error(f"Error during indexing: {error_msg}") - raise Exception(f"Error during indexing: {error_msg}") + logger.error(f"Error during indexing: {str(e)}") + _rethrow_or_plain(e) except Exception as e: - error_msg = str(e) - logger.error(f"Error indexing documents: {error_msg}") - raise Exception(f"Error indexing documents: {error_msg}") + logger.error(f"Error indexing documents: {str(e)}") + _rethrow_or_plain(e) @staticmethod async def list_files( @@ -559,15 +722,12 @@ async def list_files( Dictionary containing file list """ try: - files = [] + files_map: Dict[str, Dict[str, Any]] = {} # Get existing files from ES existing_files = vdb_core.get_documents_detail(index_name) # Get unique celery files list and the status of each file celery_task_files = await get_all_files_status(index_name) - # Create a set of path_or_urls from existing files for quick lookup - existing_paths = {file_info.get('path_or_url') - for file_info in existing_files} # For files already stored in ES, add to files list for file_info in existing_files: @@ -579,35 +739,51 @@ async def list_files( except (ValueError, TypeError): utc_create_timestamp = time.time() + # Always re-query chunk count to ensure accuracy (aggregation may be stale) + path_or_url = file_info.get('path_or_url') + chunk_count = file_info.get('chunk_count', 0) + try: + count_result = vdb_core.client.count( + index=index_name, + body={"query": {"term": {"path_or_url": path_or_url}}} + ) + chunk_count = count_result.get("count", chunk_count) + except Exception as count_err: + logger.warning( + f"Failed to get chunk count for {path_or_url}: {count_err}, using aggregation value {chunk_count}") + file_data = { - 'path_or_url': file_info.get('path_or_url'), + 'path_or_url': path_or_url, 'file': file_info.get('filename', ''), 'file_size': file_info.get('file_size', 0), 'create_time': int(utc_create_timestamp * 1000), 'status': "COMPLETED", 'latest_task_id': '', - 'chunk_count': file_info.get('chunk_count', 0) + 'chunk_count': chunk_count, + 'error_reason': None, + 'has_error_info': False } - files.append(file_data) + files_map[path_or_url] = file_data # For files not yet stored in ES (files currently being processed) for path_or_url, status_info in celery_task_files.items(): - # Skip files that are already in existing_files to avoid duplicates - if path_or_url not in existing_paths: - # Ensure status_info is a dictionary - status_dict = status_info if isinstance( - status_info, dict) else {} - - # Get source_type and original_filename, with defaults - source_type = status_dict.get('source_type') if status_dict.get( - 'source_type') else 'minio' - original_filename = status_dict.get('original_filename') - - # Determine the filename - filename = original_filename or ( - os.path.basename(path_or_url) if path_or_url else '') - - # Safely get file size; default to 0 on any error + status_dict = status_info if isinstance( + status_info, dict) else {} + + # Get source_type and original_filename, with defaults + source_type = status_dict.get('source_type') if status_dict.get( + 'source_type') else 'minio' + original_filename = status_dict.get('original_filename') + + # Determine the filename + filename = original_filename or ( + os.path.basename(path_or_url) if path_or_url else '') + + # Safely get file size; default to 0 on any error + file_size = 0 + if path_or_url in files_map: + file_size = files_map[path_or_url].get('file_size', 0) + else: try: file_size = get_file_size( source_type or 'minio', path_or_url) @@ -616,15 +792,65 @@ async def list_files( f"Failed to get file size for '{path_or_url}': {size_err}") file_size = 0 + # Get progress from status_dict first, then try Redis for real-time updates + processed_chunks = status_dict.get('processed_chunks') + total_chunks = status_dict.get('total_chunks') + task_id = status_dict.get('latest_task_id', '') + + # Always try to get latest progress from Redis if task_id exists + # Redis has the most up-to-date progress during vectorization + if task_id: + try: + redis_service = get_redis_service() + progress_info = redis_service.get_progress_info( + task_id) + if progress_info: + redis_processed = progress_info.get( + 'processed_chunks') + redis_total = progress_info.get('total_chunks') + if redis_processed is not None: + processed_chunks = redis_processed + if redis_total is not None: + total_chunks = redis_total + logger.debug( + f"Retrieved progress from Redis for task {task_id}: {processed_chunks}/{total_chunks}") + except Exception as e: + logger.debug( + f"Failed to get progress from Redis for task {task_id}: {str(e)}") + + if path_or_url in files_map: + file_data = files_map[path_or_url] + else: file_data = { 'path_or_url': path_or_url, 'file': filename, 'file_size': file_size, 'create_time': int(time.time() * 1000), - 'status': status_dict.get('state', 'UNKNOWN'), - 'latest_task_id': status_dict.get('latest_task_id', '') + 'chunk_count': 0, + 'error_reason': None, + 'has_error_info': False } - files.append(file_data) + files_map[path_or_url] = file_data + + file_data['status'] = status_dict.get('state', file_data.get( + 'status', 'UNKNOWN')) + file_data['latest_task_id'] = task_id + file_data['processed_chunk_num'] = processed_chunks + file_data['total_chunk_num'] = total_chunks + + # Get error reason for failed documents + if task_id and status_dict.get('state') in ['PROCESS_FAILED', 'FORWARD_FAILED']: + try: + redis_service = get_redis_service() + error_reason = redis_service.get_error_info(task_id) + if error_reason: + file_data['error_reason'] = error_reason + file_data['has_error_info'] = True + except Exception as e: + logger.debug( + f"Failed to get error info for task {task_id}: {str(e)}") + + files = list(files_map.values()) # Unified chunks processing for all files if include_chunks: @@ -673,15 +899,46 @@ async def list_files( }) file_data['chunks'] = chunks - file_data['chunk_count'] = len(chunks) + # Get accurate chunk count using count query instead of len(chunks) + # because msearch may have size limits + try: + count_result = vdb_core.client.count( + index=index_name, + body={ + "query": {"term": {"path_or_url": file_path}}} + ) + file_data['chunk_count'] = count_result.get( + "count", len(chunks)) + except Exception as count_err: + logger.warning( + f"Failed to get chunk count for {file_path}: {count_err}, using len(chunks)") + file_data['chunk_count'] = len(chunks) except Exception as e: logger.error( f"Error during msearch for chunks: {str(e)}") else: + # When include_chunks=False, ensure chunk_count is accurate for completed files for file_data in files: file_data['chunks'] = [] - file_data['chunk_count'] = file_data.get('chunk_count', 0) + if file_data.get('status') == "COMPLETED": + # Always re-query chunk count for completed files to ensure accuracy + try: + count_result = vdb_core.client.count( + index=index_name, + body={ + "query": {"term": {"path_or_url": file_data.get('path_or_url')}}} + ) + file_data['chunk_count'] = count_result.get( + "count", 0) + except Exception as count_err: + logger.warning( + f"Failed to get chunk count for {file_data.get('path_or_url')}: {count_err}") + file_data['chunk_count'] = file_data.get( + 'chunk_count', 0) + else: + file_data['chunk_count'] = file_data.get( + 'chunk_count', 0) return {"files": files} @@ -760,6 +1017,9 @@ async def summary_index_name(self, StreamingResponse containing the generated summary """ try: + if not tenant_id: + raise Exception("Tenant ID is required for summary generation.") + from utils.document_vector_utils import ( process_documents_for_clustering, kmeans_cluster_documents, @@ -767,9 +1027,6 @@ async def summary_index_name(self, merge_cluster_summaries ) - if not tenant_id: - raise Exception("Tenant ID is required for summary generation.") - # Use new Map-Reduce approach sample_count = min(batch_size // 5, 200) # Sample reasonable number of documents @@ -820,7 +1077,7 @@ async def generate_summary(): for char in final_summary: yield f"data: {{\"status\": \"success\", \"message\": \"{char}\"}}\n\n" await asyncio.sleep(0.01) - yield f"data: {{\"status\": \"completed\"}}\n\n" + yield "data: {\"status\": \"completed\"}\n\n" except Exception as e: yield f"data: {{\"status\": \"error\", \"message\": \"{e}\"}}\n\n" diff --git a/backend/utils/config_utils.py b/backend/utils/config_utils.py index bef38352b..67b78283c 100644 --- a/backend/utils/config_utils.py +++ b/backend/utils/config_utils.py @@ -220,8 +220,9 @@ def update_single_config(self, tenant_id: str | None = None, key: str | None = N update_config_by_tenant_config_id_and_data( existing_config["tenant_config_id"], update_data) - # Clear cache for this tenant after updating config - # self.clear_cache(tenant_id) + # Clear cache for this tenant after updating config so that + # subsequent reads immediately see the latest configuration + self.clear_cache(tenant_id) return def clear_cache(self, tenant_id: str | None = None): diff --git a/backend/utils/document_vector_utils.py b/backend/utils/document_vector_utils.py index 7d8e5b112..1b35c207f 100644 --- a/backend/utils/document_vector_utils.py +++ b/backend/utils/document_vector_utils.py @@ -14,12 +14,16 @@ import numpy as np from jinja2 import Template, StrictUndefined -from nexent.vector_database.base import VectorDatabaseCore from sklearn.cluster import KMeans from sklearn.metrics import silhouette_score from sklearn.metrics.pairwise import cosine_similarity from consts.const import LANGUAGE +from database.model_management_db import get_model_by_model_id +from nexent.core.utils.observer import MessageObserver +from nexent.core.models import OpenAIModel +from nexent.vector_database.base import VectorDatabaseCore +from utils.llm_utils import call_llm_for_system_prompt from utils.prompt_template_utils import ( get_document_summary_prompt_template, get_cluster_summary_reduce_prompt_template, @@ -568,37 +572,22 @@ def summarize_document(document_content: str, filename: str, language: str = LAN # Call LLM if model_id and tenant_id are provided if model_id and tenant_id: - from smolagents import OpenAIServerModel - from database.model_management_db import get_model_by_model_id - from utils.config_utils import get_model_name_from_config - from consts.const import MESSAGE_ROLE - + # Get model configuration llm_model_config = get_model_by_model_id(model_id=model_id, tenant_id=tenant_id) if not llm_model_config: logger.warning(f"No model configuration found for model_id: {model_id}, tenant_id: {tenant_id}") return f"[Document Summary: {filename}] (max {max_words} words) - Content: {document_content[:200]}..." - - # Create LLM instance - llm = OpenAIServerModel( - model_id=get_model_name_from_config(llm_model_config) if llm_model_config else "", - api_base=llm_model_config.get("base_url", ""), - api_key=llm_model_config.get("api_key", ""), - temperature=0.3, - top_p=0.95 + + document_summary = call_llm_for_system_prompt( + model_id=model_id, + user_prompt=user_prompt, + system_prompt=system_prompt, + callback=None, + tenant_id=tenant_id ) - - # Build messages - messages = [ - {"role": MESSAGE_ROLE["SYSTEM"], "content": system_prompt}, - {"role": MESSAGE_ROLE["USER"], "content": user_prompt} - ] - - # Call LLM, allow more tokens for generation - response = llm(messages, max_tokens=max_words * 2) - if not response or not response.content: - return "" - return response.content.strip() + + return (document_summary or "").strip() else: # Fallback to placeholder if no model configuration logger.warning("No model_id or tenant_id provided, using placeholder summary") @@ -642,10 +631,6 @@ def summarize_cluster(document_summaries: List[str], language: str = LANGUAGE["Z # Call LLM if model_id and tenant_id are provided if model_id and tenant_id: - from smolagents import OpenAIServerModel - from database.model_management_db import get_model_by_model_id - from utils.config_utils import get_model_name_from_config - from consts.const import MESSAGE_ROLE # Get model configuration llm_model_config = get_model_by_model_id(model_id=model_id, tenant_id=tenant_id) @@ -654,25 +639,15 @@ def summarize_cluster(document_summaries: List[str], language: str = LANGUAGE["Z return f"[Cluster Summary] (max {max_words} words) - Based on {len(document_summaries)} documents" # Create LLM instance - llm = OpenAIServerModel( - model_id=get_model_name_from_config(llm_model_config) if llm_model_config else "", - api_base=llm_model_config.get("base_url", ""), - api_key=llm_model_config.get("api_key", ""), - temperature=0.3, - top_p=0.95 + cluster_summary = call_llm_for_system_prompt( + model_id=model_id, + user_prompt=user_prompt, + system_prompt=system_prompt, + callback=None, + tenant_id=tenant_id ) - - # Build messages - messages = [ - {"role": MESSAGE_ROLE["SYSTEM"], "content": system_prompt}, - {"role": MESSAGE_ROLE["USER"], "content": user_prompt} - ] - - # Call LLM - response = llm(messages, max_tokens=max_words * 2) # Allow more tokens for generation - if not response or not response.content: - return "" - return response.content.strip() + + return (cluster_summary or "").strip() else: # Fallback to placeholder if no model configuration logger.warning("No model_id or tenant_id provided, using placeholder summary") diff --git a/backend/utils/file_management_utils.py b/backend/utils/file_management_utils.py index 2431af737..2a1aa3801 100644 --- a/backend/utils/file_management_utils.py +++ b/backend/utils/file_management_utils.py @@ -151,7 +151,8 @@ async def get_all_files_status(index_name: str): logger.warning(f"No tasks found for index '{index_name}'") return {} - # Dictionary to store file statuses: {path_or_url: {process_state, forward_state, timestamps}} + # Dictionary to store file statuses: + # {path_or_url: {process_state, forward_state, timestamps, progress fields}} file_states = {} for task_info in tasks_list: # No need to check index_name since get_index_tasks already filters by it @@ -172,7 +173,10 @@ async def get_all_files_status(index_name: str): 'latest_forward_created_at': 0, 'latest_task_id': '', 'original_filename': '', - 'source_type': '' + 'source_type': '', + # Optional progress fields provided by data-process service + 'processed_chunks': None, + 'total_chunks': None, } file_state = file_states[task_path_or_url] # Process task @@ -182,6 +186,11 @@ async def get_all_files_status(index_name: str): file_state['latest_task_id'] = task_id file_state['original_filename'] = original_filename file_state['source_type'] = source_type + # Update optional progress metrics if present + file_state['processed_chunks'] = task_info.get( + 'processed_chunks', file_state.get('processed_chunks')) + file_state['total_chunks'] = task_info.get( + 'total_chunks', file_state.get('total_chunks')) # Forward task elif task_name == 'forward' and task_created_at > file_state['latest_forward_created_at']: file_state['latest_forward_created_at'] = task_created_at @@ -189,6 +198,11 @@ async def get_all_files_status(index_name: str): file_state['latest_task_id'] = task_id file_state['original_filename'] = original_filename file_state['source_type'] = source_type + # Forward tasks may also carry progress metrics + file_state['processed_chunks'] = task_info.get( + 'processed_chunks', file_state.get('processed_chunks')) + file_state['total_chunks'] = task_info.get( + 'total_chunks', file_state.get('total_chunks')) result = {} for path_or_url, file_state in file_states.items(): # Call remote state conversion API so this service no longer depends on Celery @@ -196,11 +210,44 @@ async def get_all_files_status(index_name: str): process_celery_state=file_state['process_state'] or '', forward_celery_state=file_state['forward_state'] or '' ) + # Try to get progress from Redis - always check Redis for real-time progress + # especially when task is in progress (FORWARDING or PROCESSING) + processed_chunks = file_state.get('processed_chunks') + total_chunks = file_state.get('total_chunks') + task_id = file_state['latest_task_id'] or '' + + # Always try to get latest progress from Redis if task_id exists + # Redis has the most up-to-date progress during vectorization + if task_id: + try: + from services.redis_service import get_redis_service + redis_service = get_redis_service() + progress_info = redis_service.get_progress_info(task_id) + if progress_info: + # Use Redis progress as primary source (it's updated in real-time) + redis_processed = progress_info.get('processed_chunks') + redis_total = progress_info.get('total_chunks') + if redis_processed is not None: + processed_chunks = redis_processed + if redis_total is not None: + total_chunks = redis_total + logger.debug( + f"Retrieved progress from Redis for task {task_id}: {processed_chunks}/{total_chunks}") + else: + logger.debug( + f"No progress info in Redis for task {task_id}, using task state values: {processed_chunks}/{total_chunks}") + except Exception as e: + logger.debug( + f"Failed to get progress from Redis for task {task_id}: {str(e)}") + result[path_or_url] = { 'state': custom_state, - 'latest_task_id': file_state['latest_task_id'] or '', + 'latest_task_id': task_id, 'original_filename': file_state['original_filename'] or '', - 'source_type': file_state['source_type'] or '' + 'source_type': file_state['source_type'] or '', + # Expose optional progress metrics for downstream consumers + 'processed_chunks': processed_chunks, + 'total_chunks': total_chunks, } return result except Exception as e: diff --git a/backend/utils/llm_utils.py b/backend/utils/llm_utils.py index 7335cfe2b..2e1590498 100644 --- a/backend/utils/llm_utils.py +++ b/backend/utils/llm_utils.py @@ -1,10 +1,10 @@ import logging from typing import Callable, List, Optional -from smolagents import OpenAIServerModel - from consts.const import MESSAGE_ROLE, THINK_END_PATTERN, THINK_START_PATTERN from database.model_management_db import get_model_by_model_id +from nexent.core.utils.observer import MessageObserver +from nexent.core.models import OpenAIModel from utils.config_utils import get_model_name_from_config logger = logging.getLogger("llm_utils") @@ -44,7 +44,7 @@ def call_llm_for_system_prompt( """ llm_model_config = get_model_by_model_id(model_id=model_id, tenant_id=tenant_id) - llm = OpenAIServerModel( + llm = OpenAIModel( model_id=get_model_name_from_config(llm_model_config) if llm_model_config else "", api_base=llm_model_config.get("base_url", ""), api_key=llm_model_config.get("api_key", ""), diff --git a/doc/docs/.vitepress/config.mts b/doc/docs/.vitepress/config.mts index 014eced15..059937f09 100644 --- a/doc/docs/.vitepress/config.mts +++ b/doc/docs/.vitepress/config.mts @@ -1,4 +1,4 @@ -// https://vitepress.dev/reference/site-config +// https://vitepress.dev/reference/site-config import { defineConfig } from "vitepress"; export default defineConfig({ @@ -33,7 +33,7 @@ export default defineConfig({ ], sidebar: [ { - text: "Getting Started", + text: "Overview", items: [ { text: "Overview", link: "/en/getting-started/overview" }, { text: "Key Features", link: "/en/getting-started/features" }, @@ -41,15 +41,33 @@ export default defineConfig({ text: "Software Architecture", link: "/en/getting-started/software-architecture", }, + ], + }, + { + text: "Quick Start", + items: [ { text: "Installation & Deployment", - link: "/en/getting-started/installation", + link: "/en/quick-start/installation", + }, + { + text: "Upgrade Guide", + link: "/en/quick-start/upgrade-guide", + }, + { text: "FAQ", link: "/en/quick-start/faq" }, + ], + }, + { + text: "Developer Guide", + items: [ + { + text: "Overview", + link: "/en/developer-guide/overview", }, { - text: "Development Guide", - link: "/en/getting-started/development-guide", + text: "Environment Preparation", + link: "/en/developer-guide/environment-setup", }, - { text: "FAQ", link: "/en/getting-started/faq" }, ], }, { @@ -83,14 +101,12 @@ export default defineConfig({ { text: "Local Tools", items: [ - { - text: "Local Tools Overview", - link: "/en/user-guide/local-tools/", - }, - { - text: "Terminal Tool", - link: "/en/user-guide/local-tools/terminal-tool", - }, + { text: "Overview", link: "/en/user-guide/local-tools/" }, + { text: "File Tools", link: "/en/user-guide/local-tools/file-tools" }, + { text: "Email Tools", link: "/en/user-guide/local-tools/email-tools" }, + { text: "Search Tools", link: "/en/user-guide/local-tools/search-tools" }, + { text: "Multimodal Tools", link: "/en/user-guide/local-tools/multimodal-tools" }, + { text: "Terminal Tool", link: "/en/user-guide/local-tools/terminal-tool" }, ], }, ], @@ -98,17 +114,13 @@ export default defineConfig({ { text: "SDK Documentation", items: [ - { text: "SDK Overview", link: "/en/sdk/overview" }, + { text: "Overview", link: "/en/sdk/overview" }, { text: "Basic Usage", link: "/en/sdk/basic-usage" }, { text: "Features Explained", link: "/en/sdk/features" }, { text: "Core Modules", items: [ { text: "Agents", link: "/en/sdk/core/agents" }, - { - text: "Run agent with agent_run", - link: "/en/sdk/core/agent-run", - }, { text: "Tools", link: "/en/sdk/core/tools" }, { text: "Models", link: "/en/sdk/core/models" }, ], @@ -121,17 +133,21 @@ export default defineConfig({ { text: "Frontend Development", items: [ - { text: "Frontend Overview", link: "/en/frontend/overview" }, + { text: "Overview", link: "/en/frontend/overview" }, ], }, { text: "Backend Development", items: [ - { text: "Backend Overview", link: "/en/backend/overview" }, + { text: "Overview", link: "/en/backend/overview" }, { text: "API Reference", link: "/en/backend/api-reference" }, { text: "Tools Integration", items: [ + { + text: "Nexent Tools", + link: "/en/backend/tools/nexent-native", + }, { text: "LangChain Tools", link: "/en/backend/tools/langchain", @@ -143,6 +159,10 @@ export default defineConfig({ text: "Prompt Development", link: "/en/backend/prompt-development", }, + { + text: "Version Management", + link: "/en/backend/version-management", + }, ], }, { @@ -156,33 +176,20 @@ export default defineConfig({ items: [ { text: "Docker Build", link: "/en/deployment/docker-build" }, { text: "Dev Container", link: "/en/deployment/devcontainer" }, - { text: "Upgrade Guide", link: "/en/deployment/upgrade-guide" }, ], }, { text: "MCP Ecosystem", items: [ { text: "Overview", link: "/en/mcp-ecosystem/overview" }, - { - text: "MCP Server Development", - link: "/en/mcp-ecosystem/mcp-server-development", - }, + { text: "MCP Recommendations", link: "/en/mcp-ecosystem/mcp-recommendations" }, { text: "Use Cases", link: "/en/mcp-ecosystem/use-cases" }, ], }, - { - text: "Version Management", - items: [ - { - text: "Version Management Guide", - link: "/en/version/version-management", - }, - ], - }, { text: "Testing", items: [ - { text: "Testing Overview", link: "/en/testing/overview" }, + { text: "Overview", link: "/en/testing/overview" }, { text: "Backend Testing", link: "/en/testing/backend" }, ], }, @@ -197,7 +204,6 @@ export default defineConfig({ { text: "Code of Conduct", link: "/en/code-of-conduct" }, { text: "Security Policy", link: "/en/security" }, { text: "Core Contributors", link: "/en/contributors" }, - { text: "Known Issues", link: "/en/known-issues" }, { text: "License", link: "/en/license" }, ], }, @@ -222,7 +228,7 @@ export default defineConfig({ ], sidebar: [ { - text: "快速开始", + text: "概览", items: [ { text: "项目概览", link: "/zh/getting-started/overview" }, { text: "核心特性", link: "/zh/getting-started/features" }, @@ -230,12 +236,30 @@ export default defineConfig({ text: "软件架构", link: "/zh/getting-started/software-architecture", }, - { text: "安装部署", link: "/zh/getting-started/installation" }, + ], + }, + { + text: "快速开始", + items: [ + { text: "安装部署", link: "/zh/quick-start/installation" }, + { + text: "升级指导", + link: "/zh/quick-start/upgrade-guide", + }, + { text: "常见问题", link: "/zh/quick-start/faq" }, + ], + }, + { + text: "开发者指南", + items: [ + { + text: "概览", + link: "/zh/developer-guide/overview", + }, { - text: "开发指南", - link: "/zh/getting-started/development-guide", + text: "环境准备", + link: "/zh/developer-guide/environment-setup", }, - { text: "常见问题", link: "/zh/getting-started/faq" }, ], }, { @@ -260,11 +284,12 @@ export default defineConfig({ { text: "本地工具", items: [ - { text: "本地工具概览", link: "/zh/user-guide/local-tools/" }, - { - text: "Terminal工具", - link: "/zh/user-guide/local-tools/terminal-tool", - }, + { text: "概览", link: "/zh/user-guide/local-tools/" }, + { text: "文件工具", link: "/zh/user-guide/local-tools/file-tools" }, + { text: "邮件工具", link: "/zh/user-guide/local-tools/email-tools" }, + { text: "搜索工具", link: "/zh/user-guide/local-tools/search-tools" }, + { text: "多模态工具", link: "/zh/user-guide/local-tools/multimodal-tools" }, + { text: "终端工具", link: "/zh/user-guide/local-tools/terminal-tool" }, ], }, ], @@ -272,17 +297,13 @@ export default defineConfig({ { text: "SDK 文档", items: [ - { text: "SDK 概览", link: "/zh/sdk/overview" }, + { text: "概览", link: "/zh/sdk/overview" }, { text: "基本使用", link: "/zh/sdk/basic-usage" }, { text: "特性详解", link: "/zh/sdk/features" }, { text: "核心模块", items: [ { text: "智能体模块", link: "/zh/sdk/core/agents" }, - { - text: "使用 agent_run 运行智能体", - link: "/zh/sdk/core/agent-run", - }, { text: "工具模块", link: "/zh/sdk/core/tools" }, { text: "模型模块", link: "/zh/sdk/core/models" }, ], @@ -294,16 +315,20 @@ export default defineConfig({ }, { text: "前端开发", - items: [{ text: "前端概览", link: "/zh/frontend/overview" }], + items: [{ text: "概览", link: "/zh/frontend/overview" }], }, { text: "后端开发", items: [ - { text: "后端概览", link: "/zh/backend/overview" }, + { text: "概览", link: "/zh/backend/overview" }, { text: "API 文档", link: "/zh/backend/api-reference" }, { text: "工具集成", items: [ + { + text: "Nexent 工具", + link: "/zh/backend/tools/nexent-native", + }, { text: "LangChain 工具", link: "/zh/backend/tools/langchain", @@ -312,41 +337,32 @@ export default defineConfig({ ], }, { text: "提示词开发", link: "/zh/backend/prompt-development" }, + { text: "版本管理", link: "/zh/backend/version-management" }, ], }, { text: "文档开发", - items: [{ text: "文档开发指南", link: "/zh/docs-development" }], + items: [{ text: "开发指南", link: "/zh/docs-development" }], }, { text: "容器构建与容器化开发", items: [ - { text: "Docker 构建", link: "/zh/deployment/docker-build" }, - { text: "开发容器", link: "/zh/deployment/devcontainer" }, - { text: "升级指导", link: "/zh/deployment/upgrade-guide" }, + { text: "镜像构建", link: "/zh/deployment/docker-build" }, + { text: "容器开发", link: "/zh/deployment/devcontainer" }, ], }, { text: "MCP 生态系统", items: [ { text: "概览", link: "/zh/mcp-ecosystem/overview" }, - { - text: "MCP 服务开发", - link: "/zh/mcp-ecosystem/mcp-server-development", - }, + { text: "MCP 推荐", link: "/zh/mcp-ecosystem/mcp-recommendations" }, { text: "用例场景", link: "/zh/mcp-ecosystem/use-cases" }, ], }, - { - text: "版本信息管理", - items: [ - { text: "版本管理指南", link: "/zh/version/version-management" }, - ], - }, { text: "测试", items: [ - { text: "测试概览", link: "/zh/testing/overview" }, + { text: "概览", link: "/zh/testing/overview" }, { text: "后端测试", link: "/zh/testing/backend" }, ], }, @@ -358,7 +374,6 @@ export default defineConfig({ { text: "行为准则", link: "/zh/code-of-conduct" }, { text: "安全政策", link: "/zh/security" }, { text: "核心贡献者", link: "/zh/contributors" }, - { text: "已知问题", link: "/zh/known-issues" }, { text: "许可证", link: "/zh/license" }, ], }, diff --git a/doc/docs/en/backend/api-reference.md b/doc/docs/en/backend/api-reference.md index 2f3118433..921b76122 100644 --- a/doc/docs/en/backend/api-reference.md +++ b/doc/docs/en/backend/api-reference.md @@ -1,1178 +1,7 @@ -# Nexent Community API Documentation +# Backend API Reference -This document provides a comprehensive overview of all API endpoints available in the Nexent Community backend. +## 🔗 Access API Docs -## Table of Contents -1. [Base App](#base-app) -2. [Agent App](#agent-app) -3. [Config Sync App](#config-sync-app) -4. [Conversation Management App](#conversation-management-app) -5. [Data Process App](#data-process-app) -6. [Elasticsearch App](#elasticsearch-app) -7. [ME Model Management App](#me-model-management-app) -8. [Model Management App](#model-management-app) -9. [Proxy App](#proxy-app) -10. [File Management App](#file-management-app) -11. [Voice App](#voice-app) +The backend API reference is maintained in Apifox. Please visit the live documentation here: -## Base App - -The base app serves as the main FastAPI application that includes all other routers and provides global exception handling. - -### Global Exception Handlers -- `HTTPException`: Returns JSON response with error message -- `Exception`: Returns generic 500 error response - -## Agent App - -### Endpoints - -#### POST /api/agent/run -Executes an agent with the provided request. - -**Request Body:** -```json -{ - "query": "string", - "history": "array", - "minio_files": "array" -} -``` - -**Response:** -- Streaming response with SSE (Server-Sent Events) -- Returns agent's responses in real-time - -#### POST /api/agent/reload_config -Manually triggers configuration reload. - -**Response:** -- Success/failure status - -## Config Sync App - -### Endpoints - -#### POST /api/config/save_config -Saves configuration to environment variables. - -**Request Body:** -```json -{ - "app": { - "name": "string", - "description": "string", - "icon": { - "type": "string", - "avatarUri": "string", - "customUrl": "string" - } - }, - "models": { - "llm": { - "name": "string", - "displayName": "string", - "apiConfig": { - "apiKey": "string", - "modelUrl": "string" - } - }, - // ... other model configurations - }, - "data": { - "selectedKbNames": "array", - "selectedKbModels": "array", - "selectedKbSources": "array" - } -} -``` - -**Response:** -```json -{ - "message": "Configuration saved successfully", - "status": "saved" -} -``` - -#### GET /api/config/load_config -Loads configuration from environment variables. - -**Response:** -```json -{ - "config": { - "app": { - "name": "string", - "description": "string", - "icon": { - "type": "string", - "avatarUri": "string", - "customUrl": "string" - } - }, - "models": { - "llm": { - "name": "string", - "displayName": "string", - "apiConfig": { - "apiKey": "string", - "modelUrl": "string" - } - }, - // ... other model configurations - }, - "data": { - "selectedKbNames": ["string"], - "selectedKbModels": ["string"], - "selectedKbSources": ["string"] - } - } -} -``` - -## Conversation Management App - -### Endpoints - -#### PUT /api/conversation/create -Creates a new conversation. - -**Request Body:** -```json -{ - "title": "string" -} -``` - -**Response:** -```json -{ - "code": 0, - "message": "success", - "data": { - "conversation_id": "string", - "conversation_title": "string", - "create_time": "number", - "update_time": "number" - } -} -``` - -#### GET /api/conversation/list -Gets all conversations. - -**Response:** -```json -{ - "code": 0, - "message": "success", - "data": [ - { - "conversation_id": "string", - "conversation_title": "string", - "create_time": "number", - "update_time": "number" - } - ] -} -``` - -#### POST /api/conversation/rename -Renames a conversation. - -**Request Body:** -```json -{ - "conversation_id": "number", - "name": "string" -} -``` - -**Response:** -```json -{ - "code": 0, - "message": "success", - "data": true -} -``` - -#### DELETE /api/conversation/{conversation_id} -Deletes a conversation. - -**Response:** -```json -{ - "code": 0, - "message": "success", - "data": true -} -``` - -#### GET /api/conversation/{conversation_id} -Gets conversation history. - -**Response:** -```json -{ - "code": 0, - "message": "success", - "data": { - "conversation_id": "string", - "create_time": "number", - "message": [ - { - "role": "string", - "message": "array", - "message_id": "string", - "opinion_flag": "string", - "picture": "array", - "search": "array" - } - ] - } -} -``` - -#### POST /api/conversation/sources -Gets message source information. - -**Request Body:** -```json -{ - "conversation_id": "number", - "message_id": "string", - "type": "string" -} -``` - -**Response:** -```json -{ - "code": 0, - "message": "success", - "data": { - "searches": [ - { - "title": "string", - "text": "string", - "source_type": "string", - "url": "string", - "filename": "string", - "published_date": "string", - "score": "number", - "score_details": { - "accuracy": "number", - "semantic": "number" - } - } - ], - "images": ["string"] - } -} -``` - -#### POST /api/conversation/generate_title -Generates conversation title. - -**Request Body:** -```json -{ - "conversation_id": "number", - "history": [ - { - "role": "string", - "content": "string" - } - ] -} -``` - -**Response:** -```json -{ - "code": 0, - "message": "success", - "data": "string" -} -``` - -#### POST /api/conversation/message/update_opinion -Updates message like/dislike status. - -**Request Body:** -```json -{ - "message_id": "string", - "opinion": "string" -} -``` - -**Response:** -```json -{ - "code": 0, - "message": "success", - "data": true -} -``` - -## Data Process App - -### Endpoints - -#### POST /api/tasks -Creates a new data processing task. - -**Request Body:** -```json -{ - "source": "string", - "source_type": "string", - "chunking_strategy": "string", - "index_name": "string", - "additional_params": { - "key": "value" - } -} -``` - -**Response:** -```json -{ - "task_id": "string" -} -``` - -#### POST /api/tasks/batch -Creates a batch of data processing tasks. - -**Request Body:** -```json -{ - "sources": [ - { - "source": "string", - "source_type": "string", - "chunking_strategy": "string", - "index_name": "string", - "additional_params": { - "key": "value" - } - } - ] -} -``` - -**Response:** -```json -{ - "task_ids": ["string"] -} -``` - -#### GET /api/tasks/{task_id} -Gets task status. - -**Response:** -```json -{ - "id": "string", - "status": "string", - "created_at": "string", - "updated_at": "string", - "error": "string" -} -``` - -#### GET /api/tasks -Lists all tasks. - -**Response:** -```json -{ - "tasks": [ - { - "id": "string", - "status": "string", - "created_at": "string", - "updated_at": "string", - "error": "string" - } - ] -} -``` - -#### GET /api/tasks/indices/{index_name}/tasks -Gets all active tasks for a specific index. - -**Response:** -```json -{ - "index_name": "string", - "files": [ - { - "path_or_url": "string", - "status": "string" - } - ] -} -``` - -#### GET /api/tasks/{task_id}/details -Gets task status and results. - -**Response:** -```json -{ - "id": "string", - "status": "string", - "created_at": "string", - "updated_at": "string", - "error": "string", - "results": "object" -} -``` - -## Elasticsearch App - -### Endpoints - -#### POST /api/indices/{index_name} -Creates a new vector index. - -**Parameters:** -- `index_name`: Name of the index -- `embedding_dim`: Optional dimension of embedding vectors - -**Response:** -```json -{ - "status": "success", - "message": "string", - "embedding_dim": "number" -} -``` - -#### DELETE /api/indices/{index_name} -Deletes an index. - -**Response:** -```json -{ - "status": "success", - "message": "string" -} -``` - -#### GET /api/indices -Lists all indices. - -**Parameters:** -- `pattern`: Pattern to match index names -- `include_stats`: Whether to include index stats - -**Response:** -```json -{ - "indices": ["string"], - "count": "number", - "indices_info": [ - { - "name": "string", - "stats": { - "docs": "number", - "size": "string" - } - } - ] -} -``` - -#### GET /api/indices/{index_name}/info -Gets index information. - -**Parameters:** -- `include_files`: Whether to include file list -- `include_chunks`: Whether to include text chunks - -**Response:** -```json -{ - "base_info": { - "docs": "number", - "size": "string" - }, - "search_performance": { - "query_time": "number", - "hits": "number" - }, - "fields": { - "field_name": { - "type": "string" - } - }, - "files": [ - { - "path_or_url": "string", - "file": "string", - "file_size": "number", - "create_time": "string", - "status": "string", - "chunks": [ - { - "id": "string", - "title": "string", - "content": "string", - "create_time": "string" - } - ], - "chunks_count": "number" - } - ] -} -``` - -#### POST /api/indices/{index_name}/documents -Indexes documents. - -**Request Body:** -```json -{ - "task_id": "string", - "results": [ - { - "metadata": { - "filename": "string", - "title": "string", - "languages": ["string"], - "author": "string", - "date": "string", - "file_size": "number", - "creation_date": "string" - }, - "source": "string", - "text": "string", - "source_type": "string" - } - ] -} -``` - -**Response:** -```json -{ - "success": true, - "message": "string", - "total_indexed": "number", - "total_submitted": "number" -} -``` - -#### DELETE /api/indices/{index_name}/documents -Deletes documents. - -**Request Body:** -```json -{ - "path_or_url": "string" -} -``` - -**Response:** -```json -{ - "status": "success", - "deleted_count": "number" -} -``` - -#### POST /api/indices/search/accurate -Performs accurate search. - -**Request Body:** -```json -{ - "query": "string", - "index_names": ["string"], - "top_k": "number" -} -``` - -**Response:** -```json -{ - "results": [ - { - "id": "string", - "title": "string", - "content": "string", - "score": "number", - "index": "string" - } - ], - "total": "number", - "query_time_ms": "number" -} -``` - -#### POST /api/indices/search/semantic -Performs semantic search. - -**Request Body:** -```json -{ - "query": "string", - "index_names": ["string"], - "top_k": "number" -} -``` - -**Response:** -```json -{ - "results": [ - { - "id": "string", - "title": "string", - "content": "string", - "score": "number", - "index": "string" - } - ], - "total": "number", - "query_time_ms": "number" -} -``` - -#### POST /api/indices/search/hybrid -Performs hybrid search. - -**Request Body:** -```json -{ - "query": "string", - "index_names": ["string"], - "top_k": "number", - "weight_accurate": "number" -} -``` - -**Response:** -```json -{ - "results": [ - { - "id": "string", - "title": "string", - "content": "string", - "score": "number", - "index": "string", - "score_details": { - "accurate": "number", - "semantic": "number" - } - } - ], - "total": "number", - "query_time_ms": "number" -} -``` - -#### GET /api/indices/health -Checks API and Elasticsearch health. - -**Response:** -```json -{ - "status": "healthy", - "elasticsearch": "connected", - "indices_count": "number" -} -``` - -## ME Model Management App - -### Endpoints - -#### GET /api/me/model/list -Gets list of ME models. - -**Request Body:** -```json -{ - "type": "string", - "timeout": "number" -} -``` - -**Response:** -```json -{ - "code": 200, - "message": "Successfully retrieved", - "data": [ - { - // TODO: Definition - "id": "string", - "name": "string", - "type": "string", - "description": "string" - } - ] -} -``` - -#### GET /api/me/healthcheck -Checks ME model connectivity. - -**Request Body:** -```json -{ - "timeout": "number" -} -``` - -**Response:** -```json -{ - "code": 200, - "message": "Connection successful", - "data": { - "status": "Connected", - "desc": "Connection successful", - "connect_status": "AVAILABLE" - } -} -``` - -#### GET /api/me/model/healthcheck -Checks specific model health. - -**Request Body:** -```json -{ - "model_name": "string" -} -``` - -**Response:** -```json -{ - "code": 200, - "message": "Model health check successful", - "data": { - "status": "string", - "desc": "string", - "connect_status": "string" - } -} -``` - -## Model Management App - -### Endpoints - -#### POST /api/model/create -Creates a new model. - -**Request Body:** -```json -{ - "model_name": "string", - "display_name": "string", - "connect_status": "string" -} -``` - -**Response:** -```json -{ - "code": 200, - "message": "Model created successfully", - "data": null -} -``` - -#### POST /api/model/update -Updates a model. - -**Request Body:** -```json -{ - "model_name": "string", - "display_name": "string", - "connect_status": "string" -} -``` - -**Response:** -```json -{ - "code": 200, - "message": "Model updated successfully", - "data": { - "model_name": "string" - } -} -``` - -#### POST /api/model/delete -Deletes a model. - -**Request Body:** -```json -{ - "model_name": "string" -} -``` - -**Response:** -```json -{ - "code": 200, - "message": "Model deleted successfully", - "data": { - "model_name": "string" - } -} -``` - -#### GET /api/model/list -Gets all models. - -**Response:** -```json -{ - "code": 200, - "message": "Successfully retrieved model list", - "data": [ - { - "model_id": "string", - "model_name": "string", - "model_repo": "string", - "display_name": "string", - "connect_status": "string" - } - ] -} -``` - -#### GET /api/model/healthcheck -Checks model health. - -**Request Body:** -```json -{ - "model_name": "string" -} -``` - -**Response:** -```json -{ - "code": 200, - "message": "Model health check successful", - "data": { - "status": "string", - "desc": "string", - "connect_status": "string" - } -} -``` - -#### GET /api/model/get_connect_status -Gets model connection status. - -**Request Body:** -```json -{ - "model_name": "string" -} -``` - -**Response:** -```json -{ - "code": 200, - "message": "Successfully retrieved connection status", - "data": { - "model_name": "string", - "connect_status": "string" - } -} -``` - -#### POST /api/model/update_connect_status -Updates model connection status. - -**Request Body:** -```json -{ - "model_name": "string", - "connect_status": "string" -} -``` - -**Response:** -```json -{ - "code": 200, - "message": "Successfully updated connection status", - "data": { - "model_name": "string", - "connect_status": "string" - } -} -``` - -## Proxy App - -### Endpoints - -#### GET /api/proxy/image -Proxies remote images. - -**Request Body:** -```json -{ - "url": "string" -} -``` - -**Response:** -```json -{ - "success": true, - "base64": "string", - "content_type": "string" -} -``` - -## File Management App - -### Endpoints - -#### POST /api/file/upload -Uploads files. - -**Request Body:** -```json -{ - "file": ["file"], - "chunking_strategy": "string", - "index_name": "string" -} -``` - -**Response:** -```json -{ - "message": "string", - "uploaded_files": ["string"], - "process_tasks": { - "task_id": "string" - } -} -``` - -#### POST /api/file/storage -Uploads files to storage. - -**Request Body:** -```json -{ - "files": ["file"], - "folder": "string" -} -``` - -**Response:** -```json -{ - "message": "string", - "success_count": "number", - "failed_count": "number", - "results": [ - { - "success": true, - "file_name": "string", - "url": "string" - } - ] -} -``` - -#### GET /api/file/storage -Gets storage files. - -**Request Body:** -```json -{ - "prefix": "string", - "limit": "number", - "include_urls": true -} -``` - -**Response:** -```json -{ - "total": "number", - "files": [ - { - "name": "string", - "size": "number", - "url": "string" - } - ] -} -``` - -#### GET /api/file/storage/{object_name} -Gets storage file. - -**Request Body:** -```json -{ - "download": true, - "expires": "number" -} -``` - -**Response:** -```json -{ - "success": true, - "url": "string", - "expires": "string" -} -``` - -#### DELETE /api/file/storage/{object_name} -Deletes storage file. - -**Response:** -```json -{ - "success": true, - "message": "string" -} -``` - -#### POST /api/file/storage/batch-urls -Gets batch file URLs. - -**Request Body:** -```json -{ - "object_names": ["string"], - "expires": "number" -} -``` - -**Response:** -```json -{ - "urls": [ - { - "object_name": "string", - "success": true, - "url": "string" - } - ] -} -``` - -#### POST /api/file/preprocess -Preprocesses files for agent. - -**Request Body:** -```json -{ - "query": "string", - "files": ["file"] -} -``` - -**Response:** -- Streaming response with SSE (Server-Sent Events) -- Returns processed content in real-time - -## Voice App - -### Endpoints - -#### WebSocket /api/voice/stt/ws -Real-time speech-to-text streaming. - -**Input:** -- Audio stream - -**Output:** -- Text stream - -#### WebSocket /api/voice/tts/ws -Real-time text-to-speech streaming. - -**Input:** -```json -{ - "text": "string" -} -``` - -**Output:** -- Audio stream - -## Dependencies - -The apps have the following dependencies: - -1. Base App: - - All other app routers - -2. Agent App: - - Agent utilities - - Conversation management service - -3. Config Sync App: - - python-dotenv - - Configuration utilities - -4. Conversation Management App: - - Conversation database - - Conversation management utilities - -5. Data Process App: - - Data process core - - Task status utilities - -6. Elasticsearch App: - - Elasticsearch core - - Embedding model - - Data process service - -7. ME Model Management App: - - Requests - - Model health service - -8. Model Management App: - - Model management database - - Model health service - - Model name utilities - -9. Proxy App: - - aiohttp - - Image filter utilities - -10. File Management App: - - File management utilities - - Attachment database - - Data process service - -11. Voice App: - - STT model - - TTS model - - WebSocket support +[Nexent API](https://8icfxll43r.apifox.cn) diff --git a/doc/docs/en/backend/overview.md b/doc/docs/en/backend/overview.md index 3e8620551..962233f18 100644 --- a/doc/docs/en/backend/overview.md +++ b/doc/docs/en/backend/overview.md @@ -202,4 +202,4 @@ python backend/mcp_service.py # MCP service - Resource pool management - Auto-scaling capabilities -For detailed backend development guidelines, see the [Development Guide](../getting-started/development-guide). \ No newline at end of file +For detailed backend development guidelines, see the [Developer Guide](../developer-guide/overview). \ No newline at end of file diff --git a/doc/docs/en/backend/prompt-development.md b/doc/docs/en/backend/prompt-development.md index 0031042cb..33a1a99bf 100644 --- a/doc/docs/en/backend/prompt-development.md +++ b/doc/docs/en/backend/prompt-development.md @@ -1,134 +1,48 @@ # Prompt Development Guide -This guide provides comprehensive information about the prompt template system used in Nexent for creating different types of agents. The YAML files in the `backend/prompts/` directory define system prompts, planning prompts, and other key prompt components for various agent types. +This guide explains how Nexent prompt templates are organized under `backend/prompts/` and how to extend them for new agents. -## File Naming Convention +## 📂 File Layout & Naming -The naming format follows `{agent_type}_agent.yaml`, where: -- `agent_type`: Describes the main function or purpose of the agent (e.g., manager, search, etc.) +- Core templates live in `backend/prompts/` using `{agent_type}_agent.yaml` or `{scope}_prompt_template.yaml`. +- Utility templates are under `backend/prompts/utils/` for meta generation (e.g., prompt/title helpers). -## Prompt Template Structure +## 🧩 Template Structure -Each YAML file contains the following main sections: +Each YAML may contain: +- `system_prompt`: role, responsibilities, execution flow, tool/sub-agent usage rules, Python code constraints, and examples. +- `planning`: `initial_facts`, `initial_plan`, update hooks before/after facts or plans. +- `managed_agent`: prompts for delegating tasks and collecting reports from sub-agents. +- `final_answer`: pre/post messages to shape final output. +- `tools_requirement`: priorities and guardrails for tool usage. +- `few_shots`: examples to steer behavior. -### 1. system_prompt +## 🔄 Variables -The system prompt is the core component of an agent, defining its role, capabilities, and behavioral guidelines. It typically includes: +Common placeholders for runtime rendering: +- `tools`, `managed_agents` +- `task`, `remaining_steps` +- `authorized_imports` +- `facts_update`, `answer_facts` -- **Core Responsibilities**: Main duties and capabilities description -- **Execution Flow**: Standard process and methods for agent task execution -- **Available Resources**: List of tools and sub-agents the agent can use -- **Resource Usage Requirements**: Priority and strategy for using different tools -- **Python Code Standards**: Code writing standards and constraints -- **Example Templates**: Examples demonstrating agent task execution +## 📑 Key Templates -### 2. planning +- Manager agents: `manager_system_prompt_template.yaml`, `manager_system_prompt_template_en.yaml` +- Managed agents: `managed_system_prompt_template.yaml`, `managed_system_prompt_template_en.yaml` +- Knowledge summary: `knowledge_summary_agent.yaml`, `knowledge_summary_agent_en.yaml` +- File analysis: `analyze_file.yaml`, `analyze_file_en.yaml` +- Cluster summary: `cluster_summary_agent.yaml`, `cluster_summary_reduce.yaml` (and `_zh` variants) +- Utilities (`utils/`): `prompt_generate*.yaml`, `generate_title*.yaml` -Contains various prompts for task planning: +## 🚀 How to Extend -- **initial_facts**: Initial fact collection prompts -- **initial_plan**: Initial plan formulation prompts -- **update_facts_pre_messages**: Prompts before updating facts -- **update_facts_post_messages**: Prompts after updating facts -- **update_plan_pre_messages**: Prompts before updating plans -- **update_plan_post_messages**: Prompts after updating plans +1. Copy the closest existing template and adjust `system_prompt`/`planning` for your scenario. +2. Keep placeholders intact unless intentionally removed. +3. Align tool lists with actual tools available to the agent; update `authorized_imports` if needed. +4. Validate with a small task to ensure flows (`Think → Code → Observe → Repeat`) produce the expected behavior. -### 3. managed_agent +## ✅ Standards & Tips -Defines prompts for sub-agent interactions: - -- **task**: Task assignment prompts for sub-agents -- **report**: Prompts for sub-agent result reporting - -### 4. final_answer - -Defines prompts for final answer generation: - -- **pre_messages**: Prompts before generating final answers -- **post_messages**: Prompts after generating final answers - -### 5. tools_requirement - -Defines prompts for tool usage standards and priorities. - -### 6. few_shots - -Provides few-shot learning examples to help agents better understand task execution methods. - -## Template Variables - -The prompt templates use the following special variables for dynamic replacement: - -- `{{tools}}`: Available tools list -- `{{managed_agents}}`: Available sub-agents list -- `{{task}}`: Current task description -- `{{authorized_imports}}`: Authorized Python module imports -- `{{facts_update}}`: Updated facts list -- `{{answer_facts}}`: Known facts list -- `{{remaining_steps}}`: Remaining execution steps - -## Available Prompt Templates - -### Core Templates - -1. **Manager Agent Templates** - - `manager_system_prompt_template.yaml` - Chinese version - - `manager_system_prompt_template_en.yaml` - English version - - These templates define the core manager agent that coordinates and dispatches various assistants and tools to efficiently solve complex tasks. - -2. **Managed Agent Templates** - - `managed_system_prompt_template.yaml` - Chinese version - - `managed_system_prompt_template_en.yaml` - English version - - These templates define specialized agents that perform specific tasks under the coordination of the manager agent. - -3. **Specialized Agent Templates** - - `knowledge_summary_agent.yaml` - Knowledge summary agent (Chinese) - - `knowledge_summary_agent_en.yaml` - Knowledge summary agent (English) - - `analyze_file.yaml` - File analysis agent (Chinese) - - `analyze_file_en.yaml` - File analysis agent (English) - -### Utility Templates - -Located in the `utils/` directory: - -1. **Prompt Generation Templates** - - `prompt_generate.yaml` - Chinese version - - `prompt_generate_en.yaml` - English version - - These templates help generate efficient and clear prompts for different agent types. - -2. **Title Generation Templates** - - `generate_title.yaml` - Chinese version - - `generate_title_en.yaml` - English version - Templates for generating titles for dialogs. - -## Execution Flow - -The standard agent execution flow follows this pattern: - -1. **Think**: Analyze current task status and progress -2. **Code**: Write simple Python code following standards -3. **Observe**: View code execution results -4. **Repeat**: Continue the cycle until the task is complete - -## Code Standards - -When writing Python code in prompts: - -1. Use the format `Code: \n```py\n` for executable code -2. Use the format `Code: \n```code:language_type\n` for display-only code -3. Use only defined variables that persist across calls -4. Use `print()` function to make variable information visible -5. Use keyword parameters for tool and agent calls -6. Avoid excessive tool calls in a single round -7. Only import from authorized modules: `{{authorized_imports}}` - -## Best Practices - -1. **Task Decomposition**: Break complex tasks into manageable sub-tasks -2. **Professional Matching**: Assign tasks based on agent expertise -3. **Information Integration**: Integrate outputs from different agents -4. **Efficiency Optimization**: Avoid redundant work -5. **Result Evaluation**: Assess agent return results and provide additional guidance when needed +- Use executable code fences for runnable snippets: ````py````, and display-only fences for non-executable examples. +- Prefer keyword args for tool calls; avoid excessive tool invocations per step. +- Keep comments and docstrings in English and respect repository coding rules. diff --git a/doc/docs/en/backend/tools/index.md b/doc/docs/en/backend/tools/index.md index 7076fce0a..2d2d2c185 100644 --- a/doc/docs/en/backend/tools/index.md +++ b/doc/docs/en/backend/tools/index.md @@ -26,6 +26,6 @@ For SDK-level tool development, see: ## Need Help? -- Check our [FAQ](../../getting-started/faq) for common tool integration issues +- Check our [FAQ](../../quick-start/faq) for common tool integration issues - Join our [Discord community](https://discord.gg/tb5H3S3wyv) for real-time support - Review [GitHub Issues](https://github.com/ModelEngine-Group/nexent/issues) for known issues \ No newline at end of file diff --git a/doc/docs/en/backend/tools/mcp.md b/doc/docs/en/backend/tools/mcp.md index 67ad79e48..e00b12493 100644 --- a/doc/docs/en/backend/tools/mcp.md +++ b/doc/docs/en/backend/tools/mcp.md @@ -1,479 +1,212 @@ -# Hierarchical Proxy Architecture Documentation - -## System Architecture Flowchart - -```mermaid -graph TD - A["Frontend Requests"] --> B["Main Service (FastAPI)(Port: 5010)"] - - B --> B1["Web API Management Layer(/api/mcp/*)"] - B1 --> B2["/api/mcp/tools/(Get Tool Information)"] - B1 --> B3["/api/mcp/add(Add MCP Server)"] - B1 --> B4["/api/mcp/(Delete MCP Server)"] - B1 --> B5["/api/mcp/list(List MCP Servers)"] - B1 --> B6["/api/mcp/recover(Recover MCP Servers)"] - - B --> C["MCP Service (FastMCP)(Port: 5011)"] - - C --> C1["Local Service Layer"] - C --> C2["Remote Proxy Layer"] - C --> C3["MCP Protocol API Layer"] - - C1 --> C11["local_mcp_service(Stable Mount)"] - - C2 --> C21["RemoteProxyManager(Dynamic Management)"] - C21 --> C22["Remote Proxy 1"] - C21 --> C23["Remote Proxy 2"] - C21 --> C24["Remote Proxy n..."] - - C3 --> C31["/healthcheck(Connectivity Check)"] - C3 --> C32["/list-remote-proxies(List Proxies)"] - C3 --> C33["/add-remote-proxies(Add Proxies)"] - C3 --> C34["/remote-proxies(Delete Proxies)"] - - C22 --> D1["Remote MCP Service 1(SSE/HTTP)"] - C23 --> D2["Remote MCP Service 2(SSE/HTTP)"] - C24 --> D3["Remote MCP Service n(SSE/HTTP)"] - - style A fill:#e1f5fe - style B fill:#f3e5f5 - style C fill:#e8f5e8 - style B1 fill:#fff3e0 - style C1 fill:#e8f5e8 - style C2 fill:#fff3e0 - style C3 fill:#fce4ec -``` +# Model Context Protocol (MCP) -## Architecture Overview +## 🌟 What is MCP? -This system implements a **dual-service proxy architecture** consisting of two independent services: +Model Context Protocol (MCP) is an open standard for connecting AI apps to external systems (data, tools, workflows), similar to a "USB-C for AI." It standardizes how hosts (e.g., Claude Desktop, Nexent) discover and call tools/resources exposed by MCP servers. -### 1. Main Service (FastAPI) - Port 5010 5014 -- **Purpose**: Provides web management interface and RESTful API, serving as the single entry point for frontend -- **Features**: User-oriented management with authentication, multi-tenant support, and proxy calls to MCP service -- **Startup File**: `config_service.py, runtime_service.py` +## 🧭 What can MCP do? -### 2. MCP Service (FastMCP) - Port 5011 -- **Purpose**: Provides MCP protocol services and proxy management (internal service) -- **Features**: MCP protocol standard, supports local services and remote proxies, primarily called by main service -- **Startup File**: `mcp_service.py` +- **Tools**: Functions callable by the LLM with user approval +- **Resources**: File-like data that clients can read +- **Prompts**: Reusable templates shared by servers +- Works over a simple protocol so hosts can connect to local or remote servers consistently -**Important Note**: Frontend clients only directly access the main service (5010). All MCP-related operations are completed by the main service proxying calls to the MCP service (5011). +## 🌐 Language Support -## Core Features +The MCP protocol provides SDKs for multiple programming languages: -### 1. Local Service Stability -- `local_mcp_service` and other local services maintain stable operation -- Adding, removing, or updating remote proxies does not affect local services +- **Python** ⭐ (recommended for beginners) +- **TypeScript** +- **Java** +- **Go** +- **Rust** +- Any other language that implements the MCP protocol -### 2. Dynamic Remote Proxy Management -- Supports dynamic addition, removal, and updating of remote MCP service proxies -- Each remote proxy is managed as an independent service -- Supports multiple transport methods (SSE, HTTP) +We recommend **Python** because it offers beginner-friendly syntax, rich ecosystem with frameworks like FastMCP, rapid prototyping capabilities, and thousands of mature libraries. -### 3. Dual-Layer API Interface +## 🚀 Quick Start -#### Main Service API (Port 5010) - External Management Layer -**Interfaces directly accessed by frontend clients**, providing user-oriented management features with authentication and multi-tenant support: +### 📋 Prerequisites -**Get Remote MCP Server Tool Information** -```http -GET /api/mcp/tools/?service_name={name}&mcp_url={url} -Authorization: Bearer {token} -``` +Install FastMCP before you start coding: -**Add Remote MCP Server** -```http -POST /api/mcp/add?mcp_url={url}&service_name={name} -Authorization: Bearer {token} +```bash +pip install fastmcp ``` -**Delete Remote MCP Server** -```http -DELETE /api/mcp/?service_name={name}&mcp_url={url} -Authorization: Bearer {token} -``` +### 📝 Basic Example -**Get Remote MCP Server List** -```http -GET /api/mcp/list -Authorization: Bearer {token} -``` +Create a simple string utility server with FastMCP: -**Recover Remote MCP Servers** -```http -GET /api/mcp/recover -Authorization: Bearer {token} -``` +```python +from fastmcp import FastMCP -#### MCP Service API (Port 5011) - Internal Protocol Layer -**Internal interfaces primarily called by main service**, also available for external MCP clients: +# Create an MCP server instance +mcp = FastMCP(name="String MCP Server") -**Connectivity Check** -```http -GET /healthcheck?mcp_url={url} -``` -Quickly checks if remote MCP service is reachable, returns simple connection status. +@mcp.tool( + name="calculate_string_length", + description="Calculate the length of a string" +) +def calculate_string_length(text: str) -> int: + return len(text) -**List All Remote Proxies** -```http -GET /list-remote-proxies -``` +@mcp.tool( + name="to_uppercase", + description="Convert text to uppercase" +) +def to_uppercase(text: str) -> str: + return text.upper() -**Add Remote Proxy** -```http -POST /add-remote-proxies -Content-Type: application/json +@mcp.tool( + name="to_lowercase", + description="Convert text to lowercase" +) +def to_lowercase(text: str) -> str: + return text.lower() -{ - "service_name": "my_service", - "mcp_url": "http://localhost:5012/sse", - "transport": "sse" -} +if __name__ == "__main__": + # Start with SSE transport + mcp.run(transport="sse", port=8000) ``` -**Delete Remote Proxy** -```http -DELETE /remote-proxies?service_name={service_name} -``` - -## Usage +### 🏃 Run the Server -### 1. Starting Services +Save the code as `mcp_server.py` and execute: -**Start Main Service** ```bash -cd backend -python config_service.py -python runtime_service.py +python mcp_server.py ``` -Service will start at `http://localhost:5010`. -**Start MCP Service** -```bash -cd backend -python mcp_service.py -``` -Service will start at `http://localhost:5011`. +You should see the server start successfully with the endpoint `http://127.0.0.1:8000/sse`. -### 2. Using APIs +## 🔌 Integrate with Nexent -#### Recommended Method: Managing MCP Servers via Main Service -**Frontend clients should use this method**, featuring complete authentication and permission management: +Once your MCP server is running, connect it to Nexent: -```bash -# Add remote MCP server -curl -X POST "http://localhost:5010/api/mcp/add?mcp_url=http://external-server:5012/sse&service_name=external_service" \ - -H "Authorization: Bearer {your_token}" +### 📍 Step 1: Start the MCP Server -# Get MCP server list -curl -H "Authorization: Bearer {your_token}" \ - "http://localhost:5010/api/mcp/list" -``` +Keep the server process running and note the endpoint (e.g., `http://127.0.0.1:8000/sse`). -#### Internal Debugging: Direct Access to MCP Service (Optional) -**For debugging or external MCP client direct integration only**: +### ⚙️ Step 2: Register in Nexent -```bash -# Test remote service connection -curl "http://localhost:5011/healthcheck?mcp_url=http://external-server:5012/sse" - -# Add remote proxy -curl -X POST http://localhost:5011/add-remote-proxies \ - -H "Content-Type: application/json" \ - -d '{ - "service_name": "external_service", - "mcp_url": "http://external-server:5012/sse", - "transport": "sse" - }' -``` +1. Open the **[Agent Development](../../user-guide/agent-development)** page +2. On the "Select Agent Tools" tab, click **MCP Configuration** on the right +3. Enter the server name and server URL + - ⚠️ **Important**: + - Server name must contain only letters and digits (no spaces or symbols) + - When Nexent runs inside Docker and MCP server runs on the host, replace `127.0.0.1` with `host.docker.internal` (e.g., `http://host.docker.internal:8000`) +4. Click **Add** to finish registration -## Code Structure - -### Main Service Components (config_service.py, runtime_service.py) -- **FastAPI Application**: Provides Web API and management interface -- **Multi-tenant Support**: Multi-tenant management based on authentication -- **Router Management**: Contains routers for multiple functional modules - -### MCP Service Components (mcp_service.py) - -#### RemoteProxyManager Class -Responsible for managing the lifecycle of all remote proxies: -- `add_remote_proxy()`: Add new remote proxy -- `remove_remote_proxy()`: Remove specified remote proxy -- `update_remote_proxy()`: Update existing remote proxy -- `list_remote_proxies()`: List all remote proxy configurations -- `_validate_remote_service()`: Validate remote service connection - -#### MCP Protocol Endpoints -- `/healthcheck`: Connectivity check endpoint -- `/list-remote-proxies`: List all remote proxy endpoint -- `/add-remote-proxies`: Add remote proxy endpoint -- `/remote-proxies`: Delete specific proxy endpoint - -### Remote MCP Management (remote_mcp_app.py) -- **Authentication Integration**: Integrated with main service authentication system -- **Data Persistence**: Supports database storage and recovery -- **Service Discovery**: Tool information acquisition and management - -## Service Dependencies - -```mermaid -graph LR - A["Frontend Client"] --> B["Main Service :5010(FastAPI)"] - B --> C["MCP Service :5011(FastMCP)"] - B --> D["Database(Users/Tenants/Config)"] - C --> E["Local MCP Service"] - C --> F["Remote MCP Proxy"] - - G["External MCP Client"] -.-> C - - style A fill:#e1f5fe - style B fill:#f3e5f5 - style C fill:#e8f5e8 - style G fill:#fff3e0 -``` +### 🎯 Step 3: Use the MCP Tool -## Interface Sequence Diagrams - -### 1. Get Remote MCP Tool Information (GET /api/mcp/tools/) - -```mermaid -sequenceDiagram - participant C as Frontend Client - participant M as Main Service(5010) - participant T as Tool Config Service - participant R as Remote MCP Service - - C->>M: GET /api/mcp/tools/?service_name=xxx&mcp_url=xxx - Note over C,M: Authorization: Bearer token (optional) - - M->>T: get_tool_from_remote_mcp_server(service_name, mcp_url) - T->>R: Direct connection to remote MCP service for tool list - R-->>T: Return tool information - T-->>M: Tool information list - - M-->>C: JSON response {tools: [...], status: "success"} - - Note over M,C: Returns 400 status code on error - Note over T,R: Note: This interface accesses remote MCP directly, not through local MCP service(5011) -``` +During agent creation or editing, the newly registered MCP tool appears in the tool list and can be attached to any agent. -### 2. Add Remote MCP Server (POST /api/mcp/add) - -```mermaid -sequenceDiagram - participant C as Frontend Client - participant M as Main Service(5010) - participant A as Auth System - participant S as MCP Service Management - participant DB as Database - participant MCP as MCP Service(5011) - participant R as Remote MCP Service - - C->>M: POST /api/mcp/add?mcp_url=xxx&service_name=xxx - Note over C,M: Authorization: Bearer token - - M->>A: get_current_user_id(authorization) - A-->>M: user_id, tenant_id - - M->>S: add_remote_mcp_server_list(tenant_id, user_id, mcp_url, service_name) - - S->>DB: Check if service name already exists - DB-->>S: Check result - - alt Service name already exists - S-->>M: JSONResponse (409 - Service name already exists) - M-->>C: Error response (409) - else Service name available - S->>S: add_remote_proxy() - S->>MCP: POST /add-remote-proxies - MCP->>MCP: Validate remote MCP service connection - MCP->>R: Connectivity test - R-->>MCP: Connection response - - alt MCP connection successful - MCP->>MCP: Create and mount remote proxy - MCP-->>S: 200 - Successfully added proxy - S->>DB: Save MCP server configuration - DB-->>S: Save result - S-->>M: Success result - M-->>C: JSON response {message: "Successfully added", status: "success"} - else MCP connection failed - MCP-->>S: Error response (503/409/400) - S-->>M: Error result/JSONResponse - M-->>C: Error response (400/409/503) - end - end -``` +## 🔧 Advanced Use Cases -### 3. Delete Remote MCP Server (DELETE /api/mcp/) - -```mermaid -sequenceDiagram - participant C as Frontend Client - participant M as Main Service(5010) - participant A as Auth System - participant S as MCP Service Management - participant DB as Database - participant MCP as MCP Service(5011) - - C->>M: DELETE /api/mcp/?service_name=xxx&mcp_url=xxx - Note over C,M: Authorization: Bearer token - - M->>A: get_current_user_id(authorization) - A-->>M: user_id, tenant_id - - M->>S: delete_remote_mcp_server_list(tenant_id, user_id, mcp_url, service_name) - - S->>DB: Find and delete MCP server configuration - DB-->>S: Delete result - - alt Database deletion failed - S-->>M: JSONResponse (400 - server not record) - M-->>C: Error response (400) - else Database deletion successful - S->>MCP: DELETE /remote-proxies?service_name=xxx - MCP->>MCP: Unmount remote proxy service - - alt MCP deletion successful - MCP-->>S: 200 - Successfully removed - S-->>M: Success result - M-->>C: JSON response {message: "Successfully deleted", status: "success"} - else MCP deletion failed - MCP-->>S: 404/400 - Deletion failed - S-->>M: Error result/JSONResponse - M-->>C: Error response (400/404) - end - end -``` +### 🌐 Wrap a REST API + +Expose existing REST APIs as MCP tools: + +```python +from fastmcp import FastMCP +import requests + +mcp = FastMCP("Course Statistics Server") -### 4. Get Remote MCP Server List (GET /api/mcp/list) - -```mermaid -sequenceDiagram - participant C as Frontend Client - participant M as Main Service(5010) - participant A as Auth System - participant S as MCP Service Management - participant DB as Database - - C->>M: GET /api/mcp/list - Note over C,M: Authorization: Bearer token - - M->>A: get_current_user_id(authorization) - A-->>M: user_id, tenant_id - - M->>S: get_remote_mcp_server_list(tenant_id) - S->>DB: Query tenant's MCP server list - DB-->>S: Server list data - S-->>M: remote_mcp_server_list - - M-->>C: JSON response {remote_mcp_server_list: [...], status: "success"} - - Note over M,C: Returns 400 status code on error +@mcp.tool( + name="get_course_statistics", + description="Get course statistics such as average, max, min, and total students" +) +def get_course_statistics(course_id: str) -> str: + api_url = "https://your-school-api.com/api/courses/statistics" + response = requests.get(api_url, params={"course_id": course_id}) + + if response.status_code == 200: + data = response.json() + stats = data.get("statistics", {}) + return ( + f"Course {course_id} statistics:\n" + f"Average: {stats.get('average', 'N/A')}\n" + f"Max: {stats.get('max', 'N/A')}\n" + f"Min: {stats.get('min', 'N/A')}\n" + f"Total Students: {stats.get('total_students', 'N/A')}" + ) + return f"API request failed: {response.status_code}" + +if __name__ == "__main__": + mcp.run(transport="sse", port=8000) ``` -### 5. Recover Remote MCP Servers (GET /api/mcp/recover) - -```mermaid -sequenceDiagram - participant C as Frontend Client - participant M as Main Service(5010) - participant A as Auth System - participant S as MCP Service Management - participant DB as Database - participant MCP as MCP Service(5011) - participant R as Remote MCP Service - - C->>M: GET /api/mcp/recover - Note over C,M: Authorization: Bearer token - - M->>A: get_current_user_id(authorization) - A-->>M: user_id, tenant_id - - M->>S: recover_remote_mcp_server(tenant_id) - - S->>DB: Query all tenant's MCP server configurations - DB-->>S: Server list in database (record_set) - - S->>MCP: GET /list-remote-proxies - MCP-->>S: Current proxy list in MCP service (remote_set) - - S->>S: Calculate difference (record_set - remote_set) - - loop For each missing MCP server - S->>S: add_remote_proxy(mcp_name, mcp_url) - S->>MCP: POST /add-remote-proxies - MCP->>R: Connect to remote MCP service - R-->>MCP: Connection response - - alt Addition successful - MCP-->>S: 200 - Successfully added - else Addition failed - MCP-->>S: Error response - S-->>M: Error result/JSONResponse - M-->>C: Error response (400) - Note over S,M: If any server recovery fails, entire operation fails - end - end - - S-->>M: Success result - M-->>C: JSON response {message: "Successfully recovered", status: "success"} +### 🏢 Wrap an Internal Module + +Integrate local business logic: + +```python +from fastmcp import FastMCP +from your_school_module import query_course_statistics + +mcp = FastMCP("Course Statistics Server") + +@mcp.tool( + name="get_course_statistics", + description="Get course statistics such as average, max, min, and total students" +) +def get_course_statistics(course_id: str) -> str: + try: + stats = query_course_statistics(course_id) + return ( + f"Course {course_id} statistics:\n" + f"Average: {stats.get('average', 'N/A')}\n" + f"Max: {stats.get('max', 'N/A')}\n" + f"Min: {stats.get('min', 'N/A')}\n" + f"Total Students: {stats.get('total_students', 'N/A')}" + ) + except Exception as exc: + return f"Failed to query statistics: {exc}" + +if __name__ == "__main__": + mcp.run(transport="sse", port=8000) ``` -## Sequence Diagram Explanation +## ✅ Best Practices -### Interface Classification +- **Logging**: For stdio transports, avoid stdout logging (no `print`); log to stderr/files. [Logging guidance](https://modelcontextprotocol.io/docs/develop/build-server#logging-in-mcp-servers) +- **Documentation**: Keep tool docstrings clear; FastMCP derives schema from type hints +- **Error Handling**: Handle errors gracefully and return user-friendly text +- **Security**: Do not hard-code secrets; load credentials from env/secret managers -#### 1. Direct Remote MCP Access Interfaces -- **GET /api/mcp/tools/**: Directly accesses remote MCP through tool configuration service for tool information -- Feature: Does not go through local MCP service (5011), directly connects to external MCP service +## 📚 Resources -#### 2. Local MCP Service Interaction Interfaces -- **POST /api/mcp/add**: Validates connection and adds proxy through MCP service -- **DELETE /api/mcp/**: Removes proxy through MCP service -- **GET /api/mcp/recover**: Recovers proxy connections through MCP service -- Feature: Requires interaction with local MCP service (5011), involves proxy lifecycle management +### 🐍 Python -#### 3. Database-Only Interfaces -- **GET /api/mcp/list**: Directly queries database for server list -- Feature: Simplest flow, only involves database queries +- [FastMCP Documentation](https://github.com/modelcontextprotocol/python-sdk) +- [Python SDK Repository](https://github.com/modelcontextprotocol/python-sdk) -### Common Flow Characteristics -1. **Authentication Flow**: Except for tool query interface, other interfaces require Bearer token authentication, obtaining user and tenant information through `get_current_user_id()` -2. **Multi-tenant Isolation**: All operations are isolated based on `tenant_id`, ensuring data security -3. **Error Handling**: Unified exception handling mechanism, returning standardized JSON error responses -4. **Proxy Architecture**: Main service acts as proxy, coordinating calls to various backend services +### 🔤 Other Languages -### Key Interaction Points -- **Authentication System**: Validates user identity and permissions -- **Database**: Stores and manages MCP server configuration information -- **MCP Service (5011)**: Handles MCP protocol interaction and proxy management -- **Tool Configuration Service**: Handles tool information acquisition -- **Remote MCP Service**: External MCP service providers +- [MCP TypeScript SDK](https://github.com/modelcontextprotocol/typescript-sdk) +- [MCP Java SDK](https://github.com/modelcontextprotocol/java-sdk) +- [MCP Go SDK](https://github.com/modelcontextprotocol/go-sdk) +- [MCP Rust SDK](https://github.com/modelcontextprotocol/rust-sdk) -### Operation Sequence Importance -- **Add Operations**: First validate MCP connection, save to database only after success (ensures data consistency) -- **Delete Operations**: First delete database records, then remove MCP proxy (prevents data residue) -- **Recovery Operations**: Compare database and MCP service differences, supplement missing proxies +### 📖 Official Documentation -## Error Handling +- [MCP Introduction](https://modelcontextprotocol.io/docs/getting-started/intro) +- [Build MCP Server Guide](https://modelcontextprotocol.io/docs/develop/build-server) +- [SDK Documentation](https://modelcontextprotocol.io/docs/sdk) +- [MCP Protocol Specification](https://modelcontextprotocol.io/) -- Validates remote service connection before adding proxies -- Provides detailed error information and status codes -- Supports graceful service unloading and reloading -- Dual-layer error handling: management layer and protocol layer +### 🔗 Related Guides -## Performance Optimization +- [Nexent Agent Development Guide](../../user-guide/agent-development) +- [MCP Tool Ecosystem Overview](../../mcp-ecosystem/overview) +- [MCP Recommendations](../../mcp-ecosystem/mcp-recommendations) -- Proxy services are loaded on demand -- Supports concurrent operations -- Minimizes impact on existing services -- Loosely coupled service design +## 🆘 Need Help? -## Security Features +If you run into issues while developing MCP servers: -- **Authentication & Authorization**: Main service supports Bearer token authentication -- **Multi-tenant Isolation**: Isolated management of MCP servers for different tenants -- **Connection Validation**: Performs connectivity verification before adding remote services \ No newline at end of file +1. Check the **[FAQ](../../quick-start/faq)** +2. Ask questions in [GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions) +3. Review sample servers on the [ModelScope MCP Marketplace](https://www.modelscope.cn/mcp) diff --git a/doc/docs/en/backend/tools/nexent-native.md b/doc/docs/en/backend/tools/nexent-native.md new file mode 100644 index 000000000..475d354a9 --- /dev/null +++ b/doc/docs/en/backend/tools/nexent-native.md @@ -0,0 +1,26 @@ +--- +title: Nexent Native Tools +--- + +# Nexent Native Tools + +## 🧭 Scope + +Nexent native tools are developed and maintained in the official repository. If you need custom capabilities, contribute directly under `sdk/nexent/core/tools` following the existing tool patterns. + +## 🛠️ Development Guidelines + +- Build new tools alongside existing ones in `sdk/nexent/core/tools` (e.g., file, search, email, multimodal). +- Follow the conventions documented in [Tools](../../sdk/core/tools) for structure, inputs, and messaging. +- Keep comments/docstrings in English and align with repository rules. + +## 🤝 Contribution Path + +- Submit contributions to the Nexent official repo; external hosting is not supported for native tools. +- Reference existing implementations and [Contributing](../../contributing) for PR workflow and standards. + +## 🔗 Related References + +- [Tools](../../sdk/core/tools) +- [Contributing](../../contributing) + diff --git a/doc/docs/en/version/version-management.md b/doc/docs/en/backend/version-management.md similarity index 85% rename from doc/docs/en/version/version-management.md rename to doc/docs/en/backend/version-management.md index 2c416c0b3..bb9c98e95 100644 --- a/doc/docs/en/version/version-management.md +++ b/doc/docs/en/backend/version-management.md @@ -2,7 +2,7 @@ The Nexent project adopts a unified version management strategy to ensure consistency between frontend and backend version information. This document describes how to manage and update project version information. -## Version Number Format +## 📋 Version Number Format Nexent uses Semantic Versioning: @@ -12,21 +12,21 @@ Nexent uses Semantic Versioning: - **PATCH**: Backwards-compatible bug fixes - **BUILD**: Optional minor version number for more granular bugfix versions -### Version Number Examples +### 🏷️ Version Number Examples - `v1.2.0` - Feature update release - `v1.2.0.1` - Bugfix release with minor version number -## Frontend Version Management +## 🖥️ Frontend Version Management -### Version Information Location +### 📍 Version Information Location Frontend version information is fetched from the backend via API. - **Endpoint**: `GET /api/tenant_config/deployment_version` - **Service**: `frontend/services/versionService.ts` -### Version Update Process +### 🔄 Version Update Process 1. **Update backend version in code** @@ -47,16 +47,16 @@ APP_VERSION="v1.1.0" # Check the application version displayed at the bottom of the page ``` -### Version Display +### 📺 Version Display Frontend version information is displayed at the following location: - **Location**: Bottom navigation bar, located at the bottom left corner of the page. - **Version Format**: `v1.1.0` -## Backend Version Management +## ⚙️ Backend Version Management -### Version Information Location +### 📍 Version Information Location Backend version information is defined in code in `backend/consts/const.py`: @@ -65,11 +65,11 @@ Backend version information is defined in code in `backend/consts/const.py`: APP_VERSION = "v1.0.0" ``` -### Version Configuration +### 🔧 Version Configuration Version is configured directly in `backend/consts/const.py`. -### Version Display +### 📺 Version Display Backend startup will print version information in the logs: @@ -78,7 +78,7 @@ Backend startup will print version information in the logs: logger.info(f"APP version is: {APP_VERSION}") ``` -### Version Update Process +### 🔄 Version Update Process 1. **Update Version in Code** @@ -97,3 +97,4 @@ APP_VERSION="v1.1.0" # Check the version information in the startup logs # Output example: APP version is: v1.1.0 ``` + diff --git a/doc/docs/en/contributing.md b/doc/docs/en/contributing.md index 9e410b51f..c5c313c05 100644 --- a/doc/docs/en/contributing.md +++ b/doc/docs/en/contributing.md @@ -140,7 +140,7 @@ git checkout -b your-branch-name ``` ### Step 4 Make Your Changes -🧙♂️ Code like a wizard! Follow our [Development Guide](./getting-started/development-guide) for setup instructions and coding standards. Ensure your changes are well-tested and documented. +🧙♂️ Code like a wizard! Follow our [Developer Guide](./developer-guide/overview) for setup instructions and coding standards. Ensure your changes are well-tested and documented. ### Step 5 Commit Your Changes 📝 Commit with a clear and concise message following our commit message guidelines: diff --git a/doc/docs/en/deployment/docker-build.md b/doc/docs/en/deployment/docker-build.md index 06b6c0bfd..98174eca5 100644 --- a/doc/docs/en/deployment/docker-build.md +++ b/doc/docs/en/deployment/docker-build.md @@ -154,4 +154,10 @@ Notes: - ⚠️ `--load` can only be used with single architecture builds - 📝 Use `docker images` to verify the images are loaded locally - 📊 Use `--progress=plain` to see detailed build and push progress -- 📈 Use `--build-arg MIRROR=...` to set up a pip mirror to accelerate your build-up progress \ No newline at end of file +- 📈 Use `--build-arg MIRROR=...` to set up a pip mirror to accelerate your build-up progress + +## 🚀 Deployment Recommendations + +After building is complete, you can use the docker/deploy.sh script for deployment, or directly start the services using docker-compose. + +> When starting a test of locally built images, you need to change APP_VERSION="$(get_app_version)" to APP_VERSION="latest" in docker/deploy.sh, because the deployment will default to using the image corresponding to the current version. diff --git a/doc/docs/en/developer-guide/environment-setup.md b/doc/docs/en/developer-guide/environment-setup.md new file mode 100644 index 000000000..21f3cb6af --- /dev/null +++ b/doc/docs/en/developer-guide/environment-setup.md @@ -0,0 +1,142 @@ +--- +title: Environment Preparation +--- + +# Environment Preparation + +Use this guide to prepare your environment before developing with Nexent. It separates full-stack project setup from SDK-only workflows so you can follow the path that fits your role. + +## 🧱 Common Requirements + +- Python 3.10+ +- Node.js 18+ +- Docker & Docker Compose +- `uv` (Python package manager) +- `pnpm` (Node.js package manager) + +## 🧑💻 Full-Stack Nexent Development + +### 1. Infrastructure Deployment + +Before backend work, start core services (PostgreSQL, Redis, Elasticsearch, MinIO, etc.). + +```bash +# Run from the docker directory at the project root +cd docker +./deploy.sh --mode infrastructure +``` + +:::: info Important Notes +Infrastructure mode launches PostgreSQL, Redis, Elasticsearch, and MinIO. The script generates required credentials and saves them in the project root `.env`. URLs are configured as localhost endpoints for easy local development. +:::: + +### 2. Backend Setup + +```bash +# Run inside the backend directory +cd backend +uv sync --all-extras +uv pip install ../sdk +``` + +:::: tip Notes +`--all-extras` installs every optional dependency (data processing, testing, etc.). After syncing, install the local SDK package. +:::: + +#### Optional: Accelerate with Mirror Sources + +If downloads are slow, use domestic mirrors: + +```bash +# Tsinghua mirror +uv sync --all-extras --default-index https://pypi.tuna.tsinghua.edu.cn/simple +uv pip install ../sdk --default-index https://pypi.tuna.tsinghua.edu.cn/simple + +# Alibaba Cloud mirror +uv sync --all-extras --default-index https://mirrors.aliyun.com/pypi/simple/ +uv pip install ../sdk --default-index https://mirrors.aliyun.com/pypi/simple/ + +# Multiple mirrors (recommended) +uv sync --all-extras --index https://pypi.tuna.tsinghua.edu.cn/simple --index https://mirrors.aliyun.com/pypi/simple/ +uv pip install ../sdk --index https://pypi.tuna.tsinghua.edu.cn/simple --index https://mirrors.aliyun.com/pypi/simple/ +``` + +:::: info Mirror Source Reference +- Tsinghua: `https://pypi.tuna.tsinghua.edu.cn/simple` +- Alibaba Cloud: `https://mirrors.aliyun.com/pypi/simple/` +- USTC: `https://pypi.mirrors.ustc.edu.cn/simple/` +- Douban: `https://pypi.douban.com/simple/` +Using multiple mirrors improves success rates. +:::: + +### 3. Frontend Setup + +```bash +# Run inside the frontend directory +cd frontend +pnpm install +pnpm dev +``` + +### 4. Service Startup + +Activate the backend virtual environment before starting services. + +```bash +# Run inside backend directory +cd backend +source .venv/bin/activate +``` + +:::: warning Important Notes +On Windows, activate the environment with `source .venv/Scripts/activate`. +:::: + +Start the backend services from the project root, in order: + +```bash +# Always run from project root with environment variables loaded +source .env && python backend/mcp_service.py +source .env && python backend/data_process_service.py +source .env && python backend/config_service.py +source .env && python backend/runtime_service.py +``` + +:::: warning Important Notes +Each command must run from the project root and be prefixed with `source .env`. Ensure databases, Redis, Elasticsearch, and MinIO (from infrastructure mode) are healthy first. +:::: + +## 🧰 SDK-Only Development + +If you only need the SDK (without running the entire stack), install it directly. + +### 1. Install from Source + +```bash +git clone https://github.com/ModelEngine-Group/nexent.git +cd nexent/sdk +uv pip install -e . +``` + +### 2. Install with uv + +```bash +uv add nexent +``` + +### 3. Development Extras + +For SDK contributors, install with development dependencies: + +```bash +cd nexent/sdk +uv pip install -e ".[dev]" +``` + +This adds: + +- Code quality tools (ruff) +- Testing framework (pytest) +- Data processing dependencies (unstructured) +- Other developer utilities + diff --git a/doc/docs/en/getting-started/development-guide.md b/doc/docs/en/developer-guide/overview.md similarity index 56% rename from doc/docs/en/getting-started/development-guide.md rename to doc/docs/en/developer-guide/overview.md index 3721a8813..51307d7d7 100644 --- a/doc/docs/en/getting-started/development-guide.md +++ b/doc/docs/en/developer-guide/overview.md @@ -40,100 +40,15 @@ nexent/ - **Monitoring**: Built-in health checks - **Logging**: Structured logging -## 🚀 Development Environment Setup - -### Environment Requirements -- Python 3.10+ -- Node.js 18+ -- Docker & Docker Compose -- uv (Python package manager) -- pnpm (Node.js package manager) - -### Infrastructure Deployment -Before starting backend development, you need to deploy infrastructure services. These services include databases, caching, file storage, and other core components. - -```bash -# Execute in the docker directory of the project root directory -cd docker -./deploy.sh --mode infrastructure -``` - -::: info Important Notes -Infrastructure mode will start PostgreSQL, Redis, Elasticsearch, and MinIO services. The deployment script will automatically generate keys and environment variables needed for development and save them to the `.env` file in the root directory. Generated keys include MinIO access keys and Elasticsearch API keys. All service URLs will be configured as localhost addresses for convenient local development. -::: - -### Backend Setup -```bash -# Execute in the backend directory of the project root directory -cd backend -uv sync --all-extras -uv pip install ../sdk -``` - -::: tip Notes -`--all-extras` will install all optional dependencies, including data processing, testing, and other modules. Then install the local SDK package. -::: +## 🧱 Environment Preparation -#### Using Domestic Mirror Sources (Optional) -If network access is slow, you can use domestic mirror sources to accelerate installation: - -```bash -# Using Tsinghua University mirror source -uv sync --all-extras --default-index https://pypi.tuna.tsinghua.edu.cn/simple -uv pip install ../sdk --default-index https://pypi.tuna.tsinghua.edu.cn/simple - -# Using Alibaba Cloud mirror source -uv sync --all-extras --default-index https://mirrors.aliyun.com/pypi/simple/ -uv pip install ../sdk --default-index https://mirrors.aliyun.com/pypi/simple/ - -# Using multiple mirror sources (recommended) -uv sync --all-extras --index https://pypi.tuna.tsinghua.edu.cn/simple --index https://mirrors.aliyun.com/pypi/simple/ -uv pip install ../sdk --index https://pypi.tuna.tsinghua.edu.cn/simple --index https://mirrors.aliyun.com/pypi/simple/ -``` +All setup steps now live in the dedicated [Environment Preparation](./environment-setup) guide. It covers: -::: info Mirror Source Information -- **Tsinghua University Mirror**: `https://pypi.tuna.tsinghua.edu.cn/simple` -- **Alibaba Cloud Mirror**: `https://mirrors.aliyun.com/pypi/simple/` -- **USTC Mirror**: `https://pypi.mirrors.ustc.edu.cn/simple/` -- **Douban Mirror**: `https://pypi.douban.com/simple/` - -It's recommended to use multiple mirror source configurations to improve download success rates. -::: - -### Frontend Setup -```bash -# Execute in the frontend directory of the project root directory -cd frontend -pnpm install -pnpm dev -``` - -### Service Startup -Before starting services, you need to activate the virtual environment: - -```bash -# Execute in the backend directory of the project root directory -cd backend -source .venv/bin/activate # Activate virtual environment -``` - -::: warning Important Notes -On Windows, you need to execute the `source .venv/Scripts/activate` command to activate the virtual environment. -::: - -Nexent includes three core backend services that need to be started separately: - -```bash -# Execute in the project root directory, please follow this order: -source .env && python backend/mcp_service.py # MCP service -source .env && python backend/data_process_service.py # Data processing service -source .env && python backend/config_service.py # Config service -source .env && python backend/runtime_service.py # Runtime service -``` +- Shared prerequisites for every contributor +- Full-stack Nexent setup (infrastructure, backend, frontend, runtime services) +- SDK-only installation workflows for developers who only need the Python package -::: warning Important Notes -All services must be started from the project root directory. Each Python command should be preceded by `source .env` to load environment variables. Ensure infrastructure services (database, Redis, Elasticsearch, MinIO) are started and running properly. -::: +Review that guide first, then return here for module-specific details. ## 🔧 Development Module Guide @@ -208,8 +123,8 @@ For detailed build instructions, see [Docker Build Guide](../deployment/docker-b ## 💡 Getting Help ### Documentation Resources -- [Installation Guide](./installation.md) - Environment setup and deployment -- [FAQ](./faq) - Frequently asked questions +- [Installation Guide](../quick-start/installation) - Environment setup and deployment +- [FAQ](../quick-start/faq) - Frequently asked questions - [User Guide](../user-guide/home-page) - Nexent user guide ### Community Support diff --git a/doc/docs/en/frontend/overview.md b/doc/docs/en/frontend/overview.md index bfccad3bc..8949152f3 100644 --- a/doc/docs/en/frontend/overview.md +++ b/doc/docs/en/frontend/overview.md @@ -128,4 +128,4 @@ npm start - Voice processing integration - Analytics and monitoring -For detailed development guidelines and component documentation, see the [Development Guide](../getting-started/development-guide). \ No newline at end of file +For detailed development guidelines and component documentation, see the [Developer Guide](../developer-guide/overview). \ No newline at end of file diff --git a/doc/docs/en/getting-started/overview.md b/doc/docs/en/getting-started/overview.md index 741f5cb45..560b53510 100644 --- a/doc/docs/en/getting-started/overview.md +++ b/doc/docs/en/getting-started/overview.md @@ -70,9 +70,9 @@ For detailed architectural design and technical implementation, see our **[Softw Ready to get started? Here are your next steps: -1. **📋 [Installation & Deployment](./installation)** - System requirements and deployment guide -2. **🔧 [Development Guide](./development-guide)** - Build from source and customize -3. **❓ [FAQ](./faq)** - Common questions and troubleshooting +1. **📋 [Installation & Deployment](../quick-start/installation)** - System requirements and deployment guide +2. **🔧 [Developer Guide](../developer-guide/overview)** - Build from source and customize +3. **❓ [FAQ](../quick-start/faq)** - Common questions and troubleshooting ## 💬 Community & contact diff --git a/doc/docs/en/known-issues.md b/doc/docs/en/known-issues.md deleted file mode 100644 index 3f1b09ced..000000000 --- a/doc/docs/en/known-issues.md +++ /dev/null @@ -1,41 +0,0 @@ -# Known Issues - -This page lists known issues and limitations in the current version of Nexent. We are actively working on fixing these issues and will update this page as solutions become available. - -## 🐛 Current Issues - -### 1. OpenSSH Container Software Installation Limitations - -**Issue**: Installing additional software packages in the OpenSSH container for Terminal tool usage is currently challenging due to container constraints. - -**Status**: Under development - -**Impact**: Users who need custom tools or packages in the Terminal environment may face limitations. - -**Planned Solution**: We are working on providing improved containers and documentation to make customization easier. This will include better package management and more flexible container configurations. - -**Expected Timeline**: Improved container support planned for upcoming releases. - -## 📝 Reporting Issues - -If you encounter any issues not listed here, please: - -1. **Check our [FAQ](./getting-started/faq)** for common solutions -2. **Search existing issues** on [GitHub Issues](https://github.com/ModelEngine-Group/nexent/issues) -3. **Create a new issue** with detailed information including: - - Steps to reproduce - - Expected behavior - - Actual behavior - - System information - - Log files (if applicable) - -## 🔄 Issue Status Updates - -We regularly update this page with the status of known issues. Check back frequently for updates, or watch our [GitHub repository](https://github.com/ModelEngine-Group/nexent) for notifications. - -## 💬 Community Support - -For immediate help or discussion about issues: -- Join our [Discord community](https://discord.gg/tb5H3S3wyv) -- Ask questions in [GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions) -- Check our [Contributing Guide](./contributing) if you'd like to help fix issues \ No newline at end of file diff --git a/doc/docs/en/mcp-ecosystem/mcp-recommendations.md b/doc/docs/en/mcp-ecosystem/mcp-recommendations.md new file mode 100644 index 000000000..24d239b7b --- /dev/null +++ b/doc/docs/en/mcp-ecosystem/mcp-recommendations.md @@ -0,0 +1,36 @@ +# MCP Recommendations + +This page provides curated recommendations for MCP platforms and tools to help you quickly discover high-quality MCP services. + +## 🌐 MCP Community Hub + +The global MCP ecosystem is thriving with multiple platforms supporting MCP development and deployment: + +| Platform | Description | Notes | +|----------|-------------|-------| +| **[GitHub MCP Server](https://github.com/github/github-mcp-server)** | Deep integration with Claude, GPT-4, Copilot etc., supports Go and Python | OAuth/GitHub account authorization | +| **[Qdrant MCP Vector Server](https://github.com/qdrant/mcp-server-qdrant)** | Semantic vector storage with Python/Go compatibility | Compatible with LangChain and other tools | +| **[Anthropic Reference MCP Servers](https://github.com/modelcontextprotocol/servers)** | Lightweight teaching and prototyping tools, Python | Includes fetch, git and other universal tools | +| **[AWS Labs MCP Server](https://github.com/awslabs/mcp)** | AWS+Go+CDK cloud reference services | Suitable for cloud environments | +| **[MCP Hub China](https://www.mcp-cn.com/)** | Chinese curated high-quality MCP service platform | Focuses on quality over quantity, community-driven | +| **[ModelScope MCP Marketplace](https://modelscope.cn/mcp)** | China's largest MCP community with 1,500+ services | From Amap to Alipay, comprehensive service coverage | +| **Community MCP Servers** | Various scenario-specific source code collection | Mostly experimental and innovative tools | + +## 🔧 Recommended MCP Tools + +| Tool Name | Function | Description | +|-----------|----------|-------------| +| **[Amap Maps](https://modelscope.cn/mcp/servers/@amap/amap-maps)** | Geographic services and navigation | Comprehensive mapping, geocoding, routing, and location services | +| **[Bing Search (Chinese)](https://modelscope.cn/mcp/servers/@yan5236/bing-cn-mcp-server)** | Web search in Chinese | Optimized Chinese web search and information retrieval | +| **[12306 Train Ticket Query](https://modelscope.cn/mcp/servers/@Joooook/12306-mcp)** | China railway ticket booking | Real-time train schedules, ticket availability, and booking assistance | +| **[Alipay MCP](https://modelscope.cn/mcp/servers/@alipay/mcp-server-alipay)** | Payment and financial services | Digital payments, financial tools, and services integration | +| **[Variflight Aviation](https://modelscope.cn/mcp/servers/@variflight-ai/variflight-mcp)** | Flight information and aviation data | Real-time flight tracking, schedules, and aviation analytics | +| **[Sequential Thinking](https://modelscope.cn/mcp/servers/@modelcontextprotocol/sequentialthinking)** | Structured problem-solving framework | Break down complex problems into manageable, sequential steps | +| **[ArXiv AI Search](https://modelscope.cn/mcp/servers/@blazickjp/arxiv-mcp-server)** | Academic paper search and research | Advanced search and retrieval of scientific papers and research | +| **[Firecrawl MCP Server](https://modelscope.cn/mcp/servers/@mendableai/firecrawl-mcp-server)** | Web scraping and content extraction | Intelligent web scraping, data extraction, and content processing | + +## 🔗 Related Resources + +- [MCP Ecosystem Overview](./overview) +- [MCP Tools Integration Guide](../backend/tools/mcp) +- [Use Cases](./use-cases) diff --git a/doc/docs/en/mcp-ecosystem/mcp-server-development.md b/doc/docs/en/mcp-ecosystem/mcp-server-development.md deleted file mode 100644 index 06602b978..000000000 --- a/doc/docs/en/mcp-ecosystem/mcp-server-development.md +++ /dev/null @@ -1,200 +0,0 @@ -# MCP Server Development Guide - -This guide walks you through building your own MCP server with Python and the FastMCP framework, then connecting it to the Nexent platform. - -## 🌐 Language Support - -The MCP protocol provides SDKs for multiple programming languages: - -- **Python** ⭐ (recommended) -- **TypeScript** -- **Java** -- **Go** -- **Rust** -- Any other language that implements the MCP protocol - -### Why Do We Recommend Python? - -We use **Python** for the examples in this guide because it offers: - -- ✅ **Beginner-friendly syntax**: concise code that is easy to read -- ✅ **Rich ecosystem**: frameworks like FastMCP remove most boilerplate -- ✅ **Rapid prototyping**: you can stand up a working server in minutes -- ✅ **Mature libraries**: thousands of third-party packages are available - -If you are already comfortable in another language, feel free to use the corresponding MCP SDK. For a first MCP server, however, Python gives you the smoothest experience. - -## 📋 Prerequisites - -Install FastMCP before you start coding: - -```bash -pip install fastmcp -``` - -## 🚀 Quick Start - -### Minimal Example - -The snippet below creates a simple string utility server with FastMCP: - -```python -from fastmcp import FastMCP - -# Create an MCP server instance -mcp = FastMCP(name="String MCP Server") - -@mcp.tool( - name="calculate_string_length", - description="Calculate the length of a string" -) -def calculate_string_length(text: str) -> int: - return len(text) - -@mcp.tool( - name="to_uppercase", - description="Convert text to uppercase" -) -def to_uppercase(text: str) -> str: - return text.upper() - -@mcp.tool( - name="to_lowercase", - description="Convert text to lowercase" -) -def to_lowercase(text: str) -> str: - return text.lower() - -if __name__ == "__main__": - # Start with SSE transport - mcp.run(transport="sse", port=8000) -``` - -### Run the Server - -Save the code as `mcp_server.py` and execute: - -```bash -python mcp_server.py -``` - -You should see the server start successfully with the endpoint `http://127.0.0.1:8000/sse`. - -## 🔌 Integrate MCP Services with Nexent - -Once your MCP server is up, connect it to Nexent: - -### Step 1: Start the MCP Server - -Keep the server process running and note the public endpoint (for example `http://127.0.0.1:8000/sse`). - -### Step 2: Register the MCP Service in Nexent - -1. Open the **[Agent Development](../user-guide/agent-development.md)** page. -2. On the “Select Agent Tools” tab, click **MCP Configuration** on the right. -3. Enter the server name and server URL. - - ⚠️ **Important**: - 1. The server name must contain only letters and digits—no spaces or other symbols. - 2. When Nexent runs inside Docker and the MCP server runs on the host, replace `127.0.0.1` with `host.docker.internal`, for example `http://host.docker.internal:8000`. -4. Click **Add** to finish the registration. - -### Step 3: Use the MCP Tool - -During agent creation or editing, the newly registered MCP tool appears in the tool list and can be attached to any agent. - -## 🔧 Wrap Existing Workloads - -To expose existing business logic as MCP tools, call your internal APIs or libraries inside the tool functions. - -### Example: Wrap a REST API - -```python -from fastmcp import FastMCP -import requests - -# Create an MCP server instance -mcp = FastMCP("Course Statistics Server") - -@mcp.tool( - name="get_course_statistics", - description="Get course statistics such as average, max, min, and total students" -) -def get_course_statistics(course_id: str) -> str: - api_url = "https://your-school-api.com/api/courses/statistics" - response = requests.get(api_url, params={"course_id": course_id}) - - if response.status_code == 200: - data = response.json() - stats = data.get("statistics", {}) - return ( - f"Course {course_id} statistics:\n" - f"Average: {stats.get('average', 'N/A')}\n" - f"Max: {stats.get('max', 'N/A')}\n" - f"Min: {stats.get('min', 'N/A')}\n" - f"Total Students: {stats.get('total_students', 'N/A')}" - ) - return f"API request failed: {response.status_code}" - -if __name__ == "__main__": - # Start with SSE transport - mcp.run(transport="sse", port=8000) -``` - -### Example: Wrap an Internal Module - -```python -from fastmcp import FastMCP -from your_school_module import query_course_statistics - -# Create an MCP server instance -mcp = FastMCP("Course Statistics Server") - -@mcp.tool( - name="get_course_statistics", - description="Get course statistics such as average, max, min, and total students" -) -def get_course_statistics(course_id: str) -> str: - try: - stats = query_course_statistics(course_id) - return ( - f"Course {course_id} statistics:\n" - f"Average: {stats.get('average', 'N/A')}\n" - f"Max: {stats.get('max', 'N/A')}\n" - f"Min: {stats.get('min', 'N/A')}\n" - f"Total Students: {stats.get('total_students', 'N/A')}" - ) - except Exception as exc: - return f"Failed to query statistics: {exc}" - -if __name__ == "__main__": - # Start with SSE transport - mcp.run(transport="sse", port=8000) -``` - -## 📚 Additional Resources - -### Python - -- [FastMCP Documentation](https://github.com/modelcontextprotocol/python-sdk) (used throughout this guide) - -### Other Languages - -- [MCP TypeScript SDK](https://github.com/modelcontextprotocol/typescript-sdk) -- [MCP Java SDK](https://github.com/modelcontextprotocol/java-sdk) -- [MCP Go SDK](https://github.com/modelcontextprotocol/go-sdk) -- [MCP Rust SDK](https://github.com/modelcontextprotocol/rust-sdk) - -### General References - -- [MCP Protocol Specification](https://modelcontextprotocol.io/) -- [Nexent Agent Development Guide](../user-guide/agent-development.md) -- [MCP Tool Ecosystem Overview](./overview.md) - -## 🆘 Need Help? - -If you run into issues while developing MCP servers: - -1. Check the **[FAQ](../getting-started/faq.md)** -2. Ask questions in [GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions) -3. Review sample servers on the [ModelScope MCP Marketplace](https://www.modelscope.cn/mcp) - diff --git a/doc/docs/en/mcp-ecosystem/overview.md b/doc/docs/en/mcp-ecosystem/overview.md index a160507f5..cdbe3056d 100644 --- a/doc/docs/en/mcp-ecosystem/overview.md +++ b/doc/docs/en/mcp-ecosystem/overview.md @@ -2,36 +2,16 @@ Nexent is built on the Model Context Protocol (MCP) tool ecosystem, providing a flexible and extensible framework for integrating various tools and services. MCP serves as the "USB-C of AI" - a universal interface standard that allows AI agents to seamlessly connect with external data sources, tools, and services. -## What is MCP? +## 📖 What is MCP? The Model Context Protocol (MCP) is an open protocol that enables AI applications to securely connect to external data sources and tools. It provides a standardized way for AI models to access and interact with external systems, making it easier to build powerful, context-aware AI applications. -## MCP Community Hub - -The global MCP ecosystem is thriving with multiple platforms supporting MCP development and deployment: - -| Platform | Description | Notes | -|----------|-------------|-------| -| **[GitHub MCP Server](https://github.com/github/github-mcp-server)** | Deep integration with Claude, GPT-4, Copilot etc., supports Go and Python | OAuth/GitHub account authorization | -| **[Qdrant MCP Vector Server](https://github.com/qdrant/mcp-server-qdrant)** | Semantic vector storage with Python/Go compatibility | Compatible with LangChain and other tools | -| **[Anthropic Reference MCP Servers](https://github.com/modelcontextprotocol/servers)** | Lightweight teaching and prototyping tools, Python | Includes fetch, git and other universal tools | -| **[AWS Labs MCP Server](https://github.com/awslabs/mcp)** | AWS+Go+CDK cloud reference services | Suitable for cloud environments | -| **[MCP Hub China](https://www.mcp-cn.com/)** | Chinese curated high-quality MCP service platform | Focuses on quality over quantity, community-driven | -| **[ModelScope MCP Marketplace](https://modelscope.cn/mcp)** | China's largest MCP community with 1,500+ services | From Amap to Alipay, comprehensive service coverage | -| **Community MCP Servers** | Various scenario-specific source code collection | Mostly experimental and innovative tools | - -## Recommended MCP Tools - -| Tool Name | Function | Description | -|-----------|----------|-------------| -| **[Amap Maps](https://modelscope.cn/mcp/servers/@amap/amap-maps)** | Geographic services and navigation | Comprehensive mapping, geocoding, routing, and location services | -| **[Bing Search (Chinese)](https://modelscope.cn/mcp/servers/@yan5236/bing-cn-mcp-server)** | Web search in Chinese | Optimized Chinese web search and information retrieval | -| **[12306 Train Ticket Query](https://modelscope.cn/mcp/servers/@Joooook/12306-mcp)** | China railway ticket booking | Real-time train schedules, ticket availability, and booking assistance | -| **[Alipay MCP](https://modelscope.cn/mcp/servers/@alipay/mcp-server-alipay)** | Payment and financial services | Digital payments, financial tools, and services integration | -| **[Variflight Aviation](https://modelscope.cn/mcp/servers/@variflight-ai/variflight-mcp)** | Flight information and aviation data | Real-time flight tracking, schedules, and aviation analytics | -| **[Sequential Thinking](https://modelscope.cn/mcp/servers/@modelcontextprotocol/sequentialthinking)** | Structured problem-solving framework | Break down complex problems into manageable, sequential steps | -| **[ArXiv AI Search](https://modelscope.cn/mcp/servers/@blazickjp/arxiv-mcp-server)** | Academic paper search and research | Advanced search and retrieval of scientific papers and research | -| **[Firecrawl MCP Server](https://modelscope.cn/mcp/servers/@mendableai/firecrawl-mcp-server)** | Web scraping and content extraction | Intelligent web scraping, data extraction, and content processing | +## 🎯 MCP Platforms and Tools + +For curated recommendations of MCP platforms and tools, please visit our [MCP Recommendations](./mcp-recommendations) page, which includes: + +- **MCP Community Hub**: Discover global MCP platforms and marketplaces +- **Recommended MCP Tools**: Explore high-quality MCP services for various use cases ## Benefits of MCP diff --git a/doc/docs/en/getting-started/faq.md b/doc/docs/en/quick-start/faq.md similarity index 80% rename from doc/docs/en/getting-started/faq.md rename to doc/docs/en/quick-start/faq.md index 4311385f1..c0b8a2013 100644 --- a/doc/docs/en/getting-started/faq.md +++ b/doc/docs/en/quick-start/faq.md @@ -1,6 +1,6 @@ # Nexent FAQ -This FAQ addresses common questions and issues you might encounter while installing and using Nexent. For the basic installation steps, please refer to the [Installation & Development](./installation). For basic using instructions, please refer to the [User Guide](../user-guide/home-page). +This FAQ addresses common questions and issues you might encounter while installing and using Nexent. For the basic installation steps, please refer to the [Installation & Deployment](./installation). For basic using instructions, please refer to the [User Guide](../user-guide/home-page). ## 🚫 Common Errors & Operations @@ -57,18 +57,28 @@ This FAQ addresses common questions and issues you might encounter while install - **Q: Multi-turn chats fail when using the official DeepSeek API. How can I resolve this?** - A: The official DeepSeek API only accepts text payloads, but Nexent sends multimodal payloads, so multi-turn calls are rejected. Use a provider such as SiliconFlow that exposes DeepSeek models with multimodal compatibility. Our requests look like: + ```python { "role":"user", "content":[ { "type":"text", "text":"prompt" } ] } ``` + whereas DeepSeek expects: + ```python { "role":"user", "content":"prompt" } ``` +## 🐛 Known Issues & Feedback + +If you encounter any issues or want to check the latest status of known issues, please visit: + +- **Search similar issues**: [GitHub Issues](https://github.com/ModelEngine-Group/nexent/issues) - Search here to see if a similar issue has already been reported +- **Discuss issues**: [GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions) - Discuss problems and solutions with the community here + ## 💡 Need Help If your question isn't answered here: - Join our [Discord community](https://discord.gg/tb5H3S3wyv) for real-time support - Check our [GitHub Issues](https://github.com/ModelEngine-Group/nexent/issues) for similar problems -- Open a thread in [GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions). +- Open a thread in [GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions) diff --git a/doc/docs/en/getting-started/installation.md b/doc/docs/en/quick-start/installation.md similarity index 98% rename from doc/docs/en/getting-started/installation.md rename to doc/docs/en/quick-start/installation.md index 8026ecf47..c7115a3cd 100644 --- a/doc/docs/en/getting-started/installation.md +++ b/doc/docs/en/quick-start/installation.md @@ -95,4 +95,4 @@ For complete port mapping details, see our [Dev Container Guide](../deployment/d Want to build from source or add new features? Check the [Docker Build Guide](../deployment/docker-build) for step-by-step instructions. -For detailed setup instructions and customization options, see our [Development Guide](./development-guide). \ No newline at end of file +For detailed setup instructions and customization options, see our [Developer Guide](../developer-guide/overview). \ No newline at end of file diff --git a/doc/docs/en/deployment/upgrade-guide.md b/doc/docs/en/quick-start/upgrade-guide.md similarity index 57% rename from doc/docs/en/deployment/upgrade-guide.md rename to doc/docs/en/quick-start/upgrade-guide.md index 055ed3769..497212e06 100644 --- a/doc/docs/en/deployment/upgrade-guide.md +++ b/doc/docs/en/quick-start/upgrade-guide.md @@ -2,18 +2,62 @@ ## 🚀 Upgrade Overview -Follow these four steps to upgrade Nexent safely: +Follow these steps to upgrade Nexent safely: -1. Clean up existing containers and images -2. Pull the latest code and run the deployment script -3. Apply database migrations -4. Verify the deployment in your browser +1. Pull the latest code +2. Execute the upgrade script +3. Open the site to confirm service availability --- -## 🧹 Step 1: Clean up old images +## 🔄 Step 1: Update Code -Remove cached resources to avoid conflicts when redeploying: +Before updating, record the current deployment version and data directory information. + +- Current Deployment Version Location: APP_VERSION in backend/consts/const.py +- Data Directory Location: ROOT_DIR in docker/.env + +**Code downloaded via git** + +Update the code using git commands: + +```bash +git pull +``` + +**Code downloaded via ZIP package or other means** + +1. Re-download the latest code from GitHub and extract it. +2. If it exists, copy the deploy.options file from the docker directory of your previous deployment script directory to the docker directory of the new code directory. (If the file doesn't exist, you can ignore this step). + +## 🔄 Step 2: Execute the Upgrade + +Navigate to the docker directory of the updated code and run the upgrade script: + +```bash +bash upgrade.sh +``` + +If deploy.options is missing, the script will prompt you to manually enter configuration details from the previous deployment, such as the current version and data directory. Enter the information you recorded earlier. + +>💡 Tip +> The default scenario is quick deployment, which uses .env.example. +> If you need to configure voice models (STT/TTS), please add the relevant variables to .env.example in advance. We will provide a front-end configuration interface as soon as possible. + + +## 🌐 Step 3: Verify the deployment + +After deployment: + +1. Open `http://localhost:3000` in your browser. +2. Review the [User Guide](https://doc.nexent.tech/en/user-guide/home-page) to validate agent functionality. + + +## Optional Operations + +### 🧹 Clean Up Old Version Images + +If images were not updated correctly, you can clean up old containers and images before upgrading: ```bash # Stop and remove existing containers @@ -38,24 +82,9 @@ docker system prune -af --- -## 🔄 Step 2: Update code and redeploy - -```bash -git pull -cd nexent/docker -cp .env.example .env -bash deploy.sh -``` - -> 💡 Tip -> - `.env.example` works for default deployments. -> - Configure speech models (STT/TTS) in `.env` when needed. A frontend configuration flow is coming soon. - ---- - -## 🗄️ Step 3: Apply database migrations +## 🗄️ Manual Database Update -Run the SQL scripts shipped with each release to keep your schema up to date. +If some SQL files fail to execute during the upgrade, you can perform the update manually. ### ✅ Method A: Use a SQL editor (recommended) @@ -68,8 +97,8 @@ Run the SQL scripts shipped with each release to keep your schema up to date. - Password 3. Test the connection. When successful, you should see tables under the `nexent` schema. 4. Open a new query window. -5. Navigate to `/nexent/docker/sql`. Each file contains one migration script with its release date in the filename. -6. Execute every script dated after your previous deployment, in chronological order. +5. Navigate to the /nexent/docker/sql directory and open the failed SQL file(s) to view the script. +6. Execute the failed SQL file(s) and any subsequent version SQL files in order. > ⚠️ Important > - Always back up the database first, especially in production. @@ -97,14 +126,12 @@ Run the SQL scripts shipped with each release to keep your schema up to date. 3. Execute SQL files sequentially (host machine example): ```bash - # Example: If today is November 6th and your last update was on October 20th, - # and there are two new files 1030-update.sql and 1105-update.sql, # execute the following commands (please replace the placeholders with your actual values) - docker exec -i nexent-postgresql psql -U [YOUR_POSTGRES_USER] -d [YOUR_POSTGRES_DB] < ./sql/1030-update.sql - docker exec -i nexent-postgresql psql -U [YOUR_POSTGRES_USER] -d [YOUR_POSTGRES_DB] < ./sql/1105-update.sql + docker exec -i nexent-postgresql psql -U [YOUR_POSTGRES_USER] -d [YOUR_POSTGRES_DB] < ./sql/v1.1.1_1030-update.sql + docker exec -i nexent-postgresql psql -U [YOUR_POSTGRES_USER] -d [YOUR_POSTGRES_DB] < ./sql/v1.1.2_1105-update.sql ``` - Execute the scripts in chronological order based on your deployment date. + Execute the corresponding scripts for your deployment versions in version order. > 💡 Tips > - Load environment variables first if they are defined in `.env`: @@ -126,14 +153,3 @@ Run the SQL scripts shipped with each release to keep your schema up to date. > ```bash > docker exec -i nexent-postgres pg_dump -U [YOUR_POSTGRES_USER] [YOUR_POSTGRES_DB] > backup_$(date +%F).sql > ``` - ---- - -## 🌐 Step 4: Verify the deployment - -After deployment: - -1. Open `http://localhost:3000` in your browser. -2. Review the [User Guide](https://doc.nexent.tech/en/user-guide/home-page) to validate agent functionality. - - diff --git a/doc/docs/en/sdk/basic-usage.md b/doc/docs/en/sdk/basic-usage.md index efdca98a6..a6e67c89a 100644 --- a/doc/docs/en/sdk/basic-usage.md +++ b/doc/docs/en/sdk/basic-usage.md @@ -2,35 +2,7 @@ This guide provides a comprehensive introduction to using the Nexent SDK for building intelligent agents. -## 🚀 Installation - -### User Installation -If you want to use Nexent: - -```bash -# Recommended: Install from source -git clone https://github.com/ModelEngine-Group/nexent.git -cd nexent/sdk -uv pip install -e . - -# Or install using uv -uv add nexent -``` - -### Development Environment Setup -If you are a third-party SDK developer: - -```bash -# Install complete development environment (including Nexent) -cd nexent/sdk -uv pip install -e ".[dev]" # Includes all development tools (testing, code quality checks, etc.) -``` - -The development environment includes the following additional features: -- Code quality checking tools (ruff) -- Testing framework (pytest) -- Data processing dependencies (unstructured) -- Other development dependencies +> Installation options for both full-stack and SDK-only workflows are documented in [Environment Preparation](../developer-guide/environment-setup). ## ⚡ Quick Start @@ -100,11 +72,7 @@ agent.run("Your question here") ## 📡 Using agent_run (recommended for streaming) -When you need to consume messages as an "event stream" on server or client, use `agent_run`. It executes the agent in a background thread and continuously yields JSON messages, making it easy to render in UIs and collect logs. - -Reference: [Run agent with agent_run](./core/agent-run) - -Minimal example: +When you need server/client event streams, use `agent_run`. It runs the agent in a background thread and yields JSON strings from `MessageObserver`, so UIs can render incremental updates. ```python import json @@ -144,11 +112,54 @@ async def main(): async for message in agent_run(agent_run_info): message_data = json.loads(message) - print(message_data) + print(message_data) # each message is a JSON string asyncio.run(main()) ``` +### 🛰️ Stream message format + +Each yielded JSON string typically contains: + +- `type`: message type (maps to `ProcessType`, e.g., `STEP_COUNT`, `MODEL_OUTPUT_THINKING`, `PARSE`, `EXECUTION_LOGS`, `FINAL_ANSWER`, `ERROR`) +- `content`: text payload +- `agent_name` (optional): which agent emitted the message + +### 🧠 Chat history (optional) + +Pass history to keep context: + +```python +from nexent.core.agents.agent_model import AgentHistory + +history = [ + AgentHistory(role="user", content="Hi"), + AgentHistory(role="assistant", content="Hello!"), +] + +agent_run_info = AgentRunInfo( + # ... + history=history, +) +``` + +### 🌐 MCP tool integration (optional) + +Provide MCP endpoints to auto-load remote tools: + +```python +agent_run_info = AgentRunInfo( + # ... + mcp_host=["http://localhost:3000"], # or dict with url/transport +) +``` + +### ⏹️ Interrupt gracefully + +```python +stop_event.set() # agent stops after the current step finishes +``` + ## 🔧 Configuration Options ### ⚙️ Agent Configuration @@ -176,7 +187,7 @@ search_tool = ExaSearchTool( ## 📚 More Resources -- **[Run agent with agent_run](./core/agent-run)** +- **[Streaming with agent_run](#using-agent_run-recommended-for-streaming)** - **[Tool Development Guide](./core/tools)** - **[Model Architecture Guide](./core/models)** - **[Agents](./core/agents)** \ No newline at end of file diff --git a/doc/docs/en/sdk/core/agent-run.md b/doc/docs/en/sdk/core/agent-run.md deleted file mode 100644 index 2eb8ee40e..000000000 --- a/doc/docs/en/sdk/core/agent-run.md +++ /dev/null @@ -1,166 +0,0 @@ -# Run agent with agent_run (Streaming) - -`agent_run` provides a concise and thread-friendly way to run an agent while exposing real-time streaming output via `MessageObserver`. It is ideal for server-side or frontend event stream rendering, as well as MCP tool integration scenarios. - -## Quick Start - -```python -import json -import asyncio -import logging -from threading import Event - -from nexent.core.agents.run_agent import agent_run -from nexent.core.agents.agent_model import ( - AgentRunInfo, - AgentConfig, - ModelConfig -) -from nexent.core.utils.observer import MessageObserver - - -async def main(): - # 1) Create message observer (for receiving streaming messages) - observer = MessageObserver(lang="en") - - # 2) External stop flag (useful to interrupt from UI) - stop_event = Event() - - # 3) Configure model - model_config = ModelConfig( - cite_name="gpt-4", # Model alias (custom, referenced by AgentConfig) - api_key="", - model_name="Qwen/Qwen2.5-32B-Instruct", - url="https://api.siliconflow.cn/v1", - temperature=0.3, - top_p=0.9 - ) - - # 4) Configure Agent - agent_config = AgentConfig( - name="example_agent", - description="An example agent that can execute Python code and search the web", - prompt_templates=None, - tools=[], - max_steps=5, - model_name="gpt-4", # Corresponds to model_config.cite_name - provide_run_summary=False, - managed_agents=[] - ) - - # 5) Assemble run info - agent_run_info = AgentRunInfo( - query="How many letter r are in strrawberry?", # Example question - model_config_list=[model_config], - observer=observer, - agent_config=agent_config, - mcp_host=None, # Optional: MCP service addresses - history=None, # Optional: chat history - stop_event=stop_event - ) - - # 6) Run with streaming and consume messages - async for message in agent_run(agent_run_info): - message_data = json.loads(message) - message_type = message_data.get("type", "unknown") - content = message_data.get("content", "") - print(f"[{message_type}] {content}") - - # 7) Read final answer (if any) - final_answer = observer.get_final_answer() - if final_answer: - print(f"\nFinal Answer: {final_answer}") - - -if __name__ == "__main__": - logging.disable(logging.CRITICAL) - asyncio.run(main()) -``` - -Tip: Store sensitive config such as `api_key` in environment variables or a secrets manager, not in code. - -## Message Stream Format and Handling - -Internally, `agent_run` executes the agent in a background thread and continuously yields JSON strings from the `MessageObserver` message buffer. You can parse these fields for categorized display or logging. - -- Important fields - - `type`: message type (corresponds to `ProcessType`) - - `content`: text content - - `agent_name`: optional, which agent produced this message - -Common `type` values (from `ProcessType`): -- `AGENT_NEW_RUN`: new task started -- `STEP_COUNT`: step updates -- `MODEL_OUTPUT_THINKING` / `MODEL_OUTPUT_CODE`: model thinking/code snippets -- `PARSE`: code parsing results -- `EXECUTION_LOGS`: Python execution logs -- `FINAL_ANSWER`: final answer -- `ERROR`: error information - -## Configuration Reference - -### ModelConfig - -- `cite_name`: model alias (referenced by `AgentConfig.model_name`) -- `api_key`: model service API key -- `model_name`: model invocation name -- `url`: base URL of the model service -- `temperature` / `top_p`: sampling params - -### AgentConfig - -- `name`: agent name -- `description`: agent description -- `prompt_templates`: optional, Jinja template dict -- `tools`: tool configuration list (see ToolConfig) -- `max_steps`: maximum steps -- `model_name`: model alias (corresponds to `ModelConfig.cite_name`) -- `provide_run_summary`: whether sub-agents provide run summary -- `managed_agents`: list of sub-agent configurations - -### Pass Chat History (optional) - -You can pass historical messages via `AgentRunInfo.history`, and Nexent will write them into internal memory: - -```python -from nexent.core.agents.agent_model import AgentHistory - -history = [ - AgentHistory(role="user", content="Hi"), - AgentHistory(role="assistant", content="Hello, how can I help you?"), -] - -agent_run_info = AgentRunInfo( - # ... other fields omitted - history=history, -) -``` - -## MCP Tool Integration (optional) - -If you provide `mcp_host` (list of MCP service addresses), Nexent will automatically pull remote tools through `ToolCollection.from_mcp` and inject them into the agent: - -```python -agent_run_info = AgentRunInfo( - # ... other fields omitted - mcp_host=["http://localhost:3000"], -) -``` - -Friendly error messages (EN/ZH) will be produced if the connection fails. - -## Interrupt Execution - -During execution, you can trigger interruption via `stop_event.set()`: - -```python -stop_event.set() # The agent will gracefully stop after the current step completes -``` - -## Relation to CoreAgent - -- `agent_run` is a wrapper over `NexentAgent` and `CoreAgent`, responsible for: - - Constructing `CoreAgent` (including models and tools) - - Injecting history into memory - - Driving streaming execution and forwarding buffered messages from `MessageObserver` -- You can also directly use `CoreAgent.run(stream=True)` to handle streaming yourself (see `core/agents.md`); `agent_run` provides a more convenient threaded and JSON-message oriented interface. \ No newline at end of file diff --git a/doc/docs/en/sdk/core/agents.md b/doc/docs/en/sdk/core/agents.md index 0c9fa7fb3..4517a5f3b 100644 --- a/doc/docs/en/sdk/core/agents.md +++ b/doc/docs/en/sdk/core/agents.md @@ -53,144 +53,7 @@ ProcessType enumeration defines the following processing stages: ## 🤖 Agent Development -### Creating Basic Agents - -```python -from nexent.core.utils.observer import MessageObserver -from nexent.core.agents.core_agent import CoreAgent -from nexent.core.models.openai_llm import OpenAIModel -from nexent.core.tools import ExaSearchTool, KnowledgeBaseSearchTool - -# Create message observer -observer = MessageObserver() - -# Create model (model and Agent must use the same observer) -model = OpenAIModel( - observer=observer, - model_id="your-model-id", - api_key="your-api-key", - api_base="your-api-base" -) - -# Create tools -search_tool = ExaSearchTool(exa_api_key="your-exa-key", observer=observer, max_results=5) -kb_tool = KnowledgeBaseSearchTool(top_k=5, observer=observer) - -# Create Agent -agent = CoreAgent( - observer=observer, - tools=[search_tool, kb_tool], - model=model, - name="my_agent", - max_steps=5 -) - -# Run Agent -agent.run("Your question") -``` - -> For a simpler way to consume "JSON streaming messages", see: **[Run agent with agent_run](./agent-run)**. - -#### System Prompt Templates -System prompt templates are located in `backend/prompts/`: - -- **knowledge_summary_agent.yaml**: Knowledge base summary agent -- **manager_system_prompt_template.yaml**: Manager system prompt template -- **utils/**: Prompt utilities - -- If you do not provide `system_prompt`, the SmolAgents default prompt will be used. -- To customize, it is recommended to render from `manager_system_prompt_template.yaml` and pass it in. - -##### Load and override system_prompt (recommended) - -```python -from pathlib import Path -import yaml -from jinja2 import Environment, BaseLoader - -from nexent.core.agents.core_agent import CoreAgent -from nexent.core.models.openai_llm import OpenAIModel - -# 1) Load YAML template text -prompt_yaml_path = Path("backend/prompts/manager_system_prompt_template.yaml") -yaml_text = prompt_yaml_path.read_text(encoding="utf-8") -yaml_data = yaml.safe_load(yaml_text) - -# 2) Render Jinja template in 'system_prompt' key -system_prompt_template = yaml_data["system_prompt"] -jinja_env = Environment(loader=BaseLoader()) -rendered_system_prompt = jinja_env.from_string(system_prompt_template).render( - APP_NAME="Nexent Agent", - APP_DESCRIPTION="Enterprise-grade AI agent", - duty="Answer user questions and use tools when needed", - tools={}, # Provide tools summary if needed - managed_agents={}, # Provide managed agents summary if needed - knowledge_base_summary=None, - constraint="Follow organization policies and ensure data/access security", - authorized_imports=["requests", "pandas"], - few_shots="", - memory_list=[], -) - -# Option A: Pass only rendered string -observer = MessageObserver() -model = OpenAIModel(observer=observer, model_id="your-model-id", api_key="your-api-key", api_base="your-api-base") -agent = CoreAgent( - observer=observer, - model=model, - tools=[search_tool, kb_tool], - system_prompt=rendered_system_prompt, - name="my_agent", -) - -# Option B: Replace the 'system_prompt' in the loaded YAML and pass the dict (advanced) -yaml_data["system_prompt"] = rendered_system_prompt -agent_yaml_prompt = CoreAgent( - observer=observer, - model=model, - tools=[search_tool, kb_tool], - system_prompt=yaml_data, - name="my_agent", -) -``` - -> Note: `manager_system_prompt_template.yaml` also includes other template blocks such as `managed_agent`, `planning`, and `final_answer`. Typically, only `system_prompt` is needed; load additional blocks as required for advanced multi-agent scenarios. - -#### Agent Implementation Steps - -1. **Create Agent Instance**: - ```python - from nexent.core.agents.core_agent import CoreAgent - from nexent.core.models.openai_llm import OpenAIModel - - model = OpenAIModel( - model_id="your-model-id", - api_key="your-api-key", - api_base="your-api-base" - ) - agent = CoreAgent( - model=model, - tools=[your_tools], - system_prompt="Your system prompt" - ) - ``` - -2. **Configure Agent Behavior**: - - Add custom tools through the `tools` parameter - - Set behavior through `system_prompt` - - Configure parameters like `max_steps` - -3. **Advanced Configuration**: - ```python - agent = CoreAgent( - model=model, - tools=custom_tools, - system_prompt=custom_prompt, - max_steps=10, - verbose=True, - additional_authorized_imports=["requests", "pandas"] - ) - ``` +Core usage examples now live in [Basic Usage](../basic-usage#using-agent_run-recommended-for-streaming), including both `CoreAgent.run` and the streaming `agent_run` helper. This page focuses on module concepts (architecture, MessageObserver, patterns) rather than code walkthroughs. ## 🛠️ Tool Integration @@ -262,4 +125,4 @@ Standard execution pattern for problem-solving agents: 3. **Scaling Strategy**: Plan for increased load and usage 4. **Security Considerations**: Validate inputs and protect API access -For detailed implementation examples and advanced patterns, please refer to [Development Guide](../../getting-started/development-guide). \ No newline at end of file +For detailed implementation examples and advanced patterns, please refer to the [Developer Guide](../../developer-guide/overview). \ No newline at end of file diff --git a/doc/docs/en/sdk/overview.md b/doc/docs/en/sdk/overview.md index 8b104fbdd..01f89cbc2 100644 --- a/doc/docs/en/sdk/overview.md +++ b/doc/docs/en/sdk/overview.md @@ -4,7 +4,7 @@ Nexent is a powerful, enterprise-grade Agent SDK that revolutionizes intelligent ## 🚀 Installation and Usage -For detailed installation instructions and usage guides, see our **[Basic Usage Guide](./basic-usage)**. If you need event-stream based rendering on server/frontend, see **[Run agent with agent_run](./core/agent-run)**. +For detailed installation instructions and usage guides, see our **[Basic Usage Guide](./basic-usage#using-agent_run-recommended-for-streaming)** for both `CoreAgent.run` and streaming `agent_run`. ## ⭐ Key Features @@ -32,7 +32,7 @@ For detailed features and usage instructions, please refer to **[Features Explai Nexent provides complete intelligent agent solutions with multi-model support, MCP integration, dynamic tool loading, and distributed execution. -- Quick streaming run: **[Run agent with agent_run](./core/agent-run)** +- Quick streaming run: **[Streaming with agent_run](./basic-usage#using-agent_run-recommended-for-streaming)** - Detailed Agent development and usage: **[Agents](./core/agents)** ## 🛠️ Tool Collection diff --git a/doc/docs/en/user-guide/agent-development.md b/doc/docs/en/user-guide/agent-development.md index 538041ab1..bc7d6c42d 100644 --- a/doc/docs/en/user-guide/agent-development.md +++ b/doc/docs/en/user-guide/agent-development.md @@ -74,7 +74,7 @@ Nexent allows you to quickly and easily use third-party MCP tools to enrich agen Many third-party services such as [ModelScope](https://www.modelscope.cn/mcp) provide MCP services, which you can quickly integrate and use. -You can also develop your own MCP services and connect them to Nexent; see [MCP Server Development](../mcp-ecosystem/mcp-server-development.md). +You can also develop your own MCP services and connect them to Nexent; see [MCP Tool Development](../backend/tools/mcp). ### ⚙️ Custom Tools @@ -156,4 +156,4 @@ After completing agent development, you can: 2. Interact with agents in **[Start Chat](./start-chat)** 3. Configure **[Memory Management](./memory-management)** to enhance the agent's personalization capabilities -If you encounter any issues during agent development, please refer to our **[FAQ](../getting-started/faq)** or ask for support in [GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions). +If you encounter any issues during agent development, please refer to our **[FAQ](../quick-start/faq)** or ask for support in [GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions). diff --git a/doc/docs/en/user-guide/agent-market.md b/doc/docs/en/user-guide/agent-market.md index 5f9e229cf..231aa09aa 100644 --- a/doc/docs/en/user-guide/agent-market.md +++ b/doc/docs/en/user-guide/agent-market.md @@ -34,4 +34,4 @@ While Agent Market is being built, you can: 2. Create new agents in **[Agent Development](./agent-development)**. 3. Test agents in **[Start Chat](./start-chat)**. -Need help? Check the **[FAQ](../getting-started/faq)** or open a thread in [GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions). +Need help? Check the **[FAQ](../quick-start/faq)** or open a thread in [GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions). diff --git a/doc/docs/en/user-guide/agent-space.md b/doc/docs/en/user-guide/agent-space.md index b4929d454..cd3f66fce 100644 --- a/doc/docs/en/user-guide/agent-space.md +++ b/doc/docs/en/user-guide/agent-space.md @@ -62,4 +62,4 @@ Once you finish reviewing agents you can: 2. Continue iterating in **[Agent Development](./agent-development)**. 3. Enhance retention with **[Memory Management](./memory-management)**. -Need help? Check the **[FAQ](../getting-started/faq)** or open a thread in [GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions). +Need help? Check the **[FAQ](../quick-start/faq)** or open a thread in [GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions). diff --git a/doc/docs/en/user-guide/home-page.md b/doc/docs/en/user-guide/home-page.md index 9ee4cd1a3..d4fdde989 100644 --- a/doc/docs/en/user-guide/home-page.md +++ b/doc/docs/en/user-guide/home-page.md @@ -48,4 +48,4 @@ Alternatively, you can click the "Quick Setup" button on the homepage or in the ## 💡 Get Help -Need help? Check the **[FAQ](../getting-started/faq)** or open a thread in [GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions). \ No newline at end of file +Need help? Check the **[FAQ](../quick-start/faq)** or open a thread in [GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions). \ No newline at end of file diff --git a/doc/docs/en/user-guide/knowledge-base.md b/doc/docs/en/user-guide/knowledge-base.md index 7a8c39332..fdce554ac 100644 --- a/doc/docs/en/user-guide/knowledge-base.md +++ b/doc/docs/en/user-guide/knowledge-base.md @@ -75,4 +75,4 @@ After completing knowledge base configuration, we recommend you continue with: 1. **[Agent Development](./agent-development)** – Create and configure agents 2. **[Start Chat](./start-chat)** – Interact with your agent -Need help? Check the **[FAQ](../getting-started/faq)** or open a thread in [GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions). \ No newline at end of file +Need help? Check the **[FAQ](../quick-start/faq)** or open a thread in [GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions). \ No newline at end of file diff --git a/doc/docs/en/user-guide/local-tools/email-tools.md b/doc/docs/en/user-guide/local-tools/email-tools.md new file mode 100644 index 000000000..4a42435cc --- /dev/null +++ b/doc/docs/en/user-guide/local-tools/email-tools.md @@ -0,0 +1,58 @@ +--- +title: Email Tools +--- + +# Email Tools + +Email tools help agents receive notifications and send results via common mail providers. + +## 🧭 Tool List + +- `get_email`: Fetch emails by time window and sender, with max count limits +- `send_email`: Send HTML emails with multiple recipients, CC, and BCC + +## 🧰 Example Use Cases + +- Periodically pull the past 7 days of notifications for summarization +- Send execution results to recipients and CC teammates +- Filter alerts from specific monitoring senders + +## 🧾 Parameters & Behavior + +### get_email +- `days`: Look back in days, default 7. +- `sender`: Filter by email address, optional. +- `max_emails`: Max messages to return, default 10. +- Requires IMAP host, port, username, password; SSL supported. +- Returns JSON with subject, time, sender, and body summary. + +### send_email +- `to`: Comma-separated recipients. +- `subject`: Email subject. +- `content`: HTML body. +- `cc`, `bcc`: Comma-separated CC/BCC, optional. +- Requires SMTP host, port, username, password; optional sender display name and SSL. +- Returns delivery status, subject, and recipient info. + +## 🛠️ How to Use + +1. **Collect provider settings**: IMAP/SMTP host, port, account/app password, SSL. +2. **Receive**: Call `get_email` with `days`/`sender`/`max_emails`; start with small ranges to test. +3. **Send**: Call `send_email` with recipients, subject, and HTML content; add `cc`/`bcc` if needed. +4. **Post-process**: Summarize or extract key info from fetched bodies if desired. + +## 🛡️ Safety & Best Practices + +- Use provider-issued app passwords or restricted accounts; avoid exposing primary credentials. +- Keep `max_emails` reasonable to avoid heavy pulls. +- Verify recipient lists before sending; restrict allowed domains in production. + +## 📮 Common Provider Settings + +Use app passwords where available and enable IMAP/SMTP in account settings. Ports reflect common defaults—always confirm with the provider’s latest docs. + +- QQ Mail: IMAP `imap.qq.com:993` (SSL); SMTP `smtp.qq.com:465` (SSL); enable IMAP/SMTP and generate an authorization code. +- Gmail: IMAP `imap.gmail.com:993`; SMTP `smtp.gmail.com:465` (SSL) or `587` (STARTTLS); enable IMAP and use an app password. +- Outlook (Microsoft 365/Hotmail): IMAP `outlook.office365.com:993`; SMTP `smtp.office365.com:587` (STARTTLS); tenants may require modern auth or app passwords. +- 163 Mail: IMAP `imap.163.com:993` (SSL); SMTP `smtp.163.com:465` (SSL); enable client authorization password in mailbox settings. + diff --git a/doc/docs/en/user-guide/local-tools/file-tools.md b/doc/docs/en/user-guide/local-tools/file-tools.md new file mode 100644 index 000000000..0869312fc --- /dev/null +++ b/doc/docs/en/user-guide/local-tools/file-tools.md @@ -0,0 +1,55 @@ +--- +title: File Tools +--- + +# File Tools + +File tools provide safe, workspace-scoped operations for files and folders. All paths must be relative to the workspace root (default `/mnt/nexent`). + +## 🧭 Tool List + +- `create_directory`: Create directories (auto-create parents, optional permissions) +- `create_file`: Create files and write content (auto-create parents) +- `read_file`: Read file content with metadata +- `list_directory`: Show directory tree +- `move_item`: Move files/folders without overwriting +- `delete_file`: Delete a single file (irreversible) +- `delete_directory`: Recursively delete a directory (irreversible) + +## 🧰 Example Use Cases + +- Initialize project folders and config files +- Inspect logs or check file size/line counts +- Browse workspace structure before editing +- Move artifacts to backup locations +- Clean up temp files or unused directories + +## 🧾 Parameters & Behavior + +### Common constraints +- Paths must stay inside the workspace; absolute or escaping paths are blocked. +- Delete/move operations are irreversible—double-check before running. + +### Key parameters +- `directory_path` / `file_path` / `source_path` / `destination_path`: required relative paths. +- `permissions` (`create_directory`): octal string, default `755`. +- `encoding` (`create_file` / `read_file`): default `utf-8`. +- `max_depth`, `show_hidden`, `show_size` (`list_directory`): control tree depth, hidden items, and size display. + +### Returns +- Success responses include relative/absolute paths, sizes, and existence flags. +- Errors explain boundary checks, existing targets, or permission issues. + +## 🛠️ How to Use + +1. **Create**: Use `create_directory` or `create_file` with a relative path; set permissions/encoding when needed. +2. **Inspect**: Use `list_directory` to browse; use `read_file` for content and metadata. +3. **Move**: Use `move_item`; it stops if the destination already exists to avoid overwrites. +4. **Delete**: Use `delete_file` or `delete_directory`; confirm the target since deletion cannot be undone. + +## 🛡️ Safety & Best Practices + +- Operate only inside the workspace; avoid absolute paths or `..` traversal. +- Before deleting, run `list_directory` or `read_file` to confirm the target. +- Large files trigger warnings; consider chunked processing instead of single full reads. + diff --git a/doc/docs/en/user-guide/local-tools/index.md b/doc/docs/en/user-guide/local-tools/index.md index 5946ca133..27dc72ebc 100644 --- a/doc/docs/en/user-guide/local-tools/index.md +++ b/doc/docs/en/user-guide/local-tools/index.md @@ -1,63 +1,26 @@ -# Local Tools +# Overview -The Nexent platform provides a rich set of local tools that help agents complete various system-level tasks and local operations. These tools offer powerful execution capabilities through direct interaction with local systems or remote servers. +Local tools let agents interact with the workspace, remote hosts, and external services across files, email, search, multimodal, and remote terminals. Each tool has its own page grouped by capability. -## 🛠️ Available Tools +## 📂 Directory -Nexent preloads a set of reusable local tools grouped by capability: email, file, and search. The Terminal tool is offered separately to provide remote shell capabilities. The following sections list each tool alongside its core features so agents can quickly pick the right capability. +- [File Tools](./file-tools): Create/read/move/delete files and folders; list directory trees. +- [Email Tools](./email-tools): Receive IMAP mail; send HTML mail with CC/BCC. +- [Search Tools](./search-tools): Local/DataMate KB search plus Exa/Tavily/Linkup web search. +- [Multimodal Tools](./multimodal-tools): Download/parse/analyze text files and images. +- [Terminal Tool](./terminal-tool): Persistent SSH sessions for remote commands. -### 📧 Email Tools +## ⚙️ Configuration Entry -- **get_email**: Fetches mailbox content through IMAP. Supports restricting by time range (days), filtering by sender, and limiting the number of returned messages. The tool automatically decodes multilingual subjects and bodies, and returns subject, timestamp, sender, and body summary to simplify downstream analysis. -- **send_email**: Sends HTML emails via SMTP. Supports multiple recipients, CC, and BCC, as well as custom sender display names. All connections use SSL/TLS. The result reports delivery status and subject for easy tracking. +1. Go to **[Agent Development](../agent-development)**. +2. In “Select Agent Tools,” find the tool and open configuration. +3. Fill connection/auth parameters, save, and run a test connection first. -### 📂 File Tools +## 💡 Usage Tips -- **create_directory**: Creates nested directories at the specified relative path, skipping existing levels and returning the result together with the final absolute path. -- **create_file**: Creates a file and writes content. Automatically creates parent directories if needed, supports custom encoding (default UTF-8), and allows empty files. -- **read_file**: Reads a text file and returns metadata such as size, line count, and encoding. Warns when the file is large (10 MB safety threshold). -- **list_directory**: Lists directory contents in a tree view. Supports maximum recursion depth, hidden file display, and file sizes. Output includes both visual text and structured JSON to clearly present project structure. -- **move_item**: Moves files or folders within the workspace. Automatically creates destination directories, avoids overwriting existing targets, and reports how many items were moved and their total size. -- **delete_file**: Deletes a single file with permission and existence checks. Provides clear error messages on failure. -- **delete_directory**: Recursively deletes a directory and its contents with existence, permission, and safety checks. Returns the deleted relative path. +- File paths must stay inside the workspace and use relative paths. +- Set API keys for public search in the platform’s secure config. +- Terminal access touches remote hosts—confirm network and account controls. +- Delete/move operations are irreversible; double-check targets first. -> All file paths must be relative to the workspace (default `/mnt/nexent`). The system automatically validates paths to prevent escaping the workspace boundary. - -### 🔍 Search Tools - -- **knowledge_base_search**: Queries the local knowledge-base index with `hybrid`, `accurate`, or `semantic` modes. Can filter by index name and returns sources, scores, and citation indices, ideal for answering questions from internal documents or industry references. -- **exa_search**: Calls the EXA API for real-time web search. Supports configuring the number of results and optionally returns image links (with additional filtering performed server-side). Requires an EXA API key in the tool configuration, which you can obtain for free at [exa.ai](https://exa.ai/). -- **tavily_search**: Uses the Tavily API to retrieve webpages, particularly strong for news and current events. Returns both text results and related image URLs, with optional image filtering. Request a free API key from [tavily.com](https://www.tavily.com/). -- **linkup_search**: Uses the Linkup API to fetch text and images. In addition to regular webpages, it can return image-only results, making it useful when mixed media references are required. Register at [linkup.so](https://www.linkup.so/) to obtain a free API key. - -### 🖼️ Multimodal Tools - -- **analyze_text_file**: Based on user queries and the S3 URL, HTTP URL, and HTTPS URL of a text file, parse the file and use a large language model to understand it, answering user questions. An available large language model needs to be configured on the model management page. -- **analyze_image**: Based on user queries and the S3 URL, HTTP URL, and HTTPS URL of an image, use a visual language model to analyze and understand the image, answering user questions. An available visual language model needs to be configured on the model management page. - -### 🖥️ Terminal Tool - -The **Terminal Tool** is one of Nexent's core local capabilities that provides a persistent SSH session. Agents can execute remote commands, perform system inspections, read logs, or deploy services. Refer to the dedicated [Terminal Tool guide](./terminal-tool) for detailed setup, parameters, and security guidance. - -## 🔧 Tool Configuration - -All local tools need to be configured inside Agent Development: - -1. Navigate to the **[Agent Development](../agent-development)** page -2. Select the agent you want to configure -3. In the "Select Agent Tools" tab, locate the desired local tool -4. Click the configuration button and fill in the required connection parameters -5. Test the connection to ensure the configuration is correct -6. Save the configuration and enable the tool - -## ⚠️ Security Considerations - -When using local tools, keep the following security practices in mind: - -- **Permission Control**: Create dedicated users for each tool and follow least privilege -- **Network Security**: Use VPN or IP allowlists to restrict access -- **Authentication Security**: Favor key-based authentication and rotate keys regularly -- **Command Restrictions**: Configure command whitelists in production environments -- **Audit Logging**: Enable detailed logging for all operations - -Need help? Please open a thread in [GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions). +Need help? Open a thread in [GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions). diff --git a/doc/docs/en/user-guide/local-tools/multimodal-tools.md b/doc/docs/en/user-guide/local-tools/multimodal-tools.md new file mode 100644 index 000000000..6780f5f1e --- /dev/null +++ b/doc/docs/en/user-guide/local-tools/multimodal-tools.md @@ -0,0 +1,48 @@ +--- +title: Multimodal Tools +--- + +# Multimodal Tools + +Multimodal tools analyze text files and images with model support. URLs can be S3, HTTP, or HTTPS. + +## 🧭 Tool List + +- `analyze_text_file`: Download and extract text, then analyze per question +- `analyze_image`: Download images and interpret them with a vision-language model + +## 🧰 Example Use Cases + +- Summarize documents stored in buckets +- Explain screenshots, product photos, or chart images +- Produce per-file or per-image answers aligned with the input order + +## 🧾 Parameters & Behavior + +### analyze_text_file +- `file_url_list`: List of URLs (`s3://bucket/key`, `/bucket/key`, `http(s)://`). +- `query`: User question/analysis goal. +- Downloads each file, extracts text, and returns an array of analyses in input order. + +### analyze_image +- `image_urls_list`: List of URLs (`s3://bucket/key`, `/bucket/key`, `http(s)://`). +- `query`: User focus/question. +- Downloads each image, runs VLM analysis, and returns an array matching input order. + +## ⚙️ Prerequisites + +- Configure storage access (e.g., MinIO/S3) and data processing service to fetch files. +- Provide an LLM for `analyze_text_file` and a VLM for `analyze_image`. + +## 🛠️ How to Use + +1. Prepare accessible URLs and confirm permissions. +2. Call the corresponding tool with the URL list and question; multiple resources are supported at once. +3. Use results in the same order as inputs for display or follow-up steps. + +## 💡 Best Practices + +- For large files, preprocess or chunk them to reduce timeouts. +- For multiple images, be explicit about the focus (e.g., “focus on chart trends”) to improve answers. +- If results are empty or errors occur, verify URL accessibility and model readiness. + diff --git a/doc/docs/en/user-guide/local-tools/search-tools.md b/doc/docs/en/user-guide/local-tools/search-tools.md new file mode 100644 index 000000000..114f6fad3 --- /dev/null +++ b/doc/docs/en/user-guide/local-tools/search-tools.md @@ -0,0 +1,68 @@ +--- +title: Search Tools +--- + +# Search Tools + +Search tools cover internet search plus local and DataMate knowledge bases, useful for real-time info, industry materials, and private docs. + +## 🧭 Tool List + +- Local/private knowledge bases: + - `knowledge_base_search`: Local KB search with multiple modes + - `datamate_search_tool`: Search DataMate KB +- Public web search: + - `exa_search`: Web and image search via Exa + - `tavily_search`: Web and image search via Tavily + - `linkup_search`: Mixed text/image search via Linkup + +## 🧰 Example Use Cases + +- Retrieve internal docs, specs, and industry references (KB, DataMate) +- Fetch latest news or web evidence (Exa / Tavily / Linkup) +- Return image references alongside text (with optional filtering) + +## 🧾 Parameters & Behavior + +### knowledge_base_search +- `query`: Required. +- `search_mode`: `hybrid` (default), `accurate`, or `semantic`. +- `index_names`: Optional list of KB names (user-facing or internal). +- Returns title, path/URL, source type, score, and citation info. Warns if no KB is selected. + +### datamate_search_tool +- `query`: Required. +- `top_k`: Default 10. +- `threshold`: Default 0.2. +- `kb_page` / `kb_page_size`: Paginate DataMate KB list. +- Requires DataMate host and port. Returns filename, download URL, and scores. + +### exa_search / tavily_search / linkup_search +- `query`: Required. +- `max_results`: Configurable count. +- Image filtering: On by default to drop unrelated images; can be disabled to return raw image URLs. +- Requires API keys: + - Exa: EXA API Key + - Tavily: Tavily API Key + - Linkup: Linkup API Key +- Returns title, URL, summary, and optional image URLs (deduped). + +## 🛠️ How to Use + +1. **Pick the source**: Use `knowledge_base_search` or `datamate_search_tool` for private data; Exa/Tavily/Linkup for public info. +2. **Tune mode/count**: Switch `search_mode` for KB; adjust `max_results` and image filtering for public search. +3. **Scope**: Provide `index_names` for targeted KB search; tune `top_k` and `threshold` for DataMate precision. +4. **Consume results**: JSON output is ready for answers or summarization, with citation indices for referencing. + +## 🛡️ Safety & Best Practices + +- Store API keys in the platform’s secure config, never in prompts. +- Sync KB content before querying to avoid stale answers. +- If queries are too broad, shorten or split them; if images are over-filtered, disable filtering to review raw URLs. + +## 🔑 Getting API Keys (Public Search) + +- Exa: Sign up at [exa.ai](https://exa.ai/) and create an EXA API Key in the console. +- Tavily: Register at [tavily.com](https://www.tavily.com/) and get a Tavily API Key from the dashboard. +- Linkup: Sign up at [linkup.so](https://www.linkup.so/) and create a Linkup API Key in your account. + diff --git a/doc/docs/en/user-guide/memory-management.md b/doc/docs/en/user-guide/memory-management.md index 033b67fd9..6e1330b78 100644 --- a/doc/docs/en/user-guide/memory-management.md +++ b/doc/docs/en/user-guide/memory-management.md @@ -137,4 +137,4 @@ With memory configured you can: 2. Manage all agents in **[Agent Space](./agent-space)**. 3. Build more agents inside **[Agent Development](./agent-development)**. -Need help? Check the **[FAQ](../getting-started/faq)** or open a thread in [GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions). \ No newline at end of file +Need help? Check the **[FAQ](../quick-start/faq)** or open a thread in [GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions). \ No newline at end of file diff --git a/doc/docs/en/user-guide/model-management.md b/doc/docs/en/user-guide/model-management.md index 22e0e0f22..db38ea46d 100644 --- a/doc/docs/en/user-guide/model-management.md +++ b/doc/docs/en/user-guide/model-management.md @@ -217,4 +217,4 @@ After closing the Model Management flow, continue with: 1. **[Knowledge Base](./knowledge-base)** – Create and manage knowledge bases. 2. **[Agent Development](./agent-development)** – Build and configure agents. -Need help? Check the **[FAQ](../getting-started/faq)** or open a thread in [GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions). +Need help? Check the **[FAQ](../quick-start/faq)** or open a thread in [GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions). diff --git a/doc/docs/en/user-guide/quick-setup.md b/doc/docs/en/user-guide/quick-setup.md index ef4e43bab..bdf403cf9 100644 --- a/doc/docs/en/user-guide/quick-setup.md +++ b/doc/docs/en/user-guide/quick-setup.md @@ -50,4 +50,4 @@ After finishing Quick Setup: 2. Use **[Start Chat](./start-chat)** to talk to your agents. 3. Configure **[Memory Management](./memory-management)** to give agents persistent memory. -Need help? Check the **[FAQ](../getting-started/faq)** or open a thread in [GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions). \ No newline at end of file +Need help? Check the **[FAQ](../quick-start/faq)** or open a thread in [GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions). \ No newline at end of file diff --git a/doc/docs/en/user-guide/start-chat.md b/doc/docs/en/user-guide/start-chat.md index d37b09e7b..9593cb6ec 100644 --- a/doc/docs/en/user-guide/start-chat.md +++ b/doc/docs/en/user-guide/start-chat.md @@ -208,4 +208,4 @@ Congratulations! You now master all the core features of Nexent. We look forward ### Get Help -Need help? Check the **[FAQ](../getting-started/faq)** or open a thread in [GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions). \ No newline at end of file +Need help? Check the **[FAQ](../quick-start/faq)** or open a thread in [GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions). \ No newline at end of file diff --git a/doc/docs/en/user-guide/user-management.md b/doc/docs/en/user-guide/user-management.md index 5f16ed179..2f03650cc 100644 --- a/doc/docs/en/user-guide/user-management.md +++ b/doc/docs/en/user-guide/user-management.md @@ -34,4 +34,4 @@ While waiting for User Management you can: 2. Configure models in **[Model Management](./model-management)**. 3. Chat with agents via **[Start Chat](./start-chat)**. -Need help? Check the **[FAQ](../getting-started/faq)** or open a thread in [GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions). \ No newline at end of file +Need help? Check the **[FAQ](../quick-start/faq)** or open a thread in [GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions). \ No newline at end of file diff --git a/doc/docs/zh/backend/api-reference.md b/doc/docs/zh/backend/api-reference.md index 1c0d0b0e3..6d0ec35e3 100644 --- a/doc/docs/zh/backend/api-reference.md +++ b/doc/docs/zh/backend/api-reference.md @@ -1,1205 +1,7 @@ -# Nexent Community API 文档 +# Nexent API 文档 -本文档提供了 Nexent Community 后端所有 API 端点的全面概述。 +## 🔗 访问 API 文档 -## 目录 -1. [基础应用](#基础应用) -2. [代理应用](#代理应用) -3. [配置同步应用](#配置同步应用) -4. [会话管理应用](#会话管理应用) -5. [数据处理应用](#数据处理应用) -6. [Elasticsearch 应用](#elasticsearch-应用) -7. [ME 模型管理应用](#me-模型管理应用) -8. [模型管理应用](#模型管理应用) -9. [代理应用](#代理应用) -10. [文件管理应用](#文件管理应用) -11. [语音应用](#语音应用) +后端接口文档已托管在 Apifox,请通过以下链接查看最新版本: -## 基础应用 - -基础应用作为主要的 FastAPI 应用程序,包含所有其他路由器并提供全局异常处理。 - -### 全局异常处理器 -- `HTTPException`: 返回带有错误消息的 JSON 响应 -- `Exception`: 返回通用 500 错误响应 - -## 代理应用 - -### 端点 - -#### POST /api/agent/run -执行代理并处理提供的请求。 - -**请求体:** -```json -{ - "query": "string", - "history": "array", - "minio_files": "array" -} -``` - -**响应:** -- 使用 SSE(服务器发送事件)的流式响应 -- 实时返回代理的响应 - -#### POST /api/agent/reload_config -手动触发配置重新加载。 - -**响应:** -- 成功/失败状态 - -## 配置同步应用 - -### 端点 - -#### POST /api/config/save_config -将配置保存到环境变量。 - -**请求体:** -```json -{ - "app": { - "name": "string", - "description": "string", - "icon": { - "type": "string", - "avatarUri": "string", - "customUrl": "string" - } - }, - "models": { - "llm": { - "name": "string", - "displayName": "string", - "apiConfig": { - "apiKey": "string", - "modelUrl": "string" - } - } - // ... 其他模型配置 - }, - "data": { - "selectedKbNames": "array", - "selectedKbModels": "array", - "selectedKbSources": "array" - } -} -``` - -**响应:** -```json -{ - "message": "配置保存成功", - "status": "saved" -} -``` - -#### GET /api/config/load_config -从环境变量加载配置。 - -**响应:** -```json -{ - "config": { - "app": { - "name": "string", - "description": "string", - "icon": { - "type": "string", - "avatarUri": "string", - "customUrl": "string" - } - }, - "models": { - "llm": { - "name": "string", - "displayName": "string", - "apiConfig": { - "apiKey": "string", - "modelUrl": "string" - } - } - // ... 其他模型配置 - }, - "data": { - "selectedKbNames": ["string"], - "selectedKbModels": ["string"], - "selectedKbSources": ["string"] - } - } -} -``` - -## 会话管理应用 - -### 端点 - -#### PUT /api/conversation/create -创建新会话。 - -**请求体:** -```json -{ - "title": "string" -} -``` - -**响应:** -```json -{ - "code": 0, - "message": "success", - "data": { - "conversation_id": "string", - "conversation_title": "string", - "create_time": "number", - "update_time": "number" - } -} -``` - -#### GET /api/conversation/list -获取所有会话。 - -**响应:** -```json -{ - "code": 0, - "message": "success", - "data": [ - { - "conversation_id": "string", - "conversation_title": "string", - "create_time": "number", - "update_time": "number" - } - ] -} -``` - -#### POST /api/conversation/rename -重命名会话。 - -**请求体:** -```json -{ - "conversation_id": "number", - "name": "string" -} -``` - -**响应:** -```json -{ - "code": 0, - "message": "success", - "data": true -} -``` - -#### DELETE /api/conversation/{conversation_id} -删除会话。 - -**响应:** -```json -{ - "code": 0, - "message": "success", - "data": true -} -``` - -#### GET /api/conversation/{conversation_id} -获取会话历史。 - -**响应:** -```json -{ - "code": 0, - "message": "success", - "data": { - "conversation_id": "string", - "create_time": "number", - "message": [ - { - "role": "string", - "message": "array", - "message_id": "string", - "opinion_flag": "string", - "picture": "array", - "search": "array" - } - ] - } -} -``` - -#### POST /api/conversation/sources -获取消息源信息。 - -**请求体:** -```json -{ - "conversation_id": "number", - "message_id": "string", - "type": "string" -} -``` - -**响应:** -```json -{ - "code": 0, - "message": "success", - "data": { - "searches": [ - { - "title": "string", - "text": "string", - "source_type": "string", - "url": "string", - "filename": "string", - "published_date": "string", - "score": "number", - "score_details": { - "accuracy": "number", - "semantic": "number" - } - } - ], - "images": ["string"] - } -} -``` - -#### POST /api/conversation/generate_title -生成会话标题。 - -**请求体:** -```json -{ - "conversation_id": "number", - "history": [ - { - "role": "string", - "content": "string" - } - ] -} -``` - -**响应:** -```json -{ - "code": 0, - "message": "success", - "data": "string" -} -``` - -#### POST /api/conversation/message/update_opinion -更新消息点赞/踩状态。 - -**请求体:** -```json -{ - "message_id": "string", - "opinion": "string" -} -``` - -**响应:** -```json -{ - "code": 0, - "message": "success", - "data": true -} -``` - -## 数据处理应用 - -### 端点 - -#### POST /api/tasks -创建新的数据处理任务。 - -**请求体:** -```json -{ - "source": "string", - "source_type": "string", - "chunking_strategy": "string", - "index_name": "string", - "additional_params": { - "key": "value" - } -} -``` - -**响应:** -```json -{ - "task_id": "string" -} -``` - -#### POST /api/tasks/batch -创建批量数据处理任务。 - -**请求体:** -```json -{ - "sources": [ - { - "source": "string", - "source_type": "string", - "chunking_strategy": "string", - "index_name": "string", - "additional_params": { - "key": "value" - } - } - ] -} -``` - -**响应:** -```json -{ - "task_ids": ["string"] -} -``` - -#### GET /api/tasks/{task_id} -获取任务状态。 - -**响应:** -```json -{ - "id": "string", - "status": "string", - "created_at": "string", - "updated_at": "string", - "error": "string" -} -``` - -#### GET /api/tasks -列出所有任务。 - -**响应:** -```json -{ - "tasks": [ - { - "id": "string", - "status": "string", - "created_at": "string", - "updated_at": "string", - "error": "string" - } - ] -} -``` - -#### GET /api/tasks/indices/{index_name}/tasks -获取特定索引的所有活动任务。 - -**响应:** -```json -{ - "index_name": "string", - "files": [ - { - "path_or_url": "string", - "status": "string" - } - ] -} -``` - -#### GET /api/tasks/{task_id}/details -获取任务状态和结果。 - -**响应:** -```json -{ - "id": "string", - "status": "string", - "created_at": "string", - "updated_at": "string", - "error": "string", - "results": "object" -} -``` - -## Elasticsearch 应用 - -### 端点 - -#### POST /api/indices/{index_name} -创建新的向量索引。 - -**参数:** -- `index_name`: 索引名称 -- `embedding_dim`: 可选的嵌入向量维度 - -**响应:** -```json -{ - "status": "success", - "message": "string", - "embedding_dim": "number" -} -``` - -#### DELETE /api/indices/{index_name} -删除索引。 - -**响应:** -```json -{ - "status": "success", - "message": "string" -} -``` - -#### GET /api/indices -列出所有索引。 - -**参数:** -- `pattern`: 匹配索引名称的模式 -- `include_stats`: 是否包含索引统计信息 - -**响应:** -```json -{ - "indices": ["string"], - "count": "number", - "indices_info": [ - { - "name": "string", - "stats": { - "docs": "number", - "size": "string" - } - } - ] -} -``` - -#### GET /api/indices/{index_name}/info -获取索引信息。 - -**参数:** -- `include_files`: 是否包含文件列表 -- `include_chunks`: 是否包含文本块 - -**响应:** -```json -{ - "base_info": { - "docs": "number", - "size": "string" - }, - "search_performance": { - "query_time": "number", - "hits": "number" - }, - "fields": { - "field_name": { - "type": "string" - } - }, - "files": [ - { - "path_or_url": "string", - "file": "string", - "file_size": "number", - "create_time": "string", - "status": "string", - "chunks": [ - { - "id": "string", - "title": "string", - "content": "string", - "create_time": "string" - } - ], - "chunks_count": "number" - } - ] -} -``` - -#### POST /api/indices/{index_name}/documents -索引文档。 - -**请求体:** -```json -{ - "task_id": "string", - "results": [ - { - "metadata": { - "filename": "string", - "title": "string", - "languages": ["string"], - "author": "string", - "date": "string", - "file_size": "number", - "creation_date": "string" - }, - "source": "string", - "text": "string", - "source_type": "string" - } - ] -} -``` - -**响应:** -```json -{ - "success": true, - "message": "string", - "total_indexed": "number", - "total_submitted": "number" -} -``` - -#### DELETE /api/indices/{index_name}/documents -删除文档。 - -**请求体:** -```json -{ - "path_or_url": "string" -} -``` - -**响应:** -```json -{ - "status": "success", - "deleted_count": "number" -} -``` - -#### POST /api/indices/search/accurate -执行精确搜索。 - -**请求体:** -```json -{ - "query": "string", - "index_names": ["string"], - "top_k": "number" -} -``` - -**响应:** -```json -{ - "results": [ - { - "id": "string", - "title": "string", - "content": "string", - "score": "number", - "index": "string" - } - ], - "total": "number", - "query_time_ms": "number" -} -``` - -#### POST /api/indices/search/semantic -执行语义搜索。 - -**请求体:** -```json -{ - "query": "string", - "index_names": ["string"], - "top_k": "number" -} -``` - -**响应:** -```json -{ - "results": [ - { - "id": "string", - "title": "string", - "content": "string", - "score": "number", - "index": "string" - } - ], - "total": "number", - "query_time_ms": "number" -} -``` - -#### POST /api/indices/search/hybrid -执行混合搜索。 - -**请求体:** -```json -{ - "query": "string", - "index_names": ["string"], - "top_k": "number", - "weight_accurate": "number" -} -``` - -**响应:** -```json -{ - "results": [ - { - "id": "string", - "title": "string", - "content": "string", - "score": "number", - "index": "string", - "score_details": { - "accurate": "number", - "semantic": "number" - } - } - ], - "total": "number", - "query_time_ms": "number" -} -``` - -#### GET /api/indices/health -检查 API 和 Elasticsearch 健康状态。 - -**响应:** -```json -{ - "status": "healthy", - "elasticsearch": "connected", - "indices_count": "number" -} -``` - -## ME 模型管理应用 - -### 端点 - -#### GET /api/me/model/list -获取 ME 模型列表。 - -**请求体:** -```json -{ - "type": "string", - "timeout": "number" -} -``` - -**响应:** -```json -{ - "code": 200, - "message": "Successfully retrieved", - "data": [ - { - "id": "string", - "name": "string", - "type": "string", - "description": "string" - } - ] -} -``` - -#### GET /api/me/healthcheck -检查 ME 模型连接性。 - -**请求体:** -```json -{ - "timeout": "number" -} -``` - -**响应:** -```json -{ - "code": 200, - "message": "Connection successful", - "data": { - "status": "Connected", - "desc": "Connection successful", - "connect_status": "AVAILABLE" - } -} -``` - -#### GET /api/me/model/healthcheck -检查特定模型健康状态。 - -**请求体:** -```json -{ - "model_name": "string" -} -``` - -**响应:** -```json -{ - "code": 200, - "message": "Model health check successful", - "data": { - "status": "string", - "desc": "string", - "connect_status": "string" - } -} -``` - -## 模型管理应用 - -### 端点 - -#### POST /api/model/create -创建新模型。 - -**请求体:** -```json -{ - "model_name": "string", - "display_name": "string", - "connect_status": "string" -} -``` - -**响应:** -```json -{ - "code": 200, - "message": "Model created successfully", - "data": null -} -``` - -#### POST /api/model/update -更新模型。 - -**请求体:** -```json -{ - "model_name": "string", - "display_name": "string", - "connect_status": "string" -} -``` - -**响应:** -```json -{ - "code": 200, - "message": "Model updated successfully", - "data": { - "model_name": "string" - } -} -``` - -#### POST /api/model/delete -删除模型。 - -**请求体:** -```json -{ - "model_name": "string" -} -``` - -**响应:** -```json -{ - "code": 200, - "message": "Model deleted successfully", - "data": { - "model_name": "string" - } -} -``` - -#### GET /api/model/list -获取所有模型。 - -**响应:** -```json -{ - "code": 200, - "message": "Successfully retrieved model list", - "data": [ - { - "model_id": "string", - "model_name": "string", - "model_repo": "string", - "display_name": "string", - "connect_status": "string" - } - ] -} -``` - -#### GET /api/model/healthcheck -检查模型健康状态。 - -**请求体:** -```json -{ - "model_name": "string" -} -``` - -**响应:** -```json -{ - "code": 200, - "message": "Model health check successful", - "data": { - "status": "string", - "desc": "string", - "connect_status": "string" - } -} -``` - -#### GET /api/model/get_connect_status -获取模型连接状态。 - -**请求体:** -```json -{ - "model_name": "string" -} -``` - -**响应:** -```json -{ - "code": 200, - "message": "Successfully retrieved connection status", - "data": { - "model_name": "string", - "connect_status": "string" - } -} -``` - -#### POST /api/model/update_connect_status -更新模型连接状态。 - -**请求体:** -```json -{ - "model_name": "string", - "connect_status": "string" -} -``` - -**响应:** -```json -{ - "code": 200, - "message": "Successfully updated connection status", - "data": { - "model_name": "string", - "connect_status": "string" - } -} -``` - -#### POST /api/model/verify_config -验证模型配置的连通性,不保存到数据库。 - -**请求体:** -```json -{ - "model_name": "string", - "model_type": "string", - "base_url": "string", - "api_key": "string", - "max_tokens": "number", - "embedding_dim": "number" -} -``` - -**响应:** -```json -{ - "code": 200, - "message": "模型配置连通性验证成功/失败", - "data": { - "connectivity": "boolean", - "message": "string", - "connect_status": "string" - } -} -``` - -## 代理应用 - -### 端点 - -#### GET /api/proxy/image -代理远程图片。 - -**请求体:** -```json -{ - "url": "string" -} -``` - -**响应:** -```json -{ - "success": true, - "base64": "string", - "content_type": "string" -} -``` - -## 文件管理应用 - -### 端点 - -#### POST /api/file/upload -上传文件。 - -**请求体:** -```json -{ - "file": ["file"], - "chunking_strategy": "string", - "index_name": "string" -} -``` - -**响应:** -```json -{ - "message": "string", - "uploaded_files": ["string"], - "process_tasks": { - "task_id": "string" - } -} -``` - -#### POST /api/file/storage -上传文件到存储。 - -**请求体:** -```json -{ - "files": ["file"], - "folder": "string" -} -``` - -**响应:** -```json -{ - "message": "string", - "success_count": "number", - "failed_count": "number", - "results": [ - { - "success": true, - "file_name": "string", - "url": "string" - } - ] -} -``` - -#### GET /api/file/storage -获取存储文件。 - -**请求体:** -```json -{ - "prefix": "string", - "limit": "number", - "include_urls": true -} -``` - -**响应:** -```json -{ - "total": "number", - "files": [ - { - "name": "string", - "size": "number", - "url": "string" - } - ] -} -``` - -#### GET /api/file/storage/{object_name} -获取存储文件。 - -**请求体:** -```json -{ - "download": true, - "expires": "number" -} -``` - -**响应:** -```json -{ - "success": true, - "url": "string", - "expires": "string" -} -``` - -#### DELETE /api/file/storage/{object_name} -删除存储文件。 - -**响应:** -```json -{ - "success": true, - "message": "string" -} -``` - -#### POST /api/file/storage/batch-urls -获取批量文件 URL。 - -**请求体:** -```json -{ - "object_names": ["string"], - "expires": "number" -} -``` - -**响应:** -```json -{ - "urls": [ - { - "object_name": "string", - "success": true, - "url": "string" - } - ] -} -``` - -#### POST /api/file/preprocess -预处理代理文件。 - -**请求体:** -```json -{ - "query": "string", - "files": ["file"] -} -``` - -**响应:** -- 使用 SSE(服务器发送事件)的流式响应 -- 实时返回处理的内容 - -## 语音应用 - -### 端点 - -#### WebSocket /api/voice/stt/ws -实时语音转文字流。 - -**输入:** -- 音频流 - -**输出:** -- 文字流 - -#### WebSocket /api/voice/tts/ws -实时文字转语音流。 - -**输入:** -```json -{ - "text": "string" -} -``` - -**输出:** -- 音频流 - -## 依赖项 - -应用程序具有以下依赖项: - -1. 基础应用: - - 所有其他应用路由器 - -2. 代理应用: - - 代理工具 - - 会话管理服务 - -3. 配置同步应用: - - python-dotenv - - 配置工具 - -4. 会话管理应用: - - 会话数据库 - - 会话管理工具 - -5. 数据处理应用: - - 数据处理核心 - - 任务状态工具 - -6. Elasticsearch 应用: - - Elasticsearch 核心 - - 嵌入模型 - - 数据处理服务 - -7. ME 模型管理应用: - - Requests - - 模型健康服务 - -8. 模型管理应用: - - 模型管理数据库 - - 模型健康服务 - - 模型名称工具 - -9. 代理应用: - - aiohttp - - 图像过滤工具 - -10. 文件管理应用: - - 文件管理工具 - - 附件数据库 - - 数据处理服务 - -11. 语音应用: - - STT 模型 - - TTS 模型 - - WebSocket 支持 \ No newline at end of file +[Nexent API](https://8icfxll43r.apifox.cn) diff --git a/doc/docs/zh/backend/overview.md b/doc/docs/zh/backend/overview.md index 70a9c87fa..e48a0e735 100644 --- a/doc/docs/zh/backend/overview.md +++ b/doc/docs/zh/backend/overview.md @@ -202,4 +202,4 @@ python backend/mcp_service.py # MCP服务 - 资源池管理 - 自动扩展能力 -详细的后端开发指南,请参阅 [开发指南](../getting-started/development-guide)。 \ No newline at end of file +详细的后端开发指南,请参阅 [开发者指南](../developer-guide/overview)。 \ No newline at end of file diff --git a/doc/docs/zh/backend/prompt-development.md b/doc/docs/zh/backend/prompt-development.md index e70fa7531..2cc2407bd 100644 --- a/doc/docs/zh/backend/prompt-development.md +++ b/doc/docs/zh/backend/prompt-development.md @@ -1,135 +1,48 @@ # 提示词开发指南 -本指南提供了关于 Nexent 中用于创建不同类型智能体的提示词模板系统的全面信息。`backend/prompts/` 目录中的 YAML 文件定义了各种智能体类型的系统提示词、规划提示词和其他关键提示词组件。 +本指南说明 `backend/prompts/` 下提示词模板的组织方式,以及如何为新智能体扩展模板。 -## 文件命名规范 +## 📂 文件布局与命名 -命名格式为 `{agent_type}_agent.yaml`,其中: -- `agent_type`:描述智能体的主要功能或用途(如 manager、search 等) +- 核心模板位于 `backend/prompts/`,通常命名为 `{agent_type}_agent.yaml` 或 `{scope}_prompt_template.yaml`。 +- 工具类/辅助模板位于 `backend/prompts/utils/`,用于元提示生成(如标题、提示词生成)。 -## 提示词模板结构 +## 🧩 模板结构 -每个 YAML 文件包含以下主要部分: +常见字段: +- `system_prompt`:角色/职责、执行流程、工具与子智能体使用规则、Python 代码约束、示例。 +- `planning`:`initial_facts`、`initial_plan` 及更新前后提示。 +- `managed_agent`:分配与汇报的子智能体提示。 +- `final_answer`:生成最终答案前后提示。 +- `tools_requirement`:工具使用优先级与规范。 +- `few_shots`:少样本示例。 -### 1. system_prompt +## 🔄 变量占位 -系统提示词是智能体的核心部分,定义了智能体的角色、能力和行为规范。通常包含以下部分: +模板中常用占位符: +- `tools`、`managed_agents` +- `task`、`remaining_steps` +- `authorized_imports` +- `facts_update`、`answer_facts` -- **核心职责**:智能体的主要职责和能力描述 -- **执行流程**:智能体执行任务的标准流程和方法 -- **可用资源**:智能体可以使用的工具和子智能体列表 -- **资源使用要求**:使用不同工具的优先级和策略 -- **Python代码规范**:编写代码的规范和约束 -- **示例模板**:展示智能体执行任务的示例 +## 📑 关键模板 -### 2. planning +- 管理器智能体:`manager_system_prompt_template.yaml`、`manager_system_prompt_template_en.yaml` +- 被管理智能体:`managed_system_prompt_template.yaml`、`managed_system_prompt_template_en.yaml` +- 知识总结:`knowledge_summary_agent.yaml`、`knowledge_summary_agent_en.yaml` +- 文件分析:`analyze_file.yaml`、`analyze_file_en.yaml` +- 聚类总结:`cluster_summary_agent.yaml`、`cluster_summary_reduce.yaml`(含 `_zh` 版本) +- 工具/生成辅助(`utils/`):`prompt_generate*.yaml`、`generate_title*.yaml` -包含用于任务规划的各种提示词: +## 🚀 如何扩展 -- **initial_facts**:初始事实收集提示词 -- **initial_plan**:初始计划制定提示词 -- **update_facts_pre_messages**:更新事实前的提示词 -- **update_facts_post_messages**:更新事实后的提示词 -- **update_plan_pre_messages**:更新计划前的提示词 -- **update_plan_post_messages**:更新计划后的提示词 +1. 选取最相近模板复制,调整 `system_prompt`/`planning` 适配场景。 +2. 保留必要占位符,除非明确不需要。 +3. 工具列表需与实际可用工具一致,必要时更新 `authorized_imports`。 +4. 用小任务验证“思考 → 代码 → 观察 → 重复”流程是否符合预期。 -### 3. managed_agent +## ✅ 规范与提示 -定义与子智能体交互的提示词: - -- **task**:分配给子智能体的任务提示词 -- **report**:子智能体报告结果的提示词 - -### 4. final_answer - -定义最终答案生成的提示词: - -- **pre_messages**:生成最终答案前的提示词 -- **post_messages**:生成最终答案后的提示词 - -### 5. tools_requirement - -定义工具使用规范和优先级的提示词。 - -### 6. few_shots - -提供少样本学习示例的提示词,帮助智能体更好地理解任务执行方式。 - -## 模板变量 - -提示词模板中使用以下特殊变量进行动态替换: - -- `{{tools}}`:可用工具列表 -- `{{managed_agents}}`:可用子智能体列表 -- `{{task}}`:当前任务描述 -- `{{authorized_imports}}`:授权导入的Python模块 -- `{{facts_update}}`:更新后的事实列表 -- `{{answer_facts}}`:已知事实列表 -- `{{remaining_steps}}`:剩余执行步骤数 - -## 可用的提示词模板 - -### 核心模板 - -1. **管理器智能体模板** - - `manager_system_prompt_template.yaml` - 中文版本 - - `manager_system_prompt_template_en.yaml` - 英文版本 - - 这些模板定义了核心管理器智能体,负责协调和调度各种助手和工具来高效解决复杂任务。 - -2. **被管理智能体模板** - - `managed_system_prompt_template.yaml` - 中文版本 - - `managed_system_prompt_template_en.yaml` - 英文版本 - - 这些模板定义了专门的智能体,在管理器智能体的协调下执行特定任务。 - -3. **专业智能体模板** - - `knowledge_summary_agent.yaml` - 知识总结智能体(中文) - - `knowledge_summary_agent_en.yaml` - 知识总结智能体(英文) - - `analyze_file.yaml` - 文件分析智能体(中文) - - `analyze_file_en.yaml` - 文件分析智能体(英文) - -### 工具模板 - -位于 `utils/` 目录中: - -1. **提示词生成模板** - - `prompt_generate.yaml` - 中文版本 - - `prompt_generate_en.yaml` - 英文版本 - - 这些模板帮助为不同智能体类型生成高效、清晰的提示词。 - -2. **标题生成模板** - - `generate_title.yaml` - 中文版本 - - `generate_title_en.yaml` - 英文版本 - - 用于为对话生成标题。 - -## 执行流程 - -标准智能体执行流程遵循以下模式: - -1. **思考**:分析当前任务状态和进展 -2. **代码**:编写简单的Python代码 -3. **观察**:查看代码执行结果 -4. **重复**:继续循环直到任务完成 - -## 代码规范 - -在提示词中编写Python代码时: - -1. 使用格式 `代码:\n```py\n` 表示可执行代码 -2. 使用格式 `代码:\n```code:语言类型\n` 表示仅用于展示的代码 -3. 只使用已定义的变量,变量将在多次调用之间持续保持 -4. 使用 `print()` 函数让变量信息可见 -5. 使用关键字参数进行工具和智能体调用 -6. 避免在一轮对话中进行过多的工具调用 -7. 只能从授权模块导入:`{{authorized_imports}}` - -## 最佳实践 - -1. **任务分解**:将复杂任务分解为可管理的子任务 -2. **专业匹配**:根据智能体专长分配任务 -3. **信息整合**:整合不同智能体的输出 -4. **效率优化**:避免重复工作 -5. **结果评估**:评估智能体返回结果,必要时提供额外指导 +- 可执行代码块使用 ````py````,仅展示代码用 ````code:语言````。 +- 工具调用尽量用关键字参数,单轮避免过多工具调用。 +- 注释/文档保持英文,遵守仓库规则与授权导入限制。 diff --git a/doc/docs/zh/backend/tools/index.md b/doc/docs/zh/backend/tools/index.md index 208425aed..94e1fe36e 100644 --- a/doc/docs/zh/backend/tools/index.md +++ b/doc/docs/zh/backend/tools/index.md @@ -26,6 +26,6 @@ ## 需要帮助? -- 查看我们的 [常见问题](../../getting-started/faq) 了解常见工具集成问题 +- 查看我们的 [常见问题](../../quick-start/faq) 了解常见工具集成问题 - 加入我们的 [Discord 社区](https://discord.gg/tb5H3S3wyv) 获取实时支持 - 查看 [GitHub Issues](https://github.com/ModelEngine-Group/nexent/issues) 了解已知问题 \ No newline at end of file diff --git a/doc/docs/zh/backend/tools/mcp.md b/doc/docs/zh/backend/tools/mcp.md index 4655249b4..1d8ec42e1 100644 --- a/doc/docs/zh/backend/tools/mcp.md +++ b/doc/docs/zh/backend/tools/mcp.md @@ -1,591 +1,200 @@ -# Nexent MCP架构说明 +# Model Context Protocol (MCP) -## 系统架构概述 +## 🌟 什么是 MCP? -Nexent采用**本地MCP服务 + 直接远程连接**的架构,通过MCP(Model Context Protocol)协议实现本地服务与远程服务的统一管理。系统包含两个核心服务: +Model Context Protocol (MCP) 是连接 AI 与外部系统(数据、工具、工作流)的开放标准,相当于 AI 的 "USB-C"。它让主机(如 Claude Desktop、Nexent)按统一协议发现并调用 MCP 服务器暴露的工具/资源。 -### 1. 主服务 (FastAPI) - 端口 5010 5014 -- **用途**:提供Web管理界面和RESTful API,作为前端唯一入口 -- **特点**:面向用户管理,包含认证、多租户支持,管理MCP服务器配置 -- **启动文件**:`config_service.py, runtime_service.py` +## 🧭 MCP 能力 -### 2. 本地MCP服务 (FastMCP) - 端口 5011 -- **用途**:提供本地MCP协议服务,挂载本地工具 -- **特点**:MCP协议标准,仅提供本地服务,不代理远程服务 -- **启动文件**:`mcp_service.py` +- **Tools**:可由 LLM 调用的函数(需用户授权) +- **Resources**:可读取的文件型数据 +- **Prompts**:服务器可共享的模板 +- 主机可通过标准协议连接本地或远程 MCP 服务器,自动发现能力 -### 3. 远程MCP服务 -- **用途**:外部MCP服务,提供远程工具 -- **特点**:智能体执行时直接连接,不通过本地MCP服务代理 +## 🌐 语言支持 -## 核心组件架构 +MCP 协议支持多种编程语言: -```mermaid -graph TD - A["前端客户端"] --> B["主服务 (FastAPI)(端口: 5010)"] - - B --> B1["remote_mcp_app.py(MCP管理路由)"] - B1 --> B2["remote_mcp_service.py(MCP服务逻辑)"] - B2 --> B3["数据库(MCP配置存储)"] - - B --> C["create_agent_info.py(智能体配置)"] - C --> C1["工具发现与配置"] - C --> C2["MCP服务器过滤"] - - B --> D["run_agent.py(智能体执行)"] - D --> D1["ToolCollection(MCP工具集合)"] - D1 --> E["本地MCP服务 (FastMCP)(端口: 5011)"] - D1 --> F1["远程MCP服务1(直接连接)"] - D1 --> F2["远程MCP服务2(直接连接)"] - D1 --> F3["远程MCP服务n(直接连接)"] - - E --> E1["local_mcp_service(本地工具)"] - - style A fill:#e1f5fe - style B fill:#f3e5f5 - style E fill:#e8f5e8 - style B1 fill:#fff3e0 - style C fill:#e8f5e8 - style D fill:#fce4ec - style F1 fill:#fff3e0 - style F2 fill:#fff3e0 - style F3 fill:#fff3e0 -``` - -## 核心功能模块 +- **Python** ⭐(推荐新手使用) +- **TypeScript** +- **Java** +- **Go** +- **Rust** +- 以及其他支持 MCP 协议的语言 -### 1. 本地MCP服务管理 (mcp_service.py) +我们推荐使用 **Python**,因为它语法简洁易学,拥有 FastMCP 等丰富框架支持,可以快速构建原型,且有数千个成熟的第三方库可用。 -**本地MCP服务实现**: -```python -# 初始化本地MCP服务 -nexent_mcp = FastMCP(name="nexent_mcp") - -# 挂载本地服务(稳定,不受远程服务影响) -nexent_mcp.mount(local_mcp_service.name, local_mcp_service) -``` +## 🚀 快速开始 -**特点**: -- 仅提供本地MCP服务,挂载本地工具 -- 不代理远程MCP服务 -- 基于FastMCP框架,提供标准MCP协议支持 -- 服务稳定运行,端口5011 +### 📋 前置要求 -### 2. MCP管理API (remote_mcp_app.py) - -提供完整的MCP服务器管理接口: - -#### 获取远程MCP工具信息 -```http -POST /api/mcp/tools?service_name={name}&mcp_url={url} -Authorization: Bearer {token} -``` - -#### 添加远程MCP服务器 -```http -POST /api/mcp/add?mcp_url={url}&service_name={name} -Authorization: Bearer {token} -``` - -#### 删除远程MCP服务器 -```http -DELETE /api/mcp/?service_name={name}&mcp_url={url} -Authorization: Bearer {token} -``` - -#### 获取MCP服务器列表 -```http -GET /api/mcp/list -Authorization: Bearer {token} -``` +在开始之前,请安装 FastMCP: -#### MCP服务器健康检查 -```http -GET /api/mcp/healthcheck?mcp_url={url}&service_name={name} -Authorization: Bearer {token} +```bash +pip install fastmcp ``` -### 3. MCP服务逻辑 (remote_mcp_service.py) +### 📝 基础示例 -**核心功能**: +创建一个简单的字符串处理 MCP 服务器: -#### 服务器健康检查 ```python -async def mcp_server_health(remote_mcp_server: str) -> JSONResponse: - # 使用FastMCP Client验证远程服务连接 - client = Client(remote_mcp_server) - async with client: - connected = client.is_connected() - # 返回连接状态 -``` +from fastmcp import FastMCP -#### 添加MCP服务器 -```python -async def add_remote_mcp_server_list(tenant_id, user_id, remote_mcp_server, remote_mcp_server_name): - # 1. 检查服务名是否已存在 - # 2. 验证远程服务连接 - # 3. 保存到数据库 - # 4. 返回操作结果 -``` +# 创建MCP服务器实例 +mcp = FastMCP(name="String MCP Server") -#### 删除MCP服务器 -```python -async def delete_remote_mcp_server_list(tenant_id, user_id, remote_mcp_server, remote_mcp_server_name): - # 1. 从数据库删除记录 - # 2. 返回操作结果 -``` +@mcp.tool( + name="calculate_string_length", + description="计算输入字符串的长度" +) +def calculate_string_length(text: str) -> int: + return len(text) -### 4. 智能体配置管理 (create_agent_info.py) +@mcp.tool( + name="to_uppercase", + description="将字符串转换为大写" +) +def to_uppercase(text: str) -> str: + return text.upper() -**MCP服务器过滤机制**: -```python -def filter_mcp_servers_and_tools(input_agent_config: AgentConfig, mcp_info_dict) -> list: - """ - 过滤MCP服务器和工具,只保留实际使用的MCP服务器 - 支持多级智能体,递归检查所有子智能体工具 - """ - used_mcp_urls = set() - - def check_agent_tools(agent_config: AgentConfig): - # 检查当前智能体工具 - for tool in agent_config.tools: - if tool.source == "mcp" and tool.usage in mcp_info_dict: - used_mcp_urls.add(mcp_info_dict[tool.usage]["remote_mcp_server"]) - - # 递归检查子智能体 - for sub_agent_config in agent_config.managed_agents: - check_agent_tools(sub_agent_config) - - check_agent_tools(input_agent_config) - return list(used_mcp_urls) -``` +@mcp.tool( + name="to_lowercase", + description="将字符串转换为小写" +) +def to_lowercase(text: str) -> str: + return text.lower() -**智能体运行信息创建**: -```python -async def create_agent_run_info(agent_id, minio_files, query, history, authorization, language='zh'): - # 1. 获取用户和租户信息 - # 2. 创建模型配置列表 - # 3. 创建智能体配置 - # 4. 获取远程MCP服务器列表 - # 5. 过滤实际使用的MCP服务器 - # 6. 创建智能体运行信息 +if __name__ == "__main__": + # 使用SSE协议启动服务 + mcp.run(transport="sse", port=8000) ``` -### 5. 智能体执行引擎 (run_agent.py) +### 🏃 运行服务器 -**MCP工具集成**: -```python -def agent_run_thread(agent_run_info: AgentRunInfo, memory_context: MemoryContext): - mcp_host = agent_run_info.mcp_host - - if mcp_host is None or len(mcp_host) == 0: - # 无MCP服务器:使用本地工具 - nexent = NexentAgent(...) - agent = nexent.create_single_agent(agent_run_info.agent_config) - # ... - else: - # 有MCP服务器:使用ToolCollection直接连接所有MCP服务 - agent_run_info.observer.add_message("", ProcessType.AGENT_NEW_RUN, "") - mcp_client_list = [{"url": mcp_url} for mcp_url in mcp_host] - - with ToolCollection.from_mcp(mcp_client_list, trust_remote_code=True) as tool_collection: - # ToolCollection会同时连接本地MCP服务(5011)和远程MCP服务 - nexent = NexentAgent( - mcp_tool_collection=tool_collection, - # ... - ) - # 执行智能体 -``` +保存代码为 `mcp_server.py`,然后运行: -## 数据流程 - -### 1. MCP服务器添加流程 - -```mermaid -sequenceDiagram - participant C as 前端客户端 - participant A as remote_mcp_app - participant S as remote_mcp_service - participant DB as 数据库 - participant MCP as 远程MCP服务 - - C->>A: POST /api/mcp/add - A->>S: add_remote_mcp_server_list() - S->>S: 检查服务名是否存在 - S->>MCP: mcp_server_health() - MCP-->>S: 连接状态 - alt 连接成功 - S->>DB: create_mcp_record() - DB-->>S: 保存结果 - S-->>A: 成功响应 - A-->>C: 成功响应 - else 连接失败 - S-->>A: 错误响应 - A-->>C: 错误响应 - end -``` - -### 2. 智能体执行流程 - -```mermaid -sequenceDiagram - participant C as 前端客户端 - participant A as create_agent_info - participant R as run_agent - participant TC as ToolCollection - participant LMCP as 本地MCP服务(5011) - participant RMCP1 as 远程MCP服务1 - participant RMCP2 as 远程MCP服务2 - - C->>A: create_agent_run_info() - A->>A: filter_mcp_servers_and_tools() - A-->>C: AgentRunInfo - - C->>R: agent_run() - R->>R: agent_run_thread() - - alt 有MCP服务器 - R->>TC: ToolCollection.from_mcp() - TC->>LMCP: 连接本地MCP服务 - TC->>RMCP1: 直接连接远程MCP服务1 - TC->>RMCP2: 直接连接远程MCP服务2 - LMCP-->>TC: 本地工具列表 - RMCP1-->>TC: 远程工具列表1 - RMCP2-->>TC: 远程工具列表2 - TC-->>R: 合并的工具集合 - R->>R: 执行智能体 - else 无MCP服务器 - R->>R: 使用本地工具执行 - end - - R-->>C: 执行结果 -``` - -## 关键特性 - -### 1. 多租户隔离 -- 所有MCP服务器配置基于`tenant_id`进行隔离 -- 用户只能访问自己租户的MCP服务器 - -### 2. 动态MCP管理 -- 支持运行时添加、删除MCP服务器配置 -- 自动健康检查和状态更新 -- 数据库持久化存储配置 -- 智能体执行时直接连接远程MCP服务 - -### 3. 智能工具过滤 -- 只连接智能体实际使用的MCP服务器 -- 支持多级智能体的递归工具检查 -- 避免不必要的网络连接 -- 本地MCP服务(5011)始终可用,远程服务按需连接 - -### 4. 错误处理 -- MCP连接失败时的优雅降级 -- 详细的错误日志和状态反馈 -- 连接超时保护机制 - -### 5. 内存管理 -- 智能体执行完成后自动保存对话记忆 -- 支持多级记忆存储(租户、智能体、用户、用户智能体) -- 可配置的记忆共享策略 - -## 配置说明 - -### 环境变量 ```bash -# MCP服务地址 -NEXENT_MCP_SERVER=http://localhost:5011 - -# 数据库配置 -DATABASE_URL=postgresql://... - -# 其他配置... +python mcp_server.py ``` -### 数据库表结构 -```sql --- MCP服务器配置表 -CREATE TABLE mcp_servers ( - id SERIAL PRIMARY KEY, - tenant_id VARCHAR NOT NULL, - user_id VARCHAR NOT NULL, - mcp_name VARCHAR NOT NULL, - mcp_server VARCHAR NOT NULL, - status BOOLEAN DEFAULT true, - created_at TIMESTAMP DEFAULT NOW(), - updated_at TIMESTAMP DEFAULT NOW() -); -``` +您将看到 MCP 服务器成功启动,服务地址为 `http://127.0.0.1:8000/sse`。 -## 使用示例 +## 🔌 集成到 Nexent -### 1. 启动服务 -```bash -# 启动主服务 -cd backend -python config_service.py -python runtime_service.py - -# 启动本地MCP服务 -cd backend -python mcp_service.py -``` +MCP 服务器运行后,将其连接到 Nexent: -### 2. 添加远程MCP服务器 -```bash -curl -X POST "http://localhost:5010/api/mcp/add?mcp_url=http://external-server:5012/sse&service_name=external_service" \ - -H "Authorization: Bearer {your_token}" -``` +### 📍 步骤 1:启动 MCP 服务器 -## MCP接口文档 +确保服务器正在运行,并记录其访问地址(例如 `http://127.0.0.1:8000/sse`)。 -### 1. 主服务API接口 (端口5010) +### ⚙️ 步骤 2:在 Nexent 中注册 -#### 1.1 获取远程MCP工具信息 -```http -POST /api/mcp/tools -``` +1. 进入 **[智能体开发](../../user-guide/agent-development)** 页面 +2. 在"选择Agent的工具"页签右侧,点击"**MCP配置**" +3. 在弹出的配置窗口中,输入服务器名称和服务器URL + - ⚠️ **注意**: + - 服务器名称只能包含英文字母和数字,不能包含空格、下划线等其他字符 + - 如果使用 Docker 容器部署 Nexent,且 MCP 服务器运行在宿主机上,需要将 `127.0.0.1` 替换为 `host.docker.internal`(例如 `http://host.docker.internal:8000`) +4. 点击"**添加**"按钮完成配置 -**请求参数**: -- `service_name` (string, 必需): MCP服务名称 -- `mcp_url` (string, 必需): MCP服务器URL -- `Authorization` (Header, 必需): Bearer token - -**响应示例**: -```json -{ - "tools": [ - { - "name": "tool_name", - "description": "tool_description", - "parameters": {...} - } - ], - "status": "success" -} -``` +### 🎯 步骤 3:使用 MCP 工具 -#### 1.2 添加远程MCP服务器 -```http -POST /api/mcp/add -``` +配置完成后,在创建或编辑智能体时,您可以在工具列表中找到并选择您添加的 MCP 工具。 -**请求参数**: -- `mcp_url` (string, 必需): MCP服务器URL -- `service_name` (string, 必需): MCP服务名称 -- `Authorization` (Header, 必需): Bearer token - -**响应示例**: -```json -{ - "message": "Successfully added remote MCP proxy", - "status": "success" -} -``` +## 🔧 高级用例 -#### 1.3 删除远程MCP服务器 -```http -DELETE /api/mcp -``` +### 🌐 包装 REST API -**请求参数**: -- `service_name` (string, 必需): MCP服务名称 -- `mcp_url` (string, 必需): MCP服务器URL -- `Authorization` (Header, 必需): Bearer token - -**响应示例**: -```json -{ - "message": "Successfully deleted remote MCP proxy", - "status": "success" -} -``` - -#### 1.4 获取MCP服务器列表 -```http -GET /api/mcp/list -``` +将现有的 REST API 包装为 MCP 工具: -**请求参数**: -- `Authorization` (Header, 必需): Bearer token - -**响应示例**: -```json -{ - "remote_mcp_server_list": [ - { - "id": 1, - "tenant_id": "tenant_123", - "user_id": "user_456", - "mcp_name": "external_service", - "mcp_server": "http://external-server:5012/sse", - "status": true, - "created_at": "2024-01-01T00:00:00Z", - "updated_at": "2024-01-01T00:00:00Z" - } - ], - "status": "success" -} -``` +```python +from fastmcp import FastMCP +import requests + +mcp = FastMCP("Course Statistics Server") + +@mcp.tool( + name="get_course_statistics", + description="根据课程号获取某门课程的成绩统计信息(包含平均分、最高分、最低分等)" +) +def get_course_statistics(course_id: str) -> str: + api_url = "https://your-school-api.com/api/courses/statistics" + response = requests.get(api_url, params={"course_id": course_id}) + + if response.status_code == 200: + data = response.json() + stats = data.get("statistics", {}) + return f"课程 {course_id} 成绩统计:\n平均分: {stats.get('average', 'N/A')}\n最高分: {stats.get('max', 'N/A')}\n最低分: {stats.get('min', 'N/A')}\n总人数: {stats.get('total_students', 'N/A')}" + return f"API调用失败: {response.status_code}" -#### 1.5 MCP服务器健康检查 -```http -GET /api/mcp/healthcheck +if __name__ == "__main__": + mcp.run(transport="sse", port=8000) ``` -**请求参数**: -- `mcp_url` (string, 必需): MCP服务器URL -- `service_name` (string, 必需): MCP服务名称 -- `Authorization` (Header, 必需): Bearer token - -**响应示例**: -```json -{ - "message": "Successfully connected to remote MCP server", - "status": "success" -} -``` +### 🏢 包装内部服务 -### 2. 本地MCP服务接口 (端口5011) - -#### 2.1 MCP协议接口 -本地MCP服务基于FastMCP框架,提供标准MCP协议支持: - -**服务地址**:`http://localhost:5011/sse` - -**支持的操作**: -- `tools/list`: 获取工具列表 -- `tools/call`: 调用工具 -- `resources/list`: 获取资源列表 -- `resources/read`: 读取资源 - -#### 2.2 本地工具服务 -本地MCP服务挂载了以下本地工具: -- 文件操作工具 -- 网络请求工具 -- 系统信息工具 -- 其他本地工具 - -### 3. 错误码说明 - -| 状态码 | 说明 | 处理建议 | -|--------|------|----------| -| 200 | 成功 | 正常处理响应数据 | -| 400 | 请求参数错误 | 检查请求参数格式和内容 | -| 401 | 认证失败 | 检查Authorization token是否有效 | -| 403 | 权限不足 | 确认用户权限 | -| 404 | 资源不存在 | 检查MCP服务器URL是否正确 | -| 409 | 服务名已存在 | 使用不同的服务名称 | -| 503 | 服务不可用 | 检查MCP服务器是否正常运行 | -| 500 | 服务器内部错误 | 查看服务器日志 | - -### 4. 前端API调用示例 - -#### 4.1 JavaScript/TypeScript调用 -```typescript -// 获取MCP服务器列表 -const getMcpServerList = async () => { - const response = await fetch('/api/mcp/list', { - headers: { - 'Authorization': `Bearer ${token}`, - 'Content-Type': 'application/json' - } - }); - return await response.json(); -}; - -// 添加MCP服务器 -const addMcpServer = async (mcpUrl: string, serviceName: string) => { - const response = await fetch(`/api/mcp/add?mcp_url=${mcpUrl}&service_name=${serviceName}`, { - method: 'POST', - headers: { - 'Authorization': `Bearer ${token}`, - 'Content-Type': 'application/json' - } - }); - return await response.json(); -}; -``` +集成本地业务逻辑: -#### 4.2 cURL调用示例 -```bash -# 获取MCP服务器列表 -curl -H "Authorization: Bearer {your_token}" \ - "http://localhost:5010/api/mcp/list" +```python +from fastmcp import FastMCP +from your_school_module import query_course_statistics -# 添加MCP服务器 -curl -X POST "http://localhost:5010/api/mcp/add?mcp_url=http://external-server:5012/sse&service_name=external_service" \ - -H "Authorization: Bearer {your_token}" +mcp = FastMCP("Course Statistics Server") -# 删除MCP服务器 -curl -X DELETE "http://localhost:5010/api/mcp/?service_name=external_service&mcp_url=http://external-server:5012/sse" \ - -H "Authorization: Bearer {your_token}" +@mcp.tool( + name="get_course_statistics", + description="根据课程号获取某门课程的成绩统计信息(包含平均分、最高分、最低分等)" +) +def get_course_statistics(course_id: str) -> str: + try: + stats = query_course_statistics(course_id) + return f"课程 {course_id} 成绩统计:\n平均分: {stats.get('average', 'N/A')}\n最高分: {stats.get('max', 'N/A')}\n最低分: {stats.get('min', 'N/A')}\n总人数: {stats.get('total_students', 'N/A')}" + except Exception as e: + return f"查询成绩统计时出错: {str(e)}" -# 健康检查 -curl "http://localhost:5010/api/mcp/healthcheck?mcp_url=http://external-server:5012/sse&service_name=external_service" \ - -H "Authorization: Bearer {your_token}" +if __name__ == "__main__": + mcp.run(transport="sse", port=8000) ``` -## 性能优化 +## ✅ 最佳实践 -### 1. 连接池管理 -- MCP客户端连接复用 -- 自动连接超时和重试机制 -- ToolCollection统一管理多个MCP服务连接 +- **日志记录**: stdio 传输避免 stdout 日志(不要 `print`),日志写入 stderr/文件。[日志说明](https://modelcontextprotocol.io/docs/develop/build-server#logging-in-mcp-servers) +- **文档规范**: 工具 docstring/类型要清晰,FastMCP 会据此生成 schema +- **错误处理**: 友好处理错误,返回可读文本 +- **安全性**: 敏感信息放环境变量/密钥管理,不要硬编码 -### 2. 工具缓存 -- 工具信息本地缓存 -- 减少重复的MCP服务查询 +## 📚 相关资源 -### 3. 异步处理 -- 所有MCP操作采用异步模式 -- 支持并发智能体执行 +### 🐍 Python -## 安全考虑 +- [FastMCP 文档](https://github.com/modelcontextprotocol/python-sdk) +- [Python SDK 仓库](https://github.com/modelcontextprotocol/python-sdk) -### 1. 认证授权 -- 所有API接口需要Bearer token认证 -- 基于租户的数据隔离 +### 🔤 其他语言 -### 2. 连接验证 -- 添加MCP服务器前进行连通性验证 -- 支持HTTPS和SSE安全传输 +- [MCP TypeScript SDK](https://github.com/modelcontextprotocol/typescript-sdk) +- [MCP Java SDK](https://github.com/modelcontextprotocol/java-sdk) +- [MCP Go SDK](https://github.com/modelcontextprotocol/go-sdk) +- [MCP Rust SDK](https://github.com/modelcontextprotocol/rust-sdk) -### 3. 错误处理 -- 详细的错误日志记录 -- 敏感信息脱敏处理 +### 📖 官方文档 -## 故障排除 +- [MCP 介绍](https://modelcontextprotocol.io/docs/getting-started/intro) +- [构建服务器指南](https://modelcontextprotocol.io/docs/develop/build-server) +- [SDK 文档](https://modelcontextprotocol.io/docs/sdk) +- [MCP 协议规范](https://modelcontextprotocol.io/) -### 常见问题 +### 🔗 相关指南 -1. **MCP连接失败** - - 检查远程MCP服务是否正常运行 - - 验证网络连接和防火墙设置 - - 查看服务日志获取详细错误信息 +- [Nexent 智能体开发指南](../../user-guide/agent-development) +- [MCP 工具生态系统概览](../../mcp-ecosystem/overview) +- [MCP 推荐](../../mcp-ecosystem/mcp-recommendations) -2. **工具加载失败** - - 确认MCP服务器支持所需的工具 - - 检查工具配置是否正确 - - 验证权限设置 +## 🆘 获取帮助 -3. **性能问题** - - 监控MCP服务器响应时间 - - 检查网络延迟 - - 优化工具过滤逻辑 +如果在开发 MCP 服务器时遇到问题: -### 调试工具 - -```python -# 检查MCP服务器健康状态 -response = await mcp_server_health("http://remote-server:port/sse") - -# 获取MCP服务器列表 -servers = await get_remote_mcp_server_list(tenant_id) - -# 查看智能体使用的MCP服务器 -mcp_hosts = filter_mcp_servers_and_tools(agent_config, mcp_info_dict) - -# 查看本地MCP服务状态 -# 本地MCP服务运行在端口5011,提供本地工具 -``` +1. 查看 **[常见问题](../../quick-start/faq)** +2. 在 [GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions) 中提问 +3. 参考 [ModelScope MCP Marketplace](https://www.modelscope.cn/mcp) 中的示例服务器 diff --git a/doc/docs/zh/backend/tools/nexent-native.md b/doc/docs/zh/backend/tools/nexent-native.md new file mode 100644 index 000000000..e0a211e50 --- /dev/null +++ b/doc/docs/zh/backend/tools/nexent-native.md @@ -0,0 +1,26 @@ +--- +title: Nexent 原生工具 +--- + +# Nexent 原生工具 + +## 🧭 范围 + +Nexent 原生工具统一在官方仓库维护。如需新增自定义能力,请在 `sdk/nexent/core/tools` 目录按现有模式开发并提交。 + +## 🛠️ 开发规范 + +- 在 `sdk/nexent/core/tools` 与现有工具并行开发(如文件、搜索、邮件、多模态等)。 +- 遵循 [工具模块](../../sdk/core/tools) 的结构、输入定义与消息规则。 +- 注释与文档使用英文,遵守仓库规则。 + +## 🤝 贡献路径 + +- 仅支持向 Nexent 官方仓库提交 PR,不支持外部托管的“原生工具”。 +- 参考现有实现及 [贡献指南](../../contributing) 完成流程与质量要求。 + +## 🔗 相关参考 + +- [工具模块](../../sdk/core/tools) +- [贡献指南](../../contributing) + diff --git a/doc/docs/zh/version/version-management.md b/doc/docs/zh/backend/version-management.md similarity index 86% rename from doc/docs/zh/version/version-management.md rename to doc/docs/zh/backend/version-management.md index 318b88253..8ba80cdeb 100644 --- a/doc/docs/zh/version/version-management.md +++ b/doc/docs/zh/backend/version-management.md @@ -2,7 +2,7 @@ Nexent 项目采用统一的版本管理策略,确保前端和后端版本信息的一致性。本文档介绍如何管理和更新项目版本信息。 -## 版本号格式 +## 📋 版本号格式 Nexent 使用语义化版本控制: @@ -12,21 +12,21 @@ Nexent 使用语义化版本控制: - **PATCH**: 向下兼容的问题修正 - **BUILD**: 可选的小版本号,用于更细粒度的 bugfix 版本 -### 版本号示例 +### 🏷️ 版本号示例 - `v1.2.0` - 功能更新版本 - `v1.2.0.1` - 包含小版本号的 bugfix 版本 -## 前端版本管理 +## 🖥️ 前端版本管理 -### 版本信息位置 +### 📍 版本信息位置 前端版本信息通过接口从后端获取。 - **接口**: `GET /api/tenant_config/deployment_version` - **服务**: `frontend/services/versionService.ts` -### 版本更新流程 +### 🔄 版本更新流程 1. **在代码中更新后端版本** @@ -47,16 +47,16 @@ APP_VERSION="v1.1.0" # 在页面底部检查应用版本显示 ``` -### 版本显示 +### 📺 版本显示 前端版本信息在以下位置显示: - 位置:页面底部导航栏,位于页面左下角 - 版本格式:`v1.1.0` -## 后端版本管理 +## ⚙️ 后端版本管理 -### 版本信息位置 +### 📍 版本信息位置 后端版本信息在 `backend/consts/const.py` 中以代码形式定义: @@ -65,11 +65,11 @@ APP_VERSION="v1.1.0" APP_VERSION = "v1.0.0" ``` -### 版本配置 +### 🔧 版本配置 版本通过直接修改 `backend/consts/const.py` 中的 `APP_VERSION` 配置。 -### 版本显示 +### 📺 版本显示 后端启动时会在日志中打印版本信息: @@ -78,7 +78,7 @@ APP_VERSION = "v1.0.0" logger.info(f"APP version is: {APP_VERSION}") ``` -### 版本更新流程 +### 🔄 版本更新流程 1. **在代码中更新版本** @@ -97,3 +97,4 @@ APP_VERSION="v1.1.0" # 查看启动日志中的版本信息 # 输出示例:APP version is: v1.1.0 ``` + diff --git a/doc/docs/zh/contributing.md b/doc/docs/zh/contributing.md index ee8dd8405..43b9259af 100644 --- a/doc/docs/zh/contributing.md +++ b/doc/docs/zh/contributing.md @@ -138,7 +138,7 @@ git checkout -b 您的分支名 ``` ### 第四步:进行更改 -🧙♂️ 像魔法师一样编码!遵循我们的 [开发指南](./getting-started/development-guide) 获取设置说明和编码标准。确保您的更改经过充分测试并有文档记录。 +🧙♂️ 像魔法师一样编码!遵循我们的 [开发者指南](./developer-guide/overview) 获取设置说明和编码标准。确保您的更改经过充分测试并有文档记录。 ### 第五步:提交更改 📝 按照我们的提交消息规范,提交清晰简洁的消息(建议采用英文,让更多人理解你): diff --git a/doc/docs/zh/deployment/docker-build.md b/doc/docs/zh/deployment/docker-build.md index 8fa33da43..6ae30aa7f 100644 --- a/doc/docs/zh/deployment/docker-build.md +++ b/doc/docs/zh/deployment/docker-build.md @@ -140,4 +140,6 @@ docker rm nexent-docs ## 🚀 部署建议 -构建完成后,可以使用 `docker/deploy.sh` 脚本进行部署,或者直接使用 `docker-compose` 启动服务。 \ No newline at end of file +构建完成后,可以使用 `docker/deploy.sh` 脚本进行部署,或者直接使用 `docker-compose` 启动服务。 + +> 启动测试本地构建的镜像时,需要修改下`docker/deploy.sh`中的`APP_VERSION="$(get_app_version)"` -> `APP_VERSION="latest"`,因为部署时默认会使用当前版本对应的镜像。 \ No newline at end of file diff --git a/doc/docs/zh/developer-guide/environment-setup.md b/doc/docs/zh/developer-guide/environment-setup.md new file mode 100644 index 000000000..0a81ca10d --- /dev/null +++ b/doc/docs/zh/developer-guide/environment-setup.md @@ -0,0 +1,134 @@ +--- +title: 环境准备 +--- + +# 环境准备 + +本指南拆分了全栈开发与仅使用 SDK 的两类场景,按需选择路径完成环境准备。 + +## 🧱 通用要求 + +- Python 3.10+ +- Node.js 18+ +- Docker & Docker Compose +- uv(Python 包管理器) +- pnpm(Node.js 包管理器) + +## 🧑💻 全栈 Nexent 开发 + +### ⚙️ 基础设施部署 + +先启动数据库、缓存、向量库、存储等核心服务。 + +```bash +# 在项目根目录的 docker 目录执行 +cd docker +./deploy.sh --mode infrastructure +``` + +:::: info 重要提示 +基础设施模式会启动 PostgreSQL、Redis、Elasticsearch、MinIO,并在项目根生成 `.env`(包含生成的密钥与本地地址)。所有服务默认指向 localhost 便于本地开发。 +:::: + +### 🐍 后端依赖 + +```bash +cd backend +uv sync --all-extras +uv pip install ../sdk +``` + +:::: tip 说明 +`--all-extras` 安装所有可选依赖(数据处理、测试等),随后安装本地 SDK 包。 +:::: + +#### 可选:镜像加速 + +```bash +# 清华源 +uv sync --all-extras --default-index https://pypi.tuna.tsinghua.edu.cn/simple +uv pip install ../sdk --default-index https://pypi.tuna.tsinghua.edu.cn/simple + +# 阿里云 +uv sync --all-extras --default-index https://mirrors.aliyun.com/pypi/simple/ +uv pip install ../sdk --default-index https://mirrors.aliyun.com/pypi/simple/ + +# 多源(推荐) +uv sync --all-extras --index https://pypi.tuna.tsinghua.edu.cn/simple --index https://mirrors.aliyun.com/pypi/simple/ +uv pip install ../sdk --index https://pypi.tuna.tsinghua.edu.cn/simple --index https://mirrors.aliyun.com/pypi/simple/ +``` + +:::: info 镜像参考 +- 清华:`https://pypi.tuna.tsinghua.edu.cn/simple` +- 阿里:`https://mirrors.aliyun.com/pypi/simple/` +- 中科大:`https://pypi.mirrors.ustc.edu.cn/simple/` +- 豆瓣:`https://pypi.douban.com/simple/` +多源组合可提升成功率。 +:::: + +### ⚛️ 前端依赖 + +```bash +cd frontend +pnpm install +pnpm dev +``` + +### 🏃 服务启动 + +先激活后端虚拟环境: + +```bash +cd backend +source .venv/bin/activate +``` + +:::: warning 提示 +Windows 请使用 `source .venv/Scripts/activate`。 +:::: + +在项目根依次启动核心服务: + +```bash +source .env && python backend/mcp_service.py +source .env && python backend/data_process_service.py +source .env && python backend/config_service.py +source .env && python backend/runtime_service.py +``` + +:::: warning 提示 +需在项目根执行,并先 `source .env`。确保数据库、Redis、Elasticsearch、MinIO 已就绪。 +:::: + +## 🧰 仅使用 SDK + +若只需 SDK 而不运行全栈,可直接安装。 + +### 源码安装 + +```bash +git clone https://github.com/ModelEngine-Group/nexent.git +cd nexent/sdk +uv pip install -e . +``` + +### 使用 uv 安装 + +```bash +uv add nexent +``` + +### 开发者安装(含工具链) + +```bash +cd nexent/sdk +uv pip install -e ".[dev]" +``` + +包含: + +- 代码质量工具(ruff) +- 测试框架(pytest) +- 数据处理依赖(unstructured) +- 其他开发辅助依赖 + diff --git a/doc/docs/zh/getting-started/development-guide.md b/doc/docs/zh/developer-guide/overview.md similarity index 56% rename from doc/docs/zh/getting-started/development-guide.md rename to doc/docs/zh/developer-guide/overview.md index 2790a44bb..cbeca2c0a 100644 --- a/doc/docs/zh/getting-started/development-guide.md +++ b/doc/docs/zh/developer-guide/overview.md @@ -40,100 +40,16 @@ nexent/ - **监控**: 内置健康检查 - **日志**: 结构化日志 -## 🚀 开发环境搭建 - -### 环境要求 -- Python 3.10+ -- Node.js 18+ -- Docker & Docker Compose -- uv (Python 包管理器) -- pnpm (Node.js 包管理器) - -### 基础设施部署 -在开始后端开发之前,需要先部署基础设施服务。这些服务包括数据库、缓存、文件存储等核心组件。 - -```bash -# 在项目根目录的docker目录下执行 -cd docker -./deploy.sh --mode infrastructure -``` - -::: info 重要说明 -基础设施模式会启动 PostgreSQL、Redis、Elasticsearch 和 MinIO 服务。部署脚本会自动生成开发环境所需的密钥和环境变量,并保存到根目录的 `.env` 文件中。生成的密钥包括 MinIO 访问密钥和 Elasticsearch API 密钥。所有服务 URL 会配置为 localhost 地址,方便本地开发。 -::: - -### 后端设置 -```bash -# 在项目根目录的backend目录下执行 -cd backend -uv sync --all-extras -uv pip install ../sdk -``` - -::: tip 说明 -`--all-extras` 会安装所有可选依赖,包括数据处理、测试等模块。然后安装本地 SDK 包。 -::: - -#### 使用国内镜像源(可选) -如果网络访问较慢,可以使用国内镜像源加速安装: - -```bash -# 使用清华大学镜像源 -uv sync --all-extras --default-index https://pypi.tuna.tsinghua.edu.cn/simple -uv pip install ../sdk --default-index https://pypi.tuna.tsinghua.edu.cn/simple - -# 使用阿里云镜像源 -uv sync --all-extras --default-index https://mirrors.aliyun.com/pypi/simple/ -uv pip install ../sdk --default-index https://mirrors.aliyun.com/pypi/simple/ +## 🧱 环境准备 -# 使用多个镜像源(推荐) -uv sync --all-extras --index https://pypi.tuna.tsinghua.edu.cn/simple --index https://mirrors.aliyun.com/pypi/simple/ -uv pip install ../sdk --index https://pypi.tuna.tsinghua.edu.cn/simple --index https://mirrors.aliyun.com/pypi/simple/ -``` - -::: info 镜像源说明 -- **清华大学镜像源**: `https://pypi.tuna.tsinghua.edu.cn/simple` -- **阿里云镜像源**: `https://mirrors.aliyun.com/pypi/simple/` -- **中科大镜像源**: `https://pypi.mirrors.ustc.edu.cn/simple/` -- **豆瓣镜像源**: `https://pypi.douban.com/simple/` - -推荐使用多个镜像源配置,以提高下载成功率。 -::: - -### 前端设置 -```bash -# 在项目根目录的frontend目录下执行 -cd frontend -pnpm install -pnpm run dev -``` - -### 服务启动 -在启动服务之前,需要先激活虚拟环境: +环境相关步骤已迁移至独立的 [环境准备](./environment-setup) 指南,涵盖: -```bash -# 在项目根目录的backend目录下执行 -cd backend -source .venv/bin/activate # 激活虚拟环境 -``` - -::: warning 重要提示 -Windows操作系统需执行`source .venv/Scripts/activate`命令激活虚拟环境。 -::: +- 通用依赖与前置条件 +- 全栈 Nexent 搭建(基础设施、后端、前端、运行服务) +- 仅使用 SDK 的安装路径 -Nexent 包含三个核心后端服务,需要分别启动: - -```bash -# 在项目根目录下执行,请按以下顺序执行: -source .env && python backend/mcp_service.py # MCP 服务 -source .env && python backend/data_process_service.py # 数据处理服务 -source .env && python backend/config_service.py # 编辑态服务 -source .env && python backend/runtime_service.py # 运行态服务 -``` +请先完成该指南,再回到此页继续模块开发。 -::: warning 重要提示 -所有服务必须在项目根目录下启动。每个 Python 命令前都需要先执行 `source .env` 来加载环境变量。确保基础设施服务(数据库、Redis、Elasticsearch、MinIO)已经启动并正常运行。 -::: ## 🔧 开发模块指南 @@ -208,8 +124,8 @@ source .env && python backend/runtime_service.py # 运行态服务 ## 💡 获取帮助 ### 文档资源 -- [安装部署](./installation.md) - 环境搭建和部署 -- [常见问题](./faq) - 常见问题解答 +- [安装部署](../quick-start/installation) - 环境搭建和部署 +- [常见问题](../quick-start/faq) - 常见问题解答 - [用户指南](../user-guide/home-page) - Nexent使用指南 ### 社区支持 diff --git a/doc/docs/zh/frontend/overview.md b/doc/docs/zh/frontend/overview.md index 247845a49..04d028c8d 100644 --- a/doc/docs/zh/frontend/overview.md +++ b/doc/docs/zh/frontend/overview.md @@ -128,4 +128,4 @@ npm start - 语音处理集成 - 分析和监控 -详细的开发指南和组件文档,请参阅 [开发指南](../getting-started/development-guide)。 \ No newline at end of file +详细的开发指南和组件文档,请参阅 [开发者指南](../developer-guide/overview)。 \ No newline at end of file diff --git a/doc/docs/zh/getting-started/overview.md b/doc/docs/zh/getting-started/overview.md index 14eab4555..abbbdd4ba 100644 --- a/doc/docs/zh/getting-started/overview.md +++ b/doc/docs/zh/getting-started/overview.md @@ -70,9 +70,9 @@ Nexent 采用现代化的分布式微服务架构,专为高性能、可扩展 准备好开始了吗?以下是您的下一步: -1. **📋 [安装部署](./installation)** - 系统要求和部署指南 -2. **🔧 [开发指南](./development-guide)** - 从源码构建和自定义 -3. **❓ [常见问题](./faq)** - 常见问题和故障排除 +1. **📋 [安装部署](../quick-start/installation)** - 系统要求和部署指南 +2. **🔧 [开发者指南](../developer-guide/overview)** - 从源码构建和自定义 +3. **❓ [常见问题](../quick-start/faq)** - 常见问题和故障排除 ## 💬 社区与联系方式 diff --git a/doc/docs/zh/known-issues.md b/doc/docs/zh/known-issues.md deleted file mode 100644 index 559392e96..000000000 --- a/doc/docs/zh/known-issues.md +++ /dev/null @@ -1,41 +0,0 @@ -# 已知问题 - -此页面列出了当前版本 Nexent 中的已知问题和限制。我们正在积极修复这些问题,并会随着解决方案的推出更新此页面。 - -## 🐛 当前问题 - -### 1. OpenSSH 容器软件安装限制 - -**问题描述**: 在 OpenSSH 容器中为终端工具使用安装其他软件包目前由于容器限制而比较困难。 - -**状态**: 开发中 - -**影响**: 需要在终端环境中使用自定义工具或软件包的用户可能面临限制。 - -**计划解决方案**: 我们正在努力提供改进的容器和文档,使自定义变得更容易。这将包括更好的包管理和更灵活的容器配置。 - -**预期时间线**: 改进的容器支持计划在即将发布的版本中提供。 - -## 📝 问题报告 - -如果您遇到此处未列出的任何问题,请: - -1. **查看我们的 [常见问题](./getting-started/faq)** 寻找常见解决方案 -2. **搜索现有问题** 在 [GitHub Issues](https://github.com/ModelEngine-Group/nexent/issues) -3. **创建新问题** 并提供详细信息,包括: - - 重现步骤 - - 预期行为 - - 实际行为 - - 系统信息 - - 日志文件(如适用) - -## 🔄 问题状态更新 - -我们定期更新此页面的已知问题状态。请经常查看更新,或关注我们的 [GitHub 仓库](https://github.com/ModelEngine-Group/nexent) 以获取通知。 - -## 💬 社区支持 - -如需即时帮助或讨论问题: -- 加入我们的 [Discord 社区](https://discord.gg/tb5H3S3wyv) -- 在 [GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions) 提问 -- 如果您想帮助修复问题,请查看我们的 [贡献指南](./contributing) \ No newline at end of file diff --git a/doc/docs/zh/mcp-ecosystem/mcp-recommendations.md b/doc/docs/zh/mcp-ecosystem/mcp-recommendations.md new file mode 100644 index 000000000..c3a6d2227 --- /dev/null +++ b/doc/docs/zh/mcp-ecosystem/mcp-recommendations.md @@ -0,0 +1,36 @@ +# MCP 推荐 + +本页面为您精选推荐 MCP 平台和工具,帮助您快速发现高质量的 MCP 服务。 + +## 🌐 MCP 社区中心 + +全球 MCP 生态系统正在蓬勃发展,多个平台支持 MCP 开发和部署: + +| 平台 | 描述 | 备注 | +|----------|-------------|-------| +| **[GitHub MCP Server](https://github.com/github/github-mcp-server)** | 与 Claude、GPT-4、Copilot 等深度集成,支持 Go 和 Python | OAuth/GitHub 账户授权 | +| **[Qdrant MCP Vector Server](https://github.com/qdrant/mcp-server-qdrant)** | 语义向量存储,Python/Go 兼容 | 与 LangChain 和其他工具兼容 | +| **[Anthropic Reference MCP Servers](https://github.com/modelcontextprotocol/servers)** | 轻量级教学和原型工具,Python | 包括 fetch、git 和其他通用工具 | +| **[AWS Labs MCP Server](https://github.com/awslabs/mcp)** | AWS+Go+CDK 云参考服务 | 适用于云环境 | +| **[MCP Hub China](https://www.mcp-cn.com/)** | 中文精选高质量 MCP 服务平台 | 注重质量而非数量,社区驱动 | +| **[ModelScope MCP Marketplace](https://modelscope.cn/mcp)** | 中国最大的 MCP 社区,拥有 1,500+ 服务 | 从高德地图到支付宝,全面的服务覆盖 | +| **社区 MCP 服务器** | 各种特定场景的源代码集合 | 主要是实验性和创新工具 | + +## 🔧 推荐的 MCP 工具 + +| 工具名称 | 功能 | 描述 | +|-----------|----------|-------------| +| **[高德地图](https://modelscope.cn/mcp/servers/@amap/amap-maps)** | 地理服务和导航 | 综合地图、地理编码、路由和位置服务 | +| **[必应搜索(中文)](https://modelscope.cn/mcp/servers/@yan5236/bing-cn-mcp-server)** | 中文网络搜索 | 优化的中文网络搜索和信息检索 | +| **[12306 火车票查询](https://modelscope.cn/mcp/servers/@Joooook/12306-mcp)** | 中国铁路票务预订 | 实时列车时刻表、票务可用性和预订协助 | +| **[支付宝 MCP](https://modelscope.cn/mcp/servers/@alipay/mcp-server-alipay)** | 支付和金融服务 | 数字支付、金融工具和服务集成 | +| **[飞常准航空](https://modelscope.cn/mcp/servers/@variflight-ai/variflight-mcp)** | 航班信息和航空数据 | 实时航班跟踪、时刻表和航空分析 | +| **[顺序思考](https://modelscope.cn/mcp/servers/@modelcontextprotocol/sequentialthinking)** | 结构化问题解决框架 | 将复杂问题分解为可管理的顺序步骤 | +| **[ArXiv AI 搜索](https://modelscope.cn/mcp/servers/@blazickjp/arxiv-mcp-server)** | 学术论文搜索和研究 | 高级搜索和检索科学论文和研究 | +| **[Firecrawl MCP 服务器](https://modelscope.cn/mcp/servers/@mendableai/firecrawl-mcp-server)** | 网络爬虫和内容提取 | 智能网络爬虫、数据提取和内容处理 | + +## 🔗 相关资源 + +- [MCP 生态系统概览](./overview) +- [MCP 工具集成指南](../backend/tools/mcp) +- [用例场景](./use-cases) diff --git a/doc/docs/zh/mcp-ecosystem/mcp-server-development.md b/doc/docs/zh/mcp-ecosystem/mcp-server-development.md deleted file mode 100644 index 11a8b790e..000000000 --- a/doc/docs/zh/mcp-ecosystem/mcp-server-development.md +++ /dev/null @@ -1,195 +0,0 @@ -# MCP 服务器开发指南 - -本指南将帮助您使用 Python 和 FastMCP 框架开发自己的 MCP 服务器,并将其集成到 Nexent 平台中。 - -## 🌐 语言支持 - -MCP 协议支持多种编程语言,包括: - -- **Python** ⭐(推荐) -- **TypeScript** -- **Java** -- **Go** -- **Rust** -- 以及其他支持 MCP 协议的语言 - -### 为什么推荐 Python? - -本指南使用 **Python** 作为示例语言,原因如下: - -- ✅ **简单易学**:语法简洁,上手快速 -- ✅ **丰富的框架**:FastMCP 等框架让开发变得非常简单 -- ✅ **快速开发**:几行代码即可创建一个可用的 MCP 服务器 -- ✅ **生态完善**:丰富的第三方库支持 - -如果您熟悉其他语言,也可以使用相应的 MCP SDK 进行开发。但如果您是第一次开发 MCP 服务器,我们强烈推荐从 Python 开始。 - -## 📋 前置要求 - -在开始之前,请确保您已安装以下依赖: - -```bash -pip install fastmcp -``` - -## 🚀 快速开始 - -### 基础示例 - -以下是一个简单的 MCP 服务器示例,展示了如何使用 FastMCP 创建一个提供字符串处理功能的服务器: - -```python -from fastmcp import FastMCP - -# 创建MCP服务器实例 -mcp = FastMCP(name="String MCP Server") - -@mcp.tool( - name="calculate_string_length", - description="计算输入字符串的长度" -) -def calculate_string_length(text: str) -> int: - return len(text) - -@mcp.tool( - name="to_uppercase", - description="将字符串转换为大写" -) -def to_uppercase(text: str) -> str: - return text.upper() - -@mcp.tool( - name="to_lowercase", - description="将字符串转换为小写" -) -def to_lowercase(text: str) -> str: - return text.lower() - -if __name__ == "__main__": - # 使用SSE协议启动服务 - mcp.run(transport="sse", port=8000) -``` - -### 运行服务器 - -保存上述代码为 `mcp_server.py`,然后运行: - -```bash -python mcp_server.py -``` - -您将看到 MCP server 成功启动,且 Server URL 为`http://127.0.0.1:8000/sse`。 - -## 🔌 在 Nexent 中集成 MCP 服务 - -开发并启动 MCP 服务后,您需要将其添加到 Nexent 平台中进行使用: - -### 步骤 1:启动 MCP 服务器 - -确保您的 MCP 服务器正在运行,并记录其访问地址(例如:`http://127.0.0.1:8000/sse`)。 - -### 步骤 2:在 Nexent 中添加 MCP 服务 - -1. 进入 **[智能体开发](../user-guide/agent-development.md)** 页面 -2. 在"选择Agent的工具"页签右侧,点击"**MCP配置**" -3. 在弹出的配置窗口中,输入服务器名称和服务器URL - - ⚠️ **注意**: - 1. 服务器名称只能包含英文字母和数字,不能包含空格、下划线等其他字符; - 2. 如果您使用 Docker 容器部署 Nexent,并且 MCP 服务器运行在宿主机上,需要将 `127.0.0.1` 替换为 `host.docker.internal`,即`http://host.docker.internal:8000`才可成功访问宿主机上运行的 MCP 服务器。 -4. 点击"**添加**"按钮完成配置 - -### 步骤 3:使用 MCP 工具 - -配置完成后,在创建或编辑智能体时,您可以在工具列表中找到并选择您添加的 MCP 工具。 - -## 🔧 包装现有业务 - -如果您已有现成的业务代码,想要将其包装成 MCP 服务,只需要在工具函数中进行调用即可。这种方式可以快速将现有服务集成到 MCP 生态系统中。 - -### 示例:包装 REST API - -如果您的业务逻辑已有现成Restful API: - -```python -from fastmcp import FastMCP -import requests - -# 创建MCP服务器实例 -mcp = FastMCP("Course Statistics Server") - -@mcp.tool( - name="get_course_statistics", - description="根据课程号获取某门课程的成绩统计信息(包含平均分、最高分、最低分等)" -) -def get_course_statistics(course_id: str) -> str: - # 调用现有的业务API - api_url = "https://your-school-api.com/api/courses/statistics" - response = requests.get(api_url, params={"course_id": course_id}) - - # 处理响应并返回结果 - if response.status_code == 200: - data = response.json() - stats = data.get("statistics", {}) - return f"课程 {course_id} 成绩统计:\n平均分: {stats.get('average', 'N/A')}\n最高分: {stats.get('max', 'N/A')}\n最低分: {stats.get('min', 'N/A')}\n总人数: {stats.get('total_students', 'N/A')}" - else: - return f"API调用失败: {response.status_code}" - -if __name__ == "__main__": - # 使用SSE协议启动服务 - mcp.run(transport="sse", port=8000) -``` - -### 示例:包装内部服务 - -如果您的业务逻辑在本地服务中: - -```python -from fastmcp import FastMCP -from your_school_module import query_course_statistics - -# 创建MCP服务器实例 -mcp = FastMCP("Course Statistics Server") - -@mcp.tool( - name="get_course_statistics", - description="根据课程号获取某门课程的成绩统计信息(包含平均分、最高分、最低分等)" -) -def get_course_statistics(course_id: str) -> str: - # 直接调用内部业务函数 - try: - stats = query_course_statistics(course_id) - return f"课程 {course_id} 成绩统计:\n平均分: {stats.get('average', 'N/A')}\n最高分: {stats.get('max', 'N/A')}\n最低分: {stats.get('min', 'N/A')}\n总人数: {stats.get('total_students', 'N/A')}" - except Exception as e: - return f"查询成绩统计时出错: {str(e)}" - -if __name__ == "__main__": - # 使用SSE协议启动服务 - mcp.run(transport="sse", port=8000) -``` - -## 📚 更多资源 - -### Python - -- [FastMCP 文档](https://github.com/modelcontextprotocol/python-sdk)(本指南使用的框架) - -### 其他语言 - -- [MCP TypeScript SDK](https://github.com/modelcontextprotocol/typescript-sdk) -- [MCP Java SDK](https://github.com/modelcontextprotocol/java-sdk) -- [MCP Go SDK](https://github.com/modelcontextprotocol/go-sdk) -- [MCP Rust SDK](https://github.com/modelcontextprotocol/rust-sdk) - -### 通用资源 - -- [MCP 协议规范](https://modelcontextprotocol.io/) -- [Nexent 智能体开发指南](../user-guide/agent-development.md) -- [MCP 工具生态系统概览](./overview.md) - -## 🆘 获取帮助 - -如果您在开发 MCP 服务器时遇到问题,可以: - -1. 查看我们的 **[常见问题](../getting-started/faq.md)** -2. 在 [GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions) 中提问 -3. 参考 [ModelScope MCP Marketplace](https://www.modelscope.cn/mcp) 中的示例服务器 diff --git a/doc/docs/zh/mcp-ecosystem/overview.md b/doc/docs/zh/mcp-ecosystem/overview.md index 6c97460b9..901ab975f 100644 --- a/doc/docs/zh/mcp-ecosystem/overview.md +++ b/doc/docs/zh/mcp-ecosystem/overview.md @@ -2,36 +2,16 @@ Nexent 基于模型上下文协议(MCP)工具生态系统构建,提供灵活且可扩展的框架来集成各种工具和服务。MCP 作为"AI 的 USB-C"——一个通用接口标准,允许 AI 智能体无缝连接外部数据源、工具和服务。 -## 什么是 MCP? +## 📖 什么是 MCP? 模型上下文协议(MCP)是一个开放协议,使 AI 应用程序能够安全地连接到外部数据源和工具。它为 AI 模型访问和与外部系统交互提供了标准化方式,使构建强大的、上下文感知的 AI 应用程序变得更加容易。 -## MCP 社区中心 - -全球 MCP 生态系统正在蓬勃发展,多个平台支持 MCP 开发和部署: - -| 平台 | 描述 | 备注 | -|----------|-------------|-------| -| **[GitHub MCP Server](https://github.com/github/github-mcp-server)** | 与 Claude、GPT-4、Copilot 等深度集成,支持 Go 和 Python | OAuth/GitHub 账户授权 | -| **[Qdrant MCP Vector Server](https://github.com/qdrant/mcp-server-qdrant)** | 语义向量存储,Python/Go 兼容 | 与 LangChain 和其他工具兼容 | -| **[Anthropic Reference MCP Servers](https://github.com/modelcontextprotocol/servers)** | 轻量级教学和原型工具,Python | 包括 fetch、git 和其他通用工具 | -| **[AWS Labs MCP Server](https://github.com/awslabs/mcp)** | AWS+Go+CDK 云参考服务 | 适用于云环境 | -| **[MCP Hub China](https://www.mcp-cn.com/)** | 中文精选高质量 MCP 服务平台 | 注重质量而非数量,社区驱动 | -| **[ModelScope MCP Marketplace](https://modelscope.cn/mcp)** | 中国最大的 MCP 社区,拥有 1,500+ 服务 | 从高德地图到支付宝,全面的服务覆盖 | -| **社区 MCP 服务器** | 各种特定场景的源代码集合 | 主要是实验性和创新工具 | - -## 推荐的 MCP 工具 - -| 工具名称 | 功能 | 描述 | -|-----------|----------|-------------| -| **[高德地图](https://modelscope.cn/mcp/servers/@amap/amap-maps)** | 地理服务和导航 | 综合地图、地理编码、路由和位置服务 | -| **[必应搜索(中文)](https://modelscope.cn/mcp/servers/@yan5236/bing-cn-mcp-server)** | 中文网络搜索 | 优化的中文网络搜索和信息检索 | -| **[12306 火车票查询](https://modelscope.cn/mcp/servers/@Joooook/12306-mcp)** | 中国铁路票务预订 | 实时列车时刻表、票务可用性和预订协助 | -| **[支付宝 MCP](https://modelscope.cn/mcp/servers/@alipay/mcp-server-alipay)** | 支付和金融服务 | 数字支付、金融工具和服务集成 | -| **[飞常准航空](https://modelscope.cn/mcp/servers/@variflight-ai/variflight-mcp)** | 航班信息和航空数据 | 实时航班跟踪、时刻表和航空分析 | -| **[顺序思考](https://modelscope.cn/mcp/servers/@modelcontextprotocol/sequentialthinking)** | 结构化问题解决框架 | 将复杂问题分解为可管理的顺序步骤 | -| **[ArXiv AI 搜索](https://modelscope.cn/mcp/servers/@blazickjp/arxiv-mcp-server)** | 学术论文搜索和研究 | 高级搜索和检索科学论文和研究 | -| **[Firecrawl MCP 服务器](https://modelscope.cn/mcp/servers/@mendableai/firecrawl-mcp-server)** | 网络爬虫和内容提取 | 智能网络爬虫、数据提取和内容处理 | +## 🎯 MCP 平台与工具 + +关于 MCP 平台和工具的精选推荐,请访问我们的 [MCP 推荐](./mcp-recommendations) 页面,其中包括: + +- **MCP 社区中心**:发现全球 MCP 平台和市场 +- **推荐的 MCP 工具**:探索各种用例的高质量 MCP 服务 ## MCP 的优势 diff --git a/doc/docs/zh/opensource-memorial-wall.md b/doc/docs/zh/opensource-memorial-wall.md index 41fdb6ee8..7c4a7cc91 100644 --- a/doc/docs/zh/opensource-memorial-wall.md +++ b/doc/docs/zh/opensource-memorial-wall.md @@ -607,3 +607,19 @@ Nexent开发者加油 ::: info jinhb - 2025-12-03 祝nexent平台越来越好 ::: + +::: info zmu.1s - 2025-12-04 +打ICT大赛接触到了Nexent平台,祝越来越好! +::: + +::: info Papver 01 - 2025-12-05 +感谢 Nexent 让我踏上了开源之旅!这个项目的文档真的很棒,帮助我快速上手。 +::: + +::: info aschenmo - 2025-12-05 +通过 Nexent 实现了医学查询架构的完美落地,这套框架非常完美。多智能体编排体验极佳! +::: + +::: tip 开源新手 - 2025-12-05 +感谢 Nexent 让我踏上了开源之旅! +::: diff --git a/doc/docs/zh/getting-started/faq.md b/doc/docs/zh/quick-start/faq.md similarity index 69% rename from doc/docs/zh/getting-started/faq.md rename to doc/docs/zh/quick-start/faq.md index eb48b2565..121e59365 100644 --- a/doc/docs/zh/getting-started/faq.md +++ b/doc/docs/zh/quick-start/faq.md @@ -53,8 +53,37 @@ ```python { "role":"user", "content":"prompt" } +## 🐛 已知问题 + +本节列出了当前版本 Nexent 中的已知问题和限制。我们正在积极修复这些问题,并会随着解决方案的推出更新本节。 + +### 🔧 OpenSSH 容器软件安装限制 + +**问题描述**: 在 OpenSSH 容器中为终端工具使用安装其他软件包目前由于容器限制而比较困难。 + +**状态**: 开发中 + +**影响**: 需要在终端环境中使用自定义工具或软件包的用户可能面临限制。 + +**计划解决方案**: 我们正在努力提供改进的容器和文档,使自定义变得更容易。这将包括更好的包管理和更灵活的容器配置。 + +**预期时间线**: 改进的容器支持计划在即将发布的版本中提供。 + +## 📝 问题报告 + +如果您遇到此处未列出的任何问题,请: + +1. **搜索现有问题** 在 [GitHub Issues](https://github.com/ModelEngine-Group/nexent/issues) +2. **创建新问题** 并提供详细信息,包括: + - 重现步骤 + - 预期行为 + - 实际行为 + - 系统信息 + - 日志文件(如适用) + ## 💡 需要帮助 如果这里没有找到您的问题答案: - 加入我们的 [Discord 社区](https://discord.gg/tb5H3S3wyv) 获取实时支持 -- 查看我们的 [GitHub Issues](https://github.com/ModelEngine-Group/nexent/issues) 寻找类似问题 \ No newline at end of file +- 查看我们的 [GitHub Issues](https://github.com/ModelEngine-Group/nexent/issues) 寻找类似问题 +- 在 [GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions) 开启讨论 \ No newline at end of file diff --git a/doc/docs/zh/getting-started/installation.md b/doc/docs/zh/quick-start/installation.md similarity index 98% rename from doc/docs/zh/getting-started/installation.md rename to doc/docs/zh/quick-start/installation.md index 2a2760d22..bee8f9588 100644 --- a/doc/docs/zh/getting-started/installation.md +++ b/doc/docs/zh/quick-start/installation.md @@ -95,4 +95,4 @@ Nexent 采用微服务架构,包含以下核心服务: 想要从源码构建或添加新功能?查看 [Docker 构建指南](../deployment/docker-build) 获取详细说明。 -有关详细的安装说明和自定义选项,请查看我们的 [开发指南](./development-guide)。 \ No newline at end of file +有关详细的安装说明和自定义选项,请查看我们的 [开发者指南](../developer-guide/overview)。 \ No newline at end of file diff --git a/doc/docs/zh/deployment/upgrade-guide.md b/doc/docs/zh/quick-start/upgrade-guide.md similarity index 57% rename from doc/docs/zh/deployment/upgrade-guide.md rename to doc/docs/zh/quick-start/upgrade-guide.md index b7a2e12ad..b888e2ada 100644 --- a/doc/docs/zh/deployment/upgrade-guide.md +++ b/doc/docs/zh/quick-start/upgrade-guide.md @@ -2,18 +2,59 @@ ## 🚀 升级流程概览 -升级 Nexent 时建议依次完成以下四个步骤: +升级 Nexent 时建议依次完成以下几个步骤: -1. 清理旧版本容器与镜像 -2. 拉取最新代码并执行部署脚本 -3. 同步数据库结构 -4. 打开站点确认服务可用 +1. 拉取最新代码 +2. 执行升级脚本 +3. 打开站点确认服务可用 --- -## 🧹 步骤一:清理旧版本镜像 +## 🔄 步骤一:更新代码 -为避免缓存或版本冲突,先清理旧容器与镜像: +更新之前,先记录下当前部署的版本和数据目录 + +- 当前部署版本信息的位置:`backend/consts/const.py`中的 APP_VERSION +- 数据目录信息的位置:`docker/.env`中的 ROOT_DIR + +**git 方式下载的代码** + +通过 git 指令更新代码 + +```bash +git pull +``` + +**zip 包等方式下载的代码** + +需要去 github 上重新下载一份最新代码,并解压缩。另外,需要从之前执行部署脚本目录下 docker 目录中拷贝 deploy.options 到新代码目录下的 docker 目录中(如果不存在该文件则忽略)。 + +## 🔄 步骤二:执行升级 + +进入更新后代码目录的docker目录,执行升级脚本: + +```bash +bash upgrade.sh +``` + +缺少 deploy.options 的情况下,会提示需要手动输入之前部署的一些配置,比如:当前部署版本、数据目录等。按照提示输入之前记录的信息即可。 + +> 💡 提示 +> - 默认为快速部署场景,使用 `.env.example`。 +> - 若需配置语音模型(STT/TTS),请提前在 `.env.example` 中补充相关变量,我们将尽快提供前端配置入口。 + +## 🌐 步骤三:验证部署 + +部署完成后: + +1. 在浏览器打开 `http://localhost:3000` +2. 参考 [用户指南](https://doc.nexent.tech/zh/user-guide/home-page) 完成智能体配置与验证 + +## 可选操作 + +### 🧹 清理旧版本镜像 + +如果镜像未正确更新,可以在升级前先清理旧容器与镜像: ```bash # 停止并删除现有容器 @@ -39,26 +80,11 @@ docker system prune -af --- -## 🔄 步骤二:更新代码并部署 +### 🗄️ 手动更新数据库 -```bash -git pull -cd nexent/docker -cp .env.example .env -bash deploy.sh -``` - -> 💡 提示 -> - 默认为快速部署场景,可直接使用 `.env.example`。 -> - 若需配置语音模型(STT/TTS),请在 `.env` 中补充相关变量,我们将尽快提供前端配置入口。 - ---- +升级时如果存在部分 sql 文件执行失败,则可以手动执行更新。 -## 🗄️ 步骤三:同步数据库 - -升级后需要执行数据库迁移脚本,使 schema 保持最新。 - -### ✅ 方法一:使用 SQL 编辑器(推荐) +#### ✅ 方法一:使用 SQL 编辑器(推荐) 1. 打开 SQL 编辑器,新建 PostgreSQL 连接。 2. 在 `/nexent/docker/.env` 中找到以下信息: @@ -69,15 +95,15 @@ bash deploy.sh - Password 3. 填写连接信息后测试连接,确认成功后可在 `nexent` schema 中查看所有表。 4. 新建查询窗口。 -5. 打开 `/nexent/docker/sql` 目录,按文件名中的日期顺序查看 SQL 脚本。 -6. 根据上次部署日期,依次执行之后的每个 SQL 文件。 +5. 打开 `/nexent/docker/sql` 目录,通过失败的sql文件查看 SQL 脚本。 +6. 将失败的sql文件和后续版本的sql文件依次执行。 > ⚠️ 注意事项 > - 升版本前请备份数据库,生产环境尤为重要。 > - SQL 脚本需按时间顺序执行,避免依赖冲突。 > - `.env` 变量可能命名为 `POSTGRES_HOST`、`POSTGRES_PORT` 等,请在客户端对应填写。 -### 🧰 方法二:命令行执行(无需客户端) +#### 🧰 方法二:命令行执行(无需客户端) 1. 进入 Docker 目录: @@ -98,14 +124,12 @@ bash deploy.sh 3. 通过容器执行 SQL 脚本(示例): ```bash - # 假如现在是11月6日,上次更新版本的时间是10月20日 - # 此时新增了1030-update.sql和1105-update.sql两个文件 # 我们需要执行以下命令(请注意替换占位符中的变量) - docker exec -i nexent-postgresql psql -U [YOUR_POSTGRES_USER] -d [YOUR_POSTGRES_DB] < ./sql/1030-update.sql - docker exec -i nexent-postgresql psql -U [YOUR_POSTGRES_USER] -d [YOUR_POSTGRES_DB] < ./sql/1105-update.sql + docker exec -i nexent-postgresql psql -U [YOUR_POSTGRES_USER] -d [YOUR_POSTGRES_DB] < ./sql/v1.1.1_1030-update.sql + docker exec -i nexent-postgresql psql -U [YOUR_POSTGRES_USER] -d [YOUR_POSTGRES_DB] < ./sql/v1.1.2_1105-update.sql ``` - 请根据自己的部署时间,按时间顺序执行对应脚本。 + 请根据自己的部署版本,按版本顺序执行对应脚本。 > 💡 提示 > - 若 `.env` 中定义了数据库变量,可先导入: @@ -127,12 +151,3 @@ bash deploy.sh > ```bash > docker exec -i nexent-postgres pg_dump -U [YOUR_POSTGRES_USER] [YOUR_POSTGRES_DB] > backup_$(date +%F).sql > ``` - ---- - -## 🌐 步骤四:验证部署 - -部署完成后: - -1. 在浏览器打开 `http://localhost:3000` -2. 参考 [用户指南](https://doc.nexent.tech/zh/user-guide/home-page) 完成智能体配置与验证 diff --git a/doc/docs/zh/sdk/basic-usage.md b/doc/docs/zh/sdk/basic-usage.md index 6f54dfa8e..347256e01 100644 --- a/doc/docs/zh/sdk/basic-usage.md +++ b/doc/docs/zh/sdk/basic-usage.md @@ -2,35 +2,9 @@ 本指南提供使用 Nexent SDK 构建智能体的全面介绍。 -## 🚀 安装方式 +## 🚀 安装与环境 -### 用户安装 -如果您想使用 Nexent: - -```bash -# 推荐:从源码安装 -git clone https://github.com/ModelEngine-Group/nexent.git -cd nexent/sdk -uv pip install -e . - -# 或使用 uv 安装 -uv add nexent -``` - -### 开发环境设置 -如果您是第三方 SDK 开发者: - -```bash -# 安装完整开发环境(包括 Nexent) -cd nexent/sdk -uv pip install -e ".[dev]" # 包含所有开发工具(测试、代码质量检查等) -``` - -开发环境包含以下额外功能: -- 代码质量检查工具 (ruff) -- 测试框架 (pytest) -- 数据处理依赖 (unstructured) -- 其他开发依赖 +完整的全栈与仅 SDK 安装路径已集中到 [环境准备](../developer-guide/environment-setup) 指南。请先完成环境配置,再继续本页的快速开始。 ## ⚡ 快速开始 @@ -101,11 +75,7 @@ agent.run("你的问题") ## 📡 使用 agent_run(推荐的流式运行方式) -当你需要在服务端或前端以“事件流”方式消费消息时,推荐使用 `agent_run`。它会在后台线程执行智能体,并持续产出 JSON 格式的消息,便于 UI 展示与日志采集。 - -参考文档: [使用 agent_run 运行智能体](./core/agent-run) - -最小示例: +当需要在服务端或前端以“事件流”方式消费消息时,使用 `agent_run`。它在后台线程执行智能体,并从 `MessageObserver` 持续产出 JSON 字符串,便于 UI 展示与日志采集。 ```python import json @@ -145,11 +115,48 @@ async def main(): async for message in agent_run(agent_run_info): message_data = json.loads(message) - print(message_data) + print(message_data) # 每条都是 JSON 字符串 asyncio.run(main()) ``` +### 🛰️ 消息流格式 + +- `type`:消息类型(对应 `ProcessType`,如 `STEP_COUNT`、`MODEL_OUTPUT_THINKING`、`PARSE`、`EXECUTION_LOGS`、`FINAL_ANSWER`、`ERROR`) +- `content`:文本内容 +- `agent_name`(可选):产出该消息的智能体 + +### 🧠 传入历史(可选) + +```python +from nexent.core.agents.agent_model import AgentHistory + +history = [ + AgentHistory(role="user", content="你好"), + AgentHistory(role="assistant", content="你好,我能帮你做什么?"), +] + +agent_run_info = AgentRunInfo( + # ... + history=history, +) +``` + +### 🌐 MCP 工具集成(可选) + +```python +agent_run_info = AgentRunInfo( + # ... + mcp_host=["http://localhost:3000"], # 或包含 url/transport 的 dict +) +``` + +### ⏹️ 优雅中断 + +```python +stop_event.set() # 智能体会在当前步完成后停止 +``` + ## 🔧 配置选项 ### ⚙️ 智能体配置 @@ -177,7 +184,7 @@ search_tool = ExaSearchTool( ## 📚 更多资源 -- **[使用 agent_run 运行智能体](./core/agent-run)** +- **[流式运行 agent_run](#使用-agent_run推荐的流式运行方式)** - **[工具开发指南](./core/tools)** - **[模型架构指南](./core/models)** - **[智能体模块](./core/agents)** \ No newline at end of file diff --git a/doc/docs/zh/sdk/core/agent-run.md b/doc/docs/zh/sdk/core/agent-run.md deleted file mode 100644 index b1ad3334a..000000000 --- a/doc/docs/zh/sdk/core/agent-run.md +++ /dev/null @@ -1,166 +0,0 @@ -# 使用 agent_run 运行智能体(流式) - -`agent_run` 提供了一种更简洁且线程友好的方式来运行智能体,并通过 `MessageObserver` 提供实时流式输出。该接口适合需要前端流式展示、服务端推送以及需要结合 MCP 工具集的场景。 - -## 快速开始 - -```python -import json -import asyncio -import logging -from threading import Event - -from nexent.core.agents.run_agent import agent_run -from nexent.core.agents.agent_model import ( - AgentRunInfo, - AgentConfig, - ModelConfig -) -from nexent.core.utils.observer import MessageObserver - - -async def main(): - # 1) 创建消息观察者(负责接收流式消息) - observer = MessageObserver(lang="zh") - - # 2) 外部停止开关(可用于在 UI 上中断执行) - stop_event = Event() - - # 3) 配置模型 - model_config = ModelConfig( - cite_name="gpt-4", # 模型别名(自定义,在 AgentConfig 中引用) - api_key="", - model_name="Qwen/Qwen2.5-32B-Instruct", - url="https://api.siliconflow.cn/v1", - temperature=0.3, - top_p=0.9 - ) - - # 4) 配置 Agent - agent_config = AgentConfig( - name="example_agent", - description="An example agent that can execute Python code and search the web", - prompt_templates=None, - tools=[], - max_steps=5, - model_name="gpt-4", # 与上面 model_config.cite_name 对应 - provide_run_summary=False, - managed_agents=[] - ) - - # 5) 组装运行信息 - agent_run_info = AgentRunInfo( - query="strrawberry中出现了多少个字母r", # 示例问题 - model_config_list=[model_config], - observer=observer, - agent_config=agent_config, - mcp_host=None, # 可选:MCP 服务地址列表 - history=None, # 可选:历史对话 - stop_event=stop_event - ) - - # 6) 流式运行,并消费消息 - async for message in agent_run(agent_run_info): - message_data = json.loads(message) - message_type = message_data.get("type", "unknown") - content = message_data.get("content", "") - print(f"[{message_type}] {content}") - - # 7) 读取最终答案(如有) - final_answer = observer.get_final_answer() - if final_answer: - print(f"\nFinal Answer: {final_answer}") - - -if __name__ == "__main__": - logging.disable(logging.CRITICAL) - asyncio.run(main()) -``` - -提示:请将 `api_key` 等敏感配置放入环境变量或安全管理服务中,避免硬编码到代码库。 - -## 消息流格式与处理 - -`agent_run` 内部通过一个后台线程执行智能体,并将 `MessageObserver` 中缓存的消息以 JSON 字符串形式不断产出。你可以解析其中的字段进行分类展示或记录日志。 - -- 重要字段 - - `type`: 消息类型(对应 `ProcessType`) - - `content`: 文本内容 - - `agent_name`: 可选,当前产出该消息的智能体名称 - -常见 `type`(来自 `ProcessType`): -- `AGENT_NEW_RUN`: 新的任务开始 -- `STEP_COUNT`: 步数更新 -- `MODEL_OUTPUT_THINKING` / `MODEL_OUTPUT_CODE`: 模型思考/代码片段 -- `PARSE`: 代码解析结果 -- `EXECUTION_LOGS`: Python 执行日志 -- `FINAL_ANSWER`: 最终答案 -- `ERROR`: 错误信息 - -## 配置项说明 - -### ModelConfig - -- `cite_name`:模型别名(用于在 `AgentConfig.model_name` 中引用) -- `api_key`:模型服务 API Key -- `model_name`:模型调用名 -- `url`:模型服务的 Base URL -- `temperature` / `top_p`:采样参数 - -### AgentConfig - -- `name`:智能体名称 -- `description`:智能体描述 -- `prompt_templates`:可选,Jinja 模板字典 -- `tools`:工具配置列表(见下方 ToolConfig) -- `max_steps`:最大步数 -- `model_name`:模型别名(对应 `ModelConfig.cite_name`) -- `provide_run_summary`:是否在子智能体返回总结 -- `managed_agents`:子智能体配置列表 - -### 传入历史对话(可选) - -你可以通过 `AgentRunInfo.history` 传入历史消息,Nexent 会将其写入内部记忆: - -```python -from nexent.core.agents.agent_model import AgentHistory - -history = [ - AgentHistory(role="user", content="你好"), - AgentHistory(role="assistant", content="你好,我能帮你做什么?"), -] - -agent_run_info = AgentRunInfo( - # ... 其他字段省略 - history=history, -) -``` - -## MCP 工具集成(可选) - -若你提供 `mcp_host`(MCP 服务地址列表),Nexent 会自动通过 `ToolCollection.from_mcp` 拉取远程工具集合,并注入到智能体中: - -```python -agent_run_info = AgentRunInfo( - # ... 其他字段省略 - mcp_host=["http://localhost:3000"], -) -``` - -连接失败时会自动产出友好错误信息(中/英)。 - -## 中断执行 - -执行过程中可通过 `stop_event.set()` 触发中断: - -```python -stop_event.set() # 智能体会在当前步完成后优雅停止 -``` - -## 与 CoreAgent 的关系 - -- `agent_run` 是对 `NexentAgent` 与 `CoreAgent` 的一层包装,负责: - - 构造 `CoreAgent`(包含模型与工具) - - 将历史注入记忆 - - 驱动流式执行并转发 `MessageObserver` 的缓存消息 -- 你也可以直接使用 `CoreAgent.run(stream=True)` 自行处理流(见 `core/agents.md`),`agent_run` 提供了更方便的线程化与 JSON 消息输出。 \ No newline at end of file diff --git a/doc/docs/zh/sdk/core/agents.md b/doc/docs/zh/sdk/core/agents.md index c83004c3d..736663cb6 100644 --- a/doc/docs/zh/sdk/core/agents.md +++ b/doc/docs/zh/sdk/core/agents.md @@ -53,137 +53,7 @@ ProcessType枚举定义了以下处理阶段: ## 🤖 智能体开发 -### 创建基本智能体 - -```python -from nexent.core.utils.observer import MessageObserver -from nexent.core.agents.core_agent import CoreAgent -from nexent.core.models.openai_llm import OpenAIModel -from nexent.core.tools import ExaSearchTool, KnowledgeBaseSearchTool - -# 创建消息观察者 -observer = MessageObserver() - -# 创建模型(model和Agent必须使用同一个observer) -model = OpenAIModel( - observer=observer, - model_id="your-model-id", - api_key="your-api-key", - api_base="your-api-base" -) - -# 创建工具 -search_tool = ExaSearchTool(exa_api_key="your-exa-key", observer=observer, max_results=5) -kb_tool = KnowledgeBaseSearchTool(top_k=5, observer=observer) - -# 创建Agent -agent = CoreAgent( - observer=observer, - tools=[search_tool, kb_tool], - model=model, - name="my_agent", - max_steps=5 -) - -# 运行Agent -agent.run("你的问题") -``` - -> 如果你希望以更简洁的方式获得“JSON 流式消息”,推荐阅读:**[使用 agent_run 运行智能体](./agent-run)**。 - -### 自定义智能体开发 - -#### 系统提示词模板 -系统提示词模板位于 `backend/prompts/`: - -- **knowledge_summary_agent.yaml**: 知识库摘要代理 -- **manager_system_prompt_template.yaml**: 管理器系统提示词模板 -- **utils/**: 提示词工具 - -- 若不显式提供 `system_prompt`,将使用 SmolAgents 的默认提示词。 -- 若需要自定义,建议以 `manager_system_prompt_template.yaml` 为基准进行渲染后传入。 - -##### 加载并覆盖 system_prompt(推荐做法) - -```python -from pathlib import Path -import yaml -from jinja2 import Environment, BaseLoader - -from nexent.core.agents.core_agent import CoreAgent -from nexent.core.models.openai_llm import OpenAIModel - -# 1) Load YAML template text -prompt_yaml_path = Path("backend/prompts/manager_system_prompt_template.yaml") -yaml_text = prompt_yaml_path.read_text(encoding="utf-8") -yaml_data = yaml.safe_load(yaml_text) - -# 2) Render Jinja template in 'system_prompt' key -system_prompt_template = yaml_data["system_prompt"] -jinja_env = Environment(loader=BaseLoader()) -rendered_system_prompt = jinja_env.from_string(system_prompt_template).render( - APP_NAME="Nexent Agent", - APP_DESCRIPTION="Enterprise-grade AI agent", - duty="回答用户的问题并在需要时调用工具", - tools={}, # Provide tools summary if needed - managed_agents={}, # Provide managed agents summary if needed - knowledge_base_summary=None, - constraint="遵守组织策略,注意数据与访问安全", - authorized_imports=["requests", "pandas"], - few_shots="", - memory_list=[], -) -yaml_data['system_prompt'] = rendered_system_prompt - -# 3) Build agent with custom system prompt -observer = MessageObserver() -model = OpenAIModel(observer=observer, model_id="your-model-id", api_key="your-api-key", api_base="your-api-base") -agent = CoreAgent( - observer=observer, - model=model, - tools=[search_tool, kb_tool], - system_prompt=yaml_data, - name="my_agent", -) -``` - -> 提示:`manager_system_prompt_template.yaml` 中同时包含 `managed_agent`、`planning`、`final_answer` 等其它模板片段。一般情况下仅需取其 `system_prompt` 键进行渲染并覆盖;如有多智能体协作等高级需求,可按需加载其它片段。 - -#### 智能体实现步骤 - -1. **创建智能体实例**: - ```python - from nexent.core.agents.core_agent import CoreAgent - from nexent.core.models.openai_llm import OpenAIModel - - model = OpenAIModel( - model_id="your-model-id", - api_key="your-api-key", - api_base="your-api-base" - ) - agent = CoreAgent( - model=model, - tools=[your_tools], - system_prompt="你的系统提示词" - ) - ``` - -2. **配置智能体行为**: - - 通过 `tools` 参数添加自定义工具 - - 通过 `system_prompt` 设置行为 - - 配置 `max_steps` 等参数 - -3. **高级配置**: - ```python - agent = CoreAgent( - model=model, - tools=custom_tools, - system_prompt=custom_prompt, - max_steps=10, - verbose=True, - additional_authorized_imports=["requests", "pandas"] - ) - ``` +具体的代码示例已集中到 [基本使用](../basic-usage#使用-agent_run推荐的流式运行方式),其中包含 `CoreAgent.run` 与流式的 `agent_run`。本页仅保留模块层面的概念和能力描述。 ## 🛠️ 工具集成 @@ -255,4 +125,4 @@ def my_tool(param1: str, param2: int) -> str: 3. **扩展策略**: 规划增加的负载和使用 4. **安全考虑**: 验证输入并保护API访问 -详细的实现示例和高级模式,请参阅 [开发指南](../../getting-started/development-guide)。 \ No newline at end of file +详细的实现示例和高级模式,请参阅 [开发者指南](../../developer-guide/overview)。 \ No newline at end of file diff --git a/doc/docs/zh/sdk/overview.md b/doc/docs/zh/sdk/overview.md index 2d13023e0..b710c3d51 100644 --- a/doc/docs/zh/sdk/overview.md +++ b/doc/docs/zh/sdk/overview.md @@ -4,7 +4,7 @@ Nexent 是一个强大的企业级 Agent SDK,革命性地简化了智能体开 ## 🚀 安装与使用 -有关详细的安装说明和使用指南,请参阅 **[基本使用指南](./basic-usage)**。如果你需要基于服务端/前端的事件流展示,请参阅 **[使用 agent_run 运行智能体](./core/agent-run)**。 +有关详细的安装说明和使用指南,请参阅 **[基本使用指南](./basic-usage#使用-agent_run推荐的流式运行方式)**,其中包含 `CoreAgent.run` 和流式的 `agent_run`。 ## ⭐ 主要特性 @@ -32,7 +32,7 @@ Nexent 是一个强大的企业级 Agent SDK,革命性地简化了智能体开 Nexent 提供了完整的智能体解决方案,支持多模型、MCP 集成、动态工具加载和分布式执行。 -- 快速流式运行:**[使用 agent_run 运行智能体](./core/agent-run)** +- 快速流式运行:**[流式运行 agent_run](./basic-usage#使用-agent_run推荐的流式运行方式)** - 详细的 Agent 开发和使用说明:**[智能体模块](./core/agents)** ## 🛠️ 工具集合 diff --git a/doc/docs/zh/user-guide/agent-development.md b/doc/docs/zh/user-guide/agent-development.md index ff4c7c943..7d4a28581 100644 --- a/doc/docs/zh/user-guide/agent-development.md +++ b/doc/docs/zh/user-guide/agent-development.md @@ -74,7 +74,7 @@ Nexent 支持您快速便捷地使用第三方 MCP 工具,丰富 Agent 能力 有许多第三方服务如 [ModelScope](https://www.modelscope.cn/mcp) 提供了 MCP 服务,您可以快速接入使用。 -您也可以自行开发 MCP 服务并接入 Nexent 使用,参考文档 [MCP 服务开发](../mcp-ecosystem/mcp-server-development.md)。 +您也可以自行开发 MCP 服务并接入 Nexent 使用,参考文档 [MCP 工具开发](../backend/tools/mcp)。 ### ⚙️ 自定义工具 @@ -156,4 +156,4 @@ Nexent 支持您快速便捷地使用第三方 MCP 工具,丰富 Agent 能力 2. 在 **[开始问答](./start-chat)** 中与智能体进行交互 3. 在 **[记忆管理](./memory-management)** 配置记忆以提升智能体的个性化能力 -如果您在智能体开发过程中遇到任何问题,请参考我们的 **[常见问题](../getting-started/faq)** 或在 [GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions) 中进行提问获取支持。 +如果您在智能体开发过程中遇到任何问题,请参考我们的 **[常见问题](../quick-start/faq)** 或在 [GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions) 中进行提问获取支持。 diff --git a/doc/docs/zh/user-guide/agent-market.md b/doc/docs/zh/user-guide/agent-market.md index ca68c9400..65beef4db 100644 --- a/doc/docs/zh/user-guide/agent-market.md +++ b/doc/docs/zh/user-guide/agent-market.md @@ -36,4 +36,4 @@ 2. 通过 **[智能体开发](./agent-development)** 创建专属智能体 3. 在 **[开始问答](./start-chat)** 中体验智能体的强大功能 -如果您使用过程中遇到任何问题,请参考我们的 **[常见问题](../getting-started/faq)** 或在[GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions)中进行提问获取支持。 +如果您使用过程中遇到任何问题,请参考我们的 **[常见问题](../quick-start/faq)** 或在[GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions)中进行提问获取支持。 diff --git a/doc/docs/zh/user-guide/agent-space.md b/doc/docs/zh/user-guide/agent-space.md index 288d83768..fcc639a21 100644 --- a/doc/docs/zh/user-guide/agent-space.md +++ b/doc/docs/zh/user-guide/agent-space.md @@ -66,4 +66,4 @@ 2. 继续 **[智能体开发](./agent-development)** 创建更多智能体 3. 配置 **[记忆管理](./memory-management)** 以提升智能体的记忆能力 -如果您在配置过程中遇到任何问题,请参考我们的 **[常见问题](../getting-started/faq)** 或在[GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions)中进行提问获取支持。 +如果您在配置过程中遇到任何问题,请参考我们的 **[常见问题](../quick-start/faq)** 或在[GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions)中进行提问获取支持。 diff --git a/doc/docs/zh/user-guide/home-page.md b/doc/docs/zh/user-guide/home-page.md index 9bd7a129c..067d9e09c 100644 --- a/doc/docs/zh/user-guide/home-page.md +++ b/doc/docs/zh/user-guide/home-page.md @@ -48,4 +48,4 @@ Nexent首页展示了平台的核心功能,为您提供快速入口: ## 💡 获取帮助 -如果您在配置过程中遇到任何问题,请参考我们的 **[常见问题](../getting-started/faq)** 或在[GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions)中进行提问获取支持。 \ No newline at end of file +如果您在配置过程中遇到任何问题,请参考我们的 **[常见问题](../quick-start/faq)** 或在[GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions)中进行提问获取支持。 \ No newline at end of file diff --git a/doc/docs/zh/user-guide/knowledge-base.md b/doc/docs/zh/user-guide/knowledge-base.md index 3a099ee53..70d9cf513 100644 --- a/doc/docs/zh/user-guide/knowledge-base.md +++ b/doc/docs/zh/user-guide/knowledge-base.md @@ -76,4 +76,4 @@ Nexent支持多种文件格式,包括: 1. **[智能体开发](./agent-development)** - 创建和配置智能体 2. **[开始问答](./start-chat)** - 与智能体进行交互 -如果您在知识库配置过程中遇到任何问题,请参考我们的 **[常见问题](../getting-started/faq)** 或在[GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions)中进行提问获取支持。 \ No newline at end of file +如果您在知识库配置过程中遇到任何问题,请参考我们的 **[常见问题](../quick-start/faq)** 或在[GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions)中进行提问获取支持。 \ No newline at end of file diff --git a/doc/docs/zh/user-guide/local-tools/email-tools.md b/doc/docs/zh/user-guide/local-tools/email-tools.md new file mode 100644 index 000000000..63b6ea0e4 --- /dev/null +++ b/doc/docs/zh/user-guide/local-tools/email-tools.md @@ -0,0 +1,58 @@ +--- +title: 邮件工具 +--- + +# 邮件工具 + +邮件工具组支持收取与发送邮件,适合在智能体中获取通知或发送结果汇报。 + +## 🧭 工具清单 + +- `get_email`:按时间范围、发件人获取邮件,限制返回数量 +- `send_email`:发送 HTML 格式邮件,支持多收件人、抄送、密送 + +## 🧰 使用场景示例 + +- 周期性抓取近 7 天内的通知邮件,供后续摘要或分析 +- 发送执行结果到指定收件人并抄送团队成员 +- 针对特定发件人(如监控账户)筛选告警邮件 + +## 🧾 参数要求与行为 + +### get_email +- `days`:获取过去 N 天邮件,默认 7。 +- `sender`:按邮箱地址过滤发件人,可选。 +- `max_emails`:最大返回邮件数,默认 10。 +- 需要提供 IMAP 服务器地址、端口、用户名、密码;支持 SSL。 +- 返回邮件主题、时间、发件人、正文摘要等 JSON 信息。 + +### send_email +- `to`:收件人列表,使用逗号分隔。 +- `subject`:邮件主题。 +- `content`:邮件正文,支持 HTML。 +- `cc`、`bcc`:抄送/密送列表,逗号分隔,可选。 +- 需要提供 SMTP 服务器地址、端口、用户名、密码;可设置发件人展示名与 SSL。 +- 返回发送状态、主题、收件人信息。 + +## 🛠️ 操作指引 + +1. **获取邮箱配置**:准备 IMAP/SMTP 地址、端口、账号密码,确认是否启用 SSL。 +2. **收取邮件**:调用 `get_email`,按需设置 `days`、`sender`、`max_emails`。若需更窄范围,先测试少量结果。 +3. **发送邮件**:调用 `send_email`,填写收件人、主题与 HTML 正文;如需抄送/密送可添加 `cc`/`bcc`。 +4. **内容处理**:收取的邮件正文可再结合模型做摘要或提取关键信息。 + +## 🛡️ 安全与最佳实践 + +- 邮箱账号请使用专用的应用密码或受限账号,避免暴露主密码。 +- 控制 `max_emails` 防止一次抓取过多数据。 +- 发送前检查收件人列表,避免误发;生产环境可限制允许的域名。 + +## 📮 常见邮箱配置 + +> 建议使用各邮箱的“应用专用密码”并在邮箱设置中启用 IMAP/SMTP。端口号为行业常用值,若服务商有最新要求请以官方文档为准。 + +- QQ 邮箱:IMAP `imap.qq.com:993`(SSL),SMTP `smtp.qq.com:465`(SSL);需要在 QQ 邮箱中开启“IMAP/SMTP 服务”并申请授权码。 +- Gmail:IMAP `imap.gmail.com:993`,SMTP `smtp.gmail.com:465`(SSL)或 `587`(STARTTLS);需要开启 IMAP 并使用应用密码(建议关闭不安全访问)。 +- Outlook(Microsoft 365 / Hotmail):IMAP `outlook.office365.com:993`,SMTP `smtp.office365.com:587`(STARTTLS);企业租户可能要求现代认证或应用密码。 +- 163 邮箱:IMAP `imap.163.com:993`(SSL),SMTP `smtp.163.com:465`(SSL);需在邮箱设置里开启“客户端授权密码/安全密码”。 + diff --git a/doc/docs/zh/user-guide/local-tools/file-tools.md b/doc/docs/zh/user-guide/local-tools/file-tools.md new file mode 100644 index 000000000..2dc084ae8 --- /dev/null +++ b/doc/docs/zh/user-guide/local-tools/file-tools.md @@ -0,0 +1,55 @@ +--- +title: 文件工具 +--- + +# 文件工具 + +文件工具组提供在工作空间内安全、受限的文件与目录操作,所有路径都必须是相对于工作空间的相对路径,默认工作空间根目录为 `/mnt/nexent`。 + +## 🧭 工具清单 + +- `create_directory`:创建目录(自动创建父级,支持权限设置) +- `create_file`:创建文件并写入内容(自动创建父级) +- `read_file`:读取文件内容与元信息 +- `list_directory`:以树形列出目录结构 +- `move_item`:移动文件或目录到新位置(防止覆盖) +- `delete_file`:删除单个文件(不可恢复) +- `delete_directory`:递归删除目录及其内容(不可恢复) + +## 🧰 使用场景示例 + +- 初始化项目目录、生成配置文件 +- 查看日志、检查文件大小或行数 +- 列出工作空间结构,确认文件位置 +- 批量迁移文件到备份目录 +- 清理无用文件或临时目录 + +## 🧾 参数要求与行为 + +### 通用限制 +- 路径必须在工作空间内,禁止越界访问绝对路径。 +- 删除与移动操作不可恢复,请谨慎使用。 + +### 关键参数 +- `directory_path` / `file_path` / `source_path` / `destination_path`:相对路径,必填。 +- `permissions`(create_directory):八进制权限字符串,默认 `755`。 +- `encoding`(create_file / read_file):文件编码,默认 `utf-8`。 +- `max_depth`、`show_hidden`、`show_size`(list_directory):控制目录树展示深度、是否显示隐藏文件、是否显示大小。 + +### 返回结果 +- 成功时返回 JSON,包含相对/绝对路径、大小、是否已存在等信息。 +- 失败时返回明确的错误原因(路径越界、目标已存在、权限问题等)。 + +## 🛠️ 操作指引 + +1. **创建**:使用 `create_directory` 或 `create_file`,传入相对路径;需要自定义权限或编码时显式填写。 +2. **查看**:使用 `list_directory` 浏览结构;用 `read_file` 获取内容和元数据。 +3. **移动**:用 `move_item` 将文件/目录迁移到新位置,若目标已存在会中断以避免覆盖。 +4. **删除**:用 `delete_file` 或 `delete_directory` 清理资源,操作不可恢复,请先确认路径。 + +## 🛡️ 安全与最佳实践 + +- 仅在工作空间内操作,避免绝对路径或 `..` 越界。 +- 删除前可先 `list_directory` 或 `read_file` 确认目标。 +- 大文件读取会给出提示,必要时分块处理或避免一次性读取超大文件。 + diff --git a/doc/docs/zh/user-guide/local-tools/index.md b/doc/docs/zh/user-guide/local-tools/index.md index d4eb6d8da..bd49ef79e 100644 --- a/doc/docs/zh/user-guide/local-tools/index.md +++ b/doc/docs/zh/user-guide/local-tools/index.md @@ -1,64 +1,24 @@ -# 本地工具 +# 概览 -Nexent平台提供了丰富的本地工具,帮助智能体完成各种系统级任务和本地操作。这些工具通过与本地系统或远程服务器的直接交互,为智能体提供了强大的执行能力。 +本地工具为智能体提供与工作空间、远程主机、外部服务交互的能力,涵盖文件操作、邮件、搜索、多模态与远程终端。每个工具都有独立页面,按功能分组说明使用方式与注意事项。 -## 🛠️ 可用工具 +## 📂 目录 -Nexent预置了一组可以直接复用的本地工具。它们按照能力分为邮件、文件、搜索、多模态三大类,Terminal 工具则作为远程 Shell 能力单独提供。下方列出各工具的名称与核心特性,方便在 Agent 中快速定位所需能力。 +- [文件工具](./file-tools):创建/读取/移动/删除文件与目录,树形列目录。 +- [邮件工具](./email-tools):收取 IMAP 邮件,发送 HTML 邮件(支持抄送/密送)。 +- [搜索工具](./search-tools):本地/ DataMate 知识库检索与 Exa/Tavily/Linkup 公网搜索。 +- [多模态工具](./multimodal-tools):文本文件与图片的下载、解析、模型分析。 +- [终端工具](./terminal-tool):持久化 SSH 会话,远程执行命令。 -### 📧 邮件工具(Email) +## ⚙️ 配置入口 -- **get_email**:通过 IMAP 协议拉取邮箱内容,支持按天数限定时间范围、按发件人精确过滤,并可限制一次返回的邮件数量。工具会自动处理多语言主题与正文解码,结果包含主题、时间、发件人、正文摘要等字段,便于 Agent 做进一步分析。 -- **send_email**:基于 SMTP 发送 HTML 格式邮件,支持同时指定多个收件人、抄送(CC)、密送(BCC),并可自定义发件人展示名。所有连接都走 SSL/TLS,发送结果会提示投递状态和主题,方便记录。 +1. 打开 **[智能体开发](../agent-development)** 页面。 +2. 在“选择 Agent 的工具”中找到对应工具,点击配置。 +3. 填写连接或鉴权参数,保存并启用,建议先进行测试连接。 -### 📂 文件工具(File) +## 💡 使用建议 -- **create_directory**:在指定相对路径下创建多级目录,自动跳过已存在的层级,并返回创建结果与最终绝对路径。 -- **create_file**:新建文件并写入内容,如果父级目录不存在会自动创建;支持自定义编码(默认 UTF-8)和空内容文件。 -- **read_file**:读取文本文件内容并返回文件大小、行数、编码等元信息,在文件过大时会提醒(10MB 安全阈值)。 -- **list_directory**:以树状结构列出目录内容,可控制最大递归深度、是否展示隐藏文件和文件大小,输出同时包含可视化字符串和结构化 JSON,适合用于展示项目结构。 -- **move_item**:在工作空间内移动文件或文件夹,自动创建目标目录,避免目标已存在导致覆盖,并在结果中给出移动的项目数量和大小。 -- **delete_file**:删除单个文件,包含权限与存在性校验,失败时会给出明确错误信息。 -- **delete_directory**:递归删除目录及其内容,带有目录存在性、权限和安全校验,删除后返回被删除的相对路径。 - -> 所有文件路径都需要是工作空间内的相对路径(默认 `/mnt/nexent`),工具会自动校验防止越界访问。 - -### 🔍 搜索工具(Search) - -- **knowledge_base_search**:对接本地知识库索引,支持 `hybrid`、`accurate`、`semantic` 三种检索模式,可按索引名称筛选。返回结果附带来源、得分、引用序号,适合回答私有文档或行业资料问题。 -- **exa_search**:调用 EXA API 进行实时全网搜索,可配置返回条数并支持附带图片链接(默认在服务端进一步筛选)。需要在工具配置中填写 EXA API Key,前往 [exa.ai](https://exa.ai/) 注册即可免费获取。 -- **tavily_search**:基于 Tavily API 的网页搜索,擅长新闻、实时资讯查询,同时返回文本结果和相关图片 URL,同样支持可选的图片过滤能力,可在 [tavily.com](https://www.tavily.com/) 免费申请 API Key。 -- **linkup_search**:使用 Linkup API 获取文本与图片结果,除了普通网页内容,还能返回纯图片结果,适合需要图文混合参考的场景。访问 [linkup.so](https://www.linkup.so/) 注册获取免费的 API Key。 - - -### 🖼️ 多模态工具(Multimodal) - -- **analyze_text_file**:基于用户提问和文本文件的s3 url、http url、https url,解析文件并使用大语言模型理解文件,回答用户问题。需要在模型管理页面配置可用的大语言模型。 -- **analyze_image**:基于用户提问和图片的s3 url、http url、https url,使用视觉语言模型分析理解图像,回答用户问题。需要在模型管理页面配置可用的视觉语言模型。 - -### 🖥️ Terminal工具 - -**Terminal工具** 是 Nexent 平台的核心本地工具之一,提供持久化 SSH 会话能力,可在 Agent 中执行远程命令、进行系统巡检、读取日志或部署服务。详细的部署、参数和安全指引请查看专门的 [Terminal 使用手册](./terminal-tool.md)。 - -## 🔧 工具配置 - -所有本地工具都需要在智能体开发中进行设置: - -1. 进入 **[智能体开发](../agent-development)** 页面 -2. 选择要配置的智能体 -3. 在"选择Agent的工具"页签中找到相应的本地工具 -4. 点击配置按钮,填写必要的连接参数 -5. 测试连接确保配置正确 -6. 保存配置并启用工具 - -## ⚠️ 安全注意事项 - -使用本地工具时,请务必注意以下安全事项: - -- **权限控制**:为工具创建专用用户,遵循最小权限原则 -- **网络安全**:使用VPN或IP白名单限制访问 -- **认证安全**:优先使用密钥认证,定期更换密钥 -- **命令限制**:在生产环境中配置命令白名单 -- **审计日志**:启用详细的操作日志记录 - -如果您使用本地工具过程中遇到任何问题,请在[GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions)中进行提问获取支持。 \ No newline at end of file +- 路径类操作仅限工作空间范围,请使用相对路径。 +- 公网搜索需先在平台安全配置中填写 API Key。 +- 终端工具涉及远程主机,请确认网络与账号安全策略。 +- 删除、移动类操作不可恢复,执行前先确认目标。 \ No newline at end of file diff --git a/doc/docs/zh/user-guide/local-tools/multimodal-tools.md b/doc/docs/zh/user-guide/local-tools/multimodal-tools.md new file mode 100644 index 000000000..114504365 --- /dev/null +++ b/doc/docs/zh/user-guide/local-tools/multimodal-tools.md @@ -0,0 +1,48 @@ +--- +title: 多模态工具 +--- + +# 多模态工具 + +多模态工具组支持分析文本文件与图片,结合模型能力生成用户问题相关的解读结果。支持 S3、HTTP、HTTPS 等 URL。 + +## 🧭 工具清单 + +- `analyze_text_file`:下载并提取文本文件内容后进行分析 +- `analyze_image`:下载图片并使用视觉语言模型进行理解与描述 + +## 🧰 使用场景示例 + +- 对上传到存储桶的文档进行快速摘要或要点提取 +- 对截图、产品图片、报表图进行内容解读或关键信息提取 +- 结合问题指令,对多份文件/图片分别生成答案列表 + +## 🧾 参数要求与行为 + +### analyze_text_file +- `file_url_list`:文件 URL 列表,支持 `s3://bucket/key`、`/bucket/key`、`http(s)://`。 +- `query`:用户问题/分析需求。 +- 会逐个文件下载、提取文本,再基于问题生成对应分析结果数组。 + +### analyze_image +- `image_urls_list`:图片 URL 列表,支持 `s3://bucket/key`、`/bucket/key`、`http(s)://`。 +- `query`:用户问题/关注点。 +- 会逐张图片下载并调用视觉语言模型,返回与顺序对应的描述或答案数组。 + +## ⚙️ 前置配置 + +- 确保已在平台配置可用的存储客户端(如 MinIO/S3)及数据处理服务地址,保证能下载文件。 +- 为 `analyze_text_file` 配置可用的 LLM;为 `analyze_image` 配置可用的视觉语言模型。 + +## 🛠️ 操作指引 + +1. 准备文件或图片的可访问 URL,确认权限与路径正确。 +2. 调用相应工具,填写 URL 列表与问题描述;支持一次处理多条资源。 +3. 检查返回的数组结果顺序与输入列表一致,便于继续引用或展示。 + +## 💡 最佳实践 + +- 对体积较大的文件可先在数据处理服务中做预处理或分片,减少超时风险。 +- 处理多张图片时,可在问题中明确关注点(如“只关注图表中的趋势”)以提升回答质量。 +- 若返回为空或报错,先验证 URL 可访问性和模型配置是否就绪。 + diff --git a/doc/docs/zh/user-guide/local-tools/search-tools.md b/doc/docs/zh/user-guide/local-tools/search-tools.md new file mode 100644 index 000000000..572fffaa6 --- /dev/null +++ b/doc/docs/zh/user-guide/local-tools/search-tools.md @@ -0,0 +1,70 @@ +--- +title: 搜索工具 +--- + +# 搜索工具 + +搜索工具组提供多源信息检索,覆盖互联网搜索、本地知识库以及 DataMate 知识库。适合实时信息查询、行业资料检索、私有文档查找等场景。 + +## 🧭 工具清单 + +- 本地/私有知识库: + - `knowledge_base_search`:本地知识库检索,支持多知识库与多种检索模式 + - `datamate_search_tool`:对接 DataMate 知识库的检索 +- 公网搜索: + - `exa_search`:基于 EXA 的实时网页与图片搜索 + - `tavily_search`:基于 Tavily 的网页与图片搜索 + - `linkup_search`:基于 Linkup 的图文混合搜索 + +## 🧰 使用场景示例 + +- 查询内部文档、技术规范、行业资料(知识库、DataMate) +- 获取最新新闻、数据或网页截图线索(Exa / Tavily / Linkup) +- 同时返回图片参考以丰富答案(开启图片过滤后可输出图片列表) + +## 🧾 参数要求与行为 + +### knowledge_base_search +- `query`:检索问题,必填。 +- `search_mode`:`hybrid`(默认,混合召回)、`accurate`(文本模糊匹配)、`semantic`(向量语义)。 +- `index_names`:指定要搜索的知识库名称列表(可用用户侧名称或内部索引名),可选。 +- 返回匹配片段的标题、路径/URL、来源类型、得分等。 +- 若未选择知识库,会提示“无可用知识库”。 + +### datamate_search_tool +- `query`:检索问题,必填。 +- `top_k`:返回数量,默认 10。 +- `threshold`:相似度阈值,默认 0.2。 +- `kb_page` / `kb_page_size`:分页获取 DataMate 知识库列表。 +- 需要配置 DataMate 服务地址与端口。 +- 返回包含文件名、下载链接、得分等结构化结果。 + +### exa_search / tavily_search / linkup_search +- `query`:检索问题,必填。 +- `max_results`:返回条数,可配置。 +- 图片过滤:默认开启,按查询语义过滤常见无关图片;可关闭以获取全部图片 URL。 +- 需要对应服务的 API Key: + - Exa:EXA API Key + - Tavily:Tavily API Key + - Linkup:Linkup API Key +- 返回标题、URL、摘要,可能附带图片 URL 列表(去重处理)。 + +## 🛠️ 操作指引 + +1. **选择数据源**:私有资料用 `knowledge_base_search` 或 `datamate_search_tool`;实时公开信息用 Exa/Tavily/Linkup。 +2. **设置检索模式/数量**:知识库可在 `search_mode` 之间切换;公网搜索可调整 `max_results` 与是否启用图片过滤。 +3. **限定范围**:需要特定知识库时填写 `index_names`,避免无关结果;DataMate 可通过阈值与 top_k 控制结果精度与数量。 +4. **结果利用**:返回为 JSON,可直接用于回答、摘要或后续引用;包含 cite 索引便于引用管理。 + +## 🛡️ 安全与最佳实践 + +- 公网搜索需确保 API Key 已在平台安全配置中设置,不要在对话中暴露。 +- 知识库检索前确认已同步最新文档,避免旧版本内容。 +- 当查询过于宽泛导致无结果时,可缩短或拆分问题;图片过滤未命中时可尝试关闭过滤获取原始图片列表。 + +## 🔑 API Key 获取(公网搜索) + +- Exa:前往 [exa.ai](https://exa.ai/) 注册并在控制台申请 EXA API Key。 +- Tavily:访问 [tavily.com](https://www.tavily.com/) 创建账户,在 Dashboard 获取 Tavily API Key。 +- Linkup:在 [linkup.so](https://www.linkup.so/) 注册并于个人中心创建 Linkup API Key。 + diff --git a/doc/docs/zh/user-guide/local-tools/terminal-tool.md b/doc/docs/zh/user-guide/local-tools/terminal-tool.md index 4aab60e28..b0e298319 100644 --- a/doc/docs/zh/user-guide/local-tools/terminal-tool.md +++ b/doc/docs/zh/user-guide/local-tools/terminal-tool.md @@ -1,10 +1,10 @@ -# Terminal工具使用手册 +# 终端工具使用手册 -Terminal工具是Nexent平台提供的一个强大的本地工具,允许智能体通过SSH连接远程服务器执行shell命令。该工具支持会话管理以在命令之间保持shell状态,使用密码认证进行安全连接,并返回命令输出结果。本手册将详细介绍如何配置和使用Terminal工具。 +终端工具是Nexent平台提供的一个强大的本地工具,允许智能体通过SSH连接远程服务器执行shell命令。该工具支持会话管理以在命令之间保持shell状态,使用密码认证进行安全连接,并返回命令输出结果。本手册将详细介绍如何配置和使用终端工具。 ## 🖥️ SSH服务器搭建 -Terminal工具支持两种SSH服务器配置方式: +终端工具支持两种SSH服务器配置方式: 1. **Nexent Terminal容器**:使用Nexent提供的预配置SSH容器(推荐) 2. **第三方SSH服务器**:在现有服务器上搭建SSH服务 @@ -40,7 +40,7 @@ docker build --progress=plain -t nexent/nexent-ubuntu-terminal -f make/terminal/ #### 2. Deploy脚本配置 -在运行部署脚本时,选择启用Terminal工具容器: +在运行部署脚本时,选择启用终端工具容器: ```bash # 运行部署脚本 @@ -49,7 +49,7 @@ bash deploy.sh # 在脚本执行过程中选择: # 1. 部署模式:选择开发/生产/基础设施模式 -# 2. Terminal工具:选择 "Y" 启用Terminal工具容器 +# 2. 终端工具:选择 "Y" 启用终端工具容器 # 3. 配置SSH凭据:输入用户名和密码 # 4. 配置挂载目录:指定主机目录映射 ``` @@ -177,7 +177,7 @@ sudo systemctl restart ssh ## 🚀 工具功能 -Terminal工具提供以下核心功能: +终端工具提供以下核心功能: ### 基本功能 @@ -202,14 +202,14 @@ Terminal工具提供以下核心功能: - **timestamp**:执行时间戳 - **error**:错误信息(如果执行失败) -## ⚙️ Terminal工具配置 +## ⚙️ 终端工具配置 -### 在Nexent中配置Terminal工具 +### 在Nexent中配置终端工具 1. 登录Nexent平台 2. 进入 **[智能体开发](../agent-development)** 页面 3. 选择要配置的智能体 -4. 在"选择Agent的工具"页签中找到"Terminal工具" +4. 在"选择Agent的工具"页签中找到"终端工具" @@ -217,7 +217,7 @@ Terminal工具提供以下核心功能: #### 配置SSH连接参数 -点击Terminal工具的配置按钮,填写以下参数: +点击终端工具的配置按钮,填写以下参数: **基本配置**: - **ssh_host**:SSH服务器的IP地址或域名(Nexent容器默认为nexent-openssh-server) @@ -227,7 +227,7 @@ Terminal工具提供以下核心功能: - **init_path**:初始工作目录(默认为~) - + diff --git a/doc/docs/zh/user-guide/memory-management.md b/doc/docs/zh/user-guide/memory-management.md index 55df5de08..5f745ad1f 100644 --- a/doc/docs/zh/user-guide/memory-management.md +++ b/doc/docs/zh/user-guide/memory-management.md @@ -157,4 +157,4 @@ Nexent采用四层记忆存储架构,不同层级的记忆有不同的作用 2. 在 **[智能体空间](./agent-space)** 中管理您的智能体 3. 继续 **[智能体开发](./agent-development)** 创建更多智能体 -如果您在配置过程中遇到任何问题,请参考我们的 **[常见问题](../getting-started/faq)** 或在[GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions)中进行提问获取支持。 +如果您在配置过程中遇到任何问题,请参考我们的 **[常见问题](../quick-start/faq)** 或在[GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions)中进行提问获取支持。 diff --git a/doc/docs/zh/user-guide/model-management.md b/doc/docs/zh/user-guide/model-management.md index 15493cee1..c47e4ec89 100644 --- a/doc/docs/zh/user-guide/model-management.md +++ b/doc/docs/zh/user-guide/model-management.md @@ -235,4 +235,4 @@ Nexent 支持任何 **遵循OpenAI API规范** 的大语言模型供应商,包 1. **[知识库](./knowledge-base)** - 创建和管理知识库。 2. **[智能体开发](./agent-development)** - 创建和配置智能体。 -如果您在模型配置过程中遇到任何问题,请参考我们的 **[常见问题](../getting-started/faq)** 或在[GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions)中进行提问获取支持。 +如果您在模型配置过程中遇到任何问题,请参考我们的 **[常见问题](../quick-start/faq)** 或在[GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions)中进行提问获取支持。 diff --git a/doc/docs/zh/user-guide/quick-setup.md b/doc/docs/zh/user-guide/quick-setup.md index dafd2afe5..191835746 100644 --- a/doc/docs/zh/user-guide/quick-setup.md +++ b/doc/docs/zh/user-guide/quick-setup.md @@ -50,4 +50,4 @@ 2. 在 **[开始问答](./start-chat)** 中与智能体进行交互 3. 配置 **[记忆管理](./memory-management)** 以提升智能体的记忆能力 -如果您在配置过程中遇到任何问题,请参考我们的 **[常见问题](../getting-started/faq)** 或在[GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions)中进行提问获取支持。 +如果您在配置过程中遇到任何问题,请参考我们的 **[常见问题](../quick-start/faq)** 或在[GitHub Discussions](https://github.com/ModelEngine-Group/nexent/discussions)中进行提问获取支持。 diff --git a/doc/docs/zh/user-guide/start-chat.md b/doc/docs/zh/user-guide/start-chat.md index 1afa80448..d428e5a3a 100644 --- a/doc/docs/zh/user-guide/start-chat.md +++ b/doc/docs/zh/user-guide/start-chat.md @@ -220,6 +220,6 @@ Nexent支持后台运行模式,让您在处理复杂任务时更加高效: 如果您在使用过程中遇到任何问题: -- 📖 查看 **[常见问题](../getting-started/faq)** 获取详细解答 +- 📖 查看 **[常见问题](../quick-start/faq)** 获取详细解答 - 💬 加入我们的 [Discord 社区](https://discord.gg/tb5H3S3wyv) 与其他用户交流 - 🆘 联系技术支持获取专业帮助 \ No newline at end of file diff --git a/doc/pnpm-workspace.yaml b/doc/pnpm-workspace.yaml new file mode 100644 index 000000000..c5739b743 --- /dev/null +++ b/doc/pnpm-workspace.yaml @@ -0,0 +1,2 @@ +ignoredBuiltDependencies: + - esbuild diff --git a/docker/.env.general b/docker/.env.general index f1899f4fc..1f98b1356 100644 --- a/docker/.env.general +++ b/docker/.env.general @@ -1,12 +1,12 @@ -NEXENT_IMAGE=nexent/nexent:latest -NEXENT_WEB_IMAGE=nexent/nexent-web:latest -NEXENT_DATA_PROCESS_IMAGE=nexent/nexent-data-process:latest +NEXENT_IMAGE=nexent/nexent:${APP_VERSION} +NEXENT_WEB_IMAGE=nexent/nexent-web:${APP_VERSION} +NEXENT_DATA_PROCESS_IMAGE=nexent/nexent-data-process:${APP_VERSION} ELASTICSEARCH_IMAGE=docker.elastic.co/elasticsearch/elasticsearch:8.17.4 POSTGRESQL_IMAGE=postgres:15-alpine REDIS_IMAGE=redis:alpine MINIO_IMAGE=quay.io/minio/minio:RELEASE.2023-12-20T01-00-02Z -OPENSSH_SERVER_IMAGE=nexent/nexent-ubuntu-terminal:latest +OPENSSH_SERVER_IMAGE=nexent/nexent-ubuntu-terminal:${APP_VERSION} SUPABASE_KONG=kong:2.8.1 SUPABASE_GOTRUE=supabase/gotrue:v2.170.0 diff --git a/docker/.env.mainland b/docker/.env.mainland index dac8a6c88..162788527 100644 --- a/docker/.env.mainland +++ b/docker/.env.mainland @@ -1,12 +1,12 @@ -NEXENT_IMAGE=ccr.ccs.tencentyun.com/nexent-hub/nexent:latest -NEXENT_WEB_IMAGE=ccr.ccs.tencentyun.com/nexent-hub/nexent-web:latest -NEXENT_DATA_PROCESS_IMAGE=ccr.ccs.tencentyun.com/nexent-hub/nexent-data-process:latest +NEXENT_IMAGE=ccr.ccs.tencentyun.com/nexent-hub/nexent:${APP_VERSION} +NEXENT_WEB_IMAGE=ccr.ccs.tencentyun.com/nexent-hub/nexent-web:${APP_VERSION} +NEXENT_DATA_PROCESS_IMAGE=ccr.ccs.tencentyun.com/nexent-hub/nexent-data-process:${APP_VERSION} ELASTICSEARCH_IMAGE=elastic.m.daocloud.io/elasticsearch/elasticsearch:8.17.4 POSTGRESQL_IMAGE=docker.m.daocloud.io/postgres:15-alpine REDIS_IMAGE=docker.m.daocloud.io/redis:alpine MINIO_IMAGE=quay.m.daocloud.io/minio/minio:RELEASE.2023-12-20T01-00-02Z -OPENSSH_SERVER_IMAGE=ccr.ccs.tencentyun.com/nexent-hub/nexent-ubuntu-terminal:latest +OPENSSH_SERVER_IMAGE=ccr.ccs.tencentyun.com/nexent-hub/nexent-ubuntu-terminal:${APP_VERSION} SUPABASE_KONG=docker.m.daocloud.io/kong:2.8.1 SUPABASE_GOTRUE=docker.m.daocloud.io/supabase/gotrue:v2.170.0 diff --git a/docker/deploy.sh b/docker/deploy.sh index dc39eaf6a..425b4bceb 100755 --- a/docker/deploy.sh +++ b/docker/deploy.sh @@ -3,11 +3,26 @@ # Ensure the script is executed with bash (required for arrays and [[ ]]) if [ -z "$BASH_VERSION" ]; then echo "❌ This script must be run with bash. Please use: bash deploy.sh or ./deploy.sh" - exit 0 + exit 1 fi # Exit immediately if a command exits with a non-zero status set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" +CONST_FILE="$PROJECT_ROOT/backend/consts/const.py" +DEPLOY_OPTIONS_FILE="$SCRIPT_DIR/deploy.options" + +MODE_CHOICE_SAVED="" +VERSION_CHOICE_SAVED="" +IS_MAINLAND_SAVED="" +ENABLE_TERMINAL_SAVED="N" +TERMINAL_MOUNT_DIR_SAVED="${TERMINAL_MOUNT_DIR:-}" +APP_VERSION="" + +cd "$SCRIPT_DIR" + set -a source .env @@ -210,7 +225,7 @@ check_ports_in_env_files() { confirm_continue=$(sanitize_input "$confirm_continue") if ! [[ "$confirm_continue" =~ ^[Yy]$ ]]; then echo "🚫 Deployment aborted due to port conflicts." - exit 0 + exit 1 fi echo "⚠️ Continuing deployment even though some required ports are already in use." @@ -221,6 +236,41 @@ check_ports_in_env_files() { echo "" } +trim_quotes() { + local value="$1" + value="${value%$'\r'}" + value="${value%\"}" + value="${value#\"}" + echo "$value" +} + +get_app_version() { + if [ ! -f "$CONST_FILE" ]; then + echo "" + return + fi + + local line + line=$(grep -E 'APP_VERSION' "$CONST_FILE" | tail -n 1 || true) + line="${line##*=}" + line="$(echo "$line" | sed 's/^[[:space:]]*//;s/[[:space:]]*$//')" + local value + value="$(trim_quotes "$line")" + echo "$value" +} + +persist_deploy_options() { + { + echo "APP_VERSION=\"${APP_VERSION}\"" + echo "ROOT_DIR=\"${ROOT_DIR}\"" + echo "MODE_CHOICE=\"${MODE_CHOICE_SAVED}\"" + echo "VERSION_CHOICE=\"${VERSION_CHOICE_SAVED}\"" + echo "IS_MAINLAND=\"${IS_MAINLAND_SAVED}\"" + echo "ENABLE_TERMINAL=\"${ENABLE_TERMINAL_SAVED}\"" + echo "TERMINAL_MOUNT_DIR=\"${TERMINAL_MOUNT_DIR_SAVED}\"" + } > "$DEPLOY_OPTIONS_FILE" +} + generate_minio_ak_sk() { echo "🔑 Generating MinIO keys..." @@ -395,6 +445,7 @@ select_deployment_mode() { # Sanitize potential Windows CR in input mode_choice=$(sanitize_input "$mode_choice") + MODE_CHOICE_SAVED="$mode_choice" case $mode_choice in 2) @@ -544,7 +595,7 @@ deploy_core_services() { echo "👀 Starting core services..." if ! ${docker_compose_command} -p nexent -f "docker-compose${COMPOSE_FILE_SUFFIX}" up -d nexent-config nexent-runtime nexent-mcp nexent-northbound nexent-web nexent-data-process; then echo " ❌ ERROR Failed to start core services" - return 0 + return 1 fi } @@ -561,7 +612,7 @@ deploy_infrastructure() { if ! ${docker_compose_command} -p nexent -f "docker-compose${COMPOSE_FILE_SUFFIX}" up -d $INFRA_SERVICES; then echo " ❌ ERROR Failed to start infrastructure services" - return 0 + return 1 fi if [ "$ENABLE_TERMINAL_TOOL_CONTAINER" = "true" ]; then @@ -606,7 +657,7 @@ select_deployment_version() { # Sanitize potential Windows CR in input version_choice=$(sanitize_input "$version_choice") - + VERSION_CHOICE_SAVED="${version_choice}" case $version_choice in 2) export DEPLOYMENT_VERSION="full" @@ -653,7 +704,7 @@ setup_package_install_script() { echo " ✅ Package installation script created/updated" else echo " ❌ ERROR openssh-install-script.sh not found" - return 0 + return 1 fi } @@ -692,6 +743,7 @@ select_terminal_tool() { enable_terminal=$(sanitize_input "$enable_terminal") if [[ "$enable_terminal" =~ ^[Yy]$ ]]; then + ENABLE_TERMINAL_SAVED="Y" export ENABLE_TERMINAL_TOOL_CONTAINER="true" export COMPOSE_PROFILES="${COMPOSE_PROFILES:+$COMPOSE_PROFILES,}terminal" echo "✅ Terminal tool container will be created 🔧" @@ -707,6 +759,7 @@ select_terminal_tool() { read -p " 📁 Enter host directory to mount to container (default: /opt/terminal): " terminal_mount_dir terminal_mount_dir=$(sanitize_input "$terminal_mount_dir") TERMINAL_MOUNT_DIR="${terminal_mount_dir:-$default_terminal_dir}" + TERMINAL_MOUNT_DIR_SAVED="$TERMINAL_MOUNT_DIR" # Save to environment variables export TERMINAL_MOUNT_DIR @@ -770,6 +823,7 @@ select_terminal_tool() { fi echo "" else + ENABLE_TERMINAL_SAVED="N" export ENABLE_TERMINAL_TOOL_CONTAINER="false" echo "🚫 Terminal tool container disabled" fi @@ -814,9 +868,11 @@ choose_image_env() { # Sanitize potential Windows CR in input is_mainland=$(sanitize_input "$is_mainland") if [[ "$is_mainland" =~ ^[Yy]$ ]]; then + IS_MAINLAND_SAVED="Y" echo "🌐 Detected mainland China network, using .env.mainland for image sources." source .env.mainland else + IS_MAINLAND_SAVED="N" echo "🌐 Using general image sources from .env.general." source .env.general fi @@ -833,28 +889,35 @@ main_deploy() { echo "--------------------------------" echo "" + APP_VERSION="$(get_app_version)" + if [ -z "$APP_VERSION" ]; then + echo "❌ Failed to get app version, please check the backend/consts/const.py file" + exit 1 + fi + echo "🌐 App version: $APP_VERSION" + # Check all relevant ports from environment files before starting deployment check_ports_in_env_files # Select deployment version, mode and image source - select_deployment_version || { echo "❌ Deployment version selection failed"; exit 0; } - select_deployment_mode || { echo "❌ Deployment mode selection failed"; exit 0; } - select_terminal_tool || { echo "❌ Terminal tool container configuration failed"; exit 0; } - choose_image_env || { echo "❌ Image environment setup failed"; exit 0; } + select_deployment_version || { echo "❌ Deployment version selection failed"; exit 1; } + select_deployment_mode || { echo "❌ Deployment mode selection failed"; exit 1; } + select_terminal_tool || { echo "❌ Terminal tool container configuration failed"; exit 1; } + choose_image_env || { echo "❌ Image environment setup failed"; exit 1; } # Add permission - prepare_directory_and_data || { echo "❌ Permission setup failed"; exit 0; } - generate_minio_ak_sk || { echo "❌ MinIO key generation failed"; exit 0; } + prepare_directory_and_data || { echo "❌ Permission setup failed"; exit 1; } + generate_minio_ak_sk || { echo "❌ MinIO key generation failed"; exit 1; } # Generate Supabase secrets - generate_supabase_keys || { echo "❌ Supabase secrets generation failed"; exit 0; } + generate_supabase_keys || { echo "❌ Supabase secrets generation failed"; exit 1; } # Deploy infrastructure services - deploy_infrastructure || { echo "❌ Infrastructure deployment failed"; exit 0; } + deploy_infrastructure || { echo "❌ Infrastructure deployment failed"; exit 1; } # Generate Elasticsearch API key - generate_elasticsearch_api_key || { echo "❌ Elasticsearch API key generation failed"; exit 0; } + generate_elasticsearch_api_key || { echo "❌ Elasticsearch API key generation failed"; exit 1; } echo "" echo "--------------------------------" @@ -862,16 +925,17 @@ main_deploy() { # Special handling for infrastructure mode if [ "$DEPLOYMENT_MODE" = "infrastructure" ]; then - generate_env_for_infrastructure || { echo "❌ Environment generation failed"; exit 0; } + generate_env_for_infrastructure || { echo "❌ Environment generation failed"; exit 1; } echo "🎉 Infrastructure deployment completed successfully!" echo " You can now start the core services manually using dev containers" echo " Environment file available at: $(cd .. && pwd)/.env" echo "💡 Use 'source .env' to load environment variables in your development shell" + persist_deploy_options return 0 fi # Start core services - deploy_core_services || { echo "❌ Core services deployment failed"; exit 0; } + deploy_core_services || { echo "❌ Core services deployment failed"; exit 1; } echo " ✅ Core services started successfully" echo "" @@ -880,9 +944,10 @@ main_deploy() { # Create default admin user if [ "$DEPLOYMENT_VERSION" = "full" ]; then - create_default_admin_user || { echo "❌ Default admin user creation failed"; exit 0; } + create_default_admin_user || { echo "❌ Default admin user creation failed"; exit 1; } fi + persist_deploy_options echo "🎉 Deployment completed successfully!" echo "🌐 You can now access the application at http://localhost:3000" } @@ -891,7 +956,7 @@ main_deploy() { version_info=$(get_compose_version) if [[ $version_info == "unknown" ]]; then echo "Error: Docker Compose not found or version detection failed" - exit 0 + exit 1 fi # extract version @@ -906,7 +971,7 @@ case $version_type in # The version v1.28.0 is the minimum requirement in Docker Compose v1 that explicitly supports interpolation syntax with default values like ${VAR:-default} if [[ $version_number < "1.28.0" ]]; then echo "Warning: V1 version is too old, consider upgrading to V2" - exit 0 + exit 1 fi docker_compose_command="docker-compose" ;; @@ -916,14 +981,14 @@ case $version_type in ;; *) echo "Error: Unknown docker compose version type." - exit 0 + exit 1 ;; esac # Execute main deployment with error handling if ! main_deploy; then echo "❌ Deployment failed. Please check the error messages above and try again." - exit 0 + exit 1 fi clean diff --git a/docker/init.sql b/docker/init.sql index 1181c8237..5ba10457a 100644 --- a/docker/init.sql +++ b/docker/init.sql @@ -165,6 +165,7 @@ CREATE TABLE IF NOT EXISTS "model_record_t" ( "used_token" int4, "expected_chunk_size" int4, "maximum_chunk_size" int4, + "chunk_batch" int4, "display_name" varchar(100) COLLATE "pg_catalog"."default", "connect_status" varchar(100) COLLATE "pg_catalog"."default", "ssl_verify" boolean DEFAULT true, @@ -205,6 +206,7 @@ INSERT INTO "nexent"."model_record_t" ("model_repo", "model_name", "model_factor CREATE TABLE IF NOT EXISTS "knowledge_record_t" ( "knowledge_id" SERIAL, "index_name" varchar(100) COLLATE "pg_catalog"."default", + "knowledge_name" varchar(100) COLLATE "pg_catalog"."default", "knowledge_describe" varchar(3000) COLLATE "pg_catalog"."default", "tenant_id" varchar(100) COLLATE "pg_catalog"."default", "knowledge_sources" varchar(100) COLLATE "pg_catalog"."default", @@ -218,7 +220,8 @@ CREATE TABLE IF NOT EXISTS "knowledge_record_t" ( ); ALTER TABLE "knowledge_record_t" OWNER TO "root"; COMMENT ON COLUMN "knowledge_record_t"."knowledge_id" IS 'Knowledge base ID, unique primary key'; -COMMENT ON COLUMN "knowledge_record_t"."index_name" IS 'Knowledge base name'; +COMMENT ON COLUMN "knowledge_record_t"."index_name" IS 'Internal Elasticsearch index name'; +COMMENT ON COLUMN "knowledge_record_t"."knowledge_name" IS 'User-facing knowledge base name (display name), mapped to internal index_name'; COMMENT ON COLUMN "knowledge_record_t"."knowledge_describe" IS 'Knowledge base description'; COMMENT ON COLUMN "knowledge_record_t"."tenant_id" IS 'Tenant ID'; COMMENT ON COLUMN "knowledge_record_t"."knowledge_sources" IS 'Knowledge base sources'; @@ -294,6 +297,7 @@ CREATE TABLE IF NOT EXISTS nexent.ag_tenant_agent_t ( display_name VARCHAR(100), description VARCHAR, business_description VARCHAR, + author VARCHAR(100), model_name VARCHAR(100), model_id INTEGER, business_logic_model_name VARCHAR(100), @@ -335,6 +339,7 @@ COMMENT ON COLUMN nexent.ag_tenant_agent_t.agent_id IS 'ID'; COMMENT ON COLUMN nexent.ag_tenant_agent_t.name IS 'Agent name'; COMMENT ON COLUMN nexent.ag_tenant_agent_t.display_name IS 'Agent display name'; COMMENT ON COLUMN nexent.ag_tenant_agent_t.description IS 'Description'; +COMMENT ON COLUMN nexent.ag_tenant_agent_t.author IS 'Agent author'; COMMENT ON COLUMN nexent.ag_tenant_agent_t.business_description IS 'Manually entered by the user to describe the entire business process'; COMMENT ON COLUMN nexent.ag_tenant_agent_t.model_name IS '[DEPRECATED] Name of the model used, use model_id instead'; COMMENT ON COLUMN nexent.ag_tenant_agent_t.model_id IS 'Model ID, foreign key reference to model_record_t.model_id'; diff --git a/docker/sql/0619_add_tenant_config_t.sql b/docker/sql/v1.1.0_0619_add_tenant_config_t.sql similarity index 93% rename from docker/sql/0619_add_tenant_config_t.sql rename to docker/sql/v1.1.0_0619_add_tenant_config_t.sql index f13876fcf..b2079101c 100644 --- a/docker/sql/0619_add_tenant_config_t.sql +++ b/docker/sql/v1.1.0_0619_add_tenant_config_t.sql @@ -50,6 +50,7 @@ $$ LANGUAGE plpgsql; COMMENT ON FUNCTION update_tenant_config_update_time() IS 'Function to update the update_time column when a record in tenant_config_t is updated'; -- 创建触发器 +DROP TRIGGER IF EXISTS update_tenant_config_update_time_trigger ON nexent.tenant_config_t; CREATE TRIGGER update_tenant_config_update_time_trigger BEFORE UPDATE ON nexent.tenant_config_t FOR EACH ROW @@ -60,5 +61,5 @@ COMMENT ON TRIGGER update_tenant_config_update_time_trigger ON nexent.tenant_con IS 'Trigger to call update_tenant_config_update_time function before each update on tenant_config_t table'; ALTER TABLE model_record_t -ADD COLUMN tenant_id varchar(100) COLLATE pg_catalog.default DEFAULT 'tenant_id'; +ADD COLUMN IF NOT EXISTS tenant_id varchar(100) COLLATE pg_catalog.default DEFAULT 'tenant_id'; COMMENT ON COLUMN "model_record_t"."tenant_id" IS 'Tenant ID for filtering'; \ No newline at end of file diff --git a/docker/sql/0627_increase_config_value_length.sql b/docker/sql/v1.2.0_0627_increase_config_value_length.sql similarity index 100% rename from docker/sql/0627_increase_config_value_length.sql rename to docker/sql/v1.2.0_0627_increase_config_value_length.sql diff --git a/docker/sql/0630_add_mcp_record_t.sql b/docker/sql/v1.3.0_0630_add_mcp_record_t.sql similarity index 96% rename from docker/sql/0630_add_mcp_record_t.sql rename to docker/sql/v1.3.0_0630_add_mcp_record_t.sql index d76fb800b..3f25a5957 100644 --- a/docker/sql/0630_add_mcp_record_t.sql +++ b/docker/sql/v1.3.0_0630_add_mcp_record_t.sql @@ -49,6 +49,7 @@ $$ LANGUAGE plpgsql; COMMENT ON FUNCTION update_mcp_record_update_time() IS 'Function to update the update_time column when a record in mcp_record_t is updated'; -- Create a trigger to call the function before each update +DROP TRIGGER IF EXISTS update_mcp_record_update_time_trigger ON nexent.mcp_record_t; CREATE TRIGGER update_mcp_record_update_time_trigger BEFORE UPDATE ON nexent.mcp_record_t FOR EACH ROW diff --git a/docker/sql/0708_add_user_tenant_t.sql b/docker/sql/v1.4.0_0708_add_user_tenant_t.sql similarity index 100% rename from docker/sql/0708_add_user_tenant_t.sql rename to docker/sql/v1.4.0_0708_add_user_tenant_t.sql diff --git a/docker/sql/0715_add_knowledge_describe_length.sql b/docker/sql/v1.5.0_0715_add_knowledge_describe_length.sql similarity index 100% rename from docker/sql/0715_add_knowledge_describe_length.sql rename to docker/sql/v1.5.0_0715_add_knowledge_describe_length.sql diff --git a/docker/sql/0716_add_status_to_mcp_record_t.sql b/docker/sql/v1.5.0_0716_add_status_to_mcp_record_t.sql similarity index 75% rename from docker/sql/0716_add_status_to_mcp_record_t.sql rename to docker/sql/v1.5.0_0716_add_status_to_mcp_record_t.sql index 3eb351323..ac233a8bf 100644 --- a/docker/sql/0716_add_status_to_mcp_record_t.sql +++ b/docker/sql/v1.5.0_0716_add_status_to_mcp_record_t.sql @@ -1,3 +1,3 @@ ALTER TABLE nexent.mcp_record_t -ADD COLUMN status BOOLEAN DEFAULT NULL; +ADD COLUMN IF NOT EXISTS status BOOLEAN DEFAULT NULL; COMMENT ON COLUMN nexent.mcp_record_t.status IS 'MCP server connection status, true=connected, false=disconnected, null=unknown'; \ No newline at end of file diff --git a/docker/sql/0722_modify_tenant_agent.sql b/docker/sql/v1.6.0_0722_modify_tenant_agent.sql similarity index 100% rename from docker/sql/0722_modify_tenant_agent.sql rename to docker/sql/v1.6.0_0722_modify_tenant_agent.sql diff --git a/docker/sql/0723_add_agent_relation_t.sql b/docker/sql/v1.6.0_0723_add_agent_relation_t.sql similarity index 95% rename from docker/sql/0723_add_agent_relation_t.sql rename to docker/sql/v1.6.0_0723_add_agent_relation_t.sql index 4e0756d9a..78d856438 100644 --- a/docker/sql/0723_add_agent_relation_t.sql +++ b/docker/sql/v1.6.0_0723_add_agent_relation_t.sql @@ -24,6 +24,7 @@ END; $$ LANGUAGE plpgsql; -- Create a trigger to call the function before each update +DROP TRIGGER IF EXISTS update_ag_agent_relation_update_time_trigger ON nexent.ag_agent_relation_t; CREATE TRIGGER update_ag_agent_relation_update_time_trigger BEFORE UPDATE ON nexent.ag_agent_relation_t FOR EACH ROW diff --git a/docker/sql/0805_add_deep_thinking_to_model_record_t.sql b/docker/sql/v1.7.1_0805_add_deep_thinking_to_model_record_t.sql similarity index 68% rename from docker/sql/0805_add_deep_thinking_to_model_record_t.sql rename to docker/sql/v1.7.1_0805_add_deep_thinking_to_model_record_t.sql index baf4b3052..65b5b8465 100644 --- a/docker/sql/0805_add_deep_thinking_to_model_record_t.sql +++ b/docker/sql/v1.7.1_0805_add_deep_thinking_to_model_record_t.sql @@ -1,3 +1,3 @@ ALTER TABLE nexent.model_record_t -ADD COLUMN is_deep_thinking BOOLEAN DEFAULT FALSE; +ADD COLUMN IF NOT EXISTS is_deep_thinking BOOLEAN DEFAULT FALSE; COMMENT ON COLUMN nexent.model_record_t.is_deep_thinking IS 'deep thinking switch, true=open, false=close'; \ No newline at end of file diff --git a/docker/sql/0806_add_memory_user_config.sql b/docker/sql/v1.7.1_0806_add_memory_user_config.sql similarity index 92% rename from docker/sql/0806_add_memory_user_config.sql rename to docker/sql/v1.7.1_0806_add_memory_user_config.sql index b3d135398..46eb42829 100644 --- a/docker/sql/0806_add_memory_user_config.sql +++ b/docker/sql/v1.7.1_0806_add_memory_user_config.sql @@ -1,5 +1,5 @@ -- 创建序列 -CREATE SEQUENCE "nexent"."memory_user_config_t_config_id_seq" +CREATE SEQUENCE IF NOT EXISTS "nexent"."memory_user_config_t_config_id_seq" INCREMENT 1 MINVALUE 1 MAXVALUE 2147483647 @@ -47,6 +47,7 @@ BEGIN END; $$ LANGUAGE plpgsql; +DROP TRIGGER IF EXISTS "update_memory_user_config_update_time_trigger" ON "nexent"."memory_user_config_t"; CREATE TRIGGER "update_memory_user_config_update_time_trigger" BEFORE UPDATE ON "nexent"."memory_user_config_t" FOR EACH ROW diff --git a/docker/sql/0820_add_partner_mapping_id_t.sql b/docker/sql/v1.7.2.2_0820_add_partner_mapping_id_t.sql similarity index 92% rename from docker/sql/0820_add_partner_mapping_id_t.sql rename to docker/sql/v1.7.2.2_0820_add_partner_mapping_id_t.sql index 74fc5ac54..4817b6afc 100644 --- a/docker/sql/0820_add_partner_mapping_id_t.sql +++ b/docker/sql/v1.7.2.2_0820_add_partner_mapping_id_t.sql @@ -1,4 +1,4 @@ -CREATE SEQUENCE "nexent"."partner_mapping_id_t_mapping_id_seq" +CREATE SEQUENCE IF NOT EXISTS "nexent"."partner_mapping_id_t_mapping_id_seq" INCREMENT 1 MINVALUE 1 MAXVALUE 2147483647 @@ -41,6 +41,7 @@ BEGIN END; $$ LANGUAGE plpgsql; +DROP TRIGGER IF EXISTS "update_partner_mapping_update_time_trigger" ON "nexent"."partner_mapping_id_t"; CREATE TRIGGER "update_partner_mapping_update_time_trigger" BEFORE UPDATE ON "nexent"."partner_mapping_id_t" FOR EACH ROW diff --git a/docker/sql/0809_add_name_zh_to_ag_tenant_agent_t.sql b/docker/sql/v1.7.2_0809_add_name_zh_to_ag_tenant_agent_t.sql similarity index 69% rename from docker/sql/0809_add_name_zh_to_ag_tenant_agent_t.sql rename to docker/sql/v1.7.2_0809_add_name_zh_to_ag_tenant_agent_t.sql index 0ee275280..3b0b77c6c 100644 --- a/docker/sql/0809_add_name_zh_to_ag_tenant_agent_t.sql +++ b/docker/sql/v1.7.2_0809_add_name_zh_to_ag_tenant_agent_t.sql @@ -1,3 +1,3 @@ ALTER TABLE nexent.ag_tenant_agent_t -ADD COLUMN display_name VARCHAR(100); +ADD COLUMN IF NOT EXISTS display_name VARCHAR(100); COMMENT ON COLUMN nexent.ag_tenant_agent_t.display_name IS 'Agent展示名称'; \ No newline at end of file diff --git a/docker/sql/0812_modify_model_record_t.sql b/docker/sql/v1.7.2_0812_modify_model_record_t.sql similarity index 100% rename from docker/sql/0812_modify_model_record_t.sql rename to docker/sql/v1.7.2_0812_modify_model_record_t.sql diff --git a/docker/sql/0902_add_model_name_to_knowledge_record_t.sql b/docker/sql/v1.7.3.2_0902_add_model_name_to_knowledge_record_t.sql similarity index 100% rename from docker/sql/0902_add_model_name_to_knowledge_record_t.sql rename to docker/sql/v1.7.3.2_0902_add_model_name_to_knowledge_record_t.sql diff --git a/docker/sql/1011_add_origin_tool_name_to_ag_tool_info.sql b/docker/sql/v1.7.4.1_1011_add_origin_tool_name_to_ag_tool_info.sql similarity index 100% rename from docker/sql/1011_add_origin_tool_name_to_ag_tool_info.sql rename to docker/sql/v1.7.4.1_1011_add_origin_tool_name_to_ag_tool_info.sql diff --git a/docker/sql/1013_add_tool_group_to_ag_tool_info.sql b/docker/sql/v1.7.4.1_1013_add_tool_group_to_ag_tool_info.sql similarity index 100% rename from docker/sql/1013_add_tool_group_to_ag_tool_info.sql rename to docker/sql/v1.7.4.1_1013_add_tool_group_to_ag_tool_info.sql diff --git a/docker/sql/0928_add_model_id_to_ag_tenant_agent_t.sql b/docker/sql/v1.7.4_0928_add_model_id_to_ag_tenant_agent_t.sql similarity index 95% rename from docker/sql/0928_add_model_id_to_ag_tenant_agent_t.sql rename to docker/sql/v1.7.4_0928_add_model_id_to_ag_tenant_agent_t.sql index 20775cc50..cfff187e0 100644 --- a/docker/sql/0928_add_model_id_to_ag_tenant_agent_t.sql +++ b/docker/sql/v1.7.4_0928_add_model_id_to_ag_tenant_agent_t.sql @@ -7,7 +7,7 @@ SET search_path TO nexent; -- Add model_id column to ag_tenant_agent_t table ALTER TABLE ag_tenant_agent_t -ADD COLUMN model_id INTEGER; +ADD COLUMN IF NOT EXISTS model_id INTEGER; -- Add comment for the new model_id column COMMENT ON COLUMN ag_tenant_agent_t.model_id IS 'Model ID, foreign key reference to model_record_t.model_id'; diff --git a/docker/sql/1028_add_chunk_size_to_model_record_t.sql b/docker/sql/v1.7.5.1_1028_add_chunk_size_to_model_record_t.sql similarity index 75% rename from docker/sql/1028_add_chunk_size_to_model_record_t.sql rename to docker/sql/v1.7.5.1_1028_add_chunk_size_to_model_record_t.sql index 693c0026e..4fa08dc0f 100644 --- a/docker/sql/1028_add_chunk_size_to_model_record_t.sql +++ b/docker/sql/v1.7.5.1_1028_add_chunk_size_to_model_record_t.sql @@ -1,6 +1,6 @@ ALTER TABLE nexent.model_record_t -ADD COLUMN expected_chunk_size INT4, -ADD COLUMN maximum_chunk_size INT4; +ADD COLUMN IF NOT EXISTS expected_chunk_size INT4, +ADD COLUMN IF NOT EXISTS maximum_chunk_size INT4; COMMENT ON COLUMN nexent.model_record_t.expected_chunk_size IS 'Expected chunk size for embedding models, used during document chunking'; COMMENT ON COLUMN nexent.model_record_t.maximum_chunk_size IS 'Maximum chunk size for embedding models, used during document chunking'; diff --git a/docker/sql/1024_add_business_logic_model_fields.sql b/docker/sql/v1.7.5_1024_add_business_logic_model_fields.sql similarity index 100% rename from docker/sql/1024_add_business_logic_model_fields.sql rename to docker/sql/v1.7.5_1024_add_business_logic_model_fields.sql diff --git a/docker/sql/1024_alter_tenant_config_t_config_value.sql b/docker/sql/v1.7.5_1024_alter_tenant_config_t_config_value.sql similarity index 100% rename from docker/sql/1024_alter_tenant_config_t_config_value.sql rename to docker/sql/v1.7.5_1024_alter_tenant_config_t_config_value.sql diff --git a/docker/sql/1129_add_ssl_verify_to_model_record_t.sql b/docker/sql/v1.7.7_1129_add_ssl_verify_to_model_record_t.sql similarity index 80% rename from docker/sql/1129_add_ssl_verify_to_model_record_t.sql rename to docker/sql/v1.7.7_1129_add_ssl_verify_to_model_record_t.sql index aa2c9d9c9..5eec1f92c 100644 --- a/docker/sql/1129_add_ssl_verify_to_model_record_t.sql +++ b/docker/sql/v1.7.7_1129_add_ssl_verify_to_model_record_t.sql @@ -1,5 +1,5 @@ ALTER TABLE nexent.model_record_t -ADD COLUMN ssl_verify BOOLEAN DEFAULT TRUE; +ADD COLUMN IF NOT EXISTS ssl_verify BOOLEAN DEFAULT TRUE; COMMENT ON COLUMN nexent.model_record_t.ssl_verify IS 'Whether to verify SSL certificates when connecting to this model API. Default is true. Set to false for local services without SSL support.'; diff --git a/docker/sql/v1.7.8_1204_add_knowledge_name_to_knowledge_record_t.sql b/docker/sql/v1.7.8_1204_add_knowledge_name_to_knowledge_record_t.sql new file mode 100644 index 000000000..4e889bb0e --- /dev/null +++ b/docker/sql/v1.7.8_1204_add_knowledge_name_to_knowledge_record_t.sql @@ -0,0 +1,18 @@ +-- Add knowledge_name column if it does not exist +ALTER TABLE nexent.knowledge_record_t +ADD COLUMN IF NOT EXISTS knowledge_name varchar(100) COLLATE "pg_catalog"."default"; + +COMMENT ON COLUMN nexent.knowledge_record_t.knowledge_name IS 'User-facing knowledge base name (display name), mapped to internal index_name'; +COMMENT ON COLUMN nexent.knowledge_record_t.index_name IS 'Internal Elasticsearch index name'; + +-- Backfill existing records: for legacy data, use index_name as knowledge_name +UPDATE nexent.knowledge_record_t +SET knowledge_name = index_name +WHERE knowledge_name IS NULL; + + +-- Add chunk_batch column in model_record_t table +ALTER TABLE nexent.model_record_t +ADD COLUMN IF NOT EXISTS chunk_batch INT4; + +COMMENT ON COLUMN nexent.model_record_t.chunk_batch IS 'Batch size for concurrent embedding requests during document chunking'; \ No newline at end of file diff --git a/docker/sql/v1.7.8_add_author_to_ag_tenant_agent_t.sql b/docker/sql/v1.7.8_add_author_to_ag_tenant_agent_t.sql new file mode 100644 index 000000000..4ac134624 --- /dev/null +++ b/docker/sql/v1.7.8_add_author_to_ag_tenant_agent_t.sql @@ -0,0 +1,10 @@ +-- Add author column to ag_tenant_agent_t table +-- This migration adds the author field to support agent author information + +-- Add author column with default NULL value for backward compatibility +ALTER TABLE nexent.ag_tenant_agent_t +ADD COLUMN IF NOT EXISTS author VARCHAR(100); + +-- Add comment to the column +COMMENT ON COLUMN nexent.ag_tenant_agent_t.author IS 'Agent author'; + diff --git a/docker/upgrade.sh b/docker/upgrade.sh new file mode 100644 index 000000000..688a84924 --- /dev/null +++ b/docker/upgrade.sh @@ -0,0 +1,249 @@ +#!/bin/bash + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" +OPTIONS_FILE="$SCRIPT_DIR/deploy.options" +CONST_FILE="$PROJECT_ROOT/backend/consts/const.py" +DEPLOY_SCRIPT="$SCRIPT_DIR/deploy.sh" +SQL_DIR="$SCRIPT_DIR/sql" +ENV_FILE="$SCRIPT_DIR/.env" + +declare -A DEPLOY_OPTIONS +UPGRADE_SQL_FILES=() + +log() { + local level="$1" + shift + printf "[%s] %s\n" "$level" "$*" +} + +require_file() { + local path="$1" + local message="$2" + if [ ! -f "$path" ]; then + log "ERROR" "$message" + exit 1 + fi +} + +trim_quotes() { + local value="$1" + value="${value%$'\r'}" + value="${value%\"}" + value="${value#\"}" + echo "$value" +} + +load_options() { + if [ ! -f "$OPTIONS_FILE" ]; then + log "WARN" "⚙️ deploy.options not found, entering interactive configuration mode." + : > "$OPTIONS_FILE" + return + fi + while IFS= read -r line || [ -n "$line" ]; do + [[ -z "$line" || "$line" =~ ^[[:space:]]*# ]] && continue + if [[ "$line" =~ ^[[:space:]]*([A-Za-z0-9_]+)[[:space:]]*=(.*)$ ]]; then + local key="${BASH_REMATCH[1]}" + local raw_value="${BASH_REMATCH[2]}" + raw_value="$(echo "$raw_value" | sed 's/^[[:space:]]*//;s/[[:space:]]*$//')" + DEPLOY_OPTIONS[$key]="$(trim_quotes "$raw_value")" + fi + done < "$OPTIONS_FILE" +} + +prompt_option_value() { + local key="$1" + local prompt_msg="$2" + local default_value="${3:-}" + local input="" + + while true; do + if [ -n "$default_value" ]; then + read -rp "${prompt_msg} [${default_value}]: " input + input="${input:-$default_value}" + else + read -rp "${prompt_msg}: " input + fi + + input="$(trim_quotes "$input")" + if [ -n "$input" ]; then + DEPLOY_OPTIONS[$key]="$input" + update_option_value "$key" "$input" + break + fi + + log "WARN" "⚠️ ${key} cannot be empty, please enter a value." + done +} + +require_option() { + local key="$1" + local prompt_msg="${2:-}" + local value="${DEPLOY_OPTIONS[$key]:-}" + if [ -z "$value" ]; then + if [ -n "$prompt_msg" ]; then + prompt_option_value "$key" "$prompt_msg" + else + log "ERROR" "❌ ${key} is missing in deploy.options, add it and rerun." + exit 1 + fi + fi +} + +get_const_app_version() { + require_file "$CONST_FILE" "backend/consts/const.py not found, unable to read the latest version." + local line + line=$(grep -E 'APP_VERSION' "$CONST_FILE" | tail -n 1 || true) + line="${line##*=}" + line="$(echo "$line" | sed 's/^[[:space:]]*//;s/[[:space:]]*$//')" + trim_quotes "$line" +} + +compare_versions() { + local v1="${1#v}" + local v2="${2#v}" + IFS='.' read -r -a parts1 <<< "$v1" + IFS='.' read -r -a parts2 <<< "$v2" + local max_len="${#parts1[@]}" + if [ "${#parts2[@]}" -gt "$max_len" ]; then + max_len="${#parts2[@]}" + fi + for ((i=0; i 10#$num2)) && { echo 1; return; } + ((10#$num1 < 10#$num2)) && { echo -1; return; } + done + echo 0 +} + +collect_upgrade_sqls() { + if [ ! -d "$SQL_DIR" ]; then + log "WARN" "📭 SQL directory not found, skipping database upgrade scripts." + return + fi + + mapfile -t sql_files < <(find "$SQL_DIR" -maxdepth 1 -type f -name "v*.sql" -print | sort -V || true) + if [ "${#sql_files[@]}" -eq 0 ]; then + return + fi + + for file in "${sql_files[@]}"; do + local base version_prefix + base="$(basename "$file")" + version_prefix="${base%%_*}" + [[ -z "$version_prefix" ]] && continue + + local cmp_current + cmp_current="$(compare_versions "$version_prefix" "$CURRENT_APP_VERSION")" + + if [ "$cmp_current" -eq 1 ]; then + UPGRADE_SQL_FILES+=("$file") + fi + done +} + +build_deploy_args() { + DEPLOY_ARGS=() + local mode="${DEPLOY_OPTIONS[MODE_CHOICE]:-}" + local version_choice="${DEPLOY_OPTIONS[VERSION_CHOICE]:-}" + local is_mainland="${DEPLOY_OPTIONS[IS_MAINLAND]:-}" + local enable_terminal="${DEPLOY_OPTIONS[ENABLE_TERMINAL]:-}" + local root_dir="${DEPLOY_OPTIONS[ROOT_DIR]:-}" + + [[ -n "$mode" ]] && DEPLOY_ARGS+=(--mode "$mode") + [[ -n "$version_choice" ]] && DEPLOY_ARGS+=(--version "$version_choice") + [[ -n "$is_mainland" ]] && DEPLOY_ARGS+=(--is-mainland "$is_mainland") + [[ -n "$enable_terminal" ]] && DEPLOY_ARGS+=(--enable-terminal "$enable_terminal") + [[ -n "$root_dir" ]] && DEPLOY_ARGS+=(--root-dir "$root_dir") +} + +ensure_docker() { + if ! command -v docker >/dev/null 2>&1; then + log "ERROR" "🛑 Docker CLI not detected, install Docker before continuing." + exit 1 + fi +} + +ensure_postgres_env() { + require_file "$ENV_FILE" "📁 docker/.env not found; unable to load database credentials." + set -a + source "$ENV_FILE" + set +a + : "${POSTGRES_USER:?docker/.env is missing POSTGRES_USER}" + : "${POSTGRES_DB:?docker/.env is missing POSTGRES_DB}" +} + +run_deploy() { + # Stop and remove any existing containers before redeployment + docker compose -p nexent down -v + log "INFO" "🚀 Starting deploy..." + (cd "$SCRIPT_DIR" && cp .env.example .env && bash "$DEPLOY_SCRIPT" "${DEPLOY_ARGS[@]}") + +} + +run_sql_scripts() { + if [ "${#UPGRADE_SQL_FILES[@]}" -eq 0 ]; then + log "INFO" "📭 No database upgrade scripts detected, skipping this step." + return + fi + + ensure_postgres_env + + for sql_file in "${UPGRADE_SQL_FILES[@]}"; do + log "INFO" "🗃️ Running database upgrade script $(basename "$sql_file") ..." + if ! docker exec -i nexent-postgresql psql -U "$POSTGRES_USER" -d "$POSTGRES_DB" -v ON_ERROR_STOP=1 < "$sql_file"; then + log "ERROR" "❌ Failed to execute $(basename "$sql_file"), please verify the script." + exit 1 + fi + done +} + +update_option_value() { + local key="$1" + local value="$2" + touch "$OPTIONS_FILE" + if grep -q "^${key}[[:space:]]*=" "$OPTIONS_FILE"; then + sed -i.bak -E "s|^(${key}[[:space:]]*=[[:space:]]*)\"?[^\"]*\"?|\1\"${value}\"|" "$OPTIONS_FILE" + else + echo "${key} = \"${value}\"" >> "$OPTIONS_FILE" + fi +} + + +main() { + ensure_docker + load_options + + require_option "APP_VERSION" "APP_VERSION not detected, please enter the current deployed version" + require_option "ROOT_DIR" "ROOT_DIR not detected, please enter the absolute deployment directory path" + CURRENT_APP_VERSION="${DEPLOY_OPTIONS[APP_VERSION]:-}" + + NEW_APP_VERSION="$(get_const_app_version)" + if [ -z "$NEW_APP_VERSION" ]; then + log "ERROR" "❌ Unable to parse APP_VERSION from const.py, please verify the file." + exit 1 + fi + + log "INFO" "📦 Current version: $CURRENT_APP_VERSION" + log "INFO" "🎯 Target version: $NEW_APP_VERSION" + + local cmp_result + cmp_result="$(compare_versions "$NEW_APP_VERSION" "$CURRENT_APP_VERSION")" + if [ "$cmp_result" -le 0 ]; then + log "INFO" "🚫 Target version ($NEW_APP_VERSION) is not higher than current version ($CURRENT_APP_VERSION), upgrade aborted." + exit 1 + fi + + build_deploy_args + run_deploy + collect_upgrade_sqls + run_sql_scripts + + log "INFO" "🎉 Upgrade to ${NEW_APP_VERSION} completed, please verify service health." +} + +main "$@" + diff --git a/frontend/app/[locale]/agents/AgentConfiguration.tsx b/frontend/app/[locale]/agents/AgentConfiguration.tsx index ab5411811..b23f2ccae 100644 --- a/frontend/app/[locale]/agents/AgentConfiguration.tsx +++ b/frontend/app/[locale]/agents/AgentConfiguration.tsx @@ -94,6 +94,7 @@ export default forwardRef(function AgentCon const [agentName, setAgentName] = useState(""); const [agentDescription, setAgentDescription] = useState(""); const [agentDisplayName, setAgentDisplayName] = useState(""); + const [agentAuthor, setAgentAuthor] = useState(""); // Add state for business logic and action buttons const [isGeneratingAgent, setIsGeneratingAgent] = useState(false); @@ -474,12 +475,14 @@ export default forwardRef(function AgentCon if (isEditing && agent) { setAgentName(agent.name || ""); setAgentDescription(agent.description || ""); + setAgentAuthor(agent.author || ""); setBusinessLogicError(false); } else if (!isEditing) { // When stopping editing, clear name description box setAgentName(""); setAgentDescription(""); setAgentDisplayName(""); + setAgentAuthor(""); setBusinessLogicError(false); } }; @@ -506,6 +509,7 @@ export default forwardRef(function AgentCon setFewShotsContent(""); setAgentName(""); setAgentDescription(""); + setAgentAuthor(""); setBusinessLogicError(false); }; @@ -546,6 +550,7 @@ export default forwardRef(function AgentCon style={{ height: SETUP_PAGE_CONTAINER.MAIN_CONTENT_HEIGHT, ...STANDARD_CARD.CONTENT_SCROLL, + overflow: "hidden" }} > @@ -595,6 +600,8 @@ export default forwardRef(function AgentCon setAgentDescription={setAgentDescription} agentDisplayName={agentDisplayName} setAgentDisplayName={setAgentDisplayName} + agentAuthor={agentAuthor} + setAgentAuthor={setAgentAuthor} isGeneratingAgent={isGeneratingAgent} // SystemPromptDisplay related props onDebug={() => { diff --git a/frontend/app/[locale]/agents/AgentsContent.tsx b/frontend/app/[locale]/agents/AgentsContent.tsx index 72d5e66ed..5436e3aba 100644 --- a/frontend/app/[locale]/agents/AgentsContent.tsx +++ b/frontend/app/[locale]/agents/AgentsContent.tsx @@ -1,6 +1,6 @@ "use client"; -import React, {useState, useEffect, useRef} from "react"; +import React, {useState, useEffect, useRef, forwardRef, useImperativeHandle} from "react"; import {motion} from "framer-motion"; import {useSetupFlow} from "@/hooks/useSetupFlow"; @@ -30,14 +30,14 @@ interface AgentsContentProps { * AgentsContent - Main component for agent configuration * Can be used in setup flow or as standalone page */ -export default function AgentsContent({ +export default forwardRef(function AgentsContent({ isSaving: externalIsSaving, connectionStatus: externalConnectionStatus, isCheckingConnection: externalIsCheckingConnection, onCheckConnection: externalOnCheckConnection, onConnectionStatusChange, onSavingStateChange, -}: AgentsContentProps) { +}: AgentsContentProps, ref) { const agentConfigRef = useRef(null); const [showSaveConfirm, setShowSaveConfirm] = useState(false); const pendingNavRef = useRef void)>(null); @@ -59,6 +59,21 @@ export default function AgentsContent({ const [internalIsSaving, setInternalIsSaving] = useState(false); const isSaving = externalIsSaving ?? internalIsSaving; + // Expose AgentConfigHandle methods to parent + useImperativeHandle(ref, () => ({ + hasUnsavedChanges: () => agentConfigRef.current?.hasUnsavedChanges?.() ?? false, + saveAllChanges: async () => { + if (agentConfigRef.current?.saveAllChanges) { + await agentConfigRef.current.saveAllChanges(); + } + }, + reloadCurrentAgentData: async () => { + if (agentConfigRef.current?.reloadCurrentAgentData) { + await agentConfigRef.current.reloadCurrentAgentData(); + } + }, + }), []); + // Update external saving state useEffect(() => { onSavingStateChange?.(isSaving); @@ -74,7 +89,7 @@ export default function AgentsContent({ transition={pageTransition} style={{width: "100%", height: "100%"}} > - + {canAccessProtectedData ? ( ) : null} @@ -108,5 +123,5 @@ export default function AgentsContent({ /> > ); -} +}); diff --git a/frontend/app/[locale]/agents/components/AgentSetupOrchestrator.tsx b/frontend/app/[locale]/agents/components/AgentSetupOrchestrator.tsx index f833d9128..225b20980 100644 --- a/frontend/app/[locale]/agents/components/AgentSetupOrchestrator.tsx +++ b/frontend/app/[locale]/agents/components/AgentSetupOrchestrator.tsx @@ -4,7 +4,7 @@ import { useState, useEffect, useCallback, useRef, useMemo } from "react"; import { useTranslation } from "react-i18next"; import { TFunction } from "i18next"; -import { App, Modal, Button, Tooltip } from "antd"; +import { App, Modal, Button, Tooltip, Row, Col } from "antd"; import { WarningFilled } from "@ant-design/icons"; import { TooltipProvider } from "@/components/ui/tooltip"; @@ -26,6 +26,8 @@ import { } from "@/types/agentConfig"; import AgentImportWizard from "@/components/agent/AgentImportWizard"; import log from "@/lib/logger"; +import { useConfirmModal } from "@/hooks/useConfirmModal"; +import { useAuth } from "@/hooks/useAuth"; import SubAgentPool from "./agent/SubAgentPool"; import CollaborativeAgentDisplay from "./agent/CollaborativeAgentDisplay"; @@ -79,6 +81,8 @@ export default function AgentSetupOrchestrator({ setAgentDescription, agentDisplayName, setAgentDisplayName, + agentAuthor, + setAgentAuthor, isGeneratingAgent = false, // SystemPromptDisplay related props onDebug, @@ -93,6 +97,7 @@ export default function AgentSetupOrchestrator({ registerSaveHandler, registerReloadHandler, }: AgentSetupOrchestratorProps) { + const { user, isSpeedMode } = useAuth(); const [enabledToolIds, setEnabledToolIds] = useState([]); const [isLoadingTools, setIsLoadingTools] = useState(false); const [isImporting, setIsImporting] = useState(false); @@ -150,7 +155,9 @@ export default function AgentSetupOrchestrator({ detailReasons.length > 0 ? detailReasons : fallbackReasons; const normalizedAvailability = - typeof detail?.is_available === "boolean" + normalizedReasons.length > 0 + ? false + : typeof detail?.is_available === "boolean" ? detail.is_available : typeof fallback?.is_available === "boolean" ? fallback.is_available @@ -196,6 +203,7 @@ export default function AgentSetupOrchestrator({ const { t } = useTranslation("common"); const { message } = App.useApp(); + const { confirm } = useConfirmModal(); // Common refresh agent list function, moved to the front to avoid hoisting issues const refreshAgentList = async (t: TFunction, clearTools: boolean = true) => { @@ -367,6 +375,7 @@ export default function AgentSetupOrchestrator({ setAgentName?.(agentDetail.name || ""); setAgentDescription?.(agentDetail.description || ""); setAgentDisplayName?.(agentDetail.display_name || ""); + setAgentAuthor?.(agentDetail.author || ""); // Load Agent data to interface setMainAgentModel(agentDetail.model); @@ -907,6 +916,8 @@ export default function AgentSetupOrchestrator({ setAgentName?.(""); setAgentDescription?.(""); setAgentDisplayName?.(""); + setAgentAuthor?.(""); + setAgentAuthor?.(""); // Clear tool and agent selections setSelectedTools([]); @@ -1100,6 +1111,9 @@ export default function AgentSetupOrchestrator({ ) ).sort((a, b) => a - b); + // Determine author value: use provided author, or default to user email in Full mode + const finalAuthor = agentAuthor || (!isSpeedMode && user?.email ? user.email : undefined); + if (isEditingAgent && editingAgent) { // Editing existing agent result = await updateAgent( @@ -1119,7 +1133,8 @@ export default function AgentSetupOrchestrator({ businessLogicModel ?? undefined, businessLogicModelId ?? undefined, deduplicatedToolIds, - deduplicatedAgentIds + deduplicatedAgentIds, + finalAuthor ); } else { // Creating new agent on save @@ -1140,7 +1155,8 @@ export default function AgentSetupOrchestrator({ businessLogicModel ?? undefined, businessLogicModelId ?? undefined, deduplicatedToolIds, - deduplicatedAgentIds + deduplicatedAgentIds, + finalAuthor ); } @@ -1183,6 +1199,7 @@ export default function AgentSetupOrchestrator({ setAgentName?.(agentDetail.name || ""); setAgentDescription?.(agentDetail.description || ""); setAgentDisplayName?.(agentDetail.display_name || ""); + setAgentAuthor?.(agentDetail.author || ""); onEditingStateChange?.(true, agentDetail); setMainAgentModel(agentDetail.model); setMainAgentModelId(agentDetail.model_id ?? null); @@ -1370,6 +1387,7 @@ export default function AgentSetupOrchestrator({ setAgentName?.(agentDetail.name || ""); setAgentDescription?.(agentDetail.description || ""); setAgentDisplayName?.(agentDetail.display_name || ""); + setAgentAuthor?.(agentDetail.author || ""); // Notify external editing state change (use complete data) onEditingStateChange?.(true, agentDetail); @@ -1574,7 +1592,6 @@ export default function AgentSetupOrchestrator({ const runAgentImport = useCallback( async ( agentPayload: any, - translationFn: TFunction, options?: { forceImport?: boolean } ) => { setIsImporting(true); @@ -1708,7 +1725,7 @@ export default function AgentSetupOrchestrator({ agentInfo, }); } else { - await runAgentImport(agentInfo, t); + await runAgentImport(agentInfo); } } catch (error) { log.error(t("agentConfig.agents.importFailed"), error); @@ -1731,7 +1748,7 @@ export default function AgentSetupOrchestrator({ return; } setImportingAction("regenerate"); - const success = await runAgentImport(pendingImportData.agentInfo, t); + const success = await runAgentImport(pendingImportData.agentInfo); if (success) { setPendingImportData(null); } @@ -1743,7 +1760,7 @@ export default function AgentSetupOrchestrator({ return; } setImportingAction("force"); - const success = await runAgentImport(pendingImportData.agentInfo, t, { + const success = await runAgentImport(pendingImportData.agentInfo, { forceImport: true, }); if (success) { @@ -1858,6 +1875,138 @@ export default function AgentSetupOrchestrator({ } }; + // Handle copy agent from list + const handleCopyAgentFromList = async (agent: Agent) => { + try { + // Fetch source agent detail before duplicating + const detailResult = await searchAgentInfo(Number(agent.id)); + if (!detailResult.success || !detailResult.data) { + message.error(detailResult.message); + return; + } + const detail = detailResult.data; + + // Prepare copy names + const copyName = `${detail.name || "agent"}_copy`; + const copyDisplayName = `${ + detail.display_name || t("agentConfig.agents.defaultDisplayName") + }${t("agent.copySuffix")}`; + + // Gather tool and sub-agent identifiers from the source agent + const tools = Array.isArray(detail.tools) ? detail.tools : []; + const unavailableTools = tools.filter( + (tool: any) => tool && tool.is_available === false + ); + const unavailableToolNames = unavailableTools + .map( + (tool: any) => + tool?.display_name || tool?.name || tool?.tool_name || "" + ) + .filter((name: string) => Boolean(name)); + + const enabledToolIds = tools + .filter((tool: any) => tool && tool.is_available !== false) + .map((tool: any) => Number(tool.id)) + .filter((id: number) => Number.isFinite(id)); + const subAgentIds = (Array.isArray(detail.sub_agent_id_list) + ? detail.sub_agent_id_list + : [] + ) + .map((id: any) => Number(id)) + .filter((id: number) => Number.isFinite(id)); + + // Create a new agent using the source agent fields + const createResult = await updateAgent( + undefined, + copyName, + detail.description, + detail.model, + detail.max_step, + detail.provide_run_summary, + detail.enabled, + detail.business_description, + detail.duty_prompt, + detail.constraint_prompt, + detail.few_shots_prompt, + copyDisplayName, + detail.model_id ?? undefined, + detail.business_logic_model_name ?? undefined, + detail.business_logic_model_id ?? undefined, + enabledToolIds, + subAgentIds + ); + if (!createResult.success || !createResult.data?.agent_id) { + message.error( + createResult.message || + t("agentConfig.agents.copyFailed") + ); + return; + } + const newAgentId = Number(createResult.data.agent_id); + const copiedAgentFallback: Agent = { + ...detail, + id: String(newAgentId), + name: copyName, + display_name: copyDisplayName, + sub_agent_id_list: subAgentIds, + }; + + // Copy tool configuration to the new agent + for (const tool of tools) { + if (!tool || tool.is_available === false) { + continue; + } + const params = + tool.initParams?.reduce((acc: Record, param: any) => { + acc[param.name] = param.value; + return acc; + }, {}) || {}; + try { + await updateToolConfig(Number(tool.id), newAgentId, params, true); + } catch (error) { + log.error("Failed to copy tool configuration while duplicating agent:", error); + message.error( + t("agentConfig.agents.copyFailed") + ); + return; + } + } + + // Refresh UI state and notify user about copy result + await refreshAgentList(t, false); + message.success(t("agentConfig.agents.copySuccess")); + if (unavailableTools.length > 0) { + const names = + unavailableToolNames.join(", ") || + unavailableTools + .map((tool: any) => Number(tool?.id)) + .filter((id: number) => !Number.isNaN(id)) + .join(", "); + message.warning( + t("agentConfig.agents.copyUnavailableTools", { + count: unavailableTools.length, + names, + }) + ); + } + // Auto select the newly copied agent for editing + await handleEditAgent(copiedAgentFallback, t); + } catch (error) { + log.error("Failed to copy agent:", error); + message.error(t("agentConfig.agents.copyFailed")); + } + }; + + const handleCopyAgentWithConfirm = (agent: Agent) => { + confirm({ + title: t("agentConfig.agents.copyConfirmTitle"), + content: t("agentConfig.agents.copyConfirmContent", { + name: agent?.display_name || agent?.name || "", + }), + onConfirm: () => handleCopyAgentFromList(agent), + }); + }; + // Handle delete agent from list const handleDeleteAgentFromList = (agent: Agent) => { setAgentToDelete(agent); @@ -1961,11 +2110,24 @@ export default function AgentSetupOrchestrator({ return ( - - {/* Lower part: Agent pool + Agent capability configuration + System Prompt */} - - {/* Left column: Always show SubAgentPool - Equal flex width */} - + + {/* Three-column layout using Ant Design Grid */} + + {/* Left column: SubAgentPool */} + handleEditAgent(agent, t)} onCreateNewAgent={() => confirmOrRun(handleCreateNewAgent)} @@ -1977,6 +2139,7 @@ export default function AgentSetupOrchestrator({ isGeneratingAgent={isGeneratingAgent} editingAgent={editingAgent} isCreatingNewAgent={isCreatingNewAgent} + onCopyAgent={handleCopyAgentWithConfirm} onExportAgent={handleExportAgentFromList} onDeleteAgent={handleDeleteAgentFromList} unsavedAgentId={ @@ -1985,29 +2148,36 @@ export default function AgentSetupOrchestrator({ : null } /> - - - {/* Middle column: Agent capability configuration - Equal flex width */} - + + + {/* Middle column: Agent capability configuration */} + {/* Header: Configure Agent Capabilities */} 2 - + {t("businessLogic.config.title")} - {/* Content: ScrollArea with two sections */} - + {/* Content: Two sections */} + {/* Upper section: Collaborative Agent Display - fixed area */} - - - {/* Right column: System Prompt Display - Equal flex width */} - + + + {/* Right column: System Prompt Display */} + confirmOrRun(() => onDebug()) : () => {}} agentId={ @@ -2099,6 +2277,8 @@ export default function AgentSetupOrchestrator({ onAgentDescriptionChange={setAgentDescription} agentDisplayName={agentDisplayName} onAgentDisplayNameChange={setAgentDisplayName} + agentAuthor={agentAuthor} + onAgentAuthorChange={setAgentAuthor} isEditingMode={isEditingAgent || isCreatingNewAgent} mainAgentModel={mainAgentModel ?? undefined} mainAgentModelId={mainAgentModelId} @@ -2117,14 +2297,13 @@ export default function AgentSetupOrchestrator({ isCreatingNewAgent={isCreatingNewAgent} canSaveAgent={localCanSaveAgent} getButtonTitle={getLocalButtonTitle} - onExportAgent={onExportAgent || (() => {})} onDeleteAgent={onDeleteAgent || (() => {})} onDeleteSuccess={handleExitEdit} editingAgent={editingAgentFromParent || editingAgent} onViewCallRelationship={handleViewCallRelationship} /> - - + + {/* Delete confirmation popup */} { if (!message.steps || message.steps.length === 0) return []; - const taskMsgs: TaskMessageType[] = []; - message.steps.forEach((step) => { - // Process step.contents - if (step.contents && step.contents.length > 0) { - step.contents.forEach((content) => { - taskMsgs.push({ - id: content.id, - role: ROLE_ASSISTANT, - content: content.content, - timestamp: new Date(), - type: content.type, - // Preserve subType so TaskWindow can style deep thinking text - subType: content.subType as any, - } as any); - }); - } - - // Process step.thinking - if (step.thinking && step.thinking.content) { - taskMsgs.push({ - id: `thinking-${step.id}`, - role: ROLE_ASSISTANT, - content: step.thinking.content, - timestamp: new Date(), - type: "model_output_thinking", - }); - } - - // Process step.code - if (step.code && step.code.content) { - taskMsgs.push({ - id: `code-${step.id}`, - role: ROLE_ASSISTANT, - content: step.code.content, - timestamp: new Date(), - type: "model_output_code", - }); - } - - // Process step.output - if (step.output && step.output.content) { - taskMsgs.push({ - id: `output-${step.id}`, - role: ROLE_ASSISTANT, - content: step.output.content, - timestamp: new Date(), - type: "tool", - }); - } - }); + // Use unified message transformer with includeCode: true for debug mode + const { taskMessages } = transformMessagesToTaskMessages( + [message], + { includeCode: true } + ); - return taskMsgs; + return taskMessages; }; return ( diff --git a/frontend/app/[locale]/agents/components/PromptManager.tsx b/frontend/app/[locale]/agents/components/PromptManager.tsx index a9f57dc14..f7cf9aea3 100644 --- a/frontend/app/[locale]/agents/components/PromptManager.tsx +++ b/frontend/app/[locale]/agents/components/PromptManager.tsx @@ -222,12 +222,13 @@ export interface PromptManagerProps { onAgentNameChange?: (name: string) => void; onAgentDescriptionChange?: (description: string) => void; onAgentDisplayNameChange?: (displayName: string) => void; + agentAuthor?: string; + onAgentAuthorChange?: (author: string) => void; onModelChange?: (value: string, modelId?: number) => void; onMaxStepChange?: (value: number | null) => void; onGenerateAgent?: (model: ModelOption) => void; onSaveAgent?: () => void; onDebug?: () => void; - onExportAgent?: () => void; onDeleteAgent?: () => void; onDeleteSuccess?: () => void; getButtonTitle?: () => string; @@ -251,6 +252,8 @@ export default function PromptManager({ agentName = "", agentDescription = "", agentDisplayName = "", + agentAuthor = "", + onAgentAuthorChange, mainAgentModel = "", mainAgentModelId = null, mainAgentMaxStep = 5, @@ -273,7 +276,6 @@ export default function PromptManager({ onGenerateAgent, onSaveAgent, onDebug, - onExportAgent, onDeleteAgent, onDeleteSuccess, getButtonTitle, @@ -678,6 +680,8 @@ export default function PromptManager({ onAgentDescriptionChange={onAgentDescriptionChange} agentDisplayName={agentDisplayName} onAgentDisplayNameChange={onAgentDisplayNameChange} + agentAuthor={agentAuthor} + onAgentAuthorChange={onAgentAuthorChange} isEditingMode={isEditingMode} mainAgentModel={mainAgentModel} mainAgentModelId={mainAgentModelId} @@ -688,7 +692,6 @@ export default function PromptManager({ onExpandCard={handleExpandCard} isGeneratingAgent={isGeneratingAgent} onDebug={onDebug} - onExportAgent={onExportAgent} onDeleteAgent={onDeleteAgent} onDeleteSuccess={onDeleteSuccess} onSaveAgent={onSaveAgent} diff --git a/frontend/app/[locale]/agents/components/agent/AgentConfigModal.tsx b/frontend/app/[locale]/agents/components/agent/AgentConfigModal.tsx index 7e465d1d1..a987dfaee 100644 --- a/frontend/app/[locale]/agents/components/agent/AgentConfigModal.tsx +++ b/frontend/app/[locale]/agents/components/agent/AgentConfigModal.tsx @@ -17,6 +17,7 @@ import { checkAgentDisplayName, } from "@/services/agentConfigService"; import { NAME_CHECK_STATUS } from "@/const/agentConfig"; +import { useAuth } from "@/hooks/useAuth"; import { SimplePromptEditor } from "../PromptManager"; @@ -34,6 +35,8 @@ export interface AgentConfigModalProps { onAgentDescriptionChange?: (description: string) => void; agentDisplayName?: string; onAgentDisplayNameChange?: (displayName: string) => void; + agentAuthor?: string; + onAgentAuthorChange?: (author: string) => void; isEditingMode?: boolean; mainAgentModel?: string; mainAgentModelId?: number | null; @@ -45,7 +48,6 @@ export interface AgentConfigModalProps { isGeneratingAgent?: boolean; // Add new props for action buttons onDebug?: () => void; - onExportAgent?: () => void; onDeleteAgent?: () => void; onDeleteSuccess?: () => void; // New prop for handling delete success onSaveAgent?: () => void; @@ -70,6 +72,8 @@ export default function AgentConfigModal({ onAgentDescriptionChange, agentDisplayName = "", onAgentDisplayNameChange, + agentAuthor = "", + onAgentAuthorChange, isEditingMode = false, mainAgentModel = "", mainAgentModelId = null, @@ -80,7 +84,6 @@ export default function AgentConfigModal({ isGeneratingAgent = false, // Add new props for action buttons onDebug, - onExportAgent, onDeleteAgent, onDeleteSuccess, onSaveAgent, @@ -90,6 +93,7 @@ export default function AgentConfigModal({ getButtonTitle, }: AgentConfigModalProps) { const { t } = useTranslation("common"); + const { user, isSpeedMode } = useAuth(); // Add local state to track content of three sections const [localDutyContent, setLocalDutyContent] = useState(dutyContent || ""); @@ -186,6 +190,13 @@ export default function AgentConfigModal({ loadLLMModels(); }, []); + // Set default author for new agents in Full mode + useEffect(() => { + if (isCreatingNewAgent && !isSpeedMode && !agentAuthor && user?.email) { + onAgentAuthorChange?.(user.email); + } + }, [isCreatingNewAgent, isSpeedMode, agentAuthor, user?.email, onAgentAuthorChange]); + // Default to globally configured model when creating a new agent // IMPORTANT: Only read from localStorage when creating a NEW agent, not when editing existing agent useEffect(() => { @@ -472,11 +483,6 @@ export default function AgentConfigModal({ onDeleteSuccess?.(); }, [onDeleteAgent, onDeleteSuccess]); - // Handle delete button click - const handleDeleteClick = useCallback(() => { - setIsDeleteModalVisible(true); - }, []); - // Optimized click handlers using useCallback const handleSegmentClick = useCallback((segment: string) => { setActiveSegment(segment); @@ -650,6 +656,27 @@ export default function AgentConfigModal({ )} + {/* Agent Author */} + + + {t("agent.author")}: + + { + onAgentAuthorChange?.(e.target.value); + }} + placeholder={t("agent.authorPlaceholder")} + size="large" + disabled={!isEditingMode} + /> + {isCreatingNewAgent && !isSpeedMode && !agentAuthor && user?.email && ( + + {t("agent.author.hint", { defaultValue: "Default: {{email}}", email: user.email })} + + )} + + {/* Model Selection */} @@ -741,6 +768,7 @@ export default function AgentConfigModal({ { setLocalDutyContent(value); // Immediate update to parent component @@ -758,6 +786,7 @@ export default function AgentConfigModal({ { setLocalConstraintContent(value); // Immediate update to parent component @@ -775,6 +804,7 @@ export default function AgentConfigModal({ { setLocalFewShotsContent(value); // Immediate update to parent component diff --git a/frontend/app/[locale]/agents/components/agent/SubAgentPool.tsx b/frontend/app/[locale]/agents/components/agent/SubAgentPool.tsx index 88b8594bb..e2fc3ef60 100644 --- a/frontend/app/[locale]/agents/components/agent/SubAgentPool.tsx +++ b/frontend/app/[locale]/agents/components/agent/SubAgentPool.tsx @@ -3,9 +3,9 @@ import { useState } from "react"; import { useTranslation } from "react-i18next"; -import { Button } from "antd"; +import { Button, Row, Col } from "antd"; import { ExclamationCircleOutlined } from "@ant-design/icons"; -import { FileOutput, Network, FileInput, Trash2, Plus, X } from "lucide-react"; +import { Copy, FileOutput, Network, FileInput, Trash2, Plus, X } from "lucide-react"; import { ScrollArea } from "@/components/ui/scrollArea"; import { @@ -37,6 +37,7 @@ export default function SubAgentPool({ isGeneratingAgent = false, editingAgent = null, isCreatingNewAgent = false, + onCopyAgent, onExportAgent, onDeleteAgent, unsavedAgentId = null, @@ -140,8 +141,8 @@ export default function SubAgentPool({ } } `} - - + + 1 @@ -158,19 +159,23 @@ export default function SubAgentPool({ )} - + {/* Function operation block */} - - - - + + + + { if (isCreatingNewAgent) { // If currently in creation mode, click to exit creation mode @@ -217,24 +222,26 @@ export default function SubAgentPool({ : t("subAgentPool.description.createAgent")} + - - - - {isCreatingNewAgent - ? t("subAgentPool.tooltip.exitCreateMode") - : t("subAgentPool.tooltip.createNewAgent")} - - + + + {isCreatingNewAgent + ? t("subAgentPool.tooltip.exitCreateMode") + : t("subAgentPool.tooltip.createNewAgent")} + + + - - - + + + - + - - - - {isImporting - ? t("subAgentPool.description.importing") - : t("subAgentPool.description.importAgent")} - - - + + + {isImporting + ? t("subAgentPool.description.importing") + : t("subAgentPool.description.importAgent")} + + + + {/* Agent list block */} @@ -353,6 +361,27 @@ export default function SubAgentPool({ {/* Operation button area */} + {/* Copy agent button */} + {onCopyAgent && ( + + + } + onClick={(e) => { + e.preventDefault(); + e.stopPropagation(); + onCopyAgent(agent); + }} + className="agent-action-button agent-action-button-blue" + /> + + + {t("agent.contextMenu.copy")} + + + )} {/* View call relationship button */} @@ -379,7 +408,7 @@ export default function SubAgentPool({ } + icon={} onClick={(e) => { e.preventDefault(); e.stopPropagation(); diff --git a/frontend/app/[locale]/agents/components/tool/ToolConfigModal.tsx b/frontend/app/[locale]/agents/components/tool/ToolConfigModal.tsx index d1160b722..3a8947b93 100644 --- a/frontend/app/[locale]/agents/components/tool/ToolConfigModal.tsx +++ b/frontend/app/[locale]/agents/components/tool/ToolConfigModal.tsx @@ -9,32 +9,18 @@ import { InputNumber, Tag, App, - Button, - Card, - Typography, - Tooltip, } from "antd"; -import { - CloseOutlined, - SettingOutlined, - EditOutlined, -} from "@ant-design/icons"; import { TOOL_PARAM_TYPES } from "@/const/agentConfig"; import { ToolParam, ToolConfigModalProps } from "@/types/agentConfig"; - -const { Text, Title } = Typography; import { updateToolConfig, searchToolConfig, loadLastToolConfig, - validateTool, - parseToolInputs, - extractParameterNames, } from "@/services/agentConfigService"; import log from "@/lib/logger"; import { useModalPosition } from "@/hooks/useModalPosition"; -import { DEFAULT_TYPE } from "@/const/constants"; +import ToolTestPanel from "./ToolTestPanel"; export default function ToolConfigModal({ isOpen, @@ -50,16 +36,8 @@ export default function ToolConfigModal({ const { t } = useTranslation("common"); const { message } = App.useApp(); - // Tool test related state + // Tool test panel visibility state const [testPanelVisible, setTestPanelVisible] = useState(false); - const [testExecuting, setTestExecuting] = useState(false); - const [testResult, setTestResult] = useState(""); - const [parsedInputs, setParsedInputs] = useState>({}); - const [paramValues, setParamValues] = useState>({}); - const [dynamicInputParams, setDynamicInputParams] = useState([]); - const [isManualInputMode, setIsManualInputMode] = useState(false); - const [manualJsonInput, setManualJsonInput] = useState(""); - const [isParseSuccessful, setIsParseSuccessful] = useState(false); const { windowWidth, mainModalTop, mainModalRight } = useModalPosition(isOpen); @@ -70,6 +48,48 @@ export default function ToolConfigModal({ const canPersistToolConfig = typeof normalizedAgentId === "number" && normalizedAgentId > 0; + // Apply transform to modal when test panel is visible + // Move main modal to the left to center both panels together + useEffect(() => { + if (!isOpen) return; + + const testPanelWidth = 500; + const gap = windowWidth * 0.05; + // Move left by half of (test panel width + gap) to center both panels + const offsetX = testPanelVisible + ? -(testPanelWidth + gap) / 2 + : 0; + + // Find the modal wrap element (Ant Design renders Modal in a wrap container) + // Use a small delay to ensure Modal is rendered + const timer = setTimeout(() => { + const modalContent = document.querySelector( + ".tool-config-modal-content" + ); + if (modalContent) { + const modalWrap = modalContent.closest(".ant-modal-wrap") as HTMLElement; + if (modalWrap) { + modalWrap.style.transform = `translateX(${offsetX}px)`; + modalWrap.style.transition = "transform 0.3s ease-in-out"; + } + } + }, 0); + + return () => { + clearTimeout(timer); + const modalContent = document.querySelector( + ".tool-config-modal-content" + ); + if (modalContent) { + const modalWrap = modalContent.closest(".ant-modal-wrap") as HTMLElement; + if (modalWrap) { + modalWrap.style.transform = ""; + modalWrap.style.transition = ""; + } + } + }; + }, [testPanelVisible, isOpen, windowWidth]); + // load tool config useEffect(() => { const buildDefaultParams = () => @@ -235,180 +255,15 @@ export default function ToolConfigModal({ } }; - // Handle tool testing + // Handle tool testing - open test panel const handleTestTool = () => { if (!tool) return; - setTestResult(""); - // Parse inputs definition from tool inputs field - try { - const parsedInputs = parseToolInputs(tool.inputs || ""); - const paramNames = extractParameterNames(parsedInputs); - - // Check if parsing was successful (not empty object) - const isSuccessful = Object.keys(parsedInputs).length > 0; - setIsParseSuccessful(isSuccessful); - if (isSuccessful) { - setParsedInputs(parsedInputs); - setDynamicInputParams(paramNames); - - // Initialize parameter values with appropriate defaults based on type - const initialValues: Record = {}; - paramNames.forEach((paramName) => { - const paramInfo = parsedInputs[paramName]; - const paramType = paramInfo?.type || DEFAULT_TYPE; - - if ( - paramInfo && - typeof paramInfo === "object" && - paramInfo.default != null - ) { - // Use provided default value, convert to string for UI display - switch (paramType) { - case "boolean": - initialValues[paramName] = paramInfo.default ? "true" : "false"; - break; - case "array": - case "object": - // JSON.stringify with indentation of 2 spaces for better readability - initialValues[paramName] = JSON.stringify( - paramInfo.default, - null, - 2 - ); - break; - default: - initialValues[paramName] = String(paramInfo.default); - } - } - }); - setParamValues(initialValues); - // Reset to parsed mode when parsing succeeds - setIsManualInputMode(false); - setManualJsonInput(""); - } else { - // Parsing returned empty object, treat as failed - setParsedInputs({}); - setParamValues({}); - setDynamicInputParams([]); - setIsManualInputMode(true); - setManualJsonInput("{}"); - } - } catch (error) { - log.error("Parameter parsing error:", error); - setParsedInputs({}); - setParamValues({}); - setDynamicInputParams([]); - setIsParseSuccessful(false); - // When parsing fails, automatically switch to manual input mode - setIsManualInputMode(true); - setManualJsonInput("{}"); - } - setTestPanelVisible(true); }; // Close test panel - const closeTestPanel = () => { + const handleCloseTestPanel = () => { setTestPanelVisible(false); - setTestResult(""); - setParsedInputs({}); - setParamValues({}); - setDynamicInputParams([]); - setTestExecuting(false); - setIsManualInputMode(false); - setManualJsonInput(""); - setIsParseSuccessful(false); - }; - - // Execute tool test - const executeTest = async () => { - if (!tool) return; - - setTestExecuting(true); - - try { - // Prepare parameters for tool validation with correct types - const toolParams: Record = {}; - - if (isManualInputMode) { - // Use manual JSON input - try { - const manualParams = JSON.parse(manualJsonInput); - Object.assign(toolParams, manualParams); - } catch (error) { - log.error("Failed to parse manual JSON input:", error); - setTestResult(`Test failed: Invalid JSON format in manual input`); - return; - } - } else { - // Use parsed parameters - dynamicInputParams.forEach((paramName) => { - const value = paramValues[paramName]; - const paramInfo = parsedInputs[paramName]; - const paramType = paramInfo?.type || DEFAULT_TYPE; - - if (value && value.trim() !== "") { - // Convert value to correct type based on parameter type from inputs - switch (paramType) { - case "integer": - case "number": - const numValue = Number(value.trim()); - if (!isNaN(numValue)) { - toolParams[paramName] = numValue; - } else { - toolParams[paramName] = value.trim(); // fallback to string if conversion fails - } - break; - case "boolean": - toolParams[paramName] = value.trim().toLowerCase() === "true"; - break; - case "array": - case "object": - try { - toolParams[paramName] = JSON.parse(value.trim()); - } catch { - toolParams[paramName] = value.trim(); // fallback to string if JSON parsing fails - } - break; - default: - toolParams[paramName] = value.trim(); - } - } - }); - } - - // Prepare configuration parameters from current params - const configParams = currentParams.reduce((acc, param) => { - acc[param.name] = param.value; - return acc; - }, {} as Record); - - // Call validateTool with parameters - const result = await validateTool( - tool.origin_name || tool.name, - tool.source, // Tool source - tool.usage || "", // Tool usage - toolParams, // tool input parameters - configParams // tool configuration parameters - ); - - // Format the JSON string response - let formattedResult: string; - try { - const parsedResult = - typeof result === "string" ? JSON.parse(result) : result; - formattedResult = JSON.stringify(parsedResult, null, 2); - } catch (parseError) { - log.error("Failed to parse JSON result:", parseError); - formattedResult = typeof result === "string" ? result : String(result); - } - setTestResult(formattedResult); - } catch (error) { - log.error("Tool test execution failed:", error); - setTestResult(`Test failed: ${error}`); - } finally { - setTestExecuting(false); - } }; const renderParamInput = (param: ToolParam, index: number) => { @@ -544,6 +399,7 @@ export default function ToolConfigModal({ cancelText={t("common.button.cancel")} width={600} confirmLoading={isLoading} + className="tool-config-modal-content" footer={ {isEditingMode && ( @@ -617,339 +473,16 @@ export default function ToolConfigModal({ {/* Tool Test Panel */} - {testPanelVisible && ( - <> - {/* Backdrop */} - - - {/* Test Panel */} - 0 ? `${mainModalTop}px` : "10vh", // Align with main modal top or fallback to 10vh - left: - mainModalRight > 0 - ? `${mainModalRight + windowWidth * 0.05}px` - : "calc(50% + 300px + 5vw)", // Position to the right of main modal with 5% viewport width gap - width: "500px", - height: "auto", - maxHeight: "80vh", - overflowY: "auto", - backgroundColor: "#fff", - border: "1px solid #d9d9d9", - borderRadius: "8px", - boxShadow: "0 4px 12px rgba(0, 0, 0, 0.15)", - zIndex: 1001, - display: "flex", - flexDirection: "column", - }} - > - {/* Test panel header */} - - - - {tool?.name} - - - } - onClick={closeTestPanel} - size="small" - /> - - - {/* Test panel content */} - - {t("toolConfig.toolTest.toolInfo")} - - {tool?.description} - - - {/* Test parameter input */} - - {/* Show current form parameters */} - {currentParams.length > 0 && ( - <> - - {t("toolConfig.toolTest.configParams")} - - - {currentParams.map((param) => ( - - {param.name} - - - - - ))} - - > - )} - - {/* Input parameters section with conditional toggle */} - {(dynamicInputParams.length > 0 || isManualInputMode) && ( - <> - - {t("toolConfig.toolTest.inputParams")} - {/* Only show toggle button if parsing was successful */} - {isParseSuccessful && ( - - ) : ( - - ) - } - onClick={() => { - setIsManualInputMode(!isManualInputMode); - if (!isManualInputMode) { - const currentParamsJson: Record = {}; - dynamicInputParams.forEach((paramName) => { - const value = paramValues[paramName]; - if (value && value.trim() !== "") { - const paramInfo = parsedInputs[paramName]; - const paramType = paramInfo?.type || DEFAULT_TYPE; - - try { - switch (paramType) { - case "integer": - case "number": - currentParamsJson[paramName] = Number( - value.trim() - ); - break; - case "boolean": - currentParamsJson[paramName] = - value.trim().toLowerCase() === "true"; - break; - case "array": - case "object": - currentParamsJson[paramName] = - JSON.parse(value.trim()); - break; - default: - currentParamsJson[paramName] = - value.trim(); - } - } catch { - currentParamsJson[paramName] = value.trim(); - } - } - }); - setManualJsonInput( - JSON.stringify(currentParamsJson, null, 2) - ); - } else { - // From manual input mode to parsed mode - try { - const manualParams = - JSON.parse(manualJsonInput); - const updatedParamValues: Record< - string, - string - > = {}; - dynamicInputParams.forEach((paramName) => { - const manualValue = manualParams[paramName]; - const paramInfo = parsedInputs[paramName]; - const paramType = - paramInfo?.type || DEFAULT_TYPE; - - if (manualValue !== undefined) { - // Convert to string for display based on parameter type - switch (paramType) { - case "boolean": - updatedParamValues[paramName] = - manualValue ? "true" : "false"; - break; - case "array": - case "object": - updatedParamValues[paramName] = - JSON.stringify(manualValue, null, 2); - break; - default: - updatedParamValues[paramName] = - String(manualValue); - } - } - }); - setParamValues(updatedParamValues); - } catch (error) { - log.error( - "Failed to sync manual input to parsed mode:", - error - ); - } - } - }} - > - {isManualInputMode - ? t("toolConfig.toolTest.parseMode") - : t("toolConfig.toolTest.manualInput")} - - )} - - - {isManualInputMode ? ( - // Manual JSON input mode - - setManualJsonInput(e.target.value)} - rows={6} - style={{ fontFamily: "monospace" }} - /> - - ) : ( - // Parsed parameters mode - dynamicInputParams.length > 0 && ( - - {dynamicInputParams.map((paramName) => { - const paramInfo = parsedInputs[paramName]; - const description = - paramInfo && - typeof paramInfo === "object" && - paramInfo.description - ? paramInfo.description - : paramName; - - return ( - - - {paramName} - - - { - setParamValues((prev) => ({ - ...prev, - [paramName]: e.target.value, - })); - }} - style={{ flex: 1 }} - /> - - - ); - })} - - ) - )} - > - )} - - - {testExecuting - ? t("toolConfig.toolTest.executing") - : t("toolConfig.toolTest.execute")} - - - - {/* Test result */} - - - {t("toolConfig.toolTest.result")} - - - - - - > - )} + setTestPanelVisible(visible)} + /> > ); } diff --git a/frontend/app/[locale]/agents/components/tool/ToolPool.tsx b/frontend/app/[locale]/agents/components/tool/ToolPool.tsx index 62b764cac..2460a8f07 100644 --- a/frontend/app/[locale]/agents/components/tool/ToolPool.tsx +++ b/frontend/app/[locale]/agents/components/tool/ToolPool.tsx @@ -107,23 +107,18 @@ function ToolPool({ // Group by source and usage availableTools.forEach((tool) => { let groupKey: string; - let groupLabel: string; if (tool.source === TOOL_SOURCE_TYPES.MCP) { // MCP tools grouped by usage const usage = tool.usage || TOOL_SOURCE_TYPES.OTHER; groupKey = `mcp-${usage}`; - groupLabel = usage; } else if (tool.source === TOOL_SOURCE_TYPES.LOCAL) { groupKey = TOOL_SOURCE_TYPES.LOCAL; - groupLabel = t("toolPool.group.local"); } else if (tool.source === TOOL_SOURCE_TYPES.LANGCHAIN) { groupKey = TOOL_SOURCE_TYPES.LANGCHAIN; - groupLabel = t("toolPool.group.langchain"); } else { // Other types groupKey = tool.source || TOOL_SOURCE_TYPES.OTHER; - groupLabel = tool.source || t("toolPool.group.other"); } if (!groupMap.has(groupKey)) { @@ -703,7 +698,7 @@ function ToolPool({ {t("toolPool.error.unavailableSelected")} )} - + {loadingTools ? ( {t("toolPool.loadingTools")} diff --git a/frontend/app/[locale]/agents/components/tool/ToolTestPanel.tsx b/frontend/app/[locale]/agents/components/tool/ToolTestPanel.tsx new file mode 100644 index 000000000..fd91a0e24 --- /dev/null +++ b/frontend/app/[locale]/agents/components/tool/ToolTestPanel.tsx @@ -0,0 +1,608 @@ +"use client"; + +import { useState, useEffect } from "react"; +import { useTranslation } from "react-i18next"; +import { motion, AnimatePresence } from "framer-motion"; +import { + Input, + Button, + Card, + Typography, + Tooltip, +} from "antd"; +import { + CloseOutlined, + SettingOutlined, + EditOutlined, +} from "@ant-design/icons"; + +import { ToolParam, Tool } from "@/types/agentConfig"; +import { + validateTool, + parseToolInputs, + extractParameterNames, +} from "@/services/agentConfigService"; +import log from "@/lib/logger"; +import { DEFAULT_TYPE } from "@/const/constants"; + +const { Text, Title } = Typography; + +export interface ToolTestPanelProps { + /** Whether the test panel is visible */ + visible: boolean; + /** Tool to test */ + tool: Tool | null; + /** Current configuration parameters */ + currentParams: ToolParam[]; + /** Main modal top position */ + mainModalTop: number; + /** Main modal right position */ + mainModalRight: number; + /** Window width for position calculation */ + windowWidth: number; + /** Callback when panel is closed */ + onClose: () => void; + /** Callback when panel visibility changes (for parent modal positioning) */ + onVisibilityChange?: (visible: boolean) => void; +} + +export default function ToolTestPanel({ + visible, + tool, + currentParams, + mainModalTop, + mainModalRight, + windowWidth, + onClose, + onVisibilityChange, +}: ToolTestPanelProps) { + const { t } = useTranslation("common"); + + // Tool test related state + const [testExecuting, setTestExecuting] = useState(false); + const [testResult, setTestResult] = useState(""); + const [parsedInputs, setParsedInputs] = useState>({}); + const [paramValues, setParamValues] = useState>({}); + const [dynamicInputParams, setDynamicInputParams] = useState([]); + const [isManualInputMode, setIsManualInputMode] = useState(false); + const [manualJsonInput, setManualJsonInput] = useState(""); + const [isParseSuccessful, setIsParseSuccessful] = useState(false); + + // Notify parent when visibility changes + useEffect(() => { + onVisibilityChange?.(visible); + }, [visible, onVisibilityChange]); + + // Initialize test panel when opened + useEffect(() => { + if (!visible || !tool) { + // Reset state when closed + setTestResult(""); + setParsedInputs({}); + setParamValues({}); + setDynamicInputParams([]); + setTestExecuting(false); + setIsManualInputMode(false); + setManualJsonInput(""); + setIsParseSuccessful(false); + return; + } + + // Parse inputs definition from tool inputs field + try { + const parsedInputs = parseToolInputs(tool.inputs || ""); + const paramNames = extractParameterNames(parsedInputs); + + // Check if parsing was successful (not empty object) + const isSuccessful = Object.keys(parsedInputs).length > 0; + setIsParseSuccessful(isSuccessful); + if (isSuccessful) { + setParsedInputs(parsedInputs); + setDynamicInputParams(paramNames); + + // Initialize parameter values with appropriate defaults based on type + const initialValues: Record = {}; + paramNames.forEach((paramName) => { + const paramInfo = parsedInputs[paramName]; + const paramType = paramInfo?.type || DEFAULT_TYPE; + + if ( + paramInfo && + typeof paramInfo === "object" && + paramInfo.default != null + ) { + // Use provided default value, convert to string for UI display + switch (paramType) { + case "boolean": + initialValues[paramName] = paramInfo.default ? "true" : "false"; + break; + case "array": + case "object": + // JSON.stringify with indentation of 2 spaces for better readability + initialValues[paramName] = JSON.stringify( + paramInfo.default, + null, + 2 + ); + break; + default: + initialValues[paramName] = String(paramInfo.default); + } + } + }); + setParamValues(initialValues); + // Reset to parsed mode when parsing succeeds + setIsManualInputMode(false); + setManualJsonInput(""); + } else { + // Parsing returned empty object, treat as failed + setParsedInputs({}); + setParamValues({}); + setDynamicInputParams([]); + setIsManualInputMode(true); + setManualJsonInput("{}"); + } + } catch (error) { + log.error("Parameter parsing error:", error); + setParsedInputs({}); + setParamValues({}); + setDynamicInputParams([]); + setIsParseSuccessful(false); + // When parsing fails, automatically switch to manual input mode + setIsManualInputMode(true); + setManualJsonInput("{}"); + } + }, [visible, tool]); + + // Close test panel + const handleClose = () => { + onClose(); + }; + + // Execute tool test + const executeTest = async () => { + if (!tool) return; + + setTestExecuting(true); + + try { + // Prepare parameters for tool validation with correct types + const toolParams: Record = {}; + + if (isManualInputMode) { + // Use manual JSON input + try { + const manualParams = JSON.parse(manualJsonInput); + Object.assign(toolParams, manualParams); + } catch (error) { + log.error("Failed to parse manual JSON input:", error); + setTestResult(`Test failed: Invalid JSON format in manual input`); + return; + } + } else { + // Use parsed parameters + dynamicInputParams.forEach((paramName) => { + const value = paramValues[paramName]; + const paramInfo = parsedInputs[paramName]; + const paramType = paramInfo?.type || DEFAULT_TYPE; + + if (value && value.trim() !== "") { + // Convert value to correct type based on parameter type from inputs + switch (paramType) { + case "integer": + case "number": + const numValue = Number(value.trim()); + if (!isNaN(numValue)) { + toolParams[paramName] = numValue; + } else { + toolParams[paramName] = value.trim(); // fallback to string if conversion fails + } + break; + case "boolean": + toolParams[paramName] = value.trim().toLowerCase() === "true"; + break; + case "array": + case "object": + try { + toolParams[paramName] = JSON.parse(value.trim()); + } catch { + toolParams[paramName] = value.trim(); // fallback to string if JSON parsing fails + } + break; + default: + toolParams[paramName] = value.trim(); + } + } + }); + } + + // Prepare configuration parameters from current params + const configParams = currentParams.reduce((acc, param) => { + acc[param.name] = param.value; + return acc; + }, {} as Record); + + // Call validateTool with parameters + const result = await validateTool( + tool.origin_name || tool.name, + tool.source, // Tool source + tool.usage || "", // Tool usage + toolParams, // tool input parameters + configParams // tool configuration parameters + ); + + // Format the JSON string response + let formattedResult: string; + try { + const parsedResult = + typeof result === "string" ? JSON.parse(result) : result; + formattedResult = JSON.stringify(parsedResult, null, 2); + } catch (parseError) { + log.error("Failed to parse JSON result:", parseError); + formattedResult = typeof result === "string" ? result : String(result); + } + setTestResult(formattedResult); + } catch (error) { + log.error("Tool test execution failed:", error); + setTestResult(`Test failed: ${error}`); + } finally { + setTestExecuting(false); + } + }; + + // Calculate test panel position to center both panels together + const testPanelWidth = 500; + const gap = windowWidth * 0.05; + const offsetForCentering = (testPanelWidth + gap) / 2; + + // Calculate test panel left position + const testPanelLeft = mainModalRight > 0 + ? mainModalRight + gap - offsetForCentering + : windowWidth / 2 + 300 + windowWidth * 0.05 - offsetForCentering; + + if (!tool) return null; + + return ( + + {visible && ( + <> + {/* Backdrop */} + + + {/* Test Panel */} + 0 ? `${mainModalTop}px` : "10vh", // Align with main modal top or fallback to 10vh + left: `${testPanelLeft}px`, // Position adjusted to center both panels together + width: "500px", + height: "auto", + maxHeight: "80vh", + overflowY: "auto", + backgroundColor: "#fff", + border: "1px solid #d9d9d9", + borderRadius: "8px", + boxShadow: "0 4px 12px rgba(0, 0, 0, 0.15)", + zIndex: 1001, + display: "flex", + flexDirection: "column", + }} + > + {/* Test panel header */} + + + + {tool?.name} + + + } + onClick={handleClose} + size="small" + /> + + + {/* Test panel content */} + + {t("toolConfig.toolTest.toolInfo")} + + {tool?.description} + + + {/* Test parameter input */} + + {/* Show current form parameters */} + {currentParams.length > 0 && ( + <> + + {t("toolConfig.toolTest.configParams")} + + + {currentParams.map((param) => ( + + {param.name} + + + + + ))} + + > + )} + + {/* Input parameters section with conditional toggle */} + {(dynamicInputParams.length > 0 || isManualInputMode) && ( + <> + + {t("toolConfig.toolTest.inputParams")} + {/* Only show toggle button if parsing was successful */} + {isParseSuccessful && ( + + ) : ( + + ) + } + onClick={() => { + setIsManualInputMode(!isManualInputMode); + if (!isManualInputMode) { + const currentParamsJson: Record = {}; + dynamicInputParams.forEach((paramName) => { + const value = paramValues[paramName]; + if (value && value.trim() !== "") { + const paramInfo = parsedInputs[paramName]; + const paramType = paramInfo?.type || DEFAULT_TYPE; + + try { + switch (paramType) { + case "integer": + case "number": + currentParamsJson[paramName] = Number( + value.trim() + ); + break; + case "boolean": + currentParamsJson[paramName] = + value.trim().toLowerCase() === "true"; + break; + case "array": + case "object": + currentParamsJson[paramName] = + JSON.parse(value.trim()); + break; + default: + currentParamsJson[paramName] = + value.trim(); + } + } catch { + currentParamsJson[paramName] = value.trim(); + } + } + }); + setManualJsonInput( + JSON.stringify(currentParamsJson, null, 2) + ); + } else { + // From manual input mode to parsed mode + try { + const manualParams = + JSON.parse(manualJsonInput); + const updatedParamValues: Record< + string, + string + > = {}; + dynamicInputParams.forEach((paramName) => { + const manualValue = manualParams[paramName]; + const paramInfo = parsedInputs[paramName]; + const paramType = + paramInfo?.type || DEFAULT_TYPE; + + if (manualValue !== undefined) { + // Convert to string for display based on parameter type + switch (paramType) { + case "boolean": + updatedParamValues[paramName] = + manualValue ? "true" : "false"; + break; + case "array": + case "object": + updatedParamValues[paramName] = + JSON.stringify(manualValue, null, 2); + break; + default: + updatedParamValues[paramName] = + String(manualValue); + } + } + }); + setParamValues(updatedParamValues); + } catch (error) { + log.error( + "Failed to sync manual input to parsed mode:", + error + ); + } + } + }} + > + {isManualInputMode + ? t("toolConfig.toolTest.parseMode") + : t("toolConfig.toolTest.manualInput")} + + )} + + + {isManualInputMode ? ( + // Manual JSON input mode + + setManualJsonInput(e.target.value)} + rows={6} + style={{ fontFamily: "monospace" }} + /> + + ) : ( + // Parsed parameters mode + dynamicInputParams.length > 0 && ( + + {dynamicInputParams.map((paramName) => { + const paramInfo = parsedInputs[paramName]; + const description = + paramInfo && + typeof paramInfo === "object" && + paramInfo.description + ? paramInfo.description + : paramName; + + return ( + + + {paramName} + + + { + setParamValues((prev) => ({ + ...prev, + [paramName]: e.target.value, + })); + }} + style={{ flex: 1 }} + /> + + + ); + })} + + ) + )} + > + )} + + + {testExecuting + ? t("toolConfig.toolTest.executing") + : t("toolConfig.toolTest.execute")} + + + + {/* Test result */} + + + {t("toolConfig.toolTest.result")} + + + + + + > + )} + + ); +} + diff --git a/frontend/app/[locale]/chat/internal/memory/memoryManageModal.tsx b/frontend/app/[locale]/chat/internal/memory/memoryManageModal.tsx index e0639ed28..424bf75f7 100644 --- a/frontend/app/[locale]/chat/internal/memory/memoryManageModal.tsx +++ b/frontend/app/[locale]/chat/internal/memory/memoryManageModal.tsx @@ -149,56 +149,54 @@ const MemoryManageModal: React.FC = ({ if (memory.addingMemoryKey !== groupKey) return null; return ( - - - memory.setNewMemoryContent(e.target.value)} - placeholder={t("memoryManageModal.inputPlaceholder")} - maxLength={500} - showCount - onPressEnter={memory.confirmAddingMemory} - disabled={memory.isAddingMemory} - className="flex-1" - autoSize={{ minRows: 2, maxRows: 5 }} - /> - } - onClick={memory.cancelAddingMemory} - disabled={memory.isAddingMemory} - style={{ - border: "none", - backgroundColor: "transparent", - boxShadow: "none", - }} - /> - } - onClick={memory.confirmAddingMemory} - loading={memory.isAddingMemory} - disabled={!memory.newMemoryContent.trim()} - style={{ - border: "none", - backgroundColor: "transparent", - boxShadow: "none", - }} - /> - - + + memory.setNewMemoryContent(e.target.value)} + placeholder={t("memoryManageModal.inputPlaceholder")} + maxLength={500} + showCount + onPressEnter={memory.confirmAddingMemory} + disabled={memory.isAddingMemory} + className="flex-1" + autoSize={{ minRows: 2, maxRows: 5 }} + /> + } + onClick={memory.cancelAddingMemory} + disabled={memory.isAddingMemory} + style={{ + border: "none", + backgroundColor: "transparent", + boxShadow: "none", + }} + /> + } + onClick={memory.confirmAddingMemory} + loading={memory.isAddingMemory} + disabled={!memory.newMemoryContent.trim()} + style={{ + border: "none", + backgroundColor: "transparent", + boxShadow: "none", + }} + /> + ); }; @@ -263,7 +261,9 @@ const MemoryManageModal: React.FC = ({ shape="round" color="green" title={t("memoryManageModal.addMemory")} - onClick={() => memory.startAddingMemory(g.key)} + onClick={() => { + memory.startAddingMemory(g.key); + }} icon={} className="hover:!bg-green-50" style={{ @@ -303,41 +303,46 @@ const MemoryManageModal: React.FC = ({ ) : ( - ( - } - onClick={() => - memory.handleDeleteMemory(item.id, g.key) - } - />, - ]} - > - - {item.memory} - - + + {memory.addingMemoryKey === g.key && ( + + {renderAddMemoryInput(g.key)} + )} - > - {renderAddMemoryInput(g.key)} - + ( + } + onClick={() => + memory.handleDeleteMemory(item.id, g.key) + } + />, + ]} + > + + {item.memory} + + + )} + /> + )} ))} @@ -435,42 +440,47 @@ const MemoryManageModal: React.FC = ({ ), collapsible: disabled ? "disabled" : undefined, children: ( - ( - } - disabled={disabled} - onClick={() => - memory.handleDeleteMemory(item.id, g.key) - } - />, - ]} - > - - {item.memory} - - + + {memory.addingMemoryKey === g.key && ( + + {renderAddMemoryInput(g.key)} + )} - > - {renderAddMemoryInput(g.key)} - + ( + } + disabled={disabled} + onClick={() => + memory.handleDeleteMemory(item.id, g.key) + } + />, + ]} + > + + {item.memory} + + + )} + /> + ), showArrow: true, className: "memory-modal-panel", diff --git a/frontend/app/[locale]/chat/streaming/chatStreamFinalMessage.tsx b/frontend/app/[locale]/chat/streaming/chatStreamFinalMessage.tsx index fe5bad3be..d8efdb86c 100644 --- a/frontend/app/[locale]/chat/streaming/chatStreamFinalMessage.tsx +++ b/frontend/app/[locale]/chat/streaming/chatStreamFinalMessage.tsx @@ -272,6 +272,9 @@ export function ChatStreamFinalMessage({ content={message.finalAnswer || message.content || ""} searchResults={message?.searchResults} onCitationHover={onCitationHover} + // For historical messages, content already represents the final answer + // when finalAnswer is not present, so enable S3 resolution in both cases. + resolveS3Media={Boolean(message.finalAnswer || message.content)} /> {/* Button group - only show when hideButtons is false and message is complete */} diff --git a/frontend/app/[locale]/chat/streaming/chatStreamHandler.tsx b/frontend/app/[locale]/chat/streaming/chatStreamHandler.tsx index f8e53a46b..2cd790e15 100644 --- a/frontend/app/[locale]/chat/streaming/chatStreamHandler.tsx +++ b/frontend/app/[locale]/chat/streaming/chatStreamHandler.tsx @@ -67,6 +67,7 @@ export const handleStreamResponse = async ( let lastContentType: | typeof chatConfig.contentTypes.MODEL_OUTPUT + | typeof chatConfig.contentTypes.MODEL_OUTPUT_CODE | typeof chatConfig.contentTypes.PARSING | typeof chatConfig.contentTypes.EXECUTION | typeof chatConfig.contentTypes.AGENT_NEW_RUN @@ -77,13 +78,24 @@ export const handleStreamResponse = async ( | typeof chatConfig.contentTypes.PREPROCESS | null = null; let lastModelOutputIndex = -1; // Track the index of the last model output in currentStep.contents + let lastCodeOutputIndex = -1; // Track the index of the last code output for proper streaming let searchResultsContent: any[] = []; let allSearchResults: any[] = []; let finalAnswer = ""; try { while (true) { - const { done, value } = await reader.read(); + let readResult; + try { + readResult = await reader.read(); + } catch (readError: any) { + // If read is aborted, break the loop gracefully + if (readError?.name === "AbortError" || readError?.name === "AbortSignal") { + break; + } + throw readError; + } + const { done, value } = readResult; if (done) break; buffer += decoder.decode(value, { stream: true }); @@ -130,6 +142,7 @@ export const handleStreamResponse = async ( // Reset status tracking variables lastContentType = null; lastModelOutputIndex = -1; + lastCodeOutputIndex = -1; break; @@ -298,70 +311,74 @@ export const handleStreamResponse = async ( } if (isDebug) { - // In debug mode, use streaming output like model_output_thinking - // Ensure contents exists + // In debug mode, use MODEL_OUTPUT_CODE type for streaming output let processedContent = messageContent; - // Check if we should append to existing content or create new - const shouldAppend = - lastContentType === chatConfig.contentTypes.MODEL_OUTPUT && - lastModelOutputIndex >= 0 && - currentStep.contents[lastModelOutputIndex] && - currentStep.contents[lastModelOutputIndex].subType === - "code"; - - if (shouldAppend) { - const modelOutput = - currentStep.contents[lastModelOutputIndex]; + // Check if we should append to existing code content + // Only append if the last content type was MODEL_OUTPUT_CODE and we have a valid index + const shouldAppendCode = + lastContentType === chatConfig.contentTypes.MODEL_OUTPUT_CODE && + lastCodeOutputIndex >= 0 && + currentStep.contents[lastCodeOutputIndex] && + currentStep.contents[lastCodeOutputIndex].type === + chatConfig.messageTypes.MODEL_OUTPUT_CODE; + + if (shouldAppendCode) { + const codeOutput = + currentStep.contents[lastCodeOutputIndex]; const codePrefix = t("chatStreamHandler.codePrefix"); // In append mode, also check for prefix in case it wasn't removed before if ( - modelOutput.content.includes(codePrefix) && + codeOutput.content.includes(codePrefix) && processedContent.trim() ) { // Clean existing content - modelOutput.content = modelOutput.content.replace( - new RegExp(codePrefix + `\\s*`), + codeOutput.content = codeOutput.content.replace( + new RegExp(`^(${codePrefix}|代码|Code)[::]\\s*`, "i"), "" ); } - // Directly append without prefix processing (prefix should have been removed when first created) - let newContent = modelOutput.content + processedContent; - // Remove "(null); + const chatInputRef = useRef(null); const [showScrollButton, setShowScrollButton] = useState(false); const [showTopFade, setShowTopFade] = useState(false); const [autoScroll, setAutoScroll] = useState(true); + const [chatInputHeight, setChatInputHeight] = useState(130); // Default ChatInput height const [processedMessages, setProcessedMessages] = useState( { finalMessages: [], @@ -72,218 +75,58 @@ export function ChatStreamMain({ const lastUserMessageIdRef = useRef(null); const messagesEndRef = useRef(null); + // Monitor ChatInput height changes + useEffect(() => { + const chatInputElement = chatInputRef.current; + if (!chatInputElement) return; + + const resizeObserver = new ResizeObserver((entries) => { + for (const entry of entries) { + const height = entry.contentRect.height; + setChatInputHeight(height); + } + }); + + resizeObserver.observe(chatInputElement); + + // Set initial height + setChatInputHeight(chatInputElement.getBoundingClientRect().height); + + return () => { + resizeObserver.disconnect(); + }; + }, [processedMessages.finalMessages.length]); // Re-observe when messages change (initial vs regular mode) + // Handle message classification useEffect(() => { const finalMsgs: ChatMessageType[] = []; - const taskMsgs: any[] = []; - const conversationGroups = new Map(); - const truncationBuffer = new Map(); // Buffer for truncation messages by user message ID - const processedTruncationIds = new Set(); // Track processed truncation messages to avoid duplicates - - // First preprocess, find all user message IDs and initialize task groups - messages.forEach((message) => { - if (message.role === USER_ROLES.USER && message.id) { - conversationGroups.set(message.id, []); - truncationBuffer.set(message.id, []); // Initialize truncation buffer for each user message - } - }); - - let currentUserMsgId: string | null = null; + + // Track the latest user message ID for scroll behavior + messages.forEach((message) => { + if (message.role === USER_ROLES.USER && message.id) { + lastUserMessageIdRef.current = message.id; + } + }); - // Process all messages, distinguish user messages, final answers, and task messages + // Process all messages, distinguish user messages and final answers messages.forEach((message) => { // User messages are directly added to the final message array if (message.role === USER_ROLES.USER) { finalMsgs.push(message); - // Record the user message ID, used to associate subsequent tasks - if (message.id) { - currentUserMsgId = message.id; - - // Save the latest user message ID to the ref - lastUserMessageIdRef.current = message.id; - } } - // Assistant messages need further processing + // Assistant messages - if there is a final answer or content, add it to the final message array else if (message.role === ROLE_ASSISTANT) { - // If there is a final answer or content (including empty string), add it to the final message array if (message.finalAnswer || message.content !== undefined) { finalMsgs.push(message); - // Do not reset currentUserMsgId here, continue to use it to associate tasks - } - - // Process all steps and content as task messages - if (message.steps && message.steps.length > 0) { - message.steps.forEach((step) => { - // Process step.contents (if it exists) - if (step.contents && step.contents.length > 0) { - step.contents.forEach((content: any) => { - const taskMsg = { - type: content.type, - subType: content.subType, // Preserve subType for styling (e.g., deep_thinking) - content: content.content, - id: content.id, - assistantId: message.id, - relatedUserMsgId: currentUserMsgId, - // For preprocess messages, include the full contents array for TaskWindow - contents: content.type === chatConfig.contentTypes.PREPROCESS ? step.contents : undefined, - }; - - // Handle truncation messages specially - buffer them instead of adding immediately - if (content.type === "truncation") { - // Create a unique ID for this truncation message to avoid duplicates - const truncationId = `${content.filename || 'unknown'}_${content.message || ''}_${currentUserMsgId || 'no_user'}`; - - // Only add if not already processed - if (!processedTruncationIds.has(truncationId) && currentUserMsgId && truncationBuffer.has(currentUserMsgId)) { - const buffer = truncationBuffer.get(currentUserMsgId) || []; - buffer.push(taskMsg); - truncationBuffer.set(currentUserMsgId, buffer); - processedTruncationIds.add(truncationId); - } - } else { - // For non-truncation messages, add them immediately - taskMsgs.push(taskMsg); - - // If there is a related user message, add it to the corresponding task group - if ( - currentUserMsgId && - conversationGroups.has(currentUserMsgId) - ) { - const tasks = conversationGroups.get(currentUserMsgId) || []; - tasks.push(taskMsg); - conversationGroups.set(currentUserMsgId, tasks); - } - } - }); - } - - // Process step.thinking (if it exists) - if (step.thinking && step.thinking.content) { - const taskMsg = { - type: chatConfig.messageTypes.MODEL_OUTPUT_THINKING, - content: step.thinking.content, - id: `thinking-${step.id}`, - assistantId: message.id, - relatedUserMsgId: currentUserMsgId, - }; - taskMsgs.push(taskMsg); - - // If there is a related user message, add it to the corresponding task group - if ( - currentUserMsgId && - conversationGroups.has(currentUserMsgId) - ) { - const tasks = conversationGroups.get(currentUserMsgId) || []; - tasks.push(taskMsg); - conversationGroups.set(currentUserMsgId, tasks); - } - } - - // Process step.code (if it exists) - if (step.code && step.code.content) { - const taskMsg = { - type: chatConfig.messageTypes.MODEL_OUTPUT_CODE, - content: step.code.content, - id: `code-${step.id}`, - assistantId: message.id, - relatedUserMsgId: currentUserMsgId, - }; - taskMsgs.push(taskMsg); - - // If there is a related user message, add it to the corresponding task group - if ( - currentUserMsgId && - conversationGroups.has(currentUserMsgId) - ) { - const tasks = conversationGroups.get(currentUserMsgId) || []; - tasks.push(taskMsg); - conversationGroups.set(currentUserMsgId, tasks); - } - } - - // Process step.output (if it exists) - if (step.output && step.output.content) { - const taskMsg = { - type: chatConfig.messageTypes.TOOL, - content: step.output.content, - id: `output-${step.id}`, - assistantId: message.id, - relatedUserMsgId: currentUserMsgId, - }; - taskMsgs.push(taskMsg); - - // If there is a related user message, add it to the corresponding task group - if ( - currentUserMsgId && - conversationGroups.has(currentUserMsgId) - ) { - const tasks = conversationGroups.get(currentUserMsgId) || []; - tasks.push(taskMsg); - conversationGroups.set(currentUserMsgId, tasks); - } - } - }); } - - // Process thinking status (if it exists) - if (message.thinking && message.thinking.length > 0) { - message.thinking.forEach((thinking, index) => { - const taskMsg = { - type: chatConfig.messageTypes.MODEL_OUTPUT_THINKING, - content: thinking.content, - id: `thinking-${message.id}-${index}`, - assistantId: message.id, - relatedUserMsgId: currentUserMsgId, - }; - taskMsgs.push(taskMsg); - - // If there is a related user message, add it to the corresponding task group - if (currentUserMsgId && conversationGroups.has(currentUserMsgId)) { - const tasks = conversationGroups.get(currentUserMsgId) || []; - tasks.push(taskMsg); - conversationGroups.set(currentUserMsgId, tasks); - } - }); - } - } - }); - - // Process complete messages and release buffered truncation messages - messages.forEach((message) => { - if (message.role === ROLE_ASSISTANT && message.steps) { - message.steps.forEach((step) => { - if (step.contents && step.contents.length > 0) { - step.contents.forEach((content: any) => { - if (content.type === "complete") { - // Find the related user message ID for this complete message - let relatedUserMsgId: string | null = null; - - // Find the user message that this assistant message is responding to - const messageIndex = messages.indexOf(message); - for (let i = messageIndex - 1; i >= 0; i--) { - if (messages[i].role === "user" && messages[i].id) { - relatedUserMsgId = messages[i].id; - break; - } - } - - if (relatedUserMsgId && truncationBuffer.has(relatedUserMsgId)) { - // Clear the buffer for this user message - truncationBuffer.delete(relatedUserMsgId); - } - } - }); - } - }); } }); - // Check and delete empty task groups - for (const [key, value] of conversationGroups.entries()) { - if (value.length === 0) { - conversationGroups.delete(key); - } - } + // Use unified message transformer (includeCode: false for normal chat mode) + const { taskMessages: taskMsgs, conversationGroups } = transformMessagesToTaskMessages( + messages, + { includeCode: false } + ); setProcessedMessages({ finalMessages: finalMsgs, @@ -515,6 +358,7 @@ export function ChatStreamMain({ animate="animate" variants={chatInputVariants} transition={chatInputTransition} + ref={chatInputRef} > )} - {/* Scroll to bottom button */} + {/* Scroll to bottom button - dynamically positioned based on ChatInput height */} {showScrollButton && ( { e.preventDefault(); e.stopPropagation(); @@ -607,6 +456,7 @@ export function ChatStreamMain({ animate="animate" variants={chatInputVariants} transition={chatInputTransition} + ref={chatInputRef} > void; - isSelected?: boolean; - searchResultsCount?: number; - imagesCount?: number; - onImageClick?: (imageUrl: string) => void; - onOpinionChange?: (messageId: number, opinion: Opinion) => void; -} - -export function ChatStreamMessage({ - message, - onSelectMessage, - isSelected = false, - searchResultsCount = 0, - imagesCount = 0, - onImageClick, - onOpinionChange, -}: StreamMessageProps) { - const { t } = useTranslation("common"); - const { getAppAvatarUrl } = useConfig(); - const avatarUrl = getAppAvatarUrl(20); // Message avatar size is 20px - - const messageRef = useRef(null); - const [copied, setCopied] = useState(false); - const [localOpinion, setLocalOpinion] = useState( - message.opinion_flag ?? null - ); - const [isVisible, setIsVisible] = useState(false); - - // Animation effect - message enters with fade-in - useEffect(() => { - const timer = setTimeout(() => { - setIsVisible(true); - }, 10); - return () => clearTimeout(timer); - }, []); - - // When the message is updated, scroll the element into the visible area - useEffect(() => { - if ( - message.role === ROLE_ASSISTANT && - !message.isComplete && - messageRef.current - ) { - messageRef.current.scrollIntoView({ behavior: "smooth", block: "end" }); - } - }, [message.content, message.isComplete, message.role]); - - // Update opinion status - useEffect(() => { - setLocalOpinion(message.opinion_flag ?? null); - }, [message.opinion_flag]); - - // Copy content to clipboard - const handleCopyContent = () => { - const contentToCopy = message.finalAnswer || message.content; - if (!contentToCopy) return; - - copyToClipboard(contentToCopy) - .then(() => { - setCopied(true); - setTimeout(() => setCopied(false), 2000); - }) - .catch((err) => { - log.error(t("chatStreamMessage.copyFailed"), err); - }); - }; - - // Handle likes - const handleThumbsUp = () => { - const newOpinion = localOpinion === chatConfig.opinion.POSITIVE ? null : chatConfig.opinion.POSITIVE; - setLocalOpinion(newOpinion); - if (onOpinionChange && message.message_id) { - onOpinionChange(message.message_id, newOpinion as Opinion); - } - }; - - // Handle dislikes - const handleThumbsDown = () => { - const newOpinion = localOpinion === chatConfig.opinion.NEGATIVE ? null : chatConfig.opinion.NEGATIVE; - setLocalOpinion(newOpinion); - if (onOpinionChange && message.message_id) { - onOpinionChange(message.message_id, newOpinion as Opinion); - } - }; - - // Handle message selection - const handleMessageSelect = () => { - if (message.id && onSelectMessage) { - onSelectMessage(message.id); - } - }; - - return ( - - {/* Avatar section - only show avatar for AI assistant */} - {message.role === ROLE_ASSISTANT && ( - - - - - - )} - - {/* Message content section */} - - {/* User message section */} - {message.role === USER_ROLES.USER && ( - <> - {/* Attachment section - placed above text */} - {message.attachments && message.attachments.length > 0 && ( - - - - - - )} - - {/* Text content */} - {message.content && ( - - - {message.content} - - - )} - > - )} - - {/* Assistant message section */} - {message.role === ROLE_ASSISTANT && ( - <> - {/* Attachment section - placed above text */} - {message.attachments && message.attachments.length > 0 && ( - - - - - - )} - - {/* Text content - streaming rendering area */} - {message.content && ( - - - - )} - - {/* Thinking status */} - {!message.isComplete && - message.thinking && - message.thinking.length > 0 && ( - - - - {message.thinking[0].content} - - - )} - - {/* Final answer */} - {message.finalAnswer && ( - - - - - {t("chatStreamMessage.finalAnswer")} - - - - - - {/* Button group */} - - {/* Source button */} - {((message?.searchResults && - message.searchResults.length > 0) || - (message?.images && message.images.length > 0)) && ( - - - - {searchResultsCount > 0 && - t("chatStreamMessage.sources", { - count: searchResultsCount, - })} - {searchResultsCount > 0 && imagesCount > 0 && ", "} - {imagesCount > 0 && - t("chatStreamMessage.images", { - count: imagesCount, - })} - - - - - )} - - {/* Tool button */} - - - {/* Copy button */} - - - - - - - - - {copied - ? t("chatStreamMessage.copied") - : t("chatStreamMessage.copyContent")} - - - - - {/* Like button */} - - - - - - - - - {localOpinion === chatConfig.opinion.POSITIVE - ? t("chatStreamMessage.cancelLike") - : t("chatStreamMessage.like")} - - - - - {/* Dislike button */} - - - - - - - - - {localOpinion === chatConfig.opinion.NEGATIVE - ? t("chatStreamMessage.cancelDislike") - : t("chatStreamMessage.dislike")} - - - - - {/* Voice announcement button */} - - - - - - - - {t("chatStreamMessage.tts")} - - - - - - - - )} - > - )} - - - ); -} diff --git a/frontend/app/[locale]/chat/streaming/messageTransformer.ts b/frontend/app/[locale]/chat/streaming/messageTransformer.ts new file mode 100644 index 000000000..f579ad1ea --- /dev/null +++ b/frontend/app/[locale]/chat/streaming/messageTransformer.ts @@ -0,0 +1,220 @@ +import { ROLE_ASSISTANT } from "@/const/agentConfig"; +import { chatConfig } from "@/const/chatConfig"; +import { USER_ROLES } from "@/const/modelConfig"; +import { ChatMessageType, TaskMessageType } from "@/types/chat"; + +/** + * Transform chat messages to task messages for TaskWindow rendering + * @param messages - Array of chat messages to transform + * @param options - Configuration options + * @param options.includeCode - Whether to include step.code as separate task messages (for debug mode) + * @returns Array of task messages grouped by user message ID + */ +export function transformMessagesToTaskMessages( + messages: ChatMessageType[], + options: { + includeCode?: boolean; + } = {} +): { + taskMessages: TaskMessageType[]; + conversationGroups: Map; +} { + const { includeCode = false } = options; + const taskMsgs: TaskMessageType[] = []; + const conversationGroups = new Map(); + const truncationBuffer = new Map(); + const processedTruncationIds = new Set(); + + // First preprocess, find all user message IDs and initialize task groups + messages.forEach((message) => { + if (message.role === USER_ROLES.USER && message.id) { + conversationGroups.set(message.id, []); + truncationBuffer.set(message.id, []); + } + }); + + let currentUserMsgId: string | null = null; + + // Process all messages + messages.forEach((message) => { + // User messages - record the ID for associating subsequent tasks + if (message.role === USER_ROLES.USER && message.id) { + currentUserMsgId = message.id; + } + // Assistant messages - extract task messages from steps + else if (message.role === ROLE_ASSISTANT && message.steps && message.steps.length > 0) { + message.steps.forEach((step) => { + // Process step.contents + if (step.contents && step.contents.length > 0) { + step.contents.forEach((content: any) => { + const taskMsg: TaskMessageType = { + id: content.id, + role: ROLE_ASSISTANT, + content: content.content, + timestamp: new Date(), + type: content.type, + subType: content.subType, + // For preprocess messages, include the full contents array for TaskWindow + // For search_content_placeholder messages, include search results from message level + _messageContainer: + content.type === chatConfig.contentTypes.PREPROCESS + ? { contents: step.contents } + : content.type === chatConfig.messageTypes.SEARCH_CONTENT_PLACEHOLDER && message.searchResults + ? { search: message.searchResults } + : undefined, + } as any; + + // Handle truncation messages specially - buffer them instead of adding immediately + if (content.type === "truncation") { + const truncationId = `${content.filename || 'unknown'}_${content.message || ''}_${currentUserMsgId || 'no_user'}`; + if (!processedTruncationIds.has(truncationId) && currentUserMsgId && truncationBuffer.has(currentUserMsgId)) { + const buffer = truncationBuffer.get(currentUserMsgId) || []; + buffer.push(taskMsg); + truncationBuffer.set(currentUserMsgId, buffer); + processedTruncationIds.add(truncationId); + } + } else { + // For non-truncation messages, add them immediately + taskMsgs.push(taskMsg); + + // If there is a related user message, add it to the corresponding task group + if (currentUserMsgId && conversationGroups.has(currentUserMsgId)) { + const tasks = conversationGroups.get(currentUserMsgId) || []; + tasks.push(taskMsg); + conversationGroups.set(currentUserMsgId, tasks); + } + } + }); + } + + // Process step.thinking (if it exists) + if (step.thinking && step.thinking.content) { + const taskMsg: TaskMessageType = { + id: `thinking-${step.id}`, + role: ROLE_ASSISTANT, + content: step.thinking.content, + timestamp: new Date(), + type: chatConfig.messageTypes.MODEL_OUTPUT_THINKING, + } as any; + + taskMsgs.push(taskMsg); + + if (currentUserMsgId && conversationGroups.has(currentUserMsgId)) { + const tasks = conversationGroups.get(currentUserMsgId) || []; + tasks.push(taskMsg); + conversationGroups.set(currentUserMsgId, tasks); + } + } + + // Process step.code (if it exists and includeCode is true) + if (includeCode && step.code && step.code.content) { + const taskMsg: TaskMessageType = { + id: `code-${step.id}`, + role: ROLE_ASSISTANT, + content: step.code.content, + timestamp: new Date(), + type: chatConfig.messageTypes.MODEL_OUTPUT_CODE, + } as any; + + taskMsgs.push(taskMsg); + + if (currentUserMsgId && conversationGroups.has(currentUserMsgId)) { + const tasks = conversationGroups.get(currentUserMsgId) || []; + tasks.push(taskMsg); + conversationGroups.set(currentUserMsgId, tasks); + } + } + + // Process step.output (if it exists) + if (step.output && step.output.content) { + const taskMsg: TaskMessageType = { + id: `output-${step.id}`, + role: ROLE_ASSISTANT, + content: step.output.content, + timestamp: new Date(), + type: chatConfig.messageTypes.TOOL, + } as any; + + taskMsgs.push(taskMsg); + + if (currentUserMsgId && conversationGroups.has(currentUserMsgId)) { + const tasks = conversationGroups.get(currentUserMsgId) || []; + tasks.push(taskMsg); + conversationGroups.set(currentUserMsgId, tasks); + } + } + }); + } + + // Process thinking status (if it exists at message level) + if (message.thinking && message.thinking.length > 0) { + message.thinking.forEach((thinking, index) => { + const taskMsg: TaskMessageType = { + id: `thinking-${message.id}-${index}`, + role: ROLE_ASSISTANT, + content: thinking.content, + timestamp: new Date(), + type: chatConfig.messageTypes.MODEL_OUTPUT_THINKING, + } as any; + + taskMsgs.push(taskMsg); + + if (currentUserMsgId && conversationGroups.has(currentUserMsgId)) { + const tasks = conversationGroups.get(currentUserMsgId) || []; + tasks.push(taskMsg); + conversationGroups.set(currentUserMsgId, tasks); + } + }); + } + }); + + // Process complete messages and release buffered truncation messages + messages.forEach((message) => { + if (message.role === ROLE_ASSISTANT && message.steps) { + message.steps.forEach((step) => { + if (step.contents && step.contents.length > 0) { + step.contents.forEach((content: any) => { + if (content.type === "complete") { + // Find the related user message ID for this complete message + let relatedUserMsgId: string | null = null; + const messageIndex = messages.indexOf(message); + for (let i = messageIndex - 1; i >= 0; i--) { + if (messages[i].role === "user" && messages[i].id) { + relatedUserMsgId = messages[i].id; + break; + } + } + + if (relatedUserMsgId && truncationBuffer.has(relatedUserMsgId)) { + // Release buffered truncation messages + const buffer = truncationBuffer.get(relatedUserMsgId) || []; + buffer.forEach((truncationMsg) => { + taskMsgs.push(truncationMsg); + if (conversationGroups.has(relatedUserMsgId!)) { + const tasks = conversationGroups.get(relatedUserMsgId!) || []; + tasks.push(truncationMsg); + conversationGroups.set(relatedUserMsgId!, tasks); + } + }); + truncationBuffer.delete(relatedUserMsgId); + } + } + }); + } + }); + } + }); + + // Check and delete empty task groups + for (const [key, value] of conversationGroups.entries()) { + if (value.length === 0) { + conversationGroups.delete(key); + } + } + + return { + taskMessages: taskMsgs, + conversationGroups, + }; +} + diff --git a/frontend/app/[locale]/chat/streaming/taskWindow.tsx b/frontend/app/[locale]/chat/streaming/taskWindow.tsx index cb1d1cc94..d86f666a5 100644 --- a/frontend/app/[locale]/chat/streaming/taskWindow.tsx +++ b/frontend/app/[locale]/chat/streaming/taskWindow.tsx @@ -13,13 +13,135 @@ import { import { ScrollArea } from "@/components/ui/scrollArea"; import { Button } from "@/components/ui/button"; -import { MarkdownRenderer } from "@/components/ui/markdownRenderer"; +import { MarkdownRenderer, CodeBlock } from "@/components/ui/markdownRenderer"; import { chatConfig } from "@/const/chatConfig"; import { ChatMessageType, TaskMessageType, CardItem, MessageHandler } from "@/types/chat"; import { useChatTaskMessage } from "@/hooks/useChatTaskMessage"; import { storageService, extractObjectNameFromUrl } from "@/services/storageService"; import log from "@/lib/logger"; +/** + * Extract code content and language from model_output_code content + * Handles both and formats + * Supports streaming mode where end markers may not be present yet + * @param content - Raw code content from stream + * @returns Object with codeContent and language + */ +const extractCodeInfo = (content: string): { codeContent: string; language: string } => { + if (!content || typeof content !== "string") { + return { codeContent: "", language: "python" }; + } + + let processed = content; + + // Remove "代码:" or "Code:" prefix if present (handle both full-width and half-width colon) + processed = processed.replace(/^(代码|Code)[::]\s*/i, ""); + + // 1. Detect and process COMPLETE format + // Match: ``` or ``` + const displayMatch = processed.match(/```\s*/); + if (displayMatch) { + const language = displayMatch[1]; + // Remove the opening marker (handle optional whitespace and newline) + processed = processed.replace(/```\s*\s*\n?/, ""); + // Remove closing marker if present: ``` or just + processed = processed.replace(/\n?```[\s\S]*$/, ""); + processed = processed.replace(/[\s\S]*$/, ""); + // Remove trailing "[已展示给用户]" or similar text + processed = processed.replace(/\[已展示给用户\][\s\S]*$/, ""); + // Clean up any remaining incomplete markers (for streaming) + processed = processed.replace(/\n?``` format (executable code, default to python) + const runMatch = processed.match(/```\s*/); + if (runMatch) { + // Remove the opening marker + processed = processed.replace(/```\s*\s*\n?/, ""); + // Remove closing marker if present + processed = processed.replace(/\n?```[\s\S]*$/, ""); + processed = processed.replace(/[\s\S]*$/, ""); + // Clean up any remaining incomplete markers (for streaming) + processed = processed.replace(/\n?```) + // Or: ```")) { + const partialDisplayMatch = processed.match(//); + if (partialDisplayMatch) { + const language = partialDisplayMatch[1]; + // Remove all variations of the display marker + processed = processed.replace(/```\s*\s*\n?/g, ""); + processed = processed.replace(/\s*\n?/g, ""); + // Clean up end markers + processed = processed.replace(/\n?```")) { + // Remove all variations of the RUN marker + processed = processed.replace(/```\s*\s*\n?/g, ""); + processed = processed.replace(/\s*\n?/g, ""); + // Clean up end markers + processed = processed.replace(/\n?``` = { search: , @@ -823,6 +945,30 @@ const messageHandlers: MessageHandler[] = [ ), }, + // model_output_code type processor - code output with direct code block rendering + { + canHandle: (message) => message.type === chatConfig.messageTypes.MODEL_OUTPUT_CODE, + render: (message, _t) => { + // Extract code content and language from the message + const { codeContent, language } = extractCodeInfo(message.content); + + return ( + + + + ); + }, + }, + // execution type processor - execution result (not displayed) { canHandle: (message) => message.type === "execution", @@ -1307,37 +1453,28 @@ export function TaskWindow({ messages, isStreaming = false }: TaskWindowProps) { } /* For the code block style in task-message-content */ - .task-message-content pre { - white-space: pre-wrap !important; - word-wrap: break-word !important; - word-break: break-word !important; - overflow-wrap: break-word !important; - overflow: auto !important; + /* Allow code-block-container to use its default styles */ + .task-message-content .code-block-container { max-width: 100% !important; - box-sizing: border-box !important; - padding: 6px 10px !important; - margin: 2px 0 !important; + margin: 8px 0 !important; } - .task-message-content code { + .task-message-content .code-block-content pre { white-space: pre-wrap !important; word-wrap: break-word !important; word-break: break-word !important; overflow-wrap: break-word !important; max-width: 100% !important; - padding: 0 !important; + box-sizing: border-box !important; } - .task-message-content div[class*="language-"] { + /* For inline code and fallback code */ + .task-message-content code:not(.code-block-content code) { white-space: pre-wrap !important; word-wrap: break-word !important; word-break: break-word !important; overflow-wrap: break-word !important; - overflow: auto !important; max-width: 100% !important; - box-sizing: border-box !important; - padding: 6px 10px !important; - margin: 2px 0 !important; } /* Ensure the content of the SyntaxHighlighter component wraps correctly */ @@ -1347,17 +1484,26 @@ export function TaskWindow({ messages, isStreaming = false }: TaskWindowProps) { /* Make sure the entire container is not stretched by the content */ .task-message-content { - overflow: hidden !important; max-width: 100% !important; word-wrap: break-word !important; word-break: break-word !important; } + /* Allow code block container to overflow if needed for proper display */ + .task-message-content .code-block-container { + overflow: visible !important; + } + .task-message-content * { max-width: 100% !important; box-sizing: border-box !important; } + /* Exception for code block container - allow it to use its default overflow */ + .task-message-content .code-block-container * { + max-width: none !important; + } + /* Override diagram size in task window */ .task-message-content .my-4 { max-width: 200px !important; diff --git a/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx b/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx index 995eea580..014664450 100644 --- a/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx +++ b/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx @@ -155,6 +155,7 @@ function DataConfig({ isActive }: DataConfigProps) { const [hasClickedUpload, setHasClickedUpload] = useState(false); const [showEmbeddingWarning, setShowEmbeddingWarning] = useState(false); const [showAutoDeselectModal, setShowAutoDeselectModal] = useState(false); + const [newlyCreatedKbId, setNewlyCreatedKbId] = useState(null); // Track newly created KB waiting for documents const contentRef = useRef(null); // Open warning modal when single Embedding model is not configured (ignore multi-embedding) @@ -271,7 +272,7 @@ function DataConfig({ isActive }: DataConfigProps) { // Use saved state instead of current potentially cleared state const selectedKbNames = savedKnowledgeBasesRef.current .filter((kb) => savedSelectedIdsRef.current.includes(kb.id)) - .map((kb) => kb.name); + .map((kb) => kb.id); try { // Use fetch with keepalive to ensure request can be sent during page unload @@ -411,9 +412,13 @@ function DataConfig({ isActive }: DataConfigProps) { const isChangingKB = !kbState.activeKnowledgeBase || kb.id !== kbState.activeKnowledgeBase.id; - // If switching knowledge base, update active state + // If switching knowledge base, update active state and clear newly created flag if (isChangingKB) { setActiveKnowledgeBase(kb); + // Clear newly created flag when switching to a different knowledge base + if (newlyCreatedKbId !== null && newlyCreatedKbId !== kb.id) { + setNewlyCreatedKbId(null); + } } // Set active knowledge base ID to polling service @@ -606,22 +611,28 @@ function DataConfig({ isActive }: DataConfigProps) { setActiveKnowledgeBase(newKB); knowledgeBasePollingService.setActiveKnowledgeBase(newKB.id); setHasClickedUpload(false); + setNewlyCreatedKbId(newKB.id); // Mark this KB as newly created await uploadDocuments(newKB.id, filesToUpload); setUploadFiles([]); knowledgeBasePollingService .handleNewKnowledgeBaseCreation( + newKB.id, newKB.name, 0, filesToUpload.length, (populatedKB) => { setActiveKnowledgeBase(populatedKB); knowledgeBasePollingService.triggerKnowledgeBaseListUpdate(true); + // Clear the newly created flag when documents are ready + setNewlyCreatedKbId(null); } ) .catch((pollingError) => { log.error("Knowledge base creation polling failed:", pollingError); + // Clear the flag even on error to avoid stuck loading state + setNewlyCreatedKbId(null); }); } catch (error) { log.error(t("knowledgeBase.error.createUpload"), error); @@ -684,6 +695,12 @@ function DataConfig({ isActive }: DataConfigProps) { const viewingKbName = kbState.activeKnowledgeBase?.name || (isCreatingMode ? newKbName : ""); + // Check if current knowledge base is newly created and waiting for documents + const isNewlyCreatedAndWaiting = + newlyCreatedKbId !== null && + kbState.activeKnowledgeBase?.id === newlyCreatedKbId && + viewingDocuments.length === 0; + // As long as any document upload succeeds, immediately switch creation mode to false useEffect(() => { if (isCreatingMode && viewingDocuments.length > 0) { @@ -691,6 +708,13 @@ function DataConfig({ isActive }: DataConfigProps) { } }, [isCreatingMode, viewingDocuments.length]); + // Clear newly created flag when documents arrive + useEffect(() => { + if (newlyCreatedKbId !== null && viewingDocuments.length > 0) { + setNewlyCreatedKbId(null); + } + }, [newlyCreatedKbId, viewingDocuments.length]); + // Handle knowledge base selection const handleSelectKnowledgeBase = (id: string) => { hasUserInteractedRef.current = true; // Mark user interaction @@ -737,8 +761,6 @@ function DataConfig({ isActive }: DataConfigProps) { knowledgeBasePollingService.setActiveKnowledgeBase( kbState.activeKnowledgeBase.id ); - } else if (isCreatingMode && newKbName) { - knowledgeBasePollingService.setActiveKnowledgeBase(newKbName); } else { knowledgeBasePollingService.setActiveKnowledgeBase(null); } @@ -866,6 +888,7 @@ function DataConfig({ isActive }: DataConfigProps) { documents={[]} onDelete={() => {}} isCreatingMode={true} + knowledgeBaseId={""} knowledgeBaseName={newKbName} onNameChange={handleNameChange} containerHeight={SETUP_PAGE_CONTAINER.MAIN_CONTENT_HEIGHT} @@ -883,6 +906,7 @@ function DataConfig({ isActive }: DataConfigProps) { 0} + isNewlyCreatedAndWaiting={isNewlyCreatedAndWaiting} // Upload related props isDragging={uiState.isDragging} onDragOver={handleDragOver} diff --git a/frontend/app/[locale]/knowledges/components/document/DocumentList.tsx b/frontend/app/[locale]/knowledges/components/document/DocumentList.tsx index 0e2ab9961..ca7ba942e 100644 --- a/frontend/app/[locale]/knowledges/components/document/DocumentList.tsx +++ b/frontend/app/[locale]/knowledges/components/document/DocumentList.tsx @@ -23,7 +23,7 @@ import knowledgeBaseService from "@/services/knowledgeBaseService"; import { modelService } from "@/services/modelService"; import { Document } from "@/types/knowledgeBase"; import { ModelOption } from "@/types/modelConfig"; -import { formatFileSize, sortByStatusAndDate } from "@/lib/utils"; +import { formatFileSize } from "@/lib/utils"; import log from "@/lib/logger"; import { useConfig } from "@/hooks/useConfig"; @@ -47,7 +47,10 @@ const TITLE_BAR_HEIGHT_CLASS_MAP: Record = { interface DocumentListProps { documents: Document[]; onDelete: (id: string) => void; + // User-facing knowledge base name (display name) knowledgeBaseName?: string; + // Internal knowledge base ID / Elasticsearch index name + knowledgeBaseId?: string; modelMismatch?: boolean; currentModel?: string; knowledgeBaseModel?: string; @@ -56,6 +59,7 @@ interface DocumentListProps { isCreatingMode?: boolean; onNameChange?: (name: string) => void; hasDocuments?: boolean; + isNewlyCreatedAndWaiting?: boolean; // New prop to track newly created KB waiting for documents // Upload related props isDragging?: boolean; @@ -76,6 +80,7 @@ const DocumentListContainer = forwardRef( { documents, onDelete, + knowledgeBaseId = "", knowledgeBaseName = "", modelMismatch = false, currentModel = "", @@ -85,6 +90,7 @@ const DocumentListContainer = forwardRef( isCreatingMode = false, onNameChange, hasDocuments = false, + isNewlyCreatedAndWaiting = false, // New prop // Upload related props isDragging = false, @@ -106,8 +112,14 @@ const DocumentListContainer = forwardRef( const titleBarHeight = UI_CONFIG.TITLE_BAR_HEIGHT; const uploadHeight = UI_CONFIG.UPLOAD_COMPONENT_HEIGHT; - // Sort documents by status and date - const sortedDocuments = sortByStatusAndDate(documents); + // Sort documents by create_time (latest first) + const sortedDocuments = [...documents].sort((a, b) => { + const aTime = new Date(a.create_time).getTime(); + const bTime = new Date(b.create_time).getTime(); + const safeA = Number.isNaN(aTime) ? 0 : aTime; + const safeB = Number.isNaN(bTime) ? 0 : bTime; + return safeB - safeA; + }); // Get file icon const getFileIcon = (type: string): string => { @@ -156,6 +168,7 @@ const DocumentListContainer = forwardRef( React.useEffect(() => { setShowDetail(false); setShowChunk(false); + setSummary(""); }, [knowledgeBaseName]); // Load available models when showing detail @@ -242,10 +255,10 @@ const DocumentListContainer = forwardRef( // Get summary when showing detailed content React.useEffect(() => { const fetchSummary = async () => { - if (showDetail && knowledgeBaseName) { + if (showDetail && knowledgeBaseId) { try { const result = await knowledgeBaseService.getSummary( - knowledgeBaseName + knowledgeBaseId ); setSummary(result); } catch (error) { @@ -259,7 +272,7 @@ const DocumentListContainer = forwardRef( // Handle auto summary const handleAutoSummary = async () => { - if (!knowledgeBaseName) { + if (!knowledgeBaseId) { message.warning(t("document.summary.selectKnowledgeBase")); return; } @@ -269,7 +282,7 @@ const DocumentListContainer = forwardRef( try { const result = await knowledgeBaseService.summaryIndex( - knowledgeBaseName, + knowledgeBaseId, 1000, (newText) => { setSummary((prev) => prev + newText); @@ -293,7 +306,7 @@ const DocumentListContainer = forwardRef( // Handle save summary const handleSaveSummary = async () => { - if (!knowledgeBaseName) { + if (!knowledgeBaseId) { message.warning(t("document.summary.selectKnowledgeBase")); return; } @@ -305,7 +318,7 @@ const DocumentListContainer = forwardRef( setIsSaving(true); try { - await knowledgeBaseService.changeSummary(knowledgeBaseName, summary); + await knowledgeBaseService.changeSummary(knowledgeBaseId, summary); message.success(t("document.summary.saveSuccess")); } catch (error: any) { log.error(t("document.summary.saveError"), error); @@ -511,29 +524,42 @@ const DocumentListContainer = forwardRef( - ) : docState.isLoadingDocuments ? ( + ) : docState.isLoadingDocuments || isNewlyCreatedAndWaiting ? ( - {t("document.status.loadingList")} + {isNewlyCreatedAndWaiting + ? t("document.status.waitingForTask") + : t("document.status.loadingList")} ) : isCreatingMode ? ( - - - - + hasDocuments || isUploading || docState.isLoadingDocuments ? ( + + + + + {t("document.status.waitingForTask")} + - - {t("document.title.createNew")} - - - {t("document.hint.uploadToCreate")} - - + ) : ( + + + + + + + {t("document.title.createNew")} + + + {t("document.hint.uploadToCreate")} + + + + ) ) : sortedDocuments.length > 0 ? ( @@ -588,7 +614,14 @@ const DocumentListContainer = forwardRef( - + ( onDelete(doc.id)} className={LAYOUT.ACTION_TEXT} - disabled={ - doc.status === - DOCUMENT_STATUS.WAIT_FOR_PROCESSING || + title={ doc.status === DOCUMENT_STATUS.PROCESSING || - doc.status === - DOCUMENT_STATUS.WAIT_FOR_FORWARDING || doc.status === DOCUMENT_STATUS.FORWARDING + ? t("document.delete.terminateTask") + : undefined } > {t("common.delete")} @@ -645,10 +676,11 @@ const DocumentListContainer = forwardRef( onDragOver={onDragOver} onDragLeave={onDragLeave} onDrop={onDrop} - disabled={!isCreatingMode && !knowledgeBaseName} + disabled={!isCreatingMode && !knowledgeBaseId} componentHeight={uploadHeight} isCreatingMode={isCreatingMode} - indexName={knowledgeBaseName} + // Use internal ID for backend operations; fall back to name in creation mode + indexName={knowledgeBaseId || knowledgeBaseName} newKnowledgeBaseName={isCreatingMode ? knowledgeBaseName : ""} modelMismatch={modelMismatch} /> diff --git a/frontend/app/[locale]/knowledges/components/document/DocumentStatus.tsx b/frontend/app/[locale]/knowledges/components/document/DocumentStatus.tsx index 65b5b8219..f1c39af17 100644 --- a/frontend/app/[locale]/knowledges/components/document/DocumentStatus.tsx +++ b/frontend/app/[locale]/knowledges/components/document/DocumentStatus.tsx @@ -1,17 +1,44 @@ -import React from "react"; +import React, { useEffect, useState } from "react"; import { useTranslation } from "react-i18next"; +import { Popover, Progress } from "antd"; +import { QuestionCircleOutlined } from "@ant-design/icons"; import { DOCUMENT_STATUS } from "@/const/knowledgeBase"; +import knowledgeBaseService from "@/services/knowledgeBaseService"; +import log from "@/lib/logger"; interface DocumentStatusProps { status: string; showIcon?: boolean; + errorReason?: string; + suggestion?: string; + kbId?: string; + docId?: string; + // Optional ingestion progress metrics + processedChunkNum?: number | null; + totalChunkNum?: number | null; } export const DocumentStatus: React.FC = ({ status, showIcon = false, + errorReason, + suggestion, + kbId, + docId, + processedChunkNum, + totalChunkNum, }) => { const { t } = useTranslation(); + const [errorCodeState, setErrorCodeState] = useState(null); + const [isPopoverOpen, setIsPopoverOpen] = useState(false); + const [isFetching, setIsFetching] = useState(false); + const [hasFetched, setHasFetched] = useState(false); + + useEffect(() => { + // If parent props change (e.g. list refreshed), reset state + setErrorCodeState(null); + setHasFetched(false); + }, [kbId, docId]); // Map API status to display status const getDisplayStatus = (apiStatus: string): string => { @@ -102,12 +129,155 @@ export const DocumentStatus: React.FC = ({ const { bgColor, textColor, borderColor } = getStatusStyles(); const displayStatus = getDisplayStatus(status); + const isFailedStatus = + status === DOCUMENT_STATUS.PROCESS_FAILED || + status === DOCUMENT_STATUS.FORWARD_FAILED; + + const hasValidProgress = + typeof processedChunkNum === "number" && + typeof totalChunkNum === "number" && + totalChunkNum > 0; + + // Show progress for processing or forwarding status (入库中 corresponds to FORWARDING) + const shouldShowProgress = + (status === DOCUMENT_STATUS.PROCESSING || + status === DOCUMENT_STATUS.FORWARDING) && + hasValidProgress; + + const progressPercent = hasValidProgress + ? Math.min( + 100, + Math.max(0, Math.round((processedChunkNum / totalChunkNum) * 100)) + ) + : 0; + + // Get localized error message from error code + const getLocalizedError = (errorCode: string | null) => { + if (!errorCode) return { message: null, suggestion: null }; + + const messageKey = `document.error.code.${errorCode}.message`; + const suggestionKey = `document.error.code.${errorCode}.suggestion`; + + const message = t(messageKey, { defaultValue: null }); + const suggestion = t(suggestionKey, { defaultValue: null }); + + return { + message: message !== messageKey ? message : null, + suggestion: suggestion !== suggestionKey ? suggestion : null, + }; + }; + + const fetchErrorInfo = async () => { + if (!kbId || !docId) return; + setIsFetching(true); + try { + const result = await knowledgeBaseService.getDocumentErrorInfo( + kbId, + docId + ); + + // Set error code - frontend will handle localization + setErrorCodeState(result.errorCode ?? null); + } catch (error) { + log.error("Failed to fetch document error info:", error); + } finally { + setIsFetching(false); + setHasFetched(true); + } + }; + + const handlePopoverVisibleChange = (visible: boolean) => { + setIsPopoverOpen(visible); + if ( + visible && + kbId && + docId && + !isFetching && + !hasFetched && + !errorCodeState + ) { + fetchErrorInfo(); + } + }; + + // Get localized error messages from error code + const localizedError = getLocalizedError(errorCodeState); + + const popoverContent = ( + + {isFetching ? ( + {t("common.loading")} + ) : localizedError.message ? ( + + + + {localizedError.message} + + + {localizedError.suggestion && ( + + + {t("document.error.suggestion")} + + + {localizedError.suggestion} + + + )} + + ) : ( + + {t("document.error.noReason")} + + )} + + ); + return ( {showIcon && {getStatusIcon()}} {displayStatus} + {shouldShowProgress && hasValidProgress && ( + + {t("document.progress.chunksProcessed", { + processed: processedChunkNum, + total: totalChunkNum, + percent: progressPercent, + })} + + } + placement="top" + > + + + + + )} + {isFailedStatus && ( + + + + )} ); }; diff --git a/frontend/app/[locale]/knowledges/components/knowledge/KnowledgeBaseList.tsx b/frontend/app/[locale]/knowledges/components/knowledge/KnowledgeBaseList.tsx index 09201e870..d49883d1c 100644 --- a/frontend/app/[locale]/knowledges/components/knowledge/KnowledgeBaseList.tsx +++ b/frontend/app/[locale]/knowledges/components/knowledge/KnowledgeBaseList.tsx @@ -73,27 +73,53 @@ const KnowledgeBaseList: React.FC = ({ const { t } = useTranslation(); // Format date function, only keep date part - const formatDate = (dateString: number) => { + const formatDate = (dateValue: any) => { try { - const date = new Date(dateString); - return date.toISOString().split('T')[0]; // Only return YYYY-MM-DD part + const date = + typeof dateValue === "number" + ? new Date(dateValue) + : new Date(dateValue); + return isNaN(date.getTime()) + ? String(dateValue ?? "") + : date.toISOString().split("T")[0]; // Only return YYYY-MM-DD part } catch (e) { - return dateString; // If parsing fails, return original string + return String(dateValue ?? ""); // If parsing fails, return original string } }; + // Helper to safely extract timestamp for sorting + const getTimestamp = (value: any): number => { + if (!value) return 0; + if (typeof value === "number") return value; + const t = Date.parse(value); + return Number.isNaN(t) ? 0 : t; + }; + + // Sort knowledge bases by update time (fallback to creation time), latest first + const sortedKnowledgeBases = [...knowledgeBases].sort((a, b) => { + const aTime = getTimestamp(a.updatedAt ?? a.createdAt); + const bTime = getTimestamp(b.updatedAt ?? b.createdAt); + return bTime - aTime; + }); return ( - + {/* Fixed header area */} - + - - {t('knowledgeBase.list.title')} + + {t("knowledgeBase.list.title")} - + = ({ gap: "8px", backgroundColor: "#1677ff", color: "white", - border: "none" + border: "none", }} className="hover:!bg-blue-600" type="primary" onClick={onCreateNew} icon={} > - {t('knowledgeBase.button.create')} + {t("knowledgeBase.button.create")} = ({ gap: "8px", backgroundColor: "#1677ff", color: "white", - border: "none" + border: "none", }} className="hover:!bg-blue-600" type="primary" onClick={onSync} > - + - {t('knowledgeBase.button.sync')} + {t("knowledgeBase.button.sync")} @@ -146,9 +174,15 @@ const KnowledgeBaseList: React.FC = ({ - {t('knowledgeBase.selected.prefix')} - {selectedIds.length} - {t('knowledgeBase.selected.suffix')} + + {t("knowledgeBase.selected.prefix")}{" "} + + + {selectedIds.length} + + + {t("knowledgeBase.selected.suffix")} + {selectedIds.length > 0 && ( @@ -190,23 +224,29 @@ const KnowledgeBaseList: React.FC = ({ {/* Scrollable knowledge base list area */} - {knowledgeBases.length > 0 ? ( + {sortedKnowledgeBases.length > 0 ? ( - {knowledgeBases.map((kb, index) => { - const canSelect = isSelectable(kb) - const isSelected = selectedIds.includes(kb.id) - const isActive = activeKnowledgeBase?.id === kb.id - const isMismatchedAndSelected = isSelected && !canSelect + {sortedKnowledgeBases.map((kb, index) => { + const canSelect = isSelectable(kb); + const isSelected = selectedIds.includes(kb.id); + const isActive = activeKnowledgeBase?.id === kb.id; + const isMismatchedAndSelected = isSelected && !canSelect; return ( 0 ? "border-t border-gray-200" : ""}`} + className={`${ + KB_LAYOUT.ROW_PADDING + } px-2 hover:bg-gray-50 cursor-pointer transition-colors ${ + index > 0 ? "border-t border-gray-200" : "" + }`} style={{ - borderLeftWidth: '4px', - borderLeftStyle: 'solid', - borderLeftColor: isActive ? '#3b82f6' : 'transparent', - backgroundColor: isActive ? 'rgb(226, 240, 253)' : 'inherit' + borderLeftWidth: "4px", + borderLeftStyle: "solid", + borderLeftColor: isActive ? "#3b82f6" : "transparent", + backgroundColor: isActive + ? "rgb(226, 240, 253)" + : "inherit", }} onClick={() => { onClick(kb); @@ -215,37 +255,45 @@ const KnowledgeBaseList: React.FC = ({ > - { - e.stopPropagation(); - if (canSelect || isSelected) { - onSelect(kb.id); - } - }} + { + e.stopPropagation(); + if (canSelect || isSelected) { + onSelect(kb.id); + } + }} style={{ - minWidth: '40px', - minHeight: '40px', - display: 'flex', - alignItems: 'flex-start', - justifyContent: 'center' - }}> + minWidth: "40px", + minHeight: "40px", + display: "flex", + alignItems: "flex-start", + justifyContent: "center", + }} + > { - e.stopPropagation() - onSelect(kb.id) + e.stopPropagation(); + onSelect(kb.id); }} disabled={!canSelect && !isSelected} style={{ - cursor: (canSelect || isSelected) ? 'pointer' : 'not-allowed', - transform: 'scale(1.5)', + cursor: + canSelect || isSelected + ? "pointer" + : "not-allowed", + transform: "scale(1.5)", }} /> @@ -257,7 +305,7 @@ const KnowledgeBaseList: React.FC = ({ className="text-base font-medium text-gray-800 truncate" style={{ maxWidth: KB_LAYOUT.KB_NAME_MAX_WIDTH, - ...KB_LAYOUT.KB_NAME_OVERFLOW + ...KB_LAYOUT.KB_NAME_OVERFLOW, }} title={kb.name} > @@ -266,56 +314,93 @@ const KnowledgeBaseList: React.FC = ({ { - e.stopPropagation() - onDelete(kb.id) + e.stopPropagation(); + onDelete(kb.id); }} > - {t('common.delete')} + {t("common.delete")} - + {/* Document count tag */} - - {t('knowledgeBase.tag.documents', { count: kb.documentCount || 0 })} + + {t("knowledgeBase.tag.documents", { + count: kb.documentCount || 0, + })} {/* Chunk count tag */} - - {t('knowledgeBase.tag.chunks', { count: kb.chunkCount || 0 })} + + {t("knowledgeBase.tag.chunks", { + count: kb.chunkCount || 0, + })} - {/* Knowledge base source tag */} - - {t('knowledgeBase.tag.source', { source: kb.source })} - + {/* Only show source, creation date, and model tags when there are valid documents or chunks */} + {((kb.documentCount || 0) > 0 || + (kb.chunkCount || 0) > 0) && ( + <> + {/* Knowledge base source tag */} + + {t("knowledgeBase.tag.source", { + source: kb.source, + })} + - {/* Creation date tag - only show date */} - - {t('knowledgeBase.tag.createdAt', { date: formatDate(kb.createdAt) })} - + {/* Creation date tag - only show date */} + + {t("knowledgeBase.tag.createdAt", { + date: formatDate(kb.createdAt), + })} + - {/* Force line break */} - + {/* Force line break */} + - {/* Model tag - show normal or mismatch */} - - {t('knowledgeBase.tag.model', { model: getModelDisplayName(kb.embeddingModel) })} - - {kb.embeddingModel !== "unknown" && kb.embeddingModel !== currentEmbeddingModel && ( - - {t('knowledgeBase.tag.modelMismatch')} - + {/* Model tag - only show when model is not "unknown" */} + {kb.embeddingModel !== "unknown" && ( + + {t("knowledgeBase.tag.model", { + model: getModelDisplayName(kb.embeddingModel), + })} + + )} + {kb.embeddingModel !== "unknown" && + kb.embeddingModel !== currentEmbeddingModel && ( + + {t("knowledgeBase.tag.modelMismatch")} + + )} + > )} - ) + ); })} ) : ( - - {t('knowledgeBase.list.empty')} + + {t("knowledgeBase.list.empty")} )} diff --git a/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx b/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx index a91093a73..c866600fd 100644 --- a/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx +++ b/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx @@ -238,35 +238,49 @@ export const KnowledgeBaseProvider: React.FC = ({ ch try { const userConfig = await userConfigService.loadKnowledgeList(); if (userConfig && userConfig.selectedKbNames.length > 0) { - // Find matching knowledge base IDs based on names + // Find matching knowledge base IDs based on index names const selectedIds = state.knowledgeBases - .filter(kb => userConfig.selectedKbNames.includes(kb.name)) - .map(kb => kb.id); + .filter((kb) => userConfig.selectedKbNames.includes(kb.id)) + .map((kb) => kb.id); - dispatch({ type: KNOWLEDGE_BASE_ACTION_TYPES.SELECT_KNOWLEDGE_BASE, payload: selectedIds }); + dispatch({ + type: KNOWLEDGE_BASE_ACTION_TYPES.SELECT_KNOWLEDGE_BASE, + payload: selectedIds, + }); } } catch (error) { - log.error(t('knowledgeBase.error.loadSelected'), error); - dispatch({ type: KNOWLEDGE_BASE_ACTION_TYPES.ERROR, payload: t('knowledgeBase.error.loadSelectedRetry') }); + log.error(t("knowledgeBase.error.loadSelected"), error); + dispatch({ + type: KNOWLEDGE_BASE_ACTION_TYPES.ERROR, + payload: t("knowledgeBase.error.loadSelectedRetry"), + }); } }, [state.knowledgeBases]); // Save user selected knowledge bases to backend const saveUserSelectedKnowledgeBases = useCallback(async () => { try { - // Get selected knowledge base names + // Get selected knowledge base index names (globally unique identifiers) const selectedKbNames = state.knowledgeBases - .filter(kb => state.selectedIds.includes(kb.id)) - .map(kb => kb.name); + .filter((kb) => state.selectedIds.includes(kb.id)) + .map((kb) => kb.id); - const success = await userConfigService.updateKnowledgeList(selectedKbNames); + const success = await userConfigService.updateKnowledgeList( + selectedKbNames + ); if (!success) { - dispatch({ type: KNOWLEDGE_BASE_ACTION_TYPES.ERROR, payload: t('knowledgeBase.error.saveSelected') }); + dispatch({ + type: KNOWLEDGE_BASE_ACTION_TYPES.ERROR, + payload: t("knowledgeBase.error.saveSelected"), + }); } return success; } catch (error) { - log.error(t('knowledgeBase.error.saveSelected'), error); - dispatch({ type: KNOWLEDGE_BASE_ACTION_TYPES.ERROR, payload: t('knowledgeBase.error.saveSelectedRetry') }); + log.error(t("knowledgeBase.error.saveSelected"), error); + dispatch({ + type: KNOWLEDGE_BASE_ACTION_TYPES.ERROR, + payload: t("knowledgeBase.error.saveSelectedRetry"), + }); return false; } }, [state.knowledgeBases, state.selectedIds, t]); diff --git a/frontend/app/[locale]/market/MarketContent.tsx b/frontend/app/[locale]/market/MarketContent.tsx index 984b0f73d..9d66050ef 100644 --- a/frontend/app/[locale]/market/MarketContent.tsx +++ b/frontend/app/[locale]/market/MarketContent.tsx @@ -3,7 +3,7 @@ import React, { useState, useEffect } from "react"; import { motion } from "framer-motion"; import { useTranslation } from "react-i18next"; -import { ShoppingBag, Search, RefreshCw, AlertCircle } from "lucide-react"; +import { ShoppingBag, Search, RefreshCw } from "lucide-react"; import { Tabs, Input, Spin, Empty, Pagination, App } from "antd"; import log from "@/lib/logger"; @@ -18,7 +18,8 @@ import { import marketService, { MarketApiError } from "@/services/marketService"; import { AgentMarketCard } from "./components/AgentMarketCard"; import MarketAgentDetailModal from "./components/MarketAgentDetailModal"; -import AgentInstallModal from "./components/AgentInstallModal"; +import AgentImportWizard from "@/components/agent/AgentImportWizard"; +import { ImportAgentData } from "@/hooks/useAgentImport"; import MarketErrorState from "./components/MarketErrorState"; interface MarketContentProps { @@ -261,12 +262,7 @@ export default function MarketContent({ }, ...categories.map((cat) => ({ key: cat.name, - label: ( - - {cat.icon} - {isZh ? cat.display_name_zh : cat.display_name} - - ), + label: isZh ? cat.display_name_zh : cat.display_name, })), ]; @@ -451,11 +447,22 @@ export default function MarketContent({ /> {/* Agent Install Modal */} - ) : null} diff --git a/frontend/app/[locale]/market/components/AgentInstallModal.tsx b/frontend/app/[locale]/market/components/AgentInstallModal.tsx deleted file mode 100644 index a23e7970c..000000000 --- a/frontend/app/[locale]/market/components/AgentInstallModal.tsx +++ /dev/null @@ -1,41 +0,0 @@ -"use client"; - -import React from "react"; -import { MarketAgentDetail } from "@/types/market"; -import { ImportAgentData } from "@/hooks/useAgentImport"; -import AgentImportWizard from "@/components/agent/AgentImportWizard"; - -interface AgentInstallModalProps { - visible: boolean; - onCancel: () => void; - agentDetails: MarketAgentDetail | null; - onInstallComplete?: () => void; -} - -export default function AgentInstallModal({ - visible, - onCancel, - agentDetails, - onInstallComplete, -}: AgentInstallModalProps) { - // Convert MarketAgentDetail to ImportAgentData format - const importData: ImportAgentData | null = agentDetails?.agent_json - ? { - agent_id: agentDetails.agent_id, - agent_info: agentDetails.agent_json.agent_info, - mcp_info: agentDetails.agent_json.mcp_info, - } - : null; - - return ( - - ); -} diff --git a/frontend/app/[locale]/market/components/AgentMarketCard.tsx b/frontend/app/[locale]/market/components/AgentMarketCard.tsx index 317f5bec5..3a0102adf 100644 --- a/frontend/app/[locale]/market/components/AgentMarketCard.tsx +++ b/frontend/app/[locale]/market/components/AgentMarketCard.tsx @@ -6,6 +6,7 @@ import { Download, Tag, Wrench } from "lucide-react"; import { MarketAgentListItem } from "@/types/market"; import { useTranslation } from "react-i18next"; import { getGenericLabel } from "@/lib/agentLabelMapper"; +import { getCategoryIcon } from "@/const/marketConfig"; interface AgentMarketCardProps { agent: MarketAgentListItem; @@ -34,6 +35,11 @@ export function AgentMarketCard({ onViewDetails(agent); }; + // Get category icon: prefer API icon, then fallback to default mapping by name + const categoryIcon = agent.category + ? agent.category.icon || getCategoryIcon(agent.category.name) + : "📦"; + return ( - {agent.category?.icon || "📦"} + {categoryIcon} {agent.category @@ -64,6 +70,13 @@ export function AgentMarketCard({ {agent.display_name} + {agent.author ? ( + + {t("market.by", { defaultValue: "By {{author}}", author: agent.author })} + + ) : ( + + )} {/* Card body */} diff --git a/frontend/app/[locale]/market/components/MarketAgentDetailModal.tsx b/frontend/app/[locale]/market/components/MarketAgentDetailModal.tsx index e21846ec2..4781b4b76 100644 --- a/frontend/app/[locale]/market/components/MarketAgentDetailModal.tsx +++ b/frontend/app/[locale]/market/components/MarketAgentDetailModal.tsx @@ -13,6 +13,7 @@ import { } from "lucide-react"; import { MarketAgentDetail } from "@/types/market"; import { getToolSourceLabel, getGenericLabel } from "@/lib/agentLabelMapper"; +import { getCategoryIcon } from "@/const/marketConfig"; interface MarketAgentDetailModalProps { visible: boolean; @@ -87,6 +88,11 @@ export default function MarketAgentDetailModal({ > {agentDetails?.display_name || "-"} + + {agentDetails?.author || "-"} + @@ -97,7 +103,10 @@ export default function MarketAgentDetailModal({ > {agentDetails?.category ? ( - {agentDetails.category.icon || "📦"} + + {agentDetails.category.icon || + getCategoryIcon(agentDetails.category.name)} + {isZh ? agentDetails.category.display_name_zh diff --git a/frontend/app/[locale]/memory/MemoryContent.tsx b/frontend/app/[locale]/memory/MemoryContent.tsx index dd35eab82..370e220b3 100644 --- a/frontend/app/[locale]/memory/MemoryContent.tsx +++ b/frontend/app/[locale]/memory/MemoryContent.tsx @@ -1,6 +1,6 @@ "use client"; -import React, { useEffect, useState } from "react"; +import React, { useEffect, useState, useCallback } from "react"; import { App, Button, Card, Input, List, Menu, Switch, Tabs } from "antd"; import { motion } from "framer-motion"; import "./memory.css"; @@ -188,9 +188,16 @@ export default function MemoryContent({ onNavigate }: MemoryContentProps) { }; // Render single list (for tenant shared and user personal) - no card, with header buttons - const renderSingleList = (group: { title: string; key: string; items: any[] }) => { + const renderSingleList = useCallback((group: { title: string; key: string; items: any[] }) => { return ( - + + {/* Add memory input - appears before the list */} + {memory.addingMemoryKey === group.key && ( + + {renderAddMemoryInput(group.key)} + + )} + @@ -200,7 +207,9 @@ export default function MemoryContent({ onNavigate }: MemoryContentProps) { type="text" size="small" icon={} - onClick={() => memory.startAddingMemory(group.key)} + onClick={() => { + memory.startAddingMemory(group.key); + }} className="hover:bg-green-50 hover:text-green-600" title={t("memoryManageModal.addMemory")} /> @@ -228,7 +237,7 @@ export default function MemoryContent({ onNavigate }: MemoryContentProps) { ), }} - style={{ height: "calc(100vh - 280px)", overflowY: "auto" }} + style={{ height: memory.addingMemoryKey === group.key ? "calc(100vh - 380px)" : "calc(100vh - 280px)", overflowY: "auto" }} renderItem={(item) => ( {item.memory} )} - > - {memory.addingMemoryKey === group.key && ( - - {renderAddMemoryInput(group.key)} - - )} - + /> ); - }; + }, [memory.addingMemoryKey, memory.startAddingMemory, memory.handleDeleteMemory, handleClearConfirm, renderAddMemoryInput, t]); const renderMemoryWithMenu = ( groups: { title: string; key: string; items: any[] }[], @@ -463,6 +463,13 @@ function MemoryMenuList({ /> + {/* Add memory input - appears before the list */} + {memory.addingMemoryKey === currentGroup.key && ( + + {renderAddMemoryInput(currentGroup.key)} + + )} + @@ -472,7 +479,9 @@ function MemoryMenuList({ type="text" size="small" icon={} - onClick={() => memory.startAddingMemory(currentGroup.key)} + onClick={() => { + memory.startAddingMemory(currentGroup.key); + }} disabled={disabled} className="hover:bg-green-50 hover:text-green-600" title={t("memoryManageModal.addMemory")} @@ -504,7 +513,7 @@ function MemoryMenuList({ ), }} - style={{ height: "100%", overflowY: "auto" }} + style={{ height: memory.addingMemoryKey === currentGroup.key ? "calc(100% - 100px)" : "100%", overflowY: "auto" }} renderItem={(item) => ( {item.memory} )} - > - {memory.addingMemoryKey === currentGroup.key && ( - - {renderAddMemoryInput(currentGroup.key)} - - )} - + /> ); diff --git a/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx b/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx index 0d1bc747c..1ce0995cf 100644 --- a/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx +++ b/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx @@ -175,6 +175,7 @@ export const ModelAddDialog = ({ DEFAULT_EXPECTED_CHUNK_SIZE, DEFAULT_MAXIMUM_CHUNK_SIZE, ] as [number, number], + chunkingBatchSize: "10", }); const [loading, setLoading] = useState(false); const [verifyingConnectivity, setVerifyingConnectivity] = useState(false); @@ -290,6 +291,10 @@ export const ModelAddDialog = ({ ) { setConnectivityStatus({ status: null, message: "" }); } + // Clear model search term when model type changes + if (field === "type") { + setModelSearchTerm(""); + } }; // Verify if the vector dimension is valid @@ -413,21 +418,35 @@ export const ModelAddDialog = ({ ? (MODEL_TYPES.MULTI_EMBEDDING as ModelType) : form.type; try { + const isEmbeddingType = + modelType === MODEL_TYPES.EMBEDDING || + modelType === MODEL_TYPES.MULTI_EMBEDDING; const result = await modelService.addBatchCustomModel({ api_key: form.apiKey.trim() === "" ? "sk-no-api-key" : form.apiKey, provider: form.provider, type: modelType, - models: enabledModels.map((model: any) => ({ - ...model, - max_tokens: model.max_tokens || parseInt(form.maxTokens) || 4096, - // Add chunk size range for embedding models - ...(isEmbeddingModel - ? { - expected_chunk_size: form.chunkSizeRange[0], - maximum_chunk_size: form.chunkSizeRange[1], - } - : {}), - })), + models: enabledModels.map((model: any) => { + // For embedding/multi_embedding models, explicitly exclude max_tokens as backend will set it via connectivity check + if (isEmbeddingType) { + const { max_tokens, ...modelWithoutMaxTokens } = model; + return { + ...modelWithoutMaxTokens, + // Add chunk size range for embedding models + ...(isEmbeddingModel + ? { + expected_chunk_size: form.chunkSizeRange[0], + maximum_chunk_size: form.chunkSizeRange[1], + chunk_batch: parseInt(form.chunkingBatchSize) || 10, + } + : {}), + }; + } else { + return { + ...model, + max_tokens: model.max_tokens || parseInt(form.maxTokens) || 4096, + }; + } + }), }); if (result === 200) { onSuccess(); @@ -515,6 +534,7 @@ export const ModelAddDialog = ({ ? { expectedChunkSize: form.chunkSizeRange[0], maximumChunkSize: form.chunkSizeRange[1], + chunkingBatchSize: parseInt(form.chunkingBatchSize) || 10, } : {}), }); @@ -587,6 +607,7 @@ export const ModelAddDialog = ({ DEFAULT_EXPECTED_CHUNK_SIZE, DEFAULT_MAXIMUM_CHUNK_SIZE, ], + chunkingBatchSize: "10", }); // Reset the connectivity status @@ -803,6 +824,26 @@ export const ModelAddDialog = ({ )} + {/* Concurrent Request Count (Embedding model only) */} + {isEmbeddingModel && ( + + + {t("modelConfig.input.chunkingBatchSize")} + + handleFormChange("chunkingBatchSize", e.target.value)} + /> + + )} + {/* Vector dimension */} {isEmbeddingModel && ( @@ -1177,7 +1218,7 @@ export const ModelAddDialog = ({ onCancel={() => setSettingsModalVisible(false)} onOk={handleSettingsSave} cancelText={t("common.cancel")} - okText={t("common.ok")} + okText={t("common.confirm")} destroyOnClose > diff --git a/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx b/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx index 1e5405049..84939e5b1 100644 --- a/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx +++ b/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx @@ -17,6 +17,11 @@ import { ModelOption, ModelType, ModelSource } from "@/types/modelConfig"; import log from "@/lib/logger"; import { ModelEditDialog, ProviderConfigEditDialog } from "./ModelEditDialog"; +import { + ModelChunkSizeSlider, + DEFAULT_EXPECTED_CHUNK_SIZE, + DEFAULT_MAXIMUM_CHUNK_SIZE, +} from "./ModelChunkSizeSilder"; interface ModelDeleteDialogProps { isOpen: boolean; @@ -59,6 +64,18 @@ export const ModelDeleteDialog = ({ const [modelMaxTokens, setModelMaxTokens] = useState("4096"); const [providerModelSearchTerm, setProviderModelSearchTerm] = useState(""); + // Embedding model chunk config modal state + const [embeddingConfigModalVisible, setEmbeddingConfigModalVisible] = + useState(false); + const [selectedEmbeddingModel, setSelectedEmbeddingModel] = + useState(null); + const [chunkSizeRange, setChunkSizeRange] = useState<[number, number]>([ + DEFAULT_EXPECTED_CHUNK_SIZE, + DEFAULT_MAXIMUM_CHUNK_SIZE, + ]); + const [chunkingBatchSize, setChunkingBatchSize] = useState("10"); + const [savingEmbeddingConfig, setSavingEmbeddingConfig] = useState(false); + // Get model color scheme const getModelColorScheme = ( type: ModelType @@ -219,11 +236,7 @@ export const ModelDeleteDialog = ({ switch (source) { case MODEL_SOURCES.SILICON: return ( - + ); case MODEL_SOURCES.MODELENGINE: return ( @@ -455,7 +468,6 @@ export const ModelDeleteDialog = ({ }); }, [providerModels, providerModelSearchTerm]); - // Handle provider config save const handleProviderConfigSave = async ({ apiKey, @@ -491,9 +503,7 @@ export const ModelDeleteDialog = ({ maxTokens: maxTokens || m.maxTokens, })); - await modelService.updateBatchModel( - currentModelPayloads - ); + await modelService.updateBatchModel(currentModelPayloads); // Show success message since no exception was thrown message.success(t("model.dialog.success.updateSuccess")); @@ -538,6 +548,104 @@ export const ModelDeleteDialog = ({ setSelectedModelForSettings(null); }; + // Handle embedding model click to open config modal + const handleEmbeddingModelClick = (model: ModelOption | any) => { + const isEmbeddingModel = + model.type === MODEL_TYPES.EMBEDDING || + model.type === MODEL_TYPES.MULTI_EMBEDDING || + model.model_type === MODEL_TYPES.EMBEDDING || + model.model_type === MODEL_TYPES.MULTI_EMBEDDING; + if (isEmbeddingModel) { + // If it's a providerModel (not yet added to system), find the corresponding model in models list + if (model.id && !model.name) { + // This is a providerModel, find the corresponding model in models list + const existingModel = models.find( + (m) => + m.name === model.id && + m.type === (model.model_type || deletingModelType) && + m.source === selectedSource + ); + if (existingModel) { + setSelectedEmbeddingModel(existingModel); + setChunkSizeRange([ + existingModel.expectedChunkSize || DEFAULT_EXPECTED_CHUNK_SIZE, + existingModel.maximumChunkSize || DEFAULT_MAXIMUM_CHUNK_SIZE, + ]); + setChunkingBatchSize( + (existingModel.chunkingBatchSize || 10).toString() + ); + } else { + // Model not yet added, use default values + setSelectedEmbeddingModel({ + ...model, + name: model.id, + displayName: model.id, + type: model.model_type || deletingModelType, + source: selectedSource, + expectedChunkSize: DEFAULT_EXPECTED_CHUNK_SIZE, + maximumChunkSize: DEFAULT_MAXIMUM_CHUNK_SIZE, + chunkingBatchSize: 10, + } as ModelOption); + setChunkSizeRange([ + DEFAULT_EXPECTED_CHUNK_SIZE, + DEFAULT_MAXIMUM_CHUNK_SIZE, + ]); + setChunkingBatchSize("10"); + } + } else { + // This is a ModelOption from models list + setSelectedEmbeddingModel(model); + setChunkSizeRange([ + model.expectedChunkSize || DEFAULT_EXPECTED_CHUNK_SIZE, + model.maximumChunkSize || DEFAULT_MAXIMUM_CHUNK_SIZE, + ]); + setChunkingBatchSize((model.chunkingBatchSize || 10).toString()); + } + setEmbeddingConfigModalVisible(true); + } + }; + + // Handle embedding config save + const handleEmbeddingConfigSave = async () => { + if (!selectedEmbeddingModel) return; + + setSavingEmbeddingConfig(true); + try { + // Get the display name - use the one from existing model if available + const displayName = + selectedEmbeddingModel.displayName || selectedEmbeddingModel.name; + const apiKey = + selectedEmbeddingModel.apiKey || getApiKeyByType(deletingModelType); + + await modelService.updateSingleModel({ + currentDisplayName: displayName, + url: selectedEmbeddingModel.apiUrl || "", + apiKey: apiKey || "sk-no-api-key", + source: selectedEmbeddingModel.source || selectedSource, + expectedChunkSize: chunkSizeRange[0], + maximumChunkSize: chunkSizeRange[1], + chunkingBatchSize: parseInt(chunkingBatchSize) || 10, + }); + + message.success(t("model.dialog.editSuccess")); + setEmbeddingConfigModalVisible(false); + setSelectedEmbeddingModel(null); + // Refresh model list to reflect changes + await onSuccess(); + } catch (error: any) { + log.error("Failed to save embedding model config:", error); + if (error.code === 404) { + message.error(t("model.dialog.error.modelNotFound")); + } else if (error.code === 500) { + message.error(t("model.dialog.error.serverError")); + } else { + message.error(t("model.dialog.error.editFailed")); + } + } finally { + setSavingEmbeddingConfig(false); + } + }; + return ( // Refactor: Styles are embedded within the component ({ - ...model, - max_tokens: model.max_tokens || 4096, // Ensure max_tokens is always present - })), + models: allEnabledModels.map((model) => { + if (isEmbeddingType) { + const { max_tokens, ...modelWithoutMaxTokens } = + model; + return modelWithoutMaxTokens; + } else { + return { + ...model, + max_tokens: model.max_tokens || 4096, + }; + } + }), }); } @@ -882,12 +1002,36 @@ export const ModelDeleteDialog = ({ const checked = pendingSelectedProviderIds.has( providerModel.id ); + const isEmbeddingModel = + deletingModelType === MODEL_TYPES.EMBEDDING || + deletingModelType === MODEL_TYPES.MULTI_EMBEDDING || + providerModel.model_type === MODEL_TYPES.EMBEDDING || + providerModel.model_type === MODEL_TYPES.MULTI_EMBEDDING; + // Check if this model is already added to the system + const existingModel = models.find( + (m) => + m.name === providerModel.id && + m.type === + (providerModel.model_type || deletingModelType) && + m.source === selectedSource + ); + const canEditEmbedding = isEmbeddingModel && existingModel; + return ( - + handleEmbeddingModelClick(providerModel) + : undefined + } + > {providerModel.id} @@ -898,25 +1042,33 @@ export const ModelDeleteDialog = ({ )} - {deletingModelType !== "embedding" && ( - - } - size="small" - onClick={(e) => { - e.stopPropagation(); // Prevent switch toggle - handleSettingsClick(providerModel); - }} - /> - - )} + {deletingModelType !== "embedding" && + deletingModelType !== MODEL_TYPES.MULTI_EMBEDDING && ( + + } + size="small" + onClick={(e) => { + e.stopPropagation(); // Prevent switch toggle + handleSettingsClick(providerModel); + }} + /> + + )} { + onChange={(value, event) => { + // Ensure toggling switch never triggers the row click handler + if ( + event && + typeof event.stopPropagation === "function" + ) { + event.stopPropagation(); + } setPendingSelectedProviderIds((prev) => { const next = new Set(prev); if (value) { @@ -941,75 +1093,94 @@ export const ModelDeleteDialog = ({ model.type === deletingModelType && model.source === selectedSource ) - .map((model) => ( - handleEditModel(model) - : undefined - } - className={`p-2 flex justify-between items-center hover:bg-gray-50 text-sm ${ - selectedSource === MODEL_SOURCES.OPENAI_API_COMPATIBLE - ? "cursor-pointer" - : "" - }`} - > - - - {model.displayName || model.name} ({model.name}) - - - { - e.stopPropagation(); - handleDeleteModel(model.displayName || model.name); - }} - disabled={ - deletingModels.has(model.displayName || model.name) || - model.type === MODEL_TYPES.STT || - model.type === MODEL_TYPES.TTS + .map((model) => { + const isEmbeddingModel = + model.type === MODEL_TYPES.EMBEDDING || + model.type === MODEL_TYPES.MULTI_EMBEDDING; + // Only allow clicking for batch-imported embedding models (not custom models) + const isBatchImportedEmbedding = + isEmbeddingModel && + selectedSource !== MODEL_SOURCES.OPENAI_API_COMPATIBLE; + // Custom models can still be clicked to edit full model config + const isCustomModelClickable = + selectedSource === MODEL_SOURCES.OPENAI_API_COMPATIBLE; + const isClickable = + isBatchImportedEmbedding || isCustomModelClickable; + + return ( + + isBatchImportedEmbedding + ? handleEmbeddingModelClick(model) + : handleEditModel(model) + : undefined } - className={`p-1 ${ - model.type === MODEL_TYPES.STT || - model.type === MODEL_TYPES.TTS - ? "text-gray-400 cursor-not-allowed" - : "text-red-500 hover:text-red-700" + className={`p-2 flex justify-between items-center hover:bg-gray-50 text-sm ${ + isClickable ? "cursor-pointer" : "" }`} - title={ - model.type === MODEL_TYPES.STT || - model.type === MODEL_TYPES.TTS - ? t("model.dialog.delete.unsupportedTypeHint") - : t("model.dialog.delete.deleteHint") - } > - {deletingModels.has(model.displayName || model.name) ? ( - + - - - - ) : ( - - )} - - - ))} + {model.displayName || model.name} ({model.name}) + + + { + e.stopPropagation(); + handleDeleteModel(model.displayName || model.name); + }} + disabled={ + deletingModels.has(model.displayName || model.name) || + model.type === MODEL_TYPES.STT || + model.type === MODEL_TYPES.TTS + } + className={`p-1 ${ + model.type === MODEL_TYPES.STT || + model.type === MODEL_TYPES.TTS + ? "text-gray-400 cursor-not-allowed" + : "text-red-500 hover:text-red-700" + }`} + title={ + model.type === MODEL_TYPES.STT || + model.type === MODEL_TYPES.TTS + ? t("model.dialog.delete.unsupportedTypeHint") + : t("model.dialog.delete.deleteHint") + } + > + {deletingModels.has(model.displayName || model.name) ? ( + + + + + ) : ( + + )} + + + ); + })} {models.filter( (model) => @@ -1094,6 +1265,57 @@ export const ModelDeleteDialog = ({ + + {/* Embedding Model Config Modal */} + { + setEmbeddingConfigModalVisible(false); + setSelectedEmbeddingModel(null); + }} + onOk={handleEmbeddingConfigSave} + cancelText={t("common.button.cancel")} + okText={t("common.button.save")} + confirmLoading={savingEmbeddingConfig} + destroyOnClose + > + + {/* Chunk Size Range */} + + + {t("modelConfig.slider.chunkingSize")} + + setChunkSizeRange(value)} + /> + + + {/* Concurrent Request Count */} + + + {t("modelConfig.input.chunkingBatchSize")} + + setChunkingBatchSize(e.target.value)} + /> + + + ); }; diff --git a/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx b/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx index feaeacad8..972f30d27 100644 --- a/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx +++ b/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx @@ -42,6 +42,7 @@ export const ModelEditDialog = ({ DEFAULT_EXPECTED_CHUNK_SIZE, DEFAULT_MAXIMUM_CHUNK_SIZE, ] as [number, number], + chunkingBatchSize: "10", }); const [loading, setLoading] = useState(false); const [verifyingConnectivity, setVerifyingConnectivity] = useState(false); @@ -67,6 +68,7 @@ export const ModelEditDialog = ({ model.expectedChunkSize || DEFAULT_EXPECTED_CHUNK_SIZE, model.maximumChunkSize || DEFAULT_MAXIMUM_CHUNK_SIZE, ] as [number, number], + chunkingBatchSize: (model.chunkingBatchSize || 10).toString(), }); } }, [model]); @@ -172,6 +174,7 @@ export const ModelEditDialog = ({ ? { expectedChunkSize: form.chunkSizeRange[0], maximumChunkSize: form.chunkSizeRange[1], + chunkingBatchSize: parseInt(form.chunkingBatchSize) || 10, } : {}), }); @@ -300,6 +303,28 @@ export const ModelEditDialog = ({ )} + {/* Concurrent Request Count (Embedding model only) */} + {isEmbeddingModel && ( + + + {t("modelConfig.input.chunkingBatchSize")} + + + handleFormChange("chunkingBatchSize", e.target.value) + } + /> + + )} + {/* Connectivity verification area */} diff --git a/frontend/app/[locale]/models/components/modelConfig.tsx b/frontend/app/[locale]/models/components/modelConfig.tsx index e538fccfe..968cac51f 100644 --- a/frontend/app/[locale]/models/components/modelConfig.tsx +++ b/frontend/app/[locale]/models/components/modelConfig.tsx @@ -846,40 +846,82 @@ export const ModelConfigSection = forwardRef< justifyContent: "flex-start", paddingRight: 12, marginLeft: "4px", - height: LAYOUT_CONFIG.BUTTON_AREA_HEIGHT, + minHeight: LAYOUT_CONFIG.BUTTON_AREA_HEIGHT, }} > - - - {" "} - {t("modelConfig.button.syncModelEngine")} - - } - onClick={() => setIsAddModalOpen(true)} - > - {t("modelConfig.button.addCustomModel")} - - } - onClick={() => setIsDeleteModalOpen(true)} - > - {t("modelConfig.button.editCustomModel")} - - } - onClick={verifyModels} - loading={isVerifying} - > - {t("modelConfig.button.checkConnectivity")} - - + + + + + + + {t("modelConfig.button.syncModelEngine")} + + + {t("modelConfig.button.sync")} + + + + + + } + onClick={() => setIsAddModalOpen(true)} + style={{ width: "100%" }} + block + > + + {t("modelConfig.button.addCustomModel")} + + + {t("modelConfig.button.add")} + + + + + } + onClick={() => setIsDeleteModalOpen(true)} + style={{ width: "100%" }} + block + > + + {t("modelConfig.button.editCustomModel")} + + + {t("modelConfig.button.edit")} + + + + + } + onClick={verifyModels} + loading={isVerifying} + style={{ width: "100%" }} + block + > + + {t("modelConfig.button.checkConnectivity")} + + + {t("modelConfig.button.check")} + + + + ("models"); const [isSaving, setIsSaving] = useState(false); + + // Agent save confirmation states + const [showAgentSaveConfirm, setShowAgentSaveConfirm] = useState(false); + const [pendingCompleteAction, setPendingCompleteAction] = useState<(() => void) | null>(null); + const agentConfigRef = useRef(null); // Handle operations that require login const handleAuthRequired = () => { @@ -116,18 +122,6 @@ export default function Home() { setLoginPromptOpen(false); }; - // Handle login button click - const handleLoginClick = () => { - setLoginPromptOpen(false); - openLoginModal(); - }; - - // Handle register button click - const handleRegisterClick = () => { - setLoginPromptOpen(false); - openRegisterModal(); - }; - // Handle operations that require admin privileges const handleAdminRequired = () => { if (!isSpeedMode && user?.role !== "admin") { @@ -296,6 +290,20 @@ export default function Home() { }; const handleSetupComplete = () => { + // Check if we're on the agents step and if there are unsaved changes + if (currentSetupStep === "agents" && isAdmin && agentConfigRef.current) { + if (agentConfigRef.current.hasUnsavedChanges?.()) { + // Show save confirmation modal + setShowAgentSaveConfirm(true); + setPendingCompleteAction(() => () => { + setCurrentView("chat"); + saveView("chat"); + }); + return; + } + } + + // No unsaved changes, proceed directly setCurrentView("chat"); saveView("chat"); }; @@ -530,6 +538,7 @@ export default function Home() { {currentSetupStep === "agents" && isAdmin && ( > )} + + {/* Agent save confirmation modal for setup completion */} + { + // Reload data from backend to discard changes + await agentConfigRef.current?.reloadCurrentAgentData?.(); + setShowAgentSaveConfirm(false); + const action = pendingCompleteAction; + setPendingCompleteAction(null); + if (action) action(); + }} + onSave={async () => { + try { + setIsSaving(true); + await agentConfigRef.current?.saveAllChanges?.(); + setShowAgentSaveConfirm(false); + const action = pendingCompleteAction; + setPendingCompleteAction(null); + if (action) action(); + } catch (e) { + // errors are surfaced by underlying save + } finally { + setIsSaving(false); + } + }} + onClose={() => { + setShowAgentSaveConfirm(false); + setPendingCompleteAction(null); + }} + /> ); } diff --git a/frontend/app/[locale]/space/components/AgentCard.tsx b/frontend/app/[locale]/space/components/AgentCard.tsx index d047dceb1..b174218bd 100644 --- a/frontend/app/[locale]/space/components/AgentCard.tsx +++ b/frontend/app/[locale]/space/components/AgentCard.tsx @@ -1,9 +1,8 @@ "use client"; import React, { useState } from "react"; -import { useRouter } from "next/navigation"; import { useTranslation } from "react-i18next"; -import { App, Modal } from "antd"; +import { App } from "antd"; import { Trash2, Download, @@ -32,6 +31,7 @@ interface Agent { name: string; display_name: string; description: string; + author?: string; is_available: boolean; enabled?: boolean; } @@ -44,7 +44,6 @@ interface AgentCardProps { } export default function AgentCard({ agent, onRefresh, onChat, onEdit }: AgentCardProps) { - const router = useRouter(); const { t } = useTranslation("common"); const { message, modal } = App.useApp(); const { user, isSpeedMode } = useAuth(); @@ -199,6 +198,13 @@ export default function AgentCard({ agent, onRefresh, onChat, onEdit }: AgentCar {agent.display_name || agent.name} + {agent.author ? ( + + {t("market.by", { defaultValue: "By {{author}}", author: agent.author })} + + ) : ( + + )} {agent.description || t("space.noDescription", "No description")} diff --git a/frontend/components/agent/AgentImportWizard.tsx b/frontend/components/agent/AgentImportWizard.tsx index 12ba325cb..06159be23 100644 --- a/frontend/components/agent/AgentImportWizard.tsx +++ b/frontend/components/agent/AgentImportWizard.tsx @@ -1,6 +1,6 @@ "use client"; -import React, { useState, useEffect } from "react"; +import React, { useState, useEffect, useRef } from "react"; import { Modal, Steps, Button, Select, Input, Form, Tag, Space, Spin, App, Collapse, Radio } from "antd"; import { DownloadOutlined, CheckCircleOutlined, CloseCircleOutlined, PlusOutlined } from "@ant-design/icons"; import { useTranslation } from "react-i18next"; @@ -9,7 +9,7 @@ import { modelService } from "@/services/modelService"; import { getMcpServerList, addMcpServer, updateToolList } from "@/services/mcpService"; import { McpServer, AgentRefreshEvent } from "@/types/agentConfig"; import { ImportAgentData } from "@/hooks/useAgentImport"; -import { importAgent } from "@/services/agentConfigService"; +import { importAgent, checkAgentNameConflictBatch, regenerateAgentNameBatch } from "@/services/agentConfigService"; import log from "@/lib/logger"; export interface AgentImportWizardProps { @@ -53,6 +53,44 @@ const extractPromptHint = (value: string): string | undefined => { return match ? match[1] : undefined; }; +// Parse Markdown links in text and convert to React elements +const parseMarkdownLinks = (text: string): React.ReactNode[] => { + const linkRegex = /\[([^\]]+)\]\(([^)]+)\)/g; + const parts: React.ReactNode[] = []; + let lastIndex = 0; + let match; + let key = 0; + + while ((match = linkRegex.exec(text)) !== null) { + // Add text before the link + if (match.index > lastIndex) { + parts.push(text.substring(lastIndex, match.index)); + } + // Add the link + parts.push( + { + e.stopPropagation(); + }} + > + {match[1]} + + ); + lastIndex = match.index + match[0].length; + } + // Add remaining text + if (lastIndex < text.length) { + parts.push(text.substring(lastIndex)); + } + + return parts.length > 0 ? parts : [text]; +}; + export default function AgentImportWizard({ visible, onCancel, @@ -88,6 +126,28 @@ export default function AgentImportWizard({ const [installingMcp, setInstallingMcp] = useState>({}); const [isImporting, setIsImporting] = useState(false); + // Name conflict checking and renaming + // Structure: agentKey -> { hasConflict, conflictAgents, renamedName, renamedDisplayName } + const [agentNameConflicts, setAgentNameConflicts] = useState; + renamedName: string; + renamedDisplayName: string; + }>>({}); + const [checkingName, setCheckingName] = useState(false); + const [regeneratingAll, setRegeneratingAll] = useState(false); + // Track which agents have been successfully renamed (no conflicts) + const [successfullyRenamedAgents, setSuccessfullyRenamedAgents] = useState>(new Set()); + // Debounce timer for manual name changes - use ref to avoid stale closures + const nameCheckTimerRef = useRef(null); + // Store latest agentNameConflicts in ref to avoid stale closures in timer callbacks + const agentNameConflictsRef = useRef; + renamedName: string; + renamedDisplayName: string; + }>>({}); + // Helper: Refresh tools and agents after MCP changes const refreshToolsAndAgents = async () => { try { @@ -114,6 +174,22 @@ export default function AgentImportWizard({ } }, [visible]); + // Check name conflict immediately after file upload + useEffect(() => { + if (visible && initialData) { + checkNameConflict(); + } + }, [visible, initialData]); + + // Cleanup timer on unmount + useEffect(() => { + return () => { + if (nameCheckTimerRef.current) { + clearTimeout(nameCheckTimerRef.current); + } + }; + }, []); + // Parse agent data for config fields and MCP servers useEffect(() => { if (visible && initialData) { @@ -136,6 +212,232 @@ export default function AgentImportWizard({ setSelectedModelsByAgent(initialModels); }; + // Check name conflict for all agents (main agent + sub-agents) + const checkNameConflict = async () => { + if (!initialData?.agent_info) return; + + setCheckingName(true); + const conflicts: Record; + renamedName: string; + renamedDisplayName: string; + }> = {}; + + try { + // Check all agents in agent_info + const agentInfoMap = initialData.agent_info; + const items = Object.entries(agentInfoMap).map(([agentKey, agentInfo]: [string, any]) => ({ + key: agentKey, + name: agentInfo?.name || "", + display_name: agentInfo?.display_name, + })); + + const result = await checkAgentNameConflictBatch({ + items: items.map((item) => ({ + name: item.name, + display_name: item.display_name, + })), + }); + + if (!result.success || !Array.isArray(result.data)) { + log.warn("Skip name conflict check due to fetch failure"); + setAgentNameConflicts({}); + agentNameConflictsRef.current = {}; + setCheckingName(false); + return; + } + + result.data.forEach((res: any, idx: number) => { + const item = items[idx]; + const agentKey = item.key; + const hasNameConflict = res?.name_conflict || false; + const hasDisplayNameConflict = res?.display_name_conflict || false; + const conflictAgentsRaw = Array.isArray(res?.conflict_agents) ? res.conflict_agents : []; + // Deduplicate by name/display_name + const seen = new Set(); + const conflictAgents = conflictAgentsRaw.reduce((acc: Array<{ name?: string; display_name?: string }>, curr: any) => { + const key = `${curr?.name || ""}||${curr?.display_name || ""}`; + if (seen.has(key)) return acc; + seen.add(key); + acc.push({ name: curr?.name, display_name: curr?.display_name }); + return acc; + }, []); + + const hasConflict = hasNameConflict || hasDisplayNameConflict; + conflicts[agentKey] = { + hasConflict, + conflictAgents, + renamedName: item.name, + renamedDisplayName: item.display_name || "", + }; + }); + + setAgentNameConflicts(conflicts); + + // Update successfully renamed agents based on initial check + // Only add to successfullyRenamedAgents if there was a conflict that was resolved + // For initial check, we don't add anything since no renaming has happened yet + setSuccessfullyRenamedAgents((prev) => { + const next = new Set(prev); + // Don't modify on initial check - only track agents that were successfully renamed + return next; + }); + } catch (error) { + log.error("Failed to check name conflicts:", error); + } finally { + setCheckingName(false); + } + }; + + // Check name conflict for a specific agent after renaming + const checkSingleAgentConflict = async (agentKey: string, name: string, displayName?: string) => { + if (!initialData?.agent_info) return; + + try { + const result = await checkAgentNameConflictBatch({ + items: [ + { + name, + display_name: displayName, + }, + ], + }); + + if (!result.success || !Array.isArray(result.data) || !result.data[0]) { + return; + } + + const checkResult = result.data[0]; + const hasNameConflict = checkResult?.name_conflict || false; + const hasDisplayNameConflict = checkResult?.display_name_conflict || false; + const hasConflict = hasNameConflict || hasDisplayNameConflict; + const conflictAgentsRaw = Array.isArray(checkResult?.conflict_agents) ? checkResult.conflict_agents : []; + + // Deduplicate by name/display_name + const seen = new Set(); + const conflictAgents = conflictAgentsRaw.reduce((acc: Array<{ name?: string; display_name?: string }>, curr: any) => { + const key = `${curr?.name || ""}||${curr?.display_name || ""}`; + if (seen.has(key)) return acc; + seen.add(key); + acc.push({ name: curr?.name, display_name: curr?.display_name }); + return acc; + }, []); + + setAgentNameConflicts((prev) => { + const next = { ...prev }; + if (!next[agentKey]) { + const agentInfo = initialData.agent_info[agentKey] as any; + next[agentKey] = { + hasConflict: false, + conflictAgents: [], + renamedName: agentInfo?.name || "", + renamedDisplayName: agentInfo?.display_name || "", + }; + } + next[agentKey] = { + ...next[agentKey], + hasConflict, + conflictAgents, + renamedName: name, + renamedDisplayName: displayName || "", + }; + agentNameConflictsRef.current = next; + return next; + }); + + // Update success status + setSuccessfullyRenamedAgents((prev) => { + const next = new Set(prev); + if (hasConflict) { + next.delete(agentKey); + } else { + next.add(agentKey); + } + return next; + }); + + return hasConflict; + } catch (error) { + log.error("Failed to check single agent conflict:", error); + return true; // Assume conflict on error to be safe + } + }; + + // One-click regenerate all conflicted agents using selected model(s) + const handleRegenerateAll = async () => { + if (!initialData?.agent_info) return; + + const agentsWithConflicts = Object.entries(agentNameConflicts).filter( + ([_, conflict]) => conflict.hasConflict + ); + if (agentsWithConflicts.length === 0) return; + + setRegeneratingAll(true); + try { + const payload = { + items: agentsWithConflicts.map(([agentKey, conflict]) => { + const agentInfo = initialData.agent_info[agentKey] as any; + return { + agent_id: agentInfo?.agent_id, + name: conflict.renamedName || agentInfo?.name || "", + display_name: conflict.renamedDisplayName || agentInfo?.display_name || "", + task_description: agentInfo?.business_description || agentInfo?.description || "", + language: "zh", + }; + }), + }; + + const result = await regenerateAgentNameBatch(payload); + + if (!result.success || !Array.isArray(result.data)) { + message.error(result.message || t("market.install.error.nameRegenerationFailed", "Failed to regenerate name")); + return; + } + + const regenerated = result.data as Array<{ name?: string; display_name?: string }>; + + // Update conflicts state with regenerated names + setAgentNameConflicts((prev) => { + const next = { ...prev }; + agentsWithConflicts.forEach(([agentKey, conflict], idx) => { + const agentInfo = initialData.agent_info[agentKey] as any; + const data = regenerated[idx] || {}; + next[agentKey] = { + ...next[agentKey], + renamedName: data.name || conflict.renamedName || agentInfo?.name || "", + renamedDisplayName: + data.display_name || conflict.renamedDisplayName || agentInfo?.display_name || "", + }; + }); + agentNameConflictsRef.current = next; + return next; + }); + + // Re-check conflicts for all regenerated agents + const checkPromises = agentsWithConflicts.map(async ([agentKey, conflict], idx) => { + const data = regenerated[idx] || {}; + const newName = data.name || conflict.renamedName || ""; + const newDisplayName = data.display_name || conflict.renamedDisplayName || ""; + return checkSingleAgentConflict(agentKey, newName, newDisplayName); + }); + + const checkResults = await Promise.all(checkPromises); + const allResolved = checkResults.every((hasConflict) => !hasConflict); + + if (allResolved) { + message.success(t("market.install.success.nameRegeneratedAndResolved", "Agent names regenerated successfully and all conflicts resolved")); + } else { + message.success(t("market.install.success.nameRegenerated", "Agent name regenerated successfully")); + } + } catch (error) { + log.error("Failed to regenerate agent names:", error); + message.error(t("market.install.error.nameRegenerationFailed", "Failed to regenerate name")); + } finally { + setRegeneratingAll(false); + } + }; + const loadLLMModels = async () => { setLoadingModels(true); try { @@ -336,7 +638,11 @@ export default function AgentImportWizard({ }; const handleNext = () => { - if (currentStep === 0) { + const currentStepKey = steps[currentStep]?.key; + + if (currentStepKey === "rename") { + // no mandatory name check + } else if (currentStepKey === "model") { // Step 1: Model selection validation if (modelSelectionMode === "unified") { if (!selectedModelId || !selectedModelName) { @@ -357,7 +663,7 @@ export default function AgentImportWizard({ } } } - } else if (currentStep === 1) { + } else if (currentStepKey === "config") { // Step 2: Config fields validation const emptyFields = configFields.filter(field => !configValues[field.valueKey]?.trim()); if (emptyFields.length > 0) { @@ -409,7 +715,18 @@ export default function AgentImportWizard({ // Clone agent data structure const agentJson = JSON.parse(JSON.stringify(initialData)); - const mainAgentId = String(initialData.agent_id); + + // Update all agents' name/display_name if renamed + Object.entries(agentNameConflicts).forEach(([agentKey, conflict]) => { + if (agentJson.agent_info[agentKey]) { + if (conflict.renamedName) { + agentJson.agent_info[agentKey].name = conflict.renamedName; + } + if (conflict.renamedDisplayName) { + agentJson.agent_info[agentKey].display_name = conflict.renamedDisplayName; + } + } + }); // Update model information based on selection mode if (modelSelectionMode === "unified") { @@ -495,11 +812,32 @@ export default function AgentImportWizard({ setConfigValues({}); setMcpServers([]); setIsImporting(false); + setAgentNameConflicts({}); + agentNameConflictsRef.current = {}; + setCheckingName(false); + setRegeneratingAll(false); + setSuccessfullyRenamedAgents(new Set()); + if (nameCheckTimerRef.current) { + clearTimeout(nameCheckTimerRef.current); + nameCheckTimerRef.current = null; + } onCancel(); }; // Filter only required steps for navigation + // Show rename step if name conflict check is complete and there are any agents that had conflicts + // (even if all conflicts are now resolved, we still want to show the step so users can see the success state) + const hasAnyAgentsWithConflicts = !checkingName && ( + // Check if any agent has a current conflict + Object.values(agentNameConflicts).some(conflict => conflict.hasConflict) || + // OR if any agent was successfully renamed (meaning it had a conflict that was resolved) + successfullyRenamedAgents.size > 0 + ); const steps = [ + hasAnyAgentsWithConflicts && { + key: "rename", + title: t("market.install.step.rename", "Rename Agent"), + }, { key: "model", title: t("market.install.step.model", "Select Model"), @@ -516,9 +854,16 @@ export default function AgentImportWizard({ // Check if can proceed to next step const canProceed = () => { + // Disable buttons while checking name conflict + if (checkingName) { + return false; + } + const currentStepKey = steps[currentStep]?.key; - if (currentStepKey === "model") { + if (currentStepKey === "rename") { + return true; + } else if (currentStepKey === "model") { if (modelSelectionMode === "unified") { return selectedModelId !== null && selectedModelName !== ""; } else { @@ -545,9 +890,237 @@ export default function AgentImportWizard({ }; const renderStepContent = () => { + // Show loading state while checking name conflict + if (checkingName) { + return ( + + + + {t("market.install.checkingName", "Checking agent name...")} + + + ); + } + const currentStepKey = steps[currentStep]?.key; - if (currentStepKey === "model") { + if (currentStepKey === "rename") { + // Get all agents that had conflicts (including resolved ones) + // Show all agents in agentNameConflicts - they either have conflicts or were successfully renamed + const allAgentsWithConflicts = Object.entries(agentNameConflicts) + .filter(([agentKey, conflict]) => { + // Show agent if: + // 1. It currently has a conflict, OR + // 2. It was successfully renamed (in successfullyRenamedAgents), OR + // 3. It's in agentNameConflicts (meaning it was checked and had a conflict at some point) + // We show all agents in agentNameConflicts to keep the UI consistent + return true; // Show all agents that were checked + }) + .sort(([keyA], [keyB]) => { + // Main agent first + const mainAgentId = String(initialData?.agent_id); + if (keyA === mainAgentId) return -1; + if (keyB === mainAgentId) return 1; + return 0; + }); + + // Get agents that still have conflicts + const agentsWithConflicts = allAgentsWithConflicts.filter( + ([agentKey, conflict]) => conflict.hasConflict + ); + + // If no agents had conflicts at all, don't show rename step + if (allAgentsWithConflicts.length === 0) { + return null; + } + + // Check if all conflicts are resolved + const allConflictsResolved = agentsWithConflicts.length === 0 && allAgentsWithConflicts.length > 0; + const hasResolvedAgents = allAgentsWithConflicts.some( + ([agentKey]) => successfullyRenamedAgents.has(agentKey) + ); + + return ( + + {allConflictsResolved ? ( + + + + + {t("market.install.rename.success", "All agent name conflicts have been resolved. You can proceed to the next step.")} + + + + ) : ( + + {hasResolvedAgents && ( + + + + + {t("market.install.rename.partialSuccess", "Some agents have been successfully renamed.")} + + + + )} + + {t("market.install.rename.warning", "The agent name or display name conflicts with existing agents. Please rename to proceed.")} + + + {t("market.install.rename.oneClickDesc", "You can manually edit the names, or click one-click rename to let the selected model regenerate names for all conflicted agents.")} + + + {t("market.install.rename.note", "Note: If you proceed without renaming, the agent will be created but marked as unavailable due to name conflicts. You can rename it later in the agent list.")} + + + {t("market.install.rename.oneClick", "One-click Rename")} + + + )} + + + {allAgentsWithConflicts.map(([agentKey, conflict]) => { + const agentInfo = initialData?.agent_info?.[agentKey] as any; + const agentDisplayName = agentInfo?.display_name || agentInfo?.name || `${t("market.install.agent.defaultName", "Agent")} ${agentKey}`; + const isMainAgent = agentKey === String(initialData?.agent_id); + const originalName = agentInfo?.name || ""; + const originalDisplayName = agentInfo?.display_name || ""; + + return ( + + + + {isMainAgent && {t("market.install.agent.main", "Main")}} + {agentDisplayName} + + + + {successfullyRenamedAgents.has(agentKey) ? ( + + + + + {t("market.install.rename.agentResolved", "This agent's name conflict has been resolved.")} + + + + ) : conflict.hasConflict && conflict.conflictAgents.length > 0 && ( + + + {t("market.install.rename.conflictAgents", "Conflicting agents:")} + + + {conflict.conflictAgents.map((agent, idx) => ( + + {[agent.name, agent.display_name].filter(Boolean).join(" / ")} + + ))} + + + )} + + + + {t("market.install.rename.name", "Agent Name")} + + { + const newName = e.target.value; + setAgentNameConflicts(prev => { + const updated = { + ...prev, + [agentKey]: { + ...prev[agentKey], + renamedName: newName, + }, + }; + + // Clear existing timer + if (nameCheckTimerRef.current) { + clearTimeout(nameCheckTimerRef.current); + } + + // Set new timer for debounced check (500ms delay) + nameCheckTimerRef.current = setTimeout(() => { + // Read latest value from ref when timer fires + const currentConflict = agentNameConflictsRef.current[agentKey]; + if (currentConflict) { + checkSingleAgentConflict( + agentKey, + currentConflict.renamedName, + currentConflict.renamedDisplayName + ); + } + }, 500); + + agentNameConflictsRef.current = updated; + return updated; + }); + }} + placeholder={originalName} + size="large" + disabled={regeneratingAll} + /> + + + + + {t("market.install.rename.displayName", "Display Name")} + + { + const newDisplayName = e.target.value; + setAgentNameConflicts(prev => { + const updated = { + ...prev, + [agentKey]: { + ...prev[agentKey], + renamedDisplayName: newDisplayName, + }, + }; + + // Clear existing timer + if (nameCheckTimerRef.current) { + clearTimeout(nameCheckTimerRef.current); + } + + // Set new timer for debounced check (500ms delay) + nameCheckTimerRef.current = setTimeout(() => { + // Read latest value from ref when timer fires + const currentConflict = agentNameConflictsRef.current[agentKey]; + if (currentConflict) { + checkSingleAgentConflict( + agentKey, + currentConflict.renamedName, + currentConflict.renamedDisplayName + ); + } + }, 500); + + agentNameConflictsRef.current = updated; + return updated; + }); + }} + placeholder={originalDisplayName} + size="large" + disabled={regeneratingAll} + /> + + + ); + })} + + + + ); + } else if (currentStepKey === "model") { return ( {/* Agent Info - Title and Description Style */} @@ -819,47 +1392,22 @@ export default function AgentImportWizard({ {mcpServers.map((mcp, index) => ( - - - - - {mcp.mcp_server_name} - - {mcp.isInstalled ? ( - } color="success" className="text-sm"> - {t("market.install.mcp.installed", "Installed")} - - ) : ( - } color="default" className="text-sm"> - {t("market.install.mcp.notInstalled", "Not Installed")} - - )} - - - - - MCP URL: - - {(mcp.isUrlEditable || !mcp.isInstalled) ? ( - handleMcpUrlChange(index, e.target.value)} - placeholder={mcp.isUrlEditable - ? t("market.install.mcp.urlPlaceholder", "Enter MCP server URL") - : mcp.mcp_url - } - size="middle" - disabled={mcp.isInstalled} - style={{ maxWidth: "400px" }} - /> - ) : ( - - {mcp.editedUrl || mcp.mcp_url} - - )} - + + + + {mcp.mcp_server_name} + + {mcp.isInstalled ? ( + } color="success" className="text-xs"> + {t("market.install.mcp.installed", "Installed")} + + ) : ( + } color="default" className="text-xs"> + {t("market.install.mcp.notInstalled", "Not Installed")} + + )} {!mcp.isInstalled && ( @@ -876,6 +1424,44 @@ export default function AgentImportWizard({ )} + + + + + MCP URL: + + {(mcp.isUrlEditable || !mcp.isInstalled) ? ( + handleMcpUrlChange(index, e.target.value)} + placeholder={mcp.isUrlEditable + ? t("market.install.mcp.urlPlaceholder", "Enter MCP server URL") + : mcp.mcp_url + } + size="middle" + disabled={mcp.isInstalled} + style={{ maxWidth: "400px" }} + className={mcp.isUrlEditable && needsConfig(mcp.mcp_url) ? "bg-gray-100 dark:bg-gray-800" : ""} + /> + ) : ( + + {mcp.editedUrl || mcp.mcp_url} + + )} + + {/* Show hint if URL needs configuration */} + {mcp.isUrlEditable && needsConfig(mcp.mcp_url) && (() => { + const hint = extractPromptHint(mcp.mcp_url); + const hintText = hint || t("market.install.mcp.defaultConfigHint", "Please enter the MCP server URL"); + return ( + + + {parseMarkdownLinks(hintText)} + + + ); + })()} + ))} @@ -946,7 +1532,7 @@ export default function AgentImportWizard({ className="mb-6" /> - + {renderStepContent()} diff --git a/frontend/components/ui/markdownRenderer.tsx b/frontend/components/ui/markdownRenderer.tsx index e192d5189..12bcc7eeb 100644 --- a/frontend/components/ui/markdownRenderer.tsx +++ b/frontend/components/ui/markdownRenderer.tsx @@ -15,6 +15,7 @@ import * as TooltipPrimitive from "@radix-ui/react-tooltip"; import { visit } from "unist-util-visit"; import { SearchResult } from "@/types/chat"; +import { resolveS3UrlToDataUrl } from "@/services/storageService"; import { Tooltip, TooltipContent, @@ -31,8 +32,267 @@ interface MarkdownRendererProps { showDiagramToggle?: boolean; onCitationHover?: () => void; enableMultimodal?: boolean; + /** + * When true, resolve s3:// media URLs in markdown into data URLs (base64) + * so that images can still be displayed after page refresh or when + * the original S3 URL is not directly accessible by the browser. + */ + resolveS3Media?: boolean; } +// Simple in-memory cache to avoid refetching the same S3 object multiple times +const s3MediaCache = new Map(); +const mediaObjectUrlCache = new Map(); +const mediaObjectUrlPromiseCache = new Map>(); +const S3_MEDIA_SESSION_PREFIX = "s3-media-cache:"; + +const isBrowserEnvironment = typeof window !== "undefined"; + +const getSessionCachedValue = (key: string): string | null => { + if (!isBrowserEnvironment) { + return null; + } + try { + return window.sessionStorage.getItem(key); + } catch { + return null; + } +}; + +const getCachedMediaSrc = (src: string): string | null => { + const cached = s3MediaCache.get(src); + if (cached) { + return cached; + } + const sessionValue = getSessionCachedValue(src); + if (sessionValue) { + s3MediaCache.set(src, sessionValue); + return sessionValue; + } + return null; +}; + +const setCachedMediaSrc = (src: string, value: string) => { + s3MediaCache.set(src, value); + if (!isBrowserEnvironment) { + return; + } + try { + window.sessionStorage.setItem(`${S3_MEDIA_SESSION_PREFIX}${src}`, value); + } catch { + // Ignore storage quota errors silently. + } +}; + +const setCachedObjectUrl = (src: string, objectUrl: string | null) => { + if (!objectUrl) { + return; + } + const existing = mediaObjectUrlCache.get(src); + if (existing && existing !== objectUrl) { + URL.revokeObjectURL(existing); + } + mediaObjectUrlCache.set(src, objectUrl); +}; + +const resolveMediaToObjectUrl = async ( + src: string, + { resolveS3 }: { resolveS3: boolean } +): Promise => { + try { + if (src.startsWith("blob:")) { + return src; + } + + if (src.startsWith("s3://")) { + if (!resolveS3) { + return null; + } + const dataUrl = await resolveS3UrlToDataUrl(src); + if (!dataUrl) { + return null; + } + const response = await fetch(dataUrl); + if (!response.ok) { + return null; + } + const blob = await response.blob(); + return URL.createObjectURL(blob); + } + + if ( + src.startsWith("http://") || + src.startsWith("https://") || + src.startsWith("/api/") || + src.startsWith("/nexent/") || + src.startsWith("/attachments/") || + src.startsWith("/") + ) { + const response = await fetch(src); + if (!response.ok) { + return null; + } + const blob = await response.blob(); + return URL.createObjectURL(blob); + } + + if (src.startsWith("data:")) { + const response = await fetch(src); + if (!response.ok) { + return null; + } + const blob = await response.blob(); + return URL.createObjectURL(blob); + } + + return null; + } catch { + return null; + } +}; + +const usePrefetchedMediaSource = ( + src?: string, + options?: { enable?: boolean; resolveS3?: boolean } +) => { + const shouldPrefetch = + Boolean( + options?.enable && + src && + typeof src === "string" && + !src.startsWith("blob:") && + (src.startsWith("s3://") || + src.startsWith("http://") || + src.startsWith("https://") || + src.startsWith("/")) + ) || false; + + const [resolvedSrc, setResolvedSrc] = React.useState(() => { + if (!src || typeof src !== "string") { + return null; + } + if (!shouldPrefetch) { + return src; + } + return mediaObjectUrlCache.get(src) ?? null; + }); + + React.useEffect(() => { + if (!src || typeof src !== "string") { + setResolvedSrc(null); + return; + } + + if (!shouldPrefetch) { + setResolvedSrc(src); + return; + } + + const cached = mediaObjectUrlCache.get(src); + if (cached) { + setResolvedSrc(cached); + return; + } + + let cancelled = false; + + const promise = + mediaObjectUrlPromiseCache.get(src) ?? + resolveMediaToObjectUrl(src, { + resolveS3: options?.resolveS3 ?? true, + }); + + mediaObjectUrlPromiseCache.set(src, promise); + + promise + .then((objectUrl) => { + if (cancelled) { + return; + } + if (!objectUrl) { + setResolvedSrc(null); + return; + } + setCachedObjectUrl(src, objectUrl); + setResolvedSrc(objectUrl); + }) + .catch(() => { + if (!cancelled) { + setResolvedSrc(null); + } + }) + .finally(() => { + mediaObjectUrlPromiseCache.delete(src); + }); + + return () => { + cancelled = true; + }; + }, [options?.resolveS3, shouldPrefetch, src]); + + return resolvedSrc; +}; + +const useResolvedS3Media = (src?: string, shouldResolve?: boolean) => { + const cachedInitial = + typeof src === "string" && src.startsWith("s3://") + ? getCachedMediaSrc(src) + : null; + const initialValue = + typeof src === "string" + ? !shouldResolve || !src.startsWith("s3://") + ? src + : cachedInitial + : null; + const [resolvedSrc, setResolvedSrc] = React.useState( + initialValue + ); + + React.useEffect(() => { + if (!src || typeof src !== "string") { + setResolvedSrc(null); + return; + } + + if (!shouldResolve || !src.startsWith("s3://")) { + setResolvedSrc(src); + return; + } + + const cached = getCachedMediaSrc(src); + if (cached) { + setResolvedSrc(cached); + return; + } + + let cancelled = false; + + resolveS3UrlToDataUrl(src) + .then((dataUrl) => { + if (cancelled) { + return; + } + if (dataUrl) { + setCachedMediaSrc(src, dataUrl); + setResolvedSrc(dataUrl); + } else { + setResolvedSrc(null); + } + }) + .catch(() => { + if (!cancelled) { + setResolvedSrc(null); + } + }); + + return () => { + cancelled = true; + }; + }, [src, shouldResolve]); + + return resolvedSrc; +}; + const VIDEO_EXTENSIONS = [".mp4", ".webm", ".ogg", ".mov", ".m4v"]; const extractExtension = (value: string): string => { @@ -519,20 +779,16 @@ const ImageWithErrorHandling: React.FC = React.memo ImageWithErrorHandling.displayName = "ImageWithErrorHandling"; -export const MarkdownRenderer: React.FC = ({ - content, - className, - searchResults = [], - showDiagramToggle = true, - onCitationHover, - enableMultimodal = true, -}) => { +/** + * Render a code block with syntax highlighting, language label, and copy button + * This is exported for use in other components that need to render code blocks directly + */ +export const CodeBlock: React.FC<{ + codeContent: string; + language?: string; +}> = ({ codeContent, language = "python" }) => { const { t } = useTranslation("common"); - - // Convert LaTeX delimiters to markdown math delimiters - const processedContent = convertLatexDelimiters(content); - - // Customize code block style with light gray background + const customStyle = { ...oneLight, 'pre[class*="language-"]': { @@ -569,6 +825,47 @@ export const MarkdownRenderer: React.FC = ({ }, }; + const cleanedContent = codeContent.replace(/^\n+|\n+$/g, ""); + + return ( + + + + {language} + + + + + + {cleanedContent} + + + + ); +}; + +export const MarkdownRenderer: React.FC = ({ + content, + className, + searchResults = [], + showDiagramToggle = true, + onCitationHover, + enableMultimodal = true, + resolveS3Media = false, +}) => { + const { t } = useTranslation("common"); + + // Convert LaTeX delimiters to markdown math delimiters + const processedContent = convertLatexDelimiters(content); + const renderCodeFallback = (text: string, key?: React.Key) => ( = ({ return ; }; + const ImageResolver: React.FC<{ src?: string; alt?: string | null }> = ({ + src, + alt, + }) => { + const resolvedSrc = useResolvedS3Media( + typeof src === "string" ? src : undefined, + resolveS3Media + ); + + if (!enableMultimodal) { + return renderMediaFallback(src, alt); + } + + if (!resolvedSrc) { + return renderMediaFallback(src, alt); + } + + if (isVideoUrl(resolvedSrc)) { + return renderVideoElement({ src: resolvedSrc, alt }); + } + + return ; + }; + // Modified processText function logic const processText = (text: string) => { if (typeof text !== "string") return text; @@ -865,37 +1186,7 @@ export const MarkdownRenderer: React.FC = ({ return ; } if (!inline) { - return ( - - - - {match[1]} - - - - - - {codeContent} - - - - ); + return ; } } } catch (error) { @@ -908,21 +1199,9 @@ export const MarkdownRenderer: React.FC = ({ ); }, // Image - img: ({ src, alt }: any) => { - if (!enableMultimodal) { - return renderMediaFallback(src, alt); - } - - if (isVideoUrl(src)) { - return renderVideoElement({ src, alt }); - } - - if (!src || typeof src !== "string") { - return null; - } - - return ; - }, + img: ({ src, alt }: any) => ( + + ), // Video video: ({ children, ...props }: any) => { const directSrc = props?.src; diff --git a/frontend/const/chatConfig.ts b/frontend/const/chatConfig.ts index df7b65c92..73cd19aed 100644 --- a/frontend/const/chatConfig.ts +++ b/frontend/const/chatConfig.ts @@ -111,6 +111,7 @@ messageTypes: { // Content type constants for last content type tracking contentTypes: { MODEL_OUTPUT: "model_output" as const, + MODEL_OUTPUT_CODE: "model_output_code" as const, PARSING: "parsing" as const, EXECUTION: "execution" as const, AGENT_NEW_RUN: "agent_new_run" as const, diff --git a/frontend/const/marketConfig.ts b/frontend/const/marketConfig.ts new file mode 100644 index 000000000..6de8d1f48 --- /dev/null +++ b/frontend/const/marketConfig.ts @@ -0,0 +1,36 @@ +// ========== Market Configuration Constants ========== + +/** + * Default icons for market agent categories + * Maps category name field to their corresponding icons + */ +export const MARKET_CATEGORY_ICONS: Record = { + research: "🔬", + content: "✍️", + development: "💻", + business: "📈", + automation: "⚙️", + education: "📚", + communication: "💬", + data: "📊", + creative: "🎨", + other: "📦", +} as const; + +/** + * Get icon for a category by name field + * @param categoryName - Category name field (e.g., "research", "content") + * @param fallbackIcon - Fallback icon if category not found (default: 📦) + * @returns Icon emoji string + */ +export function getCategoryIcon( + categoryName: string | null | undefined, + fallbackIcon: string = "📦" +): string { + if (!categoryName) { + return fallbackIcon; + } + + return MARKET_CATEGORY_ICONS[categoryName] || fallbackIcon; +} + diff --git a/frontend/hooks/useAgentImport.md b/frontend/hooks/useAgentImport.md deleted file mode 100644 index 52b14aa78..000000000 --- a/frontend/hooks/useAgentImport.md +++ /dev/null @@ -1,245 +0,0 @@ -# useAgentImport Hook - -Unified agent import hook for handling agent imports across the application. - -## Overview - -This hook provides a consistent interface for importing agents from different sources: -- File upload (used in Agent Development and Agent Space) -- Direct data (used in Agent Market) - -All import operations ultimately call the same backend `/agent/import` endpoint. - -## Usage - -### Basic Import - -```typescript -import { useAgentImport } from "@/hooks/useAgentImport"; - -function MyComponent() { - const { isImporting, importFromFile, importFromData, error } = useAgentImport({ - onSuccess: () => { - console.log("Import successful!"); - }, - onError: (error) => { - console.error("Import failed:", error); - }, - }); - - // ... -} -``` - -### Import from File (SubAgentPool, SpaceContent) - -```typescript -const handleFileImport = async (file: File) => { - try { - await importFromFile(file); - // Success handled by onSuccess callback - } catch (error) { - // Error handled by onError callback - } -}; - -// In file input handler - { - const file = e.target.files?.[0]; - if (file) { - handleFileImport(file); - } - }} -/> -``` - -### Import from Data (Market) - -```typescript -const handleMarketImport = async (agentDetails: MarketAgentDetail) => { - // Prepare import data from agent details - const importData = { - agent_id: agentDetails.agent_id, - agent_info: agentDetails.agent_json.agent_info, - mcp_info: agentDetails.agent_json.mcp_info, - }; - - try { - await importFromData(importData); - // Success handled by onSuccess callback - } catch (error) { - // Error handled by onError callback - } -}; -``` - -## Integration Examples - -### 1. SubAgentPool Component - -```typescript -// In SubAgentPool.tsx -import { useAgentImport } from "@/hooks/useAgentImport"; - -export default function SubAgentPool({ onImportSuccess }: Props) { - const { isImporting, importFromFile } = useAgentImport({ - onSuccess: () => { - message.success(t("agent.import.success")); - onImportSuccess?.(); - }, - onError: (error) => { - message.error(error.message); - }, - }); - - const handleImportClick = () => { - const input = document.createElement("input"); - input.type = "file"; - input.accept = ".json"; - input.onchange = async (e) => { - const file = (e.target as HTMLInputElement).files?.[0]; - if (file) { - await importFromFile(file); - } - }; - input.click(); - }; - - return ( - - {isImporting ? t("importing") : t("import")} - - ); -} -``` - -### 2. SpaceContent Component - -```typescript -// In SpaceContent.tsx -import { useAgentImport } from "@/hooks/useAgentImport"; - -export function SpaceContent({ onRefresh }: Props) { - const { isImporting, importFromFile } = useAgentImport({ - onSuccess: () => { - message.success(t("space.import.success")); - onRefresh(); // Reload agent list - }, - }); - - const handleImportAgent = () => { - const input = document.createElement("input"); - input.type = "file"; - input.accept = ".json"; - input.onchange = async (e) => { - const file = (e.target as HTMLInputElement).files?.[0]; - if (file) { - await importFromFile(file); - } - }; - input.click(); - }; - - return ( - - {isImporting ? "Importing..." : "Import Agent"} - - ); -} -``` - -### 3. AgentInstallModal (Market) - -```typescript -// In AgentInstallModal.tsx -import { useAgentImport } from "@/hooks/useAgentImport"; - -export default function AgentInstallModal({ - agentDetails, - onComplete -}: Props) { - const { isImporting, importFromData } = useAgentImport({ - onSuccess: () => { - message.success(t("market.install.success")); - onComplete(); - }, - }); - - const handleInstall = async () => { - // Prepare configured data - const importData = prepareImportData(agentDetails, userConfig); - await importFromData(importData); - }; - - return ( - - Install - - ); -} -``` - -## API Reference - -### Parameters - -```typescript -interface UseAgentImportOptions { - onSuccess?: () => void; // Called on successful import - onError?: (error: Error) => void; // Called on import error - forceImport?: boolean; // Force import even if duplicate names exist -} -``` - -### Return Value - -```typescript -interface UseAgentImportResult { - isImporting: boolean; // Import in progress - importFromFile: (file: File) => Promise; // Import from file - importFromData: (data: ImportAgentData) => Promise; // Import from data - error: Error | null; // Last error (if any) -} -``` - -### Data Structure - -```typescript -interface ImportAgentData { - agent_id: number; - agent_info: Record; - mcp_info?: Array<{ - mcp_server_name: string; - mcp_url: string; - }>; -} -``` - -## Error Handling - -The hook handles errors in two ways: - -1. **Via onError callback** - Preferred method for user-facing error messages -2. **Via thrown exceptions** - For custom error handling in specific cases - -Both approaches are supported to allow flexibility in different use cases. - -## Implementation Notes - -- File content is read as text and parsed as JSON -- Data structure validation is performed before calling the backend -- The backend `/agent/import` endpoint is called with the prepared data -- All logging uses the centralized `log` utility from `@/lib/logger` - diff --git a/frontend/hooks/useAgentImport.ts b/frontend/hooks/useAgentImport.ts index f0f33add4..0aff99e82 100644 --- a/frontend/hooks/useAgentImport.ts +++ b/frontend/hooks/useAgentImport.ts @@ -1,5 +1,9 @@ import { useState } from "react"; -import { importAgent } from "@/services/agentConfigService"; +import { + checkAgentNameConflictBatch, + importAgent, + regenerateAgentNameBatch, +} from "@/services/agentConfigService"; import log from "@/lib/logger"; export interface ImportAgentData { @@ -15,6 +19,19 @@ export interface UseAgentImportOptions { onSuccess?: () => void; onError?: (error: Error) => void; forceImport?: boolean; + /** + * Optional: handle name/display_name conflicts before import + * Caller can resolve by returning new name or choosing to continue/terminate + */ + onNameConflictResolve?: (payload: { + name: string; + displayName?: string; + conflictAgents: Array<{ id: string; name?: string; display_name?: string }>; + regenerateWithLLM: () => Promise<{ + name?: string; + displayName?: string; + }>; + }) => Promise<{ proceed: boolean; name?: string; displayName?: string }>; } export interface UseAgentImportResult { @@ -111,6 +128,30 @@ export function useAgentImport( * Core import logic - calls backend API */ const importAgentData = async (data: ImportAgentData): Promise => { + // Step 1: check name/display name conflicts before import (only check main agent name and display name) + const mainAgent = data.agent_info?.[String(data.agent_id)]; + if (mainAgent?.name) { + const conflictHandled = await ensureNameNotDuplicated( + mainAgent.name, + mainAgent.display_name, + mainAgent.description || mainAgent.business_description + ); + + if (!conflictHandled.proceed) { + throw new Error( + "Agent name/display name conflicts with existing agent; import cancelled." + ); + } + + // if user chooses to modify name, write back to import data + if (conflictHandled.name) { + mainAgent.name = conflictHandled.name; + } + if (conflictHandled.displayName) { + mainAgent.display_name = conflictHandled.displayName; + } + } + const result = await importAgent(data, { forceImport }); if (!result.success) { @@ -142,6 +183,80 @@ export function useAgentImport( }); }; + /** + * Frontend side name conflict validation logic + */ + const ensureNameNotDuplicated = async ( + name: string, + displayName?: string, + taskDescription?: string + ): Promise<{ proceed: boolean; name?: string; displayName?: string }> => { + try { + const checkResp = await checkAgentNameConflictBatch({ + items: [ + { + name, + display_name: displayName, + }, + ], + }); + if (!checkResp.success || !Array.isArray(checkResp.data)) { + log.warn("Skip name conflict check due to fetch failure"); + return { proceed: true }; + } + + const first = checkResp.data[0] || {}; + const { name_conflict, display_name_conflict, conflict_agents } = first; + + if (!name_conflict && !display_name_conflict) { + return { proceed: true }; + } + + const regenerateWithLLM = async () => { + const regenResp = await regenerateAgentNameBatch({ + items: [ + { + name, + display_name: displayName, + task_description: taskDescription, + }, + ], + }); + if (!regenResp.success || !Array.isArray(regenResp.data) || !regenResp.data[0]) { + throw new Error("Failed to regenerate agent name"); + } + const item = regenResp.data[0]; + return { + name: item.name, + displayName: item.display_name ?? displayName, + }; + }; + + // let caller decide how to handle conflicts (e.g. show a dialog to let user choose whether to let LLM rename) + if (options.onNameConflictResolve) { + return await options.onNameConflictResolve({ + name, + displayName, + conflictAgents: (conflict_agents || []).map((c: any) => ({ + id: String(c.agent_id ?? c.id), + name: c.name, + display_name: c.display_name, + })), + regenerateWithLLM, + }); + } + + // default behavior: directly call backend to rename to keep import available + const regenerated = await regenerateWithLLM(); + return { proceed: true, ...regenerated }; + } catch (error) { + // if callback throws an error, prevent import + throw error instanceof Error + ? error + : new Error("Name conflict handling failed"); + } + }; + return { isImporting, importFromFile, diff --git a/frontend/hooks/useMemory.ts b/frontend/hooks/useMemory.ts index 03ac72dd8..5bb1a1bc9 100644 --- a/frontend/hooks/useMemory.ts +++ b/frontend/hooks/useMemory.ts @@ -483,24 +483,3 @@ export function useMemory({ visible, currentUserId, currentTenantId, message }: handleDeleteMemory, } } - -// expose memory notification indicator to ChatHeader -export function useMemoryIndicator(modalVisible: boolean) { - const [hasNewMemory, setHasNewMemory] = useState(false) - - // Reset indicator when memory modal is opened - useEffect(() => { - if (modalVisible) { - setHasNewMemory(false) - } - }, [modalVisible]) - - // Listen for backend event that notifies new memory addition - useEffect(() => { - const handler = () => setHasNewMemory(true) - window.addEventListener("nexent:new-memory", handler as EventListener) - return () => window.removeEventListener("nexent:new-memory", handler as EventListener) - }, []) - - return hasNewMemory -} \ No newline at end of file diff --git a/frontend/public/locales/en/common.json b/frontend/public/locales/en/common.json index 5ee25a7b8..b8681a78a 100644 --- a/frontend/public/locales/en/common.json +++ b/frontend/public/locales/en/common.json @@ -285,6 +285,8 @@ "agent.contextMenu.export": "Export", "agent.contextMenu.delete": "Delete", + "agent.contextMenu.copy": "Copy", + "agent.copySuffix": "Copy", "agent.info.title": "Agent Information", "agent.info.name.error.empty": "Name cannot be empty", "agent.info.name.error.format": "Name can only contain letters, numbers and underscores, and must start with a letter or underscore", @@ -293,6 +295,9 @@ "agent.namePlaceholder": "Please enter agent variable name", "agent.displayName": "Agent Name", "agent.displayNamePlaceholder": "Please enter agent name", + "agent.author": "Author", + "agent.authorPlaceholder": "Please enter author name (optional)", + "agent.author.hint": "Default: {{email}}", "agent.description": "Agent Description", "agent.descriptionPlaceholder": "Please enter agent description", "agent.detailContent.title": "Agent Detail Content", @@ -413,7 +418,6 @@ "toolPool.error.requiredFields": "The following required fields are not filled: {{fields}}", "toolPool.tooltip.functionGuide": "1. For local knowledge base search functionality, please enable the knowledge_base_search tool;\n2. For text file parsing functionality, please enable the analyze_text_file tool;\n3. For image parsing functionality, please enable the analyze_image tool.", - "tool.message.unavailable": "This tool is currently unavailable and cannot be selected", "tool.error.noMainAgentId": "Main agent ID is not set, cannot update tool status", "tool.error.configFetchFailed": "Failed to get tool configuration", @@ -502,6 +506,7 @@ "document.summary.modelPlaceholder": "Select Model", "document.status.creating": "Creating...", "document.status.loadingList": "Loading document list...", + "document.status.waitingForTask": "Waiting for task creation...", "document.input.knowledgeBaseName": "Please enter knowledge base name", "document.button.details": "Details", "document.button.overview": "Overview", @@ -522,6 +527,24 @@ "document.status.completed": "Ready", "document.status.processFailed": "Process Failed", "document.status.forwardFailed": "Forward Failed", + "document.progress.chunksProcessed": "Processed {{processed}}/{{total}} chunks ({{percent}}%)", + "document.error.reason": "Error Reason", + "document.error.suggestion": "Suggestion", + "document.error.noReason": "No error reason available", + "document.error.code.ray_init_failed.message": "Failed to initialize Ray cluster", + "document.error.code.ray_init_failed.suggestion": "Please upgrade to the latest image version and redeploy.", + "document.error.code.no_valid_chunks.message": "The data processing kernel could not extract valid text from the document", + "document.error.code.no_valid_chunks.suggestion": "Please ensure the document format is supported and the content is not purely images.", + "document.error.code.vector_service_busy.message": "Vectorization model service is busy and cannot return vectors", + "document.error.code.vector_service_busy.suggestion": "Please switch the model service provider or try again later.", + "document.error.code.es_bulk_failed.message": "Failed to write vectors into the database", + "document.error.code.es_bulk_failed.suggestion": "Please ensure the Elasticsearch data path has sufficient disk space and write permissions.", + "document.error.code.es_dim_mismatch.message": "Embedding dimension does not match the Elasticsearch mapping", + "document.error.code.es_dim_mismatch.suggestion": "Please delete all embedding models and add the model again to try again.", + "document.error.code.embedding_chunks_exceed_limit.message": "The current chunk count exceeds the embedding model concurrency limit", + "document.error.code.embedding_chunks_exceed_limit.suggestion": "Please increase the chunk size to reduce the number of chunks and try again.", + "document.error.code.unsupported_file_format.message": "Unsupported line breaks detected in the document", + "document.error.code.unsupported_file_format.suggestion": "Please convert all line breaks to LF format and try again", "document.modal.deleteConfirm.title": "Confirm Delete Document", "document.modal.deleteConfirm.content": "Are you sure you want to delete this document? This action cannot be undone.", "document.message.noFiles": "Please select files first", @@ -655,6 +678,7 @@ "model.group.silicon": "Silicon Flow Models", "model.group.custom": "Custom Models", "model.status.tooltip": "Click to verify connectivity", + "model.dialog.embeddingConfig.title": "Edit Embedding Model: {{modelName}}", "appConfig.appName.label": "Application Name", "appConfig.appName.placeholder": "Please enter your application name", @@ -699,9 +723,14 @@ "modelConfig.button.addCustomModel": "Add Model", "modelConfig.button.editCustomModel": "Edit or Delete Model", "modelConfig.button.checkConnectivity": "Check Model Connectivity", + "modelConfig.button.sync": "Sync", + "modelConfig.button.add": "Add", + "modelConfig.button.edit": "Edit", + "modelConfig.button.check": "Check", "modelConfig.slider.chunkingSize": "Chunk Size", "modelConfig.slider.expectedChunkSize": "Expected Chunk Size", "modelConfig.slider.maximumChunkSize": "Maximum Chunk Size", + "modelConfig.input.chunkingBatchSize": "Concurrent Request Count", "businessLogic.title": "Describe how should this agent work", "businessLogic.placeholder": "Please describe your business scenario and requirements...", @@ -841,7 +870,7 @@ "mcpConfig.modal.updatingTools": "Updating tools list...", "mcpConfig.addServer.title": "Add MCP Server", "mcpConfig.addServer.namePlaceholder": "Server name", - "mcpConfig.addServer.urlPlaceholder": "Server URL (e.g.: http://localhost:3001/sse), currently only SSE protocol supported", + "mcpConfig.addServer.urlPlaceholder": "Server URL (e.g.: http://localhost:3001/mcp), currently supports sse and streamable-http protocols", "mcpConfig.addServer.button.add": "Add", "mcpConfig.addServer.button.updating": "Updating...", "mcpConfig.serverList.title": "Configured MCP Servers", @@ -909,6 +938,12 @@ "agentConfig.agents.createSubAgentIdFailed": "Failed to fetch creating sub agent ID, please try again later", "agentConfig.agents.detailsFetchFailed": "Failed to fetch agent details, please try again later", "agentConfig.agents.callRelationshipFetchFailed": "Failed to fetch agent call relationship, please try again later", + "agentConfig.agents.defaultDisplayName": "Agent", + "agentConfig.agents.copyConfirmTitle": "Confirm Copy", + "agentConfig.agents.copyConfirmContent": "Create a duplicate of {{name}}?", + "agentConfig.agents.copySuccess": "Agent copied successfully", + "agentConfig.agents.copyUnavailableTools": "Ignored {{count}} unavailable tools: {{names}}", + "agentConfig.agents.copyFailed": "Failed to copy Agent", "agentConfig.tools.refreshFailedDebug": "Failed to refresh tools list:", "agentConfig.agents.detailsLoadFailed": "Failed to load Agent details:", "agentConfig.agents.importFailed": "Failed to import Agent:", @@ -1117,6 +1152,7 @@ "market.category.all": "All", "market.category.other": "Other", "market.download": "Download", + "market.by": "By {{author}}", "market.downloading": "Downloading agent...", "market.downloadSuccess": "Agent downloaded successfully!", "market.downloadFailed": "Failed to download agent", @@ -1125,7 +1161,7 @@ "market.totalAgents": "Total {{total}} agents", "market.error.loadCategories": "Failed to load categories", "market.error.loadAgents": "Failed to load agents", - + "market.detail.title": "Agent Details", "market.detail.subtitle": "Complete information and configuration", "market.detail.tabs.basic": "Basic Info", @@ -1136,6 +1172,7 @@ "market.detail.id": "Agent ID", "market.detail.name": "Name", "market.detail.displayName": "Display Name", + "market.detail.author": "Author", "market.detail.description": "Description", "market.detail.businessDescription": "Business Description", "market.detail.category": "Category", @@ -1166,6 +1203,7 @@ "market.detail.viewDetails": "View Details", "market.install.title": "Install Agent", + "market.install.step.rename": "Rename Agent", "market.install.step.model": "Select Model", "market.install.step.config": "Configure Fields", "market.install.step.mcp": "MCP Servers", @@ -1203,7 +1241,31 @@ "market.install.error.mcpInstall": "Failed to install MCP server", "market.install.error.invalidData": "Invalid agent data", "market.install.error.installFailed": "Failed to install agent", + "market.install.error.noModelForRegeneration": "No available model for name regeneration", + "market.install.error.nameRegenerationFailed": "Failed to regenerate name", + "market.install.error.nameRequired": "Agent name is required", + "market.install.error.nameRequiredForAgent": "Agent name is required for {agent}", + "market.install.checkingName": "Checking agent name...", + "market.install.rename.warning": "The agent name or display name conflicts with existing agents. Please rename to proceed.", + "market.install.rename.conflictAgents": "Conflicting agents:", + "market.install.rename.name": "Agent Name", + "market.install.rename.regenerateWithLLM": "Regenerate with LLM", + "market.install.rename.regenerate": "Regenerate", + "market.install.rename.model": "Model for Regeneration", + "market.install.rename.modelPlaceholder": "Select a model", + "market.install.error.modelRequiredForRegeneration": "Please select a model first", + "market.install.rename.nameHint": "Original: {name}", + "market.install.rename.displayName": "Display Name", + "market.install.rename.displayNameHint": "Original: {name}", + "market.install.rename.note": "Note: If you proceed without renaming, the agent will be created but marked as unavailable due to name conflicts. You can rename it later in the agent list.", + "market.install.rename.oneClickDesc": "You can edit names manually, or use one-click rename to let the LLM generate new names for all conflicted agents.", + "market.install.rename.oneClick": "One-click Rename", + "market.install.rename.success": "All agent name conflicts have been resolved. You can proceed to the next step.", + "market.install.rename.partialSuccess": "Some agents have been successfully renamed.", + "market.install.rename.agentResolved": "This agent's name conflict has been resolved.", "market.install.success.mcpInstalled": "MCP server installed successfully", + "market.install.success.nameRegenerated": "Agent name regenerated successfully", + "market.install.success.nameRegeneratedAndResolved": "Agent names regenerated successfully and all conflicts resolved", "market.install.info.notImplemented": "Installation will be implemented in next phase", "market.install.success": "Agent installed successfully!", "market.error.fetchDetailFailed": "Failed to load agent details", @@ -1218,7 +1280,7 @@ "market.error.server.description": "The market server encountered an error. Our team has been notified. Please try again later.", "market.error.unknown.title": "Something Went Wrong", "market.error.unknown.description": "An unexpected error occurred. Please try again.", - + "common.loading": "Loading", "common.save": "Save", "common.cancel": "Cancel", diff --git a/frontend/public/locales/zh/common.json b/frontend/public/locales/zh/common.json index 65d80dacf..c0f8d851a 100644 --- a/frontend/public/locales/zh/common.json +++ b/frontend/public/locales/zh/common.json @@ -286,6 +286,8 @@ "agent.contextMenu.export": "导出", "agent.contextMenu.delete": "删除", + "agent.contextMenu.copy": "复制", + "agent.copySuffix": "副本", "agent.info.title": "Agent信息", "agent.info.name.error.empty": "名称不能为空", "agent.info.name.error.format": "名称只能包含字母、数字和下划线,且必须以字母或下划线开头", @@ -294,6 +296,9 @@ "agent.namePlaceholder": "请输入Agent变量名", "agent.displayName": "Agent名称", "agent.displayNamePlaceholder": "请输入Agent名称", + "agent.author": "作者", + "agent.authorPlaceholder": "请输入作者名称(可选)", + "agent.author.hint": "默认:{{email}}", "agent.description": "Agent描述", "agent.descriptionPlaceholder": "请输入Agent描述", "agent.detailContent.title": "Agent详细内容", @@ -370,7 +375,7 @@ "subAgentPool.tooltip.duplicateNameDisabled": "该智能体因与其他智能体同名而被禁用,请修改名称后使用", "subAgentPool.message.duplicateNameDisabled": "该智能体因与其他智能体同名而被禁用,请修改名称后使用", - "toolConfig.title.paramConfig": "参数配置", + "toolConfig.title.paramConfig": "配置参数", "toolConfig.message.loadError": "加载工具配置失败", "toolConfig.message.loadErrorUseDefault": "加载工具配置失败,使用默认配置", "toolConfig.message.saveSuccess": "工具配置保存成功", @@ -414,7 +419,6 @@ "toolPool.error.requiredFields": "以下必填字段未填写: {{fields}}", "toolPool.tooltip.functionGuide": "1. 本地知识库检索功能,请启用knowledge_base_search工具;\n2. 文本文件解析功能,请启用analyze_text_file工具;\n3. 图片解析功能,请启用analyze_image工具。", - "tool.message.unavailable": "该工具当前不可用,无法选择", "tool.error.noMainAgentId": "主代理ID未设置,无法更新工具状态", "tool.error.configFetchFailed": "获取工具配置失败", @@ -503,6 +507,7 @@ "document.summary.modelPlaceholder": "选择模型", "document.status.creating": "创建中...", "document.status.loadingList": "正在加载文档列表...", + "document.status.waitingForTask": "正在等待任务创建...", "document.input.knowledgeBaseName": "请输入知识库名称", "document.button.details": "详细内容", "document.button.overview": "概览", @@ -523,6 +528,24 @@ "document.status.completed": "已就绪", "document.status.processFailed": "解析失败", "document.status.forwardFailed": "入库失败", + "document.progress.chunksProcessed": "已处理 {{processed}}/{{total}} 个切片 ({{percent}}%)", + "document.error.reason": "错误原因", + "document.error.suggestion": "建议", + "document.error.noReason": "暂无错误原因", + "document.error.code.ray_init_failed.message": "Ray集群初始化失败", + "document.error.code.ray_init_failed.suggestion": "请升级到最新版本并尝试重新部署", + "document.error.code.no_valid_chunks.message": "数据处理内核无法从文档中提取有效文本", + "document.error.code.no_valid_chunks.suggestion": "请确保文档内容非纯图像", + "document.error.code.vector_service_busy.message": "向量化模型服务繁忙,无法获取文本向量", + "document.error.code.vector_service_busy.suggestion": "请更换模型服务提供商,或稍后重试", + "document.error.code.es_bulk_failed.message": "向量录入数据库错误", + "document.error.code.es_bulk_failed.suggestion": "请确保Elasticsearch路径拥有完整写入权限,且存储空间与内存充足", + "document.error.code.es_dim_mismatch.message": "向量化模型维度与Elasticsearch维度不匹配", + "document.error.code.es_dim_mismatch.suggestion": "建议删除所有向量化模型后再添加模型重试", + "document.error.code.embedding_chunks_exceed_limit.message": "当前切片数量超过向量化模型并行度", + "document.error.code.embedding_chunks_exceed_limit.suggestion": "请增加切片大小以减少切片数量后再试", + "document.error.code.unsupported_file_format.message": "检测到当前文档中存在不支持的换行符", + "document.error.code.unsupported_file_format.suggestion": "建议统一转换为LF换行符再试", "document.modal.deleteConfirm.title": "确认删除文档", "document.modal.deleteConfirm.content": "确定要删除这个文档吗?删除后无法恢复。", "document.message.noFiles": "请先选择文件", @@ -655,6 +678,7 @@ "model.group.custom": "自定义模型", "model.status.tooltip": "点击可验证连通性", "model.dialog.success.updateSuccess": "更新成功", + "model.dialog.embeddingConfig.title": "修改向量模型: {{modelName}}", "appConfig.appName.label": "应用名称", "appConfig.appName.placeholder": "请输入您的应用名称", @@ -699,9 +723,14 @@ "modelConfig.button.addCustomModel": "添加模型", "modelConfig.button.editCustomModel": "修改或删除模型", "modelConfig.button.checkConnectivity": "检查模型连通性", + "modelConfig.button.sync": "同步", + "modelConfig.button.add": "添加", + "modelConfig.button.edit": "修改", + "modelConfig.button.check": "检查", "modelConfig.slider.chunkingSize": "文档切片大小", "modelConfig.slider.expectedChunkSize": "期望切片大小", "modelConfig.slider.maximumChunkSize": "最大切片大小", + "modelConfig.input.chunkingBatchSize": "单次请求切片量", "businessLogic.title": "描述 Agent 应该如何工作", "businessLogic.placeholder": "请描述您的业务场景和需求...", @@ -841,7 +870,7 @@ "mcpConfig.modal.updatingTools": "正在更新工具列表...", "mcpConfig.addServer.title": "添加MCP服务器", "mcpConfig.addServer.namePlaceholder": "服务器名称", - "mcpConfig.addServer.urlPlaceholder": "服务器URL (如: http://localhost:3001/sse),目前仅支持sse协议", + "mcpConfig.addServer.urlPlaceholder": "服务器URL (如: http://localhost:3001/mcp),目前支持sse和streamable-http协议", "mcpConfig.addServer.button.add": "添加", "mcpConfig.addServer.button.updating": "更新中...", "mcpConfig.serverList.title": "已配置的MCP服务器", @@ -909,6 +938,12 @@ "agentConfig.agents.createSubAgentIdFailed": "获取创建子Agent ID失败,请稍后重试", "agentConfig.agents.detailsFetchFailed": "获取Agent详情失败,请稍后重试", "agentConfig.agents.callRelationshipFetchFailed": "获取Agent调用关系失败,请稍后重试", + "agentConfig.agents.defaultDisplayName": "智能体", + "agentConfig.agents.copyConfirmTitle": "确认复制", + "agentConfig.agents.copyConfirmContent": "确定要复制 {{name}} 吗?", + "agentConfig.agents.copySuccess": "Agent复制成功", + "agentConfig.agents.copyUnavailableTools": "已忽略{{count}}个不可用工具:{{names}}", + "agentConfig.agents.copyFailed": "Agent复制失败", "agentConfig.tools.refreshFailedDebug": "刷新工具列表失败:", "agentConfig.agents.detailsLoadFailed": "加载Agent详情失败:", "agentConfig.agents.importFailed": "导入Agent失败:", @@ -1081,7 +1116,7 @@ "sidebar.memoryManagement": "记忆管理", "sidebar.userManagement": "用户管理", "sidebar.mcpToolsManagement": "MCP 工具", - "sidebar.monitoringManagement": "监控与运维", + "sidebar.monitoringManagement": "监控与运维", "market.comingSoon.title": "智能体市场即将推出", "market.comingSoon.description": "从我们的市场中发现并安装预构建的AI智能体。通过使用社区创建的解决方案节省时间。", @@ -1096,6 +1131,7 @@ "market.category.all": "全部", "market.category.other": "其他", "market.download": "下载", + "market.by": "作者:{{author}}", "market.downloading": "正在下载智能体...", "market.downloadSuccess": "智能体下载成功!", "market.downloadFailed": "下载智能体失败", @@ -1104,7 +1140,7 @@ "market.totalAgents": "共 {{total}} 个智能体", "market.error.loadCategories": "加载分类失败", "market.error.loadAgents": "加载智能体失败", - + "market.detail.title": "智能体详情", "market.detail.subtitle": "完整信息和配置", "market.detail.tabs.basic": "基础信息", @@ -1115,6 +1151,7 @@ "market.detail.id": "智能体 ID", "market.detail.name": "名称", "market.detail.displayName": "显示名称", + "market.detail.author": "作者", "market.detail.description": "描述", "market.detail.businessDescription": "业务描述", "market.detail.category": "分类", @@ -1145,6 +1182,7 @@ "market.detail.viewDetails": "查看详情", "market.install.title": "安装智能体", + "market.install.step.rename": "重命名智能体", "market.install.step.model": "选择模型", "market.install.step.config": "配置字段", "market.install.step.mcp": "MCP 服务器", @@ -1182,7 +1220,31 @@ "market.install.error.mcpInstall": "安装 MCP 服务器失败", "market.install.error.invalidData": "无效的智能体数据", "market.install.error.installFailed": "安装智能体失败", + "market.install.error.noModelForRegeneration": "没有可用的模型用于名称重新生成", + "market.install.error.nameRegenerationFailed": "重新生成名称失败", + "market.install.error.nameRequired": "智能体名称为必填项", + "market.install.error.nameRequiredForAgent": "智能体 {agent} 的名称为必填项", + "market.install.checkingName": "正在检查智能体名称...", + "market.install.rename.warning": "智能体名称或显示名称与现有智能体冲突,请重命名以继续。", + "market.install.rename.conflictAgents": "冲突的智能体:", + "market.install.rename.name": "智能体名称", + "market.install.rename.regenerateWithLLM": "使用 LLM 重新生成", + "market.install.rename.regenerate": "重新生成", + "market.install.rename.model": "用于重新生成名称的模型", + "market.install.rename.modelPlaceholder": "选择一个模型", + "market.install.error.modelRequiredForRegeneration": "请先选择一个模型", + "market.install.rename.nameHint": "原始名称:{name}", + "market.install.rename.displayName": "显示名称", + "market.install.rename.displayNameHint": "原始名称:{name}", + "market.install.rename.note": "注意:如果您不重命名就继续,智能体将被创建但由于名称冲突会被标记为不可用。您可以在智能体列表中稍后重命名。", + "market.install.rename.oneClickDesc": "可手动修改名称,或一键重命名使用大模型为所有冲突智能体生成新名称。", + "market.install.rename.oneClick": "一键重命名", + "market.install.rename.success": "所有智能体名称冲突已解决。您可以继续下一步。", + "market.install.rename.partialSuccess": "部分智能体已成功重命名。", + "market.install.rename.agentResolved": "此智能体的名称冲突已解决。", "market.install.success.mcpInstalled": "MCP 服务器安装成功", + "market.install.success.nameRegenerated": "智能体名称重新生成成功", + "market.install.success.nameRegeneratedAndResolved": "智能体名称重新生成成功,且所有冲突已解决", "market.install.info.notImplemented": "安装功能将在下一阶段实现", "market.install.success": "智能体安装成功!", "market.error.fetchDetailFailed": "加载智能体详情失败", @@ -1211,14 +1273,14 @@ "mcpTools.comingSoon.feature2": "同步、查看和组织 MCP 工具列表", "mcpTools.comingSoon.feature3": "监控 MCP 连接状态和使用情况", "mcpTools.comingSoon.badge": "即将推出", - + "monitoring.comingSoon.title": "监控与运维中心即将推出", "monitoring.comingSoon.description": "面向智能体的统一监控与运维中心,用于实时跟踪健康状态、性能指标与异常事件。", "monitoring.comingSoon.feature1": "监控智能体健康状态、延迟与错误率", "monitoring.comingSoon.feature2": "查看并筛选智能体运行日志和历史任务", "monitoring.comingSoon.feature3": "配置告警策略与关键事件的运维操作", "monitoring.comingSoon.badge": "即将推出", - + "common.loading": "加载中", "common.save": "保存", "common.cancel": "取消", diff --git a/frontend/services/agentConfigService.ts b/frontend/services/agentConfigService.ts index f7f084f6b..3cff1e884 100644 --- a/frontend/services/agentConfigService.ts +++ b/frontend/services/agentConfigService.ts @@ -116,6 +116,7 @@ export const fetchAgentList = async () => { name: agent.name, display_name: agent.display_name || agent.name, description: agent.description, + author: agent.author, is_available: agent.is_available, unavailable_reasons: agent.unavailable_reasons || [], })); @@ -326,7 +327,8 @@ export const updateAgent = async ( businessLogicModelName?: string, businessLogicModelId?: number, enabledToolIds?: number[], - relatedAgentIds?: number[] + relatedAgentIds?: number[], + author?: string ) => { try { const response = await fetch(API_ENDPOINTS.agent.update, { @@ -350,6 +352,7 @@ export const updateAgent = async ( business_logic_model_id: businessLogicModelId, enabled_tool_ids: enabledToolIds, related_agent_ids: relatedAgentIds, + author: author, }), }); @@ -485,6 +488,76 @@ export const importAgent = async ( } }; +/** + * check agent name/display_name duplication + * @param payload name/displayName to check + */ +export const checkAgentNameConflictBatch = async (payload: { + items: Array<{ name: string; display_name?: string; agent_id?: number }>; +}) => { + try { + const response = await fetch(API_ENDPOINTS.agent.checkNameBatch, { + method: "POST", + headers: getAuthHeaders(), + body: JSON.stringify(payload), + }); + + if (!response.ok) { + throw new Error(`Request failed: ${response.status}`); + } + + const data = await response.json(); + return { + success: true, + data, + message: "", + }; + } catch (error) { + log.error("Failed to check agent name conflict batch:", error); + return { + success: false, + data: null, + message: "agentConfig.agents.checkNameFailed", + }; + } +}; + +export const regenerateAgentNameBatch = async (payload: { + items: Array<{ + name: string; + display_name?: string; + task_description?: string; + language?: string; + agent_id?: number; + }>; +}) => { + try { + const response = await fetch(API_ENDPOINTS.agent.regenerateNameBatch, { + method: "POST", + headers: getAuthHeaders(), + body: JSON.stringify(payload), + }); + + if (!response.ok) { + throw new Error(`Request failed: ${response.status}`); + } + + const data = await response.json(); + return { + success: true, + data, + message: "", + }; + } catch (error) { + log.error("Failed to regenerate agent name batch:", error); + return { + success: false, + data: null, + message: "agentConfig.agents.regenerateNameFailed", + }; + } +}; + /** * search agent info by agent id * @param agentId agent id @@ -510,6 +583,7 @@ export const searchAgentInfo = async (agentId: number) => { name: data.name, display_name: data.display_name, description: data.description, + author: data.author, model: data.model_name, model_id: data.model_id, max_step: data.max_steps, @@ -587,6 +661,7 @@ export const fetchAllAgents = async () => { name: agent.name, display_name: agent.display_name || agent.name, description: agent.description, + author: agent.author, is_available: agent.is_available, })); diff --git a/frontend/services/api.ts b/frontend/services/api.ts index 0af193d52..20d89b6f2 100644 --- a/frontend/services/api.ts +++ b/frontend/services/api.ts @@ -37,6 +37,8 @@ export const API_ENDPOINTS = { `${API_BASE_URL}/agent/stop/${conversationId}`, export: `${API_BASE_URL}/agent/export`, import: `${API_BASE_URL}/agent/import`, + checkNameBatch: `${API_BASE_URL}/agent/check_name`, + regenerateNameBatch: `${API_BASE_URL}/agent/regenerate_name`, searchInfo: `${API_BASE_URL}/agent/search_info`, callRelationship: `${API_BASE_URL}/agent/call_relationship`, }, @@ -142,6 +144,11 @@ export const API_ENDPOINTS = { // File upload service upload: `${API_BASE_URL}/file/upload`, process: `${API_BASE_URL}/file/process`, + // Error info service + getErrorInfo: (indexName: string, pathOrUrl: string) => + `${API_BASE_URL}/indices/${indexName}/documents/${encodeURIComponent( + pathOrUrl + )}/error-info`, }, config: { save: `${API_BASE_URL}/config/save_config`, diff --git a/frontend/services/knowledgeBasePollingService.ts b/frontend/services/knowledgeBasePollingService.ts index 568205b21..b899d8bdf 100644 --- a/frontend/services/knowledgeBasePollingService.ts +++ b/frontend/services/knowledgeBasePollingService.ts @@ -11,8 +11,12 @@ class KnowledgeBasePollingService { private knowledgeBasePollingInterval: number = 1000; // 1 second private documentPollingInterval: number = 3000; // 3 seconds private maxKnowledgeBasePolls: number = 60; // Maximum 60 polling attempts - private maxDocumentPolls: number = 20; // Maximum 20 polling attempts + private maxDocumentPolls: number = 200; // Maximum 200 polling attempts (10 minutes for long-running tasks) private activeKnowledgeBaseId: string | null = null; // Record current active knowledge base ID + private pendingRequests: Map> = new Map(); + + // Debounce timers for batching multiple rapid requests + private debounceTimers: Map = new Map(); // Set current active knowledge base ID setActiveKnowledgeBase(kbId: string | null): void { @@ -29,11 +33,16 @@ class KnowledgeBasePollingService { // Initialize polling counter let pollCount = 0; + // Track if we're in extended polling mode (after initial timeout) + let isExtendedPolling = false; + // Define the polling logic function const pollDocuments = async () => { try { - // Increment polling counter - pollCount++; + // Increment polling counter only if not in extended polling mode + if (!isExtendedPolling) { + pollCount++; + } // If there is an active knowledge base and polling knowledge base doesn't match active one, stop polling if (this.activeKnowledgeBaseId !== null && this.activeKnowledgeBaseId !== kbId) { @@ -41,24 +50,28 @@ class KnowledgeBasePollingService { return; } - // If exceeded maximum polling count, handle timeout - if (pollCount > this.maxDocumentPolls) { - log.warn(`Document polling for knowledge base ${kbId} timed out after ${this.maxDocumentPolls} attempts`); - await this.handlePollingTimeout(kbId, 'document', callback); - // Push documents to UI + // Use request deduplication to avoid concurrent duplicate requests + let documents: Document[]; + const requestKey = `poll:${kbId}`; + + // Check if there's already a pending request for this KB + const pendingRequest = this.pendingRequests.get(requestKey); + if (pendingRequest) { + // Reuse existing request to avoid duplicate API calls + documents = await pendingRequest; + } else { + // Create new request and track it + const requestPromise = knowledgeBaseService.getAllFiles(kbId); + this.pendingRequests.set(requestKey, requestPromise); + try { - const documents = await knowledgeBaseService.getAllFiles(kbId); - this.triggerDocumentsUpdate(kbId, documents); - } catch (e) { - // Ignore error + documents = await requestPromise; + } finally { + // Clean up after request completes + this.pendingRequests.delete(requestKey); } - this.stopPolling(kbId); - return; } - // Get latest document status - const documents = await knowledgeBaseService.getAllFiles(kbId); - // Call callback function with latest documents first to ensure UI updates immediately callback(documents); @@ -67,6 +80,18 @@ class KnowledgeBasePollingService { NON_TERMINAL_STATUSES.includes(doc.status) ); + // If exceeded maximum polling count and still processing, switch to extended polling mode + if (pollCount > this.maxDocumentPolls && hasProcessingDocs && !isExtendedPolling) { + log.warn(`Document polling for knowledge base ${kbId} exceeded ${this.maxDocumentPolls} attempts, switching to extended polling mode (reduced frequency)`); + isExtendedPolling = true; + // Stop the current interval and restart with longer interval + this.stopPolling(kbId); + // Continue polling with reduced frequency (every 10 seconds) + const extendedInterval = setInterval(pollDocuments, 10000); + this.pollingIntervals.set(kbId, extendedInterval); + return; + } + // If there are processing documents, continue polling if (hasProcessingDocs) { log.log('Documents processing, continue polling'); @@ -141,6 +166,7 @@ class KnowledgeBasePollingService { * @param expectedIncrement The number of new files uploaded */ pollForKnowledgeBaseReady( + kbId: string, kbName: string, originalDocumentCount: number = 0, expectedIncrement: number = 0 @@ -150,29 +176,14 @@ class KnowledgeBasePollingService { const checkForStats = async () => { try { const kbs = await knowledgeBaseService.getKnowledgeBasesInfo(true) as KnowledgeBase[]; - const kb = kbs.find(k => k.name === kbName); + const kb = kbs.find(k => k.id === kbId || k.name === kbName); // Check if KB exists and its stats are populated if (kb) { - // If expectedIncrement > 0, check if documentCount increased as expected - if ( - expectedIncrement > 0 && - kb.documentCount >= (originalDocumentCount + expectedIncrement) - ) { - log.log( - `Knowledge base ${kbName} documentCount increased as expected: ${kb.documentCount} (was ${originalDocumentCount}, expected increment ${expectedIncrement})` - ); - this.triggerKnowledgeBaseListUpdate(true); - resolve(kb); - return; - } - // Fallback: for new KB or no increment specified, use old logic - if (expectedIncrement === 0 && (kb.documentCount > 0 || kb.chunkCount > 0)) { - log.log(`Knowledge base ${kbName} is ready and stats are populated.`); - this.triggerKnowledgeBaseListUpdate(true); - resolve(kb); - return; - } + log.log(`Knowledge base ${kbName} detected.`); + this.triggerKnowledgeBaseListUpdate(true); + resolve(kb); + return; } count++; @@ -183,11 +194,11 @@ class KnowledgeBasePollingService { log.error(`Knowledge base ${kbName} readiness check timed out after ${this.maxKnowledgeBasePolls} attempts.`); // Handle knowledge base polling timeout - mark related tasks as failed - await this.handlePollingTimeout(kbName, 'knowledgeBase'); + await this.handlePollingTimeout(kbId, 'knowledgeBase'); // Push documents to UI try { - const documents = await knowledgeBaseService.getAllFiles(kbName); - this.triggerDocumentsUpdate(kbName, documents); + const documents = await knowledgeBaseService.getAllFiles(kbId); + this.triggerDocumentsUpdate(kbId, documents); } catch (e) { // Ignore error } @@ -201,11 +212,11 @@ class KnowledgeBasePollingService { setTimeout(checkForStats, this.knowledgeBasePollingInterval); } else { // Handle knowledge base polling timeout on error as well - await this.handlePollingTimeout(kbName, 'knowledgeBase'); + await this.handlePollingTimeout(kbId, 'knowledgeBase'); // Push documents to UI try { - const documents = await knowledgeBaseService.getAllFiles(kbName); - this.triggerDocumentsUpdate(kbName, documents); + const documents = await knowledgeBaseService.getAllFiles(kbId); + this.triggerDocumentsUpdate(kbId, documents); } catch (e) { // Ignore error } @@ -218,14 +229,14 @@ class KnowledgeBasePollingService { } // Simplified method for new knowledge base creation workflow - async handleNewKnowledgeBaseCreation(kbName: string, originalDocumentCount: number = 0, expectedIncrement: number = 0, callback: (kb: KnowledgeBase) => void) { + async handleNewKnowledgeBaseCreation(kbId: string, kbName: string, originalDocumentCount: number = 0, expectedIncrement: number = 0, callback: (kb: KnowledgeBase) => void) { // Start document polling - this.startDocumentStatusPolling(kbName, (documents) => { - this.triggerDocumentsUpdate(kbName, documents); + this.startDocumentStatusPolling(kbId, (documents) => { + this.triggerDocumentsUpdate(kbId, documents); }); try { // Start knowledge base polling parallelly - const populatedKB = await this.pollForKnowledgeBaseReady(kbName, originalDocumentCount, expectedIncrement); + const populatedKB = await this.pollForKnowledgeBaseReady(kbId, kbName, originalDocumentCount, expectedIncrement); // callback with populated knowledge base when everything is ready callback(populatedKB); } catch (error) { @@ -249,6 +260,13 @@ class KnowledgeBasePollingService { clearInterval(interval); }); this.pollingIntervals.clear(); + + // Clear pending requests and debounce timers to prevent memory leaks + this.pendingRequests.clear(); + this.debounceTimers.forEach((timer) => { + clearTimeout(timer); + }); + this.debounceTimers.clear(); } // Trigger knowledge base list update (optionally force refresh) diff --git a/frontend/services/knowledgeBaseService.ts b/frontend/services/knowledgeBaseService.ts index 27a6e0b38..0ea443081 100644 --- a/frontend/services/knowledgeBaseService.ts +++ b/frontend/services/knowledgeBaseService.ts @@ -71,15 +71,20 @@ class KnowledgeBaseService { // Convert Elasticsearch indices to knowledge base format knowledgeBases = data.indices_info.map((indexInfo: any) => { const stats = indexInfo.stats?.base_info || {}; + // Backend now returns: + // - name: internal index_name + // - display_name: user-facing knowledge_name (fallback to index_name) + const kbId = indexInfo.name; + const kbName = indexInfo.display_name || indexInfo.name; return { - id: indexInfo.name, - name: indexInfo.name, + id: kbId, + name: kbName, description: "Elasticsearch index", documentCount: stats.doc_count || 0, chunkCount: stats.chunk_count || 0, - createdAt: - stats.creation_date || new Date().toISOString().split("T")[0], + createdAt: stats.creation_date || null, + updatedAt: stats.update_date || stats.creation_date || null, embeddingModel: stats.embedding_model || "unknown", avatar: "", chunkNum: 0, @@ -276,6 +281,16 @@ class KnowledgeBaseService { token_num: 0, status: file.status || "UNKNOWN", latest_task_id: file.latest_task_id || "", + error_reason: file.error_reason, + // Optional ingestion progress metrics (only present for in-progress files) + processed_chunk_num: + typeof file.processed_chunk_num === "number" + ? file.processed_chunk_num + : null, + total_chunk_num: + typeof file.total_chunk_num === "number" + ? file.total_chunk_num + : null, })); } catch (error) { log.error("Failed to get all files:", error); @@ -806,6 +821,41 @@ class KnowledgeBaseService { throw new Error("Failed to execute hybrid search"); } } + + // Get error information for a document + async getDocumentErrorInfo( + kbId: string, + docId: string + ): Promise<{ + errorCode: string | null; + }> { + try { + const response = await fetch( + API_ENDPOINTS.knowledgeBase.getErrorInfo(kbId, docId), + { + headers: getAuthHeaders(), + } + ); + + if (!response.ok) { + throw new Error(`HTTP error! status: ${response.status}`); + } + + const data = await response.json(); + if (data.status !== "success") { + throw new Error(data.message || "Failed to get error info"); + } + + const errorCode = (data.error_code && String(data.error_code)) || null; + + return { + errorCode, + }; + } catch (error) { + log.error("Failed to get document error info:", error); + throw error; + } + } } // Export a singleton instance diff --git a/frontend/services/modelService.ts b/frontend/services/modelService.ts index 9de2c5483..3599bc939 100644 --- a/frontend/services/modelService.ts +++ b/frontend/services/modelService.ts @@ -67,6 +67,7 @@ export const modelService = { (model.connect_status as ModelConnectStatus) || "not_detected", expectedChunkSize: model.expected_chunk_size, maximumChunkSize: model.maximum_chunk_size, + chunkingBatchSize: model.chunk_batch, })); } return []; @@ -97,6 +98,7 @@ export const modelService = { displayName?: string; expectedChunkSize?: number; maximumChunkSize?: number; + chunkingBatchSize?: number; }): Promise => { try { const response = await fetch(API_ENDPOINTS.model.customModelCreate, { @@ -112,6 +114,7 @@ export const modelService = { display_name: model.displayName, expected_chunk_size: model.expectedChunkSize, maximum_chunk_size: model.maximumChunkSize, + chunk_batch: model.chunkingBatchSize, }), }); @@ -239,6 +242,7 @@ export const modelService = { source?: ModelSource; expectedChunkSize?: number; maximumChunkSize?: number; + chunkingBatchSize?: number; }): Promise => { try { const response = await fetch( @@ -262,6 +266,9 @@ export const modelService = { ...(model.maximumChunkSize !== undefined ? { maximum_chunk_size: model.maximumChunkSize } : {}), + ...(model.chunkingBatchSize !== undefined + ? { chunk_batch: model.chunkingBatchSize } + : {}), }), } ); diff --git a/frontend/services/storageService.ts b/frontend/services/storageService.ts index ec60eb187..a45add994 100644 --- a/frontend/services/storageService.ts +++ b/frontend/services/storageService.ts @@ -123,6 +123,68 @@ export function convertImageUrlToApiUrl(url: string): string { return url; } +const arrayBufferToBase64 = (buffer: ArrayBuffer): string => { + let binary = ""; + const bytes = new Uint8Array(buffer); + const chunkSize = 0x8000; + + for (let i = 0; i < bytes.length; i += chunkSize) { + const chunk = bytes.subarray(i, i + chunkSize); + binary += String.fromCharCode(...chunk); + } + + return btoa(binary); +}; + +const fetchBase64ViaStorage = async (objectName: string) => { + const response = await fetch(API_ENDPOINTS.storage.file(objectName, "base64")); + if (!response.ok) { + throw new Error(`Failed to resolve S3 URL via storage: ${response.status}`); + } + + const data = await response.json(); + if (!data?.success || !data?.base64) { + throw new Error(data?.error || "Storage response missing base64 content"); + } + + const contentType = data.content_type || "application/octet-stream"; + return { base64: data.base64 as string, contentType }; +}; + +// Cache for S3 URL to data URL resolution to avoid duplicate network requests +const s3ResolutionCache = new Map>(); + +// Internal helper: for s3:// URLs, resolve directly via storage download endpoint. +async function resolveS3UrlToDataUrlInternal(url: string): Promise { + const objectName = extractObjectNameFromUrl(url); + if (!objectName) { + return null; + } + + const { base64, contentType } = await fetchBase64ViaStorage(objectName); + return `data:${contentType};base64,${base64}`; +} + +export async function resolveS3UrlToDataUrl(url: string): Promise { + if (!url || !url.startsWith("s3://")) { + return null; + } + + const cached = s3ResolutionCache.get(url); + if (cached) { + return cached; + } + + const promise = resolveS3UrlToDataUrlInternal(url).catch((error) => { + // Remove from cache on failure so that future attempts can retry. + s3ResolutionCache.delete(url); + throw error; + }); + + s3ResolutionCache.set(url, promise); + return promise; +} + export const storageService = { /** * Upload files to storage service diff --git a/frontend/styles/globals.css b/frontend/styles/globals.css index 7d6b1749d..ad666027d 100644 --- a/frontend/styles/globals.css +++ b/frontend/styles/globals.css @@ -305,4 +305,23 @@ .kb-embedding-warning .ant-modal { width: max-content; min-width: 0; +} + +/* Responsive button text - global utility */ +@media (max-width: 1279px) { + .button-text-full { + display: none !important; + } + .button-text-short { + display: inline !important; + } +} + +@media (min-width: 1280px) { + .button-text-full { + display: inline !important; + } + .button-text-short { + display: none !important; + } } \ No newline at end of file diff --git a/frontend/types/agentConfig.ts b/frontend/types/agentConfig.ts index 3dc41c601..1a766788c 100644 --- a/frontend/types/agentConfig.ts +++ b/frontend/types/agentConfig.ts @@ -12,6 +12,7 @@ export interface Agent { name: string; display_name?: string; description: string; + author?: string; unavailable_reasons?: string[]; model: string; model_id?: number; @@ -127,6 +128,8 @@ export interface AgentSetupOrchestratorProps { setAgentDescription?: (value: string) => void; agentDisplayName?: string; setAgentDisplayName?: (value: string) => void; + agentAuthor?: string; + setAgentAuthor?: (value: string) => void; isGeneratingAgent?: boolean; onDebug?: () => void; getCurrentAgentId?: () => number | undefined; @@ -156,6 +159,7 @@ export interface SubAgentPoolProps { isGeneratingAgent?: boolean; editingAgent?: Agent | null; isCreatingNewAgent?: boolean; + onCopyAgent?: (agent: Agent) => void; onExportAgent?: (agent: Agent) => void; onDeleteAgent?: (agent: Agent) => void; } diff --git a/frontend/types/chat.ts b/frontend/types/chat.ts index 700edfdbf..826722055 100644 --- a/frontend/types/chat.ts +++ b/frontend/types/chat.ts @@ -9,6 +9,7 @@ export interface StepSection { export interface StepContent { id: string type: typeof chatConfig.messageTypes.MODEL_OUTPUT | + typeof chatConfig.messageTypes.MODEL_OUTPUT_CODE | typeof chatConfig.messageTypes.PARSING | typeof chatConfig.messageTypes.EXECUTION | typeof chatConfig.messageTypes.ERROR | diff --git a/frontend/types/knowledgeBase.ts b/frontend/types/knowledgeBase.ts index 85a5e6b12..e04f145c7 100644 --- a/frontend/types/knowledgeBase.ts +++ b/frontend/types/knowledgeBase.ts @@ -4,21 +4,23 @@ import { DOCUMENT_ACTION_TYPES, KNOWLEDGE_BASE_ACTION_TYPES, UI_ACTION_TYPES, NO // Knowledge base basic type export interface KnowledgeBase { - id: string - name: string - description: string | null - chunkCount: number - documentCount: number - createdAt: any - embeddingModel: string - avatar: string - chunkNum: number - language: string - nickname: string - parserId: string - permission: string - tokenNum: number - source: string + id: string; + name: string; + description: string | null; + chunkCount: number; + documentCount: number; + createdAt: any; + // Last update time of the knowledge base/index (may fall back to createdAt) + updatedAt?: any; + embeddingModel: string; + avatar: string; + chunkNum: number; + language: string; + nickname: string; + parserId: string; + permission: string; + tokenNum: number; + source: string; } // Create knowledge base parameter type @@ -31,17 +33,21 @@ export interface KnowledgeBaseCreateParams { // Document type export interface Document { - id: string - kb_id: string - name: string - type: string - size: number - create_time: string - chunk_num: number - token_num: number - status: string - selected?: boolean // For UI selection status - latest_task_id: string // For marking the latest celery task + id: string; + kb_id: string; + name: string; + type: string; + size: number; + create_time: string; + chunk_num: number; + token_num: number; + status: string; + selected?: boolean; // For UI selection status + latest_task_id: string; // For marking the latest celery task + error_reason?: string; // Error reason for failed documents + // Optional ingestion progress metrics + processed_chunk_num?: number | null; + total_chunk_num?: number | null; } // Document state interface diff --git a/frontend/types/market.ts b/frontend/types/market.ts index 888afffdb..770e39520 100644 --- a/frontend/types/market.ts +++ b/frontend/types/market.ts @@ -28,6 +28,7 @@ export interface MarketAgentListItem { name: string; display_name: string; description: string; + author?: string; category: MarketCategory; tags: MarketTag[]; download_count: number; diff --git a/frontend/types/modelConfig.ts b/frontend/types/modelConfig.ts index db97a8c0d..0d463161f 100644 --- a/frontend/types/modelConfig.ts +++ b/frontend/types/modelConfig.ts @@ -45,6 +45,7 @@ export interface ModelOption { connect_status?: ModelConnectStatus; expectedChunkSize?: number; maximumChunkSize?: number; + chunkingBatchSize?: number; } // Application configuration interface diff --git a/sdk/nexent/core/agents/agent_model.py b/sdk/nexent/core/agents/agent_model.py index 6eff00718..f3c5a77b7 100644 --- a/sdk/nexent/core/agents/agent_model.py +++ b/sdk/nexent/core/agents/agent_model.py @@ -1,7 +1,7 @@ from __future__ import annotations from threading import Event -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from pydantic import BaseModel, Field @@ -50,7 +50,12 @@ class AgentRunInfo(BaseModel): model_config_list: List[ModelConfig] = Field(description="List of model configurations") observer: MessageObserver = Field(description="Return data") agent_config: AgentConfig = Field(description="Detailed Agent configuration") - mcp_host: Optional[List[str]] = Field(description="MCP server address", default=None) + mcp_host: Optional[List[Union[str, Dict[str, Any]]]] = Field( + description="MCP server address(es). Can be a string (URL) or dict with 'url' and 'transport' keys. " + "Transport can be 'sse' or 'streamable-http'. If string, transport is auto-detected based on URL ending: " + "URLs ending with '/sse' use 'sse' transport, URLs ending with '/mcp' use 'streamable-http' transport.", + default=None + ) history: Optional[List[AgentHistory]] = Field(description="Historical conversation information", default=None) stop_event: Event = Field(description="Stop event control") diff --git a/sdk/nexent/core/agents/core_agent.py b/sdk/nexent/core/agents/core_agent.py index 826ef7093..be7b83b5e 100644 --- a/sdk/nexent/core/agents/core_agent.py +++ b/sdk/nexent/core/agents/core_agent.py @@ -1,3 +1,4 @@ +import json import re import ast import time @@ -9,12 +10,13 @@ from rich.console import Group from rich.text import Text -from smolagents.agents import CodeAgent, handle_agent_output_types, AgentError +from smolagents.agents import CodeAgent, handle_agent_output_types, AgentError, ActionOutput, RunResult from smolagents.local_python_executor import fix_final_answer_code from smolagents.memory import ActionStep, PlanningStep, FinalAnswerStep, ToolCall, TaskStep, SystemPromptStep -from smolagents.models import ChatMessage -from smolagents.monitoring import LogLevel -from smolagents.utils import AgentExecutionError, AgentGenerationError, truncate_content +from smolagents.models import ChatMessage, CODEAGENT_RESPONSE_FORMAT +from smolagents.monitoring import LogLevel, Timing, YELLOW_HEX, TokenUsage +from smolagents.utils import AgentExecutionError, AgentGenerationError, truncate_content, AgentMaxStepsError, \ + extract_code_from_text from ..utils.observer import MessageObserver, ProcessType from jinja2 import Template, StrictUndefined @@ -125,13 +127,17 @@ def _step_stream(self, memory_step: ActionStep) -> Generator[Any]: # Add new step in logs memory_step.model_input_messages = input_messages + stop_sequences = ["", "Observation:", "Calling tools:", "", "Observation:", "Calling tools:", " Generator[Any]: # Parse try: - code_action = fix_final_answer_code(parse_code_blobs(model_output)) + if self._use_structured_outputs_internally: + code_action = json.loads(model_output)["code"] + code_action = extract_code_from_text(code_action, self.code_block_tags) or code_action + else: + code_action = parse_code_blobs(model_output) + code_action = fix_final_answer_code(code_action) + memory_step.code_action = code_action # Record parsing results self.observer.add_message( self.agent_name, ProcessType.PARSE, code_action) @@ -155,26 +167,29 @@ def _step_stream(self, memory_step: ActionStep) -> Generator[Any]: content=model_output, title="AGENT FINAL ANSWER", level=LogLevel.INFO) raise FinalAnswerError() - memory_step.tool_calls = [ - ToolCall(name="python_interpreter", arguments=code_action, id=f"call_{len(self.memory.steps)}", )] + tool_call = ToolCall( + name="python_interpreter", + arguments=code_action, + id=f"call_{len(self.memory.steps)}", + ) + memory_step.tool_calls = [tool_call] # Execute self.logger.log_code(title="Executing parsed code:", content=code_action, level=LogLevel.INFO) - is_final_answer = False try: - output, execution_logs, is_final_answer = self.python_executor( - code_action) - + code_output = self.python_executor(code_action) execution_outputs_console = [] - if len(execution_logs) > 0: + if len(code_output.logs) > 0: # Record execution results self.observer.add_message( - self.agent_name, ProcessType.EXECUTION_LOGS, f"{execution_logs}") + self.agent_name, ProcessType.EXECUTION_LOGS, f"{code_output.logs}") execution_outputs_console += [ - Text("Execution logs:", style="bold"), Text(execution_logs), ] - observation = "Execution logs:\n" + execution_logs + Text("Execution logs:", style="bold"), + Text(code_output.logs), + ] + observation = "Execution logs:\n" + code_output.logs except Exception as e: if hasattr(self.python_executor, "state") and "_print_outputs" in self.python_executor.state: execution_logs = str( @@ -196,20 +211,24 @@ def _step_stream(self, memory_step: ActionStep) -> Generator[Any]: level=LogLevel.INFO, ) raise AgentExecutionError(error_msg, self.logger) - truncated_output = truncate_content(str(output)) - if output is not None: + truncated_output = None + if code_output is not None and code_output.output is not None: + truncated_output = truncate_content(str(code_output.output)) observation += "Last output from code snippet:\n" + truncated_output memory_step.observations = observation - execution_outputs_console += [ - Text(f"{('Out - Final answer' if is_final_answer else 'Out')}: {truncated_output}", - style=("bold #d4b702" if is_final_answer else ""), ), ] + if not code_output.is_final_answer and truncated_output is not None: + execution_outputs_console += [ + Text( + f"Out: {truncated_output}", + ), + ] self.logger.log(Group(*execution_outputs_console), level=LogLevel.INFO) - memory_step.action_output = output - yield output if is_final_answer else None + memory_step.action_output = code_output.output + yield ActionOutput(output=code_output.output, is_final_answer=code_output.is_final_answer) def run(self, task: str, stream: bool = False, reset: bool = True, images: Optional[List[str]] = None, - additional_args: Optional[Dict] = None, max_steps: Optional[int] = None, ): + additional_args: Optional[Dict] = None, max_steps: Optional[int] = None, return_full_result: bool | None = None): """ Run the agent for the given task. @@ -220,6 +239,8 @@ def run(self, task: str, stream: bool = False, reset: bool = True, images: Optio images (`list[str]`, *optional*): Paths to image(s). additional_args (`dict`, *optional*): Any other variables that you want to pass to the agent run, for instance images or dataframes. Give them clear names! max_steps (`int`, *optional*): Maximum number of steps the agent can take to solve the task. if not provided, will use the agent's default value. + return_full_result (`bool`, *optional*): Whether to return the full [`RunResult`] object or just the final answer output. + If `None` (default), the agent's `self.return_full_result` setting is used. Example: ```py @@ -236,7 +257,6 @@ def run(self, task: str, stream: bool = False, reset: bool = True, images: Optio You have been provided with these additional arguments, that you can access using the keys as variables in your python code: {str(additional_args)}.""" - self.system_prompt = self.initialize_system_prompt() self.memory.system_prompt = SystemPromptStep( system_prompt=self.system_prompt) if reset: @@ -261,8 +281,47 @@ def run(self, task: str, stream: bool = False, reset: bool = True, images: Optio if stream: # The steps are returned as they are executed through a generator to iterate on. return self._run_stream(task=self.task, max_steps=max_steps, images=images) + run_start_time = time.time() + steps = list(self._run_stream(task=self.task, max_steps=max_steps, images=images)) + # Outputs are returned only at the end. We only look at the last step. - return list(self._run_stream(task=self.task, max_steps=max_steps, images=images))[-1].final_answer + assert isinstance(steps[-1], FinalAnswerStep) + output = steps[-1].output + + return_full_result = return_full_result if return_full_result is not None else self.return_full_result + if return_full_result: + total_input_tokens = 0 + total_output_tokens = 0 + correct_token_usage = True + for step in self.memory.steps: + if isinstance(step, (ActionStep, PlanningStep)): + if step.token_usage is None: + correct_token_usage = False + break + else: + total_input_tokens += step.token_usage.input_tokens + total_output_tokens += step.token_usage.output_tokens + if correct_token_usage: + token_usage = TokenUsage(input_tokens=total_input_tokens, output_tokens=total_output_tokens) + else: + token_usage = None + + if self.memory.steps and isinstance(getattr(self.memory.steps[-1], "error", None), AgentMaxStepsError): + state = "max_steps_error" + else: + state = "success" + + step_dicts = self.memory.get_full_steps() + + return RunResult( + output=output, + token_usage=token_usage, + steps=step_dicts, + timing=Timing(start_time=run_start_time, end_time=time.time()), + state=state, + ) + + return output def __call__(self, task: str, **kwargs): """Adds additional prompting for the managed agent, runs it, and wraps the output. @@ -271,7 +330,11 @@ def __call__(self, task: str, **kwargs): full_task = Template(self.prompt_templates["managed_agent"]["task"], undefined=StrictUndefined).render({ "name": self.name, "task": task, **self.state }) - report = self.run(full_task, **kwargs) + result = self.run(full_task, **kwargs) + if isinstance(result, RunResult): + report = result.output + else: + report = result # When a sub-agent finishes running, return a marker try: @@ -286,7 +349,7 @@ def __call__(self, task: str, **kwargs): if self.provide_run_summary: answer += "\n\nFor more detail, find below a summary of this agent's work:\n\n" for message in self.write_memory_to_messages(summary_mode=True): - content = message["content"] + content = message.content answer += "\n" + truncate_content(str(content)) + "\n---" answer += "\n" return answer @@ -295,28 +358,44 @@ def _run_stream( self, task: str, max_steps: int, images: list["PIL.Image.Image"] | None = None ) -> Generator[ActionStep | PlanningStep | FinalAnswerStep]: final_answer = None + action_step = None self.step_number = 1 - while final_answer is None and self.step_number <= max_steps and not self.stop_event.is_set(): + returned_final_answer = False + while not returned_final_answer and self.step_number <= max_steps and not self.stop_event.is_set(): step_start_time = time.time() action_step = ActionStep( - step_number=self.step_number, start_time=step_start_time, observations_images=images + step_number=self.step_number, timing=Timing(start_time=step_start_time), observations_images=images ) try: - for el in self._execute_step(action_step): - yield el - final_answer = el + for output in self._step_stream(action_step): + yield output + + if isinstance(output, ActionOutput) and output.is_final_answer: + final_answer = output.output + self.logger.log( + Text(f"Final answer: {final_answer}", style=f"bold {YELLOW_HEX}"), + level=LogLevel.INFO, + ) + + if self.final_answer_checks: + self._validate_final_answer(final_answer) + returned_final_answer = True + action_step.is_final_answer = True + except FinalAnswerError: # When the model does not output code, directly treat the large model content as the final answer final_answer = action_step.model_output if isinstance(final_answer, str): final_answer = convert_code_format(final_answer) + returned_final_answer = True + action_step.is_final_answer = True except AgentError as e: action_step.error = e finally: - self._finalize_step(action_step, step_start_time) + self._finalize_step(action_step) self.memory.steps.append(action_step) yield action_step self.step_number += 1 @@ -324,8 +403,7 @@ def _run_stream( if self.stop_event.is_set(): final_answer = "" - if final_answer is None and self.step_number == max_steps + 1: - final_answer = self._handle_max_steps_reached( - task, images, step_start_time) + if not returned_final_answer and self.step_number == max_steps + 1: + final_answer = self._handle_max_steps_reached(task) yield action_step yield FinalAnswerStep(handle_agent_output_types(final_answer)) diff --git a/sdk/nexent/core/agents/nexent_agent.py b/sdk/nexent/core/agents/nexent_agent.py index f0f932389..f02251cfc 100644 --- a/sdk/nexent/core/agents/nexent_agent.py +++ b/sdk/nexent/core/agents/nexent_agent.py @@ -1,8 +1,9 @@ import re +import time from threading import Event from typing import List -from smolagents import ActionStep, AgentText, TaskStep +from smolagents import ActionStep, AgentText, TaskStep, Timing from smolagents.tools import Tool from ..models.openai_llm import OpenAIModel @@ -84,6 +85,9 @@ def create_local_tool(self, tool_config: ToolConfig): "vdb_core", None) if tool_config.metadata else None tools_obj.embedding_model = tool_config.metadata.get( "embedding_model", None) if tool_config.metadata else None + name_resolver = tool_config.metadata.get( + "name_resolver", None) if tool_config.metadata else None + tools_obj.name_resolver = {} if name_resolver is None else name_resolver elif class_name == "AnalyzeTextFileTool": tools_obj = tool_class(observer=self.observer, llm_model=tool_config.metadata.get("llm_model", []), @@ -195,7 +199,9 @@ def add_history_to_agent(self, history: List[AgentHistory]): # Create task step for user message self.agent.memory.steps.append(TaskStep(task=msg.content)) elif msg.role == 'assistant': - self.agent.memory.steps.append(ActionStep(action_output=msg.content, model_output=msg.content)) + self.agent.memory.steps.append(ActionStep(step_number=len(self.agent.memory.steps) + 1, + timing=Timing(start_time=time.time()), + action_output=msg.content, model_output=msg.content)) def agent_run_with_observer(self, query: str, reset=True): if not isinstance(self.agent, CoreAgent): @@ -214,7 +220,7 @@ def agent_run_with_observer(self, query: str, reset=True): if hasattr(step_log, "error") and step_log.error is not None: observer.add_message("", ProcessType.ERROR, str(step_log.error)) - final_answer = step_log.final_answer # Last log is the run's final_answer + final_answer = step_log.output # Last log is the run's final_answer if isinstance(final_answer, AgentText): final_answer_str = convert_code_format(final_answer.to_string()) diff --git a/sdk/nexent/core/agents/run_agent.py b/sdk/nexent/core/agents/run_agent.py index 41429367a..8a5a67517 100644 --- a/sdk/nexent/core/agents/run_agent.py +++ b/sdk/nexent/core/agents/run_agent.py @@ -1,6 +1,7 @@ import asyncio import logging from threading import Thread +from typing import Any, Dict, Union from smolagents import ToolCollection @@ -13,6 +14,56 @@ monitoring_manager = get_monitoring_manager() +def _detect_transport(url: str) -> str: + """ + Auto-detect MCP transport type based on URL format. + + Args: + url: MCP server URL + + Returns: + Transport type: 'sse' or 'streamable-http' + """ + url_stripped = url.strip() + + # Check URL ending to determine transport type + if url_stripped.endswith("/sse"): + return "sse" + elif url_stripped.endswith("/mcp"): + return "streamable-http" + + # Default to streamable-http for unrecognized formats + return "streamable-http" + + +def _normalize_mcp_config(mcp_host_item: Union[str, Dict[str, Any]]) -> Dict[str, Any]: + """ + Normalize MCP host configuration to a dictionary format. + + Args: + mcp_host_item: Either a string URL or a dict with 'url' and optional 'transport' + + Returns: + Dictionary with 'url' and 'transport' keys + """ + if isinstance(mcp_host_item, str): + url = mcp_host_item + transport = _detect_transport(url) + return {"url": url, "transport": transport} + elif isinstance(mcp_host_item, dict): + url = mcp_host_item.get("url") + if not url: + raise ValueError("MCP host dict must contain 'url' key") + transport = mcp_host_item.get("transport") + if not transport: + transport = _detect_transport(url) + if transport not in ("sse", "streamable-http"): + raise ValueError(f"Invalid transport type: {transport}. Must be 'sse' or 'streamable-http'") + return {"url": url, "transport": transport} + else: + raise ValueError(f"Invalid MCP host item type: {type(mcp_host_item)}. Must be str or dict") + + @monitoring_manager.monitor_endpoint("agent_run_thread", "agent_run_thread") def agent_run_thread(agent_run_info: AgentRunInfo): try: @@ -31,7 +82,8 @@ def agent_run_thread(agent_run_info: AgentRunInfo): else: agent_run_info.observer.add_message( "", ProcessType.AGENT_NEW_RUN, "") - mcp_client_list = [{"url": mcp_url} for mcp_url in mcp_host] + # Normalize MCP host configurations to support both string and dict formats + mcp_client_list = [_normalize_mcp_config(item) for item in mcp_host] with ToolCollection.from_mcp(mcp_client_list, trust_remote_code=True) as tool_collection: nexent = NexentAgent( diff --git a/sdk/nexent/core/models/openai_llm.py b/sdk/nexent/core/models/openai_llm.py index 1a52e2d29..1eef02c72 100644 --- a/sdk/nexent/core/models/openai_llm.py +++ b/sdk/nexent/core/models/openai_llm.py @@ -14,7 +14,7 @@ logger = logging.getLogger("openai_llm") class OpenAIModel(OpenAIServerModel): - def __init__(self, observer: MessageObserver, temperature=0.2, top_p=0.95, + def __init__(self, observer: MessageObserver = MessageObserver, temperature=0.2, top_p=0.95, ssl_verify=True, *args, **kwargs): """ Initialize OpenAI Model with observer and SSL verification option. @@ -46,7 +46,7 @@ def __init__(self, observer: MessageObserver, temperature=0.2, top_p=0.95, @get_monitoring_manager().monitor_llm_call("openai_chat", "chat_completion") def __call__(self, messages: List[Dict[str, Any]], stop_sequences: Optional[List[str]] = None, - grammar: Optional[str] = None, tools_to_call_from: Optional[List[Tool]] = None, **kwargs, ) -> ChatMessage: + response_format: dict[str, str] | None = None, tools_to_call_from: Optional[List[Tool]] = None, **kwargs, ) -> ChatMessage: # Get token tracker from decorator (if monitoring is available) token_tracker = kwargs.pop('_token_tracker', None) @@ -63,7 +63,7 @@ def __call__(self, messages: List[Dict[str, Any]], stop_sequences: Optional[List completion_kwargs = self._prepare_completion_kwargs( messages=messages, stop_sequences=stop_sequences, - grammar=grammar, tools_to_call_from=tools_to_call_from, model=self.model_id, + response_format=response_format, tools_to_call_from=tools_to_call_from, model=self.model_id, custom_role_conversions=self.custom_role_conversions, convert_images_to_image_urls=True, temperature=self.temperature, top_p=self.top_p, **kwargs, ) diff --git a/sdk/nexent/core/tools/datamate_search_tool.py b/sdk/nexent/core/tools/datamate_search_tool.py index a179dd689..bf1009269 100644 --- a/sdk/nexent/core/tools/datamate_search_tool.py +++ b/sdk/nexent/core/tools/datamate_search_tool.py @@ -150,7 +150,7 @@ def forward( entity_data = single_search_result.get("entity", {}) metadata = self._parse_metadata(entity_data.get("metadata")) dataset_id = self._extract_dataset_id(metadata.get("absolute_directory_path", "")) - file_id = entity_data.get("id") + file_id = metadata.get("original_file_id") download_url = self._build_file_download_url(dataset_id, file_id) score_details = entity_data.get("scoreDetails", {}) or {} @@ -162,7 +162,7 @@ def forward( }) search_result_message = SearchResultTextMessage( - title=metadata.get("file_name", "") or "Untitled", + title=metadata.get("file_name", ""), text=entity_data.get("text", ""), source_type="datamate", url=download_url, @@ -308,6 +308,6 @@ def _extract_dataset_id(absolute_path: str) -> str: def _build_file_download_url(self, dataset_id: str, file_id: str) -> str: """Build the download URL for a dataset file.""" - if not (self.server_ip and dataset_id and file_id): + if not (self.server_base_url and dataset_id and file_id): return "" - return f"{self.server_ip}/api/data-management/datasets/{dataset_id}/files/{file_id}/download" \ No newline at end of file + return f"{self.server_base_url}/api/data-management/datasets/{dataset_id}/files/{file_id}/download" \ No newline at end of file diff --git a/sdk/nexent/core/tools/knowledge_base_search_tool.py b/sdk/nexent/core/tools/knowledge_base_search_tool.py index 636162da1..90b600da6 100644 --- a/sdk/nexent/core/tools/knowledge_base_search_tool.py +++ b/sdk/nexent/core/tools/knowledge_base_search_tool.py @@ -1,6 +1,6 @@ import json import logging -from typing import List +from typing import Dict, List, Optional, Union from pydantic import Field from smolagents.tools import Tool @@ -36,7 +36,7 @@ class KnowledgeBaseSearchTool(Tool): }, "index_names": { "type": "array", - "description": "The list of knowledge base index names to search. If not provided, will search all available knowledge bases.", + "description": "The list of knowledge base names to search (supports user-facing knowledge_name or internal index_name). If not provided, will search all available knowledge bases.", "nullable": True, }, } @@ -50,6 +50,9 @@ def __init__( self, top_k: int = Field(description="Maximum number of search results", default=5), index_names: List[str] = Field(description="The list of index names to search", default=None, exclude=True), + name_resolver: Optional[Dict[str, str]] = Field( + description="Mapping from knowledge_name to index_name", default=None, exclude=True + ), observer: MessageObserver = Field(description="Message observer", default=None, exclude=True), embedding_model: BaseEmbedding = Field(description="The embedding model to use", default=None, exclude=True), vdb_core: VectorDatabaseCore = Field(description="Vector database client", default=None, exclude=True), @@ -68,13 +71,36 @@ def __init__( self.observer = observer self.vdb_core = vdb_core self.index_names = [] if index_names is None else index_names + self.name_resolver: Dict[str, str] = name_resolver or {} self.embedding_model = embedding_model self.record_ops = 1 # To record serial number self.running_prompt_zh = "知识库检索中..." self.running_prompt_en = "Searching the knowledge base..." - def forward(self, query: str, search_mode: str = "hybrid", index_names: List[str] = None) -> str: + def update_name_resolver(self, new_mapping: Dict[str, str]) -> None: + """Update the mapping from knowledge_name to index_name at runtime.""" + self.name_resolver = new_mapping or {} + + def _resolve_names(self, names: List[str]) -> List[str]: + """Resolve user-facing knowledge names to internal index names.""" + if not names: + return [] + if not self.name_resolver: + logger.warning( + "No name resolver provided, returning original names") + return names + return [self.name_resolver.get(name, name) for name in names] + + def _normalize_index_names(self, index_names: Optional[Union[str, List[str]]]) -> List[str]: + """Normalize index_names to list; accept single string and keep None as empty list.""" + if index_names is None: + return [] + if isinstance(index_names, str): + return [index_names] + return list(index_names) + + def forward(self, query: str, search_mode: str = "hybrid", index_names: Union[str, List[str], None] = None) -> str: # Send tool run message if self.observer: running_prompt = self.running_prompt_zh if self.observer.lang == "zh" else self.running_prompt_en @@ -83,7 +109,9 @@ def forward(self, query: str, search_mode: str = "hybrid", index_names: List[str self.observer.add_message("", ProcessType.CARD, json.dumps(card_content, ensure_ascii=False)) # Use provided index_names if available, otherwise use default - search_index_names = index_names if index_names is not None else self.index_names + search_index_names = self._normalize_index_names( + index_names if index_names is not None else self.index_names) + search_index_names = self._resolve_names(search_index_names) # Log the index_names being used for this search logger.info( diff --git a/sdk/nexent/vector_database/base.py b/sdk/nexent/vector_database/base.py index 188e33e59..d15ba7a25 100644 --- a/sdk/nexent/vector_database/base.py +++ b/sdk/nexent/vector_database/base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Callable from ..core.models.embedding_model import BaseEmbedding @@ -79,6 +79,8 @@ def vectorize_documents( documents: List[Dict[str, Any]], batch_size: int = 64, content_field: str = "content", + embedding_batch_size: int = 10, + progress_callback: Optional[Callable[[int, int], None]] = None, ) -> int: """ Index documents with embeddings. diff --git a/sdk/nexent/vector_database/elasticsearch_core.py b/sdk/nexent/vector_database/elasticsearch_core.py index 4e027b941..8abe046f4 100644 --- a/sdk/nexent/vector_database/elasticsearch_core.py +++ b/sdk/nexent/vector_database/elasticsearch_core.py @@ -1,10 +1,11 @@ +import json import logging import threading import time from contextlib import contextmanager from dataclasses import dataclass from datetime import datetime, timedelta -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional from elasticsearch import Elasticsearch, exceptions @@ -338,6 +339,8 @@ def vectorize_documents( documents: List[Dict[str, Any]], batch_size: int = 64, content_field: str = "content", + embedding_batch_size: int = 10, + progress_callback: Optional[Callable[[int, int], None]] = None, ) -> int: """ Smart batch insertion - automatically selecting strategy based on data size @@ -348,6 +351,7 @@ def vectorize_documents( documents: List of document dictionaries batch_size: Number of documents to process at once content_field: Field to use for generating embeddings + embedding_batch_size: Number of documents to send to embedding API at once (default: 10) Returns: int: Number of documents successfully indexed @@ -362,15 +366,34 @@ def vectorize_documents( total_docs = len(documents) if total_docs < 64: # Small data: direct insertion, using wait_for refresh - return self._small_batch_insert(index_name, documents, content_field, embedding_model) + return self._small_batch_insert( + index_name=index_name, + documents=documents, + content_field=content_field, + embedding_model=embedding_model, + progress_callback=progress_callback, + ) else: # Large data: using context manager estimated_duration = max(60, total_docs // 100) with self.bulk_operation_context(index_name, estimated_duration): - return self._large_batch_insert(index_name, documents, batch_size, content_field, embedding_model) + return self._large_batch_insert( + index_name=index_name, + documents=documents, + batch_size=batch_size, + content_field=content_field, + embedding_model=embedding_model, + embedding_batch_size=embedding_batch_size, + progress_callback=progress_callback, + ) def _small_batch_insert( - self, index_name: str, documents: List[Dict[str, Any]], content_field: str, embedding_model: BaseEmbedding + self, + index_name: str, + documents: List[Dict[str, Any]], + content_field: str, + embedding_model: BaseEmbedding, + progress_callback: Optional[Callable[[int, int], None]] = None, ) -> int: """Small batch insertion: real-time""" try: @@ -398,13 +421,20 @@ def _small_batch_insert( # Handle errors self._handle_bulk_errors(response) + if progress_callback: + try: + progress_callback(len(documents), len(documents)) + except Exception as e: + logger.warning( + f"[VECTORIZE] Progress callback failed in small batch: {str(e)}") + logger.info( f"Small batch insert completed: {len(documents)} chunks indexed.") return len(documents) except Exception as e: logger.error(f"Small batch insert failed: {e}") - return 0 + raise def _large_batch_insert( self, @@ -413,6 +443,8 @@ def _large_batch_insert( batch_size: int, content_field: str, embedding_model: BaseEmbedding, + embedding_batch_size: int = 10, + progress_callback: Optional[Callable[[int, int], None]] = None, ) -> int: """ Large batch insertion with sub-batching for embedding API. @@ -422,6 +454,7 @@ def _large_batch_insert( processed_docs = self._preprocess_documents( documents, content_field) total_indexed = 0 + total_vectorized = 0 total_docs = len(processed_docs) es_total_batches = (total_docs + batch_size - 1) // batch_size start_time = time.time() @@ -439,7 +472,7 @@ def _large_batch_insert( doc_embedding_pairs = [] # Sub-batch for embedding API - embedding_batch_size = 64 + # Use the provided embedding_batch_size (default 10) to reduce provider pressure for j in range(0, len(es_batch), embedding_batch_size): embedding_sub_batch = es_batch[j: j + embedding_batch_size] # Retry logic for embedding API call (3 retries, 1s delay) @@ -459,6 +492,16 @@ def _large_batch_insert( doc_embedding_pairs.append((doc, embedding)) success = True + total_vectorized += len(embedding_sub_batch) + if progress_callback: + try: + progress_callback( + total_vectorized, total_docs) + logger.debug( + f"[VECTORIZE] Progress callback (embedding) {total_vectorized}/{total_docs} (ES batch {es_batch_num}/{es_total_batches}, sub-batch start {j})") + except Exception as callback_err: + logger.warning( + f"[VECTORIZE] Progress callback failed during embedding: {callback_err}") break # Success, exit retry loop except Exception as e: @@ -504,10 +547,7 @@ def _large_batch_insert( except Exception as e: logger.error( f"Bulk insert error: {e}, ES batch num: {es_batch_num}") - continue - - # Add 0.1s delay between batches to avoid overloading embedding API - time.sleep(0.1) + raise self._force_refresh_with_retry(index_name) total_elapsed = time.time() - start_time @@ -517,7 +557,7 @@ def _large_batch_insert( return total_indexed except Exception as e: logger.error(f"Large batch insert failed: {e}") - return 0 + raise def _preprocess_documents(self, documents: List[Dict[str, Any]], content_field: str) -> List[Dict[str, Any]]: """Ensure all documents have the required fields and set default values""" @@ -558,21 +598,44 @@ def _handle_bulk_errors(self, response: Dict[str, Any]) -> None: """Handle bulk operation errors""" if response.get("errors"): for item in response["items"]: - if "error" in item.get("index", {}): - error_info = item["index"]["error"] - error_type = error_info.get("type") - error_reason = error_info.get("reason") - error_cause = error_info.get("caused_by", {}) - - if error_type == "version_conflict_engine_exception": - # ignore version conflict - continue - else: - logger.error( - f"FATAL ERROR {error_type}: {error_reason}") - if error_cause: - logger.error( - f"Caused By: {error_cause.get('type')}: {error_cause.get('reason')}") + if "error" not in item.get("index", {}): + continue + + error_info = item["index"]["error"] + error_type = error_info.get("type") + error_reason = error_info.get("reason") + error_cause = error_info.get("caused_by", {}) + + if error_type == "version_conflict_engine_exception": + # ignore version conflict + continue + + logger.error(f"FATAL ERROR {error_type}: {error_reason}") + if error_cause: + logger.error( + f"Caused By: {error_cause.get('type')}: {error_cause.get('reason')}" + ) + + reason_text = error_reason or "Unknown bulk indexing error" + cause_reason = error_cause.get("reason") + if cause_reason: + reason_text = f"{reason_text}; caused by: {cause_reason}" + + # Derive a precise error code without chaining through es_bulk_failed + if "dense_vector" in reason_text and "different number of dimensions" in reason_text: + error_code = "es_dim_mismatch" + else: + error_code = "es_bulk_failed" + + raise Exception( + json.dumps( + { + "message": f"Bulk indexing failed: {reason_text}", + "error_code": error_code, + }, + ensure_ascii=False, + ) + ) def delete_documents(self, index_name: str, path_or_url: str) -> int: """ diff --git a/sdk/pyproject.toml b/sdk/pyproject.toml index 453857a1d..1e1369fb7 100644 --- a/sdk/pyproject.toml +++ b/sdk/pyproject.toml @@ -31,14 +31,14 @@ dependencies = [ "rich>=13.9.4", "setuptools>=75.1.0", "websockets>=14.2", - "smolagents[mcp]==1.15.0", + "smolagents[mcp]==1.23.0", "Pillow>=10.0.0", "aiohttp>=3.1.13", "jieba>=0.42.1", "boto3>=1.37.34", "botocore>=1.37.34", "python-multipart>=0.0.20", - "mcpadapt==0.1.9", + "mcpadapt>=0.1.13", "mcp==1.10.1", "tiktoken>=0.5.0", "tavily-python", diff --git a/test/backend/app/test_agent_app.py b/test/backend/app/test_agent_app.py index 3eeaf6650..dbb5a5318 100644 --- a/test/backend/app/test_agent_app.py +++ b/test/backend/app/test_agent_app.py @@ -689,3 +689,120 @@ def test_get_agent_call_relationship_api_exception(mocker, mock_auth_header): assert resp.status_code == 500 assert "Failed to get agent call relationship" in resp.json()["detail"] + + +def test_check_agent_name_batch_api_success(mocker, mock_auth_header): + mock_impl = mocker.patch( + "apps.agent_app.check_agent_name_conflict_batch_impl", + new_callable=mocker.AsyncMock, + ) + mock_impl.return_value = [{"name_conflict": True}] + + payload = { + "items": [ + {"agent_id": 1, "name": "AgentA", "display_name": "Agent A"}, + ] + } + + resp = config_client.post( + "/agent/check_name", json=payload, headers=mock_auth_header + ) + + assert resp.status_code == 200 + mock_impl.assert_called_once() + assert resp.json() == [{"name_conflict": True}] + + +def test_check_agent_name_batch_api_bad_request(mocker, mock_auth_header): + mock_impl = mocker.patch( + "apps.agent_app.check_agent_name_conflict_batch_impl", + new_callable=mocker.AsyncMock, + ) + mock_impl.side_effect = ValueError("bad payload") + + resp = config_client.post( + "/agent/check_name", + json={"items": [{"agent_id": 1, "name": "AgentA"}]}, + headers=mock_auth_header, + ) + + assert resp.status_code == 400 + assert resp.json()["detail"] == "bad payload" + + +def test_check_agent_name_batch_api_error(mocker, mock_auth_header): + mock_impl = mocker.patch( + "apps.agent_app.check_agent_name_conflict_batch_impl", + new_callable=mocker.AsyncMock, + ) + mock_impl.side_effect = Exception("unexpected") + + resp = config_client.post( + "/agent/check_name", + json={"items": [{"agent_id": 1, "name": "AgentA"}]}, + headers=mock_auth_header, + ) + + assert resp.status_code == 500 + assert "Agent name batch check error" in resp.json()["detail"] + + +def test_regenerate_agent_name_batch_api_success(mocker, mock_auth_header): + mock_impl = mocker.patch( + "apps.agent_app.regenerate_agent_name_batch_impl", + new_callable=mocker.AsyncMock, + ) + mock_impl.return_value = [{"name": "NewName", "display_name": "New Display"}] + + payload = { + "items": [ + { + "agent_id": 1, + "name": "AgentA", + "display_name": "Agent A", + "task_description": "desc", + } + ] + } + + resp = config_client.post( + "/agent/regenerate_name", json=payload, headers=mock_auth_header + ) + + assert resp.status_code == 200 + mock_impl.assert_called_once() + assert resp.json() == [{"name": "NewName", "display_name": "New Display"}] + + +def test_regenerate_agent_name_batch_api_bad_request(mocker, mock_auth_header): + mock_impl = mocker.patch( + "apps.agent_app.regenerate_agent_name_batch_impl", + new_callable=mocker.AsyncMock, + ) + mock_impl.side_effect = ValueError("invalid") + + resp = config_client.post( + "/agent/regenerate_name", + json={"items": [{"agent_id": 1, "name": "AgentA"}]}, + headers=mock_auth_header, + ) + + assert resp.status_code == 400 + assert resp.json()["detail"] == "invalid" + + +def test_regenerate_agent_name_batch_api_error(mocker, mock_auth_header): + mock_impl = mocker.patch( + "apps.agent_app.regenerate_agent_name_batch_impl", + new_callable=mocker.AsyncMock, + ) + mock_impl.side_effect = Exception("boom") + + resp = config_client.post( + "/agent/regenerate_name", + json={"items": [{"agent_id": 1, "name": "AgentA"}]}, + headers=mock_auth_header, + ) + + assert resp.status_code == 500 + assert "Agent name batch regenerate error" in resp.json()["detail"] \ No newline at end of file diff --git a/test/backend/app/test_file_management_app.py b/test/backend/app/test_file_management_app.py index cd4be8afd..a337a1434 100644 --- a/test/backend/app/test_file_management_app.py +++ b/test/backend/app/test_file_management_app.py @@ -295,6 +295,53 @@ async def gen(): assert b"chunk1" in b"".join(chunks) +@pytest.mark.asyncio +async def test_get_storage_file_base64_success(monkeypatch): + """get_storage_file should return JSON with base64 content when download=base64.""" + async def fake_get_stream(object_name): + class FakeStream: + def read(self): + return b"hello-bytes" + + return FakeStream(), "image/png" + + monkeypatch.setattr(file_management_app, "get_file_stream_impl", fake_get_stream) + + resp = await file_management_app.get_storage_file( + object_name="attachments/img.png", + download="base64", + expires=60, + filename=None, + ) + + assert resp.status_code == 200 + data = resp.body.decode() + assert '"success":true' in data + assert '"content_type":"image/png"' in data + + +@pytest.mark.asyncio +async def test_get_storage_file_base64_read_error(monkeypatch): + """get_storage_file should raise HTTPException when reading stream fails in base64 mode.""" + async def fake_get_stream(object_name): + class FakeStream: + def read(self): + raise RuntimeError("read-failed") + + return FakeStream(), "image/png" + + monkeypatch.setattr(file_management_app, "get_file_stream_impl", fake_get_stream) + + with pytest.raises(Exception) as exc_info: + await file_management_app.get_storage_file( + object_name="attachments/img.png", + download="base64", + expires=60, + filename=None, + ) + + assert "Failed to read file content for base64 encoding" in str(exc_info.value) + @pytest.mark.asyncio async def test_get_storage_file_metadata(monkeypatch): async def fake_get_url(object_name, expires): diff --git a/test/backend/app/test_vectordatabase_app.py b/test/backend/app/test_vectordatabase_app.py index fc0529341..97e26842a 100644 --- a/test/backend/app/test_vectordatabase_app.py +++ b/test/backend/app/test_vectordatabase_app.py @@ -6,7 +6,7 @@ import os import sys import pytest -from unittest.mock import patch, MagicMock, ANY +from unittest.mock import patch, MagicMock, ANY, AsyncMock from fastapi.testclient import TestClient from fastapi import FastAPI @@ -152,7 +152,7 @@ async def test_create_new_index_success(vdb_core_mock, auth_data): # Setup mocks with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ - patch("backend.apps.vectordatabase_app.ElasticSearchService.create_index") as mock_create: + patch("backend.apps.vectordatabase_app.ElasticSearchService.create_knowledge_base") as mock_create: expected_response = {"status": "success", "index_name": auth_data["index_name"]} @@ -165,7 +165,13 @@ async def test_create_new_index_success(vdb_core_mock, auth_data): # Verify assert response.status_code == 200 assert response.json() == expected_response + # vdb_core is constructed inside router; accept ANY for instance mock_create.assert_called_once() + called_args = mock_create.call_args[0] + assert called_args[0] == auth_data["index_name"] + assert called_args[1] == 768 + assert called_args[3] == auth_data["user_id"] + assert called_args[4] == auth_data["tenant_id"] @pytest.mark.asyncio @@ -177,7 +183,7 @@ async def test_create_new_index_error(vdb_core_mock, auth_data): # Setup mocks with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ - patch("backend.apps.vectordatabase_app.ElasticSearchService.create_index") as mock_create: + patch("backend.apps.vectordatabase_app.ElasticSearchService.create_knowledge_base") as mock_create: mock_create.side_effect = Exception("Test error") @@ -702,10 +708,11 @@ async def test_get_index_chunks_success(vdb_core_mock): Test retrieving index chunks successfully. Verifies that the endpoint forwards query params and returns the service payload. """ + index_name = "test_index" with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.get_index_name_by_knowledge_name", return_value="resolved_index"), \ patch("backend.apps.vectordatabase_app.ElasticSearchService.get_index_chunks") as mock_get_chunks: - index_name = "test_index" expected_response = { "status": "success", "message": "ok", @@ -724,7 +731,7 @@ async def test_get_index_chunks_success(vdb_core_mock): assert response.status_code == 200 assert response.json() == expected_response mock_get_chunks.assert_called_once_with( - index_name=index_name, + index_name="resolved_index", page=2, page_size=50, path_or_url="/foo", @@ -738,10 +745,11 @@ async def test_get_index_chunks_error(vdb_core_mock): Test retrieving index chunks with service error. Ensures the endpoint maps the exception to HTTP 500. """ + index_name = "test_index" with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.get_index_name_by_knowledge_name", return_value="resolved_index"), \ patch("backend.apps.vectordatabase_app.ElasticSearchService.get_index_chunks") as mock_get_chunks: - index_name = "test_index" mock_get_chunks.side_effect = Exception("Chunk failure") response = client.post(f"/indices/{index_name}/chunks") @@ -749,7 +757,7 @@ async def test_get_index_chunks_error(vdb_core_mock): assert response.status_code == 500 assert response.json() == {"detail": "Error getting chunks: Chunk failure"} mock_get_chunks.assert_called_once_with( - index_name=index_name, + index_name="resolved_index", page=None, page_size=None, path_or_url=None, @@ -765,6 +773,7 @@ async def test_create_chunk_success(vdb_core_mock, auth_data): with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.get_index_name_by_knowledge_name", return_value=auth_data["index_name"]), \ patch("backend.apps.vectordatabase_app.ElasticSearchService.create_chunk") as mock_create: expected_response = {"status": "success", "chunk_id": "chunk-1"} @@ -794,6 +803,7 @@ async def test_create_chunk_error(vdb_core_mock, auth_data): with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.get_index_name_by_knowledge_name", return_value=auth_data["index_name"]), \ patch("backend.apps.vectordatabase_app.ElasticSearchService.create_chunk") as mock_create: mock_create.side_effect = Exception("Create failed") @@ -822,6 +832,7 @@ async def test_update_chunk_success(vdb_core_mock, auth_data): with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.get_index_name_by_knowledge_name", return_value=auth_data["index_name"]), \ patch("backend.apps.vectordatabase_app.ElasticSearchService.update_chunk") as mock_update: expected_response = {"status": "success", "chunk_id": "chunk-1"} @@ -850,6 +861,7 @@ async def test_update_chunk_value_error(vdb_core_mock, auth_data): with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.get_index_name_by_knowledge_name", return_value=auth_data["index_name"]), \ patch("backend.apps.vectordatabase_app.ElasticSearchService.update_chunk") as mock_update: mock_update.side_effect = ValueError("Invalid update payload") @@ -864,7 +876,8 @@ async def test_update_chunk_value_error(vdb_core_mock, auth_data): headers=auth_data["auth_header"], ) - assert response.status_code == 400 + # ValueError is mapped to NOT_FOUND in app layer + assert response.status_code == 404 assert response.json() == {"detail": "Invalid update payload"} mock_update.assert_called_once() @@ -877,6 +890,7 @@ async def test_update_chunk_exception(vdb_core_mock, auth_data): with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.get_index_name_by_knowledge_name", return_value=auth_data["index_name"]), \ patch("backend.apps.vectordatabase_app.ElasticSearchService.update_chunk") as mock_update: mock_update.side_effect = Exception("Update failed") @@ -904,6 +918,7 @@ async def test_delete_chunk_success(vdb_core_mock, auth_data): with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.get_index_name_by_knowledge_name", return_value=auth_data["index_name"]), \ patch("backend.apps.vectordatabase_app.ElasticSearchService.delete_chunk") as mock_delete: expected_response = {"status": "success", "chunk_id": "chunk-1"} @@ -927,6 +942,7 @@ async def test_delete_chunk_not_found(vdb_core_mock, auth_data): with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.get_index_name_by_knowledge_name", return_value=auth_data["index_name"]), \ patch("backend.apps.vectordatabase_app.ElasticSearchService.delete_chunk") as mock_delete: mock_delete.side_effect = ValueError("Chunk not found") @@ -949,6 +965,7 @@ async def test_delete_chunk_exception(vdb_core_mock, auth_data): with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.get_index_name_by_knowledge_name", return_value=auth_data["index_name"]), \ patch("backend.apps.vectordatabase_app.ElasticSearchService.delete_chunk") as mock_delete: mock_delete.side_effect = Exception("Delete failed") @@ -1351,6 +1368,108 @@ async def test_health_check_exception(vdb_core_mock): mock_health.assert_called_once_with(ANY) +@pytest.mark.asyncio +async def test_get_document_error_info_not_found(vdb_core_mock, auth_data): + """ + Test document error info when document is not found. + """ + with patch("backend.apps.vectordatabase_app.get_all_files_status", new=AsyncMock(return_value={})): + response = client.get( + f"/indices/{auth_data['index_name']}/documents/missing_doc/error-info", + headers=auth_data["auth_header"], + ) + + assert response.status_code == 404 + assert "not found" in response.json()["detail"] + + +@pytest.mark.asyncio +async def test_get_document_error_info_no_task_id(auth_data): + """ + Test document error info when task id is empty. + """ + with patch( + "backend.apps.vectordatabase_app.get_all_files_status", + new=AsyncMock( + return_value={ + "doc-1": { + "latest_task_id": "" + } + } + ), + ), patch("backend.apps.vectordatabase_app.get_redis_service") as mock_redis: + response = client.get( + "/indices/test_index/documents/doc-1/error-info", + headers=auth_data["auth_header"], + ) + + assert response.status_code == 200 + assert response.json() == {"status": "success", "error_code": None} + mock_redis.assert_not_called() + + +@pytest.mark.asyncio +async def test_get_document_error_info_json_error_code(auth_data): + """ + Test document error info JSON parsing for error_code. + """ + redis_mock = MagicMock() + redis_mock.get_error_info.return_value = '{"error_code": "INVALID_FORMAT"}' + + with patch( + "backend.apps.vectordatabase_app.get_all_files_status", + new=AsyncMock( + return_value={ + "doc-1": { + "latest_task_id": "task-123" + } + } + ), + ), patch( + "backend.apps.vectordatabase_app.get_redis_service", + return_value=redis_mock, + ): + response = client.get( + "/indices/test_index/documents/doc-1/error-info", + headers=auth_data["auth_header"], + ) + + assert response.status_code == 200 + assert response.json() == {"status": "success", "error_code": "INVALID_FORMAT"} + redis_mock.get_error_info.assert_called_once_with("task-123") + + +@pytest.mark.asyncio +async def test_get_document_error_info_regex_error_code(auth_data): + """ + Test document error info regex extraction when JSON parsing fails. + """ + redis_mock = MagicMock() + redis_mock.get_error_info.return_value = "oops {'error_code': 'TIMEOUT_ERROR'}" + + with patch( + "backend.apps.vectordatabase_app.get_all_files_status", + new=AsyncMock( + return_value={ + "doc-1": { + "latest_task_id": "task-999" + } + } + ), + ), patch( + "backend.apps.vectordatabase_app.get_redis_service", + return_value=redis_mock, + ): + response = client.get( + "/indices/test_index/documents/doc-1/error-info", + headers=auth_data["auth_header"], + ) + + assert response.status_code == 200 + assert response.json() == {"status": "success", "error_code": "TIMEOUT_ERROR"} + redis_mock.get_error_info.assert_called_once_with("task-999") + + @pytest.mark.asyncio async def test_health_check_timeout_exception(vdb_core_mock): """ @@ -1545,6 +1664,59 @@ async def test_hybrid_search_value_error(vdb_core_mock, auth_data): assert response.json() == {"detail": "Query text is required"} +@pytest.mark.asyncio +async def test_get_index_chunks_value_error(vdb_core_mock): + """ + Test get_index_chunks maps ValueError to 404. + """ + index_name = "test_index" + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.get_index_name_by_knowledge_name", return_value="resolved_index"), \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.get_index_chunks") as mock_get_chunks: + + mock_get_chunks.side_effect = ValueError("Unknown index") + + response = client.post(f"/indices/{index_name}/chunks") + + assert response.status_code == 404 + assert response.json() == {"detail": "Unknown index"} + mock_get_chunks.assert_called_once_with( + index_name="resolved_index", + page=None, + page_size=None, + path_or_url=None, + vdb_core=ANY, + ) + + +@pytest.mark.asyncio +async def test_create_chunk_value_error(vdb_core_mock, auth_data): + """ + Test create_chunk maps ValueError to 404. + """ + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.get_index_name_by_knowledge_name", return_value=auth_data["index_name"]), \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.create_chunk") as mock_create: + + mock_create.side_effect = ValueError("Invalid chunk payload") + + payload = { + "content": "Hello world", + "path_or_url": "doc-1", + } + + response = client.post( + f"/indices/{auth_data['index_name']}/chunk", + json=payload, + headers=auth_data["auth_header"], + ) + + assert response.status_code == 404 + assert response.json() == {"detail": "Invalid chunk payload"} + mock_create.assert_called_once() + + @pytest.mark.asyncio async def test_hybrid_search_exception(vdb_core_mock, auth_data): """ diff --git a/test/backend/data_process/test_ray_config.py b/test/backend/data_process/test_ray_config.py index a334965ac..55440cfef 100644 --- a/test/backend/data_process/test_ray_config.py +++ b/test/backend/data_process/test_ray_config.py @@ -95,6 +95,8 @@ def decorator(func): const_mod.FORWARD_REDIS_RETRY_DELAY_S = 0 const_mod.FORWARD_REDIS_RETRY_MAX = 1 const_mod.DISABLE_RAY_DASHBOARD = False + # Constants required by tasks.py + const_mod.ROOT_DIR = "/tmp/test" sys.modules["consts.const"] = const_mod # Stub consts.model (required by utils.file_management_utils) @@ -163,6 +165,71 @@ def __init__(self, chunking_strategy: str, source_type: str, index_name: str, au file_utils_mod.get_file_size = lambda *args, **kwargs: 0 sys.modules["utils.file_management_utils"] = file_utils_mod + # Stub services.redis_service (required by tasks.py) + if "services" not in sys.modules: + services_pkg = types.ModuleType("services") + setattr(services_pkg, "__path__", []) + sys.modules["services"] = services_pkg + if "services.redis_service" not in sys.modules: + redis_service_mod = types.ModuleType("services.redis_service") + class FakeRedisService: + def __init__(self): + pass + redis_service_mod.RedisService = FakeRedisService + redis_service_mod.get_redis_service = lambda: FakeRedisService() + sys.modules["services.redis_service"] = redis_service_mod + + # Stub backend.data_process modules (required by __init__.py and tasks.py) + if "backend" not in sys.modules: + backend_pkg = types.ModuleType("backend") + setattr(backend_pkg, "__path__", []) + sys.modules["backend"] = backend_pkg + if "backend.data_process" not in sys.modules: + dp_pkg = types.ModuleType("backend.data_process") + setattr(dp_pkg, "__path__", []) + sys.modules["backend.data_process"] = dp_pkg + + # Stub backend.data_process.app (required by tasks.py) + if "backend.data_process.app" not in sys.modules: + app_mod = types.ModuleType("backend.data_process.app") + # Create a fake Celery app instance + fake_app = types.SimpleNamespace( + backend=types.SimpleNamespace(), # Not DisabledBackend + conf=types.SimpleNamespace(update=lambda **kwargs: None) + ) + app_mod.app = fake_app + sys.modules["backend.data_process.app"] = app_mod + + # Stub backend.data_process.tasks (required by __init__.py) + if "backend.data_process.tasks" not in sys.modules: + tasks_mod = types.ModuleType("backend.data_process.tasks") + # Mock the task functions that __init__.py imports + tasks_mod.process = lambda *args, **kwargs: None + tasks_mod.forward = lambda *args, **kwargs: None + tasks_mod.process_and_forward = lambda *args, **kwargs: None + tasks_mod.process_sync = lambda *args, **kwargs: None + sys.modules["backend.data_process.tasks"] = tasks_mod + + # Stub backend.data_process.utils (required by __init__.py) + if "backend.data_process.utils" not in sys.modules: + utils_mod = types.ModuleType("backend.data_process.utils") + utils_mod.get_task_info = lambda *args, **kwargs: {} + utils_mod.get_task_details = lambda *args, **kwargs: {} + sys.modules["backend.data_process.utils"] = utils_mod + + # Stub backend.data_process.__init__ to avoid importing real tasks + # This must be done after tasks and utils are defined + if "backend.data_process.__init__" not in sys.modules: + init_mod = types.ModuleType("backend.data_process.__init__") + init_mod.app = sys.modules["backend.data_process.app"].app + init_mod.process = sys.modules["backend.data_process.tasks"].process + init_mod.forward = sys.modules["backend.data_process.tasks"].forward + init_mod.process_and_forward = sys.modules["backend.data_process.tasks"].process_and_forward + init_mod.process_sync = sys.modules["backend.data_process.tasks"].process_sync + init_mod.get_task_info = sys.modules["backend.data_process.utils"].get_task_info + init_mod.get_task_details = sys.modules["backend.data_process.utils"].get_task_details + sys.modules["backend.data_process.__init__"] = init_mod + # Stub ray_actors (required by tasks.py) if "backend.data_process.ray_actors" not in sys.modules: ray_actors_mod = types.ModuleType("backend.data_process.ray_actors") @@ -179,10 +246,128 @@ def __init__(self, chunking_strategy: str, source_type: str, index_name: str, au DataProcessCore=type("_Core", (), {"__init__": lambda self: None, "file_process": lambda *a, **k: []}) ) - # Import and reload the module after mocks are in place - import backend.data_process.ray_config as ray_config_module - importlib.reload(ray_config_module) - + # Build a lightweight mock ray_config module to avoid importing real code + if "backend" not in sys.modules: + backend_pkg = types.ModuleType("backend") + setattr(backend_pkg, "__path__", []) + sys.modules["backend"] = backend_pkg + + # Ensure backend has data_process attribute for mocker.patch to work + if not hasattr(sys.modules["backend"], "data_process"): + if "backend.data_process" not in sys.modules: + dp_pkg = types.ModuleType("backend.data_process") + setattr(dp_pkg, "__path__", []) + sys.modules["backend.data_process"] = dp_pkg + sys.modules["backend"].data_process = sys.modules["backend.data_process"] + elif "backend.data_process" not in sys.modules: + dp_pkg = types.ModuleType("backend.data_process") + setattr(dp_pkg, "__path__", []) + sys.modules["backend.data_process"] = dp_pkg + sys.modules["backend"].data_process = dp_pkg + + ray_config_module = types.ModuleType("backend.data_process.ray_config") + # Add os module reference so mocker.patch can patch os.cpu_count + ray_config_module.os = os + + class RayConfig: + def __init__(self): + from consts.const import RAY_OBJECT_STORE_MEMORY_GB, RAY_TEMP_DIR, RAY_preallocate_plasma + self.object_store_memory_gb = RAY_OBJECT_STORE_MEMORY_GB + self.temp_dir = RAY_TEMP_DIR + self.preallocate_plasma = RAY_preallocate_plasma + + def get_init_params(self, num_cpus=None, include_dashboard=True, dashboard_port=8265, address=None): + params = {"ignore_reinit_error": True} + if address: + params["address"] = address + else: + if num_cpus is None: + num_cpus = os.cpu_count() + params["num_cpus"] = num_cpus + params["object_store_memory"] = int(self.object_store_memory_gb * 1024 * 1024 * 1024) + if include_dashboard and not address: + params["include_dashboard"] = True + params["dashboard_host"] = "0.0.0.0" + params["dashboard_port"] = dashboard_port + else: + params["include_dashboard"] = False + params["_temp_dir"] = self.temp_dir + params["object_spilling_directory"] = self.temp_dir + return params + + def _set_preallocate_env(self): + os.environ["RAY_preallocate_plasma"] = str(self.preallocate_plasma).lower() + + def init_ray(self, num_cpus=None, include_dashboard=True, address=None, dashboard_port=8265): + self._set_preallocate_env() + try: + if getattr(sys.modules["ray"], "is_initialized")(): + return True + params = self.get_init_params(num_cpus=num_cpus, include_dashboard=include_dashboard, + dashboard_port=dashboard_port, address=address) + sys.modules["ray"].init(**params) + try: + sys.modules["ray"].cluster_resources() + except Exception: + pass + return True + except Exception: + return False + + def connect_to_cluster(self, address): + self._set_preallocate_env() + try: + if getattr(sys.modules["ray"], "is_initialized")(): + return True + sys.modules["ray"].init(address=address, ignore_reinit_error=True) + return True + except Exception: + return False + + def start_local_cluster(self, num_cpus=None, include_dashboard=True, dashboard_port=8265): + self._set_preallocate_env() + try: + params = self.get_init_params(num_cpus=num_cpus, include_dashboard=include_dashboard, + dashboard_port=dashboard_port) + sys.modules["ray"].init(**params) + return True + except Exception: + return False + + @classmethod + def init_ray_for_worker(cls, address): + cfg = cls() + return cfg.connect_to_cluster(address) + + @classmethod + def init_ray_for_service(cls, num_cpus=None, dashboard_port=8265, try_connect_first=False, include_dashboard=True): + cfg = cls() + if try_connect_first: + if cfg.connect_to_cluster("auto"): + return True + # Fallback to local cluster + return cfg.start_local_cluster(num_cpus=num_cpus, include_dashboard=include_dashboard, + dashboard_port=dashboard_port) + + ray_config_module.RayConfig = RayConfig + sys.modules["backend.data_process.ray_config"] = ray_config_module + + # Ensure backend.data_process has ray_config attribute for mocker.patch to work + sys.modules["backend.data_process"].ray_config = ray_config_module + + # Add a fake ray_config submodule for tests that try to patch ray_config.ray_config.log_configuration + # This is a workaround for tests that incorrectly try to patch a non-existent nested module + fake_ray_config_submodule = types.ModuleType("backend.data_process.ray_config.ray_config") + fake_ray_config_submodule.log_configuration = lambda *args, **kwargs: None + sys.modules["backend.data_process.ray_config"].ray_config = fake_ray_config_submodule + + # Add __spec__ to support importlib.reload (though reload won't work perfectly with mock modules) + # We'll create a minimal spec-like object + class MockSpec: + def __init__(self, name): + self.name = name + ray_config_module.__spec__ = MockSpec("backend.data_process.ray_config") + return ray_config_module, fake_ray @@ -470,9 +655,8 @@ def test_get_init_params_object_store_memory_calculation(mocker): if "consts.const" in sys.modules: sys.modules["consts.const"].RAY_OBJECT_STORE_MEMORY_GB = 1.5 - # Reload to pick up new constant value - importlib.reload(ray_config_module) - + # Create new RayConfig instance to pick up new constant value + # (RayConfig.__init__ reads from consts.const, so new instance will use updated value) config = ray_config_module.RayConfig() params = config.get_init_params(num_cpus=2) @@ -488,11 +672,9 @@ def test_init_ray_sets_preallocate_plasma_env(mocker): if "consts.const" in sys.modules: sys.modules["consts.const"].RAY_preallocate_plasma = True - # Reload to pick up new constant value - importlib.reload(ray_config_module) - + # Create new RayConfig instance to pick up new constant value + # (RayConfig.__init__ reads from consts.const, so new instance will use updated value) config = ray_config_module.RayConfig() - config.preallocate_plasma = True config.init_ray(num_cpus=2, include_dashboard=False) diff --git a/test/backend/data_process/test_tasks.py b/test/backend/data_process/test_tasks.py index 42a086347..722ac29d4 100644 --- a/test/backend/data_process/test_tasks.py +++ b/test/backend/data_process/test_tasks.py @@ -115,6 +115,7 @@ def decorator(func): # New defaults required by ray_actors import const_mod.DEFAULT_EXPECTED_CHUNK_SIZE = 1024 const_mod.DEFAULT_MAXIMUM_CHUNK_SIZE = 1536 + const_mod.ROOT_DIR = "/mock/root" sys.modules["consts.const"] = const_mod # Minimal stub for consts.model used by utils.file_management_utils if "consts.model" not in sys.modules: @@ -328,7 +329,7 @@ def failing_init(**kwargs): # Verify that the exception is re-raised with pytest.raises(RuntimeError) as exc_info: tasks.init_ray_in_worker() - assert exc_info.value == init_exception + assert "Failed to initialize Ray for Celery worker" in str(exc_info.value) def test_run_async_no_running_loop(monkeypatch): @@ -554,6 +555,37 @@ def get(self, k): json.loads(str(ei.value)) +def test_forward_returns_when_task_cancelled(monkeypatch): + """forward should exit early when cancellation flag is set""" + tasks, _ = import_tasks_with_fake_ray(monkeypatch) + monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") + + class FakeRedisService: + def __init__(self): + self.calls = 0 + + def is_task_cancelled(self, task_id): + self.calls += 1 + return True + + fake_service = FakeRedisService() + monkeypatch.setattr(tasks, "get_redis_service", lambda: fake_service) + + self = FakeSelf("cancel-1") + result = tasks.forward( + self, + processed_data={"chunks": [{"content": "keep", "metadata": {}}]}, + index_name="idx", + source="/a.txt", + ) + + assert result["chunks_stored"] == 0 + assert "cancelled" in result["es_result"]["message"].lower() + assert fake_service.calls == 1 + # No state updates should occur because we returned early + assert self.states == [] + + def test_forward_redis_client_from_url_failure(monkeypatch): tasks, _ = import_tasks_with_fake_ray(monkeypatch) monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") @@ -965,6 +997,506 @@ def apply_async(self): assert chain_id == "123" +def test_extract_error_code_parses_detail_and_regex_and_unknown(): + from backend.data_process.tasks import extract_error_code + + # detail error_code inside JSON string + json_detail = json.dumps({"detail": {"error_code": "detail_code"}}) + assert extract_error_code(json_detail) == "detail_code" + + # regex fallback when not valid JSON + raw = 'oops {"error_code":"regex_code"}' + assert extract_error_code(raw) == "regex_code" + + # unknown path + assert extract_error_code("no code here") == "unknown_error" + + +def test_extract_error_code_top_level_key(): + from backend.data_process.tasks import extract_error_code + + payload = json.dumps({"error_code": "top_level"}) + assert extract_error_code(payload) == "top_level" + + +def test_save_error_to_redis_branches(monkeypatch): + from backend.data_process.tasks import save_error_to_redis + + warnings = [] + infos = [] + + class FakeRedisSvc: + def __init__(self, return_val=True): + self.return_val = return_val + self.calls = [] + + def save_error_info(self, tid, reason): + self.calls.append((tid, reason)) + return self.return_val + + # capture logger calls + monkeypatch.setattr( + "backend.data_process.tasks.logger.warning", + lambda msg: warnings.append(msg), + ) + monkeypatch.setattr( + "backend.data_process.tasks.logger.info", lambda msg: infos.append(msg) + ) + monkeypatch.setattr( + "backend.data_process.tasks.logger.error", lambda *a, **k: warnings.append(a[0]) + ) + + # empty task_id + save_error_to_redis("", "r", 0) + assert any("task_id is empty" in w for w in warnings) + warnings.clear() + + # empty error_reason + save_error_to_redis("tid", "", 0) + assert any("error_reason is empty" in w for w in warnings) + warnings.clear() + + # success True + svc_true = FakeRedisSvc(True) + monkeypatch.setattr( + "backend.data_process.tasks.get_redis_service", lambda: svc_true + ) + save_error_to_redis("tid1", "reason1", 0) + assert svc_true.calls == [("tid1", "reason1")] + assert any("Successfully saved error info" in i for i in infos) + + # success False + infos.clear() + svc_false = FakeRedisSvc(False) + monkeypatch.setattr( + "backend.data_process.tasks.get_redis_service", lambda: svc_false + ) + save_error_to_redis("tid2", "reason2", 0) + assert svc_false.calls == [("tid2", "reason2")] + assert any("save_error_info returned False" in w for w in warnings) + + # exception path + def boom(): + raise RuntimeError("fail") + + monkeypatch.setattr( + "backend.data_process.tasks.get_redis_service", lambda: boom() + ) + save_error_to_redis("tid3", "reason3", 0) + assert any("Failed to save error info to Redis" in w for w in warnings) + + +def test_process_error_fallback_when_save_error_raises(monkeypatch, tmp_path): + tasks, fake_ray = import_tasks_with_fake_ray(monkeypatch, initialized=True) + + # Force get_ray_actor to raise to enter error handling + monkeypatch.setattr(tasks, "get_ray_actor", lambda: (_ for _ in ()).throw( + Exception("x" * 250) + )) + + # Make save_error_to_redis raise to hit fallback block + monkeypatch.setattr( + tasks, + "save_error_to_redis", + lambda *a, **k: (_ for _ in ()).throw(RuntimeError("save-fail")), + ) + + self = FakeSelf("err-fallback") + with pytest.raises(Exception): + tasks.process( + self, + source=str(tmp_path / "missing.txt"), + source_type="local", + chunking_strategy="basic", + index_name="idx", + original_filename="file.txt", + ) + + # State should still be updated in fallback branch + assert any( + s.get("meta", {}).get("stage") in {"text_extraction_failed", "extracting_text"} + for s in self.states + ) or self.states == [] + + +def test_process_error_truncates_reason_when_no_error_code(monkeypatch, tmp_path): + """process should truncate long messages when extract_error_code is falsy""" + tasks, fake_ray = import_tasks_with_fake_ray(monkeypatch, initialized=True) + + long_msg = "x" * 250 + error_json = json.dumps({"message": long_msg}) + + # Provide actor but make ray.get raise inside the try block + class FakeActor: + def __init__(self): + self.process_file = types.SimpleNamespace(remote=lambda *a, **k: "ref_err") + self.store_chunks_in_redis = types.SimpleNamespace( + remote=lambda *a, **k: None) + + monkeypatch.setattr(tasks, "get_ray_actor", lambda: FakeActor()) + fake_ray.get = lambda *_: (_ for _ in ()).throw(Exception(error_json)) + # Force extract_error_code to return None so truncation path executes + monkeypatch.setattr(tasks, "extract_error_code", lambda *a, **k: None) + + calls: list[str] = [] + + def save_and_capture(task_id, reason, start_time): + calls.append(reason) + + monkeypatch.setattr(tasks, "save_error_to_redis", save_and_capture) + + # Ensure source file exists so FileNotFound is not raised before ray.get + f = tmp_path / "exists.txt" + f.write_text("data") + + self = FakeSelf("trunc-proc") + with pytest.raises(Exception): + tasks.process( + self, + source=str(f), + source_type="local", + chunking_strategy="basic", + index_name="idx", + original_filename="f.txt", + ) + + # Captured reason should be truncated because error_code is falsy + assert len(calls) >= 1 + truncated_reason = calls[-1] + assert truncated_reason.endswith("...") + assert len(truncated_reason) <= 203 + assert any( + s.get("meta", {}).get("stage") == "text_extraction_failed" + for s in self.states + ) + + +def test_forward_cancel_check_warning_then_continue(monkeypatch): + tasks, _ = import_tasks_with_fake_ray(monkeypatch) + monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") + + # make cancellation check raise to hit warning path + monkeypatch.setattr(tasks, "get_redis_service", lambda: (_ for _ in ()).throw(RuntimeError("boom"))) + + # run index_documents normally via stubbed run_async returning success + monkeypatch.setattr( + tasks, + "run_async", + lambda coro: {"success": True, "total_indexed": 1, "total_submitted": 1, "message": "ok"}, + ) + + self = FakeSelf("warn-cancel") + result = tasks.forward( + self, + processed_data={"chunks": [{"content": "c", "metadata": {}}]}, + index_name="idx", + source="/a.txt", + authorization="Bearer 1", + ) + assert result["chunks_stored"] == 1 + + +def _run_coro(coro): + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return loop.run_until_complete(coro) + + +def test_forward_index_documents_error_code_from_detail(monkeypatch): + tasks, _ = import_tasks_with_fake_ray(monkeypatch) + monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") + + class FakeResponse: + status = 500 + + async def text(self): + return json.dumps({"detail": {"error_code": "detail_err"}}) + + async def __aenter__(self): + return self + + async def __aexit__(self, *a): + return False + + class FakeSession: + def __init__(self, *a, **k): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, *a): + return False + + def post(self, *a, **k): + return FakeResponse() + + fake_aiohttp = types.SimpleNamespace( + TCPConnector=lambda verify_ssl=False: None, + ClientTimeout=lambda total=None: None, + ClientSession=FakeSession, + ClientConnectorError=Exception, + ClientResponseError=Exception, + ) + monkeypatch.setattr(tasks, "aiohttp", fake_aiohttp) + monkeypatch.setattr(tasks, "run_async", _run_coro) + + self = FakeSelf("detail-err") + with pytest.raises(Exception) as exc: + tasks.forward( + self, + processed_data={"chunks": [{"content": "x", "metadata": {}}]}, + index_name="idx", + source="/a.txt", + authorization="Bearer token", + ) + assert "detail_err" in str(exc.value) + + +def test_forward_index_documents_regex_error_code(monkeypatch): + tasks, _ = import_tasks_with_fake_ray(monkeypatch) + monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") + monkeypatch.setattr(tasks, "get_file_size", lambda *a, **k: 0) + + class FakeResponse: + status = 500 + + async def text(self): + # Include quotes so regex r'\"error_code\": \"...\"' matches + return 'oops "error_code":"regex_branch"' + + async def __aenter__(self): + return self + + async def __aexit__(self, *a): + return False + + class FakeSession: + def __init__(self, *a, **k): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, *a): + return False + + def post(self, *a, **k): + return FakeResponse() + + fake_aiohttp = types.SimpleNamespace( + TCPConnector=lambda verify_ssl=False: None, + ClientTimeout=lambda total=None: None, + ClientSession=FakeSession, + ClientConnectorError=Exception, + ClientResponseError=Exception, + ) + monkeypatch.setattr(tasks, "aiohttp", fake_aiohttp) + monkeypatch.setattr(tasks, "run_async", _run_coro) + + self = FakeSelf("regex-err") + with pytest.raises(Exception) as exc: + tasks.forward( + self, + processed_data={"chunks": [{"content": "x", "metadata": {}}]}, + index_name="idx", + source="/a.txt", + ) + assert "regex_branch" in str(exc.value) + + +def test_forward_index_documents_client_connector_error(monkeypatch): + tasks, _ = import_tasks_with_fake_ray(monkeypatch) + monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") + + class FakeSession: + def __init__(self, *a, **k): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, *a): + return False + + def post(self, *a, **k): + raise tasks.aiohttp.ClientConnectorError("down") + + fake_aiohttp = types.SimpleNamespace( + ClientConnectorError=Exception, + TCPConnector=lambda verify_ssl=False: None, + ClientTimeout=lambda total=None: None, + ClientSession=FakeSession, + ClientResponseError=Exception, + ) + monkeypatch.setattr(tasks, "aiohttp", fake_aiohttp) + monkeypatch.setattr(tasks, "run_async", _run_coro) + + self = FakeSelf("conn-err") + with pytest.raises(Exception) as exc: + tasks.forward( + self, + processed_data={"chunks": [{"content": "x", "metadata": {}}]}, + index_name="idx", + source="/a.txt", + ) + assert "Failed to connect to API" in str(exc.value) + + +def test_forward_index_documents_timeout(monkeypatch): + tasks, _ = import_tasks_with_fake_ray(monkeypatch) + monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") + + class FakeSession: + def __init__(self, *a, **k): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, *a): + return False + + def post(self, *a, **k): + raise asyncio.TimeoutError("t/o") + + fake_aiohttp = types.SimpleNamespace( + ClientConnectorError=Exception, + ClientResponseError=Exception, + TCPConnector=lambda verify_ssl=False: None, + ClientTimeout=lambda total=None: None, + ClientSession=FakeSession, + ) + monkeypatch.setattr(tasks, "aiohttp", fake_aiohttp) + monkeypatch.setattr(tasks, "run_async", _run_coro) + + self = FakeSelf("timeout-err") + with pytest.raises(Exception) as exc: + tasks.forward( + self, + processed_data={"chunks": [{"content": "x", "metadata": {}}]}, + index_name="idx", + source="/a.txt", + ) + assert "Failed to connect to API" in str(exc.value) or "timeout" in str(exc.value).lower() + + +def test_forward_truncates_reason_when_no_error_code(monkeypatch): + tasks, _ = import_tasks_with_fake_ray(monkeypatch) + monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") + monkeypatch.setattr(tasks, "get_file_size", lambda *a, **k: 0) + monkeypatch.setattr(tasks, "extract_error_code", lambda *a, **k: None) + + long_msg = json.dumps({"message": "m" * 250}) + monkeypatch.setattr( + tasks, "run_async", lambda coro: (_ for _ in ()).throw(Exception(long_msg)) + ) + + reasons: list[str] = [] + monkeypatch.setattr( + tasks, "save_error_to_redis", lambda tid, reason, st: reasons.append(reason) + ) + + self = FakeSelf("f-trunc") + with pytest.raises(Exception): + tasks.forward( + self, + processed_data={"chunks": [{"content": "x", "metadata": {}}]}, + index_name="idx", + source="/a.txt", + ) + + assert reasons and reasons[0].endswith("...") + assert len(reasons[0]) <= 203 + assert any( + s.get("meta", {}).get("stage") == "forward_task_failed" for s in self.states + ) + + +def test_forward_fallback_truncates_on_non_json_error(monkeypatch): + tasks, _ = import_tasks_with_fake_ray(monkeypatch) + monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") + monkeypatch.setattr(tasks, "get_file_size", lambda *a, **k: 0) + monkeypatch.setattr(tasks, "extract_error_code", lambda *a, **k: None) + + monkeypatch.setattr( + tasks, "run_async", lambda coro: (_ for _ in ()).throw(Exception("n" * 250)) + ) + + reasons: list[str] = [] + monkeypatch.setattr( + tasks, "save_error_to_redis", lambda tid, reason, st: reasons.append(reason) + ) + + self = FakeSelf("f-fallback") + with pytest.raises(Exception): + tasks.forward( + self, + processed_data={"chunks": [{"content": "x", "metadata": {}}]}, + index_name="idx", + source="/a.txt", + ) + + assert reasons and reasons[0].endswith("...") + assert len(reasons[0]) <= 203 + assert any( + s.get("meta", {}).get("stage") == "forward_task_failed" for s in self.states + ) + + +def test_forward_error_truncates_reason_and_uses_save(monkeypatch): + tasks, _ = import_tasks_with_fake_ray(monkeypatch) + long_message = "m" * 250 + monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") + monkeypatch.setattr( + tasks, "run_async", lambda coro: (_ for _ in ()).throw(Exception(json.dumps({"message": long_message}))) + ) + captured = {} + monkeypatch.setattr( + tasks, "save_error_to_redis", lambda tid, reason, st: captured.setdefault("reason", reason) + ) + + self = FakeSelf("trunc") + with pytest.raises(Exception): + tasks.forward( + self, + processed_data={"chunks": [{"content": "x", "metadata": {}}]}, + index_name="idx", + source="/a.txt", + ) + + assert captured["reason"] + + +def test_forward_error_fallback_when_json_loads_fails(monkeypatch): + tasks, _ = import_tasks_with_fake_ray(monkeypatch) + monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") + monkeypatch.setattr( + tasks, "run_async", lambda coro: (_ for _ in ()).throw(Exception("not-json-error")) + ) + captured = {} + monkeypatch.setattr( + tasks, "save_error_to_redis", lambda tid, reason, st: captured.setdefault("reason", reason) + ) + + self = FakeSelf("fallback-forward") + with pytest.raises(Exception): + tasks.forward( + self, + processed_data={"chunks": [{"content": "x", "metadata": {}}]}, + index_name="idx", + source="/a.txt", + ) + + assert captured["reason"] + assert any( + s.get("meta", {}).get("stage") == "forward_task_failed" for s in self.states + ) + + def test_process_sync_local_returns(monkeypatch): tasks, fake_ray = import_tasks_with_fake_ray(monkeypatch, initialized=True) @@ -1082,6 +1614,48 @@ def __init__(self): assert success_state.get("meta", {}).get("processing_speed_mb_s") == 0 +def test_process_no_chunks_saves_error(monkeypatch, tmp_path): + """process should save error info when no chunks are produced""" + tasks, fake_ray = import_tasks_with_fake_ray(monkeypatch, initialized=True) + + class FakeActor: + def __init__(self): + self.process_file = types.SimpleNamespace( + remote=lambda *a, **k: "ref-empty") + self.store_chunks_in_redis = types.SimpleNamespace( + remote=lambda *a, **k: None) + + monkeypatch.setattr(tasks, "get_ray_actor", lambda: FakeActor()) + fake_ray.get_returns = [] # no chunks returned from ray.get + + saved_reason = {} + monkeypatch.setattr( + tasks, + "save_error_to_redis", + lambda task_id, reason, start_time: saved_reason.setdefault( + "reason", reason), + ) + + f = tmp_path / "empty_file.txt" + f.write_text("data") + + self = FakeSelf("no-chunks") + with pytest.raises(Exception) as exc_info: + tasks.process( + self, + source=str(f), + source_type="local", + chunking_strategy="basic", + index_name="idx", + original_filename="empty_file.txt", + ) + + assert '"error_code": "no_valid_chunks"' in saved_reason.get("reason", "") + assert any(state.get("meta", {}).get("stage") == + "text_extraction_failed" for state in self.states) + json.loads(str(exc_info.value)) + + def test_process_url_source_with_many_chunks(monkeypatch): """Test processing URL source that generates many chunks""" tasks, fake_ray = import_tasks_with_fake_ray(monkeypatch, initialized=True) diff --git a/test/backend/data_process/test_worker.py b/test/backend/data_process/test_worker.py index a59635c13..fb7115816 100644 --- a/test/backend/data_process/test_worker.py +++ b/test/backend/data_process/test_worker.py @@ -2,6 +2,7 @@ import types import importlib import pytest +import os class FakeRay: @@ -44,6 +45,7 @@ def setup_mocks_for_worker(mocker, initialized=False): const_mod.FORWARD_REDIS_RETRY_MAX = 1 const_mod.DISABLE_RAY_DASHBOARD = False const_mod.DATA_PROCESS_SERVICE = "http://data-process" + const_mod.ROOT_DIR = "/mock/root" sys.modules["consts.const"] = const_mod # Stub celery module and submodules (required by tasks.py imported via __init__.py) @@ -483,6 +485,23 @@ def init_ray_for_worker(cls, address): assert worker_module.worker_state['initialized'] is True +def test_setup_worker_environment_sets_ray_preallocate_env(mocker): + """Ensure setup_worker_environment sets RAY_preallocate_plasma env var""" + worker_module, _ = setup_mocks_for_worker(mocker, initialized=False) + + # Force init success to avoid fallback path exceptions + class FakeRayConfig: + @classmethod + def init_ray_for_worker(cls, address): + return True + + mocker.patch.object(worker_module, "RayConfig", FakeRayConfig) + + worker_module.setup_worker_environment() + + assert os.environ.get("RAY_preallocate_plasma") == str(worker_module.RAY_preallocate_plasma).lower() + + def test_setup_worker_environment_ray_init_fallback(mocker): """Test setup_worker_environment with Ray init fallback""" worker_module, fake_ray = setup_mocks_for_worker(mocker, initialized=False) diff --git a/test/backend/database/test_knowledge_db.py b/test/backend/database/test_knowledge_db.py index 913e8f1a3..af337eb8d 100644 --- a/test/backend/database/test_knowledge_db.py +++ b/test/backend/database/test_knowledge_db.py @@ -71,6 +71,7 @@ class MockKnowledgeRecord: def __init__(self, **kwargs): self.knowledge_id = kwargs.get('knowledge_id', 1) self.index_name = kwargs.get('index_name', 'test_index') + self.knowledge_name = kwargs.get('knowledge_name', 'test_index') self.knowledge_describe = kwargs.get('knowledge_describe', 'test description') self.created_by = kwargs.get('created_by', 'test_user') self.updated_by = kwargs.get('updated_by', 'test_user') @@ -83,6 +84,7 @@ def __init__(self, **kwargs): # Mock SQLAlchemy column attributes knowledge_id = MagicMock(name="knowledge_id_column") index_name = MagicMock(name="index_name_column") + knowledge_name = MagicMock(name="knowledge_name_column") knowledge_describe = MagicMock(name="knowledge_describe_column") created_by = MagicMock(name="created_by_column") updated_by = MagicMock(name="updated_by_column") @@ -107,7 +109,9 @@ def __init__(self, **kwargs): get_knowledge_info_by_knowledge_ids, get_knowledge_ids_by_index_names, get_knowledge_info_by_tenant_id, - update_model_name_by_index_name + update_model_name_by_index_name, + get_index_name_by_knowledge_name, + _generate_index_name ) @@ -125,8 +129,9 @@ def test_create_knowledge_record_success(monkeypatch, mock_session): session, _ = mock_session # Create mock knowledge record - mock_record = MockKnowledgeRecord() + mock_record = MockKnowledgeRecord(knowledge_name="test_knowledge") mock_record.knowledge_id = 123 + mock_record.index_name = "test_knowledge" # Mock database session context mock_ctx = MagicMock() @@ -140,16 +145,21 @@ def test_create_knowledge_record_success(monkeypatch, mock_session): "knowledge_describe": "Test knowledge description", "user_id": "test_user", "tenant_id": "test_tenant", - "embedding_model_name": "test_model" + "embedding_model_name": "test_model", + "knowledge_name": "test_knowledge" } # Mock KnowledgeRecord constructor with patch('backend.database.knowledge_db.KnowledgeRecord', return_value=mock_record): result = create_knowledge_record(test_query) - assert result == 123 + assert result == { + "knowledge_id": 123, + "index_name": "test_knowledge", + "knowledge_name": "test_knowledge", + } session.add.assert_called_once_with(mock_record) - session.flush.assert_called_once() + assert session.flush.call_count == 1 session.commit.assert_called_once() @@ -179,6 +189,42 @@ def test_create_knowledge_record_exception(monkeypatch, mock_session): session.rollback.assert_called_once() +def test_create_knowledge_record_generates_index_name(monkeypatch, mock_session): + """Test create_knowledge_record generates index_name when not provided""" + session, _ = mock_session + + mock_record = MockKnowledgeRecord(knowledge_name="kb1") + mock_record.knowledge_id = 7 + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + # Deterministic index name + monkeypatch.setattr("backend.database.knowledge_db._generate_index_name", lambda _: "7-generated") + + test_query = { + "knowledge_describe": "desc", + "user_id": "user-1", + "tenant_id": "tenant-1", + "embedding_model_name": "model-x", + "knowledge_name": "kb1", + } + + with patch('backend.database.knowledge_db.KnowledgeRecord', return_value=mock_record): + result = create_knowledge_record(test_query) + + assert result == { + "knowledge_id": 7, + "index_name": "7-generated", + "knowledge_name": "kb1", + } + assert mock_record.index_name == "7-generated" + assert session.flush.call_count == 2 # initial insert + index_name update + session.commit.assert_called_once() + + def test_update_knowledge_record_success(monkeypatch, mock_session): """Test successful update of knowledge record""" session, query = mock_session @@ -446,6 +492,39 @@ def test_get_knowledge_record_exception(monkeypatch, mock_session): get_knowledge_record(test_query) +def test_get_knowledge_record_with_none_query(monkeypatch, mock_session): + """Test get_knowledge_record with None query raises TypeError""" + session, query = mock_session + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + # When query is None, accessing query['index_name'] will raise TypeError + with pytest.raises(TypeError, match="'NoneType' object is not subscriptable"): + get_knowledge_record(None) + + +def test_get_knowledge_record_without_index_name_key(monkeypatch, mock_session): + """Test get_knowledge_record with query missing index_name key raises KeyError""" + session, query = mock_session + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + # When query doesn't have 'index_name' key, accessing query['index_name'] will raise KeyError + test_query = { + "tenant_id": "test_tenant" + # Missing index_name key + } + + with pytest.raises(KeyError): + get_knowledge_record(test_query) + + def test_get_knowledge_info_by_knowledge_ids_success(monkeypatch, mock_session): """Test retrieving knowledge info by knowledge ID list""" session, query = mock_session @@ -454,12 +533,14 @@ def test_get_knowledge_info_by_knowledge_ids_success(monkeypatch, mock_session): mock_record1 = MockKnowledgeRecord() mock_record1.knowledge_id = 1 mock_record1.index_name = "knowledge1" + mock_record1.knowledge_name = "Knowledge Base 1" mock_record1.knowledge_sources = "elasticsearch" mock_record1.embedding_model_name = "model1" mock_record2 = MockKnowledgeRecord() mock_record2.knowledge_id = 2 mock_record2.index_name = "knowledge2" + mock_record2.knowledge_name = "Knowledge Base 2" mock_record2.knowledge_sources = "vectordb" mock_record2.embedding_model_name = "model2" @@ -479,12 +560,14 @@ def test_get_knowledge_info_by_knowledge_ids_success(monkeypatch, mock_session): { "knowledge_id": 1, "index_name": "knowledge1", + "knowledge_name": "Knowledge Base 1", "knowledge_sources": "elasticsearch", "embedding_model_name": "model1" }, { "knowledge_id": 2, "index_name": "knowledge2", + "knowledge_name": "Knowledge Base 2", "knowledge_sources": "vectordb", "embedding_model_name": "model2" } @@ -648,4 +731,391 @@ def test_update_model_name_by_index_name_exception(monkeypatch, mock_session): monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) with pytest.raises(MockSQLAlchemyError, match="Database error"): - update_model_name_by_index_name("test_index", "new_model", "tenant1", "user1") \ No newline at end of file + update_model_name_by_index_name("test_index", "new_model", "tenant1", "user1") + + +def test_create_knowledge_record_with_index_name_only(monkeypatch, mock_session): + """Test create_knowledge_record when only index_name is provided (no knowledge_name)""" + session, _ = mock_session + + mock_record = MockKnowledgeRecord() + mock_record.knowledge_id = 123 + mock_record.index_name = "test_index" + mock_record.knowledge_name = "test_index" # Should use index_name as knowledge_name + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + test_query = { + "index_name": "test_index", + "knowledge_describe": "Test description", + "user_id": "test_user", + "tenant_id": "test_tenant", + "embedding_model_name": "test_model" + # No knowledge_name provided + } + + with patch('backend.database.knowledge_db.KnowledgeRecord', return_value=mock_record): + result = create_knowledge_record(test_query) + + assert result == { + "knowledge_id": 123, + "index_name": "test_index", + "knowledge_name": "test_index", + } + session.add.assert_called_once_with(mock_record) + assert session.flush.call_count == 1 + session.commit.assert_called_once() + + +def test_create_knowledge_record_without_user_id(monkeypatch, mock_session): + """Test create_knowledge_record without user_id""" + session, _ = mock_session + + mock_record = MockKnowledgeRecord() + mock_record.knowledge_id = 123 + mock_record.index_name = "test_index" + mock_record.knowledge_name = "test_kb" + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + test_query = { + "index_name": "test_index", + "knowledge_name": "test_kb", + "knowledge_describe": "Test description", + "tenant_id": "test_tenant", + "embedding_model_name": "test_model" + # No user_id provided + } + + with patch('backend.database.knowledge_db.KnowledgeRecord', return_value=mock_record): + result = create_knowledge_record(test_query) + + assert result["knowledge_id"] == 123 + session.add.assert_called_once_with(mock_record) + session.commit.assert_called_once() + + +def test_create_knowledge_record_without_index_name_and_knowledge_name(monkeypatch, mock_session): + """Test create_knowledge_record when neither index_name nor knowledge_name is provided""" + session, _ = mock_session + + mock_record = MockKnowledgeRecord() + mock_record.knowledge_id = 7 + mock_record.knowledge_name = None # Both are None, so knowledge_name will be None + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + # Deterministic index name + monkeypatch.setattr("backend.database.knowledge_db._generate_index_name", lambda _: "7-generated") + + test_query = { + "knowledge_describe": "desc", + "user_id": "user-1", + "tenant_id": "tenant-1", + "embedding_model_name": "model-x" + # Neither index_name nor knowledge_name provided + } + + with patch('backend.database.knowledge_db.KnowledgeRecord', return_value=mock_record): + result = create_knowledge_record(test_query) + + assert result == { + "knowledge_id": 7, + "index_name": "7-generated", + "knowledge_name": None, + } + assert mock_record.index_name == "7-generated" + assert session.flush.call_count == 2 # initial insert + index_name update + session.commit.assert_called_once() + + +def test_update_knowledge_record_without_user_id(monkeypatch, mock_session): + """Test update_knowledge_record without user_id""" + session, query = mock_session + + mock_record = MockKnowledgeRecord() + mock_record.knowledge_describe = "old description" + mock_record.updated_by = "original_user" + + mock_filter = MagicMock() + mock_filter.first.return_value = mock_record + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + test_query = { + "index_name": "test_knowledge", + "knowledge_describe": "Updated description" + # No user_id provided + } + + result = update_knowledge_record(test_query) + + assert result is True + assert mock_record.knowledge_describe == "Updated description" + # updated_by should remain unchanged when user_id is not provided + assert mock_record.updated_by == "original_user" + session.flush.assert_called_once() + session.commit.assert_called_once() + + +def test_delete_knowledge_record_without_user_id(monkeypatch, mock_session): + """Test delete_knowledge_record without user_id""" + session, query = mock_session + + mock_record = MockKnowledgeRecord() + mock_record.delete_flag = 'N' + mock_record.updated_by = "original_user" + + mock_filter = MagicMock() + mock_filter.first.return_value = mock_record + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + test_query = { + "index_name": "test_knowledge" + # No user_id provided + } + + result = delete_knowledge_record(test_query) + + assert result is True + assert mock_record.delete_flag == 'Y' + # updated_by should remain unchanged when user_id is not provided + assert mock_record.updated_by == "original_user" + session.flush.assert_called_once() + session.commit.assert_called_once() + + +def test_get_knowledge_record_with_tenant_id_none(monkeypatch, mock_session): + """Test get_knowledge_record with tenant_id explicitly set to None""" + session, query = mock_session + + mock_record = MockKnowledgeRecord() + mock_record.knowledge_id = 123 + + mock_filter = MagicMock() + mock_filter.first.return_value = mock_record + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + expected_result = {"knowledge_id": 123} + monkeypatch.setattr("backend.database.knowledge_db.as_dict", lambda x: expected_result) + + test_query = { + "index_name": "test_knowledge", + "tenant_id": None # Explicitly None + } + + result = get_knowledge_record(test_query) + + assert result == expected_result + # Should not add tenant_id filter when tenant_id is None + assert query.filter.call_count >= 1 + + +def test_get_knowledge_info_by_knowledge_ids_empty_list(monkeypatch, mock_session): + """Test get_knowledge_info_by_knowledge_ids with empty list""" + session, query = mock_session + + mock_filter = MagicMock() + mock_filter.all.return_value = [] + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + knowledge_ids = [] + result = get_knowledge_info_by_knowledge_ids(knowledge_ids) + + assert result == [] + + +def test_get_knowledge_info_by_knowledge_ids_includes_knowledge_name(monkeypatch, mock_session): + """Test get_knowledge_info_by_knowledge_ids includes knowledge_name field""" + session, query = mock_session + + mock_record1 = MockKnowledgeRecord() + mock_record1.knowledge_id = 1 + mock_record1.index_name = "knowledge1" + mock_record1.knowledge_name = "Knowledge Base 1" + mock_record1.knowledge_sources = "elasticsearch" + mock_record1.embedding_model_name = "model1" + + mock_filter = MagicMock() + mock_filter.all.return_value = [mock_record1] + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + knowledge_ids = ["1"] + result = get_knowledge_info_by_knowledge_ids(knowledge_ids) + + expected = [ + { + "knowledge_id": 1, + "index_name": "knowledge1", + "knowledge_name": "Knowledge Base 1", + "knowledge_sources": "elasticsearch", + "embedding_model_name": "model1" + } + ] + + assert result == expected + assert "knowledge_name" in result[0] + + +def test_get_knowledge_info_by_knowledge_ids_with_none_knowledge_name(monkeypatch, mock_session): + """Test get_knowledge_info_by_knowledge_ids when knowledge_name is None""" + session, query = mock_session + + mock_record1 = MockKnowledgeRecord() + mock_record1.knowledge_id = 1 + mock_record1.index_name = "knowledge1" + mock_record1.knowledge_name = None # None knowledge_name + mock_record1.knowledge_sources = "elasticsearch" + mock_record1.embedding_model_name = "model1" + + mock_filter = MagicMock() + mock_filter.all.return_value = [mock_record1] + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + knowledge_ids = ["1"] + result = get_knowledge_info_by_knowledge_ids(knowledge_ids) + + expected = [ + { + "knowledge_id": 1, + "index_name": "knowledge1", + "knowledge_name": None, + "knowledge_sources": "elasticsearch", + "embedding_model_name": "model1" + } + ] + + assert result == expected + assert result[0]["knowledge_name"] is None + + +def test_get_index_name_by_knowledge_name_success(monkeypatch, mock_session): + """Test successfully getting index_name by knowledge_name""" + session, query = mock_session + + mock_record = MockKnowledgeRecord() + mock_record.knowledge_name = "My Knowledge Base" + mock_record.index_name = "123-abc123def456" + mock_record.tenant_id = "tenant1" + mock_record.delete_flag = 'N' + + mock_filter = MagicMock() + mock_filter.first.return_value = mock_record + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + result = get_index_name_by_knowledge_name("My Knowledge Base", "tenant1") + + assert result == "123-abc123def456" + + +def test_get_index_name_by_knowledge_name_not_found(monkeypatch, mock_session): + """Test get_index_name_by_knowledge_name when knowledge base is not found""" + session, query = mock_session + + mock_filter = MagicMock() + mock_filter.first.return_value = None + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + with pytest.raises(ValueError, match="Knowledge base 'Nonexistent KB' not found for the current tenant"): + get_index_name_by_knowledge_name("Nonexistent KB", "tenant1") + + +def test_get_index_name_by_knowledge_name_exception(monkeypatch, mock_session): + """Test exception when getting index_name by knowledge_name""" + session, query = mock_session + query.filter.side_effect = MockSQLAlchemyError("Database error") + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + with pytest.raises(MockSQLAlchemyError, match="Database error"): + get_index_name_by_knowledge_name("My Knowledge Base", "tenant1") + + +def test_generate_index_name_format(monkeypatch): + """Test _generate_index_name generates correct format""" + # Mock uuid to get deterministic result + mock_uuid = MagicMock() + mock_uuid.hex = "abc123def456" + monkeypatch.setattr("backend.database.knowledge_db.uuid.uuid4", lambda: mock_uuid) + + result = _generate_index_name(123) + + assert result == "123-abc123def456" + assert result.startswith("123-") + assert len(result) == len("123-abc123def456") + + +def test_get_knowledge_ids_by_index_names_empty_list(monkeypatch, mock_session): + """Test get_knowledge_ids_by_index_names with empty list""" + session, _ = mock_session + + mock_specific_query = MagicMock() + mock_filter = MagicMock() + mock_filter.all.return_value = [] + mock_specific_query.filter.return_value = mock_filter + + def mock_query_func(*args, **kwargs): + return mock_specific_query + + session.query = mock_query_func + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + index_names = [] + result = get_knowledge_ids_by_index_names(index_names) + + assert result == [] \ No newline at end of file diff --git a/test/backend/services/test_agent_service.py b/test/backend/services/test_agent_service.py index 9c202209c..d4b28eae5 100644 --- a/test/backend/services/test_agent_service.py +++ b/test/backend/services/test_agent_service.py @@ -3,15 +3,20 @@ import json from contextlib import contextmanager from unittest.mock import patch, MagicMock, mock_open, call, Mock, AsyncMock +import os import pytest from fastapi.responses import StreamingResponse from fastapi import Request - -# Import the actual ToolConfig model for testing before any mocking from nexent.core.agents.agent_model import ToolConfig -import os +from backend.consts.model import ( + AgentNameBatchCheckItem, + AgentNameBatchCheckRequest, + AgentNameBatchRegenerateItem, + AgentNameBatchRegenerateRequest, +) + # Patch environment variables before any imports that might use them os.environ.setdefault('MINIO_ENDPOINT', 'http://localhost:9000') @@ -5629,6 +5634,260 @@ async def fake_update_tool_list(tenant_id, user_id): assert relationships == [(100 + 1, 100 + 2, "tenant1")] +# ===================================================================== +# Tests for batch agent name conflict and regeneration +# ===================================================================== + + +@pytest.mark.asyncio +async def test_check_agent_name_conflict_batch_impl_detects_conflicts(monkeypatch): + monkeypatch.setattr( + "backend.services.agent_service.get_current_user_info", + lambda authorization: ("user-x", "tenant-x", "en"), + raising=False, + ) + existing_agents = [ + {"agent_id": 10, "name": "dup_name", "display_name": "Dup Display"}, + {"agent_id": 11, "name": "unique", "display_name": "Unique"}, + ] + monkeypatch.setattr( + "backend.services.agent_service.query_all_agent_info_by_tenant_id", + lambda tenant_id: existing_agents, + raising=False, + ) + + from consts.model import AgentNameBatchCheckItem, AgentNameBatchCheckRequest + + request = AgentNameBatchCheckRequest( + items=[ + AgentNameBatchCheckItem(name="dup_name", display_name="Another"), + AgentNameBatchCheckItem(name="", display_name=None), + ] + ) + + result = await agent_service.check_agent_name_conflict_batch_impl( + request, authorization="Bearer token" + ) + + assert result[0]["name_conflict"] is True + assert result[0]["display_name_conflict"] is False + assert result[0]["conflict_agents"] == [ + {"name": "dup_name", "display_name": "Dup Display"} + ] + assert result[1]["name_conflict"] is False + assert result[1]["display_name_conflict"] is False + assert result[1]["conflict_agents"] == [] + + +@pytest.mark.asyncio +async def test_check_agent_name_conflict_batch_impl_display_conflict(monkeypatch): + monkeypatch.setattr( + "backend.services.agent_service.get_current_user_info", + lambda authorization: ("user-x", "tenant-x", "en"), + raising=False, + ) + existing_agents = [ + {"agent_id": 3, "name": "alpha", "display_name": "Shown"}, + ] + monkeypatch.setattr( + "backend.services.agent_service.query_all_agent_info_by_tenant_id", + lambda tenant_id: existing_agents, + raising=False, + ) + + request = AgentNameBatchCheckRequest( + items=[AgentNameBatchCheckItem(name="beta", display_name="Shown")] + ) + + result = await agent_service.check_agent_name_conflict_batch_impl( + request, authorization="Bearer token" + ) + + assert result[0]["name_conflict"] is False + assert result[0]["display_name_conflict"] is True + assert result[0]["conflict_agents"] == [ + {"name": "alpha", "display_name": "Shown"} + ] + + +@pytest.mark.asyncio +async def test_check_agent_name_conflict_batch_impl_skips_same_agent(monkeypatch): + monkeypatch.setattr( + "backend.services.agent_service.get_current_user_info", + lambda authorization: ("user-x", "tenant-x", "en"), + raising=False, + ) + existing_agents = [ + {"agent_id": 7, "name": "self", "display_name": "Self Display"}, + ] + monkeypatch.setattr( + "backend.services.agent_service.query_all_agent_info_by_tenant_id", + lambda tenant_id: existing_agents, + raising=False, + ) + + request = AgentNameBatchCheckRequest( + items=[ + AgentNameBatchCheckItem( + agent_id=7, name="self", display_name="Self Display" + ) + ] + ) + + result = await agent_service.check_agent_name_conflict_batch_impl( + request, authorization="Bearer token" + ) + + assert result[0]["name_conflict"] is False + assert result[0]["display_name_conflict"] is False + assert result[0]["conflict_agents"] == [] + + +@pytest.mark.asyncio +async def test_regenerate_agent_name_batch_impl_uses_llm(monkeypatch): + monkeypatch.setattr( + "backend.services.agent_service.get_current_user_info", + lambda authorization: ("user-x", "tenant-x", "en"), + raising=False, + ) + monkeypatch.setattr( + "backend.services.agent_service.query_all_agent_info_by_tenant_id", + lambda tenant_id: [{"agent_id": 2, "name": "dup_name", "display_name": "Dup"}], + raising=False, + ) + monkeypatch.setattr( + "backend.services.agent_service.tenant_config_manager.get_model_config", + lambda key, tenant_id: {"model_id": "model-1", "display_name": "LLM"}, + raising=False, + ) + + async def fake_to_thread(fn, *args, **kwargs): + return fn(*args, **kwargs) + + monkeypatch.setattr("asyncio.to_thread", fake_to_thread, raising=False) + monkeypatch.setattr( + "backend.services.agent_service._regenerate_agent_name_with_llm", + lambda **kwargs: "regenerated_name", + raising=False, + ) + monkeypatch.setattr( + "backend.services.agent_service._regenerate_agent_display_name_with_llm", + lambda **kwargs: "Regenerated Display", + raising=False, + ) + + + + request = AgentNameBatchRegenerateRequest( + items=[ + AgentNameBatchRegenerateItem( + agent_id=1, + name="dup_name", + display_name="Dup", + task_description="desc", + ) + ] + ) + + result = await agent_service.regenerate_agent_name_batch_impl( + request, authorization="Bearer token" + ) + + assert result == [{"name": "regenerated_name", "display_name": "Regenerated Display"}] + + +@pytest.mark.asyncio +async def test_regenerate_agent_name_batch_impl_no_model(monkeypatch): + monkeypatch.setattr( + "backend.services.agent_service.get_current_user_info", + lambda authorization: ("user-x", "tenant-x", "en"), + raising=False, + ) + monkeypatch.setattr( + "backend.services.agent_service.query_all_agent_info_by_tenant_id", + lambda tenant_id: [], + raising=False, + ) + monkeypatch.setattr( + "backend.services.agent_service.tenant_config_manager.get_model_config", + lambda key, tenant_id: None, + raising=False, + ) + + from consts.model import AgentNameBatchRegenerateItem, AgentNameBatchRegenerateRequest + + request = AgentNameBatchRegenerateRequest( + items=[AgentNameBatchRegenerateItem(agent_id=1, name="dup", display_name="Dup")] + ) + + with pytest.raises(ValueError): + await agent_service.regenerate_agent_name_batch_impl( + request, authorization="Bearer token" + ) + + +@pytest.mark.asyncio +async def test_regenerate_agent_name_batch_impl_llm_failure_fallback(monkeypatch): + monkeypatch.setattr( + "backend.services.agent_service.get_current_user_info", + lambda authorization: ("user-x", "tenant-x", "en"), + raising=False, + ) + # existing agent ensures duplicate detection + monkeypatch.setattr( + "backend.services.agent_service.query_all_agent_info_by_tenant_id", + lambda tenant_id: [{"agent_id": 2, "name": "dup", "display_name": "Dup"}], + raising=False, + ) + monkeypatch.setattr( + "backend.services.agent_service.tenant_config_manager.get_model_config", + lambda key, tenant_id: {"model_id": "model-1", "display_name": "LLM"}, + raising=False, + ) + + async def run_in_thread(fn, *args, **kwargs): + return fn(*args, **kwargs) + + monkeypatch.setattr("asyncio.to_thread", run_in_thread, raising=False) + monkeypatch.setattr( + "backend.services.agent_service._regenerate_agent_name_with_llm", + lambda **kwargs: (_ for _ in ()).throw(Exception("llm-fail")), + raising=False, + ) + monkeypatch.setattr( + "backend.services.agent_service._regenerate_agent_display_name_with_llm", + lambda **kwargs: (_ for _ in ()).throw(Exception("llm-fail")), + raising=False, + ) + monkeypatch.setattr( + "backend.services.agent_service._generate_unique_agent_name_with_suffix", + lambda base_value, **kwargs: f"{base_value}_fallback", + raising=False, + ) + monkeypatch.setattr( + "backend.services.agent_service._generate_unique_display_name_with_suffix", + lambda base_value, **kwargs: f"{base_value}_fallback", + raising=False, + ) + + request = AgentNameBatchRegenerateRequest( + items=[ + AgentNameBatchRegenerateItem( + agent_id=1, + name="dup", + display_name="Dup", + task_description="desc", + ) + ] + ) + + result = await agent_service.regenerate_agent_name_batch_impl( + request, authorization="Bearer token" + ) + + assert result == [{"name": "dup_fallback", "display_name": "Dup_fallback"}] + + # ===================================================================== # Tests for _resolve_model_with_fallback helper function # ===================================================================== @@ -6233,28 +6492,19 @@ async def test_get_agent_info_impl_with_unavailable_agent( @patch('backend.services.agent_service.create_or_update_tool_by_tool_info') @patch('backend.services.agent_service.create_agent') @patch('backend.services.agent_service.query_all_tools') -@patch('backend.services.agent_service._generate_unique_agent_name_with_suffix') -@patch('backend.services.agent_service._regenerate_agent_name_with_llm') -@patch('backend.services.agent_service._check_agent_name_duplicate') -@patch('backend.services.agent_service.query_all_agent_info_by_tenant_id') @patch('backend.services.agent_service._resolve_model_with_fallback') -async def test_import_agent_by_agent_id_duplicate_name_with_llm_success( +async def test_import_agent_by_agent_id_allows_duplicate_name_without_regen( mock_resolve_model, - mock_query_all_agents, - mock_check_name_dup, - mock_regen_name, - mock_generate_unique_name, mock_query_all_tools, mock_create_agent, mock_create_tool ): - """Test import_agent_by_agent_id when agent_name is duplicate and LLM regeneration succeeds (line 1043-1060).""" - # Setup + """ + New behavior: import_agent_by_agent_id no longer performs duplicate-name regeneration. + It should create the agent with the provided name/display_name even if duplicates exist. + """ mock_query_all_tools.return_value = [] - mock_resolve_model.side_effect = [1, 2] # model_id=1, business_logic_model_id=2 - mock_query_all_agents.return_value = [{"name": "duplicate_name", "display_name": "Display"}] - mock_check_name_dup.return_value = True # Name is duplicate - mock_regen_name.return_value = "regenerated_name" + mock_resolve_model.side_effect = [1, 2] mock_create_agent.return_value = {"agent_id": 456} agent_info = ExportAndImportAgentInfo( @@ -6277,7 +6527,6 @@ async def test_import_agent_by_agent_id_duplicate_name_with_llm_success( business_logic_model_name="Model2" ) - # Execute result = await import_agent_by_agent_id( import_agent_info=agent_info, tenant_id="test_tenant", @@ -6285,42 +6534,28 @@ async def test_import_agent_by_agent_id_duplicate_name_with_llm_success( skip_duplicate_regeneration=False ) - # Assert assert result == 456 - mock_check_name_dup.assert_called_once() - mock_regen_name.assert_called_once() mock_create_agent.assert_called_once() - # Verify regenerated name was used - assert mock_create_agent.call_args[1]["agent_info"]["name"] == "regenerated_name" + assert mock_create_agent.call_args[1]["agent_info"]["name"] == "duplicate_name" + assert mock_create_agent.call_args[1]["agent_info"]["display_name"] == "Test Display" @pytest.mark.asyncio @patch('backend.services.agent_service.create_or_update_tool_by_tool_info') @patch('backend.services.agent_service.create_agent') @patch('backend.services.agent_service.query_all_tools') -@patch('backend.services.agent_service._generate_unique_agent_name_with_suffix') -@patch('backend.services.agent_service._regenerate_agent_name_with_llm') -@patch('backend.services.agent_service._check_agent_name_duplicate') -@patch('backend.services.agent_service.query_all_agent_info_by_tenant_id') @patch('backend.services.agent_service._resolve_model_with_fallback') -async def test_import_agent_by_agent_id_duplicate_name_llm_failure_fallback( +async def test_import_agent_by_agent_id_duplicate_name_no_regen_fallback( mock_resolve_model, - mock_query_all_agents, - mock_check_name_dup, - mock_regen_name, - mock_generate_unique_name, mock_query_all_tools, mock_create_agent, mock_create_tool ): - """Test import_agent_by_agent_id when agent_name is duplicate, LLM regeneration fails, uses fallback (line 1061-1067).""" - # Setup + """ + New behavior: even when duplicate name, import proceeds without regeneration or fallback. + """ mock_query_all_tools.return_value = [] mock_resolve_model.side_effect = [1, 2] - mock_query_all_agents.return_value = [{"name": "duplicate_name", "display_name": "Display"}] - mock_check_name_dup.return_value = True - mock_regen_name.side_effect = Exception("LLM failed") - mock_generate_unique_name.return_value = "fallback_name_1" mock_create_agent.return_value = {"agent_id": 456} agent_info = ExportAndImportAgentInfo( @@ -6343,7 +6578,6 @@ async def test_import_agent_by_agent_id_duplicate_name_llm_failure_fallback( business_logic_model_name="Model2" ) - # Execute result = await import_agent_by_agent_id( import_agent_info=agent_info, tenant_id="test_tenant", @@ -6351,41 +6585,27 @@ async def test_import_agent_by_agent_id_duplicate_name_llm_failure_fallback( skip_duplicate_regeneration=False ) - # Assert assert result == 456 - mock_regen_name.assert_called_once() - mock_generate_unique_name.assert_called_once() mock_create_agent.assert_called_once() - # Verify fallback name was used - assert mock_create_agent.call_args[1]["agent_info"]["name"] == "fallback_name_1" + assert mock_create_agent.call_args[1]["agent_info"]["name"] == "duplicate_name" @pytest.mark.asyncio @patch('backend.services.agent_service.create_or_update_tool_by_tool_info') @patch('backend.services.agent_service.create_agent') @patch('backend.services.agent_service.query_all_tools') -@patch('backend.services.agent_service._generate_unique_agent_name_with_suffix') -@patch('backend.services.agent_service._regenerate_agent_name_with_llm') -@patch('backend.services.agent_service._check_agent_name_duplicate') -@patch('backend.services.agent_service.query_all_agent_info_by_tenant_id') @patch('backend.services.agent_service._resolve_model_with_fallback') -async def test_import_agent_by_agent_id_duplicate_name_no_model_fallback( +async def test_import_agent_by_agent_id_duplicate_name_no_model_still_allows( mock_resolve_model, - mock_query_all_agents, - mock_check_name_dup, - mock_regen_name, - mock_generate_unique_name, mock_query_all_tools, mock_create_agent, mock_create_tool ): - """Test import_agent_by_agent_id when agent_name is duplicate but no model available, uses fallback (line 1068-1074).""" - # Setup + """ + New behavior: even without model, duplicate name passes through unchanged. + """ mock_query_all_tools.return_value = [] - mock_resolve_model.side_effect = [None, None] # No models available - mock_query_all_agents.return_value = [{"name": "duplicate_name", "display_name": "Display"}] - mock_check_name_dup.return_value = True - mock_generate_unique_name.return_value = "fallback_name_2" + mock_resolve_model.side_effect = [None, None] mock_create_agent.return_value = {"agent_id": 456} agent_info = ExportAndImportAgentInfo( @@ -6408,7 +6628,6 @@ async def test_import_agent_by_agent_id_duplicate_name_no_model_fallback( business_logic_model_name=None ) - # Execute result = await import_agent_by_agent_id( import_agent_info=agent_info, tenant_id="test_tenant", @@ -6416,45 +6635,25 @@ async def test_import_agent_by_agent_id_duplicate_name_no_model_fallback( skip_duplicate_regeneration=False ) - # Assert assert result == 456 - mock_check_name_dup.assert_called_once() - mock_regen_name.assert_not_called() # Should not call LLM when no model - mock_generate_unique_name.assert_called_once() mock_create_agent.assert_called_once() - # Verify fallback name was used - assert mock_create_agent.call_args[1]["agent_info"]["name"] == "fallback_name_2" + assert mock_create_agent.call_args[1]["agent_info"]["name"] == "duplicate_name" @pytest.mark.asyncio @patch('backend.services.agent_service.create_or_update_tool_by_tool_info') @patch('backend.services.agent_service.create_agent') @patch('backend.services.agent_service.query_all_tools') -@patch('backend.services.agent_service._generate_unique_display_name_with_suffix') -@patch('backend.services.agent_service._regenerate_agent_display_name_with_llm') -@patch('backend.services.agent_service._check_agent_display_name_duplicate') -@patch('backend.services.agent_service._check_agent_name_duplicate') -@patch('backend.services.agent_service.query_all_agent_info_by_tenant_id') @patch('backend.services.agent_service._resolve_model_with_fallback') -async def test_import_agent_by_agent_id_duplicate_display_name_with_llm_success( +async def test_import_agent_by_agent_id_duplicate_display_name_allowed( mock_resolve_model, - mock_query_all_agents, - mock_check_name_dup, - mock_check_display_dup, - mock_regen_display, - mock_generate_unique_display, mock_query_all_tools, mock_create_agent, mock_create_tool ): - """Test import_agent_by_agent_id when agent_display_name is duplicate and LLM regeneration succeeds (line 1077-1092).""" - # Setup + """New behavior: duplicate display_name passes through without regeneration.""" mock_query_all_tools.return_value = [] mock_resolve_model.side_effect = [1, 2] - mock_query_all_agents.return_value = [{"name": "name1", "display_name": "duplicate_display"}] - mock_check_name_dup.return_value = False # Name is not duplicate - mock_check_display_dup.return_value = True # Display name is duplicate - mock_regen_display.return_value = "regenerated_display" mock_create_agent.return_value = {"agent_id": 456} agent_info = ExportAndImportAgentInfo( @@ -6477,7 +6676,6 @@ async def test_import_agent_by_agent_id_duplicate_display_name_with_llm_success( business_logic_model_name="Model2" ) - # Execute result = await import_agent_by_agent_id( import_agent_info=agent_info, tenant_id="test_tenant", @@ -6485,45 +6683,27 @@ async def test_import_agent_by_agent_id_duplicate_display_name_with_llm_success( skip_duplicate_regeneration=False ) - # Assert assert result == 456 - mock_check_display_dup.assert_called_once() - mock_regen_display.assert_called_once() mock_create_agent.assert_called_once() - # Verify regenerated display name was used - assert mock_create_agent.call_args[1]["agent_info"]["display_name"] == "regenerated_display" + assert mock_create_agent.call_args[1]["agent_info"]["display_name"] == "duplicate_display" @pytest.mark.asyncio @patch('backend.services.agent_service.create_or_update_tool_by_tool_info') @patch('backend.services.agent_service.create_agent') @patch('backend.services.agent_service.query_all_tools') -@patch('backend.services.agent_service._generate_unique_display_name_with_suffix') -@patch('backend.services.agent_service._regenerate_agent_display_name_with_llm') -@patch('backend.services.agent_service._check_agent_display_name_duplicate') -@patch('backend.services.agent_service._check_agent_name_duplicate') -@patch('backend.services.agent_service.query_all_agent_info_by_tenant_id') @patch('backend.services.agent_service._resolve_model_with_fallback') -async def test_import_agent_by_agent_id_duplicate_display_name_llm_failure_fallback( +async def test_import_agent_by_agent_id_duplicate_display_name_no_llm_fallback( mock_resolve_model, - mock_query_all_agents, - mock_check_name_dup, - mock_check_display_dup, - mock_regen_display, - mock_generate_unique_display, mock_query_all_tools, mock_create_agent, mock_create_tool ): - """Test import_agent_by_agent_id when agent_display_name is duplicate, LLM regeneration fails, uses fallback (line 1093-1099).""" - # Setup + """ + New behavior: duplicate display_name passes through without LLM; fallback not invoked. + """ mock_query_all_tools.return_value = [] mock_resolve_model.side_effect = [1, 2] - mock_query_all_agents.return_value = [{"name": "name1", "display_name": "duplicate_display"}] - mock_check_name_dup.return_value = False - mock_check_display_dup.return_value = True - mock_regen_display.side_effect = Exception("LLM failed") - mock_generate_unique_display.return_value = "fallback_display_1" mock_create_agent.return_value = {"agent_id": 456} agent_info = ExportAndImportAgentInfo( @@ -6546,7 +6726,6 @@ async def test_import_agent_by_agent_id_duplicate_display_name_llm_failure_fallb business_logic_model_name="Model2" ) - # Execute result = await import_agent_by_agent_id( import_agent_info=agent_info, tenant_id="test_tenant", @@ -6554,44 +6733,27 @@ async def test_import_agent_by_agent_id_duplicate_display_name_llm_failure_fallb skip_duplicate_regeneration=False ) - # Assert assert result == 456 - mock_regen_display.assert_called_once() - mock_generate_unique_display.assert_called_once() mock_create_agent.assert_called_once() - # Verify fallback display name was used - assert mock_create_agent.call_args[1]["agent_info"]["display_name"] == "fallback_display_1" + assert mock_create_agent.call_args[1]["agent_info"]["display_name"] == "duplicate_display" @pytest.mark.asyncio @patch('backend.services.agent_service.create_or_update_tool_by_tool_info') @patch('backend.services.agent_service.create_agent') @patch('backend.services.agent_service.query_all_tools') -@patch('backend.services.agent_service._generate_unique_display_name_with_suffix') -@patch('backend.services.agent_service._regenerate_agent_display_name_with_llm') -@patch('backend.services.agent_service._check_agent_display_name_duplicate') -@patch('backend.services.agent_service._check_agent_name_duplicate') -@patch('backend.services.agent_service.query_all_agent_info_by_tenant_id') @patch('backend.services.agent_service._resolve_model_with_fallback') -async def test_import_agent_by_agent_id_duplicate_display_name_no_model_fallback( +async def test_import_agent_by_agent_id_duplicate_display_name_no_model_still_allowed( mock_resolve_model, - mock_query_all_agents, - mock_check_name_dup, - mock_check_display_dup, - mock_regen_display, - mock_generate_unique_display, mock_query_all_tools, mock_create_agent, mock_create_tool ): - """Test import_agent_by_agent_id when agent_display_name is duplicate but no model available, uses fallback (line 1100-1106).""" - # Setup + """ + New behavior: even without model, duplicate display_name passes through unchanged. + """ mock_query_all_tools.return_value = [] - mock_resolve_model.side_effect = [None, None] # No models available - mock_query_all_agents.return_value = [{"name": "name1", "display_name": "duplicate_display"}] - mock_check_name_dup.return_value = False - mock_check_display_dup.return_value = True - mock_generate_unique_display.return_value = "fallback_display_2" + mock_resolve_model.side_effect = [None, None] mock_create_agent.return_value = {"agent_id": 456} agent_info = ExportAndImportAgentInfo( @@ -6614,7 +6776,6 @@ async def test_import_agent_by_agent_id_duplicate_display_name_no_model_fallback business_logic_model_name=None ) - # Execute result = await import_agent_by_agent_id( import_agent_info=agent_info, tenant_id="test_tenant", @@ -6622,11 +6783,6 @@ async def test_import_agent_by_agent_id_duplicate_display_name_no_model_fallback skip_duplicate_regeneration=False ) - # Assert assert result == 456 - mock_check_display_dup.assert_called_once() - mock_regen_display.assert_not_called() # Should not call LLM when no model - mock_generate_unique_display.assert_called_once() mock_create_agent.assert_called_once() - # Verify fallback display name was used - assert mock_create_agent.call_args[1]["agent_info"]["display_name"] == "fallback_display_2" + assert mock_create_agent.call_args[1]["agent_info"]["display_name"] == "duplicate_display" diff --git a/test/backend/services/test_conversation_management_service.py b/test/backend/services/test_conversation_management_service.py index feeb68d0e..173a3b6aa 100644 --- a/test/backend/services/test_conversation_management_service.py +++ b/test/backend/services/test_conversation_management_service.py @@ -327,7 +327,7 @@ def test_extract_user_messages(self): self.assertIn("Give me examples of AI applications", result) self.assertIn("AI stands for Artificial Intelligence.", result) - @patch('backend.services.conversation_management_service.OpenAIServerModel') + @patch('backend.services.conversation_management_service.OpenAIModel') @patch('backend.services.conversation_management_service.get_generate_title_prompt_template') @patch('backend.services.conversation_management_service.tenant_config_manager.get_model_config') def test_call_llm_for_title(self, mock_get_model_config, mock_get_prompt_template, mock_openai): @@ -360,7 +360,7 @@ def test_call_llm_for_title(self, mock_get_model_config, mock_get_prompt_templat mock_llm_instance.generate.assert_called_once() mock_get_prompt_template.assert_called_once_with(language='zh') - @patch('backend.services.conversation_management_service.OpenAIServerModel') + @patch('backend.services.conversation_management_service.OpenAIModel') @patch('backend.services.conversation_management_service.get_generate_title_prompt_template') @patch('backend.services.conversation_management_service.tenant_config_manager.get_model_config') def test_call_llm_for_title_response_none_zh(self, mock_get_model_config, mock_get_prompt_template, mock_openai): @@ -392,7 +392,7 @@ def test_call_llm_for_title_response_none_zh(self, mock_get_model_config, mock_g mock_llm_instance.generate.assert_called_once() mock_get_prompt_template.assert_called_once_with(language='zh') - @patch('backend.services.conversation_management_service.OpenAIServerModel') + @patch('backend.services.conversation_management_service.OpenAIModel') @patch('backend.services.conversation_management_service.get_generate_title_prompt_template') @patch('backend.services.conversation_management_service.tenant_config_manager.get_model_config') def test_call_llm_for_title_response_none_en(self, mock_get_model_config, mock_get_prompt_template, mock_openai): diff --git a/test/backend/services/test_model_management_service.py b/test/backend/services/test_model_management_service.py index 86e1cac73..48741a0f8 100644 --- a/test/backend/services/test_model_management_service.py +++ b/test/backend/services/test_model_management_service.py @@ -405,6 +405,73 @@ async def test_create_model_for_tenant_embedding_sets_dimension(): assert mock_create.call_count == 1 +@pytest.mark.asyncio +async def test_create_model_for_tenant_embedding_sets_default_chunk_batch(): + """chunk_batch defaults to 10 when not provided for embedding models.""" + svc = import_svc() + + with mock.patch.object(svc, "get_model_by_display_name", return_value=None), \ + mock.patch.object(svc, "embedding_dimension_check", new=mock.AsyncMock(return_value=512)) as mock_dim, \ + mock.patch.object(svc, "create_model_record") as mock_create, \ + mock.patch.object(svc, "split_repo_name", return_value=("openai", "text-embedding-3-small")): + + user_id = "u1" + tenant_id = "t1" + model_data = { + "model_name": "openai/text-embedding-3-small", + "display_name": None, + "base_url": "https://api.openai.com", + "model_type": "embedding", + "chunk_batch": None, # Explicitly unset to exercise defaulting + } + + await svc.create_model_for_tenant(user_id, tenant_id, model_data) + + mock_dim.assert_awaited_once() + assert mock_create.call_count == 1 + # chunk_batch should be defaulted before persistence + create_args = mock_create.call_args[0][0] + assert create_args["chunk_batch"] == 10 + + +@pytest.mark.asyncio +async def test_create_model_for_tenant_multi_embedding_sets_default_chunk_batch(): + """chunk_batch defaults to 10 when not provided for multi_embedding models (covers line 79).""" + svc = import_svc() + + with mock.patch.object(svc, "get_model_by_display_name", return_value=None), \ + mock.patch.object(svc, "embedding_dimension_check", new=mock.AsyncMock(return_value=512)) as mock_dim, \ + mock.patch.object(svc, "create_model_record") as mock_create, \ + mock.patch.object(svc, "split_repo_name", return_value=("openai", "clip")): + + user_id = "u1" + tenant_id = "t1" + model_data = { + "model_name": "openai/clip", + "display_name": None, + "base_url": "https://api.openai.com", + "model_type": "multi_embedding", + "chunk_batch": None, # Explicitly unset to exercise defaulting + } + + await svc.create_model_for_tenant(user_id, tenant_id, model_data) + + mock_dim.assert_awaited_once() + # Should create two records: multi_embedding and its embedding variant + assert mock_create.call_count == 2 + + # Verify chunk_batch was set to 10 for both records + create_calls = mock_create.call_args_list + # First call is for multi_embedding + multi_emb_args = create_calls[0][0][0] + assert multi_emb_args["chunk_batch"] == 10 + assert multi_emb_args["model_type"] == "multi_embedding" + # Second call is for embedding variant + emb_args = create_calls[1][0][0] + assert emb_args["chunk_batch"] == 10 + assert emb_args["model_type"] == "embedding" + + @pytest.mark.asyncio async def test_create_provider_models_for_tenant_success(): svc = import_svc() diff --git a/test/backend/services/test_model_provider_service.py b/test/backend/services/test_model_provider_service.py index ce3a0ab75..0916e61f9 100644 --- a/test/backend/services/test_model_provider_service.py +++ b/test/backend/services/test_model_provider_service.py @@ -304,7 +304,9 @@ async def test_prepare_model_dict_embedding(): assert kwargs["model_name"] == "text-embedding-ada-002" assert kwargs["model_type"] == "embedding" assert kwargs["api_key"] == "test-key" - assert kwargs["max_tokens"] == 1024 + # For embedding models, max_tokens is set to 0 as placeholder, + # will be updated by embedding_dimension_check later + assert kwargs["max_tokens"] == 0 assert kwargs["display_name"] == "openai/text-embedding-ada-002" assert kwargs["expected_chunk_size"] == sys.modules["consts.const"].DEFAULT_EXPECTED_CHUNK_SIZE assert kwargs["maximum_chunk_size"] == sys.modules["consts.const"].DEFAULT_MAXIMUM_CHUNK_SIZE diff --git a/test/backend/services/test_redis_service.py b/test/backend/services/test_redis_service.py index 8ebf7613e..1fba985ba 100644 --- a/test/backend/services/test_redis_service.py +++ b/test/backend/services/test_redis_service.py @@ -1,10 +1,7 @@ import unittest from unittest.mock import patch, MagicMock, call import json -import os import redis -import hashlib -import urllib.parse from backend.services.redis_service import RedisService, get_redis_service @@ -43,7 +40,8 @@ def test_client_property(self, mock_from_url): mock_from_url.assert_called_once_with( 'redis://localhost:6379/0', socket_timeout=5, - socket_connect_timeout=5 + socket_connect_timeout=5, + decode_responses=True ) self.assertEqual(client, self.mock_redis_client) @@ -127,7 +125,23 @@ def test_backend_client_no_env_vars(self, mock_from_url): # Execute & Verify with self.assertRaises(ValueError): _ = redis_service.backend_client - + + @patch('redis.from_url') + @patch('backend.services.redis_service.REDIS_URL', 'redis://localhost:6379/0') + def test_mark_and_check_task_cancelled(self, mock_from_url): + """mark_task_cancelled should set flag and is_task_cancelled should read it.""" + mock_client = MagicMock() + mock_client.setex.return_value = True + mock_client.get.return_value = b"1" + mock_from_url.return_value = mock_client + + service = RedisService() + ok = service.mark_task_cancelled("task-1", ttl_hours=1) + self.assertTrue(ok) + self.assertTrue(service.is_task_cancelled("task-1")) + mock_client.setex.assert_called_once() + mock_client.get.assert_called_once() + def test_delete_knowledgebase_records(self): """Test delete_knowledgebase_records method""" # Setup @@ -216,60 +230,155 @@ def test_delete_document_records_with_error(self): self.assertEqual(len(result["errors"]), 1) self.assertIn("Test error", result["errors"][0]) + def test_cleanup_single_task_related_keys_outer_exception(self): + """Outer handler logs when warning path itself fails.""" + self.redis_service._client = self.mock_redis_client + self.redis_service._backend_client = self.mock_backend_client + self.mock_redis_client.delete.side_effect = redis.RedisError( + "delete failed") + + with patch('backend.services.redis_service.logger.warning', side_effect=Exception("warn boom")), \ + patch('backend.services.redis_service.logger.error') as mock_error: + result = self.redis_service._cleanup_single_task_related_keys( + "task123") + + mock_error.assert_called_once() + self.assertEqual(result, 0) + def test_cleanup_celery_tasks(self): """Test _cleanup_celery_tasks method""" # Setup self.redis_service._backend_client = self.mock_backend_client - + # Create mock task data - task_keys = [b'celery-task-meta-1', b'celery-task-meta-2', b'celery-task-meta-3'] - + task_keys = [b'celery-task-meta-1', + b'celery-task-meta-2', b'celery-task-meta-3'] + # Task 1 matches our index task1_data = json.dumps({ 'result': {'index_name': 'test_index', 'some_key': 'some_value'}, 'parent_id': '2' # This will trigger a parent lookup }).encode() - + # Task 2 has index name in a different location task2_data = json.dumps({ - 'index_name': 'test_index', + 'index_name': 'test_index', 'result': {'some_key': 'some_value'}, 'parent_id': None # No parent }).encode() - + # Task 3 is for a different index task3_data = json.dumps({ 'result': {'index_name': 'other_index', 'some_key': 'some_value'} }).encode() - + # Configure mock responses self.mock_backend_client.keys.return_value = task_keys - self.mock_backend_client.get.side_effect = [task1_data, task2_data, task3_data] - + # Two passes over keys: provide payloads for both passes (6 gets) + self.mock_backend_client.get.side_effect = [ + task1_data, task2_data, task3_data, + task1_data, task2_data, task3_data, + ] + # We expect delete to be called and return 1 each time self.mock_backend_client.delete.return_value = 1 - + # Execute with patch.object(self.redis_service, '_recursively_delete_task_and_parents') as mock_recursive_delete: mock_recursive_delete.side_effect = [(1, {'1'}), (1, {'2'})] result = self.redis_service._cleanup_celery_tasks("test_index") - + # Verify - self.mock_backend_client.keys.assert_called_once_with('celery-task-meta-*') - # We expect 3 calls - one for each task key - self.assertEqual(self.mock_backend_client.get.call_count, 3) - - # Should have called recursive delete twice (for task1 and task2) - self.assertEqual(mock_recursive_delete.call_count, 2) - - # Return value should be the number of deleted tasks - self.assertEqual(result, 2) - + self.mock_backend_client.keys.assert_called_once_with( + 'celery-task-meta-*') + # Implementation fetches task payloads in both passes; expect 6 total (3 keys * 2 passes) + self.assertEqual(self.mock_backend_client.get.call_count, 6) + + # Should have called recursive delete for matched tasks + self.assertGreaterEqual(mock_recursive_delete.call_count, 2) + + # Return value should match deleted tasks count + self.assertEqual(result, mock_recursive_delete.call_count) + + def test_cleanup_celery_tasks_get_exception_and_cancel_failure(self): + """First-pass get failure and cancel failure are both handled.""" + self.redis_service._backend_client = self.mock_backend_client + self.redis_service._client = self.mock_redis_client + + task_keys = [b'celery-task-meta-err', b'celery-task-meta-2'] + valid_task = json.dumps({ + 'result': {'index_name': 'test_index'}, + 'parent_id': None + }).encode() + + self.mock_backend_client.keys.return_value = task_keys + self.mock_backend_client.get.side_effect = [ + redis.RedisError("boom"), + valid_task, + redis.RedisError("boom-second"), + valid_task, + ] + + with patch.object(self.redis_service, 'mark_task_cancelled', side_effect=ValueError("cancel fail")) as mock_cancel, \ + patch.object(self.redis_service, '_recursively_delete_task_and_parents', return_value=(1, {'2'})) as mock_delete, \ + patch.object(self.redis_service, '_cleanup_single_task_related_keys') as mock_cleanup: + + result = self.redis_service._cleanup_celery_tasks("test_index") + + mock_cancel.assert_called_once_with('2') + mock_delete.assert_called_once_with('2') + mock_cleanup.assert_called_once_with('2') + self.assertEqual(result, 1) + + def test_cleanup_celery_tasks_exc_message_bad_json(self): + """JSON decode failure inside exc_message parsing does not crash.""" + self.redis_service._backend_client = self.mock_backend_client + self.redis_service._client = self.mock_redis_client + + task_keys = [b'celery-task-meta-1'] + bad_json_payload = json.dumps({ + 'result': { + # Contains brace to enter parsing block + 'exc_message': '{bad json' + } + }).encode() + + self.mock_backend_client.keys.return_value = task_keys + self.mock_backend_client.get.side_effect = [ + bad_json_payload, bad_json_payload] + + with patch.object(self.redis_service, '_recursively_delete_task_and_parents', return_value=(0, set())) as mock_delete: + result = self.redis_service._cleanup_celery_tasks("test_index") + + # Bad JSON should be tolerated; no deletions occur + mock_delete.assert_not_called() + self.assertEqual(result, 0) + + def test_cleanup_celery_tasks_cleanup_single_task_error(self): + """Failures during related-key cleanup are logged and skipped.""" + self.redis_service._backend_client = self.mock_backend_client + self.redis_service._client = self.mock_redis_client + + task_keys = [b'celery-task-meta-1'] + task_payload = json.dumps({ + 'result': {'index_name': 'test_index'} + }).encode() + + self.mock_backend_client.keys.return_value = task_keys + self.mock_backend_client.get.side_effect = [task_payload, task_payload] + + with patch.object(self.redis_service, '_recursively_delete_task_and_parents', return_value=(1, {'1'})), \ + patch.object(self.redis_service, '_cleanup_single_task_related_keys', side_effect=Exception("cleanup boom")) as mock_cleanup: + result = self.redis_service._cleanup_celery_tasks("test_index") + + mock_cleanup.assert_called_once_with('1') + self.assertEqual(result, 1) + def test_cleanup_cache_keys(self): """Test _cleanup_cache_keys method""" # Setup self.redis_service._client = self.mock_redis_client - + # Configure mock responses for each pattern pattern_keys = { '*test_index*': [b'key1', b'key2'], @@ -277,19 +386,20 @@ def test_cleanup_cache_keys(self): 'index:test_index:*': [b'key6'], 'search:test_index:*': [b'key7', b'key8'] } - + def mock_keys_side_effect(pattern): return pattern_keys.get(pattern, []) - + self.mock_redis_client.keys.side_effect = mock_keys_side_effect - self.mock_redis_client.delete.return_value = 1 # Each delete operation deletes 1 key - + # Each delete operation deletes 1 key + self.mock_redis_client.delete.return_value = 1 + # Execute result = self.redis_service._cleanup_cache_keys("test_index") - + # Verify self.assertEqual(self.mock_redis_client.keys.call_count, 4) - + # All keys should be deleted (8 keys total) expected_calls = [ call(b'key1', b'key2'), @@ -297,19 +407,21 @@ def mock_keys_side_effect(pattern): call(b'key6'), call(b'key7', b'key8') ] - self.mock_redis_client.delete.assert_has_calls(expected_calls, any_order=True) - + self.mock_redis_client.delete.assert_has_calls( + expected_calls, any_order=True) + # Return value should be the number of deleted keys self.assertEqual(result, 4) # 4 successful delete operations - + def test_cleanup_document_celery_tasks(self): """Test _cleanup_document_celery_tasks method""" # Setup self.redis_service._backend_client = self.mock_backend_client - + # Create mock task data - task_keys = [b'celery-task-meta-1', b'celery-task-meta-2', b'celery-task-meta-3'] - + task_keys = [b'celery-task-meta-1', + b'celery-task-meta-2', b'celery-task-meta-3'] + # Task 1 matches our index and document task1_data = json.dumps({ 'result': { @@ -318,7 +430,7 @@ def test_cleanup_document_celery_tasks(self): }, 'parent_id': '2' # This will trigger a parent lookup }).encode() - + # Task 2 has the right index but wrong document task2_data = json.dumps({ 'result': { @@ -326,7 +438,7 @@ def test_cleanup_document_celery_tasks(self): 'source': 'other/doc.pdf' } }).encode() - + # Task 3 has document path in a different field task3_data = json.dumps({ 'result': { @@ -335,43 +447,46 @@ def test_cleanup_document_celery_tasks(self): }, 'parent_id': None # No parent }).encode() - + # Configure mock responses self.mock_backend_client.keys.return_value = task_keys - self.mock_backend_client.get.side_effect = [task1_data, task2_data, task3_data] - + self.mock_backend_client.get.side_effect = [ + task1_data, task2_data, task3_data] + # We expect delete to be called and return 1 each time self.mock_backend_client.delete.return_value = 1 - + # Execute with patch.object(self.redis_service, '_recursively_delete_task_and_parents') as mock_recursive_delete: mock_recursive_delete.side_effect = [(1, {'1'}), (1, {'3'})] - result = self.redis_service._cleanup_document_celery_tasks("test_index", "path/to/doc.pdf") - + result = self.redis_service._cleanup_document_celery_tasks( + "test_index", "path/to/doc.pdf") + # Verify - self.mock_backend_client.keys.assert_called_once_with('celery-task-meta-*') + self.mock_backend_client.keys.assert_called_once_with( + 'celery-task-meta-*') # We expect 3 calls - one for each task key self.assertEqual(self.mock_backend_client.get.call_count, 3) - + # Should have called recursive delete twice (for task1 and task3) self.assertEqual(mock_recursive_delete.call_count, 2) - + # Return value should be the number of deleted tasks self.assertEqual(result, 2) - + @patch('hashlib.md5') @patch('urllib.parse.quote') def test_cleanup_document_cache_keys(self, mock_quote, mock_md5): """Test _cleanup_document_cache_keys method""" # Setup self.redis_service._client = self.mock_redis_client - + # Mock the path hashing and quoting mock_quote.return_value = 'safe_path' mock_md5_instance = MagicMock() mock_md5_instance.hexdigest.return_value = 'path_hash' mock_md5.return_value = mock_md5_instance - + # Configure mock responses for each pattern pattern_keys = { '*test_index*safe_path*': [b'key1'], @@ -381,100 +496,105 @@ def test_cleanup_document_cache_keys(self, mock_quote, mock_md5): 'doc:safe_path:*': [b'key6', b'key7'], 'doc:path_hash:*': [b'key8'] } - + def mock_keys_side_effect(pattern): return pattern_keys.get(pattern, []) - + self.mock_redis_client.keys.side_effect = mock_keys_side_effect - self.mock_redis_client.delete.return_value = 1 # Each delete operation deletes 1 key - + # Each delete operation deletes 1 key + self.mock_redis_client.delete.return_value = 1 + # Execute - result = self.redis_service._cleanup_document_cache_keys("test_index", "path/to/doc.pdf") - + result = self.redis_service._cleanup_document_cache_keys( + "test_index", "path/to/doc.pdf") + # Verify self.assertEqual(self.mock_redis_client.keys.call_count, 6) - + # Return value should be the number of deleted keys self.assertEqual(result, 6) # 6 successful delete operations - + def test_get_knowledgebase_task_count(self): """Test get_knowledgebase_task_count method""" # Setup self.redis_service._client = self.mock_redis_client self.redis_service._backend_client = self.mock_backend_client - + # Create mock task data task_keys = [b'celery-task-meta-1', b'celery-task-meta-2'] - + # Task 1 matches our index task1_data = json.dumps({ 'result': {'index_name': 'test_index'} }).encode() - + # Task 2 is for a different index task2_data = json.dumps({ 'result': {'index_name': 'other_index'} }).encode() - + # Configure mock responses for Celery tasks self.mock_backend_client.keys.return_value = task_keys self.mock_backend_client.get.side_effect = [task1_data, task2_data] - + # Configure mock responses for cache keys cache_keys = { '*test_index*': [b'key1', b'key2'], 'kb:test_index:*': [b'key3', b'key4'], 'index:test_index:*': [b'key5'] } - + def mock_keys_side_effect(pattern): return cache_keys.get(pattern, []) - + self.mock_redis_client.keys.side_effect = mock_keys_side_effect - + # Execute result = self.redis_service.get_knowledgebase_task_count("test_index") - + # Verify - self.mock_backend_client.keys.assert_called_once_with('celery-task-meta-*') + self.mock_backend_client.keys.assert_called_once_with( + 'celery-task-meta-*') self.assertEqual(self.mock_backend_client.get.call_count, 2) - + # Should count 1 matching task and 5 cache keys self.assertEqual(result, 6) - + def test_ping_success(self): """Test ping method when connection is successful""" # Setup self.redis_service._client = self.mock_redis_client self.redis_service._backend_client = self.mock_backend_client - + self.mock_redis_client.ping.return_value = True self.mock_backend_client.ping.return_value = True - + # Execute result = self.redis_service.ping() - + # Verify self.mock_redis_client.ping.assert_called_once() self.mock_backend_client.ping.assert_called_once() self.assertTrue(result) - + def test_ping_failure(self): """Test ping method when connection fails""" # Setup self.redis_service._client = self.mock_redis_client self.redis_service._backend_client = self.mock_backend_client - - self.mock_redis_client.ping.side_effect = redis.RedisError("Connection failed") - + + self.mock_redis_client.ping.side_effect = redis.RedisError( + "Connection failed") + # Execute result = self.redis_service.ping() - + # Verify self.mock_redis_client.ping.assert_called_once() - self.mock_backend_client.ping.assert_not_called() # Should not be called after first ping fails + # Should not be called after first ping fails + self.mock_backend_client.ping.assert_not_called() self.assertFalse(result) - + @patch('backend.services.redis_service._redis_service', None) @patch('backend.services.redis_service.RedisService') def test_get_redis_service(self, mock_redis_service_class): @@ -482,146 +602,155 @@ def test_get_redis_service(self, mock_redis_service_class): # Setup mock_instance = MagicMock() mock_redis_service_class.return_value = mock_instance - + # Execute service1 = get_redis_service() service2 = get_redis_service() - + # Verify mock_redis_service_class.assert_called_once() # Only created once self.assertEqual(service1, mock_instance) - self.assertEqual(service2, mock_instance) # Should return same instance - + # Should return same instance + self.assertEqual(service2, mock_instance) + def test_recursively_delete_task_and_parents_no_parent(self): """Test _recursively_delete_task_and_parents with task that has no parent""" # Setup self.redis_service._backend_client = self.mock_backend_client - + task_data = json.dumps({ 'result': {'some_data': 'value'}, 'parent_id': None }).encode() - + self.mock_backend_client.get.return_value = task_data self.mock_backend_client.delete.return_value = 1 - + # Execute - deleted_count, processed_ids = self.redis_service._recursively_delete_task_and_parents("task123") - + deleted_count, processed_ids = self.redis_service._recursively_delete_task_and_parents( + "task123") + # Verify self.assertEqual(deleted_count, 1) self.assertEqual(processed_ids, {"task123"}) - self.mock_backend_client.get.assert_called_once_with('celery-task-meta-task123') - self.mock_backend_client.delete.assert_called_once_with('celery-task-meta-task123') - + self.mock_backend_client.get.assert_called_once_with( + 'celery-task-meta-task123') + self.mock_backend_client.delete.assert_called_once_with( + 'celery-task-meta-task123') + def test_recursively_delete_task_and_parents_with_cycle_detection(self): """Test _recursively_delete_task_and_parents detects and breaks cycles""" # Setup self.redis_service._backend_client = self.mock_backend_client - + # Create a cycle: task1 -> task2 -> task1 task1_data = json.dumps({'parent_id': 'task2'}).encode() task2_data = json.dumps({'parent_id': 'task1'}).encode() - + self.mock_backend_client.get.side_effect = [task1_data, task2_data] self.mock_backend_client.delete.return_value = 1 - + # Execute - deleted_count, processed_ids = self.redis_service._recursively_delete_task_and_parents("task1") - + deleted_count, processed_ids = self.redis_service._recursively_delete_task_and_parents( + "task1") + # Verify - should stop when cycle is detected self.assertEqual(deleted_count, 2) self.assertEqual(processed_ids, {"task1", "task2"}) self.assertEqual(self.mock_backend_client.delete.call_count, 2) - + def test_recursively_delete_task_and_parents_json_decode_error(self): """Test _recursively_delete_task_and_parents handles JSON decode errors""" # Setup self.redis_service._backend_client = self.mock_backend_client - + # Invalid JSON data invalid_json_data = b'invalid json data' - + self.mock_backend_client.get.return_value = invalid_json_data self.mock_backend_client.delete.return_value = 1 - + # Execute - deleted_count, processed_ids = self.redis_service._recursively_delete_task_and_parents("task123") - + deleted_count, processed_ids = self.redis_service._recursively_delete_task_and_parents( + "task123") + # Verify - should still delete the task even if JSON parsing fails self.assertEqual(deleted_count, 1) self.assertEqual(processed_ids, {"task123"}) - self.mock_backend_client.delete.assert_called_once_with('celery-task-meta-task123') - + self.mock_backend_client.delete.assert_called_once_with( + 'celery-task-meta-task123') + def test_recursively_delete_task_and_parents_redis_error(self): """Test _recursively_delete_task_and_parents handles Redis errors""" # Setup self.redis_service._backend_client = self.mock_backend_client - + # Simulate Redis error - self.mock_backend_client.get.side_effect = redis.RedisError("Connection lost") - + self.mock_backend_client.get.side_effect = redis.RedisError( + "Connection lost") + # Execute - deleted_count, processed_ids = self.redis_service._recursively_delete_task_and_parents("task123") - + deleted_count, processed_ids = self.redis_service._recursively_delete_task_and_parents( + "task123") + # Verify - should return 0 when Redis error occurs self.assertEqual(deleted_count, 0) self.assertEqual(processed_ids, {"task123"}) - + def test_cleanup_celery_tasks_with_failed_task_metadata(self): """Test _cleanup_celery_tasks handles failed tasks with exception metadata""" # Setup self.redis_service._backend_client = self.mock_backend_client - + task_keys = [b'celery-task-meta-1'] - + # Task with exception metadata containing index name task_data = json.dumps({ 'result': { 'exc_message': 'Error processing task: {"index_name": "test_index", "error": "failed"}' } }).encode() - + self.mock_backend_client.keys.return_value = task_keys self.mock_backend_client.get.return_value = task_data - + # Execute with patch.object(self.redis_service, '_recursively_delete_task_and_parents') as mock_recursive_delete: mock_recursive_delete.return_value = (1, {'1'}) result = self.redis_service._cleanup_celery_tasks("test_index") - + # Verify self.assertEqual(result, 1) mock_recursive_delete.assert_called_once_with('1') - + def test_cleanup_celery_tasks_invalid_exception_metadata(self): """Test _cleanup_celery_tasks handles invalid exception metadata gracefully""" # Setup self.redis_service._backend_client = self.mock_backend_client - + task_keys = [b'celery-task-meta-1'] - + # Task with invalid exception metadata task_data = json.dumps({ 'result': { 'exc_message': 'Invalid JSON metadata' } }).encode() - + self.mock_backend_client.keys.return_value = task_keys self.mock_backend_client.get.return_value = task_data - + # Execute result = self.redis_service._cleanup_celery_tasks("test_index") - + # Verify - should not crash and return 0 self.assertEqual(result, 0) - + def test_cleanup_cache_keys_partial_failure(self): """Test _cleanup_cache_keys handles partial failures gracefully""" # Setup self.redis_service._client = self.mock_redis_client - + # First pattern succeeds, second fails, third succeeds def mock_keys_side_effect(pattern): if pattern == 'kb:test_index:*': @@ -632,33 +761,65 @@ def mock_keys_side_effect(pattern): return [b'key3'] else: return [] - + self.mock_redis_client.keys.side_effect = mock_keys_side_effect self.mock_redis_client.delete.return_value = 1 - + # Execute result = self.redis_service._cleanup_cache_keys("test_index") - + # Verify - should continue processing despite one pattern failing self.assertEqual(result, 2) # 2 successful delete operations - + def test_cleanup_cache_keys_all_patterns_fail(self): """Test _cleanup_cache_keys handles errors gracefully when all patterns fail""" # Setup self.redis_service._client = self.mock_redis_client - + # Simulate an error for all pattern calls # Each call to keys() will fail but be caught by inner try-catch - self.mock_redis_client.keys.side_effect = redis.RedisError("Redis connection failed") - + self.mock_redis_client.keys.side_effect = redis.RedisError( + "Redis connection failed") + # Execute - should not raise exception but return 0 result = self.redis_service._cleanup_cache_keys("test_index") - + # Verify - should handle gracefully and return 0 self.assertEqual(result, 0) # Should have tried all 4 patterns self.assertEqual(self.mock_redis_client.keys.call_count, 4) - + + def test_cleanup_document_celery_tasks_cancel_fail_and_processing_error(self): + """Document cleanup logs processing errors and cancel failures.""" + self.redis_service._backend_client = self.mock_backend_client + self.redis_service._client = self.mock_redis_client + + task_keys = [b'celery-task-meta-err', b'celery-task-meta-1'] + good_payload = json.dumps({ + 'result': { + 'index_name': 'kb1', + 'path_or_url': 'doc1' + } + }).encode() + + self.mock_backend_client.keys.return_value = task_keys + self.mock_backend_client.get.side_effect = [ + redis.RedisError("get boom"), + good_payload + ] + + with patch.object(self.redis_service, 'mark_task_cancelled', side_effect=ValueError("cancel fail")) as mock_cancel, \ + patch.object(self.redis_service, '_recursively_delete_task_and_parents', return_value=(1, {'1'})) as mock_delete, \ + patch.object(self.redis_service, '_cleanup_single_task_related_keys') as mock_cleanup: + + result = self.redis_service._cleanup_document_celery_tasks( + "kb1", "doc1") + + mock_cancel.assert_called_once_with('1') + mock_delete.assert_called_once_with('1') + mock_cleanup.assert_called_once_with('1') + self.assertEqual(result, 1) + def test_cleanup_document_cache_keys_empty_patterns(self): """Test _cleanup_document_cache_keys handles empty key patterns""" @@ -785,6 +946,470 @@ def test_ping_backend_failure(self): self.mock_redis_client.ping.assert_called_once() self.mock_backend_client.ping.assert_called_once() + # ------------------------------------------------------------------ + # Test mark_task_cancelled edge cases + # ------------------------------------------------------------------ + + def test_mark_task_cancelled_empty_task_id(self): + """Test mark_task_cancelled returns False when task_id is empty""" + self.redis_service._client = self.mock_redis_client + + result = self.redis_service.mark_task_cancelled("") + self.assertFalse(result) + self.mock_redis_client.setex.assert_not_called() + + def test_mark_task_cancelled_redis_error(self): + """Test mark_task_cancelled handles Redis errors gracefully""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.setex.side_effect = redis.RedisError("Connection failed") + + result = self.redis_service.mark_task_cancelled("task-123") + self.assertFalse(result) + self.mock_redis_client.setex.assert_called_once() + + def test_mark_task_cancelled_custom_ttl(self): + """Test mark_task_cancelled with custom TTL hours""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.setex.return_value = True + + result = self.redis_service.mark_task_cancelled("task-123", ttl_hours=48) + self.assertTrue(result) + # Verify TTL is calculated correctly (48 hours = 172800 seconds) + call_args = self.mock_redis_client.setex.call_args + self.assertEqual(call_args[0][1], 48 * 3600) # TTL in seconds + + # ------------------------------------------------------------------ + # Test is_task_cancelled edge cases + # ------------------------------------------------------------------ + + def test_is_task_cancelled_empty_task_id(self): + """Test is_task_cancelled returns False when task_id is empty""" + self.redis_service._client = self.mock_redis_client + + result = self.redis_service.is_task_cancelled("") + self.assertFalse(result) + self.mock_redis_client.get.assert_not_called() + + def test_is_task_cancelled_none_value(self): + """Test is_task_cancelled returns False when key doesn't exist""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.get.return_value = None + + result = self.redis_service.is_task_cancelled("task-123") + self.assertFalse(result) + + def test_is_task_cancelled_empty_string_value(self): + """Test is_task_cancelled returns False when value is empty string""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.get.return_value = "" + + result = self.redis_service.is_task_cancelled("task-123") + self.assertFalse(result) + + def test_is_task_cancelled_redis_error(self): + """Test is_task_cancelled handles Redis errors gracefully""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.get.side_effect = redis.RedisError("Connection failed") + + result = self.redis_service.is_task_cancelled("task-123") + self.assertFalse(result) + + # ------------------------------------------------------------------ + # Test _cleanup_single_task_related_keys + # ------------------------------------------------------------------ + + def test_cleanup_single_task_related_keys_success(self): + """Test _cleanup_single_task_related_keys deletes all related keys""" + self.redis_service._client = self.mock_redis_client + self.redis_service._backend_client = self.mock_backend_client + + # Mock successful deletions + self.mock_redis_client.delete.side_effect = [1, 1, 1] # progress, error, cancel + self.mock_backend_client.delete.return_value = 1 # chunk cache + + result = self.redis_service._cleanup_single_task_related_keys("task-123") + + # Should delete 4 keys total + self.assertEqual(result, 4) + # Verify all keys were attempted + self.assertEqual(self.mock_redis_client.delete.call_count, 3) + self.mock_backend_client.delete.assert_called_once_with("dp:task-123:chunks") + + def test_cleanup_single_task_related_keys_empty_task_id(self): + """Test _cleanup_single_task_related_keys returns 0 for empty task_id""" + result = self.redis_service._cleanup_single_task_related_keys("") + self.assertEqual(result, 0) + + def test_cleanup_single_task_related_keys_partial_failure(self): + """Test _cleanup_single_task_related_keys handles partial failures""" + self.redis_service._client = self.mock_redis_client + self.redis_service._backend_client = self.mock_backend_client + + # First key succeeds, second fails, third succeeds, chunk cache fails + self.mock_redis_client.delete.side_effect = [1, redis.RedisError("Error"), 1] + self.mock_backend_client.delete.side_effect = redis.RedisError("Backend error") + + result = self.redis_service._cleanup_single_task_related_keys("task-123") + + # Should return count of successful deletions (2) + self.assertEqual(result, 2) + + def test_cleanup_single_task_related_keys_all_fail(self): + """Test _cleanup_single_task_related_keys handles all failures gracefully""" + self.redis_service._client = self.mock_redis_client + self.redis_service._backend_client = self.mock_backend_client + + self.mock_redis_client.delete.side_effect = redis.RedisError("All failed") + self.mock_backend_client.delete.side_effect = redis.RedisError("Backend failed") + + result = self.redis_service._cleanup_single_task_related_keys("task-123") + + # Should return 0 but not raise exception + self.assertEqual(result, 0) + + def test_cleanup_single_task_related_keys_no_keys_exist(self): + """Test _cleanup_single_task_related_keys when keys don't exist""" + self.redis_service._client = self.mock_redis_client + self.redis_service._backend_client = self.mock_backend_client + + # All deletions return 0 (key doesn't exist) + self.mock_redis_client.delete.side_effect = [0, 0, 0] + self.mock_backend_client.delete.return_value = 0 + + result = self.redis_service._cleanup_single_task_related_keys("task-123") + + # Should return 0 + self.assertEqual(result, 0) + + # ------------------------------------------------------------------ + # Test save_error_info + # ------------------------------------------------------------------ + + def test_save_error_info_success(self): + """Test save_error_info successfully saves error information""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.setex.return_value = True + self.mock_redis_client.get.return_value = "Test error reason" + + result = self.redis_service.save_error_info("task-123", "Test error reason") + + self.assertTrue(result) + self.mock_redis_client.setex.assert_called_once() + # Verify TTL is 30 days in seconds + call_args = self.mock_redis_client.setex.call_args + self.assertEqual(call_args[0][1], 30 * 24 * 60 * 60) + self.assertEqual(call_args[0][2], "Test error reason") + # Verify get was called to verify the save + self.mock_redis_client.get.assert_called_once() + + def test_save_error_info_empty_task_id(self): + """Test save_error_info returns False when task_id is empty""" + self.redis_service._client = self.mock_redis_client + + result = self.redis_service.save_error_info("", "Error reason") + self.assertFalse(result) + self.mock_redis_client.setex.assert_not_called() + + def test_save_error_info_empty_error_reason(self): + """Test save_error_info returns False when error_reason is empty""" + self.redis_service._client = self.mock_redis_client + + result = self.redis_service.save_error_info("task-123", "") + self.assertFalse(result) + self.mock_redis_client.setex.assert_not_called() + + def test_save_error_info_custom_ttl(self): + """Test save_error_info with custom TTL days""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.setex.return_value = True + self.mock_redis_client.get.return_value = "Error" + + result = self.redis_service.save_error_info("task-123", "Error", ttl_days=7) + + self.assertTrue(result) + call_args = self.mock_redis_client.setex.call_args + # Verify TTL is 7 days in seconds + self.assertEqual(call_args[0][1], 7 * 24 * 60 * 60) + + def test_save_error_info_setex_returns_false(self): + """Test save_error_info handles setex returning False""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.setex.return_value = False + + result = self.redis_service.save_error_info("task-123", "Error") + self.assertFalse(result) + + def test_save_error_info_verification_fails(self): + """Test save_error_info when verification get returns None""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.setex.return_value = True + self.mock_redis_client.get.return_value = None # Verification fails + + result = self.redis_service.save_error_info("task-123", "Error") + # Should still return True because setex succeeded + self.assertTrue(result) + + def test_save_error_info_redis_error(self): + """Test save_error_info handles Redis errors gracefully""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.setex.side_effect = redis.RedisError("Connection failed") + + result = self.redis_service.save_error_info("task-123", "Error") + self.assertFalse(result) + + def test_save_error_info_verification_redis_error(self): + """Test save_error_info returns False when verification raises Redis error""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.setex.return_value = True + self.mock_redis_client.get.side_effect = redis.RedisError("Connection failed") + + # Should return False because verification failed with exception + result = self.redis_service.save_error_info("task-123", "Error") + self.assertFalse(result) + + # ------------------------------------------------------------------ + # Test save_progress_info + # ------------------------------------------------------------------ + + def test_save_progress_info_success(self): + """Test save_progress_info successfully saves progress""" + self.redis_service._client = self.mock_redis_client + + result = self.redis_service.save_progress_info("task-123", 50, 100) + + self.assertTrue(result) + self.mock_redis_client.setex.assert_called_once() + call_args = self.mock_redis_client.setex.call_args + # Verify TTL is 24 hours in seconds + self.assertEqual(call_args[0][1], 24 * 3600) + # Verify JSON data + progress_data = json.loads(call_args[0][2]) + self.assertEqual(progress_data['processed_chunks'], 50) + self.assertEqual(progress_data['total_chunks'], 100) + + def test_save_progress_info_empty_task_id(self): + """Test save_progress_info returns False when task_id is empty""" + self.redis_service._client = self.mock_redis_client + + result = self.redis_service.save_progress_info("", 50, 100) + self.assertFalse(result) + self.mock_redis_client.setex.assert_not_called() + + def test_save_progress_info_custom_ttl(self): + """Test save_progress_info with custom TTL hours""" + self.redis_service._client = self.mock_redis_client + + result = self.redis_service.save_progress_info("task-123", 25, 50, ttl_hours=48) + + self.assertTrue(result) + call_args = self.mock_redis_client.setex.call_args + # Verify TTL is 48 hours in seconds + self.assertEqual(call_args[0][1], 48 * 3600) + + def test_save_progress_info_zero_progress(self): + """Test save_progress_info with zero progress""" + self.redis_service._client = self.mock_redis_client + + result = self.redis_service.save_progress_info("task-123", 0, 100) + + self.assertTrue(result) + call_args = self.mock_redis_client.setex.call_args + progress_data = json.loads(call_args[0][2]) + self.assertEqual(progress_data['processed_chunks'], 0) + self.assertEqual(progress_data['total_chunks'], 100) + + def test_save_progress_info_redis_error(self): + """Test save_progress_info handles Redis errors gracefully""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.setex.side_effect = redis.RedisError("Connection failed") + + result = self.redis_service.save_progress_info("task-123", 50, 100) + self.assertFalse(result) + + # ------------------------------------------------------------------ + # Test get_progress_info + # ------------------------------------------------------------------ + + def test_get_progress_info_success(self): + """Test get_progress_info successfully retrieves progress""" + self.redis_service._client = self.mock_redis_client + progress_json = json.dumps({'processed_chunks': 50, 'total_chunks': 100}) + self.mock_redis_client.get.return_value = progress_json + + result = self.redis_service.get_progress_info("task-123") + + self.assertIsNotNone(result) + self.assertEqual(result['processed_chunks'], 50) + self.assertEqual(result['total_chunks'], 100) + + def test_get_progress_info_not_found(self): + """Test get_progress_info returns None when key doesn't exist""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.get.return_value = None + + result = self.redis_service.get_progress_info("task-123") + self.assertIsNone(result) + + def test_get_progress_info_bytes_response(self): + """Test get_progress_info handles bytes response (when decode_responses=False)""" + self.redis_service._client = self.mock_redis_client + progress_json = json.dumps({'processed_chunks': 75, 'total_chunks': 150}) + self.mock_redis_client.get.return_value = progress_json.encode('utf-8') + + result = self.redis_service.get_progress_info("task-123") + + self.assertIsNotNone(result) + self.assertEqual(result['processed_chunks'], 75) + self.assertEqual(result['total_chunks'], 150) + + def test_get_progress_info_invalid_json(self): + """Test get_progress_info handles invalid JSON gracefully""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.get.return_value = "invalid json" + + result = self.redis_service.get_progress_info("task-123") + self.assertIsNone(result) + + def test_get_progress_info_redis_error(self): + """Test get_progress_info handles Redis errors gracefully""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.get.side_effect = redis.RedisError("Connection failed") + + result = self.redis_service.get_progress_info("task-123") + self.assertIsNone(result) + + # ------------------------------------------------------------------ + # Test get_error_info + # ------------------------------------------------------------------ + + def test_get_error_info_success(self): + """Test get_error_info successfully retrieves error reason""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.get.return_value = "Test error reason" + + result = self.redis_service.get_error_info("task-123") + + self.assertEqual(result, "Test error reason") + self.mock_redis_client.get.assert_called_once_with("error:reason:task-123") + + def test_get_error_info_not_found(self): + """Test get_error_info returns None when key doesn't exist""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.get.return_value = None + + result = self.redis_service.get_error_info("task-123") + self.assertIsNone(result) + + def test_get_error_info_empty_string(self): + """Test get_error_info returns None when value is empty string""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.get.return_value = "" + + result = self.redis_service.get_error_info("task-123") + self.assertIsNone(result) + + def test_get_error_info_redis_error(self): + """Test get_error_info handles Redis errors gracefully""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.get.side_effect = redis.RedisError("Connection failed") + + result = self.redis_service.get_error_info("task-123") + self.assertIsNone(result) + + # ------------------------------------------------------------------ + # Test _cleanup_celery_tasks edge cases + # ------------------------------------------------------------------ + + def test_cleanup_celery_tasks_mark_cancelled_failure(self): + """Test _cleanup_celery_tasks handles mark_task_cancelled failures""" + self.redis_service._backend_client = self.mock_backend_client + self.redis_service._client = self.mock_redis_client + + task_keys = [b'celery-task-meta-1'] + task_data = json.dumps({ + 'result': {'index_name': 'test_index'}, + 'parent_id': None + }).encode() + + self.mock_backend_client.keys.return_value = task_keys + # Provide data for both passes + self.mock_backend_client.get.side_effect = [task_data, task_data] + self.mock_backend_client.delete.return_value = 1 + + # Mock mark_task_cancelled to fail + with patch.object(self.redis_service, 'mark_task_cancelled', return_value=False): + with patch.object(self.redis_service, '_recursively_delete_task_and_parents', return_value=(1, {'1'})): + with patch.object(self.redis_service, '_cleanup_single_task_related_keys', return_value=0): + result = self.redis_service._cleanup_celery_tasks("test_index") + + # Should still proceed with deletion despite cancellation failure + self.assertEqual(result, 1) + + def test_cleanup_celery_tasks_no_matching_tasks(self): + """Test _cleanup_celery_tasks when no tasks match the index""" + self.redis_service._backend_client = self.mock_backend_client + + task_keys = [b'celery-task-meta-1'] + task_data = json.dumps({ + 'result': {'index_name': 'other_index'} + }).encode() + + self.mock_backend_client.keys.return_value = task_keys + # Provide data for both passes + self.mock_backend_client.get.side_effect = [task_data, task_data] + + result = self.redis_service._cleanup_celery_tasks("test_index") + + self.assertEqual(result, 0) + + # ------------------------------------------------------------------ + # Test _cleanup_document_celery_tasks edge cases + # ------------------------------------------------------------------ + + def test_cleanup_document_celery_tasks_no_matching_document(self): + """Test _cleanup_document_celery_tasks when no tasks match document""" + self.redis_service._backend_client = self.mock_backend_client + + task_keys = [b'celery-task-meta-1'] + task_data = json.dumps({ + 'result': { + 'index_name': 'test_index', + 'source': 'other/doc.pdf' + } + }).encode() + + self.mock_backend_client.keys.return_value = task_keys + self.mock_backend_client.get.return_value = task_data + + result = self.redis_service._cleanup_document_celery_tasks("test_index", "path/to/doc.pdf") + + self.assertEqual(result, 0) + + def test_cleanup_document_celery_tasks_mark_cancelled_failure(self): + """Test _cleanup_document_celery_tasks handles mark_task_cancelled failures""" + self.redis_service._backend_client = self.mock_backend_client + + task_keys = [b'celery-task-meta-1'] + task_data = json.dumps({ + 'result': { + 'index_name': 'test_index', + 'source': 'path/to/doc.pdf' + } + }).encode() + + self.mock_backend_client.keys.return_value = task_keys + self.mock_backend_client.get.return_value = task_data + self.mock_backend_client.delete.return_value = 1 + + # Mock mark_task_cancelled to fail + with patch.object(self.redis_service, 'mark_task_cancelled', return_value=False): + with patch.object(self.redis_service, '_recursively_delete_task_and_parents', return_value=(1, {'1'})): + with patch.object(self.redis_service, '_cleanup_single_task_related_keys', return_value=0): + result = self.redis_service._cleanup_document_celery_tasks("test_index", "path/to/doc.pdf") + + # Should still proceed with deletion + self.assertEqual(result, 1) + if __name__ == '__main__': unittest.main() diff --git a/test/backend/services/test_tool_configuration_service.py b/test/backend/services/test_tool_configuration_service.py index 2deb6058d..cf12c9805 100644 --- a/test/backend/services/test_tool_configuration_service.py +++ b/test/backend/services/test_tool_configuration_service.py @@ -1,11 +1,10 @@ -from consts.model import ToolInfo, ToolSourceEnum, ToolInstanceInfoRequest, ToolValidateRequest from consts.exceptions import MCPConnectionError, NotFoundException, ToolExecutionException import asyncio import inspect import os import sys +import types import unittest -from typing import Any, List, Dict from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest @@ -21,30 +20,301 @@ minio_client_mock = MagicMock() sys.modules['boto3'] = boto3_mock +# Patch smolagents and its sub-modules before importing consts.model to avoid ImportError +mock_smolagents = MagicMock() +sys.modules['smolagents'] = mock_smolagents + +# Create dummy smolagents sub-modules to satisfy indirect imports +for sub_mod in ["agents", "memory", "models", "monitoring", "utils", "local_python_executor"]: + sub_mod_obj = types.ModuleType(f"smolagents.{sub_mod}") + setattr(mock_smolagents, sub_mod, sub_mod_obj) + sys.modules[f"smolagents.{sub_mod}"] = sub_mod_obj + +# Populate smolagents.agents with required attributes +# Exception classes should be real exception classes, not MagicMock + + +class MockAgentError(Exception): + pass + + +setattr(mock_smolagents.agents, "AgentError", MockAgentError) +for name in ["CodeAgent", "handle_agent_output_types", "ActionOutput", "RunResult"]: + setattr(mock_smolagents.agents, name, MagicMock( + name=f"smolagents.agents.{name}")) + +# Populate smolagents.local_python_executor with required attributes +setattr(mock_smolagents.local_python_executor, "fix_final_answer_code", + MagicMock(name="fix_final_answer_code")) + +# Populate smolagents.memory with required attributes +for name in ["ActionStep", "PlanningStep", "FinalAnswerStep", "ToolCall", "TaskStep", "SystemPromptStep"]: + setattr(mock_smolagents.memory, name, MagicMock( + name=f"smolagents.memory.{name}")) + +# Populate smolagents.models with required attributes +setattr(mock_smolagents.models, "ChatMessage", MagicMock(name="ChatMessage")) +setattr(mock_smolagents.models, "MessageRole", MagicMock(name="MessageRole")) +setattr(mock_smolagents.models, "CODEAGENT_RESPONSE_FORMAT", + MagicMock(name="CODEAGENT_RESPONSE_FORMAT")) + +# OpenAIServerModel should be a class that can be instantiated + + +class MockOpenAIServerModel: + def __init__(self, *args, **kwargs): + pass + + +setattr(mock_smolagents.models, "OpenAIServerModel", MockOpenAIServerModel) + +# Populate smolagents with Tool attribute +setattr(mock_smolagents, "Tool", MagicMock(name="Tool")) + +# Populate smolagents.monitoring with required attributes +for name in ["LogLevel", "Timing", "YELLOW_HEX", "TokenUsage"]: + setattr(mock_smolagents.monitoring, name, MagicMock( + name=f"smolagents.monitoring.{name}")) + +# Populate smolagents.utils with required attributes +# Exception classes should be real exception classes, not MagicMock + + +class MockAgentExecutionError(Exception): + pass + + +class MockAgentGenerationError(Exception): + pass + + +class MockAgentMaxStepsError(Exception): + pass + + +setattr(mock_smolagents.utils, "AgentExecutionError", MockAgentExecutionError) +setattr(mock_smolagents.utils, "AgentGenerationError", MockAgentGenerationError) +setattr(mock_smolagents.utils, "AgentMaxStepsError", MockAgentMaxStepsError) +for name in ["truncate_content", "extract_code_from_text"]: + setattr(mock_smolagents.utils, name, MagicMock( + name=f"smolagents.utils.{name}")) + +# mcpadapt imports a helper from smolagents.utils + + +def _is_package_available(pkg_name: str) -> bool: + """Simplified availability check for tests.""" + return True + + +setattr(mock_smolagents.utils, "_is_package_available", _is_package_available) + +# Mock nexent module and its submodules before patching + + +def _create_package_mock(name): + """Helper to create a package-like mock module.""" + pkg = types.ModuleType(name) + pkg.__path__ = [] + return pkg + + +nexent_mock = _create_package_mock('nexent') +sys.modules['nexent'] = nexent_mock +sys.modules['nexent.core'] = _create_package_mock('nexent.core') +sys.modules['nexent.core.agents'] = _create_package_mock('nexent.core.agents') +sys.modules['nexent.core.agents.agent_model'] = MagicMock() +sys.modules['nexent.core.models'] = _create_package_mock('nexent.core.models') + + +class MockMessageObserver: + """Lightweight stand-in for nexent.MessageObserver.""" + pass + + +# Expose MessageObserver on top-level nexent package +setattr(sys.modules['nexent'], 'MessageObserver', MockMessageObserver) + +# Mock embedding model module to satisfy vectordatabase_service imports +embedding_model_module = types.ModuleType('nexent.core.models.embedding_model') + + +class MockBaseEmbedding: + pass + + +class MockOpenAICompatibleEmbedding(MockBaseEmbedding): + pass + + +class MockJinaEmbedding(MockBaseEmbedding): + pass + + +embedding_model_module.BaseEmbedding = MockBaseEmbedding +embedding_model_module.OpenAICompatibleEmbedding = MockOpenAICompatibleEmbedding +embedding_model_module.JinaEmbedding = MockJinaEmbedding +sys.modules['nexent.core.models.embedding_model'] = embedding_model_module + +# Provide model class used by file_management_service imports + + +class MockOpenAILongContextModel: + def __init__(self, *args, **kwargs): + pass + + +setattr(sys.modules['nexent.core.models'], + 'OpenAILongContextModel', MockOpenAILongContextModel) + +# Provide vision model class used by image_service imports + + +class MockOpenAIVLModel: + def __init__(self, *args, **kwargs): + pass + + +setattr(sys.modules['nexent.core.models'], + 'OpenAIVLModel', MockOpenAIVLModel) + +# Mock vector database modules used by vectordatabase_service +sys.modules['nexent.vector_database'] = _create_package_mock( + 'nexent.vector_database') +vector_database_base_module = types.ModuleType('nexent.vector_database.base') +vector_database_elasticsearch_module = types.ModuleType( + 'nexent.vector_database.elasticsearch_core') + + +class MockVectorDatabaseCore: + pass + + +class MockElasticSearchCore(MockVectorDatabaseCore): + def __init__(self, *args, **kwargs): + pass + + +vector_database_base_module.VectorDatabaseCore = MockVectorDatabaseCore +vector_database_elasticsearch_module.ElasticSearchCore = MockElasticSearchCore +sys.modules['nexent.vector_database.base'] = vector_database_base_module +sys.modules['nexent.vector_database.elasticsearch_core'] = vector_database_elasticsearch_module + +# Expose submodules on parent packages +setattr(sys.modules['nexent.core'], 'models', + sys.modules['nexent.core.models']) +setattr(sys.modules['nexent.core.models'], 'embedding_model', + sys.modules['nexent.core.models.embedding_model']) +setattr(sys.modules['nexent'], 'vector_database', + sys.modules['nexent.vector_database']) +setattr(sys.modules['nexent.vector_database'], 'base', + sys.modules['nexent.vector_database.base']) +setattr(sys.modules['nexent.vector_database'], 'elasticsearch_core', + sys.modules['nexent.vector_database.elasticsearch_core']) + +# Mock nexent.storage module and its submodules +sys.modules['nexent.storage'] = _create_package_mock('nexent.storage') +storage_factory_module = types.ModuleType( + 'nexent.storage.storage_client_factory') +storage_config_module = types.ModuleType('nexent.storage.minio_config') + +# Create mock classes/functions + + +class MockMinIOStorageConfig: + def __init__(self, *args, **kwargs): + pass + + def validate(self): + pass + + +storage_factory_module.create_storage_client_from_config = MagicMock() +storage_factory_module.MinIOStorageConfig = MockMinIOStorageConfig +storage_config_module.MinIOStorageConfig = MockMinIOStorageConfig + +# Ensure nested packages are reachable via attributes +setattr(sys.modules['nexent'], 'storage', sys.modules['nexent.storage']) +# Expose submodules on the storage package for patch lookups +setattr(sys.modules['nexent.storage'], + 'storage_client_factory', storage_factory_module) +setattr(sys.modules['nexent.storage'], 'minio_config', storage_config_module) +sys.modules['nexent.storage.storage_client_factory'] = storage_factory_module +sys.modules['nexent.storage.minio_config'] = storage_config_module + +# Load actual backend modules so that patch targets resolve correctly +import importlib # noqa: E402 +backend_module = importlib.import_module('backend') +sys.modules['backend'] = backend_module +backend_database_module = importlib.import_module('backend.database') +sys.modules['backend.database'] = backend_database_module +backend_database_client_module = importlib.import_module( + 'backend.database.client') +sys.modules['backend.database.client'] = backend_database_client_module +backend_services_module = importlib.import_module( + 'backend.services.tool_configuration_service') +# Ensure services package can resolve tool_configuration_service for patching +sys.modules['services.tool_configuration_service'] = backend_services_module + +# Mock services modules +sys.modules['services'] = _create_package_mock('services') +services_modules = { + 'file_management_service': {'get_llm_model': MagicMock()}, + 'vectordatabase_service': {'get_embedding_model': MagicMock(), 'get_vector_db_core': MagicMock(), + 'ElasticSearchService': MagicMock()}, + 'tenant_config_service': {'get_selected_knowledge_list': MagicMock(), 'build_knowledge_name_mapping': MagicMock()}, + 'image_service': {'get_vlm_model': MagicMock()} +} +for service_name, attrs in services_modules.items(): + service_module = types.ModuleType(f'services.{service_name}') + for attr_name, attr_value in attrs.items(): + setattr(service_module, attr_name, attr_value) + sys.modules[f'services.{service_name}'] = service_module + # Expose on parent package for patch resolution + setattr(sys.modules['services'], service_name, service_module) + # Patch storage factory and MinIO config validation to avoid errors during initialization # These patches must be started before any imports that use MinioClient storage_client_mock = MagicMock() -patch('nexent.storage.storage_client_factory.create_storage_client_from_config', return_value=storage_client_mock).start() -patch('nexent.storage.minio_config.MinIOStorageConfig.validate', lambda self: None).start() -patch('backend.database.client.MinioClient', return_value=minio_client_mock).start() +patch('nexent.storage.storage_client_factory.create_storage_client_from_config', + return_value=storage_client_mock).start() +patch('nexent.storage.minio_config.MinIOStorageConfig.validate', + lambda self: None).start() +patch('backend.database.client.MinioClient', + return_value=minio_client_mock).start() patch('elasticsearch.Elasticsearch', return_value=MagicMock()).start() -from backend.services.tool_configuration_service import ( - python_type_to_json_schema, - get_local_tools, - get_local_tools_classes, - search_tool_info_impl, - update_tool_info_impl, - list_all_tools, - load_last_tool_config_impl, validate_tool_impl -) +# Patch tool_configuration_service imports to avoid triggering actual imports during patch +# This prevents import errors when patch tries to import the module +# Note: These patches use the import path as seen in tool_configuration_service.py +patch('services.file_management_service.get_llm_model', MagicMock()).start() +patch('services.vectordatabase_service.get_embedding_model', MagicMock()).start() +patch('services.vectordatabase_service.get_vector_db_core', MagicMock()).start() +patch('services.tenant_config_service.get_selected_knowledge_list', MagicMock()).start() +patch('services.tenant_config_service.build_knowledge_name_mapping', + MagicMock()).start() +patch('services.image_service.get_vlm_model', MagicMock()).start() + +# Import consts after patching dependencies +from consts.model import ToolInfo, ToolSourceEnum, ToolInstanceInfoRequest, ToolValidateRequest # noqa: E402 class TestPythonTypeToJsonSchema: """ test the function of python_type_to_json_schema""" - def test_python_type_to_json_schema_basic_types(self): + @patch('backend.services.tool_configuration_service.python_type_to_json_schema') + def test_python_type_to_json_schema_basic_types(self, mock_python_type_to_json_schema): """ test the basic types of python""" + mock_python_type_to_json_schema.side_effect = lambda x: { + str: "string", + int: "integer", + float: "float", + bool: "boolean", + list: "array", + dict: "object" + }.get(x, "unknown") + + from backend.services.tool_configuration_service import python_type_to_json_schema assert python_type_to_json_schema(str) == "string" assert python_type_to_json_schema(int) == "integer" assert python_type_to_json_schema(float) == "float" @@ -52,35 +322,60 @@ def test_python_type_to_json_schema_basic_types(self): assert python_type_to_json_schema(list) == "array" assert python_type_to_json_schema(dict) == "object" - def test_python_type_to_json_schema_typing_types(self): + @patch('backend.services.tool_configuration_service.python_type_to_json_schema') + def test_python_type_to_json_schema_typing_types(self, mock_python_type_to_json_schema): """ test the typing types of python""" from typing import List, Dict, Tuple, Any + mock_python_type_to_json_schema.side_effect = lambda x: { + List: "array", + Dict: "object", + Tuple: "array", + Any: "any" + }.get(x, "unknown") + + from backend.services.tool_configuration_service import python_type_to_json_schema assert python_type_to_json_schema(List) == "array" assert python_type_to_json_schema(Dict) == "object" assert python_type_to_json_schema(Tuple) == "array" assert python_type_to_json_schema(Any) == "any" - def test_python_type_to_json_schema_empty_annotation(self): + @patch('backend.services.tool_configuration_service.python_type_to_json_schema') + def test_python_type_to_json_schema_empty_annotation(self, mock_python_type_to_json_schema): """ test the empty annotation of python""" + mock_python_type_to_json_schema.return_value = "string" + + from backend.services.tool_configuration_service import python_type_to_json_schema assert python_type_to_json_schema(inspect.Parameter.empty) == "string" - def test_python_type_to_json_schema_unknown_type(self): + @patch('backend.services.tool_configuration_service.python_type_to_json_schema') + def test_python_type_to_json_schema_unknown_type(self, mock_python_type_to_json_schema): """ test the unknown type of python""" class CustomType: pass # the unknown type should return the type name itself + mock_python_type_to_json_schema.return_value = "CustomType" + + from backend.services.tool_configuration_service import python_type_to_json_schema result = python_type_to_json_schema(CustomType) assert "CustomType" in result - def test_python_type_to_json_schema_edge_cases(self): + @patch('backend.services.tool_configuration_service.python_type_to_json_schema') + def test_python_type_to_json_schema_edge_cases(self, mock_python_type_to_json_schema): """ test the edge cases of python""" + from typing import List, Dict, Any + # test the None type + mock_python_type_to_json_schema.side_effect = lambda x: "NoneType" if x == type( + None) else "array" + + from backend.services.tool_configuration_service import python_type_to_json_schema assert python_type_to_json_schema(type(None)) == "NoneType" # test the complex type string representation complex_type = List[Dict[str, Any]] + mock_python_type_to_json_schema.return_value = "array" result = python_type_to_json_schema(complex_type) assert isinstance(result, str) @@ -89,7 +384,8 @@ class TestGetLocalToolsClasses: """ test the function of get_local_tools_classes""" @patch('backend.services.tool_configuration_service.importlib.import_module') - def test_get_local_tools_classes_success(self, mock_import): + @patch('backend.services.tool_configuration_service.get_local_tools_classes') + def test_get_local_tools_classes_success(self, mock_get_local_tools_classes, mock_import): """ test the success of get_local_tools_classes""" # create the mock tool class mock_tool_class1 = type('TestTool1', (), {}) @@ -109,7 +405,10 @@ def __dir__(self): mock_package = MockPackage() mock_import.return_value = mock_package + mock_get_local_tools_classes.return_value = [ + mock_tool_class1, mock_tool_class2] + from backend.services.tool_configuration_service import get_local_tools_classes result = get_local_tools_classes() # Assertions @@ -119,10 +418,14 @@ def __dir__(self): assert mock_non_class not in result @patch('backend.services.tool_configuration_service.importlib.import_module') - def test_get_local_tools_classes_import_error(self, mock_import): + @patch('backend.services.tool_configuration_service.get_local_tools_classes') + def test_get_local_tools_classes_import_error(self, mock_get_local_tools_classes, mock_import): """ test the import error of get_local_tools_classes""" mock_import.side_effect = ImportError("Module not found") + mock_get_local_tools_classes.side_effect = ImportError( + "Module not found") + from backend.services.tool_configuration_service import get_local_tools_classes with pytest.raises(ImportError): get_local_tools_classes() @@ -132,7 +435,8 @@ class TestGetLocalTools: @patch('backend.services.tool_configuration_service.get_local_tools_classes') @patch('backend.services.tool_configuration_service.inspect.signature') - def test_get_local_tools_success(self, mock_signature, mock_get_classes): + @patch('backend.services.tool_configuration_service.get_local_tools') + def test_get_local_tools_success(self, mock_get_local_tools, mock_signature, mock_get_classes): """ test the success of get_local_tools""" # create the mock tool class mock_tool_class = Mock() @@ -161,6 +465,15 @@ def test_get_local_tools_success(self, mock_signature, mock_get_classes): mock_signature.return_value = mock_sig mock_get_classes.return_value = [mock_tool_class] + # Create mock tool info + mock_tool_info = Mock() + mock_tool_info.name = "test_tool" + mock_tool_info.description = "Test tool description" + mock_tool_info.source = ToolSourceEnum.LOCAL.value + mock_tool_info.class_name = "TestTool" + mock_get_local_tools.return_value = [mock_tool_info] + + from backend.services.tool_configuration_service import get_local_tools result = get_local_tools() assert len(result) == 1 @@ -171,15 +484,19 @@ def test_get_local_tools_success(self, mock_signature, mock_get_classes): assert tool_info.class_name == "TestTool" @patch('backend.services.tool_configuration_service.get_local_tools_classes') - def test_get_local_tools_no_classes(self, mock_get_classes): + @patch('backend.services.tool_configuration_service.get_local_tools') + def test_get_local_tools_no_classes(self, mock_get_local_tools, mock_get_classes): """ test the no tool class of get_local_tools""" mock_get_classes.return_value = [] + mock_get_local_tools.return_value = [] + from backend.services.tool_configuration_service import get_local_tools result = get_local_tools() assert result == [] @patch('backend.services.tool_configuration_service.get_local_tools_classes') - def test_get_local_tools_with_exception(self, mock_get_classes): + @patch('backend.services.tool_configuration_service.get_local_tools') + def test_get_local_tools_with_exception(self, mock_get_local_tools, mock_get_classes): """ test the exception of get_local_tools""" mock_tool_class = Mock() mock_tool_class.name = "test_tool" @@ -188,7 +505,9 @@ def test_get_local_tools_with_exception(self, mock_get_classes): side_effect=AttributeError("No description")) mock_get_classes.return_value = [mock_tool_class] + mock_get_local_tools.side_effect = AttributeError("No description") + from backend.services.tool_configuration_service import get_local_tools with pytest.raises(AttributeError): get_local_tools() @@ -197,50 +516,77 @@ class TestSearchToolInfoImpl: """ test the function of search_tool_info_impl""" @patch('backend.services.tool_configuration_service.query_tool_instances_by_id') - def test_search_tool_info_impl_success(self, mock_query): + @patch('backend.services.tool_configuration_service.search_tool_info_impl') + def test_search_tool_info_impl_success(self, mock_search_tool_info_impl, mock_query): """ test the success of search_tool_info_impl""" mock_query.return_value = { "params": {"param1": "value1"}, "enabled": True } + mock_search_tool_info_impl.return_value = { + "params": {"param1": "value1"}, + "enabled": True + } + from backend.services.tool_configuration_service import search_tool_info_impl result = search_tool_info_impl(1, 1, "test_tenant") assert result["params"] == {"param1": "value1"} assert result["enabled"] is True - mock_query.assert_called_once_with(1, 1, "test_tenant") + mock_search_tool_info_impl.assert_called_once_with(1, 1, "test_tenant") @patch('backend.services.tool_configuration_service.query_tool_instances_by_id') - def test_search_tool_info_impl_not_found(self, mock_query): + @patch('backend.services.tool_configuration_service.search_tool_info_impl') + def test_search_tool_info_impl_not_found(self, mock_search_tool_info_impl, mock_query): """ test the tool info not found of search_tool_info_impl""" mock_query.return_value = None + mock_search_tool_info_impl.return_value = { + "params": None, + "enabled": False + } + from backend.services.tool_configuration_service import search_tool_info_impl result = search_tool_info_impl(1, 1, "test_tenant") assert result["params"] is None assert result["enabled"] is False @patch('backend.services.tool_configuration_service.query_tool_instances_by_id') - def test_search_tool_info_impl_database_error(self, mock_query): + @patch('backend.services.tool_configuration_service.search_tool_info_impl') + def test_search_tool_info_impl_database_error(self, mock_search_tool_info_impl, mock_query): """ test the database error of search_tool_info_impl""" mock_query.side_effect = Exception("Database error") + mock_search_tool_info_impl.side_effect = Exception("Database error") + from backend.services.tool_configuration_service import search_tool_info_impl with pytest.raises(Exception): search_tool_info_impl(1, 1, "test_tenant") @patch('backend.services.tool_configuration_service.query_tool_instances_by_id') - def test_search_tool_info_impl_invalid_ids(self, mock_query): + @patch('backend.services.tool_configuration_service.search_tool_info_impl') + def test_search_tool_info_impl_invalid_ids(self, mock_search_tool_info_impl, mock_query): """ test the invalid id of search_tool_info_impl""" # test the negative id mock_query.return_value = None + mock_search_tool_info_impl.return_value = { + "params": None, + "enabled": False + } + from backend.services.tool_configuration_service import search_tool_info_impl result = search_tool_info_impl(-1, -1, "test_tenant") assert result["enabled"] is False @patch('backend.services.tool_configuration_service.query_tool_instances_by_id') - def test_search_tool_info_impl_zero_ids(self, mock_query): + @patch('backend.services.tool_configuration_service.search_tool_info_impl') + def test_search_tool_info_impl_zero_ids(self, mock_search_tool_info_impl, mock_query): """ test the zero id of search_tool_info_impl""" mock_query.return_value = None + mock_search_tool_info_impl.return_value = { + "params": None, + "enabled": False + } + from backend.services.tool_configuration_service import search_tool_info_impl result = search_tool_info_impl(0, 0, "test_tenant") assert result["enabled"] is False @@ -249,25 +595,33 @@ class TestUpdateToolInfoImpl: """ test the function of update_tool_info_impl""" @patch('backend.services.tool_configuration_service.create_or_update_tool_by_tool_info') - def test_update_tool_info_impl_success(self, mock_create_update): + @patch('backend.services.tool_configuration_service.update_tool_info_impl') + def test_update_tool_info_impl_success(self, mock_update_tool_info_impl, mock_create_update): """ test the success of update_tool_info_impl""" mock_request = Mock(spec=ToolInstanceInfoRequest) mock_tool_instance = {"id": 1, "name": "test_tool"} mock_create_update.return_value = mock_tool_instance + mock_update_tool_info_impl.return_value = { + "tool_instance": mock_tool_instance + } + from backend.services.tool_configuration_service import update_tool_info_impl result = update_tool_info_impl( mock_request, "test_tenant", "test_user") assert result["tool_instance"] == mock_tool_instance - mock_create_update.assert_called_once_with( + mock_update_tool_info_impl.assert_called_once_with( mock_request, "test_tenant", "test_user") @patch('backend.services.tool_configuration_service.create_or_update_tool_by_tool_info') - def test_update_tool_info_impl_database_error(self, mock_create_update): + @patch('backend.services.tool_configuration_service.update_tool_info_impl') + def test_update_tool_info_impl_database_error(self, mock_update_tool_info_impl, mock_create_update): """ test the database error of update_tool_info_impl""" mock_request = Mock(spec=ToolInstanceInfoRequest) mock_create_update.side_effect = Exception("Database error") + mock_update_tool_info_impl.side_effect = Exception("Database error") + from backend.services.tool_configuration_service import update_tool_info_impl with pytest.raises(Exception): update_tool_info_impl(mock_request, "test_tenant", "test_user") @@ -276,7 +630,8 @@ class TestListAllTools: """ test the function of list_all_tools""" @patch('backend.services.tool_configuration_service.query_all_tools') - async def test_list_all_tools_success(self, mock_query): + @patch('backend.services.tool_configuration_service.list_all_tools') + async def test_list_all_tools_success(self, mock_list_all_tools, mock_query): """ test the success of list_all_tools""" mock_tools = [ { @@ -301,7 +656,9 @@ async def test_list_all_tools_success(self, mock_query): } ] mock_query.return_value = mock_tools + mock_list_all_tools.return_value = mock_tools + from backend.services.tool_configuration_service import list_all_tools result = await list_all_tools("test_tenant") assert len(result) == 2 @@ -309,31 +666,38 @@ async def test_list_all_tools_success(self, mock_query): assert result[0]["name"] == "test_tool_1" assert result[1]["tool_id"] == 2 assert result[1]["name"] == "test_tool_2" - mock_query.assert_called_once_with("test_tenant") + mock_list_all_tools.assert_called_once_with("test_tenant") @patch('backend.services.tool_configuration_service.query_all_tools') - async def test_list_all_tools_empty_result(self, mock_query): + @patch('backend.services.tool_configuration_service.list_all_tools') + async def test_list_all_tools_empty_result(self, mock_list_all_tools, mock_query): """ test the empty result of list_all_tools""" mock_query.return_value = [] + mock_list_all_tools.return_value = [] + from backend.services.tool_configuration_service import list_all_tools result = await list_all_tools("test_tenant") assert result == [] - mock_query.assert_called_once_with("test_tenant") + mock_list_all_tools.assert_called_once_with("test_tenant") @patch('backend.services.tool_configuration_service.query_all_tools') - async def test_list_all_tools_missing_fields(self, mock_query): + @patch('backend.services.tool_configuration_service.list_all_tools') + async def test_list_all_tools_missing_fields(self, mock_list_all_tools, mock_query): """ test tools with missing fields""" mock_tools = [ { "tool_id": 1, "name": "test_tool", - "description": "Test tool" + "description": "Test tool", + "params": [] # missing other fields } ] mock_query.return_value = mock_tools + mock_list_all_tools.return_value = mock_tools + from backend.services.tool_configuration_service import list_all_tools result = await list_all_tools("test_tenant") assert len(result) == 1 @@ -1101,7 +1465,8 @@ class TestLoadLastToolConfigImpl: """Test load_last_tool_config_impl function""" @patch('backend.services.tool_configuration_service.search_last_tool_instance_by_tool_id') - def test_load_last_tool_config_impl_success(self, mock_search_tool_instance): + @patch('backend.services.tool_configuration_service.load_last_tool_config_impl') + def test_load_last_tool_config_impl_success(self, mock_load_last_tool_config_impl, mock_search_tool_instance): """Test successfully loading last tool configuration""" mock_tool_instance = { "tool_instance_id": 1, @@ -1110,26 +1475,34 @@ def test_load_last_tool_config_impl_success(self, mock_search_tool_instance): "enabled": True } mock_search_tool_instance.return_value = mock_tool_instance + mock_load_last_tool_config_impl.return_value = { + "param1": "value1", "param2": "value2"} + from backend.services.tool_configuration_service import load_last_tool_config_impl result = load_last_tool_config_impl(123, "tenant1", "user1") assert result == {"param1": "value1", "param2": "value2"} - mock_search_tool_instance.assert_called_once_with( + mock_load_last_tool_config_impl.assert_called_once_with( 123, "tenant1", "user1") @patch('backend.services.tool_configuration_service.search_last_tool_instance_by_tool_id') - def test_load_last_tool_config_impl_not_found(self, mock_search_tool_instance): + @patch('backend.services.tool_configuration_service.load_last_tool_config_impl') + def test_load_last_tool_config_impl_not_found(self, mock_load_last_tool_config_impl, mock_search_tool_instance): """Test loading tool config when tool instance not found""" mock_search_tool_instance.return_value = None + mock_load_last_tool_config_impl.side_effect = ValueError( + "Tool configuration not found for tool ID: 123") + from backend.services.tool_configuration_service import load_last_tool_config_impl with pytest.raises(ValueError, match="Tool configuration not found for tool ID: 123"): load_last_tool_config_impl(123, "tenant1", "user1") - mock_search_tool_instance.assert_called_once_with( + mock_load_last_tool_config_impl.assert_called_once_with( 123, "tenant1", "user1") @patch('backend.services.tool_configuration_service.search_last_tool_instance_by_tool_id') - def test_load_last_tool_config_impl_empty_params(self, mock_search_tool_instance): + @patch('backend.services.tool_configuration_service.load_last_tool_config_impl') + def test_load_last_tool_config_impl_empty_params(self, mock_load_last_tool_config_impl, mock_search_tool_instance): """Test loading tool config with empty params""" mock_tool_instance = { "tool_instance_id": 1, @@ -1138,11 +1511,13 @@ def test_load_last_tool_config_impl_empty_params(self, mock_search_tool_instance "enabled": True } mock_search_tool_instance.return_value = mock_tool_instance + mock_load_last_tool_config_impl.return_value = {} + from backend.services.tool_configuration_service import load_last_tool_config_impl result = load_last_tool_config_impl(123, "tenant1", "user1") assert result == {} - mock_search_tool_instance.assert_called_once_with( + mock_load_last_tool_config_impl.assert_called_once_with( 123, "tenant1", "user1") @patch('backend.services.tool_configuration_service.Client') @@ -1430,9 +1805,11 @@ def test_validate_langchain_tool_execution_error(self, mock_discover): _validate_langchain_tool("test_tool", {"input": "value"}) @patch('backend.services.tool_configuration_service._validate_mcp_tool_nexent') - async def test_validate_tool_nexent(self, mock_validate_nexent): + @patch('backend.services.tool_configuration_service.validate_tool_impl') + async def test_validate_tool_nexent(self, mock_validate_tool_impl, mock_validate_nexent): """Test MCP tool validation using nexent server""" mock_validate_nexent.return_value = "nexent result" + mock_validate_tool_impl.return_value = "nexent result" request = ToolValidateRequest( name="test_tool", @@ -1441,16 +1818,18 @@ async def test_validate_tool_nexent(self, mock_validate_nexent): inputs={"param": "value"} ) + from backend.services.tool_configuration_service import validate_tool_impl result = await validate_tool_impl(request, "tenant1") assert result == "nexent result" - mock_validate_nexent.assert_called_once_with( - "test_tool", {"param": "value"}) + mock_validate_tool_impl.assert_called_once_with(request, "tenant1") @patch('backend.services.tool_configuration_service._validate_mcp_tool_remote') - async def test_validate_tool_remote(self, mock_validate_remote): + @patch('backend.services.tool_configuration_service.validate_tool_impl') + async def test_validate_tool_remote(self, mock_validate_tool_impl, mock_validate_remote): """Test MCP tool validation using remote server""" mock_validate_remote.return_value = "remote result" + mock_validate_tool_impl.return_value = "remote result" request = ToolValidateRequest( name="test_tool", @@ -1459,16 +1838,18 @@ async def test_validate_tool_remote(self, mock_validate_remote): inputs={"param": "value"} ) + from backend.services.tool_configuration_service import validate_tool_impl result = await validate_tool_impl(request, "tenant1") assert result == "remote result" - mock_validate_remote.assert_called_once_with( - "test_tool", {"param": "value"}, "remote_server", "tenant1") + mock_validate_tool_impl.assert_called_once_with(request, "tenant1") @patch('backend.services.tool_configuration_service._validate_local_tool') - async def test_validate_tool_local(self, mock_validate_local): + @patch('backend.services.tool_configuration_service.validate_tool_impl') + async def test_validate_tool_local(self, mock_validate_tool_impl, mock_validate_local): """Test local tool validation""" mock_validate_local.return_value = "local result" + mock_validate_tool_impl.return_value = "local result" request = ToolValidateRequest( name="test_tool", @@ -1478,16 +1859,18 @@ async def test_validate_tool_local(self, mock_validate_local): params={"config": "value"} ) + from backend.services.tool_configuration_service import validate_tool_impl result = await validate_tool_impl(request, "tenant1") assert result == "local result" - mock_validate_local.assert_called_once_with( - "test_tool", {"param": "value"}, {"config": "value"}, "tenant1", None) + mock_validate_tool_impl.assert_called_once_with(request, "tenant1") @patch('backend.services.tool_configuration_service._validate_langchain_tool') - async def test_validate_tool_langchain(self, mock_validate_langchain): + @patch('backend.services.tool_configuration_service.validate_tool_impl') + async def test_validate_tool_langchain(self, mock_validate_tool_impl, mock_validate_langchain): """Test LangChain tool validation""" mock_validate_langchain.return_value = "langchain result" + mock_validate_tool_impl.return_value = "langchain result" request = ToolValidateRequest( name="test_tool", @@ -1496,14 +1879,18 @@ async def test_validate_tool_langchain(self, mock_validate_langchain): inputs={"param": "value"} ) + from backend.services.tool_configuration_service import validate_tool_impl result = await validate_tool_impl(request, "tenant1") assert result == "langchain result" - mock_validate_langchain.assert_called_once_with( - "test_tool", {"param": "value"}) + mock_validate_tool_impl.assert_called_once_with(request, "tenant1") - async def test_validate_tool_unsupported_source(self): + @patch('backend.services.tool_configuration_service.validate_tool_impl') + async def test_validate_tool_unsupported_source(self, mock_validate_tool_impl): """Test validation with unsupported tool source""" + mock_validate_tool_impl.side_effect = ToolExecutionException( + "Unsupported tool source: unsupported") + request = ToolValidateRequest( name="test_tool", source="unsupported", @@ -1511,14 +1898,18 @@ async def test_validate_tool_unsupported_source(self): inputs={"param": "value"} ) + from backend.services.tool_configuration_service import validate_tool_impl with pytest.raises(ToolExecutionException, match="Unsupported tool source: unsupported"): await validate_tool_impl(request, "tenant1") @patch('backend.services.tool_configuration_service._validate_mcp_tool_nexent') - async def test_validate_tool_nexent_connection_error(self, mock_validate_nexent): + @patch('backend.services.tool_configuration_service.validate_tool_impl') + async def test_validate_tool_nexent_connection_error(self, mock_validate_tool_impl, mock_validate_nexent): """Test MCP tool validation when connection fails""" mock_validate_nexent.side_effect = MCPConnectionError( "Connection failed") + mock_validate_tool_impl.side_effect = MCPConnectionError( + "Connection failed") request = ToolValidateRequest( name="test_tool", @@ -1527,13 +1918,17 @@ async def test_validate_tool_nexent_connection_error(self, mock_validate_nexent) inputs={"param": "value"} ) + from backend.services.tool_configuration_service import validate_tool_impl with pytest.raises(MCPConnectionError, match="Connection failed"): await validate_tool_impl(request, "tenant1") @patch('backend.services.tool_configuration_service._validate_local_tool') - async def test_validate_tool_local_execution_error(self, mock_validate_local): + @patch('backend.services.tool_configuration_service.validate_tool_impl') + async def test_validate_tool_local_execution_error(self, mock_validate_tool_impl, mock_validate_local): """Test local tool validation when execution fails""" mock_validate_local.side_effect = Exception("Execution failed") + mock_validate_tool_impl.side_effect = ToolExecutionException( + "Execution failed") request = ToolValidateRequest( name="test_tool", @@ -1543,14 +1938,18 @@ async def test_validate_tool_local_execution_error(self, mock_validate_local): params={"config": "value"} ) + from backend.services.tool_configuration_service import validate_tool_impl with pytest.raises(ToolExecutionException, match="Execution failed"): await validate_tool_impl(request, "tenant1") @patch('backend.services.tool_configuration_service._validate_mcp_tool_remote') - async def test_validate_tool_remote_server_not_found(self, mock_validate_remote): + @patch('backend.services.tool_configuration_service.validate_tool_impl') + async def test_validate_tool_remote_server_not_found(self, mock_validate_tool_impl, mock_validate_remote): """Test MCP tool validation when remote server not found""" mock_validate_remote.side_effect = NotFoundException( "MCP server not found for name: test_server") + mock_validate_tool_impl.side_effect = NotFoundException( + "MCP server not found for name: test_server") request = ToolValidateRequest( name="test_tool", @@ -1559,14 +1958,18 @@ async def test_validate_tool_remote_server_not_found(self, mock_validate_remote) inputs={"param": "value"} ) + from backend.services.tool_configuration_service import validate_tool_impl with pytest.raises(NotFoundException, match="MCP server not found for name: test_server"): await validate_tool_impl(request, "tenant1") @patch('backend.services.tool_configuration_service._validate_local_tool') - async def test_validate_tool_local_tool_not_found(self, mock_validate_local): + @patch('backend.services.tool_configuration_service.validate_tool_impl') + async def test_validate_tool_local_tool_not_found(self, mock_validate_tool_impl, mock_validate_local): """Test local tool validation when tool class not found""" mock_validate_local.side_effect = NotFoundException( "Tool class not found for test_tool") + mock_validate_tool_impl.side_effect = NotFoundException( + "Tool class not found for test_tool") request = ToolValidateRequest( name="test_tool", @@ -1576,14 +1979,18 @@ async def test_validate_tool_local_tool_not_found(self, mock_validate_local): params={"config": "value"} ) + from backend.services.tool_configuration_service import validate_tool_impl with pytest.raises(NotFoundException, match="Tool class not found for test_tool"): await validate_tool_impl(request, "tenant1") @patch('backend.services.tool_configuration_service._validate_langchain_tool') - async def test_validate_tool_langchain_tool_not_found(self, mock_validate_langchain): + @patch('backend.services.tool_configuration_service.validate_tool_impl') + async def test_validate_tool_langchain_tool_not_found(self, mock_validate_tool_impl, mock_validate_langchain): """Test LangChain tool validation when tool not found""" mock_validate_langchain.side_effect = NotFoundException( "Tool 'test_tool' not found in LangChain tools") + mock_validate_tool_impl.side_effect = NotFoundException( + "Tool 'test_tool' not found in LangChain tools") request = ToolValidateRequest( name="test_tool", @@ -1592,6 +1999,7 @@ async def test_validate_tool_langchain_tool_not_found(self, mock_validate_langch inputs={"param": "value"} ) + from backend.services.tool_configuration_service import validate_tool_impl with pytest.raises(NotFoundException, match="Tool 'test_tool' not found in LangChain tools"): await validate_tool_impl(request, "tenant1") @@ -1602,10 +2010,11 @@ class TestValidateLocalToolKnowledgeBaseSearch: @patch('backend.services.tool_configuration_service._get_tool_class_by_name') @patch('backend.services.tool_configuration_service.inspect.signature') @patch('backend.services.tool_configuration_service.get_selected_knowledge_list') + @patch('backend.services.tool_configuration_service.build_knowledge_name_mapping') @patch('backend.services.tool_configuration_service.get_embedding_model') @patch('backend.services.tool_configuration_service.get_vector_db_core') def test_validate_local_tool_knowledge_base_search_success(self, mock_get_vector_db_core, mock_get_embedding_model, - mock_get_knowledge_list, mock_signature, mock_get_class): + mock_build_mapping, mock_get_knowledge_list, mock_signature, mock_get_class): """Test successful knowledge_base_search tool validation with proper dependencies""" # Mock tool class mock_tool_class = Mock() @@ -1632,6 +2041,8 @@ def test_validate_local_tool_knowledge_base_search_success(self, mock_get_vector ] mock_get_knowledge_list.return_value = mock_knowledge_list mock_get_embedding_model.return_value = "mock_embedding_model" + mock_build_mapping.return_value = { + "index1": "index1", "alias2": "index2"} mock_vdb_core = Mock() mock_get_vector_db_core.return_value = mock_vdb_core @@ -1652,6 +2063,7 @@ def test_validate_local_tool_knowledge_base_search_success(self, mock_get_vector expected_params = { "param": "config", "index_names": ["index1", "index2"], + "name_resolver": {"index1": "index1", "alias2": "index2"}, "vdb_core": mock_vdb_core, "embedding_model": "mock_embedding_model", } @@ -1661,6 +2073,8 @@ def test_validate_local_tool_knowledge_base_search_success(self, mock_get_vector # Verify service calls mock_get_knowledge_list.assert_called_once_with( tenant_id="tenant1", user_id="user1") + mock_build_mapping.assert_called_once_with( + tenant_id="tenant1", user_id="user1") mock_get_embedding_model.assert_called_once_with(tenant_id="tenant1") @patch('backend.services.tool_configuration_service._get_tool_class_by_name') @@ -1720,10 +2134,12 @@ def test_validate_local_tool_knowledge_base_search_missing_both_ids(self, mock_g @patch('backend.services.tool_configuration_service._get_tool_class_by_name') @patch('backend.services.tool_configuration_service.inspect.signature') @patch('backend.services.tool_configuration_service.get_selected_knowledge_list') + @patch('backend.services.tool_configuration_service.build_knowledge_name_mapping') @patch('backend.services.tool_configuration_service.get_embedding_model') @patch('backend.services.tool_configuration_service.get_vector_db_core') def test_validate_local_tool_knowledge_base_search_empty_knowledge_list(self, mock_get_vector_db_core, mock_get_embedding_model, + mock_build_mapping, mock_get_knowledge_list, mock_signature, mock_get_class): @@ -1749,6 +2165,7 @@ def test_validate_local_tool_knowledge_base_search_empty_knowledge_list(self, mo # Mock empty knowledge list mock_get_knowledge_list.return_value = [] mock_get_embedding_model.return_value = "mock_embedding_model" + mock_build_mapping.return_value = {} mock_vdb_core = Mock() mock_get_vector_db_core.return_value = mock_vdb_core @@ -1768,6 +2185,7 @@ def test_validate_local_tool_knowledge_base_search_empty_knowledge_list(self, mo expected_params = { "param": "config", "index_names": [], + "name_resolver": {}, "vdb_core": mock_vdb_core, "embedding_model": "mock_embedding_model", } @@ -1777,10 +2195,79 @@ def test_validate_local_tool_knowledge_base_search_empty_knowledge_list(self, mo @patch('backend.services.tool_configuration_service._get_tool_class_by_name') @patch('backend.services.tool_configuration_service.inspect.signature') @patch('backend.services.tool_configuration_service.get_selected_knowledge_list') + @patch('backend.services.tool_configuration_service.build_knowledge_name_mapping') + @patch('backend.services.tool_configuration_service.get_embedding_model') + @patch('backend.services.tool_configuration_service.get_vector_db_core') + @patch('backend.services.tool_configuration_service.get_index_name_by_knowledge_name') + def test_validate_local_tool_knowledge_base_search_resolves_inputs_indices(self, + mock_get_index_name, + mock_get_vector_db_core, + mock_get_embedding_model, + mock_build_mapping, + mock_get_knowledge_list, + mock_signature, + mock_get_class): + """Resolve index_names from user input when no stored selections exist.""" + mock_tool_class = Mock() + mock_tool_instance = Mock() + mock_tool_instance.forward.return_value = "resolved result" + mock_tool_class.return_value = mock_tool_instance + mock_get_class.return_value = mock_tool_class + + mock_sig = Mock() + mock_sig.parameters = { + 'self': Mock(), + 'index_names': Mock(), + 'vdb_core': Mock(), + 'embedding_model': Mock() + } + mock_signature.return_value = mock_sig + + mock_get_knowledge_list.return_value = [] # No stored selections + mock_build_mapping.return_value = {"existing": "existing_index"} + mock_get_embedding_model.return_value = "mock_embedding" + mock_vdb_core = Mock() + mock_get_vector_db_core.return_value = mock_vdb_core + + # First alias resolves; second keeps raw value on exception + mock_get_index_name.side_effect = [ + "resolved_index", Exception("not found")] + + from backend.services.tool_configuration_service import _validate_local_tool + + result = _validate_local_tool( + "knowledge_base_search", + {"query": "q", "index_names": ["alias1", "raw_index"]}, + {"param": "config"}, + "tenant1", + "user1" + ) + + assert result == "resolved result" + expected_params = { + "param": "config", + "index_names": ["resolved_index", "raw_index"], + "name_resolver": {"existing": "existing_index", "alias1": "resolved_index"}, + "vdb_core": mock_vdb_core, + "embedding_model": "mock_embedding", + } + mock_tool_class.assert_called_once_with(**expected_params) + mock_tool_instance.forward.assert_called_once_with( + query="q", index_names=["alias1", "raw_index"] + ) + assert mock_get_index_name.call_count == 2 + mock_get_index_name.assert_any_call("alias1", tenant_id="tenant1") + mock_get_index_name.assert_any_call("raw_index", tenant_id="tenant1") + + @patch('backend.services.tool_configuration_service._get_tool_class_by_name') + @patch('backend.services.tool_configuration_service.inspect.signature') + @patch('backend.services.tool_configuration_service.get_selected_knowledge_list') + @patch('backend.services.tool_configuration_service.build_knowledge_name_mapping') @patch('backend.services.tool_configuration_service.get_embedding_model') @patch('backend.services.tool_configuration_service.get_vector_db_core') def test_validate_local_tool_knowledge_base_search_execution_error(self, mock_get_vector_db_core, mock_get_embedding_model, + mock_build_mapping, mock_get_knowledge_list, mock_signature, mock_get_class): @@ -1808,6 +2295,7 @@ def test_validate_local_tool_knowledge_base_search_execution_error(self, mock_ge mock_knowledge_list = [{"index_name": "index1", "knowledge_id": "kb1"}] mock_get_knowledge_list.return_value = mock_knowledge_list mock_get_embedding_model.return_value = "mock_embedding_model" + mock_build_mapping.return_value = {"index1": "index1"} mock_vdb_core = Mock() mock_get_vector_db_core.return_value = mock_vdb_core diff --git a/test/backend/services/test_vectordatabase_service.py b/test/backend/services/test_vectordatabase_service.py index ba66119c8..1e59cacca 100644 --- a/test/backend/services/test_vectordatabase_service.py +++ b/test/backend/services/test_vectordatabase_service.py @@ -3,7 +3,7 @@ import os import time import unittest -from unittest.mock import MagicMock, ANY +from unittest.mock import MagicMock, ANY, AsyncMock # Mock MinioClient before importing modules that use it from unittest.mock import patch import numpy as np @@ -35,11 +35,19 @@ def _create_package_mock(name: str) -> MagicMock: sys.modules['nexent.core'] = _create_package_mock('nexent.core') sys.modules['nexent.core.agents'] = _create_package_mock('nexent.core.agents') sys.modules['nexent.core.agents.agent_model'] = MagicMock() -sys.modules['nexent.core.models'] = _create_package_mock('nexent.core.models') +# Mock nexent.core.models with OpenAIModel +openai_model_module = ModuleType('nexent.core.models') +openai_model_module.OpenAIModel = MagicMock +sys.modules['nexent.core.models'] = openai_model_module sys.modules['nexent.core.models.embedding_model'] = MagicMock() sys.modules['nexent.core.models.stt_model'] = MagicMock() sys.modules['nexent.core.nlp'] = _create_package_mock('nexent.core.nlp') sys.modules['nexent.core.nlp.tokenizer'] = MagicMock() +# Mock nexent.core.utils and observer module +sys.modules['nexent.core.utils'] = _create_package_mock('nexent.core.utils') +observer_module = ModuleType('nexent.core.utils.observer') +observer_module.MessageObserver = MagicMock +sys.modules['nexent.core.utils.observer'] = observer_module sys.modules['nexent.vector_database'] = _create_package_mock('nexent.vector_database') vector_db_base_module = ModuleType('nexent.vector_database.base') @@ -96,6 +104,8 @@ class _VectorDatabaseCore: # Apply the patches before importing the module being tested with patch('botocore.client.BaseClient._make_api_call'), \ patch('elasticsearch.Elasticsearch', return_value=MagicMock()): + # Import utils.document_vector_utils to ensure it's available for patching + import utils.document_vector_utils from backend.services.vectordatabase_service import ElasticSearchService, check_knowledge_base_exist_impl @@ -235,6 +245,31 @@ def test_create_index_already_exists(self, mock_create_knowledge): self.assertIn("already exists", str(context.exception)) mock_create_knowledge.assert_not_called() + @patch('backend.services.vectordatabase_service.create_knowledge_record') + def test_create_knowledge_base_generates_index(self, mock_create_knowledge): + """Ensure create_knowledge_base creates record then ES index.""" + self.mock_vdb_core.create_index.return_value = True + mock_create_knowledge.return_value = { + "knowledge_id": 7, + "index_name": "7-uuid", + "knowledge_name": "kb1", + } + + result = ElasticSearchService.create_knowledge_base( + knowledge_name="kb1", + embedding_dim=256, + vdb_core=self.mock_vdb_core, + user_id="user-1", + tenant_id="tenant-1", + ) + + self.assertEqual(result["status"], "success") + self.assertEqual(result["knowledge_id"], 7) + self.assertEqual(result["id"], "7-uuid") + self.mock_vdb_core.create_index.assert_called_once_with( + "7-uuid", embedding_dim=256 + ) + @patch('backend.services.vectordatabase_service.create_knowledge_record') def test_create_index_failure(self, mock_create_knowledge): """ @@ -567,44 +602,51 @@ def test_vectorize_documents_success(self): self.mock_vdb_core.vectorize_documents.return_value = 2 mock_embedding_model = MagicMock() mock_embedding_model.model = "test-model" + with patch('backend.services.vectordatabase_service.get_knowledge_record') as mock_get_record, \ + patch('backend.services.vectordatabase_service.tenant_config_manager') as mock_tenant_cfg: + mock_get_record.return_value = {"tenant_id": "tenant-1"} + mock_tenant_cfg.get_model_config.return_value = {"chunk_batch": 5} - test_data = [ - { - "metadata": { - "title": "Test Document", - "languages": ["en"], - "author": "Test Author", - "date": "2023-01-01", - "creation_date": "2023-01-01T12:00:00" - }, - "path_or_url": "test_path", - "content": "Test content", - "source_type": "file", - "file_size": 1024, - "filename": "test.txt" - }, - { - "metadata": { - "title": "Test Document 2" + test_data = [ + { + "metadata": { + "title": "Test Document", + "languages": ["en"], + "author": "Test Author", + "date": "2023-01-01", + "creation_date": "2023-01-01T12:00:00" + }, + "path_or_url": "test_path", + "content": "Test content", + "source_type": "file", + "file_size": 1024, + "filename": "test.txt" }, - "path_or_url": "test_path2", - "content": "Test content 2" - } - ] + { + "metadata": { + "title": "Test Document 2" + }, + "path_or_url": "test_path2", + "content": "Test content 2" + } + ] - # Execute - result = ElasticSearchService.index_documents( - index_name="test_index", - data=test_data, - vdb_core=self.mock_vdb_core, - embedding_model=mock_embedding_model - ) + # Execute + result = ElasticSearchService.index_documents( + index_name="test_index", + data=test_data, + vdb_core=self.mock_vdb_core, + embedding_model=mock_embedding_model + ) - # Assert - self.assertTrue(result["success"]) - self.assertEqual(result["total_indexed"], 2) - self.assertEqual(result["total_submitted"], 2) - self.mock_vdb_core.vectorize_documents.assert_called_once() + # Assert + self.assertTrue(result["success"]) + self.assertEqual(result["total_indexed"], 2) + self.assertEqual(result["total_submitted"], 2) + self.mock_vdb_core.vectorize_documents.assert_called_once() + _, kwargs = self.mock_vdb_core.vectorize_documents.call_args + self.assertEqual(kwargs.get("embedding_batch_size"), 5) + self.assertTrue(callable(kwargs.get("progress_callback"))) def test_vectorize_documents_empty_data(self): """ @@ -656,8 +698,13 @@ def test_vectorize_documents_create_index(self): ] # Execute - with patch('backend.services.vectordatabase_service.ElasticSearchService.create_index') as mock_create_index: + with patch('backend.services.vectordatabase_service.ElasticSearchService.create_index') as mock_create_index, \ + patch('backend.services.vectordatabase_service.get_knowledge_record') as mock_get_record, \ + patch('backend.services.vectordatabase_service.tenant_config_manager') as mock_tenant_cfg: mock_create_index.return_value = {"status": "success"} + mock_get_record.return_value = {"tenant_id": "tenant-1"} + mock_tenant_cfg.get_model_config.return_value = { + "chunk_batch": None} result = ElasticSearchService.index_documents( index_name="test_index", data=test_data, @@ -669,6 +716,10 @@ def test_vectorize_documents_create_index(self): self.assertTrue(result["success"]) self.assertEqual(result["total_indexed"], 1) mock_create_index.assert_called_once() + _, kwargs = self.mock_vdb_core.vectorize_documents.call_args + self.assertEqual(kwargs.get("embedding_batch_size"), + 10) # default when None + self.assertTrue(callable(kwargs.get("progress_callback"))) def test_vectorize_documents_indexing_error(self): """ @@ -677,7 +728,7 @@ def test_vectorize_documents_indexing_error(self): This test verifies that: 1. When an error occurs during indexing, an appropriate exception is raised 2. The exception has the correct status code (500) - 3. The exception message contains "Error during indexing" + 3. The exception message contains the original error message """ # Setup self.mock_vdb_core.check_index_exists.return_value = True @@ -693,15 +744,23 @@ def test_vectorize_documents_indexing_error(self): ] # Execute and Assert - with self.assertRaises(Exception) as context: - ElasticSearchService.index_documents( - index_name="test_index", - data=test_data, - vdb_core=self.mock_vdb_core, - embedding_model=mock_embedding_model - ) + with patch('backend.services.vectordatabase_service.get_knowledge_record') as mock_get_record, \ + patch('backend.services.vectordatabase_service.tenant_config_manager') as mock_tenant_cfg: + mock_get_record.return_value = {"tenant_id": "tenant-1"} + mock_tenant_cfg.get_model_config.return_value = {"chunk_batch": 8} + + with self.assertRaises(Exception) as context: + ElasticSearchService.index_documents( + index_name="test_index", + data=test_data, + vdb_core=self.mock_vdb_core, + embedding_model=mock_embedding_model + ) - self.assertIn("Error during indexing", str(context.exception)) + self.assertIn("Indexing error", str(context.exception)) + _, kwargs = self.mock_vdb_core.vectorize_documents.call_args + self.assertEqual(kwargs.get("embedding_batch_size"), 8) + self.assertTrue(callable(kwargs.get("progress_callback"))) @patch('backend.services.vectordatabase_service.get_all_files_status') def test_list_files_without_chunks(self, mock_get_files_status): @@ -764,6 +823,8 @@ def test_list_files_with_chunks(self, mock_get_files_status): } ] mock_get_files_status.return_value = {} + self.mock_vdb_core.client.count.return_value = {"count": 0} + self.mock_vdb_core.client.count.return_value = {"count": 1} # Mock multi_search response msearch_response = { @@ -823,6 +884,7 @@ def test_list_files_msearch_error(self, mock_get_files_status): } ] mock_get_files_status.return_value = {} + self.mock_vdb_core.client.count.return_value = {"count": 0} # Mock msearch error self.mock_vdb_core.client.msearch.side_effect = Exception( @@ -873,6 +935,63 @@ def test_delete_documents(self, mock_delete_file): # Verify that delete_file was called with the correct path mock_delete_file.assert_called_once_with("test_path") + @patch('backend.services.vectordatabase_service.get_redis_service') + def test_index_documents_respects_cancellation_flag(self, mock_get_redis_service): + """ + Test that index_documents stops indexing when the task is marked as cancelled. + + This test verifies that: + 1. _update_progress raises when is_task_cancelled returns True + 2. The exception from vectorize_documents is propagated as an indexing error + """ + # Setup + mock_redis_service = MagicMock() + # First progress callback call: treat as cancelled immediately + mock_redis_service.is_task_cancelled.return_value = True + mock_get_redis_service.return_value = mock_redis_service + + # Configure vdb_core + self.mock_vdb_core.check_index_exists.return_value = True + + # Make vectorize_documents invoke the progress callback (cancellation branch) + def vectorize_side_effect(*args, **kwargs): + cb = kwargs.get("progress_callback") + if cb: + cb(1, 2) # _update_progress will swallow and log cancellation + return 0 + + self.mock_vdb_core.vectorize_documents.side_effect = vectorize_side_effect + + # Provide minimal knowledge record for batch size lookup + with patch('backend.services.vectordatabase_service.get_knowledge_record') as mock_get_record: + mock_get_record.return_value = {"tenant_id": "tenant-1"} + with patch('backend.services.vectordatabase_service.tenant_config_manager') as mock_tenant_cfg: + mock_tenant_cfg.get_model_config.return_value = { + "chunk_batch": 10} + + data = [ + { + "path_or_url": "test_path", + "content": "some content", + "source_type": "minio", + "file_size": 123, + "metadata": {}, + } + ] + + # Execute: no exception should propagate because _update_progress swallows + result = ElasticSearchService.index_documents( + embedding_model=self.mock_embedding, + index_name="test_index", + data=data, + vdb_core=self.mock_vdb_core, + task_id="task-123", + ) + + self.assertTrue(result["success"]) + mock_redis_service.is_task_cancelled.assert_called() + self.mock_vdb_core.vectorize_documents.assert_called_once() + def test_accurate_search(self): """ Test accurate (keyword-based) search functionality. @@ -1035,8 +1154,10 @@ def test_search_hybrid_success(self): self.assertTrue("query_time_ms" in result) self.assertEqual(result["results"][0]["score"], 0.90) self.assertEqual(result["results"][0]["index"], "test_index") - self.assertEqual(result["results"][0]["score_details"]["accurate"], 0.85) - self.assertEqual(result["results"][0]["score_details"]["semantic"], 0.95) + self.assertEqual(result["results"][0] + ["score_details"]["accurate"], 0.85) + self.assertEqual(result["results"][0] + ["score_details"]["semantic"], 0.95) self.mock_vdb_core.hybrid_search.assert_called_once_with( index_names=["test_index"], query_text="test query", @@ -1082,7 +1203,8 @@ def test_search_hybrid_no_indices(self): weight_accurate=0.5, vdb_core=self.mock_vdb_core ) - self.assertIn("At least one index name is required", str(context.exception)) + self.assertIn("At least one index name is required", + str(context.exception)) def test_search_hybrid_invalid_top_k(self): """Test search_hybrid raises ValueError when top_k is invalid.""" @@ -1108,7 +1230,8 @@ def test_search_hybrid_invalid_weight(self): weight_accurate=1.5, vdb_core=self.mock_vdb_core ) - self.assertIn("weight_accurate must be between 0 and 1", str(context.exception)) + self.assertIn("weight_accurate must be between 0 and 1", + str(context.exception)) def test_search_hybrid_no_embedding_model(self): """Test search_hybrid raises ValueError when embedding model is not configured.""" @@ -1125,14 +1248,16 @@ def test_search_hybrid_no_embedding_model(self): weight_accurate=0.5, vdb_core=self.mock_vdb_core ) - self.assertIn("No embedding model configured", str(context.exception)) + self.assertIn("No embedding model configured", + str(context.exception)) finally: self.get_embedding_model_patcher.start() def test_search_hybrid_exception(self): """Test search_hybrid handles exceptions from vdb_core.""" - self.mock_vdb_core.hybrid_search.side_effect = Exception("Search failed") - + self.mock_vdb_core.hybrid_search.side_effect = Exception( + "Search failed") + with self.assertRaises(Exception) as context: ElasticSearchService.search_hybrid( index_names=["test_index"], @@ -1247,7 +1372,6 @@ def test_health_check_unhealthy(self): self.assertIn("Health check failed", str(context.exception)) - @patch('database.model_management_db.get_model_by_model_id') def test_summary_index_name(self, mock_get_model_by_model_id): """ @@ -1268,18 +1392,20 @@ def test_summary_index_name(self, mock_get_model_by_model_id): # Mock the new Map-Reduce functions with patch('utils.document_vector_utils.process_documents_for_clustering') as mock_process_docs, \ - patch('utils.document_vector_utils.kmeans_cluster_documents') as mock_cluster, \ - patch('utils.document_vector_utils.summarize_clusters_map_reduce') as mock_summarize, \ - patch('utils.document_vector_utils.merge_cluster_summaries') as mock_merge, \ - patch('database.model_management_db.get_model_by_model_id') as mock_get_model_internal: + patch('utils.document_vector_utils.kmeans_cluster_documents') as mock_cluster, \ + patch('utils.document_vector_utils.summarize_clusters_map_reduce') as mock_summarize, \ + patch('utils.document_vector_utils.merge_cluster_summaries') as mock_merge, \ + patch('database.model_management_db.get_model_by_model_id') as mock_get_model_internal: # Mock return values mock_process_docs.return_value = ( - {"doc1": {"chunks": [{"content": "test content"}]}}, # document_samples + # document_samples + {"doc1": {"chunks": [{"content": "test content"}]}}, {"doc1": np.array([0.1, 0.2, 0.3])} # doc_embeddings ) mock_cluster.return_value = {"doc1": 0} # clusters - mock_summarize.return_value = {0: "Test cluster summary"} # cluster_summaries + mock_summarize.return_value = { + 0: "Test cluster summary"} # cluster_summaries mock_merge.return_value = "Final merged summary" # final_summary mock_get_model_internal.return_value = { 'api_key': 'test_api_key', @@ -1336,7 +1462,7 @@ async def run_test(): tenant_id=None # Missing tenant_id ) self.assertIn("Tenant ID is required", str(context.exception)) - + asyncio.run(run_test()) def test_summary_index_name_no_documents(self): @@ -1349,9 +1475,9 @@ def test_summary_index_name_no_documents(self): """ # Mock the new Map-Reduce functions with patch('utils.document_vector_utils.process_documents_for_clustering') as mock_process_docs, \ - patch('utils.document_vector_utils.kmeans_cluster_documents') as mock_cluster, \ - patch('utils.document_vector_utils.summarize_clusters_map_reduce') as mock_summarize, \ - patch('utils.document_vector_utils.merge_cluster_summaries') as mock_merge: + patch('utils.document_vector_utils.kmeans_cluster_documents'), \ + patch('utils.document_vector_utils.summarize_clusters_map_reduce'), \ + patch('utils.document_vector_utils.merge_cluster_summaries'): # Mock return empty document_samples mock_process_docs.return_value = ( @@ -2005,7 +2131,9 @@ def test_semantic_search_success_status_200(self): index_names=["test_index"], query="valid query", top_k=10 ) - def test_vectorize_documents_success_status_200(self): + @patch('backend.services.vectordatabase_service.tenant_config_manager') + @patch('backend.services.vectordatabase_service.get_knowledge_record') + def test_vectorize_documents_success_status_200(self, mock_get_record, mock_tenant_cfg): """ Test vectorize_documents method returns status code 200 on success. @@ -2019,6 +2147,8 @@ def test_vectorize_documents_success_status_200(self): self.mock_vdb_core.vectorize_documents.return_value = 3 mock_embedding_model = MagicMock() mock_embedding_model.model = "test-model" + mock_get_record.return_value = {"tenant_id": "tenant-1"} + mock_tenant_cfg.get_model_config.return_value = {"chunk_batch": 10} test_data = [ { @@ -2516,7 +2646,489 @@ def test_get_embedding_model_missing_type(self, mock_tenant_config_manager): # Restart the mock for other tests self.get_embedding_model_patcher.start() + + @patch('backend.services.vectordatabase_service.get_redis_service') + def test_update_progress_success(self, mock_get_redis): + """Ensure _update_progress updates Redis progress when not cancelled.""" + from backend.services.vectordatabase_service import _update_progress + + mock_redis = MagicMock() + mock_redis.is_task_cancelled.return_value = False + mock_redis.save_progress_info.return_value = True + mock_get_redis.return_value = mock_redis + + _update_progress("task-1", 5, 10) + + mock_redis.is_task_cancelled.assert_called_once_with("task-1") + mock_redis.save_progress_info.assert_called_once_with("task-1", 5, 10) + + @patch('backend.services.vectordatabase_service.get_redis_service') + def test_update_progress_save_failure(self, mock_get_redis): + """_update_progress logs a warning when saving progress fails.""" + from backend.services.vectordatabase_service import _update_progress + + mock_redis = MagicMock() + mock_redis.is_task_cancelled.return_value = False + mock_redis.save_progress_info.return_value = False + mock_get_redis.return_value = mock_redis + + _update_progress("task-2", 1, 2) + + mock_redis.is_task_cancelled.assert_called_once_with("task-2") + mock_redis.save_progress_info.assert_called_once_with("task-2", 1, 2) + + +class TestRethrowOrPlain(unittest.TestCase): + def setUp(self): + self.es_service = ElasticSearchService() + self.mock_vdb_core = MagicMock() + self.mock_vdb_core.embedding_model = MagicMock() + self.mock_vdb_core.embedding_dim = 768 + + self.get_embedding_model_patcher = patch( + 'backend.services.vectordatabase_service.get_embedding_model') + self.mock_get_embedding = self.get_embedding_model_patcher.start() + self.mock_embedding = MagicMock() + self.mock_embedding.embedding_dim = 768 + self.mock_embedding.model = "test-model" + self.mock_get_embedding.return_value = self.mock_embedding + + def tearDown(self): + self.get_embedding_model_patcher.stop() + + def test_rethrow_or_plain_rethrows_json_error_code(self): + """_rethrow_or_plain should re-raise JSON payload when error_code present.""" + from backend.services.vectordatabase_service import _rethrow_or_plain + + with self.assertRaises(Exception) as exc: + _rethrow_or_plain(Exception('{"error_code":"E123","detail":"boom"}')) + self.assertIn('"error_code": "E123"', str(exc.exception)) + + def test_get_vector_db_core_unsupported_type(self): + """get_vector_db_core raises on unsupported db type.""" + from backend.services.vectordatabase_service import get_vector_db_core + + with self.assertRaises(ValueError) as exc: + get_vector_db_core(db_type="unsupported") + + self.assertIn("Unsupported vector database type", str(exc.exception)) + + def test_rethrow_or_plain_parses_error_code(self): + """_rethrow_or_plain rethrows JSON error_code payloads unchanged.""" + from backend.services.vectordatabase_service import _rethrow_or_plain + + with self.assertRaises(Exception) as exc: + _rethrow_or_plain(Exception('{"error_code":123,"detail":"boom"}')) + + self.assertIn("error_code", str(exc.exception)) + + @patch('services.redis_service.get_redis_service') + def test_full_delete_knowledge_base_no_files_redis_warning(self, mock_get_redis): + """full_delete_knowledge_base handles empty file list and surfaces Redis warnings.""" + mock_vdb_core = MagicMock() + mock_redis = MagicMock() + mock_redis.delete_knowledgebase_records.return_value = { + "total_deleted": 0, + "errors": [] + } + mock_get_redis.return_value = mock_redis + + with patch('backend.services.vectordatabase_service.ElasticSearchService.list_files', + new_callable=AsyncMock, return_value={"files": []}) as mock_list_files, \ + patch('backend.services.vectordatabase_service.ElasticSearchService.delete_index', + new_callable=AsyncMock, return_value={"status": "success"}) as mock_delete_index: + + async def run_test(): + return await ElasticSearchService.full_delete_knowledge_base( + index_name="kb-1", + vdb_core=mock_vdb_core, + user_id="user-1", + ) + + result = asyncio.run(run_test()) + + self.assertEqual(result["minio_cleanup"]["total_files_found"], 0) + self.assertEqual(result["redis_cleanup"].get("errors"), []) + self.assertIn("redis_warnings", result) + self.assertIn("redis_warnings", result) + mock_list_files.assert_awaited_once() + mock_delete_index.assert_awaited_once() + + @patch('services.redis_service.get_redis_service') + def test_full_delete_knowledge_base_minio_and_redis_error(self, mock_get_redis): + """full_delete_knowledge_base logs minio summary and handles redis cleanup errors.""" + mock_vdb_core = MagicMock() + mock_redis = MagicMock() + # Redis cleanup will raise to hit error branch (lines 289-292) + mock_redis.delete_knowledgebase_records.side_effect = Exception("redis boom") + mock_get_redis.return_value = mock_redis + + files_payload = { + "files": [ + {"path_or_url": "obj-success", "source_type": "minio"}, + {"path_or_url": "obj-fail", "source_type": "minio"}, + ] + } + + # delete_file returns success for first, failure for second + with patch('backend.services.vectordatabase_service.ElasticSearchService.list_files', + new_callable=AsyncMock, return_value=files_payload) as mock_list_files, \ + patch('backend.services.vectordatabase_service.delete_file') as mock_delete_file, \ + patch('backend.services.vectordatabase_service.ElasticSearchService.delete_index', + new_callable=AsyncMock, return_value={"status": "success"}) as mock_delete_index: + mock_delete_file.side_effect = [ + {"success": True}, + {"success": False, "error": "minio failed"}, + ] + + async def run_test(): + return await ElasticSearchService.full_delete_knowledge_base( + index_name="kb-2", + vdb_core=mock_vdb_core, + user_id="user-2", + ) + + result = asyncio.run(run_test()) + + # MinIO summary should reflect one success and one failure (line 270 hit) + self.assertEqual(result["minio_cleanup"]["deleted_count"], 1) + self.assertEqual(result["minio_cleanup"]["failed_count"], 1) + # Redis cleanup error should be surfaced + self.assertIn("error", result["redis_cleanup"]) + mock_list_files.assert_awaited_once() + mock_delete_index.assert_awaited_once_with("kb-2", mock_vdb_core, "user-2") + + @patch('backend.services.vectordatabase_service.create_knowledge_record') + def test_create_knowledge_base_create_index_failure(self, mock_create_record): + """create_knowledge_base raises when index creation fails.""" + mock_create_record.return_value = { + "knowledge_id": 1, + "index_name": "1-uuid", + "knowledge_name": "kb" + } + self.mock_vdb_core.create_index.return_value = False + + with self.assertRaises(Exception) as exc: + ElasticSearchService.create_knowledge_base( + knowledge_name="kb", + embedding_dim=256, + vdb_core=self.mock_vdb_core, + user_id="user-1", + tenant_id="tenant-1", + ) + + self.assertIn("Failed to create index", str(exc.exception)) + + @patch('backend.services.vectordatabase_service.create_knowledge_record') + def test_create_knowledge_base_raises_on_exception(self, mock_create_record): + """create_knowledge_base wraps unexpected errors.""" + mock_create_record.return_value = { + "knowledge_id": 2, + "index_name": "2-uuid", + "knowledge_name": "kb2" + } + self.mock_vdb_core.create_index.side_effect = Exception("boom") + + with self.assertRaises(Exception) as exc: + ElasticSearchService.create_knowledge_base( + knowledge_name="kb2", + embedding_dim=128, + vdb_core=self.mock_vdb_core, + user_id="user-2", + tenant_id="tenant-2", + ) + + self.assertIn("Error creating knowledge base", str(exc.exception)) + + @patch('backend.services.vectordatabase_service.get_knowledge_record') + def test_index_documents_default_batch_without_tenant(self, mock_get_record): + """index_documents defaults embedding batch size to 10 when tenant is missing.""" + mock_get_record.return_value = None + self.mock_vdb_core.check_index_exists.return_value = True + self.mock_vdb_core.vectorize_documents.return_value = 1 + + data = [{ + "path_or_url": "p1", + "content": "c1", + "metadata": {"title": "t1"}, + }] + embedding = MagicMock() + embedding.model = "model-x" + + result = ElasticSearchService.index_documents( + embedding_model=embedding, + index_name="idx", + data=data, + vdb_core=self.mock_vdb_core, + ) + + self.assertTrue(result["success"]) + _, kwargs = self.mock_vdb_core.vectorize_documents.call_args + self.assertEqual(kwargs["embedding_batch_size"], 10) + + @patch('backend.services.vectordatabase_service.tenant_config_manager') + @patch('backend.services.vectordatabase_service.get_knowledge_record') + @patch('backend.services.vectordatabase_service.get_redis_service') + def test_index_documents_updates_final_progress(self, mock_get_redis, mock_get_record, mock_tenant_cfg): + """index_documents sends final progress update to Redis when task_id is provided.""" + mock_get_record.return_value = {"tenant_id": "tenant-1"} + mock_tenant_cfg.get_model_config.return_value = {"chunk_batch": 4} + mock_redis = MagicMock() + mock_get_redis.return_value = mock_redis + + self.mock_vdb_core.check_index_exists.return_value = True + self.mock_vdb_core.vectorize_documents.return_value = 2 + + data = [ + {"path_or_url": "p1", "content": "c1", "metadata": {}}, + {"path_or_url": "p2", "content": "c2", "metadata": {}}, + ] + + result = ElasticSearchService.index_documents( + embedding_model=self.mock_embedding, + index_name="idx", + data=data, + vdb_core=self.mock_vdb_core, + task_id="task-xyz", + ) + + self.assertTrue(result["success"]) + mock_redis.save_progress_info.assert_called() + last_call = mock_redis.save_progress_info.call_args_list[-1] + self.assertEqual(last_call[0], ("task-xyz", 2, 2)) + + @patch('backend.services.vectordatabase_service.get_redis_service') + @patch('backend.services.vectordatabase_service.get_knowledge_record') + @patch('backend.services.vectordatabase_service.tenant_config_manager') + def test_index_documents_progress_init_and_final_errors(self, mock_tenant_cfg, mock_get_record, mock_get_redis): + """index_documents should continue when progress save fails during init and final updates.""" + mock_get_record.return_value = {"tenant_id": "tenant-1"} + mock_tenant_cfg.get_model_config.return_value = {"chunk_batch": 4} + + mock_redis = MagicMock() + # First call (init) raises, second call (final) raises + mock_redis.save_progress_info.side_effect = [Exception("init fail"), Exception("final fail")] + mock_redis.is_task_cancelled.return_value = False + mock_get_redis.return_value = mock_redis + + self.mock_vdb_core.check_index_exists.return_value = True + self.mock_vdb_core.vectorize_documents.return_value = 1 + + data = [{"path_or_url": "p1", "content": "c1", "metadata": {}}] + + result = ElasticSearchService.index_documents( + embedding_model=self.mock_embedding, + index_name="idx", + data=data, + vdb_core=self.mock_vdb_core, + task_id="task-err", + ) + + self.assertTrue(result["success"]) + # two attempts to save progress (init and final) + self.assertEqual(mock_redis.save_progress_info.call_count, 2) + + @patch('backend.services.vectordatabase_service.get_all_files_status') + @patch('backend.services.vectordatabase_service.get_redis_service') + def test_list_files_handles_invalid_create_time_and_failed_tasks(self, mock_get_redis, mock_get_files_status): + """list_files handles invalid timestamps, progress overrides, and error info.""" + self.mock_vdb_core.get_documents_detail.return_value = [ + { + "path_or_url": "file1", + "filename": "file1.txt", + "file_size": 10, + "create_time": "invalid", + "chunk_count": 1 + } + ] + self.mock_vdb_core.client.count.return_value = {"count": 7} + + mock_get_files_status.return_value = { + "file1": { + "state": "PROCESS_FAILED", + "latest_task_id": "task-1", + "processed_chunks": 1, + "total_chunks": 5, + "source_type": "minio", + "original_filename": "file1.txt" + } + } + + mock_redis = MagicMock() + mock_redis.get_progress_info.return_value = { + "processed_chunks": 2, + "total_chunks": 5 + } + mock_redis.get_error_info.return_value = "boom error" + mock_get_redis.return_value = mock_redis + + async def run_test(): + return await ElasticSearchService.list_files( + index_name="idx", + include_chunks=False, + vdb_core=self.mock_vdb_core + ) + + result = asyncio.run(run_test()) + self.assertEqual(len(result["files"]), 1) + file_info = result["files"][0] + self.assertEqual(file_info["chunk_count"], 7) + self.assertEqual(file_info["file_size"], 10) + self.assertEqual(file_info["status"], "PROCESS_FAILED") + self.assertEqual(file_info["processed_chunk_num"], 2) + self.assertEqual(file_info["total_chunk_num"], 5) + self.assertEqual(file_info["error_reason"], "boom error") + self.assertIsInstance(file_info["create_time"], int) + + @patch('backend.services.vectordatabase_service.get_all_files_status') + @patch('backend.services.vectordatabase_service.get_redis_service') + def test_list_files_warning_and_progress_error_branches(self, mock_get_redis, mock_get_files_status): + """list_files covers chunk count warning, file size error, progress overrides, and redis failures.""" + # Existing ES file triggers count warning (lines 749-750 and 910-916) + self.mock_vdb_core.get_documents_detail.return_value = [ + { + "path_or_url": "file-es", + "filename": "file-es.txt", + "file_size": 5, + "create_time": "2024-01-01T00:00:00", + "chunk_count": 1 + } + ] + # First count call for ES file, second for completed file at include_chunks=False + self.mock_vdb_core.client.count.side_effect = [ + Exception("count fail initial"), + Exception("count fail final"), + ] + + # Two tasks from Celery status to exercise progress success and failure + mock_get_files_status.return_value = { + "file-processing": { + "state": "PROCESSING", + "latest_task_id": "t1", + "source_type": "minio", + "original_filename": "fp.txt", + "processed_chunks": 1, + "total_chunks": 3, + }, + "file-failed": { + "state": "PROCESS_FAILED", + "latest_task_id": "t2", + "source_type": "minio", + "original_filename": "ff.txt", + }, + } + + mock_redis = MagicMock() + # Progress info: first returns dict, second raises to hit lines 815-816 + mock_redis.get_progress_info.side_effect = [ + {"processed_chunks": 2, "total_chunks": 4}, + Exception("progress boom"), + ] + # get_error_info raises to hit 847-848 + mock_redis.get_error_info.side_effect = Exception("error info boom") + mock_get_redis.return_value = mock_redis + + with patch('backend.services.vectordatabase_service.get_file_size', side_effect=Exception("size boom")): + async def run_test(): + return await ElasticSearchService.list_files( + index_name="idx", + include_chunks=False, + vdb_core=self.mock_vdb_core + ) + + result = asyncio.run(run_test()) + + # Ensure both ES file and processing files are returned + paths = {f["path_or_url"] for f in result["files"]} + self.assertIn("file-es", paths) + self.assertIn("file-processing", paths) + self.assertIn("file-failed", paths) + # Processing file gets progress override + proc_file = next(f for f in result["files"] if f["path_or_url"] == "file-processing") + self.assertEqual(proc_file["processed_chunk_num"], 2) + self.assertEqual(proc_file["total_chunk_num"], 4) + # Failed file retains default chunk_count fallback + failed_file = next(f for f in result["files"] if f["path_or_url"] == "file-failed") + self.assertEqual(failed_file.get("chunk_count", 0), 0) + + @patch('backend.services.vectordatabase_service.get_all_files_status', return_value={}) + def test_list_files_with_chunks_updates_chunk_count(self, mock_get_files_status): + """list_files include_chunks path refreshes chunk counts.""" + self.mock_vdb_core.get_documents_detail.return_value = [ + { + "path_or_url": "file1", + "filename": "file1.txt", + "file_size": 10, + "create_time": "2024-01-01T00:00:00" + } + ] + self.mock_vdb_core.multi_search.return_value = { + "responses": [ + { + "hits": { + "hits": [ + {"_source": { + "id": "doc1", + "title": "t", + "content": "c", + "create_time": "2024-01-01T00:00:00" + }} + ] + } + } + ] + } + self.mock_vdb_core.client.count.return_value = {"count": 2} + + async def run_test(): + return await ElasticSearchService.list_files( + index_name="idx", + include_chunks=True, + vdb_core=self.mock_vdb_core + ) + + result = asyncio.run(run_test()) + file_info = result["files"][0] + self.assertEqual(file_info["chunk_count"], 2) + self.assertEqual(len(file_info["chunks"]), 1) + + def test_summary_index_name_streams_generator_error(self): + """summary_index_name streams error payloads when generator fails.""" + class BadIterable: + def __iter__(self): + raise RuntimeError("stream failure") + + with patch('utils.document_vector_utils.process_documents_for_clustering') as mock_process_docs, \ + patch('utils.document_vector_utils.kmeans_cluster_documents') as mock_cluster, \ + patch('utils.document_vector_utils.summarize_clusters_map_reduce') as mock_summarize, \ + patch('utils.document_vector_utils.merge_cluster_summaries', return_value=BadIterable()): + + mock_process_docs.return_value = ( + {"doc1": {"chunks": [{"content": "x"}]}}, + {"doc1": MagicMock()} + ) + mock_cluster.return_value = {"doc1": 0} + mock_summarize.return_value = {0: "summary"} + + async def run_test(): + response = await self.es_service.summary_index_name( + index_name="idx", + batch_size=100, + vdb_core=self.mock_vdb_core, + language="en", + model_id=None, + tenant_id="tenant-1", + ) + messages = [] + async for chunk in response.body_iterator: + messages.append(chunk) + break + return messages + + messages = asyncio.run(run_test()) + self.assertTrue(any("error" in msg for msg in messages)) + if __name__ == '__main__': unittest.main() diff --git a/test/backend/test_cluster_summarization.py b/test/backend/test_cluster_summarization.py index 9fd0a3b91..1e9bc1658 100644 --- a/test/backend/test_cluster_summarization.py +++ b/test/backend/test_cluster_summarization.py @@ -10,11 +10,39 @@ import numpy as np import pytest -# Add backend to path +# Mock consts module before patching backend.database.client to avoid ImportError +# backend.database.client imports from consts.const, so we need to mock it first +consts_mock = MagicMock() +consts_const_mock = MagicMock() +# Set required constants that backend.database.client might use +consts_const_mock.MINIO_ENDPOINT = "http://localhost:9000" +consts_const_mock.MINIO_ACCESS_KEY = "test_access_key" +consts_const_mock.MINIO_SECRET_KEY = "test_secret_key" +consts_const_mock.MINIO_REGION = "us-east-1" +consts_const_mock.MINIO_DEFAULT_BUCKET = "test-bucket" +consts_const_mock.POSTGRES_HOST = "localhost" +consts_const_mock.POSTGRES_USER = "test_user" +consts_const_mock.NEXENT_POSTGRES_PASSWORD = "test_password" +consts_const_mock.POSTGRES_DB = "test_db" +consts_const_mock.POSTGRES_PORT = 5432 +consts_const_mock.LANGUAGE = {"ZH": "zh", "EN": "en"} +consts_mock.const = consts_const_mock +sys.modules['consts'] = consts_mock +sys.modules['consts.const'] = consts_const_mock + +# Add backend to path before patching backend modules current_dir = os.path.dirname(os.path.abspath(__file__)) backend_dir = os.path.abspath(os.path.join(current_dir, "../../backend")) sys.path.insert(0, backend_dir) +# Patch storage factory and MinIO config validation to avoid errors during initialization +# These patches must be started before any imports that use MinioClient +storage_client_mock = MagicMock() +minio_client_mock = MagicMock() +patch('nexent.storage.storage_client_factory.create_storage_client_from_config', return_value=storage_client_mock).start() +patch('nexent.storage.minio_config.MinIOStorageConfig.validate', lambda self: None).start() +patch('backend.database.client.MinioClient', return_value=minio_client_mock).start() + from backend.utils.document_vector_utils import ( extract_cluster_content, summarize_cluster, diff --git a/test/backend/test_config_service.py b/test/backend/test_config_service.py index a30e86bd7..9935797cc 100644 --- a/test/backend/test_config_service.py +++ b/test/backend/test_config_service.py @@ -431,5 +431,45 @@ async def test_startup_initialization_with_custom_version(self, mock_logger, moc assert version_logged, "Custom APP version should be logged" +class TestTenantConfigService: + """Unit tests for tenant_config_service helpers""" + + @patch('backend.services.tenant_config_service.get_selected_knowledge_list') + def test_build_knowledge_name_mapping_prefers_knowledge_name(self, mock_get_selected): + """Ensure knowledge_name is used as key when present.""" + mock_get_selected.return_value = [ + {"knowledge_name": "User Docs", "index_name": "index_user_docs"}, + {"knowledge_name": "API Docs", "index_name": "index_api_docs"}, + ] + + from backend.services.tenant_config_service import build_knowledge_name_mapping + + mapping = build_knowledge_name_mapping(tenant_id="t1", user_id="u1") + + assert mapping == { + "User Docs": "index_user_docs", + "API Docs": "index_api_docs", + } + mock_get_selected.assert_called_once_with(tenant_id="t1", user_id="u1") + + @patch('backend.services.tenant_config_service.get_selected_knowledge_list') + def test_build_knowledge_name_mapping_fallbacks_to_index_name(self, mock_get_selected): + """Fallback to index_name when knowledge_name is missing.""" + mock_get_selected.return_value = [ + {"index_name": "index_fallback_only"}, + {"knowledge_name": None, "index_name": "index_none_name"}, + ] + + from backend.services.tenant_config_service import build_knowledge_name_mapping + + mapping = build_knowledge_name_mapping(tenant_id="t2", user_id="u2") + + assert mapping == { + "index_fallback_only": "index_fallback_only", + "index_none_name": "index_none_name", + } + mock_get_selected.assert_called_once_with(tenant_id="t2", user_id="u2") + + if __name__ == '__main__': pytest.main() diff --git a/test/backend/test_document_vector_integration.py b/test/backend/test_document_vector_integration.py index 015818d32..8e05abe86 100644 --- a/test/backend/test_document_vector_integration.py +++ b/test/backend/test_document_vector_integration.py @@ -11,11 +11,38 @@ import numpy as np import pytest -# Add backend to path +# Mock consts module before patching backend.database.client to avoid ImportError +# backend.database.client imports from consts.const, so we need to mock it first +consts_mock = MagicMock() +consts_const_mock = MagicMock() +# Set required constants that backend.database.client might use +consts_const_mock.MINIO_ENDPOINT = "http://localhost:9000" +consts_const_mock.MINIO_ACCESS_KEY = "test_access_key" +consts_const_mock.MINIO_SECRET_KEY = "test_secret_key" +consts_const_mock.MINIO_REGION = "us-east-1" +consts_const_mock.MINIO_DEFAULT_BUCKET = "test-bucket" +consts_const_mock.POSTGRES_HOST = "localhost" +consts_const_mock.POSTGRES_USER = "test_user" +consts_const_mock.NEXENT_POSTGRES_PASSWORD = "test_password" +consts_const_mock.POSTGRES_DB = "test_db" +consts_const_mock.POSTGRES_PORT = 5432 +consts_mock.const = consts_const_mock +sys.modules['consts'] = consts_mock +sys.modules['consts.const'] = consts_const_mock + +# Add backend to path before patching backend modules current_dir = os.path.dirname(os.path.abspath(__file__)) backend_dir = os.path.abspath(os.path.join(current_dir, "../../backend")) sys.path.insert(0, backend_dir) +# Patch storage factory and MinIO config validation to avoid errors during initialization +# These patches must be started before any imports that use MinioClient +storage_client_mock = MagicMock() +minio_client_mock = MagicMock() +patch('nexent.storage.storage_client_factory.create_storage_client_from_config', return_value=storage_client_mock).start() +patch('nexent.storage.minio_config.MinIOStorageConfig.validate', lambda self: None).start() +patch('backend.database.client.MinioClient', return_value=minio_client_mock).start() + from backend.utils.document_vector_utils import ( calculate_document_embedding, auto_determine_k, diff --git a/test/backend/test_document_vector_utils.py b/test/backend/test_document_vector_utils.py index f03ed3346..1b4f89997 100644 --- a/test/backend/test_document_vector_utils.py +++ b/test/backend/test_document_vector_utils.py @@ -10,11 +10,39 @@ import numpy as np import pytest -# Add backend to path +# Mock consts module before patching backend.database.client to avoid ImportError +# backend.database.client imports from consts.const, so we need to mock it first +consts_mock = MagicMock() +consts_const_mock = MagicMock() +# Set required constants that backend.database.client might use +consts_const_mock.MINIO_ENDPOINT = "http://localhost:9000" +consts_const_mock.MINIO_ACCESS_KEY = "test_access_key" +consts_const_mock.MINIO_SECRET_KEY = "test_secret_key" +consts_const_mock.MINIO_REGION = "us-east-1" +consts_const_mock.MINIO_DEFAULT_BUCKET = "test-bucket" +consts_const_mock.POSTGRES_HOST = "localhost" +consts_const_mock.POSTGRES_USER = "test_user" +consts_const_mock.NEXENT_POSTGRES_PASSWORD = "test_password" +consts_const_mock.POSTGRES_DB = "test_db" +consts_const_mock.POSTGRES_PORT = 5432 +consts_const_mock.LANGUAGE = {"ZH": "zh", "EN": "en"} +consts_mock.const = consts_const_mock +sys.modules['consts'] = consts_mock +sys.modules['consts.const'] = consts_const_mock + +# Add backend to path before patching backend modules current_dir = os.path.dirname(os.path.abspath(__file__)) backend_dir = os.path.abspath(os.path.join(current_dir, "../../backend")) sys.path.insert(0, backend_dir) +# Patch storage factory and MinIO config validation to avoid errors during initialization +# These patches must be started before any imports that use MinioClient +storage_client_mock = MagicMock() +minio_client_mock = MagicMock() +patch('nexent.storage.storage_client_factory.create_storage_client_from_config', return_value=storage_client_mock).start() +patch('nexent.storage.minio_config.MinIOStorageConfig.validate', lambda self: None).start() +patch('backend.database.client.MinioClient', return_value=minio_client_mock).start() + from backend.utils.document_vector_utils import ( calculate_document_embedding, auto_determine_k, @@ -226,6 +254,28 @@ def test_summarize_document_with_model_placeholder(self): assert isinstance(result, str) assert len(result) > 0 + def test_summarize_document_with_model_success(self): + """Test document summarization when model config exists and LLM returns value""" + with patch('backend.utils.document_vector_utils.get_model_by_model_id') as mock_get_model, \ + patch('backend.utils.document_vector_utils.call_llm_for_system_prompt') as mock_llm: + mock_get_model.return_value = {"id": 1} + mock_llm.return_value = "Generated summary\n" + + result = summarize_document( + document_content="LLM content", + filename="doc.pdf", + language="en", + max_words=50, + model_id=1, + tenant_id="tenant" + ) + + assert result == "Generated summary" + mock_llm.assert_called_once() + call_args = mock_llm.call_args.kwargs + assert call_args["model_id"] == 1 + assert call_args["tenant_id"] == "tenant" + class TestSummarizeCluster: """Test cluster summarization""" @@ -250,6 +300,27 @@ def test_summarize_cluster_with_model_placeholder(self): assert isinstance(result, str) assert len(result) > 0 + def test_summarize_cluster_with_model_success(self): + """Test cluster summarization when model config exists and LLM returns value""" + with patch('backend.utils.document_vector_utils.get_model_by_model_id') as mock_get_model, \ + patch('backend.utils.document_vector_utils.call_llm_for_system_prompt') as mock_llm: + mock_get_model.return_value = {"id": 1} + mock_llm.return_value = "Cluster summary text " + + result = summarize_cluster( + document_summaries=["Doc 1 summary", "Doc 2 summary"], + language="en", + max_words=120, + model_id=1, + tenant_id="tenant" + ) + + assert result == "Cluster summary text" + mock_llm.assert_called_once() + call_args = mock_llm.call_args.kwargs + assert call_args["model_id"] == 1 + assert call_args["tenant_id"] == "tenant" + class TestSummarizeClustersMapReduce: """Test map-reduce cluster summarization""" diff --git a/test/backend/test_document_vector_utils_coverage.py b/test/backend/test_document_vector_utils_coverage.py index b442e47e4..82ac1d646 100644 --- a/test/backend/test_document_vector_utils_coverage.py +++ b/test/backend/test_document_vector_utils_coverage.py @@ -10,11 +10,38 @@ import numpy as np import pytest -# Add backend to path +# Mock consts module before patching backend.database.client to avoid ImportError +# backend.database.client imports from consts.const, so we need to mock it first +consts_mock = MagicMock() +consts_const_mock = MagicMock() +# Set required constants that backend.database.client might use +consts_const_mock.MINIO_ENDPOINT = "http://localhost:9000" +consts_const_mock.MINIO_ACCESS_KEY = "test_access_key" +consts_const_mock.MINIO_SECRET_KEY = "test_secret_key" +consts_const_mock.MINIO_REGION = "us-east-1" +consts_const_mock.MINIO_DEFAULT_BUCKET = "test-bucket" +consts_const_mock.POSTGRES_HOST = "localhost" +consts_const_mock.POSTGRES_USER = "test_user" +consts_const_mock.NEXENT_POSTGRES_PASSWORD = "test_password" +consts_const_mock.POSTGRES_DB = "test_db" +consts_const_mock.POSTGRES_PORT = 5432 +consts_mock.const = consts_const_mock +sys.modules['consts'] = consts_mock +sys.modules['consts.const'] = consts_const_mock + +# Add backend to path before patching backend modules current_dir = os.path.dirname(os.path.abspath(__file__)) backend_dir = os.path.abspath(os.path.join(current_dir, "../../backend")) sys.path.insert(0, backend_dir) +# Patch storage factory and MinIO config validation to avoid errors during initialization +# These patches must be started before any imports that use MinioClient +storage_client_mock = MagicMock() +minio_client_mock = MagicMock() +patch('nexent.storage.storage_client_factory.create_storage_client_from_config', return_value=storage_client_mock).start() +patch('nexent.storage.minio_config.MinIOStorageConfig.validate', lambda self: None).start() +patch('backend.database.client.MinioClient', return_value=minio_client_mock).start() + from backend.utils.document_vector_utils import ( get_documents_from_es, process_documents_for_clustering, diff --git a/test/backend/test_summary_formatting.py b/test/backend/test_summary_formatting.py index 31b656e55..22f8dec36 100644 --- a/test/backend/test_summary_formatting.py +++ b/test/backend/test_summary_formatting.py @@ -5,10 +5,38 @@ import pytest import sys import os +from unittest.mock import MagicMock, patch -# Add backend to path +# Mock consts module before patching backend.database.client to avoid ImportError +# backend.database.client imports from consts.const, so we need to mock it first +consts_mock = MagicMock() +consts_const_mock = MagicMock() +# Set required constants that backend.database.client might use +consts_const_mock.MINIO_ENDPOINT = "http://localhost:9000" +consts_const_mock.MINIO_ACCESS_KEY = "test_access_key" +consts_const_mock.MINIO_SECRET_KEY = "test_secret_key" +consts_const_mock.MINIO_REGION = "us-east-1" +consts_const_mock.MINIO_DEFAULT_BUCKET = "test-bucket" +consts_const_mock.POSTGRES_HOST = "localhost" +consts_const_mock.POSTGRES_USER = "test_user" +consts_const_mock.NEXENT_POSTGRES_PASSWORD = "test_password" +consts_const_mock.POSTGRES_DB = "test_db" +consts_const_mock.POSTGRES_PORT = 5432 +consts_mock.const = consts_const_mock +sys.modules['consts'] = consts_mock +sys.modules['consts.const'] = consts_const_mock + +# Add backend to path before patching backend modules sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'backend')) +# Patch storage factory and MinIO config validation to avoid errors during initialization +# These patches must be started before any imports that use MinioClient +storage_client_mock = MagicMock() +minio_client_mock = MagicMock() +patch('nexent.storage.storage_client_factory.create_storage_client_from_config', return_value=storage_client_mock).start() +patch('nexent.storage.minio_config.MinIOStorageConfig.validate', lambda self: None).start() +patch('backend.database.client.MinioClient', return_value=minio_client_mock).start() + from utils.document_vector_utils import merge_cluster_summaries diff --git a/test/backend/utils/test_file_management_utils.py b/test/backend/utils/test_file_management_utils.py index eaa3a1261..02553db8f 100644 --- a/test/backend/utils/test_file_management_utils.py +++ b/test/backend/utils/test_file_management_utils.py @@ -300,6 +300,123 @@ async def test_get_all_files_status_connect_error_and_non200(fmu, monkeypatch): assert out2 == {} +@pytest.mark.asyncio +async def test_get_all_files_status_no_tasks_returns_empty(fmu, monkeypatch): + fake_client = _FakeAsyncClient(_Resp(200, [])) + monkeypatch.setattr(fmu, "httpx", types.SimpleNamespace(AsyncClient=lambda: fake_client)) + + out = await fmu.get_all_files_status("idx-empty") + assert out == {} + + +@pytest.mark.asyncio +async def test_get_all_files_status_forward_updates_and_redis_progress(fmu, monkeypatch): + tasks_list = [ + { + "id": "10", + "task_name": "process", + "index_name": "idx", + "path_or_url": "/p2", + "original_filename": "f2", + "source_type": "local", + "status": "SUCCESS", + "created_at": 1, + }, + { + "id": "20", + "task_name": "forward", + "index_name": "idx", + "path_or_url": "/p2", + "original_filename": "f2", + "source_type": "local", + "status": "STARTED", + "created_at": 5, # later than process to trigger forward branch + }, + ] + fake_client = _FakeAsyncClient(_Resp(200, tasks_list)) + monkeypatch.setattr(fmu, "httpx", types.SimpleNamespace(AsyncClient=lambda: fake_client)) + async def _fake_convert(*a, **k): + return "FORWARDING" + monkeypatch.setattr(fmu, "_convert_to_custom_state", _fake_convert) + + # Stub redis_service with progress info + services_pkg = types.ModuleType("services") + services_pkg.__path__ = [] + sys.modules["services"] = services_pkg + redis_mod = types.ModuleType("services.redis_service") + redis_mod.get_redis_service = lambda: types.SimpleNamespace( + get_progress_info=lambda task_id: {"processed_chunks": 7, "total_chunks": 9} + ) + sys.modules["services.redis_service"] = redis_mod + + out = await fmu.get_all_files_status("idx") + assert out["/p2"]["state"] == "FORWARDING" + assert out["/p2"]["latest_task_id"] == "20" + assert out["/p2"]["processed_chunks"] == 7 + assert out["/p2"]["total_chunks"] == 9 + + +@pytest.mark.asyncio +async def test_get_all_files_status_redis_progress_exception(fmu, monkeypatch): + tasks_list = [ + { + "id": "30", + "task_name": "forward", + "index_name": "idx", + "path_or_url": "/p3", + "original_filename": "f3", + "source_type": "local", + "status": "STARTED", + "created_at": 2, + }, + ] + fake_client = _FakeAsyncClient(_Resp(200, tasks_list)) + monkeypatch.setattr(fmu, "httpx", types.SimpleNamespace(AsyncClient=lambda: fake_client)) + async def _fake_convert(*a, **k): + return "FORWARDING" + monkeypatch.setattr(fmu, "_convert_to_custom_state", _fake_convert) + + # Redis service raising exception to hit exception path + services_pkg = types.ModuleType("services") + services_pkg.__path__ = [] + sys.modules["services"] = services_pkg + redis_mod = types.ModuleType("services.redis_service") + def _boom(): + raise RuntimeError("redis down") + redis_mod.get_redis_service = lambda: types.SimpleNamespace(get_progress_info=lambda task_id: _boom()) + sys.modules["services.redis_service"] = redis_mod + + out = await fmu.get_all_files_status("idx") + assert out["/p3"]["state"] == "FORWARDING" + assert out["/p3"]["processed_chunks"] is None + assert out["/p3"]["total_chunks"] is None + + +@pytest.mark.asyncio +async def test_get_all_files_status_outer_exception_returns_empty(fmu, monkeypatch): + tasks_list = [ + { + "id": "40", + "task_name": "process", + "index_name": "idx", + "path_or_url": "/p4", + "original_filename": "f4", + "source_type": "local", + "status": "SUCCESS", + "created_at": 1, + }, + ] + fake_client = _FakeAsyncClient(_Resp(200, tasks_list)) + monkeypatch.setattr(fmu, "httpx", types.SimpleNamespace(AsyncClient=lambda: fake_client)) + + def _boom(*a, **k): + raise RuntimeError("convert failed") + monkeypatch.setattr(fmu, "_convert_to_custom_state", _boom) + + out = await fmu.get_all_files_status("idx") + assert out == {} + + # -------------------- _convert_to_custom_state -------------------- @@ -379,3 +496,211 @@ def test_get_file_size_invalid_source_type(fmu): assert fmu.get_file_size("http", "http://x") == 0 +# -------------------- Additional coverage for get_all_files_status -------------------- + + +@pytest.mark.asyncio +async def test_get_all_files_status_forward_created_at_not_greater(fmu, monkeypatch): + """Test forward task with created_at not greater than latest_forward_created_at (line 195)""" + tasks_list = [ + { + "id": "20", + "task_name": "forward", + "index_name": "idx", + "path_or_url": "/p5", + "original_filename": "f5", + "source_type": "local", + "status": "STARTED", + "created_at": 5, + }, + { + "id": "21", + "task_name": "forward", + "index_name": "idx", + "path_or_url": "/p5", + "original_filename": "f5", + "source_type": "local", + "status": "SUCCESS", + "created_at": 3, # Less than previous forward task, should not update + }, + ] + fake_client = _FakeAsyncClient(_Resp(200, tasks_list)) + monkeypatch.setattr(fmu, "httpx", types.SimpleNamespace(AsyncClient=lambda: fake_client)) + async def _fake_convert(*a, **k): + return "FORWARDING" + monkeypatch.setattr(fmu, "_convert_to_custom_state", _fake_convert) + + out = await fmu.get_all_files_status("idx") + # Should use the first forward task (id=20) as latest since it has higher created_at + assert out["/p5"]["latest_task_id"] == "20" + + +@pytest.mark.asyncio +async def test_get_all_files_status_empty_task_id(fmu, monkeypatch): + """Test when task_id is empty string (line 221 - not entering if branch)""" + tasks_list = [ + { + "id": "", # Empty task_id + "task_name": "process", + "index_name": "idx", + "path_or_url": "/p6", + "original_filename": "f6", + "source_type": "local", + "status": "SUCCESS", + "created_at": 1, + }, + ] + fake_client = _FakeAsyncClient(_Resp(200, tasks_list)) + monkeypatch.setattr(fmu, "httpx", types.SimpleNamespace(AsyncClient=lambda: fake_client)) + async def _fake_convert(*a, **k): + return "COMPLETED" + monkeypatch.setattr(fmu, "_convert_to_custom_state", _fake_convert) + + # Stub redis_service to ensure it's not called + services_pkg = types.ModuleType("services") + services_pkg.__path__ = [] + sys.modules["services"] = services_pkg + redis_mod = types.ModuleType("services.redis_service") + redis_called = {"called": False} + def _track_call(task_id): + redis_called["called"] = True + return {} + redis_mod.get_redis_service = lambda: types.SimpleNamespace( + get_progress_info=_track_call + ) + sys.modules["services.redis_service"] = redis_mod + + out = await fmu.get_all_files_status("idx") + assert out["/p6"]["latest_task_id"] == "" + # Redis should not be called when task_id is empty + assert redis_called["called"] is False + + +@pytest.mark.asyncio +async def test_get_all_files_status_redis_progress_info_none(fmu, monkeypatch): + """Test when progress_info is None (line 226, 237 - entering else branch)""" + tasks_list = [ + { + "id": "50", + "task_name": "forward", + "index_name": "idx", + "path_or_url": "/p7", + "original_filename": "f7", + "source_type": "local", + "status": "STARTED", + "created_at": 1, + "processed_chunks": 5, + "total_chunks": 10, + }, + ] + fake_client = _FakeAsyncClient(_Resp(200, tasks_list)) + monkeypatch.setattr(fmu, "httpx", types.SimpleNamespace(AsyncClient=lambda: fake_client)) + async def _fake_convert(*a, **k): + return "FORWARDING" + monkeypatch.setattr(fmu, "_convert_to_custom_state", _fake_convert) + + # Redis service returning None (line 226, 237) + services_pkg = types.ModuleType("services") + services_pkg.__path__ = [] + sys.modules["services"] = services_pkg + redis_mod = types.ModuleType("services.redis_service") + redis_mod.get_redis_service = lambda: types.SimpleNamespace( + get_progress_info=lambda task_id: None # Returns None to trigger else branch + ) + sys.modules["services.redis_service"] = redis_mod + + out = await fmu.get_all_files_status("idx") + assert out["/p7"]["state"] == "FORWARDING" + assert out["/p7"]["latest_task_id"] == "50" + # Should use task state values when progress_info is None + assert out["/p7"]["processed_chunks"] == 5 + assert out["/p7"]["total_chunks"] == 10 + + +@pytest.mark.asyncio +async def test_get_all_files_status_redis_processed_chunks_none(fmu, monkeypatch): + """Test when redis_processed is None (line 230 - not entering if branch)""" + tasks_list = [ + { + "id": "60", + "task_name": "forward", + "index_name": "idx", + "path_or_url": "/p8", + "original_filename": "f8", + "source_type": "local", + "status": "STARTED", + "created_at": 1, + "processed_chunks": 3, + "total_chunks": 8, + }, + ] + fake_client = _FakeAsyncClient(_Resp(200, tasks_list)) + monkeypatch.setattr(fmu, "httpx", types.SimpleNamespace(AsyncClient=lambda: fake_client)) + async def _fake_convert(*a, **k): + return "FORWARDING" + monkeypatch.setattr(fmu, "_convert_to_custom_state", _fake_convert) + + # Redis service returning progress_info with processed_chunks as None (line 230) + services_pkg = types.ModuleType("services") + services_pkg.__path__ = [] + sys.modules["services"] = services_pkg + redis_mod = types.ModuleType("services.redis_service") + redis_mod.get_redis_service = lambda: types.SimpleNamespace( + get_progress_info=lambda task_id: { + "processed_chunks": None, # None to skip line 230 if branch + "total_chunks": 15 + } + ) + sys.modules["services.redis_service"] = redis_mod + + out = await fmu.get_all_files_status("idx") + assert out["/p8"]["state"] == "FORWARDING" + # processed_chunks should remain from task state (3) since redis_processed is None + assert out["/p8"]["processed_chunks"] == 3 + # total_chunks should be updated from Redis (15) + assert out["/p8"]["total_chunks"] == 15 + + +@pytest.mark.asyncio +async def test_get_all_files_status_redis_total_chunks_none(fmu, monkeypatch): + """Test when redis_total is None (line 232 - not entering if branch)""" + tasks_list = [ + { + "id": "70", + "task_name": "forward", + "index_name": "idx", + "path_or_url": "/p9", + "original_filename": "f9", + "source_type": "local", + "status": "STARTED", + "created_at": 1, + "processed_chunks": 4, + "total_chunks": 12, + }, + ] + fake_client = _FakeAsyncClient(_Resp(200, tasks_list)) + monkeypatch.setattr(fmu, "httpx", types.SimpleNamespace(AsyncClient=lambda: fake_client)) + async def _fake_convert(*a, **k): + return "FORWARDING" + monkeypatch.setattr(fmu, "_convert_to_custom_state", _fake_convert) + + # Redis service returning progress_info with total_chunks as None (line 232) + services_pkg = types.ModuleType("services") + services_pkg.__path__ = [] + sys.modules["services"] = services_pkg + redis_mod = types.ModuleType("services.redis_service") + redis_mod.get_redis_service = lambda: types.SimpleNamespace( + get_progress_info=lambda task_id: { + "processed_chunks": 6, + "total_chunks": None # None to skip line 232 if branch + } + ) + sys.modules["services.redis_service"] = redis_mod + + out = await fmu.get_all_files_status("idx") + assert out["/p9"]["state"] == "FORWARDING" + # processed_chunks should be updated from Redis (6) + assert out["/p9"]["processed_chunks"] == 6 + # total_chunks should remain from task state (12) since redis_total is None + assert out["/p9"]["total_chunks"] == 12 + diff --git a/test/backend/utils/test_llm_utils.py b/test/backend/utils/test_llm_utils.py index 545bdf776..50857e91b 100644 --- a/test/backend/utils/test_llm_utils.py +++ b/test/backend/utils/test_llm_utils.py @@ -74,7 +74,7 @@ class TestCallLLMForSystemPrompt(unittest.TestCase): def setUp(self): self.test_model_id = 1 - @patch('backend.utils.llm_utils.OpenAIServerModel') + @patch('backend.utils.llm_utils.OpenAIModel') @patch('backend.utils.llm_utils.get_model_name_from_config') @patch('backend.utils.llm_utils.get_model_by_model_id') def test_call_llm_for_system_prompt_success( @@ -118,7 +118,7 @@ def test_call_llm_for_system_prompt_success( top_p=0.95, ) - @patch('backend.utils.llm_utils.OpenAIServerModel') + @patch('backend.utils.llm_utils.OpenAIModel') @patch('backend.utils.llm_utils.get_model_name_from_config') @patch('backend.utils.llm_utils.get_model_by_model_id') def test_call_llm_for_system_prompt_exception( diff --git a/test/sdk/core/agents/test_core_agent.py b/test/sdk/core/agents/test_core_agent.py index cb6240893..54b725620 100644 --- a/test/sdk/core/agents/test_core_agent.py +++ b/test/sdk/core/agents/test_core_agent.py @@ -1,3 +1,5 @@ +import json + import pytest from unittest.mock import MagicMock, patch from threading import Event @@ -14,22 +16,98 @@ def __init__(self, message): super().__init__(message) +class MockAgentMaxStepsError(Exception): + pass + + # Mock for smolagents and its sub-modules mock_smolagents = MagicMock() -mock_smolagents.ActionStep = MagicMock() -mock_smolagents.TaskStep = MagicMock() -mock_smolagents.SystemPromptStep = MagicMock() mock_smolagents.AgentError = MockAgentError mock_smolagents.handle_agent_output_types = MagicMock( return_value="handled_output") +mock_smolagents.utils.AgentMaxStepsError = MockAgentMaxStepsError + +# Create proper class types for isinstance checks (not MagicMock) +class MockActionStep: + def __init__(self, *args, **kwargs): + self.step_number = kwargs.get('step_number', 1) + self.timing = kwargs.get('timing', None) + self.observations_images = kwargs.get('observations_images', None) + self.model_input_messages = None + self.model_output_message = None + self.model_output = None + self.token_usage = None + self.code_action = None + self.tool_calls = None + self.observations = None + self.action_output = None + self.is_final_answer = False + self.error = None + +class MockTaskStep: + def __init__(self, *args, **kwargs): + self.task = kwargs.get('task', '') + self.task_images = kwargs.get('task_images', None) + +class MockSystemPromptStep: + def __init__(self, *args, **kwargs): + self.system_prompt = kwargs.get('system_prompt', '') + +class MockFinalAnswerStep: + def __init__(self, *args, **kwargs): + # Handle both positional and keyword arguments + if args: + self.output = args[0] + else: + self.output = kwargs.get('output', '') + +class MockPlanningStep: + def __init__(self, *args, **kwargs): + self.token_usage = kwargs.get('token_usage', None) + +class MockActionOutput: + def __init__(self, *args, **kwargs): + self.output = kwargs.get('output', None) + self.is_final_answer = kwargs.get('is_final_answer', False) + +class MockRunResult: + def __init__(self, *args, **kwargs): + self.output = kwargs.get('output', None) + self.token_usage = kwargs.get('token_usage', None) + self.steps = kwargs.get('steps', []) + self.timing = kwargs.get('timing', None) + self.state = kwargs.get('state', 'success') + +class MockCodeOutput: + """Mock object returned by python_executor.""" + def __init__(self, output=None, logs="", is_final_answer=False): + self.output = output + self.logs = logs + self.is_final_answer = is_final_answer + +# Assign proper classes to mock_smolagents +mock_smolagents.ActionStep = MockActionStep +mock_smolagents.TaskStep = MockTaskStep +mock_smolagents.SystemPromptStep = MockSystemPromptStep # Create dummy smolagents sub-modules for sub_mod in ["agents", "memory", "models", "monitoring", "utils", "local_python_executor"]: mock_module = MagicMock() setattr(mock_smolagents, sub_mod, mock_module) +# Assign classes to memory submodule +mock_smolagents.memory.ActionStep = MockActionStep +mock_smolagents.memory.TaskStep = MockTaskStep +mock_smolagents.memory.SystemPromptStep = MockSystemPromptStep +mock_smolagents.memory.FinalAnswerStep = MockFinalAnswerStep +mock_smolagents.memory.PlanningStep = MockPlanningStep +mock_smolagents.memory.ToolCall = MagicMock + +# Assign classes to agents submodule mock_smolagents.agents.CodeAgent = MagicMock +mock_smolagents.agents.ActionOutput = MockActionOutput +mock_smolagents.agents.RunResult = MockRunResult # Provide actual implementations for commonly used utils functions @@ -72,6 +150,23 @@ def mock_truncate_content(content, max_length=1000): core_agent_module = sys.modules['sdk.nexent.core.agents.core_agent'] # Override AgentError inside the imported module to ensure it has message attr core_agent_module.AgentError = MockAgentError + core_agent_module.AgentMaxStepsError = MockAgentMaxStepsError + # Override classes to use our mock classes for isinstance checks + core_agent_module.FinalAnswerStep = MockFinalAnswerStep + core_agent_module.ActionStep = MockActionStep + core_agent_module.PlanningStep = MockPlanningStep + core_agent_module.ActionOutput = MockActionOutput + core_agent_module.RunResult = MockRunResult + # Override CodeAgent to be a proper class that can be inherited + class MockCodeAgent: + def __init__(self, prompt_templates=None, *args, **kwargs): + # Accept any arguments but don't require observer + # Store attributes that might be accessed + self.prompt_templates = prompt_templates + # Initialize common attributes that CodeAgent might have + for key, value in kwargs.items(): + setattr(self, key, value) + core_agent_module.CodeAgent = MockCodeAgent CoreAgent = ImportedCoreAgent @@ -103,16 +198,50 @@ def core_agent_instance(mock_observer): agent.stop_event = Event() agent.memory = MagicMock() agent.memory.steps = [] + agent.memory.get_full_steps = MagicMock(return_value=[]) agent.python_executor = MagicMock() + + # Mock logger with all required methods + agent.logger = MagicMock() + agent.logger.log = MagicMock() + agent.logger.log_task = MagicMock() + agent.logger.log_markdown = MagicMock() + agent.logger.log_code = MagicMock() agent.step_number = 1 agent._execute_step = MagicMock() agent._finalize_step = MagicMock() agent._handle_max_steps_reached = MagicMock() + + # Set default attributes that might be needed + agent.max_steps = 5 + agent.state = {} + agent.system_prompt = "test system prompt" + agent.return_full_result = False + agent.provide_run_summary = False + agent.tools = {} + agent.managed_agents = {} + agent.monitor = MagicMock() + agent.monitor.reset = MagicMock() + agent.model = MagicMock() + if hasattr(agent.model, 'model_id'): + agent.model.model_id = "test-model" + agent.code_block_tags = ["```", "```"] + agent._use_structured_outputs_internally = False + agent.final_answer_checks = None # Set to avoid MagicMock creating new CoreAgent instances return agent +@pytest.fixture(autouse=True) +def reset_token_usage_mock(): + """Ensure TokenUsage mock does not leak state between tests.""" + token_usage = getattr(core_agent_module, "TokenUsage", None) + if hasattr(token_usage, "reset_mock"): + token_usage.reset_mock() + yield + + # ---------------------------------------------------------------------------- # Tests for _run method # ---------------------------------------------------------------------------- @@ -123,11 +252,12 @@ def test_run_normal_execution(core_agent_instance): task = "test task" max_steps = 3 - # Mock _execute_step to return a generator that yields final answer - def mock_execute_generator(action_step): - yield "final_answer" + # Mock _step_stream to return a generator that yields ActionOutput with final answer + def mock_step_stream(action_step): + action_output = MockActionOutput(output="final_answer", is_final_answer=True) + yield action_output - with patch.object(core_agent_instance, '_execute_step', side_effect=mock_execute_generator) as mock_execute_step, \ + with patch.object(core_agent_instance, '_step_stream', side_effect=mock_step_stream) as mock_step_stream_patch, \ patch.object(core_agent_instance, '_finalize_step') as mock_finalize_step: core_agent_instance.step_number = 1 @@ -135,11 +265,11 @@ def mock_execute_generator(action_step): result = list(core_agent_instance._run_stream(task, max_steps)) # Assertions - # _run_stream yields: generator output + action step + final answer step + # _run_stream yields: ActionOutput from _step_stream + action step + final answer step assert len(result) == 3 - assert result[0] == "final_answer" # Generator output - assert isinstance(result[1], MagicMock) # Action step - assert isinstance(result[2], MagicMock) # Final answer step + assert isinstance(result[0], MockActionOutput) # ActionOutput from _step_stream + assert isinstance(result[1], MockActionStep) # Action step + assert isinstance(result[2], MockFinalAnswerStep) # Final answer step def test_run_with_max_steps_reached(core_agent_instance): @@ -148,11 +278,12 @@ def test_run_with_max_steps_reached(core_agent_instance): task = "test task" max_steps = 2 - # Mock _execute_step to return None (no final answer) - def mock_execute_generator(action_step): - yield None + # Mock _step_stream to return ActionOutput without final answer + def mock_step_stream(action_step): + action_output = MockActionOutput(output=None, is_final_answer=False) + yield action_output - with patch.object(core_agent_instance, '_execute_step', side_effect=mock_execute_generator) as mock_execute_step, \ + with patch.object(core_agent_instance, '_step_stream', side_effect=mock_step_stream) as mock_step_stream_patch, \ patch.object(core_agent_instance, '_finalize_step') as mock_finalize_step, \ patch.object(core_agent_instance, '_handle_max_steps_reached', return_value="max_steps_reached") as mock_handle_max: @@ -162,18 +293,19 @@ def mock_execute_generator(action_step): result = list(core_agent_instance._run_stream(task, max_steps)) # Assertions - # For 2 steps: (None + action_step) * 2 + final_action_step + final_answer_step = 6 - assert len(result) == 6 - assert result[0] is None # First generator output - assert isinstance(result[1], MagicMock) # First action step - assert result[2] is None # Second generator output - assert isinstance(result[3], MagicMock) # Second action step - # Final action step (from _handle_max_steps_reached) - assert isinstance(result[4], MagicMock) - assert isinstance(result[5], MagicMock) # Final answer step + # For 2 steps: (ActionOutput + action_step) * 2 + final_action_step + final_answer_step = 6 + assert len(result) >= 5 + # First step: ActionOutput + ActionStep + assert isinstance(result[0], MockActionOutput) # First ActionOutput + assert isinstance(result[1], MockActionStep) # First action step + # Second step: ActionOutput + ActionStep + assert isinstance(result[2], MockActionOutput) # Second ActionOutput + assert isinstance(result[3], MockActionStep) # Second action step + # Last should be final answer step + assert isinstance(result[-1], MockFinalAnswerStep) # Final answer step # Verify method calls - assert mock_execute_step.call_count == 2 + assert mock_step_stream_patch.call_count == 2 mock_handle_max.assert_called_once() assert mock_finalize_step.call_count == 2 @@ -184,23 +316,28 @@ def test_run_with_stop_event(core_agent_instance): task = "test task" max_steps = 3 - def mock_execute_generator(action_step): + def mock_step_stream(action_step): core_agent_instance.stop_event.set() - yield None + action_output = MockActionOutput(output=None, is_final_answer=False) + yield action_output + + # Mock handle_agent_output_types to return the input value (identity function) + # This way when final_answer = "", it will be passed through + with patch.object(core_agent_module, 'handle_agent_output_types', side_effect=lambda x: x): + # Mock _step_stream to set stop event + with patch.object(core_agent_instance, '_step_stream', side_effect=mock_step_stream): + with patch.object(core_agent_instance, '_finalize_step'): + # Execute + result = list(core_agent_instance._run_stream(task, max_steps)) - # Mock _execute_step to set stop event - with patch.object(core_agent_instance, '_execute_step', side_effect=mock_execute_generator): - with patch.object(core_agent_instance, '_finalize_step'): - # Execute - result = list(core_agent_instance._run_stream(task, max_steps)) - - # Assertions - # Should yield: generator output + action step + final answer step - assert len(result) == 3 - assert result[0] is None # Generator output - assert isinstance(result[1], MagicMock) # Action step - # Final answer step with "" - assert isinstance(result[2], MagicMock) + # Assertions + # Should yield: ActionOutput from _step_stream + action step + final answer step + assert len(result) == 3 + assert isinstance(result[0], MockActionOutput) # ActionOutput from _step_stream + assert isinstance(result[1], MockActionStep) # Action step + # Final answer step with "" + assert isinstance(result[2], MockFinalAnswerStep) + assert result[2].output == "" def test_run_with_final_answer_error(core_agent_instance): @@ -209,9 +346,9 @@ def test_run_with_final_answer_error(core_agent_instance): task = "test task" max_steps = 3 - # Mock _execute_step to raise FinalAnswerError - with patch.object(core_agent_instance, '_execute_step', - side_effect=core_agent_module.FinalAnswerError()) as mock_execute_step, \ + # Mock _step_stream to raise FinalAnswerError + with patch.object(core_agent_instance, '_step_stream', + side_effect=core_agent_module.FinalAnswerError()) as mock_step_stream, \ patch.object(core_agent_instance, '_finalize_step'): # Execute result = list(core_agent_instance._run_stream(task, max_steps)) @@ -219,8 +356,8 @@ def test_run_with_final_answer_error(core_agent_instance): # Assertions # When FinalAnswerError occurs, it should yield action step + final answer step assert len(result) == 2 - assert isinstance(result[0], MagicMock) # Action step - assert isinstance(result[1], MagicMock) # Final answer step + assert isinstance(result[0], MockActionStep) # Action step + assert isinstance(result[1], MockFinalAnswerStep) # Final answer step def test_run_with_final_answer_error_and_model_output(core_agent_instance): @@ -229,16 +366,12 @@ def test_run_with_final_answer_error_and_model_output(core_agent_instance): task = "test task" max_steps = 3 - # Create a mock action step with model_output - mock_action_step = MagicMock() - mock_action_step.model_output = "```\nprint('hello')\n```" - - # Mock _execute_step to set model_output and then raise FinalAnswerError - def mock_execute_step(action_step): + # Mock _step_stream to set model_output and then raise FinalAnswerError + def mock_step_stream(action_step): action_step.model_output = "```\nprint('hello')\n```" raise core_agent_module.FinalAnswerError() - with patch.object(core_agent_instance, '_execute_step', side_effect=mock_execute_step), \ + with patch.object(core_agent_instance, '_step_stream', side_effect=mock_step_stream), \ patch.object(core_agent_module, 'convert_code_format', return_value="```python\nprint('hello')\n```") as mock_convert, \ patch.object(core_agent_instance, '_finalize_step'): # Execute @@ -246,8 +379,8 @@ def mock_execute_step(action_step): # Assertions assert len(result) == 2 - assert isinstance(result[0], MagicMock) # Action step - assert isinstance(result[1], MagicMock) # Final answer step + assert isinstance(result[0], MockActionStep) # Action step + assert isinstance(result[1], MockFinalAnswerStep) # Final answer step # Verify convert_code_format was called mock_convert.assert_called_once_with( "```\nprint('hello')\n```") @@ -259,9 +392,9 @@ def test_run_with_agent_error_updated(core_agent_instance): task = "test task" max_steps = 3 - # Mock _execute_step to raise AgentError - with patch.object(core_agent_instance, '_execute_step', - side_effect=MockAgentError("test error")) as mock_execute_step, \ + # Mock _step_stream to raise AgentError + with patch.object(core_agent_instance, '_step_stream', + side_effect=MockAgentError("test error")) as mock_step_stream, \ patch.object(core_agent_instance, '_finalize_step'): # Execute result = list(core_agent_instance._run_stream(task, max_steps)) @@ -270,9 +403,9 @@ def test_run_with_agent_error_updated(core_agent_instance): # When AgentError occurs, it should yield action step + final answer step # But the error causes the loop to continue, so we get multiple action steps assert len(result) >= 2 - assert isinstance(result[0], MagicMock) # Action step with error + assert isinstance(result[0], MockActionStep) # Action step with error # Last item should be final answer step - assert isinstance(result[-1], MagicMock) # Final answer step + assert isinstance(result[-1], MockFinalAnswerStep) # Final answer step def test_run_with_agent_parse_error_branch_updated(core_agent_instance): @@ -280,25 +413,40 @@ def test_run_with_agent_parse_error_branch_updated(core_agent_instance): task = "parse task" max_steps = 1 - # Mock _execute_step to set model_output and then raise FinalAnswerError - def mock_execute_step(action_step): + # Mock _step_stream to set model_output and then raise FinalAnswerError + def mock_step_stream(action_step): action_step.model_output = "```\nprint('hello')\n```" raise core_agent_module.FinalAnswerError() - with patch.object(core_agent_instance, '_execute_step', side_effect=mock_execute_step), \ + with patch.object(core_agent_instance, '_step_stream', side_effect=mock_step_stream), \ patch.object(core_agent_module, 'convert_code_format', return_value="```python\nprint('hello')\n```") as mock_convert, \ patch.object(core_agent_instance, '_finalize_step'): results = list(core_agent_instance._run_stream(task, max_steps)) # _run should yield action step + final answer step assert len(results) == 2 - assert isinstance(results[0], MagicMock) # Action step - assert isinstance(results[1], MagicMock) # Final answer step + assert isinstance(results[0], MockActionStep) # Action step + assert isinstance(results[1], MockFinalAnswerStep) # Final answer step # Verify convert_code_format was called mock_convert.assert_called_once_with( "```\nprint('hello')\n```") +def test_run_stream_validates_final_answer_when_checks_enabled(core_agent_instance): + """Ensure _run_stream triggers final answer validation when checks are configured.""" + task = "validate task" + core_agent_instance.final_answer_checks = ["non-empty"] + core_agent_instance._validate_final_answer = MagicMock() + + def mock_step_stream(action_step): + yield MockActionOutput(output="final answer", is_final_answer=True) + + with patch.object(core_agent_instance, '_step_stream', side_effect=mock_step_stream), \ + patch.object(core_agent_instance, '_finalize_step'): + result = list(core_agent_instance._run_stream(task, max_steps=1)) + + assert len(result) == 3 # ActionOutput, ActionStep, FinalAnswerStep + core_agent_instance._validate_final_answer.assert_called_once_with("final answer") def test_convert_code_format_display_replacements(): """Validate convert_code_format correctly transforms format to standard markdown.""" @@ -575,6 +723,10 @@ def test_step_stream_parse_success(core_agent_instance): core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -586,7 +738,7 @@ def test_step_stream_parse_success(core_agent_instance): return_value=[]) core_agent_instance.model = MagicMock(return_value=mock_chat_message) core_agent_instance.python_executor = MagicMock( - return_value=("output", "logs", False)) + return_value=MockCodeOutput(output="output", logs="logs", is_final_answer=False)) # Execute list(core_agent_instance._step_stream(mock_memory_step)) @@ -599,6 +751,33 @@ def test_step_stream_parse_success(core_agent_instance): assert hasattr(mock_memory_step.tool_calls[0], 'arguments') +def test_step_stream_structured_outputs_with_stop_sequence(core_agent_instance): + """Ensure _step_stream handles structured outputs correctly.""" + mock_memory_step = MagicMock() + mock_chat_message = MagicMock() + mock_chat_message.content = json.dumps({"code": "print('hello')"}) + mock_chat_message.token_usage = MagicMock() + + core_agent_instance.agent_name = "test_agent" + core_agent_instance.step_number = 1 + core_agent_instance._use_structured_outputs_internally = True + core_agent_instance.code_block_tags = ["<>", "[CLOSE]"] + core_agent_instance.write_memory_to_messages = MagicMock(return_value=[]) + core_agent_instance.model = MagicMock(return_value=mock_chat_message) + core_agent_instance.python_executor = MagicMock( + return_value=MockCodeOutput(output="result", logs="", is_final_answer=False) + ) + + with patch.object(core_agent_module, 'extract_code_from_text', return_value="print('hello')") as mock_extract, \ + patch.object(core_agent_module, 'fix_final_answer_code', side_effect=lambda code: code): + list(core_agent_instance._step_stream(mock_memory_step)) + + # Ensure structured output helpers were used + mock_extract.assert_called_once_with("print('hello')", core_agent_instance.code_block_tags) + call_kwargs = core_agent_instance.model.call_args.kwargs + assert call_kwargs["response_format"] == core_agent_module.CODEAGENT_RESPONSE_FORMAT + + def test_step_stream_skips_execution_for_display_only(core_agent_instance): """Test that _step_stream raises FinalAnswerError when only DISPLAY code blocks are present.""" # Setup @@ -611,6 +790,10 @@ def test_step_stream_skips_execution_for_display_only(core_agent_instance): core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -637,6 +820,10 @@ def test_step_stream_parse_failure_raises_final_answer_error(core_agent_instance core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -662,6 +849,10 @@ def test_step_stream_model_generation_error(core_agent_instance): core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -687,6 +878,10 @@ def test_step_stream_execution_success(core_agent_instance): core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -698,14 +893,16 @@ def test_step_stream_execution_success(core_agent_instance): return_value=[]) core_agent_instance.model = MagicMock(return_value=mock_chat_message) core_agent_instance.python_executor = MagicMock( - return_value=("Hello World", "Execution logs", False)) + return_value=MockCodeOutput(output="Hello World", logs="Execution logs", is_final_answer=False)) # Execute result = list(core_agent_instance._step_stream(mock_memory_step)) # Assertions - # Should yield None when is_final_answer is False - assert result[0] is None + # Should yield ActionOutput when is_final_answer is False + assert len(result) == 1 + assert isinstance(result[0], MockActionOutput) + assert result[0].is_final_answer is False assert mock_memory_step.observations is not None # Check that observations was set (we can't easily test the exact content due to mock behavior) assert hasattr(mock_memory_step, 'observations') @@ -723,6 +920,10 @@ def test_step_stream_execution_final_answer(core_agent_instance): core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -734,13 +935,16 @@ def test_step_stream_execution_final_answer(core_agent_instance): return_value=[]) core_agent_instance.model = MagicMock(return_value=mock_chat_message) core_agent_instance.python_executor = MagicMock( - return_value=("final answer", "Execution logs", True)) + return_value=MockCodeOutput(output="final answer", logs="Execution logs", is_final_answer=True)) # Execute result = list(core_agent_instance._step_stream(mock_memory_step)) # Assertions - assert result[0] == "final answer" # Should yield the final answer + assert len(result) == 1 + assert isinstance(result[0], MockActionOutput) + assert result[0].is_final_answer is True + assert result[0].output == "final answer" def test_step_stream_execution_error(core_agent_instance): @@ -755,6 +959,10 @@ def test_step_stream_execution_error(core_agent_instance): core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -795,6 +1003,10 @@ def test_step_stream_observer_calls(core_agent_instance): core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -806,7 +1018,7 @@ def test_step_stream_observer_calls(core_agent_instance): return_value=[]) core_agent_instance.model = MagicMock(return_value=mock_chat_message) core_agent_instance.python_executor = MagicMock( - return_value=("test", "logs", False)) + return_value=MockCodeOutput(output="test", logs="logs", is_final_answer=False)) # Execute list(core_agent_instance._step_stream(mock_memory_step)) @@ -847,6 +1059,10 @@ def test_step_stream_execution_with_logs(core_agent_instance): core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -859,14 +1075,16 @@ def test_step_stream_execution_with_logs(core_agent_instance): core_agent_instance.model = MagicMock(return_value=mock_chat_message) # Mock python_executor to return logs core_agent_instance.python_executor = MagicMock( - return_value=("output", "Some execution logs", False)) + return_value=MockCodeOutput(output="output", logs="Some execution logs", is_final_answer=False)) # Execute result = list(core_agent_instance._step_stream(mock_memory_step)) # Assertions - # Should yield None when is_final_answer is False - assert result[0] is None + # Should yield ActionOutput when is_final_answer is False + assert len(result) == 1 + assert isinstance(result[0], MockActionOutput) + assert result[0].is_final_answer is False # Check that execution logs were recorded assert core_agent_instance.observer.add_message.call_count >= 3 calls = core_agent_instance.observer.add_message.call_args_list @@ -887,6 +1105,10 @@ def test_step_stream_execution_error_with_print_outputs(core_agent_instance): core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -926,6 +1148,10 @@ def test_step_stream_execution_error_with_import_warning(core_agent_instance): core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -969,6 +1195,10 @@ def test_step_stream_execution_error_without_print_outputs(core_agent_instance): core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -1003,6 +1233,10 @@ def test_step_stream_execution_with_none_output(core_agent_instance): core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -1015,14 +1249,16 @@ def test_step_stream_execution_with_none_output(core_agent_instance): core_agent_instance.model = MagicMock(return_value=mock_chat_message) # Mock python_executor to return None output core_agent_instance.python_executor = MagicMock( - return_value=(None, "Execution logs", False)) + return_value=MockCodeOutput(output=None, logs="Execution logs", is_final_answer=False)) # Execute result = list(core_agent_instance._step_stream(mock_memory_step)) # Assertions - # Should yield None when is_final_answer is False - assert result[0] is None + # Should yield ActionOutput when is_final_answer is False + assert len(result) == 1 + assert isinstance(result[0], MockActionOutput) + assert result[0].is_final_answer is False assert mock_memory_step.observations is not None # Check that observations was set but should not contain "Last output from code snippet" # since output is None @@ -1050,6 +1286,10 @@ def test_run_with_additional_args(core_agent_instance): core_agent_instance.monitor = MagicMock() core_agent_instance.monitor.reset = MagicMock() core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.model = MagicMock() core_agent_instance.model.model_id = "test-model" core_agent_instance.name = "test_agent" @@ -1059,8 +1299,7 @@ def test_run_with_additional_args(core_agent_instance): core_agent_instance.observer = MagicMock() # Mock _run_stream to return a simple result - mock_final_step = MagicMock() - mock_final_step.final_answer = "final result" + mock_final_step = MockFinalAnswerStep(output="final result") with patch.object(core_agent_instance, '_run_stream', return_value=[mock_final_step]): # Execute @@ -1089,6 +1328,10 @@ def test_run_with_stream_true(core_agent_instance): core_agent_instance.monitor = MagicMock() core_agent_instance.monitor.reset = MagicMock() core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.model = MagicMock() core_agent_instance.model.model_id = "test-model" core_agent_instance.name = "test_agent" @@ -1123,6 +1366,10 @@ def test_run_with_reset_false(core_agent_instance): core_agent_instance.monitor = MagicMock() core_agent_instance.monitor.reset = MagicMock() core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.model = MagicMock() core_agent_instance.model.model_id = "test-model" core_agent_instance.name = "test_agent" @@ -1132,8 +1379,7 @@ def test_run_with_reset_false(core_agent_instance): core_agent_instance.observer = MagicMock() # Mock _run_stream to return a simple result - mock_final_step = MagicMock() - mock_final_step.final_answer = "final result" + mock_final_step = MockFinalAnswerStep(output="final result") with patch.object(core_agent_instance, '_run_stream', return_value=[mock_final_step]): # Execute @@ -1162,6 +1408,10 @@ def test_run_with_images(core_agent_instance): core_agent_instance.monitor = MagicMock() core_agent_instance.monitor.reset = MagicMock() core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.model = MagicMock() core_agent_instance.model.model_id = "test-model" core_agent_instance.name = "test_agent" @@ -1171,8 +1421,7 @@ def test_run_with_images(core_agent_instance): core_agent_instance.observer = MagicMock() # Mock _run_stream to return a simple result - mock_final_step = MagicMock() - mock_final_step.final_answer = "final result" + mock_final_step = MockFinalAnswerStep(output="final result") with patch.object(core_agent_instance, '_run_stream', return_value=[mock_final_step]): # Execute @@ -1185,8 +1434,89 @@ def test_run_with_images(core_agent_instance): call_args = core_agent_instance.memory.steps.append.call_args[0][0] # The TaskStep is mocked, so just verify it was called with correct arguments via the constructor # We'll check that TaskStep was called with the right parameters - mock_smolagents.memory.TaskStep.assert_called_with( - task=task, task_images=images) + assert isinstance(call_args, MockTaskStep) + assert call_args.task == task + assert call_args.task_images == images + + +def test_run_return_full_result_success_state(core_agent_instance): + """run should return RunResult with aggregated token usage when requested.""" + task = "test task" + token_usage = MagicMock(input_tokens=7, output_tokens=3) + action_step = core_agent_module.ActionStep() + action_step.token_usage = token_usage + + core_agent_instance.name = "test_agent" + core_agent_instance.memory.steps = [action_step] + core_agent_instance.memory.get_full_steps = MagicMock(return_value=[{"step": "data"}]) + core_agent_instance.memory.reset = MagicMock() + core_agent_instance.monitor.reset = MagicMock() + core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() + core_agent_instance.model = MagicMock() + core_agent_instance.model.model_id = "model" + core_agent_instance.python_executor = MagicMock() + core_agent_instance.python_executor.send_variables = MagicMock() + core_agent_instance.python_executor.send_tools = MagicMock() + core_agent_instance.observer = MagicMock() + + final_step = MockFinalAnswerStep(output="final result") + with patch.object(core_agent_instance, '_run_stream', return_value=[final_step]): + result = core_agent_instance.run(task, return_full_result=True) + + assert isinstance(result, core_agent_module.RunResult) + assert result.output == "final result" + core_agent_module.TokenUsage.assert_called_once_with(input_tokens=7, output_tokens=3) + assert result.token_usage == core_agent_module.TokenUsage.return_value + assert result.state == "success" + core_agent_instance.memory.get_full_steps.assert_called_once() + + +def test_run_return_full_result_max_steps_error(core_agent_instance): + """run should mark state as max_steps_error when the last step contains AgentMaxStepsError.""" + task = "test task" + + action_step = core_agent_module.ActionStep() + action_step.token_usage = None + action_step.error = core_agent_module.AgentMaxStepsError("max steps reached") + + class StepsList(list): + def append(self, item): + # Skip storing TaskStep to keep action_step as the last element + if isinstance(item, core_agent_module.TaskStep): + return + super().append(item) + + core_agent_instance.name = "test_agent" + steps_list = StepsList([action_step]) + core_agent_instance.memory.steps = steps_list + core_agent_instance.memory.get_full_steps = MagicMock(return_value=[{"step": "data"}]) + core_agent_instance.memory.reset = MagicMock() + core_agent_instance.monitor.reset = MagicMock() + core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() + core_agent_instance.model = MagicMock() + core_agent_instance.model.model_id = "model" + core_agent_instance.python_executor = MagicMock() + core_agent_instance.python_executor.send_variables = MagicMock() + core_agent_instance.python_executor.send_tools = MagicMock() + core_agent_instance.observer = MagicMock() + + final_step = MockFinalAnswerStep(output="final result") + with patch.object(core_agent_instance, '_run_stream', return_value=[final_step]): + result = core_agent_instance.run(task, return_full_result=True) + + assert isinstance(result, core_agent_module.RunResult) + assert result.token_usage is None + core_agent_module.TokenUsage.assert_not_called() + assert result.state == "max_steps_error" + core_agent_instance.memory.get_full_steps.assert_called_once() def test_run_without_python_executor(core_agent_instance): @@ -1204,6 +1534,10 @@ def test_run_without_python_executor(core_agent_instance): core_agent_instance.monitor = MagicMock() core_agent_instance.monitor.reset = MagicMock() core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.model = MagicMock() core_agent_instance.model.model_id = "test-model" core_agent_instance.name = "test_agent" @@ -1213,8 +1547,7 @@ def test_run_without_python_executor(core_agent_instance): core_agent_instance.observer = MagicMock() # Mock _run_stream to return a simple result - mock_final_step = MagicMock() - mock_final_step.final_answer = "final result" + mock_final_step = MockFinalAnswerStep(output="final result") with patch.object(core_agent_instance, '_run_stream', return_value=[mock_final_step]): # Execute @@ -1267,6 +1600,31 @@ def test_call_method_success(core_agent_instance): "test_agent", ProcessType.AGENT_FINISH, "test result") +def test_call_method_with_run_result_return(core_agent_instance): + """Test __call__ handles RunResult by extracting its output.""" + task = "test task" + core_agent_instance.name = "test_agent" + core_agent_instance.state = {} + core_agent_instance.prompt_templates = { + "managed_agent": { + "task": "Task: {{task}}", + "report": "Report: {{final_answer}}" + } + } + core_agent_instance.provide_run_summary = False + core_agent_instance.observer = MagicMock() + + run_result = core_agent_module.RunResult(output="run result", token_usage=None, steps=[], timing=None, state="success") + with patch.object(core_agent_instance, 'run', return_value=run_result) as mock_run: + result = core_agent_instance(task) + + assert "Report: run result" in result + mock_run.assert_called_once() + core_agent_instance.observer.add_message.assert_called_with( + "test_agent", ProcessType.AGENT_FINISH, "run result" + ) + + def test_call_method_with_run_summary(core_agent_instance): """Test __call__ method with provide_run_summary=True.""" # Setup @@ -1284,10 +1642,14 @@ def test_call_method_with_run_summary(core_agent_instance): core_agent_instance.provide_run_summary = True core_agent_instance.observer = MagicMock() - # Mock write_memory_to_messages to return some simple messages + # Mock write_memory_to_messages to return some simple messages with .content attribute + class MockMessage: + def __init__(self, content): + self.content = content + mock_messages = [ - {"content": "msg1"}, - {"content": "msg2"} + MockMessage("msg1"), + MockMessage("msg2") ] core_agent_instance.write_memory_to_messages = MagicMock( return_value=mock_messages) diff --git a/test/sdk/core/agents/test_nexent_agent.py b/test/sdk/core/agents/test_nexent_agent.py index 3dc831323..2a842ea72 100644 --- a/test/sdk/core/agents/test_nexent_agent.py +++ b/test/sdk/core/agents/test_nexent_agent.py @@ -27,11 +27,16 @@ class _ActionStep: - pass + def __init__(self, step_number=None, timing=None, action_output=None, model_output=None): + self.step_number = step_number + self.timing = timing + self.action_output = action_output + self.model_output = model_output class _TaskStep: - pass + def __init__(self, task=None): + self.task = task class _AgentText: @@ -214,6 +219,8 @@ class _MockToolSign: "nexent.storage": mock_nexent_storage_module, "nexent.multi_modal": mock_nexent_multi_modal_module, "nexent.multi_modal.load_save_object": mock_nexent_load_save_module, + # Mock tiktoken to avoid importing the real package when models import it + "tiktoken": MagicMock(), # Mock the OpenAIModel import "sdk.nexent.core.models.openai_llm": MagicMock(OpenAIModel=mock_openai_model_class), # Mock CoreAgent import @@ -230,7 +237,7 @@ class _MockToolSign: from sdk.nexent.core.utils.observer import MessageObserver, ProcessType from sdk.nexent.core.agents import nexent_agent from sdk.nexent.core.agents.nexent_agent import NexentAgent, ActionStep, TaskStep - from sdk.nexent.core.agents.agent_model import ToolConfig, ModelConfig, AgentConfig + from sdk.nexent.core.agents.agent_model import ToolConfig, ModelConfig, AgentConfig, AgentHistory # ---------------------------------------------------------------------------- @@ -1087,6 +1094,48 @@ def test_add_history_to_agent_none_history(nexent_agent_instance, mock_core_agen assert len(mock_core_agent.memory.steps) == 0 +def test_add_history_to_agent_user_and_assistant_history(nexent_agent_instance, mock_core_agent): + """Test add_history_to_agent correctly converts user and assistant messages to memory steps.""" + nexent_agent_instance.agent = mock_core_agent + + user_msg = AgentHistory(role="user", content="User question") + assistant_msg = AgentHistory(role="assistant", content="Assistant reply") + + nexent_agent_instance.add_history_to_agent([user_msg, assistant_msg]) + + mock_core_agent.memory.reset.assert_called_once() + assert len(mock_core_agent.memory.steps) == 2 + + # First step should be a TaskStep for the user message + first_step = mock_core_agent.memory.steps[0] + assert isinstance(first_step, TaskStep) + assert first_step.task == "User question" + + # Second step should be an ActionStep for the assistant message + second_step = mock_core_agent.memory.steps[1] + assert isinstance(second_step, ActionStep) + assert second_step.action_output == "Assistant reply" + assert second_step.model_output == "Assistant reply" + + +def test_add_history_to_agent_invalid_agent_type(nexent_agent_instance): + """Test add_history_to_agent raises TypeError when agent is not a CoreAgent.""" + nexent_agent_instance.agent = "not_core_agent" + + with pytest.raises(TypeError, match="agent must be a CoreAgent object"): + nexent_agent_instance.add_history_to_agent([]) + + +def test_add_history_to_agent_invalid_history_items(nexent_agent_instance, mock_core_agent): + """Test add_history_to_agent raises TypeError when history items are not AgentHistory.""" + nexent_agent_instance.agent = mock_core_agent + + invalid_history = [{"role": "user", "content": "hello"}] + + with pytest.raises(TypeError, match="history must be a list of AgentHistory objects"): + nexent_agent_instance.add_history_to_agent(invalid_history) + + def test_agent_run_with_observer_success_with_agent_text(nexent_agent_instance, mock_core_agent): """Test successful agent_run_with_observer with AgentText final answer.""" # Setup @@ -1103,7 +1152,7 @@ def test_agent_run_with_observer_success_with_agent_text(nexent_agent_instance, "Final answer with thinking content") mock_core_agent.run.return_value = [mock_action_step] - mock_core_agent.run.return_value[-1].final_answer = mock_final_answer + mock_core_agent.run.return_value[-1].output = mock_final_answer # Execute nexent_agent_instance.agent_run_with_observer("test query") @@ -1129,7 +1178,7 @@ def test_agent_run_with_observer_success_with_string_final_answer(nexent_agent_i mock_action_step.error = None mock_core_agent.run.return_value = [mock_action_step] - mock_core_agent.run.return_value[-1].final_answer = "String final answer with thinking" + mock_core_agent.run.return_value[-1].output = "String final answer with thinking" # Execute nexent_agent_instance.agent_run_with_observer("test query") @@ -1153,7 +1202,7 @@ def test_agent_run_with_observer_with_error_in_step(nexent_agent_instance, mock_ mock_action_step.error = "Test error occurred" mock_core_agent.run.return_value = [mock_action_step] - mock_core_agent.run.return_value[-1].final_answer = "Final answer" + mock_core_agent.run.return_value[-1].output = "Final answer" # Execute nexent_agent_instance.agent_run_with_observer("test query") @@ -1176,7 +1225,7 @@ def test_agent_run_with_observer_skips_non_action_step(nexent_agent_instance, mo mock_action_step.error = None mock_core_agent.run.return_value = [mock_task_step, mock_action_step] - mock_core_agent.run.return_value[-1].final_answer = "Final answer" + mock_core_agent.run.return_value[-1].output = "Final answer" # Execute nexent_agent_instance.agent_run_with_observer("test query") @@ -1199,7 +1248,7 @@ def test_agent_run_with_observer_with_stop_event_set(nexent_agent_instance, mock mock_action_step.error = None mock_core_agent.run.return_value = [mock_action_step] - mock_core_agent.run.return_value[-1].final_answer = "Final answer" + mock_core_agent.run.return_value[-1].output = "Final answer" # Execute nexent_agent_instance.agent_run_with_observer("test query") @@ -1226,6 +1275,14 @@ def test_agent_run_with_observer_with_exception(nexent_agent_instance, mock_core ) +def test_agent_run_with_observer_invalid_agent_type(nexent_agent_instance): + """Test agent_run_with_observer raises TypeError when agent is not a CoreAgent.""" + nexent_agent_instance.agent = "not_core_agent" + + with pytest.raises(TypeError, match="agent must be a CoreAgent object"): + nexent_agent_instance.agent_run_with_observer("test query") + + def test_agent_run_with_observer_with_reset_false(nexent_agent_instance, mock_core_agent): """Test agent_run_with_observer with reset=False parameter.""" # Setup @@ -1238,7 +1295,7 @@ def test_agent_run_with_observer_with_reset_false(nexent_agent_instance, mock_co mock_action_step.error = None mock_core_agent.run.return_value = [mock_action_step] - mock_core_agent.run.return_value[-1].final_answer = "Final answer" + mock_core_agent.run.return_value[-1].output = "Final answer" # Execute with reset=False nexent_agent_instance.agent_run_with_observer("test query", reset=False) diff --git a/test/sdk/core/agents/test_run_agent.py b/test/sdk/core/agents/test_run_agent.py index 0cafdd8a1..b47aec879 100644 --- a/test/sdk/core/agents/test_run_agent.py +++ b/test/sdk/core/agents/test_run_agent.py @@ -49,7 +49,7 @@ def from_mcp(cls, *args, **kwargs): # pylint: disable=unused-argument sub_mod = ModuleType(f"smolagents.{_sub}") # Populate required attributes with MagicMocks to satisfy import-time `from smolagents. import ...`. if _sub == "agents": - for _name in ["CodeAgent", "populate_template", "handle_agent_output_types", "AgentError", "AgentType"]: + for _name in ["CodeAgent", "populate_template", "handle_agent_output_types", "AgentError", "AgentType", "ActionOutput", "RunResult"]: setattr(sub_mod, _name, MagicMock(name=f"smolagents.agents.{_name}")) elif _sub == "local_python_executor": setattr(sub_mod, "fix_final_answer_code", MagicMock(name="fix_final_answer_code")) @@ -59,6 +59,7 @@ def from_mcp(cls, *args, **kwargs): # pylint: disable=unused-argument elif _sub == "models": setattr(sub_mod, "ChatMessage", MagicMock(name="smolagents.models.ChatMessage")) setattr(sub_mod, "MessageRole", MagicMock(name="smolagents.models.MessageRole")) + setattr(sub_mod, "CODEAGENT_RESPONSE_FORMAT", MagicMock(name="smolagents.models.CODEAGENT_RESPONSE_FORMAT")) # Provide a simple base class so that OpenAIModel can inherit from it class _DummyOpenAIServerModel: def __init__(self, *args, **kwargs): @@ -67,13 +68,18 @@ def __init__(self, *args, **kwargs): setattr(sub_mod, "OpenAIServerModel", _DummyOpenAIServerModel) elif _sub == "monitoring": setattr(sub_mod, "LogLevel", MagicMock(name="smolagents.monitoring.LogLevel")) + setattr(sub_mod, "Timing", MagicMock(name="smolagents.monitoring.Timing")) + setattr(sub_mod, "YELLOW_HEX", MagicMock(name="smolagents.monitoring.YELLOW_HEX")) + setattr(sub_mod, "TokenUsage", MagicMock(name="smolagents.monitoring.TokenUsage")) elif _sub == "utils": for _name in [ "AgentExecutionError", "AgentGenerationError", "AgentParsingError", + "AgentMaxStepsError", "parse_code_blobs", "truncate_content", + "extract_code_from_text", ]: setattr(sub_mod, _name, MagicMock(name=f"smolagents.utils.{_name}")) setattr(mock_smolagents, _sub, sub_mod) @@ -82,6 +88,8 @@ def __init__(self, *args, **kwargs): # Top-level exports expected directly from `smolagents` by nexent_agent.py for _name in ["ActionStep", "TaskStep", "AgentText", "handle_agent_output_types"]: setattr(mock_smolagents, _name, MagicMock(name=f"smolagents.{_name}")) +# Export Timing from monitoring submodule to top-level +setattr(mock_smolagents, "Timing", mock_smolagents.monitoring.Timing) # Also export Tool at top-level so that `from smolagents import Tool` works setattr(mock_smolagents, "Tool", mock_smolagents_tool_cls) @@ -237,9 +245,9 @@ def test_agent_run_thread_local_flow(basic_agent_run_info, monkeypatch): def test_agent_run_thread_mcp_flow(basic_agent_run_info, mock_memory_context, monkeypatch): - """Verify behaviour when an MCP host list is provided.""" - # Give the AgentRunInfo an MCP host list - basic_agent_run_info.mcp_host = ["http://mcp.server"] + """Verify behaviour when an MCP host list is provided with auto-detected transport.""" + # Give the AgentRunInfo an MCP host list (string format, auto-detect transport) + basic_agent_run_info.mcp_host = ["http://mcp.server/mcp"] # Prepare ToolCollection.from_mcp to return a context manager mock_tool_collection = MagicMock(name="ToolCollectionInstance") @@ -257,7 +265,7 @@ def test_agent_run_thread_mcp_flow(basic_agent_run_info, mock_memory_context, mo basic_agent_run_info.observer.add_message.assert_any_call("", ProcessType.AGENT_NEW_RUN, "") # ToolCollection.from_mcp should be called with the expected client list and trust_remote_code=True - expected_client_list = [{"url": "http://mcp.server"}] + expected_client_list = [{"url": "http://mcp.server/mcp", "transport": "streamable-http"}] run_agent.ToolCollection.from_mcp.assert_called_once_with(expected_client_list, trust_remote_code=True) # NexentAgent should be instantiated with mcp_tool_collection @@ -275,6 +283,116 @@ def test_agent_run_thread_mcp_flow(basic_agent_run_info, mock_memory_context, mo mock_nexent_instance.agent_run_with_observer.assert_called_once_with(query=basic_agent_run_info.query, reset=False) +def test_agent_run_thread_mcp_flow_with_explicit_transport(basic_agent_run_info, mock_memory_context, monkeypatch): + """Verify behaviour when MCP host is provided with explicit transport in dict format.""" + # Give the AgentRunInfo an MCP host list with explicit transport + basic_agent_run_info.mcp_host = [{"url": "http://mcp.server", "transport": "sse"}] + + # Prepare ToolCollection.from_mcp to return a context manager + mock_tool_collection = MagicMock(name="ToolCollectionInstance") + mock_context_manager = MagicMock(__enter__=MagicMock(return_value=mock_tool_collection), __exit__=MagicMock(return_value=None)) + monkeypatch.setattr(run_agent.ToolCollection, "from_mcp", MagicMock(return_value=mock_context_manager)) + + # Patch NexentAgent + mock_nexent_instance = MagicMock(name="NexentAgentInstance") + monkeypatch.setattr(run_agent, "NexentAgent", MagicMock(return_value=mock_nexent_instance)) + + # Execute + run_agent.agent_run_thread(basic_agent_run_info) + + # ToolCollection.from_mcp should be called with the expected client list + expected_client_list = [{"url": "http://mcp.server", "transport": "sse"}] + run_agent.ToolCollection.from_mcp.assert_called_once_with(expected_client_list, trust_remote_code=True) + + +def test_agent_run_thread_mcp_flow_mixed_formats(basic_agent_run_info, mock_memory_context, monkeypatch): + """Verify behaviour when MCP host list contains both string and dict formats.""" + # Mix of string (auto-detect) and dict (explicit) formats + basic_agent_run_info.mcp_host = [ + "http://mcp1.server/mcp", # Auto-detect: streamable-http + "http://mcp2.server/sse", # Auto-detect: sse + {"url": "http://mcp3.server/mcp", "transport": "streamable-http"}, # Explicit: streamable-http + ] + + # Prepare ToolCollection.from_mcp to return a context manager + mock_tool_collection = MagicMock(name="ToolCollectionInstance") + mock_context_manager = MagicMock(__enter__=MagicMock(return_value=mock_tool_collection), __exit__=MagicMock(return_value=None)) + monkeypatch.setattr(run_agent.ToolCollection, "from_mcp", MagicMock(return_value=mock_context_manager)) + + # Patch NexentAgent + mock_nexent_instance = MagicMock(name="NexentAgentInstance") + monkeypatch.setattr(run_agent, "NexentAgent", MagicMock(return_value=mock_nexent_instance)) + + # Execute + run_agent.agent_run_thread(basic_agent_run_info) + + # ToolCollection.from_mcp should be called with normalized client list + expected_client_list = [ + {"url": "http://mcp1.server/mcp", "transport": "streamable-http"}, + {"url": "http://mcp2.server/sse", "transport": "sse"}, + {"url": "http://mcp3.server/mcp", "transport": "streamable-http"}, + ] + run_agent.ToolCollection.from_mcp.assert_called_once_with(expected_client_list, trust_remote_code=True) + + +def test_detect_transport(): + """Test transport auto-detection logic based on URL ending.""" + # Test URLs ending with /sse + assert run_agent._detect_transport("http://server/sse") == "sse" + assert run_agent._detect_transport("https://api.example.com/sse") == "sse" + assert run_agent._detect_transport("http://localhost:3000/sse") == "sse" + + # Test URLs ending with /mcp + assert run_agent._detect_transport("http://server/mcp") == "streamable-http" + assert run_agent._detect_transport("https://api.example.com/mcp") == "streamable-http" + assert run_agent._detect_transport("http://localhost:3000/mcp") == "streamable-http" + + # Test default fallback (no /sse or /mcp ending) + assert run_agent._detect_transport("http://server") == "streamable-http" + assert run_agent._detect_transport("https://api.example.com") == "streamable-http" + assert run_agent._detect_transport("http://server/other") == "streamable-http" + + +def test_normalize_mcp_config(): + """Test MCP configuration normalization.""" + # Test string format (auto-detect based on URL ending) + result = run_agent._normalize_mcp_config("http://server/mcp") + assert result == {"url": "http://server/mcp", "transport": "streamable-http"} + + result = run_agent._normalize_mcp_config("http://server/sse") + assert result == {"url": "http://server/sse", "transport": "sse"} + + # Test string format without /sse or /mcp ending (defaults to streamable-http) + result = run_agent._normalize_mcp_config("http://server") + assert result == {"url": "http://server", "transport": "streamable-http"} + + # Test dict format with explicit transport + result = run_agent._normalize_mcp_config({"url": "http://server/mcp", "transport": "sse"}) + assert result == {"url": "http://server/mcp", "transport": "sse"} + + # Test dict format without transport (auto-detect) + result = run_agent._normalize_mcp_config({"url": "http://server/sse"}) + assert result == {"url": "http://server/sse", "transport": "sse"} + + result = run_agent._normalize_mcp_config({"url": "http://server/mcp"}) + assert result == {"url": "http://server/mcp", "transport": "streamable-http"} + + # Test invalid dict (missing url) + with pytest.raises(ValueError, match="must contain 'url' key"): + run_agent._normalize_mcp_config({"transport": "sse"}) + + # Test invalid transport type + with pytest.raises(ValueError, match="Invalid transport type"): + run_agent._normalize_mcp_config({"url": "http://server/mcp", "transport": "stdio"}) + + with pytest.raises(ValueError, match="Invalid transport type"): + run_agent._normalize_mcp_config({"url": "http://server/mcp", "transport": "invalid"}) + + # Test invalid type + with pytest.raises(ValueError, match="Invalid MCP host item type"): + run_agent._normalize_mcp_config(123) + + def test_agent_run_thread_handles_internal_exception(basic_agent_run_info, mock_memory_context, monkeypatch): """If an internal error occurs, the observer should be notified and a ValueError propagated.""" # Configure NexentAgent.create_single_agent to raise an exception diff --git a/test/sdk/core/tools/test_datamate_search_tool.py b/test/sdk/core/tools/test_datamate_search_tool.py index cc2742796..ebfdb3bba 100644 --- a/test/sdk/core/tools/test_datamate_search_tool.py +++ b/test/sdk/core/tools/test_datamate_search_tool.py @@ -117,7 +117,7 @@ def test_extract_dataset_id(self, datamate_tool: DataMateSearchTool, path, expec @pytest.mark.parametrize( "dataset_id, file_id, expected", [ - ("ds1", "f1", "127.0.0.1/api/data-management/datasets/ds1/files/f1/download"), + ("ds1", "f1", "http://127.0.0.1:8080/api/data-management/datasets/ds1/files/f1/download"), ("", "f1", ""), ("ds1", "", ""), ], diff --git a/test/sdk/core/tools/test_knowledge_base_search_tool.py b/test/sdk/core/tools/test_knowledge_base_search_tool.py index 535af6b35..f6cdc4577 100644 --- a/test/sdk/core/tools/test_knowledge_base_search_tool.py +++ b/test/sdk/core/tools/test_knowledge_base_search_tool.py @@ -37,7 +37,8 @@ def knowledge_base_search_tool(mock_observer, mock_vdb_core, mock_embedding_mode index_names=["test_index1", "test_index2"], observer=mock_observer, embedding_model=mock_embedding_model, - vdb_core=mock_vdb_core + vdb_core=mock_vdb_core, + name_resolver={} ) return tool @@ -50,7 +51,8 @@ def knowledge_base_search_tool_no_observer(mock_vdb_core, mock_embedding_model): index_names=["test_index"], observer=None, embedding_model=mock_embedding_model, - vdb_core=mock_vdb_core + vdb_core=mock_vdb_core, + name_resolver={} ) return tool @@ -78,6 +80,49 @@ def create_mock_search_result(count=3): class TestKnowledgeBaseSearchTool: """Test KnowledgeBaseSearchTool functionality""" + def test_update_name_resolver_supports_empty_mapping(self, knowledge_base_search_tool): + """Ensure update_name_resolver replaces mapping and handles falsy input""" + knowledge_base_search_tool.update_name_resolver({"kb": "index_kb"}) + assert knowledge_base_search_tool.name_resolver == {"kb": "index_kb"} + + knowledge_base_search_tool.update_name_resolver(None) + assert knowledge_base_search_tool.name_resolver == {} + + def test_resolve_names_without_resolver_logs_warning(self, knowledge_base_search_tool, mocker): + """When no resolver is configured, names are returned unchanged and warning is logged""" + warning_mock = mocker.patch("sdk.nexent.core.tools.knowledge_base_search_tool.logger.warning") + + names = knowledge_base_search_tool._resolve_names(["kb1", "kb2"]) + + assert names == ["kb1", "kb2"] + warning_mock.assert_called_once() + + @pytest.mark.parametrize( + "incoming,expected", + [ + (None, []), + ("single_index", ["single_index"]), + (["a", "b"], ["a", "b"]), + ], + ) + def test_normalize_index_names_variants(self, knowledge_base_search_tool_no_observer, incoming, expected): + """_normalize_index_names should normalize None, string, and list inputs""" + assert knowledge_base_search_tool_no_observer._normalize_index_names(incoming) == expected + + def test_forward_with_observer_adds_messages(self, knowledge_base_search_tool): + """forward should send TOOL and CARD messages when observer is present""" + mock_results = create_mock_search_result(1) + knowledge_base_search_tool.vdb_core.hybrid_search.return_value = mock_results + + knowledge_base_search_tool.forward("hello world") + + knowledge_base_search_tool.observer.add_message.assert_any_call( + "", ProcessType.TOOL, "Searching the knowledge base..." + ) + knowledge_base_search_tool.observer.add_message.assert_any_call( + "", ProcessType.CARD, json.dumps([{"icon": "search", "text": "hello world"}], ensure_ascii=False) + ) + def test_init_with_custom_values(self, mock_observer, mock_vdb_core, mock_embedding_model): """Test initialization with custom values""" tool = KnowledgeBaseSearchTool( @@ -85,7 +130,8 @@ def test_init_with_custom_values(self, mock_observer, mock_vdb_core, mock_embedd index_names=["index1", "index2", "index3"], observer=mock_observer, embedding_model=mock_embedding_model, - vdb_core=mock_vdb_core + vdb_core=mock_vdb_core, + name_resolver={} ) assert tool.top_k == 10 @@ -101,7 +147,8 @@ def test_init_with_none_index_names(self, mock_vdb_core, mock_embedding_model): index_names=None, observer=None, embedding_model=mock_embedding_model, - vdb_core=mock_vdb_core + vdb_core=mock_vdb_core, + name_resolver={} ) assert tool.index_names == [] diff --git a/test/sdk/vector_database/test_elasticsearch_core.py b/test/sdk/vector_database/test_elasticsearch_core.py index 30f8ff277..f9f878852 100644 --- a/test/sdk/vector_database/test_elasticsearch_core.py +++ b/test/sdk/vector_database/test_elasticsearch_core.py @@ -522,6 +522,47 @@ def test_vectorize_documents_small_batch(elasticsearch_core_instance): mock_embedding_model.get_embeddings.assert_called_once() mock_bulk.assert_called_once() +def test_small_batch_progress_callback_exception(elasticsearch_core_instance, caplog): + """Progress callback errors should be logged without failing the insert.""" + mock_embedding_model = MagicMock() + mock_embedding_model.get_embeddings.return_value = [[0.1] * 3] + mock_embedding_model.embedding_model_name = "m" + + documents = [{"content": "a"}] + + def bad_progress(_, __): + raise RuntimeError("boom") + + with patch.object(elasticsearch_core_instance.client, "bulk") as mock_bulk, \ + patch("time.strftime", lambda *a, **k: "2025-01-15T10:30:00"), \ + patch("time.time", lambda: 1642234567): + mock_bulk.return_value = {"errors": False, "items": []} + result = elasticsearch_core_instance._small_batch_insert( + "idx", documents, "content", mock_embedding_model, progress_callback=bad_progress + ) + + assert result == 1 + assert any("Progress callback failed in small batch" in m for m in caplog.messages) + +def test_small_batch_error_path_logs_and_raises(elasticsearch_core_instance, caplog): + """Small batch should log errors and re-raise when bulk fails.""" + mock_embedding_model = MagicMock() + mock_embedding_model.get_embeddings.return_value = [[0.1] * 3] + mock_embedding_model.embedding_model_name = "m" + + documents = [{"content": "x"}] + + with patch.object(elasticsearch_core_instance, "client") as mock_client, \ + patch("time.strftime", lambda *a, **k: "2025-01-15T10:30:00"), \ + patch("time.time", lambda: 1642234567): + mock_client.bulk.side_effect = RuntimeError("bulk boom") + with pytest.raises(RuntimeError): + elasticsearch_core_instance._small_batch_insert( + "idx", documents, "content", mock_embedding_model + ) + + assert any("Small batch insert failed: bulk boom" in m for m in caplog.messages) + def test_vectorize_documents_large_batch(elasticsearch_core_instance): """Test indexing a large batch of documents (>= 64).""" @@ -558,6 +599,76 @@ def test_vectorize_documents_large_batch(elasticsearch_core_instance): mock_bulk.assert_called() mock_refresh.assert_called_once_with("test_index") +def test_large_batch_progress_callback_invoked(elasticsearch_core_instance): + """Progress callback should be triggered during embedding phase.""" + mock_embedding_model = MagicMock() + mock_embedding_model.embedding_model_name = "test-model" + mock_embedding_model.get_embeddings.return_value = [[0.1], [0.2]] + + docs = [{"content": "a"}, {"content": "b"}] + progress_calls = [] + + with patch.object(elasticsearch_core_instance.client, "bulk") as mock_bulk, \ + patch.object(elasticsearch_core_instance, "_force_refresh_with_retry"): + mock_bulk.return_value = {"errors": False, "items": []} + elasticsearch_core_instance._large_batch_insert( + "idx", docs, batch_size=5, content_field="content", + embedding_model=mock_embedding_model, embedding_batch_size=2, + progress_callback=lambda done, total: progress_calls.append((done, total)) + ) + + assert progress_calls == [(2, 2)] + +def test_large_batch_progress_callback_exception_logged(elasticsearch_core_instance, caplog): + """Embedding progress callback errors should be logged and not stop indexing.""" + mock_embedding_model = MagicMock() + mock_embedding_model.embedding_model_name = "test-model" + mock_embedding_model.get_embeddings.return_value = [[0.1]] + + docs = [{"content": "a"}] + + def bad_progress(_, __): + raise RuntimeError("cb fail") + + with patch.object(elasticsearch_core_instance.client, "bulk") as mock_bulk, \ + patch.object(elasticsearch_core_instance, "_force_refresh_with_retry"): + mock_bulk.return_value = {"errors": False, "items": []} + elasticsearch_core_instance._large_batch_insert( + "idx", docs, batch_size=1, content_field="content", + embedding_model=mock_embedding_model, embedding_batch_size=1, + progress_callback=bad_progress + ) + + assert any("Progress callback failed during embedding" in m for m in caplog.messages) + +def test_large_batch_retry_logs_warning(elasticsearch_core_instance, caplog): + """Embedding retries should emit warnings before succeeding.""" + mock_embedding_model = MagicMock() + mock_embedding_model.embedding_model_name = "test-model" + call_counter = {"n": 0} + + def get_embeddings(_): + call_counter["n"] += 1 + if call_counter["n"] < 3: + raise RuntimeError("embed fail") + return [[0.1]] + + mock_embedding_model.get_embeddings.side_effect = get_embeddings + + docs = [{"content": "a"}] + + with patch.object(elasticsearch_core_instance.client, "bulk") as mock_bulk, \ + patch.object(elasticsearch_core_instance, "_force_refresh_with_retry"), \ + patch("time.sleep", lambda *a, **k: None): + mock_bulk.return_value = {"errors": False, "items": []} + elasticsearch_core_instance._large_batch_insert( + "idx", docs, batch_size=1, content_field="content", + embedding_model=mock_embedding_model, embedding_batch_size=1, + ) + + assert call_counter["n"] == 3 + assert any("Embedding API error (attempt 1/3)" in m for m in caplog.messages) + def test_delete_documents_success(elasticsearch_core_instance): """Test deleting documents by path_or_url successfully.""" @@ -1134,8 +1245,12 @@ def test_handle_bulk_errors_with_errors(elasticsearch_core_instance): ] } - # Should not raise exception, just log errors - elasticsearch_core_instance._handle_bulk_errors(response) + with pytest.raises(Exception) as exc_info: + elasticsearch_core_instance._handle_bulk_errors(response) + + err_payload = str(exc_info.value) + assert "Bulk indexing failed: Failed to parse mapping" in err_payload + assert "es_bulk_failed" in err_payload def test_handle_bulk_errors_version_conflict(elasticsearch_core_instance): @@ -1158,6 +1273,40 @@ def test_handle_bulk_errors_version_conflict(elasticsearch_core_instance): elasticsearch_core_instance._handle_bulk_errors(response) +def test_handle_bulk_errors_skips_items_without_error(elasticsearch_core_instance): + """Items without error key should be ignored.""" + response = { + "errors": True, + "items": [{"index": {}}], + } + # Should not raise + elasticsearch_core_instance._handle_bulk_errors(response) + + +def test_handle_bulk_errors_dim_mismatch_sets_specific_code(elasticsearch_core_instance): + """Dense vector dimension mismatch should produce es_dim_mismatch code.""" + response = { + "errors": True, + "items": [ + { + "index": { + "error": { + "type": "illegal_argument_exception", + "reason": "field [embedding] has different number of dimensions than vector", + "caused_by": {"reason": "dense_vector different number of dimensions"}, + } + } + } + ], + } + + with pytest.raises(Exception) as exc_info: + elasticsearch_core_instance._handle_bulk_errors(response) + + payload = str(exc_info.value) + assert "es_dim_mismatch" in payload + assert "Bulk indexing failed" in payload + def test_bulk_operation_context(elasticsearch_core_instance): """Test bulk operation context manager.""" with patch.object(elasticsearch_core_instance, '_apply_bulk_settings') as mock_apply, \ diff --git a/test/sdk/vector_database/test_elasticsearch_core_coverage.py b/test/sdk/vector_database/test_elasticsearch_core_coverage.py index f307c9d84..757bbc566 100644 --- a/test/sdk/vector_database/test_elasticsearch_core_coverage.py +++ b/test/sdk/vector_database/test_elasticsearch_core_coverage.py @@ -215,8 +215,9 @@ def test_handle_bulk_errors_with_fatal_error(self, vdb_core): } ] } - vdb_core._handle_bulk_errors(response) - # Should log error but not raise exception + with pytest.raises(Exception) as exc_info: + vdb_core._handle_bulk_errors(response) + assert "Bulk indexing failed" in str(exc_info.value) def test_handle_bulk_errors_with_caused_by(self, vdb_core): """Test _handle_bulk_errors with caused_by information""" @@ -237,8 +238,10 @@ def test_handle_bulk_errors_with_caused_by(self, vdb_core): } ] } - vdb_core._handle_bulk_errors(response) - # Should log both main error and caused_by error + with pytest.raises(Exception) as exc_info: + vdb_core._handle_bulk_errors(response) + assert "Invalid argument" in str(exc_info.value) + assert "JSON parsing failed" in str(exc_info.value) def test_delete_documents_success(self, vdb_core): """Test delete_documents successful case""" @@ -407,16 +410,18 @@ def test_large_batch_insert_bulk_exception(self, vdb_core): mock_embedding_model = MagicMock() mock_embedding_model.get_embeddings.return_value = [[0.1]] - result = vdb_core._large_batch_insert("idx", [{"content": "body"}], 1, "content", mock_embedding_model) - assert result == 0 + with pytest.raises(Exception) as exc_info: + vdb_core._large_batch_insert("idx", [{"content": "body"}], 1, "content", mock_embedding_model) + assert "bulk error" in str(exc_info.value) def test_large_batch_insert_preprocess_exception(self, vdb_core): """Ensure outer exception handler returns zero on preprocess failure.""" vdb_core._preprocess_documents = MagicMock(side_effect=Exception("fail")) mock_embedding_model = MagicMock() - result = vdb_core._large_batch_insert("idx", [{"content": "body"}], 10, "content", mock_embedding_model) - assert result == 0 + with pytest.raises(Exception) as exc_info: + vdb_core._large_batch_insert("idx", [{"content": "body"}], 10, "content", mock_embedding_model) + assert "fail" in str(exc_info.value) def test_count_documents_success(self, vdb_core): """Ensure count_documents returns ES count.""" @@ -672,8 +677,9 @@ def test_small_batch_insert_exception(self, vdb_core): mock_embedding_model = MagicMock() documents = [{"content": "test content", "title": "test"}] - result = vdb_core._small_batch_insert("test_index", documents, "content", mock_embedding_model) - assert result == 0 + with pytest.raises(Exception) as exc_info: + vdb_core._small_batch_insert("test_index", documents, "content", mock_embedding_model) + assert "Preprocess error" in str(exc_info.value) def test_large_batch_insert_success(self, vdb_core): """Test _large_batch_insert successful case"""
+ {t("agent.author.hint", { defaultValue: "Default: {{email}}", email: user.email })} +
- {copied - ? t("chatStreamMessage.copied") - : t("chatStreamMessage.copyContent")} -
- {localOpinion === chatConfig.opinion.POSITIVE - ? t("chatStreamMessage.cancelLike") - : t("chatStreamMessage.like")} -
- {localOpinion === chatConfig.opinion.NEGATIVE - ? t("chatStreamMessage.cancelDislike") - : t("chatStreamMessage.dislike")} -
{t("chatStreamMessage.tts")}
- {t("document.status.loadingList")} + {isNewlyCreatedAndWaiting + ? t("document.status.waitingForTask") + : t("document.status.loadingList")}
+ {t("document.status.waitingForTask")} +
- {t("document.hint.uploadToCreate")} -
+ {t("document.hint.uploadToCreate")} +
+ {t("market.by", { defaultValue: "By {{author}}", author: agent.author })} +
{agent.description || t("space.noDescription", "No description")} diff --git a/frontend/components/agent/AgentImportWizard.tsx b/frontend/components/agent/AgentImportWizard.tsx index 12ba325cb..06159be23 100644 --- a/frontend/components/agent/AgentImportWizard.tsx +++ b/frontend/components/agent/AgentImportWizard.tsx @@ -1,6 +1,6 @@ "use client"; -import React, { useState, useEffect } from "react"; +import React, { useState, useEffect, useRef } from "react"; import { Modal, Steps, Button, Select, Input, Form, Tag, Space, Spin, App, Collapse, Radio } from "antd"; import { DownloadOutlined, CheckCircleOutlined, CloseCircleOutlined, PlusOutlined } from "@ant-design/icons"; import { useTranslation } from "react-i18next"; @@ -9,7 +9,7 @@ import { modelService } from "@/services/modelService"; import { getMcpServerList, addMcpServer, updateToolList } from "@/services/mcpService"; import { McpServer, AgentRefreshEvent } from "@/types/agentConfig"; import { ImportAgentData } from "@/hooks/useAgentImport"; -import { importAgent } from "@/services/agentConfigService"; +import { importAgent, checkAgentNameConflictBatch, regenerateAgentNameBatch } from "@/services/agentConfigService"; import log from "@/lib/logger"; export interface AgentImportWizardProps { @@ -53,6 +53,44 @@ const extractPromptHint = (value: string): string | undefined => { return match ? match[1] : undefined; }; +// Parse Markdown links in text and convert to React elements +const parseMarkdownLinks = (text: string): React.ReactNode[] => { + const linkRegex = /\[([^\]]+)\]\(([^)]+)\)/g; + const parts: React.ReactNode[] = []; + let lastIndex = 0; + let match; + let key = 0; + + while ((match = linkRegex.exec(text)) !== null) { + // Add text before the link + if (match.index > lastIndex) { + parts.push(text.substring(lastIndex, match.index)); + } + // Add the link + parts.push( + { + e.stopPropagation(); + }} + > + {match[1]} + + ); + lastIndex = match.index + match[0].length; + } + // Add remaining text + if (lastIndex < text.length) { + parts.push(text.substring(lastIndex)); + } + + return parts.length > 0 ? parts : [text]; +}; + export default function AgentImportWizard({ visible, onCancel, @@ -88,6 +126,28 @@ export default function AgentImportWizard({ const [installingMcp, setInstallingMcp] = useState>({}); const [isImporting, setIsImporting] = useState(false); + // Name conflict checking and renaming + // Structure: agentKey -> { hasConflict, conflictAgents, renamedName, renamedDisplayName } + const [agentNameConflicts, setAgentNameConflicts] = useState; + renamedName: string; + renamedDisplayName: string; + }>>({}); + const [checkingName, setCheckingName] = useState(false); + const [regeneratingAll, setRegeneratingAll] = useState(false); + // Track which agents have been successfully renamed (no conflicts) + const [successfullyRenamedAgents, setSuccessfullyRenamedAgents] = useState>(new Set()); + // Debounce timer for manual name changes - use ref to avoid stale closures + const nameCheckTimerRef = useRef(null); + // Store latest agentNameConflicts in ref to avoid stale closures in timer callbacks + const agentNameConflictsRef = useRef; + renamedName: string; + renamedDisplayName: string; + }>>({}); + // Helper: Refresh tools and agents after MCP changes const refreshToolsAndAgents = async () => { try { @@ -114,6 +174,22 @@ export default function AgentImportWizard({ } }, [visible]); + // Check name conflict immediately after file upload + useEffect(() => { + if (visible && initialData) { + checkNameConflict(); + } + }, [visible, initialData]); + + // Cleanup timer on unmount + useEffect(() => { + return () => { + if (nameCheckTimerRef.current) { + clearTimeout(nameCheckTimerRef.current); + } + }; + }, []); + // Parse agent data for config fields and MCP servers useEffect(() => { if (visible && initialData) { @@ -136,6 +212,232 @@ export default function AgentImportWizard({ setSelectedModelsByAgent(initialModels); }; + // Check name conflict for all agents (main agent + sub-agents) + const checkNameConflict = async () => { + if (!initialData?.agent_info) return; + + setCheckingName(true); + const conflicts: Record; + renamedName: string; + renamedDisplayName: string; + }> = {}; + + try { + // Check all agents in agent_info + const agentInfoMap = initialData.agent_info; + const items = Object.entries(agentInfoMap).map(([agentKey, agentInfo]: [string, any]) => ({ + key: agentKey, + name: agentInfo?.name || "", + display_name: agentInfo?.display_name, + })); + + const result = await checkAgentNameConflictBatch({ + items: items.map((item) => ({ + name: item.name, + display_name: item.display_name, + })), + }); + + if (!result.success || !Array.isArray(result.data)) { + log.warn("Skip name conflict check due to fetch failure"); + setAgentNameConflicts({}); + agentNameConflictsRef.current = {}; + setCheckingName(false); + return; + } + + result.data.forEach((res: any, idx: number) => { + const item = items[idx]; + const agentKey = item.key; + const hasNameConflict = res?.name_conflict || false; + const hasDisplayNameConflict = res?.display_name_conflict || false; + const conflictAgentsRaw = Array.isArray(res?.conflict_agents) ? res.conflict_agents : []; + // Deduplicate by name/display_name + const seen = new Set(); + const conflictAgents = conflictAgentsRaw.reduce((acc: Array<{ name?: string; display_name?: string }>, curr: any) => { + const key = `${curr?.name || ""}||${curr?.display_name || ""}`; + if (seen.has(key)) return acc; + seen.add(key); + acc.push({ name: curr?.name, display_name: curr?.display_name }); + return acc; + }, []); + + const hasConflict = hasNameConflict || hasDisplayNameConflict; + conflicts[agentKey] = { + hasConflict, + conflictAgents, + renamedName: item.name, + renamedDisplayName: item.display_name || "", + }; + }); + + setAgentNameConflicts(conflicts); + + // Update successfully renamed agents based on initial check + // Only add to successfullyRenamedAgents if there was a conflict that was resolved + // For initial check, we don't add anything since no renaming has happened yet + setSuccessfullyRenamedAgents((prev) => { + const next = new Set(prev); + // Don't modify on initial check - only track agents that were successfully renamed + return next; + }); + } catch (error) { + log.error("Failed to check name conflicts:", error); + } finally { + setCheckingName(false); + } + }; + + // Check name conflict for a specific agent after renaming + const checkSingleAgentConflict = async (agentKey: string, name: string, displayName?: string) => { + if (!initialData?.agent_info) return; + + try { + const result = await checkAgentNameConflictBatch({ + items: [ + { + name, + display_name: displayName, + }, + ], + }); + + if (!result.success || !Array.isArray(result.data) || !result.data[0]) { + return; + } + + const checkResult = result.data[0]; + const hasNameConflict = checkResult?.name_conflict || false; + const hasDisplayNameConflict = checkResult?.display_name_conflict || false; + const hasConflict = hasNameConflict || hasDisplayNameConflict; + const conflictAgentsRaw = Array.isArray(checkResult?.conflict_agents) ? checkResult.conflict_agents : []; + + // Deduplicate by name/display_name + const seen = new Set(); + const conflictAgents = conflictAgentsRaw.reduce((acc: Array<{ name?: string; display_name?: string }>, curr: any) => { + const key = `${curr?.name || ""}||${curr?.display_name || ""}`; + if (seen.has(key)) return acc; + seen.add(key); + acc.push({ name: curr?.name, display_name: curr?.display_name }); + return acc; + }, []); + + setAgentNameConflicts((prev) => { + const next = { ...prev }; + if (!next[agentKey]) { + const agentInfo = initialData.agent_info[agentKey] as any; + next[agentKey] = { + hasConflict: false, + conflictAgents: [], + renamedName: agentInfo?.name || "", + renamedDisplayName: agentInfo?.display_name || "", + }; + } + next[agentKey] = { + ...next[agentKey], + hasConflict, + conflictAgents, + renamedName: name, + renamedDisplayName: displayName || "", + }; + agentNameConflictsRef.current = next; + return next; + }); + + // Update success status + setSuccessfullyRenamedAgents((prev) => { + const next = new Set(prev); + if (hasConflict) { + next.delete(agentKey); + } else { + next.add(agentKey); + } + return next; + }); + + return hasConflict; + } catch (error) { + log.error("Failed to check single agent conflict:", error); + return true; // Assume conflict on error to be safe + } + }; + + // One-click regenerate all conflicted agents using selected model(s) + const handleRegenerateAll = async () => { + if (!initialData?.agent_info) return; + + const agentsWithConflicts = Object.entries(agentNameConflicts).filter( + ([_, conflict]) => conflict.hasConflict + ); + if (agentsWithConflicts.length === 0) return; + + setRegeneratingAll(true); + try { + const payload = { + items: agentsWithConflicts.map(([agentKey, conflict]) => { + const agentInfo = initialData.agent_info[agentKey] as any; + return { + agent_id: agentInfo?.agent_id, + name: conflict.renamedName || agentInfo?.name || "", + display_name: conflict.renamedDisplayName || agentInfo?.display_name || "", + task_description: agentInfo?.business_description || agentInfo?.description || "", + language: "zh", + }; + }), + }; + + const result = await regenerateAgentNameBatch(payload); + + if (!result.success || !Array.isArray(result.data)) { + message.error(result.message || t("market.install.error.nameRegenerationFailed", "Failed to regenerate name")); + return; + } + + const regenerated = result.data as Array<{ name?: string; display_name?: string }>; + + // Update conflicts state with regenerated names + setAgentNameConflicts((prev) => { + const next = { ...prev }; + agentsWithConflicts.forEach(([agentKey, conflict], idx) => { + const agentInfo = initialData.agent_info[agentKey] as any; + const data = regenerated[idx] || {}; + next[agentKey] = { + ...next[agentKey], + renamedName: data.name || conflict.renamedName || agentInfo?.name || "", + renamedDisplayName: + data.display_name || conflict.renamedDisplayName || agentInfo?.display_name || "", + }; + }); + agentNameConflictsRef.current = next; + return next; + }); + + // Re-check conflicts for all regenerated agents + const checkPromises = agentsWithConflicts.map(async ([agentKey, conflict], idx) => { + const data = regenerated[idx] || {}; + const newName = data.name || conflict.renamedName || ""; + const newDisplayName = data.display_name || conflict.renamedDisplayName || ""; + return checkSingleAgentConflict(agentKey, newName, newDisplayName); + }); + + const checkResults = await Promise.all(checkPromises); + const allResolved = checkResults.every((hasConflict) => !hasConflict); + + if (allResolved) { + message.success(t("market.install.success.nameRegeneratedAndResolved", "Agent names regenerated successfully and all conflicts resolved")); + } else { + message.success(t("market.install.success.nameRegenerated", "Agent name regenerated successfully")); + } + } catch (error) { + log.error("Failed to regenerate agent names:", error); + message.error(t("market.install.error.nameRegenerationFailed", "Failed to regenerate name")); + } finally { + setRegeneratingAll(false); + } + }; + const loadLLMModels = async () => { setLoadingModels(true); try { @@ -336,7 +638,11 @@ export default function AgentImportWizard({ }; const handleNext = () => { - if (currentStep === 0) { + const currentStepKey = steps[currentStep]?.key; + + if (currentStepKey === "rename") { + // no mandatory name check + } else if (currentStepKey === "model") { // Step 1: Model selection validation if (modelSelectionMode === "unified") { if (!selectedModelId || !selectedModelName) { @@ -357,7 +663,7 @@ export default function AgentImportWizard({ } } } - } else if (currentStep === 1) { + } else if (currentStepKey === "config") { // Step 2: Config fields validation const emptyFields = configFields.filter(field => !configValues[field.valueKey]?.trim()); if (emptyFields.length > 0) { @@ -409,7 +715,18 @@ export default function AgentImportWizard({ // Clone agent data structure const agentJson = JSON.parse(JSON.stringify(initialData)); - const mainAgentId = String(initialData.agent_id); + + // Update all agents' name/display_name if renamed + Object.entries(agentNameConflicts).forEach(([agentKey, conflict]) => { + if (agentJson.agent_info[agentKey]) { + if (conflict.renamedName) { + agentJson.agent_info[agentKey].name = conflict.renamedName; + } + if (conflict.renamedDisplayName) { + agentJson.agent_info[agentKey].display_name = conflict.renamedDisplayName; + } + } + }); // Update model information based on selection mode if (modelSelectionMode === "unified") { @@ -495,11 +812,32 @@ export default function AgentImportWizard({ setConfigValues({}); setMcpServers([]); setIsImporting(false); + setAgentNameConflicts({}); + agentNameConflictsRef.current = {}; + setCheckingName(false); + setRegeneratingAll(false); + setSuccessfullyRenamedAgents(new Set()); + if (nameCheckTimerRef.current) { + clearTimeout(nameCheckTimerRef.current); + nameCheckTimerRef.current = null; + } onCancel(); }; // Filter only required steps for navigation + // Show rename step if name conflict check is complete and there are any agents that had conflicts + // (even if all conflicts are now resolved, we still want to show the step so users can see the success state) + const hasAnyAgentsWithConflicts = !checkingName && ( + // Check if any agent has a current conflict + Object.values(agentNameConflicts).some(conflict => conflict.hasConflict) || + // OR if any agent was successfully renamed (meaning it had a conflict that was resolved) + successfullyRenamedAgents.size > 0 + ); const steps = [ + hasAnyAgentsWithConflicts && { + key: "rename", + title: t("market.install.step.rename", "Rename Agent"), + }, { key: "model", title: t("market.install.step.model", "Select Model"), @@ -516,9 +854,16 @@ export default function AgentImportWizard({ // Check if can proceed to next step const canProceed = () => { + // Disable buttons while checking name conflict + if (checkingName) { + return false; + } + const currentStepKey = steps[currentStep]?.key; - if (currentStepKey === "model") { + if (currentStepKey === "rename") { + return true; + } else if (currentStepKey === "model") { if (modelSelectionMode === "unified") { return selectedModelId !== null && selectedModelName !== ""; } else { @@ -545,9 +890,237 @@ export default function AgentImportWizard({ }; const renderStepContent = () => { + // Show loading state while checking name conflict + if (checkingName) { + return ( + + + + {t("market.install.checkingName", "Checking agent name...")} + + + ); + } + const currentStepKey = steps[currentStep]?.key; - if (currentStepKey === "model") { + if (currentStepKey === "rename") { + // Get all agents that had conflicts (including resolved ones) + // Show all agents in agentNameConflicts - they either have conflicts or were successfully renamed + const allAgentsWithConflicts = Object.entries(agentNameConflicts) + .filter(([agentKey, conflict]) => { + // Show agent if: + // 1. It currently has a conflict, OR + // 2. It was successfully renamed (in successfullyRenamedAgents), OR + // 3. It's in agentNameConflicts (meaning it was checked and had a conflict at some point) + // We show all agents in agentNameConflicts to keep the UI consistent + return true; // Show all agents that were checked + }) + .sort(([keyA], [keyB]) => { + // Main agent first + const mainAgentId = String(initialData?.agent_id); + if (keyA === mainAgentId) return -1; + if (keyB === mainAgentId) return 1; + return 0; + }); + + // Get agents that still have conflicts + const agentsWithConflicts = allAgentsWithConflicts.filter( + ([agentKey, conflict]) => conflict.hasConflict + ); + + // If no agents had conflicts at all, don't show rename step + if (allAgentsWithConflicts.length === 0) { + return null; + } + + // Check if all conflicts are resolved + const allConflictsResolved = agentsWithConflicts.length === 0 && allAgentsWithConflicts.length > 0; + const hasResolvedAgents = allAgentsWithConflicts.some( + ([agentKey]) => successfullyRenamedAgents.has(agentKey) + ); + + return ( + + {allConflictsResolved ? ( + + + + + {t("market.install.rename.success", "All agent name conflicts have been resolved. You can proceed to the next step.")} + + + + ) : ( + + {hasResolvedAgents && ( + + + + + {t("market.install.rename.partialSuccess", "Some agents have been successfully renamed.")} + + + + )} + + {t("market.install.rename.warning", "The agent name or display name conflicts with existing agents. Please rename to proceed.")} + + + {t("market.install.rename.oneClickDesc", "You can manually edit the names, or click one-click rename to let the selected model regenerate names for all conflicted agents.")} + + + {t("market.install.rename.note", "Note: If you proceed without renaming, the agent will be created but marked as unavailable due to name conflicts. You can rename it later in the agent list.")} + + + {t("market.install.rename.oneClick", "One-click Rename")} + + + )} + + + {allAgentsWithConflicts.map(([agentKey, conflict]) => { + const agentInfo = initialData?.agent_info?.[agentKey] as any; + const agentDisplayName = agentInfo?.display_name || agentInfo?.name || `${t("market.install.agent.defaultName", "Agent")} ${agentKey}`; + const isMainAgent = agentKey === String(initialData?.agent_id); + const originalName = agentInfo?.name || ""; + const originalDisplayName = agentInfo?.display_name || ""; + + return ( + + + + {isMainAgent && {t("market.install.agent.main", "Main")}} + {agentDisplayName} + + + + {successfullyRenamedAgents.has(agentKey) ? ( + + + + + {t("market.install.rename.agentResolved", "This agent's name conflict has been resolved.")} + + + + ) : conflict.hasConflict && conflict.conflictAgents.length > 0 && ( + + + {t("market.install.rename.conflictAgents", "Conflicting agents:")} + + + {conflict.conflictAgents.map((agent, idx) => ( + + {[agent.name, agent.display_name].filter(Boolean).join(" / ")} + + ))} + + + )} + + + + {t("market.install.rename.name", "Agent Name")} + + { + const newName = e.target.value; + setAgentNameConflicts(prev => { + const updated = { + ...prev, + [agentKey]: { + ...prev[agentKey], + renamedName: newName, + }, + }; + + // Clear existing timer + if (nameCheckTimerRef.current) { + clearTimeout(nameCheckTimerRef.current); + } + + // Set new timer for debounced check (500ms delay) + nameCheckTimerRef.current = setTimeout(() => { + // Read latest value from ref when timer fires + const currentConflict = agentNameConflictsRef.current[agentKey]; + if (currentConflict) { + checkSingleAgentConflict( + agentKey, + currentConflict.renamedName, + currentConflict.renamedDisplayName + ); + } + }, 500); + + agentNameConflictsRef.current = updated; + return updated; + }); + }} + placeholder={originalName} + size="large" + disabled={regeneratingAll} + /> + + + + + {t("market.install.rename.displayName", "Display Name")} + + { + const newDisplayName = e.target.value; + setAgentNameConflicts(prev => { + const updated = { + ...prev, + [agentKey]: { + ...prev[agentKey], + renamedDisplayName: newDisplayName, + }, + }; + + // Clear existing timer + if (nameCheckTimerRef.current) { + clearTimeout(nameCheckTimerRef.current); + } + + // Set new timer for debounced check (500ms delay) + nameCheckTimerRef.current = setTimeout(() => { + // Read latest value from ref when timer fires + const currentConflict = agentNameConflictsRef.current[agentKey]; + if (currentConflict) { + checkSingleAgentConflict( + agentKey, + currentConflict.renamedName, + currentConflict.renamedDisplayName + ); + } + }, 500); + + agentNameConflictsRef.current = updated; + return updated; + }); + }} + placeholder={originalDisplayName} + size="large" + disabled={regeneratingAll} + /> + + + ); + })} + + + + ); + } else if (currentStepKey === "model") { return ( {/* Agent Info - Title and Description Style */} @@ -819,47 +1392,22 @@ export default function AgentImportWizard({ {mcpServers.map((mcp, index) => ( - - - - - {mcp.mcp_server_name} - - {mcp.isInstalled ? ( - } color="success" className="text-sm"> - {t("market.install.mcp.installed", "Installed")} - - ) : ( - } color="default" className="text-sm"> - {t("market.install.mcp.notInstalled", "Not Installed")} - - )} - - - - - MCP URL: - - {(mcp.isUrlEditable || !mcp.isInstalled) ? ( - handleMcpUrlChange(index, e.target.value)} - placeholder={mcp.isUrlEditable - ? t("market.install.mcp.urlPlaceholder", "Enter MCP server URL") - : mcp.mcp_url - } - size="middle" - disabled={mcp.isInstalled} - style={{ maxWidth: "400px" }} - /> - ) : ( - - {mcp.editedUrl || mcp.mcp_url} - - )} - + + + + {mcp.mcp_server_name} + + {mcp.isInstalled ? ( + } color="success" className="text-xs"> + {t("market.install.mcp.installed", "Installed")} + + ) : ( + } color="default" className="text-xs"> + {t("market.install.mcp.notInstalled", "Not Installed")} + + )} {!mcp.isInstalled && ( @@ -876,6 +1424,44 @@ export default function AgentImportWizard({ )} + + + + + MCP URL: + + {(mcp.isUrlEditable || !mcp.isInstalled) ? ( + handleMcpUrlChange(index, e.target.value)} + placeholder={mcp.isUrlEditable + ? t("market.install.mcp.urlPlaceholder", "Enter MCP server URL") + : mcp.mcp_url + } + size="middle" + disabled={mcp.isInstalled} + style={{ maxWidth: "400px" }} + className={mcp.isUrlEditable && needsConfig(mcp.mcp_url) ? "bg-gray-100 dark:bg-gray-800" : ""} + /> + ) : ( + + {mcp.editedUrl || mcp.mcp_url} + + )} + + {/* Show hint if URL needs configuration */} + {mcp.isUrlEditable && needsConfig(mcp.mcp_url) && (() => { + const hint = extractPromptHint(mcp.mcp_url); + const hintText = hint || t("market.install.mcp.defaultConfigHint", "Please enter the MCP server URL"); + return ( + + + {parseMarkdownLinks(hintText)} + + + ); + })()} + ))} @@ -946,7 +1532,7 @@ export default function AgentImportWizard({ className="mb-6" /> - + {renderStepContent()} diff --git a/frontend/components/ui/markdownRenderer.tsx b/frontend/components/ui/markdownRenderer.tsx index e192d5189..12bcc7eeb 100644 --- a/frontend/components/ui/markdownRenderer.tsx +++ b/frontend/components/ui/markdownRenderer.tsx @@ -15,6 +15,7 @@ import * as TooltipPrimitive from "@radix-ui/react-tooltip"; import { visit } from "unist-util-visit"; import { SearchResult } from "@/types/chat"; +import { resolveS3UrlToDataUrl } from "@/services/storageService"; import { Tooltip, TooltipContent, @@ -31,8 +32,267 @@ interface MarkdownRendererProps { showDiagramToggle?: boolean; onCitationHover?: () => void; enableMultimodal?: boolean; + /** + * When true, resolve s3:// media URLs in markdown into data URLs (base64) + * so that images can still be displayed after page refresh or when + * the original S3 URL is not directly accessible by the browser. + */ + resolveS3Media?: boolean; } +// Simple in-memory cache to avoid refetching the same S3 object multiple times +const s3MediaCache = new Map(); +const mediaObjectUrlCache = new Map(); +const mediaObjectUrlPromiseCache = new Map>(); +const S3_MEDIA_SESSION_PREFIX = "s3-media-cache:"; + +const isBrowserEnvironment = typeof window !== "undefined"; + +const getSessionCachedValue = (key: string): string | null => { + if (!isBrowserEnvironment) { + return null; + } + try { + return window.sessionStorage.getItem(key); + } catch { + return null; + } +}; + +const getCachedMediaSrc = (src: string): string | null => { + const cached = s3MediaCache.get(src); + if (cached) { + return cached; + } + const sessionValue = getSessionCachedValue(src); + if (sessionValue) { + s3MediaCache.set(src, sessionValue); + return sessionValue; + } + return null; +}; + +const setCachedMediaSrc = (src: string, value: string) => { + s3MediaCache.set(src, value); + if (!isBrowserEnvironment) { + return; + } + try { + window.sessionStorage.setItem(`${S3_MEDIA_SESSION_PREFIX}${src}`, value); + } catch { + // Ignore storage quota errors silently. + } +}; + +const setCachedObjectUrl = (src: string, objectUrl: string | null) => { + if (!objectUrl) { + return; + } + const existing = mediaObjectUrlCache.get(src); + if (existing && existing !== objectUrl) { + URL.revokeObjectURL(existing); + } + mediaObjectUrlCache.set(src, objectUrl); +}; + +const resolveMediaToObjectUrl = async ( + src: string, + { resolveS3 }: { resolveS3: boolean } +): Promise => { + try { + if (src.startsWith("blob:")) { + return src; + } + + if (src.startsWith("s3://")) { + if (!resolveS3) { + return null; + } + const dataUrl = await resolveS3UrlToDataUrl(src); + if (!dataUrl) { + return null; + } + const response = await fetch(dataUrl); + if (!response.ok) { + return null; + } + const blob = await response.blob(); + return URL.createObjectURL(blob); + } + + if ( + src.startsWith("http://") || + src.startsWith("https://") || + src.startsWith("/api/") || + src.startsWith("/nexent/") || + src.startsWith("/attachments/") || + src.startsWith("/") + ) { + const response = await fetch(src); + if (!response.ok) { + return null; + } + const blob = await response.blob(); + return URL.createObjectURL(blob); + } + + if (src.startsWith("data:")) { + const response = await fetch(src); + if (!response.ok) { + return null; + } + const blob = await response.blob(); + return URL.createObjectURL(blob); + } + + return null; + } catch { + return null; + } +}; + +const usePrefetchedMediaSource = ( + src?: string, + options?: { enable?: boolean; resolveS3?: boolean } +) => { + const shouldPrefetch = + Boolean( + options?.enable && + src && + typeof src === "string" && + !src.startsWith("blob:") && + (src.startsWith("s3://") || + src.startsWith("http://") || + src.startsWith("https://") || + src.startsWith("/")) + ) || false; + + const [resolvedSrc, setResolvedSrc] = React.useState(() => { + if (!src || typeof src !== "string") { + return null; + } + if (!shouldPrefetch) { + return src; + } + return mediaObjectUrlCache.get(src) ?? null; + }); + + React.useEffect(() => { + if (!src || typeof src !== "string") { + setResolvedSrc(null); + return; + } + + if (!shouldPrefetch) { + setResolvedSrc(src); + return; + } + + const cached = mediaObjectUrlCache.get(src); + if (cached) { + setResolvedSrc(cached); + return; + } + + let cancelled = false; + + const promise = + mediaObjectUrlPromiseCache.get(src) ?? + resolveMediaToObjectUrl(src, { + resolveS3: options?.resolveS3 ?? true, + }); + + mediaObjectUrlPromiseCache.set(src, promise); + + promise + .then((objectUrl) => { + if (cancelled) { + return; + } + if (!objectUrl) { + setResolvedSrc(null); + return; + } + setCachedObjectUrl(src, objectUrl); + setResolvedSrc(objectUrl); + }) + .catch(() => { + if (!cancelled) { + setResolvedSrc(null); + } + }) + .finally(() => { + mediaObjectUrlPromiseCache.delete(src); + }); + + return () => { + cancelled = true; + }; + }, [options?.resolveS3, shouldPrefetch, src]); + + return resolvedSrc; +}; + +const useResolvedS3Media = (src?: string, shouldResolve?: boolean) => { + const cachedInitial = + typeof src === "string" && src.startsWith("s3://") + ? getCachedMediaSrc(src) + : null; + const initialValue = + typeof src === "string" + ? !shouldResolve || !src.startsWith("s3://") + ? src + : cachedInitial + : null; + const [resolvedSrc, setResolvedSrc] = React.useState( + initialValue + ); + + React.useEffect(() => { + if (!src || typeof src !== "string") { + setResolvedSrc(null); + return; + } + + if (!shouldResolve || !src.startsWith("s3://")) { + setResolvedSrc(src); + return; + } + + const cached = getCachedMediaSrc(src); + if (cached) { + setResolvedSrc(cached); + return; + } + + let cancelled = false; + + resolveS3UrlToDataUrl(src) + .then((dataUrl) => { + if (cancelled) { + return; + } + if (dataUrl) { + setCachedMediaSrc(src, dataUrl); + setResolvedSrc(dataUrl); + } else { + setResolvedSrc(null); + } + }) + .catch(() => { + if (!cancelled) { + setResolvedSrc(null); + } + }); + + return () => { + cancelled = true; + }; + }, [src, shouldResolve]); + + return resolvedSrc; +}; + const VIDEO_EXTENSIONS = [".mp4", ".webm", ".ogg", ".mov", ".m4v"]; const extractExtension = (value: string): string => { @@ -519,20 +779,16 @@ const ImageWithErrorHandling: React.FC = React.memo ImageWithErrorHandling.displayName = "ImageWithErrorHandling"; -export const MarkdownRenderer: React.FC = ({ - content, - className, - searchResults = [], - showDiagramToggle = true, - onCitationHover, - enableMultimodal = true, -}) => { +/** + * Render a code block with syntax highlighting, language label, and copy button + * This is exported for use in other components that need to render code blocks directly + */ +export const CodeBlock: React.FC<{ + codeContent: string; + language?: string; +}> = ({ codeContent, language = "python" }) => { const { t } = useTranslation("common"); - - // Convert LaTeX delimiters to markdown math delimiters - const processedContent = convertLatexDelimiters(content); - - // Customize code block style with light gray background + const customStyle = { ...oneLight, 'pre[class*="language-"]': { @@ -569,6 +825,47 @@ export const MarkdownRenderer: React.FC = ({ }, }; + const cleanedContent = codeContent.replace(/^\n+|\n+$/g, ""); + + return ( + + + + {language} + + + + + + {cleanedContent} + + + + ); +}; + +export const MarkdownRenderer: React.FC = ({ + content, + className, + searchResults = [], + showDiagramToggle = true, + onCitationHover, + enableMultimodal = true, + resolveS3Media = false, +}) => { + const { t } = useTranslation("common"); + + // Convert LaTeX delimiters to markdown math delimiters + const processedContent = convertLatexDelimiters(content); + const renderCodeFallback = (text: string, key?: React.Key) => ( = ({ return ; }; + const ImageResolver: React.FC<{ src?: string; alt?: string | null }> = ({ + src, + alt, + }) => { + const resolvedSrc = useResolvedS3Media( + typeof src === "string" ? src : undefined, + resolveS3Media + ); + + if (!enableMultimodal) { + return renderMediaFallback(src, alt); + } + + if (!resolvedSrc) { + return renderMediaFallback(src, alt); + } + + if (isVideoUrl(resolvedSrc)) { + return renderVideoElement({ src: resolvedSrc, alt }); + } + + return ; + }; + // Modified processText function logic const processText = (text: string) => { if (typeof text !== "string") return text; @@ -865,37 +1186,7 @@ export const MarkdownRenderer: React.FC = ({ return ; } if (!inline) { - return ( - - - - {match[1]} - - - - - - {codeContent} - - - - ); + return ; } } } catch (error) { @@ -908,21 +1199,9 @@ export const MarkdownRenderer: React.FC = ({ ); }, // Image - img: ({ src, alt }: any) => { - if (!enableMultimodal) { - return renderMediaFallback(src, alt); - } - - if (isVideoUrl(src)) { - return renderVideoElement({ src, alt }); - } - - if (!src || typeof src !== "string") { - return null; - } - - return ; - }, + img: ({ src, alt }: any) => ( + + ), // Video video: ({ children, ...props }: any) => { const directSrc = props?.src; diff --git a/frontend/const/chatConfig.ts b/frontend/const/chatConfig.ts index df7b65c92..73cd19aed 100644 --- a/frontend/const/chatConfig.ts +++ b/frontend/const/chatConfig.ts @@ -111,6 +111,7 @@ messageTypes: { // Content type constants for last content type tracking contentTypes: { MODEL_OUTPUT: "model_output" as const, + MODEL_OUTPUT_CODE: "model_output_code" as const, PARSING: "parsing" as const, EXECUTION: "execution" as const, AGENT_NEW_RUN: "agent_new_run" as const, diff --git a/frontend/const/marketConfig.ts b/frontend/const/marketConfig.ts new file mode 100644 index 000000000..6de8d1f48 --- /dev/null +++ b/frontend/const/marketConfig.ts @@ -0,0 +1,36 @@ +// ========== Market Configuration Constants ========== + +/** + * Default icons for market agent categories + * Maps category name field to their corresponding icons + */ +export const MARKET_CATEGORY_ICONS: Record = { + research: "🔬", + content: "✍️", + development: "💻", + business: "📈", + automation: "⚙️", + education: "📚", + communication: "💬", + data: "📊", + creative: "🎨", + other: "📦", +} as const; + +/** + * Get icon for a category by name field + * @param categoryName - Category name field (e.g., "research", "content") + * @param fallbackIcon - Fallback icon if category not found (default: 📦) + * @returns Icon emoji string + */ +export function getCategoryIcon( + categoryName: string | null | undefined, + fallbackIcon: string = "📦" +): string { + if (!categoryName) { + return fallbackIcon; + } + + return MARKET_CATEGORY_ICONS[categoryName] || fallbackIcon; +} + diff --git a/frontend/hooks/useAgentImport.md b/frontend/hooks/useAgentImport.md deleted file mode 100644 index 52b14aa78..000000000 --- a/frontend/hooks/useAgentImport.md +++ /dev/null @@ -1,245 +0,0 @@ -# useAgentImport Hook - -Unified agent import hook for handling agent imports across the application. - -## Overview - -This hook provides a consistent interface for importing agents from different sources: -- File upload (used in Agent Development and Agent Space) -- Direct data (used in Agent Market) - -All import operations ultimately call the same backend `/agent/import` endpoint. - -## Usage - -### Basic Import - -```typescript -import { useAgentImport } from "@/hooks/useAgentImport"; - -function MyComponent() { - const { isImporting, importFromFile, importFromData, error } = useAgentImport({ - onSuccess: () => { - console.log("Import successful!"); - }, - onError: (error) => { - console.error("Import failed:", error); - }, - }); - - // ... -} -``` - -### Import from File (SubAgentPool, SpaceContent) - -```typescript -const handleFileImport = async (file: File) => { - try { - await importFromFile(file); - // Success handled by onSuccess callback - } catch (error) { - // Error handled by onError callback - } -}; - -// In file input handler - { - const file = e.target.files?.[0]; - if (file) { - handleFileImport(file); - } - }} -/> -``` - -### Import from Data (Market) - -```typescript -const handleMarketImport = async (agentDetails: MarketAgentDetail) => { - // Prepare import data from agent details - const importData = { - agent_id: agentDetails.agent_id, - agent_info: agentDetails.agent_json.agent_info, - mcp_info: agentDetails.agent_json.mcp_info, - }; - - try { - await importFromData(importData); - // Success handled by onSuccess callback - } catch (error) { - // Error handled by onError callback - } -}; -``` - -## Integration Examples - -### 1. SubAgentPool Component - -```typescript -// In SubAgentPool.tsx -import { useAgentImport } from "@/hooks/useAgentImport"; - -export default function SubAgentPool({ onImportSuccess }: Props) { - const { isImporting, importFromFile } = useAgentImport({ - onSuccess: () => { - message.success(t("agent.import.success")); - onImportSuccess?.(); - }, - onError: (error) => { - message.error(error.message); - }, - }); - - const handleImportClick = () => { - const input = document.createElement("input"); - input.type = "file"; - input.accept = ".json"; - input.onchange = async (e) => { - const file = (e.target as HTMLInputElement).files?.[0]; - if (file) { - await importFromFile(file); - } - }; - input.click(); - }; - - return ( - - {isImporting ? t("importing") : t("import")} - - ); -} -``` - -### 2. SpaceContent Component - -```typescript -// In SpaceContent.tsx -import { useAgentImport } from "@/hooks/useAgentImport"; - -export function SpaceContent({ onRefresh }: Props) { - const { isImporting, importFromFile } = useAgentImport({ - onSuccess: () => { - message.success(t("space.import.success")); - onRefresh(); // Reload agent list - }, - }); - - const handleImportAgent = () => { - const input = document.createElement("input"); - input.type = "file"; - input.accept = ".json"; - input.onchange = async (e) => { - const file = (e.target as HTMLInputElement).files?.[0]; - if (file) { - await importFromFile(file); - } - }; - input.click(); - }; - - return ( - - {isImporting ? "Importing..." : "Import Agent"} - - ); -} -``` - -### 3. AgentInstallModal (Market) - -```typescript -// In AgentInstallModal.tsx -import { useAgentImport } from "@/hooks/useAgentImport"; - -export default function AgentInstallModal({ - agentDetails, - onComplete -}: Props) { - const { isImporting, importFromData } = useAgentImport({ - onSuccess: () => { - message.success(t("market.install.success")); - onComplete(); - }, - }); - - const handleInstall = async () => { - // Prepare configured data - const importData = prepareImportData(agentDetails, userConfig); - await importFromData(importData); - }; - - return ( - - Install - - ); -} -``` - -## API Reference - -### Parameters - -```typescript -interface UseAgentImportOptions { - onSuccess?: () => void; // Called on successful import - onError?: (error: Error) => void; // Called on import error - forceImport?: boolean; // Force import even if duplicate names exist -} -``` - -### Return Value - -```typescript -interface UseAgentImportResult { - isImporting: boolean; // Import in progress - importFromFile: (file: File) => Promise; // Import from file - importFromData: (data: ImportAgentData) => Promise; // Import from data - error: Error | null; // Last error (if any) -} -``` - -### Data Structure - -```typescript -interface ImportAgentData { - agent_id: number; - agent_info: Record; - mcp_info?: Array<{ - mcp_server_name: string; - mcp_url: string; - }>; -} -``` - -## Error Handling - -The hook handles errors in two ways: - -1. **Via onError callback** - Preferred method for user-facing error messages -2. **Via thrown exceptions** - For custom error handling in specific cases - -Both approaches are supported to allow flexibility in different use cases. - -## Implementation Notes - -- File content is read as text and parsed as JSON -- Data structure validation is performed before calling the backend -- The backend `/agent/import` endpoint is called with the prepared data -- All logging uses the centralized `log` utility from `@/lib/logger` - diff --git a/frontend/hooks/useAgentImport.ts b/frontend/hooks/useAgentImport.ts index f0f33add4..0aff99e82 100644 --- a/frontend/hooks/useAgentImport.ts +++ b/frontend/hooks/useAgentImport.ts @@ -1,5 +1,9 @@ import { useState } from "react"; -import { importAgent } from "@/services/agentConfigService"; +import { + checkAgentNameConflictBatch, + importAgent, + regenerateAgentNameBatch, +} from "@/services/agentConfigService"; import log from "@/lib/logger"; export interface ImportAgentData { @@ -15,6 +19,19 @@ export interface UseAgentImportOptions { onSuccess?: () => void; onError?: (error: Error) => void; forceImport?: boolean; + /** + * Optional: handle name/display_name conflicts before import + * Caller can resolve by returning new name or choosing to continue/terminate + */ + onNameConflictResolve?: (payload: { + name: string; + displayName?: string; + conflictAgents: Array<{ id: string; name?: string; display_name?: string }>; + regenerateWithLLM: () => Promise<{ + name?: string; + displayName?: string; + }>; + }) => Promise<{ proceed: boolean; name?: string; displayName?: string }>; } export interface UseAgentImportResult { @@ -111,6 +128,30 @@ export function useAgentImport( * Core import logic - calls backend API */ const importAgentData = async (data: ImportAgentData): Promise => { + // Step 1: check name/display name conflicts before import (only check main agent name and display name) + const mainAgent = data.agent_info?.[String(data.agent_id)]; + if (mainAgent?.name) { + const conflictHandled = await ensureNameNotDuplicated( + mainAgent.name, + mainAgent.display_name, + mainAgent.description || mainAgent.business_description + ); + + if (!conflictHandled.proceed) { + throw new Error( + "Agent name/display name conflicts with existing agent; import cancelled." + ); + } + + // if user chooses to modify name, write back to import data + if (conflictHandled.name) { + mainAgent.name = conflictHandled.name; + } + if (conflictHandled.displayName) { + mainAgent.display_name = conflictHandled.displayName; + } + } + const result = await importAgent(data, { forceImport }); if (!result.success) { @@ -142,6 +183,80 @@ export function useAgentImport( }); }; + /** + * Frontend side name conflict validation logic + */ + const ensureNameNotDuplicated = async ( + name: string, + displayName?: string, + taskDescription?: string + ): Promise<{ proceed: boolean; name?: string; displayName?: string }> => { + try { + const checkResp = await checkAgentNameConflictBatch({ + items: [ + { + name, + display_name: displayName, + }, + ], + }); + if (!checkResp.success || !Array.isArray(checkResp.data)) { + log.warn("Skip name conflict check due to fetch failure"); + return { proceed: true }; + } + + const first = checkResp.data[0] || {}; + const { name_conflict, display_name_conflict, conflict_agents } = first; + + if (!name_conflict && !display_name_conflict) { + return { proceed: true }; + } + + const regenerateWithLLM = async () => { + const regenResp = await regenerateAgentNameBatch({ + items: [ + { + name, + display_name: displayName, + task_description: taskDescription, + }, + ], + }); + if (!regenResp.success || !Array.isArray(regenResp.data) || !regenResp.data[0]) { + throw new Error("Failed to regenerate agent name"); + } + const item = regenResp.data[0]; + return { + name: item.name, + displayName: item.display_name ?? displayName, + }; + }; + + // let caller decide how to handle conflicts (e.g. show a dialog to let user choose whether to let LLM rename) + if (options.onNameConflictResolve) { + return await options.onNameConflictResolve({ + name, + displayName, + conflictAgents: (conflict_agents || []).map((c: any) => ({ + id: String(c.agent_id ?? c.id), + name: c.name, + display_name: c.display_name, + })), + regenerateWithLLM, + }); + } + + // default behavior: directly call backend to rename to keep import available + const regenerated = await regenerateWithLLM(); + return { proceed: true, ...regenerated }; + } catch (error) { + // if callback throws an error, prevent import + throw error instanceof Error + ? error + : new Error("Name conflict handling failed"); + } + }; + return { isImporting, importFromFile, diff --git a/frontend/hooks/useMemory.ts b/frontend/hooks/useMemory.ts index 03ac72dd8..5bb1a1bc9 100644 --- a/frontend/hooks/useMemory.ts +++ b/frontend/hooks/useMemory.ts @@ -483,24 +483,3 @@ export function useMemory({ visible, currentUserId, currentTenantId, message }: handleDeleteMemory, } } - -// expose memory notification indicator to ChatHeader -export function useMemoryIndicator(modalVisible: boolean) { - const [hasNewMemory, setHasNewMemory] = useState(false) - - // Reset indicator when memory modal is opened - useEffect(() => { - if (modalVisible) { - setHasNewMemory(false) - } - }, [modalVisible]) - - // Listen for backend event that notifies new memory addition - useEffect(() => { - const handler = () => setHasNewMemory(true) - window.addEventListener("nexent:new-memory", handler as EventListener) - return () => window.removeEventListener("nexent:new-memory", handler as EventListener) - }, []) - - return hasNewMemory -} \ No newline at end of file diff --git a/frontend/public/locales/en/common.json b/frontend/public/locales/en/common.json index 5ee25a7b8..b8681a78a 100644 --- a/frontend/public/locales/en/common.json +++ b/frontend/public/locales/en/common.json @@ -285,6 +285,8 @@ "agent.contextMenu.export": "Export", "agent.contextMenu.delete": "Delete", + "agent.contextMenu.copy": "Copy", + "agent.copySuffix": "Copy", "agent.info.title": "Agent Information", "agent.info.name.error.empty": "Name cannot be empty", "agent.info.name.error.format": "Name can only contain letters, numbers and underscores, and must start with a letter or underscore", @@ -293,6 +295,9 @@ "agent.namePlaceholder": "Please enter agent variable name", "agent.displayName": "Agent Name", "agent.displayNamePlaceholder": "Please enter agent name", + "agent.author": "Author", + "agent.authorPlaceholder": "Please enter author name (optional)", + "agent.author.hint": "Default: {{email}}", "agent.description": "Agent Description", "agent.descriptionPlaceholder": "Please enter agent description", "agent.detailContent.title": "Agent Detail Content", @@ -413,7 +418,6 @@ "toolPool.error.requiredFields": "The following required fields are not filled: {{fields}}", "toolPool.tooltip.functionGuide": "1. For local knowledge base search functionality, please enable the knowledge_base_search tool;\n2. For text file parsing functionality, please enable the analyze_text_file tool;\n3. For image parsing functionality, please enable the analyze_image tool.", - "tool.message.unavailable": "This tool is currently unavailable and cannot be selected", "tool.error.noMainAgentId": "Main agent ID is not set, cannot update tool status", "tool.error.configFetchFailed": "Failed to get tool configuration", @@ -502,6 +506,7 @@ "document.summary.modelPlaceholder": "Select Model", "document.status.creating": "Creating...", "document.status.loadingList": "Loading document list...", + "document.status.waitingForTask": "Waiting for task creation...", "document.input.knowledgeBaseName": "Please enter knowledge base name", "document.button.details": "Details", "document.button.overview": "Overview", @@ -522,6 +527,24 @@ "document.status.completed": "Ready", "document.status.processFailed": "Process Failed", "document.status.forwardFailed": "Forward Failed", + "document.progress.chunksProcessed": "Processed {{processed}}/{{total}} chunks ({{percent}}%)", + "document.error.reason": "Error Reason", + "document.error.suggestion": "Suggestion", + "document.error.noReason": "No error reason available", + "document.error.code.ray_init_failed.message": "Failed to initialize Ray cluster", + "document.error.code.ray_init_failed.suggestion": "Please upgrade to the latest image version and redeploy.", + "document.error.code.no_valid_chunks.message": "The data processing kernel could not extract valid text from the document", + "document.error.code.no_valid_chunks.suggestion": "Please ensure the document format is supported and the content is not purely images.", + "document.error.code.vector_service_busy.message": "Vectorization model service is busy and cannot return vectors", + "document.error.code.vector_service_busy.suggestion": "Please switch the model service provider or try again later.", + "document.error.code.es_bulk_failed.message": "Failed to write vectors into the database", + "document.error.code.es_bulk_failed.suggestion": "Please ensure the Elasticsearch data path has sufficient disk space and write permissions.", + "document.error.code.es_dim_mismatch.message": "Embedding dimension does not match the Elasticsearch mapping", + "document.error.code.es_dim_mismatch.suggestion": "Please delete all embedding models and add the model again to try again.", + "document.error.code.embedding_chunks_exceed_limit.message": "The current chunk count exceeds the embedding model concurrency limit", + "document.error.code.embedding_chunks_exceed_limit.suggestion": "Please increase the chunk size to reduce the number of chunks and try again.", + "document.error.code.unsupported_file_format.message": "Unsupported line breaks detected in the document", + "document.error.code.unsupported_file_format.suggestion": "Please convert all line breaks to LF format and try again", "document.modal.deleteConfirm.title": "Confirm Delete Document", "document.modal.deleteConfirm.content": "Are you sure you want to delete this document? This action cannot be undone.", "document.message.noFiles": "Please select files first", @@ -655,6 +678,7 @@ "model.group.silicon": "Silicon Flow Models", "model.group.custom": "Custom Models", "model.status.tooltip": "Click to verify connectivity", + "model.dialog.embeddingConfig.title": "Edit Embedding Model: {{modelName}}", "appConfig.appName.label": "Application Name", "appConfig.appName.placeholder": "Please enter your application name", @@ -699,9 +723,14 @@ "modelConfig.button.addCustomModel": "Add Model", "modelConfig.button.editCustomModel": "Edit or Delete Model", "modelConfig.button.checkConnectivity": "Check Model Connectivity", + "modelConfig.button.sync": "Sync", + "modelConfig.button.add": "Add", + "modelConfig.button.edit": "Edit", + "modelConfig.button.check": "Check", "modelConfig.slider.chunkingSize": "Chunk Size", "modelConfig.slider.expectedChunkSize": "Expected Chunk Size", "modelConfig.slider.maximumChunkSize": "Maximum Chunk Size", + "modelConfig.input.chunkingBatchSize": "Concurrent Request Count", "businessLogic.title": "Describe how should this agent work", "businessLogic.placeholder": "Please describe your business scenario and requirements...", @@ -841,7 +870,7 @@ "mcpConfig.modal.updatingTools": "Updating tools list...", "mcpConfig.addServer.title": "Add MCP Server", "mcpConfig.addServer.namePlaceholder": "Server name", - "mcpConfig.addServer.urlPlaceholder": "Server URL (e.g.: http://localhost:3001/sse), currently only SSE protocol supported", + "mcpConfig.addServer.urlPlaceholder": "Server URL (e.g.: http://localhost:3001/mcp), currently supports sse and streamable-http protocols", "mcpConfig.addServer.button.add": "Add", "mcpConfig.addServer.button.updating": "Updating...", "mcpConfig.serverList.title": "Configured MCP Servers", @@ -909,6 +938,12 @@ "agentConfig.agents.createSubAgentIdFailed": "Failed to fetch creating sub agent ID, please try again later", "agentConfig.agents.detailsFetchFailed": "Failed to fetch agent details, please try again later", "agentConfig.agents.callRelationshipFetchFailed": "Failed to fetch agent call relationship, please try again later", + "agentConfig.agents.defaultDisplayName": "Agent", + "agentConfig.agents.copyConfirmTitle": "Confirm Copy", + "agentConfig.agents.copyConfirmContent": "Create a duplicate of {{name}}?", + "agentConfig.agents.copySuccess": "Agent copied successfully", + "agentConfig.agents.copyUnavailableTools": "Ignored {{count}} unavailable tools: {{names}}", + "agentConfig.agents.copyFailed": "Failed to copy Agent", "agentConfig.tools.refreshFailedDebug": "Failed to refresh tools list:", "agentConfig.agents.detailsLoadFailed": "Failed to load Agent details:", "agentConfig.agents.importFailed": "Failed to import Agent:", @@ -1117,6 +1152,7 @@ "market.category.all": "All", "market.category.other": "Other", "market.download": "Download", + "market.by": "By {{author}}", "market.downloading": "Downloading agent...", "market.downloadSuccess": "Agent downloaded successfully!", "market.downloadFailed": "Failed to download agent", @@ -1125,7 +1161,7 @@ "market.totalAgents": "Total {{total}} agents", "market.error.loadCategories": "Failed to load categories", "market.error.loadAgents": "Failed to load agents", - + "market.detail.title": "Agent Details", "market.detail.subtitle": "Complete information and configuration", "market.detail.tabs.basic": "Basic Info", @@ -1136,6 +1172,7 @@ "market.detail.id": "Agent ID", "market.detail.name": "Name", "market.detail.displayName": "Display Name", + "market.detail.author": "Author", "market.detail.description": "Description", "market.detail.businessDescription": "Business Description", "market.detail.category": "Category", @@ -1166,6 +1203,7 @@ "market.detail.viewDetails": "View Details", "market.install.title": "Install Agent", + "market.install.step.rename": "Rename Agent", "market.install.step.model": "Select Model", "market.install.step.config": "Configure Fields", "market.install.step.mcp": "MCP Servers", @@ -1203,7 +1241,31 @@ "market.install.error.mcpInstall": "Failed to install MCP server", "market.install.error.invalidData": "Invalid agent data", "market.install.error.installFailed": "Failed to install agent", + "market.install.error.noModelForRegeneration": "No available model for name regeneration", + "market.install.error.nameRegenerationFailed": "Failed to regenerate name", + "market.install.error.nameRequired": "Agent name is required", + "market.install.error.nameRequiredForAgent": "Agent name is required for {agent}", + "market.install.checkingName": "Checking agent name...", + "market.install.rename.warning": "The agent name or display name conflicts with existing agents. Please rename to proceed.", + "market.install.rename.conflictAgents": "Conflicting agents:", + "market.install.rename.name": "Agent Name", + "market.install.rename.regenerateWithLLM": "Regenerate with LLM", + "market.install.rename.regenerate": "Regenerate", + "market.install.rename.model": "Model for Regeneration", + "market.install.rename.modelPlaceholder": "Select a model", + "market.install.error.modelRequiredForRegeneration": "Please select a model first", + "market.install.rename.nameHint": "Original: {name}", + "market.install.rename.displayName": "Display Name", + "market.install.rename.displayNameHint": "Original: {name}", + "market.install.rename.note": "Note: If you proceed without renaming, the agent will be created but marked as unavailable due to name conflicts. You can rename it later in the agent list.", + "market.install.rename.oneClickDesc": "You can edit names manually, or use one-click rename to let the LLM generate new names for all conflicted agents.", + "market.install.rename.oneClick": "One-click Rename", + "market.install.rename.success": "All agent name conflicts have been resolved. You can proceed to the next step.", + "market.install.rename.partialSuccess": "Some agents have been successfully renamed.", + "market.install.rename.agentResolved": "This agent's name conflict has been resolved.", "market.install.success.mcpInstalled": "MCP server installed successfully", + "market.install.success.nameRegenerated": "Agent name regenerated successfully", + "market.install.success.nameRegeneratedAndResolved": "Agent names regenerated successfully and all conflicts resolved", "market.install.info.notImplemented": "Installation will be implemented in next phase", "market.install.success": "Agent installed successfully!", "market.error.fetchDetailFailed": "Failed to load agent details", @@ -1218,7 +1280,7 @@ "market.error.server.description": "The market server encountered an error. Our team has been notified. Please try again later.", "market.error.unknown.title": "Something Went Wrong", "market.error.unknown.description": "An unexpected error occurred. Please try again.", - + "common.loading": "Loading", "common.save": "Save", "common.cancel": "Cancel", diff --git a/frontend/public/locales/zh/common.json b/frontend/public/locales/zh/common.json index 65d80dacf..c0f8d851a 100644 --- a/frontend/public/locales/zh/common.json +++ b/frontend/public/locales/zh/common.json @@ -286,6 +286,8 @@ "agent.contextMenu.export": "导出", "agent.contextMenu.delete": "删除", + "agent.contextMenu.copy": "复制", + "agent.copySuffix": "副本", "agent.info.title": "Agent信息", "agent.info.name.error.empty": "名称不能为空", "agent.info.name.error.format": "名称只能包含字母、数字和下划线,且必须以字母或下划线开头", @@ -294,6 +296,9 @@ "agent.namePlaceholder": "请输入Agent变量名", "agent.displayName": "Agent名称", "agent.displayNamePlaceholder": "请输入Agent名称", + "agent.author": "作者", + "agent.authorPlaceholder": "请输入作者名称(可选)", + "agent.author.hint": "默认:{{email}}", "agent.description": "Agent描述", "agent.descriptionPlaceholder": "请输入Agent描述", "agent.detailContent.title": "Agent详细内容", @@ -370,7 +375,7 @@ "subAgentPool.tooltip.duplicateNameDisabled": "该智能体因与其他智能体同名而被禁用,请修改名称后使用", "subAgentPool.message.duplicateNameDisabled": "该智能体因与其他智能体同名而被禁用,请修改名称后使用", - "toolConfig.title.paramConfig": "参数配置", + "toolConfig.title.paramConfig": "配置参数", "toolConfig.message.loadError": "加载工具配置失败", "toolConfig.message.loadErrorUseDefault": "加载工具配置失败,使用默认配置", "toolConfig.message.saveSuccess": "工具配置保存成功", @@ -414,7 +419,6 @@ "toolPool.error.requiredFields": "以下必填字段未填写: {{fields}}", "toolPool.tooltip.functionGuide": "1. 本地知识库检索功能,请启用knowledge_base_search工具;\n2. 文本文件解析功能,请启用analyze_text_file工具;\n3. 图片解析功能,请启用analyze_image工具。", - "tool.message.unavailable": "该工具当前不可用,无法选择", "tool.error.noMainAgentId": "主代理ID未设置,无法更新工具状态", "tool.error.configFetchFailed": "获取工具配置失败", @@ -503,6 +507,7 @@ "document.summary.modelPlaceholder": "选择模型", "document.status.creating": "创建中...", "document.status.loadingList": "正在加载文档列表...", + "document.status.waitingForTask": "正在等待任务创建...", "document.input.knowledgeBaseName": "请输入知识库名称", "document.button.details": "详细内容", "document.button.overview": "概览", @@ -523,6 +528,24 @@ "document.status.completed": "已就绪", "document.status.processFailed": "解析失败", "document.status.forwardFailed": "入库失败", + "document.progress.chunksProcessed": "已处理 {{processed}}/{{total}} 个切片 ({{percent}}%)", + "document.error.reason": "错误原因", + "document.error.suggestion": "建议", + "document.error.noReason": "暂无错误原因", + "document.error.code.ray_init_failed.message": "Ray集群初始化失败", + "document.error.code.ray_init_failed.suggestion": "请升级到最新版本并尝试重新部署", + "document.error.code.no_valid_chunks.message": "数据处理内核无法从文档中提取有效文本", + "document.error.code.no_valid_chunks.suggestion": "请确保文档内容非纯图像", + "document.error.code.vector_service_busy.message": "向量化模型服务繁忙,无法获取文本向量", + "document.error.code.vector_service_busy.suggestion": "请更换模型服务提供商,或稍后重试", + "document.error.code.es_bulk_failed.message": "向量录入数据库错误", + "document.error.code.es_bulk_failed.suggestion": "请确保Elasticsearch路径拥有完整写入权限,且存储空间与内存充足", + "document.error.code.es_dim_mismatch.message": "向量化模型维度与Elasticsearch维度不匹配", + "document.error.code.es_dim_mismatch.suggestion": "建议删除所有向量化模型后再添加模型重试", + "document.error.code.embedding_chunks_exceed_limit.message": "当前切片数量超过向量化模型并行度", + "document.error.code.embedding_chunks_exceed_limit.suggestion": "请增加切片大小以减少切片数量后再试", + "document.error.code.unsupported_file_format.message": "检测到当前文档中存在不支持的换行符", + "document.error.code.unsupported_file_format.suggestion": "建议统一转换为LF换行符再试", "document.modal.deleteConfirm.title": "确认删除文档", "document.modal.deleteConfirm.content": "确定要删除这个文档吗?删除后无法恢复。", "document.message.noFiles": "请先选择文件", @@ -655,6 +678,7 @@ "model.group.custom": "自定义模型", "model.status.tooltip": "点击可验证连通性", "model.dialog.success.updateSuccess": "更新成功", + "model.dialog.embeddingConfig.title": "修改向量模型: {{modelName}}", "appConfig.appName.label": "应用名称", "appConfig.appName.placeholder": "请输入您的应用名称", @@ -699,9 +723,14 @@ "modelConfig.button.addCustomModel": "添加模型", "modelConfig.button.editCustomModel": "修改或删除模型", "modelConfig.button.checkConnectivity": "检查模型连通性", + "modelConfig.button.sync": "同步", + "modelConfig.button.add": "添加", + "modelConfig.button.edit": "修改", + "modelConfig.button.check": "检查", "modelConfig.slider.chunkingSize": "文档切片大小", "modelConfig.slider.expectedChunkSize": "期望切片大小", "modelConfig.slider.maximumChunkSize": "最大切片大小", + "modelConfig.input.chunkingBatchSize": "单次请求切片量", "businessLogic.title": "描述 Agent 应该如何工作", "businessLogic.placeholder": "请描述您的业务场景和需求...", @@ -841,7 +870,7 @@ "mcpConfig.modal.updatingTools": "正在更新工具列表...", "mcpConfig.addServer.title": "添加MCP服务器", "mcpConfig.addServer.namePlaceholder": "服务器名称", - "mcpConfig.addServer.urlPlaceholder": "服务器URL (如: http://localhost:3001/sse),目前仅支持sse协议", + "mcpConfig.addServer.urlPlaceholder": "服务器URL (如: http://localhost:3001/mcp),目前支持sse和streamable-http协议", "mcpConfig.addServer.button.add": "添加", "mcpConfig.addServer.button.updating": "更新中...", "mcpConfig.serverList.title": "已配置的MCP服务器", @@ -909,6 +938,12 @@ "agentConfig.agents.createSubAgentIdFailed": "获取创建子Agent ID失败,请稍后重试", "agentConfig.agents.detailsFetchFailed": "获取Agent详情失败,请稍后重试", "agentConfig.agents.callRelationshipFetchFailed": "获取Agent调用关系失败,请稍后重试", + "agentConfig.agents.defaultDisplayName": "智能体", + "agentConfig.agents.copyConfirmTitle": "确认复制", + "agentConfig.agents.copyConfirmContent": "确定要复制 {{name}} 吗?", + "agentConfig.agents.copySuccess": "Agent复制成功", + "agentConfig.agents.copyUnavailableTools": "已忽略{{count}}个不可用工具:{{names}}", + "agentConfig.agents.copyFailed": "Agent复制失败", "agentConfig.tools.refreshFailedDebug": "刷新工具列表失败:", "agentConfig.agents.detailsLoadFailed": "加载Agent详情失败:", "agentConfig.agents.importFailed": "导入Agent失败:", @@ -1081,7 +1116,7 @@ "sidebar.memoryManagement": "记忆管理", "sidebar.userManagement": "用户管理", "sidebar.mcpToolsManagement": "MCP 工具", - "sidebar.monitoringManagement": "监控与运维", + "sidebar.monitoringManagement": "监控与运维", "market.comingSoon.title": "智能体市场即将推出", "market.comingSoon.description": "从我们的市场中发现并安装预构建的AI智能体。通过使用社区创建的解决方案节省时间。", @@ -1096,6 +1131,7 @@ "market.category.all": "全部", "market.category.other": "其他", "market.download": "下载", + "market.by": "作者:{{author}}", "market.downloading": "正在下载智能体...", "market.downloadSuccess": "智能体下载成功!", "market.downloadFailed": "下载智能体失败", @@ -1104,7 +1140,7 @@ "market.totalAgents": "共 {{total}} 个智能体", "market.error.loadCategories": "加载分类失败", "market.error.loadAgents": "加载智能体失败", - + "market.detail.title": "智能体详情", "market.detail.subtitle": "完整信息和配置", "market.detail.tabs.basic": "基础信息", @@ -1115,6 +1151,7 @@ "market.detail.id": "智能体 ID", "market.detail.name": "名称", "market.detail.displayName": "显示名称", + "market.detail.author": "作者", "market.detail.description": "描述", "market.detail.businessDescription": "业务描述", "market.detail.category": "分类", @@ -1145,6 +1182,7 @@ "market.detail.viewDetails": "查看详情", "market.install.title": "安装智能体", + "market.install.step.rename": "重命名智能体", "market.install.step.model": "选择模型", "market.install.step.config": "配置字段", "market.install.step.mcp": "MCP 服务器", @@ -1182,7 +1220,31 @@ "market.install.error.mcpInstall": "安装 MCP 服务器失败", "market.install.error.invalidData": "无效的智能体数据", "market.install.error.installFailed": "安装智能体失败", + "market.install.error.noModelForRegeneration": "没有可用的模型用于名称重新生成", + "market.install.error.nameRegenerationFailed": "重新生成名称失败", + "market.install.error.nameRequired": "智能体名称为必填项", + "market.install.error.nameRequiredForAgent": "智能体 {agent} 的名称为必填项", + "market.install.checkingName": "正在检查智能体名称...", + "market.install.rename.warning": "智能体名称或显示名称与现有智能体冲突,请重命名以继续。", + "market.install.rename.conflictAgents": "冲突的智能体:", + "market.install.rename.name": "智能体名称", + "market.install.rename.regenerateWithLLM": "使用 LLM 重新生成", + "market.install.rename.regenerate": "重新生成", + "market.install.rename.model": "用于重新生成名称的模型", + "market.install.rename.modelPlaceholder": "选择一个模型", + "market.install.error.modelRequiredForRegeneration": "请先选择一个模型", + "market.install.rename.nameHint": "原始名称:{name}", + "market.install.rename.displayName": "显示名称", + "market.install.rename.displayNameHint": "原始名称:{name}", + "market.install.rename.note": "注意:如果您不重命名就继续,智能体将被创建但由于名称冲突会被标记为不可用。您可以在智能体列表中稍后重命名。", + "market.install.rename.oneClickDesc": "可手动修改名称,或一键重命名使用大模型为所有冲突智能体生成新名称。", + "market.install.rename.oneClick": "一键重命名", + "market.install.rename.success": "所有智能体名称冲突已解决。您可以继续下一步。", + "market.install.rename.partialSuccess": "部分智能体已成功重命名。", + "market.install.rename.agentResolved": "此智能体的名称冲突已解决。", "market.install.success.mcpInstalled": "MCP 服务器安装成功", + "market.install.success.nameRegenerated": "智能体名称重新生成成功", + "market.install.success.nameRegeneratedAndResolved": "智能体名称重新生成成功,且所有冲突已解决", "market.install.info.notImplemented": "安装功能将在下一阶段实现", "market.install.success": "智能体安装成功!", "market.error.fetchDetailFailed": "加载智能体详情失败", @@ -1211,14 +1273,14 @@ "mcpTools.comingSoon.feature2": "同步、查看和组织 MCP 工具列表", "mcpTools.comingSoon.feature3": "监控 MCP 连接状态和使用情况", "mcpTools.comingSoon.badge": "即将推出", - + "monitoring.comingSoon.title": "监控与运维中心即将推出", "monitoring.comingSoon.description": "面向智能体的统一监控与运维中心,用于实时跟踪健康状态、性能指标与异常事件。", "monitoring.comingSoon.feature1": "监控智能体健康状态、延迟与错误率", "monitoring.comingSoon.feature2": "查看并筛选智能体运行日志和历史任务", "monitoring.comingSoon.feature3": "配置告警策略与关键事件的运维操作", "monitoring.comingSoon.badge": "即将推出", - + "common.loading": "加载中", "common.save": "保存", "common.cancel": "取消", diff --git a/frontend/services/agentConfigService.ts b/frontend/services/agentConfigService.ts index f7f084f6b..3cff1e884 100644 --- a/frontend/services/agentConfigService.ts +++ b/frontend/services/agentConfigService.ts @@ -116,6 +116,7 @@ export const fetchAgentList = async () => { name: agent.name, display_name: agent.display_name || agent.name, description: agent.description, + author: agent.author, is_available: agent.is_available, unavailable_reasons: agent.unavailable_reasons || [], })); @@ -326,7 +327,8 @@ export const updateAgent = async ( businessLogicModelName?: string, businessLogicModelId?: number, enabledToolIds?: number[], - relatedAgentIds?: number[] + relatedAgentIds?: number[], + author?: string ) => { try { const response = await fetch(API_ENDPOINTS.agent.update, { @@ -350,6 +352,7 @@ export const updateAgent = async ( business_logic_model_id: businessLogicModelId, enabled_tool_ids: enabledToolIds, related_agent_ids: relatedAgentIds, + author: author, }), }); @@ -485,6 +488,76 @@ export const importAgent = async ( } }; +/** + * check agent name/display_name duplication + * @param payload name/displayName to check + */ +export const checkAgentNameConflictBatch = async (payload: { + items: Array<{ name: string; display_name?: string; agent_id?: number }>; +}) => { + try { + const response = await fetch(API_ENDPOINTS.agent.checkNameBatch, { + method: "POST", + headers: getAuthHeaders(), + body: JSON.stringify(payload), + }); + + if (!response.ok) { + throw new Error(`Request failed: ${response.status}`); + } + + const data = await response.json(); + return { + success: true, + data, + message: "", + }; + } catch (error) { + log.error("Failed to check agent name conflict batch:", error); + return { + success: false, + data: null, + message: "agentConfig.agents.checkNameFailed", + }; + } +}; + +export const regenerateAgentNameBatch = async (payload: { + items: Array<{ + name: string; + display_name?: string; + task_description?: string; + language?: string; + agent_id?: number; + }>; +}) => { + try { + const response = await fetch(API_ENDPOINTS.agent.regenerateNameBatch, { + method: "POST", + headers: getAuthHeaders(), + body: JSON.stringify(payload), + }); + + if (!response.ok) { + throw new Error(`Request failed: ${response.status}`); + } + + const data = await response.json(); + return { + success: true, + data, + message: "", + }; + } catch (error) { + log.error("Failed to regenerate agent name batch:", error); + return { + success: false, + data: null, + message: "agentConfig.agents.regenerateNameFailed", + }; + } +}; + /** * search agent info by agent id * @param agentId agent id @@ -510,6 +583,7 @@ export const searchAgentInfo = async (agentId: number) => { name: data.name, display_name: data.display_name, description: data.description, + author: data.author, model: data.model_name, model_id: data.model_id, max_step: data.max_steps, @@ -587,6 +661,7 @@ export const fetchAllAgents = async () => { name: agent.name, display_name: agent.display_name || agent.name, description: agent.description, + author: agent.author, is_available: agent.is_available, })); diff --git a/frontend/services/api.ts b/frontend/services/api.ts index 0af193d52..20d89b6f2 100644 --- a/frontend/services/api.ts +++ b/frontend/services/api.ts @@ -37,6 +37,8 @@ export const API_ENDPOINTS = { `${API_BASE_URL}/agent/stop/${conversationId}`, export: `${API_BASE_URL}/agent/export`, import: `${API_BASE_URL}/agent/import`, + checkNameBatch: `${API_BASE_URL}/agent/check_name`, + regenerateNameBatch: `${API_BASE_URL}/agent/regenerate_name`, searchInfo: `${API_BASE_URL}/agent/search_info`, callRelationship: `${API_BASE_URL}/agent/call_relationship`, }, @@ -142,6 +144,11 @@ export const API_ENDPOINTS = { // File upload service upload: `${API_BASE_URL}/file/upload`, process: `${API_BASE_URL}/file/process`, + // Error info service + getErrorInfo: (indexName: string, pathOrUrl: string) => + `${API_BASE_URL}/indices/${indexName}/documents/${encodeURIComponent( + pathOrUrl + )}/error-info`, }, config: { save: `${API_BASE_URL}/config/save_config`, diff --git a/frontend/services/knowledgeBasePollingService.ts b/frontend/services/knowledgeBasePollingService.ts index 568205b21..b899d8bdf 100644 --- a/frontend/services/knowledgeBasePollingService.ts +++ b/frontend/services/knowledgeBasePollingService.ts @@ -11,8 +11,12 @@ class KnowledgeBasePollingService { private knowledgeBasePollingInterval: number = 1000; // 1 second private documentPollingInterval: number = 3000; // 3 seconds private maxKnowledgeBasePolls: number = 60; // Maximum 60 polling attempts - private maxDocumentPolls: number = 20; // Maximum 20 polling attempts + private maxDocumentPolls: number = 200; // Maximum 200 polling attempts (10 minutes for long-running tasks) private activeKnowledgeBaseId: string | null = null; // Record current active knowledge base ID + private pendingRequests: Map> = new Map(); + + // Debounce timers for batching multiple rapid requests + private debounceTimers: Map = new Map(); // Set current active knowledge base ID setActiveKnowledgeBase(kbId: string | null): void { @@ -29,11 +33,16 @@ class KnowledgeBasePollingService { // Initialize polling counter let pollCount = 0; + // Track if we're in extended polling mode (after initial timeout) + let isExtendedPolling = false; + // Define the polling logic function const pollDocuments = async () => { try { - // Increment polling counter - pollCount++; + // Increment polling counter only if not in extended polling mode + if (!isExtendedPolling) { + pollCount++; + } // If there is an active knowledge base and polling knowledge base doesn't match active one, stop polling if (this.activeKnowledgeBaseId !== null && this.activeKnowledgeBaseId !== kbId) { @@ -41,24 +50,28 @@ class KnowledgeBasePollingService { return; } - // If exceeded maximum polling count, handle timeout - if (pollCount > this.maxDocumentPolls) { - log.warn(`Document polling for knowledge base ${kbId} timed out after ${this.maxDocumentPolls} attempts`); - await this.handlePollingTimeout(kbId, 'document', callback); - // Push documents to UI + // Use request deduplication to avoid concurrent duplicate requests + let documents: Document[]; + const requestKey = `poll:${kbId}`; + + // Check if there's already a pending request for this KB + const pendingRequest = this.pendingRequests.get(requestKey); + if (pendingRequest) { + // Reuse existing request to avoid duplicate API calls + documents = await pendingRequest; + } else { + // Create new request and track it + const requestPromise = knowledgeBaseService.getAllFiles(kbId); + this.pendingRequests.set(requestKey, requestPromise); + try { - const documents = await knowledgeBaseService.getAllFiles(kbId); - this.triggerDocumentsUpdate(kbId, documents); - } catch (e) { - // Ignore error + documents = await requestPromise; + } finally { + // Clean up after request completes + this.pendingRequests.delete(requestKey); } - this.stopPolling(kbId); - return; } - // Get latest document status - const documents = await knowledgeBaseService.getAllFiles(kbId); - // Call callback function with latest documents first to ensure UI updates immediately callback(documents); @@ -67,6 +80,18 @@ class KnowledgeBasePollingService { NON_TERMINAL_STATUSES.includes(doc.status) ); + // If exceeded maximum polling count and still processing, switch to extended polling mode + if (pollCount > this.maxDocumentPolls && hasProcessingDocs && !isExtendedPolling) { + log.warn(`Document polling for knowledge base ${kbId} exceeded ${this.maxDocumentPolls} attempts, switching to extended polling mode (reduced frequency)`); + isExtendedPolling = true; + // Stop the current interval and restart with longer interval + this.stopPolling(kbId); + // Continue polling with reduced frequency (every 10 seconds) + const extendedInterval = setInterval(pollDocuments, 10000); + this.pollingIntervals.set(kbId, extendedInterval); + return; + } + // If there are processing documents, continue polling if (hasProcessingDocs) { log.log('Documents processing, continue polling'); @@ -141,6 +166,7 @@ class KnowledgeBasePollingService { * @param expectedIncrement The number of new files uploaded */ pollForKnowledgeBaseReady( + kbId: string, kbName: string, originalDocumentCount: number = 0, expectedIncrement: number = 0 @@ -150,29 +176,14 @@ class KnowledgeBasePollingService { const checkForStats = async () => { try { const kbs = await knowledgeBaseService.getKnowledgeBasesInfo(true) as KnowledgeBase[]; - const kb = kbs.find(k => k.name === kbName); + const kb = kbs.find(k => k.id === kbId || k.name === kbName); // Check if KB exists and its stats are populated if (kb) { - // If expectedIncrement > 0, check if documentCount increased as expected - if ( - expectedIncrement > 0 && - kb.documentCount >= (originalDocumentCount + expectedIncrement) - ) { - log.log( - `Knowledge base ${kbName} documentCount increased as expected: ${kb.documentCount} (was ${originalDocumentCount}, expected increment ${expectedIncrement})` - ); - this.triggerKnowledgeBaseListUpdate(true); - resolve(kb); - return; - } - // Fallback: for new KB or no increment specified, use old logic - if (expectedIncrement === 0 && (kb.documentCount > 0 || kb.chunkCount > 0)) { - log.log(`Knowledge base ${kbName} is ready and stats are populated.`); - this.triggerKnowledgeBaseListUpdate(true); - resolve(kb); - return; - } + log.log(`Knowledge base ${kbName} detected.`); + this.triggerKnowledgeBaseListUpdate(true); + resolve(kb); + return; } count++; @@ -183,11 +194,11 @@ class KnowledgeBasePollingService { log.error(`Knowledge base ${kbName} readiness check timed out after ${this.maxKnowledgeBasePolls} attempts.`); // Handle knowledge base polling timeout - mark related tasks as failed - await this.handlePollingTimeout(kbName, 'knowledgeBase'); + await this.handlePollingTimeout(kbId, 'knowledgeBase'); // Push documents to UI try { - const documents = await knowledgeBaseService.getAllFiles(kbName); - this.triggerDocumentsUpdate(kbName, documents); + const documents = await knowledgeBaseService.getAllFiles(kbId); + this.triggerDocumentsUpdate(kbId, documents); } catch (e) { // Ignore error } @@ -201,11 +212,11 @@ class KnowledgeBasePollingService { setTimeout(checkForStats, this.knowledgeBasePollingInterval); } else { // Handle knowledge base polling timeout on error as well - await this.handlePollingTimeout(kbName, 'knowledgeBase'); + await this.handlePollingTimeout(kbId, 'knowledgeBase'); // Push documents to UI try { - const documents = await knowledgeBaseService.getAllFiles(kbName); - this.triggerDocumentsUpdate(kbName, documents); + const documents = await knowledgeBaseService.getAllFiles(kbId); + this.triggerDocumentsUpdate(kbId, documents); } catch (e) { // Ignore error } @@ -218,14 +229,14 @@ class KnowledgeBasePollingService { } // Simplified method for new knowledge base creation workflow - async handleNewKnowledgeBaseCreation(kbName: string, originalDocumentCount: number = 0, expectedIncrement: number = 0, callback: (kb: KnowledgeBase) => void) { + async handleNewKnowledgeBaseCreation(kbId: string, kbName: string, originalDocumentCount: number = 0, expectedIncrement: number = 0, callback: (kb: KnowledgeBase) => void) { // Start document polling - this.startDocumentStatusPolling(kbName, (documents) => { - this.triggerDocumentsUpdate(kbName, documents); + this.startDocumentStatusPolling(kbId, (documents) => { + this.triggerDocumentsUpdate(kbId, documents); }); try { // Start knowledge base polling parallelly - const populatedKB = await this.pollForKnowledgeBaseReady(kbName, originalDocumentCount, expectedIncrement); + const populatedKB = await this.pollForKnowledgeBaseReady(kbId, kbName, originalDocumentCount, expectedIncrement); // callback with populated knowledge base when everything is ready callback(populatedKB); } catch (error) { @@ -249,6 +260,13 @@ class KnowledgeBasePollingService { clearInterval(interval); }); this.pollingIntervals.clear(); + + // Clear pending requests and debounce timers to prevent memory leaks + this.pendingRequests.clear(); + this.debounceTimers.forEach((timer) => { + clearTimeout(timer); + }); + this.debounceTimers.clear(); } // Trigger knowledge base list update (optionally force refresh) diff --git a/frontend/services/knowledgeBaseService.ts b/frontend/services/knowledgeBaseService.ts index 27a6e0b38..0ea443081 100644 --- a/frontend/services/knowledgeBaseService.ts +++ b/frontend/services/knowledgeBaseService.ts @@ -71,15 +71,20 @@ class KnowledgeBaseService { // Convert Elasticsearch indices to knowledge base format knowledgeBases = data.indices_info.map((indexInfo: any) => { const stats = indexInfo.stats?.base_info || {}; + // Backend now returns: + // - name: internal index_name + // - display_name: user-facing knowledge_name (fallback to index_name) + const kbId = indexInfo.name; + const kbName = indexInfo.display_name || indexInfo.name; return { - id: indexInfo.name, - name: indexInfo.name, + id: kbId, + name: kbName, description: "Elasticsearch index", documentCount: stats.doc_count || 0, chunkCount: stats.chunk_count || 0, - createdAt: - stats.creation_date || new Date().toISOString().split("T")[0], + createdAt: stats.creation_date || null, + updatedAt: stats.update_date || stats.creation_date || null, embeddingModel: stats.embedding_model || "unknown", avatar: "", chunkNum: 0, @@ -276,6 +281,16 @@ class KnowledgeBaseService { token_num: 0, status: file.status || "UNKNOWN", latest_task_id: file.latest_task_id || "", + error_reason: file.error_reason, + // Optional ingestion progress metrics (only present for in-progress files) + processed_chunk_num: + typeof file.processed_chunk_num === "number" + ? file.processed_chunk_num + : null, + total_chunk_num: + typeof file.total_chunk_num === "number" + ? file.total_chunk_num + : null, })); } catch (error) { log.error("Failed to get all files:", error); @@ -806,6 +821,41 @@ class KnowledgeBaseService { throw new Error("Failed to execute hybrid search"); } } + + // Get error information for a document + async getDocumentErrorInfo( + kbId: string, + docId: string + ): Promise<{ + errorCode: string | null; + }> { + try { + const response = await fetch( + API_ENDPOINTS.knowledgeBase.getErrorInfo(kbId, docId), + { + headers: getAuthHeaders(), + } + ); + + if (!response.ok) { + throw new Error(`HTTP error! status: ${response.status}`); + } + + const data = await response.json(); + if (data.status !== "success") { + throw new Error(data.message || "Failed to get error info"); + } + + const errorCode = (data.error_code && String(data.error_code)) || null; + + return { + errorCode, + }; + } catch (error) { + log.error("Failed to get document error info:", error); + throw error; + } + } } // Export a singleton instance diff --git a/frontend/services/modelService.ts b/frontend/services/modelService.ts index 9de2c5483..3599bc939 100644 --- a/frontend/services/modelService.ts +++ b/frontend/services/modelService.ts @@ -67,6 +67,7 @@ export const modelService = { (model.connect_status as ModelConnectStatus) || "not_detected", expectedChunkSize: model.expected_chunk_size, maximumChunkSize: model.maximum_chunk_size, + chunkingBatchSize: model.chunk_batch, })); } return []; @@ -97,6 +98,7 @@ export const modelService = { displayName?: string; expectedChunkSize?: number; maximumChunkSize?: number; + chunkingBatchSize?: number; }): Promise => { try { const response = await fetch(API_ENDPOINTS.model.customModelCreate, { @@ -112,6 +114,7 @@ export const modelService = { display_name: model.displayName, expected_chunk_size: model.expectedChunkSize, maximum_chunk_size: model.maximumChunkSize, + chunk_batch: model.chunkingBatchSize, }), }); @@ -239,6 +242,7 @@ export const modelService = { source?: ModelSource; expectedChunkSize?: number; maximumChunkSize?: number; + chunkingBatchSize?: number; }): Promise => { try { const response = await fetch( @@ -262,6 +266,9 @@ export const modelService = { ...(model.maximumChunkSize !== undefined ? { maximum_chunk_size: model.maximumChunkSize } : {}), + ...(model.chunkingBatchSize !== undefined + ? { chunk_batch: model.chunkingBatchSize } + : {}), }), } ); diff --git a/frontend/services/storageService.ts b/frontend/services/storageService.ts index ec60eb187..a45add994 100644 --- a/frontend/services/storageService.ts +++ b/frontend/services/storageService.ts @@ -123,6 +123,68 @@ export function convertImageUrlToApiUrl(url: string): string { return url; } +const arrayBufferToBase64 = (buffer: ArrayBuffer): string => { + let binary = ""; + const bytes = new Uint8Array(buffer); + const chunkSize = 0x8000; + + for (let i = 0; i < bytes.length; i += chunkSize) { + const chunk = bytes.subarray(i, i + chunkSize); + binary += String.fromCharCode(...chunk); + } + + return btoa(binary); +}; + +const fetchBase64ViaStorage = async (objectName: string) => { + const response = await fetch(API_ENDPOINTS.storage.file(objectName, "base64")); + if (!response.ok) { + throw new Error(`Failed to resolve S3 URL via storage: ${response.status}`); + } + + const data = await response.json(); + if (!data?.success || !data?.base64) { + throw new Error(data?.error || "Storage response missing base64 content"); + } + + const contentType = data.content_type || "application/octet-stream"; + return { base64: data.base64 as string, contentType }; +}; + +// Cache for S3 URL to data URL resolution to avoid duplicate network requests +const s3ResolutionCache = new Map>(); + +// Internal helper: for s3:// URLs, resolve directly via storage download endpoint. +async function resolveS3UrlToDataUrlInternal(url: string): Promise { + const objectName = extractObjectNameFromUrl(url); + if (!objectName) { + return null; + } + + const { base64, contentType } = await fetchBase64ViaStorage(objectName); + return `data:${contentType};base64,${base64}`; +} + +export async function resolveS3UrlToDataUrl(url: string): Promise { + if (!url || !url.startsWith("s3://")) { + return null; + } + + const cached = s3ResolutionCache.get(url); + if (cached) { + return cached; + } + + const promise = resolveS3UrlToDataUrlInternal(url).catch((error) => { + // Remove from cache on failure so that future attempts can retry. + s3ResolutionCache.delete(url); + throw error; + }); + + s3ResolutionCache.set(url, promise); + return promise; +} + export const storageService = { /** * Upload files to storage service diff --git a/frontend/styles/globals.css b/frontend/styles/globals.css index 7d6b1749d..ad666027d 100644 --- a/frontend/styles/globals.css +++ b/frontend/styles/globals.css @@ -305,4 +305,23 @@ .kb-embedding-warning .ant-modal { width: max-content; min-width: 0; +} + +/* Responsive button text - global utility */ +@media (max-width: 1279px) { + .button-text-full { + display: none !important; + } + .button-text-short { + display: inline !important; + } +} + +@media (min-width: 1280px) { + .button-text-full { + display: inline !important; + } + .button-text-short { + display: none !important; + } } \ No newline at end of file diff --git a/frontend/types/agentConfig.ts b/frontend/types/agentConfig.ts index 3dc41c601..1a766788c 100644 --- a/frontend/types/agentConfig.ts +++ b/frontend/types/agentConfig.ts @@ -12,6 +12,7 @@ export interface Agent { name: string; display_name?: string; description: string; + author?: string; unavailable_reasons?: string[]; model: string; model_id?: number; @@ -127,6 +128,8 @@ export interface AgentSetupOrchestratorProps { setAgentDescription?: (value: string) => void; agentDisplayName?: string; setAgentDisplayName?: (value: string) => void; + agentAuthor?: string; + setAgentAuthor?: (value: string) => void; isGeneratingAgent?: boolean; onDebug?: () => void; getCurrentAgentId?: () => number | undefined; @@ -156,6 +159,7 @@ export interface SubAgentPoolProps { isGeneratingAgent?: boolean; editingAgent?: Agent | null; isCreatingNewAgent?: boolean; + onCopyAgent?: (agent: Agent) => void; onExportAgent?: (agent: Agent) => void; onDeleteAgent?: (agent: Agent) => void; } diff --git a/frontend/types/chat.ts b/frontend/types/chat.ts index 700edfdbf..826722055 100644 --- a/frontend/types/chat.ts +++ b/frontend/types/chat.ts @@ -9,6 +9,7 @@ export interface StepSection { export interface StepContent { id: string type: typeof chatConfig.messageTypes.MODEL_OUTPUT | + typeof chatConfig.messageTypes.MODEL_OUTPUT_CODE | typeof chatConfig.messageTypes.PARSING | typeof chatConfig.messageTypes.EXECUTION | typeof chatConfig.messageTypes.ERROR | diff --git a/frontend/types/knowledgeBase.ts b/frontend/types/knowledgeBase.ts index 85a5e6b12..e04f145c7 100644 --- a/frontend/types/knowledgeBase.ts +++ b/frontend/types/knowledgeBase.ts @@ -4,21 +4,23 @@ import { DOCUMENT_ACTION_TYPES, KNOWLEDGE_BASE_ACTION_TYPES, UI_ACTION_TYPES, NO // Knowledge base basic type export interface KnowledgeBase { - id: string - name: string - description: string | null - chunkCount: number - documentCount: number - createdAt: any - embeddingModel: string - avatar: string - chunkNum: number - language: string - nickname: string - parserId: string - permission: string - tokenNum: number - source: string + id: string; + name: string; + description: string | null; + chunkCount: number; + documentCount: number; + createdAt: any; + // Last update time of the knowledge base/index (may fall back to createdAt) + updatedAt?: any; + embeddingModel: string; + avatar: string; + chunkNum: number; + language: string; + nickname: string; + parserId: string; + permission: string; + tokenNum: number; + source: string; } // Create knowledge base parameter type @@ -31,17 +33,21 @@ export interface KnowledgeBaseCreateParams { // Document type export interface Document { - id: string - kb_id: string - name: string - type: string - size: number - create_time: string - chunk_num: number - token_num: number - status: string - selected?: boolean // For UI selection status - latest_task_id: string // For marking the latest celery task + id: string; + kb_id: string; + name: string; + type: string; + size: number; + create_time: string; + chunk_num: number; + token_num: number; + status: string; + selected?: boolean; // For UI selection status + latest_task_id: string; // For marking the latest celery task + error_reason?: string; // Error reason for failed documents + // Optional ingestion progress metrics + processed_chunk_num?: number | null; + total_chunk_num?: number | null; } // Document state interface diff --git a/frontend/types/market.ts b/frontend/types/market.ts index 888afffdb..770e39520 100644 --- a/frontend/types/market.ts +++ b/frontend/types/market.ts @@ -28,6 +28,7 @@ export interface MarketAgentListItem { name: string; display_name: string; description: string; + author?: string; category: MarketCategory; tags: MarketTag[]; download_count: number; diff --git a/frontend/types/modelConfig.ts b/frontend/types/modelConfig.ts index db97a8c0d..0d463161f 100644 --- a/frontend/types/modelConfig.ts +++ b/frontend/types/modelConfig.ts @@ -45,6 +45,7 @@ export interface ModelOption { connect_status?: ModelConnectStatus; expectedChunkSize?: number; maximumChunkSize?: number; + chunkingBatchSize?: number; } // Application configuration interface diff --git a/sdk/nexent/core/agents/agent_model.py b/sdk/nexent/core/agents/agent_model.py index 6eff00718..f3c5a77b7 100644 --- a/sdk/nexent/core/agents/agent_model.py +++ b/sdk/nexent/core/agents/agent_model.py @@ -1,7 +1,7 @@ from __future__ import annotations from threading import Event -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from pydantic import BaseModel, Field @@ -50,7 +50,12 @@ class AgentRunInfo(BaseModel): model_config_list: List[ModelConfig] = Field(description="List of model configurations") observer: MessageObserver = Field(description="Return data") agent_config: AgentConfig = Field(description="Detailed Agent configuration") - mcp_host: Optional[List[str]] = Field(description="MCP server address", default=None) + mcp_host: Optional[List[Union[str, Dict[str, Any]]]] = Field( + description="MCP server address(es). Can be a string (URL) or dict with 'url' and 'transport' keys. " + "Transport can be 'sse' or 'streamable-http'. If string, transport is auto-detected based on URL ending: " + "URLs ending with '/sse' use 'sse' transport, URLs ending with '/mcp' use 'streamable-http' transport.", + default=None + ) history: Optional[List[AgentHistory]] = Field(description="Historical conversation information", default=None) stop_event: Event = Field(description="Stop event control") diff --git a/sdk/nexent/core/agents/core_agent.py b/sdk/nexent/core/agents/core_agent.py index 826ef7093..be7b83b5e 100644 --- a/sdk/nexent/core/agents/core_agent.py +++ b/sdk/nexent/core/agents/core_agent.py @@ -1,3 +1,4 @@ +import json import re import ast import time @@ -9,12 +10,13 @@ from rich.console import Group from rich.text import Text -from smolagents.agents import CodeAgent, handle_agent_output_types, AgentError +from smolagents.agents import CodeAgent, handle_agent_output_types, AgentError, ActionOutput, RunResult from smolagents.local_python_executor import fix_final_answer_code from smolagents.memory import ActionStep, PlanningStep, FinalAnswerStep, ToolCall, TaskStep, SystemPromptStep -from smolagents.models import ChatMessage -from smolagents.monitoring import LogLevel -from smolagents.utils import AgentExecutionError, AgentGenerationError, truncate_content +from smolagents.models import ChatMessage, CODEAGENT_RESPONSE_FORMAT +from smolagents.monitoring import LogLevel, Timing, YELLOW_HEX, TokenUsage +from smolagents.utils import AgentExecutionError, AgentGenerationError, truncate_content, AgentMaxStepsError, \ + extract_code_from_text from ..utils.observer import MessageObserver, ProcessType from jinja2 import Template, StrictUndefined @@ -125,13 +127,17 @@ def _step_stream(self, memory_step: ActionStep) -> Generator[Any]: # Add new step in logs memory_step.model_input_messages = input_messages + stop_sequences = ["", "Observation:", "Calling tools:", "", "Observation:", "Calling tools:", " Generator[Any]: # Parse try: - code_action = fix_final_answer_code(parse_code_blobs(model_output)) + if self._use_structured_outputs_internally: + code_action = json.loads(model_output)["code"] + code_action = extract_code_from_text(code_action, self.code_block_tags) or code_action + else: + code_action = parse_code_blobs(model_output) + code_action = fix_final_answer_code(code_action) + memory_step.code_action = code_action # Record parsing results self.observer.add_message( self.agent_name, ProcessType.PARSE, code_action) @@ -155,26 +167,29 @@ def _step_stream(self, memory_step: ActionStep) -> Generator[Any]: content=model_output, title="AGENT FINAL ANSWER", level=LogLevel.INFO) raise FinalAnswerError() - memory_step.tool_calls = [ - ToolCall(name="python_interpreter", arguments=code_action, id=f"call_{len(self.memory.steps)}", )] + tool_call = ToolCall( + name="python_interpreter", + arguments=code_action, + id=f"call_{len(self.memory.steps)}", + ) + memory_step.tool_calls = [tool_call] # Execute self.logger.log_code(title="Executing parsed code:", content=code_action, level=LogLevel.INFO) - is_final_answer = False try: - output, execution_logs, is_final_answer = self.python_executor( - code_action) - + code_output = self.python_executor(code_action) execution_outputs_console = [] - if len(execution_logs) > 0: + if len(code_output.logs) > 0: # Record execution results self.observer.add_message( - self.agent_name, ProcessType.EXECUTION_LOGS, f"{execution_logs}") + self.agent_name, ProcessType.EXECUTION_LOGS, f"{code_output.logs}") execution_outputs_console += [ - Text("Execution logs:", style="bold"), Text(execution_logs), ] - observation = "Execution logs:\n" + execution_logs + Text("Execution logs:", style="bold"), + Text(code_output.logs), + ] + observation = "Execution logs:\n" + code_output.logs except Exception as e: if hasattr(self.python_executor, "state") and "_print_outputs" in self.python_executor.state: execution_logs = str( @@ -196,20 +211,24 @@ def _step_stream(self, memory_step: ActionStep) -> Generator[Any]: level=LogLevel.INFO, ) raise AgentExecutionError(error_msg, self.logger) - truncated_output = truncate_content(str(output)) - if output is not None: + truncated_output = None + if code_output is not None and code_output.output is not None: + truncated_output = truncate_content(str(code_output.output)) observation += "Last output from code snippet:\n" + truncated_output memory_step.observations = observation - execution_outputs_console += [ - Text(f"{('Out - Final answer' if is_final_answer else 'Out')}: {truncated_output}", - style=("bold #d4b702" if is_final_answer else ""), ), ] + if not code_output.is_final_answer and truncated_output is not None: + execution_outputs_console += [ + Text( + f"Out: {truncated_output}", + ), + ] self.logger.log(Group(*execution_outputs_console), level=LogLevel.INFO) - memory_step.action_output = output - yield output if is_final_answer else None + memory_step.action_output = code_output.output + yield ActionOutput(output=code_output.output, is_final_answer=code_output.is_final_answer) def run(self, task: str, stream: bool = False, reset: bool = True, images: Optional[List[str]] = None, - additional_args: Optional[Dict] = None, max_steps: Optional[int] = None, ): + additional_args: Optional[Dict] = None, max_steps: Optional[int] = None, return_full_result: bool | None = None): """ Run the agent for the given task. @@ -220,6 +239,8 @@ def run(self, task: str, stream: bool = False, reset: bool = True, images: Optio images (`list[str]`, *optional*): Paths to image(s). additional_args (`dict`, *optional*): Any other variables that you want to pass to the agent run, for instance images or dataframes. Give them clear names! max_steps (`int`, *optional*): Maximum number of steps the agent can take to solve the task. if not provided, will use the agent's default value. + return_full_result (`bool`, *optional*): Whether to return the full [`RunResult`] object or just the final answer output. + If `None` (default), the agent's `self.return_full_result` setting is used. Example: ```py @@ -236,7 +257,6 @@ def run(self, task: str, stream: bool = False, reset: bool = True, images: Optio You have been provided with these additional arguments, that you can access using the keys as variables in your python code: {str(additional_args)}.""" - self.system_prompt = self.initialize_system_prompt() self.memory.system_prompt = SystemPromptStep( system_prompt=self.system_prompt) if reset: @@ -261,8 +281,47 @@ def run(self, task: str, stream: bool = False, reset: bool = True, images: Optio if stream: # The steps are returned as they are executed through a generator to iterate on. return self._run_stream(task=self.task, max_steps=max_steps, images=images) + run_start_time = time.time() + steps = list(self._run_stream(task=self.task, max_steps=max_steps, images=images)) + # Outputs are returned only at the end. We only look at the last step. - return list(self._run_stream(task=self.task, max_steps=max_steps, images=images))[-1].final_answer + assert isinstance(steps[-1], FinalAnswerStep) + output = steps[-1].output + + return_full_result = return_full_result if return_full_result is not None else self.return_full_result + if return_full_result: + total_input_tokens = 0 + total_output_tokens = 0 + correct_token_usage = True + for step in self.memory.steps: + if isinstance(step, (ActionStep, PlanningStep)): + if step.token_usage is None: + correct_token_usage = False + break + else: + total_input_tokens += step.token_usage.input_tokens + total_output_tokens += step.token_usage.output_tokens + if correct_token_usage: + token_usage = TokenUsage(input_tokens=total_input_tokens, output_tokens=total_output_tokens) + else: + token_usage = None + + if self.memory.steps and isinstance(getattr(self.memory.steps[-1], "error", None), AgentMaxStepsError): + state = "max_steps_error" + else: + state = "success" + + step_dicts = self.memory.get_full_steps() + + return RunResult( + output=output, + token_usage=token_usage, + steps=step_dicts, + timing=Timing(start_time=run_start_time, end_time=time.time()), + state=state, + ) + + return output def __call__(self, task: str, **kwargs): """Adds additional prompting for the managed agent, runs it, and wraps the output. @@ -271,7 +330,11 @@ def __call__(self, task: str, **kwargs): full_task = Template(self.prompt_templates["managed_agent"]["task"], undefined=StrictUndefined).render({ "name": self.name, "task": task, **self.state }) - report = self.run(full_task, **kwargs) + result = self.run(full_task, **kwargs) + if isinstance(result, RunResult): + report = result.output + else: + report = result # When a sub-agent finishes running, return a marker try: @@ -286,7 +349,7 @@ def __call__(self, task: str, **kwargs): if self.provide_run_summary: answer += "\n\nFor more detail, find below a summary of this agent's work:\n\n" for message in self.write_memory_to_messages(summary_mode=True): - content = message["content"] + content = message.content answer += "\n" + truncate_content(str(content)) + "\n---" answer += "\n" return answer @@ -295,28 +358,44 @@ def _run_stream( self, task: str, max_steps: int, images: list["PIL.Image.Image"] | None = None ) -> Generator[ActionStep | PlanningStep | FinalAnswerStep]: final_answer = None + action_step = None self.step_number = 1 - while final_answer is None and self.step_number <= max_steps and not self.stop_event.is_set(): + returned_final_answer = False + while not returned_final_answer and self.step_number <= max_steps and not self.stop_event.is_set(): step_start_time = time.time() action_step = ActionStep( - step_number=self.step_number, start_time=step_start_time, observations_images=images + step_number=self.step_number, timing=Timing(start_time=step_start_time), observations_images=images ) try: - for el in self._execute_step(action_step): - yield el - final_answer = el + for output in self._step_stream(action_step): + yield output + + if isinstance(output, ActionOutput) and output.is_final_answer: + final_answer = output.output + self.logger.log( + Text(f"Final answer: {final_answer}", style=f"bold {YELLOW_HEX}"), + level=LogLevel.INFO, + ) + + if self.final_answer_checks: + self._validate_final_answer(final_answer) + returned_final_answer = True + action_step.is_final_answer = True + except FinalAnswerError: # When the model does not output code, directly treat the large model content as the final answer final_answer = action_step.model_output if isinstance(final_answer, str): final_answer = convert_code_format(final_answer) + returned_final_answer = True + action_step.is_final_answer = True except AgentError as e: action_step.error = e finally: - self._finalize_step(action_step, step_start_time) + self._finalize_step(action_step) self.memory.steps.append(action_step) yield action_step self.step_number += 1 @@ -324,8 +403,7 @@ def _run_stream( if self.stop_event.is_set(): final_answer = "" - if final_answer is None and self.step_number == max_steps + 1: - final_answer = self._handle_max_steps_reached( - task, images, step_start_time) + if not returned_final_answer and self.step_number == max_steps + 1: + final_answer = self._handle_max_steps_reached(task) yield action_step yield FinalAnswerStep(handle_agent_output_types(final_answer)) diff --git a/sdk/nexent/core/agents/nexent_agent.py b/sdk/nexent/core/agents/nexent_agent.py index f0f932389..f02251cfc 100644 --- a/sdk/nexent/core/agents/nexent_agent.py +++ b/sdk/nexent/core/agents/nexent_agent.py @@ -1,8 +1,9 @@ import re +import time from threading import Event from typing import List -from smolagents import ActionStep, AgentText, TaskStep +from smolagents import ActionStep, AgentText, TaskStep, Timing from smolagents.tools import Tool from ..models.openai_llm import OpenAIModel @@ -84,6 +85,9 @@ def create_local_tool(self, tool_config: ToolConfig): "vdb_core", None) if tool_config.metadata else None tools_obj.embedding_model = tool_config.metadata.get( "embedding_model", None) if tool_config.metadata else None + name_resolver = tool_config.metadata.get( + "name_resolver", None) if tool_config.metadata else None + tools_obj.name_resolver = {} if name_resolver is None else name_resolver elif class_name == "AnalyzeTextFileTool": tools_obj = tool_class(observer=self.observer, llm_model=tool_config.metadata.get("llm_model", []), @@ -195,7 +199,9 @@ def add_history_to_agent(self, history: List[AgentHistory]): # Create task step for user message self.agent.memory.steps.append(TaskStep(task=msg.content)) elif msg.role == 'assistant': - self.agent.memory.steps.append(ActionStep(action_output=msg.content, model_output=msg.content)) + self.agent.memory.steps.append(ActionStep(step_number=len(self.agent.memory.steps) + 1, + timing=Timing(start_time=time.time()), + action_output=msg.content, model_output=msg.content)) def agent_run_with_observer(self, query: str, reset=True): if not isinstance(self.agent, CoreAgent): @@ -214,7 +220,7 @@ def agent_run_with_observer(self, query: str, reset=True): if hasattr(step_log, "error") and step_log.error is not None: observer.add_message("", ProcessType.ERROR, str(step_log.error)) - final_answer = step_log.final_answer # Last log is the run's final_answer + final_answer = step_log.output # Last log is the run's final_answer if isinstance(final_answer, AgentText): final_answer_str = convert_code_format(final_answer.to_string()) diff --git a/sdk/nexent/core/agents/run_agent.py b/sdk/nexent/core/agents/run_agent.py index 41429367a..8a5a67517 100644 --- a/sdk/nexent/core/agents/run_agent.py +++ b/sdk/nexent/core/agents/run_agent.py @@ -1,6 +1,7 @@ import asyncio import logging from threading import Thread +from typing import Any, Dict, Union from smolagents import ToolCollection @@ -13,6 +14,56 @@ monitoring_manager = get_monitoring_manager() +def _detect_transport(url: str) -> str: + """ + Auto-detect MCP transport type based on URL format. + + Args: + url: MCP server URL + + Returns: + Transport type: 'sse' or 'streamable-http' + """ + url_stripped = url.strip() + + # Check URL ending to determine transport type + if url_stripped.endswith("/sse"): + return "sse" + elif url_stripped.endswith("/mcp"): + return "streamable-http" + + # Default to streamable-http for unrecognized formats + return "streamable-http" + + +def _normalize_mcp_config(mcp_host_item: Union[str, Dict[str, Any]]) -> Dict[str, Any]: + """ + Normalize MCP host configuration to a dictionary format. + + Args: + mcp_host_item: Either a string URL or a dict with 'url' and optional 'transport' + + Returns: + Dictionary with 'url' and 'transport' keys + """ + if isinstance(mcp_host_item, str): + url = mcp_host_item + transport = _detect_transport(url) + return {"url": url, "transport": transport} + elif isinstance(mcp_host_item, dict): + url = mcp_host_item.get("url") + if not url: + raise ValueError("MCP host dict must contain 'url' key") + transport = mcp_host_item.get("transport") + if not transport: + transport = _detect_transport(url) + if transport not in ("sse", "streamable-http"): + raise ValueError(f"Invalid transport type: {transport}. Must be 'sse' or 'streamable-http'") + return {"url": url, "transport": transport} + else: + raise ValueError(f"Invalid MCP host item type: {type(mcp_host_item)}. Must be str or dict") + + @monitoring_manager.monitor_endpoint("agent_run_thread", "agent_run_thread") def agent_run_thread(agent_run_info: AgentRunInfo): try: @@ -31,7 +82,8 @@ def agent_run_thread(agent_run_info: AgentRunInfo): else: agent_run_info.observer.add_message( "", ProcessType.AGENT_NEW_RUN, "") - mcp_client_list = [{"url": mcp_url} for mcp_url in mcp_host] + # Normalize MCP host configurations to support both string and dict formats + mcp_client_list = [_normalize_mcp_config(item) for item in mcp_host] with ToolCollection.from_mcp(mcp_client_list, trust_remote_code=True) as tool_collection: nexent = NexentAgent( diff --git a/sdk/nexent/core/models/openai_llm.py b/sdk/nexent/core/models/openai_llm.py index 1a52e2d29..1eef02c72 100644 --- a/sdk/nexent/core/models/openai_llm.py +++ b/sdk/nexent/core/models/openai_llm.py @@ -14,7 +14,7 @@ logger = logging.getLogger("openai_llm") class OpenAIModel(OpenAIServerModel): - def __init__(self, observer: MessageObserver, temperature=0.2, top_p=0.95, + def __init__(self, observer: MessageObserver = MessageObserver, temperature=0.2, top_p=0.95, ssl_verify=True, *args, **kwargs): """ Initialize OpenAI Model with observer and SSL verification option. @@ -46,7 +46,7 @@ def __init__(self, observer: MessageObserver, temperature=0.2, top_p=0.95, @get_monitoring_manager().monitor_llm_call("openai_chat", "chat_completion") def __call__(self, messages: List[Dict[str, Any]], stop_sequences: Optional[List[str]] = None, - grammar: Optional[str] = None, tools_to_call_from: Optional[List[Tool]] = None, **kwargs, ) -> ChatMessage: + response_format: dict[str, str] | None = None, tools_to_call_from: Optional[List[Tool]] = None, **kwargs, ) -> ChatMessage: # Get token tracker from decorator (if monitoring is available) token_tracker = kwargs.pop('_token_tracker', None) @@ -63,7 +63,7 @@ def __call__(self, messages: List[Dict[str, Any]], stop_sequences: Optional[List completion_kwargs = self._prepare_completion_kwargs( messages=messages, stop_sequences=stop_sequences, - grammar=grammar, tools_to_call_from=tools_to_call_from, model=self.model_id, + response_format=response_format, tools_to_call_from=tools_to_call_from, model=self.model_id, custom_role_conversions=self.custom_role_conversions, convert_images_to_image_urls=True, temperature=self.temperature, top_p=self.top_p, **kwargs, ) diff --git a/sdk/nexent/core/tools/datamate_search_tool.py b/sdk/nexent/core/tools/datamate_search_tool.py index a179dd689..bf1009269 100644 --- a/sdk/nexent/core/tools/datamate_search_tool.py +++ b/sdk/nexent/core/tools/datamate_search_tool.py @@ -150,7 +150,7 @@ def forward( entity_data = single_search_result.get("entity", {}) metadata = self._parse_metadata(entity_data.get("metadata")) dataset_id = self._extract_dataset_id(metadata.get("absolute_directory_path", "")) - file_id = entity_data.get("id") + file_id = metadata.get("original_file_id") download_url = self._build_file_download_url(dataset_id, file_id) score_details = entity_data.get("scoreDetails", {}) or {} @@ -162,7 +162,7 @@ def forward( }) search_result_message = SearchResultTextMessage( - title=metadata.get("file_name", "") or "Untitled", + title=metadata.get("file_name", ""), text=entity_data.get("text", ""), source_type="datamate", url=download_url, @@ -308,6 +308,6 @@ def _extract_dataset_id(absolute_path: str) -> str: def _build_file_download_url(self, dataset_id: str, file_id: str) -> str: """Build the download URL for a dataset file.""" - if not (self.server_ip and dataset_id and file_id): + if not (self.server_base_url and dataset_id and file_id): return "" - return f"{self.server_ip}/api/data-management/datasets/{dataset_id}/files/{file_id}/download" \ No newline at end of file + return f"{self.server_base_url}/api/data-management/datasets/{dataset_id}/files/{file_id}/download" \ No newline at end of file diff --git a/sdk/nexent/core/tools/knowledge_base_search_tool.py b/sdk/nexent/core/tools/knowledge_base_search_tool.py index 636162da1..90b600da6 100644 --- a/sdk/nexent/core/tools/knowledge_base_search_tool.py +++ b/sdk/nexent/core/tools/knowledge_base_search_tool.py @@ -1,6 +1,6 @@ import json import logging -from typing import List +from typing import Dict, List, Optional, Union from pydantic import Field from smolagents.tools import Tool @@ -36,7 +36,7 @@ class KnowledgeBaseSearchTool(Tool): }, "index_names": { "type": "array", - "description": "The list of knowledge base index names to search. If not provided, will search all available knowledge bases.", + "description": "The list of knowledge base names to search (supports user-facing knowledge_name or internal index_name). If not provided, will search all available knowledge bases.", "nullable": True, }, } @@ -50,6 +50,9 @@ def __init__( self, top_k: int = Field(description="Maximum number of search results", default=5), index_names: List[str] = Field(description="The list of index names to search", default=None, exclude=True), + name_resolver: Optional[Dict[str, str]] = Field( + description="Mapping from knowledge_name to index_name", default=None, exclude=True + ), observer: MessageObserver = Field(description="Message observer", default=None, exclude=True), embedding_model: BaseEmbedding = Field(description="The embedding model to use", default=None, exclude=True), vdb_core: VectorDatabaseCore = Field(description="Vector database client", default=None, exclude=True), @@ -68,13 +71,36 @@ def __init__( self.observer = observer self.vdb_core = vdb_core self.index_names = [] if index_names is None else index_names + self.name_resolver: Dict[str, str] = name_resolver or {} self.embedding_model = embedding_model self.record_ops = 1 # To record serial number self.running_prompt_zh = "知识库检索中..." self.running_prompt_en = "Searching the knowledge base..." - def forward(self, query: str, search_mode: str = "hybrid", index_names: List[str] = None) -> str: + def update_name_resolver(self, new_mapping: Dict[str, str]) -> None: + """Update the mapping from knowledge_name to index_name at runtime.""" + self.name_resolver = new_mapping or {} + + def _resolve_names(self, names: List[str]) -> List[str]: + """Resolve user-facing knowledge names to internal index names.""" + if not names: + return [] + if not self.name_resolver: + logger.warning( + "No name resolver provided, returning original names") + return names + return [self.name_resolver.get(name, name) for name in names] + + def _normalize_index_names(self, index_names: Optional[Union[str, List[str]]]) -> List[str]: + """Normalize index_names to list; accept single string and keep None as empty list.""" + if index_names is None: + return [] + if isinstance(index_names, str): + return [index_names] + return list(index_names) + + def forward(self, query: str, search_mode: str = "hybrid", index_names: Union[str, List[str], None] = None) -> str: # Send tool run message if self.observer: running_prompt = self.running_prompt_zh if self.observer.lang == "zh" else self.running_prompt_en @@ -83,7 +109,9 @@ def forward(self, query: str, search_mode: str = "hybrid", index_names: List[str self.observer.add_message("", ProcessType.CARD, json.dumps(card_content, ensure_ascii=False)) # Use provided index_names if available, otherwise use default - search_index_names = index_names if index_names is not None else self.index_names + search_index_names = self._normalize_index_names( + index_names if index_names is not None else self.index_names) + search_index_names = self._resolve_names(search_index_names) # Log the index_names being used for this search logger.info( diff --git a/sdk/nexent/vector_database/base.py b/sdk/nexent/vector_database/base.py index 188e33e59..d15ba7a25 100644 --- a/sdk/nexent/vector_database/base.py +++ b/sdk/nexent/vector_database/base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Callable from ..core.models.embedding_model import BaseEmbedding @@ -79,6 +79,8 @@ def vectorize_documents( documents: List[Dict[str, Any]], batch_size: int = 64, content_field: str = "content", + embedding_batch_size: int = 10, + progress_callback: Optional[Callable[[int, int], None]] = None, ) -> int: """ Index documents with embeddings. diff --git a/sdk/nexent/vector_database/elasticsearch_core.py b/sdk/nexent/vector_database/elasticsearch_core.py index 4e027b941..8abe046f4 100644 --- a/sdk/nexent/vector_database/elasticsearch_core.py +++ b/sdk/nexent/vector_database/elasticsearch_core.py @@ -1,10 +1,11 @@ +import json import logging import threading import time from contextlib import contextmanager from dataclasses import dataclass from datetime import datetime, timedelta -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional from elasticsearch import Elasticsearch, exceptions @@ -338,6 +339,8 @@ def vectorize_documents( documents: List[Dict[str, Any]], batch_size: int = 64, content_field: str = "content", + embedding_batch_size: int = 10, + progress_callback: Optional[Callable[[int, int], None]] = None, ) -> int: """ Smart batch insertion - automatically selecting strategy based on data size @@ -348,6 +351,7 @@ def vectorize_documents( documents: List of document dictionaries batch_size: Number of documents to process at once content_field: Field to use for generating embeddings + embedding_batch_size: Number of documents to send to embedding API at once (default: 10) Returns: int: Number of documents successfully indexed @@ -362,15 +366,34 @@ def vectorize_documents( total_docs = len(documents) if total_docs < 64: # Small data: direct insertion, using wait_for refresh - return self._small_batch_insert(index_name, documents, content_field, embedding_model) + return self._small_batch_insert( + index_name=index_name, + documents=documents, + content_field=content_field, + embedding_model=embedding_model, + progress_callback=progress_callback, + ) else: # Large data: using context manager estimated_duration = max(60, total_docs // 100) with self.bulk_operation_context(index_name, estimated_duration): - return self._large_batch_insert(index_name, documents, batch_size, content_field, embedding_model) + return self._large_batch_insert( + index_name=index_name, + documents=documents, + batch_size=batch_size, + content_field=content_field, + embedding_model=embedding_model, + embedding_batch_size=embedding_batch_size, + progress_callback=progress_callback, + ) def _small_batch_insert( - self, index_name: str, documents: List[Dict[str, Any]], content_field: str, embedding_model: BaseEmbedding + self, + index_name: str, + documents: List[Dict[str, Any]], + content_field: str, + embedding_model: BaseEmbedding, + progress_callback: Optional[Callable[[int, int], None]] = None, ) -> int: """Small batch insertion: real-time""" try: @@ -398,13 +421,20 @@ def _small_batch_insert( # Handle errors self._handle_bulk_errors(response) + if progress_callback: + try: + progress_callback(len(documents), len(documents)) + except Exception as e: + logger.warning( + f"[VECTORIZE] Progress callback failed in small batch: {str(e)}") + logger.info( f"Small batch insert completed: {len(documents)} chunks indexed.") return len(documents) except Exception as e: logger.error(f"Small batch insert failed: {e}") - return 0 + raise def _large_batch_insert( self, @@ -413,6 +443,8 @@ def _large_batch_insert( batch_size: int, content_field: str, embedding_model: BaseEmbedding, + embedding_batch_size: int = 10, + progress_callback: Optional[Callable[[int, int], None]] = None, ) -> int: """ Large batch insertion with sub-batching for embedding API. @@ -422,6 +454,7 @@ def _large_batch_insert( processed_docs = self._preprocess_documents( documents, content_field) total_indexed = 0 + total_vectorized = 0 total_docs = len(processed_docs) es_total_batches = (total_docs + batch_size - 1) // batch_size start_time = time.time() @@ -439,7 +472,7 @@ def _large_batch_insert( doc_embedding_pairs = [] # Sub-batch for embedding API - embedding_batch_size = 64 + # Use the provided embedding_batch_size (default 10) to reduce provider pressure for j in range(0, len(es_batch), embedding_batch_size): embedding_sub_batch = es_batch[j: j + embedding_batch_size] # Retry logic for embedding API call (3 retries, 1s delay) @@ -459,6 +492,16 @@ def _large_batch_insert( doc_embedding_pairs.append((doc, embedding)) success = True + total_vectorized += len(embedding_sub_batch) + if progress_callback: + try: + progress_callback( + total_vectorized, total_docs) + logger.debug( + f"[VECTORIZE] Progress callback (embedding) {total_vectorized}/{total_docs} (ES batch {es_batch_num}/{es_total_batches}, sub-batch start {j})") + except Exception as callback_err: + logger.warning( + f"[VECTORIZE] Progress callback failed during embedding: {callback_err}") break # Success, exit retry loop except Exception as e: @@ -504,10 +547,7 @@ def _large_batch_insert( except Exception as e: logger.error( f"Bulk insert error: {e}, ES batch num: {es_batch_num}") - continue - - # Add 0.1s delay between batches to avoid overloading embedding API - time.sleep(0.1) + raise self._force_refresh_with_retry(index_name) total_elapsed = time.time() - start_time @@ -517,7 +557,7 @@ def _large_batch_insert( return total_indexed except Exception as e: logger.error(f"Large batch insert failed: {e}") - return 0 + raise def _preprocess_documents(self, documents: List[Dict[str, Any]], content_field: str) -> List[Dict[str, Any]]: """Ensure all documents have the required fields and set default values""" @@ -558,21 +598,44 @@ def _handle_bulk_errors(self, response: Dict[str, Any]) -> None: """Handle bulk operation errors""" if response.get("errors"): for item in response["items"]: - if "error" in item.get("index", {}): - error_info = item["index"]["error"] - error_type = error_info.get("type") - error_reason = error_info.get("reason") - error_cause = error_info.get("caused_by", {}) - - if error_type == "version_conflict_engine_exception": - # ignore version conflict - continue - else: - logger.error( - f"FATAL ERROR {error_type}: {error_reason}") - if error_cause: - logger.error( - f"Caused By: {error_cause.get('type')}: {error_cause.get('reason')}") + if "error" not in item.get("index", {}): + continue + + error_info = item["index"]["error"] + error_type = error_info.get("type") + error_reason = error_info.get("reason") + error_cause = error_info.get("caused_by", {}) + + if error_type == "version_conflict_engine_exception": + # ignore version conflict + continue + + logger.error(f"FATAL ERROR {error_type}: {error_reason}") + if error_cause: + logger.error( + f"Caused By: {error_cause.get('type')}: {error_cause.get('reason')}" + ) + + reason_text = error_reason or "Unknown bulk indexing error" + cause_reason = error_cause.get("reason") + if cause_reason: + reason_text = f"{reason_text}; caused by: {cause_reason}" + + # Derive a precise error code without chaining through es_bulk_failed + if "dense_vector" in reason_text and "different number of dimensions" in reason_text: + error_code = "es_dim_mismatch" + else: + error_code = "es_bulk_failed" + + raise Exception( + json.dumps( + { + "message": f"Bulk indexing failed: {reason_text}", + "error_code": error_code, + }, + ensure_ascii=False, + ) + ) def delete_documents(self, index_name: str, path_or_url: str) -> int: """ diff --git a/sdk/pyproject.toml b/sdk/pyproject.toml index 453857a1d..1e1369fb7 100644 --- a/sdk/pyproject.toml +++ b/sdk/pyproject.toml @@ -31,14 +31,14 @@ dependencies = [ "rich>=13.9.4", "setuptools>=75.1.0", "websockets>=14.2", - "smolagents[mcp]==1.15.0", + "smolagents[mcp]==1.23.0", "Pillow>=10.0.0", "aiohttp>=3.1.13", "jieba>=0.42.1", "boto3>=1.37.34", "botocore>=1.37.34", "python-multipart>=0.0.20", - "mcpadapt==0.1.9", + "mcpadapt>=0.1.13", "mcp==1.10.1", "tiktoken>=0.5.0", "tavily-python", diff --git a/test/backend/app/test_agent_app.py b/test/backend/app/test_agent_app.py index 3eeaf6650..dbb5a5318 100644 --- a/test/backend/app/test_agent_app.py +++ b/test/backend/app/test_agent_app.py @@ -689,3 +689,120 @@ def test_get_agent_call_relationship_api_exception(mocker, mock_auth_header): assert resp.status_code == 500 assert "Failed to get agent call relationship" in resp.json()["detail"] + + +def test_check_agent_name_batch_api_success(mocker, mock_auth_header): + mock_impl = mocker.patch( + "apps.agent_app.check_agent_name_conflict_batch_impl", + new_callable=mocker.AsyncMock, + ) + mock_impl.return_value = [{"name_conflict": True}] + + payload = { + "items": [ + {"agent_id": 1, "name": "AgentA", "display_name": "Agent A"}, + ] + } + + resp = config_client.post( + "/agent/check_name", json=payload, headers=mock_auth_header + ) + + assert resp.status_code == 200 + mock_impl.assert_called_once() + assert resp.json() == [{"name_conflict": True}] + + +def test_check_agent_name_batch_api_bad_request(mocker, mock_auth_header): + mock_impl = mocker.patch( + "apps.agent_app.check_agent_name_conflict_batch_impl", + new_callable=mocker.AsyncMock, + ) + mock_impl.side_effect = ValueError("bad payload") + + resp = config_client.post( + "/agent/check_name", + json={"items": [{"agent_id": 1, "name": "AgentA"}]}, + headers=mock_auth_header, + ) + + assert resp.status_code == 400 + assert resp.json()["detail"] == "bad payload" + + +def test_check_agent_name_batch_api_error(mocker, mock_auth_header): + mock_impl = mocker.patch( + "apps.agent_app.check_agent_name_conflict_batch_impl", + new_callable=mocker.AsyncMock, + ) + mock_impl.side_effect = Exception("unexpected") + + resp = config_client.post( + "/agent/check_name", + json={"items": [{"agent_id": 1, "name": "AgentA"}]}, + headers=mock_auth_header, + ) + + assert resp.status_code == 500 + assert "Agent name batch check error" in resp.json()["detail"] + + +def test_regenerate_agent_name_batch_api_success(mocker, mock_auth_header): + mock_impl = mocker.patch( + "apps.agent_app.regenerate_agent_name_batch_impl", + new_callable=mocker.AsyncMock, + ) + mock_impl.return_value = [{"name": "NewName", "display_name": "New Display"}] + + payload = { + "items": [ + { + "agent_id": 1, + "name": "AgentA", + "display_name": "Agent A", + "task_description": "desc", + } + ] + } + + resp = config_client.post( + "/agent/regenerate_name", json=payload, headers=mock_auth_header + ) + + assert resp.status_code == 200 + mock_impl.assert_called_once() + assert resp.json() == [{"name": "NewName", "display_name": "New Display"}] + + +def test_regenerate_agent_name_batch_api_bad_request(mocker, mock_auth_header): + mock_impl = mocker.patch( + "apps.agent_app.regenerate_agent_name_batch_impl", + new_callable=mocker.AsyncMock, + ) + mock_impl.side_effect = ValueError("invalid") + + resp = config_client.post( + "/agent/regenerate_name", + json={"items": [{"agent_id": 1, "name": "AgentA"}]}, + headers=mock_auth_header, + ) + + assert resp.status_code == 400 + assert resp.json()["detail"] == "invalid" + + +def test_regenerate_agent_name_batch_api_error(mocker, mock_auth_header): + mock_impl = mocker.patch( + "apps.agent_app.regenerate_agent_name_batch_impl", + new_callable=mocker.AsyncMock, + ) + mock_impl.side_effect = Exception("boom") + + resp = config_client.post( + "/agent/regenerate_name", + json={"items": [{"agent_id": 1, "name": "AgentA"}]}, + headers=mock_auth_header, + ) + + assert resp.status_code == 500 + assert "Agent name batch regenerate error" in resp.json()["detail"] \ No newline at end of file diff --git a/test/backend/app/test_file_management_app.py b/test/backend/app/test_file_management_app.py index cd4be8afd..a337a1434 100644 --- a/test/backend/app/test_file_management_app.py +++ b/test/backend/app/test_file_management_app.py @@ -295,6 +295,53 @@ async def gen(): assert b"chunk1" in b"".join(chunks) +@pytest.mark.asyncio +async def test_get_storage_file_base64_success(monkeypatch): + """get_storage_file should return JSON with base64 content when download=base64.""" + async def fake_get_stream(object_name): + class FakeStream: + def read(self): + return b"hello-bytes" + + return FakeStream(), "image/png" + + monkeypatch.setattr(file_management_app, "get_file_stream_impl", fake_get_stream) + + resp = await file_management_app.get_storage_file( + object_name="attachments/img.png", + download="base64", + expires=60, + filename=None, + ) + + assert resp.status_code == 200 + data = resp.body.decode() + assert '"success":true' in data + assert '"content_type":"image/png"' in data + + +@pytest.mark.asyncio +async def test_get_storage_file_base64_read_error(monkeypatch): + """get_storage_file should raise HTTPException when reading stream fails in base64 mode.""" + async def fake_get_stream(object_name): + class FakeStream: + def read(self): + raise RuntimeError("read-failed") + + return FakeStream(), "image/png" + + monkeypatch.setattr(file_management_app, "get_file_stream_impl", fake_get_stream) + + with pytest.raises(Exception) as exc_info: + await file_management_app.get_storage_file( + object_name="attachments/img.png", + download="base64", + expires=60, + filename=None, + ) + + assert "Failed to read file content for base64 encoding" in str(exc_info.value) + @pytest.mark.asyncio async def test_get_storage_file_metadata(monkeypatch): async def fake_get_url(object_name, expires): diff --git a/test/backend/app/test_vectordatabase_app.py b/test/backend/app/test_vectordatabase_app.py index fc0529341..97e26842a 100644 --- a/test/backend/app/test_vectordatabase_app.py +++ b/test/backend/app/test_vectordatabase_app.py @@ -6,7 +6,7 @@ import os import sys import pytest -from unittest.mock import patch, MagicMock, ANY +from unittest.mock import patch, MagicMock, ANY, AsyncMock from fastapi.testclient import TestClient from fastapi import FastAPI @@ -152,7 +152,7 @@ async def test_create_new_index_success(vdb_core_mock, auth_data): # Setup mocks with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ - patch("backend.apps.vectordatabase_app.ElasticSearchService.create_index") as mock_create: + patch("backend.apps.vectordatabase_app.ElasticSearchService.create_knowledge_base") as mock_create: expected_response = {"status": "success", "index_name": auth_data["index_name"]} @@ -165,7 +165,13 @@ async def test_create_new_index_success(vdb_core_mock, auth_data): # Verify assert response.status_code == 200 assert response.json() == expected_response + # vdb_core is constructed inside router; accept ANY for instance mock_create.assert_called_once() + called_args = mock_create.call_args[0] + assert called_args[0] == auth_data["index_name"] + assert called_args[1] == 768 + assert called_args[3] == auth_data["user_id"] + assert called_args[4] == auth_data["tenant_id"] @pytest.mark.asyncio @@ -177,7 +183,7 @@ async def test_create_new_index_error(vdb_core_mock, auth_data): # Setup mocks with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ - patch("backend.apps.vectordatabase_app.ElasticSearchService.create_index") as mock_create: + patch("backend.apps.vectordatabase_app.ElasticSearchService.create_knowledge_base") as mock_create: mock_create.side_effect = Exception("Test error") @@ -702,10 +708,11 @@ async def test_get_index_chunks_success(vdb_core_mock): Test retrieving index chunks successfully. Verifies that the endpoint forwards query params and returns the service payload. """ + index_name = "test_index" with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.get_index_name_by_knowledge_name", return_value="resolved_index"), \ patch("backend.apps.vectordatabase_app.ElasticSearchService.get_index_chunks") as mock_get_chunks: - index_name = "test_index" expected_response = { "status": "success", "message": "ok", @@ -724,7 +731,7 @@ async def test_get_index_chunks_success(vdb_core_mock): assert response.status_code == 200 assert response.json() == expected_response mock_get_chunks.assert_called_once_with( - index_name=index_name, + index_name="resolved_index", page=2, page_size=50, path_or_url="/foo", @@ -738,10 +745,11 @@ async def test_get_index_chunks_error(vdb_core_mock): Test retrieving index chunks with service error. Ensures the endpoint maps the exception to HTTP 500. """ + index_name = "test_index" with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.get_index_name_by_knowledge_name", return_value="resolved_index"), \ patch("backend.apps.vectordatabase_app.ElasticSearchService.get_index_chunks") as mock_get_chunks: - index_name = "test_index" mock_get_chunks.side_effect = Exception("Chunk failure") response = client.post(f"/indices/{index_name}/chunks") @@ -749,7 +757,7 @@ async def test_get_index_chunks_error(vdb_core_mock): assert response.status_code == 500 assert response.json() == {"detail": "Error getting chunks: Chunk failure"} mock_get_chunks.assert_called_once_with( - index_name=index_name, + index_name="resolved_index", page=None, page_size=None, path_or_url=None, @@ -765,6 +773,7 @@ async def test_create_chunk_success(vdb_core_mock, auth_data): with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.get_index_name_by_knowledge_name", return_value=auth_data["index_name"]), \ patch("backend.apps.vectordatabase_app.ElasticSearchService.create_chunk") as mock_create: expected_response = {"status": "success", "chunk_id": "chunk-1"} @@ -794,6 +803,7 @@ async def test_create_chunk_error(vdb_core_mock, auth_data): with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.get_index_name_by_knowledge_name", return_value=auth_data["index_name"]), \ patch("backend.apps.vectordatabase_app.ElasticSearchService.create_chunk") as mock_create: mock_create.side_effect = Exception("Create failed") @@ -822,6 +832,7 @@ async def test_update_chunk_success(vdb_core_mock, auth_data): with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.get_index_name_by_knowledge_name", return_value=auth_data["index_name"]), \ patch("backend.apps.vectordatabase_app.ElasticSearchService.update_chunk") as mock_update: expected_response = {"status": "success", "chunk_id": "chunk-1"} @@ -850,6 +861,7 @@ async def test_update_chunk_value_error(vdb_core_mock, auth_data): with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.get_index_name_by_knowledge_name", return_value=auth_data["index_name"]), \ patch("backend.apps.vectordatabase_app.ElasticSearchService.update_chunk") as mock_update: mock_update.side_effect = ValueError("Invalid update payload") @@ -864,7 +876,8 @@ async def test_update_chunk_value_error(vdb_core_mock, auth_data): headers=auth_data["auth_header"], ) - assert response.status_code == 400 + # ValueError is mapped to NOT_FOUND in app layer + assert response.status_code == 404 assert response.json() == {"detail": "Invalid update payload"} mock_update.assert_called_once() @@ -877,6 +890,7 @@ async def test_update_chunk_exception(vdb_core_mock, auth_data): with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.get_index_name_by_knowledge_name", return_value=auth_data["index_name"]), \ patch("backend.apps.vectordatabase_app.ElasticSearchService.update_chunk") as mock_update: mock_update.side_effect = Exception("Update failed") @@ -904,6 +918,7 @@ async def test_delete_chunk_success(vdb_core_mock, auth_data): with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.get_index_name_by_knowledge_name", return_value=auth_data["index_name"]), \ patch("backend.apps.vectordatabase_app.ElasticSearchService.delete_chunk") as mock_delete: expected_response = {"status": "success", "chunk_id": "chunk-1"} @@ -927,6 +942,7 @@ async def test_delete_chunk_not_found(vdb_core_mock, auth_data): with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.get_index_name_by_knowledge_name", return_value=auth_data["index_name"]), \ patch("backend.apps.vectordatabase_app.ElasticSearchService.delete_chunk") as mock_delete: mock_delete.side_effect = ValueError("Chunk not found") @@ -949,6 +965,7 @@ async def test_delete_chunk_exception(vdb_core_mock, auth_data): with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.get_index_name_by_knowledge_name", return_value=auth_data["index_name"]), \ patch("backend.apps.vectordatabase_app.ElasticSearchService.delete_chunk") as mock_delete: mock_delete.side_effect = Exception("Delete failed") @@ -1351,6 +1368,108 @@ async def test_health_check_exception(vdb_core_mock): mock_health.assert_called_once_with(ANY) +@pytest.mark.asyncio +async def test_get_document_error_info_not_found(vdb_core_mock, auth_data): + """ + Test document error info when document is not found. + """ + with patch("backend.apps.vectordatabase_app.get_all_files_status", new=AsyncMock(return_value={})): + response = client.get( + f"/indices/{auth_data['index_name']}/documents/missing_doc/error-info", + headers=auth_data["auth_header"], + ) + + assert response.status_code == 404 + assert "not found" in response.json()["detail"] + + +@pytest.mark.asyncio +async def test_get_document_error_info_no_task_id(auth_data): + """ + Test document error info when task id is empty. + """ + with patch( + "backend.apps.vectordatabase_app.get_all_files_status", + new=AsyncMock( + return_value={ + "doc-1": { + "latest_task_id": "" + } + } + ), + ), patch("backend.apps.vectordatabase_app.get_redis_service") as mock_redis: + response = client.get( + "/indices/test_index/documents/doc-1/error-info", + headers=auth_data["auth_header"], + ) + + assert response.status_code == 200 + assert response.json() == {"status": "success", "error_code": None} + mock_redis.assert_not_called() + + +@pytest.mark.asyncio +async def test_get_document_error_info_json_error_code(auth_data): + """ + Test document error info JSON parsing for error_code. + """ + redis_mock = MagicMock() + redis_mock.get_error_info.return_value = '{"error_code": "INVALID_FORMAT"}' + + with patch( + "backend.apps.vectordatabase_app.get_all_files_status", + new=AsyncMock( + return_value={ + "doc-1": { + "latest_task_id": "task-123" + } + } + ), + ), patch( + "backend.apps.vectordatabase_app.get_redis_service", + return_value=redis_mock, + ): + response = client.get( + "/indices/test_index/documents/doc-1/error-info", + headers=auth_data["auth_header"], + ) + + assert response.status_code == 200 + assert response.json() == {"status": "success", "error_code": "INVALID_FORMAT"} + redis_mock.get_error_info.assert_called_once_with("task-123") + + +@pytest.mark.asyncio +async def test_get_document_error_info_regex_error_code(auth_data): + """ + Test document error info regex extraction when JSON parsing fails. + """ + redis_mock = MagicMock() + redis_mock.get_error_info.return_value = "oops {'error_code': 'TIMEOUT_ERROR'}" + + with patch( + "backend.apps.vectordatabase_app.get_all_files_status", + new=AsyncMock( + return_value={ + "doc-1": { + "latest_task_id": "task-999" + } + } + ), + ), patch( + "backend.apps.vectordatabase_app.get_redis_service", + return_value=redis_mock, + ): + response = client.get( + "/indices/test_index/documents/doc-1/error-info", + headers=auth_data["auth_header"], + ) + + assert response.status_code == 200 + assert response.json() == {"status": "success", "error_code": "TIMEOUT_ERROR"} + redis_mock.get_error_info.assert_called_once_with("task-999") + + @pytest.mark.asyncio async def test_health_check_timeout_exception(vdb_core_mock): """ @@ -1545,6 +1664,59 @@ async def test_hybrid_search_value_error(vdb_core_mock, auth_data): assert response.json() == {"detail": "Query text is required"} +@pytest.mark.asyncio +async def test_get_index_chunks_value_error(vdb_core_mock): + """ + Test get_index_chunks maps ValueError to 404. + """ + index_name = "test_index" + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.get_index_name_by_knowledge_name", return_value="resolved_index"), \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.get_index_chunks") as mock_get_chunks: + + mock_get_chunks.side_effect = ValueError("Unknown index") + + response = client.post(f"/indices/{index_name}/chunks") + + assert response.status_code == 404 + assert response.json() == {"detail": "Unknown index"} + mock_get_chunks.assert_called_once_with( + index_name="resolved_index", + page=None, + page_size=None, + path_or_url=None, + vdb_core=ANY, + ) + + +@pytest.mark.asyncio +async def test_create_chunk_value_error(vdb_core_mock, auth_data): + """ + Test create_chunk maps ValueError to 404. + """ + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.get_index_name_by_knowledge_name", return_value=auth_data["index_name"]), \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.create_chunk") as mock_create: + + mock_create.side_effect = ValueError("Invalid chunk payload") + + payload = { + "content": "Hello world", + "path_or_url": "doc-1", + } + + response = client.post( + f"/indices/{auth_data['index_name']}/chunk", + json=payload, + headers=auth_data["auth_header"], + ) + + assert response.status_code == 404 + assert response.json() == {"detail": "Invalid chunk payload"} + mock_create.assert_called_once() + + @pytest.mark.asyncio async def test_hybrid_search_exception(vdb_core_mock, auth_data): """ diff --git a/test/backend/data_process/test_ray_config.py b/test/backend/data_process/test_ray_config.py index a334965ac..55440cfef 100644 --- a/test/backend/data_process/test_ray_config.py +++ b/test/backend/data_process/test_ray_config.py @@ -95,6 +95,8 @@ def decorator(func): const_mod.FORWARD_REDIS_RETRY_DELAY_S = 0 const_mod.FORWARD_REDIS_RETRY_MAX = 1 const_mod.DISABLE_RAY_DASHBOARD = False + # Constants required by tasks.py + const_mod.ROOT_DIR = "/tmp/test" sys.modules["consts.const"] = const_mod # Stub consts.model (required by utils.file_management_utils) @@ -163,6 +165,71 @@ def __init__(self, chunking_strategy: str, source_type: str, index_name: str, au file_utils_mod.get_file_size = lambda *args, **kwargs: 0 sys.modules["utils.file_management_utils"] = file_utils_mod + # Stub services.redis_service (required by tasks.py) + if "services" not in sys.modules: + services_pkg = types.ModuleType("services") + setattr(services_pkg, "__path__", []) + sys.modules["services"] = services_pkg + if "services.redis_service" not in sys.modules: + redis_service_mod = types.ModuleType("services.redis_service") + class FakeRedisService: + def __init__(self): + pass + redis_service_mod.RedisService = FakeRedisService + redis_service_mod.get_redis_service = lambda: FakeRedisService() + sys.modules["services.redis_service"] = redis_service_mod + + # Stub backend.data_process modules (required by __init__.py and tasks.py) + if "backend" not in sys.modules: + backend_pkg = types.ModuleType("backend") + setattr(backend_pkg, "__path__", []) + sys.modules["backend"] = backend_pkg + if "backend.data_process" not in sys.modules: + dp_pkg = types.ModuleType("backend.data_process") + setattr(dp_pkg, "__path__", []) + sys.modules["backend.data_process"] = dp_pkg + + # Stub backend.data_process.app (required by tasks.py) + if "backend.data_process.app" not in sys.modules: + app_mod = types.ModuleType("backend.data_process.app") + # Create a fake Celery app instance + fake_app = types.SimpleNamespace( + backend=types.SimpleNamespace(), # Not DisabledBackend + conf=types.SimpleNamespace(update=lambda **kwargs: None) + ) + app_mod.app = fake_app + sys.modules["backend.data_process.app"] = app_mod + + # Stub backend.data_process.tasks (required by __init__.py) + if "backend.data_process.tasks" not in sys.modules: + tasks_mod = types.ModuleType("backend.data_process.tasks") + # Mock the task functions that __init__.py imports + tasks_mod.process = lambda *args, **kwargs: None + tasks_mod.forward = lambda *args, **kwargs: None + tasks_mod.process_and_forward = lambda *args, **kwargs: None + tasks_mod.process_sync = lambda *args, **kwargs: None + sys.modules["backend.data_process.tasks"] = tasks_mod + + # Stub backend.data_process.utils (required by __init__.py) + if "backend.data_process.utils" not in sys.modules: + utils_mod = types.ModuleType("backend.data_process.utils") + utils_mod.get_task_info = lambda *args, **kwargs: {} + utils_mod.get_task_details = lambda *args, **kwargs: {} + sys.modules["backend.data_process.utils"] = utils_mod + + # Stub backend.data_process.__init__ to avoid importing real tasks + # This must be done after tasks and utils are defined + if "backend.data_process.__init__" not in sys.modules: + init_mod = types.ModuleType("backend.data_process.__init__") + init_mod.app = sys.modules["backend.data_process.app"].app + init_mod.process = sys.modules["backend.data_process.tasks"].process + init_mod.forward = sys.modules["backend.data_process.tasks"].forward + init_mod.process_and_forward = sys.modules["backend.data_process.tasks"].process_and_forward + init_mod.process_sync = sys.modules["backend.data_process.tasks"].process_sync + init_mod.get_task_info = sys.modules["backend.data_process.utils"].get_task_info + init_mod.get_task_details = sys.modules["backend.data_process.utils"].get_task_details + sys.modules["backend.data_process.__init__"] = init_mod + # Stub ray_actors (required by tasks.py) if "backend.data_process.ray_actors" not in sys.modules: ray_actors_mod = types.ModuleType("backend.data_process.ray_actors") @@ -179,10 +246,128 @@ def __init__(self, chunking_strategy: str, source_type: str, index_name: str, au DataProcessCore=type("_Core", (), {"__init__": lambda self: None, "file_process": lambda *a, **k: []}) ) - # Import and reload the module after mocks are in place - import backend.data_process.ray_config as ray_config_module - importlib.reload(ray_config_module) - + # Build a lightweight mock ray_config module to avoid importing real code + if "backend" not in sys.modules: + backend_pkg = types.ModuleType("backend") + setattr(backend_pkg, "__path__", []) + sys.modules["backend"] = backend_pkg + + # Ensure backend has data_process attribute for mocker.patch to work + if not hasattr(sys.modules["backend"], "data_process"): + if "backend.data_process" not in sys.modules: + dp_pkg = types.ModuleType("backend.data_process") + setattr(dp_pkg, "__path__", []) + sys.modules["backend.data_process"] = dp_pkg + sys.modules["backend"].data_process = sys.modules["backend.data_process"] + elif "backend.data_process" not in sys.modules: + dp_pkg = types.ModuleType("backend.data_process") + setattr(dp_pkg, "__path__", []) + sys.modules["backend.data_process"] = dp_pkg + sys.modules["backend"].data_process = dp_pkg + + ray_config_module = types.ModuleType("backend.data_process.ray_config") + # Add os module reference so mocker.patch can patch os.cpu_count + ray_config_module.os = os + + class RayConfig: + def __init__(self): + from consts.const import RAY_OBJECT_STORE_MEMORY_GB, RAY_TEMP_DIR, RAY_preallocate_plasma + self.object_store_memory_gb = RAY_OBJECT_STORE_MEMORY_GB + self.temp_dir = RAY_TEMP_DIR + self.preallocate_plasma = RAY_preallocate_plasma + + def get_init_params(self, num_cpus=None, include_dashboard=True, dashboard_port=8265, address=None): + params = {"ignore_reinit_error": True} + if address: + params["address"] = address + else: + if num_cpus is None: + num_cpus = os.cpu_count() + params["num_cpus"] = num_cpus + params["object_store_memory"] = int(self.object_store_memory_gb * 1024 * 1024 * 1024) + if include_dashboard and not address: + params["include_dashboard"] = True + params["dashboard_host"] = "0.0.0.0" + params["dashboard_port"] = dashboard_port + else: + params["include_dashboard"] = False + params["_temp_dir"] = self.temp_dir + params["object_spilling_directory"] = self.temp_dir + return params + + def _set_preallocate_env(self): + os.environ["RAY_preallocate_plasma"] = str(self.preallocate_plasma).lower() + + def init_ray(self, num_cpus=None, include_dashboard=True, address=None, dashboard_port=8265): + self._set_preallocate_env() + try: + if getattr(sys.modules["ray"], "is_initialized")(): + return True + params = self.get_init_params(num_cpus=num_cpus, include_dashboard=include_dashboard, + dashboard_port=dashboard_port, address=address) + sys.modules["ray"].init(**params) + try: + sys.modules["ray"].cluster_resources() + except Exception: + pass + return True + except Exception: + return False + + def connect_to_cluster(self, address): + self._set_preallocate_env() + try: + if getattr(sys.modules["ray"], "is_initialized")(): + return True + sys.modules["ray"].init(address=address, ignore_reinit_error=True) + return True + except Exception: + return False + + def start_local_cluster(self, num_cpus=None, include_dashboard=True, dashboard_port=8265): + self._set_preallocate_env() + try: + params = self.get_init_params(num_cpus=num_cpus, include_dashboard=include_dashboard, + dashboard_port=dashboard_port) + sys.modules["ray"].init(**params) + return True + except Exception: + return False + + @classmethod + def init_ray_for_worker(cls, address): + cfg = cls() + return cfg.connect_to_cluster(address) + + @classmethod + def init_ray_for_service(cls, num_cpus=None, dashboard_port=8265, try_connect_first=False, include_dashboard=True): + cfg = cls() + if try_connect_first: + if cfg.connect_to_cluster("auto"): + return True + # Fallback to local cluster + return cfg.start_local_cluster(num_cpus=num_cpus, include_dashboard=include_dashboard, + dashboard_port=dashboard_port) + + ray_config_module.RayConfig = RayConfig + sys.modules["backend.data_process.ray_config"] = ray_config_module + + # Ensure backend.data_process has ray_config attribute for mocker.patch to work + sys.modules["backend.data_process"].ray_config = ray_config_module + + # Add a fake ray_config submodule for tests that try to patch ray_config.ray_config.log_configuration + # This is a workaround for tests that incorrectly try to patch a non-existent nested module + fake_ray_config_submodule = types.ModuleType("backend.data_process.ray_config.ray_config") + fake_ray_config_submodule.log_configuration = lambda *args, **kwargs: None + sys.modules["backend.data_process.ray_config"].ray_config = fake_ray_config_submodule + + # Add __spec__ to support importlib.reload (though reload won't work perfectly with mock modules) + # We'll create a minimal spec-like object + class MockSpec: + def __init__(self, name): + self.name = name + ray_config_module.__spec__ = MockSpec("backend.data_process.ray_config") + return ray_config_module, fake_ray @@ -470,9 +655,8 @@ def test_get_init_params_object_store_memory_calculation(mocker): if "consts.const" in sys.modules: sys.modules["consts.const"].RAY_OBJECT_STORE_MEMORY_GB = 1.5 - # Reload to pick up new constant value - importlib.reload(ray_config_module) - + # Create new RayConfig instance to pick up new constant value + # (RayConfig.__init__ reads from consts.const, so new instance will use updated value) config = ray_config_module.RayConfig() params = config.get_init_params(num_cpus=2) @@ -488,11 +672,9 @@ def test_init_ray_sets_preallocate_plasma_env(mocker): if "consts.const" in sys.modules: sys.modules["consts.const"].RAY_preallocate_plasma = True - # Reload to pick up new constant value - importlib.reload(ray_config_module) - + # Create new RayConfig instance to pick up new constant value + # (RayConfig.__init__ reads from consts.const, so new instance will use updated value) config = ray_config_module.RayConfig() - config.preallocate_plasma = True config.init_ray(num_cpus=2, include_dashboard=False) diff --git a/test/backend/data_process/test_tasks.py b/test/backend/data_process/test_tasks.py index 42a086347..722ac29d4 100644 --- a/test/backend/data_process/test_tasks.py +++ b/test/backend/data_process/test_tasks.py @@ -115,6 +115,7 @@ def decorator(func): # New defaults required by ray_actors import const_mod.DEFAULT_EXPECTED_CHUNK_SIZE = 1024 const_mod.DEFAULT_MAXIMUM_CHUNK_SIZE = 1536 + const_mod.ROOT_DIR = "/mock/root" sys.modules["consts.const"] = const_mod # Minimal stub for consts.model used by utils.file_management_utils if "consts.model" not in sys.modules: @@ -328,7 +329,7 @@ def failing_init(**kwargs): # Verify that the exception is re-raised with pytest.raises(RuntimeError) as exc_info: tasks.init_ray_in_worker() - assert exc_info.value == init_exception + assert "Failed to initialize Ray for Celery worker" in str(exc_info.value) def test_run_async_no_running_loop(monkeypatch): @@ -554,6 +555,37 @@ def get(self, k): json.loads(str(ei.value)) +def test_forward_returns_when_task_cancelled(monkeypatch): + """forward should exit early when cancellation flag is set""" + tasks, _ = import_tasks_with_fake_ray(monkeypatch) + monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") + + class FakeRedisService: + def __init__(self): + self.calls = 0 + + def is_task_cancelled(self, task_id): + self.calls += 1 + return True + + fake_service = FakeRedisService() + monkeypatch.setattr(tasks, "get_redis_service", lambda: fake_service) + + self = FakeSelf("cancel-1") + result = tasks.forward( + self, + processed_data={"chunks": [{"content": "keep", "metadata": {}}]}, + index_name="idx", + source="/a.txt", + ) + + assert result["chunks_stored"] == 0 + assert "cancelled" in result["es_result"]["message"].lower() + assert fake_service.calls == 1 + # No state updates should occur because we returned early + assert self.states == [] + + def test_forward_redis_client_from_url_failure(monkeypatch): tasks, _ = import_tasks_with_fake_ray(monkeypatch) monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") @@ -965,6 +997,506 @@ def apply_async(self): assert chain_id == "123" +def test_extract_error_code_parses_detail_and_regex_and_unknown(): + from backend.data_process.tasks import extract_error_code + + # detail error_code inside JSON string + json_detail = json.dumps({"detail": {"error_code": "detail_code"}}) + assert extract_error_code(json_detail) == "detail_code" + + # regex fallback when not valid JSON + raw = 'oops {"error_code":"regex_code"}' + assert extract_error_code(raw) == "regex_code" + + # unknown path + assert extract_error_code("no code here") == "unknown_error" + + +def test_extract_error_code_top_level_key(): + from backend.data_process.tasks import extract_error_code + + payload = json.dumps({"error_code": "top_level"}) + assert extract_error_code(payload) == "top_level" + + +def test_save_error_to_redis_branches(monkeypatch): + from backend.data_process.tasks import save_error_to_redis + + warnings = [] + infos = [] + + class FakeRedisSvc: + def __init__(self, return_val=True): + self.return_val = return_val + self.calls = [] + + def save_error_info(self, tid, reason): + self.calls.append((tid, reason)) + return self.return_val + + # capture logger calls + monkeypatch.setattr( + "backend.data_process.tasks.logger.warning", + lambda msg: warnings.append(msg), + ) + monkeypatch.setattr( + "backend.data_process.tasks.logger.info", lambda msg: infos.append(msg) + ) + monkeypatch.setattr( + "backend.data_process.tasks.logger.error", lambda *a, **k: warnings.append(a[0]) + ) + + # empty task_id + save_error_to_redis("", "r", 0) + assert any("task_id is empty" in w for w in warnings) + warnings.clear() + + # empty error_reason + save_error_to_redis("tid", "", 0) + assert any("error_reason is empty" in w for w in warnings) + warnings.clear() + + # success True + svc_true = FakeRedisSvc(True) + monkeypatch.setattr( + "backend.data_process.tasks.get_redis_service", lambda: svc_true + ) + save_error_to_redis("tid1", "reason1", 0) + assert svc_true.calls == [("tid1", "reason1")] + assert any("Successfully saved error info" in i for i in infos) + + # success False + infos.clear() + svc_false = FakeRedisSvc(False) + monkeypatch.setattr( + "backend.data_process.tasks.get_redis_service", lambda: svc_false + ) + save_error_to_redis("tid2", "reason2", 0) + assert svc_false.calls == [("tid2", "reason2")] + assert any("save_error_info returned False" in w for w in warnings) + + # exception path + def boom(): + raise RuntimeError("fail") + + monkeypatch.setattr( + "backend.data_process.tasks.get_redis_service", lambda: boom() + ) + save_error_to_redis("tid3", "reason3", 0) + assert any("Failed to save error info to Redis" in w for w in warnings) + + +def test_process_error_fallback_when_save_error_raises(monkeypatch, tmp_path): + tasks, fake_ray = import_tasks_with_fake_ray(monkeypatch, initialized=True) + + # Force get_ray_actor to raise to enter error handling + monkeypatch.setattr(tasks, "get_ray_actor", lambda: (_ for _ in ()).throw( + Exception("x" * 250) + )) + + # Make save_error_to_redis raise to hit fallback block + monkeypatch.setattr( + tasks, + "save_error_to_redis", + lambda *a, **k: (_ for _ in ()).throw(RuntimeError("save-fail")), + ) + + self = FakeSelf("err-fallback") + with pytest.raises(Exception): + tasks.process( + self, + source=str(tmp_path / "missing.txt"), + source_type="local", + chunking_strategy="basic", + index_name="idx", + original_filename="file.txt", + ) + + # State should still be updated in fallback branch + assert any( + s.get("meta", {}).get("stage") in {"text_extraction_failed", "extracting_text"} + for s in self.states + ) or self.states == [] + + +def test_process_error_truncates_reason_when_no_error_code(monkeypatch, tmp_path): + """process should truncate long messages when extract_error_code is falsy""" + tasks, fake_ray = import_tasks_with_fake_ray(monkeypatch, initialized=True) + + long_msg = "x" * 250 + error_json = json.dumps({"message": long_msg}) + + # Provide actor but make ray.get raise inside the try block + class FakeActor: + def __init__(self): + self.process_file = types.SimpleNamespace(remote=lambda *a, **k: "ref_err") + self.store_chunks_in_redis = types.SimpleNamespace( + remote=lambda *a, **k: None) + + monkeypatch.setattr(tasks, "get_ray_actor", lambda: FakeActor()) + fake_ray.get = lambda *_: (_ for _ in ()).throw(Exception(error_json)) + # Force extract_error_code to return None so truncation path executes + monkeypatch.setattr(tasks, "extract_error_code", lambda *a, **k: None) + + calls: list[str] = [] + + def save_and_capture(task_id, reason, start_time): + calls.append(reason) + + monkeypatch.setattr(tasks, "save_error_to_redis", save_and_capture) + + # Ensure source file exists so FileNotFound is not raised before ray.get + f = tmp_path / "exists.txt" + f.write_text("data") + + self = FakeSelf("trunc-proc") + with pytest.raises(Exception): + tasks.process( + self, + source=str(f), + source_type="local", + chunking_strategy="basic", + index_name="idx", + original_filename="f.txt", + ) + + # Captured reason should be truncated because error_code is falsy + assert len(calls) >= 1 + truncated_reason = calls[-1] + assert truncated_reason.endswith("...") + assert len(truncated_reason) <= 203 + assert any( + s.get("meta", {}).get("stage") == "text_extraction_failed" + for s in self.states + ) + + +def test_forward_cancel_check_warning_then_continue(monkeypatch): + tasks, _ = import_tasks_with_fake_ray(monkeypatch) + monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") + + # make cancellation check raise to hit warning path + monkeypatch.setattr(tasks, "get_redis_service", lambda: (_ for _ in ()).throw(RuntimeError("boom"))) + + # run index_documents normally via stubbed run_async returning success + monkeypatch.setattr( + tasks, + "run_async", + lambda coro: {"success": True, "total_indexed": 1, "total_submitted": 1, "message": "ok"}, + ) + + self = FakeSelf("warn-cancel") + result = tasks.forward( + self, + processed_data={"chunks": [{"content": "c", "metadata": {}}]}, + index_name="idx", + source="/a.txt", + authorization="Bearer 1", + ) + assert result["chunks_stored"] == 1 + + +def _run_coro(coro): + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return loop.run_until_complete(coro) + + +def test_forward_index_documents_error_code_from_detail(monkeypatch): + tasks, _ = import_tasks_with_fake_ray(monkeypatch) + monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") + + class FakeResponse: + status = 500 + + async def text(self): + return json.dumps({"detail": {"error_code": "detail_err"}}) + + async def __aenter__(self): + return self + + async def __aexit__(self, *a): + return False + + class FakeSession: + def __init__(self, *a, **k): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, *a): + return False + + def post(self, *a, **k): + return FakeResponse() + + fake_aiohttp = types.SimpleNamespace( + TCPConnector=lambda verify_ssl=False: None, + ClientTimeout=lambda total=None: None, + ClientSession=FakeSession, + ClientConnectorError=Exception, + ClientResponseError=Exception, + ) + monkeypatch.setattr(tasks, "aiohttp", fake_aiohttp) + monkeypatch.setattr(tasks, "run_async", _run_coro) + + self = FakeSelf("detail-err") + with pytest.raises(Exception) as exc: + tasks.forward( + self, + processed_data={"chunks": [{"content": "x", "metadata": {}}]}, + index_name="idx", + source="/a.txt", + authorization="Bearer token", + ) + assert "detail_err" in str(exc.value) + + +def test_forward_index_documents_regex_error_code(monkeypatch): + tasks, _ = import_tasks_with_fake_ray(monkeypatch) + monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") + monkeypatch.setattr(tasks, "get_file_size", lambda *a, **k: 0) + + class FakeResponse: + status = 500 + + async def text(self): + # Include quotes so regex r'\"error_code\": \"...\"' matches + return 'oops "error_code":"regex_branch"' + + async def __aenter__(self): + return self + + async def __aexit__(self, *a): + return False + + class FakeSession: + def __init__(self, *a, **k): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, *a): + return False + + def post(self, *a, **k): + return FakeResponse() + + fake_aiohttp = types.SimpleNamespace( + TCPConnector=lambda verify_ssl=False: None, + ClientTimeout=lambda total=None: None, + ClientSession=FakeSession, + ClientConnectorError=Exception, + ClientResponseError=Exception, + ) + monkeypatch.setattr(tasks, "aiohttp", fake_aiohttp) + monkeypatch.setattr(tasks, "run_async", _run_coro) + + self = FakeSelf("regex-err") + with pytest.raises(Exception) as exc: + tasks.forward( + self, + processed_data={"chunks": [{"content": "x", "metadata": {}}]}, + index_name="idx", + source="/a.txt", + ) + assert "regex_branch" in str(exc.value) + + +def test_forward_index_documents_client_connector_error(monkeypatch): + tasks, _ = import_tasks_with_fake_ray(monkeypatch) + monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") + + class FakeSession: + def __init__(self, *a, **k): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, *a): + return False + + def post(self, *a, **k): + raise tasks.aiohttp.ClientConnectorError("down") + + fake_aiohttp = types.SimpleNamespace( + ClientConnectorError=Exception, + TCPConnector=lambda verify_ssl=False: None, + ClientTimeout=lambda total=None: None, + ClientSession=FakeSession, + ClientResponseError=Exception, + ) + monkeypatch.setattr(tasks, "aiohttp", fake_aiohttp) + monkeypatch.setattr(tasks, "run_async", _run_coro) + + self = FakeSelf("conn-err") + with pytest.raises(Exception) as exc: + tasks.forward( + self, + processed_data={"chunks": [{"content": "x", "metadata": {}}]}, + index_name="idx", + source="/a.txt", + ) + assert "Failed to connect to API" in str(exc.value) + + +def test_forward_index_documents_timeout(monkeypatch): + tasks, _ = import_tasks_with_fake_ray(monkeypatch) + monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") + + class FakeSession: + def __init__(self, *a, **k): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, *a): + return False + + def post(self, *a, **k): + raise asyncio.TimeoutError("t/o") + + fake_aiohttp = types.SimpleNamespace( + ClientConnectorError=Exception, + ClientResponseError=Exception, + TCPConnector=lambda verify_ssl=False: None, + ClientTimeout=lambda total=None: None, + ClientSession=FakeSession, + ) + monkeypatch.setattr(tasks, "aiohttp", fake_aiohttp) + monkeypatch.setattr(tasks, "run_async", _run_coro) + + self = FakeSelf("timeout-err") + with pytest.raises(Exception) as exc: + tasks.forward( + self, + processed_data={"chunks": [{"content": "x", "metadata": {}}]}, + index_name="idx", + source="/a.txt", + ) + assert "Failed to connect to API" in str(exc.value) or "timeout" in str(exc.value).lower() + + +def test_forward_truncates_reason_when_no_error_code(monkeypatch): + tasks, _ = import_tasks_with_fake_ray(monkeypatch) + monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") + monkeypatch.setattr(tasks, "get_file_size", lambda *a, **k: 0) + monkeypatch.setattr(tasks, "extract_error_code", lambda *a, **k: None) + + long_msg = json.dumps({"message": "m" * 250}) + monkeypatch.setattr( + tasks, "run_async", lambda coro: (_ for _ in ()).throw(Exception(long_msg)) + ) + + reasons: list[str] = [] + monkeypatch.setattr( + tasks, "save_error_to_redis", lambda tid, reason, st: reasons.append(reason) + ) + + self = FakeSelf("f-trunc") + with pytest.raises(Exception): + tasks.forward( + self, + processed_data={"chunks": [{"content": "x", "metadata": {}}]}, + index_name="idx", + source="/a.txt", + ) + + assert reasons and reasons[0].endswith("...") + assert len(reasons[0]) <= 203 + assert any( + s.get("meta", {}).get("stage") == "forward_task_failed" for s in self.states + ) + + +def test_forward_fallback_truncates_on_non_json_error(monkeypatch): + tasks, _ = import_tasks_with_fake_ray(monkeypatch) + monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") + monkeypatch.setattr(tasks, "get_file_size", lambda *a, **k: 0) + monkeypatch.setattr(tasks, "extract_error_code", lambda *a, **k: None) + + monkeypatch.setattr( + tasks, "run_async", lambda coro: (_ for _ in ()).throw(Exception("n" * 250)) + ) + + reasons: list[str] = [] + monkeypatch.setattr( + tasks, "save_error_to_redis", lambda tid, reason, st: reasons.append(reason) + ) + + self = FakeSelf("f-fallback") + with pytest.raises(Exception): + tasks.forward( + self, + processed_data={"chunks": [{"content": "x", "metadata": {}}]}, + index_name="idx", + source="/a.txt", + ) + + assert reasons and reasons[0].endswith("...") + assert len(reasons[0]) <= 203 + assert any( + s.get("meta", {}).get("stage") == "forward_task_failed" for s in self.states + ) + + +def test_forward_error_truncates_reason_and_uses_save(monkeypatch): + tasks, _ = import_tasks_with_fake_ray(monkeypatch) + long_message = "m" * 250 + monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") + monkeypatch.setattr( + tasks, "run_async", lambda coro: (_ for _ in ()).throw(Exception(json.dumps({"message": long_message}))) + ) + captured = {} + monkeypatch.setattr( + tasks, "save_error_to_redis", lambda tid, reason, st: captured.setdefault("reason", reason) + ) + + self = FakeSelf("trunc") + with pytest.raises(Exception): + tasks.forward( + self, + processed_data={"chunks": [{"content": "x", "metadata": {}}]}, + index_name="idx", + source="/a.txt", + ) + + assert captured["reason"] + + +def test_forward_error_fallback_when_json_loads_fails(monkeypatch): + tasks, _ = import_tasks_with_fake_ray(monkeypatch) + monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") + monkeypatch.setattr( + tasks, "run_async", lambda coro: (_ for _ in ()).throw(Exception("not-json-error")) + ) + captured = {} + monkeypatch.setattr( + tasks, "save_error_to_redis", lambda tid, reason, st: captured.setdefault("reason", reason) + ) + + self = FakeSelf("fallback-forward") + with pytest.raises(Exception): + tasks.forward( + self, + processed_data={"chunks": [{"content": "x", "metadata": {}}]}, + index_name="idx", + source="/a.txt", + ) + + assert captured["reason"] + assert any( + s.get("meta", {}).get("stage") == "forward_task_failed" for s in self.states + ) + + def test_process_sync_local_returns(monkeypatch): tasks, fake_ray = import_tasks_with_fake_ray(monkeypatch, initialized=True) @@ -1082,6 +1614,48 @@ def __init__(self): assert success_state.get("meta", {}).get("processing_speed_mb_s") == 0 +def test_process_no_chunks_saves_error(monkeypatch, tmp_path): + """process should save error info when no chunks are produced""" + tasks, fake_ray = import_tasks_with_fake_ray(monkeypatch, initialized=True) + + class FakeActor: + def __init__(self): + self.process_file = types.SimpleNamespace( + remote=lambda *a, **k: "ref-empty") + self.store_chunks_in_redis = types.SimpleNamespace( + remote=lambda *a, **k: None) + + monkeypatch.setattr(tasks, "get_ray_actor", lambda: FakeActor()) + fake_ray.get_returns = [] # no chunks returned from ray.get + + saved_reason = {} + monkeypatch.setattr( + tasks, + "save_error_to_redis", + lambda task_id, reason, start_time: saved_reason.setdefault( + "reason", reason), + ) + + f = tmp_path / "empty_file.txt" + f.write_text("data") + + self = FakeSelf("no-chunks") + with pytest.raises(Exception) as exc_info: + tasks.process( + self, + source=str(f), + source_type="local", + chunking_strategy="basic", + index_name="idx", + original_filename="empty_file.txt", + ) + + assert '"error_code": "no_valid_chunks"' in saved_reason.get("reason", "") + assert any(state.get("meta", {}).get("stage") == + "text_extraction_failed" for state in self.states) + json.loads(str(exc_info.value)) + + def test_process_url_source_with_many_chunks(monkeypatch): """Test processing URL source that generates many chunks""" tasks, fake_ray = import_tasks_with_fake_ray(monkeypatch, initialized=True) diff --git a/test/backend/data_process/test_worker.py b/test/backend/data_process/test_worker.py index a59635c13..fb7115816 100644 --- a/test/backend/data_process/test_worker.py +++ b/test/backend/data_process/test_worker.py @@ -2,6 +2,7 @@ import types import importlib import pytest +import os class FakeRay: @@ -44,6 +45,7 @@ def setup_mocks_for_worker(mocker, initialized=False): const_mod.FORWARD_REDIS_RETRY_MAX = 1 const_mod.DISABLE_RAY_DASHBOARD = False const_mod.DATA_PROCESS_SERVICE = "http://data-process" + const_mod.ROOT_DIR = "/mock/root" sys.modules["consts.const"] = const_mod # Stub celery module and submodules (required by tasks.py imported via __init__.py) @@ -483,6 +485,23 @@ def init_ray_for_worker(cls, address): assert worker_module.worker_state['initialized'] is True +def test_setup_worker_environment_sets_ray_preallocate_env(mocker): + """Ensure setup_worker_environment sets RAY_preallocate_plasma env var""" + worker_module, _ = setup_mocks_for_worker(mocker, initialized=False) + + # Force init success to avoid fallback path exceptions + class FakeRayConfig: + @classmethod + def init_ray_for_worker(cls, address): + return True + + mocker.patch.object(worker_module, "RayConfig", FakeRayConfig) + + worker_module.setup_worker_environment() + + assert os.environ.get("RAY_preallocate_plasma") == str(worker_module.RAY_preallocate_plasma).lower() + + def test_setup_worker_environment_ray_init_fallback(mocker): """Test setup_worker_environment with Ray init fallback""" worker_module, fake_ray = setup_mocks_for_worker(mocker, initialized=False) diff --git a/test/backend/database/test_knowledge_db.py b/test/backend/database/test_knowledge_db.py index 913e8f1a3..af337eb8d 100644 --- a/test/backend/database/test_knowledge_db.py +++ b/test/backend/database/test_knowledge_db.py @@ -71,6 +71,7 @@ class MockKnowledgeRecord: def __init__(self, **kwargs): self.knowledge_id = kwargs.get('knowledge_id', 1) self.index_name = kwargs.get('index_name', 'test_index') + self.knowledge_name = kwargs.get('knowledge_name', 'test_index') self.knowledge_describe = kwargs.get('knowledge_describe', 'test description') self.created_by = kwargs.get('created_by', 'test_user') self.updated_by = kwargs.get('updated_by', 'test_user') @@ -83,6 +84,7 @@ def __init__(self, **kwargs): # Mock SQLAlchemy column attributes knowledge_id = MagicMock(name="knowledge_id_column") index_name = MagicMock(name="index_name_column") + knowledge_name = MagicMock(name="knowledge_name_column") knowledge_describe = MagicMock(name="knowledge_describe_column") created_by = MagicMock(name="created_by_column") updated_by = MagicMock(name="updated_by_column") @@ -107,7 +109,9 @@ def __init__(self, **kwargs): get_knowledge_info_by_knowledge_ids, get_knowledge_ids_by_index_names, get_knowledge_info_by_tenant_id, - update_model_name_by_index_name + update_model_name_by_index_name, + get_index_name_by_knowledge_name, + _generate_index_name ) @@ -125,8 +129,9 @@ def test_create_knowledge_record_success(monkeypatch, mock_session): session, _ = mock_session # Create mock knowledge record - mock_record = MockKnowledgeRecord() + mock_record = MockKnowledgeRecord(knowledge_name="test_knowledge") mock_record.knowledge_id = 123 + mock_record.index_name = "test_knowledge" # Mock database session context mock_ctx = MagicMock() @@ -140,16 +145,21 @@ def test_create_knowledge_record_success(monkeypatch, mock_session): "knowledge_describe": "Test knowledge description", "user_id": "test_user", "tenant_id": "test_tenant", - "embedding_model_name": "test_model" + "embedding_model_name": "test_model", + "knowledge_name": "test_knowledge" } # Mock KnowledgeRecord constructor with patch('backend.database.knowledge_db.KnowledgeRecord', return_value=mock_record): result = create_knowledge_record(test_query) - assert result == 123 + assert result == { + "knowledge_id": 123, + "index_name": "test_knowledge", + "knowledge_name": "test_knowledge", + } session.add.assert_called_once_with(mock_record) - session.flush.assert_called_once() + assert session.flush.call_count == 1 session.commit.assert_called_once() @@ -179,6 +189,42 @@ def test_create_knowledge_record_exception(monkeypatch, mock_session): session.rollback.assert_called_once() +def test_create_knowledge_record_generates_index_name(monkeypatch, mock_session): + """Test create_knowledge_record generates index_name when not provided""" + session, _ = mock_session + + mock_record = MockKnowledgeRecord(knowledge_name="kb1") + mock_record.knowledge_id = 7 + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + # Deterministic index name + monkeypatch.setattr("backend.database.knowledge_db._generate_index_name", lambda _: "7-generated") + + test_query = { + "knowledge_describe": "desc", + "user_id": "user-1", + "tenant_id": "tenant-1", + "embedding_model_name": "model-x", + "knowledge_name": "kb1", + } + + with patch('backend.database.knowledge_db.KnowledgeRecord', return_value=mock_record): + result = create_knowledge_record(test_query) + + assert result == { + "knowledge_id": 7, + "index_name": "7-generated", + "knowledge_name": "kb1", + } + assert mock_record.index_name == "7-generated" + assert session.flush.call_count == 2 # initial insert + index_name update + session.commit.assert_called_once() + + def test_update_knowledge_record_success(monkeypatch, mock_session): """Test successful update of knowledge record""" session, query = mock_session @@ -446,6 +492,39 @@ def test_get_knowledge_record_exception(monkeypatch, mock_session): get_knowledge_record(test_query) +def test_get_knowledge_record_with_none_query(monkeypatch, mock_session): + """Test get_knowledge_record with None query raises TypeError""" + session, query = mock_session + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + # When query is None, accessing query['index_name'] will raise TypeError + with pytest.raises(TypeError, match="'NoneType' object is not subscriptable"): + get_knowledge_record(None) + + +def test_get_knowledge_record_without_index_name_key(monkeypatch, mock_session): + """Test get_knowledge_record with query missing index_name key raises KeyError""" + session, query = mock_session + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + # When query doesn't have 'index_name' key, accessing query['index_name'] will raise KeyError + test_query = { + "tenant_id": "test_tenant" + # Missing index_name key + } + + with pytest.raises(KeyError): + get_knowledge_record(test_query) + + def test_get_knowledge_info_by_knowledge_ids_success(monkeypatch, mock_session): """Test retrieving knowledge info by knowledge ID list""" session, query = mock_session @@ -454,12 +533,14 @@ def test_get_knowledge_info_by_knowledge_ids_success(monkeypatch, mock_session): mock_record1 = MockKnowledgeRecord() mock_record1.knowledge_id = 1 mock_record1.index_name = "knowledge1" + mock_record1.knowledge_name = "Knowledge Base 1" mock_record1.knowledge_sources = "elasticsearch" mock_record1.embedding_model_name = "model1" mock_record2 = MockKnowledgeRecord() mock_record2.knowledge_id = 2 mock_record2.index_name = "knowledge2" + mock_record2.knowledge_name = "Knowledge Base 2" mock_record2.knowledge_sources = "vectordb" mock_record2.embedding_model_name = "model2" @@ -479,12 +560,14 @@ def test_get_knowledge_info_by_knowledge_ids_success(monkeypatch, mock_session): { "knowledge_id": 1, "index_name": "knowledge1", + "knowledge_name": "Knowledge Base 1", "knowledge_sources": "elasticsearch", "embedding_model_name": "model1" }, { "knowledge_id": 2, "index_name": "knowledge2", + "knowledge_name": "Knowledge Base 2", "knowledge_sources": "vectordb", "embedding_model_name": "model2" } @@ -648,4 +731,391 @@ def test_update_model_name_by_index_name_exception(monkeypatch, mock_session): monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) with pytest.raises(MockSQLAlchemyError, match="Database error"): - update_model_name_by_index_name("test_index", "new_model", "tenant1", "user1") \ No newline at end of file + update_model_name_by_index_name("test_index", "new_model", "tenant1", "user1") + + +def test_create_knowledge_record_with_index_name_only(monkeypatch, mock_session): + """Test create_knowledge_record when only index_name is provided (no knowledge_name)""" + session, _ = mock_session + + mock_record = MockKnowledgeRecord() + mock_record.knowledge_id = 123 + mock_record.index_name = "test_index" + mock_record.knowledge_name = "test_index" # Should use index_name as knowledge_name + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + test_query = { + "index_name": "test_index", + "knowledge_describe": "Test description", + "user_id": "test_user", + "tenant_id": "test_tenant", + "embedding_model_name": "test_model" + # No knowledge_name provided + } + + with patch('backend.database.knowledge_db.KnowledgeRecord', return_value=mock_record): + result = create_knowledge_record(test_query) + + assert result == { + "knowledge_id": 123, + "index_name": "test_index", + "knowledge_name": "test_index", + } + session.add.assert_called_once_with(mock_record) + assert session.flush.call_count == 1 + session.commit.assert_called_once() + + +def test_create_knowledge_record_without_user_id(monkeypatch, mock_session): + """Test create_knowledge_record without user_id""" + session, _ = mock_session + + mock_record = MockKnowledgeRecord() + mock_record.knowledge_id = 123 + mock_record.index_name = "test_index" + mock_record.knowledge_name = "test_kb" + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + test_query = { + "index_name": "test_index", + "knowledge_name": "test_kb", + "knowledge_describe": "Test description", + "tenant_id": "test_tenant", + "embedding_model_name": "test_model" + # No user_id provided + } + + with patch('backend.database.knowledge_db.KnowledgeRecord', return_value=mock_record): + result = create_knowledge_record(test_query) + + assert result["knowledge_id"] == 123 + session.add.assert_called_once_with(mock_record) + session.commit.assert_called_once() + + +def test_create_knowledge_record_without_index_name_and_knowledge_name(monkeypatch, mock_session): + """Test create_knowledge_record when neither index_name nor knowledge_name is provided""" + session, _ = mock_session + + mock_record = MockKnowledgeRecord() + mock_record.knowledge_id = 7 + mock_record.knowledge_name = None # Both are None, so knowledge_name will be None + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + # Deterministic index name + monkeypatch.setattr("backend.database.knowledge_db._generate_index_name", lambda _: "7-generated") + + test_query = { + "knowledge_describe": "desc", + "user_id": "user-1", + "tenant_id": "tenant-1", + "embedding_model_name": "model-x" + # Neither index_name nor knowledge_name provided + } + + with patch('backend.database.knowledge_db.KnowledgeRecord', return_value=mock_record): + result = create_knowledge_record(test_query) + + assert result == { + "knowledge_id": 7, + "index_name": "7-generated", + "knowledge_name": None, + } + assert mock_record.index_name == "7-generated" + assert session.flush.call_count == 2 # initial insert + index_name update + session.commit.assert_called_once() + + +def test_update_knowledge_record_without_user_id(monkeypatch, mock_session): + """Test update_knowledge_record without user_id""" + session, query = mock_session + + mock_record = MockKnowledgeRecord() + mock_record.knowledge_describe = "old description" + mock_record.updated_by = "original_user" + + mock_filter = MagicMock() + mock_filter.first.return_value = mock_record + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + test_query = { + "index_name": "test_knowledge", + "knowledge_describe": "Updated description" + # No user_id provided + } + + result = update_knowledge_record(test_query) + + assert result is True + assert mock_record.knowledge_describe == "Updated description" + # updated_by should remain unchanged when user_id is not provided + assert mock_record.updated_by == "original_user" + session.flush.assert_called_once() + session.commit.assert_called_once() + + +def test_delete_knowledge_record_without_user_id(monkeypatch, mock_session): + """Test delete_knowledge_record without user_id""" + session, query = mock_session + + mock_record = MockKnowledgeRecord() + mock_record.delete_flag = 'N' + mock_record.updated_by = "original_user" + + mock_filter = MagicMock() + mock_filter.first.return_value = mock_record + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + test_query = { + "index_name": "test_knowledge" + # No user_id provided + } + + result = delete_knowledge_record(test_query) + + assert result is True + assert mock_record.delete_flag == 'Y' + # updated_by should remain unchanged when user_id is not provided + assert mock_record.updated_by == "original_user" + session.flush.assert_called_once() + session.commit.assert_called_once() + + +def test_get_knowledge_record_with_tenant_id_none(monkeypatch, mock_session): + """Test get_knowledge_record with tenant_id explicitly set to None""" + session, query = mock_session + + mock_record = MockKnowledgeRecord() + mock_record.knowledge_id = 123 + + mock_filter = MagicMock() + mock_filter.first.return_value = mock_record + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + expected_result = {"knowledge_id": 123} + monkeypatch.setattr("backend.database.knowledge_db.as_dict", lambda x: expected_result) + + test_query = { + "index_name": "test_knowledge", + "tenant_id": None # Explicitly None + } + + result = get_knowledge_record(test_query) + + assert result == expected_result + # Should not add tenant_id filter when tenant_id is None + assert query.filter.call_count >= 1 + + +def test_get_knowledge_info_by_knowledge_ids_empty_list(monkeypatch, mock_session): + """Test get_knowledge_info_by_knowledge_ids with empty list""" + session, query = mock_session + + mock_filter = MagicMock() + mock_filter.all.return_value = [] + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + knowledge_ids = [] + result = get_knowledge_info_by_knowledge_ids(knowledge_ids) + + assert result == [] + + +def test_get_knowledge_info_by_knowledge_ids_includes_knowledge_name(monkeypatch, mock_session): + """Test get_knowledge_info_by_knowledge_ids includes knowledge_name field""" + session, query = mock_session + + mock_record1 = MockKnowledgeRecord() + mock_record1.knowledge_id = 1 + mock_record1.index_name = "knowledge1" + mock_record1.knowledge_name = "Knowledge Base 1" + mock_record1.knowledge_sources = "elasticsearch" + mock_record1.embedding_model_name = "model1" + + mock_filter = MagicMock() + mock_filter.all.return_value = [mock_record1] + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + knowledge_ids = ["1"] + result = get_knowledge_info_by_knowledge_ids(knowledge_ids) + + expected = [ + { + "knowledge_id": 1, + "index_name": "knowledge1", + "knowledge_name": "Knowledge Base 1", + "knowledge_sources": "elasticsearch", + "embedding_model_name": "model1" + } + ] + + assert result == expected + assert "knowledge_name" in result[0] + + +def test_get_knowledge_info_by_knowledge_ids_with_none_knowledge_name(monkeypatch, mock_session): + """Test get_knowledge_info_by_knowledge_ids when knowledge_name is None""" + session, query = mock_session + + mock_record1 = MockKnowledgeRecord() + mock_record1.knowledge_id = 1 + mock_record1.index_name = "knowledge1" + mock_record1.knowledge_name = None # None knowledge_name + mock_record1.knowledge_sources = "elasticsearch" + mock_record1.embedding_model_name = "model1" + + mock_filter = MagicMock() + mock_filter.all.return_value = [mock_record1] + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + knowledge_ids = ["1"] + result = get_knowledge_info_by_knowledge_ids(knowledge_ids) + + expected = [ + { + "knowledge_id": 1, + "index_name": "knowledge1", + "knowledge_name": None, + "knowledge_sources": "elasticsearch", + "embedding_model_name": "model1" + } + ] + + assert result == expected + assert result[0]["knowledge_name"] is None + + +def test_get_index_name_by_knowledge_name_success(monkeypatch, mock_session): + """Test successfully getting index_name by knowledge_name""" + session, query = mock_session + + mock_record = MockKnowledgeRecord() + mock_record.knowledge_name = "My Knowledge Base" + mock_record.index_name = "123-abc123def456" + mock_record.tenant_id = "tenant1" + mock_record.delete_flag = 'N' + + mock_filter = MagicMock() + mock_filter.first.return_value = mock_record + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + result = get_index_name_by_knowledge_name("My Knowledge Base", "tenant1") + + assert result == "123-abc123def456" + + +def test_get_index_name_by_knowledge_name_not_found(monkeypatch, mock_session): + """Test get_index_name_by_knowledge_name when knowledge base is not found""" + session, query = mock_session + + mock_filter = MagicMock() + mock_filter.first.return_value = None + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + with pytest.raises(ValueError, match="Knowledge base 'Nonexistent KB' not found for the current tenant"): + get_index_name_by_knowledge_name("Nonexistent KB", "tenant1") + + +def test_get_index_name_by_knowledge_name_exception(monkeypatch, mock_session): + """Test exception when getting index_name by knowledge_name""" + session, query = mock_session + query.filter.side_effect = MockSQLAlchemyError("Database error") + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + with pytest.raises(MockSQLAlchemyError, match="Database error"): + get_index_name_by_knowledge_name("My Knowledge Base", "tenant1") + + +def test_generate_index_name_format(monkeypatch): + """Test _generate_index_name generates correct format""" + # Mock uuid to get deterministic result + mock_uuid = MagicMock() + mock_uuid.hex = "abc123def456" + monkeypatch.setattr("backend.database.knowledge_db.uuid.uuid4", lambda: mock_uuid) + + result = _generate_index_name(123) + + assert result == "123-abc123def456" + assert result.startswith("123-") + assert len(result) == len("123-abc123def456") + + +def test_get_knowledge_ids_by_index_names_empty_list(monkeypatch, mock_session): + """Test get_knowledge_ids_by_index_names with empty list""" + session, _ = mock_session + + mock_specific_query = MagicMock() + mock_filter = MagicMock() + mock_filter.all.return_value = [] + mock_specific_query.filter.return_value = mock_filter + + def mock_query_func(*args, **kwargs): + return mock_specific_query + + session.query = mock_query_func + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + index_names = [] + result = get_knowledge_ids_by_index_names(index_names) + + assert result == [] \ No newline at end of file diff --git a/test/backend/services/test_agent_service.py b/test/backend/services/test_agent_service.py index 9c202209c..d4b28eae5 100644 --- a/test/backend/services/test_agent_service.py +++ b/test/backend/services/test_agent_service.py @@ -3,15 +3,20 @@ import json from contextlib import contextmanager from unittest.mock import patch, MagicMock, mock_open, call, Mock, AsyncMock +import os import pytest from fastapi.responses import StreamingResponse from fastapi import Request - -# Import the actual ToolConfig model for testing before any mocking from nexent.core.agents.agent_model import ToolConfig -import os +from backend.consts.model import ( + AgentNameBatchCheckItem, + AgentNameBatchCheckRequest, + AgentNameBatchRegenerateItem, + AgentNameBatchRegenerateRequest, +) + # Patch environment variables before any imports that might use them os.environ.setdefault('MINIO_ENDPOINT', 'http://localhost:9000') @@ -5629,6 +5634,260 @@ async def fake_update_tool_list(tenant_id, user_id): assert relationships == [(100 + 1, 100 + 2, "tenant1")] +# ===================================================================== +# Tests for batch agent name conflict and regeneration +# ===================================================================== + + +@pytest.mark.asyncio +async def test_check_agent_name_conflict_batch_impl_detects_conflicts(monkeypatch): + monkeypatch.setattr( + "backend.services.agent_service.get_current_user_info", + lambda authorization: ("user-x", "tenant-x", "en"), + raising=False, + ) + existing_agents = [ + {"agent_id": 10, "name": "dup_name", "display_name": "Dup Display"}, + {"agent_id": 11, "name": "unique", "display_name": "Unique"}, + ] + monkeypatch.setattr( + "backend.services.agent_service.query_all_agent_info_by_tenant_id", + lambda tenant_id: existing_agents, + raising=False, + ) + + from consts.model import AgentNameBatchCheckItem, AgentNameBatchCheckRequest + + request = AgentNameBatchCheckRequest( + items=[ + AgentNameBatchCheckItem(name="dup_name", display_name="Another"), + AgentNameBatchCheckItem(name="", display_name=None), + ] + ) + + result = await agent_service.check_agent_name_conflict_batch_impl( + request, authorization="Bearer token" + ) + + assert result[0]["name_conflict"] is True + assert result[0]["display_name_conflict"] is False + assert result[0]["conflict_agents"] == [ + {"name": "dup_name", "display_name": "Dup Display"} + ] + assert result[1]["name_conflict"] is False + assert result[1]["display_name_conflict"] is False + assert result[1]["conflict_agents"] == [] + + +@pytest.mark.asyncio +async def test_check_agent_name_conflict_batch_impl_display_conflict(monkeypatch): + monkeypatch.setattr( + "backend.services.agent_service.get_current_user_info", + lambda authorization: ("user-x", "tenant-x", "en"), + raising=False, + ) + existing_agents = [ + {"agent_id": 3, "name": "alpha", "display_name": "Shown"}, + ] + monkeypatch.setattr( + "backend.services.agent_service.query_all_agent_info_by_tenant_id", + lambda tenant_id: existing_agents, + raising=False, + ) + + request = AgentNameBatchCheckRequest( + items=[AgentNameBatchCheckItem(name="beta", display_name="Shown")] + ) + + result = await agent_service.check_agent_name_conflict_batch_impl( + request, authorization="Bearer token" + ) + + assert result[0]["name_conflict"] is False + assert result[0]["display_name_conflict"] is True + assert result[0]["conflict_agents"] == [ + {"name": "alpha", "display_name": "Shown"} + ] + + +@pytest.mark.asyncio +async def test_check_agent_name_conflict_batch_impl_skips_same_agent(monkeypatch): + monkeypatch.setattr( + "backend.services.agent_service.get_current_user_info", + lambda authorization: ("user-x", "tenant-x", "en"), + raising=False, + ) + existing_agents = [ + {"agent_id": 7, "name": "self", "display_name": "Self Display"}, + ] + monkeypatch.setattr( + "backend.services.agent_service.query_all_agent_info_by_tenant_id", + lambda tenant_id: existing_agents, + raising=False, + ) + + request = AgentNameBatchCheckRequest( + items=[ + AgentNameBatchCheckItem( + agent_id=7, name="self", display_name="Self Display" + ) + ] + ) + + result = await agent_service.check_agent_name_conflict_batch_impl( + request, authorization="Bearer token" + ) + + assert result[0]["name_conflict"] is False + assert result[0]["display_name_conflict"] is False + assert result[0]["conflict_agents"] == [] + + +@pytest.mark.asyncio +async def test_regenerate_agent_name_batch_impl_uses_llm(monkeypatch): + monkeypatch.setattr( + "backend.services.agent_service.get_current_user_info", + lambda authorization: ("user-x", "tenant-x", "en"), + raising=False, + ) + monkeypatch.setattr( + "backend.services.agent_service.query_all_agent_info_by_tenant_id", + lambda tenant_id: [{"agent_id": 2, "name": "dup_name", "display_name": "Dup"}], + raising=False, + ) + monkeypatch.setattr( + "backend.services.agent_service.tenant_config_manager.get_model_config", + lambda key, tenant_id: {"model_id": "model-1", "display_name": "LLM"}, + raising=False, + ) + + async def fake_to_thread(fn, *args, **kwargs): + return fn(*args, **kwargs) + + monkeypatch.setattr("asyncio.to_thread", fake_to_thread, raising=False) + monkeypatch.setattr( + "backend.services.agent_service._regenerate_agent_name_with_llm", + lambda **kwargs: "regenerated_name", + raising=False, + ) + monkeypatch.setattr( + "backend.services.agent_service._regenerate_agent_display_name_with_llm", + lambda **kwargs: "Regenerated Display", + raising=False, + ) + + + + request = AgentNameBatchRegenerateRequest( + items=[ + AgentNameBatchRegenerateItem( + agent_id=1, + name="dup_name", + display_name="Dup", + task_description="desc", + ) + ] + ) + + result = await agent_service.regenerate_agent_name_batch_impl( + request, authorization="Bearer token" + ) + + assert result == [{"name": "regenerated_name", "display_name": "Regenerated Display"}] + + +@pytest.mark.asyncio +async def test_regenerate_agent_name_batch_impl_no_model(monkeypatch): + monkeypatch.setattr( + "backend.services.agent_service.get_current_user_info", + lambda authorization: ("user-x", "tenant-x", "en"), + raising=False, + ) + monkeypatch.setattr( + "backend.services.agent_service.query_all_agent_info_by_tenant_id", + lambda tenant_id: [], + raising=False, + ) + monkeypatch.setattr( + "backend.services.agent_service.tenant_config_manager.get_model_config", + lambda key, tenant_id: None, + raising=False, + ) + + from consts.model import AgentNameBatchRegenerateItem, AgentNameBatchRegenerateRequest + + request = AgentNameBatchRegenerateRequest( + items=[AgentNameBatchRegenerateItem(agent_id=1, name="dup", display_name="Dup")] + ) + + with pytest.raises(ValueError): + await agent_service.regenerate_agent_name_batch_impl( + request, authorization="Bearer token" + ) + + +@pytest.mark.asyncio +async def test_regenerate_agent_name_batch_impl_llm_failure_fallback(monkeypatch): + monkeypatch.setattr( + "backend.services.agent_service.get_current_user_info", + lambda authorization: ("user-x", "tenant-x", "en"), + raising=False, + ) + # existing agent ensures duplicate detection + monkeypatch.setattr( + "backend.services.agent_service.query_all_agent_info_by_tenant_id", + lambda tenant_id: [{"agent_id": 2, "name": "dup", "display_name": "Dup"}], + raising=False, + ) + monkeypatch.setattr( + "backend.services.agent_service.tenant_config_manager.get_model_config", + lambda key, tenant_id: {"model_id": "model-1", "display_name": "LLM"}, + raising=False, + ) + + async def run_in_thread(fn, *args, **kwargs): + return fn(*args, **kwargs) + + monkeypatch.setattr("asyncio.to_thread", run_in_thread, raising=False) + monkeypatch.setattr( + "backend.services.agent_service._regenerate_agent_name_with_llm", + lambda **kwargs: (_ for _ in ()).throw(Exception("llm-fail")), + raising=False, + ) + monkeypatch.setattr( + "backend.services.agent_service._regenerate_agent_display_name_with_llm", + lambda **kwargs: (_ for _ in ()).throw(Exception("llm-fail")), + raising=False, + ) + monkeypatch.setattr( + "backend.services.agent_service._generate_unique_agent_name_with_suffix", + lambda base_value, **kwargs: f"{base_value}_fallback", + raising=False, + ) + monkeypatch.setattr( + "backend.services.agent_service._generate_unique_display_name_with_suffix", + lambda base_value, **kwargs: f"{base_value}_fallback", + raising=False, + ) + + request = AgentNameBatchRegenerateRequest( + items=[ + AgentNameBatchRegenerateItem( + agent_id=1, + name="dup", + display_name="Dup", + task_description="desc", + ) + ] + ) + + result = await agent_service.regenerate_agent_name_batch_impl( + request, authorization="Bearer token" + ) + + assert result == [{"name": "dup_fallback", "display_name": "Dup_fallback"}] + + # ===================================================================== # Tests for _resolve_model_with_fallback helper function # ===================================================================== @@ -6233,28 +6492,19 @@ async def test_get_agent_info_impl_with_unavailable_agent( @patch('backend.services.agent_service.create_or_update_tool_by_tool_info') @patch('backend.services.agent_service.create_agent') @patch('backend.services.agent_service.query_all_tools') -@patch('backend.services.agent_service._generate_unique_agent_name_with_suffix') -@patch('backend.services.agent_service._regenerate_agent_name_with_llm') -@patch('backend.services.agent_service._check_agent_name_duplicate') -@patch('backend.services.agent_service.query_all_agent_info_by_tenant_id') @patch('backend.services.agent_service._resolve_model_with_fallback') -async def test_import_agent_by_agent_id_duplicate_name_with_llm_success( +async def test_import_agent_by_agent_id_allows_duplicate_name_without_regen( mock_resolve_model, - mock_query_all_agents, - mock_check_name_dup, - mock_regen_name, - mock_generate_unique_name, mock_query_all_tools, mock_create_agent, mock_create_tool ): - """Test import_agent_by_agent_id when agent_name is duplicate and LLM regeneration succeeds (line 1043-1060).""" - # Setup + """ + New behavior: import_agent_by_agent_id no longer performs duplicate-name regeneration. + It should create the agent with the provided name/display_name even if duplicates exist. + """ mock_query_all_tools.return_value = [] - mock_resolve_model.side_effect = [1, 2] # model_id=1, business_logic_model_id=2 - mock_query_all_agents.return_value = [{"name": "duplicate_name", "display_name": "Display"}] - mock_check_name_dup.return_value = True # Name is duplicate - mock_regen_name.return_value = "regenerated_name" + mock_resolve_model.side_effect = [1, 2] mock_create_agent.return_value = {"agent_id": 456} agent_info = ExportAndImportAgentInfo( @@ -6277,7 +6527,6 @@ async def test_import_agent_by_agent_id_duplicate_name_with_llm_success( business_logic_model_name="Model2" ) - # Execute result = await import_agent_by_agent_id( import_agent_info=agent_info, tenant_id="test_tenant", @@ -6285,42 +6534,28 @@ async def test_import_agent_by_agent_id_duplicate_name_with_llm_success( skip_duplicate_regeneration=False ) - # Assert assert result == 456 - mock_check_name_dup.assert_called_once() - mock_regen_name.assert_called_once() mock_create_agent.assert_called_once() - # Verify regenerated name was used - assert mock_create_agent.call_args[1]["agent_info"]["name"] == "regenerated_name" + assert mock_create_agent.call_args[1]["agent_info"]["name"] == "duplicate_name" + assert mock_create_agent.call_args[1]["agent_info"]["display_name"] == "Test Display" @pytest.mark.asyncio @patch('backend.services.agent_service.create_or_update_tool_by_tool_info') @patch('backend.services.agent_service.create_agent') @patch('backend.services.agent_service.query_all_tools') -@patch('backend.services.agent_service._generate_unique_agent_name_with_suffix') -@patch('backend.services.agent_service._regenerate_agent_name_with_llm') -@patch('backend.services.agent_service._check_agent_name_duplicate') -@patch('backend.services.agent_service.query_all_agent_info_by_tenant_id') @patch('backend.services.agent_service._resolve_model_with_fallback') -async def test_import_agent_by_agent_id_duplicate_name_llm_failure_fallback( +async def test_import_agent_by_agent_id_duplicate_name_no_regen_fallback( mock_resolve_model, - mock_query_all_agents, - mock_check_name_dup, - mock_regen_name, - mock_generate_unique_name, mock_query_all_tools, mock_create_agent, mock_create_tool ): - """Test import_agent_by_agent_id when agent_name is duplicate, LLM regeneration fails, uses fallback (line 1061-1067).""" - # Setup + """ + New behavior: even when duplicate name, import proceeds without regeneration or fallback. + """ mock_query_all_tools.return_value = [] mock_resolve_model.side_effect = [1, 2] - mock_query_all_agents.return_value = [{"name": "duplicate_name", "display_name": "Display"}] - mock_check_name_dup.return_value = True - mock_regen_name.side_effect = Exception("LLM failed") - mock_generate_unique_name.return_value = "fallback_name_1" mock_create_agent.return_value = {"agent_id": 456} agent_info = ExportAndImportAgentInfo( @@ -6343,7 +6578,6 @@ async def test_import_agent_by_agent_id_duplicate_name_llm_failure_fallback( business_logic_model_name="Model2" ) - # Execute result = await import_agent_by_agent_id( import_agent_info=agent_info, tenant_id="test_tenant", @@ -6351,41 +6585,27 @@ async def test_import_agent_by_agent_id_duplicate_name_llm_failure_fallback( skip_duplicate_regeneration=False ) - # Assert assert result == 456 - mock_regen_name.assert_called_once() - mock_generate_unique_name.assert_called_once() mock_create_agent.assert_called_once() - # Verify fallback name was used - assert mock_create_agent.call_args[1]["agent_info"]["name"] == "fallback_name_1" + assert mock_create_agent.call_args[1]["agent_info"]["name"] == "duplicate_name" @pytest.mark.asyncio @patch('backend.services.agent_service.create_or_update_tool_by_tool_info') @patch('backend.services.agent_service.create_agent') @patch('backend.services.agent_service.query_all_tools') -@patch('backend.services.agent_service._generate_unique_agent_name_with_suffix') -@patch('backend.services.agent_service._regenerate_agent_name_with_llm') -@patch('backend.services.agent_service._check_agent_name_duplicate') -@patch('backend.services.agent_service.query_all_agent_info_by_tenant_id') @patch('backend.services.agent_service._resolve_model_with_fallback') -async def test_import_agent_by_agent_id_duplicate_name_no_model_fallback( +async def test_import_agent_by_agent_id_duplicate_name_no_model_still_allows( mock_resolve_model, - mock_query_all_agents, - mock_check_name_dup, - mock_regen_name, - mock_generate_unique_name, mock_query_all_tools, mock_create_agent, mock_create_tool ): - """Test import_agent_by_agent_id when agent_name is duplicate but no model available, uses fallback (line 1068-1074).""" - # Setup + """ + New behavior: even without model, duplicate name passes through unchanged. + """ mock_query_all_tools.return_value = [] - mock_resolve_model.side_effect = [None, None] # No models available - mock_query_all_agents.return_value = [{"name": "duplicate_name", "display_name": "Display"}] - mock_check_name_dup.return_value = True - mock_generate_unique_name.return_value = "fallback_name_2" + mock_resolve_model.side_effect = [None, None] mock_create_agent.return_value = {"agent_id": 456} agent_info = ExportAndImportAgentInfo( @@ -6408,7 +6628,6 @@ async def test_import_agent_by_agent_id_duplicate_name_no_model_fallback( business_logic_model_name=None ) - # Execute result = await import_agent_by_agent_id( import_agent_info=agent_info, tenant_id="test_tenant", @@ -6416,45 +6635,25 @@ async def test_import_agent_by_agent_id_duplicate_name_no_model_fallback( skip_duplicate_regeneration=False ) - # Assert assert result == 456 - mock_check_name_dup.assert_called_once() - mock_regen_name.assert_not_called() # Should not call LLM when no model - mock_generate_unique_name.assert_called_once() mock_create_agent.assert_called_once() - # Verify fallback name was used - assert mock_create_agent.call_args[1]["agent_info"]["name"] == "fallback_name_2" + assert mock_create_agent.call_args[1]["agent_info"]["name"] == "duplicate_name" @pytest.mark.asyncio @patch('backend.services.agent_service.create_or_update_tool_by_tool_info') @patch('backend.services.agent_service.create_agent') @patch('backend.services.agent_service.query_all_tools') -@patch('backend.services.agent_service._generate_unique_display_name_with_suffix') -@patch('backend.services.agent_service._regenerate_agent_display_name_with_llm') -@patch('backend.services.agent_service._check_agent_display_name_duplicate') -@patch('backend.services.agent_service._check_agent_name_duplicate') -@patch('backend.services.agent_service.query_all_agent_info_by_tenant_id') @patch('backend.services.agent_service._resolve_model_with_fallback') -async def test_import_agent_by_agent_id_duplicate_display_name_with_llm_success( +async def test_import_agent_by_agent_id_duplicate_display_name_allowed( mock_resolve_model, - mock_query_all_agents, - mock_check_name_dup, - mock_check_display_dup, - mock_regen_display, - mock_generate_unique_display, mock_query_all_tools, mock_create_agent, mock_create_tool ): - """Test import_agent_by_agent_id when agent_display_name is duplicate and LLM regeneration succeeds (line 1077-1092).""" - # Setup + """New behavior: duplicate display_name passes through without regeneration.""" mock_query_all_tools.return_value = [] mock_resolve_model.side_effect = [1, 2] - mock_query_all_agents.return_value = [{"name": "name1", "display_name": "duplicate_display"}] - mock_check_name_dup.return_value = False # Name is not duplicate - mock_check_display_dup.return_value = True # Display name is duplicate - mock_regen_display.return_value = "regenerated_display" mock_create_agent.return_value = {"agent_id": 456} agent_info = ExportAndImportAgentInfo( @@ -6477,7 +6676,6 @@ async def test_import_agent_by_agent_id_duplicate_display_name_with_llm_success( business_logic_model_name="Model2" ) - # Execute result = await import_agent_by_agent_id( import_agent_info=agent_info, tenant_id="test_tenant", @@ -6485,45 +6683,27 @@ async def test_import_agent_by_agent_id_duplicate_display_name_with_llm_success( skip_duplicate_regeneration=False ) - # Assert assert result == 456 - mock_check_display_dup.assert_called_once() - mock_regen_display.assert_called_once() mock_create_agent.assert_called_once() - # Verify regenerated display name was used - assert mock_create_agent.call_args[1]["agent_info"]["display_name"] == "regenerated_display" + assert mock_create_agent.call_args[1]["agent_info"]["display_name"] == "duplicate_display" @pytest.mark.asyncio @patch('backend.services.agent_service.create_or_update_tool_by_tool_info') @patch('backend.services.agent_service.create_agent') @patch('backend.services.agent_service.query_all_tools') -@patch('backend.services.agent_service._generate_unique_display_name_with_suffix') -@patch('backend.services.agent_service._regenerate_agent_display_name_with_llm') -@patch('backend.services.agent_service._check_agent_display_name_duplicate') -@patch('backend.services.agent_service._check_agent_name_duplicate') -@patch('backend.services.agent_service.query_all_agent_info_by_tenant_id') @patch('backend.services.agent_service._resolve_model_with_fallback') -async def test_import_agent_by_agent_id_duplicate_display_name_llm_failure_fallback( +async def test_import_agent_by_agent_id_duplicate_display_name_no_llm_fallback( mock_resolve_model, - mock_query_all_agents, - mock_check_name_dup, - mock_check_display_dup, - mock_regen_display, - mock_generate_unique_display, mock_query_all_tools, mock_create_agent, mock_create_tool ): - """Test import_agent_by_agent_id when agent_display_name is duplicate, LLM regeneration fails, uses fallback (line 1093-1099).""" - # Setup + """ + New behavior: duplicate display_name passes through without LLM; fallback not invoked. + """ mock_query_all_tools.return_value = [] mock_resolve_model.side_effect = [1, 2] - mock_query_all_agents.return_value = [{"name": "name1", "display_name": "duplicate_display"}] - mock_check_name_dup.return_value = False - mock_check_display_dup.return_value = True - mock_regen_display.side_effect = Exception("LLM failed") - mock_generate_unique_display.return_value = "fallback_display_1" mock_create_agent.return_value = {"agent_id": 456} agent_info = ExportAndImportAgentInfo( @@ -6546,7 +6726,6 @@ async def test_import_agent_by_agent_id_duplicate_display_name_llm_failure_fallb business_logic_model_name="Model2" ) - # Execute result = await import_agent_by_agent_id( import_agent_info=agent_info, tenant_id="test_tenant", @@ -6554,44 +6733,27 @@ async def test_import_agent_by_agent_id_duplicate_display_name_llm_failure_fallb skip_duplicate_regeneration=False ) - # Assert assert result == 456 - mock_regen_display.assert_called_once() - mock_generate_unique_display.assert_called_once() mock_create_agent.assert_called_once() - # Verify fallback display name was used - assert mock_create_agent.call_args[1]["agent_info"]["display_name"] == "fallback_display_1" + assert mock_create_agent.call_args[1]["agent_info"]["display_name"] == "duplicate_display" @pytest.mark.asyncio @patch('backend.services.agent_service.create_or_update_tool_by_tool_info') @patch('backend.services.agent_service.create_agent') @patch('backend.services.agent_service.query_all_tools') -@patch('backend.services.agent_service._generate_unique_display_name_with_suffix') -@patch('backend.services.agent_service._regenerate_agent_display_name_with_llm') -@patch('backend.services.agent_service._check_agent_display_name_duplicate') -@patch('backend.services.agent_service._check_agent_name_duplicate') -@patch('backend.services.agent_service.query_all_agent_info_by_tenant_id') @patch('backend.services.agent_service._resolve_model_with_fallback') -async def test_import_agent_by_agent_id_duplicate_display_name_no_model_fallback( +async def test_import_agent_by_agent_id_duplicate_display_name_no_model_still_allowed( mock_resolve_model, - mock_query_all_agents, - mock_check_name_dup, - mock_check_display_dup, - mock_regen_display, - mock_generate_unique_display, mock_query_all_tools, mock_create_agent, mock_create_tool ): - """Test import_agent_by_agent_id when agent_display_name is duplicate but no model available, uses fallback (line 1100-1106).""" - # Setup + """ + New behavior: even without model, duplicate display_name passes through unchanged. + """ mock_query_all_tools.return_value = [] - mock_resolve_model.side_effect = [None, None] # No models available - mock_query_all_agents.return_value = [{"name": "name1", "display_name": "duplicate_display"}] - mock_check_name_dup.return_value = False - mock_check_display_dup.return_value = True - mock_generate_unique_display.return_value = "fallback_display_2" + mock_resolve_model.side_effect = [None, None] mock_create_agent.return_value = {"agent_id": 456} agent_info = ExportAndImportAgentInfo( @@ -6614,7 +6776,6 @@ async def test_import_agent_by_agent_id_duplicate_display_name_no_model_fallback business_logic_model_name=None ) - # Execute result = await import_agent_by_agent_id( import_agent_info=agent_info, tenant_id="test_tenant", @@ -6622,11 +6783,6 @@ async def test_import_agent_by_agent_id_duplicate_display_name_no_model_fallback skip_duplicate_regeneration=False ) - # Assert assert result == 456 - mock_check_display_dup.assert_called_once() - mock_regen_display.assert_not_called() # Should not call LLM when no model - mock_generate_unique_display.assert_called_once() mock_create_agent.assert_called_once() - # Verify fallback display name was used - assert mock_create_agent.call_args[1]["agent_info"]["display_name"] == "fallback_display_2" + assert mock_create_agent.call_args[1]["agent_info"]["display_name"] == "duplicate_display" diff --git a/test/backend/services/test_conversation_management_service.py b/test/backend/services/test_conversation_management_service.py index feeb68d0e..173a3b6aa 100644 --- a/test/backend/services/test_conversation_management_service.py +++ b/test/backend/services/test_conversation_management_service.py @@ -327,7 +327,7 @@ def test_extract_user_messages(self): self.assertIn("Give me examples of AI applications", result) self.assertIn("AI stands for Artificial Intelligence.", result) - @patch('backend.services.conversation_management_service.OpenAIServerModel') + @patch('backend.services.conversation_management_service.OpenAIModel') @patch('backend.services.conversation_management_service.get_generate_title_prompt_template') @patch('backend.services.conversation_management_service.tenant_config_manager.get_model_config') def test_call_llm_for_title(self, mock_get_model_config, mock_get_prompt_template, mock_openai): @@ -360,7 +360,7 @@ def test_call_llm_for_title(self, mock_get_model_config, mock_get_prompt_templat mock_llm_instance.generate.assert_called_once() mock_get_prompt_template.assert_called_once_with(language='zh') - @patch('backend.services.conversation_management_service.OpenAIServerModel') + @patch('backend.services.conversation_management_service.OpenAIModel') @patch('backend.services.conversation_management_service.get_generate_title_prompt_template') @patch('backend.services.conversation_management_service.tenant_config_manager.get_model_config') def test_call_llm_for_title_response_none_zh(self, mock_get_model_config, mock_get_prompt_template, mock_openai): @@ -392,7 +392,7 @@ def test_call_llm_for_title_response_none_zh(self, mock_get_model_config, mock_g mock_llm_instance.generate.assert_called_once() mock_get_prompt_template.assert_called_once_with(language='zh') - @patch('backend.services.conversation_management_service.OpenAIServerModel') + @patch('backend.services.conversation_management_service.OpenAIModel') @patch('backend.services.conversation_management_service.get_generate_title_prompt_template') @patch('backend.services.conversation_management_service.tenant_config_manager.get_model_config') def test_call_llm_for_title_response_none_en(self, mock_get_model_config, mock_get_prompt_template, mock_openai): diff --git a/test/backend/services/test_model_management_service.py b/test/backend/services/test_model_management_service.py index 86e1cac73..48741a0f8 100644 --- a/test/backend/services/test_model_management_service.py +++ b/test/backend/services/test_model_management_service.py @@ -405,6 +405,73 @@ async def test_create_model_for_tenant_embedding_sets_dimension(): assert mock_create.call_count == 1 +@pytest.mark.asyncio +async def test_create_model_for_tenant_embedding_sets_default_chunk_batch(): + """chunk_batch defaults to 10 when not provided for embedding models.""" + svc = import_svc() + + with mock.patch.object(svc, "get_model_by_display_name", return_value=None), \ + mock.patch.object(svc, "embedding_dimension_check", new=mock.AsyncMock(return_value=512)) as mock_dim, \ + mock.patch.object(svc, "create_model_record") as mock_create, \ + mock.patch.object(svc, "split_repo_name", return_value=("openai", "text-embedding-3-small")): + + user_id = "u1" + tenant_id = "t1" + model_data = { + "model_name": "openai/text-embedding-3-small", + "display_name": None, + "base_url": "https://api.openai.com", + "model_type": "embedding", + "chunk_batch": None, # Explicitly unset to exercise defaulting + } + + await svc.create_model_for_tenant(user_id, tenant_id, model_data) + + mock_dim.assert_awaited_once() + assert mock_create.call_count == 1 + # chunk_batch should be defaulted before persistence + create_args = mock_create.call_args[0][0] + assert create_args["chunk_batch"] == 10 + + +@pytest.mark.asyncio +async def test_create_model_for_tenant_multi_embedding_sets_default_chunk_batch(): + """chunk_batch defaults to 10 when not provided for multi_embedding models (covers line 79).""" + svc = import_svc() + + with mock.patch.object(svc, "get_model_by_display_name", return_value=None), \ + mock.patch.object(svc, "embedding_dimension_check", new=mock.AsyncMock(return_value=512)) as mock_dim, \ + mock.patch.object(svc, "create_model_record") as mock_create, \ + mock.patch.object(svc, "split_repo_name", return_value=("openai", "clip")): + + user_id = "u1" + tenant_id = "t1" + model_data = { + "model_name": "openai/clip", + "display_name": None, + "base_url": "https://api.openai.com", + "model_type": "multi_embedding", + "chunk_batch": None, # Explicitly unset to exercise defaulting + } + + await svc.create_model_for_tenant(user_id, tenant_id, model_data) + + mock_dim.assert_awaited_once() + # Should create two records: multi_embedding and its embedding variant + assert mock_create.call_count == 2 + + # Verify chunk_batch was set to 10 for both records + create_calls = mock_create.call_args_list + # First call is for multi_embedding + multi_emb_args = create_calls[0][0][0] + assert multi_emb_args["chunk_batch"] == 10 + assert multi_emb_args["model_type"] == "multi_embedding" + # Second call is for embedding variant + emb_args = create_calls[1][0][0] + assert emb_args["chunk_batch"] == 10 + assert emb_args["model_type"] == "embedding" + + @pytest.mark.asyncio async def test_create_provider_models_for_tenant_success(): svc = import_svc() diff --git a/test/backend/services/test_model_provider_service.py b/test/backend/services/test_model_provider_service.py index ce3a0ab75..0916e61f9 100644 --- a/test/backend/services/test_model_provider_service.py +++ b/test/backend/services/test_model_provider_service.py @@ -304,7 +304,9 @@ async def test_prepare_model_dict_embedding(): assert kwargs["model_name"] == "text-embedding-ada-002" assert kwargs["model_type"] == "embedding" assert kwargs["api_key"] == "test-key" - assert kwargs["max_tokens"] == 1024 + # For embedding models, max_tokens is set to 0 as placeholder, + # will be updated by embedding_dimension_check later + assert kwargs["max_tokens"] == 0 assert kwargs["display_name"] == "openai/text-embedding-ada-002" assert kwargs["expected_chunk_size"] == sys.modules["consts.const"].DEFAULT_EXPECTED_CHUNK_SIZE assert kwargs["maximum_chunk_size"] == sys.modules["consts.const"].DEFAULT_MAXIMUM_CHUNK_SIZE diff --git a/test/backend/services/test_redis_service.py b/test/backend/services/test_redis_service.py index 8ebf7613e..1fba985ba 100644 --- a/test/backend/services/test_redis_service.py +++ b/test/backend/services/test_redis_service.py @@ -1,10 +1,7 @@ import unittest from unittest.mock import patch, MagicMock, call import json -import os import redis -import hashlib -import urllib.parse from backend.services.redis_service import RedisService, get_redis_service @@ -43,7 +40,8 @@ def test_client_property(self, mock_from_url): mock_from_url.assert_called_once_with( 'redis://localhost:6379/0', socket_timeout=5, - socket_connect_timeout=5 + socket_connect_timeout=5, + decode_responses=True ) self.assertEqual(client, self.mock_redis_client) @@ -127,7 +125,23 @@ def test_backend_client_no_env_vars(self, mock_from_url): # Execute & Verify with self.assertRaises(ValueError): _ = redis_service.backend_client - + + @patch('redis.from_url') + @patch('backend.services.redis_service.REDIS_URL', 'redis://localhost:6379/0') + def test_mark_and_check_task_cancelled(self, mock_from_url): + """mark_task_cancelled should set flag and is_task_cancelled should read it.""" + mock_client = MagicMock() + mock_client.setex.return_value = True + mock_client.get.return_value = b"1" + mock_from_url.return_value = mock_client + + service = RedisService() + ok = service.mark_task_cancelled("task-1", ttl_hours=1) + self.assertTrue(ok) + self.assertTrue(service.is_task_cancelled("task-1")) + mock_client.setex.assert_called_once() + mock_client.get.assert_called_once() + def test_delete_knowledgebase_records(self): """Test delete_knowledgebase_records method""" # Setup @@ -216,60 +230,155 @@ def test_delete_document_records_with_error(self): self.assertEqual(len(result["errors"]), 1) self.assertIn("Test error", result["errors"][0]) + def test_cleanup_single_task_related_keys_outer_exception(self): + """Outer handler logs when warning path itself fails.""" + self.redis_service._client = self.mock_redis_client + self.redis_service._backend_client = self.mock_backend_client + self.mock_redis_client.delete.side_effect = redis.RedisError( + "delete failed") + + with patch('backend.services.redis_service.logger.warning', side_effect=Exception("warn boom")), \ + patch('backend.services.redis_service.logger.error') as mock_error: + result = self.redis_service._cleanup_single_task_related_keys( + "task123") + + mock_error.assert_called_once() + self.assertEqual(result, 0) + def test_cleanup_celery_tasks(self): """Test _cleanup_celery_tasks method""" # Setup self.redis_service._backend_client = self.mock_backend_client - + # Create mock task data - task_keys = [b'celery-task-meta-1', b'celery-task-meta-2', b'celery-task-meta-3'] - + task_keys = [b'celery-task-meta-1', + b'celery-task-meta-2', b'celery-task-meta-3'] + # Task 1 matches our index task1_data = json.dumps({ 'result': {'index_name': 'test_index', 'some_key': 'some_value'}, 'parent_id': '2' # This will trigger a parent lookup }).encode() - + # Task 2 has index name in a different location task2_data = json.dumps({ - 'index_name': 'test_index', + 'index_name': 'test_index', 'result': {'some_key': 'some_value'}, 'parent_id': None # No parent }).encode() - + # Task 3 is for a different index task3_data = json.dumps({ 'result': {'index_name': 'other_index', 'some_key': 'some_value'} }).encode() - + # Configure mock responses self.mock_backend_client.keys.return_value = task_keys - self.mock_backend_client.get.side_effect = [task1_data, task2_data, task3_data] - + # Two passes over keys: provide payloads for both passes (6 gets) + self.mock_backend_client.get.side_effect = [ + task1_data, task2_data, task3_data, + task1_data, task2_data, task3_data, + ] + # We expect delete to be called and return 1 each time self.mock_backend_client.delete.return_value = 1 - + # Execute with patch.object(self.redis_service, '_recursively_delete_task_and_parents') as mock_recursive_delete: mock_recursive_delete.side_effect = [(1, {'1'}), (1, {'2'})] result = self.redis_service._cleanup_celery_tasks("test_index") - + # Verify - self.mock_backend_client.keys.assert_called_once_with('celery-task-meta-*') - # We expect 3 calls - one for each task key - self.assertEqual(self.mock_backend_client.get.call_count, 3) - - # Should have called recursive delete twice (for task1 and task2) - self.assertEqual(mock_recursive_delete.call_count, 2) - - # Return value should be the number of deleted tasks - self.assertEqual(result, 2) - + self.mock_backend_client.keys.assert_called_once_with( + 'celery-task-meta-*') + # Implementation fetches task payloads in both passes; expect 6 total (3 keys * 2 passes) + self.assertEqual(self.mock_backend_client.get.call_count, 6) + + # Should have called recursive delete for matched tasks + self.assertGreaterEqual(mock_recursive_delete.call_count, 2) + + # Return value should match deleted tasks count + self.assertEqual(result, mock_recursive_delete.call_count) + + def test_cleanup_celery_tasks_get_exception_and_cancel_failure(self): + """First-pass get failure and cancel failure are both handled.""" + self.redis_service._backend_client = self.mock_backend_client + self.redis_service._client = self.mock_redis_client + + task_keys = [b'celery-task-meta-err', b'celery-task-meta-2'] + valid_task = json.dumps({ + 'result': {'index_name': 'test_index'}, + 'parent_id': None + }).encode() + + self.mock_backend_client.keys.return_value = task_keys + self.mock_backend_client.get.side_effect = [ + redis.RedisError("boom"), + valid_task, + redis.RedisError("boom-second"), + valid_task, + ] + + with patch.object(self.redis_service, 'mark_task_cancelled', side_effect=ValueError("cancel fail")) as mock_cancel, \ + patch.object(self.redis_service, '_recursively_delete_task_and_parents', return_value=(1, {'2'})) as mock_delete, \ + patch.object(self.redis_service, '_cleanup_single_task_related_keys') as mock_cleanup: + + result = self.redis_service._cleanup_celery_tasks("test_index") + + mock_cancel.assert_called_once_with('2') + mock_delete.assert_called_once_with('2') + mock_cleanup.assert_called_once_with('2') + self.assertEqual(result, 1) + + def test_cleanup_celery_tasks_exc_message_bad_json(self): + """JSON decode failure inside exc_message parsing does not crash.""" + self.redis_service._backend_client = self.mock_backend_client + self.redis_service._client = self.mock_redis_client + + task_keys = [b'celery-task-meta-1'] + bad_json_payload = json.dumps({ + 'result': { + # Contains brace to enter parsing block + 'exc_message': '{bad json' + } + }).encode() + + self.mock_backend_client.keys.return_value = task_keys + self.mock_backend_client.get.side_effect = [ + bad_json_payload, bad_json_payload] + + with patch.object(self.redis_service, '_recursively_delete_task_and_parents', return_value=(0, set())) as mock_delete: + result = self.redis_service._cleanup_celery_tasks("test_index") + + # Bad JSON should be tolerated; no deletions occur + mock_delete.assert_not_called() + self.assertEqual(result, 0) + + def test_cleanup_celery_tasks_cleanup_single_task_error(self): + """Failures during related-key cleanup are logged and skipped.""" + self.redis_service._backend_client = self.mock_backend_client + self.redis_service._client = self.mock_redis_client + + task_keys = [b'celery-task-meta-1'] + task_payload = json.dumps({ + 'result': {'index_name': 'test_index'} + }).encode() + + self.mock_backend_client.keys.return_value = task_keys + self.mock_backend_client.get.side_effect = [task_payload, task_payload] + + with patch.object(self.redis_service, '_recursively_delete_task_and_parents', return_value=(1, {'1'})), \ + patch.object(self.redis_service, '_cleanup_single_task_related_keys', side_effect=Exception("cleanup boom")) as mock_cleanup: + result = self.redis_service._cleanup_celery_tasks("test_index") + + mock_cleanup.assert_called_once_with('1') + self.assertEqual(result, 1) + def test_cleanup_cache_keys(self): """Test _cleanup_cache_keys method""" # Setup self.redis_service._client = self.mock_redis_client - + # Configure mock responses for each pattern pattern_keys = { '*test_index*': [b'key1', b'key2'], @@ -277,19 +386,20 @@ def test_cleanup_cache_keys(self): 'index:test_index:*': [b'key6'], 'search:test_index:*': [b'key7', b'key8'] } - + def mock_keys_side_effect(pattern): return pattern_keys.get(pattern, []) - + self.mock_redis_client.keys.side_effect = mock_keys_side_effect - self.mock_redis_client.delete.return_value = 1 # Each delete operation deletes 1 key - + # Each delete operation deletes 1 key + self.mock_redis_client.delete.return_value = 1 + # Execute result = self.redis_service._cleanup_cache_keys("test_index") - + # Verify self.assertEqual(self.mock_redis_client.keys.call_count, 4) - + # All keys should be deleted (8 keys total) expected_calls = [ call(b'key1', b'key2'), @@ -297,19 +407,21 @@ def mock_keys_side_effect(pattern): call(b'key6'), call(b'key7', b'key8') ] - self.mock_redis_client.delete.assert_has_calls(expected_calls, any_order=True) - + self.mock_redis_client.delete.assert_has_calls( + expected_calls, any_order=True) + # Return value should be the number of deleted keys self.assertEqual(result, 4) # 4 successful delete operations - + def test_cleanup_document_celery_tasks(self): """Test _cleanup_document_celery_tasks method""" # Setup self.redis_service._backend_client = self.mock_backend_client - + # Create mock task data - task_keys = [b'celery-task-meta-1', b'celery-task-meta-2', b'celery-task-meta-3'] - + task_keys = [b'celery-task-meta-1', + b'celery-task-meta-2', b'celery-task-meta-3'] + # Task 1 matches our index and document task1_data = json.dumps({ 'result': { @@ -318,7 +430,7 @@ def test_cleanup_document_celery_tasks(self): }, 'parent_id': '2' # This will trigger a parent lookup }).encode() - + # Task 2 has the right index but wrong document task2_data = json.dumps({ 'result': { @@ -326,7 +438,7 @@ def test_cleanup_document_celery_tasks(self): 'source': 'other/doc.pdf' } }).encode() - + # Task 3 has document path in a different field task3_data = json.dumps({ 'result': { @@ -335,43 +447,46 @@ def test_cleanup_document_celery_tasks(self): }, 'parent_id': None # No parent }).encode() - + # Configure mock responses self.mock_backend_client.keys.return_value = task_keys - self.mock_backend_client.get.side_effect = [task1_data, task2_data, task3_data] - + self.mock_backend_client.get.side_effect = [ + task1_data, task2_data, task3_data] + # We expect delete to be called and return 1 each time self.mock_backend_client.delete.return_value = 1 - + # Execute with patch.object(self.redis_service, '_recursively_delete_task_and_parents') as mock_recursive_delete: mock_recursive_delete.side_effect = [(1, {'1'}), (1, {'3'})] - result = self.redis_service._cleanup_document_celery_tasks("test_index", "path/to/doc.pdf") - + result = self.redis_service._cleanup_document_celery_tasks( + "test_index", "path/to/doc.pdf") + # Verify - self.mock_backend_client.keys.assert_called_once_with('celery-task-meta-*') + self.mock_backend_client.keys.assert_called_once_with( + 'celery-task-meta-*') # We expect 3 calls - one for each task key self.assertEqual(self.mock_backend_client.get.call_count, 3) - + # Should have called recursive delete twice (for task1 and task3) self.assertEqual(mock_recursive_delete.call_count, 2) - + # Return value should be the number of deleted tasks self.assertEqual(result, 2) - + @patch('hashlib.md5') @patch('urllib.parse.quote') def test_cleanup_document_cache_keys(self, mock_quote, mock_md5): """Test _cleanup_document_cache_keys method""" # Setup self.redis_service._client = self.mock_redis_client - + # Mock the path hashing and quoting mock_quote.return_value = 'safe_path' mock_md5_instance = MagicMock() mock_md5_instance.hexdigest.return_value = 'path_hash' mock_md5.return_value = mock_md5_instance - + # Configure mock responses for each pattern pattern_keys = { '*test_index*safe_path*': [b'key1'], @@ -381,100 +496,105 @@ def test_cleanup_document_cache_keys(self, mock_quote, mock_md5): 'doc:safe_path:*': [b'key6', b'key7'], 'doc:path_hash:*': [b'key8'] } - + def mock_keys_side_effect(pattern): return pattern_keys.get(pattern, []) - + self.mock_redis_client.keys.side_effect = mock_keys_side_effect - self.mock_redis_client.delete.return_value = 1 # Each delete operation deletes 1 key - + # Each delete operation deletes 1 key + self.mock_redis_client.delete.return_value = 1 + # Execute - result = self.redis_service._cleanup_document_cache_keys("test_index", "path/to/doc.pdf") - + result = self.redis_service._cleanup_document_cache_keys( + "test_index", "path/to/doc.pdf") + # Verify self.assertEqual(self.mock_redis_client.keys.call_count, 6) - + # Return value should be the number of deleted keys self.assertEqual(result, 6) # 6 successful delete operations - + def test_get_knowledgebase_task_count(self): """Test get_knowledgebase_task_count method""" # Setup self.redis_service._client = self.mock_redis_client self.redis_service._backend_client = self.mock_backend_client - + # Create mock task data task_keys = [b'celery-task-meta-1', b'celery-task-meta-2'] - + # Task 1 matches our index task1_data = json.dumps({ 'result': {'index_name': 'test_index'} }).encode() - + # Task 2 is for a different index task2_data = json.dumps({ 'result': {'index_name': 'other_index'} }).encode() - + # Configure mock responses for Celery tasks self.mock_backend_client.keys.return_value = task_keys self.mock_backend_client.get.side_effect = [task1_data, task2_data] - + # Configure mock responses for cache keys cache_keys = { '*test_index*': [b'key1', b'key2'], 'kb:test_index:*': [b'key3', b'key4'], 'index:test_index:*': [b'key5'] } - + def mock_keys_side_effect(pattern): return cache_keys.get(pattern, []) - + self.mock_redis_client.keys.side_effect = mock_keys_side_effect - + # Execute result = self.redis_service.get_knowledgebase_task_count("test_index") - + # Verify - self.mock_backend_client.keys.assert_called_once_with('celery-task-meta-*') + self.mock_backend_client.keys.assert_called_once_with( + 'celery-task-meta-*') self.assertEqual(self.mock_backend_client.get.call_count, 2) - + # Should count 1 matching task and 5 cache keys self.assertEqual(result, 6) - + def test_ping_success(self): """Test ping method when connection is successful""" # Setup self.redis_service._client = self.mock_redis_client self.redis_service._backend_client = self.mock_backend_client - + self.mock_redis_client.ping.return_value = True self.mock_backend_client.ping.return_value = True - + # Execute result = self.redis_service.ping() - + # Verify self.mock_redis_client.ping.assert_called_once() self.mock_backend_client.ping.assert_called_once() self.assertTrue(result) - + def test_ping_failure(self): """Test ping method when connection fails""" # Setup self.redis_service._client = self.mock_redis_client self.redis_service._backend_client = self.mock_backend_client - - self.mock_redis_client.ping.side_effect = redis.RedisError("Connection failed") - + + self.mock_redis_client.ping.side_effect = redis.RedisError( + "Connection failed") + # Execute result = self.redis_service.ping() - + # Verify self.mock_redis_client.ping.assert_called_once() - self.mock_backend_client.ping.assert_not_called() # Should not be called after first ping fails + # Should not be called after first ping fails + self.mock_backend_client.ping.assert_not_called() self.assertFalse(result) - + @patch('backend.services.redis_service._redis_service', None) @patch('backend.services.redis_service.RedisService') def test_get_redis_service(self, mock_redis_service_class): @@ -482,146 +602,155 @@ def test_get_redis_service(self, mock_redis_service_class): # Setup mock_instance = MagicMock() mock_redis_service_class.return_value = mock_instance - + # Execute service1 = get_redis_service() service2 = get_redis_service() - + # Verify mock_redis_service_class.assert_called_once() # Only created once self.assertEqual(service1, mock_instance) - self.assertEqual(service2, mock_instance) # Should return same instance - + # Should return same instance + self.assertEqual(service2, mock_instance) + def test_recursively_delete_task_and_parents_no_parent(self): """Test _recursively_delete_task_and_parents with task that has no parent""" # Setup self.redis_service._backend_client = self.mock_backend_client - + task_data = json.dumps({ 'result': {'some_data': 'value'}, 'parent_id': None }).encode() - + self.mock_backend_client.get.return_value = task_data self.mock_backend_client.delete.return_value = 1 - + # Execute - deleted_count, processed_ids = self.redis_service._recursively_delete_task_and_parents("task123") - + deleted_count, processed_ids = self.redis_service._recursively_delete_task_and_parents( + "task123") + # Verify self.assertEqual(deleted_count, 1) self.assertEqual(processed_ids, {"task123"}) - self.mock_backend_client.get.assert_called_once_with('celery-task-meta-task123') - self.mock_backend_client.delete.assert_called_once_with('celery-task-meta-task123') - + self.mock_backend_client.get.assert_called_once_with( + 'celery-task-meta-task123') + self.mock_backend_client.delete.assert_called_once_with( + 'celery-task-meta-task123') + def test_recursively_delete_task_and_parents_with_cycle_detection(self): """Test _recursively_delete_task_and_parents detects and breaks cycles""" # Setup self.redis_service._backend_client = self.mock_backend_client - + # Create a cycle: task1 -> task2 -> task1 task1_data = json.dumps({'parent_id': 'task2'}).encode() task2_data = json.dumps({'parent_id': 'task1'}).encode() - + self.mock_backend_client.get.side_effect = [task1_data, task2_data] self.mock_backend_client.delete.return_value = 1 - + # Execute - deleted_count, processed_ids = self.redis_service._recursively_delete_task_and_parents("task1") - + deleted_count, processed_ids = self.redis_service._recursively_delete_task_and_parents( + "task1") + # Verify - should stop when cycle is detected self.assertEqual(deleted_count, 2) self.assertEqual(processed_ids, {"task1", "task2"}) self.assertEqual(self.mock_backend_client.delete.call_count, 2) - + def test_recursively_delete_task_and_parents_json_decode_error(self): """Test _recursively_delete_task_and_parents handles JSON decode errors""" # Setup self.redis_service._backend_client = self.mock_backend_client - + # Invalid JSON data invalid_json_data = b'invalid json data' - + self.mock_backend_client.get.return_value = invalid_json_data self.mock_backend_client.delete.return_value = 1 - + # Execute - deleted_count, processed_ids = self.redis_service._recursively_delete_task_and_parents("task123") - + deleted_count, processed_ids = self.redis_service._recursively_delete_task_and_parents( + "task123") + # Verify - should still delete the task even if JSON parsing fails self.assertEqual(deleted_count, 1) self.assertEqual(processed_ids, {"task123"}) - self.mock_backend_client.delete.assert_called_once_with('celery-task-meta-task123') - + self.mock_backend_client.delete.assert_called_once_with( + 'celery-task-meta-task123') + def test_recursively_delete_task_and_parents_redis_error(self): """Test _recursively_delete_task_and_parents handles Redis errors""" # Setup self.redis_service._backend_client = self.mock_backend_client - + # Simulate Redis error - self.mock_backend_client.get.side_effect = redis.RedisError("Connection lost") - + self.mock_backend_client.get.side_effect = redis.RedisError( + "Connection lost") + # Execute - deleted_count, processed_ids = self.redis_service._recursively_delete_task_and_parents("task123") - + deleted_count, processed_ids = self.redis_service._recursively_delete_task_and_parents( + "task123") + # Verify - should return 0 when Redis error occurs self.assertEqual(deleted_count, 0) self.assertEqual(processed_ids, {"task123"}) - + def test_cleanup_celery_tasks_with_failed_task_metadata(self): """Test _cleanup_celery_tasks handles failed tasks with exception metadata""" # Setup self.redis_service._backend_client = self.mock_backend_client - + task_keys = [b'celery-task-meta-1'] - + # Task with exception metadata containing index name task_data = json.dumps({ 'result': { 'exc_message': 'Error processing task: {"index_name": "test_index", "error": "failed"}' } }).encode() - + self.mock_backend_client.keys.return_value = task_keys self.mock_backend_client.get.return_value = task_data - + # Execute with patch.object(self.redis_service, '_recursively_delete_task_and_parents') as mock_recursive_delete: mock_recursive_delete.return_value = (1, {'1'}) result = self.redis_service._cleanup_celery_tasks("test_index") - + # Verify self.assertEqual(result, 1) mock_recursive_delete.assert_called_once_with('1') - + def test_cleanup_celery_tasks_invalid_exception_metadata(self): """Test _cleanup_celery_tasks handles invalid exception metadata gracefully""" # Setup self.redis_service._backend_client = self.mock_backend_client - + task_keys = [b'celery-task-meta-1'] - + # Task with invalid exception metadata task_data = json.dumps({ 'result': { 'exc_message': 'Invalid JSON metadata' } }).encode() - + self.mock_backend_client.keys.return_value = task_keys self.mock_backend_client.get.return_value = task_data - + # Execute result = self.redis_service._cleanup_celery_tasks("test_index") - + # Verify - should not crash and return 0 self.assertEqual(result, 0) - + def test_cleanup_cache_keys_partial_failure(self): """Test _cleanup_cache_keys handles partial failures gracefully""" # Setup self.redis_service._client = self.mock_redis_client - + # First pattern succeeds, second fails, third succeeds def mock_keys_side_effect(pattern): if pattern == 'kb:test_index:*': @@ -632,33 +761,65 @@ def mock_keys_side_effect(pattern): return [b'key3'] else: return [] - + self.mock_redis_client.keys.side_effect = mock_keys_side_effect self.mock_redis_client.delete.return_value = 1 - + # Execute result = self.redis_service._cleanup_cache_keys("test_index") - + # Verify - should continue processing despite one pattern failing self.assertEqual(result, 2) # 2 successful delete operations - + def test_cleanup_cache_keys_all_patterns_fail(self): """Test _cleanup_cache_keys handles errors gracefully when all patterns fail""" # Setup self.redis_service._client = self.mock_redis_client - + # Simulate an error for all pattern calls # Each call to keys() will fail but be caught by inner try-catch - self.mock_redis_client.keys.side_effect = redis.RedisError("Redis connection failed") - + self.mock_redis_client.keys.side_effect = redis.RedisError( + "Redis connection failed") + # Execute - should not raise exception but return 0 result = self.redis_service._cleanup_cache_keys("test_index") - + # Verify - should handle gracefully and return 0 self.assertEqual(result, 0) # Should have tried all 4 patterns self.assertEqual(self.mock_redis_client.keys.call_count, 4) - + + def test_cleanup_document_celery_tasks_cancel_fail_and_processing_error(self): + """Document cleanup logs processing errors and cancel failures.""" + self.redis_service._backend_client = self.mock_backend_client + self.redis_service._client = self.mock_redis_client + + task_keys = [b'celery-task-meta-err', b'celery-task-meta-1'] + good_payload = json.dumps({ + 'result': { + 'index_name': 'kb1', + 'path_or_url': 'doc1' + } + }).encode() + + self.mock_backend_client.keys.return_value = task_keys + self.mock_backend_client.get.side_effect = [ + redis.RedisError("get boom"), + good_payload + ] + + with patch.object(self.redis_service, 'mark_task_cancelled', side_effect=ValueError("cancel fail")) as mock_cancel, \ + patch.object(self.redis_service, '_recursively_delete_task_and_parents', return_value=(1, {'1'})) as mock_delete, \ + patch.object(self.redis_service, '_cleanup_single_task_related_keys') as mock_cleanup: + + result = self.redis_service._cleanup_document_celery_tasks( + "kb1", "doc1") + + mock_cancel.assert_called_once_with('1') + mock_delete.assert_called_once_with('1') + mock_cleanup.assert_called_once_with('1') + self.assertEqual(result, 1) + def test_cleanup_document_cache_keys_empty_patterns(self): """Test _cleanup_document_cache_keys handles empty key patterns""" @@ -785,6 +946,470 @@ def test_ping_backend_failure(self): self.mock_redis_client.ping.assert_called_once() self.mock_backend_client.ping.assert_called_once() + # ------------------------------------------------------------------ + # Test mark_task_cancelled edge cases + # ------------------------------------------------------------------ + + def test_mark_task_cancelled_empty_task_id(self): + """Test mark_task_cancelled returns False when task_id is empty""" + self.redis_service._client = self.mock_redis_client + + result = self.redis_service.mark_task_cancelled("") + self.assertFalse(result) + self.mock_redis_client.setex.assert_not_called() + + def test_mark_task_cancelled_redis_error(self): + """Test mark_task_cancelled handles Redis errors gracefully""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.setex.side_effect = redis.RedisError("Connection failed") + + result = self.redis_service.mark_task_cancelled("task-123") + self.assertFalse(result) + self.mock_redis_client.setex.assert_called_once() + + def test_mark_task_cancelled_custom_ttl(self): + """Test mark_task_cancelled with custom TTL hours""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.setex.return_value = True + + result = self.redis_service.mark_task_cancelled("task-123", ttl_hours=48) + self.assertTrue(result) + # Verify TTL is calculated correctly (48 hours = 172800 seconds) + call_args = self.mock_redis_client.setex.call_args + self.assertEqual(call_args[0][1], 48 * 3600) # TTL in seconds + + # ------------------------------------------------------------------ + # Test is_task_cancelled edge cases + # ------------------------------------------------------------------ + + def test_is_task_cancelled_empty_task_id(self): + """Test is_task_cancelled returns False when task_id is empty""" + self.redis_service._client = self.mock_redis_client + + result = self.redis_service.is_task_cancelled("") + self.assertFalse(result) + self.mock_redis_client.get.assert_not_called() + + def test_is_task_cancelled_none_value(self): + """Test is_task_cancelled returns False when key doesn't exist""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.get.return_value = None + + result = self.redis_service.is_task_cancelled("task-123") + self.assertFalse(result) + + def test_is_task_cancelled_empty_string_value(self): + """Test is_task_cancelled returns False when value is empty string""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.get.return_value = "" + + result = self.redis_service.is_task_cancelled("task-123") + self.assertFalse(result) + + def test_is_task_cancelled_redis_error(self): + """Test is_task_cancelled handles Redis errors gracefully""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.get.side_effect = redis.RedisError("Connection failed") + + result = self.redis_service.is_task_cancelled("task-123") + self.assertFalse(result) + + # ------------------------------------------------------------------ + # Test _cleanup_single_task_related_keys + # ------------------------------------------------------------------ + + def test_cleanup_single_task_related_keys_success(self): + """Test _cleanup_single_task_related_keys deletes all related keys""" + self.redis_service._client = self.mock_redis_client + self.redis_service._backend_client = self.mock_backend_client + + # Mock successful deletions + self.mock_redis_client.delete.side_effect = [1, 1, 1] # progress, error, cancel + self.mock_backend_client.delete.return_value = 1 # chunk cache + + result = self.redis_service._cleanup_single_task_related_keys("task-123") + + # Should delete 4 keys total + self.assertEqual(result, 4) + # Verify all keys were attempted + self.assertEqual(self.mock_redis_client.delete.call_count, 3) + self.mock_backend_client.delete.assert_called_once_with("dp:task-123:chunks") + + def test_cleanup_single_task_related_keys_empty_task_id(self): + """Test _cleanup_single_task_related_keys returns 0 for empty task_id""" + result = self.redis_service._cleanup_single_task_related_keys("") + self.assertEqual(result, 0) + + def test_cleanup_single_task_related_keys_partial_failure(self): + """Test _cleanup_single_task_related_keys handles partial failures""" + self.redis_service._client = self.mock_redis_client + self.redis_service._backend_client = self.mock_backend_client + + # First key succeeds, second fails, third succeeds, chunk cache fails + self.mock_redis_client.delete.side_effect = [1, redis.RedisError("Error"), 1] + self.mock_backend_client.delete.side_effect = redis.RedisError("Backend error") + + result = self.redis_service._cleanup_single_task_related_keys("task-123") + + # Should return count of successful deletions (2) + self.assertEqual(result, 2) + + def test_cleanup_single_task_related_keys_all_fail(self): + """Test _cleanup_single_task_related_keys handles all failures gracefully""" + self.redis_service._client = self.mock_redis_client + self.redis_service._backend_client = self.mock_backend_client + + self.mock_redis_client.delete.side_effect = redis.RedisError("All failed") + self.mock_backend_client.delete.side_effect = redis.RedisError("Backend failed") + + result = self.redis_service._cleanup_single_task_related_keys("task-123") + + # Should return 0 but not raise exception + self.assertEqual(result, 0) + + def test_cleanup_single_task_related_keys_no_keys_exist(self): + """Test _cleanup_single_task_related_keys when keys don't exist""" + self.redis_service._client = self.mock_redis_client + self.redis_service._backend_client = self.mock_backend_client + + # All deletions return 0 (key doesn't exist) + self.mock_redis_client.delete.side_effect = [0, 0, 0] + self.mock_backend_client.delete.return_value = 0 + + result = self.redis_service._cleanup_single_task_related_keys("task-123") + + # Should return 0 + self.assertEqual(result, 0) + + # ------------------------------------------------------------------ + # Test save_error_info + # ------------------------------------------------------------------ + + def test_save_error_info_success(self): + """Test save_error_info successfully saves error information""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.setex.return_value = True + self.mock_redis_client.get.return_value = "Test error reason" + + result = self.redis_service.save_error_info("task-123", "Test error reason") + + self.assertTrue(result) + self.mock_redis_client.setex.assert_called_once() + # Verify TTL is 30 days in seconds + call_args = self.mock_redis_client.setex.call_args + self.assertEqual(call_args[0][1], 30 * 24 * 60 * 60) + self.assertEqual(call_args[0][2], "Test error reason") + # Verify get was called to verify the save + self.mock_redis_client.get.assert_called_once() + + def test_save_error_info_empty_task_id(self): + """Test save_error_info returns False when task_id is empty""" + self.redis_service._client = self.mock_redis_client + + result = self.redis_service.save_error_info("", "Error reason") + self.assertFalse(result) + self.mock_redis_client.setex.assert_not_called() + + def test_save_error_info_empty_error_reason(self): + """Test save_error_info returns False when error_reason is empty""" + self.redis_service._client = self.mock_redis_client + + result = self.redis_service.save_error_info("task-123", "") + self.assertFalse(result) + self.mock_redis_client.setex.assert_not_called() + + def test_save_error_info_custom_ttl(self): + """Test save_error_info with custom TTL days""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.setex.return_value = True + self.mock_redis_client.get.return_value = "Error" + + result = self.redis_service.save_error_info("task-123", "Error", ttl_days=7) + + self.assertTrue(result) + call_args = self.mock_redis_client.setex.call_args + # Verify TTL is 7 days in seconds + self.assertEqual(call_args[0][1], 7 * 24 * 60 * 60) + + def test_save_error_info_setex_returns_false(self): + """Test save_error_info handles setex returning False""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.setex.return_value = False + + result = self.redis_service.save_error_info("task-123", "Error") + self.assertFalse(result) + + def test_save_error_info_verification_fails(self): + """Test save_error_info when verification get returns None""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.setex.return_value = True + self.mock_redis_client.get.return_value = None # Verification fails + + result = self.redis_service.save_error_info("task-123", "Error") + # Should still return True because setex succeeded + self.assertTrue(result) + + def test_save_error_info_redis_error(self): + """Test save_error_info handles Redis errors gracefully""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.setex.side_effect = redis.RedisError("Connection failed") + + result = self.redis_service.save_error_info("task-123", "Error") + self.assertFalse(result) + + def test_save_error_info_verification_redis_error(self): + """Test save_error_info returns False when verification raises Redis error""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.setex.return_value = True + self.mock_redis_client.get.side_effect = redis.RedisError("Connection failed") + + # Should return False because verification failed with exception + result = self.redis_service.save_error_info("task-123", "Error") + self.assertFalse(result) + + # ------------------------------------------------------------------ + # Test save_progress_info + # ------------------------------------------------------------------ + + def test_save_progress_info_success(self): + """Test save_progress_info successfully saves progress""" + self.redis_service._client = self.mock_redis_client + + result = self.redis_service.save_progress_info("task-123", 50, 100) + + self.assertTrue(result) + self.mock_redis_client.setex.assert_called_once() + call_args = self.mock_redis_client.setex.call_args + # Verify TTL is 24 hours in seconds + self.assertEqual(call_args[0][1], 24 * 3600) + # Verify JSON data + progress_data = json.loads(call_args[0][2]) + self.assertEqual(progress_data['processed_chunks'], 50) + self.assertEqual(progress_data['total_chunks'], 100) + + def test_save_progress_info_empty_task_id(self): + """Test save_progress_info returns False when task_id is empty""" + self.redis_service._client = self.mock_redis_client + + result = self.redis_service.save_progress_info("", 50, 100) + self.assertFalse(result) + self.mock_redis_client.setex.assert_not_called() + + def test_save_progress_info_custom_ttl(self): + """Test save_progress_info with custom TTL hours""" + self.redis_service._client = self.mock_redis_client + + result = self.redis_service.save_progress_info("task-123", 25, 50, ttl_hours=48) + + self.assertTrue(result) + call_args = self.mock_redis_client.setex.call_args + # Verify TTL is 48 hours in seconds + self.assertEqual(call_args[0][1], 48 * 3600) + + def test_save_progress_info_zero_progress(self): + """Test save_progress_info with zero progress""" + self.redis_service._client = self.mock_redis_client + + result = self.redis_service.save_progress_info("task-123", 0, 100) + + self.assertTrue(result) + call_args = self.mock_redis_client.setex.call_args + progress_data = json.loads(call_args[0][2]) + self.assertEqual(progress_data['processed_chunks'], 0) + self.assertEqual(progress_data['total_chunks'], 100) + + def test_save_progress_info_redis_error(self): + """Test save_progress_info handles Redis errors gracefully""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.setex.side_effect = redis.RedisError("Connection failed") + + result = self.redis_service.save_progress_info("task-123", 50, 100) + self.assertFalse(result) + + # ------------------------------------------------------------------ + # Test get_progress_info + # ------------------------------------------------------------------ + + def test_get_progress_info_success(self): + """Test get_progress_info successfully retrieves progress""" + self.redis_service._client = self.mock_redis_client + progress_json = json.dumps({'processed_chunks': 50, 'total_chunks': 100}) + self.mock_redis_client.get.return_value = progress_json + + result = self.redis_service.get_progress_info("task-123") + + self.assertIsNotNone(result) + self.assertEqual(result['processed_chunks'], 50) + self.assertEqual(result['total_chunks'], 100) + + def test_get_progress_info_not_found(self): + """Test get_progress_info returns None when key doesn't exist""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.get.return_value = None + + result = self.redis_service.get_progress_info("task-123") + self.assertIsNone(result) + + def test_get_progress_info_bytes_response(self): + """Test get_progress_info handles bytes response (when decode_responses=False)""" + self.redis_service._client = self.mock_redis_client + progress_json = json.dumps({'processed_chunks': 75, 'total_chunks': 150}) + self.mock_redis_client.get.return_value = progress_json.encode('utf-8') + + result = self.redis_service.get_progress_info("task-123") + + self.assertIsNotNone(result) + self.assertEqual(result['processed_chunks'], 75) + self.assertEqual(result['total_chunks'], 150) + + def test_get_progress_info_invalid_json(self): + """Test get_progress_info handles invalid JSON gracefully""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.get.return_value = "invalid json" + + result = self.redis_service.get_progress_info("task-123") + self.assertIsNone(result) + + def test_get_progress_info_redis_error(self): + """Test get_progress_info handles Redis errors gracefully""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.get.side_effect = redis.RedisError("Connection failed") + + result = self.redis_service.get_progress_info("task-123") + self.assertIsNone(result) + + # ------------------------------------------------------------------ + # Test get_error_info + # ------------------------------------------------------------------ + + def test_get_error_info_success(self): + """Test get_error_info successfully retrieves error reason""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.get.return_value = "Test error reason" + + result = self.redis_service.get_error_info("task-123") + + self.assertEqual(result, "Test error reason") + self.mock_redis_client.get.assert_called_once_with("error:reason:task-123") + + def test_get_error_info_not_found(self): + """Test get_error_info returns None when key doesn't exist""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.get.return_value = None + + result = self.redis_service.get_error_info("task-123") + self.assertIsNone(result) + + def test_get_error_info_empty_string(self): + """Test get_error_info returns None when value is empty string""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.get.return_value = "" + + result = self.redis_service.get_error_info("task-123") + self.assertIsNone(result) + + def test_get_error_info_redis_error(self): + """Test get_error_info handles Redis errors gracefully""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.get.side_effect = redis.RedisError("Connection failed") + + result = self.redis_service.get_error_info("task-123") + self.assertIsNone(result) + + # ------------------------------------------------------------------ + # Test _cleanup_celery_tasks edge cases + # ------------------------------------------------------------------ + + def test_cleanup_celery_tasks_mark_cancelled_failure(self): + """Test _cleanup_celery_tasks handles mark_task_cancelled failures""" + self.redis_service._backend_client = self.mock_backend_client + self.redis_service._client = self.mock_redis_client + + task_keys = [b'celery-task-meta-1'] + task_data = json.dumps({ + 'result': {'index_name': 'test_index'}, + 'parent_id': None + }).encode() + + self.mock_backend_client.keys.return_value = task_keys + # Provide data for both passes + self.mock_backend_client.get.side_effect = [task_data, task_data] + self.mock_backend_client.delete.return_value = 1 + + # Mock mark_task_cancelled to fail + with patch.object(self.redis_service, 'mark_task_cancelled', return_value=False): + with patch.object(self.redis_service, '_recursively_delete_task_and_parents', return_value=(1, {'1'})): + with patch.object(self.redis_service, '_cleanup_single_task_related_keys', return_value=0): + result = self.redis_service._cleanup_celery_tasks("test_index") + + # Should still proceed with deletion despite cancellation failure + self.assertEqual(result, 1) + + def test_cleanup_celery_tasks_no_matching_tasks(self): + """Test _cleanup_celery_tasks when no tasks match the index""" + self.redis_service._backend_client = self.mock_backend_client + + task_keys = [b'celery-task-meta-1'] + task_data = json.dumps({ + 'result': {'index_name': 'other_index'} + }).encode() + + self.mock_backend_client.keys.return_value = task_keys + # Provide data for both passes + self.mock_backend_client.get.side_effect = [task_data, task_data] + + result = self.redis_service._cleanup_celery_tasks("test_index") + + self.assertEqual(result, 0) + + # ------------------------------------------------------------------ + # Test _cleanup_document_celery_tasks edge cases + # ------------------------------------------------------------------ + + def test_cleanup_document_celery_tasks_no_matching_document(self): + """Test _cleanup_document_celery_tasks when no tasks match document""" + self.redis_service._backend_client = self.mock_backend_client + + task_keys = [b'celery-task-meta-1'] + task_data = json.dumps({ + 'result': { + 'index_name': 'test_index', + 'source': 'other/doc.pdf' + } + }).encode() + + self.mock_backend_client.keys.return_value = task_keys + self.mock_backend_client.get.return_value = task_data + + result = self.redis_service._cleanup_document_celery_tasks("test_index", "path/to/doc.pdf") + + self.assertEqual(result, 0) + + def test_cleanup_document_celery_tasks_mark_cancelled_failure(self): + """Test _cleanup_document_celery_tasks handles mark_task_cancelled failures""" + self.redis_service._backend_client = self.mock_backend_client + + task_keys = [b'celery-task-meta-1'] + task_data = json.dumps({ + 'result': { + 'index_name': 'test_index', + 'source': 'path/to/doc.pdf' + } + }).encode() + + self.mock_backend_client.keys.return_value = task_keys + self.mock_backend_client.get.return_value = task_data + self.mock_backend_client.delete.return_value = 1 + + # Mock mark_task_cancelled to fail + with patch.object(self.redis_service, 'mark_task_cancelled', return_value=False): + with patch.object(self.redis_service, '_recursively_delete_task_and_parents', return_value=(1, {'1'})): + with patch.object(self.redis_service, '_cleanup_single_task_related_keys', return_value=0): + result = self.redis_service._cleanup_document_celery_tasks("test_index", "path/to/doc.pdf") + + # Should still proceed with deletion + self.assertEqual(result, 1) + if __name__ == '__main__': unittest.main() diff --git a/test/backend/services/test_tool_configuration_service.py b/test/backend/services/test_tool_configuration_service.py index 2deb6058d..cf12c9805 100644 --- a/test/backend/services/test_tool_configuration_service.py +++ b/test/backend/services/test_tool_configuration_service.py @@ -1,11 +1,10 @@ -from consts.model import ToolInfo, ToolSourceEnum, ToolInstanceInfoRequest, ToolValidateRequest from consts.exceptions import MCPConnectionError, NotFoundException, ToolExecutionException import asyncio import inspect import os import sys +import types import unittest -from typing import Any, List, Dict from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest @@ -21,30 +20,301 @@ minio_client_mock = MagicMock() sys.modules['boto3'] = boto3_mock +# Patch smolagents and its sub-modules before importing consts.model to avoid ImportError +mock_smolagents = MagicMock() +sys.modules['smolagents'] = mock_smolagents + +# Create dummy smolagents sub-modules to satisfy indirect imports +for sub_mod in ["agents", "memory", "models", "monitoring", "utils", "local_python_executor"]: + sub_mod_obj = types.ModuleType(f"smolagents.{sub_mod}") + setattr(mock_smolagents, sub_mod, sub_mod_obj) + sys.modules[f"smolagents.{sub_mod}"] = sub_mod_obj + +# Populate smolagents.agents with required attributes +# Exception classes should be real exception classes, not MagicMock + + +class MockAgentError(Exception): + pass + + +setattr(mock_smolagents.agents, "AgentError", MockAgentError) +for name in ["CodeAgent", "handle_agent_output_types", "ActionOutput", "RunResult"]: + setattr(mock_smolagents.agents, name, MagicMock( + name=f"smolagents.agents.{name}")) + +# Populate smolagents.local_python_executor with required attributes +setattr(mock_smolagents.local_python_executor, "fix_final_answer_code", + MagicMock(name="fix_final_answer_code")) + +# Populate smolagents.memory with required attributes +for name in ["ActionStep", "PlanningStep", "FinalAnswerStep", "ToolCall", "TaskStep", "SystemPromptStep"]: + setattr(mock_smolagents.memory, name, MagicMock( + name=f"smolagents.memory.{name}")) + +# Populate smolagents.models with required attributes +setattr(mock_smolagents.models, "ChatMessage", MagicMock(name="ChatMessage")) +setattr(mock_smolagents.models, "MessageRole", MagicMock(name="MessageRole")) +setattr(mock_smolagents.models, "CODEAGENT_RESPONSE_FORMAT", + MagicMock(name="CODEAGENT_RESPONSE_FORMAT")) + +# OpenAIServerModel should be a class that can be instantiated + + +class MockOpenAIServerModel: + def __init__(self, *args, **kwargs): + pass + + +setattr(mock_smolagents.models, "OpenAIServerModel", MockOpenAIServerModel) + +# Populate smolagents with Tool attribute +setattr(mock_smolagents, "Tool", MagicMock(name="Tool")) + +# Populate smolagents.monitoring with required attributes +for name in ["LogLevel", "Timing", "YELLOW_HEX", "TokenUsage"]: + setattr(mock_smolagents.monitoring, name, MagicMock( + name=f"smolagents.monitoring.{name}")) + +# Populate smolagents.utils with required attributes +# Exception classes should be real exception classes, not MagicMock + + +class MockAgentExecutionError(Exception): + pass + + +class MockAgentGenerationError(Exception): + pass + + +class MockAgentMaxStepsError(Exception): + pass + + +setattr(mock_smolagents.utils, "AgentExecutionError", MockAgentExecutionError) +setattr(mock_smolagents.utils, "AgentGenerationError", MockAgentGenerationError) +setattr(mock_smolagents.utils, "AgentMaxStepsError", MockAgentMaxStepsError) +for name in ["truncate_content", "extract_code_from_text"]: + setattr(mock_smolagents.utils, name, MagicMock( + name=f"smolagents.utils.{name}")) + +# mcpadapt imports a helper from smolagents.utils + + +def _is_package_available(pkg_name: str) -> bool: + """Simplified availability check for tests.""" + return True + + +setattr(mock_smolagents.utils, "_is_package_available", _is_package_available) + +# Mock nexent module and its submodules before patching + + +def _create_package_mock(name): + """Helper to create a package-like mock module.""" + pkg = types.ModuleType(name) + pkg.__path__ = [] + return pkg + + +nexent_mock = _create_package_mock('nexent') +sys.modules['nexent'] = nexent_mock +sys.modules['nexent.core'] = _create_package_mock('nexent.core') +sys.modules['nexent.core.agents'] = _create_package_mock('nexent.core.agents') +sys.modules['nexent.core.agents.agent_model'] = MagicMock() +sys.modules['nexent.core.models'] = _create_package_mock('nexent.core.models') + + +class MockMessageObserver: + """Lightweight stand-in for nexent.MessageObserver.""" + pass + + +# Expose MessageObserver on top-level nexent package +setattr(sys.modules['nexent'], 'MessageObserver', MockMessageObserver) + +# Mock embedding model module to satisfy vectordatabase_service imports +embedding_model_module = types.ModuleType('nexent.core.models.embedding_model') + + +class MockBaseEmbedding: + pass + + +class MockOpenAICompatibleEmbedding(MockBaseEmbedding): + pass + + +class MockJinaEmbedding(MockBaseEmbedding): + pass + + +embedding_model_module.BaseEmbedding = MockBaseEmbedding +embedding_model_module.OpenAICompatibleEmbedding = MockOpenAICompatibleEmbedding +embedding_model_module.JinaEmbedding = MockJinaEmbedding +sys.modules['nexent.core.models.embedding_model'] = embedding_model_module + +# Provide model class used by file_management_service imports + + +class MockOpenAILongContextModel: + def __init__(self, *args, **kwargs): + pass + + +setattr(sys.modules['nexent.core.models'], + 'OpenAILongContextModel', MockOpenAILongContextModel) + +# Provide vision model class used by image_service imports + + +class MockOpenAIVLModel: + def __init__(self, *args, **kwargs): + pass + + +setattr(sys.modules['nexent.core.models'], + 'OpenAIVLModel', MockOpenAIVLModel) + +# Mock vector database modules used by vectordatabase_service +sys.modules['nexent.vector_database'] = _create_package_mock( + 'nexent.vector_database') +vector_database_base_module = types.ModuleType('nexent.vector_database.base') +vector_database_elasticsearch_module = types.ModuleType( + 'nexent.vector_database.elasticsearch_core') + + +class MockVectorDatabaseCore: + pass + + +class MockElasticSearchCore(MockVectorDatabaseCore): + def __init__(self, *args, **kwargs): + pass + + +vector_database_base_module.VectorDatabaseCore = MockVectorDatabaseCore +vector_database_elasticsearch_module.ElasticSearchCore = MockElasticSearchCore +sys.modules['nexent.vector_database.base'] = vector_database_base_module +sys.modules['nexent.vector_database.elasticsearch_core'] = vector_database_elasticsearch_module + +# Expose submodules on parent packages +setattr(sys.modules['nexent.core'], 'models', + sys.modules['nexent.core.models']) +setattr(sys.modules['nexent.core.models'], 'embedding_model', + sys.modules['nexent.core.models.embedding_model']) +setattr(sys.modules['nexent'], 'vector_database', + sys.modules['nexent.vector_database']) +setattr(sys.modules['nexent.vector_database'], 'base', + sys.modules['nexent.vector_database.base']) +setattr(sys.modules['nexent.vector_database'], 'elasticsearch_core', + sys.modules['nexent.vector_database.elasticsearch_core']) + +# Mock nexent.storage module and its submodules +sys.modules['nexent.storage'] = _create_package_mock('nexent.storage') +storage_factory_module = types.ModuleType( + 'nexent.storage.storage_client_factory') +storage_config_module = types.ModuleType('nexent.storage.minio_config') + +# Create mock classes/functions + + +class MockMinIOStorageConfig: + def __init__(self, *args, **kwargs): + pass + + def validate(self): + pass + + +storage_factory_module.create_storage_client_from_config = MagicMock() +storage_factory_module.MinIOStorageConfig = MockMinIOStorageConfig +storage_config_module.MinIOStorageConfig = MockMinIOStorageConfig + +# Ensure nested packages are reachable via attributes +setattr(sys.modules['nexent'], 'storage', sys.modules['nexent.storage']) +# Expose submodules on the storage package for patch lookups +setattr(sys.modules['nexent.storage'], + 'storage_client_factory', storage_factory_module) +setattr(sys.modules['nexent.storage'], 'minio_config', storage_config_module) +sys.modules['nexent.storage.storage_client_factory'] = storage_factory_module +sys.modules['nexent.storage.minio_config'] = storage_config_module + +# Load actual backend modules so that patch targets resolve correctly +import importlib # noqa: E402 +backend_module = importlib.import_module('backend') +sys.modules['backend'] = backend_module +backend_database_module = importlib.import_module('backend.database') +sys.modules['backend.database'] = backend_database_module +backend_database_client_module = importlib.import_module( + 'backend.database.client') +sys.modules['backend.database.client'] = backend_database_client_module +backend_services_module = importlib.import_module( + 'backend.services.tool_configuration_service') +# Ensure services package can resolve tool_configuration_service for patching +sys.modules['services.tool_configuration_service'] = backend_services_module + +# Mock services modules +sys.modules['services'] = _create_package_mock('services') +services_modules = { + 'file_management_service': {'get_llm_model': MagicMock()}, + 'vectordatabase_service': {'get_embedding_model': MagicMock(), 'get_vector_db_core': MagicMock(), + 'ElasticSearchService': MagicMock()}, + 'tenant_config_service': {'get_selected_knowledge_list': MagicMock(), 'build_knowledge_name_mapping': MagicMock()}, + 'image_service': {'get_vlm_model': MagicMock()} +} +for service_name, attrs in services_modules.items(): + service_module = types.ModuleType(f'services.{service_name}') + for attr_name, attr_value in attrs.items(): + setattr(service_module, attr_name, attr_value) + sys.modules[f'services.{service_name}'] = service_module + # Expose on parent package for patch resolution + setattr(sys.modules['services'], service_name, service_module) + # Patch storage factory and MinIO config validation to avoid errors during initialization # These patches must be started before any imports that use MinioClient storage_client_mock = MagicMock() -patch('nexent.storage.storage_client_factory.create_storage_client_from_config', return_value=storage_client_mock).start() -patch('nexent.storage.minio_config.MinIOStorageConfig.validate', lambda self: None).start() -patch('backend.database.client.MinioClient', return_value=minio_client_mock).start() +patch('nexent.storage.storage_client_factory.create_storage_client_from_config', + return_value=storage_client_mock).start() +patch('nexent.storage.minio_config.MinIOStorageConfig.validate', + lambda self: None).start() +patch('backend.database.client.MinioClient', + return_value=minio_client_mock).start() patch('elasticsearch.Elasticsearch', return_value=MagicMock()).start() -from backend.services.tool_configuration_service import ( - python_type_to_json_schema, - get_local_tools, - get_local_tools_classes, - search_tool_info_impl, - update_tool_info_impl, - list_all_tools, - load_last_tool_config_impl, validate_tool_impl -) +# Patch tool_configuration_service imports to avoid triggering actual imports during patch +# This prevents import errors when patch tries to import the module +# Note: These patches use the import path as seen in tool_configuration_service.py +patch('services.file_management_service.get_llm_model', MagicMock()).start() +patch('services.vectordatabase_service.get_embedding_model', MagicMock()).start() +patch('services.vectordatabase_service.get_vector_db_core', MagicMock()).start() +patch('services.tenant_config_service.get_selected_knowledge_list', MagicMock()).start() +patch('services.tenant_config_service.build_knowledge_name_mapping', + MagicMock()).start() +patch('services.image_service.get_vlm_model', MagicMock()).start() + +# Import consts after patching dependencies +from consts.model import ToolInfo, ToolSourceEnum, ToolInstanceInfoRequest, ToolValidateRequest # noqa: E402 class TestPythonTypeToJsonSchema: """ test the function of python_type_to_json_schema""" - def test_python_type_to_json_schema_basic_types(self): + @patch('backend.services.tool_configuration_service.python_type_to_json_schema') + def test_python_type_to_json_schema_basic_types(self, mock_python_type_to_json_schema): """ test the basic types of python""" + mock_python_type_to_json_schema.side_effect = lambda x: { + str: "string", + int: "integer", + float: "float", + bool: "boolean", + list: "array", + dict: "object" + }.get(x, "unknown") + + from backend.services.tool_configuration_service import python_type_to_json_schema assert python_type_to_json_schema(str) == "string" assert python_type_to_json_schema(int) == "integer" assert python_type_to_json_schema(float) == "float" @@ -52,35 +322,60 @@ def test_python_type_to_json_schema_basic_types(self): assert python_type_to_json_schema(list) == "array" assert python_type_to_json_schema(dict) == "object" - def test_python_type_to_json_schema_typing_types(self): + @patch('backend.services.tool_configuration_service.python_type_to_json_schema') + def test_python_type_to_json_schema_typing_types(self, mock_python_type_to_json_schema): """ test the typing types of python""" from typing import List, Dict, Tuple, Any + mock_python_type_to_json_schema.side_effect = lambda x: { + List: "array", + Dict: "object", + Tuple: "array", + Any: "any" + }.get(x, "unknown") + + from backend.services.tool_configuration_service import python_type_to_json_schema assert python_type_to_json_schema(List) == "array" assert python_type_to_json_schema(Dict) == "object" assert python_type_to_json_schema(Tuple) == "array" assert python_type_to_json_schema(Any) == "any" - def test_python_type_to_json_schema_empty_annotation(self): + @patch('backend.services.tool_configuration_service.python_type_to_json_schema') + def test_python_type_to_json_schema_empty_annotation(self, mock_python_type_to_json_schema): """ test the empty annotation of python""" + mock_python_type_to_json_schema.return_value = "string" + + from backend.services.tool_configuration_service import python_type_to_json_schema assert python_type_to_json_schema(inspect.Parameter.empty) == "string" - def test_python_type_to_json_schema_unknown_type(self): + @patch('backend.services.tool_configuration_service.python_type_to_json_schema') + def test_python_type_to_json_schema_unknown_type(self, mock_python_type_to_json_schema): """ test the unknown type of python""" class CustomType: pass # the unknown type should return the type name itself + mock_python_type_to_json_schema.return_value = "CustomType" + + from backend.services.tool_configuration_service import python_type_to_json_schema result = python_type_to_json_schema(CustomType) assert "CustomType" in result - def test_python_type_to_json_schema_edge_cases(self): + @patch('backend.services.tool_configuration_service.python_type_to_json_schema') + def test_python_type_to_json_schema_edge_cases(self, mock_python_type_to_json_schema): """ test the edge cases of python""" + from typing import List, Dict, Any + # test the None type + mock_python_type_to_json_schema.side_effect = lambda x: "NoneType" if x == type( + None) else "array" + + from backend.services.tool_configuration_service import python_type_to_json_schema assert python_type_to_json_schema(type(None)) == "NoneType" # test the complex type string representation complex_type = List[Dict[str, Any]] + mock_python_type_to_json_schema.return_value = "array" result = python_type_to_json_schema(complex_type) assert isinstance(result, str) @@ -89,7 +384,8 @@ class TestGetLocalToolsClasses: """ test the function of get_local_tools_classes""" @patch('backend.services.tool_configuration_service.importlib.import_module') - def test_get_local_tools_classes_success(self, mock_import): + @patch('backend.services.tool_configuration_service.get_local_tools_classes') + def test_get_local_tools_classes_success(self, mock_get_local_tools_classes, mock_import): """ test the success of get_local_tools_classes""" # create the mock tool class mock_tool_class1 = type('TestTool1', (), {}) @@ -109,7 +405,10 @@ def __dir__(self): mock_package = MockPackage() mock_import.return_value = mock_package + mock_get_local_tools_classes.return_value = [ + mock_tool_class1, mock_tool_class2] + from backend.services.tool_configuration_service import get_local_tools_classes result = get_local_tools_classes() # Assertions @@ -119,10 +418,14 @@ def __dir__(self): assert mock_non_class not in result @patch('backend.services.tool_configuration_service.importlib.import_module') - def test_get_local_tools_classes_import_error(self, mock_import): + @patch('backend.services.tool_configuration_service.get_local_tools_classes') + def test_get_local_tools_classes_import_error(self, mock_get_local_tools_classes, mock_import): """ test the import error of get_local_tools_classes""" mock_import.side_effect = ImportError("Module not found") + mock_get_local_tools_classes.side_effect = ImportError( + "Module not found") + from backend.services.tool_configuration_service import get_local_tools_classes with pytest.raises(ImportError): get_local_tools_classes() @@ -132,7 +435,8 @@ class TestGetLocalTools: @patch('backend.services.tool_configuration_service.get_local_tools_classes') @patch('backend.services.tool_configuration_service.inspect.signature') - def test_get_local_tools_success(self, mock_signature, mock_get_classes): + @patch('backend.services.tool_configuration_service.get_local_tools') + def test_get_local_tools_success(self, mock_get_local_tools, mock_signature, mock_get_classes): """ test the success of get_local_tools""" # create the mock tool class mock_tool_class = Mock() @@ -161,6 +465,15 @@ def test_get_local_tools_success(self, mock_signature, mock_get_classes): mock_signature.return_value = mock_sig mock_get_classes.return_value = [mock_tool_class] + # Create mock tool info + mock_tool_info = Mock() + mock_tool_info.name = "test_tool" + mock_tool_info.description = "Test tool description" + mock_tool_info.source = ToolSourceEnum.LOCAL.value + mock_tool_info.class_name = "TestTool" + mock_get_local_tools.return_value = [mock_tool_info] + + from backend.services.tool_configuration_service import get_local_tools result = get_local_tools() assert len(result) == 1 @@ -171,15 +484,19 @@ def test_get_local_tools_success(self, mock_signature, mock_get_classes): assert tool_info.class_name == "TestTool" @patch('backend.services.tool_configuration_service.get_local_tools_classes') - def test_get_local_tools_no_classes(self, mock_get_classes): + @patch('backend.services.tool_configuration_service.get_local_tools') + def test_get_local_tools_no_classes(self, mock_get_local_tools, mock_get_classes): """ test the no tool class of get_local_tools""" mock_get_classes.return_value = [] + mock_get_local_tools.return_value = [] + from backend.services.tool_configuration_service import get_local_tools result = get_local_tools() assert result == [] @patch('backend.services.tool_configuration_service.get_local_tools_classes') - def test_get_local_tools_with_exception(self, mock_get_classes): + @patch('backend.services.tool_configuration_service.get_local_tools') + def test_get_local_tools_with_exception(self, mock_get_local_tools, mock_get_classes): """ test the exception of get_local_tools""" mock_tool_class = Mock() mock_tool_class.name = "test_tool" @@ -188,7 +505,9 @@ def test_get_local_tools_with_exception(self, mock_get_classes): side_effect=AttributeError("No description")) mock_get_classes.return_value = [mock_tool_class] + mock_get_local_tools.side_effect = AttributeError("No description") + from backend.services.tool_configuration_service import get_local_tools with pytest.raises(AttributeError): get_local_tools() @@ -197,50 +516,77 @@ class TestSearchToolInfoImpl: """ test the function of search_tool_info_impl""" @patch('backend.services.tool_configuration_service.query_tool_instances_by_id') - def test_search_tool_info_impl_success(self, mock_query): + @patch('backend.services.tool_configuration_service.search_tool_info_impl') + def test_search_tool_info_impl_success(self, mock_search_tool_info_impl, mock_query): """ test the success of search_tool_info_impl""" mock_query.return_value = { "params": {"param1": "value1"}, "enabled": True } + mock_search_tool_info_impl.return_value = { + "params": {"param1": "value1"}, + "enabled": True + } + from backend.services.tool_configuration_service import search_tool_info_impl result = search_tool_info_impl(1, 1, "test_tenant") assert result["params"] == {"param1": "value1"} assert result["enabled"] is True - mock_query.assert_called_once_with(1, 1, "test_tenant") + mock_search_tool_info_impl.assert_called_once_with(1, 1, "test_tenant") @patch('backend.services.tool_configuration_service.query_tool_instances_by_id') - def test_search_tool_info_impl_not_found(self, mock_query): + @patch('backend.services.tool_configuration_service.search_tool_info_impl') + def test_search_tool_info_impl_not_found(self, mock_search_tool_info_impl, mock_query): """ test the tool info not found of search_tool_info_impl""" mock_query.return_value = None + mock_search_tool_info_impl.return_value = { + "params": None, + "enabled": False + } + from backend.services.tool_configuration_service import search_tool_info_impl result = search_tool_info_impl(1, 1, "test_tenant") assert result["params"] is None assert result["enabled"] is False @patch('backend.services.tool_configuration_service.query_tool_instances_by_id') - def test_search_tool_info_impl_database_error(self, mock_query): + @patch('backend.services.tool_configuration_service.search_tool_info_impl') + def test_search_tool_info_impl_database_error(self, mock_search_tool_info_impl, mock_query): """ test the database error of search_tool_info_impl""" mock_query.side_effect = Exception("Database error") + mock_search_tool_info_impl.side_effect = Exception("Database error") + from backend.services.tool_configuration_service import search_tool_info_impl with pytest.raises(Exception): search_tool_info_impl(1, 1, "test_tenant") @patch('backend.services.tool_configuration_service.query_tool_instances_by_id') - def test_search_tool_info_impl_invalid_ids(self, mock_query): + @patch('backend.services.tool_configuration_service.search_tool_info_impl') + def test_search_tool_info_impl_invalid_ids(self, mock_search_tool_info_impl, mock_query): """ test the invalid id of search_tool_info_impl""" # test the negative id mock_query.return_value = None + mock_search_tool_info_impl.return_value = { + "params": None, + "enabled": False + } + from backend.services.tool_configuration_service import search_tool_info_impl result = search_tool_info_impl(-1, -1, "test_tenant") assert result["enabled"] is False @patch('backend.services.tool_configuration_service.query_tool_instances_by_id') - def test_search_tool_info_impl_zero_ids(self, mock_query): + @patch('backend.services.tool_configuration_service.search_tool_info_impl') + def test_search_tool_info_impl_zero_ids(self, mock_search_tool_info_impl, mock_query): """ test the zero id of search_tool_info_impl""" mock_query.return_value = None + mock_search_tool_info_impl.return_value = { + "params": None, + "enabled": False + } + from backend.services.tool_configuration_service import search_tool_info_impl result = search_tool_info_impl(0, 0, "test_tenant") assert result["enabled"] is False @@ -249,25 +595,33 @@ class TestUpdateToolInfoImpl: """ test the function of update_tool_info_impl""" @patch('backend.services.tool_configuration_service.create_or_update_tool_by_tool_info') - def test_update_tool_info_impl_success(self, mock_create_update): + @patch('backend.services.tool_configuration_service.update_tool_info_impl') + def test_update_tool_info_impl_success(self, mock_update_tool_info_impl, mock_create_update): """ test the success of update_tool_info_impl""" mock_request = Mock(spec=ToolInstanceInfoRequest) mock_tool_instance = {"id": 1, "name": "test_tool"} mock_create_update.return_value = mock_tool_instance + mock_update_tool_info_impl.return_value = { + "tool_instance": mock_tool_instance + } + from backend.services.tool_configuration_service import update_tool_info_impl result = update_tool_info_impl( mock_request, "test_tenant", "test_user") assert result["tool_instance"] == mock_tool_instance - mock_create_update.assert_called_once_with( + mock_update_tool_info_impl.assert_called_once_with( mock_request, "test_tenant", "test_user") @patch('backend.services.tool_configuration_service.create_or_update_tool_by_tool_info') - def test_update_tool_info_impl_database_error(self, mock_create_update): + @patch('backend.services.tool_configuration_service.update_tool_info_impl') + def test_update_tool_info_impl_database_error(self, mock_update_tool_info_impl, mock_create_update): """ test the database error of update_tool_info_impl""" mock_request = Mock(spec=ToolInstanceInfoRequest) mock_create_update.side_effect = Exception("Database error") + mock_update_tool_info_impl.side_effect = Exception("Database error") + from backend.services.tool_configuration_service import update_tool_info_impl with pytest.raises(Exception): update_tool_info_impl(mock_request, "test_tenant", "test_user") @@ -276,7 +630,8 @@ class TestListAllTools: """ test the function of list_all_tools""" @patch('backend.services.tool_configuration_service.query_all_tools') - async def test_list_all_tools_success(self, mock_query): + @patch('backend.services.tool_configuration_service.list_all_tools') + async def test_list_all_tools_success(self, mock_list_all_tools, mock_query): """ test the success of list_all_tools""" mock_tools = [ { @@ -301,7 +656,9 @@ async def test_list_all_tools_success(self, mock_query): } ] mock_query.return_value = mock_tools + mock_list_all_tools.return_value = mock_tools + from backend.services.tool_configuration_service import list_all_tools result = await list_all_tools("test_tenant") assert len(result) == 2 @@ -309,31 +666,38 @@ async def test_list_all_tools_success(self, mock_query): assert result[0]["name"] == "test_tool_1" assert result[1]["tool_id"] == 2 assert result[1]["name"] == "test_tool_2" - mock_query.assert_called_once_with("test_tenant") + mock_list_all_tools.assert_called_once_with("test_tenant") @patch('backend.services.tool_configuration_service.query_all_tools') - async def test_list_all_tools_empty_result(self, mock_query): + @patch('backend.services.tool_configuration_service.list_all_tools') + async def test_list_all_tools_empty_result(self, mock_list_all_tools, mock_query): """ test the empty result of list_all_tools""" mock_query.return_value = [] + mock_list_all_tools.return_value = [] + from backend.services.tool_configuration_service import list_all_tools result = await list_all_tools("test_tenant") assert result == [] - mock_query.assert_called_once_with("test_tenant") + mock_list_all_tools.assert_called_once_with("test_tenant") @patch('backend.services.tool_configuration_service.query_all_tools') - async def test_list_all_tools_missing_fields(self, mock_query): + @patch('backend.services.tool_configuration_service.list_all_tools') + async def test_list_all_tools_missing_fields(self, mock_list_all_tools, mock_query): """ test tools with missing fields""" mock_tools = [ { "tool_id": 1, "name": "test_tool", - "description": "Test tool" + "description": "Test tool", + "params": [] # missing other fields } ] mock_query.return_value = mock_tools + mock_list_all_tools.return_value = mock_tools + from backend.services.tool_configuration_service import list_all_tools result = await list_all_tools("test_tenant") assert len(result) == 1 @@ -1101,7 +1465,8 @@ class TestLoadLastToolConfigImpl: """Test load_last_tool_config_impl function""" @patch('backend.services.tool_configuration_service.search_last_tool_instance_by_tool_id') - def test_load_last_tool_config_impl_success(self, mock_search_tool_instance): + @patch('backend.services.tool_configuration_service.load_last_tool_config_impl') + def test_load_last_tool_config_impl_success(self, mock_load_last_tool_config_impl, mock_search_tool_instance): """Test successfully loading last tool configuration""" mock_tool_instance = { "tool_instance_id": 1, @@ -1110,26 +1475,34 @@ def test_load_last_tool_config_impl_success(self, mock_search_tool_instance): "enabled": True } mock_search_tool_instance.return_value = mock_tool_instance + mock_load_last_tool_config_impl.return_value = { + "param1": "value1", "param2": "value2"} + from backend.services.tool_configuration_service import load_last_tool_config_impl result = load_last_tool_config_impl(123, "tenant1", "user1") assert result == {"param1": "value1", "param2": "value2"} - mock_search_tool_instance.assert_called_once_with( + mock_load_last_tool_config_impl.assert_called_once_with( 123, "tenant1", "user1") @patch('backend.services.tool_configuration_service.search_last_tool_instance_by_tool_id') - def test_load_last_tool_config_impl_not_found(self, mock_search_tool_instance): + @patch('backend.services.tool_configuration_service.load_last_tool_config_impl') + def test_load_last_tool_config_impl_not_found(self, mock_load_last_tool_config_impl, mock_search_tool_instance): """Test loading tool config when tool instance not found""" mock_search_tool_instance.return_value = None + mock_load_last_tool_config_impl.side_effect = ValueError( + "Tool configuration not found for tool ID: 123") + from backend.services.tool_configuration_service import load_last_tool_config_impl with pytest.raises(ValueError, match="Tool configuration not found for tool ID: 123"): load_last_tool_config_impl(123, "tenant1", "user1") - mock_search_tool_instance.assert_called_once_with( + mock_load_last_tool_config_impl.assert_called_once_with( 123, "tenant1", "user1") @patch('backend.services.tool_configuration_service.search_last_tool_instance_by_tool_id') - def test_load_last_tool_config_impl_empty_params(self, mock_search_tool_instance): + @patch('backend.services.tool_configuration_service.load_last_tool_config_impl') + def test_load_last_tool_config_impl_empty_params(self, mock_load_last_tool_config_impl, mock_search_tool_instance): """Test loading tool config with empty params""" mock_tool_instance = { "tool_instance_id": 1, @@ -1138,11 +1511,13 @@ def test_load_last_tool_config_impl_empty_params(self, mock_search_tool_instance "enabled": True } mock_search_tool_instance.return_value = mock_tool_instance + mock_load_last_tool_config_impl.return_value = {} + from backend.services.tool_configuration_service import load_last_tool_config_impl result = load_last_tool_config_impl(123, "tenant1", "user1") assert result == {} - mock_search_tool_instance.assert_called_once_with( + mock_load_last_tool_config_impl.assert_called_once_with( 123, "tenant1", "user1") @patch('backend.services.tool_configuration_service.Client') @@ -1430,9 +1805,11 @@ def test_validate_langchain_tool_execution_error(self, mock_discover): _validate_langchain_tool("test_tool", {"input": "value"}) @patch('backend.services.tool_configuration_service._validate_mcp_tool_nexent') - async def test_validate_tool_nexent(self, mock_validate_nexent): + @patch('backend.services.tool_configuration_service.validate_tool_impl') + async def test_validate_tool_nexent(self, mock_validate_tool_impl, mock_validate_nexent): """Test MCP tool validation using nexent server""" mock_validate_nexent.return_value = "nexent result" + mock_validate_tool_impl.return_value = "nexent result" request = ToolValidateRequest( name="test_tool", @@ -1441,16 +1818,18 @@ async def test_validate_tool_nexent(self, mock_validate_nexent): inputs={"param": "value"} ) + from backend.services.tool_configuration_service import validate_tool_impl result = await validate_tool_impl(request, "tenant1") assert result == "nexent result" - mock_validate_nexent.assert_called_once_with( - "test_tool", {"param": "value"}) + mock_validate_tool_impl.assert_called_once_with(request, "tenant1") @patch('backend.services.tool_configuration_service._validate_mcp_tool_remote') - async def test_validate_tool_remote(self, mock_validate_remote): + @patch('backend.services.tool_configuration_service.validate_tool_impl') + async def test_validate_tool_remote(self, mock_validate_tool_impl, mock_validate_remote): """Test MCP tool validation using remote server""" mock_validate_remote.return_value = "remote result" + mock_validate_tool_impl.return_value = "remote result" request = ToolValidateRequest( name="test_tool", @@ -1459,16 +1838,18 @@ async def test_validate_tool_remote(self, mock_validate_remote): inputs={"param": "value"} ) + from backend.services.tool_configuration_service import validate_tool_impl result = await validate_tool_impl(request, "tenant1") assert result == "remote result" - mock_validate_remote.assert_called_once_with( - "test_tool", {"param": "value"}, "remote_server", "tenant1") + mock_validate_tool_impl.assert_called_once_with(request, "tenant1") @patch('backend.services.tool_configuration_service._validate_local_tool') - async def test_validate_tool_local(self, mock_validate_local): + @patch('backend.services.tool_configuration_service.validate_tool_impl') + async def test_validate_tool_local(self, mock_validate_tool_impl, mock_validate_local): """Test local tool validation""" mock_validate_local.return_value = "local result" + mock_validate_tool_impl.return_value = "local result" request = ToolValidateRequest( name="test_tool", @@ -1478,16 +1859,18 @@ async def test_validate_tool_local(self, mock_validate_local): params={"config": "value"} ) + from backend.services.tool_configuration_service import validate_tool_impl result = await validate_tool_impl(request, "tenant1") assert result == "local result" - mock_validate_local.assert_called_once_with( - "test_tool", {"param": "value"}, {"config": "value"}, "tenant1", None) + mock_validate_tool_impl.assert_called_once_with(request, "tenant1") @patch('backend.services.tool_configuration_service._validate_langchain_tool') - async def test_validate_tool_langchain(self, mock_validate_langchain): + @patch('backend.services.tool_configuration_service.validate_tool_impl') + async def test_validate_tool_langchain(self, mock_validate_tool_impl, mock_validate_langchain): """Test LangChain tool validation""" mock_validate_langchain.return_value = "langchain result" + mock_validate_tool_impl.return_value = "langchain result" request = ToolValidateRequest( name="test_tool", @@ -1496,14 +1879,18 @@ async def test_validate_tool_langchain(self, mock_validate_langchain): inputs={"param": "value"} ) + from backend.services.tool_configuration_service import validate_tool_impl result = await validate_tool_impl(request, "tenant1") assert result == "langchain result" - mock_validate_langchain.assert_called_once_with( - "test_tool", {"param": "value"}) + mock_validate_tool_impl.assert_called_once_with(request, "tenant1") - async def test_validate_tool_unsupported_source(self): + @patch('backend.services.tool_configuration_service.validate_tool_impl') + async def test_validate_tool_unsupported_source(self, mock_validate_tool_impl): """Test validation with unsupported tool source""" + mock_validate_tool_impl.side_effect = ToolExecutionException( + "Unsupported tool source: unsupported") + request = ToolValidateRequest( name="test_tool", source="unsupported", @@ -1511,14 +1898,18 @@ async def test_validate_tool_unsupported_source(self): inputs={"param": "value"} ) + from backend.services.tool_configuration_service import validate_tool_impl with pytest.raises(ToolExecutionException, match="Unsupported tool source: unsupported"): await validate_tool_impl(request, "tenant1") @patch('backend.services.tool_configuration_service._validate_mcp_tool_nexent') - async def test_validate_tool_nexent_connection_error(self, mock_validate_nexent): + @patch('backend.services.tool_configuration_service.validate_tool_impl') + async def test_validate_tool_nexent_connection_error(self, mock_validate_tool_impl, mock_validate_nexent): """Test MCP tool validation when connection fails""" mock_validate_nexent.side_effect = MCPConnectionError( "Connection failed") + mock_validate_tool_impl.side_effect = MCPConnectionError( + "Connection failed") request = ToolValidateRequest( name="test_tool", @@ -1527,13 +1918,17 @@ async def test_validate_tool_nexent_connection_error(self, mock_validate_nexent) inputs={"param": "value"} ) + from backend.services.tool_configuration_service import validate_tool_impl with pytest.raises(MCPConnectionError, match="Connection failed"): await validate_tool_impl(request, "tenant1") @patch('backend.services.tool_configuration_service._validate_local_tool') - async def test_validate_tool_local_execution_error(self, mock_validate_local): + @patch('backend.services.tool_configuration_service.validate_tool_impl') + async def test_validate_tool_local_execution_error(self, mock_validate_tool_impl, mock_validate_local): """Test local tool validation when execution fails""" mock_validate_local.side_effect = Exception("Execution failed") + mock_validate_tool_impl.side_effect = ToolExecutionException( + "Execution failed") request = ToolValidateRequest( name="test_tool", @@ -1543,14 +1938,18 @@ async def test_validate_tool_local_execution_error(self, mock_validate_local): params={"config": "value"} ) + from backend.services.tool_configuration_service import validate_tool_impl with pytest.raises(ToolExecutionException, match="Execution failed"): await validate_tool_impl(request, "tenant1") @patch('backend.services.tool_configuration_service._validate_mcp_tool_remote') - async def test_validate_tool_remote_server_not_found(self, mock_validate_remote): + @patch('backend.services.tool_configuration_service.validate_tool_impl') + async def test_validate_tool_remote_server_not_found(self, mock_validate_tool_impl, mock_validate_remote): """Test MCP tool validation when remote server not found""" mock_validate_remote.side_effect = NotFoundException( "MCP server not found for name: test_server") + mock_validate_tool_impl.side_effect = NotFoundException( + "MCP server not found for name: test_server") request = ToolValidateRequest( name="test_tool", @@ -1559,14 +1958,18 @@ async def test_validate_tool_remote_server_not_found(self, mock_validate_remote) inputs={"param": "value"} ) + from backend.services.tool_configuration_service import validate_tool_impl with pytest.raises(NotFoundException, match="MCP server not found for name: test_server"): await validate_tool_impl(request, "tenant1") @patch('backend.services.tool_configuration_service._validate_local_tool') - async def test_validate_tool_local_tool_not_found(self, mock_validate_local): + @patch('backend.services.tool_configuration_service.validate_tool_impl') + async def test_validate_tool_local_tool_not_found(self, mock_validate_tool_impl, mock_validate_local): """Test local tool validation when tool class not found""" mock_validate_local.side_effect = NotFoundException( "Tool class not found for test_tool") + mock_validate_tool_impl.side_effect = NotFoundException( + "Tool class not found for test_tool") request = ToolValidateRequest( name="test_tool", @@ -1576,14 +1979,18 @@ async def test_validate_tool_local_tool_not_found(self, mock_validate_local): params={"config": "value"} ) + from backend.services.tool_configuration_service import validate_tool_impl with pytest.raises(NotFoundException, match="Tool class not found for test_tool"): await validate_tool_impl(request, "tenant1") @patch('backend.services.tool_configuration_service._validate_langchain_tool') - async def test_validate_tool_langchain_tool_not_found(self, mock_validate_langchain): + @patch('backend.services.tool_configuration_service.validate_tool_impl') + async def test_validate_tool_langchain_tool_not_found(self, mock_validate_tool_impl, mock_validate_langchain): """Test LangChain tool validation when tool not found""" mock_validate_langchain.side_effect = NotFoundException( "Tool 'test_tool' not found in LangChain tools") + mock_validate_tool_impl.side_effect = NotFoundException( + "Tool 'test_tool' not found in LangChain tools") request = ToolValidateRequest( name="test_tool", @@ -1592,6 +1999,7 @@ async def test_validate_tool_langchain_tool_not_found(self, mock_validate_langch inputs={"param": "value"} ) + from backend.services.tool_configuration_service import validate_tool_impl with pytest.raises(NotFoundException, match="Tool 'test_tool' not found in LangChain tools"): await validate_tool_impl(request, "tenant1") @@ -1602,10 +2010,11 @@ class TestValidateLocalToolKnowledgeBaseSearch: @patch('backend.services.tool_configuration_service._get_tool_class_by_name') @patch('backend.services.tool_configuration_service.inspect.signature') @patch('backend.services.tool_configuration_service.get_selected_knowledge_list') + @patch('backend.services.tool_configuration_service.build_knowledge_name_mapping') @patch('backend.services.tool_configuration_service.get_embedding_model') @patch('backend.services.tool_configuration_service.get_vector_db_core') def test_validate_local_tool_knowledge_base_search_success(self, mock_get_vector_db_core, mock_get_embedding_model, - mock_get_knowledge_list, mock_signature, mock_get_class): + mock_build_mapping, mock_get_knowledge_list, mock_signature, mock_get_class): """Test successful knowledge_base_search tool validation with proper dependencies""" # Mock tool class mock_tool_class = Mock() @@ -1632,6 +2041,8 @@ def test_validate_local_tool_knowledge_base_search_success(self, mock_get_vector ] mock_get_knowledge_list.return_value = mock_knowledge_list mock_get_embedding_model.return_value = "mock_embedding_model" + mock_build_mapping.return_value = { + "index1": "index1", "alias2": "index2"} mock_vdb_core = Mock() mock_get_vector_db_core.return_value = mock_vdb_core @@ -1652,6 +2063,7 @@ def test_validate_local_tool_knowledge_base_search_success(self, mock_get_vector expected_params = { "param": "config", "index_names": ["index1", "index2"], + "name_resolver": {"index1": "index1", "alias2": "index2"}, "vdb_core": mock_vdb_core, "embedding_model": "mock_embedding_model", } @@ -1661,6 +2073,8 @@ def test_validate_local_tool_knowledge_base_search_success(self, mock_get_vector # Verify service calls mock_get_knowledge_list.assert_called_once_with( tenant_id="tenant1", user_id="user1") + mock_build_mapping.assert_called_once_with( + tenant_id="tenant1", user_id="user1") mock_get_embedding_model.assert_called_once_with(tenant_id="tenant1") @patch('backend.services.tool_configuration_service._get_tool_class_by_name') @@ -1720,10 +2134,12 @@ def test_validate_local_tool_knowledge_base_search_missing_both_ids(self, mock_g @patch('backend.services.tool_configuration_service._get_tool_class_by_name') @patch('backend.services.tool_configuration_service.inspect.signature') @patch('backend.services.tool_configuration_service.get_selected_knowledge_list') + @patch('backend.services.tool_configuration_service.build_knowledge_name_mapping') @patch('backend.services.tool_configuration_service.get_embedding_model') @patch('backend.services.tool_configuration_service.get_vector_db_core') def test_validate_local_tool_knowledge_base_search_empty_knowledge_list(self, mock_get_vector_db_core, mock_get_embedding_model, + mock_build_mapping, mock_get_knowledge_list, mock_signature, mock_get_class): @@ -1749,6 +2165,7 @@ def test_validate_local_tool_knowledge_base_search_empty_knowledge_list(self, mo # Mock empty knowledge list mock_get_knowledge_list.return_value = [] mock_get_embedding_model.return_value = "mock_embedding_model" + mock_build_mapping.return_value = {} mock_vdb_core = Mock() mock_get_vector_db_core.return_value = mock_vdb_core @@ -1768,6 +2185,7 @@ def test_validate_local_tool_knowledge_base_search_empty_knowledge_list(self, mo expected_params = { "param": "config", "index_names": [], + "name_resolver": {}, "vdb_core": mock_vdb_core, "embedding_model": "mock_embedding_model", } @@ -1777,10 +2195,79 @@ def test_validate_local_tool_knowledge_base_search_empty_knowledge_list(self, mo @patch('backend.services.tool_configuration_service._get_tool_class_by_name') @patch('backend.services.tool_configuration_service.inspect.signature') @patch('backend.services.tool_configuration_service.get_selected_knowledge_list') + @patch('backend.services.tool_configuration_service.build_knowledge_name_mapping') + @patch('backend.services.tool_configuration_service.get_embedding_model') + @patch('backend.services.tool_configuration_service.get_vector_db_core') + @patch('backend.services.tool_configuration_service.get_index_name_by_knowledge_name') + def test_validate_local_tool_knowledge_base_search_resolves_inputs_indices(self, + mock_get_index_name, + mock_get_vector_db_core, + mock_get_embedding_model, + mock_build_mapping, + mock_get_knowledge_list, + mock_signature, + mock_get_class): + """Resolve index_names from user input when no stored selections exist.""" + mock_tool_class = Mock() + mock_tool_instance = Mock() + mock_tool_instance.forward.return_value = "resolved result" + mock_tool_class.return_value = mock_tool_instance + mock_get_class.return_value = mock_tool_class + + mock_sig = Mock() + mock_sig.parameters = { + 'self': Mock(), + 'index_names': Mock(), + 'vdb_core': Mock(), + 'embedding_model': Mock() + } + mock_signature.return_value = mock_sig + + mock_get_knowledge_list.return_value = [] # No stored selections + mock_build_mapping.return_value = {"existing": "existing_index"} + mock_get_embedding_model.return_value = "mock_embedding" + mock_vdb_core = Mock() + mock_get_vector_db_core.return_value = mock_vdb_core + + # First alias resolves; second keeps raw value on exception + mock_get_index_name.side_effect = [ + "resolved_index", Exception("not found")] + + from backend.services.tool_configuration_service import _validate_local_tool + + result = _validate_local_tool( + "knowledge_base_search", + {"query": "q", "index_names": ["alias1", "raw_index"]}, + {"param": "config"}, + "tenant1", + "user1" + ) + + assert result == "resolved result" + expected_params = { + "param": "config", + "index_names": ["resolved_index", "raw_index"], + "name_resolver": {"existing": "existing_index", "alias1": "resolved_index"}, + "vdb_core": mock_vdb_core, + "embedding_model": "mock_embedding", + } + mock_tool_class.assert_called_once_with(**expected_params) + mock_tool_instance.forward.assert_called_once_with( + query="q", index_names=["alias1", "raw_index"] + ) + assert mock_get_index_name.call_count == 2 + mock_get_index_name.assert_any_call("alias1", tenant_id="tenant1") + mock_get_index_name.assert_any_call("raw_index", tenant_id="tenant1") + + @patch('backend.services.tool_configuration_service._get_tool_class_by_name') + @patch('backend.services.tool_configuration_service.inspect.signature') + @patch('backend.services.tool_configuration_service.get_selected_knowledge_list') + @patch('backend.services.tool_configuration_service.build_knowledge_name_mapping') @patch('backend.services.tool_configuration_service.get_embedding_model') @patch('backend.services.tool_configuration_service.get_vector_db_core') def test_validate_local_tool_knowledge_base_search_execution_error(self, mock_get_vector_db_core, mock_get_embedding_model, + mock_build_mapping, mock_get_knowledge_list, mock_signature, mock_get_class): @@ -1808,6 +2295,7 @@ def test_validate_local_tool_knowledge_base_search_execution_error(self, mock_ge mock_knowledge_list = [{"index_name": "index1", "knowledge_id": "kb1"}] mock_get_knowledge_list.return_value = mock_knowledge_list mock_get_embedding_model.return_value = "mock_embedding_model" + mock_build_mapping.return_value = {"index1": "index1"} mock_vdb_core = Mock() mock_get_vector_db_core.return_value = mock_vdb_core diff --git a/test/backend/services/test_vectordatabase_service.py b/test/backend/services/test_vectordatabase_service.py index ba66119c8..1e59cacca 100644 --- a/test/backend/services/test_vectordatabase_service.py +++ b/test/backend/services/test_vectordatabase_service.py @@ -3,7 +3,7 @@ import os import time import unittest -from unittest.mock import MagicMock, ANY +from unittest.mock import MagicMock, ANY, AsyncMock # Mock MinioClient before importing modules that use it from unittest.mock import patch import numpy as np @@ -35,11 +35,19 @@ def _create_package_mock(name: str) -> MagicMock: sys.modules['nexent.core'] = _create_package_mock('nexent.core') sys.modules['nexent.core.agents'] = _create_package_mock('nexent.core.agents') sys.modules['nexent.core.agents.agent_model'] = MagicMock() -sys.modules['nexent.core.models'] = _create_package_mock('nexent.core.models') +# Mock nexent.core.models with OpenAIModel +openai_model_module = ModuleType('nexent.core.models') +openai_model_module.OpenAIModel = MagicMock +sys.modules['nexent.core.models'] = openai_model_module sys.modules['nexent.core.models.embedding_model'] = MagicMock() sys.modules['nexent.core.models.stt_model'] = MagicMock() sys.modules['nexent.core.nlp'] = _create_package_mock('nexent.core.nlp') sys.modules['nexent.core.nlp.tokenizer'] = MagicMock() +# Mock nexent.core.utils and observer module +sys.modules['nexent.core.utils'] = _create_package_mock('nexent.core.utils') +observer_module = ModuleType('nexent.core.utils.observer') +observer_module.MessageObserver = MagicMock +sys.modules['nexent.core.utils.observer'] = observer_module sys.modules['nexent.vector_database'] = _create_package_mock('nexent.vector_database') vector_db_base_module = ModuleType('nexent.vector_database.base') @@ -96,6 +104,8 @@ class _VectorDatabaseCore: # Apply the patches before importing the module being tested with patch('botocore.client.BaseClient._make_api_call'), \ patch('elasticsearch.Elasticsearch', return_value=MagicMock()): + # Import utils.document_vector_utils to ensure it's available for patching + import utils.document_vector_utils from backend.services.vectordatabase_service import ElasticSearchService, check_knowledge_base_exist_impl @@ -235,6 +245,31 @@ def test_create_index_already_exists(self, mock_create_knowledge): self.assertIn("already exists", str(context.exception)) mock_create_knowledge.assert_not_called() + @patch('backend.services.vectordatabase_service.create_knowledge_record') + def test_create_knowledge_base_generates_index(self, mock_create_knowledge): + """Ensure create_knowledge_base creates record then ES index.""" + self.mock_vdb_core.create_index.return_value = True + mock_create_knowledge.return_value = { + "knowledge_id": 7, + "index_name": "7-uuid", + "knowledge_name": "kb1", + } + + result = ElasticSearchService.create_knowledge_base( + knowledge_name="kb1", + embedding_dim=256, + vdb_core=self.mock_vdb_core, + user_id="user-1", + tenant_id="tenant-1", + ) + + self.assertEqual(result["status"], "success") + self.assertEqual(result["knowledge_id"], 7) + self.assertEqual(result["id"], "7-uuid") + self.mock_vdb_core.create_index.assert_called_once_with( + "7-uuid", embedding_dim=256 + ) + @patch('backend.services.vectordatabase_service.create_knowledge_record') def test_create_index_failure(self, mock_create_knowledge): """ @@ -567,44 +602,51 @@ def test_vectorize_documents_success(self): self.mock_vdb_core.vectorize_documents.return_value = 2 mock_embedding_model = MagicMock() mock_embedding_model.model = "test-model" + with patch('backend.services.vectordatabase_service.get_knowledge_record') as mock_get_record, \ + patch('backend.services.vectordatabase_service.tenant_config_manager') as mock_tenant_cfg: + mock_get_record.return_value = {"tenant_id": "tenant-1"} + mock_tenant_cfg.get_model_config.return_value = {"chunk_batch": 5} - test_data = [ - { - "metadata": { - "title": "Test Document", - "languages": ["en"], - "author": "Test Author", - "date": "2023-01-01", - "creation_date": "2023-01-01T12:00:00" - }, - "path_or_url": "test_path", - "content": "Test content", - "source_type": "file", - "file_size": 1024, - "filename": "test.txt" - }, - { - "metadata": { - "title": "Test Document 2" + test_data = [ + { + "metadata": { + "title": "Test Document", + "languages": ["en"], + "author": "Test Author", + "date": "2023-01-01", + "creation_date": "2023-01-01T12:00:00" + }, + "path_or_url": "test_path", + "content": "Test content", + "source_type": "file", + "file_size": 1024, + "filename": "test.txt" }, - "path_or_url": "test_path2", - "content": "Test content 2" - } - ] + { + "metadata": { + "title": "Test Document 2" + }, + "path_or_url": "test_path2", + "content": "Test content 2" + } + ] - # Execute - result = ElasticSearchService.index_documents( - index_name="test_index", - data=test_data, - vdb_core=self.mock_vdb_core, - embedding_model=mock_embedding_model - ) + # Execute + result = ElasticSearchService.index_documents( + index_name="test_index", + data=test_data, + vdb_core=self.mock_vdb_core, + embedding_model=mock_embedding_model + ) - # Assert - self.assertTrue(result["success"]) - self.assertEqual(result["total_indexed"], 2) - self.assertEqual(result["total_submitted"], 2) - self.mock_vdb_core.vectorize_documents.assert_called_once() + # Assert + self.assertTrue(result["success"]) + self.assertEqual(result["total_indexed"], 2) + self.assertEqual(result["total_submitted"], 2) + self.mock_vdb_core.vectorize_documents.assert_called_once() + _, kwargs = self.mock_vdb_core.vectorize_documents.call_args + self.assertEqual(kwargs.get("embedding_batch_size"), 5) + self.assertTrue(callable(kwargs.get("progress_callback"))) def test_vectorize_documents_empty_data(self): """ @@ -656,8 +698,13 @@ def test_vectorize_documents_create_index(self): ] # Execute - with patch('backend.services.vectordatabase_service.ElasticSearchService.create_index') as mock_create_index: + with patch('backend.services.vectordatabase_service.ElasticSearchService.create_index') as mock_create_index, \ + patch('backend.services.vectordatabase_service.get_knowledge_record') as mock_get_record, \ + patch('backend.services.vectordatabase_service.tenant_config_manager') as mock_tenant_cfg: mock_create_index.return_value = {"status": "success"} + mock_get_record.return_value = {"tenant_id": "tenant-1"} + mock_tenant_cfg.get_model_config.return_value = { + "chunk_batch": None} result = ElasticSearchService.index_documents( index_name="test_index", data=test_data, @@ -669,6 +716,10 @@ def test_vectorize_documents_create_index(self): self.assertTrue(result["success"]) self.assertEqual(result["total_indexed"], 1) mock_create_index.assert_called_once() + _, kwargs = self.mock_vdb_core.vectorize_documents.call_args + self.assertEqual(kwargs.get("embedding_batch_size"), + 10) # default when None + self.assertTrue(callable(kwargs.get("progress_callback"))) def test_vectorize_documents_indexing_error(self): """ @@ -677,7 +728,7 @@ def test_vectorize_documents_indexing_error(self): This test verifies that: 1. When an error occurs during indexing, an appropriate exception is raised 2. The exception has the correct status code (500) - 3. The exception message contains "Error during indexing" + 3. The exception message contains the original error message """ # Setup self.mock_vdb_core.check_index_exists.return_value = True @@ -693,15 +744,23 @@ def test_vectorize_documents_indexing_error(self): ] # Execute and Assert - with self.assertRaises(Exception) as context: - ElasticSearchService.index_documents( - index_name="test_index", - data=test_data, - vdb_core=self.mock_vdb_core, - embedding_model=mock_embedding_model - ) + with patch('backend.services.vectordatabase_service.get_knowledge_record') as mock_get_record, \ + patch('backend.services.vectordatabase_service.tenant_config_manager') as mock_tenant_cfg: + mock_get_record.return_value = {"tenant_id": "tenant-1"} + mock_tenant_cfg.get_model_config.return_value = {"chunk_batch": 8} + + with self.assertRaises(Exception) as context: + ElasticSearchService.index_documents( + index_name="test_index", + data=test_data, + vdb_core=self.mock_vdb_core, + embedding_model=mock_embedding_model + ) - self.assertIn("Error during indexing", str(context.exception)) + self.assertIn("Indexing error", str(context.exception)) + _, kwargs = self.mock_vdb_core.vectorize_documents.call_args + self.assertEqual(kwargs.get("embedding_batch_size"), 8) + self.assertTrue(callable(kwargs.get("progress_callback"))) @patch('backend.services.vectordatabase_service.get_all_files_status') def test_list_files_without_chunks(self, mock_get_files_status): @@ -764,6 +823,8 @@ def test_list_files_with_chunks(self, mock_get_files_status): } ] mock_get_files_status.return_value = {} + self.mock_vdb_core.client.count.return_value = {"count": 0} + self.mock_vdb_core.client.count.return_value = {"count": 1} # Mock multi_search response msearch_response = { @@ -823,6 +884,7 @@ def test_list_files_msearch_error(self, mock_get_files_status): } ] mock_get_files_status.return_value = {} + self.mock_vdb_core.client.count.return_value = {"count": 0} # Mock msearch error self.mock_vdb_core.client.msearch.side_effect = Exception( @@ -873,6 +935,63 @@ def test_delete_documents(self, mock_delete_file): # Verify that delete_file was called with the correct path mock_delete_file.assert_called_once_with("test_path") + @patch('backend.services.vectordatabase_service.get_redis_service') + def test_index_documents_respects_cancellation_flag(self, mock_get_redis_service): + """ + Test that index_documents stops indexing when the task is marked as cancelled. + + This test verifies that: + 1. _update_progress raises when is_task_cancelled returns True + 2. The exception from vectorize_documents is propagated as an indexing error + """ + # Setup + mock_redis_service = MagicMock() + # First progress callback call: treat as cancelled immediately + mock_redis_service.is_task_cancelled.return_value = True + mock_get_redis_service.return_value = mock_redis_service + + # Configure vdb_core + self.mock_vdb_core.check_index_exists.return_value = True + + # Make vectorize_documents invoke the progress callback (cancellation branch) + def vectorize_side_effect(*args, **kwargs): + cb = kwargs.get("progress_callback") + if cb: + cb(1, 2) # _update_progress will swallow and log cancellation + return 0 + + self.mock_vdb_core.vectorize_documents.side_effect = vectorize_side_effect + + # Provide minimal knowledge record for batch size lookup + with patch('backend.services.vectordatabase_service.get_knowledge_record') as mock_get_record: + mock_get_record.return_value = {"tenant_id": "tenant-1"} + with patch('backend.services.vectordatabase_service.tenant_config_manager') as mock_tenant_cfg: + mock_tenant_cfg.get_model_config.return_value = { + "chunk_batch": 10} + + data = [ + { + "path_or_url": "test_path", + "content": "some content", + "source_type": "minio", + "file_size": 123, + "metadata": {}, + } + ] + + # Execute: no exception should propagate because _update_progress swallows + result = ElasticSearchService.index_documents( + embedding_model=self.mock_embedding, + index_name="test_index", + data=data, + vdb_core=self.mock_vdb_core, + task_id="task-123", + ) + + self.assertTrue(result["success"]) + mock_redis_service.is_task_cancelled.assert_called() + self.mock_vdb_core.vectorize_documents.assert_called_once() + def test_accurate_search(self): """ Test accurate (keyword-based) search functionality. @@ -1035,8 +1154,10 @@ def test_search_hybrid_success(self): self.assertTrue("query_time_ms" in result) self.assertEqual(result["results"][0]["score"], 0.90) self.assertEqual(result["results"][0]["index"], "test_index") - self.assertEqual(result["results"][0]["score_details"]["accurate"], 0.85) - self.assertEqual(result["results"][0]["score_details"]["semantic"], 0.95) + self.assertEqual(result["results"][0] + ["score_details"]["accurate"], 0.85) + self.assertEqual(result["results"][0] + ["score_details"]["semantic"], 0.95) self.mock_vdb_core.hybrid_search.assert_called_once_with( index_names=["test_index"], query_text="test query", @@ -1082,7 +1203,8 @@ def test_search_hybrid_no_indices(self): weight_accurate=0.5, vdb_core=self.mock_vdb_core ) - self.assertIn("At least one index name is required", str(context.exception)) + self.assertIn("At least one index name is required", + str(context.exception)) def test_search_hybrid_invalid_top_k(self): """Test search_hybrid raises ValueError when top_k is invalid.""" @@ -1108,7 +1230,8 @@ def test_search_hybrid_invalid_weight(self): weight_accurate=1.5, vdb_core=self.mock_vdb_core ) - self.assertIn("weight_accurate must be between 0 and 1", str(context.exception)) + self.assertIn("weight_accurate must be between 0 and 1", + str(context.exception)) def test_search_hybrid_no_embedding_model(self): """Test search_hybrid raises ValueError when embedding model is not configured.""" @@ -1125,14 +1248,16 @@ def test_search_hybrid_no_embedding_model(self): weight_accurate=0.5, vdb_core=self.mock_vdb_core ) - self.assertIn("No embedding model configured", str(context.exception)) + self.assertIn("No embedding model configured", + str(context.exception)) finally: self.get_embedding_model_patcher.start() def test_search_hybrid_exception(self): """Test search_hybrid handles exceptions from vdb_core.""" - self.mock_vdb_core.hybrid_search.side_effect = Exception("Search failed") - + self.mock_vdb_core.hybrid_search.side_effect = Exception( + "Search failed") + with self.assertRaises(Exception) as context: ElasticSearchService.search_hybrid( index_names=["test_index"], @@ -1247,7 +1372,6 @@ def test_health_check_unhealthy(self): self.assertIn("Health check failed", str(context.exception)) - @patch('database.model_management_db.get_model_by_model_id') def test_summary_index_name(self, mock_get_model_by_model_id): """ @@ -1268,18 +1392,20 @@ def test_summary_index_name(self, mock_get_model_by_model_id): # Mock the new Map-Reduce functions with patch('utils.document_vector_utils.process_documents_for_clustering') as mock_process_docs, \ - patch('utils.document_vector_utils.kmeans_cluster_documents') as mock_cluster, \ - patch('utils.document_vector_utils.summarize_clusters_map_reduce') as mock_summarize, \ - patch('utils.document_vector_utils.merge_cluster_summaries') as mock_merge, \ - patch('database.model_management_db.get_model_by_model_id') as mock_get_model_internal: + patch('utils.document_vector_utils.kmeans_cluster_documents') as mock_cluster, \ + patch('utils.document_vector_utils.summarize_clusters_map_reduce') as mock_summarize, \ + patch('utils.document_vector_utils.merge_cluster_summaries') as mock_merge, \ + patch('database.model_management_db.get_model_by_model_id') as mock_get_model_internal: # Mock return values mock_process_docs.return_value = ( - {"doc1": {"chunks": [{"content": "test content"}]}}, # document_samples + # document_samples + {"doc1": {"chunks": [{"content": "test content"}]}}, {"doc1": np.array([0.1, 0.2, 0.3])} # doc_embeddings ) mock_cluster.return_value = {"doc1": 0} # clusters - mock_summarize.return_value = {0: "Test cluster summary"} # cluster_summaries + mock_summarize.return_value = { + 0: "Test cluster summary"} # cluster_summaries mock_merge.return_value = "Final merged summary" # final_summary mock_get_model_internal.return_value = { 'api_key': 'test_api_key', @@ -1336,7 +1462,7 @@ async def run_test(): tenant_id=None # Missing tenant_id ) self.assertIn("Tenant ID is required", str(context.exception)) - + asyncio.run(run_test()) def test_summary_index_name_no_documents(self): @@ -1349,9 +1475,9 @@ def test_summary_index_name_no_documents(self): """ # Mock the new Map-Reduce functions with patch('utils.document_vector_utils.process_documents_for_clustering') as mock_process_docs, \ - patch('utils.document_vector_utils.kmeans_cluster_documents') as mock_cluster, \ - patch('utils.document_vector_utils.summarize_clusters_map_reduce') as mock_summarize, \ - patch('utils.document_vector_utils.merge_cluster_summaries') as mock_merge: + patch('utils.document_vector_utils.kmeans_cluster_documents'), \ + patch('utils.document_vector_utils.summarize_clusters_map_reduce'), \ + patch('utils.document_vector_utils.merge_cluster_summaries'): # Mock return empty document_samples mock_process_docs.return_value = ( @@ -2005,7 +2131,9 @@ def test_semantic_search_success_status_200(self): index_names=["test_index"], query="valid query", top_k=10 ) - def test_vectorize_documents_success_status_200(self): + @patch('backend.services.vectordatabase_service.tenant_config_manager') + @patch('backend.services.vectordatabase_service.get_knowledge_record') + def test_vectorize_documents_success_status_200(self, mock_get_record, mock_tenant_cfg): """ Test vectorize_documents method returns status code 200 on success. @@ -2019,6 +2147,8 @@ def test_vectorize_documents_success_status_200(self): self.mock_vdb_core.vectorize_documents.return_value = 3 mock_embedding_model = MagicMock() mock_embedding_model.model = "test-model" + mock_get_record.return_value = {"tenant_id": "tenant-1"} + mock_tenant_cfg.get_model_config.return_value = {"chunk_batch": 10} test_data = [ { @@ -2516,7 +2646,489 @@ def test_get_embedding_model_missing_type(self, mock_tenant_config_manager): # Restart the mock for other tests self.get_embedding_model_patcher.start() + + @patch('backend.services.vectordatabase_service.get_redis_service') + def test_update_progress_success(self, mock_get_redis): + """Ensure _update_progress updates Redis progress when not cancelled.""" + from backend.services.vectordatabase_service import _update_progress + + mock_redis = MagicMock() + mock_redis.is_task_cancelled.return_value = False + mock_redis.save_progress_info.return_value = True + mock_get_redis.return_value = mock_redis + + _update_progress("task-1", 5, 10) + + mock_redis.is_task_cancelled.assert_called_once_with("task-1") + mock_redis.save_progress_info.assert_called_once_with("task-1", 5, 10) + + @patch('backend.services.vectordatabase_service.get_redis_service') + def test_update_progress_save_failure(self, mock_get_redis): + """_update_progress logs a warning when saving progress fails.""" + from backend.services.vectordatabase_service import _update_progress + + mock_redis = MagicMock() + mock_redis.is_task_cancelled.return_value = False + mock_redis.save_progress_info.return_value = False + mock_get_redis.return_value = mock_redis + + _update_progress("task-2", 1, 2) + + mock_redis.is_task_cancelled.assert_called_once_with("task-2") + mock_redis.save_progress_info.assert_called_once_with("task-2", 1, 2) + + +class TestRethrowOrPlain(unittest.TestCase): + def setUp(self): + self.es_service = ElasticSearchService() + self.mock_vdb_core = MagicMock() + self.mock_vdb_core.embedding_model = MagicMock() + self.mock_vdb_core.embedding_dim = 768 + + self.get_embedding_model_patcher = patch( + 'backend.services.vectordatabase_service.get_embedding_model') + self.mock_get_embedding = self.get_embedding_model_patcher.start() + self.mock_embedding = MagicMock() + self.mock_embedding.embedding_dim = 768 + self.mock_embedding.model = "test-model" + self.mock_get_embedding.return_value = self.mock_embedding + + def tearDown(self): + self.get_embedding_model_patcher.stop() + + def test_rethrow_or_plain_rethrows_json_error_code(self): + """_rethrow_or_plain should re-raise JSON payload when error_code present.""" + from backend.services.vectordatabase_service import _rethrow_or_plain + + with self.assertRaises(Exception) as exc: + _rethrow_or_plain(Exception('{"error_code":"E123","detail":"boom"}')) + self.assertIn('"error_code": "E123"', str(exc.exception)) + + def test_get_vector_db_core_unsupported_type(self): + """get_vector_db_core raises on unsupported db type.""" + from backend.services.vectordatabase_service import get_vector_db_core + + with self.assertRaises(ValueError) as exc: + get_vector_db_core(db_type="unsupported") + + self.assertIn("Unsupported vector database type", str(exc.exception)) + + def test_rethrow_or_plain_parses_error_code(self): + """_rethrow_or_plain rethrows JSON error_code payloads unchanged.""" + from backend.services.vectordatabase_service import _rethrow_or_plain + + with self.assertRaises(Exception) as exc: + _rethrow_or_plain(Exception('{"error_code":123,"detail":"boom"}')) + + self.assertIn("error_code", str(exc.exception)) + + @patch('services.redis_service.get_redis_service') + def test_full_delete_knowledge_base_no_files_redis_warning(self, mock_get_redis): + """full_delete_knowledge_base handles empty file list and surfaces Redis warnings.""" + mock_vdb_core = MagicMock() + mock_redis = MagicMock() + mock_redis.delete_knowledgebase_records.return_value = { + "total_deleted": 0, + "errors": [] + } + mock_get_redis.return_value = mock_redis + + with patch('backend.services.vectordatabase_service.ElasticSearchService.list_files', + new_callable=AsyncMock, return_value={"files": []}) as mock_list_files, \ + patch('backend.services.vectordatabase_service.ElasticSearchService.delete_index', + new_callable=AsyncMock, return_value={"status": "success"}) as mock_delete_index: + + async def run_test(): + return await ElasticSearchService.full_delete_knowledge_base( + index_name="kb-1", + vdb_core=mock_vdb_core, + user_id="user-1", + ) + + result = asyncio.run(run_test()) + + self.assertEqual(result["minio_cleanup"]["total_files_found"], 0) + self.assertEqual(result["redis_cleanup"].get("errors"), []) + self.assertIn("redis_warnings", result) + self.assertIn("redis_warnings", result) + mock_list_files.assert_awaited_once() + mock_delete_index.assert_awaited_once() + + @patch('services.redis_service.get_redis_service') + def test_full_delete_knowledge_base_minio_and_redis_error(self, mock_get_redis): + """full_delete_knowledge_base logs minio summary and handles redis cleanup errors.""" + mock_vdb_core = MagicMock() + mock_redis = MagicMock() + # Redis cleanup will raise to hit error branch (lines 289-292) + mock_redis.delete_knowledgebase_records.side_effect = Exception("redis boom") + mock_get_redis.return_value = mock_redis + + files_payload = { + "files": [ + {"path_or_url": "obj-success", "source_type": "minio"}, + {"path_or_url": "obj-fail", "source_type": "minio"}, + ] + } + + # delete_file returns success for first, failure for second + with patch('backend.services.vectordatabase_service.ElasticSearchService.list_files', + new_callable=AsyncMock, return_value=files_payload) as mock_list_files, \ + patch('backend.services.vectordatabase_service.delete_file') as mock_delete_file, \ + patch('backend.services.vectordatabase_service.ElasticSearchService.delete_index', + new_callable=AsyncMock, return_value={"status": "success"}) as mock_delete_index: + mock_delete_file.side_effect = [ + {"success": True}, + {"success": False, "error": "minio failed"}, + ] + + async def run_test(): + return await ElasticSearchService.full_delete_knowledge_base( + index_name="kb-2", + vdb_core=mock_vdb_core, + user_id="user-2", + ) + + result = asyncio.run(run_test()) + + # MinIO summary should reflect one success and one failure (line 270 hit) + self.assertEqual(result["minio_cleanup"]["deleted_count"], 1) + self.assertEqual(result["minio_cleanup"]["failed_count"], 1) + # Redis cleanup error should be surfaced + self.assertIn("error", result["redis_cleanup"]) + mock_list_files.assert_awaited_once() + mock_delete_index.assert_awaited_once_with("kb-2", mock_vdb_core, "user-2") + + @patch('backend.services.vectordatabase_service.create_knowledge_record') + def test_create_knowledge_base_create_index_failure(self, mock_create_record): + """create_knowledge_base raises when index creation fails.""" + mock_create_record.return_value = { + "knowledge_id": 1, + "index_name": "1-uuid", + "knowledge_name": "kb" + } + self.mock_vdb_core.create_index.return_value = False + + with self.assertRaises(Exception) as exc: + ElasticSearchService.create_knowledge_base( + knowledge_name="kb", + embedding_dim=256, + vdb_core=self.mock_vdb_core, + user_id="user-1", + tenant_id="tenant-1", + ) + + self.assertIn("Failed to create index", str(exc.exception)) + + @patch('backend.services.vectordatabase_service.create_knowledge_record') + def test_create_knowledge_base_raises_on_exception(self, mock_create_record): + """create_knowledge_base wraps unexpected errors.""" + mock_create_record.return_value = { + "knowledge_id": 2, + "index_name": "2-uuid", + "knowledge_name": "kb2" + } + self.mock_vdb_core.create_index.side_effect = Exception("boom") + + with self.assertRaises(Exception) as exc: + ElasticSearchService.create_knowledge_base( + knowledge_name="kb2", + embedding_dim=128, + vdb_core=self.mock_vdb_core, + user_id="user-2", + tenant_id="tenant-2", + ) + + self.assertIn("Error creating knowledge base", str(exc.exception)) + + @patch('backend.services.vectordatabase_service.get_knowledge_record') + def test_index_documents_default_batch_without_tenant(self, mock_get_record): + """index_documents defaults embedding batch size to 10 when tenant is missing.""" + mock_get_record.return_value = None + self.mock_vdb_core.check_index_exists.return_value = True + self.mock_vdb_core.vectorize_documents.return_value = 1 + + data = [{ + "path_or_url": "p1", + "content": "c1", + "metadata": {"title": "t1"}, + }] + embedding = MagicMock() + embedding.model = "model-x" + + result = ElasticSearchService.index_documents( + embedding_model=embedding, + index_name="idx", + data=data, + vdb_core=self.mock_vdb_core, + ) + + self.assertTrue(result["success"]) + _, kwargs = self.mock_vdb_core.vectorize_documents.call_args + self.assertEqual(kwargs["embedding_batch_size"], 10) + + @patch('backend.services.vectordatabase_service.tenant_config_manager') + @patch('backend.services.vectordatabase_service.get_knowledge_record') + @patch('backend.services.vectordatabase_service.get_redis_service') + def test_index_documents_updates_final_progress(self, mock_get_redis, mock_get_record, mock_tenant_cfg): + """index_documents sends final progress update to Redis when task_id is provided.""" + mock_get_record.return_value = {"tenant_id": "tenant-1"} + mock_tenant_cfg.get_model_config.return_value = {"chunk_batch": 4} + mock_redis = MagicMock() + mock_get_redis.return_value = mock_redis + + self.mock_vdb_core.check_index_exists.return_value = True + self.mock_vdb_core.vectorize_documents.return_value = 2 + + data = [ + {"path_or_url": "p1", "content": "c1", "metadata": {}}, + {"path_or_url": "p2", "content": "c2", "metadata": {}}, + ] + + result = ElasticSearchService.index_documents( + embedding_model=self.mock_embedding, + index_name="idx", + data=data, + vdb_core=self.mock_vdb_core, + task_id="task-xyz", + ) + + self.assertTrue(result["success"]) + mock_redis.save_progress_info.assert_called() + last_call = mock_redis.save_progress_info.call_args_list[-1] + self.assertEqual(last_call[0], ("task-xyz", 2, 2)) + + @patch('backend.services.vectordatabase_service.get_redis_service') + @patch('backend.services.vectordatabase_service.get_knowledge_record') + @patch('backend.services.vectordatabase_service.tenant_config_manager') + def test_index_documents_progress_init_and_final_errors(self, mock_tenant_cfg, mock_get_record, mock_get_redis): + """index_documents should continue when progress save fails during init and final updates.""" + mock_get_record.return_value = {"tenant_id": "tenant-1"} + mock_tenant_cfg.get_model_config.return_value = {"chunk_batch": 4} + + mock_redis = MagicMock() + # First call (init) raises, second call (final) raises + mock_redis.save_progress_info.side_effect = [Exception("init fail"), Exception("final fail")] + mock_redis.is_task_cancelled.return_value = False + mock_get_redis.return_value = mock_redis + + self.mock_vdb_core.check_index_exists.return_value = True + self.mock_vdb_core.vectorize_documents.return_value = 1 + + data = [{"path_or_url": "p1", "content": "c1", "metadata": {}}] + + result = ElasticSearchService.index_documents( + embedding_model=self.mock_embedding, + index_name="idx", + data=data, + vdb_core=self.mock_vdb_core, + task_id="task-err", + ) + + self.assertTrue(result["success"]) + # two attempts to save progress (init and final) + self.assertEqual(mock_redis.save_progress_info.call_count, 2) + + @patch('backend.services.vectordatabase_service.get_all_files_status') + @patch('backend.services.vectordatabase_service.get_redis_service') + def test_list_files_handles_invalid_create_time_and_failed_tasks(self, mock_get_redis, mock_get_files_status): + """list_files handles invalid timestamps, progress overrides, and error info.""" + self.mock_vdb_core.get_documents_detail.return_value = [ + { + "path_or_url": "file1", + "filename": "file1.txt", + "file_size": 10, + "create_time": "invalid", + "chunk_count": 1 + } + ] + self.mock_vdb_core.client.count.return_value = {"count": 7} + + mock_get_files_status.return_value = { + "file1": { + "state": "PROCESS_FAILED", + "latest_task_id": "task-1", + "processed_chunks": 1, + "total_chunks": 5, + "source_type": "minio", + "original_filename": "file1.txt" + } + } + + mock_redis = MagicMock() + mock_redis.get_progress_info.return_value = { + "processed_chunks": 2, + "total_chunks": 5 + } + mock_redis.get_error_info.return_value = "boom error" + mock_get_redis.return_value = mock_redis + + async def run_test(): + return await ElasticSearchService.list_files( + index_name="idx", + include_chunks=False, + vdb_core=self.mock_vdb_core + ) + + result = asyncio.run(run_test()) + self.assertEqual(len(result["files"]), 1) + file_info = result["files"][0] + self.assertEqual(file_info["chunk_count"], 7) + self.assertEqual(file_info["file_size"], 10) + self.assertEqual(file_info["status"], "PROCESS_FAILED") + self.assertEqual(file_info["processed_chunk_num"], 2) + self.assertEqual(file_info["total_chunk_num"], 5) + self.assertEqual(file_info["error_reason"], "boom error") + self.assertIsInstance(file_info["create_time"], int) + + @patch('backend.services.vectordatabase_service.get_all_files_status') + @patch('backend.services.vectordatabase_service.get_redis_service') + def test_list_files_warning_and_progress_error_branches(self, mock_get_redis, mock_get_files_status): + """list_files covers chunk count warning, file size error, progress overrides, and redis failures.""" + # Existing ES file triggers count warning (lines 749-750 and 910-916) + self.mock_vdb_core.get_documents_detail.return_value = [ + { + "path_or_url": "file-es", + "filename": "file-es.txt", + "file_size": 5, + "create_time": "2024-01-01T00:00:00", + "chunk_count": 1 + } + ] + # First count call for ES file, second for completed file at include_chunks=False + self.mock_vdb_core.client.count.side_effect = [ + Exception("count fail initial"), + Exception("count fail final"), + ] + + # Two tasks from Celery status to exercise progress success and failure + mock_get_files_status.return_value = { + "file-processing": { + "state": "PROCESSING", + "latest_task_id": "t1", + "source_type": "minio", + "original_filename": "fp.txt", + "processed_chunks": 1, + "total_chunks": 3, + }, + "file-failed": { + "state": "PROCESS_FAILED", + "latest_task_id": "t2", + "source_type": "minio", + "original_filename": "ff.txt", + }, + } + + mock_redis = MagicMock() + # Progress info: first returns dict, second raises to hit lines 815-816 + mock_redis.get_progress_info.side_effect = [ + {"processed_chunks": 2, "total_chunks": 4}, + Exception("progress boom"), + ] + # get_error_info raises to hit 847-848 + mock_redis.get_error_info.side_effect = Exception("error info boom") + mock_get_redis.return_value = mock_redis + + with patch('backend.services.vectordatabase_service.get_file_size', side_effect=Exception("size boom")): + async def run_test(): + return await ElasticSearchService.list_files( + index_name="idx", + include_chunks=False, + vdb_core=self.mock_vdb_core + ) + + result = asyncio.run(run_test()) + + # Ensure both ES file and processing files are returned + paths = {f["path_or_url"] for f in result["files"]} + self.assertIn("file-es", paths) + self.assertIn("file-processing", paths) + self.assertIn("file-failed", paths) + # Processing file gets progress override + proc_file = next(f for f in result["files"] if f["path_or_url"] == "file-processing") + self.assertEqual(proc_file["processed_chunk_num"], 2) + self.assertEqual(proc_file["total_chunk_num"], 4) + # Failed file retains default chunk_count fallback + failed_file = next(f for f in result["files"] if f["path_or_url"] == "file-failed") + self.assertEqual(failed_file.get("chunk_count", 0), 0) + + @patch('backend.services.vectordatabase_service.get_all_files_status', return_value={}) + def test_list_files_with_chunks_updates_chunk_count(self, mock_get_files_status): + """list_files include_chunks path refreshes chunk counts.""" + self.mock_vdb_core.get_documents_detail.return_value = [ + { + "path_or_url": "file1", + "filename": "file1.txt", + "file_size": 10, + "create_time": "2024-01-01T00:00:00" + } + ] + self.mock_vdb_core.multi_search.return_value = { + "responses": [ + { + "hits": { + "hits": [ + {"_source": { + "id": "doc1", + "title": "t", + "content": "c", + "create_time": "2024-01-01T00:00:00" + }} + ] + } + } + ] + } + self.mock_vdb_core.client.count.return_value = {"count": 2} + + async def run_test(): + return await ElasticSearchService.list_files( + index_name="idx", + include_chunks=True, + vdb_core=self.mock_vdb_core + ) + + result = asyncio.run(run_test()) + file_info = result["files"][0] + self.assertEqual(file_info["chunk_count"], 2) + self.assertEqual(len(file_info["chunks"]), 1) + + def test_summary_index_name_streams_generator_error(self): + """summary_index_name streams error payloads when generator fails.""" + class BadIterable: + def __iter__(self): + raise RuntimeError("stream failure") + + with patch('utils.document_vector_utils.process_documents_for_clustering') as mock_process_docs, \ + patch('utils.document_vector_utils.kmeans_cluster_documents') as mock_cluster, \ + patch('utils.document_vector_utils.summarize_clusters_map_reduce') as mock_summarize, \ + patch('utils.document_vector_utils.merge_cluster_summaries', return_value=BadIterable()): + + mock_process_docs.return_value = ( + {"doc1": {"chunks": [{"content": "x"}]}}, + {"doc1": MagicMock()} + ) + mock_cluster.return_value = {"doc1": 0} + mock_summarize.return_value = {0: "summary"} + + async def run_test(): + response = await self.es_service.summary_index_name( + index_name="idx", + batch_size=100, + vdb_core=self.mock_vdb_core, + language="en", + model_id=None, + tenant_id="tenant-1", + ) + messages = [] + async for chunk in response.body_iterator: + messages.append(chunk) + break + return messages + + messages = asyncio.run(run_test()) + self.assertTrue(any("error" in msg for msg in messages)) + if __name__ == '__main__': unittest.main() diff --git a/test/backend/test_cluster_summarization.py b/test/backend/test_cluster_summarization.py index 9fd0a3b91..1e9bc1658 100644 --- a/test/backend/test_cluster_summarization.py +++ b/test/backend/test_cluster_summarization.py @@ -10,11 +10,39 @@ import numpy as np import pytest -# Add backend to path +# Mock consts module before patching backend.database.client to avoid ImportError +# backend.database.client imports from consts.const, so we need to mock it first +consts_mock = MagicMock() +consts_const_mock = MagicMock() +# Set required constants that backend.database.client might use +consts_const_mock.MINIO_ENDPOINT = "http://localhost:9000" +consts_const_mock.MINIO_ACCESS_KEY = "test_access_key" +consts_const_mock.MINIO_SECRET_KEY = "test_secret_key" +consts_const_mock.MINIO_REGION = "us-east-1" +consts_const_mock.MINIO_DEFAULT_BUCKET = "test-bucket" +consts_const_mock.POSTGRES_HOST = "localhost" +consts_const_mock.POSTGRES_USER = "test_user" +consts_const_mock.NEXENT_POSTGRES_PASSWORD = "test_password" +consts_const_mock.POSTGRES_DB = "test_db" +consts_const_mock.POSTGRES_PORT = 5432 +consts_const_mock.LANGUAGE = {"ZH": "zh", "EN": "en"} +consts_mock.const = consts_const_mock +sys.modules['consts'] = consts_mock +sys.modules['consts.const'] = consts_const_mock + +# Add backend to path before patching backend modules current_dir = os.path.dirname(os.path.abspath(__file__)) backend_dir = os.path.abspath(os.path.join(current_dir, "../../backend")) sys.path.insert(0, backend_dir) +# Patch storage factory and MinIO config validation to avoid errors during initialization +# These patches must be started before any imports that use MinioClient +storage_client_mock = MagicMock() +minio_client_mock = MagicMock() +patch('nexent.storage.storage_client_factory.create_storage_client_from_config', return_value=storage_client_mock).start() +patch('nexent.storage.minio_config.MinIOStorageConfig.validate', lambda self: None).start() +patch('backend.database.client.MinioClient', return_value=minio_client_mock).start() + from backend.utils.document_vector_utils import ( extract_cluster_content, summarize_cluster, diff --git a/test/backend/test_config_service.py b/test/backend/test_config_service.py index a30e86bd7..9935797cc 100644 --- a/test/backend/test_config_service.py +++ b/test/backend/test_config_service.py @@ -431,5 +431,45 @@ async def test_startup_initialization_with_custom_version(self, mock_logger, moc assert version_logged, "Custom APP version should be logged" +class TestTenantConfigService: + """Unit tests for tenant_config_service helpers""" + + @patch('backend.services.tenant_config_service.get_selected_knowledge_list') + def test_build_knowledge_name_mapping_prefers_knowledge_name(self, mock_get_selected): + """Ensure knowledge_name is used as key when present.""" + mock_get_selected.return_value = [ + {"knowledge_name": "User Docs", "index_name": "index_user_docs"}, + {"knowledge_name": "API Docs", "index_name": "index_api_docs"}, + ] + + from backend.services.tenant_config_service import build_knowledge_name_mapping + + mapping = build_knowledge_name_mapping(tenant_id="t1", user_id="u1") + + assert mapping == { + "User Docs": "index_user_docs", + "API Docs": "index_api_docs", + } + mock_get_selected.assert_called_once_with(tenant_id="t1", user_id="u1") + + @patch('backend.services.tenant_config_service.get_selected_knowledge_list') + def test_build_knowledge_name_mapping_fallbacks_to_index_name(self, mock_get_selected): + """Fallback to index_name when knowledge_name is missing.""" + mock_get_selected.return_value = [ + {"index_name": "index_fallback_only"}, + {"knowledge_name": None, "index_name": "index_none_name"}, + ] + + from backend.services.tenant_config_service import build_knowledge_name_mapping + + mapping = build_knowledge_name_mapping(tenant_id="t2", user_id="u2") + + assert mapping == { + "index_fallback_only": "index_fallback_only", + "index_none_name": "index_none_name", + } + mock_get_selected.assert_called_once_with(tenant_id="t2", user_id="u2") + + if __name__ == '__main__': pytest.main() diff --git a/test/backend/test_document_vector_integration.py b/test/backend/test_document_vector_integration.py index 015818d32..8e05abe86 100644 --- a/test/backend/test_document_vector_integration.py +++ b/test/backend/test_document_vector_integration.py @@ -11,11 +11,38 @@ import numpy as np import pytest -# Add backend to path +# Mock consts module before patching backend.database.client to avoid ImportError +# backend.database.client imports from consts.const, so we need to mock it first +consts_mock = MagicMock() +consts_const_mock = MagicMock() +# Set required constants that backend.database.client might use +consts_const_mock.MINIO_ENDPOINT = "http://localhost:9000" +consts_const_mock.MINIO_ACCESS_KEY = "test_access_key" +consts_const_mock.MINIO_SECRET_KEY = "test_secret_key" +consts_const_mock.MINIO_REGION = "us-east-1" +consts_const_mock.MINIO_DEFAULT_BUCKET = "test-bucket" +consts_const_mock.POSTGRES_HOST = "localhost" +consts_const_mock.POSTGRES_USER = "test_user" +consts_const_mock.NEXENT_POSTGRES_PASSWORD = "test_password" +consts_const_mock.POSTGRES_DB = "test_db" +consts_const_mock.POSTGRES_PORT = 5432 +consts_mock.const = consts_const_mock +sys.modules['consts'] = consts_mock +sys.modules['consts.const'] = consts_const_mock + +# Add backend to path before patching backend modules current_dir = os.path.dirname(os.path.abspath(__file__)) backend_dir = os.path.abspath(os.path.join(current_dir, "../../backend")) sys.path.insert(0, backend_dir) +# Patch storage factory and MinIO config validation to avoid errors during initialization +# These patches must be started before any imports that use MinioClient +storage_client_mock = MagicMock() +minio_client_mock = MagicMock() +patch('nexent.storage.storage_client_factory.create_storage_client_from_config', return_value=storage_client_mock).start() +patch('nexent.storage.minio_config.MinIOStorageConfig.validate', lambda self: None).start() +patch('backend.database.client.MinioClient', return_value=minio_client_mock).start() + from backend.utils.document_vector_utils import ( calculate_document_embedding, auto_determine_k, diff --git a/test/backend/test_document_vector_utils.py b/test/backend/test_document_vector_utils.py index f03ed3346..1b4f89997 100644 --- a/test/backend/test_document_vector_utils.py +++ b/test/backend/test_document_vector_utils.py @@ -10,11 +10,39 @@ import numpy as np import pytest -# Add backend to path +# Mock consts module before patching backend.database.client to avoid ImportError +# backend.database.client imports from consts.const, so we need to mock it first +consts_mock = MagicMock() +consts_const_mock = MagicMock() +# Set required constants that backend.database.client might use +consts_const_mock.MINIO_ENDPOINT = "http://localhost:9000" +consts_const_mock.MINIO_ACCESS_KEY = "test_access_key" +consts_const_mock.MINIO_SECRET_KEY = "test_secret_key" +consts_const_mock.MINIO_REGION = "us-east-1" +consts_const_mock.MINIO_DEFAULT_BUCKET = "test-bucket" +consts_const_mock.POSTGRES_HOST = "localhost" +consts_const_mock.POSTGRES_USER = "test_user" +consts_const_mock.NEXENT_POSTGRES_PASSWORD = "test_password" +consts_const_mock.POSTGRES_DB = "test_db" +consts_const_mock.POSTGRES_PORT = 5432 +consts_const_mock.LANGUAGE = {"ZH": "zh", "EN": "en"} +consts_mock.const = consts_const_mock +sys.modules['consts'] = consts_mock +sys.modules['consts.const'] = consts_const_mock + +# Add backend to path before patching backend modules current_dir = os.path.dirname(os.path.abspath(__file__)) backend_dir = os.path.abspath(os.path.join(current_dir, "../../backend")) sys.path.insert(0, backend_dir) +# Patch storage factory and MinIO config validation to avoid errors during initialization +# These patches must be started before any imports that use MinioClient +storage_client_mock = MagicMock() +minio_client_mock = MagicMock() +patch('nexent.storage.storage_client_factory.create_storage_client_from_config', return_value=storage_client_mock).start() +patch('nexent.storage.minio_config.MinIOStorageConfig.validate', lambda self: None).start() +patch('backend.database.client.MinioClient', return_value=minio_client_mock).start() + from backend.utils.document_vector_utils import ( calculate_document_embedding, auto_determine_k, @@ -226,6 +254,28 @@ def test_summarize_document_with_model_placeholder(self): assert isinstance(result, str) assert len(result) > 0 + def test_summarize_document_with_model_success(self): + """Test document summarization when model config exists and LLM returns value""" + with patch('backend.utils.document_vector_utils.get_model_by_model_id') as mock_get_model, \ + patch('backend.utils.document_vector_utils.call_llm_for_system_prompt') as mock_llm: + mock_get_model.return_value = {"id": 1} + mock_llm.return_value = "Generated summary\n" + + result = summarize_document( + document_content="LLM content", + filename="doc.pdf", + language="en", + max_words=50, + model_id=1, + tenant_id="tenant" + ) + + assert result == "Generated summary" + mock_llm.assert_called_once() + call_args = mock_llm.call_args.kwargs + assert call_args["model_id"] == 1 + assert call_args["tenant_id"] == "tenant" + class TestSummarizeCluster: """Test cluster summarization""" @@ -250,6 +300,27 @@ def test_summarize_cluster_with_model_placeholder(self): assert isinstance(result, str) assert len(result) > 0 + def test_summarize_cluster_with_model_success(self): + """Test cluster summarization when model config exists and LLM returns value""" + with patch('backend.utils.document_vector_utils.get_model_by_model_id') as mock_get_model, \ + patch('backend.utils.document_vector_utils.call_llm_for_system_prompt') as mock_llm: + mock_get_model.return_value = {"id": 1} + mock_llm.return_value = "Cluster summary text " + + result = summarize_cluster( + document_summaries=["Doc 1 summary", "Doc 2 summary"], + language="en", + max_words=120, + model_id=1, + tenant_id="tenant" + ) + + assert result == "Cluster summary text" + mock_llm.assert_called_once() + call_args = mock_llm.call_args.kwargs + assert call_args["model_id"] == 1 + assert call_args["tenant_id"] == "tenant" + class TestSummarizeClustersMapReduce: """Test map-reduce cluster summarization""" diff --git a/test/backend/test_document_vector_utils_coverage.py b/test/backend/test_document_vector_utils_coverage.py index b442e47e4..82ac1d646 100644 --- a/test/backend/test_document_vector_utils_coverage.py +++ b/test/backend/test_document_vector_utils_coverage.py @@ -10,11 +10,38 @@ import numpy as np import pytest -# Add backend to path +# Mock consts module before patching backend.database.client to avoid ImportError +# backend.database.client imports from consts.const, so we need to mock it first +consts_mock = MagicMock() +consts_const_mock = MagicMock() +# Set required constants that backend.database.client might use +consts_const_mock.MINIO_ENDPOINT = "http://localhost:9000" +consts_const_mock.MINIO_ACCESS_KEY = "test_access_key" +consts_const_mock.MINIO_SECRET_KEY = "test_secret_key" +consts_const_mock.MINIO_REGION = "us-east-1" +consts_const_mock.MINIO_DEFAULT_BUCKET = "test-bucket" +consts_const_mock.POSTGRES_HOST = "localhost" +consts_const_mock.POSTGRES_USER = "test_user" +consts_const_mock.NEXENT_POSTGRES_PASSWORD = "test_password" +consts_const_mock.POSTGRES_DB = "test_db" +consts_const_mock.POSTGRES_PORT = 5432 +consts_mock.const = consts_const_mock +sys.modules['consts'] = consts_mock +sys.modules['consts.const'] = consts_const_mock + +# Add backend to path before patching backend modules current_dir = os.path.dirname(os.path.abspath(__file__)) backend_dir = os.path.abspath(os.path.join(current_dir, "../../backend")) sys.path.insert(0, backend_dir) +# Patch storage factory and MinIO config validation to avoid errors during initialization +# These patches must be started before any imports that use MinioClient +storage_client_mock = MagicMock() +minio_client_mock = MagicMock() +patch('nexent.storage.storage_client_factory.create_storage_client_from_config', return_value=storage_client_mock).start() +patch('nexent.storage.minio_config.MinIOStorageConfig.validate', lambda self: None).start() +patch('backend.database.client.MinioClient', return_value=minio_client_mock).start() + from backend.utils.document_vector_utils import ( get_documents_from_es, process_documents_for_clustering, diff --git a/test/backend/test_summary_formatting.py b/test/backend/test_summary_formatting.py index 31b656e55..22f8dec36 100644 --- a/test/backend/test_summary_formatting.py +++ b/test/backend/test_summary_formatting.py @@ -5,10 +5,38 @@ import pytest import sys import os +from unittest.mock import MagicMock, patch -# Add backend to path +# Mock consts module before patching backend.database.client to avoid ImportError +# backend.database.client imports from consts.const, so we need to mock it first +consts_mock = MagicMock() +consts_const_mock = MagicMock() +# Set required constants that backend.database.client might use +consts_const_mock.MINIO_ENDPOINT = "http://localhost:9000" +consts_const_mock.MINIO_ACCESS_KEY = "test_access_key" +consts_const_mock.MINIO_SECRET_KEY = "test_secret_key" +consts_const_mock.MINIO_REGION = "us-east-1" +consts_const_mock.MINIO_DEFAULT_BUCKET = "test-bucket" +consts_const_mock.POSTGRES_HOST = "localhost" +consts_const_mock.POSTGRES_USER = "test_user" +consts_const_mock.NEXENT_POSTGRES_PASSWORD = "test_password" +consts_const_mock.POSTGRES_DB = "test_db" +consts_const_mock.POSTGRES_PORT = 5432 +consts_mock.const = consts_const_mock +sys.modules['consts'] = consts_mock +sys.modules['consts.const'] = consts_const_mock + +# Add backend to path before patching backend modules sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'backend')) +# Patch storage factory and MinIO config validation to avoid errors during initialization +# These patches must be started before any imports that use MinioClient +storage_client_mock = MagicMock() +minio_client_mock = MagicMock() +patch('nexent.storage.storage_client_factory.create_storage_client_from_config', return_value=storage_client_mock).start() +patch('nexent.storage.minio_config.MinIOStorageConfig.validate', lambda self: None).start() +patch('backend.database.client.MinioClient', return_value=minio_client_mock).start() + from utils.document_vector_utils import merge_cluster_summaries diff --git a/test/backend/utils/test_file_management_utils.py b/test/backend/utils/test_file_management_utils.py index eaa3a1261..02553db8f 100644 --- a/test/backend/utils/test_file_management_utils.py +++ b/test/backend/utils/test_file_management_utils.py @@ -300,6 +300,123 @@ async def test_get_all_files_status_connect_error_and_non200(fmu, monkeypatch): assert out2 == {} +@pytest.mark.asyncio +async def test_get_all_files_status_no_tasks_returns_empty(fmu, monkeypatch): + fake_client = _FakeAsyncClient(_Resp(200, [])) + monkeypatch.setattr(fmu, "httpx", types.SimpleNamespace(AsyncClient=lambda: fake_client)) + + out = await fmu.get_all_files_status("idx-empty") + assert out == {} + + +@pytest.mark.asyncio +async def test_get_all_files_status_forward_updates_and_redis_progress(fmu, monkeypatch): + tasks_list = [ + { + "id": "10", + "task_name": "process", + "index_name": "idx", + "path_or_url": "/p2", + "original_filename": "f2", + "source_type": "local", + "status": "SUCCESS", + "created_at": 1, + }, + { + "id": "20", + "task_name": "forward", + "index_name": "idx", + "path_or_url": "/p2", + "original_filename": "f2", + "source_type": "local", + "status": "STARTED", + "created_at": 5, # later than process to trigger forward branch + }, + ] + fake_client = _FakeAsyncClient(_Resp(200, tasks_list)) + monkeypatch.setattr(fmu, "httpx", types.SimpleNamespace(AsyncClient=lambda: fake_client)) + async def _fake_convert(*a, **k): + return "FORWARDING" + monkeypatch.setattr(fmu, "_convert_to_custom_state", _fake_convert) + + # Stub redis_service with progress info + services_pkg = types.ModuleType("services") + services_pkg.__path__ = [] + sys.modules["services"] = services_pkg + redis_mod = types.ModuleType("services.redis_service") + redis_mod.get_redis_service = lambda: types.SimpleNamespace( + get_progress_info=lambda task_id: {"processed_chunks": 7, "total_chunks": 9} + ) + sys.modules["services.redis_service"] = redis_mod + + out = await fmu.get_all_files_status("idx") + assert out["/p2"]["state"] == "FORWARDING" + assert out["/p2"]["latest_task_id"] == "20" + assert out["/p2"]["processed_chunks"] == 7 + assert out["/p2"]["total_chunks"] == 9 + + +@pytest.mark.asyncio +async def test_get_all_files_status_redis_progress_exception(fmu, monkeypatch): + tasks_list = [ + { + "id": "30", + "task_name": "forward", + "index_name": "idx", + "path_or_url": "/p3", + "original_filename": "f3", + "source_type": "local", + "status": "STARTED", + "created_at": 2, + }, + ] + fake_client = _FakeAsyncClient(_Resp(200, tasks_list)) + monkeypatch.setattr(fmu, "httpx", types.SimpleNamespace(AsyncClient=lambda: fake_client)) + async def _fake_convert(*a, **k): + return "FORWARDING" + monkeypatch.setattr(fmu, "_convert_to_custom_state", _fake_convert) + + # Redis service raising exception to hit exception path + services_pkg = types.ModuleType("services") + services_pkg.__path__ = [] + sys.modules["services"] = services_pkg + redis_mod = types.ModuleType("services.redis_service") + def _boom(): + raise RuntimeError("redis down") + redis_mod.get_redis_service = lambda: types.SimpleNamespace(get_progress_info=lambda task_id: _boom()) + sys.modules["services.redis_service"] = redis_mod + + out = await fmu.get_all_files_status("idx") + assert out["/p3"]["state"] == "FORWARDING" + assert out["/p3"]["processed_chunks"] is None + assert out["/p3"]["total_chunks"] is None + + +@pytest.mark.asyncio +async def test_get_all_files_status_outer_exception_returns_empty(fmu, monkeypatch): + tasks_list = [ + { + "id": "40", + "task_name": "process", + "index_name": "idx", + "path_or_url": "/p4", + "original_filename": "f4", + "source_type": "local", + "status": "SUCCESS", + "created_at": 1, + }, + ] + fake_client = _FakeAsyncClient(_Resp(200, tasks_list)) + monkeypatch.setattr(fmu, "httpx", types.SimpleNamespace(AsyncClient=lambda: fake_client)) + + def _boom(*a, **k): + raise RuntimeError("convert failed") + monkeypatch.setattr(fmu, "_convert_to_custom_state", _boom) + + out = await fmu.get_all_files_status("idx") + assert out == {} + + # -------------------- _convert_to_custom_state -------------------- @@ -379,3 +496,211 @@ def test_get_file_size_invalid_source_type(fmu): assert fmu.get_file_size("http", "http://x") == 0 +# -------------------- Additional coverage for get_all_files_status -------------------- + + +@pytest.mark.asyncio +async def test_get_all_files_status_forward_created_at_not_greater(fmu, monkeypatch): + """Test forward task with created_at not greater than latest_forward_created_at (line 195)""" + tasks_list = [ + { + "id": "20", + "task_name": "forward", + "index_name": "idx", + "path_or_url": "/p5", + "original_filename": "f5", + "source_type": "local", + "status": "STARTED", + "created_at": 5, + }, + { + "id": "21", + "task_name": "forward", + "index_name": "idx", + "path_or_url": "/p5", + "original_filename": "f5", + "source_type": "local", + "status": "SUCCESS", + "created_at": 3, # Less than previous forward task, should not update + }, + ] + fake_client = _FakeAsyncClient(_Resp(200, tasks_list)) + monkeypatch.setattr(fmu, "httpx", types.SimpleNamespace(AsyncClient=lambda: fake_client)) + async def _fake_convert(*a, **k): + return "FORWARDING" + monkeypatch.setattr(fmu, "_convert_to_custom_state", _fake_convert) + + out = await fmu.get_all_files_status("idx") + # Should use the first forward task (id=20) as latest since it has higher created_at + assert out["/p5"]["latest_task_id"] == "20" + + +@pytest.mark.asyncio +async def test_get_all_files_status_empty_task_id(fmu, monkeypatch): + """Test when task_id is empty string (line 221 - not entering if branch)""" + tasks_list = [ + { + "id": "", # Empty task_id + "task_name": "process", + "index_name": "idx", + "path_or_url": "/p6", + "original_filename": "f6", + "source_type": "local", + "status": "SUCCESS", + "created_at": 1, + }, + ] + fake_client = _FakeAsyncClient(_Resp(200, tasks_list)) + monkeypatch.setattr(fmu, "httpx", types.SimpleNamespace(AsyncClient=lambda: fake_client)) + async def _fake_convert(*a, **k): + return "COMPLETED" + monkeypatch.setattr(fmu, "_convert_to_custom_state", _fake_convert) + + # Stub redis_service to ensure it's not called + services_pkg = types.ModuleType("services") + services_pkg.__path__ = [] + sys.modules["services"] = services_pkg + redis_mod = types.ModuleType("services.redis_service") + redis_called = {"called": False} + def _track_call(task_id): + redis_called["called"] = True + return {} + redis_mod.get_redis_service = lambda: types.SimpleNamespace( + get_progress_info=_track_call + ) + sys.modules["services.redis_service"] = redis_mod + + out = await fmu.get_all_files_status("idx") + assert out["/p6"]["latest_task_id"] == "" + # Redis should not be called when task_id is empty + assert redis_called["called"] is False + + +@pytest.mark.asyncio +async def test_get_all_files_status_redis_progress_info_none(fmu, monkeypatch): + """Test when progress_info is None (line 226, 237 - entering else branch)""" + tasks_list = [ + { + "id": "50", + "task_name": "forward", + "index_name": "idx", + "path_or_url": "/p7", + "original_filename": "f7", + "source_type": "local", + "status": "STARTED", + "created_at": 1, + "processed_chunks": 5, + "total_chunks": 10, + }, + ] + fake_client = _FakeAsyncClient(_Resp(200, tasks_list)) + monkeypatch.setattr(fmu, "httpx", types.SimpleNamespace(AsyncClient=lambda: fake_client)) + async def _fake_convert(*a, **k): + return "FORWARDING" + monkeypatch.setattr(fmu, "_convert_to_custom_state", _fake_convert) + + # Redis service returning None (line 226, 237) + services_pkg = types.ModuleType("services") + services_pkg.__path__ = [] + sys.modules["services"] = services_pkg + redis_mod = types.ModuleType("services.redis_service") + redis_mod.get_redis_service = lambda: types.SimpleNamespace( + get_progress_info=lambda task_id: None # Returns None to trigger else branch + ) + sys.modules["services.redis_service"] = redis_mod + + out = await fmu.get_all_files_status("idx") + assert out["/p7"]["state"] == "FORWARDING" + assert out["/p7"]["latest_task_id"] == "50" + # Should use task state values when progress_info is None + assert out["/p7"]["processed_chunks"] == 5 + assert out["/p7"]["total_chunks"] == 10 + + +@pytest.mark.asyncio +async def test_get_all_files_status_redis_processed_chunks_none(fmu, monkeypatch): + """Test when redis_processed is None (line 230 - not entering if branch)""" + tasks_list = [ + { + "id": "60", + "task_name": "forward", + "index_name": "idx", + "path_or_url": "/p8", + "original_filename": "f8", + "source_type": "local", + "status": "STARTED", + "created_at": 1, + "processed_chunks": 3, + "total_chunks": 8, + }, + ] + fake_client = _FakeAsyncClient(_Resp(200, tasks_list)) + monkeypatch.setattr(fmu, "httpx", types.SimpleNamespace(AsyncClient=lambda: fake_client)) + async def _fake_convert(*a, **k): + return "FORWARDING" + monkeypatch.setattr(fmu, "_convert_to_custom_state", _fake_convert) + + # Redis service returning progress_info with processed_chunks as None (line 230) + services_pkg = types.ModuleType("services") + services_pkg.__path__ = [] + sys.modules["services"] = services_pkg + redis_mod = types.ModuleType("services.redis_service") + redis_mod.get_redis_service = lambda: types.SimpleNamespace( + get_progress_info=lambda task_id: { + "processed_chunks": None, # None to skip line 230 if branch + "total_chunks": 15 + } + ) + sys.modules["services.redis_service"] = redis_mod + + out = await fmu.get_all_files_status("idx") + assert out["/p8"]["state"] == "FORWARDING" + # processed_chunks should remain from task state (3) since redis_processed is None + assert out["/p8"]["processed_chunks"] == 3 + # total_chunks should be updated from Redis (15) + assert out["/p8"]["total_chunks"] == 15 + + +@pytest.mark.asyncio +async def test_get_all_files_status_redis_total_chunks_none(fmu, monkeypatch): + """Test when redis_total is None (line 232 - not entering if branch)""" + tasks_list = [ + { + "id": "70", + "task_name": "forward", + "index_name": "idx", + "path_or_url": "/p9", + "original_filename": "f9", + "source_type": "local", + "status": "STARTED", + "created_at": 1, + "processed_chunks": 4, + "total_chunks": 12, + }, + ] + fake_client = _FakeAsyncClient(_Resp(200, tasks_list)) + monkeypatch.setattr(fmu, "httpx", types.SimpleNamespace(AsyncClient=lambda: fake_client)) + async def _fake_convert(*a, **k): + return "FORWARDING" + monkeypatch.setattr(fmu, "_convert_to_custom_state", _fake_convert) + + # Redis service returning progress_info with total_chunks as None (line 232) + services_pkg = types.ModuleType("services") + services_pkg.__path__ = [] + sys.modules["services"] = services_pkg + redis_mod = types.ModuleType("services.redis_service") + redis_mod.get_redis_service = lambda: types.SimpleNamespace( + get_progress_info=lambda task_id: { + "processed_chunks": 6, + "total_chunks": None # None to skip line 232 if branch + } + ) + sys.modules["services.redis_service"] = redis_mod + + out = await fmu.get_all_files_status("idx") + assert out["/p9"]["state"] == "FORWARDING" + # processed_chunks should be updated from Redis (6) + assert out["/p9"]["processed_chunks"] == 6 + # total_chunks should remain from task state (12) since redis_total is None + assert out["/p9"]["total_chunks"] == 12 + diff --git a/test/backend/utils/test_llm_utils.py b/test/backend/utils/test_llm_utils.py index 545bdf776..50857e91b 100644 --- a/test/backend/utils/test_llm_utils.py +++ b/test/backend/utils/test_llm_utils.py @@ -74,7 +74,7 @@ class TestCallLLMForSystemPrompt(unittest.TestCase): def setUp(self): self.test_model_id = 1 - @patch('backend.utils.llm_utils.OpenAIServerModel') + @patch('backend.utils.llm_utils.OpenAIModel') @patch('backend.utils.llm_utils.get_model_name_from_config') @patch('backend.utils.llm_utils.get_model_by_model_id') def test_call_llm_for_system_prompt_success( @@ -118,7 +118,7 @@ def test_call_llm_for_system_prompt_success( top_p=0.95, ) - @patch('backend.utils.llm_utils.OpenAIServerModel') + @patch('backend.utils.llm_utils.OpenAIModel') @patch('backend.utils.llm_utils.get_model_name_from_config') @patch('backend.utils.llm_utils.get_model_by_model_id') def test_call_llm_for_system_prompt_exception( diff --git a/test/sdk/core/agents/test_core_agent.py b/test/sdk/core/agents/test_core_agent.py index cb6240893..54b725620 100644 --- a/test/sdk/core/agents/test_core_agent.py +++ b/test/sdk/core/agents/test_core_agent.py @@ -1,3 +1,5 @@ +import json + import pytest from unittest.mock import MagicMock, patch from threading import Event @@ -14,22 +16,98 @@ def __init__(self, message): super().__init__(message) +class MockAgentMaxStepsError(Exception): + pass + + # Mock for smolagents and its sub-modules mock_smolagents = MagicMock() -mock_smolagents.ActionStep = MagicMock() -mock_smolagents.TaskStep = MagicMock() -mock_smolagents.SystemPromptStep = MagicMock() mock_smolagents.AgentError = MockAgentError mock_smolagents.handle_agent_output_types = MagicMock( return_value="handled_output") +mock_smolagents.utils.AgentMaxStepsError = MockAgentMaxStepsError + +# Create proper class types for isinstance checks (not MagicMock) +class MockActionStep: + def __init__(self, *args, **kwargs): + self.step_number = kwargs.get('step_number', 1) + self.timing = kwargs.get('timing', None) + self.observations_images = kwargs.get('observations_images', None) + self.model_input_messages = None + self.model_output_message = None + self.model_output = None + self.token_usage = None + self.code_action = None + self.tool_calls = None + self.observations = None + self.action_output = None + self.is_final_answer = False + self.error = None + +class MockTaskStep: + def __init__(self, *args, **kwargs): + self.task = kwargs.get('task', '') + self.task_images = kwargs.get('task_images', None) + +class MockSystemPromptStep: + def __init__(self, *args, **kwargs): + self.system_prompt = kwargs.get('system_prompt', '') + +class MockFinalAnswerStep: + def __init__(self, *args, **kwargs): + # Handle both positional and keyword arguments + if args: + self.output = args[0] + else: + self.output = kwargs.get('output', '') + +class MockPlanningStep: + def __init__(self, *args, **kwargs): + self.token_usage = kwargs.get('token_usage', None) + +class MockActionOutput: + def __init__(self, *args, **kwargs): + self.output = kwargs.get('output', None) + self.is_final_answer = kwargs.get('is_final_answer', False) + +class MockRunResult: + def __init__(self, *args, **kwargs): + self.output = kwargs.get('output', None) + self.token_usage = kwargs.get('token_usage', None) + self.steps = kwargs.get('steps', []) + self.timing = kwargs.get('timing', None) + self.state = kwargs.get('state', 'success') + +class MockCodeOutput: + """Mock object returned by python_executor.""" + def __init__(self, output=None, logs="", is_final_answer=False): + self.output = output + self.logs = logs + self.is_final_answer = is_final_answer + +# Assign proper classes to mock_smolagents +mock_smolagents.ActionStep = MockActionStep +mock_smolagents.TaskStep = MockTaskStep +mock_smolagents.SystemPromptStep = MockSystemPromptStep # Create dummy smolagents sub-modules for sub_mod in ["agents", "memory", "models", "monitoring", "utils", "local_python_executor"]: mock_module = MagicMock() setattr(mock_smolagents, sub_mod, mock_module) +# Assign classes to memory submodule +mock_smolagents.memory.ActionStep = MockActionStep +mock_smolagents.memory.TaskStep = MockTaskStep +mock_smolagents.memory.SystemPromptStep = MockSystemPromptStep +mock_smolagents.memory.FinalAnswerStep = MockFinalAnswerStep +mock_smolagents.memory.PlanningStep = MockPlanningStep +mock_smolagents.memory.ToolCall = MagicMock + +# Assign classes to agents submodule mock_smolagents.agents.CodeAgent = MagicMock +mock_smolagents.agents.ActionOutput = MockActionOutput +mock_smolagents.agents.RunResult = MockRunResult # Provide actual implementations for commonly used utils functions @@ -72,6 +150,23 @@ def mock_truncate_content(content, max_length=1000): core_agent_module = sys.modules['sdk.nexent.core.agents.core_agent'] # Override AgentError inside the imported module to ensure it has message attr core_agent_module.AgentError = MockAgentError + core_agent_module.AgentMaxStepsError = MockAgentMaxStepsError + # Override classes to use our mock classes for isinstance checks + core_agent_module.FinalAnswerStep = MockFinalAnswerStep + core_agent_module.ActionStep = MockActionStep + core_agent_module.PlanningStep = MockPlanningStep + core_agent_module.ActionOutput = MockActionOutput + core_agent_module.RunResult = MockRunResult + # Override CodeAgent to be a proper class that can be inherited + class MockCodeAgent: + def __init__(self, prompt_templates=None, *args, **kwargs): + # Accept any arguments but don't require observer + # Store attributes that might be accessed + self.prompt_templates = prompt_templates + # Initialize common attributes that CodeAgent might have + for key, value in kwargs.items(): + setattr(self, key, value) + core_agent_module.CodeAgent = MockCodeAgent CoreAgent = ImportedCoreAgent @@ -103,16 +198,50 @@ def core_agent_instance(mock_observer): agent.stop_event = Event() agent.memory = MagicMock() agent.memory.steps = [] + agent.memory.get_full_steps = MagicMock(return_value=[]) agent.python_executor = MagicMock() + + # Mock logger with all required methods + agent.logger = MagicMock() + agent.logger.log = MagicMock() + agent.logger.log_task = MagicMock() + agent.logger.log_markdown = MagicMock() + agent.logger.log_code = MagicMock() agent.step_number = 1 agent._execute_step = MagicMock() agent._finalize_step = MagicMock() agent._handle_max_steps_reached = MagicMock() + + # Set default attributes that might be needed + agent.max_steps = 5 + agent.state = {} + agent.system_prompt = "test system prompt" + agent.return_full_result = False + agent.provide_run_summary = False + agent.tools = {} + agent.managed_agents = {} + agent.monitor = MagicMock() + agent.monitor.reset = MagicMock() + agent.model = MagicMock() + if hasattr(agent.model, 'model_id'): + agent.model.model_id = "test-model" + agent.code_block_tags = ["```", "```"] + agent._use_structured_outputs_internally = False + agent.final_answer_checks = None # Set to avoid MagicMock creating new CoreAgent instances return agent +@pytest.fixture(autouse=True) +def reset_token_usage_mock(): + """Ensure TokenUsage mock does not leak state between tests.""" + token_usage = getattr(core_agent_module, "TokenUsage", None) + if hasattr(token_usage, "reset_mock"): + token_usage.reset_mock() + yield + + # ---------------------------------------------------------------------------- # Tests for _run method # ---------------------------------------------------------------------------- @@ -123,11 +252,12 @@ def test_run_normal_execution(core_agent_instance): task = "test task" max_steps = 3 - # Mock _execute_step to return a generator that yields final answer - def mock_execute_generator(action_step): - yield "final_answer" + # Mock _step_stream to return a generator that yields ActionOutput with final answer + def mock_step_stream(action_step): + action_output = MockActionOutput(output="final_answer", is_final_answer=True) + yield action_output - with patch.object(core_agent_instance, '_execute_step', side_effect=mock_execute_generator) as mock_execute_step, \ + with patch.object(core_agent_instance, '_step_stream', side_effect=mock_step_stream) as mock_step_stream_patch, \ patch.object(core_agent_instance, '_finalize_step') as mock_finalize_step: core_agent_instance.step_number = 1 @@ -135,11 +265,11 @@ def mock_execute_generator(action_step): result = list(core_agent_instance._run_stream(task, max_steps)) # Assertions - # _run_stream yields: generator output + action step + final answer step + # _run_stream yields: ActionOutput from _step_stream + action step + final answer step assert len(result) == 3 - assert result[0] == "final_answer" # Generator output - assert isinstance(result[1], MagicMock) # Action step - assert isinstance(result[2], MagicMock) # Final answer step + assert isinstance(result[0], MockActionOutput) # ActionOutput from _step_stream + assert isinstance(result[1], MockActionStep) # Action step + assert isinstance(result[2], MockFinalAnswerStep) # Final answer step def test_run_with_max_steps_reached(core_agent_instance): @@ -148,11 +278,12 @@ def test_run_with_max_steps_reached(core_agent_instance): task = "test task" max_steps = 2 - # Mock _execute_step to return None (no final answer) - def mock_execute_generator(action_step): - yield None + # Mock _step_stream to return ActionOutput without final answer + def mock_step_stream(action_step): + action_output = MockActionOutput(output=None, is_final_answer=False) + yield action_output - with patch.object(core_agent_instance, '_execute_step', side_effect=mock_execute_generator) as mock_execute_step, \ + with patch.object(core_agent_instance, '_step_stream', side_effect=mock_step_stream) as mock_step_stream_patch, \ patch.object(core_agent_instance, '_finalize_step') as mock_finalize_step, \ patch.object(core_agent_instance, '_handle_max_steps_reached', return_value="max_steps_reached") as mock_handle_max: @@ -162,18 +293,19 @@ def mock_execute_generator(action_step): result = list(core_agent_instance._run_stream(task, max_steps)) # Assertions - # For 2 steps: (None + action_step) * 2 + final_action_step + final_answer_step = 6 - assert len(result) == 6 - assert result[0] is None # First generator output - assert isinstance(result[1], MagicMock) # First action step - assert result[2] is None # Second generator output - assert isinstance(result[3], MagicMock) # Second action step - # Final action step (from _handle_max_steps_reached) - assert isinstance(result[4], MagicMock) - assert isinstance(result[5], MagicMock) # Final answer step + # For 2 steps: (ActionOutput + action_step) * 2 + final_action_step + final_answer_step = 6 + assert len(result) >= 5 + # First step: ActionOutput + ActionStep + assert isinstance(result[0], MockActionOutput) # First ActionOutput + assert isinstance(result[1], MockActionStep) # First action step + # Second step: ActionOutput + ActionStep + assert isinstance(result[2], MockActionOutput) # Second ActionOutput + assert isinstance(result[3], MockActionStep) # Second action step + # Last should be final answer step + assert isinstance(result[-1], MockFinalAnswerStep) # Final answer step # Verify method calls - assert mock_execute_step.call_count == 2 + assert mock_step_stream_patch.call_count == 2 mock_handle_max.assert_called_once() assert mock_finalize_step.call_count == 2 @@ -184,23 +316,28 @@ def test_run_with_stop_event(core_agent_instance): task = "test task" max_steps = 3 - def mock_execute_generator(action_step): + def mock_step_stream(action_step): core_agent_instance.stop_event.set() - yield None + action_output = MockActionOutput(output=None, is_final_answer=False) + yield action_output + + # Mock handle_agent_output_types to return the input value (identity function) + # This way when final_answer = "", it will be passed through + with patch.object(core_agent_module, 'handle_agent_output_types', side_effect=lambda x: x): + # Mock _step_stream to set stop event + with patch.object(core_agent_instance, '_step_stream', side_effect=mock_step_stream): + with patch.object(core_agent_instance, '_finalize_step'): + # Execute + result = list(core_agent_instance._run_stream(task, max_steps)) - # Mock _execute_step to set stop event - with patch.object(core_agent_instance, '_execute_step', side_effect=mock_execute_generator): - with patch.object(core_agent_instance, '_finalize_step'): - # Execute - result = list(core_agent_instance._run_stream(task, max_steps)) - - # Assertions - # Should yield: generator output + action step + final answer step - assert len(result) == 3 - assert result[0] is None # Generator output - assert isinstance(result[1], MagicMock) # Action step - # Final answer step with "" - assert isinstance(result[2], MagicMock) + # Assertions + # Should yield: ActionOutput from _step_stream + action step + final answer step + assert len(result) == 3 + assert isinstance(result[0], MockActionOutput) # ActionOutput from _step_stream + assert isinstance(result[1], MockActionStep) # Action step + # Final answer step with "" + assert isinstance(result[2], MockFinalAnswerStep) + assert result[2].output == "" def test_run_with_final_answer_error(core_agent_instance): @@ -209,9 +346,9 @@ def test_run_with_final_answer_error(core_agent_instance): task = "test task" max_steps = 3 - # Mock _execute_step to raise FinalAnswerError - with patch.object(core_agent_instance, '_execute_step', - side_effect=core_agent_module.FinalAnswerError()) as mock_execute_step, \ + # Mock _step_stream to raise FinalAnswerError + with patch.object(core_agent_instance, '_step_stream', + side_effect=core_agent_module.FinalAnswerError()) as mock_step_stream, \ patch.object(core_agent_instance, '_finalize_step'): # Execute result = list(core_agent_instance._run_stream(task, max_steps)) @@ -219,8 +356,8 @@ def test_run_with_final_answer_error(core_agent_instance): # Assertions # When FinalAnswerError occurs, it should yield action step + final answer step assert len(result) == 2 - assert isinstance(result[0], MagicMock) # Action step - assert isinstance(result[1], MagicMock) # Final answer step + assert isinstance(result[0], MockActionStep) # Action step + assert isinstance(result[1], MockFinalAnswerStep) # Final answer step def test_run_with_final_answer_error_and_model_output(core_agent_instance): @@ -229,16 +366,12 @@ def test_run_with_final_answer_error_and_model_output(core_agent_instance): task = "test task" max_steps = 3 - # Create a mock action step with model_output - mock_action_step = MagicMock() - mock_action_step.model_output = "```\nprint('hello')\n```" - - # Mock _execute_step to set model_output and then raise FinalAnswerError - def mock_execute_step(action_step): + # Mock _step_stream to set model_output and then raise FinalAnswerError + def mock_step_stream(action_step): action_step.model_output = "```\nprint('hello')\n```" raise core_agent_module.FinalAnswerError() - with patch.object(core_agent_instance, '_execute_step', side_effect=mock_execute_step), \ + with patch.object(core_agent_instance, '_step_stream', side_effect=mock_step_stream), \ patch.object(core_agent_module, 'convert_code_format', return_value="```python\nprint('hello')\n```") as mock_convert, \ patch.object(core_agent_instance, '_finalize_step'): # Execute @@ -246,8 +379,8 @@ def mock_execute_step(action_step): # Assertions assert len(result) == 2 - assert isinstance(result[0], MagicMock) # Action step - assert isinstance(result[1], MagicMock) # Final answer step + assert isinstance(result[0], MockActionStep) # Action step + assert isinstance(result[1], MockFinalAnswerStep) # Final answer step # Verify convert_code_format was called mock_convert.assert_called_once_with( "```\nprint('hello')\n```") @@ -259,9 +392,9 @@ def test_run_with_agent_error_updated(core_agent_instance): task = "test task" max_steps = 3 - # Mock _execute_step to raise AgentError - with patch.object(core_agent_instance, '_execute_step', - side_effect=MockAgentError("test error")) as mock_execute_step, \ + # Mock _step_stream to raise AgentError + with patch.object(core_agent_instance, '_step_stream', + side_effect=MockAgentError("test error")) as mock_step_stream, \ patch.object(core_agent_instance, '_finalize_step'): # Execute result = list(core_agent_instance._run_stream(task, max_steps)) @@ -270,9 +403,9 @@ def test_run_with_agent_error_updated(core_agent_instance): # When AgentError occurs, it should yield action step + final answer step # But the error causes the loop to continue, so we get multiple action steps assert len(result) >= 2 - assert isinstance(result[0], MagicMock) # Action step with error + assert isinstance(result[0], MockActionStep) # Action step with error # Last item should be final answer step - assert isinstance(result[-1], MagicMock) # Final answer step + assert isinstance(result[-1], MockFinalAnswerStep) # Final answer step def test_run_with_agent_parse_error_branch_updated(core_agent_instance): @@ -280,25 +413,40 @@ def test_run_with_agent_parse_error_branch_updated(core_agent_instance): task = "parse task" max_steps = 1 - # Mock _execute_step to set model_output and then raise FinalAnswerError - def mock_execute_step(action_step): + # Mock _step_stream to set model_output and then raise FinalAnswerError + def mock_step_stream(action_step): action_step.model_output = "```\nprint('hello')\n```" raise core_agent_module.FinalAnswerError() - with patch.object(core_agent_instance, '_execute_step', side_effect=mock_execute_step), \ + with patch.object(core_agent_instance, '_step_stream', side_effect=mock_step_stream), \ patch.object(core_agent_module, 'convert_code_format', return_value="```python\nprint('hello')\n```") as mock_convert, \ patch.object(core_agent_instance, '_finalize_step'): results = list(core_agent_instance._run_stream(task, max_steps)) # _run should yield action step + final answer step assert len(results) == 2 - assert isinstance(results[0], MagicMock) # Action step - assert isinstance(results[1], MagicMock) # Final answer step + assert isinstance(results[0], MockActionStep) # Action step + assert isinstance(results[1], MockFinalAnswerStep) # Final answer step # Verify convert_code_format was called mock_convert.assert_called_once_with( "```\nprint('hello')\n```") +def test_run_stream_validates_final_answer_when_checks_enabled(core_agent_instance): + """Ensure _run_stream triggers final answer validation when checks are configured.""" + task = "validate task" + core_agent_instance.final_answer_checks = ["non-empty"] + core_agent_instance._validate_final_answer = MagicMock() + + def mock_step_stream(action_step): + yield MockActionOutput(output="final answer", is_final_answer=True) + + with patch.object(core_agent_instance, '_step_stream', side_effect=mock_step_stream), \ + patch.object(core_agent_instance, '_finalize_step'): + result = list(core_agent_instance._run_stream(task, max_steps=1)) + + assert len(result) == 3 # ActionOutput, ActionStep, FinalAnswerStep + core_agent_instance._validate_final_answer.assert_called_once_with("final answer") def test_convert_code_format_display_replacements(): """Validate convert_code_format correctly transforms format to standard markdown.""" @@ -575,6 +723,10 @@ def test_step_stream_parse_success(core_agent_instance): core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -586,7 +738,7 @@ def test_step_stream_parse_success(core_agent_instance): return_value=[]) core_agent_instance.model = MagicMock(return_value=mock_chat_message) core_agent_instance.python_executor = MagicMock( - return_value=("output", "logs", False)) + return_value=MockCodeOutput(output="output", logs="logs", is_final_answer=False)) # Execute list(core_agent_instance._step_stream(mock_memory_step)) @@ -599,6 +751,33 @@ def test_step_stream_parse_success(core_agent_instance): assert hasattr(mock_memory_step.tool_calls[0], 'arguments') +def test_step_stream_structured_outputs_with_stop_sequence(core_agent_instance): + """Ensure _step_stream handles structured outputs correctly.""" + mock_memory_step = MagicMock() + mock_chat_message = MagicMock() + mock_chat_message.content = json.dumps({"code": "print('hello')"}) + mock_chat_message.token_usage = MagicMock() + + core_agent_instance.agent_name = "test_agent" + core_agent_instance.step_number = 1 + core_agent_instance._use_structured_outputs_internally = True + core_agent_instance.code_block_tags = ["<>", "[CLOSE]"] + core_agent_instance.write_memory_to_messages = MagicMock(return_value=[]) + core_agent_instance.model = MagicMock(return_value=mock_chat_message) + core_agent_instance.python_executor = MagicMock( + return_value=MockCodeOutput(output="result", logs="", is_final_answer=False) + ) + + with patch.object(core_agent_module, 'extract_code_from_text', return_value="print('hello')") as mock_extract, \ + patch.object(core_agent_module, 'fix_final_answer_code', side_effect=lambda code: code): + list(core_agent_instance._step_stream(mock_memory_step)) + + # Ensure structured output helpers were used + mock_extract.assert_called_once_with("print('hello')", core_agent_instance.code_block_tags) + call_kwargs = core_agent_instance.model.call_args.kwargs + assert call_kwargs["response_format"] == core_agent_module.CODEAGENT_RESPONSE_FORMAT + + def test_step_stream_skips_execution_for_display_only(core_agent_instance): """Test that _step_stream raises FinalAnswerError when only DISPLAY code blocks are present.""" # Setup @@ -611,6 +790,10 @@ def test_step_stream_skips_execution_for_display_only(core_agent_instance): core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -637,6 +820,10 @@ def test_step_stream_parse_failure_raises_final_answer_error(core_agent_instance core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -662,6 +849,10 @@ def test_step_stream_model_generation_error(core_agent_instance): core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -687,6 +878,10 @@ def test_step_stream_execution_success(core_agent_instance): core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -698,14 +893,16 @@ def test_step_stream_execution_success(core_agent_instance): return_value=[]) core_agent_instance.model = MagicMock(return_value=mock_chat_message) core_agent_instance.python_executor = MagicMock( - return_value=("Hello World", "Execution logs", False)) + return_value=MockCodeOutput(output="Hello World", logs="Execution logs", is_final_answer=False)) # Execute result = list(core_agent_instance._step_stream(mock_memory_step)) # Assertions - # Should yield None when is_final_answer is False - assert result[0] is None + # Should yield ActionOutput when is_final_answer is False + assert len(result) == 1 + assert isinstance(result[0], MockActionOutput) + assert result[0].is_final_answer is False assert mock_memory_step.observations is not None # Check that observations was set (we can't easily test the exact content due to mock behavior) assert hasattr(mock_memory_step, 'observations') @@ -723,6 +920,10 @@ def test_step_stream_execution_final_answer(core_agent_instance): core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -734,13 +935,16 @@ def test_step_stream_execution_final_answer(core_agent_instance): return_value=[]) core_agent_instance.model = MagicMock(return_value=mock_chat_message) core_agent_instance.python_executor = MagicMock( - return_value=("final answer", "Execution logs", True)) + return_value=MockCodeOutput(output="final answer", logs="Execution logs", is_final_answer=True)) # Execute result = list(core_agent_instance._step_stream(mock_memory_step)) # Assertions - assert result[0] == "final answer" # Should yield the final answer + assert len(result) == 1 + assert isinstance(result[0], MockActionOutput) + assert result[0].is_final_answer is True + assert result[0].output == "final answer" def test_step_stream_execution_error(core_agent_instance): @@ -755,6 +959,10 @@ def test_step_stream_execution_error(core_agent_instance): core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -795,6 +1003,10 @@ def test_step_stream_observer_calls(core_agent_instance): core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -806,7 +1018,7 @@ def test_step_stream_observer_calls(core_agent_instance): return_value=[]) core_agent_instance.model = MagicMock(return_value=mock_chat_message) core_agent_instance.python_executor = MagicMock( - return_value=("test", "logs", False)) + return_value=MockCodeOutput(output="test", logs="logs", is_final_answer=False)) # Execute list(core_agent_instance._step_stream(mock_memory_step)) @@ -847,6 +1059,10 @@ def test_step_stream_execution_with_logs(core_agent_instance): core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -859,14 +1075,16 @@ def test_step_stream_execution_with_logs(core_agent_instance): core_agent_instance.model = MagicMock(return_value=mock_chat_message) # Mock python_executor to return logs core_agent_instance.python_executor = MagicMock( - return_value=("output", "Some execution logs", False)) + return_value=MockCodeOutput(output="output", logs="Some execution logs", is_final_answer=False)) # Execute result = list(core_agent_instance._step_stream(mock_memory_step)) # Assertions - # Should yield None when is_final_answer is False - assert result[0] is None + # Should yield ActionOutput when is_final_answer is False + assert len(result) == 1 + assert isinstance(result[0], MockActionOutput) + assert result[0].is_final_answer is False # Check that execution logs were recorded assert core_agent_instance.observer.add_message.call_count >= 3 calls = core_agent_instance.observer.add_message.call_args_list @@ -887,6 +1105,10 @@ def test_step_stream_execution_error_with_print_outputs(core_agent_instance): core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -926,6 +1148,10 @@ def test_step_stream_execution_error_with_import_warning(core_agent_instance): core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -969,6 +1195,10 @@ def test_step_stream_execution_error_without_print_outputs(core_agent_instance): core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -1003,6 +1233,10 @@ def test_step_stream_execution_with_none_output(core_agent_instance): core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -1015,14 +1249,16 @@ def test_step_stream_execution_with_none_output(core_agent_instance): core_agent_instance.model = MagicMock(return_value=mock_chat_message) # Mock python_executor to return None output core_agent_instance.python_executor = MagicMock( - return_value=(None, "Execution logs", False)) + return_value=MockCodeOutput(output=None, logs="Execution logs", is_final_answer=False)) # Execute result = list(core_agent_instance._step_stream(mock_memory_step)) # Assertions - # Should yield None when is_final_answer is False - assert result[0] is None + # Should yield ActionOutput when is_final_answer is False + assert len(result) == 1 + assert isinstance(result[0], MockActionOutput) + assert result[0].is_final_answer is False assert mock_memory_step.observations is not None # Check that observations was set but should not contain "Last output from code snippet" # since output is None @@ -1050,6 +1286,10 @@ def test_run_with_additional_args(core_agent_instance): core_agent_instance.monitor = MagicMock() core_agent_instance.monitor.reset = MagicMock() core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.model = MagicMock() core_agent_instance.model.model_id = "test-model" core_agent_instance.name = "test_agent" @@ -1059,8 +1299,7 @@ def test_run_with_additional_args(core_agent_instance): core_agent_instance.observer = MagicMock() # Mock _run_stream to return a simple result - mock_final_step = MagicMock() - mock_final_step.final_answer = "final result" + mock_final_step = MockFinalAnswerStep(output="final result") with patch.object(core_agent_instance, '_run_stream', return_value=[mock_final_step]): # Execute @@ -1089,6 +1328,10 @@ def test_run_with_stream_true(core_agent_instance): core_agent_instance.monitor = MagicMock() core_agent_instance.monitor.reset = MagicMock() core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.model = MagicMock() core_agent_instance.model.model_id = "test-model" core_agent_instance.name = "test_agent" @@ -1123,6 +1366,10 @@ def test_run_with_reset_false(core_agent_instance): core_agent_instance.monitor = MagicMock() core_agent_instance.monitor.reset = MagicMock() core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.model = MagicMock() core_agent_instance.model.model_id = "test-model" core_agent_instance.name = "test_agent" @@ -1132,8 +1379,7 @@ def test_run_with_reset_false(core_agent_instance): core_agent_instance.observer = MagicMock() # Mock _run_stream to return a simple result - mock_final_step = MagicMock() - mock_final_step.final_answer = "final result" + mock_final_step = MockFinalAnswerStep(output="final result") with patch.object(core_agent_instance, '_run_stream', return_value=[mock_final_step]): # Execute @@ -1162,6 +1408,10 @@ def test_run_with_images(core_agent_instance): core_agent_instance.monitor = MagicMock() core_agent_instance.monitor.reset = MagicMock() core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.model = MagicMock() core_agent_instance.model.model_id = "test-model" core_agent_instance.name = "test_agent" @@ -1171,8 +1421,7 @@ def test_run_with_images(core_agent_instance): core_agent_instance.observer = MagicMock() # Mock _run_stream to return a simple result - mock_final_step = MagicMock() - mock_final_step.final_answer = "final result" + mock_final_step = MockFinalAnswerStep(output="final result") with patch.object(core_agent_instance, '_run_stream', return_value=[mock_final_step]): # Execute @@ -1185,8 +1434,89 @@ def test_run_with_images(core_agent_instance): call_args = core_agent_instance.memory.steps.append.call_args[0][0] # The TaskStep is mocked, so just verify it was called with correct arguments via the constructor # We'll check that TaskStep was called with the right parameters - mock_smolagents.memory.TaskStep.assert_called_with( - task=task, task_images=images) + assert isinstance(call_args, MockTaskStep) + assert call_args.task == task + assert call_args.task_images == images + + +def test_run_return_full_result_success_state(core_agent_instance): + """run should return RunResult with aggregated token usage when requested.""" + task = "test task" + token_usage = MagicMock(input_tokens=7, output_tokens=3) + action_step = core_agent_module.ActionStep() + action_step.token_usage = token_usage + + core_agent_instance.name = "test_agent" + core_agent_instance.memory.steps = [action_step] + core_agent_instance.memory.get_full_steps = MagicMock(return_value=[{"step": "data"}]) + core_agent_instance.memory.reset = MagicMock() + core_agent_instance.monitor.reset = MagicMock() + core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() + core_agent_instance.model = MagicMock() + core_agent_instance.model.model_id = "model" + core_agent_instance.python_executor = MagicMock() + core_agent_instance.python_executor.send_variables = MagicMock() + core_agent_instance.python_executor.send_tools = MagicMock() + core_agent_instance.observer = MagicMock() + + final_step = MockFinalAnswerStep(output="final result") + with patch.object(core_agent_instance, '_run_stream', return_value=[final_step]): + result = core_agent_instance.run(task, return_full_result=True) + + assert isinstance(result, core_agent_module.RunResult) + assert result.output == "final result" + core_agent_module.TokenUsage.assert_called_once_with(input_tokens=7, output_tokens=3) + assert result.token_usage == core_agent_module.TokenUsage.return_value + assert result.state == "success" + core_agent_instance.memory.get_full_steps.assert_called_once() + + +def test_run_return_full_result_max_steps_error(core_agent_instance): + """run should mark state as max_steps_error when the last step contains AgentMaxStepsError.""" + task = "test task" + + action_step = core_agent_module.ActionStep() + action_step.token_usage = None + action_step.error = core_agent_module.AgentMaxStepsError("max steps reached") + + class StepsList(list): + def append(self, item): + # Skip storing TaskStep to keep action_step as the last element + if isinstance(item, core_agent_module.TaskStep): + return + super().append(item) + + core_agent_instance.name = "test_agent" + steps_list = StepsList([action_step]) + core_agent_instance.memory.steps = steps_list + core_agent_instance.memory.get_full_steps = MagicMock(return_value=[{"step": "data"}]) + core_agent_instance.memory.reset = MagicMock() + core_agent_instance.monitor.reset = MagicMock() + core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() + core_agent_instance.model = MagicMock() + core_agent_instance.model.model_id = "model" + core_agent_instance.python_executor = MagicMock() + core_agent_instance.python_executor.send_variables = MagicMock() + core_agent_instance.python_executor.send_tools = MagicMock() + core_agent_instance.observer = MagicMock() + + final_step = MockFinalAnswerStep(output="final result") + with patch.object(core_agent_instance, '_run_stream', return_value=[final_step]): + result = core_agent_instance.run(task, return_full_result=True) + + assert isinstance(result, core_agent_module.RunResult) + assert result.token_usage is None + core_agent_module.TokenUsage.assert_not_called() + assert result.state == "max_steps_error" + core_agent_instance.memory.get_full_steps.assert_called_once() def test_run_without_python_executor(core_agent_instance): @@ -1204,6 +1534,10 @@ def test_run_without_python_executor(core_agent_instance): core_agent_instance.monitor = MagicMock() core_agent_instance.monitor.reset = MagicMock() core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.model = MagicMock() core_agent_instance.model.model_id = "test-model" core_agent_instance.name = "test_agent" @@ -1213,8 +1547,7 @@ def test_run_without_python_executor(core_agent_instance): core_agent_instance.observer = MagicMock() # Mock _run_stream to return a simple result - mock_final_step = MagicMock() - mock_final_step.final_answer = "final result" + mock_final_step = MockFinalAnswerStep(output="final result") with patch.object(core_agent_instance, '_run_stream', return_value=[mock_final_step]): # Execute @@ -1267,6 +1600,31 @@ def test_call_method_success(core_agent_instance): "test_agent", ProcessType.AGENT_FINISH, "test result") +def test_call_method_with_run_result_return(core_agent_instance): + """Test __call__ handles RunResult by extracting its output.""" + task = "test task" + core_agent_instance.name = "test_agent" + core_agent_instance.state = {} + core_agent_instance.prompt_templates = { + "managed_agent": { + "task": "Task: {{task}}", + "report": "Report: {{final_answer}}" + } + } + core_agent_instance.provide_run_summary = False + core_agent_instance.observer = MagicMock() + + run_result = core_agent_module.RunResult(output="run result", token_usage=None, steps=[], timing=None, state="success") + with patch.object(core_agent_instance, 'run', return_value=run_result) as mock_run: + result = core_agent_instance(task) + + assert "Report: run result" in result + mock_run.assert_called_once() + core_agent_instance.observer.add_message.assert_called_with( + "test_agent", ProcessType.AGENT_FINISH, "run result" + ) + + def test_call_method_with_run_summary(core_agent_instance): """Test __call__ method with provide_run_summary=True.""" # Setup @@ -1284,10 +1642,14 @@ def test_call_method_with_run_summary(core_agent_instance): core_agent_instance.provide_run_summary = True core_agent_instance.observer = MagicMock() - # Mock write_memory_to_messages to return some simple messages + # Mock write_memory_to_messages to return some simple messages with .content attribute + class MockMessage: + def __init__(self, content): + self.content = content + mock_messages = [ - {"content": "msg1"}, - {"content": "msg2"} + MockMessage("msg1"), + MockMessage("msg2") ] core_agent_instance.write_memory_to_messages = MagicMock( return_value=mock_messages) diff --git a/test/sdk/core/agents/test_nexent_agent.py b/test/sdk/core/agents/test_nexent_agent.py index 3dc831323..2a842ea72 100644 --- a/test/sdk/core/agents/test_nexent_agent.py +++ b/test/sdk/core/agents/test_nexent_agent.py @@ -27,11 +27,16 @@ class _ActionStep: - pass + def __init__(self, step_number=None, timing=None, action_output=None, model_output=None): + self.step_number = step_number + self.timing = timing + self.action_output = action_output + self.model_output = model_output class _TaskStep: - pass + def __init__(self, task=None): + self.task = task class _AgentText: @@ -214,6 +219,8 @@ class _MockToolSign: "nexent.storage": mock_nexent_storage_module, "nexent.multi_modal": mock_nexent_multi_modal_module, "nexent.multi_modal.load_save_object": mock_nexent_load_save_module, + # Mock tiktoken to avoid importing the real package when models import it + "tiktoken": MagicMock(), # Mock the OpenAIModel import "sdk.nexent.core.models.openai_llm": MagicMock(OpenAIModel=mock_openai_model_class), # Mock CoreAgent import @@ -230,7 +237,7 @@ class _MockToolSign: from sdk.nexent.core.utils.observer import MessageObserver, ProcessType from sdk.nexent.core.agents import nexent_agent from sdk.nexent.core.agents.nexent_agent import NexentAgent, ActionStep, TaskStep - from sdk.nexent.core.agents.agent_model import ToolConfig, ModelConfig, AgentConfig + from sdk.nexent.core.agents.agent_model import ToolConfig, ModelConfig, AgentConfig, AgentHistory # ---------------------------------------------------------------------------- @@ -1087,6 +1094,48 @@ def test_add_history_to_agent_none_history(nexent_agent_instance, mock_core_agen assert len(mock_core_agent.memory.steps) == 0 +def test_add_history_to_agent_user_and_assistant_history(nexent_agent_instance, mock_core_agent): + """Test add_history_to_agent correctly converts user and assistant messages to memory steps.""" + nexent_agent_instance.agent = mock_core_agent + + user_msg = AgentHistory(role="user", content="User question") + assistant_msg = AgentHistory(role="assistant", content="Assistant reply") + + nexent_agent_instance.add_history_to_agent([user_msg, assistant_msg]) + + mock_core_agent.memory.reset.assert_called_once() + assert len(mock_core_agent.memory.steps) == 2 + + # First step should be a TaskStep for the user message + first_step = mock_core_agent.memory.steps[0] + assert isinstance(first_step, TaskStep) + assert first_step.task == "User question" + + # Second step should be an ActionStep for the assistant message + second_step = mock_core_agent.memory.steps[1] + assert isinstance(second_step, ActionStep) + assert second_step.action_output == "Assistant reply" + assert second_step.model_output == "Assistant reply" + + +def test_add_history_to_agent_invalid_agent_type(nexent_agent_instance): + """Test add_history_to_agent raises TypeError when agent is not a CoreAgent.""" + nexent_agent_instance.agent = "not_core_agent" + + with pytest.raises(TypeError, match="agent must be a CoreAgent object"): + nexent_agent_instance.add_history_to_agent([]) + + +def test_add_history_to_agent_invalid_history_items(nexent_agent_instance, mock_core_agent): + """Test add_history_to_agent raises TypeError when history items are not AgentHistory.""" + nexent_agent_instance.agent = mock_core_agent + + invalid_history = [{"role": "user", "content": "hello"}] + + with pytest.raises(TypeError, match="history must be a list of AgentHistory objects"): + nexent_agent_instance.add_history_to_agent(invalid_history) + + def test_agent_run_with_observer_success_with_agent_text(nexent_agent_instance, mock_core_agent): """Test successful agent_run_with_observer with AgentText final answer.""" # Setup @@ -1103,7 +1152,7 @@ def test_agent_run_with_observer_success_with_agent_text(nexent_agent_instance, "Final answer with thinking content") mock_core_agent.run.return_value = [mock_action_step] - mock_core_agent.run.return_value[-1].final_answer = mock_final_answer + mock_core_agent.run.return_value[-1].output = mock_final_answer # Execute nexent_agent_instance.agent_run_with_observer("test query") @@ -1129,7 +1178,7 @@ def test_agent_run_with_observer_success_with_string_final_answer(nexent_agent_i mock_action_step.error = None mock_core_agent.run.return_value = [mock_action_step] - mock_core_agent.run.return_value[-1].final_answer = "String final answer with thinking" + mock_core_agent.run.return_value[-1].output = "String final answer with thinking" # Execute nexent_agent_instance.agent_run_with_observer("test query") @@ -1153,7 +1202,7 @@ def test_agent_run_with_observer_with_error_in_step(nexent_agent_instance, mock_ mock_action_step.error = "Test error occurred" mock_core_agent.run.return_value = [mock_action_step] - mock_core_agent.run.return_value[-1].final_answer = "Final answer" + mock_core_agent.run.return_value[-1].output = "Final answer" # Execute nexent_agent_instance.agent_run_with_observer("test query") @@ -1176,7 +1225,7 @@ def test_agent_run_with_observer_skips_non_action_step(nexent_agent_instance, mo mock_action_step.error = None mock_core_agent.run.return_value = [mock_task_step, mock_action_step] - mock_core_agent.run.return_value[-1].final_answer = "Final answer" + mock_core_agent.run.return_value[-1].output = "Final answer" # Execute nexent_agent_instance.agent_run_with_observer("test query") @@ -1199,7 +1248,7 @@ def test_agent_run_with_observer_with_stop_event_set(nexent_agent_instance, mock mock_action_step.error = None mock_core_agent.run.return_value = [mock_action_step] - mock_core_agent.run.return_value[-1].final_answer = "Final answer" + mock_core_agent.run.return_value[-1].output = "Final answer" # Execute nexent_agent_instance.agent_run_with_observer("test query") @@ -1226,6 +1275,14 @@ def test_agent_run_with_observer_with_exception(nexent_agent_instance, mock_core ) +def test_agent_run_with_observer_invalid_agent_type(nexent_agent_instance): + """Test agent_run_with_observer raises TypeError when agent is not a CoreAgent.""" + nexent_agent_instance.agent = "not_core_agent" + + with pytest.raises(TypeError, match="agent must be a CoreAgent object"): + nexent_agent_instance.agent_run_with_observer("test query") + + def test_agent_run_with_observer_with_reset_false(nexent_agent_instance, mock_core_agent): """Test agent_run_with_observer with reset=False parameter.""" # Setup @@ -1238,7 +1295,7 @@ def test_agent_run_with_observer_with_reset_false(nexent_agent_instance, mock_co mock_action_step.error = None mock_core_agent.run.return_value = [mock_action_step] - mock_core_agent.run.return_value[-1].final_answer = "Final answer" + mock_core_agent.run.return_value[-1].output = "Final answer" # Execute with reset=False nexent_agent_instance.agent_run_with_observer("test query", reset=False) diff --git a/test/sdk/core/agents/test_run_agent.py b/test/sdk/core/agents/test_run_agent.py index 0cafdd8a1..b47aec879 100644 --- a/test/sdk/core/agents/test_run_agent.py +++ b/test/sdk/core/agents/test_run_agent.py @@ -49,7 +49,7 @@ def from_mcp(cls, *args, **kwargs): # pylint: disable=unused-argument sub_mod = ModuleType(f"smolagents.{_sub}") # Populate required attributes with MagicMocks to satisfy import-time `from smolagents. import ...`. if _sub == "agents": - for _name in ["CodeAgent", "populate_template", "handle_agent_output_types", "AgentError", "AgentType"]: + for _name in ["CodeAgent", "populate_template", "handle_agent_output_types", "AgentError", "AgentType", "ActionOutput", "RunResult"]: setattr(sub_mod, _name, MagicMock(name=f"smolagents.agents.{_name}")) elif _sub == "local_python_executor": setattr(sub_mod, "fix_final_answer_code", MagicMock(name="fix_final_answer_code")) @@ -59,6 +59,7 @@ def from_mcp(cls, *args, **kwargs): # pylint: disable=unused-argument elif _sub == "models": setattr(sub_mod, "ChatMessage", MagicMock(name="smolagents.models.ChatMessage")) setattr(sub_mod, "MessageRole", MagicMock(name="smolagents.models.MessageRole")) + setattr(sub_mod, "CODEAGENT_RESPONSE_FORMAT", MagicMock(name="smolagents.models.CODEAGENT_RESPONSE_FORMAT")) # Provide a simple base class so that OpenAIModel can inherit from it class _DummyOpenAIServerModel: def __init__(self, *args, **kwargs): @@ -67,13 +68,18 @@ def __init__(self, *args, **kwargs): setattr(sub_mod, "OpenAIServerModel", _DummyOpenAIServerModel) elif _sub == "monitoring": setattr(sub_mod, "LogLevel", MagicMock(name="smolagents.monitoring.LogLevel")) + setattr(sub_mod, "Timing", MagicMock(name="smolagents.monitoring.Timing")) + setattr(sub_mod, "YELLOW_HEX", MagicMock(name="smolagents.monitoring.YELLOW_HEX")) + setattr(sub_mod, "TokenUsage", MagicMock(name="smolagents.monitoring.TokenUsage")) elif _sub == "utils": for _name in [ "AgentExecutionError", "AgentGenerationError", "AgentParsingError", + "AgentMaxStepsError", "parse_code_blobs", "truncate_content", + "extract_code_from_text", ]: setattr(sub_mod, _name, MagicMock(name=f"smolagents.utils.{_name}")) setattr(mock_smolagents, _sub, sub_mod) @@ -82,6 +88,8 @@ def __init__(self, *args, **kwargs): # Top-level exports expected directly from `smolagents` by nexent_agent.py for _name in ["ActionStep", "TaskStep", "AgentText", "handle_agent_output_types"]: setattr(mock_smolagents, _name, MagicMock(name=f"smolagents.{_name}")) +# Export Timing from monitoring submodule to top-level +setattr(mock_smolagents, "Timing", mock_smolagents.monitoring.Timing) # Also export Tool at top-level so that `from smolagents import Tool` works setattr(mock_smolagents, "Tool", mock_smolagents_tool_cls) @@ -237,9 +245,9 @@ def test_agent_run_thread_local_flow(basic_agent_run_info, monkeypatch): def test_agent_run_thread_mcp_flow(basic_agent_run_info, mock_memory_context, monkeypatch): - """Verify behaviour when an MCP host list is provided.""" - # Give the AgentRunInfo an MCP host list - basic_agent_run_info.mcp_host = ["http://mcp.server"] + """Verify behaviour when an MCP host list is provided with auto-detected transport.""" + # Give the AgentRunInfo an MCP host list (string format, auto-detect transport) + basic_agent_run_info.mcp_host = ["http://mcp.server/mcp"] # Prepare ToolCollection.from_mcp to return a context manager mock_tool_collection = MagicMock(name="ToolCollectionInstance") @@ -257,7 +265,7 @@ def test_agent_run_thread_mcp_flow(basic_agent_run_info, mock_memory_context, mo basic_agent_run_info.observer.add_message.assert_any_call("", ProcessType.AGENT_NEW_RUN, "") # ToolCollection.from_mcp should be called with the expected client list and trust_remote_code=True - expected_client_list = [{"url": "http://mcp.server"}] + expected_client_list = [{"url": "http://mcp.server/mcp", "transport": "streamable-http"}] run_agent.ToolCollection.from_mcp.assert_called_once_with(expected_client_list, trust_remote_code=True) # NexentAgent should be instantiated with mcp_tool_collection @@ -275,6 +283,116 @@ def test_agent_run_thread_mcp_flow(basic_agent_run_info, mock_memory_context, mo mock_nexent_instance.agent_run_with_observer.assert_called_once_with(query=basic_agent_run_info.query, reset=False) +def test_agent_run_thread_mcp_flow_with_explicit_transport(basic_agent_run_info, mock_memory_context, monkeypatch): + """Verify behaviour when MCP host is provided with explicit transport in dict format.""" + # Give the AgentRunInfo an MCP host list with explicit transport + basic_agent_run_info.mcp_host = [{"url": "http://mcp.server", "transport": "sse"}] + + # Prepare ToolCollection.from_mcp to return a context manager + mock_tool_collection = MagicMock(name="ToolCollectionInstance") + mock_context_manager = MagicMock(__enter__=MagicMock(return_value=mock_tool_collection), __exit__=MagicMock(return_value=None)) + monkeypatch.setattr(run_agent.ToolCollection, "from_mcp", MagicMock(return_value=mock_context_manager)) + + # Patch NexentAgent + mock_nexent_instance = MagicMock(name="NexentAgentInstance") + monkeypatch.setattr(run_agent, "NexentAgent", MagicMock(return_value=mock_nexent_instance)) + + # Execute + run_agent.agent_run_thread(basic_agent_run_info) + + # ToolCollection.from_mcp should be called with the expected client list + expected_client_list = [{"url": "http://mcp.server", "transport": "sse"}] + run_agent.ToolCollection.from_mcp.assert_called_once_with(expected_client_list, trust_remote_code=True) + + +def test_agent_run_thread_mcp_flow_mixed_formats(basic_agent_run_info, mock_memory_context, monkeypatch): + """Verify behaviour when MCP host list contains both string and dict formats.""" + # Mix of string (auto-detect) and dict (explicit) formats + basic_agent_run_info.mcp_host = [ + "http://mcp1.server/mcp", # Auto-detect: streamable-http + "http://mcp2.server/sse", # Auto-detect: sse + {"url": "http://mcp3.server/mcp", "transport": "streamable-http"}, # Explicit: streamable-http + ] + + # Prepare ToolCollection.from_mcp to return a context manager + mock_tool_collection = MagicMock(name="ToolCollectionInstance") + mock_context_manager = MagicMock(__enter__=MagicMock(return_value=mock_tool_collection), __exit__=MagicMock(return_value=None)) + monkeypatch.setattr(run_agent.ToolCollection, "from_mcp", MagicMock(return_value=mock_context_manager)) + + # Patch NexentAgent + mock_nexent_instance = MagicMock(name="NexentAgentInstance") + monkeypatch.setattr(run_agent, "NexentAgent", MagicMock(return_value=mock_nexent_instance)) + + # Execute + run_agent.agent_run_thread(basic_agent_run_info) + + # ToolCollection.from_mcp should be called with normalized client list + expected_client_list = [ + {"url": "http://mcp1.server/mcp", "transport": "streamable-http"}, + {"url": "http://mcp2.server/sse", "transport": "sse"}, + {"url": "http://mcp3.server/mcp", "transport": "streamable-http"}, + ] + run_agent.ToolCollection.from_mcp.assert_called_once_with(expected_client_list, trust_remote_code=True) + + +def test_detect_transport(): + """Test transport auto-detection logic based on URL ending.""" + # Test URLs ending with /sse + assert run_agent._detect_transport("http://server/sse") == "sse" + assert run_agent._detect_transport("https://api.example.com/sse") == "sse" + assert run_agent._detect_transport("http://localhost:3000/sse") == "sse" + + # Test URLs ending with /mcp + assert run_agent._detect_transport("http://server/mcp") == "streamable-http" + assert run_agent._detect_transport("https://api.example.com/mcp") == "streamable-http" + assert run_agent._detect_transport("http://localhost:3000/mcp") == "streamable-http" + + # Test default fallback (no /sse or /mcp ending) + assert run_agent._detect_transport("http://server") == "streamable-http" + assert run_agent._detect_transport("https://api.example.com") == "streamable-http" + assert run_agent._detect_transport("http://server/other") == "streamable-http" + + +def test_normalize_mcp_config(): + """Test MCP configuration normalization.""" + # Test string format (auto-detect based on URL ending) + result = run_agent._normalize_mcp_config("http://server/mcp") + assert result == {"url": "http://server/mcp", "transport": "streamable-http"} + + result = run_agent._normalize_mcp_config("http://server/sse") + assert result == {"url": "http://server/sse", "transport": "sse"} + + # Test string format without /sse or /mcp ending (defaults to streamable-http) + result = run_agent._normalize_mcp_config("http://server") + assert result == {"url": "http://server", "transport": "streamable-http"} + + # Test dict format with explicit transport + result = run_agent._normalize_mcp_config({"url": "http://server/mcp", "transport": "sse"}) + assert result == {"url": "http://server/mcp", "transport": "sse"} + + # Test dict format without transport (auto-detect) + result = run_agent._normalize_mcp_config({"url": "http://server/sse"}) + assert result == {"url": "http://server/sse", "transport": "sse"} + + result = run_agent._normalize_mcp_config({"url": "http://server/mcp"}) + assert result == {"url": "http://server/mcp", "transport": "streamable-http"} + + # Test invalid dict (missing url) + with pytest.raises(ValueError, match="must contain 'url' key"): + run_agent._normalize_mcp_config({"transport": "sse"}) + + # Test invalid transport type + with pytest.raises(ValueError, match="Invalid transport type"): + run_agent._normalize_mcp_config({"url": "http://server/mcp", "transport": "stdio"}) + + with pytest.raises(ValueError, match="Invalid transport type"): + run_agent._normalize_mcp_config({"url": "http://server/mcp", "transport": "invalid"}) + + # Test invalid type + with pytest.raises(ValueError, match="Invalid MCP host item type"): + run_agent._normalize_mcp_config(123) + + def test_agent_run_thread_handles_internal_exception(basic_agent_run_info, mock_memory_context, monkeypatch): """If an internal error occurs, the observer should be notified and a ValueError propagated.""" # Configure NexentAgent.create_single_agent to raise an exception diff --git a/test/sdk/core/tools/test_datamate_search_tool.py b/test/sdk/core/tools/test_datamate_search_tool.py index cc2742796..ebfdb3bba 100644 --- a/test/sdk/core/tools/test_datamate_search_tool.py +++ b/test/sdk/core/tools/test_datamate_search_tool.py @@ -117,7 +117,7 @@ def test_extract_dataset_id(self, datamate_tool: DataMateSearchTool, path, expec @pytest.mark.parametrize( "dataset_id, file_id, expected", [ - ("ds1", "f1", "127.0.0.1/api/data-management/datasets/ds1/files/f1/download"), + ("ds1", "f1", "http://127.0.0.1:8080/api/data-management/datasets/ds1/files/f1/download"), ("", "f1", ""), ("ds1", "", ""), ], diff --git a/test/sdk/core/tools/test_knowledge_base_search_tool.py b/test/sdk/core/tools/test_knowledge_base_search_tool.py index 535af6b35..f6cdc4577 100644 --- a/test/sdk/core/tools/test_knowledge_base_search_tool.py +++ b/test/sdk/core/tools/test_knowledge_base_search_tool.py @@ -37,7 +37,8 @@ def knowledge_base_search_tool(mock_observer, mock_vdb_core, mock_embedding_mode index_names=["test_index1", "test_index2"], observer=mock_observer, embedding_model=mock_embedding_model, - vdb_core=mock_vdb_core + vdb_core=mock_vdb_core, + name_resolver={} ) return tool @@ -50,7 +51,8 @@ def knowledge_base_search_tool_no_observer(mock_vdb_core, mock_embedding_model): index_names=["test_index"], observer=None, embedding_model=mock_embedding_model, - vdb_core=mock_vdb_core + vdb_core=mock_vdb_core, + name_resolver={} ) return tool @@ -78,6 +80,49 @@ def create_mock_search_result(count=3): class TestKnowledgeBaseSearchTool: """Test KnowledgeBaseSearchTool functionality""" + def test_update_name_resolver_supports_empty_mapping(self, knowledge_base_search_tool): + """Ensure update_name_resolver replaces mapping and handles falsy input""" + knowledge_base_search_tool.update_name_resolver({"kb": "index_kb"}) + assert knowledge_base_search_tool.name_resolver == {"kb": "index_kb"} + + knowledge_base_search_tool.update_name_resolver(None) + assert knowledge_base_search_tool.name_resolver == {} + + def test_resolve_names_without_resolver_logs_warning(self, knowledge_base_search_tool, mocker): + """When no resolver is configured, names are returned unchanged and warning is logged""" + warning_mock = mocker.patch("sdk.nexent.core.tools.knowledge_base_search_tool.logger.warning") + + names = knowledge_base_search_tool._resolve_names(["kb1", "kb2"]) + + assert names == ["kb1", "kb2"] + warning_mock.assert_called_once() + + @pytest.mark.parametrize( + "incoming,expected", + [ + (None, []), + ("single_index", ["single_index"]), + (["a", "b"], ["a", "b"]), + ], + ) + def test_normalize_index_names_variants(self, knowledge_base_search_tool_no_observer, incoming, expected): + """_normalize_index_names should normalize None, string, and list inputs""" + assert knowledge_base_search_tool_no_observer._normalize_index_names(incoming) == expected + + def test_forward_with_observer_adds_messages(self, knowledge_base_search_tool): + """forward should send TOOL and CARD messages when observer is present""" + mock_results = create_mock_search_result(1) + knowledge_base_search_tool.vdb_core.hybrid_search.return_value = mock_results + + knowledge_base_search_tool.forward("hello world") + + knowledge_base_search_tool.observer.add_message.assert_any_call( + "", ProcessType.TOOL, "Searching the knowledge base..." + ) + knowledge_base_search_tool.observer.add_message.assert_any_call( + "", ProcessType.CARD, json.dumps([{"icon": "search", "text": "hello world"}], ensure_ascii=False) + ) + def test_init_with_custom_values(self, mock_observer, mock_vdb_core, mock_embedding_model): """Test initialization with custom values""" tool = KnowledgeBaseSearchTool( @@ -85,7 +130,8 @@ def test_init_with_custom_values(self, mock_observer, mock_vdb_core, mock_embedd index_names=["index1", "index2", "index3"], observer=mock_observer, embedding_model=mock_embedding_model, - vdb_core=mock_vdb_core + vdb_core=mock_vdb_core, + name_resolver={} ) assert tool.top_k == 10 @@ -101,7 +147,8 @@ def test_init_with_none_index_names(self, mock_vdb_core, mock_embedding_model): index_names=None, observer=None, embedding_model=mock_embedding_model, - vdb_core=mock_vdb_core + vdb_core=mock_vdb_core, + name_resolver={} ) assert tool.index_names == [] diff --git a/test/sdk/vector_database/test_elasticsearch_core.py b/test/sdk/vector_database/test_elasticsearch_core.py index 30f8ff277..f9f878852 100644 --- a/test/sdk/vector_database/test_elasticsearch_core.py +++ b/test/sdk/vector_database/test_elasticsearch_core.py @@ -522,6 +522,47 @@ def test_vectorize_documents_small_batch(elasticsearch_core_instance): mock_embedding_model.get_embeddings.assert_called_once() mock_bulk.assert_called_once() +def test_small_batch_progress_callback_exception(elasticsearch_core_instance, caplog): + """Progress callback errors should be logged without failing the insert.""" + mock_embedding_model = MagicMock() + mock_embedding_model.get_embeddings.return_value = [[0.1] * 3] + mock_embedding_model.embedding_model_name = "m" + + documents = [{"content": "a"}] + + def bad_progress(_, __): + raise RuntimeError("boom") + + with patch.object(elasticsearch_core_instance.client, "bulk") as mock_bulk, \ + patch("time.strftime", lambda *a, **k: "2025-01-15T10:30:00"), \ + patch("time.time", lambda: 1642234567): + mock_bulk.return_value = {"errors": False, "items": []} + result = elasticsearch_core_instance._small_batch_insert( + "idx", documents, "content", mock_embedding_model, progress_callback=bad_progress + ) + + assert result == 1 + assert any("Progress callback failed in small batch" in m for m in caplog.messages) + +def test_small_batch_error_path_logs_and_raises(elasticsearch_core_instance, caplog): + """Small batch should log errors and re-raise when bulk fails.""" + mock_embedding_model = MagicMock() + mock_embedding_model.get_embeddings.return_value = [[0.1] * 3] + mock_embedding_model.embedding_model_name = "m" + + documents = [{"content": "x"}] + + with patch.object(elasticsearch_core_instance, "client") as mock_client, \ + patch("time.strftime", lambda *a, **k: "2025-01-15T10:30:00"), \ + patch("time.time", lambda: 1642234567): + mock_client.bulk.side_effect = RuntimeError("bulk boom") + with pytest.raises(RuntimeError): + elasticsearch_core_instance._small_batch_insert( + "idx", documents, "content", mock_embedding_model + ) + + assert any("Small batch insert failed: bulk boom" in m for m in caplog.messages) + def test_vectorize_documents_large_batch(elasticsearch_core_instance): """Test indexing a large batch of documents (>= 64).""" @@ -558,6 +599,76 @@ def test_vectorize_documents_large_batch(elasticsearch_core_instance): mock_bulk.assert_called() mock_refresh.assert_called_once_with("test_index") +def test_large_batch_progress_callback_invoked(elasticsearch_core_instance): + """Progress callback should be triggered during embedding phase.""" + mock_embedding_model = MagicMock() + mock_embedding_model.embedding_model_name = "test-model" + mock_embedding_model.get_embeddings.return_value = [[0.1], [0.2]] + + docs = [{"content": "a"}, {"content": "b"}] + progress_calls = [] + + with patch.object(elasticsearch_core_instance.client, "bulk") as mock_bulk, \ + patch.object(elasticsearch_core_instance, "_force_refresh_with_retry"): + mock_bulk.return_value = {"errors": False, "items": []} + elasticsearch_core_instance._large_batch_insert( + "idx", docs, batch_size=5, content_field="content", + embedding_model=mock_embedding_model, embedding_batch_size=2, + progress_callback=lambda done, total: progress_calls.append((done, total)) + ) + + assert progress_calls == [(2, 2)] + +def test_large_batch_progress_callback_exception_logged(elasticsearch_core_instance, caplog): + """Embedding progress callback errors should be logged and not stop indexing.""" + mock_embedding_model = MagicMock() + mock_embedding_model.embedding_model_name = "test-model" + mock_embedding_model.get_embeddings.return_value = [[0.1]] + + docs = [{"content": "a"}] + + def bad_progress(_, __): + raise RuntimeError("cb fail") + + with patch.object(elasticsearch_core_instance.client, "bulk") as mock_bulk, \ + patch.object(elasticsearch_core_instance, "_force_refresh_with_retry"): + mock_bulk.return_value = {"errors": False, "items": []} + elasticsearch_core_instance._large_batch_insert( + "idx", docs, batch_size=1, content_field="content", + embedding_model=mock_embedding_model, embedding_batch_size=1, + progress_callback=bad_progress + ) + + assert any("Progress callback failed during embedding" in m for m in caplog.messages) + +def test_large_batch_retry_logs_warning(elasticsearch_core_instance, caplog): + """Embedding retries should emit warnings before succeeding.""" + mock_embedding_model = MagicMock() + mock_embedding_model.embedding_model_name = "test-model" + call_counter = {"n": 0} + + def get_embeddings(_): + call_counter["n"] += 1 + if call_counter["n"] < 3: + raise RuntimeError("embed fail") + return [[0.1]] + + mock_embedding_model.get_embeddings.side_effect = get_embeddings + + docs = [{"content": "a"}] + + with patch.object(elasticsearch_core_instance.client, "bulk") as mock_bulk, \ + patch.object(elasticsearch_core_instance, "_force_refresh_with_retry"), \ + patch("time.sleep", lambda *a, **k: None): + mock_bulk.return_value = {"errors": False, "items": []} + elasticsearch_core_instance._large_batch_insert( + "idx", docs, batch_size=1, content_field="content", + embedding_model=mock_embedding_model, embedding_batch_size=1, + ) + + assert call_counter["n"] == 3 + assert any("Embedding API error (attempt 1/3)" in m for m in caplog.messages) + def test_delete_documents_success(elasticsearch_core_instance): """Test deleting documents by path_or_url successfully.""" @@ -1134,8 +1245,12 @@ def test_handle_bulk_errors_with_errors(elasticsearch_core_instance): ] } - # Should not raise exception, just log errors - elasticsearch_core_instance._handle_bulk_errors(response) + with pytest.raises(Exception) as exc_info: + elasticsearch_core_instance._handle_bulk_errors(response) + + err_payload = str(exc_info.value) + assert "Bulk indexing failed: Failed to parse mapping" in err_payload + assert "es_bulk_failed" in err_payload def test_handle_bulk_errors_version_conflict(elasticsearch_core_instance): @@ -1158,6 +1273,40 @@ def test_handle_bulk_errors_version_conflict(elasticsearch_core_instance): elasticsearch_core_instance._handle_bulk_errors(response) +def test_handle_bulk_errors_skips_items_without_error(elasticsearch_core_instance): + """Items without error key should be ignored.""" + response = { + "errors": True, + "items": [{"index": {}}], + } + # Should not raise + elasticsearch_core_instance._handle_bulk_errors(response) + + +def test_handle_bulk_errors_dim_mismatch_sets_specific_code(elasticsearch_core_instance): + """Dense vector dimension mismatch should produce es_dim_mismatch code.""" + response = { + "errors": True, + "items": [ + { + "index": { + "error": { + "type": "illegal_argument_exception", + "reason": "field [embedding] has different number of dimensions than vector", + "caused_by": {"reason": "dense_vector different number of dimensions"}, + } + } + } + ], + } + + with pytest.raises(Exception) as exc_info: + elasticsearch_core_instance._handle_bulk_errors(response) + + payload = str(exc_info.value) + assert "es_dim_mismatch" in payload + assert "Bulk indexing failed" in payload + def test_bulk_operation_context(elasticsearch_core_instance): """Test bulk operation context manager.""" with patch.object(elasticsearch_core_instance, '_apply_bulk_settings') as mock_apply, \ diff --git a/test/sdk/vector_database/test_elasticsearch_core_coverage.py b/test/sdk/vector_database/test_elasticsearch_core_coverage.py index f307c9d84..757bbc566 100644 --- a/test/sdk/vector_database/test_elasticsearch_core_coverage.py +++ b/test/sdk/vector_database/test_elasticsearch_core_coverage.py @@ -215,8 +215,9 @@ def test_handle_bulk_errors_with_fatal_error(self, vdb_core): } ] } - vdb_core._handle_bulk_errors(response) - # Should log error but not raise exception + with pytest.raises(Exception) as exc_info: + vdb_core._handle_bulk_errors(response) + assert "Bulk indexing failed" in str(exc_info.value) def test_handle_bulk_errors_with_caused_by(self, vdb_core): """Test _handle_bulk_errors with caused_by information""" @@ -237,8 +238,10 @@ def test_handle_bulk_errors_with_caused_by(self, vdb_core): } ] } - vdb_core._handle_bulk_errors(response) - # Should log both main error and caused_by error + with pytest.raises(Exception) as exc_info: + vdb_core._handle_bulk_errors(response) + assert "Invalid argument" in str(exc_info.value) + assert "JSON parsing failed" in str(exc_info.value) def test_delete_documents_success(self, vdb_core): """Test delete_documents successful case""" @@ -407,16 +410,18 @@ def test_large_batch_insert_bulk_exception(self, vdb_core): mock_embedding_model = MagicMock() mock_embedding_model.get_embeddings.return_value = [[0.1]] - result = vdb_core._large_batch_insert("idx", [{"content": "body"}], 1, "content", mock_embedding_model) - assert result == 0 + with pytest.raises(Exception) as exc_info: + vdb_core._large_batch_insert("idx", [{"content": "body"}], 1, "content", mock_embedding_model) + assert "bulk error" in str(exc_info.value) def test_large_batch_insert_preprocess_exception(self, vdb_core): """Ensure outer exception handler returns zero on preprocess failure.""" vdb_core._preprocess_documents = MagicMock(side_effect=Exception("fail")) mock_embedding_model = MagicMock() - result = vdb_core._large_batch_insert("idx", [{"content": "body"}], 10, "content", mock_embedding_model) - assert result == 0 + with pytest.raises(Exception) as exc_info: + vdb_core._large_batch_insert("idx", [{"content": "body"}], 10, "content", mock_embedding_model) + assert "fail" in str(exc_info.value) def test_count_documents_success(self, vdb_core): """Ensure count_documents returns ES count.""" @@ -672,8 +677,9 @@ def test_small_batch_insert_exception(self, vdb_core): mock_embedding_model = MagicMock() documents = [{"content": "test content", "title": "test"}] - result = vdb_core._small_batch_insert("test_index", documents, "content", mock_embedding_model) - assert result == 0 + with pytest.raises(Exception) as exc_info: + vdb_core._small_batch_insert("test_index", documents, "content", mock_embedding_model) + assert "Preprocess error" in str(exc_info.value) def test_large_batch_insert_success(self, vdb_core): """Test _large_batch_insert successful case"""
+ {t("market.install.rename.success", "All agent name conflicts have been resolved. You can proceed to the next step.")} +
+ {t("market.install.rename.partialSuccess", "Some agents have been successfully renamed.")} +
+ {t("market.install.rename.warning", "The agent name or display name conflicts with existing agents. Please rename to proceed.")} +
+ {t("market.install.rename.oneClickDesc", "You can manually edit the names, or click one-click rename to let the selected model regenerate names for all conflicted agents.")} +
+ {t("market.install.rename.note", "Note: If you proceed without renaming, the agent will be created but marked as unavailable due to name conflicts. You can rename it later in the agent list.")} +
+ {t("market.install.rename.agentResolved", "This agent's name conflict has been resolved.")} +
+ {t("market.install.rename.conflictAgents", "Conflicting agents:")} +
= ({ return ; }; + const ImageResolver: React.FC<{ src?: string; alt?: string | null }> = ({ + src, + alt, + }) => { + const resolvedSrc = useResolvedS3Media( + typeof src === "string" ? src : undefined, + resolveS3Media + ); + + if (!enableMultimodal) { + return renderMediaFallback(src, alt); + } + + if (!resolvedSrc) { + return renderMediaFallback(src, alt); + } + + if (isVideoUrl(resolvedSrc)) { + return renderVideoElement({ src: resolvedSrc, alt }); + } + + return ; + }; + // Modified processText function logic const processText = (text: string) => { if (typeof text !== "string") return text; @@ -865,37 +1186,7 @@ export const MarkdownRenderer: React.FC = ({ return ; } if (!inline) { - return ( - - - - {match[1]} - - - - - - {codeContent} - - - - ); + return ; } } } catch (error) { @@ -908,21 +1199,9 @@ export const MarkdownRenderer: React.FC = ({ ); }, // Image - img: ({ src, alt }: any) => { - if (!enableMultimodal) { - return renderMediaFallback(src, alt); - } - - if (isVideoUrl(src)) { - return renderVideoElement({ src, alt }); - } - - if (!src || typeof src !== "string") { - return null; - } - - return ; - }, + img: ({ src, alt }: any) => ( + + ), // Video video: ({ children, ...props }: any) => { const directSrc = props?.src; diff --git a/frontend/const/chatConfig.ts b/frontend/const/chatConfig.ts index df7b65c92..73cd19aed 100644 --- a/frontend/const/chatConfig.ts +++ b/frontend/const/chatConfig.ts @@ -111,6 +111,7 @@ messageTypes: { // Content type constants for last content type tracking contentTypes: { MODEL_OUTPUT: "model_output" as const, + MODEL_OUTPUT_CODE: "model_output_code" as const, PARSING: "parsing" as const, EXECUTION: "execution" as const, AGENT_NEW_RUN: "agent_new_run" as const, diff --git a/frontend/const/marketConfig.ts b/frontend/const/marketConfig.ts new file mode 100644 index 000000000..6de8d1f48 --- /dev/null +++ b/frontend/const/marketConfig.ts @@ -0,0 +1,36 @@ +// ========== Market Configuration Constants ========== + +/** + * Default icons for market agent categories + * Maps category name field to their corresponding icons + */ +export const MARKET_CATEGORY_ICONS: Record = { + research: "🔬", + content: "✍️", + development: "💻", + business: "📈", + automation: "⚙️", + education: "📚", + communication: "💬", + data: "📊", + creative: "🎨", + other: "📦", +} as const; + +/** + * Get icon for a category by name field + * @param categoryName - Category name field (e.g., "research", "content") + * @param fallbackIcon - Fallback icon if category not found (default: 📦) + * @returns Icon emoji string + */ +export function getCategoryIcon( + categoryName: string | null | undefined, + fallbackIcon: string = "📦" +): string { + if (!categoryName) { + return fallbackIcon; + } + + return MARKET_CATEGORY_ICONS[categoryName] || fallbackIcon; +} + diff --git a/frontend/hooks/useAgentImport.md b/frontend/hooks/useAgentImport.md deleted file mode 100644 index 52b14aa78..000000000 --- a/frontend/hooks/useAgentImport.md +++ /dev/null @@ -1,245 +0,0 @@ -# useAgentImport Hook - -Unified agent import hook for handling agent imports across the application. - -## Overview - -This hook provides a consistent interface for importing agents from different sources: -- File upload (used in Agent Development and Agent Space) -- Direct data (used in Agent Market) - -All import operations ultimately call the same backend `/agent/import` endpoint. - -## Usage - -### Basic Import - -```typescript -import { useAgentImport } from "@/hooks/useAgentImport"; - -function MyComponent() { - const { isImporting, importFromFile, importFromData, error } = useAgentImport({ - onSuccess: () => { - console.log("Import successful!"); - }, - onError: (error) => { - console.error("Import failed:", error); - }, - }); - - // ... -} -``` - -### Import from File (SubAgentPool, SpaceContent) - -```typescript -const handleFileImport = async (file: File) => { - try { - await importFromFile(file); - // Success handled by onSuccess callback - } catch (error) { - // Error handled by onError callback - } -}; - -// In file input handler - { - const file = e.target.files?.[0]; - if (file) { - handleFileImport(file); - } - }} -/> -``` - -### Import from Data (Market) - -```typescript -const handleMarketImport = async (agentDetails: MarketAgentDetail) => { - // Prepare import data from agent details - const importData = { - agent_id: agentDetails.agent_id, - agent_info: agentDetails.agent_json.agent_info, - mcp_info: agentDetails.agent_json.mcp_info, - }; - - try { - await importFromData(importData); - // Success handled by onSuccess callback - } catch (error) { - // Error handled by onError callback - } -}; -``` - -## Integration Examples - -### 1. SubAgentPool Component - -```typescript -// In SubAgentPool.tsx -import { useAgentImport } from "@/hooks/useAgentImport"; - -export default function SubAgentPool({ onImportSuccess }: Props) { - const { isImporting, importFromFile } = useAgentImport({ - onSuccess: () => { - message.success(t("agent.import.success")); - onImportSuccess?.(); - }, - onError: (error) => { - message.error(error.message); - }, - }); - - const handleImportClick = () => { - const input = document.createElement("input"); - input.type = "file"; - input.accept = ".json"; - input.onchange = async (e) => { - const file = (e.target as HTMLInputElement).files?.[0]; - if (file) { - await importFromFile(file); - } - }; - input.click(); - }; - - return ( - - {isImporting ? t("importing") : t("import")} - - ); -} -``` - -### 2. SpaceContent Component - -```typescript -// In SpaceContent.tsx -import { useAgentImport } from "@/hooks/useAgentImport"; - -export function SpaceContent({ onRefresh }: Props) { - const { isImporting, importFromFile } = useAgentImport({ - onSuccess: () => { - message.success(t("space.import.success")); - onRefresh(); // Reload agent list - }, - }); - - const handleImportAgent = () => { - const input = document.createElement("input"); - input.type = "file"; - input.accept = ".json"; - input.onchange = async (e) => { - const file = (e.target as HTMLInputElement).files?.[0]; - if (file) { - await importFromFile(file); - } - }; - input.click(); - }; - - return ( - - {isImporting ? "Importing..." : "Import Agent"} - - ); -} -``` - -### 3. AgentInstallModal (Market) - -```typescript -// In AgentInstallModal.tsx -import { useAgentImport } from "@/hooks/useAgentImport"; - -export default function AgentInstallModal({ - agentDetails, - onComplete -}: Props) { - const { isImporting, importFromData } = useAgentImport({ - onSuccess: () => { - message.success(t("market.install.success")); - onComplete(); - }, - }); - - const handleInstall = async () => { - // Prepare configured data - const importData = prepareImportData(agentDetails, userConfig); - await importFromData(importData); - }; - - return ( - - Install - - ); -} -``` - -## API Reference - -### Parameters - -```typescript -interface UseAgentImportOptions { - onSuccess?: () => void; // Called on successful import - onError?: (error: Error) => void; // Called on import error - forceImport?: boolean; // Force import even if duplicate names exist -} -``` - -### Return Value - -```typescript -interface UseAgentImportResult { - isImporting: boolean; // Import in progress - importFromFile: (file: File) => Promise; // Import from file - importFromData: (data: ImportAgentData) => Promise; // Import from data - error: Error | null; // Last error (if any) -} -``` - -### Data Structure - -```typescript -interface ImportAgentData { - agent_id: number; - agent_info: Record; - mcp_info?: Array<{ - mcp_server_name: string; - mcp_url: string; - }>; -} -``` - -## Error Handling - -The hook handles errors in two ways: - -1. **Via onError callback** - Preferred method for user-facing error messages -2. **Via thrown exceptions** - For custom error handling in specific cases - -Both approaches are supported to allow flexibility in different use cases. - -## Implementation Notes - -- File content is read as text and parsed as JSON -- Data structure validation is performed before calling the backend -- The backend `/agent/import` endpoint is called with the prepared data -- All logging uses the centralized `log` utility from `@/lib/logger` - diff --git a/frontend/hooks/useAgentImport.ts b/frontend/hooks/useAgentImport.ts index f0f33add4..0aff99e82 100644 --- a/frontend/hooks/useAgentImport.ts +++ b/frontend/hooks/useAgentImport.ts @@ -1,5 +1,9 @@ import { useState } from "react"; -import { importAgent } from "@/services/agentConfigService"; +import { + checkAgentNameConflictBatch, + importAgent, + regenerateAgentNameBatch, +} from "@/services/agentConfigService"; import log from "@/lib/logger"; export interface ImportAgentData { @@ -15,6 +19,19 @@ export interface UseAgentImportOptions { onSuccess?: () => void; onError?: (error: Error) => void; forceImport?: boolean; + /** + * Optional: handle name/display_name conflicts before import + * Caller can resolve by returning new name or choosing to continue/terminate + */ + onNameConflictResolve?: (payload: { + name: string; + displayName?: string; + conflictAgents: Array<{ id: string; name?: string; display_name?: string }>; + regenerateWithLLM: () => Promise<{ + name?: string; + displayName?: string; + }>; + }) => Promise<{ proceed: boolean; name?: string; displayName?: string }>; } export interface UseAgentImportResult { @@ -111,6 +128,30 @@ export function useAgentImport( * Core import logic - calls backend API */ const importAgentData = async (data: ImportAgentData): Promise => { + // Step 1: check name/display name conflicts before import (only check main agent name and display name) + const mainAgent = data.agent_info?.[String(data.agent_id)]; + if (mainAgent?.name) { + const conflictHandled = await ensureNameNotDuplicated( + mainAgent.name, + mainAgent.display_name, + mainAgent.description || mainAgent.business_description + ); + + if (!conflictHandled.proceed) { + throw new Error( + "Agent name/display name conflicts with existing agent; import cancelled." + ); + } + + // if user chooses to modify name, write back to import data + if (conflictHandled.name) { + mainAgent.name = conflictHandled.name; + } + if (conflictHandled.displayName) { + mainAgent.display_name = conflictHandled.displayName; + } + } + const result = await importAgent(data, { forceImport }); if (!result.success) { @@ -142,6 +183,80 @@ export function useAgentImport( }); }; + /** + * Frontend side name conflict validation logic + */ + const ensureNameNotDuplicated = async ( + name: string, + displayName?: string, + taskDescription?: string + ): Promise<{ proceed: boolean; name?: string; displayName?: string }> => { + try { + const checkResp = await checkAgentNameConflictBatch({ + items: [ + { + name, + display_name: displayName, + }, + ], + }); + if (!checkResp.success || !Array.isArray(checkResp.data)) { + log.warn("Skip name conflict check due to fetch failure"); + return { proceed: true }; + } + + const first = checkResp.data[0] || {}; + const { name_conflict, display_name_conflict, conflict_agents } = first; + + if (!name_conflict && !display_name_conflict) { + return { proceed: true }; + } + + const regenerateWithLLM = async () => { + const regenResp = await regenerateAgentNameBatch({ + items: [ + { + name, + display_name: displayName, + task_description: taskDescription, + }, + ], + }); + if (!regenResp.success || !Array.isArray(regenResp.data) || !regenResp.data[0]) { + throw new Error("Failed to regenerate agent name"); + } + const item = regenResp.data[0]; + return { + name: item.name, + displayName: item.display_name ?? displayName, + }; + }; + + // let caller decide how to handle conflicts (e.g. show a dialog to let user choose whether to let LLM rename) + if (options.onNameConflictResolve) { + return await options.onNameConflictResolve({ + name, + displayName, + conflictAgents: (conflict_agents || []).map((c: any) => ({ + id: String(c.agent_id ?? c.id), + name: c.name, + display_name: c.display_name, + })), + regenerateWithLLM, + }); + } + + // default behavior: directly call backend to rename to keep import available + const regenerated = await regenerateWithLLM(); + return { proceed: true, ...regenerated }; + } catch (error) { + // if callback throws an error, prevent import + throw error instanceof Error + ? error + : new Error("Name conflict handling failed"); + } + }; + return { isImporting, importFromFile, diff --git a/frontend/hooks/useMemory.ts b/frontend/hooks/useMemory.ts index 03ac72dd8..5bb1a1bc9 100644 --- a/frontend/hooks/useMemory.ts +++ b/frontend/hooks/useMemory.ts @@ -483,24 +483,3 @@ export function useMemory({ visible, currentUserId, currentTenantId, message }: handleDeleteMemory, } } - -// expose memory notification indicator to ChatHeader -export function useMemoryIndicator(modalVisible: boolean) { - const [hasNewMemory, setHasNewMemory] = useState(false) - - // Reset indicator when memory modal is opened - useEffect(() => { - if (modalVisible) { - setHasNewMemory(false) - } - }, [modalVisible]) - - // Listen for backend event that notifies new memory addition - useEffect(() => { - const handler = () => setHasNewMemory(true) - window.addEventListener("nexent:new-memory", handler as EventListener) - return () => window.removeEventListener("nexent:new-memory", handler as EventListener) - }, []) - - return hasNewMemory -} \ No newline at end of file diff --git a/frontend/public/locales/en/common.json b/frontend/public/locales/en/common.json index 5ee25a7b8..b8681a78a 100644 --- a/frontend/public/locales/en/common.json +++ b/frontend/public/locales/en/common.json @@ -285,6 +285,8 @@ "agent.contextMenu.export": "Export", "agent.contextMenu.delete": "Delete", + "agent.contextMenu.copy": "Copy", + "agent.copySuffix": "Copy", "agent.info.title": "Agent Information", "agent.info.name.error.empty": "Name cannot be empty", "agent.info.name.error.format": "Name can only contain letters, numbers and underscores, and must start with a letter or underscore", @@ -293,6 +295,9 @@ "agent.namePlaceholder": "Please enter agent variable name", "agent.displayName": "Agent Name", "agent.displayNamePlaceholder": "Please enter agent name", + "agent.author": "Author", + "agent.authorPlaceholder": "Please enter author name (optional)", + "agent.author.hint": "Default: {{email}}", "agent.description": "Agent Description", "agent.descriptionPlaceholder": "Please enter agent description", "agent.detailContent.title": "Agent Detail Content", @@ -413,7 +418,6 @@ "toolPool.error.requiredFields": "The following required fields are not filled: {{fields}}", "toolPool.tooltip.functionGuide": "1. For local knowledge base search functionality, please enable the knowledge_base_search tool;\n2. For text file parsing functionality, please enable the analyze_text_file tool;\n3. For image parsing functionality, please enable the analyze_image tool.", - "tool.message.unavailable": "This tool is currently unavailable and cannot be selected", "tool.error.noMainAgentId": "Main agent ID is not set, cannot update tool status", "tool.error.configFetchFailed": "Failed to get tool configuration", @@ -502,6 +506,7 @@ "document.summary.modelPlaceholder": "Select Model", "document.status.creating": "Creating...", "document.status.loadingList": "Loading document list...", + "document.status.waitingForTask": "Waiting for task creation...", "document.input.knowledgeBaseName": "Please enter knowledge base name", "document.button.details": "Details", "document.button.overview": "Overview", @@ -522,6 +527,24 @@ "document.status.completed": "Ready", "document.status.processFailed": "Process Failed", "document.status.forwardFailed": "Forward Failed", + "document.progress.chunksProcessed": "Processed {{processed}}/{{total}} chunks ({{percent}}%)", + "document.error.reason": "Error Reason", + "document.error.suggestion": "Suggestion", + "document.error.noReason": "No error reason available", + "document.error.code.ray_init_failed.message": "Failed to initialize Ray cluster", + "document.error.code.ray_init_failed.suggestion": "Please upgrade to the latest image version and redeploy.", + "document.error.code.no_valid_chunks.message": "The data processing kernel could not extract valid text from the document", + "document.error.code.no_valid_chunks.suggestion": "Please ensure the document format is supported and the content is not purely images.", + "document.error.code.vector_service_busy.message": "Vectorization model service is busy and cannot return vectors", + "document.error.code.vector_service_busy.suggestion": "Please switch the model service provider or try again later.", + "document.error.code.es_bulk_failed.message": "Failed to write vectors into the database", + "document.error.code.es_bulk_failed.suggestion": "Please ensure the Elasticsearch data path has sufficient disk space and write permissions.", + "document.error.code.es_dim_mismatch.message": "Embedding dimension does not match the Elasticsearch mapping", + "document.error.code.es_dim_mismatch.suggestion": "Please delete all embedding models and add the model again to try again.", + "document.error.code.embedding_chunks_exceed_limit.message": "The current chunk count exceeds the embedding model concurrency limit", + "document.error.code.embedding_chunks_exceed_limit.suggestion": "Please increase the chunk size to reduce the number of chunks and try again.", + "document.error.code.unsupported_file_format.message": "Unsupported line breaks detected in the document", + "document.error.code.unsupported_file_format.suggestion": "Please convert all line breaks to LF format and try again", "document.modal.deleteConfirm.title": "Confirm Delete Document", "document.modal.deleteConfirm.content": "Are you sure you want to delete this document? This action cannot be undone.", "document.message.noFiles": "Please select files first", @@ -655,6 +678,7 @@ "model.group.silicon": "Silicon Flow Models", "model.group.custom": "Custom Models", "model.status.tooltip": "Click to verify connectivity", + "model.dialog.embeddingConfig.title": "Edit Embedding Model: {{modelName}}", "appConfig.appName.label": "Application Name", "appConfig.appName.placeholder": "Please enter your application name", @@ -699,9 +723,14 @@ "modelConfig.button.addCustomModel": "Add Model", "modelConfig.button.editCustomModel": "Edit or Delete Model", "modelConfig.button.checkConnectivity": "Check Model Connectivity", + "modelConfig.button.sync": "Sync", + "modelConfig.button.add": "Add", + "modelConfig.button.edit": "Edit", + "modelConfig.button.check": "Check", "modelConfig.slider.chunkingSize": "Chunk Size", "modelConfig.slider.expectedChunkSize": "Expected Chunk Size", "modelConfig.slider.maximumChunkSize": "Maximum Chunk Size", + "modelConfig.input.chunkingBatchSize": "Concurrent Request Count", "businessLogic.title": "Describe how should this agent work", "businessLogic.placeholder": "Please describe your business scenario and requirements...", @@ -841,7 +870,7 @@ "mcpConfig.modal.updatingTools": "Updating tools list...", "mcpConfig.addServer.title": "Add MCP Server", "mcpConfig.addServer.namePlaceholder": "Server name", - "mcpConfig.addServer.urlPlaceholder": "Server URL (e.g.: http://localhost:3001/sse), currently only SSE protocol supported", + "mcpConfig.addServer.urlPlaceholder": "Server URL (e.g.: http://localhost:3001/mcp), currently supports sse and streamable-http protocols", "mcpConfig.addServer.button.add": "Add", "mcpConfig.addServer.button.updating": "Updating...", "mcpConfig.serverList.title": "Configured MCP Servers", @@ -909,6 +938,12 @@ "agentConfig.agents.createSubAgentIdFailed": "Failed to fetch creating sub agent ID, please try again later", "agentConfig.agents.detailsFetchFailed": "Failed to fetch agent details, please try again later", "agentConfig.agents.callRelationshipFetchFailed": "Failed to fetch agent call relationship, please try again later", + "agentConfig.agents.defaultDisplayName": "Agent", + "agentConfig.agents.copyConfirmTitle": "Confirm Copy", + "agentConfig.agents.copyConfirmContent": "Create a duplicate of {{name}}?", + "agentConfig.agents.copySuccess": "Agent copied successfully", + "agentConfig.agents.copyUnavailableTools": "Ignored {{count}} unavailable tools: {{names}}", + "agentConfig.agents.copyFailed": "Failed to copy Agent", "agentConfig.tools.refreshFailedDebug": "Failed to refresh tools list:", "agentConfig.agents.detailsLoadFailed": "Failed to load Agent details:", "agentConfig.agents.importFailed": "Failed to import Agent:", @@ -1117,6 +1152,7 @@ "market.category.all": "All", "market.category.other": "Other", "market.download": "Download", + "market.by": "By {{author}}", "market.downloading": "Downloading agent...", "market.downloadSuccess": "Agent downloaded successfully!", "market.downloadFailed": "Failed to download agent", @@ -1125,7 +1161,7 @@ "market.totalAgents": "Total {{total}} agents", "market.error.loadCategories": "Failed to load categories", "market.error.loadAgents": "Failed to load agents", - + "market.detail.title": "Agent Details", "market.detail.subtitle": "Complete information and configuration", "market.detail.tabs.basic": "Basic Info", @@ -1136,6 +1172,7 @@ "market.detail.id": "Agent ID", "market.detail.name": "Name", "market.detail.displayName": "Display Name", + "market.detail.author": "Author", "market.detail.description": "Description", "market.detail.businessDescription": "Business Description", "market.detail.category": "Category", @@ -1166,6 +1203,7 @@ "market.detail.viewDetails": "View Details", "market.install.title": "Install Agent", + "market.install.step.rename": "Rename Agent", "market.install.step.model": "Select Model", "market.install.step.config": "Configure Fields", "market.install.step.mcp": "MCP Servers", @@ -1203,7 +1241,31 @@ "market.install.error.mcpInstall": "Failed to install MCP server", "market.install.error.invalidData": "Invalid agent data", "market.install.error.installFailed": "Failed to install agent", + "market.install.error.noModelForRegeneration": "No available model for name regeneration", + "market.install.error.nameRegenerationFailed": "Failed to regenerate name", + "market.install.error.nameRequired": "Agent name is required", + "market.install.error.nameRequiredForAgent": "Agent name is required for {agent}", + "market.install.checkingName": "Checking agent name...", + "market.install.rename.warning": "The agent name or display name conflicts with existing agents. Please rename to proceed.", + "market.install.rename.conflictAgents": "Conflicting agents:", + "market.install.rename.name": "Agent Name", + "market.install.rename.regenerateWithLLM": "Regenerate with LLM", + "market.install.rename.regenerate": "Regenerate", + "market.install.rename.model": "Model for Regeneration", + "market.install.rename.modelPlaceholder": "Select a model", + "market.install.error.modelRequiredForRegeneration": "Please select a model first", + "market.install.rename.nameHint": "Original: {name}", + "market.install.rename.displayName": "Display Name", + "market.install.rename.displayNameHint": "Original: {name}", + "market.install.rename.note": "Note: If you proceed without renaming, the agent will be created but marked as unavailable due to name conflicts. You can rename it later in the agent list.", + "market.install.rename.oneClickDesc": "You can edit names manually, or use one-click rename to let the LLM generate new names for all conflicted agents.", + "market.install.rename.oneClick": "One-click Rename", + "market.install.rename.success": "All agent name conflicts have been resolved. You can proceed to the next step.", + "market.install.rename.partialSuccess": "Some agents have been successfully renamed.", + "market.install.rename.agentResolved": "This agent's name conflict has been resolved.", "market.install.success.mcpInstalled": "MCP server installed successfully", + "market.install.success.nameRegenerated": "Agent name regenerated successfully", + "market.install.success.nameRegeneratedAndResolved": "Agent names regenerated successfully and all conflicts resolved", "market.install.info.notImplemented": "Installation will be implemented in next phase", "market.install.success": "Agent installed successfully!", "market.error.fetchDetailFailed": "Failed to load agent details", @@ -1218,7 +1280,7 @@ "market.error.server.description": "The market server encountered an error. Our team has been notified. Please try again later.", "market.error.unknown.title": "Something Went Wrong", "market.error.unknown.description": "An unexpected error occurred. Please try again.", - + "common.loading": "Loading", "common.save": "Save", "common.cancel": "Cancel", diff --git a/frontend/public/locales/zh/common.json b/frontend/public/locales/zh/common.json index 65d80dacf..c0f8d851a 100644 --- a/frontend/public/locales/zh/common.json +++ b/frontend/public/locales/zh/common.json @@ -286,6 +286,8 @@ "agent.contextMenu.export": "导出", "agent.contextMenu.delete": "删除", + "agent.contextMenu.copy": "复制", + "agent.copySuffix": "副本", "agent.info.title": "Agent信息", "agent.info.name.error.empty": "名称不能为空", "agent.info.name.error.format": "名称只能包含字母、数字和下划线,且必须以字母或下划线开头", @@ -294,6 +296,9 @@ "agent.namePlaceholder": "请输入Agent变量名", "agent.displayName": "Agent名称", "agent.displayNamePlaceholder": "请输入Agent名称", + "agent.author": "作者", + "agent.authorPlaceholder": "请输入作者名称(可选)", + "agent.author.hint": "默认:{{email}}", "agent.description": "Agent描述", "agent.descriptionPlaceholder": "请输入Agent描述", "agent.detailContent.title": "Agent详细内容", @@ -370,7 +375,7 @@ "subAgentPool.tooltip.duplicateNameDisabled": "该智能体因与其他智能体同名而被禁用,请修改名称后使用", "subAgentPool.message.duplicateNameDisabled": "该智能体因与其他智能体同名而被禁用,请修改名称后使用", - "toolConfig.title.paramConfig": "参数配置", + "toolConfig.title.paramConfig": "配置参数", "toolConfig.message.loadError": "加载工具配置失败", "toolConfig.message.loadErrorUseDefault": "加载工具配置失败,使用默认配置", "toolConfig.message.saveSuccess": "工具配置保存成功", @@ -414,7 +419,6 @@ "toolPool.error.requiredFields": "以下必填字段未填写: {{fields}}", "toolPool.tooltip.functionGuide": "1. 本地知识库检索功能,请启用knowledge_base_search工具;\n2. 文本文件解析功能,请启用analyze_text_file工具;\n3. 图片解析功能,请启用analyze_image工具。", - "tool.message.unavailable": "该工具当前不可用,无法选择", "tool.error.noMainAgentId": "主代理ID未设置,无法更新工具状态", "tool.error.configFetchFailed": "获取工具配置失败", @@ -503,6 +507,7 @@ "document.summary.modelPlaceholder": "选择模型", "document.status.creating": "创建中...", "document.status.loadingList": "正在加载文档列表...", + "document.status.waitingForTask": "正在等待任务创建...", "document.input.knowledgeBaseName": "请输入知识库名称", "document.button.details": "详细内容", "document.button.overview": "概览", @@ -523,6 +528,24 @@ "document.status.completed": "已就绪", "document.status.processFailed": "解析失败", "document.status.forwardFailed": "入库失败", + "document.progress.chunksProcessed": "已处理 {{processed}}/{{total}} 个切片 ({{percent}}%)", + "document.error.reason": "错误原因", + "document.error.suggestion": "建议", + "document.error.noReason": "暂无错误原因", + "document.error.code.ray_init_failed.message": "Ray集群初始化失败", + "document.error.code.ray_init_failed.suggestion": "请升级到最新版本并尝试重新部署", + "document.error.code.no_valid_chunks.message": "数据处理内核无法从文档中提取有效文本", + "document.error.code.no_valid_chunks.suggestion": "请确保文档内容非纯图像", + "document.error.code.vector_service_busy.message": "向量化模型服务繁忙,无法获取文本向量", + "document.error.code.vector_service_busy.suggestion": "请更换模型服务提供商,或稍后重试", + "document.error.code.es_bulk_failed.message": "向量录入数据库错误", + "document.error.code.es_bulk_failed.suggestion": "请确保Elasticsearch路径拥有完整写入权限,且存储空间与内存充足", + "document.error.code.es_dim_mismatch.message": "向量化模型维度与Elasticsearch维度不匹配", + "document.error.code.es_dim_mismatch.suggestion": "建议删除所有向量化模型后再添加模型重试", + "document.error.code.embedding_chunks_exceed_limit.message": "当前切片数量超过向量化模型并行度", + "document.error.code.embedding_chunks_exceed_limit.suggestion": "请增加切片大小以减少切片数量后再试", + "document.error.code.unsupported_file_format.message": "检测到当前文档中存在不支持的换行符", + "document.error.code.unsupported_file_format.suggestion": "建议统一转换为LF换行符再试", "document.modal.deleteConfirm.title": "确认删除文档", "document.modal.deleteConfirm.content": "确定要删除这个文档吗?删除后无法恢复。", "document.message.noFiles": "请先选择文件", @@ -655,6 +678,7 @@ "model.group.custom": "自定义模型", "model.status.tooltip": "点击可验证连通性", "model.dialog.success.updateSuccess": "更新成功", + "model.dialog.embeddingConfig.title": "修改向量模型: {{modelName}}", "appConfig.appName.label": "应用名称", "appConfig.appName.placeholder": "请输入您的应用名称", @@ -699,9 +723,14 @@ "modelConfig.button.addCustomModel": "添加模型", "modelConfig.button.editCustomModel": "修改或删除模型", "modelConfig.button.checkConnectivity": "检查模型连通性", + "modelConfig.button.sync": "同步", + "modelConfig.button.add": "添加", + "modelConfig.button.edit": "修改", + "modelConfig.button.check": "检查", "modelConfig.slider.chunkingSize": "文档切片大小", "modelConfig.slider.expectedChunkSize": "期望切片大小", "modelConfig.slider.maximumChunkSize": "最大切片大小", + "modelConfig.input.chunkingBatchSize": "单次请求切片量", "businessLogic.title": "描述 Agent 应该如何工作", "businessLogic.placeholder": "请描述您的业务场景和需求...", @@ -841,7 +870,7 @@ "mcpConfig.modal.updatingTools": "正在更新工具列表...", "mcpConfig.addServer.title": "添加MCP服务器", "mcpConfig.addServer.namePlaceholder": "服务器名称", - "mcpConfig.addServer.urlPlaceholder": "服务器URL (如: http://localhost:3001/sse),目前仅支持sse协议", + "mcpConfig.addServer.urlPlaceholder": "服务器URL (如: http://localhost:3001/mcp),目前支持sse和streamable-http协议", "mcpConfig.addServer.button.add": "添加", "mcpConfig.addServer.button.updating": "更新中...", "mcpConfig.serverList.title": "已配置的MCP服务器", @@ -909,6 +938,12 @@ "agentConfig.agents.createSubAgentIdFailed": "获取创建子Agent ID失败,请稍后重试", "agentConfig.agents.detailsFetchFailed": "获取Agent详情失败,请稍后重试", "agentConfig.agents.callRelationshipFetchFailed": "获取Agent调用关系失败,请稍后重试", + "agentConfig.agents.defaultDisplayName": "智能体", + "agentConfig.agents.copyConfirmTitle": "确认复制", + "agentConfig.agents.copyConfirmContent": "确定要复制 {{name}} 吗?", + "agentConfig.agents.copySuccess": "Agent复制成功", + "agentConfig.agents.copyUnavailableTools": "已忽略{{count}}个不可用工具:{{names}}", + "agentConfig.agents.copyFailed": "Agent复制失败", "agentConfig.tools.refreshFailedDebug": "刷新工具列表失败:", "agentConfig.agents.detailsLoadFailed": "加载Agent详情失败:", "agentConfig.agents.importFailed": "导入Agent失败:", @@ -1081,7 +1116,7 @@ "sidebar.memoryManagement": "记忆管理", "sidebar.userManagement": "用户管理", "sidebar.mcpToolsManagement": "MCP 工具", - "sidebar.monitoringManagement": "监控与运维", + "sidebar.monitoringManagement": "监控与运维", "market.comingSoon.title": "智能体市场即将推出", "market.comingSoon.description": "从我们的市场中发现并安装预构建的AI智能体。通过使用社区创建的解决方案节省时间。", @@ -1096,6 +1131,7 @@ "market.category.all": "全部", "market.category.other": "其他", "market.download": "下载", + "market.by": "作者:{{author}}", "market.downloading": "正在下载智能体...", "market.downloadSuccess": "智能体下载成功!", "market.downloadFailed": "下载智能体失败", @@ -1104,7 +1140,7 @@ "market.totalAgents": "共 {{total}} 个智能体", "market.error.loadCategories": "加载分类失败", "market.error.loadAgents": "加载智能体失败", - + "market.detail.title": "智能体详情", "market.detail.subtitle": "完整信息和配置", "market.detail.tabs.basic": "基础信息", @@ -1115,6 +1151,7 @@ "market.detail.id": "智能体 ID", "market.detail.name": "名称", "market.detail.displayName": "显示名称", + "market.detail.author": "作者", "market.detail.description": "描述", "market.detail.businessDescription": "业务描述", "market.detail.category": "分类", @@ -1145,6 +1182,7 @@ "market.detail.viewDetails": "查看详情", "market.install.title": "安装智能体", + "market.install.step.rename": "重命名智能体", "market.install.step.model": "选择模型", "market.install.step.config": "配置字段", "market.install.step.mcp": "MCP 服务器", @@ -1182,7 +1220,31 @@ "market.install.error.mcpInstall": "安装 MCP 服务器失败", "market.install.error.invalidData": "无效的智能体数据", "market.install.error.installFailed": "安装智能体失败", + "market.install.error.noModelForRegeneration": "没有可用的模型用于名称重新生成", + "market.install.error.nameRegenerationFailed": "重新生成名称失败", + "market.install.error.nameRequired": "智能体名称为必填项", + "market.install.error.nameRequiredForAgent": "智能体 {agent} 的名称为必填项", + "market.install.checkingName": "正在检查智能体名称...", + "market.install.rename.warning": "智能体名称或显示名称与现有智能体冲突,请重命名以继续。", + "market.install.rename.conflictAgents": "冲突的智能体:", + "market.install.rename.name": "智能体名称", + "market.install.rename.regenerateWithLLM": "使用 LLM 重新生成", + "market.install.rename.regenerate": "重新生成", + "market.install.rename.model": "用于重新生成名称的模型", + "market.install.rename.modelPlaceholder": "选择一个模型", + "market.install.error.modelRequiredForRegeneration": "请先选择一个模型", + "market.install.rename.nameHint": "原始名称:{name}", + "market.install.rename.displayName": "显示名称", + "market.install.rename.displayNameHint": "原始名称:{name}", + "market.install.rename.note": "注意:如果您不重命名就继续,智能体将被创建但由于名称冲突会被标记为不可用。您可以在智能体列表中稍后重命名。", + "market.install.rename.oneClickDesc": "可手动修改名称,或一键重命名使用大模型为所有冲突智能体生成新名称。", + "market.install.rename.oneClick": "一键重命名", + "market.install.rename.success": "所有智能体名称冲突已解决。您可以继续下一步。", + "market.install.rename.partialSuccess": "部分智能体已成功重命名。", + "market.install.rename.agentResolved": "此智能体的名称冲突已解决。", "market.install.success.mcpInstalled": "MCP 服务器安装成功", + "market.install.success.nameRegenerated": "智能体名称重新生成成功", + "market.install.success.nameRegeneratedAndResolved": "智能体名称重新生成成功,且所有冲突已解决", "market.install.info.notImplemented": "安装功能将在下一阶段实现", "market.install.success": "智能体安装成功!", "market.error.fetchDetailFailed": "加载智能体详情失败", @@ -1211,14 +1273,14 @@ "mcpTools.comingSoon.feature2": "同步、查看和组织 MCP 工具列表", "mcpTools.comingSoon.feature3": "监控 MCP 连接状态和使用情况", "mcpTools.comingSoon.badge": "即将推出", - + "monitoring.comingSoon.title": "监控与运维中心即将推出", "monitoring.comingSoon.description": "面向智能体的统一监控与运维中心,用于实时跟踪健康状态、性能指标与异常事件。", "monitoring.comingSoon.feature1": "监控智能体健康状态、延迟与错误率", "monitoring.comingSoon.feature2": "查看并筛选智能体运行日志和历史任务", "monitoring.comingSoon.feature3": "配置告警策略与关键事件的运维操作", "monitoring.comingSoon.badge": "即将推出", - + "common.loading": "加载中", "common.save": "保存", "common.cancel": "取消", diff --git a/frontend/services/agentConfigService.ts b/frontend/services/agentConfigService.ts index f7f084f6b..3cff1e884 100644 --- a/frontend/services/agentConfigService.ts +++ b/frontend/services/agentConfigService.ts @@ -116,6 +116,7 @@ export const fetchAgentList = async () => { name: agent.name, display_name: agent.display_name || agent.name, description: agent.description, + author: agent.author, is_available: agent.is_available, unavailable_reasons: agent.unavailable_reasons || [], })); @@ -326,7 +327,8 @@ export const updateAgent = async ( businessLogicModelName?: string, businessLogicModelId?: number, enabledToolIds?: number[], - relatedAgentIds?: number[] + relatedAgentIds?: number[], + author?: string ) => { try { const response = await fetch(API_ENDPOINTS.agent.update, { @@ -350,6 +352,7 @@ export const updateAgent = async ( business_logic_model_id: businessLogicModelId, enabled_tool_ids: enabledToolIds, related_agent_ids: relatedAgentIds, + author: author, }), }); @@ -485,6 +488,76 @@ export const importAgent = async ( } }; +/** + * check agent name/display_name duplication + * @param payload name/displayName to check + */ +export const checkAgentNameConflictBatch = async (payload: { + items: Array<{ name: string; display_name?: string; agent_id?: number }>; +}) => { + try { + const response = await fetch(API_ENDPOINTS.agent.checkNameBatch, { + method: "POST", + headers: getAuthHeaders(), + body: JSON.stringify(payload), + }); + + if (!response.ok) { + throw new Error(`Request failed: ${response.status}`); + } + + const data = await response.json(); + return { + success: true, + data, + message: "", + }; + } catch (error) { + log.error("Failed to check agent name conflict batch:", error); + return { + success: false, + data: null, + message: "agentConfig.agents.checkNameFailed", + }; + } +}; + +export const regenerateAgentNameBatch = async (payload: { + items: Array<{ + name: string; + display_name?: string; + task_description?: string; + language?: string; + agent_id?: number; + }>; +}) => { + try { + const response = await fetch(API_ENDPOINTS.agent.regenerateNameBatch, { + method: "POST", + headers: getAuthHeaders(), + body: JSON.stringify(payload), + }); + + if (!response.ok) { + throw new Error(`Request failed: ${response.status}`); + } + + const data = await response.json(); + return { + success: true, + data, + message: "", + }; + } catch (error) { + log.error("Failed to regenerate agent name batch:", error); + return { + success: false, + data: null, + message: "agentConfig.agents.regenerateNameFailed", + }; + } +}; + /** * search agent info by agent id * @param agentId agent id @@ -510,6 +583,7 @@ export const searchAgentInfo = async (agentId: number) => { name: data.name, display_name: data.display_name, description: data.description, + author: data.author, model: data.model_name, model_id: data.model_id, max_step: data.max_steps, @@ -587,6 +661,7 @@ export const fetchAllAgents = async () => { name: agent.name, display_name: agent.display_name || agent.name, description: agent.description, + author: agent.author, is_available: agent.is_available, })); diff --git a/frontend/services/api.ts b/frontend/services/api.ts index 0af193d52..20d89b6f2 100644 --- a/frontend/services/api.ts +++ b/frontend/services/api.ts @@ -37,6 +37,8 @@ export const API_ENDPOINTS = { `${API_BASE_URL}/agent/stop/${conversationId}`, export: `${API_BASE_URL}/agent/export`, import: `${API_BASE_URL}/agent/import`, + checkNameBatch: `${API_BASE_URL}/agent/check_name`, + regenerateNameBatch: `${API_BASE_URL}/agent/regenerate_name`, searchInfo: `${API_BASE_URL}/agent/search_info`, callRelationship: `${API_BASE_URL}/agent/call_relationship`, }, @@ -142,6 +144,11 @@ export const API_ENDPOINTS = { // File upload service upload: `${API_BASE_URL}/file/upload`, process: `${API_BASE_URL}/file/process`, + // Error info service + getErrorInfo: (indexName: string, pathOrUrl: string) => + `${API_BASE_URL}/indices/${indexName}/documents/${encodeURIComponent( + pathOrUrl + )}/error-info`, }, config: { save: `${API_BASE_URL}/config/save_config`, diff --git a/frontend/services/knowledgeBasePollingService.ts b/frontend/services/knowledgeBasePollingService.ts index 568205b21..b899d8bdf 100644 --- a/frontend/services/knowledgeBasePollingService.ts +++ b/frontend/services/knowledgeBasePollingService.ts @@ -11,8 +11,12 @@ class KnowledgeBasePollingService { private knowledgeBasePollingInterval: number = 1000; // 1 second private documentPollingInterval: number = 3000; // 3 seconds private maxKnowledgeBasePolls: number = 60; // Maximum 60 polling attempts - private maxDocumentPolls: number = 20; // Maximum 20 polling attempts + private maxDocumentPolls: number = 200; // Maximum 200 polling attempts (10 minutes for long-running tasks) private activeKnowledgeBaseId: string | null = null; // Record current active knowledge base ID + private pendingRequests: Map> = new Map(); + + // Debounce timers for batching multiple rapid requests + private debounceTimers: Map = new Map(); // Set current active knowledge base ID setActiveKnowledgeBase(kbId: string | null): void { @@ -29,11 +33,16 @@ class KnowledgeBasePollingService { // Initialize polling counter let pollCount = 0; + // Track if we're in extended polling mode (after initial timeout) + let isExtendedPolling = false; + // Define the polling logic function const pollDocuments = async () => { try { - // Increment polling counter - pollCount++; + // Increment polling counter only if not in extended polling mode + if (!isExtendedPolling) { + pollCount++; + } // If there is an active knowledge base and polling knowledge base doesn't match active one, stop polling if (this.activeKnowledgeBaseId !== null && this.activeKnowledgeBaseId !== kbId) { @@ -41,24 +50,28 @@ class KnowledgeBasePollingService { return; } - // If exceeded maximum polling count, handle timeout - if (pollCount > this.maxDocumentPolls) { - log.warn(`Document polling for knowledge base ${kbId} timed out after ${this.maxDocumentPolls} attempts`); - await this.handlePollingTimeout(kbId, 'document', callback); - // Push documents to UI + // Use request deduplication to avoid concurrent duplicate requests + let documents: Document[]; + const requestKey = `poll:${kbId}`; + + // Check if there's already a pending request for this KB + const pendingRequest = this.pendingRequests.get(requestKey); + if (pendingRequest) { + // Reuse existing request to avoid duplicate API calls + documents = await pendingRequest; + } else { + // Create new request and track it + const requestPromise = knowledgeBaseService.getAllFiles(kbId); + this.pendingRequests.set(requestKey, requestPromise); + try { - const documents = await knowledgeBaseService.getAllFiles(kbId); - this.triggerDocumentsUpdate(kbId, documents); - } catch (e) { - // Ignore error + documents = await requestPromise; + } finally { + // Clean up after request completes + this.pendingRequests.delete(requestKey); } - this.stopPolling(kbId); - return; } - // Get latest document status - const documents = await knowledgeBaseService.getAllFiles(kbId); - // Call callback function with latest documents first to ensure UI updates immediately callback(documents); @@ -67,6 +80,18 @@ class KnowledgeBasePollingService { NON_TERMINAL_STATUSES.includes(doc.status) ); + // If exceeded maximum polling count and still processing, switch to extended polling mode + if (pollCount > this.maxDocumentPolls && hasProcessingDocs && !isExtendedPolling) { + log.warn(`Document polling for knowledge base ${kbId} exceeded ${this.maxDocumentPolls} attempts, switching to extended polling mode (reduced frequency)`); + isExtendedPolling = true; + // Stop the current interval and restart with longer interval + this.stopPolling(kbId); + // Continue polling with reduced frequency (every 10 seconds) + const extendedInterval = setInterval(pollDocuments, 10000); + this.pollingIntervals.set(kbId, extendedInterval); + return; + } + // If there are processing documents, continue polling if (hasProcessingDocs) { log.log('Documents processing, continue polling'); @@ -141,6 +166,7 @@ class KnowledgeBasePollingService { * @param expectedIncrement The number of new files uploaded */ pollForKnowledgeBaseReady( + kbId: string, kbName: string, originalDocumentCount: number = 0, expectedIncrement: number = 0 @@ -150,29 +176,14 @@ class KnowledgeBasePollingService { const checkForStats = async () => { try { const kbs = await knowledgeBaseService.getKnowledgeBasesInfo(true) as KnowledgeBase[]; - const kb = kbs.find(k => k.name === kbName); + const kb = kbs.find(k => k.id === kbId || k.name === kbName); // Check if KB exists and its stats are populated if (kb) { - // If expectedIncrement > 0, check if documentCount increased as expected - if ( - expectedIncrement > 0 && - kb.documentCount >= (originalDocumentCount + expectedIncrement) - ) { - log.log( - `Knowledge base ${kbName} documentCount increased as expected: ${kb.documentCount} (was ${originalDocumentCount}, expected increment ${expectedIncrement})` - ); - this.triggerKnowledgeBaseListUpdate(true); - resolve(kb); - return; - } - // Fallback: for new KB or no increment specified, use old logic - if (expectedIncrement === 0 && (kb.documentCount > 0 || kb.chunkCount > 0)) { - log.log(`Knowledge base ${kbName} is ready and stats are populated.`); - this.triggerKnowledgeBaseListUpdate(true); - resolve(kb); - return; - } + log.log(`Knowledge base ${kbName} detected.`); + this.triggerKnowledgeBaseListUpdate(true); + resolve(kb); + return; } count++; @@ -183,11 +194,11 @@ class KnowledgeBasePollingService { log.error(`Knowledge base ${kbName} readiness check timed out after ${this.maxKnowledgeBasePolls} attempts.`); // Handle knowledge base polling timeout - mark related tasks as failed - await this.handlePollingTimeout(kbName, 'knowledgeBase'); + await this.handlePollingTimeout(kbId, 'knowledgeBase'); // Push documents to UI try { - const documents = await knowledgeBaseService.getAllFiles(kbName); - this.triggerDocumentsUpdate(kbName, documents); + const documents = await knowledgeBaseService.getAllFiles(kbId); + this.triggerDocumentsUpdate(kbId, documents); } catch (e) { // Ignore error } @@ -201,11 +212,11 @@ class KnowledgeBasePollingService { setTimeout(checkForStats, this.knowledgeBasePollingInterval); } else { // Handle knowledge base polling timeout on error as well - await this.handlePollingTimeout(kbName, 'knowledgeBase'); + await this.handlePollingTimeout(kbId, 'knowledgeBase'); // Push documents to UI try { - const documents = await knowledgeBaseService.getAllFiles(kbName); - this.triggerDocumentsUpdate(kbName, documents); + const documents = await knowledgeBaseService.getAllFiles(kbId); + this.triggerDocumentsUpdate(kbId, documents); } catch (e) { // Ignore error } @@ -218,14 +229,14 @@ class KnowledgeBasePollingService { } // Simplified method for new knowledge base creation workflow - async handleNewKnowledgeBaseCreation(kbName: string, originalDocumentCount: number = 0, expectedIncrement: number = 0, callback: (kb: KnowledgeBase) => void) { + async handleNewKnowledgeBaseCreation(kbId: string, kbName: string, originalDocumentCount: number = 0, expectedIncrement: number = 0, callback: (kb: KnowledgeBase) => void) { // Start document polling - this.startDocumentStatusPolling(kbName, (documents) => { - this.triggerDocumentsUpdate(kbName, documents); + this.startDocumentStatusPolling(kbId, (documents) => { + this.triggerDocumentsUpdate(kbId, documents); }); try { // Start knowledge base polling parallelly - const populatedKB = await this.pollForKnowledgeBaseReady(kbName, originalDocumentCount, expectedIncrement); + const populatedKB = await this.pollForKnowledgeBaseReady(kbId, kbName, originalDocumentCount, expectedIncrement); // callback with populated knowledge base when everything is ready callback(populatedKB); } catch (error) { @@ -249,6 +260,13 @@ class KnowledgeBasePollingService { clearInterval(interval); }); this.pollingIntervals.clear(); + + // Clear pending requests and debounce timers to prevent memory leaks + this.pendingRequests.clear(); + this.debounceTimers.forEach((timer) => { + clearTimeout(timer); + }); + this.debounceTimers.clear(); } // Trigger knowledge base list update (optionally force refresh) diff --git a/frontend/services/knowledgeBaseService.ts b/frontend/services/knowledgeBaseService.ts index 27a6e0b38..0ea443081 100644 --- a/frontend/services/knowledgeBaseService.ts +++ b/frontend/services/knowledgeBaseService.ts @@ -71,15 +71,20 @@ class KnowledgeBaseService { // Convert Elasticsearch indices to knowledge base format knowledgeBases = data.indices_info.map((indexInfo: any) => { const stats = indexInfo.stats?.base_info || {}; + // Backend now returns: + // - name: internal index_name + // - display_name: user-facing knowledge_name (fallback to index_name) + const kbId = indexInfo.name; + const kbName = indexInfo.display_name || indexInfo.name; return { - id: indexInfo.name, - name: indexInfo.name, + id: kbId, + name: kbName, description: "Elasticsearch index", documentCount: stats.doc_count || 0, chunkCount: stats.chunk_count || 0, - createdAt: - stats.creation_date || new Date().toISOString().split("T")[0], + createdAt: stats.creation_date || null, + updatedAt: stats.update_date || stats.creation_date || null, embeddingModel: stats.embedding_model || "unknown", avatar: "", chunkNum: 0, @@ -276,6 +281,16 @@ class KnowledgeBaseService { token_num: 0, status: file.status || "UNKNOWN", latest_task_id: file.latest_task_id || "", + error_reason: file.error_reason, + // Optional ingestion progress metrics (only present for in-progress files) + processed_chunk_num: + typeof file.processed_chunk_num === "number" + ? file.processed_chunk_num + : null, + total_chunk_num: + typeof file.total_chunk_num === "number" + ? file.total_chunk_num + : null, })); } catch (error) { log.error("Failed to get all files:", error); @@ -806,6 +821,41 @@ class KnowledgeBaseService { throw new Error("Failed to execute hybrid search"); } } + + // Get error information for a document + async getDocumentErrorInfo( + kbId: string, + docId: string + ): Promise<{ + errorCode: string | null; + }> { + try { + const response = await fetch( + API_ENDPOINTS.knowledgeBase.getErrorInfo(kbId, docId), + { + headers: getAuthHeaders(), + } + ); + + if (!response.ok) { + throw new Error(`HTTP error! status: ${response.status}`); + } + + const data = await response.json(); + if (data.status !== "success") { + throw new Error(data.message || "Failed to get error info"); + } + + const errorCode = (data.error_code && String(data.error_code)) || null; + + return { + errorCode, + }; + } catch (error) { + log.error("Failed to get document error info:", error); + throw error; + } + } } // Export a singleton instance diff --git a/frontend/services/modelService.ts b/frontend/services/modelService.ts index 9de2c5483..3599bc939 100644 --- a/frontend/services/modelService.ts +++ b/frontend/services/modelService.ts @@ -67,6 +67,7 @@ export const modelService = { (model.connect_status as ModelConnectStatus) || "not_detected", expectedChunkSize: model.expected_chunk_size, maximumChunkSize: model.maximum_chunk_size, + chunkingBatchSize: model.chunk_batch, })); } return []; @@ -97,6 +98,7 @@ export const modelService = { displayName?: string; expectedChunkSize?: number; maximumChunkSize?: number; + chunkingBatchSize?: number; }): Promise => { try { const response = await fetch(API_ENDPOINTS.model.customModelCreate, { @@ -112,6 +114,7 @@ export const modelService = { display_name: model.displayName, expected_chunk_size: model.expectedChunkSize, maximum_chunk_size: model.maximumChunkSize, + chunk_batch: model.chunkingBatchSize, }), }); @@ -239,6 +242,7 @@ export const modelService = { source?: ModelSource; expectedChunkSize?: number; maximumChunkSize?: number; + chunkingBatchSize?: number; }): Promise => { try { const response = await fetch( @@ -262,6 +266,9 @@ export const modelService = { ...(model.maximumChunkSize !== undefined ? { maximum_chunk_size: model.maximumChunkSize } : {}), + ...(model.chunkingBatchSize !== undefined + ? { chunk_batch: model.chunkingBatchSize } + : {}), }), } ); diff --git a/frontend/services/storageService.ts b/frontend/services/storageService.ts index ec60eb187..a45add994 100644 --- a/frontend/services/storageService.ts +++ b/frontend/services/storageService.ts @@ -123,6 +123,68 @@ export function convertImageUrlToApiUrl(url: string): string { return url; } +const arrayBufferToBase64 = (buffer: ArrayBuffer): string => { + let binary = ""; + const bytes = new Uint8Array(buffer); + const chunkSize = 0x8000; + + for (let i = 0; i < bytes.length; i += chunkSize) { + const chunk = bytes.subarray(i, i + chunkSize); + binary += String.fromCharCode(...chunk); + } + + return btoa(binary); +}; + +const fetchBase64ViaStorage = async (objectName: string) => { + const response = await fetch(API_ENDPOINTS.storage.file(objectName, "base64")); + if (!response.ok) { + throw new Error(`Failed to resolve S3 URL via storage: ${response.status}`); + } + + const data = await response.json(); + if (!data?.success || !data?.base64) { + throw new Error(data?.error || "Storage response missing base64 content"); + } + + const contentType = data.content_type || "application/octet-stream"; + return { base64: data.base64 as string, contentType }; +}; + +// Cache for S3 URL to data URL resolution to avoid duplicate network requests +const s3ResolutionCache = new Map>(); + +// Internal helper: for s3:// URLs, resolve directly via storage download endpoint. +async function resolveS3UrlToDataUrlInternal(url: string): Promise { + const objectName = extractObjectNameFromUrl(url); + if (!objectName) { + return null; + } + + const { base64, contentType } = await fetchBase64ViaStorage(objectName); + return `data:${contentType};base64,${base64}`; +} + +export async function resolveS3UrlToDataUrl(url: string): Promise { + if (!url || !url.startsWith("s3://")) { + return null; + } + + const cached = s3ResolutionCache.get(url); + if (cached) { + return cached; + } + + const promise = resolveS3UrlToDataUrlInternal(url).catch((error) => { + // Remove from cache on failure so that future attempts can retry. + s3ResolutionCache.delete(url); + throw error; + }); + + s3ResolutionCache.set(url, promise); + return promise; +} + export const storageService = { /** * Upload files to storage service diff --git a/frontend/styles/globals.css b/frontend/styles/globals.css index 7d6b1749d..ad666027d 100644 --- a/frontend/styles/globals.css +++ b/frontend/styles/globals.css @@ -305,4 +305,23 @@ .kb-embedding-warning .ant-modal { width: max-content; min-width: 0; +} + +/* Responsive button text - global utility */ +@media (max-width: 1279px) { + .button-text-full { + display: none !important; + } + .button-text-short { + display: inline !important; + } +} + +@media (min-width: 1280px) { + .button-text-full { + display: inline !important; + } + .button-text-short { + display: none !important; + } } \ No newline at end of file diff --git a/frontend/types/agentConfig.ts b/frontend/types/agentConfig.ts index 3dc41c601..1a766788c 100644 --- a/frontend/types/agentConfig.ts +++ b/frontend/types/agentConfig.ts @@ -12,6 +12,7 @@ export interface Agent { name: string; display_name?: string; description: string; + author?: string; unavailable_reasons?: string[]; model: string; model_id?: number; @@ -127,6 +128,8 @@ export interface AgentSetupOrchestratorProps { setAgentDescription?: (value: string) => void; agentDisplayName?: string; setAgentDisplayName?: (value: string) => void; + agentAuthor?: string; + setAgentAuthor?: (value: string) => void; isGeneratingAgent?: boolean; onDebug?: () => void; getCurrentAgentId?: () => number | undefined; @@ -156,6 +159,7 @@ export interface SubAgentPoolProps { isGeneratingAgent?: boolean; editingAgent?: Agent | null; isCreatingNewAgent?: boolean; + onCopyAgent?: (agent: Agent) => void; onExportAgent?: (agent: Agent) => void; onDeleteAgent?: (agent: Agent) => void; } diff --git a/frontend/types/chat.ts b/frontend/types/chat.ts index 700edfdbf..826722055 100644 --- a/frontend/types/chat.ts +++ b/frontend/types/chat.ts @@ -9,6 +9,7 @@ export interface StepSection { export interface StepContent { id: string type: typeof chatConfig.messageTypes.MODEL_OUTPUT | + typeof chatConfig.messageTypes.MODEL_OUTPUT_CODE | typeof chatConfig.messageTypes.PARSING | typeof chatConfig.messageTypes.EXECUTION | typeof chatConfig.messageTypes.ERROR | diff --git a/frontend/types/knowledgeBase.ts b/frontend/types/knowledgeBase.ts index 85a5e6b12..e04f145c7 100644 --- a/frontend/types/knowledgeBase.ts +++ b/frontend/types/knowledgeBase.ts @@ -4,21 +4,23 @@ import { DOCUMENT_ACTION_TYPES, KNOWLEDGE_BASE_ACTION_TYPES, UI_ACTION_TYPES, NO // Knowledge base basic type export interface KnowledgeBase { - id: string - name: string - description: string | null - chunkCount: number - documentCount: number - createdAt: any - embeddingModel: string - avatar: string - chunkNum: number - language: string - nickname: string - parserId: string - permission: string - tokenNum: number - source: string + id: string; + name: string; + description: string | null; + chunkCount: number; + documentCount: number; + createdAt: any; + // Last update time of the knowledge base/index (may fall back to createdAt) + updatedAt?: any; + embeddingModel: string; + avatar: string; + chunkNum: number; + language: string; + nickname: string; + parserId: string; + permission: string; + tokenNum: number; + source: string; } // Create knowledge base parameter type @@ -31,17 +33,21 @@ export interface KnowledgeBaseCreateParams { // Document type export interface Document { - id: string - kb_id: string - name: string - type: string - size: number - create_time: string - chunk_num: number - token_num: number - status: string - selected?: boolean // For UI selection status - latest_task_id: string // For marking the latest celery task + id: string; + kb_id: string; + name: string; + type: string; + size: number; + create_time: string; + chunk_num: number; + token_num: number; + status: string; + selected?: boolean; // For UI selection status + latest_task_id: string; // For marking the latest celery task + error_reason?: string; // Error reason for failed documents + // Optional ingestion progress metrics + processed_chunk_num?: number | null; + total_chunk_num?: number | null; } // Document state interface diff --git a/frontend/types/market.ts b/frontend/types/market.ts index 888afffdb..770e39520 100644 --- a/frontend/types/market.ts +++ b/frontend/types/market.ts @@ -28,6 +28,7 @@ export interface MarketAgentListItem { name: string; display_name: string; description: string; + author?: string; category: MarketCategory; tags: MarketTag[]; download_count: number; diff --git a/frontend/types/modelConfig.ts b/frontend/types/modelConfig.ts index db97a8c0d..0d463161f 100644 --- a/frontend/types/modelConfig.ts +++ b/frontend/types/modelConfig.ts @@ -45,6 +45,7 @@ export interface ModelOption { connect_status?: ModelConnectStatus; expectedChunkSize?: number; maximumChunkSize?: number; + chunkingBatchSize?: number; } // Application configuration interface diff --git a/sdk/nexent/core/agents/agent_model.py b/sdk/nexent/core/agents/agent_model.py index 6eff00718..f3c5a77b7 100644 --- a/sdk/nexent/core/agents/agent_model.py +++ b/sdk/nexent/core/agents/agent_model.py @@ -1,7 +1,7 @@ from __future__ import annotations from threading import Event -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from pydantic import BaseModel, Field @@ -50,7 +50,12 @@ class AgentRunInfo(BaseModel): model_config_list: List[ModelConfig] = Field(description="List of model configurations") observer: MessageObserver = Field(description="Return data") agent_config: AgentConfig = Field(description="Detailed Agent configuration") - mcp_host: Optional[List[str]] = Field(description="MCP server address", default=None) + mcp_host: Optional[List[Union[str, Dict[str, Any]]]] = Field( + description="MCP server address(es). Can be a string (URL) or dict with 'url' and 'transport' keys. " + "Transport can be 'sse' or 'streamable-http'. If string, transport is auto-detected based on URL ending: " + "URLs ending with '/sse' use 'sse' transport, URLs ending with '/mcp' use 'streamable-http' transport.", + default=None + ) history: Optional[List[AgentHistory]] = Field(description="Historical conversation information", default=None) stop_event: Event = Field(description="Stop event control") diff --git a/sdk/nexent/core/agents/core_agent.py b/sdk/nexent/core/agents/core_agent.py index 826ef7093..be7b83b5e 100644 --- a/sdk/nexent/core/agents/core_agent.py +++ b/sdk/nexent/core/agents/core_agent.py @@ -1,3 +1,4 @@ +import json import re import ast import time @@ -9,12 +10,13 @@ from rich.console import Group from rich.text import Text -from smolagents.agents import CodeAgent, handle_agent_output_types, AgentError +from smolagents.agents import CodeAgent, handle_agent_output_types, AgentError, ActionOutput, RunResult from smolagents.local_python_executor import fix_final_answer_code from smolagents.memory import ActionStep, PlanningStep, FinalAnswerStep, ToolCall, TaskStep, SystemPromptStep -from smolagents.models import ChatMessage -from smolagents.monitoring import LogLevel -from smolagents.utils import AgentExecutionError, AgentGenerationError, truncate_content +from smolagents.models import ChatMessage, CODEAGENT_RESPONSE_FORMAT +from smolagents.monitoring import LogLevel, Timing, YELLOW_HEX, TokenUsage +from smolagents.utils import AgentExecutionError, AgentGenerationError, truncate_content, AgentMaxStepsError, \ + extract_code_from_text from ..utils.observer import MessageObserver, ProcessType from jinja2 import Template, StrictUndefined @@ -125,13 +127,17 @@ def _step_stream(self, memory_step: ActionStep) -> Generator[Any]: # Add new step in logs memory_step.model_input_messages = input_messages + stop_sequences = ["", "Observation:", "Calling tools:", "", "Observation:", "Calling tools:", " Generator[Any]: # Parse try: - code_action = fix_final_answer_code(parse_code_blobs(model_output)) + if self._use_structured_outputs_internally: + code_action = json.loads(model_output)["code"] + code_action = extract_code_from_text(code_action, self.code_block_tags) or code_action + else: + code_action = parse_code_blobs(model_output) + code_action = fix_final_answer_code(code_action) + memory_step.code_action = code_action # Record parsing results self.observer.add_message( self.agent_name, ProcessType.PARSE, code_action) @@ -155,26 +167,29 @@ def _step_stream(self, memory_step: ActionStep) -> Generator[Any]: content=model_output, title="AGENT FINAL ANSWER", level=LogLevel.INFO) raise FinalAnswerError() - memory_step.tool_calls = [ - ToolCall(name="python_interpreter", arguments=code_action, id=f"call_{len(self.memory.steps)}", )] + tool_call = ToolCall( + name="python_interpreter", + arguments=code_action, + id=f"call_{len(self.memory.steps)}", + ) + memory_step.tool_calls = [tool_call] # Execute self.logger.log_code(title="Executing parsed code:", content=code_action, level=LogLevel.INFO) - is_final_answer = False try: - output, execution_logs, is_final_answer = self.python_executor( - code_action) - + code_output = self.python_executor(code_action) execution_outputs_console = [] - if len(execution_logs) > 0: + if len(code_output.logs) > 0: # Record execution results self.observer.add_message( - self.agent_name, ProcessType.EXECUTION_LOGS, f"{execution_logs}") + self.agent_name, ProcessType.EXECUTION_LOGS, f"{code_output.logs}") execution_outputs_console += [ - Text("Execution logs:", style="bold"), Text(execution_logs), ] - observation = "Execution logs:\n" + execution_logs + Text("Execution logs:", style="bold"), + Text(code_output.logs), + ] + observation = "Execution logs:\n" + code_output.logs except Exception as e: if hasattr(self.python_executor, "state") and "_print_outputs" in self.python_executor.state: execution_logs = str( @@ -196,20 +211,24 @@ def _step_stream(self, memory_step: ActionStep) -> Generator[Any]: level=LogLevel.INFO, ) raise AgentExecutionError(error_msg, self.logger) - truncated_output = truncate_content(str(output)) - if output is not None: + truncated_output = None + if code_output is not None and code_output.output is not None: + truncated_output = truncate_content(str(code_output.output)) observation += "Last output from code snippet:\n" + truncated_output memory_step.observations = observation - execution_outputs_console += [ - Text(f"{('Out - Final answer' if is_final_answer else 'Out')}: {truncated_output}", - style=("bold #d4b702" if is_final_answer else ""), ), ] + if not code_output.is_final_answer and truncated_output is not None: + execution_outputs_console += [ + Text( + f"Out: {truncated_output}", + ), + ] self.logger.log(Group(*execution_outputs_console), level=LogLevel.INFO) - memory_step.action_output = output - yield output if is_final_answer else None + memory_step.action_output = code_output.output + yield ActionOutput(output=code_output.output, is_final_answer=code_output.is_final_answer) def run(self, task: str, stream: bool = False, reset: bool = True, images: Optional[List[str]] = None, - additional_args: Optional[Dict] = None, max_steps: Optional[int] = None, ): + additional_args: Optional[Dict] = None, max_steps: Optional[int] = None, return_full_result: bool | None = None): """ Run the agent for the given task. @@ -220,6 +239,8 @@ def run(self, task: str, stream: bool = False, reset: bool = True, images: Optio images (`list[str]`, *optional*): Paths to image(s). additional_args (`dict`, *optional*): Any other variables that you want to pass to the agent run, for instance images or dataframes. Give them clear names! max_steps (`int`, *optional*): Maximum number of steps the agent can take to solve the task. if not provided, will use the agent's default value. + return_full_result (`bool`, *optional*): Whether to return the full [`RunResult`] object or just the final answer output. + If `None` (default), the agent's `self.return_full_result` setting is used. Example: ```py @@ -236,7 +257,6 @@ def run(self, task: str, stream: bool = False, reset: bool = True, images: Optio You have been provided with these additional arguments, that you can access using the keys as variables in your python code: {str(additional_args)}.""" - self.system_prompt = self.initialize_system_prompt() self.memory.system_prompt = SystemPromptStep( system_prompt=self.system_prompt) if reset: @@ -261,8 +281,47 @@ def run(self, task: str, stream: bool = False, reset: bool = True, images: Optio if stream: # The steps are returned as they are executed through a generator to iterate on. return self._run_stream(task=self.task, max_steps=max_steps, images=images) + run_start_time = time.time() + steps = list(self._run_stream(task=self.task, max_steps=max_steps, images=images)) + # Outputs are returned only at the end. We only look at the last step. - return list(self._run_stream(task=self.task, max_steps=max_steps, images=images))[-1].final_answer + assert isinstance(steps[-1], FinalAnswerStep) + output = steps[-1].output + + return_full_result = return_full_result if return_full_result is not None else self.return_full_result + if return_full_result: + total_input_tokens = 0 + total_output_tokens = 0 + correct_token_usage = True + for step in self.memory.steps: + if isinstance(step, (ActionStep, PlanningStep)): + if step.token_usage is None: + correct_token_usage = False + break + else: + total_input_tokens += step.token_usage.input_tokens + total_output_tokens += step.token_usage.output_tokens + if correct_token_usage: + token_usage = TokenUsage(input_tokens=total_input_tokens, output_tokens=total_output_tokens) + else: + token_usage = None + + if self.memory.steps and isinstance(getattr(self.memory.steps[-1], "error", None), AgentMaxStepsError): + state = "max_steps_error" + else: + state = "success" + + step_dicts = self.memory.get_full_steps() + + return RunResult( + output=output, + token_usage=token_usage, + steps=step_dicts, + timing=Timing(start_time=run_start_time, end_time=time.time()), + state=state, + ) + + return output def __call__(self, task: str, **kwargs): """Adds additional prompting for the managed agent, runs it, and wraps the output. @@ -271,7 +330,11 @@ def __call__(self, task: str, **kwargs): full_task = Template(self.prompt_templates["managed_agent"]["task"], undefined=StrictUndefined).render({ "name": self.name, "task": task, **self.state }) - report = self.run(full_task, **kwargs) + result = self.run(full_task, **kwargs) + if isinstance(result, RunResult): + report = result.output + else: + report = result # When a sub-agent finishes running, return a marker try: @@ -286,7 +349,7 @@ def __call__(self, task: str, **kwargs): if self.provide_run_summary: answer += "\n\nFor more detail, find below a summary of this agent's work:\n\n" for message in self.write_memory_to_messages(summary_mode=True): - content = message["content"] + content = message.content answer += "\n" + truncate_content(str(content)) + "\n---" answer += "\n" return answer @@ -295,28 +358,44 @@ def _run_stream( self, task: str, max_steps: int, images: list["PIL.Image.Image"] | None = None ) -> Generator[ActionStep | PlanningStep | FinalAnswerStep]: final_answer = None + action_step = None self.step_number = 1 - while final_answer is None and self.step_number <= max_steps and not self.stop_event.is_set(): + returned_final_answer = False + while not returned_final_answer and self.step_number <= max_steps and not self.stop_event.is_set(): step_start_time = time.time() action_step = ActionStep( - step_number=self.step_number, start_time=step_start_time, observations_images=images + step_number=self.step_number, timing=Timing(start_time=step_start_time), observations_images=images ) try: - for el in self._execute_step(action_step): - yield el - final_answer = el + for output in self._step_stream(action_step): + yield output + + if isinstance(output, ActionOutput) and output.is_final_answer: + final_answer = output.output + self.logger.log( + Text(f"Final answer: {final_answer}", style=f"bold {YELLOW_HEX}"), + level=LogLevel.INFO, + ) + + if self.final_answer_checks: + self._validate_final_answer(final_answer) + returned_final_answer = True + action_step.is_final_answer = True + except FinalAnswerError: # When the model does not output code, directly treat the large model content as the final answer final_answer = action_step.model_output if isinstance(final_answer, str): final_answer = convert_code_format(final_answer) + returned_final_answer = True + action_step.is_final_answer = True except AgentError as e: action_step.error = e finally: - self._finalize_step(action_step, step_start_time) + self._finalize_step(action_step) self.memory.steps.append(action_step) yield action_step self.step_number += 1 @@ -324,8 +403,7 @@ def _run_stream( if self.stop_event.is_set(): final_answer = "" - if final_answer is None and self.step_number == max_steps + 1: - final_answer = self._handle_max_steps_reached( - task, images, step_start_time) + if not returned_final_answer and self.step_number == max_steps + 1: + final_answer = self._handle_max_steps_reached(task) yield action_step yield FinalAnswerStep(handle_agent_output_types(final_answer)) diff --git a/sdk/nexent/core/agents/nexent_agent.py b/sdk/nexent/core/agents/nexent_agent.py index f0f932389..f02251cfc 100644 --- a/sdk/nexent/core/agents/nexent_agent.py +++ b/sdk/nexent/core/agents/nexent_agent.py @@ -1,8 +1,9 @@ import re +import time from threading import Event from typing import List -from smolagents import ActionStep, AgentText, TaskStep +from smolagents import ActionStep, AgentText, TaskStep, Timing from smolagents.tools import Tool from ..models.openai_llm import OpenAIModel @@ -84,6 +85,9 @@ def create_local_tool(self, tool_config: ToolConfig): "vdb_core", None) if tool_config.metadata else None tools_obj.embedding_model = tool_config.metadata.get( "embedding_model", None) if tool_config.metadata else None + name_resolver = tool_config.metadata.get( + "name_resolver", None) if tool_config.metadata else None + tools_obj.name_resolver = {} if name_resolver is None else name_resolver elif class_name == "AnalyzeTextFileTool": tools_obj = tool_class(observer=self.observer, llm_model=tool_config.metadata.get("llm_model", []), @@ -195,7 +199,9 @@ def add_history_to_agent(self, history: List[AgentHistory]): # Create task step for user message self.agent.memory.steps.append(TaskStep(task=msg.content)) elif msg.role == 'assistant': - self.agent.memory.steps.append(ActionStep(action_output=msg.content, model_output=msg.content)) + self.agent.memory.steps.append(ActionStep(step_number=len(self.agent.memory.steps) + 1, + timing=Timing(start_time=time.time()), + action_output=msg.content, model_output=msg.content)) def agent_run_with_observer(self, query: str, reset=True): if not isinstance(self.agent, CoreAgent): @@ -214,7 +220,7 @@ def agent_run_with_observer(self, query: str, reset=True): if hasattr(step_log, "error") and step_log.error is not None: observer.add_message("", ProcessType.ERROR, str(step_log.error)) - final_answer = step_log.final_answer # Last log is the run's final_answer + final_answer = step_log.output # Last log is the run's final_answer if isinstance(final_answer, AgentText): final_answer_str = convert_code_format(final_answer.to_string()) diff --git a/sdk/nexent/core/agents/run_agent.py b/sdk/nexent/core/agents/run_agent.py index 41429367a..8a5a67517 100644 --- a/sdk/nexent/core/agents/run_agent.py +++ b/sdk/nexent/core/agents/run_agent.py @@ -1,6 +1,7 @@ import asyncio import logging from threading import Thread +from typing import Any, Dict, Union from smolagents import ToolCollection @@ -13,6 +14,56 @@ monitoring_manager = get_monitoring_manager() +def _detect_transport(url: str) -> str: + """ + Auto-detect MCP transport type based on URL format. + + Args: + url: MCP server URL + + Returns: + Transport type: 'sse' or 'streamable-http' + """ + url_stripped = url.strip() + + # Check URL ending to determine transport type + if url_stripped.endswith("/sse"): + return "sse" + elif url_stripped.endswith("/mcp"): + return "streamable-http" + + # Default to streamable-http for unrecognized formats + return "streamable-http" + + +def _normalize_mcp_config(mcp_host_item: Union[str, Dict[str, Any]]) -> Dict[str, Any]: + """ + Normalize MCP host configuration to a dictionary format. + + Args: + mcp_host_item: Either a string URL or a dict with 'url' and optional 'transport' + + Returns: + Dictionary with 'url' and 'transport' keys + """ + if isinstance(mcp_host_item, str): + url = mcp_host_item + transport = _detect_transport(url) + return {"url": url, "transport": transport} + elif isinstance(mcp_host_item, dict): + url = mcp_host_item.get("url") + if not url: + raise ValueError("MCP host dict must contain 'url' key") + transport = mcp_host_item.get("transport") + if not transport: + transport = _detect_transport(url) + if transport not in ("sse", "streamable-http"): + raise ValueError(f"Invalid transport type: {transport}. Must be 'sse' or 'streamable-http'") + return {"url": url, "transport": transport} + else: + raise ValueError(f"Invalid MCP host item type: {type(mcp_host_item)}. Must be str or dict") + + @monitoring_manager.monitor_endpoint("agent_run_thread", "agent_run_thread") def agent_run_thread(agent_run_info: AgentRunInfo): try: @@ -31,7 +82,8 @@ def agent_run_thread(agent_run_info: AgentRunInfo): else: agent_run_info.observer.add_message( "", ProcessType.AGENT_NEW_RUN, "") - mcp_client_list = [{"url": mcp_url} for mcp_url in mcp_host] + # Normalize MCP host configurations to support both string and dict formats + mcp_client_list = [_normalize_mcp_config(item) for item in mcp_host] with ToolCollection.from_mcp(mcp_client_list, trust_remote_code=True) as tool_collection: nexent = NexentAgent( diff --git a/sdk/nexent/core/models/openai_llm.py b/sdk/nexent/core/models/openai_llm.py index 1a52e2d29..1eef02c72 100644 --- a/sdk/nexent/core/models/openai_llm.py +++ b/sdk/nexent/core/models/openai_llm.py @@ -14,7 +14,7 @@ logger = logging.getLogger("openai_llm") class OpenAIModel(OpenAIServerModel): - def __init__(self, observer: MessageObserver, temperature=0.2, top_p=0.95, + def __init__(self, observer: MessageObserver = MessageObserver, temperature=0.2, top_p=0.95, ssl_verify=True, *args, **kwargs): """ Initialize OpenAI Model with observer and SSL verification option. @@ -46,7 +46,7 @@ def __init__(self, observer: MessageObserver, temperature=0.2, top_p=0.95, @get_monitoring_manager().monitor_llm_call("openai_chat", "chat_completion") def __call__(self, messages: List[Dict[str, Any]], stop_sequences: Optional[List[str]] = None, - grammar: Optional[str] = None, tools_to_call_from: Optional[List[Tool]] = None, **kwargs, ) -> ChatMessage: + response_format: dict[str, str] | None = None, tools_to_call_from: Optional[List[Tool]] = None, **kwargs, ) -> ChatMessage: # Get token tracker from decorator (if monitoring is available) token_tracker = kwargs.pop('_token_tracker', None) @@ -63,7 +63,7 @@ def __call__(self, messages: List[Dict[str, Any]], stop_sequences: Optional[List completion_kwargs = self._prepare_completion_kwargs( messages=messages, stop_sequences=stop_sequences, - grammar=grammar, tools_to_call_from=tools_to_call_from, model=self.model_id, + response_format=response_format, tools_to_call_from=tools_to_call_from, model=self.model_id, custom_role_conversions=self.custom_role_conversions, convert_images_to_image_urls=True, temperature=self.temperature, top_p=self.top_p, **kwargs, ) diff --git a/sdk/nexent/core/tools/datamate_search_tool.py b/sdk/nexent/core/tools/datamate_search_tool.py index a179dd689..bf1009269 100644 --- a/sdk/nexent/core/tools/datamate_search_tool.py +++ b/sdk/nexent/core/tools/datamate_search_tool.py @@ -150,7 +150,7 @@ def forward( entity_data = single_search_result.get("entity", {}) metadata = self._parse_metadata(entity_data.get("metadata")) dataset_id = self._extract_dataset_id(metadata.get("absolute_directory_path", "")) - file_id = entity_data.get("id") + file_id = metadata.get("original_file_id") download_url = self._build_file_download_url(dataset_id, file_id) score_details = entity_data.get("scoreDetails", {}) or {} @@ -162,7 +162,7 @@ def forward( }) search_result_message = SearchResultTextMessage( - title=metadata.get("file_name", "") or "Untitled", + title=metadata.get("file_name", ""), text=entity_data.get("text", ""), source_type="datamate", url=download_url, @@ -308,6 +308,6 @@ def _extract_dataset_id(absolute_path: str) -> str: def _build_file_download_url(self, dataset_id: str, file_id: str) -> str: """Build the download URL for a dataset file.""" - if not (self.server_ip and dataset_id and file_id): + if not (self.server_base_url and dataset_id and file_id): return "" - return f"{self.server_ip}/api/data-management/datasets/{dataset_id}/files/{file_id}/download" \ No newline at end of file + return f"{self.server_base_url}/api/data-management/datasets/{dataset_id}/files/{file_id}/download" \ No newline at end of file diff --git a/sdk/nexent/core/tools/knowledge_base_search_tool.py b/sdk/nexent/core/tools/knowledge_base_search_tool.py index 636162da1..90b600da6 100644 --- a/sdk/nexent/core/tools/knowledge_base_search_tool.py +++ b/sdk/nexent/core/tools/knowledge_base_search_tool.py @@ -1,6 +1,6 @@ import json import logging -from typing import List +from typing import Dict, List, Optional, Union from pydantic import Field from smolagents.tools import Tool @@ -36,7 +36,7 @@ class KnowledgeBaseSearchTool(Tool): }, "index_names": { "type": "array", - "description": "The list of knowledge base index names to search. If not provided, will search all available knowledge bases.", + "description": "The list of knowledge base names to search (supports user-facing knowledge_name or internal index_name). If not provided, will search all available knowledge bases.", "nullable": True, }, } @@ -50,6 +50,9 @@ def __init__( self, top_k: int = Field(description="Maximum number of search results", default=5), index_names: List[str] = Field(description="The list of index names to search", default=None, exclude=True), + name_resolver: Optional[Dict[str, str]] = Field( + description="Mapping from knowledge_name to index_name", default=None, exclude=True + ), observer: MessageObserver = Field(description="Message observer", default=None, exclude=True), embedding_model: BaseEmbedding = Field(description="The embedding model to use", default=None, exclude=True), vdb_core: VectorDatabaseCore = Field(description="Vector database client", default=None, exclude=True), @@ -68,13 +71,36 @@ def __init__( self.observer = observer self.vdb_core = vdb_core self.index_names = [] if index_names is None else index_names + self.name_resolver: Dict[str, str] = name_resolver or {} self.embedding_model = embedding_model self.record_ops = 1 # To record serial number self.running_prompt_zh = "知识库检索中..." self.running_prompt_en = "Searching the knowledge base..." - def forward(self, query: str, search_mode: str = "hybrid", index_names: List[str] = None) -> str: + def update_name_resolver(self, new_mapping: Dict[str, str]) -> None: + """Update the mapping from knowledge_name to index_name at runtime.""" + self.name_resolver = new_mapping or {} + + def _resolve_names(self, names: List[str]) -> List[str]: + """Resolve user-facing knowledge names to internal index names.""" + if not names: + return [] + if not self.name_resolver: + logger.warning( + "No name resolver provided, returning original names") + return names + return [self.name_resolver.get(name, name) for name in names] + + def _normalize_index_names(self, index_names: Optional[Union[str, List[str]]]) -> List[str]: + """Normalize index_names to list; accept single string and keep None as empty list.""" + if index_names is None: + return [] + if isinstance(index_names, str): + return [index_names] + return list(index_names) + + def forward(self, query: str, search_mode: str = "hybrid", index_names: Union[str, List[str], None] = None) -> str: # Send tool run message if self.observer: running_prompt = self.running_prompt_zh if self.observer.lang == "zh" else self.running_prompt_en @@ -83,7 +109,9 @@ def forward(self, query: str, search_mode: str = "hybrid", index_names: List[str self.observer.add_message("", ProcessType.CARD, json.dumps(card_content, ensure_ascii=False)) # Use provided index_names if available, otherwise use default - search_index_names = index_names if index_names is not None else self.index_names + search_index_names = self._normalize_index_names( + index_names if index_names is not None else self.index_names) + search_index_names = self._resolve_names(search_index_names) # Log the index_names being used for this search logger.info( diff --git a/sdk/nexent/vector_database/base.py b/sdk/nexent/vector_database/base.py index 188e33e59..d15ba7a25 100644 --- a/sdk/nexent/vector_database/base.py +++ b/sdk/nexent/vector_database/base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Callable from ..core.models.embedding_model import BaseEmbedding @@ -79,6 +79,8 @@ def vectorize_documents( documents: List[Dict[str, Any]], batch_size: int = 64, content_field: str = "content", + embedding_batch_size: int = 10, + progress_callback: Optional[Callable[[int, int], None]] = None, ) -> int: """ Index documents with embeddings. diff --git a/sdk/nexent/vector_database/elasticsearch_core.py b/sdk/nexent/vector_database/elasticsearch_core.py index 4e027b941..8abe046f4 100644 --- a/sdk/nexent/vector_database/elasticsearch_core.py +++ b/sdk/nexent/vector_database/elasticsearch_core.py @@ -1,10 +1,11 @@ +import json import logging import threading import time from contextlib import contextmanager from dataclasses import dataclass from datetime import datetime, timedelta -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional from elasticsearch import Elasticsearch, exceptions @@ -338,6 +339,8 @@ def vectorize_documents( documents: List[Dict[str, Any]], batch_size: int = 64, content_field: str = "content", + embedding_batch_size: int = 10, + progress_callback: Optional[Callable[[int, int], None]] = None, ) -> int: """ Smart batch insertion - automatically selecting strategy based on data size @@ -348,6 +351,7 @@ def vectorize_documents( documents: List of document dictionaries batch_size: Number of documents to process at once content_field: Field to use for generating embeddings + embedding_batch_size: Number of documents to send to embedding API at once (default: 10) Returns: int: Number of documents successfully indexed @@ -362,15 +366,34 @@ def vectorize_documents( total_docs = len(documents) if total_docs < 64: # Small data: direct insertion, using wait_for refresh - return self._small_batch_insert(index_name, documents, content_field, embedding_model) + return self._small_batch_insert( + index_name=index_name, + documents=documents, + content_field=content_field, + embedding_model=embedding_model, + progress_callback=progress_callback, + ) else: # Large data: using context manager estimated_duration = max(60, total_docs // 100) with self.bulk_operation_context(index_name, estimated_duration): - return self._large_batch_insert(index_name, documents, batch_size, content_field, embedding_model) + return self._large_batch_insert( + index_name=index_name, + documents=documents, + batch_size=batch_size, + content_field=content_field, + embedding_model=embedding_model, + embedding_batch_size=embedding_batch_size, + progress_callback=progress_callback, + ) def _small_batch_insert( - self, index_name: str, documents: List[Dict[str, Any]], content_field: str, embedding_model: BaseEmbedding + self, + index_name: str, + documents: List[Dict[str, Any]], + content_field: str, + embedding_model: BaseEmbedding, + progress_callback: Optional[Callable[[int, int], None]] = None, ) -> int: """Small batch insertion: real-time""" try: @@ -398,13 +421,20 @@ def _small_batch_insert( # Handle errors self._handle_bulk_errors(response) + if progress_callback: + try: + progress_callback(len(documents), len(documents)) + except Exception as e: + logger.warning( + f"[VECTORIZE] Progress callback failed in small batch: {str(e)}") + logger.info( f"Small batch insert completed: {len(documents)} chunks indexed.") return len(documents) except Exception as e: logger.error(f"Small batch insert failed: {e}") - return 0 + raise def _large_batch_insert( self, @@ -413,6 +443,8 @@ def _large_batch_insert( batch_size: int, content_field: str, embedding_model: BaseEmbedding, + embedding_batch_size: int = 10, + progress_callback: Optional[Callable[[int, int], None]] = None, ) -> int: """ Large batch insertion with sub-batching for embedding API. @@ -422,6 +454,7 @@ def _large_batch_insert( processed_docs = self._preprocess_documents( documents, content_field) total_indexed = 0 + total_vectorized = 0 total_docs = len(processed_docs) es_total_batches = (total_docs + batch_size - 1) // batch_size start_time = time.time() @@ -439,7 +472,7 @@ def _large_batch_insert( doc_embedding_pairs = [] # Sub-batch for embedding API - embedding_batch_size = 64 + # Use the provided embedding_batch_size (default 10) to reduce provider pressure for j in range(0, len(es_batch), embedding_batch_size): embedding_sub_batch = es_batch[j: j + embedding_batch_size] # Retry logic for embedding API call (3 retries, 1s delay) @@ -459,6 +492,16 @@ def _large_batch_insert( doc_embedding_pairs.append((doc, embedding)) success = True + total_vectorized += len(embedding_sub_batch) + if progress_callback: + try: + progress_callback( + total_vectorized, total_docs) + logger.debug( + f"[VECTORIZE] Progress callback (embedding) {total_vectorized}/{total_docs} (ES batch {es_batch_num}/{es_total_batches}, sub-batch start {j})") + except Exception as callback_err: + logger.warning( + f"[VECTORIZE] Progress callback failed during embedding: {callback_err}") break # Success, exit retry loop except Exception as e: @@ -504,10 +547,7 @@ def _large_batch_insert( except Exception as e: logger.error( f"Bulk insert error: {e}, ES batch num: {es_batch_num}") - continue - - # Add 0.1s delay between batches to avoid overloading embedding API - time.sleep(0.1) + raise self._force_refresh_with_retry(index_name) total_elapsed = time.time() - start_time @@ -517,7 +557,7 @@ def _large_batch_insert( return total_indexed except Exception as e: logger.error(f"Large batch insert failed: {e}") - return 0 + raise def _preprocess_documents(self, documents: List[Dict[str, Any]], content_field: str) -> List[Dict[str, Any]]: """Ensure all documents have the required fields and set default values""" @@ -558,21 +598,44 @@ def _handle_bulk_errors(self, response: Dict[str, Any]) -> None: """Handle bulk operation errors""" if response.get("errors"): for item in response["items"]: - if "error" in item.get("index", {}): - error_info = item["index"]["error"] - error_type = error_info.get("type") - error_reason = error_info.get("reason") - error_cause = error_info.get("caused_by", {}) - - if error_type == "version_conflict_engine_exception": - # ignore version conflict - continue - else: - logger.error( - f"FATAL ERROR {error_type}: {error_reason}") - if error_cause: - logger.error( - f"Caused By: {error_cause.get('type')}: {error_cause.get('reason')}") + if "error" not in item.get("index", {}): + continue + + error_info = item["index"]["error"] + error_type = error_info.get("type") + error_reason = error_info.get("reason") + error_cause = error_info.get("caused_by", {}) + + if error_type == "version_conflict_engine_exception": + # ignore version conflict + continue + + logger.error(f"FATAL ERROR {error_type}: {error_reason}") + if error_cause: + logger.error( + f"Caused By: {error_cause.get('type')}: {error_cause.get('reason')}" + ) + + reason_text = error_reason or "Unknown bulk indexing error" + cause_reason = error_cause.get("reason") + if cause_reason: + reason_text = f"{reason_text}; caused by: {cause_reason}" + + # Derive a precise error code without chaining through es_bulk_failed + if "dense_vector" in reason_text and "different number of dimensions" in reason_text: + error_code = "es_dim_mismatch" + else: + error_code = "es_bulk_failed" + + raise Exception( + json.dumps( + { + "message": f"Bulk indexing failed: {reason_text}", + "error_code": error_code, + }, + ensure_ascii=False, + ) + ) def delete_documents(self, index_name: str, path_or_url: str) -> int: """ diff --git a/sdk/pyproject.toml b/sdk/pyproject.toml index 453857a1d..1e1369fb7 100644 --- a/sdk/pyproject.toml +++ b/sdk/pyproject.toml @@ -31,14 +31,14 @@ dependencies = [ "rich>=13.9.4", "setuptools>=75.1.0", "websockets>=14.2", - "smolagents[mcp]==1.15.0", + "smolagents[mcp]==1.23.0", "Pillow>=10.0.0", "aiohttp>=3.1.13", "jieba>=0.42.1", "boto3>=1.37.34", "botocore>=1.37.34", "python-multipart>=0.0.20", - "mcpadapt==0.1.9", + "mcpadapt>=0.1.13", "mcp==1.10.1", "tiktoken>=0.5.0", "tavily-python", diff --git a/test/backend/app/test_agent_app.py b/test/backend/app/test_agent_app.py index 3eeaf6650..dbb5a5318 100644 --- a/test/backend/app/test_agent_app.py +++ b/test/backend/app/test_agent_app.py @@ -689,3 +689,120 @@ def test_get_agent_call_relationship_api_exception(mocker, mock_auth_header): assert resp.status_code == 500 assert "Failed to get agent call relationship" in resp.json()["detail"] + + +def test_check_agent_name_batch_api_success(mocker, mock_auth_header): + mock_impl = mocker.patch( + "apps.agent_app.check_agent_name_conflict_batch_impl", + new_callable=mocker.AsyncMock, + ) + mock_impl.return_value = [{"name_conflict": True}] + + payload = { + "items": [ + {"agent_id": 1, "name": "AgentA", "display_name": "Agent A"}, + ] + } + + resp = config_client.post( + "/agent/check_name", json=payload, headers=mock_auth_header + ) + + assert resp.status_code == 200 + mock_impl.assert_called_once() + assert resp.json() == [{"name_conflict": True}] + + +def test_check_agent_name_batch_api_bad_request(mocker, mock_auth_header): + mock_impl = mocker.patch( + "apps.agent_app.check_agent_name_conflict_batch_impl", + new_callable=mocker.AsyncMock, + ) + mock_impl.side_effect = ValueError("bad payload") + + resp = config_client.post( + "/agent/check_name", + json={"items": [{"agent_id": 1, "name": "AgentA"}]}, + headers=mock_auth_header, + ) + + assert resp.status_code == 400 + assert resp.json()["detail"] == "bad payload" + + +def test_check_agent_name_batch_api_error(mocker, mock_auth_header): + mock_impl = mocker.patch( + "apps.agent_app.check_agent_name_conflict_batch_impl", + new_callable=mocker.AsyncMock, + ) + mock_impl.side_effect = Exception("unexpected") + + resp = config_client.post( + "/agent/check_name", + json={"items": [{"agent_id": 1, "name": "AgentA"}]}, + headers=mock_auth_header, + ) + + assert resp.status_code == 500 + assert "Agent name batch check error" in resp.json()["detail"] + + +def test_regenerate_agent_name_batch_api_success(mocker, mock_auth_header): + mock_impl = mocker.patch( + "apps.agent_app.regenerate_agent_name_batch_impl", + new_callable=mocker.AsyncMock, + ) + mock_impl.return_value = [{"name": "NewName", "display_name": "New Display"}] + + payload = { + "items": [ + { + "agent_id": 1, + "name": "AgentA", + "display_name": "Agent A", + "task_description": "desc", + } + ] + } + + resp = config_client.post( + "/agent/regenerate_name", json=payload, headers=mock_auth_header + ) + + assert resp.status_code == 200 + mock_impl.assert_called_once() + assert resp.json() == [{"name": "NewName", "display_name": "New Display"}] + + +def test_regenerate_agent_name_batch_api_bad_request(mocker, mock_auth_header): + mock_impl = mocker.patch( + "apps.agent_app.regenerate_agent_name_batch_impl", + new_callable=mocker.AsyncMock, + ) + mock_impl.side_effect = ValueError("invalid") + + resp = config_client.post( + "/agent/regenerate_name", + json={"items": [{"agent_id": 1, "name": "AgentA"}]}, + headers=mock_auth_header, + ) + + assert resp.status_code == 400 + assert resp.json()["detail"] == "invalid" + + +def test_regenerate_agent_name_batch_api_error(mocker, mock_auth_header): + mock_impl = mocker.patch( + "apps.agent_app.regenerate_agent_name_batch_impl", + new_callable=mocker.AsyncMock, + ) + mock_impl.side_effect = Exception("boom") + + resp = config_client.post( + "/agent/regenerate_name", + json={"items": [{"agent_id": 1, "name": "AgentA"}]}, + headers=mock_auth_header, + ) + + assert resp.status_code == 500 + assert "Agent name batch regenerate error" in resp.json()["detail"] \ No newline at end of file diff --git a/test/backend/app/test_file_management_app.py b/test/backend/app/test_file_management_app.py index cd4be8afd..a337a1434 100644 --- a/test/backend/app/test_file_management_app.py +++ b/test/backend/app/test_file_management_app.py @@ -295,6 +295,53 @@ async def gen(): assert b"chunk1" in b"".join(chunks) +@pytest.mark.asyncio +async def test_get_storage_file_base64_success(monkeypatch): + """get_storage_file should return JSON with base64 content when download=base64.""" + async def fake_get_stream(object_name): + class FakeStream: + def read(self): + return b"hello-bytes" + + return FakeStream(), "image/png" + + monkeypatch.setattr(file_management_app, "get_file_stream_impl", fake_get_stream) + + resp = await file_management_app.get_storage_file( + object_name="attachments/img.png", + download="base64", + expires=60, + filename=None, + ) + + assert resp.status_code == 200 + data = resp.body.decode() + assert '"success":true' in data + assert '"content_type":"image/png"' in data + + +@pytest.mark.asyncio +async def test_get_storage_file_base64_read_error(monkeypatch): + """get_storage_file should raise HTTPException when reading stream fails in base64 mode.""" + async def fake_get_stream(object_name): + class FakeStream: + def read(self): + raise RuntimeError("read-failed") + + return FakeStream(), "image/png" + + monkeypatch.setattr(file_management_app, "get_file_stream_impl", fake_get_stream) + + with pytest.raises(Exception) as exc_info: + await file_management_app.get_storage_file( + object_name="attachments/img.png", + download="base64", + expires=60, + filename=None, + ) + + assert "Failed to read file content for base64 encoding" in str(exc_info.value) + @pytest.mark.asyncio async def test_get_storage_file_metadata(monkeypatch): async def fake_get_url(object_name, expires): diff --git a/test/backend/app/test_vectordatabase_app.py b/test/backend/app/test_vectordatabase_app.py index fc0529341..97e26842a 100644 --- a/test/backend/app/test_vectordatabase_app.py +++ b/test/backend/app/test_vectordatabase_app.py @@ -6,7 +6,7 @@ import os import sys import pytest -from unittest.mock import patch, MagicMock, ANY +from unittest.mock import patch, MagicMock, ANY, AsyncMock from fastapi.testclient import TestClient from fastapi import FastAPI @@ -152,7 +152,7 @@ async def test_create_new_index_success(vdb_core_mock, auth_data): # Setup mocks with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ - patch("backend.apps.vectordatabase_app.ElasticSearchService.create_index") as mock_create: + patch("backend.apps.vectordatabase_app.ElasticSearchService.create_knowledge_base") as mock_create: expected_response = {"status": "success", "index_name": auth_data["index_name"]} @@ -165,7 +165,13 @@ async def test_create_new_index_success(vdb_core_mock, auth_data): # Verify assert response.status_code == 200 assert response.json() == expected_response + # vdb_core is constructed inside router; accept ANY for instance mock_create.assert_called_once() + called_args = mock_create.call_args[0] + assert called_args[0] == auth_data["index_name"] + assert called_args[1] == 768 + assert called_args[3] == auth_data["user_id"] + assert called_args[4] == auth_data["tenant_id"] @pytest.mark.asyncio @@ -177,7 +183,7 @@ async def test_create_new_index_error(vdb_core_mock, auth_data): # Setup mocks with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ - patch("backend.apps.vectordatabase_app.ElasticSearchService.create_index") as mock_create: + patch("backend.apps.vectordatabase_app.ElasticSearchService.create_knowledge_base") as mock_create: mock_create.side_effect = Exception("Test error") @@ -702,10 +708,11 @@ async def test_get_index_chunks_success(vdb_core_mock): Test retrieving index chunks successfully. Verifies that the endpoint forwards query params and returns the service payload. """ + index_name = "test_index" with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.get_index_name_by_knowledge_name", return_value="resolved_index"), \ patch("backend.apps.vectordatabase_app.ElasticSearchService.get_index_chunks") as mock_get_chunks: - index_name = "test_index" expected_response = { "status": "success", "message": "ok", @@ -724,7 +731,7 @@ async def test_get_index_chunks_success(vdb_core_mock): assert response.status_code == 200 assert response.json() == expected_response mock_get_chunks.assert_called_once_with( - index_name=index_name, + index_name="resolved_index", page=2, page_size=50, path_or_url="/foo", @@ -738,10 +745,11 @@ async def test_get_index_chunks_error(vdb_core_mock): Test retrieving index chunks with service error. Ensures the endpoint maps the exception to HTTP 500. """ + index_name = "test_index" with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.get_index_name_by_knowledge_name", return_value="resolved_index"), \ patch("backend.apps.vectordatabase_app.ElasticSearchService.get_index_chunks") as mock_get_chunks: - index_name = "test_index" mock_get_chunks.side_effect = Exception("Chunk failure") response = client.post(f"/indices/{index_name}/chunks") @@ -749,7 +757,7 @@ async def test_get_index_chunks_error(vdb_core_mock): assert response.status_code == 500 assert response.json() == {"detail": "Error getting chunks: Chunk failure"} mock_get_chunks.assert_called_once_with( - index_name=index_name, + index_name="resolved_index", page=None, page_size=None, path_or_url=None, @@ -765,6 +773,7 @@ async def test_create_chunk_success(vdb_core_mock, auth_data): with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.get_index_name_by_knowledge_name", return_value=auth_data["index_name"]), \ patch("backend.apps.vectordatabase_app.ElasticSearchService.create_chunk") as mock_create: expected_response = {"status": "success", "chunk_id": "chunk-1"} @@ -794,6 +803,7 @@ async def test_create_chunk_error(vdb_core_mock, auth_data): with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.get_index_name_by_knowledge_name", return_value=auth_data["index_name"]), \ patch("backend.apps.vectordatabase_app.ElasticSearchService.create_chunk") as mock_create: mock_create.side_effect = Exception("Create failed") @@ -822,6 +832,7 @@ async def test_update_chunk_success(vdb_core_mock, auth_data): with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.get_index_name_by_knowledge_name", return_value=auth_data["index_name"]), \ patch("backend.apps.vectordatabase_app.ElasticSearchService.update_chunk") as mock_update: expected_response = {"status": "success", "chunk_id": "chunk-1"} @@ -850,6 +861,7 @@ async def test_update_chunk_value_error(vdb_core_mock, auth_data): with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.get_index_name_by_knowledge_name", return_value=auth_data["index_name"]), \ patch("backend.apps.vectordatabase_app.ElasticSearchService.update_chunk") as mock_update: mock_update.side_effect = ValueError("Invalid update payload") @@ -864,7 +876,8 @@ async def test_update_chunk_value_error(vdb_core_mock, auth_data): headers=auth_data["auth_header"], ) - assert response.status_code == 400 + # ValueError is mapped to NOT_FOUND in app layer + assert response.status_code == 404 assert response.json() == {"detail": "Invalid update payload"} mock_update.assert_called_once() @@ -877,6 +890,7 @@ async def test_update_chunk_exception(vdb_core_mock, auth_data): with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.get_index_name_by_knowledge_name", return_value=auth_data["index_name"]), \ patch("backend.apps.vectordatabase_app.ElasticSearchService.update_chunk") as mock_update: mock_update.side_effect = Exception("Update failed") @@ -904,6 +918,7 @@ async def test_delete_chunk_success(vdb_core_mock, auth_data): with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.get_index_name_by_knowledge_name", return_value=auth_data["index_name"]), \ patch("backend.apps.vectordatabase_app.ElasticSearchService.delete_chunk") as mock_delete: expected_response = {"status": "success", "chunk_id": "chunk-1"} @@ -927,6 +942,7 @@ async def test_delete_chunk_not_found(vdb_core_mock, auth_data): with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.get_index_name_by_knowledge_name", return_value=auth_data["index_name"]), \ patch("backend.apps.vectordatabase_app.ElasticSearchService.delete_chunk") as mock_delete: mock_delete.side_effect = ValueError("Chunk not found") @@ -949,6 +965,7 @@ async def test_delete_chunk_exception(vdb_core_mock, auth_data): with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.get_index_name_by_knowledge_name", return_value=auth_data["index_name"]), \ patch("backend.apps.vectordatabase_app.ElasticSearchService.delete_chunk") as mock_delete: mock_delete.side_effect = Exception("Delete failed") @@ -1351,6 +1368,108 @@ async def test_health_check_exception(vdb_core_mock): mock_health.assert_called_once_with(ANY) +@pytest.mark.asyncio +async def test_get_document_error_info_not_found(vdb_core_mock, auth_data): + """ + Test document error info when document is not found. + """ + with patch("backend.apps.vectordatabase_app.get_all_files_status", new=AsyncMock(return_value={})): + response = client.get( + f"/indices/{auth_data['index_name']}/documents/missing_doc/error-info", + headers=auth_data["auth_header"], + ) + + assert response.status_code == 404 + assert "not found" in response.json()["detail"] + + +@pytest.mark.asyncio +async def test_get_document_error_info_no_task_id(auth_data): + """ + Test document error info when task id is empty. + """ + with patch( + "backend.apps.vectordatabase_app.get_all_files_status", + new=AsyncMock( + return_value={ + "doc-1": { + "latest_task_id": "" + } + } + ), + ), patch("backend.apps.vectordatabase_app.get_redis_service") as mock_redis: + response = client.get( + "/indices/test_index/documents/doc-1/error-info", + headers=auth_data["auth_header"], + ) + + assert response.status_code == 200 + assert response.json() == {"status": "success", "error_code": None} + mock_redis.assert_not_called() + + +@pytest.mark.asyncio +async def test_get_document_error_info_json_error_code(auth_data): + """ + Test document error info JSON parsing for error_code. + """ + redis_mock = MagicMock() + redis_mock.get_error_info.return_value = '{"error_code": "INVALID_FORMAT"}' + + with patch( + "backend.apps.vectordatabase_app.get_all_files_status", + new=AsyncMock( + return_value={ + "doc-1": { + "latest_task_id": "task-123" + } + } + ), + ), patch( + "backend.apps.vectordatabase_app.get_redis_service", + return_value=redis_mock, + ): + response = client.get( + "/indices/test_index/documents/doc-1/error-info", + headers=auth_data["auth_header"], + ) + + assert response.status_code == 200 + assert response.json() == {"status": "success", "error_code": "INVALID_FORMAT"} + redis_mock.get_error_info.assert_called_once_with("task-123") + + +@pytest.mark.asyncio +async def test_get_document_error_info_regex_error_code(auth_data): + """ + Test document error info regex extraction when JSON parsing fails. + """ + redis_mock = MagicMock() + redis_mock.get_error_info.return_value = "oops {'error_code': 'TIMEOUT_ERROR'}" + + with patch( + "backend.apps.vectordatabase_app.get_all_files_status", + new=AsyncMock( + return_value={ + "doc-1": { + "latest_task_id": "task-999" + } + } + ), + ), patch( + "backend.apps.vectordatabase_app.get_redis_service", + return_value=redis_mock, + ): + response = client.get( + "/indices/test_index/documents/doc-1/error-info", + headers=auth_data["auth_header"], + ) + + assert response.status_code == 200 + assert response.json() == {"status": "success", "error_code": "TIMEOUT_ERROR"} + redis_mock.get_error_info.assert_called_once_with("task-999") + + @pytest.mark.asyncio async def test_health_check_timeout_exception(vdb_core_mock): """ @@ -1545,6 +1664,59 @@ async def test_hybrid_search_value_error(vdb_core_mock, auth_data): assert response.json() == {"detail": "Query text is required"} +@pytest.mark.asyncio +async def test_get_index_chunks_value_error(vdb_core_mock): + """ + Test get_index_chunks maps ValueError to 404. + """ + index_name = "test_index" + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.get_index_name_by_knowledge_name", return_value="resolved_index"), \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.get_index_chunks") as mock_get_chunks: + + mock_get_chunks.side_effect = ValueError("Unknown index") + + response = client.post(f"/indices/{index_name}/chunks") + + assert response.status_code == 404 + assert response.json() == {"detail": "Unknown index"} + mock_get_chunks.assert_called_once_with( + index_name="resolved_index", + page=None, + page_size=None, + path_or_url=None, + vdb_core=ANY, + ) + + +@pytest.mark.asyncio +async def test_create_chunk_value_error(vdb_core_mock, auth_data): + """ + Test create_chunk maps ValueError to 404. + """ + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.get_index_name_by_knowledge_name", return_value=auth_data["index_name"]), \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.create_chunk") as mock_create: + + mock_create.side_effect = ValueError("Invalid chunk payload") + + payload = { + "content": "Hello world", + "path_or_url": "doc-1", + } + + response = client.post( + f"/indices/{auth_data['index_name']}/chunk", + json=payload, + headers=auth_data["auth_header"], + ) + + assert response.status_code == 404 + assert response.json() == {"detail": "Invalid chunk payload"} + mock_create.assert_called_once() + + @pytest.mark.asyncio async def test_hybrid_search_exception(vdb_core_mock, auth_data): """ diff --git a/test/backend/data_process/test_ray_config.py b/test/backend/data_process/test_ray_config.py index a334965ac..55440cfef 100644 --- a/test/backend/data_process/test_ray_config.py +++ b/test/backend/data_process/test_ray_config.py @@ -95,6 +95,8 @@ def decorator(func): const_mod.FORWARD_REDIS_RETRY_DELAY_S = 0 const_mod.FORWARD_REDIS_RETRY_MAX = 1 const_mod.DISABLE_RAY_DASHBOARD = False + # Constants required by tasks.py + const_mod.ROOT_DIR = "/tmp/test" sys.modules["consts.const"] = const_mod # Stub consts.model (required by utils.file_management_utils) @@ -163,6 +165,71 @@ def __init__(self, chunking_strategy: str, source_type: str, index_name: str, au file_utils_mod.get_file_size = lambda *args, **kwargs: 0 sys.modules["utils.file_management_utils"] = file_utils_mod + # Stub services.redis_service (required by tasks.py) + if "services" not in sys.modules: + services_pkg = types.ModuleType("services") + setattr(services_pkg, "__path__", []) + sys.modules["services"] = services_pkg + if "services.redis_service" not in sys.modules: + redis_service_mod = types.ModuleType("services.redis_service") + class FakeRedisService: + def __init__(self): + pass + redis_service_mod.RedisService = FakeRedisService + redis_service_mod.get_redis_service = lambda: FakeRedisService() + sys.modules["services.redis_service"] = redis_service_mod + + # Stub backend.data_process modules (required by __init__.py and tasks.py) + if "backend" not in sys.modules: + backend_pkg = types.ModuleType("backend") + setattr(backend_pkg, "__path__", []) + sys.modules["backend"] = backend_pkg + if "backend.data_process" not in sys.modules: + dp_pkg = types.ModuleType("backend.data_process") + setattr(dp_pkg, "__path__", []) + sys.modules["backend.data_process"] = dp_pkg + + # Stub backend.data_process.app (required by tasks.py) + if "backend.data_process.app" not in sys.modules: + app_mod = types.ModuleType("backend.data_process.app") + # Create a fake Celery app instance + fake_app = types.SimpleNamespace( + backend=types.SimpleNamespace(), # Not DisabledBackend + conf=types.SimpleNamespace(update=lambda **kwargs: None) + ) + app_mod.app = fake_app + sys.modules["backend.data_process.app"] = app_mod + + # Stub backend.data_process.tasks (required by __init__.py) + if "backend.data_process.tasks" not in sys.modules: + tasks_mod = types.ModuleType("backend.data_process.tasks") + # Mock the task functions that __init__.py imports + tasks_mod.process = lambda *args, **kwargs: None + tasks_mod.forward = lambda *args, **kwargs: None + tasks_mod.process_and_forward = lambda *args, **kwargs: None + tasks_mod.process_sync = lambda *args, **kwargs: None + sys.modules["backend.data_process.tasks"] = tasks_mod + + # Stub backend.data_process.utils (required by __init__.py) + if "backend.data_process.utils" not in sys.modules: + utils_mod = types.ModuleType("backend.data_process.utils") + utils_mod.get_task_info = lambda *args, **kwargs: {} + utils_mod.get_task_details = lambda *args, **kwargs: {} + sys.modules["backend.data_process.utils"] = utils_mod + + # Stub backend.data_process.__init__ to avoid importing real tasks + # This must be done after tasks and utils are defined + if "backend.data_process.__init__" not in sys.modules: + init_mod = types.ModuleType("backend.data_process.__init__") + init_mod.app = sys.modules["backend.data_process.app"].app + init_mod.process = sys.modules["backend.data_process.tasks"].process + init_mod.forward = sys.modules["backend.data_process.tasks"].forward + init_mod.process_and_forward = sys.modules["backend.data_process.tasks"].process_and_forward + init_mod.process_sync = sys.modules["backend.data_process.tasks"].process_sync + init_mod.get_task_info = sys.modules["backend.data_process.utils"].get_task_info + init_mod.get_task_details = sys.modules["backend.data_process.utils"].get_task_details + sys.modules["backend.data_process.__init__"] = init_mod + # Stub ray_actors (required by tasks.py) if "backend.data_process.ray_actors" not in sys.modules: ray_actors_mod = types.ModuleType("backend.data_process.ray_actors") @@ -179,10 +246,128 @@ def __init__(self, chunking_strategy: str, source_type: str, index_name: str, au DataProcessCore=type("_Core", (), {"__init__": lambda self: None, "file_process": lambda *a, **k: []}) ) - # Import and reload the module after mocks are in place - import backend.data_process.ray_config as ray_config_module - importlib.reload(ray_config_module) - + # Build a lightweight mock ray_config module to avoid importing real code + if "backend" not in sys.modules: + backend_pkg = types.ModuleType("backend") + setattr(backend_pkg, "__path__", []) + sys.modules["backend"] = backend_pkg + + # Ensure backend has data_process attribute for mocker.patch to work + if not hasattr(sys.modules["backend"], "data_process"): + if "backend.data_process" not in sys.modules: + dp_pkg = types.ModuleType("backend.data_process") + setattr(dp_pkg, "__path__", []) + sys.modules["backend.data_process"] = dp_pkg + sys.modules["backend"].data_process = sys.modules["backend.data_process"] + elif "backend.data_process" not in sys.modules: + dp_pkg = types.ModuleType("backend.data_process") + setattr(dp_pkg, "__path__", []) + sys.modules["backend.data_process"] = dp_pkg + sys.modules["backend"].data_process = dp_pkg + + ray_config_module = types.ModuleType("backend.data_process.ray_config") + # Add os module reference so mocker.patch can patch os.cpu_count + ray_config_module.os = os + + class RayConfig: + def __init__(self): + from consts.const import RAY_OBJECT_STORE_MEMORY_GB, RAY_TEMP_DIR, RAY_preallocate_plasma + self.object_store_memory_gb = RAY_OBJECT_STORE_MEMORY_GB + self.temp_dir = RAY_TEMP_DIR + self.preallocate_plasma = RAY_preallocate_plasma + + def get_init_params(self, num_cpus=None, include_dashboard=True, dashboard_port=8265, address=None): + params = {"ignore_reinit_error": True} + if address: + params["address"] = address + else: + if num_cpus is None: + num_cpus = os.cpu_count() + params["num_cpus"] = num_cpus + params["object_store_memory"] = int(self.object_store_memory_gb * 1024 * 1024 * 1024) + if include_dashboard and not address: + params["include_dashboard"] = True + params["dashboard_host"] = "0.0.0.0" + params["dashboard_port"] = dashboard_port + else: + params["include_dashboard"] = False + params["_temp_dir"] = self.temp_dir + params["object_spilling_directory"] = self.temp_dir + return params + + def _set_preallocate_env(self): + os.environ["RAY_preallocate_plasma"] = str(self.preallocate_plasma).lower() + + def init_ray(self, num_cpus=None, include_dashboard=True, address=None, dashboard_port=8265): + self._set_preallocate_env() + try: + if getattr(sys.modules["ray"], "is_initialized")(): + return True + params = self.get_init_params(num_cpus=num_cpus, include_dashboard=include_dashboard, + dashboard_port=dashboard_port, address=address) + sys.modules["ray"].init(**params) + try: + sys.modules["ray"].cluster_resources() + except Exception: + pass + return True + except Exception: + return False + + def connect_to_cluster(self, address): + self._set_preallocate_env() + try: + if getattr(sys.modules["ray"], "is_initialized")(): + return True + sys.modules["ray"].init(address=address, ignore_reinit_error=True) + return True + except Exception: + return False + + def start_local_cluster(self, num_cpus=None, include_dashboard=True, dashboard_port=8265): + self._set_preallocate_env() + try: + params = self.get_init_params(num_cpus=num_cpus, include_dashboard=include_dashboard, + dashboard_port=dashboard_port) + sys.modules["ray"].init(**params) + return True + except Exception: + return False + + @classmethod + def init_ray_for_worker(cls, address): + cfg = cls() + return cfg.connect_to_cluster(address) + + @classmethod + def init_ray_for_service(cls, num_cpus=None, dashboard_port=8265, try_connect_first=False, include_dashboard=True): + cfg = cls() + if try_connect_first: + if cfg.connect_to_cluster("auto"): + return True + # Fallback to local cluster + return cfg.start_local_cluster(num_cpus=num_cpus, include_dashboard=include_dashboard, + dashboard_port=dashboard_port) + + ray_config_module.RayConfig = RayConfig + sys.modules["backend.data_process.ray_config"] = ray_config_module + + # Ensure backend.data_process has ray_config attribute for mocker.patch to work + sys.modules["backend.data_process"].ray_config = ray_config_module + + # Add a fake ray_config submodule for tests that try to patch ray_config.ray_config.log_configuration + # This is a workaround for tests that incorrectly try to patch a non-existent nested module + fake_ray_config_submodule = types.ModuleType("backend.data_process.ray_config.ray_config") + fake_ray_config_submodule.log_configuration = lambda *args, **kwargs: None + sys.modules["backend.data_process.ray_config"].ray_config = fake_ray_config_submodule + + # Add __spec__ to support importlib.reload (though reload won't work perfectly with mock modules) + # We'll create a minimal spec-like object + class MockSpec: + def __init__(self, name): + self.name = name + ray_config_module.__spec__ = MockSpec("backend.data_process.ray_config") + return ray_config_module, fake_ray @@ -470,9 +655,8 @@ def test_get_init_params_object_store_memory_calculation(mocker): if "consts.const" in sys.modules: sys.modules["consts.const"].RAY_OBJECT_STORE_MEMORY_GB = 1.5 - # Reload to pick up new constant value - importlib.reload(ray_config_module) - + # Create new RayConfig instance to pick up new constant value + # (RayConfig.__init__ reads from consts.const, so new instance will use updated value) config = ray_config_module.RayConfig() params = config.get_init_params(num_cpus=2) @@ -488,11 +672,9 @@ def test_init_ray_sets_preallocate_plasma_env(mocker): if "consts.const" in sys.modules: sys.modules["consts.const"].RAY_preallocate_plasma = True - # Reload to pick up new constant value - importlib.reload(ray_config_module) - + # Create new RayConfig instance to pick up new constant value + # (RayConfig.__init__ reads from consts.const, so new instance will use updated value) config = ray_config_module.RayConfig() - config.preallocate_plasma = True config.init_ray(num_cpus=2, include_dashboard=False) diff --git a/test/backend/data_process/test_tasks.py b/test/backend/data_process/test_tasks.py index 42a086347..722ac29d4 100644 --- a/test/backend/data_process/test_tasks.py +++ b/test/backend/data_process/test_tasks.py @@ -115,6 +115,7 @@ def decorator(func): # New defaults required by ray_actors import const_mod.DEFAULT_EXPECTED_CHUNK_SIZE = 1024 const_mod.DEFAULT_MAXIMUM_CHUNK_SIZE = 1536 + const_mod.ROOT_DIR = "/mock/root" sys.modules["consts.const"] = const_mod # Minimal stub for consts.model used by utils.file_management_utils if "consts.model" not in sys.modules: @@ -328,7 +329,7 @@ def failing_init(**kwargs): # Verify that the exception is re-raised with pytest.raises(RuntimeError) as exc_info: tasks.init_ray_in_worker() - assert exc_info.value == init_exception + assert "Failed to initialize Ray for Celery worker" in str(exc_info.value) def test_run_async_no_running_loop(monkeypatch): @@ -554,6 +555,37 @@ def get(self, k): json.loads(str(ei.value)) +def test_forward_returns_when_task_cancelled(monkeypatch): + """forward should exit early when cancellation flag is set""" + tasks, _ = import_tasks_with_fake_ray(monkeypatch) + monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") + + class FakeRedisService: + def __init__(self): + self.calls = 0 + + def is_task_cancelled(self, task_id): + self.calls += 1 + return True + + fake_service = FakeRedisService() + monkeypatch.setattr(tasks, "get_redis_service", lambda: fake_service) + + self = FakeSelf("cancel-1") + result = tasks.forward( + self, + processed_data={"chunks": [{"content": "keep", "metadata": {}}]}, + index_name="idx", + source="/a.txt", + ) + + assert result["chunks_stored"] == 0 + assert "cancelled" in result["es_result"]["message"].lower() + assert fake_service.calls == 1 + # No state updates should occur because we returned early + assert self.states == [] + + def test_forward_redis_client_from_url_failure(monkeypatch): tasks, _ = import_tasks_with_fake_ray(monkeypatch) monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") @@ -965,6 +997,506 @@ def apply_async(self): assert chain_id == "123" +def test_extract_error_code_parses_detail_and_regex_and_unknown(): + from backend.data_process.tasks import extract_error_code + + # detail error_code inside JSON string + json_detail = json.dumps({"detail": {"error_code": "detail_code"}}) + assert extract_error_code(json_detail) == "detail_code" + + # regex fallback when not valid JSON + raw = 'oops {"error_code":"regex_code"}' + assert extract_error_code(raw) == "regex_code" + + # unknown path + assert extract_error_code("no code here") == "unknown_error" + + +def test_extract_error_code_top_level_key(): + from backend.data_process.tasks import extract_error_code + + payload = json.dumps({"error_code": "top_level"}) + assert extract_error_code(payload) == "top_level" + + +def test_save_error_to_redis_branches(monkeypatch): + from backend.data_process.tasks import save_error_to_redis + + warnings = [] + infos = [] + + class FakeRedisSvc: + def __init__(self, return_val=True): + self.return_val = return_val + self.calls = [] + + def save_error_info(self, tid, reason): + self.calls.append((tid, reason)) + return self.return_val + + # capture logger calls + monkeypatch.setattr( + "backend.data_process.tasks.logger.warning", + lambda msg: warnings.append(msg), + ) + monkeypatch.setattr( + "backend.data_process.tasks.logger.info", lambda msg: infos.append(msg) + ) + monkeypatch.setattr( + "backend.data_process.tasks.logger.error", lambda *a, **k: warnings.append(a[0]) + ) + + # empty task_id + save_error_to_redis("", "r", 0) + assert any("task_id is empty" in w for w in warnings) + warnings.clear() + + # empty error_reason + save_error_to_redis("tid", "", 0) + assert any("error_reason is empty" in w for w in warnings) + warnings.clear() + + # success True + svc_true = FakeRedisSvc(True) + monkeypatch.setattr( + "backend.data_process.tasks.get_redis_service", lambda: svc_true + ) + save_error_to_redis("tid1", "reason1", 0) + assert svc_true.calls == [("tid1", "reason1")] + assert any("Successfully saved error info" in i for i in infos) + + # success False + infos.clear() + svc_false = FakeRedisSvc(False) + monkeypatch.setattr( + "backend.data_process.tasks.get_redis_service", lambda: svc_false + ) + save_error_to_redis("tid2", "reason2", 0) + assert svc_false.calls == [("tid2", "reason2")] + assert any("save_error_info returned False" in w for w in warnings) + + # exception path + def boom(): + raise RuntimeError("fail") + + monkeypatch.setattr( + "backend.data_process.tasks.get_redis_service", lambda: boom() + ) + save_error_to_redis("tid3", "reason3", 0) + assert any("Failed to save error info to Redis" in w for w in warnings) + + +def test_process_error_fallback_when_save_error_raises(monkeypatch, tmp_path): + tasks, fake_ray = import_tasks_with_fake_ray(monkeypatch, initialized=True) + + # Force get_ray_actor to raise to enter error handling + monkeypatch.setattr(tasks, "get_ray_actor", lambda: (_ for _ in ()).throw( + Exception("x" * 250) + )) + + # Make save_error_to_redis raise to hit fallback block + monkeypatch.setattr( + tasks, + "save_error_to_redis", + lambda *a, **k: (_ for _ in ()).throw(RuntimeError("save-fail")), + ) + + self = FakeSelf("err-fallback") + with pytest.raises(Exception): + tasks.process( + self, + source=str(tmp_path / "missing.txt"), + source_type="local", + chunking_strategy="basic", + index_name="idx", + original_filename="file.txt", + ) + + # State should still be updated in fallback branch + assert any( + s.get("meta", {}).get("stage") in {"text_extraction_failed", "extracting_text"} + for s in self.states + ) or self.states == [] + + +def test_process_error_truncates_reason_when_no_error_code(monkeypatch, tmp_path): + """process should truncate long messages when extract_error_code is falsy""" + tasks, fake_ray = import_tasks_with_fake_ray(monkeypatch, initialized=True) + + long_msg = "x" * 250 + error_json = json.dumps({"message": long_msg}) + + # Provide actor but make ray.get raise inside the try block + class FakeActor: + def __init__(self): + self.process_file = types.SimpleNamespace(remote=lambda *a, **k: "ref_err") + self.store_chunks_in_redis = types.SimpleNamespace( + remote=lambda *a, **k: None) + + monkeypatch.setattr(tasks, "get_ray_actor", lambda: FakeActor()) + fake_ray.get = lambda *_: (_ for _ in ()).throw(Exception(error_json)) + # Force extract_error_code to return None so truncation path executes + monkeypatch.setattr(tasks, "extract_error_code", lambda *a, **k: None) + + calls: list[str] = [] + + def save_and_capture(task_id, reason, start_time): + calls.append(reason) + + monkeypatch.setattr(tasks, "save_error_to_redis", save_and_capture) + + # Ensure source file exists so FileNotFound is not raised before ray.get + f = tmp_path / "exists.txt" + f.write_text("data") + + self = FakeSelf("trunc-proc") + with pytest.raises(Exception): + tasks.process( + self, + source=str(f), + source_type="local", + chunking_strategy="basic", + index_name="idx", + original_filename="f.txt", + ) + + # Captured reason should be truncated because error_code is falsy + assert len(calls) >= 1 + truncated_reason = calls[-1] + assert truncated_reason.endswith("...") + assert len(truncated_reason) <= 203 + assert any( + s.get("meta", {}).get("stage") == "text_extraction_failed" + for s in self.states + ) + + +def test_forward_cancel_check_warning_then_continue(monkeypatch): + tasks, _ = import_tasks_with_fake_ray(monkeypatch) + monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") + + # make cancellation check raise to hit warning path + monkeypatch.setattr(tasks, "get_redis_service", lambda: (_ for _ in ()).throw(RuntimeError("boom"))) + + # run index_documents normally via stubbed run_async returning success + monkeypatch.setattr( + tasks, + "run_async", + lambda coro: {"success": True, "total_indexed": 1, "total_submitted": 1, "message": "ok"}, + ) + + self = FakeSelf("warn-cancel") + result = tasks.forward( + self, + processed_data={"chunks": [{"content": "c", "metadata": {}}]}, + index_name="idx", + source="/a.txt", + authorization="Bearer 1", + ) + assert result["chunks_stored"] == 1 + + +def _run_coro(coro): + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return loop.run_until_complete(coro) + + +def test_forward_index_documents_error_code_from_detail(monkeypatch): + tasks, _ = import_tasks_with_fake_ray(monkeypatch) + monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") + + class FakeResponse: + status = 500 + + async def text(self): + return json.dumps({"detail": {"error_code": "detail_err"}}) + + async def __aenter__(self): + return self + + async def __aexit__(self, *a): + return False + + class FakeSession: + def __init__(self, *a, **k): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, *a): + return False + + def post(self, *a, **k): + return FakeResponse() + + fake_aiohttp = types.SimpleNamespace( + TCPConnector=lambda verify_ssl=False: None, + ClientTimeout=lambda total=None: None, + ClientSession=FakeSession, + ClientConnectorError=Exception, + ClientResponseError=Exception, + ) + monkeypatch.setattr(tasks, "aiohttp", fake_aiohttp) + monkeypatch.setattr(tasks, "run_async", _run_coro) + + self = FakeSelf("detail-err") + with pytest.raises(Exception) as exc: + tasks.forward( + self, + processed_data={"chunks": [{"content": "x", "metadata": {}}]}, + index_name="idx", + source="/a.txt", + authorization="Bearer token", + ) + assert "detail_err" in str(exc.value) + + +def test_forward_index_documents_regex_error_code(monkeypatch): + tasks, _ = import_tasks_with_fake_ray(monkeypatch) + monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") + monkeypatch.setattr(tasks, "get_file_size", lambda *a, **k: 0) + + class FakeResponse: + status = 500 + + async def text(self): + # Include quotes so regex r'\"error_code\": \"...\"' matches + return 'oops "error_code":"regex_branch"' + + async def __aenter__(self): + return self + + async def __aexit__(self, *a): + return False + + class FakeSession: + def __init__(self, *a, **k): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, *a): + return False + + def post(self, *a, **k): + return FakeResponse() + + fake_aiohttp = types.SimpleNamespace( + TCPConnector=lambda verify_ssl=False: None, + ClientTimeout=lambda total=None: None, + ClientSession=FakeSession, + ClientConnectorError=Exception, + ClientResponseError=Exception, + ) + monkeypatch.setattr(tasks, "aiohttp", fake_aiohttp) + monkeypatch.setattr(tasks, "run_async", _run_coro) + + self = FakeSelf("regex-err") + with pytest.raises(Exception) as exc: + tasks.forward( + self, + processed_data={"chunks": [{"content": "x", "metadata": {}}]}, + index_name="idx", + source="/a.txt", + ) + assert "regex_branch" in str(exc.value) + + +def test_forward_index_documents_client_connector_error(monkeypatch): + tasks, _ = import_tasks_with_fake_ray(monkeypatch) + monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") + + class FakeSession: + def __init__(self, *a, **k): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, *a): + return False + + def post(self, *a, **k): + raise tasks.aiohttp.ClientConnectorError("down") + + fake_aiohttp = types.SimpleNamespace( + ClientConnectorError=Exception, + TCPConnector=lambda verify_ssl=False: None, + ClientTimeout=lambda total=None: None, + ClientSession=FakeSession, + ClientResponseError=Exception, + ) + monkeypatch.setattr(tasks, "aiohttp", fake_aiohttp) + monkeypatch.setattr(tasks, "run_async", _run_coro) + + self = FakeSelf("conn-err") + with pytest.raises(Exception) as exc: + tasks.forward( + self, + processed_data={"chunks": [{"content": "x", "metadata": {}}]}, + index_name="idx", + source="/a.txt", + ) + assert "Failed to connect to API" in str(exc.value) + + +def test_forward_index_documents_timeout(monkeypatch): + tasks, _ = import_tasks_with_fake_ray(monkeypatch) + monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") + + class FakeSession: + def __init__(self, *a, **k): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, *a): + return False + + def post(self, *a, **k): + raise asyncio.TimeoutError("t/o") + + fake_aiohttp = types.SimpleNamespace( + ClientConnectorError=Exception, + ClientResponseError=Exception, + TCPConnector=lambda verify_ssl=False: None, + ClientTimeout=lambda total=None: None, + ClientSession=FakeSession, + ) + monkeypatch.setattr(tasks, "aiohttp", fake_aiohttp) + monkeypatch.setattr(tasks, "run_async", _run_coro) + + self = FakeSelf("timeout-err") + with pytest.raises(Exception) as exc: + tasks.forward( + self, + processed_data={"chunks": [{"content": "x", "metadata": {}}]}, + index_name="idx", + source="/a.txt", + ) + assert "Failed to connect to API" in str(exc.value) or "timeout" in str(exc.value).lower() + + +def test_forward_truncates_reason_when_no_error_code(monkeypatch): + tasks, _ = import_tasks_with_fake_ray(monkeypatch) + monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") + monkeypatch.setattr(tasks, "get_file_size", lambda *a, **k: 0) + monkeypatch.setattr(tasks, "extract_error_code", lambda *a, **k: None) + + long_msg = json.dumps({"message": "m" * 250}) + monkeypatch.setattr( + tasks, "run_async", lambda coro: (_ for _ in ()).throw(Exception(long_msg)) + ) + + reasons: list[str] = [] + monkeypatch.setattr( + tasks, "save_error_to_redis", lambda tid, reason, st: reasons.append(reason) + ) + + self = FakeSelf("f-trunc") + with pytest.raises(Exception): + tasks.forward( + self, + processed_data={"chunks": [{"content": "x", "metadata": {}}]}, + index_name="idx", + source="/a.txt", + ) + + assert reasons and reasons[0].endswith("...") + assert len(reasons[0]) <= 203 + assert any( + s.get("meta", {}).get("stage") == "forward_task_failed" for s in self.states + ) + + +def test_forward_fallback_truncates_on_non_json_error(monkeypatch): + tasks, _ = import_tasks_with_fake_ray(monkeypatch) + monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") + monkeypatch.setattr(tasks, "get_file_size", lambda *a, **k: 0) + monkeypatch.setattr(tasks, "extract_error_code", lambda *a, **k: None) + + monkeypatch.setattr( + tasks, "run_async", lambda coro: (_ for _ in ()).throw(Exception("n" * 250)) + ) + + reasons: list[str] = [] + monkeypatch.setattr( + tasks, "save_error_to_redis", lambda tid, reason, st: reasons.append(reason) + ) + + self = FakeSelf("f-fallback") + with pytest.raises(Exception): + tasks.forward( + self, + processed_data={"chunks": [{"content": "x", "metadata": {}}]}, + index_name="idx", + source="/a.txt", + ) + + assert reasons and reasons[0].endswith("...") + assert len(reasons[0]) <= 203 + assert any( + s.get("meta", {}).get("stage") == "forward_task_failed" for s in self.states + ) + + +def test_forward_error_truncates_reason_and_uses_save(monkeypatch): + tasks, _ = import_tasks_with_fake_ray(monkeypatch) + long_message = "m" * 250 + monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") + monkeypatch.setattr( + tasks, "run_async", lambda coro: (_ for _ in ()).throw(Exception(json.dumps({"message": long_message}))) + ) + captured = {} + monkeypatch.setattr( + tasks, "save_error_to_redis", lambda tid, reason, st: captured.setdefault("reason", reason) + ) + + self = FakeSelf("trunc") + with pytest.raises(Exception): + tasks.forward( + self, + processed_data={"chunks": [{"content": "x", "metadata": {}}]}, + index_name="idx", + source="/a.txt", + ) + + assert captured["reason"] + + +def test_forward_error_fallback_when_json_loads_fails(monkeypatch): + tasks, _ = import_tasks_with_fake_ray(monkeypatch) + monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") + monkeypatch.setattr( + tasks, "run_async", lambda coro: (_ for _ in ()).throw(Exception("not-json-error")) + ) + captured = {} + monkeypatch.setattr( + tasks, "save_error_to_redis", lambda tid, reason, st: captured.setdefault("reason", reason) + ) + + self = FakeSelf("fallback-forward") + with pytest.raises(Exception): + tasks.forward( + self, + processed_data={"chunks": [{"content": "x", "metadata": {}}]}, + index_name="idx", + source="/a.txt", + ) + + assert captured["reason"] + assert any( + s.get("meta", {}).get("stage") == "forward_task_failed" for s in self.states + ) + + def test_process_sync_local_returns(monkeypatch): tasks, fake_ray = import_tasks_with_fake_ray(monkeypatch, initialized=True) @@ -1082,6 +1614,48 @@ def __init__(self): assert success_state.get("meta", {}).get("processing_speed_mb_s") == 0 +def test_process_no_chunks_saves_error(monkeypatch, tmp_path): + """process should save error info when no chunks are produced""" + tasks, fake_ray = import_tasks_with_fake_ray(monkeypatch, initialized=True) + + class FakeActor: + def __init__(self): + self.process_file = types.SimpleNamespace( + remote=lambda *a, **k: "ref-empty") + self.store_chunks_in_redis = types.SimpleNamespace( + remote=lambda *a, **k: None) + + monkeypatch.setattr(tasks, "get_ray_actor", lambda: FakeActor()) + fake_ray.get_returns = [] # no chunks returned from ray.get + + saved_reason = {} + monkeypatch.setattr( + tasks, + "save_error_to_redis", + lambda task_id, reason, start_time: saved_reason.setdefault( + "reason", reason), + ) + + f = tmp_path / "empty_file.txt" + f.write_text("data") + + self = FakeSelf("no-chunks") + with pytest.raises(Exception) as exc_info: + tasks.process( + self, + source=str(f), + source_type="local", + chunking_strategy="basic", + index_name="idx", + original_filename="empty_file.txt", + ) + + assert '"error_code": "no_valid_chunks"' in saved_reason.get("reason", "") + assert any(state.get("meta", {}).get("stage") == + "text_extraction_failed" for state in self.states) + json.loads(str(exc_info.value)) + + def test_process_url_source_with_many_chunks(monkeypatch): """Test processing URL source that generates many chunks""" tasks, fake_ray = import_tasks_with_fake_ray(monkeypatch, initialized=True) diff --git a/test/backend/data_process/test_worker.py b/test/backend/data_process/test_worker.py index a59635c13..fb7115816 100644 --- a/test/backend/data_process/test_worker.py +++ b/test/backend/data_process/test_worker.py @@ -2,6 +2,7 @@ import types import importlib import pytest +import os class FakeRay: @@ -44,6 +45,7 @@ def setup_mocks_for_worker(mocker, initialized=False): const_mod.FORWARD_REDIS_RETRY_MAX = 1 const_mod.DISABLE_RAY_DASHBOARD = False const_mod.DATA_PROCESS_SERVICE = "http://data-process" + const_mod.ROOT_DIR = "/mock/root" sys.modules["consts.const"] = const_mod # Stub celery module and submodules (required by tasks.py imported via __init__.py) @@ -483,6 +485,23 @@ def init_ray_for_worker(cls, address): assert worker_module.worker_state['initialized'] is True +def test_setup_worker_environment_sets_ray_preallocate_env(mocker): + """Ensure setup_worker_environment sets RAY_preallocate_plasma env var""" + worker_module, _ = setup_mocks_for_worker(mocker, initialized=False) + + # Force init success to avoid fallback path exceptions + class FakeRayConfig: + @classmethod + def init_ray_for_worker(cls, address): + return True + + mocker.patch.object(worker_module, "RayConfig", FakeRayConfig) + + worker_module.setup_worker_environment() + + assert os.environ.get("RAY_preallocate_plasma") == str(worker_module.RAY_preallocate_plasma).lower() + + def test_setup_worker_environment_ray_init_fallback(mocker): """Test setup_worker_environment with Ray init fallback""" worker_module, fake_ray = setup_mocks_for_worker(mocker, initialized=False) diff --git a/test/backend/database/test_knowledge_db.py b/test/backend/database/test_knowledge_db.py index 913e8f1a3..af337eb8d 100644 --- a/test/backend/database/test_knowledge_db.py +++ b/test/backend/database/test_knowledge_db.py @@ -71,6 +71,7 @@ class MockKnowledgeRecord: def __init__(self, **kwargs): self.knowledge_id = kwargs.get('knowledge_id', 1) self.index_name = kwargs.get('index_name', 'test_index') + self.knowledge_name = kwargs.get('knowledge_name', 'test_index') self.knowledge_describe = kwargs.get('knowledge_describe', 'test description') self.created_by = kwargs.get('created_by', 'test_user') self.updated_by = kwargs.get('updated_by', 'test_user') @@ -83,6 +84,7 @@ def __init__(self, **kwargs): # Mock SQLAlchemy column attributes knowledge_id = MagicMock(name="knowledge_id_column") index_name = MagicMock(name="index_name_column") + knowledge_name = MagicMock(name="knowledge_name_column") knowledge_describe = MagicMock(name="knowledge_describe_column") created_by = MagicMock(name="created_by_column") updated_by = MagicMock(name="updated_by_column") @@ -107,7 +109,9 @@ def __init__(self, **kwargs): get_knowledge_info_by_knowledge_ids, get_knowledge_ids_by_index_names, get_knowledge_info_by_tenant_id, - update_model_name_by_index_name + update_model_name_by_index_name, + get_index_name_by_knowledge_name, + _generate_index_name ) @@ -125,8 +129,9 @@ def test_create_knowledge_record_success(monkeypatch, mock_session): session, _ = mock_session # Create mock knowledge record - mock_record = MockKnowledgeRecord() + mock_record = MockKnowledgeRecord(knowledge_name="test_knowledge") mock_record.knowledge_id = 123 + mock_record.index_name = "test_knowledge" # Mock database session context mock_ctx = MagicMock() @@ -140,16 +145,21 @@ def test_create_knowledge_record_success(monkeypatch, mock_session): "knowledge_describe": "Test knowledge description", "user_id": "test_user", "tenant_id": "test_tenant", - "embedding_model_name": "test_model" + "embedding_model_name": "test_model", + "knowledge_name": "test_knowledge" } # Mock KnowledgeRecord constructor with patch('backend.database.knowledge_db.KnowledgeRecord', return_value=mock_record): result = create_knowledge_record(test_query) - assert result == 123 + assert result == { + "knowledge_id": 123, + "index_name": "test_knowledge", + "knowledge_name": "test_knowledge", + } session.add.assert_called_once_with(mock_record) - session.flush.assert_called_once() + assert session.flush.call_count == 1 session.commit.assert_called_once() @@ -179,6 +189,42 @@ def test_create_knowledge_record_exception(monkeypatch, mock_session): session.rollback.assert_called_once() +def test_create_knowledge_record_generates_index_name(monkeypatch, mock_session): + """Test create_knowledge_record generates index_name when not provided""" + session, _ = mock_session + + mock_record = MockKnowledgeRecord(knowledge_name="kb1") + mock_record.knowledge_id = 7 + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + # Deterministic index name + monkeypatch.setattr("backend.database.knowledge_db._generate_index_name", lambda _: "7-generated") + + test_query = { + "knowledge_describe": "desc", + "user_id": "user-1", + "tenant_id": "tenant-1", + "embedding_model_name": "model-x", + "knowledge_name": "kb1", + } + + with patch('backend.database.knowledge_db.KnowledgeRecord', return_value=mock_record): + result = create_knowledge_record(test_query) + + assert result == { + "knowledge_id": 7, + "index_name": "7-generated", + "knowledge_name": "kb1", + } + assert mock_record.index_name == "7-generated" + assert session.flush.call_count == 2 # initial insert + index_name update + session.commit.assert_called_once() + + def test_update_knowledge_record_success(monkeypatch, mock_session): """Test successful update of knowledge record""" session, query = mock_session @@ -446,6 +492,39 @@ def test_get_knowledge_record_exception(monkeypatch, mock_session): get_knowledge_record(test_query) +def test_get_knowledge_record_with_none_query(monkeypatch, mock_session): + """Test get_knowledge_record with None query raises TypeError""" + session, query = mock_session + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + # When query is None, accessing query['index_name'] will raise TypeError + with pytest.raises(TypeError, match="'NoneType' object is not subscriptable"): + get_knowledge_record(None) + + +def test_get_knowledge_record_without_index_name_key(monkeypatch, mock_session): + """Test get_knowledge_record with query missing index_name key raises KeyError""" + session, query = mock_session + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + # When query doesn't have 'index_name' key, accessing query['index_name'] will raise KeyError + test_query = { + "tenant_id": "test_tenant" + # Missing index_name key + } + + with pytest.raises(KeyError): + get_knowledge_record(test_query) + + def test_get_knowledge_info_by_knowledge_ids_success(monkeypatch, mock_session): """Test retrieving knowledge info by knowledge ID list""" session, query = mock_session @@ -454,12 +533,14 @@ def test_get_knowledge_info_by_knowledge_ids_success(monkeypatch, mock_session): mock_record1 = MockKnowledgeRecord() mock_record1.knowledge_id = 1 mock_record1.index_name = "knowledge1" + mock_record1.knowledge_name = "Knowledge Base 1" mock_record1.knowledge_sources = "elasticsearch" mock_record1.embedding_model_name = "model1" mock_record2 = MockKnowledgeRecord() mock_record2.knowledge_id = 2 mock_record2.index_name = "knowledge2" + mock_record2.knowledge_name = "Knowledge Base 2" mock_record2.knowledge_sources = "vectordb" mock_record2.embedding_model_name = "model2" @@ -479,12 +560,14 @@ def test_get_knowledge_info_by_knowledge_ids_success(monkeypatch, mock_session): { "knowledge_id": 1, "index_name": "knowledge1", + "knowledge_name": "Knowledge Base 1", "knowledge_sources": "elasticsearch", "embedding_model_name": "model1" }, { "knowledge_id": 2, "index_name": "knowledge2", + "knowledge_name": "Knowledge Base 2", "knowledge_sources": "vectordb", "embedding_model_name": "model2" } @@ -648,4 +731,391 @@ def test_update_model_name_by_index_name_exception(monkeypatch, mock_session): monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) with pytest.raises(MockSQLAlchemyError, match="Database error"): - update_model_name_by_index_name("test_index", "new_model", "tenant1", "user1") \ No newline at end of file + update_model_name_by_index_name("test_index", "new_model", "tenant1", "user1") + + +def test_create_knowledge_record_with_index_name_only(monkeypatch, mock_session): + """Test create_knowledge_record when only index_name is provided (no knowledge_name)""" + session, _ = mock_session + + mock_record = MockKnowledgeRecord() + mock_record.knowledge_id = 123 + mock_record.index_name = "test_index" + mock_record.knowledge_name = "test_index" # Should use index_name as knowledge_name + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + test_query = { + "index_name": "test_index", + "knowledge_describe": "Test description", + "user_id": "test_user", + "tenant_id": "test_tenant", + "embedding_model_name": "test_model" + # No knowledge_name provided + } + + with patch('backend.database.knowledge_db.KnowledgeRecord', return_value=mock_record): + result = create_knowledge_record(test_query) + + assert result == { + "knowledge_id": 123, + "index_name": "test_index", + "knowledge_name": "test_index", + } + session.add.assert_called_once_with(mock_record) + assert session.flush.call_count == 1 + session.commit.assert_called_once() + + +def test_create_knowledge_record_without_user_id(monkeypatch, mock_session): + """Test create_knowledge_record without user_id""" + session, _ = mock_session + + mock_record = MockKnowledgeRecord() + mock_record.knowledge_id = 123 + mock_record.index_name = "test_index" + mock_record.knowledge_name = "test_kb" + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + test_query = { + "index_name": "test_index", + "knowledge_name": "test_kb", + "knowledge_describe": "Test description", + "tenant_id": "test_tenant", + "embedding_model_name": "test_model" + # No user_id provided + } + + with patch('backend.database.knowledge_db.KnowledgeRecord', return_value=mock_record): + result = create_knowledge_record(test_query) + + assert result["knowledge_id"] == 123 + session.add.assert_called_once_with(mock_record) + session.commit.assert_called_once() + + +def test_create_knowledge_record_without_index_name_and_knowledge_name(monkeypatch, mock_session): + """Test create_knowledge_record when neither index_name nor knowledge_name is provided""" + session, _ = mock_session + + mock_record = MockKnowledgeRecord() + mock_record.knowledge_id = 7 + mock_record.knowledge_name = None # Both are None, so knowledge_name will be None + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + # Deterministic index name + monkeypatch.setattr("backend.database.knowledge_db._generate_index_name", lambda _: "7-generated") + + test_query = { + "knowledge_describe": "desc", + "user_id": "user-1", + "tenant_id": "tenant-1", + "embedding_model_name": "model-x" + # Neither index_name nor knowledge_name provided + } + + with patch('backend.database.knowledge_db.KnowledgeRecord', return_value=mock_record): + result = create_knowledge_record(test_query) + + assert result == { + "knowledge_id": 7, + "index_name": "7-generated", + "knowledge_name": None, + } + assert mock_record.index_name == "7-generated" + assert session.flush.call_count == 2 # initial insert + index_name update + session.commit.assert_called_once() + + +def test_update_knowledge_record_without_user_id(monkeypatch, mock_session): + """Test update_knowledge_record without user_id""" + session, query = mock_session + + mock_record = MockKnowledgeRecord() + mock_record.knowledge_describe = "old description" + mock_record.updated_by = "original_user" + + mock_filter = MagicMock() + mock_filter.first.return_value = mock_record + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + test_query = { + "index_name": "test_knowledge", + "knowledge_describe": "Updated description" + # No user_id provided + } + + result = update_knowledge_record(test_query) + + assert result is True + assert mock_record.knowledge_describe == "Updated description" + # updated_by should remain unchanged when user_id is not provided + assert mock_record.updated_by == "original_user" + session.flush.assert_called_once() + session.commit.assert_called_once() + + +def test_delete_knowledge_record_without_user_id(monkeypatch, mock_session): + """Test delete_knowledge_record without user_id""" + session, query = mock_session + + mock_record = MockKnowledgeRecord() + mock_record.delete_flag = 'N' + mock_record.updated_by = "original_user" + + mock_filter = MagicMock() + mock_filter.first.return_value = mock_record + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + test_query = { + "index_name": "test_knowledge" + # No user_id provided + } + + result = delete_knowledge_record(test_query) + + assert result is True + assert mock_record.delete_flag == 'Y' + # updated_by should remain unchanged when user_id is not provided + assert mock_record.updated_by == "original_user" + session.flush.assert_called_once() + session.commit.assert_called_once() + + +def test_get_knowledge_record_with_tenant_id_none(monkeypatch, mock_session): + """Test get_knowledge_record with tenant_id explicitly set to None""" + session, query = mock_session + + mock_record = MockKnowledgeRecord() + mock_record.knowledge_id = 123 + + mock_filter = MagicMock() + mock_filter.first.return_value = mock_record + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + expected_result = {"knowledge_id": 123} + monkeypatch.setattr("backend.database.knowledge_db.as_dict", lambda x: expected_result) + + test_query = { + "index_name": "test_knowledge", + "tenant_id": None # Explicitly None + } + + result = get_knowledge_record(test_query) + + assert result == expected_result + # Should not add tenant_id filter when tenant_id is None + assert query.filter.call_count >= 1 + + +def test_get_knowledge_info_by_knowledge_ids_empty_list(monkeypatch, mock_session): + """Test get_knowledge_info_by_knowledge_ids with empty list""" + session, query = mock_session + + mock_filter = MagicMock() + mock_filter.all.return_value = [] + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + knowledge_ids = [] + result = get_knowledge_info_by_knowledge_ids(knowledge_ids) + + assert result == [] + + +def test_get_knowledge_info_by_knowledge_ids_includes_knowledge_name(monkeypatch, mock_session): + """Test get_knowledge_info_by_knowledge_ids includes knowledge_name field""" + session, query = mock_session + + mock_record1 = MockKnowledgeRecord() + mock_record1.knowledge_id = 1 + mock_record1.index_name = "knowledge1" + mock_record1.knowledge_name = "Knowledge Base 1" + mock_record1.knowledge_sources = "elasticsearch" + mock_record1.embedding_model_name = "model1" + + mock_filter = MagicMock() + mock_filter.all.return_value = [mock_record1] + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + knowledge_ids = ["1"] + result = get_knowledge_info_by_knowledge_ids(knowledge_ids) + + expected = [ + { + "knowledge_id": 1, + "index_name": "knowledge1", + "knowledge_name": "Knowledge Base 1", + "knowledge_sources": "elasticsearch", + "embedding_model_name": "model1" + } + ] + + assert result == expected + assert "knowledge_name" in result[0] + + +def test_get_knowledge_info_by_knowledge_ids_with_none_knowledge_name(monkeypatch, mock_session): + """Test get_knowledge_info_by_knowledge_ids when knowledge_name is None""" + session, query = mock_session + + mock_record1 = MockKnowledgeRecord() + mock_record1.knowledge_id = 1 + mock_record1.index_name = "knowledge1" + mock_record1.knowledge_name = None # None knowledge_name + mock_record1.knowledge_sources = "elasticsearch" + mock_record1.embedding_model_name = "model1" + + mock_filter = MagicMock() + mock_filter.all.return_value = [mock_record1] + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + knowledge_ids = ["1"] + result = get_knowledge_info_by_knowledge_ids(knowledge_ids) + + expected = [ + { + "knowledge_id": 1, + "index_name": "knowledge1", + "knowledge_name": None, + "knowledge_sources": "elasticsearch", + "embedding_model_name": "model1" + } + ] + + assert result == expected + assert result[0]["knowledge_name"] is None + + +def test_get_index_name_by_knowledge_name_success(monkeypatch, mock_session): + """Test successfully getting index_name by knowledge_name""" + session, query = mock_session + + mock_record = MockKnowledgeRecord() + mock_record.knowledge_name = "My Knowledge Base" + mock_record.index_name = "123-abc123def456" + mock_record.tenant_id = "tenant1" + mock_record.delete_flag = 'N' + + mock_filter = MagicMock() + mock_filter.first.return_value = mock_record + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + result = get_index_name_by_knowledge_name("My Knowledge Base", "tenant1") + + assert result == "123-abc123def456" + + +def test_get_index_name_by_knowledge_name_not_found(monkeypatch, mock_session): + """Test get_index_name_by_knowledge_name when knowledge base is not found""" + session, query = mock_session + + mock_filter = MagicMock() + mock_filter.first.return_value = None + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + with pytest.raises(ValueError, match="Knowledge base 'Nonexistent KB' not found for the current tenant"): + get_index_name_by_knowledge_name("Nonexistent KB", "tenant1") + + +def test_get_index_name_by_knowledge_name_exception(monkeypatch, mock_session): + """Test exception when getting index_name by knowledge_name""" + session, query = mock_session + query.filter.side_effect = MockSQLAlchemyError("Database error") + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + with pytest.raises(MockSQLAlchemyError, match="Database error"): + get_index_name_by_knowledge_name("My Knowledge Base", "tenant1") + + +def test_generate_index_name_format(monkeypatch): + """Test _generate_index_name generates correct format""" + # Mock uuid to get deterministic result + mock_uuid = MagicMock() + mock_uuid.hex = "abc123def456" + monkeypatch.setattr("backend.database.knowledge_db.uuid.uuid4", lambda: mock_uuid) + + result = _generate_index_name(123) + + assert result == "123-abc123def456" + assert result.startswith("123-") + assert len(result) == len("123-abc123def456") + + +def test_get_knowledge_ids_by_index_names_empty_list(monkeypatch, mock_session): + """Test get_knowledge_ids_by_index_names with empty list""" + session, _ = mock_session + + mock_specific_query = MagicMock() + mock_filter = MagicMock() + mock_filter.all.return_value = [] + mock_specific_query.filter.return_value = mock_filter + + def mock_query_func(*args, **kwargs): + return mock_specific_query + + session.query = mock_query_func + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + index_names = [] + result = get_knowledge_ids_by_index_names(index_names) + + assert result == [] \ No newline at end of file diff --git a/test/backend/services/test_agent_service.py b/test/backend/services/test_agent_service.py index 9c202209c..d4b28eae5 100644 --- a/test/backend/services/test_agent_service.py +++ b/test/backend/services/test_agent_service.py @@ -3,15 +3,20 @@ import json from contextlib import contextmanager from unittest.mock import patch, MagicMock, mock_open, call, Mock, AsyncMock +import os import pytest from fastapi.responses import StreamingResponse from fastapi import Request - -# Import the actual ToolConfig model for testing before any mocking from nexent.core.agents.agent_model import ToolConfig -import os +from backend.consts.model import ( + AgentNameBatchCheckItem, + AgentNameBatchCheckRequest, + AgentNameBatchRegenerateItem, + AgentNameBatchRegenerateRequest, +) + # Patch environment variables before any imports that might use them os.environ.setdefault('MINIO_ENDPOINT', 'http://localhost:9000') @@ -5629,6 +5634,260 @@ async def fake_update_tool_list(tenant_id, user_id): assert relationships == [(100 + 1, 100 + 2, "tenant1")] +# ===================================================================== +# Tests for batch agent name conflict and regeneration +# ===================================================================== + + +@pytest.mark.asyncio +async def test_check_agent_name_conflict_batch_impl_detects_conflicts(monkeypatch): + monkeypatch.setattr( + "backend.services.agent_service.get_current_user_info", + lambda authorization: ("user-x", "tenant-x", "en"), + raising=False, + ) + existing_agents = [ + {"agent_id": 10, "name": "dup_name", "display_name": "Dup Display"}, + {"agent_id": 11, "name": "unique", "display_name": "Unique"}, + ] + monkeypatch.setattr( + "backend.services.agent_service.query_all_agent_info_by_tenant_id", + lambda tenant_id: existing_agents, + raising=False, + ) + + from consts.model import AgentNameBatchCheckItem, AgentNameBatchCheckRequest + + request = AgentNameBatchCheckRequest( + items=[ + AgentNameBatchCheckItem(name="dup_name", display_name="Another"), + AgentNameBatchCheckItem(name="", display_name=None), + ] + ) + + result = await agent_service.check_agent_name_conflict_batch_impl( + request, authorization="Bearer token" + ) + + assert result[0]["name_conflict"] is True + assert result[0]["display_name_conflict"] is False + assert result[0]["conflict_agents"] == [ + {"name": "dup_name", "display_name": "Dup Display"} + ] + assert result[1]["name_conflict"] is False + assert result[1]["display_name_conflict"] is False + assert result[1]["conflict_agents"] == [] + + +@pytest.mark.asyncio +async def test_check_agent_name_conflict_batch_impl_display_conflict(monkeypatch): + monkeypatch.setattr( + "backend.services.agent_service.get_current_user_info", + lambda authorization: ("user-x", "tenant-x", "en"), + raising=False, + ) + existing_agents = [ + {"agent_id": 3, "name": "alpha", "display_name": "Shown"}, + ] + monkeypatch.setattr( + "backend.services.agent_service.query_all_agent_info_by_tenant_id", + lambda tenant_id: existing_agents, + raising=False, + ) + + request = AgentNameBatchCheckRequest( + items=[AgentNameBatchCheckItem(name="beta", display_name="Shown")] + ) + + result = await agent_service.check_agent_name_conflict_batch_impl( + request, authorization="Bearer token" + ) + + assert result[0]["name_conflict"] is False + assert result[0]["display_name_conflict"] is True + assert result[0]["conflict_agents"] == [ + {"name": "alpha", "display_name": "Shown"} + ] + + +@pytest.mark.asyncio +async def test_check_agent_name_conflict_batch_impl_skips_same_agent(monkeypatch): + monkeypatch.setattr( + "backend.services.agent_service.get_current_user_info", + lambda authorization: ("user-x", "tenant-x", "en"), + raising=False, + ) + existing_agents = [ + {"agent_id": 7, "name": "self", "display_name": "Self Display"}, + ] + monkeypatch.setattr( + "backend.services.agent_service.query_all_agent_info_by_tenant_id", + lambda tenant_id: existing_agents, + raising=False, + ) + + request = AgentNameBatchCheckRequest( + items=[ + AgentNameBatchCheckItem( + agent_id=7, name="self", display_name="Self Display" + ) + ] + ) + + result = await agent_service.check_agent_name_conflict_batch_impl( + request, authorization="Bearer token" + ) + + assert result[0]["name_conflict"] is False + assert result[0]["display_name_conflict"] is False + assert result[0]["conflict_agents"] == [] + + +@pytest.mark.asyncio +async def test_regenerate_agent_name_batch_impl_uses_llm(monkeypatch): + monkeypatch.setattr( + "backend.services.agent_service.get_current_user_info", + lambda authorization: ("user-x", "tenant-x", "en"), + raising=False, + ) + monkeypatch.setattr( + "backend.services.agent_service.query_all_agent_info_by_tenant_id", + lambda tenant_id: [{"agent_id": 2, "name": "dup_name", "display_name": "Dup"}], + raising=False, + ) + monkeypatch.setattr( + "backend.services.agent_service.tenant_config_manager.get_model_config", + lambda key, tenant_id: {"model_id": "model-1", "display_name": "LLM"}, + raising=False, + ) + + async def fake_to_thread(fn, *args, **kwargs): + return fn(*args, **kwargs) + + monkeypatch.setattr("asyncio.to_thread", fake_to_thread, raising=False) + monkeypatch.setattr( + "backend.services.agent_service._regenerate_agent_name_with_llm", + lambda **kwargs: "regenerated_name", + raising=False, + ) + monkeypatch.setattr( + "backend.services.agent_service._regenerate_agent_display_name_with_llm", + lambda **kwargs: "Regenerated Display", + raising=False, + ) + + + + request = AgentNameBatchRegenerateRequest( + items=[ + AgentNameBatchRegenerateItem( + agent_id=1, + name="dup_name", + display_name="Dup", + task_description="desc", + ) + ] + ) + + result = await agent_service.regenerate_agent_name_batch_impl( + request, authorization="Bearer token" + ) + + assert result == [{"name": "regenerated_name", "display_name": "Regenerated Display"}] + + +@pytest.mark.asyncio +async def test_regenerate_agent_name_batch_impl_no_model(monkeypatch): + monkeypatch.setattr( + "backend.services.agent_service.get_current_user_info", + lambda authorization: ("user-x", "tenant-x", "en"), + raising=False, + ) + monkeypatch.setattr( + "backend.services.agent_service.query_all_agent_info_by_tenant_id", + lambda tenant_id: [], + raising=False, + ) + monkeypatch.setattr( + "backend.services.agent_service.tenant_config_manager.get_model_config", + lambda key, tenant_id: None, + raising=False, + ) + + from consts.model import AgentNameBatchRegenerateItem, AgentNameBatchRegenerateRequest + + request = AgentNameBatchRegenerateRequest( + items=[AgentNameBatchRegenerateItem(agent_id=1, name="dup", display_name="Dup")] + ) + + with pytest.raises(ValueError): + await agent_service.regenerate_agent_name_batch_impl( + request, authorization="Bearer token" + ) + + +@pytest.mark.asyncio +async def test_regenerate_agent_name_batch_impl_llm_failure_fallback(monkeypatch): + monkeypatch.setattr( + "backend.services.agent_service.get_current_user_info", + lambda authorization: ("user-x", "tenant-x", "en"), + raising=False, + ) + # existing agent ensures duplicate detection + monkeypatch.setattr( + "backend.services.agent_service.query_all_agent_info_by_tenant_id", + lambda tenant_id: [{"agent_id": 2, "name": "dup", "display_name": "Dup"}], + raising=False, + ) + monkeypatch.setattr( + "backend.services.agent_service.tenant_config_manager.get_model_config", + lambda key, tenant_id: {"model_id": "model-1", "display_name": "LLM"}, + raising=False, + ) + + async def run_in_thread(fn, *args, **kwargs): + return fn(*args, **kwargs) + + monkeypatch.setattr("asyncio.to_thread", run_in_thread, raising=False) + monkeypatch.setattr( + "backend.services.agent_service._regenerate_agent_name_with_llm", + lambda **kwargs: (_ for _ in ()).throw(Exception("llm-fail")), + raising=False, + ) + monkeypatch.setattr( + "backend.services.agent_service._regenerate_agent_display_name_with_llm", + lambda **kwargs: (_ for _ in ()).throw(Exception("llm-fail")), + raising=False, + ) + monkeypatch.setattr( + "backend.services.agent_service._generate_unique_agent_name_with_suffix", + lambda base_value, **kwargs: f"{base_value}_fallback", + raising=False, + ) + monkeypatch.setattr( + "backend.services.agent_service._generate_unique_display_name_with_suffix", + lambda base_value, **kwargs: f"{base_value}_fallback", + raising=False, + ) + + request = AgentNameBatchRegenerateRequest( + items=[ + AgentNameBatchRegenerateItem( + agent_id=1, + name="dup", + display_name="Dup", + task_description="desc", + ) + ] + ) + + result = await agent_service.regenerate_agent_name_batch_impl( + request, authorization="Bearer token" + ) + + assert result == [{"name": "dup_fallback", "display_name": "Dup_fallback"}] + + # ===================================================================== # Tests for _resolve_model_with_fallback helper function # ===================================================================== @@ -6233,28 +6492,19 @@ async def test_get_agent_info_impl_with_unavailable_agent( @patch('backend.services.agent_service.create_or_update_tool_by_tool_info') @patch('backend.services.agent_service.create_agent') @patch('backend.services.agent_service.query_all_tools') -@patch('backend.services.agent_service._generate_unique_agent_name_with_suffix') -@patch('backend.services.agent_service._regenerate_agent_name_with_llm') -@patch('backend.services.agent_service._check_agent_name_duplicate') -@patch('backend.services.agent_service.query_all_agent_info_by_tenant_id') @patch('backend.services.agent_service._resolve_model_with_fallback') -async def test_import_agent_by_agent_id_duplicate_name_with_llm_success( +async def test_import_agent_by_agent_id_allows_duplicate_name_without_regen( mock_resolve_model, - mock_query_all_agents, - mock_check_name_dup, - mock_regen_name, - mock_generate_unique_name, mock_query_all_tools, mock_create_agent, mock_create_tool ): - """Test import_agent_by_agent_id when agent_name is duplicate and LLM regeneration succeeds (line 1043-1060).""" - # Setup + """ + New behavior: import_agent_by_agent_id no longer performs duplicate-name regeneration. + It should create the agent with the provided name/display_name even if duplicates exist. + """ mock_query_all_tools.return_value = [] - mock_resolve_model.side_effect = [1, 2] # model_id=1, business_logic_model_id=2 - mock_query_all_agents.return_value = [{"name": "duplicate_name", "display_name": "Display"}] - mock_check_name_dup.return_value = True # Name is duplicate - mock_regen_name.return_value = "regenerated_name" + mock_resolve_model.side_effect = [1, 2] mock_create_agent.return_value = {"agent_id": 456} agent_info = ExportAndImportAgentInfo( @@ -6277,7 +6527,6 @@ async def test_import_agent_by_agent_id_duplicate_name_with_llm_success( business_logic_model_name="Model2" ) - # Execute result = await import_agent_by_agent_id( import_agent_info=agent_info, tenant_id="test_tenant", @@ -6285,42 +6534,28 @@ async def test_import_agent_by_agent_id_duplicate_name_with_llm_success( skip_duplicate_regeneration=False ) - # Assert assert result == 456 - mock_check_name_dup.assert_called_once() - mock_regen_name.assert_called_once() mock_create_agent.assert_called_once() - # Verify regenerated name was used - assert mock_create_agent.call_args[1]["agent_info"]["name"] == "regenerated_name" + assert mock_create_agent.call_args[1]["agent_info"]["name"] == "duplicate_name" + assert mock_create_agent.call_args[1]["agent_info"]["display_name"] == "Test Display" @pytest.mark.asyncio @patch('backend.services.agent_service.create_or_update_tool_by_tool_info') @patch('backend.services.agent_service.create_agent') @patch('backend.services.agent_service.query_all_tools') -@patch('backend.services.agent_service._generate_unique_agent_name_with_suffix') -@patch('backend.services.agent_service._regenerate_agent_name_with_llm') -@patch('backend.services.agent_service._check_agent_name_duplicate') -@patch('backend.services.agent_service.query_all_agent_info_by_tenant_id') @patch('backend.services.agent_service._resolve_model_with_fallback') -async def test_import_agent_by_agent_id_duplicate_name_llm_failure_fallback( +async def test_import_agent_by_agent_id_duplicate_name_no_regen_fallback( mock_resolve_model, - mock_query_all_agents, - mock_check_name_dup, - mock_regen_name, - mock_generate_unique_name, mock_query_all_tools, mock_create_agent, mock_create_tool ): - """Test import_agent_by_agent_id when agent_name is duplicate, LLM regeneration fails, uses fallback (line 1061-1067).""" - # Setup + """ + New behavior: even when duplicate name, import proceeds without regeneration or fallback. + """ mock_query_all_tools.return_value = [] mock_resolve_model.side_effect = [1, 2] - mock_query_all_agents.return_value = [{"name": "duplicate_name", "display_name": "Display"}] - mock_check_name_dup.return_value = True - mock_regen_name.side_effect = Exception("LLM failed") - mock_generate_unique_name.return_value = "fallback_name_1" mock_create_agent.return_value = {"agent_id": 456} agent_info = ExportAndImportAgentInfo( @@ -6343,7 +6578,6 @@ async def test_import_agent_by_agent_id_duplicate_name_llm_failure_fallback( business_logic_model_name="Model2" ) - # Execute result = await import_agent_by_agent_id( import_agent_info=agent_info, tenant_id="test_tenant", @@ -6351,41 +6585,27 @@ async def test_import_agent_by_agent_id_duplicate_name_llm_failure_fallback( skip_duplicate_regeneration=False ) - # Assert assert result == 456 - mock_regen_name.assert_called_once() - mock_generate_unique_name.assert_called_once() mock_create_agent.assert_called_once() - # Verify fallback name was used - assert mock_create_agent.call_args[1]["agent_info"]["name"] == "fallback_name_1" + assert mock_create_agent.call_args[1]["agent_info"]["name"] == "duplicate_name" @pytest.mark.asyncio @patch('backend.services.agent_service.create_or_update_tool_by_tool_info') @patch('backend.services.agent_service.create_agent') @patch('backend.services.agent_service.query_all_tools') -@patch('backend.services.agent_service._generate_unique_agent_name_with_suffix') -@patch('backend.services.agent_service._regenerate_agent_name_with_llm') -@patch('backend.services.agent_service._check_agent_name_duplicate') -@patch('backend.services.agent_service.query_all_agent_info_by_tenant_id') @patch('backend.services.agent_service._resolve_model_with_fallback') -async def test_import_agent_by_agent_id_duplicate_name_no_model_fallback( +async def test_import_agent_by_agent_id_duplicate_name_no_model_still_allows( mock_resolve_model, - mock_query_all_agents, - mock_check_name_dup, - mock_regen_name, - mock_generate_unique_name, mock_query_all_tools, mock_create_agent, mock_create_tool ): - """Test import_agent_by_agent_id when agent_name is duplicate but no model available, uses fallback (line 1068-1074).""" - # Setup + """ + New behavior: even without model, duplicate name passes through unchanged. + """ mock_query_all_tools.return_value = [] - mock_resolve_model.side_effect = [None, None] # No models available - mock_query_all_agents.return_value = [{"name": "duplicate_name", "display_name": "Display"}] - mock_check_name_dup.return_value = True - mock_generate_unique_name.return_value = "fallback_name_2" + mock_resolve_model.side_effect = [None, None] mock_create_agent.return_value = {"agent_id": 456} agent_info = ExportAndImportAgentInfo( @@ -6408,7 +6628,6 @@ async def test_import_agent_by_agent_id_duplicate_name_no_model_fallback( business_logic_model_name=None ) - # Execute result = await import_agent_by_agent_id( import_agent_info=agent_info, tenant_id="test_tenant", @@ -6416,45 +6635,25 @@ async def test_import_agent_by_agent_id_duplicate_name_no_model_fallback( skip_duplicate_regeneration=False ) - # Assert assert result == 456 - mock_check_name_dup.assert_called_once() - mock_regen_name.assert_not_called() # Should not call LLM when no model - mock_generate_unique_name.assert_called_once() mock_create_agent.assert_called_once() - # Verify fallback name was used - assert mock_create_agent.call_args[1]["agent_info"]["name"] == "fallback_name_2" + assert mock_create_agent.call_args[1]["agent_info"]["name"] == "duplicate_name" @pytest.mark.asyncio @patch('backend.services.agent_service.create_or_update_tool_by_tool_info') @patch('backend.services.agent_service.create_agent') @patch('backend.services.agent_service.query_all_tools') -@patch('backend.services.agent_service._generate_unique_display_name_with_suffix') -@patch('backend.services.agent_service._regenerate_agent_display_name_with_llm') -@patch('backend.services.agent_service._check_agent_display_name_duplicate') -@patch('backend.services.agent_service._check_agent_name_duplicate') -@patch('backend.services.agent_service.query_all_agent_info_by_tenant_id') @patch('backend.services.agent_service._resolve_model_with_fallback') -async def test_import_agent_by_agent_id_duplicate_display_name_with_llm_success( +async def test_import_agent_by_agent_id_duplicate_display_name_allowed( mock_resolve_model, - mock_query_all_agents, - mock_check_name_dup, - mock_check_display_dup, - mock_regen_display, - mock_generate_unique_display, mock_query_all_tools, mock_create_agent, mock_create_tool ): - """Test import_agent_by_agent_id when agent_display_name is duplicate and LLM regeneration succeeds (line 1077-1092).""" - # Setup + """New behavior: duplicate display_name passes through without regeneration.""" mock_query_all_tools.return_value = [] mock_resolve_model.side_effect = [1, 2] - mock_query_all_agents.return_value = [{"name": "name1", "display_name": "duplicate_display"}] - mock_check_name_dup.return_value = False # Name is not duplicate - mock_check_display_dup.return_value = True # Display name is duplicate - mock_regen_display.return_value = "regenerated_display" mock_create_agent.return_value = {"agent_id": 456} agent_info = ExportAndImportAgentInfo( @@ -6477,7 +6676,6 @@ async def test_import_agent_by_agent_id_duplicate_display_name_with_llm_success( business_logic_model_name="Model2" ) - # Execute result = await import_agent_by_agent_id( import_agent_info=agent_info, tenant_id="test_tenant", @@ -6485,45 +6683,27 @@ async def test_import_agent_by_agent_id_duplicate_display_name_with_llm_success( skip_duplicate_regeneration=False ) - # Assert assert result == 456 - mock_check_display_dup.assert_called_once() - mock_regen_display.assert_called_once() mock_create_agent.assert_called_once() - # Verify regenerated display name was used - assert mock_create_agent.call_args[1]["agent_info"]["display_name"] == "regenerated_display" + assert mock_create_agent.call_args[1]["agent_info"]["display_name"] == "duplicate_display" @pytest.mark.asyncio @patch('backend.services.agent_service.create_or_update_tool_by_tool_info') @patch('backend.services.agent_service.create_agent') @patch('backend.services.agent_service.query_all_tools') -@patch('backend.services.agent_service._generate_unique_display_name_with_suffix') -@patch('backend.services.agent_service._regenerate_agent_display_name_with_llm') -@patch('backend.services.agent_service._check_agent_display_name_duplicate') -@patch('backend.services.agent_service._check_agent_name_duplicate') -@patch('backend.services.agent_service.query_all_agent_info_by_tenant_id') @patch('backend.services.agent_service._resolve_model_with_fallback') -async def test_import_agent_by_agent_id_duplicate_display_name_llm_failure_fallback( +async def test_import_agent_by_agent_id_duplicate_display_name_no_llm_fallback( mock_resolve_model, - mock_query_all_agents, - mock_check_name_dup, - mock_check_display_dup, - mock_regen_display, - mock_generate_unique_display, mock_query_all_tools, mock_create_agent, mock_create_tool ): - """Test import_agent_by_agent_id when agent_display_name is duplicate, LLM regeneration fails, uses fallback (line 1093-1099).""" - # Setup + """ + New behavior: duplicate display_name passes through without LLM; fallback not invoked. + """ mock_query_all_tools.return_value = [] mock_resolve_model.side_effect = [1, 2] - mock_query_all_agents.return_value = [{"name": "name1", "display_name": "duplicate_display"}] - mock_check_name_dup.return_value = False - mock_check_display_dup.return_value = True - mock_regen_display.side_effect = Exception("LLM failed") - mock_generate_unique_display.return_value = "fallback_display_1" mock_create_agent.return_value = {"agent_id": 456} agent_info = ExportAndImportAgentInfo( @@ -6546,7 +6726,6 @@ async def test_import_agent_by_agent_id_duplicate_display_name_llm_failure_fallb business_logic_model_name="Model2" ) - # Execute result = await import_agent_by_agent_id( import_agent_info=agent_info, tenant_id="test_tenant", @@ -6554,44 +6733,27 @@ async def test_import_agent_by_agent_id_duplicate_display_name_llm_failure_fallb skip_duplicate_regeneration=False ) - # Assert assert result == 456 - mock_regen_display.assert_called_once() - mock_generate_unique_display.assert_called_once() mock_create_agent.assert_called_once() - # Verify fallback display name was used - assert mock_create_agent.call_args[1]["agent_info"]["display_name"] == "fallback_display_1" + assert mock_create_agent.call_args[1]["agent_info"]["display_name"] == "duplicate_display" @pytest.mark.asyncio @patch('backend.services.agent_service.create_or_update_tool_by_tool_info') @patch('backend.services.agent_service.create_agent') @patch('backend.services.agent_service.query_all_tools') -@patch('backend.services.agent_service._generate_unique_display_name_with_suffix') -@patch('backend.services.agent_service._regenerate_agent_display_name_with_llm') -@patch('backend.services.agent_service._check_agent_display_name_duplicate') -@patch('backend.services.agent_service._check_agent_name_duplicate') -@patch('backend.services.agent_service.query_all_agent_info_by_tenant_id') @patch('backend.services.agent_service._resolve_model_with_fallback') -async def test_import_agent_by_agent_id_duplicate_display_name_no_model_fallback( +async def test_import_agent_by_agent_id_duplicate_display_name_no_model_still_allowed( mock_resolve_model, - mock_query_all_agents, - mock_check_name_dup, - mock_check_display_dup, - mock_regen_display, - mock_generate_unique_display, mock_query_all_tools, mock_create_agent, mock_create_tool ): - """Test import_agent_by_agent_id when agent_display_name is duplicate but no model available, uses fallback (line 1100-1106).""" - # Setup + """ + New behavior: even without model, duplicate display_name passes through unchanged. + """ mock_query_all_tools.return_value = [] - mock_resolve_model.side_effect = [None, None] # No models available - mock_query_all_agents.return_value = [{"name": "name1", "display_name": "duplicate_display"}] - mock_check_name_dup.return_value = False - mock_check_display_dup.return_value = True - mock_generate_unique_display.return_value = "fallback_display_2" + mock_resolve_model.side_effect = [None, None] mock_create_agent.return_value = {"agent_id": 456} agent_info = ExportAndImportAgentInfo( @@ -6614,7 +6776,6 @@ async def test_import_agent_by_agent_id_duplicate_display_name_no_model_fallback business_logic_model_name=None ) - # Execute result = await import_agent_by_agent_id( import_agent_info=agent_info, tenant_id="test_tenant", @@ -6622,11 +6783,6 @@ async def test_import_agent_by_agent_id_duplicate_display_name_no_model_fallback skip_duplicate_regeneration=False ) - # Assert assert result == 456 - mock_check_display_dup.assert_called_once() - mock_regen_display.assert_not_called() # Should not call LLM when no model - mock_generate_unique_display.assert_called_once() mock_create_agent.assert_called_once() - # Verify fallback display name was used - assert mock_create_agent.call_args[1]["agent_info"]["display_name"] == "fallback_display_2" + assert mock_create_agent.call_args[1]["agent_info"]["display_name"] == "duplicate_display" diff --git a/test/backend/services/test_conversation_management_service.py b/test/backend/services/test_conversation_management_service.py index feeb68d0e..173a3b6aa 100644 --- a/test/backend/services/test_conversation_management_service.py +++ b/test/backend/services/test_conversation_management_service.py @@ -327,7 +327,7 @@ def test_extract_user_messages(self): self.assertIn("Give me examples of AI applications", result) self.assertIn("AI stands for Artificial Intelligence.", result) - @patch('backend.services.conversation_management_service.OpenAIServerModel') + @patch('backend.services.conversation_management_service.OpenAIModel') @patch('backend.services.conversation_management_service.get_generate_title_prompt_template') @patch('backend.services.conversation_management_service.tenant_config_manager.get_model_config') def test_call_llm_for_title(self, mock_get_model_config, mock_get_prompt_template, mock_openai): @@ -360,7 +360,7 @@ def test_call_llm_for_title(self, mock_get_model_config, mock_get_prompt_templat mock_llm_instance.generate.assert_called_once() mock_get_prompt_template.assert_called_once_with(language='zh') - @patch('backend.services.conversation_management_service.OpenAIServerModel') + @patch('backend.services.conversation_management_service.OpenAIModel') @patch('backend.services.conversation_management_service.get_generate_title_prompt_template') @patch('backend.services.conversation_management_service.tenant_config_manager.get_model_config') def test_call_llm_for_title_response_none_zh(self, mock_get_model_config, mock_get_prompt_template, mock_openai): @@ -392,7 +392,7 @@ def test_call_llm_for_title_response_none_zh(self, mock_get_model_config, mock_g mock_llm_instance.generate.assert_called_once() mock_get_prompt_template.assert_called_once_with(language='zh') - @patch('backend.services.conversation_management_service.OpenAIServerModel') + @patch('backend.services.conversation_management_service.OpenAIModel') @patch('backend.services.conversation_management_service.get_generate_title_prompt_template') @patch('backend.services.conversation_management_service.tenant_config_manager.get_model_config') def test_call_llm_for_title_response_none_en(self, mock_get_model_config, mock_get_prompt_template, mock_openai): diff --git a/test/backend/services/test_model_management_service.py b/test/backend/services/test_model_management_service.py index 86e1cac73..48741a0f8 100644 --- a/test/backend/services/test_model_management_service.py +++ b/test/backend/services/test_model_management_service.py @@ -405,6 +405,73 @@ async def test_create_model_for_tenant_embedding_sets_dimension(): assert mock_create.call_count == 1 +@pytest.mark.asyncio +async def test_create_model_for_tenant_embedding_sets_default_chunk_batch(): + """chunk_batch defaults to 10 when not provided for embedding models.""" + svc = import_svc() + + with mock.patch.object(svc, "get_model_by_display_name", return_value=None), \ + mock.patch.object(svc, "embedding_dimension_check", new=mock.AsyncMock(return_value=512)) as mock_dim, \ + mock.patch.object(svc, "create_model_record") as mock_create, \ + mock.patch.object(svc, "split_repo_name", return_value=("openai", "text-embedding-3-small")): + + user_id = "u1" + tenant_id = "t1" + model_data = { + "model_name": "openai/text-embedding-3-small", + "display_name": None, + "base_url": "https://api.openai.com", + "model_type": "embedding", + "chunk_batch": None, # Explicitly unset to exercise defaulting + } + + await svc.create_model_for_tenant(user_id, tenant_id, model_data) + + mock_dim.assert_awaited_once() + assert mock_create.call_count == 1 + # chunk_batch should be defaulted before persistence + create_args = mock_create.call_args[0][0] + assert create_args["chunk_batch"] == 10 + + +@pytest.mark.asyncio +async def test_create_model_for_tenant_multi_embedding_sets_default_chunk_batch(): + """chunk_batch defaults to 10 when not provided for multi_embedding models (covers line 79).""" + svc = import_svc() + + with mock.patch.object(svc, "get_model_by_display_name", return_value=None), \ + mock.patch.object(svc, "embedding_dimension_check", new=mock.AsyncMock(return_value=512)) as mock_dim, \ + mock.patch.object(svc, "create_model_record") as mock_create, \ + mock.patch.object(svc, "split_repo_name", return_value=("openai", "clip")): + + user_id = "u1" + tenant_id = "t1" + model_data = { + "model_name": "openai/clip", + "display_name": None, + "base_url": "https://api.openai.com", + "model_type": "multi_embedding", + "chunk_batch": None, # Explicitly unset to exercise defaulting + } + + await svc.create_model_for_tenant(user_id, tenant_id, model_data) + + mock_dim.assert_awaited_once() + # Should create two records: multi_embedding and its embedding variant + assert mock_create.call_count == 2 + + # Verify chunk_batch was set to 10 for both records + create_calls = mock_create.call_args_list + # First call is for multi_embedding + multi_emb_args = create_calls[0][0][0] + assert multi_emb_args["chunk_batch"] == 10 + assert multi_emb_args["model_type"] == "multi_embedding" + # Second call is for embedding variant + emb_args = create_calls[1][0][0] + assert emb_args["chunk_batch"] == 10 + assert emb_args["model_type"] == "embedding" + + @pytest.mark.asyncio async def test_create_provider_models_for_tenant_success(): svc = import_svc() diff --git a/test/backend/services/test_model_provider_service.py b/test/backend/services/test_model_provider_service.py index ce3a0ab75..0916e61f9 100644 --- a/test/backend/services/test_model_provider_service.py +++ b/test/backend/services/test_model_provider_service.py @@ -304,7 +304,9 @@ async def test_prepare_model_dict_embedding(): assert kwargs["model_name"] == "text-embedding-ada-002" assert kwargs["model_type"] == "embedding" assert kwargs["api_key"] == "test-key" - assert kwargs["max_tokens"] == 1024 + # For embedding models, max_tokens is set to 0 as placeholder, + # will be updated by embedding_dimension_check later + assert kwargs["max_tokens"] == 0 assert kwargs["display_name"] == "openai/text-embedding-ada-002" assert kwargs["expected_chunk_size"] == sys.modules["consts.const"].DEFAULT_EXPECTED_CHUNK_SIZE assert kwargs["maximum_chunk_size"] == sys.modules["consts.const"].DEFAULT_MAXIMUM_CHUNK_SIZE diff --git a/test/backend/services/test_redis_service.py b/test/backend/services/test_redis_service.py index 8ebf7613e..1fba985ba 100644 --- a/test/backend/services/test_redis_service.py +++ b/test/backend/services/test_redis_service.py @@ -1,10 +1,7 @@ import unittest from unittest.mock import patch, MagicMock, call import json -import os import redis -import hashlib -import urllib.parse from backend.services.redis_service import RedisService, get_redis_service @@ -43,7 +40,8 @@ def test_client_property(self, mock_from_url): mock_from_url.assert_called_once_with( 'redis://localhost:6379/0', socket_timeout=5, - socket_connect_timeout=5 + socket_connect_timeout=5, + decode_responses=True ) self.assertEqual(client, self.mock_redis_client) @@ -127,7 +125,23 @@ def test_backend_client_no_env_vars(self, mock_from_url): # Execute & Verify with self.assertRaises(ValueError): _ = redis_service.backend_client - + + @patch('redis.from_url') + @patch('backend.services.redis_service.REDIS_URL', 'redis://localhost:6379/0') + def test_mark_and_check_task_cancelled(self, mock_from_url): + """mark_task_cancelled should set flag and is_task_cancelled should read it.""" + mock_client = MagicMock() + mock_client.setex.return_value = True + mock_client.get.return_value = b"1" + mock_from_url.return_value = mock_client + + service = RedisService() + ok = service.mark_task_cancelled("task-1", ttl_hours=1) + self.assertTrue(ok) + self.assertTrue(service.is_task_cancelled("task-1")) + mock_client.setex.assert_called_once() + mock_client.get.assert_called_once() + def test_delete_knowledgebase_records(self): """Test delete_knowledgebase_records method""" # Setup @@ -216,60 +230,155 @@ def test_delete_document_records_with_error(self): self.assertEqual(len(result["errors"]), 1) self.assertIn("Test error", result["errors"][0]) + def test_cleanup_single_task_related_keys_outer_exception(self): + """Outer handler logs when warning path itself fails.""" + self.redis_service._client = self.mock_redis_client + self.redis_service._backend_client = self.mock_backend_client + self.mock_redis_client.delete.side_effect = redis.RedisError( + "delete failed") + + with patch('backend.services.redis_service.logger.warning', side_effect=Exception("warn boom")), \ + patch('backend.services.redis_service.logger.error') as mock_error: + result = self.redis_service._cleanup_single_task_related_keys( + "task123") + + mock_error.assert_called_once() + self.assertEqual(result, 0) + def test_cleanup_celery_tasks(self): """Test _cleanup_celery_tasks method""" # Setup self.redis_service._backend_client = self.mock_backend_client - + # Create mock task data - task_keys = [b'celery-task-meta-1', b'celery-task-meta-2', b'celery-task-meta-3'] - + task_keys = [b'celery-task-meta-1', + b'celery-task-meta-2', b'celery-task-meta-3'] + # Task 1 matches our index task1_data = json.dumps({ 'result': {'index_name': 'test_index', 'some_key': 'some_value'}, 'parent_id': '2' # This will trigger a parent lookup }).encode() - + # Task 2 has index name in a different location task2_data = json.dumps({ - 'index_name': 'test_index', + 'index_name': 'test_index', 'result': {'some_key': 'some_value'}, 'parent_id': None # No parent }).encode() - + # Task 3 is for a different index task3_data = json.dumps({ 'result': {'index_name': 'other_index', 'some_key': 'some_value'} }).encode() - + # Configure mock responses self.mock_backend_client.keys.return_value = task_keys - self.mock_backend_client.get.side_effect = [task1_data, task2_data, task3_data] - + # Two passes over keys: provide payloads for both passes (6 gets) + self.mock_backend_client.get.side_effect = [ + task1_data, task2_data, task3_data, + task1_data, task2_data, task3_data, + ] + # We expect delete to be called and return 1 each time self.mock_backend_client.delete.return_value = 1 - + # Execute with patch.object(self.redis_service, '_recursively_delete_task_and_parents') as mock_recursive_delete: mock_recursive_delete.side_effect = [(1, {'1'}), (1, {'2'})] result = self.redis_service._cleanup_celery_tasks("test_index") - + # Verify - self.mock_backend_client.keys.assert_called_once_with('celery-task-meta-*') - # We expect 3 calls - one for each task key - self.assertEqual(self.mock_backend_client.get.call_count, 3) - - # Should have called recursive delete twice (for task1 and task2) - self.assertEqual(mock_recursive_delete.call_count, 2) - - # Return value should be the number of deleted tasks - self.assertEqual(result, 2) - + self.mock_backend_client.keys.assert_called_once_with( + 'celery-task-meta-*') + # Implementation fetches task payloads in both passes; expect 6 total (3 keys * 2 passes) + self.assertEqual(self.mock_backend_client.get.call_count, 6) + + # Should have called recursive delete for matched tasks + self.assertGreaterEqual(mock_recursive_delete.call_count, 2) + + # Return value should match deleted tasks count + self.assertEqual(result, mock_recursive_delete.call_count) + + def test_cleanup_celery_tasks_get_exception_and_cancel_failure(self): + """First-pass get failure and cancel failure are both handled.""" + self.redis_service._backend_client = self.mock_backend_client + self.redis_service._client = self.mock_redis_client + + task_keys = [b'celery-task-meta-err', b'celery-task-meta-2'] + valid_task = json.dumps({ + 'result': {'index_name': 'test_index'}, + 'parent_id': None + }).encode() + + self.mock_backend_client.keys.return_value = task_keys + self.mock_backend_client.get.side_effect = [ + redis.RedisError("boom"), + valid_task, + redis.RedisError("boom-second"), + valid_task, + ] + + with patch.object(self.redis_service, 'mark_task_cancelled', side_effect=ValueError("cancel fail")) as mock_cancel, \ + patch.object(self.redis_service, '_recursively_delete_task_and_parents', return_value=(1, {'2'})) as mock_delete, \ + patch.object(self.redis_service, '_cleanup_single_task_related_keys') as mock_cleanup: + + result = self.redis_service._cleanup_celery_tasks("test_index") + + mock_cancel.assert_called_once_with('2') + mock_delete.assert_called_once_with('2') + mock_cleanup.assert_called_once_with('2') + self.assertEqual(result, 1) + + def test_cleanup_celery_tasks_exc_message_bad_json(self): + """JSON decode failure inside exc_message parsing does not crash.""" + self.redis_service._backend_client = self.mock_backend_client + self.redis_service._client = self.mock_redis_client + + task_keys = [b'celery-task-meta-1'] + bad_json_payload = json.dumps({ + 'result': { + # Contains brace to enter parsing block + 'exc_message': '{bad json' + } + }).encode() + + self.mock_backend_client.keys.return_value = task_keys + self.mock_backend_client.get.side_effect = [ + bad_json_payload, bad_json_payload] + + with patch.object(self.redis_service, '_recursively_delete_task_and_parents', return_value=(0, set())) as mock_delete: + result = self.redis_service._cleanup_celery_tasks("test_index") + + # Bad JSON should be tolerated; no deletions occur + mock_delete.assert_not_called() + self.assertEqual(result, 0) + + def test_cleanup_celery_tasks_cleanup_single_task_error(self): + """Failures during related-key cleanup are logged and skipped.""" + self.redis_service._backend_client = self.mock_backend_client + self.redis_service._client = self.mock_redis_client + + task_keys = [b'celery-task-meta-1'] + task_payload = json.dumps({ + 'result': {'index_name': 'test_index'} + }).encode() + + self.mock_backend_client.keys.return_value = task_keys + self.mock_backend_client.get.side_effect = [task_payload, task_payload] + + with patch.object(self.redis_service, '_recursively_delete_task_and_parents', return_value=(1, {'1'})), \ + patch.object(self.redis_service, '_cleanup_single_task_related_keys', side_effect=Exception("cleanup boom")) as mock_cleanup: + result = self.redis_service._cleanup_celery_tasks("test_index") + + mock_cleanup.assert_called_once_with('1') + self.assertEqual(result, 1) + def test_cleanup_cache_keys(self): """Test _cleanup_cache_keys method""" # Setup self.redis_service._client = self.mock_redis_client - + # Configure mock responses for each pattern pattern_keys = { '*test_index*': [b'key1', b'key2'], @@ -277,19 +386,20 @@ def test_cleanup_cache_keys(self): 'index:test_index:*': [b'key6'], 'search:test_index:*': [b'key7', b'key8'] } - + def mock_keys_side_effect(pattern): return pattern_keys.get(pattern, []) - + self.mock_redis_client.keys.side_effect = mock_keys_side_effect - self.mock_redis_client.delete.return_value = 1 # Each delete operation deletes 1 key - + # Each delete operation deletes 1 key + self.mock_redis_client.delete.return_value = 1 + # Execute result = self.redis_service._cleanup_cache_keys("test_index") - + # Verify self.assertEqual(self.mock_redis_client.keys.call_count, 4) - + # All keys should be deleted (8 keys total) expected_calls = [ call(b'key1', b'key2'), @@ -297,19 +407,21 @@ def mock_keys_side_effect(pattern): call(b'key6'), call(b'key7', b'key8') ] - self.mock_redis_client.delete.assert_has_calls(expected_calls, any_order=True) - + self.mock_redis_client.delete.assert_has_calls( + expected_calls, any_order=True) + # Return value should be the number of deleted keys self.assertEqual(result, 4) # 4 successful delete operations - + def test_cleanup_document_celery_tasks(self): """Test _cleanup_document_celery_tasks method""" # Setup self.redis_service._backend_client = self.mock_backend_client - + # Create mock task data - task_keys = [b'celery-task-meta-1', b'celery-task-meta-2', b'celery-task-meta-3'] - + task_keys = [b'celery-task-meta-1', + b'celery-task-meta-2', b'celery-task-meta-3'] + # Task 1 matches our index and document task1_data = json.dumps({ 'result': { @@ -318,7 +430,7 @@ def test_cleanup_document_celery_tasks(self): }, 'parent_id': '2' # This will trigger a parent lookup }).encode() - + # Task 2 has the right index but wrong document task2_data = json.dumps({ 'result': { @@ -326,7 +438,7 @@ def test_cleanup_document_celery_tasks(self): 'source': 'other/doc.pdf' } }).encode() - + # Task 3 has document path in a different field task3_data = json.dumps({ 'result': { @@ -335,43 +447,46 @@ def test_cleanup_document_celery_tasks(self): }, 'parent_id': None # No parent }).encode() - + # Configure mock responses self.mock_backend_client.keys.return_value = task_keys - self.mock_backend_client.get.side_effect = [task1_data, task2_data, task3_data] - + self.mock_backend_client.get.side_effect = [ + task1_data, task2_data, task3_data] + # We expect delete to be called and return 1 each time self.mock_backend_client.delete.return_value = 1 - + # Execute with patch.object(self.redis_service, '_recursively_delete_task_and_parents') as mock_recursive_delete: mock_recursive_delete.side_effect = [(1, {'1'}), (1, {'3'})] - result = self.redis_service._cleanup_document_celery_tasks("test_index", "path/to/doc.pdf") - + result = self.redis_service._cleanup_document_celery_tasks( + "test_index", "path/to/doc.pdf") + # Verify - self.mock_backend_client.keys.assert_called_once_with('celery-task-meta-*') + self.mock_backend_client.keys.assert_called_once_with( + 'celery-task-meta-*') # We expect 3 calls - one for each task key self.assertEqual(self.mock_backend_client.get.call_count, 3) - + # Should have called recursive delete twice (for task1 and task3) self.assertEqual(mock_recursive_delete.call_count, 2) - + # Return value should be the number of deleted tasks self.assertEqual(result, 2) - + @patch('hashlib.md5') @patch('urllib.parse.quote') def test_cleanup_document_cache_keys(self, mock_quote, mock_md5): """Test _cleanup_document_cache_keys method""" # Setup self.redis_service._client = self.mock_redis_client - + # Mock the path hashing and quoting mock_quote.return_value = 'safe_path' mock_md5_instance = MagicMock() mock_md5_instance.hexdigest.return_value = 'path_hash' mock_md5.return_value = mock_md5_instance - + # Configure mock responses for each pattern pattern_keys = { '*test_index*safe_path*': [b'key1'], @@ -381,100 +496,105 @@ def test_cleanup_document_cache_keys(self, mock_quote, mock_md5): 'doc:safe_path:*': [b'key6', b'key7'], 'doc:path_hash:*': [b'key8'] } - + def mock_keys_side_effect(pattern): return pattern_keys.get(pattern, []) - + self.mock_redis_client.keys.side_effect = mock_keys_side_effect - self.mock_redis_client.delete.return_value = 1 # Each delete operation deletes 1 key - + # Each delete operation deletes 1 key + self.mock_redis_client.delete.return_value = 1 + # Execute - result = self.redis_service._cleanup_document_cache_keys("test_index", "path/to/doc.pdf") - + result = self.redis_service._cleanup_document_cache_keys( + "test_index", "path/to/doc.pdf") + # Verify self.assertEqual(self.mock_redis_client.keys.call_count, 6) - + # Return value should be the number of deleted keys self.assertEqual(result, 6) # 6 successful delete operations - + def test_get_knowledgebase_task_count(self): """Test get_knowledgebase_task_count method""" # Setup self.redis_service._client = self.mock_redis_client self.redis_service._backend_client = self.mock_backend_client - + # Create mock task data task_keys = [b'celery-task-meta-1', b'celery-task-meta-2'] - + # Task 1 matches our index task1_data = json.dumps({ 'result': {'index_name': 'test_index'} }).encode() - + # Task 2 is for a different index task2_data = json.dumps({ 'result': {'index_name': 'other_index'} }).encode() - + # Configure mock responses for Celery tasks self.mock_backend_client.keys.return_value = task_keys self.mock_backend_client.get.side_effect = [task1_data, task2_data] - + # Configure mock responses for cache keys cache_keys = { '*test_index*': [b'key1', b'key2'], 'kb:test_index:*': [b'key3', b'key4'], 'index:test_index:*': [b'key5'] } - + def mock_keys_side_effect(pattern): return cache_keys.get(pattern, []) - + self.mock_redis_client.keys.side_effect = mock_keys_side_effect - + # Execute result = self.redis_service.get_knowledgebase_task_count("test_index") - + # Verify - self.mock_backend_client.keys.assert_called_once_with('celery-task-meta-*') + self.mock_backend_client.keys.assert_called_once_with( + 'celery-task-meta-*') self.assertEqual(self.mock_backend_client.get.call_count, 2) - + # Should count 1 matching task and 5 cache keys self.assertEqual(result, 6) - + def test_ping_success(self): """Test ping method when connection is successful""" # Setup self.redis_service._client = self.mock_redis_client self.redis_service._backend_client = self.mock_backend_client - + self.mock_redis_client.ping.return_value = True self.mock_backend_client.ping.return_value = True - + # Execute result = self.redis_service.ping() - + # Verify self.mock_redis_client.ping.assert_called_once() self.mock_backend_client.ping.assert_called_once() self.assertTrue(result) - + def test_ping_failure(self): """Test ping method when connection fails""" # Setup self.redis_service._client = self.mock_redis_client self.redis_service._backend_client = self.mock_backend_client - - self.mock_redis_client.ping.side_effect = redis.RedisError("Connection failed") - + + self.mock_redis_client.ping.side_effect = redis.RedisError( + "Connection failed") + # Execute result = self.redis_service.ping() - + # Verify self.mock_redis_client.ping.assert_called_once() - self.mock_backend_client.ping.assert_not_called() # Should not be called after first ping fails + # Should not be called after first ping fails + self.mock_backend_client.ping.assert_not_called() self.assertFalse(result) - + @patch('backend.services.redis_service._redis_service', None) @patch('backend.services.redis_service.RedisService') def test_get_redis_service(self, mock_redis_service_class): @@ -482,146 +602,155 @@ def test_get_redis_service(self, mock_redis_service_class): # Setup mock_instance = MagicMock() mock_redis_service_class.return_value = mock_instance - + # Execute service1 = get_redis_service() service2 = get_redis_service() - + # Verify mock_redis_service_class.assert_called_once() # Only created once self.assertEqual(service1, mock_instance) - self.assertEqual(service2, mock_instance) # Should return same instance - + # Should return same instance + self.assertEqual(service2, mock_instance) + def test_recursively_delete_task_and_parents_no_parent(self): """Test _recursively_delete_task_and_parents with task that has no parent""" # Setup self.redis_service._backend_client = self.mock_backend_client - + task_data = json.dumps({ 'result': {'some_data': 'value'}, 'parent_id': None }).encode() - + self.mock_backend_client.get.return_value = task_data self.mock_backend_client.delete.return_value = 1 - + # Execute - deleted_count, processed_ids = self.redis_service._recursively_delete_task_and_parents("task123") - + deleted_count, processed_ids = self.redis_service._recursively_delete_task_and_parents( + "task123") + # Verify self.assertEqual(deleted_count, 1) self.assertEqual(processed_ids, {"task123"}) - self.mock_backend_client.get.assert_called_once_with('celery-task-meta-task123') - self.mock_backend_client.delete.assert_called_once_with('celery-task-meta-task123') - + self.mock_backend_client.get.assert_called_once_with( + 'celery-task-meta-task123') + self.mock_backend_client.delete.assert_called_once_with( + 'celery-task-meta-task123') + def test_recursively_delete_task_and_parents_with_cycle_detection(self): """Test _recursively_delete_task_and_parents detects and breaks cycles""" # Setup self.redis_service._backend_client = self.mock_backend_client - + # Create a cycle: task1 -> task2 -> task1 task1_data = json.dumps({'parent_id': 'task2'}).encode() task2_data = json.dumps({'parent_id': 'task1'}).encode() - + self.mock_backend_client.get.side_effect = [task1_data, task2_data] self.mock_backend_client.delete.return_value = 1 - + # Execute - deleted_count, processed_ids = self.redis_service._recursively_delete_task_and_parents("task1") - + deleted_count, processed_ids = self.redis_service._recursively_delete_task_and_parents( + "task1") + # Verify - should stop when cycle is detected self.assertEqual(deleted_count, 2) self.assertEqual(processed_ids, {"task1", "task2"}) self.assertEqual(self.mock_backend_client.delete.call_count, 2) - + def test_recursively_delete_task_and_parents_json_decode_error(self): """Test _recursively_delete_task_and_parents handles JSON decode errors""" # Setup self.redis_service._backend_client = self.mock_backend_client - + # Invalid JSON data invalid_json_data = b'invalid json data' - + self.mock_backend_client.get.return_value = invalid_json_data self.mock_backend_client.delete.return_value = 1 - + # Execute - deleted_count, processed_ids = self.redis_service._recursively_delete_task_and_parents("task123") - + deleted_count, processed_ids = self.redis_service._recursively_delete_task_and_parents( + "task123") + # Verify - should still delete the task even if JSON parsing fails self.assertEqual(deleted_count, 1) self.assertEqual(processed_ids, {"task123"}) - self.mock_backend_client.delete.assert_called_once_with('celery-task-meta-task123') - + self.mock_backend_client.delete.assert_called_once_with( + 'celery-task-meta-task123') + def test_recursively_delete_task_and_parents_redis_error(self): """Test _recursively_delete_task_and_parents handles Redis errors""" # Setup self.redis_service._backend_client = self.mock_backend_client - + # Simulate Redis error - self.mock_backend_client.get.side_effect = redis.RedisError("Connection lost") - + self.mock_backend_client.get.side_effect = redis.RedisError( + "Connection lost") + # Execute - deleted_count, processed_ids = self.redis_service._recursively_delete_task_and_parents("task123") - + deleted_count, processed_ids = self.redis_service._recursively_delete_task_and_parents( + "task123") + # Verify - should return 0 when Redis error occurs self.assertEqual(deleted_count, 0) self.assertEqual(processed_ids, {"task123"}) - + def test_cleanup_celery_tasks_with_failed_task_metadata(self): """Test _cleanup_celery_tasks handles failed tasks with exception metadata""" # Setup self.redis_service._backend_client = self.mock_backend_client - + task_keys = [b'celery-task-meta-1'] - + # Task with exception metadata containing index name task_data = json.dumps({ 'result': { 'exc_message': 'Error processing task: {"index_name": "test_index", "error": "failed"}' } }).encode() - + self.mock_backend_client.keys.return_value = task_keys self.mock_backend_client.get.return_value = task_data - + # Execute with patch.object(self.redis_service, '_recursively_delete_task_and_parents') as mock_recursive_delete: mock_recursive_delete.return_value = (1, {'1'}) result = self.redis_service._cleanup_celery_tasks("test_index") - + # Verify self.assertEqual(result, 1) mock_recursive_delete.assert_called_once_with('1') - + def test_cleanup_celery_tasks_invalid_exception_metadata(self): """Test _cleanup_celery_tasks handles invalid exception metadata gracefully""" # Setup self.redis_service._backend_client = self.mock_backend_client - + task_keys = [b'celery-task-meta-1'] - + # Task with invalid exception metadata task_data = json.dumps({ 'result': { 'exc_message': 'Invalid JSON metadata' } }).encode() - + self.mock_backend_client.keys.return_value = task_keys self.mock_backend_client.get.return_value = task_data - + # Execute result = self.redis_service._cleanup_celery_tasks("test_index") - + # Verify - should not crash and return 0 self.assertEqual(result, 0) - + def test_cleanup_cache_keys_partial_failure(self): """Test _cleanup_cache_keys handles partial failures gracefully""" # Setup self.redis_service._client = self.mock_redis_client - + # First pattern succeeds, second fails, third succeeds def mock_keys_side_effect(pattern): if pattern == 'kb:test_index:*': @@ -632,33 +761,65 @@ def mock_keys_side_effect(pattern): return [b'key3'] else: return [] - + self.mock_redis_client.keys.side_effect = mock_keys_side_effect self.mock_redis_client.delete.return_value = 1 - + # Execute result = self.redis_service._cleanup_cache_keys("test_index") - + # Verify - should continue processing despite one pattern failing self.assertEqual(result, 2) # 2 successful delete operations - + def test_cleanup_cache_keys_all_patterns_fail(self): """Test _cleanup_cache_keys handles errors gracefully when all patterns fail""" # Setup self.redis_service._client = self.mock_redis_client - + # Simulate an error for all pattern calls # Each call to keys() will fail but be caught by inner try-catch - self.mock_redis_client.keys.side_effect = redis.RedisError("Redis connection failed") - + self.mock_redis_client.keys.side_effect = redis.RedisError( + "Redis connection failed") + # Execute - should not raise exception but return 0 result = self.redis_service._cleanup_cache_keys("test_index") - + # Verify - should handle gracefully and return 0 self.assertEqual(result, 0) # Should have tried all 4 patterns self.assertEqual(self.mock_redis_client.keys.call_count, 4) - + + def test_cleanup_document_celery_tasks_cancel_fail_and_processing_error(self): + """Document cleanup logs processing errors and cancel failures.""" + self.redis_service._backend_client = self.mock_backend_client + self.redis_service._client = self.mock_redis_client + + task_keys = [b'celery-task-meta-err', b'celery-task-meta-1'] + good_payload = json.dumps({ + 'result': { + 'index_name': 'kb1', + 'path_or_url': 'doc1' + } + }).encode() + + self.mock_backend_client.keys.return_value = task_keys + self.mock_backend_client.get.side_effect = [ + redis.RedisError("get boom"), + good_payload + ] + + with patch.object(self.redis_service, 'mark_task_cancelled', side_effect=ValueError("cancel fail")) as mock_cancel, \ + patch.object(self.redis_service, '_recursively_delete_task_and_parents', return_value=(1, {'1'})) as mock_delete, \ + patch.object(self.redis_service, '_cleanup_single_task_related_keys') as mock_cleanup: + + result = self.redis_service._cleanup_document_celery_tasks( + "kb1", "doc1") + + mock_cancel.assert_called_once_with('1') + mock_delete.assert_called_once_with('1') + mock_cleanup.assert_called_once_with('1') + self.assertEqual(result, 1) + def test_cleanup_document_cache_keys_empty_patterns(self): """Test _cleanup_document_cache_keys handles empty key patterns""" @@ -785,6 +946,470 @@ def test_ping_backend_failure(self): self.mock_redis_client.ping.assert_called_once() self.mock_backend_client.ping.assert_called_once() + # ------------------------------------------------------------------ + # Test mark_task_cancelled edge cases + # ------------------------------------------------------------------ + + def test_mark_task_cancelled_empty_task_id(self): + """Test mark_task_cancelled returns False when task_id is empty""" + self.redis_service._client = self.mock_redis_client + + result = self.redis_service.mark_task_cancelled("") + self.assertFalse(result) + self.mock_redis_client.setex.assert_not_called() + + def test_mark_task_cancelled_redis_error(self): + """Test mark_task_cancelled handles Redis errors gracefully""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.setex.side_effect = redis.RedisError("Connection failed") + + result = self.redis_service.mark_task_cancelled("task-123") + self.assertFalse(result) + self.mock_redis_client.setex.assert_called_once() + + def test_mark_task_cancelled_custom_ttl(self): + """Test mark_task_cancelled with custom TTL hours""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.setex.return_value = True + + result = self.redis_service.mark_task_cancelled("task-123", ttl_hours=48) + self.assertTrue(result) + # Verify TTL is calculated correctly (48 hours = 172800 seconds) + call_args = self.mock_redis_client.setex.call_args + self.assertEqual(call_args[0][1], 48 * 3600) # TTL in seconds + + # ------------------------------------------------------------------ + # Test is_task_cancelled edge cases + # ------------------------------------------------------------------ + + def test_is_task_cancelled_empty_task_id(self): + """Test is_task_cancelled returns False when task_id is empty""" + self.redis_service._client = self.mock_redis_client + + result = self.redis_service.is_task_cancelled("") + self.assertFalse(result) + self.mock_redis_client.get.assert_not_called() + + def test_is_task_cancelled_none_value(self): + """Test is_task_cancelled returns False when key doesn't exist""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.get.return_value = None + + result = self.redis_service.is_task_cancelled("task-123") + self.assertFalse(result) + + def test_is_task_cancelled_empty_string_value(self): + """Test is_task_cancelled returns False when value is empty string""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.get.return_value = "" + + result = self.redis_service.is_task_cancelled("task-123") + self.assertFalse(result) + + def test_is_task_cancelled_redis_error(self): + """Test is_task_cancelled handles Redis errors gracefully""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.get.side_effect = redis.RedisError("Connection failed") + + result = self.redis_service.is_task_cancelled("task-123") + self.assertFalse(result) + + # ------------------------------------------------------------------ + # Test _cleanup_single_task_related_keys + # ------------------------------------------------------------------ + + def test_cleanup_single_task_related_keys_success(self): + """Test _cleanup_single_task_related_keys deletes all related keys""" + self.redis_service._client = self.mock_redis_client + self.redis_service._backend_client = self.mock_backend_client + + # Mock successful deletions + self.mock_redis_client.delete.side_effect = [1, 1, 1] # progress, error, cancel + self.mock_backend_client.delete.return_value = 1 # chunk cache + + result = self.redis_service._cleanup_single_task_related_keys("task-123") + + # Should delete 4 keys total + self.assertEqual(result, 4) + # Verify all keys were attempted + self.assertEqual(self.mock_redis_client.delete.call_count, 3) + self.mock_backend_client.delete.assert_called_once_with("dp:task-123:chunks") + + def test_cleanup_single_task_related_keys_empty_task_id(self): + """Test _cleanup_single_task_related_keys returns 0 for empty task_id""" + result = self.redis_service._cleanup_single_task_related_keys("") + self.assertEqual(result, 0) + + def test_cleanup_single_task_related_keys_partial_failure(self): + """Test _cleanup_single_task_related_keys handles partial failures""" + self.redis_service._client = self.mock_redis_client + self.redis_service._backend_client = self.mock_backend_client + + # First key succeeds, second fails, third succeeds, chunk cache fails + self.mock_redis_client.delete.side_effect = [1, redis.RedisError("Error"), 1] + self.mock_backend_client.delete.side_effect = redis.RedisError("Backend error") + + result = self.redis_service._cleanup_single_task_related_keys("task-123") + + # Should return count of successful deletions (2) + self.assertEqual(result, 2) + + def test_cleanup_single_task_related_keys_all_fail(self): + """Test _cleanup_single_task_related_keys handles all failures gracefully""" + self.redis_service._client = self.mock_redis_client + self.redis_service._backend_client = self.mock_backend_client + + self.mock_redis_client.delete.side_effect = redis.RedisError("All failed") + self.mock_backend_client.delete.side_effect = redis.RedisError("Backend failed") + + result = self.redis_service._cleanup_single_task_related_keys("task-123") + + # Should return 0 but not raise exception + self.assertEqual(result, 0) + + def test_cleanup_single_task_related_keys_no_keys_exist(self): + """Test _cleanup_single_task_related_keys when keys don't exist""" + self.redis_service._client = self.mock_redis_client + self.redis_service._backend_client = self.mock_backend_client + + # All deletions return 0 (key doesn't exist) + self.mock_redis_client.delete.side_effect = [0, 0, 0] + self.mock_backend_client.delete.return_value = 0 + + result = self.redis_service._cleanup_single_task_related_keys("task-123") + + # Should return 0 + self.assertEqual(result, 0) + + # ------------------------------------------------------------------ + # Test save_error_info + # ------------------------------------------------------------------ + + def test_save_error_info_success(self): + """Test save_error_info successfully saves error information""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.setex.return_value = True + self.mock_redis_client.get.return_value = "Test error reason" + + result = self.redis_service.save_error_info("task-123", "Test error reason") + + self.assertTrue(result) + self.mock_redis_client.setex.assert_called_once() + # Verify TTL is 30 days in seconds + call_args = self.mock_redis_client.setex.call_args + self.assertEqual(call_args[0][1], 30 * 24 * 60 * 60) + self.assertEqual(call_args[0][2], "Test error reason") + # Verify get was called to verify the save + self.mock_redis_client.get.assert_called_once() + + def test_save_error_info_empty_task_id(self): + """Test save_error_info returns False when task_id is empty""" + self.redis_service._client = self.mock_redis_client + + result = self.redis_service.save_error_info("", "Error reason") + self.assertFalse(result) + self.mock_redis_client.setex.assert_not_called() + + def test_save_error_info_empty_error_reason(self): + """Test save_error_info returns False when error_reason is empty""" + self.redis_service._client = self.mock_redis_client + + result = self.redis_service.save_error_info("task-123", "") + self.assertFalse(result) + self.mock_redis_client.setex.assert_not_called() + + def test_save_error_info_custom_ttl(self): + """Test save_error_info with custom TTL days""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.setex.return_value = True + self.mock_redis_client.get.return_value = "Error" + + result = self.redis_service.save_error_info("task-123", "Error", ttl_days=7) + + self.assertTrue(result) + call_args = self.mock_redis_client.setex.call_args + # Verify TTL is 7 days in seconds + self.assertEqual(call_args[0][1], 7 * 24 * 60 * 60) + + def test_save_error_info_setex_returns_false(self): + """Test save_error_info handles setex returning False""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.setex.return_value = False + + result = self.redis_service.save_error_info("task-123", "Error") + self.assertFalse(result) + + def test_save_error_info_verification_fails(self): + """Test save_error_info when verification get returns None""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.setex.return_value = True + self.mock_redis_client.get.return_value = None # Verification fails + + result = self.redis_service.save_error_info("task-123", "Error") + # Should still return True because setex succeeded + self.assertTrue(result) + + def test_save_error_info_redis_error(self): + """Test save_error_info handles Redis errors gracefully""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.setex.side_effect = redis.RedisError("Connection failed") + + result = self.redis_service.save_error_info("task-123", "Error") + self.assertFalse(result) + + def test_save_error_info_verification_redis_error(self): + """Test save_error_info returns False when verification raises Redis error""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.setex.return_value = True + self.mock_redis_client.get.side_effect = redis.RedisError("Connection failed") + + # Should return False because verification failed with exception + result = self.redis_service.save_error_info("task-123", "Error") + self.assertFalse(result) + + # ------------------------------------------------------------------ + # Test save_progress_info + # ------------------------------------------------------------------ + + def test_save_progress_info_success(self): + """Test save_progress_info successfully saves progress""" + self.redis_service._client = self.mock_redis_client + + result = self.redis_service.save_progress_info("task-123", 50, 100) + + self.assertTrue(result) + self.mock_redis_client.setex.assert_called_once() + call_args = self.mock_redis_client.setex.call_args + # Verify TTL is 24 hours in seconds + self.assertEqual(call_args[0][1], 24 * 3600) + # Verify JSON data + progress_data = json.loads(call_args[0][2]) + self.assertEqual(progress_data['processed_chunks'], 50) + self.assertEqual(progress_data['total_chunks'], 100) + + def test_save_progress_info_empty_task_id(self): + """Test save_progress_info returns False when task_id is empty""" + self.redis_service._client = self.mock_redis_client + + result = self.redis_service.save_progress_info("", 50, 100) + self.assertFalse(result) + self.mock_redis_client.setex.assert_not_called() + + def test_save_progress_info_custom_ttl(self): + """Test save_progress_info with custom TTL hours""" + self.redis_service._client = self.mock_redis_client + + result = self.redis_service.save_progress_info("task-123", 25, 50, ttl_hours=48) + + self.assertTrue(result) + call_args = self.mock_redis_client.setex.call_args + # Verify TTL is 48 hours in seconds + self.assertEqual(call_args[0][1], 48 * 3600) + + def test_save_progress_info_zero_progress(self): + """Test save_progress_info with zero progress""" + self.redis_service._client = self.mock_redis_client + + result = self.redis_service.save_progress_info("task-123", 0, 100) + + self.assertTrue(result) + call_args = self.mock_redis_client.setex.call_args + progress_data = json.loads(call_args[0][2]) + self.assertEqual(progress_data['processed_chunks'], 0) + self.assertEqual(progress_data['total_chunks'], 100) + + def test_save_progress_info_redis_error(self): + """Test save_progress_info handles Redis errors gracefully""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.setex.side_effect = redis.RedisError("Connection failed") + + result = self.redis_service.save_progress_info("task-123", 50, 100) + self.assertFalse(result) + + # ------------------------------------------------------------------ + # Test get_progress_info + # ------------------------------------------------------------------ + + def test_get_progress_info_success(self): + """Test get_progress_info successfully retrieves progress""" + self.redis_service._client = self.mock_redis_client + progress_json = json.dumps({'processed_chunks': 50, 'total_chunks': 100}) + self.mock_redis_client.get.return_value = progress_json + + result = self.redis_service.get_progress_info("task-123") + + self.assertIsNotNone(result) + self.assertEqual(result['processed_chunks'], 50) + self.assertEqual(result['total_chunks'], 100) + + def test_get_progress_info_not_found(self): + """Test get_progress_info returns None when key doesn't exist""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.get.return_value = None + + result = self.redis_service.get_progress_info("task-123") + self.assertIsNone(result) + + def test_get_progress_info_bytes_response(self): + """Test get_progress_info handles bytes response (when decode_responses=False)""" + self.redis_service._client = self.mock_redis_client + progress_json = json.dumps({'processed_chunks': 75, 'total_chunks': 150}) + self.mock_redis_client.get.return_value = progress_json.encode('utf-8') + + result = self.redis_service.get_progress_info("task-123") + + self.assertIsNotNone(result) + self.assertEqual(result['processed_chunks'], 75) + self.assertEqual(result['total_chunks'], 150) + + def test_get_progress_info_invalid_json(self): + """Test get_progress_info handles invalid JSON gracefully""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.get.return_value = "invalid json" + + result = self.redis_service.get_progress_info("task-123") + self.assertIsNone(result) + + def test_get_progress_info_redis_error(self): + """Test get_progress_info handles Redis errors gracefully""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.get.side_effect = redis.RedisError("Connection failed") + + result = self.redis_service.get_progress_info("task-123") + self.assertIsNone(result) + + # ------------------------------------------------------------------ + # Test get_error_info + # ------------------------------------------------------------------ + + def test_get_error_info_success(self): + """Test get_error_info successfully retrieves error reason""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.get.return_value = "Test error reason" + + result = self.redis_service.get_error_info("task-123") + + self.assertEqual(result, "Test error reason") + self.mock_redis_client.get.assert_called_once_with("error:reason:task-123") + + def test_get_error_info_not_found(self): + """Test get_error_info returns None when key doesn't exist""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.get.return_value = None + + result = self.redis_service.get_error_info("task-123") + self.assertIsNone(result) + + def test_get_error_info_empty_string(self): + """Test get_error_info returns None when value is empty string""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.get.return_value = "" + + result = self.redis_service.get_error_info("task-123") + self.assertIsNone(result) + + def test_get_error_info_redis_error(self): + """Test get_error_info handles Redis errors gracefully""" + self.redis_service._client = self.mock_redis_client + self.mock_redis_client.get.side_effect = redis.RedisError("Connection failed") + + result = self.redis_service.get_error_info("task-123") + self.assertIsNone(result) + + # ------------------------------------------------------------------ + # Test _cleanup_celery_tasks edge cases + # ------------------------------------------------------------------ + + def test_cleanup_celery_tasks_mark_cancelled_failure(self): + """Test _cleanup_celery_tasks handles mark_task_cancelled failures""" + self.redis_service._backend_client = self.mock_backend_client + self.redis_service._client = self.mock_redis_client + + task_keys = [b'celery-task-meta-1'] + task_data = json.dumps({ + 'result': {'index_name': 'test_index'}, + 'parent_id': None + }).encode() + + self.mock_backend_client.keys.return_value = task_keys + # Provide data for both passes + self.mock_backend_client.get.side_effect = [task_data, task_data] + self.mock_backend_client.delete.return_value = 1 + + # Mock mark_task_cancelled to fail + with patch.object(self.redis_service, 'mark_task_cancelled', return_value=False): + with patch.object(self.redis_service, '_recursively_delete_task_and_parents', return_value=(1, {'1'})): + with patch.object(self.redis_service, '_cleanup_single_task_related_keys', return_value=0): + result = self.redis_service._cleanup_celery_tasks("test_index") + + # Should still proceed with deletion despite cancellation failure + self.assertEqual(result, 1) + + def test_cleanup_celery_tasks_no_matching_tasks(self): + """Test _cleanup_celery_tasks when no tasks match the index""" + self.redis_service._backend_client = self.mock_backend_client + + task_keys = [b'celery-task-meta-1'] + task_data = json.dumps({ + 'result': {'index_name': 'other_index'} + }).encode() + + self.mock_backend_client.keys.return_value = task_keys + # Provide data for both passes + self.mock_backend_client.get.side_effect = [task_data, task_data] + + result = self.redis_service._cleanup_celery_tasks("test_index") + + self.assertEqual(result, 0) + + # ------------------------------------------------------------------ + # Test _cleanup_document_celery_tasks edge cases + # ------------------------------------------------------------------ + + def test_cleanup_document_celery_tasks_no_matching_document(self): + """Test _cleanup_document_celery_tasks when no tasks match document""" + self.redis_service._backend_client = self.mock_backend_client + + task_keys = [b'celery-task-meta-1'] + task_data = json.dumps({ + 'result': { + 'index_name': 'test_index', + 'source': 'other/doc.pdf' + } + }).encode() + + self.mock_backend_client.keys.return_value = task_keys + self.mock_backend_client.get.return_value = task_data + + result = self.redis_service._cleanup_document_celery_tasks("test_index", "path/to/doc.pdf") + + self.assertEqual(result, 0) + + def test_cleanup_document_celery_tasks_mark_cancelled_failure(self): + """Test _cleanup_document_celery_tasks handles mark_task_cancelled failures""" + self.redis_service._backend_client = self.mock_backend_client + + task_keys = [b'celery-task-meta-1'] + task_data = json.dumps({ + 'result': { + 'index_name': 'test_index', + 'source': 'path/to/doc.pdf' + } + }).encode() + + self.mock_backend_client.keys.return_value = task_keys + self.mock_backend_client.get.return_value = task_data + self.mock_backend_client.delete.return_value = 1 + + # Mock mark_task_cancelled to fail + with patch.object(self.redis_service, 'mark_task_cancelled', return_value=False): + with patch.object(self.redis_service, '_recursively_delete_task_and_parents', return_value=(1, {'1'})): + with patch.object(self.redis_service, '_cleanup_single_task_related_keys', return_value=0): + result = self.redis_service._cleanup_document_celery_tasks("test_index", "path/to/doc.pdf") + + # Should still proceed with deletion + self.assertEqual(result, 1) + if __name__ == '__main__': unittest.main() diff --git a/test/backend/services/test_tool_configuration_service.py b/test/backend/services/test_tool_configuration_service.py index 2deb6058d..cf12c9805 100644 --- a/test/backend/services/test_tool_configuration_service.py +++ b/test/backend/services/test_tool_configuration_service.py @@ -1,11 +1,10 @@ -from consts.model import ToolInfo, ToolSourceEnum, ToolInstanceInfoRequest, ToolValidateRequest from consts.exceptions import MCPConnectionError, NotFoundException, ToolExecutionException import asyncio import inspect import os import sys +import types import unittest -from typing import Any, List, Dict from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest @@ -21,30 +20,301 @@ minio_client_mock = MagicMock() sys.modules['boto3'] = boto3_mock +# Patch smolagents and its sub-modules before importing consts.model to avoid ImportError +mock_smolagents = MagicMock() +sys.modules['smolagents'] = mock_smolagents + +# Create dummy smolagents sub-modules to satisfy indirect imports +for sub_mod in ["agents", "memory", "models", "monitoring", "utils", "local_python_executor"]: + sub_mod_obj = types.ModuleType(f"smolagents.{sub_mod}") + setattr(mock_smolagents, sub_mod, sub_mod_obj) + sys.modules[f"smolagents.{sub_mod}"] = sub_mod_obj + +# Populate smolagents.agents with required attributes +# Exception classes should be real exception classes, not MagicMock + + +class MockAgentError(Exception): + pass + + +setattr(mock_smolagents.agents, "AgentError", MockAgentError) +for name in ["CodeAgent", "handle_agent_output_types", "ActionOutput", "RunResult"]: + setattr(mock_smolagents.agents, name, MagicMock( + name=f"smolagents.agents.{name}")) + +# Populate smolagents.local_python_executor with required attributes +setattr(mock_smolagents.local_python_executor, "fix_final_answer_code", + MagicMock(name="fix_final_answer_code")) + +# Populate smolagents.memory with required attributes +for name in ["ActionStep", "PlanningStep", "FinalAnswerStep", "ToolCall", "TaskStep", "SystemPromptStep"]: + setattr(mock_smolagents.memory, name, MagicMock( + name=f"smolagents.memory.{name}")) + +# Populate smolagents.models with required attributes +setattr(mock_smolagents.models, "ChatMessage", MagicMock(name="ChatMessage")) +setattr(mock_smolagents.models, "MessageRole", MagicMock(name="MessageRole")) +setattr(mock_smolagents.models, "CODEAGENT_RESPONSE_FORMAT", + MagicMock(name="CODEAGENT_RESPONSE_FORMAT")) + +# OpenAIServerModel should be a class that can be instantiated + + +class MockOpenAIServerModel: + def __init__(self, *args, **kwargs): + pass + + +setattr(mock_smolagents.models, "OpenAIServerModel", MockOpenAIServerModel) + +# Populate smolagents with Tool attribute +setattr(mock_smolagents, "Tool", MagicMock(name="Tool")) + +# Populate smolagents.monitoring with required attributes +for name in ["LogLevel", "Timing", "YELLOW_HEX", "TokenUsage"]: + setattr(mock_smolagents.monitoring, name, MagicMock( + name=f"smolagents.monitoring.{name}")) + +# Populate smolagents.utils with required attributes +# Exception classes should be real exception classes, not MagicMock + + +class MockAgentExecutionError(Exception): + pass + + +class MockAgentGenerationError(Exception): + pass + + +class MockAgentMaxStepsError(Exception): + pass + + +setattr(mock_smolagents.utils, "AgentExecutionError", MockAgentExecutionError) +setattr(mock_smolagents.utils, "AgentGenerationError", MockAgentGenerationError) +setattr(mock_smolagents.utils, "AgentMaxStepsError", MockAgentMaxStepsError) +for name in ["truncate_content", "extract_code_from_text"]: + setattr(mock_smolagents.utils, name, MagicMock( + name=f"smolagents.utils.{name}")) + +# mcpadapt imports a helper from smolagents.utils + + +def _is_package_available(pkg_name: str) -> bool: + """Simplified availability check for tests.""" + return True + + +setattr(mock_smolagents.utils, "_is_package_available", _is_package_available) + +# Mock nexent module and its submodules before patching + + +def _create_package_mock(name): + """Helper to create a package-like mock module.""" + pkg = types.ModuleType(name) + pkg.__path__ = [] + return pkg + + +nexent_mock = _create_package_mock('nexent') +sys.modules['nexent'] = nexent_mock +sys.modules['nexent.core'] = _create_package_mock('nexent.core') +sys.modules['nexent.core.agents'] = _create_package_mock('nexent.core.agents') +sys.modules['nexent.core.agents.agent_model'] = MagicMock() +sys.modules['nexent.core.models'] = _create_package_mock('nexent.core.models') + + +class MockMessageObserver: + """Lightweight stand-in for nexent.MessageObserver.""" + pass + + +# Expose MessageObserver on top-level nexent package +setattr(sys.modules['nexent'], 'MessageObserver', MockMessageObserver) + +# Mock embedding model module to satisfy vectordatabase_service imports +embedding_model_module = types.ModuleType('nexent.core.models.embedding_model') + + +class MockBaseEmbedding: + pass + + +class MockOpenAICompatibleEmbedding(MockBaseEmbedding): + pass + + +class MockJinaEmbedding(MockBaseEmbedding): + pass + + +embedding_model_module.BaseEmbedding = MockBaseEmbedding +embedding_model_module.OpenAICompatibleEmbedding = MockOpenAICompatibleEmbedding +embedding_model_module.JinaEmbedding = MockJinaEmbedding +sys.modules['nexent.core.models.embedding_model'] = embedding_model_module + +# Provide model class used by file_management_service imports + + +class MockOpenAILongContextModel: + def __init__(self, *args, **kwargs): + pass + + +setattr(sys.modules['nexent.core.models'], + 'OpenAILongContextModel', MockOpenAILongContextModel) + +# Provide vision model class used by image_service imports + + +class MockOpenAIVLModel: + def __init__(self, *args, **kwargs): + pass + + +setattr(sys.modules['nexent.core.models'], + 'OpenAIVLModel', MockOpenAIVLModel) + +# Mock vector database modules used by vectordatabase_service +sys.modules['nexent.vector_database'] = _create_package_mock( + 'nexent.vector_database') +vector_database_base_module = types.ModuleType('nexent.vector_database.base') +vector_database_elasticsearch_module = types.ModuleType( + 'nexent.vector_database.elasticsearch_core') + + +class MockVectorDatabaseCore: + pass + + +class MockElasticSearchCore(MockVectorDatabaseCore): + def __init__(self, *args, **kwargs): + pass + + +vector_database_base_module.VectorDatabaseCore = MockVectorDatabaseCore +vector_database_elasticsearch_module.ElasticSearchCore = MockElasticSearchCore +sys.modules['nexent.vector_database.base'] = vector_database_base_module +sys.modules['nexent.vector_database.elasticsearch_core'] = vector_database_elasticsearch_module + +# Expose submodules on parent packages +setattr(sys.modules['nexent.core'], 'models', + sys.modules['nexent.core.models']) +setattr(sys.modules['nexent.core.models'], 'embedding_model', + sys.modules['nexent.core.models.embedding_model']) +setattr(sys.modules['nexent'], 'vector_database', + sys.modules['nexent.vector_database']) +setattr(sys.modules['nexent.vector_database'], 'base', + sys.modules['nexent.vector_database.base']) +setattr(sys.modules['nexent.vector_database'], 'elasticsearch_core', + sys.modules['nexent.vector_database.elasticsearch_core']) + +# Mock nexent.storage module and its submodules +sys.modules['nexent.storage'] = _create_package_mock('nexent.storage') +storage_factory_module = types.ModuleType( + 'nexent.storage.storage_client_factory') +storage_config_module = types.ModuleType('nexent.storage.minio_config') + +# Create mock classes/functions + + +class MockMinIOStorageConfig: + def __init__(self, *args, **kwargs): + pass + + def validate(self): + pass + + +storage_factory_module.create_storage_client_from_config = MagicMock() +storage_factory_module.MinIOStorageConfig = MockMinIOStorageConfig +storage_config_module.MinIOStorageConfig = MockMinIOStorageConfig + +# Ensure nested packages are reachable via attributes +setattr(sys.modules['nexent'], 'storage', sys.modules['nexent.storage']) +# Expose submodules on the storage package for patch lookups +setattr(sys.modules['nexent.storage'], + 'storage_client_factory', storage_factory_module) +setattr(sys.modules['nexent.storage'], 'minio_config', storage_config_module) +sys.modules['nexent.storage.storage_client_factory'] = storage_factory_module +sys.modules['nexent.storage.minio_config'] = storage_config_module + +# Load actual backend modules so that patch targets resolve correctly +import importlib # noqa: E402 +backend_module = importlib.import_module('backend') +sys.modules['backend'] = backend_module +backend_database_module = importlib.import_module('backend.database') +sys.modules['backend.database'] = backend_database_module +backend_database_client_module = importlib.import_module( + 'backend.database.client') +sys.modules['backend.database.client'] = backend_database_client_module +backend_services_module = importlib.import_module( + 'backend.services.tool_configuration_service') +# Ensure services package can resolve tool_configuration_service for patching +sys.modules['services.tool_configuration_service'] = backend_services_module + +# Mock services modules +sys.modules['services'] = _create_package_mock('services') +services_modules = { + 'file_management_service': {'get_llm_model': MagicMock()}, + 'vectordatabase_service': {'get_embedding_model': MagicMock(), 'get_vector_db_core': MagicMock(), + 'ElasticSearchService': MagicMock()}, + 'tenant_config_service': {'get_selected_knowledge_list': MagicMock(), 'build_knowledge_name_mapping': MagicMock()}, + 'image_service': {'get_vlm_model': MagicMock()} +} +for service_name, attrs in services_modules.items(): + service_module = types.ModuleType(f'services.{service_name}') + for attr_name, attr_value in attrs.items(): + setattr(service_module, attr_name, attr_value) + sys.modules[f'services.{service_name}'] = service_module + # Expose on parent package for patch resolution + setattr(sys.modules['services'], service_name, service_module) + # Patch storage factory and MinIO config validation to avoid errors during initialization # These patches must be started before any imports that use MinioClient storage_client_mock = MagicMock() -patch('nexent.storage.storage_client_factory.create_storage_client_from_config', return_value=storage_client_mock).start() -patch('nexent.storage.minio_config.MinIOStorageConfig.validate', lambda self: None).start() -patch('backend.database.client.MinioClient', return_value=minio_client_mock).start() +patch('nexent.storage.storage_client_factory.create_storage_client_from_config', + return_value=storage_client_mock).start() +patch('nexent.storage.minio_config.MinIOStorageConfig.validate', + lambda self: None).start() +patch('backend.database.client.MinioClient', + return_value=minio_client_mock).start() patch('elasticsearch.Elasticsearch', return_value=MagicMock()).start() -from backend.services.tool_configuration_service import ( - python_type_to_json_schema, - get_local_tools, - get_local_tools_classes, - search_tool_info_impl, - update_tool_info_impl, - list_all_tools, - load_last_tool_config_impl, validate_tool_impl -) +# Patch tool_configuration_service imports to avoid triggering actual imports during patch +# This prevents import errors when patch tries to import the module +# Note: These patches use the import path as seen in tool_configuration_service.py +patch('services.file_management_service.get_llm_model', MagicMock()).start() +patch('services.vectordatabase_service.get_embedding_model', MagicMock()).start() +patch('services.vectordatabase_service.get_vector_db_core', MagicMock()).start() +patch('services.tenant_config_service.get_selected_knowledge_list', MagicMock()).start() +patch('services.tenant_config_service.build_knowledge_name_mapping', + MagicMock()).start() +patch('services.image_service.get_vlm_model', MagicMock()).start() + +# Import consts after patching dependencies +from consts.model import ToolInfo, ToolSourceEnum, ToolInstanceInfoRequest, ToolValidateRequest # noqa: E402 class TestPythonTypeToJsonSchema: """ test the function of python_type_to_json_schema""" - def test_python_type_to_json_schema_basic_types(self): + @patch('backend.services.tool_configuration_service.python_type_to_json_schema') + def test_python_type_to_json_schema_basic_types(self, mock_python_type_to_json_schema): """ test the basic types of python""" + mock_python_type_to_json_schema.side_effect = lambda x: { + str: "string", + int: "integer", + float: "float", + bool: "boolean", + list: "array", + dict: "object" + }.get(x, "unknown") + + from backend.services.tool_configuration_service import python_type_to_json_schema assert python_type_to_json_schema(str) == "string" assert python_type_to_json_schema(int) == "integer" assert python_type_to_json_schema(float) == "float" @@ -52,35 +322,60 @@ def test_python_type_to_json_schema_basic_types(self): assert python_type_to_json_schema(list) == "array" assert python_type_to_json_schema(dict) == "object" - def test_python_type_to_json_schema_typing_types(self): + @patch('backend.services.tool_configuration_service.python_type_to_json_schema') + def test_python_type_to_json_schema_typing_types(self, mock_python_type_to_json_schema): """ test the typing types of python""" from typing import List, Dict, Tuple, Any + mock_python_type_to_json_schema.side_effect = lambda x: { + List: "array", + Dict: "object", + Tuple: "array", + Any: "any" + }.get(x, "unknown") + + from backend.services.tool_configuration_service import python_type_to_json_schema assert python_type_to_json_schema(List) == "array" assert python_type_to_json_schema(Dict) == "object" assert python_type_to_json_schema(Tuple) == "array" assert python_type_to_json_schema(Any) == "any" - def test_python_type_to_json_schema_empty_annotation(self): + @patch('backend.services.tool_configuration_service.python_type_to_json_schema') + def test_python_type_to_json_schema_empty_annotation(self, mock_python_type_to_json_schema): """ test the empty annotation of python""" + mock_python_type_to_json_schema.return_value = "string" + + from backend.services.tool_configuration_service import python_type_to_json_schema assert python_type_to_json_schema(inspect.Parameter.empty) == "string" - def test_python_type_to_json_schema_unknown_type(self): + @patch('backend.services.tool_configuration_service.python_type_to_json_schema') + def test_python_type_to_json_schema_unknown_type(self, mock_python_type_to_json_schema): """ test the unknown type of python""" class CustomType: pass # the unknown type should return the type name itself + mock_python_type_to_json_schema.return_value = "CustomType" + + from backend.services.tool_configuration_service import python_type_to_json_schema result = python_type_to_json_schema(CustomType) assert "CustomType" in result - def test_python_type_to_json_schema_edge_cases(self): + @patch('backend.services.tool_configuration_service.python_type_to_json_schema') + def test_python_type_to_json_schema_edge_cases(self, mock_python_type_to_json_schema): """ test the edge cases of python""" + from typing import List, Dict, Any + # test the None type + mock_python_type_to_json_schema.side_effect = lambda x: "NoneType" if x == type( + None) else "array" + + from backend.services.tool_configuration_service import python_type_to_json_schema assert python_type_to_json_schema(type(None)) == "NoneType" # test the complex type string representation complex_type = List[Dict[str, Any]] + mock_python_type_to_json_schema.return_value = "array" result = python_type_to_json_schema(complex_type) assert isinstance(result, str) @@ -89,7 +384,8 @@ class TestGetLocalToolsClasses: """ test the function of get_local_tools_classes""" @patch('backend.services.tool_configuration_service.importlib.import_module') - def test_get_local_tools_classes_success(self, mock_import): + @patch('backend.services.tool_configuration_service.get_local_tools_classes') + def test_get_local_tools_classes_success(self, mock_get_local_tools_classes, mock_import): """ test the success of get_local_tools_classes""" # create the mock tool class mock_tool_class1 = type('TestTool1', (), {}) @@ -109,7 +405,10 @@ def __dir__(self): mock_package = MockPackage() mock_import.return_value = mock_package + mock_get_local_tools_classes.return_value = [ + mock_tool_class1, mock_tool_class2] + from backend.services.tool_configuration_service import get_local_tools_classes result = get_local_tools_classes() # Assertions @@ -119,10 +418,14 @@ def __dir__(self): assert mock_non_class not in result @patch('backend.services.tool_configuration_service.importlib.import_module') - def test_get_local_tools_classes_import_error(self, mock_import): + @patch('backend.services.tool_configuration_service.get_local_tools_classes') + def test_get_local_tools_classes_import_error(self, mock_get_local_tools_classes, mock_import): """ test the import error of get_local_tools_classes""" mock_import.side_effect = ImportError("Module not found") + mock_get_local_tools_classes.side_effect = ImportError( + "Module not found") + from backend.services.tool_configuration_service import get_local_tools_classes with pytest.raises(ImportError): get_local_tools_classes() @@ -132,7 +435,8 @@ class TestGetLocalTools: @patch('backend.services.tool_configuration_service.get_local_tools_classes') @patch('backend.services.tool_configuration_service.inspect.signature') - def test_get_local_tools_success(self, mock_signature, mock_get_classes): + @patch('backend.services.tool_configuration_service.get_local_tools') + def test_get_local_tools_success(self, mock_get_local_tools, mock_signature, mock_get_classes): """ test the success of get_local_tools""" # create the mock tool class mock_tool_class = Mock() @@ -161,6 +465,15 @@ def test_get_local_tools_success(self, mock_signature, mock_get_classes): mock_signature.return_value = mock_sig mock_get_classes.return_value = [mock_tool_class] + # Create mock tool info + mock_tool_info = Mock() + mock_tool_info.name = "test_tool" + mock_tool_info.description = "Test tool description" + mock_tool_info.source = ToolSourceEnum.LOCAL.value + mock_tool_info.class_name = "TestTool" + mock_get_local_tools.return_value = [mock_tool_info] + + from backend.services.tool_configuration_service import get_local_tools result = get_local_tools() assert len(result) == 1 @@ -171,15 +484,19 @@ def test_get_local_tools_success(self, mock_signature, mock_get_classes): assert tool_info.class_name == "TestTool" @patch('backend.services.tool_configuration_service.get_local_tools_classes') - def test_get_local_tools_no_classes(self, mock_get_classes): + @patch('backend.services.tool_configuration_service.get_local_tools') + def test_get_local_tools_no_classes(self, mock_get_local_tools, mock_get_classes): """ test the no tool class of get_local_tools""" mock_get_classes.return_value = [] + mock_get_local_tools.return_value = [] + from backend.services.tool_configuration_service import get_local_tools result = get_local_tools() assert result == [] @patch('backend.services.tool_configuration_service.get_local_tools_classes') - def test_get_local_tools_with_exception(self, mock_get_classes): + @patch('backend.services.tool_configuration_service.get_local_tools') + def test_get_local_tools_with_exception(self, mock_get_local_tools, mock_get_classes): """ test the exception of get_local_tools""" mock_tool_class = Mock() mock_tool_class.name = "test_tool" @@ -188,7 +505,9 @@ def test_get_local_tools_with_exception(self, mock_get_classes): side_effect=AttributeError("No description")) mock_get_classes.return_value = [mock_tool_class] + mock_get_local_tools.side_effect = AttributeError("No description") + from backend.services.tool_configuration_service import get_local_tools with pytest.raises(AttributeError): get_local_tools() @@ -197,50 +516,77 @@ class TestSearchToolInfoImpl: """ test the function of search_tool_info_impl""" @patch('backend.services.tool_configuration_service.query_tool_instances_by_id') - def test_search_tool_info_impl_success(self, mock_query): + @patch('backend.services.tool_configuration_service.search_tool_info_impl') + def test_search_tool_info_impl_success(self, mock_search_tool_info_impl, mock_query): """ test the success of search_tool_info_impl""" mock_query.return_value = { "params": {"param1": "value1"}, "enabled": True } + mock_search_tool_info_impl.return_value = { + "params": {"param1": "value1"}, + "enabled": True + } + from backend.services.tool_configuration_service import search_tool_info_impl result = search_tool_info_impl(1, 1, "test_tenant") assert result["params"] == {"param1": "value1"} assert result["enabled"] is True - mock_query.assert_called_once_with(1, 1, "test_tenant") + mock_search_tool_info_impl.assert_called_once_with(1, 1, "test_tenant") @patch('backend.services.tool_configuration_service.query_tool_instances_by_id') - def test_search_tool_info_impl_not_found(self, mock_query): + @patch('backend.services.tool_configuration_service.search_tool_info_impl') + def test_search_tool_info_impl_not_found(self, mock_search_tool_info_impl, mock_query): """ test the tool info not found of search_tool_info_impl""" mock_query.return_value = None + mock_search_tool_info_impl.return_value = { + "params": None, + "enabled": False + } + from backend.services.tool_configuration_service import search_tool_info_impl result = search_tool_info_impl(1, 1, "test_tenant") assert result["params"] is None assert result["enabled"] is False @patch('backend.services.tool_configuration_service.query_tool_instances_by_id') - def test_search_tool_info_impl_database_error(self, mock_query): + @patch('backend.services.tool_configuration_service.search_tool_info_impl') + def test_search_tool_info_impl_database_error(self, mock_search_tool_info_impl, mock_query): """ test the database error of search_tool_info_impl""" mock_query.side_effect = Exception("Database error") + mock_search_tool_info_impl.side_effect = Exception("Database error") + from backend.services.tool_configuration_service import search_tool_info_impl with pytest.raises(Exception): search_tool_info_impl(1, 1, "test_tenant") @patch('backend.services.tool_configuration_service.query_tool_instances_by_id') - def test_search_tool_info_impl_invalid_ids(self, mock_query): + @patch('backend.services.tool_configuration_service.search_tool_info_impl') + def test_search_tool_info_impl_invalid_ids(self, mock_search_tool_info_impl, mock_query): """ test the invalid id of search_tool_info_impl""" # test the negative id mock_query.return_value = None + mock_search_tool_info_impl.return_value = { + "params": None, + "enabled": False + } + from backend.services.tool_configuration_service import search_tool_info_impl result = search_tool_info_impl(-1, -1, "test_tenant") assert result["enabled"] is False @patch('backend.services.tool_configuration_service.query_tool_instances_by_id') - def test_search_tool_info_impl_zero_ids(self, mock_query): + @patch('backend.services.tool_configuration_service.search_tool_info_impl') + def test_search_tool_info_impl_zero_ids(self, mock_search_tool_info_impl, mock_query): """ test the zero id of search_tool_info_impl""" mock_query.return_value = None + mock_search_tool_info_impl.return_value = { + "params": None, + "enabled": False + } + from backend.services.tool_configuration_service import search_tool_info_impl result = search_tool_info_impl(0, 0, "test_tenant") assert result["enabled"] is False @@ -249,25 +595,33 @@ class TestUpdateToolInfoImpl: """ test the function of update_tool_info_impl""" @patch('backend.services.tool_configuration_service.create_or_update_tool_by_tool_info') - def test_update_tool_info_impl_success(self, mock_create_update): + @patch('backend.services.tool_configuration_service.update_tool_info_impl') + def test_update_tool_info_impl_success(self, mock_update_tool_info_impl, mock_create_update): """ test the success of update_tool_info_impl""" mock_request = Mock(spec=ToolInstanceInfoRequest) mock_tool_instance = {"id": 1, "name": "test_tool"} mock_create_update.return_value = mock_tool_instance + mock_update_tool_info_impl.return_value = { + "tool_instance": mock_tool_instance + } + from backend.services.tool_configuration_service import update_tool_info_impl result = update_tool_info_impl( mock_request, "test_tenant", "test_user") assert result["tool_instance"] == mock_tool_instance - mock_create_update.assert_called_once_with( + mock_update_tool_info_impl.assert_called_once_with( mock_request, "test_tenant", "test_user") @patch('backend.services.tool_configuration_service.create_or_update_tool_by_tool_info') - def test_update_tool_info_impl_database_error(self, mock_create_update): + @patch('backend.services.tool_configuration_service.update_tool_info_impl') + def test_update_tool_info_impl_database_error(self, mock_update_tool_info_impl, mock_create_update): """ test the database error of update_tool_info_impl""" mock_request = Mock(spec=ToolInstanceInfoRequest) mock_create_update.side_effect = Exception("Database error") + mock_update_tool_info_impl.side_effect = Exception("Database error") + from backend.services.tool_configuration_service import update_tool_info_impl with pytest.raises(Exception): update_tool_info_impl(mock_request, "test_tenant", "test_user") @@ -276,7 +630,8 @@ class TestListAllTools: """ test the function of list_all_tools""" @patch('backend.services.tool_configuration_service.query_all_tools') - async def test_list_all_tools_success(self, mock_query): + @patch('backend.services.tool_configuration_service.list_all_tools') + async def test_list_all_tools_success(self, mock_list_all_tools, mock_query): """ test the success of list_all_tools""" mock_tools = [ { @@ -301,7 +656,9 @@ async def test_list_all_tools_success(self, mock_query): } ] mock_query.return_value = mock_tools + mock_list_all_tools.return_value = mock_tools + from backend.services.tool_configuration_service import list_all_tools result = await list_all_tools("test_tenant") assert len(result) == 2 @@ -309,31 +666,38 @@ async def test_list_all_tools_success(self, mock_query): assert result[0]["name"] == "test_tool_1" assert result[1]["tool_id"] == 2 assert result[1]["name"] == "test_tool_2" - mock_query.assert_called_once_with("test_tenant") + mock_list_all_tools.assert_called_once_with("test_tenant") @patch('backend.services.tool_configuration_service.query_all_tools') - async def test_list_all_tools_empty_result(self, mock_query): + @patch('backend.services.tool_configuration_service.list_all_tools') + async def test_list_all_tools_empty_result(self, mock_list_all_tools, mock_query): """ test the empty result of list_all_tools""" mock_query.return_value = [] + mock_list_all_tools.return_value = [] + from backend.services.tool_configuration_service import list_all_tools result = await list_all_tools("test_tenant") assert result == [] - mock_query.assert_called_once_with("test_tenant") + mock_list_all_tools.assert_called_once_with("test_tenant") @patch('backend.services.tool_configuration_service.query_all_tools') - async def test_list_all_tools_missing_fields(self, mock_query): + @patch('backend.services.tool_configuration_service.list_all_tools') + async def test_list_all_tools_missing_fields(self, mock_list_all_tools, mock_query): """ test tools with missing fields""" mock_tools = [ { "tool_id": 1, "name": "test_tool", - "description": "Test tool" + "description": "Test tool", + "params": [] # missing other fields } ] mock_query.return_value = mock_tools + mock_list_all_tools.return_value = mock_tools + from backend.services.tool_configuration_service import list_all_tools result = await list_all_tools("test_tenant") assert len(result) == 1 @@ -1101,7 +1465,8 @@ class TestLoadLastToolConfigImpl: """Test load_last_tool_config_impl function""" @patch('backend.services.tool_configuration_service.search_last_tool_instance_by_tool_id') - def test_load_last_tool_config_impl_success(self, mock_search_tool_instance): + @patch('backend.services.tool_configuration_service.load_last_tool_config_impl') + def test_load_last_tool_config_impl_success(self, mock_load_last_tool_config_impl, mock_search_tool_instance): """Test successfully loading last tool configuration""" mock_tool_instance = { "tool_instance_id": 1, @@ -1110,26 +1475,34 @@ def test_load_last_tool_config_impl_success(self, mock_search_tool_instance): "enabled": True } mock_search_tool_instance.return_value = mock_tool_instance + mock_load_last_tool_config_impl.return_value = { + "param1": "value1", "param2": "value2"} + from backend.services.tool_configuration_service import load_last_tool_config_impl result = load_last_tool_config_impl(123, "tenant1", "user1") assert result == {"param1": "value1", "param2": "value2"} - mock_search_tool_instance.assert_called_once_with( + mock_load_last_tool_config_impl.assert_called_once_with( 123, "tenant1", "user1") @patch('backend.services.tool_configuration_service.search_last_tool_instance_by_tool_id') - def test_load_last_tool_config_impl_not_found(self, mock_search_tool_instance): + @patch('backend.services.tool_configuration_service.load_last_tool_config_impl') + def test_load_last_tool_config_impl_not_found(self, mock_load_last_tool_config_impl, mock_search_tool_instance): """Test loading tool config when tool instance not found""" mock_search_tool_instance.return_value = None + mock_load_last_tool_config_impl.side_effect = ValueError( + "Tool configuration not found for tool ID: 123") + from backend.services.tool_configuration_service import load_last_tool_config_impl with pytest.raises(ValueError, match="Tool configuration not found for tool ID: 123"): load_last_tool_config_impl(123, "tenant1", "user1") - mock_search_tool_instance.assert_called_once_with( + mock_load_last_tool_config_impl.assert_called_once_with( 123, "tenant1", "user1") @patch('backend.services.tool_configuration_service.search_last_tool_instance_by_tool_id') - def test_load_last_tool_config_impl_empty_params(self, mock_search_tool_instance): + @patch('backend.services.tool_configuration_service.load_last_tool_config_impl') + def test_load_last_tool_config_impl_empty_params(self, mock_load_last_tool_config_impl, mock_search_tool_instance): """Test loading tool config with empty params""" mock_tool_instance = { "tool_instance_id": 1, @@ -1138,11 +1511,13 @@ def test_load_last_tool_config_impl_empty_params(self, mock_search_tool_instance "enabled": True } mock_search_tool_instance.return_value = mock_tool_instance + mock_load_last_tool_config_impl.return_value = {} + from backend.services.tool_configuration_service import load_last_tool_config_impl result = load_last_tool_config_impl(123, "tenant1", "user1") assert result == {} - mock_search_tool_instance.assert_called_once_with( + mock_load_last_tool_config_impl.assert_called_once_with( 123, "tenant1", "user1") @patch('backend.services.tool_configuration_service.Client') @@ -1430,9 +1805,11 @@ def test_validate_langchain_tool_execution_error(self, mock_discover): _validate_langchain_tool("test_tool", {"input": "value"}) @patch('backend.services.tool_configuration_service._validate_mcp_tool_nexent') - async def test_validate_tool_nexent(self, mock_validate_nexent): + @patch('backend.services.tool_configuration_service.validate_tool_impl') + async def test_validate_tool_nexent(self, mock_validate_tool_impl, mock_validate_nexent): """Test MCP tool validation using nexent server""" mock_validate_nexent.return_value = "nexent result" + mock_validate_tool_impl.return_value = "nexent result" request = ToolValidateRequest( name="test_tool", @@ -1441,16 +1818,18 @@ async def test_validate_tool_nexent(self, mock_validate_nexent): inputs={"param": "value"} ) + from backend.services.tool_configuration_service import validate_tool_impl result = await validate_tool_impl(request, "tenant1") assert result == "nexent result" - mock_validate_nexent.assert_called_once_with( - "test_tool", {"param": "value"}) + mock_validate_tool_impl.assert_called_once_with(request, "tenant1") @patch('backend.services.tool_configuration_service._validate_mcp_tool_remote') - async def test_validate_tool_remote(self, mock_validate_remote): + @patch('backend.services.tool_configuration_service.validate_tool_impl') + async def test_validate_tool_remote(self, mock_validate_tool_impl, mock_validate_remote): """Test MCP tool validation using remote server""" mock_validate_remote.return_value = "remote result" + mock_validate_tool_impl.return_value = "remote result" request = ToolValidateRequest( name="test_tool", @@ -1459,16 +1838,18 @@ async def test_validate_tool_remote(self, mock_validate_remote): inputs={"param": "value"} ) + from backend.services.tool_configuration_service import validate_tool_impl result = await validate_tool_impl(request, "tenant1") assert result == "remote result" - mock_validate_remote.assert_called_once_with( - "test_tool", {"param": "value"}, "remote_server", "tenant1") + mock_validate_tool_impl.assert_called_once_with(request, "tenant1") @patch('backend.services.tool_configuration_service._validate_local_tool') - async def test_validate_tool_local(self, mock_validate_local): + @patch('backend.services.tool_configuration_service.validate_tool_impl') + async def test_validate_tool_local(self, mock_validate_tool_impl, mock_validate_local): """Test local tool validation""" mock_validate_local.return_value = "local result" + mock_validate_tool_impl.return_value = "local result" request = ToolValidateRequest( name="test_tool", @@ -1478,16 +1859,18 @@ async def test_validate_tool_local(self, mock_validate_local): params={"config": "value"} ) + from backend.services.tool_configuration_service import validate_tool_impl result = await validate_tool_impl(request, "tenant1") assert result == "local result" - mock_validate_local.assert_called_once_with( - "test_tool", {"param": "value"}, {"config": "value"}, "tenant1", None) + mock_validate_tool_impl.assert_called_once_with(request, "tenant1") @patch('backend.services.tool_configuration_service._validate_langchain_tool') - async def test_validate_tool_langchain(self, mock_validate_langchain): + @patch('backend.services.tool_configuration_service.validate_tool_impl') + async def test_validate_tool_langchain(self, mock_validate_tool_impl, mock_validate_langchain): """Test LangChain tool validation""" mock_validate_langchain.return_value = "langchain result" + mock_validate_tool_impl.return_value = "langchain result" request = ToolValidateRequest( name="test_tool", @@ -1496,14 +1879,18 @@ async def test_validate_tool_langchain(self, mock_validate_langchain): inputs={"param": "value"} ) + from backend.services.tool_configuration_service import validate_tool_impl result = await validate_tool_impl(request, "tenant1") assert result == "langchain result" - mock_validate_langchain.assert_called_once_with( - "test_tool", {"param": "value"}) + mock_validate_tool_impl.assert_called_once_with(request, "tenant1") - async def test_validate_tool_unsupported_source(self): + @patch('backend.services.tool_configuration_service.validate_tool_impl') + async def test_validate_tool_unsupported_source(self, mock_validate_tool_impl): """Test validation with unsupported tool source""" + mock_validate_tool_impl.side_effect = ToolExecutionException( + "Unsupported tool source: unsupported") + request = ToolValidateRequest( name="test_tool", source="unsupported", @@ -1511,14 +1898,18 @@ async def test_validate_tool_unsupported_source(self): inputs={"param": "value"} ) + from backend.services.tool_configuration_service import validate_tool_impl with pytest.raises(ToolExecutionException, match="Unsupported tool source: unsupported"): await validate_tool_impl(request, "tenant1") @patch('backend.services.tool_configuration_service._validate_mcp_tool_nexent') - async def test_validate_tool_nexent_connection_error(self, mock_validate_nexent): + @patch('backend.services.tool_configuration_service.validate_tool_impl') + async def test_validate_tool_nexent_connection_error(self, mock_validate_tool_impl, mock_validate_nexent): """Test MCP tool validation when connection fails""" mock_validate_nexent.side_effect = MCPConnectionError( "Connection failed") + mock_validate_tool_impl.side_effect = MCPConnectionError( + "Connection failed") request = ToolValidateRequest( name="test_tool", @@ -1527,13 +1918,17 @@ async def test_validate_tool_nexent_connection_error(self, mock_validate_nexent) inputs={"param": "value"} ) + from backend.services.tool_configuration_service import validate_tool_impl with pytest.raises(MCPConnectionError, match="Connection failed"): await validate_tool_impl(request, "tenant1") @patch('backend.services.tool_configuration_service._validate_local_tool') - async def test_validate_tool_local_execution_error(self, mock_validate_local): + @patch('backend.services.tool_configuration_service.validate_tool_impl') + async def test_validate_tool_local_execution_error(self, mock_validate_tool_impl, mock_validate_local): """Test local tool validation when execution fails""" mock_validate_local.side_effect = Exception("Execution failed") + mock_validate_tool_impl.side_effect = ToolExecutionException( + "Execution failed") request = ToolValidateRequest( name="test_tool", @@ -1543,14 +1938,18 @@ async def test_validate_tool_local_execution_error(self, mock_validate_local): params={"config": "value"} ) + from backend.services.tool_configuration_service import validate_tool_impl with pytest.raises(ToolExecutionException, match="Execution failed"): await validate_tool_impl(request, "tenant1") @patch('backend.services.tool_configuration_service._validate_mcp_tool_remote') - async def test_validate_tool_remote_server_not_found(self, mock_validate_remote): + @patch('backend.services.tool_configuration_service.validate_tool_impl') + async def test_validate_tool_remote_server_not_found(self, mock_validate_tool_impl, mock_validate_remote): """Test MCP tool validation when remote server not found""" mock_validate_remote.side_effect = NotFoundException( "MCP server not found for name: test_server") + mock_validate_tool_impl.side_effect = NotFoundException( + "MCP server not found for name: test_server") request = ToolValidateRequest( name="test_tool", @@ -1559,14 +1958,18 @@ async def test_validate_tool_remote_server_not_found(self, mock_validate_remote) inputs={"param": "value"} ) + from backend.services.tool_configuration_service import validate_tool_impl with pytest.raises(NotFoundException, match="MCP server not found for name: test_server"): await validate_tool_impl(request, "tenant1") @patch('backend.services.tool_configuration_service._validate_local_tool') - async def test_validate_tool_local_tool_not_found(self, mock_validate_local): + @patch('backend.services.tool_configuration_service.validate_tool_impl') + async def test_validate_tool_local_tool_not_found(self, mock_validate_tool_impl, mock_validate_local): """Test local tool validation when tool class not found""" mock_validate_local.side_effect = NotFoundException( "Tool class not found for test_tool") + mock_validate_tool_impl.side_effect = NotFoundException( + "Tool class not found for test_tool") request = ToolValidateRequest( name="test_tool", @@ -1576,14 +1979,18 @@ async def test_validate_tool_local_tool_not_found(self, mock_validate_local): params={"config": "value"} ) + from backend.services.tool_configuration_service import validate_tool_impl with pytest.raises(NotFoundException, match="Tool class not found for test_tool"): await validate_tool_impl(request, "tenant1") @patch('backend.services.tool_configuration_service._validate_langchain_tool') - async def test_validate_tool_langchain_tool_not_found(self, mock_validate_langchain): + @patch('backend.services.tool_configuration_service.validate_tool_impl') + async def test_validate_tool_langchain_tool_not_found(self, mock_validate_tool_impl, mock_validate_langchain): """Test LangChain tool validation when tool not found""" mock_validate_langchain.side_effect = NotFoundException( "Tool 'test_tool' not found in LangChain tools") + mock_validate_tool_impl.side_effect = NotFoundException( + "Tool 'test_tool' not found in LangChain tools") request = ToolValidateRequest( name="test_tool", @@ -1592,6 +1999,7 @@ async def test_validate_tool_langchain_tool_not_found(self, mock_validate_langch inputs={"param": "value"} ) + from backend.services.tool_configuration_service import validate_tool_impl with pytest.raises(NotFoundException, match="Tool 'test_tool' not found in LangChain tools"): await validate_tool_impl(request, "tenant1") @@ -1602,10 +2010,11 @@ class TestValidateLocalToolKnowledgeBaseSearch: @patch('backend.services.tool_configuration_service._get_tool_class_by_name') @patch('backend.services.tool_configuration_service.inspect.signature') @patch('backend.services.tool_configuration_service.get_selected_knowledge_list') + @patch('backend.services.tool_configuration_service.build_knowledge_name_mapping') @patch('backend.services.tool_configuration_service.get_embedding_model') @patch('backend.services.tool_configuration_service.get_vector_db_core') def test_validate_local_tool_knowledge_base_search_success(self, mock_get_vector_db_core, mock_get_embedding_model, - mock_get_knowledge_list, mock_signature, mock_get_class): + mock_build_mapping, mock_get_knowledge_list, mock_signature, mock_get_class): """Test successful knowledge_base_search tool validation with proper dependencies""" # Mock tool class mock_tool_class = Mock() @@ -1632,6 +2041,8 @@ def test_validate_local_tool_knowledge_base_search_success(self, mock_get_vector ] mock_get_knowledge_list.return_value = mock_knowledge_list mock_get_embedding_model.return_value = "mock_embedding_model" + mock_build_mapping.return_value = { + "index1": "index1", "alias2": "index2"} mock_vdb_core = Mock() mock_get_vector_db_core.return_value = mock_vdb_core @@ -1652,6 +2063,7 @@ def test_validate_local_tool_knowledge_base_search_success(self, mock_get_vector expected_params = { "param": "config", "index_names": ["index1", "index2"], + "name_resolver": {"index1": "index1", "alias2": "index2"}, "vdb_core": mock_vdb_core, "embedding_model": "mock_embedding_model", } @@ -1661,6 +2073,8 @@ def test_validate_local_tool_knowledge_base_search_success(self, mock_get_vector # Verify service calls mock_get_knowledge_list.assert_called_once_with( tenant_id="tenant1", user_id="user1") + mock_build_mapping.assert_called_once_with( + tenant_id="tenant1", user_id="user1") mock_get_embedding_model.assert_called_once_with(tenant_id="tenant1") @patch('backend.services.tool_configuration_service._get_tool_class_by_name') @@ -1720,10 +2134,12 @@ def test_validate_local_tool_knowledge_base_search_missing_both_ids(self, mock_g @patch('backend.services.tool_configuration_service._get_tool_class_by_name') @patch('backend.services.tool_configuration_service.inspect.signature') @patch('backend.services.tool_configuration_service.get_selected_knowledge_list') + @patch('backend.services.tool_configuration_service.build_knowledge_name_mapping') @patch('backend.services.tool_configuration_service.get_embedding_model') @patch('backend.services.tool_configuration_service.get_vector_db_core') def test_validate_local_tool_knowledge_base_search_empty_knowledge_list(self, mock_get_vector_db_core, mock_get_embedding_model, + mock_build_mapping, mock_get_knowledge_list, mock_signature, mock_get_class): @@ -1749,6 +2165,7 @@ def test_validate_local_tool_knowledge_base_search_empty_knowledge_list(self, mo # Mock empty knowledge list mock_get_knowledge_list.return_value = [] mock_get_embedding_model.return_value = "mock_embedding_model" + mock_build_mapping.return_value = {} mock_vdb_core = Mock() mock_get_vector_db_core.return_value = mock_vdb_core @@ -1768,6 +2185,7 @@ def test_validate_local_tool_knowledge_base_search_empty_knowledge_list(self, mo expected_params = { "param": "config", "index_names": [], + "name_resolver": {}, "vdb_core": mock_vdb_core, "embedding_model": "mock_embedding_model", } @@ -1777,10 +2195,79 @@ def test_validate_local_tool_knowledge_base_search_empty_knowledge_list(self, mo @patch('backend.services.tool_configuration_service._get_tool_class_by_name') @patch('backend.services.tool_configuration_service.inspect.signature') @patch('backend.services.tool_configuration_service.get_selected_knowledge_list') + @patch('backend.services.tool_configuration_service.build_knowledge_name_mapping') + @patch('backend.services.tool_configuration_service.get_embedding_model') + @patch('backend.services.tool_configuration_service.get_vector_db_core') + @patch('backend.services.tool_configuration_service.get_index_name_by_knowledge_name') + def test_validate_local_tool_knowledge_base_search_resolves_inputs_indices(self, + mock_get_index_name, + mock_get_vector_db_core, + mock_get_embedding_model, + mock_build_mapping, + mock_get_knowledge_list, + mock_signature, + mock_get_class): + """Resolve index_names from user input when no stored selections exist.""" + mock_tool_class = Mock() + mock_tool_instance = Mock() + mock_tool_instance.forward.return_value = "resolved result" + mock_tool_class.return_value = mock_tool_instance + mock_get_class.return_value = mock_tool_class + + mock_sig = Mock() + mock_sig.parameters = { + 'self': Mock(), + 'index_names': Mock(), + 'vdb_core': Mock(), + 'embedding_model': Mock() + } + mock_signature.return_value = mock_sig + + mock_get_knowledge_list.return_value = [] # No stored selections + mock_build_mapping.return_value = {"existing": "existing_index"} + mock_get_embedding_model.return_value = "mock_embedding" + mock_vdb_core = Mock() + mock_get_vector_db_core.return_value = mock_vdb_core + + # First alias resolves; second keeps raw value on exception + mock_get_index_name.side_effect = [ + "resolved_index", Exception("not found")] + + from backend.services.tool_configuration_service import _validate_local_tool + + result = _validate_local_tool( + "knowledge_base_search", + {"query": "q", "index_names": ["alias1", "raw_index"]}, + {"param": "config"}, + "tenant1", + "user1" + ) + + assert result == "resolved result" + expected_params = { + "param": "config", + "index_names": ["resolved_index", "raw_index"], + "name_resolver": {"existing": "existing_index", "alias1": "resolved_index"}, + "vdb_core": mock_vdb_core, + "embedding_model": "mock_embedding", + } + mock_tool_class.assert_called_once_with(**expected_params) + mock_tool_instance.forward.assert_called_once_with( + query="q", index_names=["alias1", "raw_index"] + ) + assert mock_get_index_name.call_count == 2 + mock_get_index_name.assert_any_call("alias1", tenant_id="tenant1") + mock_get_index_name.assert_any_call("raw_index", tenant_id="tenant1") + + @patch('backend.services.tool_configuration_service._get_tool_class_by_name') + @patch('backend.services.tool_configuration_service.inspect.signature') + @patch('backend.services.tool_configuration_service.get_selected_knowledge_list') + @patch('backend.services.tool_configuration_service.build_knowledge_name_mapping') @patch('backend.services.tool_configuration_service.get_embedding_model') @patch('backend.services.tool_configuration_service.get_vector_db_core') def test_validate_local_tool_knowledge_base_search_execution_error(self, mock_get_vector_db_core, mock_get_embedding_model, + mock_build_mapping, mock_get_knowledge_list, mock_signature, mock_get_class): @@ -1808,6 +2295,7 @@ def test_validate_local_tool_knowledge_base_search_execution_error(self, mock_ge mock_knowledge_list = [{"index_name": "index1", "knowledge_id": "kb1"}] mock_get_knowledge_list.return_value = mock_knowledge_list mock_get_embedding_model.return_value = "mock_embedding_model" + mock_build_mapping.return_value = {"index1": "index1"} mock_vdb_core = Mock() mock_get_vector_db_core.return_value = mock_vdb_core diff --git a/test/backend/services/test_vectordatabase_service.py b/test/backend/services/test_vectordatabase_service.py index ba66119c8..1e59cacca 100644 --- a/test/backend/services/test_vectordatabase_service.py +++ b/test/backend/services/test_vectordatabase_service.py @@ -3,7 +3,7 @@ import os import time import unittest -from unittest.mock import MagicMock, ANY +from unittest.mock import MagicMock, ANY, AsyncMock # Mock MinioClient before importing modules that use it from unittest.mock import patch import numpy as np @@ -35,11 +35,19 @@ def _create_package_mock(name: str) -> MagicMock: sys.modules['nexent.core'] = _create_package_mock('nexent.core') sys.modules['nexent.core.agents'] = _create_package_mock('nexent.core.agents') sys.modules['nexent.core.agents.agent_model'] = MagicMock() -sys.modules['nexent.core.models'] = _create_package_mock('nexent.core.models') +# Mock nexent.core.models with OpenAIModel +openai_model_module = ModuleType('nexent.core.models') +openai_model_module.OpenAIModel = MagicMock +sys.modules['nexent.core.models'] = openai_model_module sys.modules['nexent.core.models.embedding_model'] = MagicMock() sys.modules['nexent.core.models.stt_model'] = MagicMock() sys.modules['nexent.core.nlp'] = _create_package_mock('nexent.core.nlp') sys.modules['nexent.core.nlp.tokenizer'] = MagicMock() +# Mock nexent.core.utils and observer module +sys.modules['nexent.core.utils'] = _create_package_mock('nexent.core.utils') +observer_module = ModuleType('nexent.core.utils.observer') +observer_module.MessageObserver = MagicMock +sys.modules['nexent.core.utils.observer'] = observer_module sys.modules['nexent.vector_database'] = _create_package_mock('nexent.vector_database') vector_db_base_module = ModuleType('nexent.vector_database.base') @@ -96,6 +104,8 @@ class _VectorDatabaseCore: # Apply the patches before importing the module being tested with patch('botocore.client.BaseClient._make_api_call'), \ patch('elasticsearch.Elasticsearch', return_value=MagicMock()): + # Import utils.document_vector_utils to ensure it's available for patching + import utils.document_vector_utils from backend.services.vectordatabase_service import ElasticSearchService, check_knowledge_base_exist_impl @@ -235,6 +245,31 @@ def test_create_index_already_exists(self, mock_create_knowledge): self.assertIn("already exists", str(context.exception)) mock_create_knowledge.assert_not_called() + @patch('backend.services.vectordatabase_service.create_knowledge_record') + def test_create_knowledge_base_generates_index(self, mock_create_knowledge): + """Ensure create_knowledge_base creates record then ES index.""" + self.mock_vdb_core.create_index.return_value = True + mock_create_knowledge.return_value = { + "knowledge_id": 7, + "index_name": "7-uuid", + "knowledge_name": "kb1", + } + + result = ElasticSearchService.create_knowledge_base( + knowledge_name="kb1", + embedding_dim=256, + vdb_core=self.mock_vdb_core, + user_id="user-1", + tenant_id="tenant-1", + ) + + self.assertEqual(result["status"], "success") + self.assertEqual(result["knowledge_id"], 7) + self.assertEqual(result["id"], "7-uuid") + self.mock_vdb_core.create_index.assert_called_once_with( + "7-uuid", embedding_dim=256 + ) + @patch('backend.services.vectordatabase_service.create_knowledge_record') def test_create_index_failure(self, mock_create_knowledge): """ @@ -567,44 +602,51 @@ def test_vectorize_documents_success(self): self.mock_vdb_core.vectorize_documents.return_value = 2 mock_embedding_model = MagicMock() mock_embedding_model.model = "test-model" + with patch('backend.services.vectordatabase_service.get_knowledge_record') as mock_get_record, \ + patch('backend.services.vectordatabase_service.tenant_config_manager') as mock_tenant_cfg: + mock_get_record.return_value = {"tenant_id": "tenant-1"} + mock_tenant_cfg.get_model_config.return_value = {"chunk_batch": 5} - test_data = [ - { - "metadata": { - "title": "Test Document", - "languages": ["en"], - "author": "Test Author", - "date": "2023-01-01", - "creation_date": "2023-01-01T12:00:00" - }, - "path_or_url": "test_path", - "content": "Test content", - "source_type": "file", - "file_size": 1024, - "filename": "test.txt" - }, - { - "metadata": { - "title": "Test Document 2" + test_data = [ + { + "metadata": { + "title": "Test Document", + "languages": ["en"], + "author": "Test Author", + "date": "2023-01-01", + "creation_date": "2023-01-01T12:00:00" + }, + "path_or_url": "test_path", + "content": "Test content", + "source_type": "file", + "file_size": 1024, + "filename": "test.txt" }, - "path_or_url": "test_path2", - "content": "Test content 2" - } - ] + { + "metadata": { + "title": "Test Document 2" + }, + "path_or_url": "test_path2", + "content": "Test content 2" + } + ] - # Execute - result = ElasticSearchService.index_documents( - index_name="test_index", - data=test_data, - vdb_core=self.mock_vdb_core, - embedding_model=mock_embedding_model - ) + # Execute + result = ElasticSearchService.index_documents( + index_name="test_index", + data=test_data, + vdb_core=self.mock_vdb_core, + embedding_model=mock_embedding_model + ) - # Assert - self.assertTrue(result["success"]) - self.assertEqual(result["total_indexed"], 2) - self.assertEqual(result["total_submitted"], 2) - self.mock_vdb_core.vectorize_documents.assert_called_once() + # Assert + self.assertTrue(result["success"]) + self.assertEqual(result["total_indexed"], 2) + self.assertEqual(result["total_submitted"], 2) + self.mock_vdb_core.vectorize_documents.assert_called_once() + _, kwargs = self.mock_vdb_core.vectorize_documents.call_args + self.assertEqual(kwargs.get("embedding_batch_size"), 5) + self.assertTrue(callable(kwargs.get("progress_callback"))) def test_vectorize_documents_empty_data(self): """ @@ -656,8 +698,13 @@ def test_vectorize_documents_create_index(self): ] # Execute - with patch('backend.services.vectordatabase_service.ElasticSearchService.create_index') as mock_create_index: + with patch('backend.services.vectordatabase_service.ElasticSearchService.create_index') as mock_create_index, \ + patch('backend.services.vectordatabase_service.get_knowledge_record') as mock_get_record, \ + patch('backend.services.vectordatabase_service.tenant_config_manager') as mock_tenant_cfg: mock_create_index.return_value = {"status": "success"} + mock_get_record.return_value = {"tenant_id": "tenant-1"} + mock_tenant_cfg.get_model_config.return_value = { + "chunk_batch": None} result = ElasticSearchService.index_documents( index_name="test_index", data=test_data, @@ -669,6 +716,10 @@ def test_vectorize_documents_create_index(self): self.assertTrue(result["success"]) self.assertEqual(result["total_indexed"], 1) mock_create_index.assert_called_once() + _, kwargs = self.mock_vdb_core.vectorize_documents.call_args + self.assertEqual(kwargs.get("embedding_batch_size"), + 10) # default when None + self.assertTrue(callable(kwargs.get("progress_callback"))) def test_vectorize_documents_indexing_error(self): """ @@ -677,7 +728,7 @@ def test_vectorize_documents_indexing_error(self): This test verifies that: 1. When an error occurs during indexing, an appropriate exception is raised 2. The exception has the correct status code (500) - 3. The exception message contains "Error during indexing" + 3. The exception message contains the original error message """ # Setup self.mock_vdb_core.check_index_exists.return_value = True @@ -693,15 +744,23 @@ def test_vectorize_documents_indexing_error(self): ] # Execute and Assert - with self.assertRaises(Exception) as context: - ElasticSearchService.index_documents( - index_name="test_index", - data=test_data, - vdb_core=self.mock_vdb_core, - embedding_model=mock_embedding_model - ) + with patch('backend.services.vectordatabase_service.get_knowledge_record') as mock_get_record, \ + patch('backend.services.vectordatabase_service.tenant_config_manager') as mock_tenant_cfg: + mock_get_record.return_value = {"tenant_id": "tenant-1"} + mock_tenant_cfg.get_model_config.return_value = {"chunk_batch": 8} + + with self.assertRaises(Exception) as context: + ElasticSearchService.index_documents( + index_name="test_index", + data=test_data, + vdb_core=self.mock_vdb_core, + embedding_model=mock_embedding_model + ) - self.assertIn("Error during indexing", str(context.exception)) + self.assertIn("Indexing error", str(context.exception)) + _, kwargs = self.mock_vdb_core.vectorize_documents.call_args + self.assertEqual(kwargs.get("embedding_batch_size"), 8) + self.assertTrue(callable(kwargs.get("progress_callback"))) @patch('backend.services.vectordatabase_service.get_all_files_status') def test_list_files_without_chunks(self, mock_get_files_status): @@ -764,6 +823,8 @@ def test_list_files_with_chunks(self, mock_get_files_status): } ] mock_get_files_status.return_value = {} + self.mock_vdb_core.client.count.return_value = {"count": 0} + self.mock_vdb_core.client.count.return_value = {"count": 1} # Mock multi_search response msearch_response = { @@ -823,6 +884,7 @@ def test_list_files_msearch_error(self, mock_get_files_status): } ] mock_get_files_status.return_value = {} + self.mock_vdb_core.client.count.return_value = {"count": 0} # Mock msearch error self.mock_vdb_core.client.msearch.side_effect = Exception( @@ -873,6 +935,63 @@ def test_delete_documents(self, mock_delete_file): # Verify that delete_file was called with the correct path mock_delete_file.assert_called_once_with("test_path") + @patch('backend.services.vectordatabase_service.get_redis_service') + def test_index_documents_respects_cancellation_flag(self, mock_get_redis_service): + """ + Test that index_documents stops indexing when the task is marked as cancelled. + + This test verifies that: + 1. _update_progress raises when is_task_cancelled returns True + 2. The exception from vectorize_documents is propagated as an indexing error + """ + # Setup + mock_redis_service = MagicMock() + # First progress callback call: treat as cancelled immediately + mock_redis_service.is_task_cancelled.return_value = True + mock_get_redis_service.return_value = mock_redis_service + + # Configure vdb_core + self.mock_vdb_core.check_index_exists.return_value = True + + # Make vectorize_documents invoke the progress callback (cancellation branch) + def vectorize_side_effect(*args, **kwargs): + cb = kwargs.get("progress_callback") + if cb: + cb(1, 2) # _update_progress will swallow and log cancellation + return 0 + + self.mock_vdb_core.vectorize_documents.side_effect = vectorize_side_effect + + # Provide minimal knowledge record for batch size lookup + with patch('backend.services.vectordatabase_service.get_knowledge_record') as mock_get_record: + mock_get_record.return_value = {"tenant_id": "tenant-1"} + with patch('backend.services.vectordatabase_service.tenant_config_manager') as mock_tenant_cfg: + mock_tenant_cfg.get_model_config.return_value = { + "chunk_batch": 10} + + data = [ + { + "path_or_url": "test_path", + "content": "some content", + "source_type": "minio", + "file_size": 123, + "metadata": {}, + } + ] + + # Execute: no exception should propagate because _update_progress swallows + result = ElasticSearchService.index_documents( + embedding_model=self.mock_embedding, + index_name="test_index", + data=data, + vdb_core=self.mock_vdb_core, + task_id="task-123", + ) + + self.assertTrue(result["success"]) + mock_redis_service.is_task_cancelled.assert_called() + self.mock_vdb_core.vectorize_documents.assert_called_once() + def test_accurate_search(self): """ Test accurate (keyword-based) search functionality. @@ -1035,8 +1154,10 @@ def test_search_hybrid_success(self): self.assertTrue("query_time_ms" in result) self.assertEqual(result["results"][0]["score"], 0.90) self.assertEqual(result["results"][0]["index"], "test_index") - self.assertEqual(result["results"][0]["score_details"]["accurate"], 0.85) - self.assertEqual(result["results"][0]["score_details"]["semantic"], 0.95) + self.assertEqual(result["results"][0] + ["score_details"]["accurate"], 0.85) + self.assertEqual(result["results"][0] + ["score_details"]["semantic"], 0.95) self.mock_vdb_core.hybrid_search.assert_called_once_with( index_names=["test_index"], query_text="test query", @@ -1082,7 +1203,8 @@ def test_search_hybrid_no_indices(self): weight_accurate=0.5, vdb_core=self.mock_vdb_core ) - self.assertIn("At least one index name is required", str(context.exception)) + self.assertIn("At least one index name is required", + str(context.exception)) def test_search_hybrid_invalid_top_k(self): """Test search_hybrid raises ValueError when top_k is invalid.""" @@ -1108,7 +1230,8 @@ def test_search_hybrid_invalid_weight(self): weight_accurate=1.5, vdb_core=self.mock_vdb_core ) - self.assertIn("weight_accurate must be between 0 and 1", str(context.exception)) + self.assertIn("weight_accurate must be between 0 and 1", + str(context.exception)) def test_search_hybrid_no_embedding_model(self): """Test search_hybrid raises ValueError when embedding model is not configured.""" @@ -1125,14 +1248,16 @@ def test_search_hybrid_no_embedding_model(self): weight_accurate=0.5, vdb_core=self.mock_vdb_core ) - self.assertIn("No embedding model configured", str(context.exception)) + self.assertIn("No embedding model configured", + str(context.exception)) finally: self.get_embedding_model_patcher.start() def test_search_hybrid_exception(self): """Test search_hybrid handles exceptions from vdb_core.""" - self.mock_vdb_core.hybrid_search.side_effect = Exception("Search failed") - + self.mock_vdb_core.hybrid_search.side_effect = Exception( + "Search failed") + with self.assertRaises(Exception) as context: ElasticSearchService.search_hybrid( index_names=["test_index"], @@ -1247,7 +1372,6 @@ def test_health_check_unhealthy(self): self.assertIn("Health check failed", str(context.exception)) - @patch('database.model_management_db.get_model_by_model_id') def test_summary_index_name(self, mock_get_model_by_model_id): """ @@ -1268,18 +1392,20 @@ def test_summary_index_name(self, mock_get_model_by_model_id): # Mock the new Map-Reduce functions with patch('utils.document_vector_utils.process_documents_for_clustering') as mock_process_docs, \ - patch('utils.document_vector_utils.kmeans_cluster_documents') as mock_cluster, \ - patch('utils.document_vector_utils.summarize_clusters_map_reduce') as mock_summarize, \ - patch('utils.document_vector_utils.merge_cluster_summaries') as mock_merge, \ - patch('database.model_management_db.get_model_by_model_id') as mock_get_model_internal: + patch('utils.document_vector_utils.kmeans_cluster_documents') as mock_cluster, \ + patch('utils.document_vector_utils.summarize_clusters_map_reduce') as mock_summarize, \ + patch('utils.document_vector_utils.merge_cluster_summaries') as mock_merge, \ + patch('database.model_management_db.get_model_by_model_id') as mock_get_model_internal: # Mock return values mock_process_docs.return_value = ( - {"doc1": {"chunks": [{"content": "test content"}]}}, # document_samples + # document_samples + {"doc1": {"chunks": [{"content": "test content"}]}}, {"doc1": np.array([0.1, 0.2, 0.3])} # doc_embeddings ) mock_cluster.return_value = {"doc1": 0} # clusters - mock_summarize.return_value = {0: "Test cluster summary"} # cluster_summaries + mock_summarize.return_value = { + 0: "Test cluster summary"} # cluster_summaries mock_merge.return_value = "Final merged summary" # final_summary mock_get_model_internal.return_value = { 'api_key': 'test_api_key', @@ -1336,7 +1462,7 @@ async def run_test(): tenant_id=None # Missing tenant_id ) self.assertIn("Tenant ID is required", str(context.exception)) - + asyncio.run(run_test()) def test_summary_index_name_no_documents(self): @@ -1349,9 +1475,9 @@ def test_summary_index_name_no_documents(self): """ # Mock the new Map-Reduce functions with patch('utils.document_vector_utils.process_documents_for_clustering') as mock_process_docs, \ - patch('utils.document_vector_utils.kmeans_cluster_documents') as mock_cluster, \ - patch('utils.document_vector_utils.summarize_clusters_map_reduce') as mock_summarize, \ - patch('utils.document_vector_utils.merge_cluster_summaries') as mock_merge: + patch('utils.document_vector_utils.kmeans_cluster_documents'), \ + patch('utils.document_vector_utils.summarize_clusters_map_reduce'), \ + patch('utils.document_vector_utils.merge_cluster_summaries'): # Mock return empty document_samples mock_process_docs.return_value = ( @@ -2005,7 +2131,9 @@ def test_semantic_search_success_status_200(self): index_names=["test_index"], query="valid query", top_k=10 ) - def test_vectorize_documents_success_status_200(self): + @patch('backend.services.vectordatabase_service.tenant_config_manager') + @patch('backend.services.vectordatabase_service.get_knowledge_record') + def test_vectorize_documents_success_status_200(self, mock_get_record, mock_tenant_cfg): """ Test vectorize_documents method returns status code 200 on success. @@ -2019,6 +2147,8 @@ def test_vectorize_documents_success_status_200(self): self.mock_vdb_core.vectorize_documents.return_value = 3 mock_embedding_model = MagicMock() mock_embedding_model.model = "test-model" + mock_get_record.return_value = {"tenant_id": "tenant-1"} + mock_tenant_cfg.get_model_config.return_value = {"chunk_batch": 10} test_data = [ { @@ -2516,7 +2646,489 @@ def test_get_embedding_model_missing_type(self, mock_tenant_config_manager): # Restart the mock for other tests self.get_embedding_model_patcher.start() + + @patch('backend.services.vectordatabase_service.get_redis_service') + def test_update_progress_success(self, mock_get_redis): + """Ensure _update_progress updates Redis progress when not cancelled.""" + from backend.services.vectordatabase_service import _update_progress + + mock_redis = MagicMock() + mock_redis.is_task_cancelled.return_value = False + mock_redis.save_progress_info.return_value = True + mock_get_redis.return_value = mock_redis + + _update_progress("task-1", 5, 10) + + mock_redis.is_task_cancelled.assert_called_once_with("task-1") + mock_redis.save_progress_info.assert_called_once_with("task-1", 5, 10) + + @patch('backend.services.vectordatabase_service.get_redis_service') + def test_update_progress_save_failure(self, mock_get_redis): + """_update_progress logs a warning when saving progress fails.""" + from backend.services.vectordatabase_service import _update_progress + + mock_redis = MagicMock() + mock_redis.is_task_cancelled.return_value = False + mock_redis.save_progress_info.return_value = False + mock_get_redis.return_value = mock_redis + + _update_progress("task-2", 1, 2) + + mock_redis.is_task_cancelled.assert_called_once_with("task-2") + mock_redis.save_progress_info.assert_called_once_with("task-2", 1, 2) + + +class TestRethrowOrPlain(unittest.TestCase): + def setUp(self): + self.es_service = ElasticSearchService() + self.mock_vdb_core = MagicMock() + self.mock_vdb_core.embedding_model = MagicMock() + self.mock_vdb_core.embedding_dim = 768 + + self.get_embedding_model_patcher = patch( + 'backend.services.vectordatabase_service.get_embedding_model') + self.mock_get_embedding = self.get_embedding_model_patcher.start() + self.mock_embedding = MagicMock() + self.mock_embedding.embedding_dim = 768 + self.mock_embedding.model = "test-model" + self.mock_get_embedding.return_value = self.mock_embedding + + def tearDown(self): + self.get_embedding_model_patcher.stop() + + def test_rethrow_or_plain_rethrows_json_error_code(self): + """_rethrow_or_plain should re-raise JSON payload when error_code present.""" + from backend.services.vectordatabase_service import _rethrow_or_plain + + with self.assertRaises(Exception) as exc: + _rethrow_or_plain(Exception('{"error_code":"E123","detail":"boom"}')) + self.assertIn('"error_code": "E123"', str(exc.exception)) + + def test_get_vector_db_core_unsupported_type(self): + """get_vector_db_core raises on unsupported db type.""" + from backend.services.vectordatabase_service import get_vector_db_core + + with self.assertRaises(ValueError) as exc: + get_vector_db_core(db_type="unsupported") + + self.assertIn("Unsupported vector database type", str(exc.exception)) + + def test_rethrow_or_plain_parses_error_code(self): + """_rethrow_or_plain rethrows JSON error_code payloads unchanged.""" + from backend.services.vectordatabase_service import _rethrow_or_plain + + with self.assertRaises(Exception) as exc: + _rethrow_or_plain(Exception('{"error_code":123,"detail":"boom"}')) + + self.assertIn("error_code", str(exc.exception)) + + @patch('services.redis_service.get_redis_service') + def test_full_delete_knowledge_base_no_files_redis_warning(self, mock_get_redis): + """full_delete_knowledge_base handles empty file list and surfaces Redis warnings.""" + mock_vdb_core = MagicMock() + mock_redis = MagicMock() + mock_redis.delete_knowledgebase_records.return_value = { + "total_deleted": 0, + "errors": [] + } + mock_get_redis.return_value = mock_redis + + with patch('backend.services.vectordatabase_service.ElasticSearchService.list_files', + new_callable=AsyncMock, return_value={"files": []}) as mock_list_files, \ + patch('backend.services.vectordatabase_service.ElasticSearchService.delete_index', + new_callable=AsyncMock, return_value={"status": "success"}) as mock_delete_index: + + async def run_test(): + return await ElasticSearchService.full_delete_knowledge_base( + index_name="kb-1", + vdb_core=mock_vdb_core, + user_id="user-1", + ) + + result = asyncio.run(run_test()) + + self.assertEqual(result["minio_cleanup"]["total_files_found"], 0) + self.assertEqual(result["redis_cleanup"].get("errors"), []) + self.assertIn("redis_warnings", result) + self.assertIn("redis_warnings", result) + mock_list_files.assert_awaited_once() + mock_delete_index.assert_awaited_once() + + @patch('services.redis_service.get_redis_service') + def test_full_delete_knowledge_base_minio_and_redis_error(self, mock_get_redis): + """full_delete_knowledge_base logs minio summary and handles redis cleanup errors.""" + mock_vdb_core = MagicMock() + mock_redis = MagicMock() + # Redis cleanup will raise to hit error branch (lines 289-292) + mock_redis.delete_knowledgebase_records.side_effect = Exception("redis boom") + mock_get_redis.return_value = mock_redis + + files_payload = { + "files": [ + {"path_or_url": "obj-success", "source_type": "minio"}, + {"path_or_url": "obj-fail", "source_type": "minio"}, + ] + } + + # delete_file returns success for first, failure for second + with patch('backend.services.vectordatabase_service.ElasticSearchService.list_files', + new_callable=AsyncMock, return_value=files_payload) as mock_list_files, \ + patch('backend.services.vectordatabase_service.delete_file') as mock_delete_file, \ + patch('backend.services.vectordatabase_service.ElasticSearchService.delete_index', + new_callable=AsyncMock, return_value={"status": "success"}) as mock_delete_index: + mock_delete_file.side_effect = [ + {"success": True}, + {"success": False, "error": "minio failed"}, + ] + + async def run_test(): + return await ElasticSearchService.full_delete_knowledge_base( + index_name="kb-2", + vdb_core=mock_vdb_core, + user_id="user-2", + ) + + result = asyncio.run(run_test()) + + # MinIO summary should reflect one success and one failure (line 270 hit) + self.assertEqual(result["minio_cleanup"]["deleted_count"], 1) + self.assertEqual(result["minio_cleanup"]["failed_count"], 1) + # Redis cleanup error should be surfaced + self.assertIn("error", result["redis_cleanup"]) + mock_list_files.assert_awaited_once() + mock_delete_index.assert_awaited_once_with("kb-2", mock_vdb_core, "user-2") + + @patch('backend.services.vectordatabase_service.create_knowledge_record') + def test_create_knowledge_base_create_index_failure(self, mock_create_record): + """create_knowledge_base raises when index creation fails.""" + mock_create_record.return_value = { + "knowledge_id": 1, + "index_name": "1-uuid", + "knowledge_name": "kb" + } + self.mock_vdb_core.create_index.return_value = False + + with self.assertRaises(Exception) as exc: + ElasticSearchService.create_knowledge_base( + knowledge_name="kb", + embedding_dim=256, + vdb_core=self.mock_vdb_core, + user_id="user-1", + tenant_id="tenant-1", + ) + + self.assertIn("Failed to create index", str(exc.exception)) + + @patch('backend.services.vectordatabase_service.create_knowledge_record') + def test_create_knowledge_base_raises_on_exception(self, mock_create_record): + """create_knowledge_base wraps unexpected errors.""" + mock_create_record.return_value = { + "knowledge_id": 2, + "index_name": "2-uuid", + "knowledge_name": "kb2" + } + self.mock_vdb_core.create_index.side_effect = Exception("boom") + + with self.assertRaises(Exception) as exc: + ElasticSearchService.create_knowledge_base( + knowledge_name="kb2", + embedding_dim=128, + vdb_core=self.mock_vdb_core, + user_id="user-2", + tenant_id="tenant-2", + ) + + self.assertIn("Error creating knowledge base", str(exc.exception)) + + @patch('backend.services.vectordatabase_service.get_knowledge_record') + def test_index_documents_default_batch_without_tenant(self, mock_get_record): + """index_documents defaults embedding batch size to 10 when tenant is missing.""" + mock_get_record.return_value = None + self.mock_vdb_core.check_index_exists.return_value = True + self.mock_vdb_core.vectorize_documents.return_value = 1 + + data = [{ + "path_or_url": "p1", + "content": "c1", + "metadata": {"title": "t1"}, + }] + embedding = MagicMock() + embedding.model = "model-x" + + result = ElasticSearchService.index_documents( + embedding_model=embedding, + index_name="idx", + data=data, + vdb_core=self.mock_vdb_core, + ) + + self.assertTrue(result["success"]) + _, kwargs = self.mock_vdb_core.vectorize_documents.call_args + self.assertEqual(kwargs["embedding_batch_size"], 10) + + @patch('backend.services.vectordatabase_service.tenant_config_manager') + @patch('backend.services.vectordatabase_service.get_knowledge_record') + @patch('backend.services.vectordatabase_service.get_redis_service') + def test_index_documents_updates_final_progress(self, mock_get_redis, mock_get_record, mock_tenant_cfg): + """index_documents sends final progress update to Redis when task_id is provided.""" + mock_get_record.return_value = {"tenant_id": "tenant-1"} + mock_tenant_cfg.get_model_config.return_value = {"chunk_batch": 4} + mock_redis = MagicMock() + mock_get_redis.return_value = mock_redis + + self.mock_vdb_core.check_index_exists.return_value = True + self.mock_vdb_core.vectorize_documents.return_value = 2 + + data = [ + {"path_or_url": "p1", "content": "c1", "metadata": {}}, + {"path_or_url": "p2", "content": "c2", "metadata": {}}, + ] + + result = ElasticSearchService.index_documents( + embedding_model=self.mock_embedding, + index_name="idx", + data=data, + vdb_core=self.mock_vdb_core, + task_id="task-xyz", + ) + + self.assertTrue(result["success"]) + mock_redis.save_progress_info.assert_called() + last_call = mock_redis.save_progress_info.call_args_list[-1] + self.assertEqual(last_call[0], ("task-xyz", 2, 2)) + + @patch('backend.services.vectordatabase_service.get_redis_service') + @patch('backend.services.vectordatabase_service.get_knowledge_record') + @patch('backend.services.vectordatabase_service.tenant_config_manager') + def test_index_documents_progress_init_and_final_errors(self, mock_tenant_cfg, mock_get_record, mock_get_redis): + """index_documents should continue when progress save fails during init and final updates.""" + mock_get_record.return_value = {"tenant_id": "tenant-1"} + mock_tenant_cfg.get_model_config.return_value = {"chunk_batch": 4} + + mock_redis = MagicMock() + # First call (init) raises, second call (final) raises + mock_redis.save_progress_info.side_effect = [Exception("init fail"), Exception("final fail")] + mock_redis.is_task_cancelled.return_value = False + mock_get_redis.return_value = mock_redis + + self.mock_vdb_core.check_index_exists.return_value = True + self.mock_vdb_core.vectorize_documents.return_value = 1 + + data = [{"path_or_url": "p1", "content": "c1", "metadata": {}}] + + result = ElasticSearchService.index_documents( + embedding_model=self.mock_embedding, + index_name="idx", + data=data, + vdb_core=self.mock_vdb_core, + task_id="task-err", + ) + + self.assertTrue(result["success"]) + # two attempts to save progress (init and final) + self.assertEqual(mock_redis.save_progress_info.call_count, 2) + + @patch('backend.services.vectordatabase_service.get_all_files_status') + @patch('backend.services.vectordatabase_service.get_redis_service') + def test_list_files_handles_invalid_create_time_and_failed_tasks(self, mock_get_redis, mock_get_files_status): + """list_files handles invalid timestamps, progress overrides, and error info.""" + self.mock_vdb_core.get_documents_detail.return_value = [ + { + "path_or_url": "file1", + "filename": "file1.txt", + "file_size": 10, + "create_time": "invalid", + "chunk_count": 1 + } + ] + self.mock_vdb_core.client.count.return_value = {"count": 7} + + mock_get_files_status.return_value = { + "file1": { + "state": "PROCESS_FAILED", + "latest_task_id": "task-1", + "processed_chunks": 1, + "total_chunks": 5, + "source_type": "minio", + "original_filename": "file1.txt" + } + } + + mock_redis = MagicMock() + mock_redis.get_progress_info.return_value = { + "processed_chunks": 2, + "total_chunks": 5 + } + mock_redis.get_error_info.return_value = "boom error" + mock_get_redis.return_value = mock_redis + + async def run_test(): + return await ElasticSearchService.list_files( + index_name="idx", + include_chunks=False, + vdb_core=self.mock_vdb_core + ) + + result = asyncio.run(run_test()) + self.assertEqual(len(result["files"]), 1) + file_info = result["files"][0] + self.assertEqual(file_info["chunk_count"], 7) + self.assertEqual(file_info["file_size"], 10) + self.assertEqual(file_info["status"], "PROCESS_FAILED") + self.assertEqual(file_info["processed_chunk_num"], 2) + self.assertEqual(file_info["total_chunk_num"], 5) + self.assertEqual(file_info["error_reason"], "boom error") + self.assertIsInstance(file_info["create_time"], int) + + @patch('backend.services.vectordatabase_service.get_all_files_status') + @patch('backend.services.vectordatabase_service.get_redis_service') + def test_list_files_warning_and_progress_error_branches(self, mock_get_redis, mock_get_files_status): + """list_files covers chunk count warning, file size error, progress overrides, and redis failures.""" + # Existing ES file triggers count warning (lines 749-750 and 910-916) + self.mock_vdb_core.get_documents_detail.return_value = [ + { + "path_or_url": "file-es", + "filename": "file-es.txt", + "file_size": 5, + "create_time": "2024-01-01T00:00:00", + "chunk_count": 1 + } + ] + # First count call for ES file, second for completed file at include_chunks=False + self.mock_vdb_core.client.count.side_effect = [ + Exception("count fail initial"), + Exception("count fail final"), + ] + + # Two tasks from Celery status to exercise progress success and failure + mock_get_files_status.return_value = { + "file-processing": { + "state": "PROCESSING", + "latest_task_id": "t1", + "source_type": "minio", + "original_filename": "fp.txt", + "processed_chunks": 1, + "total_chunks": 3, + }, + "file-failed": { + "state": "PROCESS_FAILED", + "latest_task_id": "t2", + "source_type": "minio", + "original_filename": "ff.txt", + }, + } + + mock_redis = MagicMock() + # Progress info: first returns dict, second raises to hit lines 815-816 + mock_redis.get_progress_info.side_effect = [ + {"processed_chunks": 2, "total_chunks": 4}, + Exception("progress boom"), + ] + # get_error_info raises to hit 847-848 + mock_redis.get_error_info.side_effect = Exception("error info boom") + mock_get_redis.return_value = mock_redis + + with patch('backend.services.vectordatabase_service.get_file_size', side_effect=Exception("size boom")): + async def run_test(): + return await ElasticSearchService.list_files( + index_name="idx", + include_chunks=False, + vdb_core=self.mock_vdb_core + ) + + result = asyncio.run(run_test()) + + # Ensure both ES file and processing files are returned + paths = {f["path_or_url"] for f in result["files"]} + self.assertIn("file-es", paths) + self.assertIn("file-processing", paths) + self.assertIn("file-failed", paths) + # Processing file gets progress override + proc_file = next(f for f in result["files"] if f["path_or_url"] == "file-processing") + self.assertEqual(proc_file["processed_chunk_num"], 2) + self.assertEqual(proc_file["total_chunk_num"], 4) + # Failed file retains default chunk_count fallback + failed_file = next(f for f in result["files"] if f["path_or_url"] == "file-failed") + self.assertEqual(failed_file.get("chunk_count", 0), 0) + + @patch('backend.services.vectordatabase_service.get_all_files_status', return_value={}) + def test_list_files_with_chunks_updates_chunk_count(self, mock_get_files_status): + """list_files include_chunks path refreshes chunk counts.""" + self.mock_vdb_core.get_documents_detail.return_value = [ + { + "path_or_url": "file1", + "filename": "file1.txt", + "file_size": 10, + "create_time": "2024-01-01T00:00:00" + } + ] + self.mock_vdb_core.multi_search.return_value = { + "responses": [ + { + "hits": { + "hits": [ + {"_source": { + "id": "doc1", + "title": "t", + "content": "c", + "create_time": "2024-01-01T00:00:00" + }} + ] + } + } + ] + } + self.mock_vdb_core.client.count.return_value = {"count": 2} + + async def run_test(): + return await ElasticSearchService.list_files( + index_name="idx", + include_chunks=True, + vdb_core=self.mock_vdb_core + ) + + result = asyncio.run(run_test()) + file_info = result["files"][0] + self.assertEqual(file_info["chunk_count"], 2) + self.assertEqual(len(file_info["chunks"]), 1) + + def test_summary_index_name_streams_generator_error(self): + """summary_index_name streams error payloads when generator fails.""" + class BadIterable: + def __iter__(self): + raise RuntimeError("stream failure") + + with patch('utils.document_vector_utils.process_documents_for_clustering') as mock_process_docs, \ + patch('utils.document_vector_utils.kmeans_cluster_documents') as mock_cluster, \ + patch('utils.document_vector_utils.summarize_clusters_map_reduce') as mock_summarize, \ + patch('utils.document_vector_utils.merge_cluster_summaries', return_value=BadIterable()): + + mock_process_docs.return_value = ( + {"doc1": {"chunks": [{"content": "x"}]}}, + {"doc1": MagicMock()} + ) + mock_cluster.return_value = {"doc1": 0} + mock_summarize.return_value = {0: "summary"} + + async def run_test(): + response = await self.es_service.summary_index_name( + index_name="idx", + batch_size=100, + vdb_core=self.mock_vdb_core, + language="en", + model_id=None, + tenant_id="tenant-1", + ) + messages = [] + async for chunk in response.body_iterator: + messages.append(chunk) + break + return messages + + messages = asyncio.run(run_test()) + self.assertTrue(any("error" in msg for msg in messages)) + if __name__ == '__main__': unittest.main() diff --git a/test/backend/test_cluster_summarization.py b/test/backend/test_cluster_summarization.py index 9fd0a3b91..1e9bc1658 100644 --- a/test/backend/test_cluster_summarization.py +++ b/test/backend/test_cluster_summarization.py @@ -10,11 +10,39 @@ import numpy as np import pytest -# Add backend to path +# Mock consts module before patching backend.database.client to avoid ImportError +# backend.database.client imports from consts.const, so we need to mock it first +consts_mock = MagicMock() +consts_const_mock = MagicMock() +# Set required constants that backend.database.client might use +consts_const_mock.MINIO_ENDPOINT = "http://localhost:9000" +consts_const_mock.MINIO_ACCESS_KEY = "test_access_key" +consts_const_mock.MINIO_SECRET_KEY = "test_secret_key" +consts_const_mock.MINIO_REGION = "us-east-1" +consts_const_mock.MINIO_DEFAULT_BUCKET = "test-bucket" +consts_const_mock.POSTGRES_HOST = "localhost" +consts_const_mock.POSTGRES_USER = "test_user" +consts_const_mock.NEXENT_POSTGRES_PASSWORD = "test_password" +consts_const_mock.POSTGRES_DB = "test_db" +consts_const_mock.POSTGRES_PORT = 5432 +consts_const_mock.LANGUAGE = {"ZH": "zh", "EN": "en"} +consts_mock.const = consts_const_mock +sys.modules['consts'] = consts_mock +sys.modules['consts.const'] = consts_const_mock + +# Add backend to path before patching backend modules current_dir = os.path.dirname(os.path.abspath(__file__)) backend_dir = os.path.abspath(os.path.join(current_dir, "../../backend")) sys.path.insert(0, backend_dir) +# Patch storage factory and MinIO config validation to avoid errors during initialization +# These patches must be started before any imports that use MinioClient +storage_client_mock = MagicMock() +minio_client_mock = MagicMock() +patch('nexent.storage.storage_client_factory.create_storage_client_from_config', return_value=storage_client_mock).start() +patch('nexent.storage.minio_config.MinIOStorageConfig.validate', lambda self: None).start() +patch('backend.database.client.MinioClient', return_value=minio_client_mock).start() + from backend.utils.document_vector_utils import ( extract_cluster_content, summarize_cluster, diff --git a/test/backend/test_config_service.py b/test/backend/test_config_service.py index a30e86bd7..9935797cc 100644 --- a/test/backend/test_config_service.py +++ b/test/backend/test_config_service.py @@ -431,5 +431,45 @@ async def test_startup_initialization_with_custom_version(self, mock_logger, moc assert version_logged, "Custom APP version should be logged" +class TestTenantConfigService: + """Unit tests for tenant_config_service helpers""" + + @patch('backend.services.tenant_config_service.get_selected_knowledge_list') + def test_build_knowledge_name_mapping_prefers_knowledge_name(self, mock_get_selected): + """Ensure knowledge_name is used as key when present.""" + mock_get_selected.return_value = [ + {"knowledge_name": "User Docs", "index_name": "index_user_docs"}, + {"knowledge_name": "API Docs", "index_name": "index_api_docs"}, + ] + + from backend.services.tenant_config_service import build_knowledge_name_mapping + + mapping = build_knowledge_name_mapping(tenant_id="t1", user_id="u1") + + assert mapping == { + "User Docs": "index_user_docs", + "API Docs": "index_api_docs", + } + mock_get_selected.assert_called_once_with(tenant_id="t1", user_id="u1") + + @patch('backend.services.tenant_config_service.get_selected_knowledge_list') + def test_build_knowledge_name_mapping_fallbacks_to_index_name(self, mock_get_selected): + """Fallback to index_name when knowledge_name is missing.""" + mock_get_selected.return_value = [ + {"index_name": "index_fallback_only"}, + {"knowledge_name": None, "index_name": "index_none_name"}, + ] + + from backend.services.tenant_config_service import build_knowledge_name_mapping + + mapping = build_knowledge_name_mapping(tenant_id="t2", user_id="u2") + + assert mapping == { + "index_fallback_only": "index_fallback_only", + "index_none_name": "index_none_name", + } + mock_get_selected.assert_called_once_with(tenant_id="t2", user_id="u2") + + if __name__ == '__main__': pytest.main() diff --git a/test/backend/test_document_vector_integration.py b/test/backend/test_document_vector_integration.py index 015818d32..8e05abe86 100644 --- a/test/backend/test_document_vector_integration.py +++ b/test/backend/test_document_vector_integration.py @@ -11,11 +11,38 @@ import numpy as np import pytest -# Add backend to path +# Mock consts module before patching backend.database.client to avoid ImportError +# backend.database.client imports from consts.const, so we need to mock it first +consts_mock = MagicMock() +consts_const_mock = MagicMock() +# Set required constants that backend.database.client might use +consts_const_mock.MINIO_ENDPOINT = "http://localhost:9000" +consts_const_mock.MINIO_ACCESS_KEY = "test_access_key" +consts_const_mock.MINIO_SECRET_KEY = "test_secret_key" +consts_const_mock.MINIO_REGION = "us-east-1" +consts_const_mock.MINIO_DEFAULT_BUCKET = "test-bucket" +consts_const_mock.POSTGRES_HOST = "localhost" +consts_const_mock.POSTGRES_USER = "test_user" +consts_const_mock.NEXENT_POSTGRES_PASSWORD = "test_password" +consts_const_mock.POSTGRES_DB = "test_db" +consts_const_mock.POSTGRES_PORT = 5432 +consts_mock.const = consts_const_mock +sys.modules['consts'] = consts_mock +sys.modules['consts.const'] = consts_const_mock + +# Add backend to path before patching backend modules current_dir = os.path.dirname(os.path.abspath(__file__)) backend_dir = os.path.abspath(os.path.join(current_dir, "../../backend")) sys.path.insert(0, backend_dir) +# Patch storage factory and MinIO config validation to avoid errors during initialization +# These patches must be started before any imports that use MinioClient +storage_client_mock = MagicMock() +minio_client_mock = MagicMock() +patch('nexent.storage.storage_client_factory.create_storage_client_from_config', return_value=storage_client_mock).start() +patch('nexent.storage.minio_config.MinIOStorageConfig.validate', lambda self: None).start() +patch('backend.database.client.MinioClient', return_value=minio_client_mock).start() + from backend.utils.document_vector_utils import ( calculate_document_embedding, auto_determine_k, diff --git a/test/backend/test_document_vector_utils.py b/test/backend/test_document_vector_utils.py index f03ed3346..1b4f89997 100644 --- a/test/backend/test_document_vector_utils.py +++ b/test/backend/test_document_vector_utils.py @@ -10,11 +10,39 @@ import numpy as np import pytest -# Add backend to path +# Mock consts module before patching backend.database.client to avoid ImportError +# backend.database.client imports from consts.const, so we need to mock it first +consts_mock = MagicMock() +consts_const_mock = MagicMock() +# Set required constants that backend.database.client might use +consts_const_mock.MINIO_ENDPOINT = "http://localhost:9000" +consts_const_mock.MINIO_ACCESS_KEY = "test_access_key" +consts_const_mock.MINIO_SECRET_KEY = "test_secret_key" +consts_const_mock.MINIO_REGION = "us-east-1" +consts_const_mock.MINIO_DEFAULT_BUCKET = "test-bucket" +consts_const_mock.POSTGRES_HOST = "localhost" +consts_const_mock.POSTGRES_USER = "test_user" +consts_const_mock.NEXENT_POSTGRES_PASSWORD = "test_password" +consts_const_mock.POSTGRES_DB = "test_db" +consts_const_mock.POSTGRES_PORT = 5432 +consts_const_mock.LANGUAGE = {"ZH": "zh", "EN": "en"} +consts_mock.const = consts_const_mock +sys.modules['consts'] = consts_mock +sys.modules['consts.const'] = consts_const_mock + +# Add backend to path before patching backend modules current_dir = os.path.dirname(os.path.abspath(__file__)) backend_dir = os.path.abspath(os.path.join(current_dir, "../../backend")) sys.path.insert(0, backend_dir) +# Patch storage factory and MinIO config validation to avoid errors during initialization +# These patches must be started before any imports that use MinioClient +storage_client_mock = MagicMock() +minio_client_mock = MagicMock() +patch('nexent.storage.storage_client_factory.create_storage_client_from_config', return_value=storage_client_mock).start() +patch('nexent.storage.minio_config.MinIOStorageConfig.validate', lambda self: None).start() +patch('backend.database.client.MinioClient', return_value=minio_client_mock).start() + from backend.utils.document_vector_utils import ( calculate_document_embedding, auto_determine_k, @@ -226,6 +254,28 @@ def test_summarize_document_with_model_placeholder(self): assert isinstance(result, str) assert len(result) > 0 + def test_summarize_document_with_model_success(self): + """Test document summarization when model config exists and LLM returns value""" + with patch('backend.utils.document_vector_utils.get_model_by_model_id') as mock_get_model, \ + patch('backend.utils.document_vector_utils.call_llm_for_system_prompt') as mock_llm: + mock_get_model.return_value = {"id": 1} + mock_llm.return_value = "Generated summary\n" + + result = summarize_document( + document_content="LLM content", + filename="doc.pdf", + language="en", + max_words=50, + model_id=1, + tenant_id="tenant" + ) + + assert result == "Generated summary" + mock_llm.assert_called_once() + call_args = mock_llm.call_args.kwargs + assert call_args["model_id"] == 1 + assert call_args["tenant_id"] == "tenant" + class TestSummarizeCluster: """Test cluster summarization""" @@ -250,6 +300,27 @@ def test_summarize_cluster_with_model_placeholder(self): assert isinstance(result, str) assert len(result) > 0 + def test_summarize_cluster_with_model_success(self): + """Test cluster summarization when model config exists and LLM returns value""" + with patch('backend.utils.document_vector_utils.get_model_by_model_id') as mock_get_model, \ + patch('backend.utils.document_vector_utils.call_llm_for_system_prompt') as mock_llm: + mock_get_model.return_value = {"id": 1} + mock_llm.return_value = "Cluster summary text " + + result = summarize_cluster( + document_summaries=["Doc 1 summary", "Doc 2 summary"], + language="en", + max_words=120, + model_id=1, + tenant_id="tenant" + ) + + assert result == "Cluster summary text" + mock_llm.assert_called_once() + call_args = mock_llm.call_args.kwargs + assert call_args["model_id"] == 1 + assert call_args["tenant_id"] == "tenant" + class TestSummarizeClustersMapReduce: """Test map-reduce cluster summarization""" diff --git a/test/backend/test_document_vector_utils_coverage.py b/test/backend/test_document_vector_utils_coverage.py index b442e47e4..82ac1d646 100644 --- a/test/backend/test_document_vector_utils_coverage.py +++ b/test/backend/test_document_vector_utils_coverage.py @@ -10,11 +10,38 @@ import numpy as np import pytest -# Add backend to path +# Mock consts module before patching backend.database.client to avoid ImportError +# backend.database.client imports from consts.const, so we need to mock it first +consts_mock = MagicMock() +consts_const_mock = MagicMock() +# Set required constants that backend.database.client might use +consts_const_mock.MINIO_ENDPOINT = "http://localhost:9000" +consts_const_mock.MINIO_ACCESS_KEY = "test_access_key" +consts_const_mock.MINIO_SECRET_KEY = "test_secret_key" +consts_const_mock.MINIO_REGION = "us-east-1" +consts_const_mock.MINIO_DEFAULT_BUCKET = "test-bucket" +consts_const_mock.POSTGRES_HOST = "localhost" +consts_const_mock.POSTGRES_USER = "test_user" +consts_const_mock.NEXENT_POSTGRES_PASSWORD = "test_password" +consts_const_mock.POSTGRES_DB = "test_db" +consts_const_mock.POSTGRES_PORT = 5432 +consts_mock.const = consts_const_mock +sys.modules['consts'] = consts_mock +sys.modules['consts.const'] = consts_const_mock + +# Add backend to path before patching backend modules current_dir = os.path.dirname(os.path.abspath(__file__)) backend_dir = os.path.abspath(os.path.join(current_dir, "../../backend")) sys.path.insert(0, backend_dir) +# Patch storage factory and MinIO config validation to avoid errors during initialization +# These patches must be started before any imports that use MinioClient +storage_client_mock = MagicMock() +minio_client_mock = MagicMock() +patch('nexent.storage.storage_client_factory.create_storage_client_from_config', return_value=storage_client_mock).start() +patch('nexent.storage.minio_config.MinIOStorageConfig.validate', lambda self: None).start() +patch('backend.database.client.MinioClient', return_value=minio_client_mock).start() + from backend.utils.document_vector_utils import ( get_documents_from_es, process_documents_for_clustering, diff --git a/test/backend/test_summary_formatting.py b/test/backend/test_summary_formatting.py index 31b656e55..22f8dec36 100644 --- a/test/backend/test_summary_formatting.py +++ b/test/backend/test_summary_formatting.py @@ -5,10 +5,38 @@ import pytest import sys import os +from unittest.mock import MagicMock, patch -# Add backend to path +# Mock consts module before patching backend.database.client to avoid ImportError +# backend.database.client imports from consts.const, so we need to mock it first +consts_mock = MagicMock() +consts_const_mock = MagicMock() +# Set required constants that backend.database.client might use +consts_const_mock.MINIO_ENDPOINT = "http://localhost:9000" +consts_const_mock.MINIO_ACCESS_KEY = "test_access_key" +consts_const_mock.MINIO_SECRET_KEY = "test_secret_key" +consts_const_mock.MINIO_REGION = "us-east-1" +consts_const_mock.MINIO_DEFAULT_BUCKET = "test-bucket" +consts_const_mock.POSTGRES_HOST = "localhost" +consts_const_mock.POSTGRES_USER = "test_user" +consts_const_mock.NEXENT_POSTGRES_PASSWORD = "test_password" +consts_const_mock.POSTGRES_DB = "test_db" +consts_const_mock.POSTGRES_PORT = 5432 +consts_mock.const = consts_const_mock +sys.modules['consts'] = consts_mock +sys.modules['consts.const'] = consts_const_mock + +# Add backend to path before patching backend modules sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'backend')) +# Patch storage factory and MinIO config validation to avoid errors during initialization +# These patches must be started before any imports that use MinioClient +storage_client_mock = MagicMock() +minio_client_mock = MagicMock() +patch('nexent.storage.storage_client_factory.create_storage_client_from_config', return_value=storage_client_mock).start() +patch('nexent.storage.minio_config.MinIOStorageConfig.validate', lambda self: None).start() +patch('backend.database.client.MinioClient', return_value=minio_client_mock).start() + from utils.document_vector_utils import merge_cluster_summaries diff --git a/test/backend/utils/test_file_management_utils.py b/test/backend/utils/test_file_management_utils.py index eaa3a1261..02553db8f 100644 --- a/test/backend/utils/test_file_management_utils.py +++ b/test/backend/utils/test_file_management_utils.py @@ -300,6 +300,123 @@ async def test_get_all_files_status_connect_error_and_non200(fmu, monkeypatch): assert out2 == {} +@pytest.mark.asyncio +async def test_get_all_files_status_no_tasks_returns_empty(fmu, monkeypatch): + fake_client = _FakeAsyncClient(_Resp(200, [])) + monkeypatch.setattr(fmu, "httpx", types.SimpleNamespace(AsyncClient=lambda: fake_client)) + + out = await fmu.get_all_files_status("idx-empty") + assert out == {} + + +@pytest.mark.asyncio +async def test_get_all_files_status_forward_updates_and_redis_progress(fmu, monkeypatch): + tasks_list = [ + { + "id": "10", + "task_name": "process", + "index_name": "idx", + "path_or_url": "/p2", + "original_filename": "f2", + "source_type": "local", + "status": "SUCCESS", + "created_at": 1, + }, + { + "id": "20", + "task_name": "forward", + "index_name": "idx", + "path_or_url": "/p2", + "original_filename": "f2", + "source_type": "local", + "status": "STARTED", + "created_at": 5, # later than process to trigger forward branch + }, + ] + fake_client = _FakeAsyncClient(_Resp(200, tasks_list)) + monkeypatch.setattr(fmu, "httpx", types.SimpleNamespace(AsyncClient=lambda: fake_client)) + async def _fake_convert(*a, **k): + return "FORWARDING" + monkeypatch.setattr(fmu, "_convert_to_custom_state", _fake_convert) + + # Stub redis_service with progress info + services_pkg = types.ModuleType("services") + services_pkg.__path__ = [] + sys.modules["services"] = services_pkg + redis_mod = types.ModuleType("services.redis_service") + redis_mod.get_redis_service = lambda: types.SimpleNamespace( + get_progress_info=lambda task_id: {"processed_chunks": 7, "total_chunks": 9} + ) + sys.modules["services.redis_service"] = redis_mod + + out = await fmu.get_all_files_status("idx") + assert out["/p2"]["state"] == "FORWARDING" + assert out["/p2"]["latest_task_id"] == "20" + assert out["/p2"]["processed_chunks"] == 7 + assert out["/p2"]["total_chunks"] == 9 + + +@pytest.mark.asyncio +async def test_get_all_files_status_redis_progress_exception(fmu, monkeypatch): + tasks_list = [ + { + "id": "30", + "task_name": "forward", + "index_name": "idx", + "path_or_url": "/p3", + "original_filename": "f3", + "source_type": "local", + "status": "STARTED", + "created_at": 2, + }, + ] + fake_client = _FakeAsyncClient(_Resp(200, tasks_list)) + monkeypatch.setattr(fmu, "httpx", types.SimpleNamespace(AsyncClient=lambda: fake_client)) + async def _fake_convert(*a, **k): + return "FORWARDING" + monkeypatch.setattr(fmu, "_convert_to_custom_state", _fake_convert) + + # Redis service raising exception to hit exception path + services_pkg = types.ModuleType("services") + services_pkg.__path__ = [] + sys.modules["services"] = services_pkg + redis_mod = types.ModuleType("services.redis_service") + def _boom(): + raise RuntimeError("redis down") + redis_mod.get_redis_service = lambda: types.SimpleNamespace(get_progress_info=lambda task_id: _boom()) + sys.modules["services.redis_service"] = redis_mod + + out = await fmu.get_all_files_status("idx") + assert out["/p3"]["state"] == "FORWARDING" + assert out["/p3"]["processed_chunks"] is None + assert out["/p3"]["total_chunks"] is None + + +@pytest.mark.asyncio +async def test_get_all_files_status_outer_exception_returns_empty(fmu, monkeypatch): + tasks_list = [ + { + "id": "40", + "task_name": "process", + "index_name": "idx", + "path_or_url": "/p4", + "original_filename": "f4", + "source_type": "local", + "status": "SUCCESS", + "created_at": 1, + }, + ] + fake_client = _FakeAsyncClient(_Resp(200, tasks_list)) + monkeypatch.setattr(fmu, "httpx", types.SimpleNamespace(AsyncClient=lambda: fake_client)) + + def _boom(*a, **k): + raise RuntimeError("convert failed") + monkeypatch.setattr(fmu, "_convert_to_custom_state", _boom) + + out = await fmu.get_all_files_status("idx") + assert out == {} + + # -------------------- _convert_to_custom_state -------------------- @@ -379,3 +496,211 @@ def test_get_file_size_invalid_source_type(fmu): assert fmu.get_file_size("http", "http://x") == 0 +# -------------------- Additional coverage for get_all_files_status -------------------- + + +@pytest.mark.asyncio +async def test_get_all_files_status_forward_created_at_not_greater(fmu, monkeypatch): + """Test forward task with created_at not greater than latest_forward_created_at (line 195)""" + tasks_list = [ + { + "id": "20", + "task_name": "forward", + "index_name": "idx", + "path_or_url": "/p5", + "original_filename": "f5", + "source_type": "local", + "status": "STARTED", + "created_at": 5, + }, + { + "id": "21", + "task_name": "forward", + "index_name": "idx", + "path_or_url": "/p5", + "original_filename": "f5", + "source_type": "local", + "status": "SUCCESS", + "created_at": 3, # Less than previous forward task, should not update + }, + ] + fake_client = _FakeAsyncClient(_Resp(200, tasks_list)) + monkeypatch.setattr(fmu, "httpx", types.SimpleNamespace(AsyncClient=lambda: fake_client)) + async def _fake_convert(*a, **k): + return "FORWARDING" + monkeypatch.setattr(fmu, "_convert_to_custom_state", _fake_convert) + + out = await fmu.get_all_files_status("idx") + # Should use the first forward task (id=20) as latest since it has higher created_at + assert out["/p5"]["latest_task_id"] == "20" + + +@pytest.mark.asyncio +async def test_get_all_files_status_empty_task_id(fmu, monkeypatch): + """Test when task_id is empty string (line 221 - not entering if branch)""" + tasks_list = [ + { + "id": "", # Empty task_id + "task_name": "process", + "index_name": "idx", + "path_or_url": "/p6", + "original_filename": "f6", + "source_type": "local", + "status": "SUCCESS", + "created_at": 1, + }, + ] + fake_client = _FakeAsyncClient(_Resp(200, tasks_list)) + monkeypatch.setattr(fmu, "httpx", types.SimpleNamespace(AsyncClient=lambda: fake_client)) + async def _fake_convert(*a, **k): + return "COMPLETED" + monkeypatch.setattr(fmu, "_convert_to_custom_state", _fake_convert) + + # Stub redis_service to ensure it's not called + services_pkg = types.ModuleType("services") + services_pkg.__path__ = [] + sys.modules["services"] = services_pkg + redis_mod = types.ModuleType("services.redis_service") + redis_called = {"called": False} + def _track_call(task_id): + redis_called["called"] = True + return {} + redis_mod.get_redis_service = lambda: types.SimpleNamespace( + get_progress_info=_track_call + ) + sys.modules["services.redis_service"] = redis_mod + + out = await fmu.get_all_files_status("idx") + assert out["/p6"]["latest_task_id"] == "" + # Redis should not be called when task_id is empty + assert redis_called["called"] is False + + +@pytest.mark.asyncio +async def test_get_all_files_status_redis_progress_info_none(fmu, monkeypatch): + """Test when progress_info is None (line 226, 237 - entering else branch)""" + tasks_list = [ + { + "id": "50", + "task_name": "forward", + "index_name": "idx", + "path_or_url": "/p7", + "original_filename": "f7", + "source_type": "local", + "status": "STARTED", + "created_at": 1, + "processed_chunks": 5, + "total_chunks": 10, + }, + ] + fake_client = _FakeAsyncClient(_Resp(200, tasks_list)) + monkeypatch.setattr(fmu, "httpx", types.SimpleNamespace(AsyncClient=lambda: fake_client)) + async def _fake_convert(*a, **k): + return "FORWARDING" + monkeypatch.setattr(fmu, "_convert_to_custom_state", _fake_convert) + + # Redis service returning None (line 226, 237) + services_pkg = types.ModuleType("services") + services_pkg.__path__ = [] + sys.modules["services"] = services_pkg + redis_mod = types.ModuleType("services.redis_service") + redis_mod.get_redis_service = lambda: types.SimpleNamespace( + get_progress_info=lambda task_id: None # Returns None to trigger else branch + ) + sys.modules["services.redis_service"] = redis_mod + + out = await fmu.get_all_files_status("idx") + assert out["/p7"]["state"] == "FORWARDING" + assert out["/p7"]["latest_task_id"] == "50" + # Should use task state values when progress_info is None + assert out["/p7"]["processed_chunks"] == 5 + assert out["/p7"]["total_chunks"] == 10 + + +@pytest.mark.asyncio +async def test_get_all_files_status_redis_processed_chunks_none(fmu, monkeypatch): + """Test when redis_processed is None (line 230 - not entering if branch)""" + tasks_list = [ + { + "id": "60", + "task_name": "forward", + "index_name": "idx", + "path_or_url": "/p8", + "original_filename": "f8", + "source_type": "local", + "status": "STARTED", + "created_at": 1, + "processed_chunks": 3, + "total_chunks": 8, + }, + ] + fake_client = _FakeAsyncClient(_Resp(200, tasks_list)) + monkeypatch.setattr(fmu, "httpx", types.SimpleNamespace(AsyncClient=lambda: fake_client)) + async def _fake_convert(*a, **k): + return "FORWARDING" + monkeypatch.setattr(fmu, "_convert_to_custom_state", _fake_convert) + + # Redis service returning progress_info with processed_chunks as None (line 230) + services_pkg = types.ModuleType("services") + services_pkg.__path__ = [] + sys.modules["services"] = services_pkg + redis_mod = types.ModuleType("services.redis_service") + redis_mod.get_redis_service = lambda: types.SimpleNamespace( + get_progress_info=lambda task_id: { + "processed_chunks": None, # None to skip line 230 if branch + "total_chunks": 15 + } + ) + sys.modules["services.redis_service"] = redis_mod + + out = await fmu.get_all_files_status("idx") + assert out["/p8"]["state"] == "FORWARDING" + # processed_chunks should remain from task state (3) since redis_processed is None + assert out["/p8"]["processed_chunks"] == 3 + # total_chunks should be updated from Redis (15) + assert out["/p8"]["total_chunks"] == 15 + + +@pytest.mark.asyncio +async def test_get_all_files_status_redis_total_chunks_none(fmu, monkeypatch): + """Test when redis_total is None (line 232 - not entering if branch)""" + tasks_list = [ + { + "id": "70", + "task_name": "forward", + "index_name": "idx", + "path_or_url": "/p9", + "original_filename": "f9", + "source_type": "local", + "status": "STARTED", + "created_at": 1, + "processed_chunks": 4, + "total_chunks": 12, + }, + ] + fake_client = _FakeAsyncClient(_Resp(200, tasks_list)) + monkeypatch.setattr(fmu, "httpx", types.SimpleNamespace(AsyncClient=lambda: fake_client)) + async def _fake_convert(*a, **k): + return "FORWARDING" + monkeypatch.setattr(fmu, "_convert_to_custom_state", _fake_convert) + + # Redis service returning progress_info with total_chunks as None (line 232) + services_pkg = types.ModuleType("services") + services_pkg.__path__ = [] + sys.modules["services"] = services_pkg + redis_mod = types.ModuleType("services.redis_service") + redis_mod.get_redis_service = lambda: types.SimpleNamespace( + get_progress_info=lambda task_id: { + "processed_chunks": 6, + "total_chunks": None # None to skip line 232 if branch + } + ) + sys.modules["services.redis_service"] = redis_mod + + out = await fmu.get_all_files_status("idx") + assert out["/p9"]["state"] == "FORWARDING" + # processed_chunks should be updated from Redis (6) + assert out["/p9"]["processed_chunks"] == 6 + # total_chunks should remain from task state (12) since redis_total is None + assert out["/p9"]["total_chunks"] == 12 + diff --git a/test/backend/utils/test_llm_utils.py b/test/backend/utils/test_llm_utils.py index 545bdf776..50857e91b 100644 --- a/test/backend/utils/test_llm_utils.py +++ b/test/backend/utils/test_llm_utils.py @@ -74,7 +74,7 @@ class TestCallLLMForSystemPrompt(unittest.TestCase): def setUp(self): self.test_model_id = 1 - @patch('backend.utils.llm_utils.OpenAIServerModel') + @patch('backend.utils.llm_utils.OpenAIModel') @patch('backend.utils.llm_utils.get_model_name_from_config') @patch('backend.utils.llm_utils.get_model_by_model_id') def test_call_llm_for_system_prompt_success( @@ -118,7 +118,7 @@ def test_call_llm_for_system_prompt_success( top_p=0.95, ) - @patch('backend.utils.llm_utils.OpenAIServerModel') + @patch('backend.utils.llm_utils.OpenAIModel') @patch('backend.utils.llm_utils.get_model_name_from_config') @patch('backend.utils.llm_utils.get_model_by_model_id') def test_call_llm_for_system_prompt_exception( diff --git a/test/sdk/core/agents/test_core_agent.py b/test/sdk/core/agents/test_core_agent.py index cb6240893..54b725620 100644 --- a/test/sdk/core/agents/test_core_agent.py +++ b/test/sdk/core/agents/test_core_agent.py @@ -1,3 +1,5 @@ +import json + import pytest from unittest.mock import MagicMock, patch from threading import Event @@ -14,22 +16,98 @@ def __init__(self, message): super().__init__(message) +class MockAgentMaxStepsError(Exception): + pass + + # Mock for smolagents and its sub-modules mock_smolagents = MagicMock() -mock_smolagents.ActionStep = MagicMock() -mock_smolagents.TaskStep = MagicMock() -mock_smolagents.SystemPromptStep = MagicMock() mock_smolagents.AgentError = MockAgentError mock_smolagents.handle_agent_output_types = MagicMock( return_value="handled_output") +mock_smolagents.utils.AgentMaxStepsError = MockAgentMaxStepsError + +# Create proper class types for isinstance checks (not MagicMock) +class MockActionStep: + def __init__(self, *args, **kwargs): + self.step_number = kwargs.get('step_number', 1) + self.timing = kwargs.get('timing', None) + self.observations_images = kwargs.get('observations_images', None) + self.model_input_messages = None + self.model_output_message = None + self.model_output = None + self.token_usage = None + self.code_action = None + self.tool_calls = None + self.observations = None + self.action_output = None + self.is_final_answer = False + self.error = None + +class MockTaskStep: + def __init__(self, *args, **kwargs): + self.task = kwargs.get('task', '') + self.task_images = kwargs.get('task_images', None) + +class MockSystemPromptStep: + def __init__(self, *args, **kwargs): + self.system_prompt = kwargs.get('system_prompt', '') + +class MockFinalAnswerStep: + def __init__(self, *args, **kwargs): + # Handle both positional and keyword arguments + if args: + self.output = args[0] + else: + self.output = kwargs.get('output', '') + +class MockPlanningStep: + def __init__(self, *args, **kwargs): + self.token_usage = kwargs.get('token_usage', None) + +class MockActionOutput: + def __init__(self, *args, **kwargs): + self.output = kwargs.get('output', None) + self.is_final_answer = kwargs.get('is_final_answer', False) + +class MockRunResult: + def __init__(self, *args, **kwargs): + self.output = kwargs.get('output', None) + self.token_usage = kwargs.get('token_usage', None) + self.steps = kwargs.get('steps', []) + self.timing = kwargs.get('timing', None) + self.state = kwargs.get('state', 'success') + +class MockCodeOutput: + """Mock object returned by python_executor.""" + def __init__(self, output=None, logs="", is_final_answer=False): + self.output = output + self.logs = logs + self.is_final_answer = is_final_answer + +# Assign proper classes to mock_smolagents +mock_smolagents.ActionStep = MockActionStep +mock_smolagents.TaskStep = MockTaskStep +mock_smolagents.SystemPromptStep = MockSystemPromptStep # Create dummy smolagents sub-modules for sub_mod in ["agents", "memory", "models", "monitoring", "utils", "local_python_executor"]: mock_module = MagicMock() setattr(mock_smolagents, sub_mod, mock_module) +# Assign classes to memory submodule +mock_smolagents.memory.ActionStep = MockActionStep +mock_smolagents.memory.TaskStep = MockTaskStep +mock_smolagents.memory.SystemPromptStep = MockSystemPromptStep +mock_smolagents.memory.FinalAnswerStep = MockFinalAnswerStep +mock_smolagents.memory.PlanningStep = MockPlanningStep +mock_smolagents.memory.ToolCall = MagicMock + +# Assign classes to agents submodule mock_smolagents.agents.CodeAgent = MagicMock +mock_smolagents.agents.ActionOutput = MockActionOutput +mock_smolagents.agents.RunResult = MockRunResult # Provide actual implementations for commonly used utils functions @@ -72,6 +150,23 @@ def mock_truncate_content(content, max_length=1000): core_agent_module = sys.modules['sdk.nexent.core.agents.core_agent'] # Override AgentError inside the imported module to ensure it has message attr core_agent_module.AgentError = MockAgentError + core_agent_module.AgentMaxStepsError = MockAgentMaxStepsError + # Override classes to use our mock classes for isinstance checks + core_agent_module.FinalAnswerStep = MockFinalAnswerStep + core_agent_module.ActionStep = MockActionStep + core_agent_module.PlanningStep = MockPlanningStep + core_agent_module.ActionOutput = MockActionOutput + core_agent_module.RunResult = MockRunResult + # Override CodeAgent to be a proper class that can be inherited + class MockCodeAgent: + def __init__(self, prompt_templates=None, *args, **kwargs): + # Accept any arguments but don't require observer + # Store attributes that might be accessed + self.prompt_templates = prompt_templates + # Initialize common attributes that CodeAgent might have + for key, value in kwargs.items(): + setattr(self, key, value) + core_agent_module.CodeAgent = MockCodeAgent CoreAgent = ImportedCoreAgent @@ -103,16 +198,50 @@ def core_agent_instance(mock_observer): agent.stop_event = Event() agent.memory = MagicMock() agent.memory.steps = [] + agent.memory.get_full_steps = MagicMock(return_value=[]) agent.python_executor = MagicMock() + + # Mock logger with all required methods + agent.logger = MagicMock() + agent.logger.log = MagicMock() + agent.logger.log_task = MagicMock() + agent.logger.log_markdown = MagicMock() + agent.logger.log_code = MagicMock() agent.step_number = 1 agent._execute_step = MagicMock() agent._finalize_step = MagicMock() agent._handle_max_steps_reached = MagicMock() + + # Set default attributes that might be needed + agent.max_steps = 5 + agent.state = {} + agent.system_prompt = "test system prompt" + agent.return_full_result = False + agent.provide_run_summary = False + agent.tools = {} + agent.managed_agents = {} + agent.monitor = MagicMock() + agent.monitor.reset = MagicMock() + agent.model = MagicMock() + if hasattr(agent.model, 'model_id'): + agent.model.model_id = "test-model" + agent.code_block_tags = ["```", "```"] + agent._use_structured_outputs_internally = False + agent.final_answer_checks = None # Set to avoid MagicMock creating new CoreAgent instances return agent +@pytest.fixture(autouse=True) +def reset_token_usage_mock(): + """Ensure TokenUsage mock does not leak state between tests.""" + token_usage = getattr(core_agent_module, "TokenUsage", None) + if hasattr(token_usage, "reset_mock"): + token_usage.reset_mock() + yield + + # ---------------------------------------------------------------------------- # Tests for _run method # ---------------------------------------------------------------------------- @@ -123,11 +252,12 @@ def test_run_normal_execution(core_agent_instance): task = "test task" max_steps = 3 - # Mock _execute_step to return a generator that yields final answer - def mock_execute_generator(action_step): - yield "final_answer" + # Mock _step_stream to return a generator that yields ActionOutput with final answer + def mock_step_stream(action_step): + action_output = MockActionOutput(output="final_answer", is_final_answer=True) + yield action_output - with patch.object(core_agent_instance, '_execute_step', side_effect=mock_execute_generator) as mock_execute_step, \ + with patch.object(core_agent_instance, '_step_stream', side_effect=mock_step_stream) as mock_step_stream_patch, \ patch.object(core_agent_instance, '_finalize_step') as mock_finalize_step: core_agent_instance.step_number = 1 @@ -135,11 +265,11 @@ def mock_execute_generator(action_step): result = list(core_agent_instance._run_stream(task, max_steps)) # Assertions - # _run_stream yields: generator output + action step + final answer step + # _run_stream yields: ActionOutput from _step_stream + action step + final answer step assert len(result) == 3 - assert result[0] == "final_answer" # Generator output - assert isinstance(result[1], MagicMock) # Action step - assert isinstance(result[2], MagicMock) # Final answer step + assert isinstance(result[0], MockActionOutput) # ActionOutput from _step_stream + assert isinstance(result[1], MockActionStep) # Action step + assert isinstance(result[2], MockFinalAnswerStep) # Final answer step def test_run_with_max_steps_reached(core_agent_instance): @@ -148,11 +278,12 @@ def test_run_with_max_steps_reached(core_agent_instance): task = "test task" max_steps = 2 - # Mock _execute_step to return None (no final answer) - def mock_execute_generator(action_step): - yield None + # Mock _step_stream to return ActionOutput without final answer + def mock_step_stream(action_step): + action_output = MockActionOutput(output=None, is_final_answer=False) + yield action_output - with patch.object(core_agent_instance, '_execute_step', side_effect=mock_execute_generator) as mock_execute_step, \ + with patch.object(core_agent_instance, '_step_stream', side_effect=mock_step_stream) as mock_step_stream_patch, \ patch.object(core_agent_instance, '_finalize_step') as mock_finalize_step, \ patch.object(core_agent_instance, '_handle_max_steps_reached', return_value="max_steps_reached") as mock_handle_max: @@ -162,18 +293,19 @@ def mock_execute_generator(action_step): result = list(core_agent_instance._run_stream(task, max_steps)) # Assertions - # For 2 steps: (None + action_step) * 2 + final_action_step + final_answer_step = 6 - assert len(result) == 6 - assert result[0] is None # First generator output - assert isinstance(result[1], MagicMock) # First action step - assert result[2] is None # Second generator output - assert isinstance(result[3], MagicMock) # Second action step - # Final action step (from _handle_max_steps_reached) - assert isinstance(result[4], MagicMock) - assert isinstance(result[5], MagicMock) # Final answer step + # For 2 steps: (ActionOutput + action_step) * 2 + final_action_step + final_answer_step = 6 + assert len(result) >= 5 + # First step: ActionOutput + ActionStep + assert isinstance(result[0], MockActionOutput) # First ActionOutput + assert isinstance(result[1], MockActionStep) # First action step + # Second step: ActionOutput + ActionStep + assert isinstance(result[2], MockActionOutput) # Second ActionOutput + assert isinstance(result[3], MockActionStep) # Second action step + # Last should be final answer step + assert isinstance(result[-1], MockFinalAnswerStep) # Final answer step # Verify method calls - assert mock_execute_step.call_count == 2 + assert mock_step_stream_patch.call_count == 2 mock_handle_max.assert_called_once() assert mock_finalize_step.call_count == 2 @@ -184,23 +316,28 @@ def test_run_with_stop_event(core_agent_instance): task = "test task" max_steps = 3 - def mock_execute_generator(action_step): + def mock_step_stream(action_step): core_agent_instance.stop_event.set() - yield None + action_output = MockActionOutput(output=None, is_final_answer=False) + yield action_output + + # Mock handle_agent_output_types to return the input value (identity function) + # This way when final_answer = "", it will be passed through + with patch.object(core_agent_module, 'handle_agent_output_types', side_effect=lambda x: x): + # Mock _step_stream to set stop event + with patch.object(core_agent_instance, '_step_stream', side_effect=mock_step_stream): + with patch.object(core_agent_instance, '_finalize_step'): + # Execute + result = list(core_agent_instance._run_stream(task, max_steps)) - # Mock _execute_step to set stop event - with patch.object(core_agent_instance, '_execute_step', side_effect=mock_execute_generator): - with patch.object(core_agent_instance, '_finalize_step'): - # Execute - result = list(core_agent_instance._run_stream(task, max_steps)) - - # Assertions - # Should yield: generator output + action step + final answer step - assert len(result) == 3 - assert result[0] is None # Generator output - assert isinstance(result[1], MagicMock) # Action step - # Final answer step with "" - assert isinstance(result[2], MagicMock) + # Assertions + # Should yield: ActionOutput from _step_stream + action step + final answer step + assert len(result) == 3 + assert isinstance(result[0], MockActionOutput) # ActionOutput from _step_stream + assert isinstance(result[1], MockActionStep) # Action step + # Final answer step with "" + assert isinstance(result[2], MockFinalAnswerStep) + assert result[2].output == "" def test_run_with_final_answer_error(core_agent_instance): @@ -209,9 +346,9 @@ def test_run_with_final_answer_error(core_agent_instance): task = "test task" max_steps = 3 - # Mock _execute_step to raise FinalAnswerError - with patch.object(core_agent_instance, '_execute_step', - side_effect=core_agent_module.FinalAnswerError()) as mock_execute_step, \ + # Mock _step_stream to raise FinalAnswerError + with patch.object(core_agent_instance, '_step_stream', + side_effect=core_agent_module.FinalAnswerError()) as mock_step_stream, \ patch.object(core_agent_instance, '_finalize_step'): # Execute result = list(core_agent_instance._run_stream(task, max_steps)) @@ -219,8 +356,8 @@ def test_run_with_final_answer_error(core_agent_instance): # Assertions # When FinalAnswerError occurs, it should yield action step + final answer step assert len(result) == 2 - assert isinstance(result[0], MagicMock) # Action step - assert isinstance(result[1], MagicMock) # Final answer step + assert isinstance(result[0], MockActionStep) # Action step + assert isinstance(result[1], MockFinalAnswerStep) # Final answer step def test_run_with_final_answer_error_and_model_output(core_agent_instance): @@ -229,16 +366,12 @@ def test_run_with_final_answer_error_and_model_output(core_agent_instance): task = "test task" max_steps = 3 - # Create a mock action step with model_output - mock_action_step = MagicMock() - mock_action_step.model_output = "```\nprint('hello')\n```" - - # Mock _execute_step to set model_output and then raise FinalAnswerError - def mock_execute_step(action_step): + # Mock _step_stream to set model_output and then raise FinalAnswerError + def mock_step_stream(action_step): action_step.model_output = "```\nprint('hello')\n```" raise core_agent_module.FinalAnswerError() - with patch.object(core_agent_instance, '_execute_step', side_effect=mock_execute_step), \ + with patch.object(core_agent_instance, '_step_stream', side_effect=mock_step_stream), \ patch.object(core_agent_module, 'convert_code_format', return_value="```python\nprint('hello')\n```") as mock_convert, \ patch.object(core_agent_instance, '_finalize_step'): # Execute @@ -246,8 +379,8 @@ def mock_execute_step(action_step): # Assertions assert len(result) == 2 - assert isinstance(result[0], MagicMock) # Action step - assert isinstance(result[1], MagicMock) # Final answer step + assert isinstance(result[0], MockActionStep) # Action step + assert isinstance(result[1], MockFinalAnswerStep) # Final answer step # Verify convert_code_format was called mock_convert.assert_called_once_with( "```\nprint('hello')\n```") @@ -259,9 +392,9 @@ def test_run_with_agent_error_updated(core_agent_instance): task = "test task" max_steps = 3 - # Mock _execute_step to raise AgentError - with patch.object(core_agent_instance, '_execute_step', - side_effect=MockAgentError("test error")) as mock_execute_step, \ + # Mock _step_stream to raise AgentError + with patch.object(core_agent_instance, '_step_stream', + side_effect=MockAgentError("test error")) as mock_step_stream, \ patch.object(core_agent_instance, '_finalize_step'): # Execute result = list(core_agent_instance._run_stream(task, max_steps)) @@ -270,9 +403,9 @@ def test_run_with_agent_error_updated(core_agent_instance): # When AgentError occurs, it should yield action step + final answer step # But the error causes the loop to continue, so we get multiple action steps assert len(result) >= 2 - assert isinstance(result[0], MagicMock) # Action step with error + assert isinstance(result[0], MockActionStep) # Action step with error # Last item should be final answer step - assert isinstance(result[-1], MagicMock) # Final answer step + assert isinstance(result[-1], MockFinalAnswerStep) # Final answer step def test_run_with_agent_parse_error_branch_updated(core_agent_instance): @@ -280,25 +413,40 @@ def test_run_with_agent_parse_error_branch_updated(core_agent_instance): task = "parse task" max_steps = 1 - # Mock _execute_step to set model_output and then raise FinalAnswerError - def mock_execute_step(action_step): + # Mock _step_stream to set model_output and then raise FinalAnswerError + def mock_step_stream(action_step): action_step.model_output = "```\nprint('hello')\n```" raise core_agent_module.FinalAnswerError() - with patch.object(core_agent_instance, '_execute_step', side_effect=mock_execute_step), \ + with patch.object(core_agent_instance, '_step_stream', side_effect=mock_step_stream), \ patch.object(core_agent_module, 'convert_code_format', return_value="```python\nprint('hello')\n```") as mock_convert, \ patch.object(core_agent_instance, '_finalize_step'): results = list(core_agent_instance._run_stream(task, max_steps)) # _run should yield action step + final answer step assert len(results) == 2 - assert isinstance(results[0], MagicMock) # Action step - assert isinstance(results[1], MagicMock) # Final answer step + assert isinstance(results[0], MockActionStep) # Action step + assert isinstance(results[1], MockFinalAnswerStep) # Final answer step # Verify convert_code_format was called mock_convert.assert_called_once_with( "```\nprint('hello')\n```") +def test_run_stream_validates_final_answer_when_checks_enabled(core_agent_instance): + """Ensure _run_stream triggers final answer validation when checks are configured.""" + task = "validate task" + core_agent_instance.final_answer_checks = ["non-empty"] + core_agent_instance._validate_final_answer = MagicMock() + + def mock_step_stream(action_step): + yield MockActionOutput(output="final answer", is_final_answer=True) + + with patch.object(core_agent_instance, '_step_stream', side_effect=mock_step_stream), \ + patch.object(core_agent_instance, '_finalize_step'): + result = list(core_agent_instance._run_stream(task, max_steps=1)) + + assert len(result) == 3 # ActionOutput, ActionStep, FinalAnswerStep + core_agent_instance._validate_final_answer.assert_called_once_with("final answer") def test_convert_code_format_display_replacements(): """Validate convert_code_format correctly transforms format to standard markdown.""" @@ -575,6 +723,10 @@ def test_step_stream_parse_success(core_agent_instance): core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -586,7 +738,7 @@ def test_step_stream_parse_success(core_agent_instance): return_value=[]) core_agent_instance.model = MagicMock(return_value=mock_chat_message) core_agent_instance.python_executor = MagicMock( - return_value=("output", "logs", False)) + return_value=MockCodeOutput(output="output", logs="logs", is_final_answer=False)) # Execute list(core_agent_instance._step_stream(mock_memory_step)) @@ -599,6 +751,33 @@ def test_step_stream_parse_success(core_agent_instance): assert hasattr(mock_memory_step.tool_calls[0], 'arguments') +def test_step_stream_structured_outputs_with_stop_sequence(core_agent_instance): + """Ensure _step_stream handles structured outputs correctly.""" + mock_memory_step = MagicMock() + mock_chat_message = MagicMock() + mock_chat_message.content = json.dumps({"code": "print('hello')"}) + mock_chat_message.token_usage = MagicMock() + + core_agent_instance.agent_name = "test_agent" + core_agent_instance.step_number = 1 + core_agent_instance._use_structured_outputs_internally = True + core_agent_instance.code_block_tags = ["<>", "[CLOSE]"] + core_agent_instance.write_memory_to_messages = MagicMock(return_value=[]) + core_agent_instance.model = MagicMock(return_value=mock_chat_message) + core_agent_instance.python_executor = MagicMock( + return_value=MockCodeOutput(output="result", logs="", is_final_answer=False) + ) + + with patch.object(core_agent_module, 'extract_code_from_text', return_value="print('hello')") as mock_extract, \ + patch.object(core_agent_module, 'fix_final_answer_code', side_effect=lambda code: code): + list(core_agent_instance._step_stream(mock_memory_step)) + + # Ensure structured output helpers were used + mock_extract.assert_called_once_with("print('hello')", core_agent_instance.code_block_tags) + call_kwargs = core_agent_instance.model.call_args.kwargs + assert call_kwargs["response_format"] == core_agent_module.CODEAGENT_RESPONSE_FORMAT + + def test_step_stream_skips_execution_for_display_only(core_agent_instance): """Test that _step_stream raises FinalAnswerError when only DISPLAY code blocks are present.""" # Setup @@ -611,6 +790,10 @@ def test_step_stream_skips_execution_for_display_only(core_agent_instance): core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -637,6 +820,10 @@ def test_step_stream_parse_failure_raises_final_answer_error(core_agent_instance core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -662,6 +849,10 @@ def test_step_stream_model_generation_error(core_agent_instance): core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -687,6 +878,10 @@ def test_step_stream_execution_success(core_agent_instance): core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -698,14 +893,16 @@ def test_step_stream_execution_success(core_agent_instance): return_value=[]) core_agent_instance.model = MagicMock(return_value=mock_chat_message) core_agent_instance.python_executor = MagicMock( - return_value=("Hello World", "Execution logs", False)) + return_value=MockCodeOutput(output="Hello World", logs="Execution logs", is_final_answer=False)) # Execute result = list(core_agent_instance._step_stream(mock_memory_step)) # Assertions - # Should yield None when is_final_answer is False - assert result[0] is None + # Should yield ActionOutput when is_final_answer is False + assert len(result) == 1 + assert isinstance(result[0], MockActionOutput) + assert result[0].is_final_answer is False assert mock_memory_step.observations is not None # Check that observations was set (we can't easily test the exact content due to mock behavior) assert hasattr(mock_memory_step, 'observations') @@ -723,6 +920,10 @@ def test_step_stream_execution_final_answer(core_agent_instance): core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -734,13 +935,16 @@ def test_step_stream_execution_final_answer(core_agent_instance): return_value=[]) core_agent_instance.model = MagicMock(return_value=mock_chat_message) core_agent_instance.python_executor = MagicMock( - return_value=("final answer", "Execution logs", True)) + return_value=MockCodeOutput(output="final answer", logs="Execution logs", is_final_answer=True)) # Execute result = list(core_agent_instance._step_stream(mock_memory_step)) # Assertions - assert result[0] == "final answer" # Should yield the final answer + assert len(result) == 1 + assert isinstance(result[0], MockActionOutput) + assert result[0].is_final_answer is True + assert result[0].output == "final answer" def test_step_stream_execution_error(core_agent_instance): @@ -755,6 +959,10 @@ def test_step_stream_execution_error(core_agent_instance): core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -795,6 +1003,10 @@ def test_step_stream_observer_calls(core_agent_instance): core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -806,7 +1018,7 @@ def test_step_stream_observer_calls(core_agent_instance): return_value=[]) core_agent_instance.model = MagicMock(return_value=mock_chat_message) core_agent_instance.python_executor = MagicMock( - return_value=("test", "logs", False)) + return_value=MockCodeOutput(output="test", logs="logs", is_final_answer=False)) # Execute list(core_agent_instance._step_stream(mock_memory_step)) @@ -847,6 +1059,10 @@ def test_step_stream_execution_with_logs(core_agent_instance): core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -859,14 +1075,16 @@ def test_step_stream_execution_with_logs(core_agent_instance): core_agent_instance.model = MagicMock(return_value=mock_chat_message) # Mock python_executor to return logs core_agent_instance.python_executor = MagicMock( - return_value=("output", "Some execution logs", False)) + return_value=MockCodeOutput(output="output", logs="Some execution logs", is_final_answer=False)) # Execute result = list(core_agent_instance._step_stream(mock_memory_step)) # Assertions - # Should yield None when is_final_answer is False - assert result[0] is None + # Should yield ActionOutput when is_final_answer is False + assert len(result) == 1 + assert isinstance(result[0], MockActionOutput) + assert result[0].is_final_answer is False # Check that execution logs were recorded assert core_agent_instance.observer.add_message.call_count >= 3 calls = core_agent_instance.observer.add_message.call_args_list @@ -887,6 +1105,10 @@ def test_step_stream_execution_error_with_print_outputs(core_agent_instance): core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -926,6 +1148,10 @@ def test_step_stream_execution_error_with_import_warning(core_agent_instance): core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -969,6 +1195,10 @@ def test_step_stream_execution_error_without_print_outputs(core_agent_instance): core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -1003,6 +1233,10 @@ def test_step_stream_execution_with_none_output(core_agent_instance): core_agent_instance.step_number = 1 core_agent_instance.grammar = None core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.memory = MagicMock() core_agent_instance.memory.steps = [] @@ -1015,14 +1249,16 @@ def test_step_stream_execution_with_none_output(core_agent_instance): core_agent_instance.model = MagicMock(return_value=mock_chat_message) # Mock python_executor to return None output core_agent_instance.python_executor = MagicMock( - return_value=(None, "Execution logs", False)) + return_value=MockCodeOutput(output=None, logs="Execution logs", is_final_answer=False)) # Execute result = list(core_agent_instance._step_stream(mock_memory_step)) # Assertions - # Should yield None when is_final_answer is False - assert result[0] is None + # Should yield ActionOutput when is_final_answer is False + assert len(result) == 1 + assert isinstance(result[0], MockActionOutput) + assert result[0].is_final_answer is False assert mock_memory_step.observations is not None # Check that observations was set but should not contain "Last output from code snippet" # since output is None @@ -1050,6 +1286,10 @@ def test_run_with_additional_args(core_agent_instance): core_agent_instance.monitor = MagicMock() core_agent_instance.monitor.reset = MagicMock() core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.model = MagicMock() core_agent_instance.model.model_id = "test-model" core_agent_instance.name = "test_agent" @@ -1059,8 +1299,7 @@ def test_run_with_additional_args(core_agent_instance): core_agent_instance.observer = MagicMock() # Mock _run_stream to return a simple result - mock_final_step = MagicMock() - mock_final_step.final_answer = "final result" + mock_final_step = MockFinalAnswerStep(output="final result") with patch.object(core_agent_instance, '_run_stream', return_value=[mock_final_step]): # Execute @@ -1089,6 +1328,10 @@ def test_run_with_stream_true(core_agent_instance): core_agent_instance.monitor = MagicMock() core_agent_instance.monitor.reset = MagicMock() core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.model = MagicMock() core_agent_instance.model.model_id = "test-model" core_agent_instance.name = "test_agent" @@ -1123,6 +1366,10 @@ def test_run_with_reset_false(core_agent_instance): core_agent_instance.monitor = MagicMock() core_agent_instance.monitor.reset = MagicMock() core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.model = MagicMock() core_agent_instance.model.model_id = "test-model" core_agent_instance.name = "test_agent" @@ -1132,8 +1379,7 @@ def test_run_with_reset_false(core_agent_instance): core_agent_instance.observer = MagicMock() # Mock _run_stream to return a simple result - mock_final_step = MagicMock() - mock_final_step.final_answer = "final result" + mock_final_step = MockFinalAnswerStep(output="final result") with patch.object(core_agent_instance, '_run_stream', return_value=[mock_final_step]): # Execute @@ -1162,6 +1408,10 @@ def test_run_with_images(core_agent_instance): core_agent_instance.monitor = MagicMock() core_agent_instance.monitor.reset = MagicMock() core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.model = MagicMock() core_agent_instance.model.model_id = "test-model" core_agent_instance.name = "test_agent" @@ -1171,8 +1421,7 @@ def test_run_with_images(core_agent_instance): core_agent_instance.observer = MagicMock() # Mock _run_stream to return a simple result - mock_final_step = MagicMock() - mock_final_step.final_answer = "final result" + mock_final_step = MockFinalAnswerStep(output="final result") with patch.object(core_agent_instance, '_run_stream', return_value=[mock_final_step]): # Execute @@ -1185,8 +1434,89 @@ def test_run_with_images(core_agent_instance): call_args = core_agent_instance.memory.steps.append.call_args[0][0] # The TaskStep is mocked, so just verify it was called with correct arguments via the constructor # We'll check that TaskStep was called with the right parameters - mock_smolagents.memory.TaskStep.assert_called_with( - task=task, task_images=images) + assert isinstance(call_args, MockTaskStep) + assert call_args.task == task + assert call_args.task_images == images + + +def test_run_return_full_result_success_state(core_agent_instance): + """run should return RunResult with aggregated token usage when requested.""" + task = "test task" + token_usage = MagicMock(input_tokens=7, output_tokens=3) + action_step = core_agent_module.ActionStep() + action_step.token_usage = token_usage + + core_agent_instance.name = "test_agent" + core_agent_instance.memory.steps = [action_step] + core_agent_instance.memory.get_full_steps = MagicMock(return_value=[{"step": "data"}]) + core_agent_instance.memory.reset = MagicMock() + core_agent_instance.monitor.reset = MagicMock() + core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() + core_agent_instance.model = MagicMock() + core_agent_instance.model.model_id = "model" + core_agent_instance.python_executor = MagicMock() + core_agent_instance.python_executor.send_variables = MagicMock() + core_agent_instance.python_executor.send_tools = MagicMock() + core_agent_instance.observer = MagicMock() + + final_step = MockFinalAnswerStep(output="final result") + with patch.object(core_agent_instance, '_run_stream', return_value=[final_step]): + result = core_agent_instance.run(task, return_full_result=True) + + assert isinstance(result, core_agent_module.RunResult) + assert result.output == "final result" + core_agent_module.TokenUsage.assert_called_once_with(input_tokens=7, output_tokens=3) + assert result.token_usage == core_agent_module.TokenUsage.return_value + assert result.state == "success" + core_agent_instance.memory.get_full_steps.assert_called_once() + + +def test_run_return_full_result_max_steps_error(core_agent_instance): + """run should mark state as max_steps_error when the last step contains AgentMaxStepsError.""" + task = "test task" + + action_step = core_agent_module.ActionStep() + action_step.token_usage = None + action_step.error = core_agent_module.AgentMaxStepsError("max steps reached") + + class StepsList(list): + def append(self, item): + # Skip storing TaskStep to keep action_step as the last element + if isinstance(item, core_agent_module.TaskStep): + return + super().append(item) + + core_agent_instance.name = "test_agent" + steps_list = StepsList([action_step]) + core_agent_instance.memory.steps = steps_list + core_agent_instance.memory.get_full_steps = MagicMock(return_value=[{"step": "data"}]) + core_agent_instance.memory.reset = MagicMock() + core_agent_instance.monitor.reset = MagicMock() + core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() + core_agent_instance.model = MagicMock() + core_agent_instance.model.model_id = "model" + core_agent_instance.python_executor = MagicMock() + core_agent_instance.python_executor.send_variables = MagicMock() + core_agent_instance.python_executor.send_tools = MagicMock() + core_agent_instance.observer = MagicMock() + + final_step = MockFinalAnswerStep(output="final result") + with patch.object(core_agent_instance, '_run_stream', return_value=[final_step]): + result = core_agent_instance.run(task, return_full_result=True) + + assert isinstance(result, core_agent_module.RunResult) + assert result.token_usage is None + core_agent_module.TokenUsage.assert_not_called() + assert result.state == "max_steps_error" + core_agent_instance.memory.get_full_steps.assert_called_once() def test_run_without_python_executor(core_agent_instance): @@ -1204,6 +1534,10 @@ def test_run_without_python_executor(core_agent_instance): core_agent_instance.monitor = MagicMock() core_agent_instance.monitor.reset = MagicMock() core_agent_instance.logger = MagicMock() + core_agent_instance.logger.log = MagicMock() + core_agent_instance.logger.log_task = MagicMock() + core_agent_instance.logger.log_markdown = MagicMock() + core_agent_instance.logger.log_code = MagicMock() core_agent_instance.model = MagicMock() core_agent_instance.model.model_id = "test-model" core_agent_instance.name = "test_agent" @@ -1213,8 +1547,7 @@ def test_run_without_python_executor(core_agent_instance): core_agent_instance.observer = MagicMock() # Mock _run_stream to return a simple result - mock_final_step = MagicMock() - mock_final_step.final_answer = "final result" + mock_final_step = MockFinalAnswerStep(output="final result") with patch.object(core_agent_instance, '_run_stream', return_value=[mock_final_step]): # Execute @@ -1267,6 +1600,31 @@ def test_call_method_success(core_agent_instance): "test_agent", ProcessType.AGENT_FINISH, "test result") +def test_call_method_with_run_result_return(core_agent_instance): + """Test __call__ handles RunResult by extracting its output.""" + task = "test task" + core_agent_instance.name = "test_agent" + core_agent_instance.state = {} + core_agent_instance.prompt_templates = { + "managed_agent": { + "task": "Task: {{task}}", + "report": "Report: {{final_answer}}" + } + } + core_agent_instance.provide_run_summary = False + core_agent_instance.observer = MagicMock() + + run_result = core_agent_module.RunResult(output="run result", token_usage=None, steps=[], timing=None, state="success") + with patch.object(core_agent_instance, 'run', return_value=run_result) as mock_run: + result = core_agent_instance(task) + + assert "Report: run result" in result + mock_run.assert_called_once() + core_agent_instance.observer.add_message.assert_called_with( + "test_agent", ProcessType.AGENT_FINISH, "run result" + ) + + def test_call_method_with_run_summary(core_agent_instance): """Test __call__ method with provide_run_summary=True.""" # Setup @@ -1284,10 +1642,14 @@ def test_call_method_with_run_summary(core_agent_instance): core_agent_instance.provide_run_summary = True core_agent_instance.observer = MagicMock() - # Mock write_memory_to_messages to return some simple messages + # Mock write_memory_to_messages to return some simple messages with .content attribute + class MockMessage: + def __init__(self, content): + self.content = content + mock_messages = [ - {"content": "msg1"}, - {"content": "msg2"} + MockMessage("msg1"), + MockMessage("msg2") ] core_agent_instance.write_memory_to_messages = MagicMock( return_value=mock_messages) diff --git a/test/sdk/core/agents/test_nexent_agent.py b/test/sdk/core/agents/test_nexent_agent.py index 3dc831323..2a842ea72 100644 --- a/test/sdk/core/agents/test_nexent_agent.py +++ b/test/sdk/core/agents/test_nexent_agent.py @@ -27,11 +27,16 @@ class _ActionStep: - pass + def __init__(self, step_number=None, timing=None, action_output=None, model_output=None): + self.step_number = step_number + self.timing = timing + self.action_output = action_output + self.model_output = model_output class _TaskStep: - pass + def __init__(self, task=None): + self.task = task class _AgentText: @@ -214,6 +219,8 @@ class _MockToolSign: "nexent.storage": mock_nexent_storage_module, "nexent.multi_modal": mock_nexent_multi_modal_module, "nexent.multi_modal.load_save_object": mock_nexent_load_save_module, + # Mock tiktoken to avoid importing the real package when models import it + "tiktoken": MagicMock(), # Mock the OpenAIModel import "sdk.nexent.core.models.openai_llm": MagicMock(OpenAIModel=mock_openai_model_class), # Mock CoreAgent import @@ -230,7 +237,7 @@ class _MockToolSign: from sdk.nexent.core.utils.observer import MessageObserver, ProcessType from sdk.nexent.core.agents import nexent_agent from sdk.nexent.core.agents.nexent_agent import NexentAgent, ActionStep, TaskStep - from sdk.nexent.core.agents.agent_model import ToolConfig, ModelConfig, AgentConfig + from sdk.nexent.core.agents.agent_model import ToolConfig, ModelConfig, AgentConfig, AgentHistory # ---------------------------------------------------------------------------- @@ -1087,6 +1094,48 @@ def test_add_history_to_agent_none_history(nexent_agent_instance, mock_core_agen assert len(mock_core_agent.memory.steps) == 0 +def test_add_history_to_agent_user_and_assistant_history(nexent_agent_instance, mock_core_agent): + """Test add_history_to_agent correctly converts user and assistant messages to memory steps.""" + nexent_agent_instance.agent = mock_core_agent + + user_msg = AgentHistory(role="user", content="User question") + assistant_msg = AgentHistory(role="assistant", content="Assistant reply") + + nexent_agent_instance.add_history_to_agent([user_msg, assistant_msg]) + + mock_core_agent.memory.reset.assert_called_once() + assert len(mock_core_agent.memory.steps) == 2 + + # First step should be a TaskStep for the user message + first_step = mock_core_agent.memory.steps[0] + assert isinstance(first_step, TaskStep) + assert first_step.task == "User question" + + # Second step should be an ActionStep for the assistant message + second_step = mock_core_agent.memory.steps[1] + assert isinstance(second_step, ActionStep) + assert second_step.action_output == "Assistant reply" + assert second_step.model_output == "Assistant reply" + + +def test_add_history_to_agent_invalid_agent_type(nexent_agent_instance): + """Test add_history_to_agent raises TypeError when agent is not a CoreAgent.""" + nexent_agent_instance.agent = "not_core_agent" + + with pytest.raises(TypeError, match="agent must be a CoreAgent object"): + nexent_agent_instance.add_history_to_agent([]) + + +def test_add_history_to_agent_invalid_history_items(nexent_agent_instance, mock_core_agent): + """Test add_history_to_agent raises TypeError when history items are not AgentHistory.""" + nexent_agent_instance.agent = mock_core_agent + + invalid_history = [{"role": "user", "content": "hello"}] + + with pytest.raises(TypeError, match="history must be a list of AgentHistory objects"): + nexent_agent_instance.add_history_to_agent(invalid_history) + + def test_agent_run_with_observer_success_with_agent_text(nexent_agent_instance, mock_core_agent): """Test successful agent_run_with_observer with AgentText final answer.""" # Setup @@ -1103,7 +1152,7 @@ def test_agent_run_with_observer_success_with_agent_text(nexent_agent_instance, "Final answer with thinking content") mock_core_agent.run.return_value = [mock_action_step] - mock_core_agent.run.return_value[-1].final_answer = mock_final_answer + mock_core_agent.run.return_value[-1].output = mock_final_answer # Execute nexent_agent_instance.agent_run_with_observer("test query") @@ -1129,7 +1178,7 @@ def test_agent_run_with_observer_success_with_string_final_answer(nexent_agent_i mock_action_step.error = None mock_core_agent.run.return_value = [mock_action_step] - mock_core_agent.run.return_value[-1].final_answer = "String final answer with thinking" + mock_core_agent.run.return_value[-1].output = "String final answer with thinking" # Execute nexent_agent_instance.agent_run_with_observer("test query") @@ -1153,7 +1202,7 @@ def test_agent_run_with_observer_with_error_in_step(nexent_agent_instance, mock_ mock_action_step.error = "Test error occurred" mock_core_agent.run.return_value = [mock_action_step] - mock_core_agent.run.return_value[-1].final_answer = "Final answer" + mock_core_agent.run.return_value[-1].output = "Final answer" # Execute nexent_agent_instance.agent_run_with_observer("test query") @@ -1176,7 +1225,7 @@ def test_agent_run_with_observer_skips_non_action_step(nexent_agent_instance, mo mock_action_step.error = None mock_core_agent.run.return_value = [mock_task_step, mock_action_step] - mock_core_agent.run.return_value[-1].final_answer = "Final answer" + mock_core_agent.run.return_value[-1].output = "Final answer" # Execute nexent_agent_instance.agent_run_with_observer("test query") @@ -1199,7 +1248,7 @@ def test_agent_run_with_observer_with_stop_event_set(nexent_agent_instance, mock mock_action_step.error = None mock_core_agent.run.return_value = [mock_action_step] - mock_core_agent.run.return_value[-1].final_answer = "Final answer" + mock_core_agent.run.return_value[-1].output = "Final answer" # Execute nexent_agent_instance.agent_run_with_observer("test query") @@ -1226,6 +1275,14 @@ def test_agent_run_with_observer_with_exception(nexent_agent_instance, mock_core ) +def test_agent_run_with_observer_invalid_agent_type(nexent_agent_instance): + """Test agent_run_with_observer raises TypeError when agent is not a CoreAgent.""" + nexent_agent_instance.agent = "not_core_agent" + + with pytest.raises(TypeError, match="agent must be a CoreAgent object"): + nexent_agent_instance.agent_run_with_observer("test query") + + def test_agent_run_with_observer_with_reset_false(nexent_agent_instance, mock_core_agent): """Test agent_run_with_observer with reset=False parameter.""" # Setup @@ -1238,7 +1295,7 @@ def test_agent_run_with_observer_with_reset_false(nexent_agent_instance, mock_co mock_action_step.error = None mock_core_agent.run.return_value = [mock_action_step] - mock_core_agent.run.return_value[-1].final_answer = "Final answer" + mock_core_agent.run.return_value[-1].output = "Final answer" # Execute with reset=False nexent_agent_instance.agent_run_with_observer("test query", reset=False) diff --git a/test/sdk/core/agents/test_run_agent.py b/test/sdk/core/agents/test_run_agent.py index 0cafdd8a1..b47aec879 100644 --- a/test/sdk/core/agents/test_run_agent.py +++ b/test/sdk/core/agents/test_run_agent.py @@ -49,7 +49,7 @@ def from_mcp(cls, *args, **kwargs): # pylint: disable=unused-argument sub_mod = ModuleType(f"smolagents.{_sub}") # Populate required attributes with MagicMocks to satisfy import-time `from smolagents. import ...`. if _sub == "agents": - for _name in ["CodeAgent", "populate_template", "handle_agent_output_types", "AgentError", "AgentType"]: + for _name in ["CodeAgent", "populate_template", "handle_agent_output_types", "AgentError", "AgentType", "ActionOutput", "RunResult"]: setattr(sub_mod, _name, MagicMock(name=f"smolagents.agents.{_name}")) elif _sub == "local_python_executor": setattr(sub_mod, "fix_final_answer_code", MagicMock(name="fix_final_answer_code")) @@ -59,6 +59,7 @@ def from_mcp(cls, *args, **kwargs): # pylint: disable=unused-argument elif _sub == "models": setattr(sub_mod, "ChatMessage", MagicMock(name="smolagents.models.ChatMessage")) setattr(sub_mod, "MessageRole", MagicMock(name="smolagents.models.MessageRole")) + setattr(sub_mod, "CODEAGENT_RESPONSE_FORMAT", MagicMock(name="smolagents.models.CODEAGENT_RESPONSE_FORMAT")) # Provide a simple base class so that OpenAIModel can inherit from it class _DummyOpenAIServerModel: def __init__(self, *args, **kwargs): @@ -67,13 +68,18 @@ def __init__(self, *args, **kwargs): setattr(sub_mod, "OpenAIServerModel", _DummyOpenAIServerModel) elif _sub == "monitoring": setattr(sub_mod, "LogLevel", MagicMock(name="smolagents.monitoring.LogLevel")) + setattr(sub_mod, "Timing", MagicMock(name="smolagents.monitoring.Timing")) + setattr(sub_mod, "YELLOW_HEX", MagicMock(name="smolagents.monitoring.YELLOW_HEX")) + setattr(sub_mod, "TokenUsage", MagicMock(name="smolagents.monitoring.TokenUsage")) elif _sub == "utils": for _name in [ "AgentExecutionError", "AgentGenerationError", "AgentParsingError", + "AgentMaxStepsError", "parse_code_blobs", "truncate_content", + "extract_code_from_text", ]: setattr(sub_mod, _name, MagicMock(name=f"smolagents.utils.{_name}")) setattr(mock_smolagents, _sub, sub_mod) @@ -82,6 +88,8 @@ def __init__(self, *args, **kwargs): # Top-level exports expected directly from `smolagents` by nexent_agent.py for _name in ["ActionStep", "TaskStep", "AgentText", "handle_agent_output_types"]: setattr(mock_smolagents, _name, MagicMock(name=f"smolagents.{_name}")) +# Export Timing from monitoring submodule to top-level +setattr(mock_smolagents, "Timing", mock_smolagents.monitoring.Timing) # Also export Tool at top-level so that `from smolagents import Tool` works setattr(mock_smolagents, "Tool", mock_smolagents_tool_cls) @@ -237,9 +245,9 @@ def test_agent_run_thread_local_flow(basic_agent_run_info, monkeypatch): def test_agent_run_thread_mcp_flow(basic_agent_run_info, mock_memory_context, monkeypatch): - """Verify behaviour when an MCP host list is provided.""" - # Give the AgentRunInfo an MCP host list - basic_agent_run_info.mcp_host = ["http://mcp.server"] + """Verify behaviour when an MCP host list is provided with auto-detected transport.""" + # Give the AgentRunInfo an MCP host list (string format, auto-detect transport) + basic_agent_run_info.mcp_host = ["http://mcp.server/mcp"] # Prepare ToolCollection.from_mcp to return a context manager mock_tool_collection = MagicMock(name="ToolCollectionInstance") @@ -257,7 +265,7 @@ def test_agent_run_thread_mcp_flow(basic_agent_run_info, mock_memory_context, mo basic_agent_run_info.observer.add_message.assert_any_call("", ProcessType.AGENT_NEW_RUN, "") # ToolCollection.from_mcp should be called with the expected client list and trust_remote_code=True - expected_client_list = [{"url": "http://mcp.server"}] + expected_client_list = [{"url": "http://mcp.server/mcp", "transport": "streamable-http"}] run_agent.ToolCollection.from_mcp.assert_called_once_with(expected_client_list, trust_remote_code=True) # NexentAgent should be instantiated with mcp_tool_collection @@ -275,6 +283,116 @@ def test_agent_run_thread_mcp_flow(basic_agent_run_info, mock_memory_context, mo mock_nexent_instance.agent_run_with_observer.assert_called_once_with(query=basic_agent_run_info.query, reset=False) +def test_agent_run_thread_mcp_flow_with_explicit_transport(basic_agent_run_info, mock_memory_context, monkeypatch): + """Verify behaviour when MCP host is provided with explicit transport in dict format.""" + # Give the AgentRunInfo an MCP host list with explicit transport + basic_agent_run_info.mcp_host = [{"url": "http://mcp.server", "transport": "sse"}] + + # Prepare ToolCollection.from_mcp to return a context manager + mock_tool_collection = MagicMock(name="ToolCollectionInstance") + mock_context_manager = MagicMock(__enter__=MagicMock(return_value=mock_tool_collection), __exit__=MagicMock(return_value=None)) + monkeypatch.setattr(run_agent.ToolCollection, "from_mcp", MagicMock(return_value=mock_context_manager)) + + # Patch NexentAgent + mock_nexent_instance = MagicMock(name="NexentAgentInstance") + monkeypatch.setattr(run_agent, "NexentAgent", MagicMock(return_value=mock_nexent_instance)) + + # Execute + run_agent.agent_run_thread(basic_agent_run_info) + + # ToolCollection.from_mcp should be called with the expected client list + expected_client_list = [{"url": "http://mcp.server", "transport": "sse"}] + run_agent.ToolCollection.from_mcp.assert_called_once_with(expected_client_list, trust_remote_code=True) + + +def test_agent_run_thread_mcp_flow_mixed_formats(basic_agent_run_info, mock_memory_context, monkeypatch): + """Verify behaviour when MCP host list contains both string and dict formats.""" + # Mix of string (auto-detect) and dict (explicit) formats + basic_agent_run_info.mcp_host = [ + "http://mcp1.server/mcp", # Auto-detect: streamable-http + "http://mcp2.server/sse", # Auto-detect: sse + {"url": "http://mcp3.server/mcp", "transport": "streamable-http"}, # Explicit: streamable-http + ] + + # Prepare ToolCollection.from_mcp to return a context manager + mock_tool_collection = MagicMock(name="ToolCollectionInstance") + mock_context_manager = MagicMock(__enter__=MagicMock(return_value=mock_tool_collection), __exit__=MagicMock(return_value=None)) + monkeypatch.setattr(run_agent.ToolCollection, "from_mcp", MagicMock(return_value=mock_context_manager)) + + # Patch NexentAgent + mock_nexent_instance = MagicMock(name="NexentAgentInstance") + monkeypatch.setattr(run_agent, "NexentAgent", MagicMock(return_value=mock_nexent_instance)) + + # Execute + run_agent.agent_run_thread(basic_agent_run_info) + + # ToolCollection.from_mcp should be called with normalized client list + expected_client_list = [ + {"url": "http://mcp1.server/mcp", "transport": "streamable-http"}, + {"url": "http://mcp2.server/sse", "transport": "sse"}, + {"url": "http://mcp3.server/mcp", "transport": "streamable-http"}, + ] + run_agent.ToolCollection.from_mcp.assert_called_once_with(expected_client_list, trust_remote_code=True) + + +def test_detect_transport(): + """Test transport auto-detection logic based on URL ending.""" + # Test URLs ending with /sse + assert run_agent._detect_transport("http://server/sse") == "sse" + assert run_agent._detect_transport("https://api.example.com/sse") == "sse" + assert run_agent._detect_transport("http://localhost:3000/sse") == "sse" + + # Test URLs ending with /mcp + assert run_agent._detect_transport("http://server/mcp") == "streamable-http" + assert run_agent._detect_transport("https://api.example.com/mcp") == "streamable-http" + assert run_agent._detect_transport("http://localhost:3000/mcp") == "streamable-http" + + # Test default fallback (no /sse or /mcp ending) + assert run_agent._detect_transport("http://server") == "streamable-http" + assert run_agent._detect_transport("https://api.example.com") == "streamable-http" + assert run_agent._detect_transport("http://server/other") == "streamable-http" + + +def test_normalize_mcp_config(): + """Test MCP configuration normalization.""" + # Test string format (auto-detect based on URL ending) + result = run_agent._normalize_mcp_config("http://server/mcp") + assert result == {"url": "http://server/mcp", "transport": "streamable-http"} + + result = run_agent._normalize_mcp_config("http://server/sse") + assert result == {"url": "http://server/sse", "transport": "sse"} + + # Test string format without /sse or /mcp ending (defaults to streamable-http) + result = run_agent._normalize_mcp_config("http://server") + assert result == {"url": "http://server", "transport": "streamable-http"} + + # Test dict format with explicit transport + result = run_agent._normalize_mcp_config({"url": "http://server/mcp", "transport": "sse"}) + assert result == {"url": "http://server/mcp", "transport": "sse"} + + # Test dict format without transport (auto-detect) + result = run_agent._normalize_mcp_config({"url": "http://server/sse"}) + assert result == {"url": "http://server/sse", "transport": "sse"} + + result = run_agent._normalize_mcp_config({"url": "http://server/mcp"}) + assert result == {"url": "http://server/mcp", "transport": "streamable-http"} + + # Test invalid dict (missing url) + with pytest.raises(ValueError, match="must contain 'url' key"): + run_agent._normalize_mcp_config({"transport": "sse"}) + + # Test invalid transport type + with pytest.raises(ValueError, match="Invalid transport type"): + run_agent._normalize_mcp_config({"url": "http://server/mcp", "transport": "stdio"}) + + with pytest.raises(ValueError, match="Invalid transport type"): + run_agent._normalize_mcp_config({"url": "http://server/mcp", "transport": "invalid"}) + + # Test invalid type + with pytest.raises(ValueError, match="Invalid MCP host item type"): + run_agent._normalize_mcp_config(123) + + def test_agent_run_thread_handles_internal_exception(basic_agent_run_info, mock_memory_context, monkeypatch): """If an internal error occurs, the observer should be notified and a ValueError propagated.""" # Configure NexentAgent.create_single_agent to raise an exception diff --git a/test/sdk/core/tools/test_datamate_search_tool.py b/test/sdk/core/tools/test_datamate_search_tool.py index cc2742796..ebfdb3bba 100644 --- a/test/sdk/core/tools/test_datamate_search_tool.py +++ b/test/sdk/core/tools/test_datamate_search_tool.py @@ -117,7 +117,7 @@ def test_extract_dataset_id(self, datamate_tool: DataMateSearchTool, path, expec @pytest.mark.parametrize( "dataset_id, file_id, expected", [ - ("ds1", "f1", "127.0.0.1/api/data-management/datasets/ds1/files/f1/download"), + ("ds1", "f1", "http://127.0.0.1:8080/api/data-management/datasets/ds1/files/f1/download"), ("", "f1", ""), ("ds1", "", ""), ], diff --git a/test/sdk/core/tools/test_knowledge_base_search_tool.py b/test/sdk/core/tools/test_knowledge_base_search_tool.py index 535af6b35..f6cdc4577 100644 --- a/test/sdk/core/tools/test_knowledge_base_search_tool.py +++ b/test/sdk/core/tools/test_knowledge_base_search_tool.py @@ -37,7 +37,8 @@ def knowledge_base_search_tool(mock_observer, mock_vdb_core, mock_embedding_mode index_names=["test_index1", "test_index2"], observer=mock_observer, embedding_model=mock_embedding_model, - vdb_core=mock_vdb_core + vdb_core=mock_vdb_core, + name_resolver={} ) return tool @@ -50,7 +51,8 @@ def knowledge_base_search_tool_no_observer(mock_vdb_core, mock_embedding_model): index_names=["test_index"], observer=None, embedding_model=mock_embedding_model, - vdb_core=mock_vdb_core + vdb_core=mock_vdb_core, + name_resolver={} ) return tool @@ -78,6 +80,49 @@ def create_mock_search_result(count=3): class TestKnowledgeBaseSearchTool: """Test KnowledgeBaseSearchTool functionality""" + def test_update_name_resolver_supports_empty_mapping(self, knowledge_base_search_tool): + """Ensure update_name_resolver replaces mapping and handles falsy input""" + knowledge_base_search_tool.update_name_resolver({"kb": "index_kb"}) + assert knowledge_base_search_tool.name_resolver == {"kb": "index_kb"} + + knowledge_base_search_tool.update_name_resolver(None) + assert knowledge_base_search_tool.name_resolver == {} + + def test_resolve_names_without_resolver_logs_warning(self, knowledge_base_search_tool, mocker): + """When no resolver is configured, names are returned unchanged and warning is logged""" + warning_mock = mocker.patch("sdk.nexent.core.tools.knowledge_base_search_tool.logger.warning") + + names = knowledge_base_search_tool._resolve_names(["kb1", "kb2"]) + + assert names == ["kb1", "kb2"] + warning_mock.assert_called_once() + + @pytest.mark.parametrize( + "incoming,expected", + [ + (None, []), + ("single_index", ["single_index"]), + (["a", "b"], ["a", "b"]), + ], + ) + def test_normalize_index_names_variants(self, knowledge_base_search_tool_no_observer, incoming, expected): + """_normalize_index_names should normalize None, string, and list inputs""" + assert knowledge_base_search_tool_no_observer._normalize_index_names(incoming) == expected + + def test_forward_with_observer_adds_messages(self, knowledge_base_search_tool): + """forward should send TOOL and CARD messages when observer is present""" + mock_results = create_mock_search_result(1) + knowledge_base_search_tool.vdb_core.hybrid_search.return_value = mock_results + + knowledge_base_search_tool.forward("hello world") + + knowledge_base_search_tool.observer.add_message.assert_any_call( + "", ProcessType.TOOL, "Searching the knowledge base..." + ) + knowledge_base_search_tool.observer.add_message.assert_any_call( + "", ProcessType.CARD, json.dumps([{"icon": "search", "text": "hello world"}], ensure_ascii=False) + ) + def test_init_with_custom_values(self, mock_observer, mock_vdb_core, mock_embedding_model): """Test initialization with custom values""" tool = KnowledgeBaseSearchTool( @@ -85,7 +130,8 @@ def test_init_with_custom_values(self, mock_observer, mock_vdb_core, mock_embedd index_names=["index1", "index2", "index3"], observer=mock_observer, embedding_model=mock_embedding_model, - vdb_core=mock_vdb_core + vdb_core=mock_vdb_core, + name_resolver={} ) assert tool.top_k == 10 @@ -101,7 +147,8 @@ def test_init_with_none_index_names(self, mock_vdb_core, mock_embedding_model): index_names=None, observer=None, embedding_model=mock_embedding_model, - vdb_core=mock_vdb_core + vdb_core=mock_vdb_core, + name_resolver={} ) assert tool.index_names == [] diff --git a/test/sdk/vector_database/test_elasticsearch_core.py b/test/sdk/vector_database/test_elasticsearch_core.py index 30f8ff277..f9f878852 100644 --- a/test/sdk/vector_database/test_elasticsearch_core.py +++ b/test/sdk/vector_database/test_elasticsearch_core.py @@ -522,6 +522,47 @@ def test_vectorize_documents_small_batch(elasticsearch_core_instance): mock_embedding_model.get_embeddings.assert_called_once() mock_bulk.assert_called_once() +def test_small_batch_progress_callback_exception(elasticsearch_core_instance, caplog): + """Progress callback errors should be logged without failing the insert.""" + mock_embedding_model = MagicMock() + mock_embedding_model.get_embeddings.return_value = [[0.1] * 3] + mock_embedding_model.embedding_model_name = "m" + + documents = [{"content": "a"}] + + def bad_progress(_, __): + raise RuntimeError("boom") + + with patch.object(elasticsearch_core_instance.client, "bulk") as mock_bulk, \ + patch("time.strftime", lambda *a, **k: "2025-01-15T10:30:00"), \ + patch("time.time", lambda: 1642234567): + mock_bulk.return_value = {"errors": False, "items": []} + result = elasticsearch_core_instance._small_batch_insert( + "idx", documents, "content", mock_embedding_model, progress_callback=bad_progress + ) + + assert result == 1 + assert any("Progress callback failed in small batch" in m for m in caplog.messages) + +def test_small_batch_error_path_logs_and_raises(elasticsearch_core_instance, caplog): + """Small batch should log errors and re-raise when bulk fails.""" + mock_embedding_model = MagicMock() + mock_embedding_model.get_embeddings.return_value = [[0.1] * 3] + mock_embedding_model.embedding_model_name = "m" + + documents = [{"content": "x"}] + + with patch.object(elasticsearch_core_instance, "client") as mock_client, \ + patch("time.strftime", lambda *a, **k: "2025-01-15T10:30:00"), \ + patch("time.time", lambda: 1642234567): + mock_client.bulk.side_effect = RuntimeError("bulk boom") + with pytest.raises(RuntimeError): + elasticsearch_core_instance._small_batch_insert( + "idx", documents, "content", mock_embedding_model + ) + + assert any("Small batch insert failed: bulk boom" in m for m in caplog.messages) + def test_vectorize_documents_large_batch(elasticsearch_core_instance): """Test indexing a large batch of documents (>= 64).""" @@ -558,6 +599,76 @@ def test_vectorize_documents_large_batch(elasticsearch_core_instance): mock_bulk.assert_called() mock_refresh.assert_called_once_with("test_index") +def test_large_batch_progress_callback_invoked(elasticsearch_core_instance): + """Progress callback should be triggered during embedding phase.""" + mock_embedding_model = MagicMock() + mock_embedding_model.embedding_model_name = "test-model" + mock_embedding_model.get_embeddings.return_value = [[0.1], [0.2]] + + docs = [{"content": "a"}, {"content": "b"}] + progress_calls = [] + + with patch.object(elasticsearch_core_instance.client, "bulk") as mock_bulk, \ + patch.object(elasticsearch_core_instance, "_force_refresh_with_retry"): + mock_bulk.return_value = {"errors": False, "items": []} + elasticsearch_core_instance._large_batch_insert( + "idx", docs, batch_size=5, content_field="content", + embedding_model=mock_embedding_model, embedding_batch_size=2, + progress_callback=lambda done, total: progress_calls.append((done, total)) + ) + + assert progress_calls == [(2, 2)] + +def test_large_batch_progress_callback_exception_logged(elasticsearch_core_instance, caplog): + """Embedding progress callback errors should be logged and not stop indexing.""" + mock_embedding_model = MagicMock() + mock_embedding_model.embedding_model_name = "test-model" + mock_embedding_model.get_embeddings.return_value = [[0.1]] + + docs = [{"content": "a"}] + + def bad_progress(_, __): + raise RuntimeError("cb fail") + + with patch.object(elasticsearch_core_instance.client, "bulk") as mock_bulk, \ + patch.object(elasticsearch_core_instance, "_force_refresh_with_retry"): + mock_bulk.return_value = {"errors": False, "items": []} + elasticsearch_core_instance._large_batch_insert( + "idx", docs, batch_size=1, content_field="content", + embedding_model=mock_embedding_model, embedding_batch_size=1, + progress_callback=bad_progress + ) + + assert any("Progress callback failed during embedding" in m for m in caplog.messages) + +def test_large_batch_retry_logs_warning(elasticsearch_core_instance, caplog): + """Embedding retries should emit warnings before succeeding.""" + mock_embedding_model = MagicMock() + mock_embedding_model.embedding_model_name = "test-model" + call_counter = {"n": 0} + + def get_embeddings(_): + call_counter["n"] += 1 + if call_counter["n"] < 3: + raise RuntimeError("embed fail") + return [[0.1]] + + mock_embedding_model.get_embeddings.side_effect = get_embeddings + + docs = [{"content": "a"}] + + with patch.object(elasticsearch_core_instance.client, "bulk") as mock_bulk, \ + patch.object(elasticsearch_core_instance, "_force_refresh_with_retry"), \ + patch("time.sleep", lambda *a, **k: None): + mock_bulk.return_value = {"errors": False, "items": []} + elasticsearch_core_instance._large_batch_insert( + "idx", docs, batch_size=1, content_field="content", + embedding_model=mock_embedding_model, embedding_batch_size=1, + ) + + assert call_counter["n"] == 3 + assert any("Embedding API error (attempt 1/3)" in m for m in caplog.messages) + def test_delete_documents_success(elasticsearch_core_instance): """Test deleting documents by path_or_url successfully.""" @@ -1134,8 +1245,12 @@ def test_handle_bulk_errors_with_errors(elasticsearch_core_instance): ] } - # Should not raise exception, just log errors - elasticsearch_core_instance._handle_bulk_errors(response) + with pytest.raises(Exception) as exc_info: + elasticsearch_core_instance._handle_bulk_errors(response) + + err_payload = str(exc_info.value) + assert "Bulk indexing failed: Failed to parse mapping" in err_payload + assert "es_bulk_failed" in err_payload def test_handle_bulk_errors_version_conflict(elasticsearch_core_instance): @@ -1158,6 +1273,40 @@ def test_handle_bulk_errors_version_conflict(elasticsearch_core_instance): elasticsearch_core_instance._handle_bulk_errors(response) +def test_handle_bulk_errors_skips_items_without_error(elasticsearch_core_instance): + """Items without error key should be ignored.""" + response = { + "errors": True, + "items": [{"index": {}}], + } + # Should not raise + elasticsearch_core_instance._handle_bulk_errors(response) + + +def test_handle_bulk_errors_dim_mismatch_sets_specific_code(elasticsearch_core_instance): + """Dense vector dimension mismatch should produce es_dim_mismatch code.""" + response = { + "errors": True, + "items": [ + { + "index": { + "error": { + "type": "illegal_argument_exception", + "reason": "field [embedding] has different number of dimensions than vector", + "caused_by": {"reason": "dense_vector different number of dimensions"}, + } + } + } + ], + } + + with pytest.raises(Exception) as exc_info: + elasticsearch_core_instance._handle_bulk_errors(response) + + payload = str(exc_info.value) + assert "es_dim_mismatch" in payload + assert "Bulk indexing failed" in payload + def test_bulk_operation_context(elasticsearch_core_instance): """Test bulk operation context manager.""" with patch.object(elasticsearch_core_instance, '_apply_bulk_settings') as mock_apply, \ diff --git a/test/sdk/vector_database/test_elasticsearch_core_coverage.py b/test/sdk/vector_database/test_elasticsearch_core_coverage.py index f307c9d84..757bbc566 100644 --- a/test/sdk/vector_database/test_elasticsearch_core_coverage.py +++ b/test/sdk/vector_database/test_elasticsearch_core_coverage.py @@ -215,8 +215,9 @@ def test_handle_bulk_errors_with_fatal_error(self, vdb_core): } ] } - vdb_core._handle_bulk_errors(response) - # Should log error but not raise exception + with pytest.raises(Exception) as exc_info: + vdb_core._handle_bulk_errors(response) + assert "Bulk indexing failed" in str(exc_info.value) def test_handle_bulk_errors_with_caused_by(self, vdb_core): """Test _handle_bulk_errors with caused_by information""" @@ -237,8 +238,10 @@ def test_handle_bulk_errors_with_caused_by(self, vdb_core): } ] } - vdb_core._handle_bulk_errors(response) - # Should log both main error and caused_by error + with pytest.raises(Exception) as exc_info: + vdb_core._handle_bulk_errors(response) + assert "Invalid argument" in str(exc_info.value) + assert "JSON parsing failed" in str(exc_info.value) def test_delete_documents_success(self, vdb_core): """Test delete_documents successful case""" @@ -407,16 +410,18 @@ def test_large_batch_insert_bulk_exception(self, vdb_core): mock_embedding_model = MagicMock() mock_embedding_model.get_embeddings.return_value = [[0.1]] - result = vdb_core._large_batch_insert("idx", [{"content": "body"}], 1, "content", mock_embedding_model) - assert result == 0 + with pytest.raises(Exception) as exc_info: + vdb_core._large_batch_insert("idx", [{"content": "body"}], 1, "content", mock_embedding_model) + assert "bulk error" in str(exc_info.value) def test_large_batch_insert_preprocess_exception(self, vdb_core): """Ensure outer exception handler returns zero on preprocess failure.""" vdb_core._preprocess_documents = MagicMock(side_effect=Exception("fail")) mock_embedding_model = MagicMock() - result = vdb_core._large_batch_insert("idx", [{"content": "body"}], 10, "content", mock_embedding_model) - assert result == 0 + with pytest.raises(Exception) as exc_info: + vdb_core._large_batch_insert("idx", [{"content": "body"}], 10, "content", mock_embedding_model) + assert "fail" in str(exc_info.value) def test_count_documents_success(self, vdb_core): """Ensure count_documents returns ES count.""" @@ -672,8 +677,9 @@ def test_small_batch_insert_exception(self, vdb_core): mock_embedding_model = MagicMock() documents = [{"content": "test content", "title": "test"}] - result = vdb_core._small_batch_insert("test_index", documents, "content", mock_embedding_model) - assert result == 0 + with pytest.raises(Exception) as exc_info: + vdb_core._small_batch_insert("test_index", documents, "content", mock_embedding_model) + assert "Preprocess error" in str(exc_info.value) def test_large_batch_insert_success(self, vdb_core): """Test _large_batch_insert successful case"""