diff --git a/backend/agents/create_agent_info.py b/backend/agents/create_agent_info.py index 8a59e5838..3cae4f1c0 100644 --- a/backend/agents/create_agent_info.py +++ b/backend/agents/create_agent_info.py @@ -9,7 +9,11 @@ from nexent.core.agents.agent_model import AgentRunInfo, ModelConfig, AgentConfig, ToolConfig from nexent.memory.memory_service import search_memory_in_levels -from services.elasticsearch_service import ElasticSearchService, elastic_core, get_embedding_model +from services.vectordatabase_service import ( + ElasticSearchService, + get_vector_db_core, + get_embedding_model, +) from services.tenant_config_service import get_selected_knowledge_list from services.remote_mcp_service import get_remote_mcp_server_list from services.memory_config_service import build_memory_context @@ -227,9 +231,11 @@ async def create_tool_config_list(agent_id, tenant_id, user_id): tenant_id=tenant_id, user_id=user_id) index_names = [knowledge_info.get( "index_name") for knowledge_info in knowledge_info_list] - tool_config.metadata = {"index_names": index_names, - "es_core": elastic_core, - "embedding_model": get_embedding_model(tenant_id=tenant_id)} + tool_config.metadata = { + "index_names": index_names, + "vdb_core": get_vector_db_core(), + "embedding_model": get_embedding_model(tenant_id=tenant_id), + } tool_config_list.append(tool_config) return tool_config_list diff --git a/backend/apps/agent_app.py b/backend/apps/agent_app.py index 698b2ea5f..3e63cb357 100644 --- a/backend/apps/agent_app.py +++ b/backend/apps/agent_app.py @@ -22,12 +22,13 @@ # Import monitoring utilities from utils.monitoring import monitoring_manager -router = APIRouter(prefix="/agent") +agent_runtime_router = APIRouter(prefix="/agent") +agent_config_router = APIRouter(prefix="/agent") logger = logging.getLogger("agent_app") # Define API route -@router.post("/run") +@agent_runtime_router.post("/run") @monitoring_manager.monitor_endpoint("agent.run", exclude_params=["authorization"]) async def agent_run_api(agent_request: AgentRequest, http_request: Request, authorization: str = Header(None)): """ @@ -45,7 +46,7 @@ async def agent_run_api(agent_request: AgentRequest, http_request: Request, auth status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Agent run error.") -@router.get("/stop/{conversation_id}") +@agent_runtime_router.get("/stop/{conversation_id}") async def agent_stop_api(conversation_id: int, authorization: Optional[str] = Header(None)): """ stop agent run and preprocess tasks for specified conversation_id @@ -58,7 +59,7 @@ async def agent_stop_api(conversation_id: int, authorization: Optional[str] = He detail=f"no running agent or preprocess tasks found for conversation_id {conversation_id}") -@router.post("/search_info") +@agent_config_router.post("/search_info") async def search_agent_info_api(agent_id: int = Body(...), authorization: Optional[str] = Header(None)): """ Search agent info by agent_id @@ -72,7 +73,7 @@ async def search_agent_info_api(agent_id: int = Body(...), authorization: Option status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Agent search info error.") -@router.get("/get_creating_sub_agent_id") +@agent_config_router.get("/get_creating_sub_agent_id") async def get_creating_sub_agent_info_api(authorization: Optional[str] = Header(None)): """ Create a new sub agent, return agent_ID @@ -85,7 +86,7 @@ async def get_creating_sub_agent_info_api(authorization: Optional[str] = Header( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Agent create error.") -@router.post("/update") +@agent_config_router.post("/update") async def update_agent_info_api(request: AgentInfoRequest, authorization: Optional[str] = Header(None)): """ Update an existing agent @@ -99,7 +100,7 @@ async def update_agent_info_api(request: AgentInfoRequest, authorization: Option status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Agent update error.") -@router.delete("") +@agent_config_router.delete("") async def delete_agent_api(request: AgentIDRequest, authorization: Optional[str] = Header(None)): """ Delete an agent @@ -113,7 +114,7 @@ async def delete_agent_api(request: AgentIDRequest, authorization: Optional[str] status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Agent delete error.") -@router.post("/export") +@agent_config_router.post("/export") async def export_agent_api(request: AgentIDRequest, authorization: Optional[str] = Header(None)): """ export an agent @@ -127,7 +128,7 @@ async def export_agent_api(request: AgentIDRequest, authorization: Optional[str] status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Agent export error.") -@router.post("/import") +@agent_config_router.post("/import") async def import_agent_api(request: AgentImportRequest, authorization: Optional[str] = Header(None)): """ import an agent @@ -141,7 +142,7 @@ async def import_agent_api(request: AgentImportRequest, authorization: Optional[ status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Agent import error.") -@router.get("/list") +@agent_config_router.get("/list") async def list_all_agent_info_api(authorization: Optional[str] = Header(None), request: Request = None): """ list all agent info @@ -155,7 +156,7 @@ async def list_all_agent_info_api(authorization: Optional[str] = Header(None), r status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Agent list error.") -@router.get("/call_relationship/{agent_id}") +@agent_config_router.get("/call_relationship/{agent_id}") async def get_agent_call_relationship_api(agent_id: int, authorization: Optional[str] = Header(None)): """ Get agent call relationship tree including tools and sub-agents diff --git a/backend/apps/base_app.py b/backend/apps/config_app.py similarity index 84% rename from backend/apps/base_app.py rename to backend/apps/config_app.py index 4683a5d39..eb2c824c1 100644 --- a/backend/apps/base_app.py +++ b/backend/apps/config_app.py @@ -4,14 +4,12 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse -from apps.agent_app import router as agent_router +from apps.agent_app import agent_config_router as agent_router from apps.config_sync_app import router as config_sync_router -from apps.conversation_management_app import router as conversation_management_router -from apps.elasticsearch_app import router as elasticsearch_router -from apps.file_management_app import router as file_manager_router +from apps.vectordatabase_app import router as vectordatabase_router +from apps.file_management_app import file_management_config_router as file_manager_router from apps.image_app import router as proxy_router from apps.knowledge_summary_app import router as summary_router -from apps.memory_config_app import router as memory_router from apps.me_model_managment_app import router as me_model_manager_router from apps.mock_user_management_app import router as mock_user_management_router from apps.model_managment_app import router as model_manager_router @@ -20,7 +18,7 @@ from apps.tenant_config_app import router as tenant_config_router from apps.tool_config_app import router as tool_config_router from apps.user_management_app import router as user_management_router -from apps.voice_app import router as voice_router +from apps.voice_app import voice_config_router as voice_router from consts.const import IS_SPEED_MODE # Import monitoring utilities @@ -41,11 +39,9 @@ app.include_router(me_model_manager_router) app.include_router(model_manager_router) -app.include_router(memory_router) app.include_router(config_sync_router) app.include_router(agent_router) -app.include_router(conversation_management_router) -app.include_router(elasticsearch_router) +app.include_router(vectordatabase_router) app.include_router(voice_router) app.include_router(file_manager_router) app.include_router(proxy_router) diff --git a/backend/apps/file_management_app.py b/backend/apps/file_management_app.py index e1f2f6b52..448b03a61 100644 --- a/backend/apps/file_management_app.py +++ b/backend/apps/file_management_app.py @@ -16,11 +16,12 @@ logger = logging.getLogger("file_management_app") # Create API router -router = APIRouter(prefix="/file") +file_management_runtime_router = APIRouter(prefix="/file") +file_management_config_router = APIRouter(prefix="/file") # Handle preflight requests -@router.options("/{full_path:path}") +@file_management_config_router.options("/{full_path:path}") async def options_route(full_path: str): return JSONResponse( status_code=HTTPStatus.OK, @@ -28,7 +29,7 @@ async def options_route(full_path: str): ) -@router.post("/upload") +@file_management_config_router.post("/upload") async def upload_files( file: List[UploadFile] = File(..., alias="file"), destination: str = Form(..., @@ -59,7 +60,7 @@ async def upload_files( detail="No valid files uploaded") -@router.post("/process") +@file_management_config_router.post("/process") async def process_files( files: List[dict] = Body( ..., description="List of file details to process, including path_or_url and filename"), @@ -100,7 +101,7 @@ async def process_files( ) -@router.post("/storage") +@file_management_runtime_router.post("/storage") async def storage_upload_files( files: List[UploadFile] = File(..., description="List of files to upload"), folder: str = Form( @@ -125,7 +126,7 @@ async def storage_upload_files( } -@router.get("/storage") +@file_management_config_router.get("/storage") async def get_storage_files( prefix: str = Query("", description="File prefix filter"), limit: int = Query(100, description="Maximum number of files to return"), @@ -160,7 +161,7 @@ async def get_storage_files( ) -@router.get("/storage/{path}/{object_name}") +@file_management_config_router.get("/storage/{path}/{object_name}") async def get_storage_file( object_name: str = PathParam(..., description="File object name"), download: str = Query("ignore", description="How to get the file"), @@ -200,7 +201,7 @@ async def get_storage_file( ) -@router.delete("/storage/{object_name:path}") +@file_management_config_router.delete("/storage/{object_name:path}") async def remove_storage_file( object_name: str = PathParam(..., description="File object name to delete") ): @@ -224,7 +225,7 @@ async def remove_storage_file( ) -@router.post("/storage/batch-urls") +@file_management_config_router.post("/storage/batch-urls") async def get_storage_file_batch_urls( request_data: dict = Body(..., description="JSON containing list of file object names"), @@ -272,7 +273,7 @@ async def get_storage_file_batch_urls( } -@router.post("/preprocess") +@file_management_runtime_router.post("/preprocess") async def agent_preprocess_api( request: Request, query: str = Form(...), files: List[UploadFile] = File(...), diff --git a/backend/apps/knowledge_summary_app.py b/backend/apps/knowledge_summary_app.py index 526f4b5ac..e4e11ace9 100644 --- a/backend/apps/knowledge_summary_app.py +++ b/backend/apps/knowledge_summary_app.py @@ -3,10 +3,10 @@ from fastapi import APIRouter, Body, Depends, Header, HTTPException, Path, Query, Request from fastapi.responses import StreamingResponse -from nexent.vector_database.elasticsearch_core import ElasticSearchCore +from nexent.vector_database.base import VectorDatabaseCore from consts.model import ChangeSummaryRequest -from services.elasticsearch_service import ElasticSearchService, get_es_core +from services.vectordatabase_service import ElasticSearchService, get_vector_db_core from utils.auth_utils import get_current_user_id, get_current_user_info router = APIRouter(prefix="/summary") @@ -22,7 +22,7 @@ async def auto_summary( 1000, description="Number of documents to retrieve per batch"), model_id: Optional[int] = Query( None, description="Model ID to use for summary generation"), - es_core: ElasticSearchCore = Depends(get_es_core), + vdb_core: VectorDatabaseCore = Depends(get_vector_db_core), authorization: Optional[str] = Header(None) ): """Summary Elasticsearch index_name by model""" @@ -34,15 +34,16 @@ async def auto_summary( return await service.summary_index_name( index_name=index_name, batch_size=batch_size, - es_core=es_core, + vdb_core=vdb_core, tenant_id=tenant_id, language=language, model_id=model_id ) except Exception as e: - logger.error("Knowledge base summary generation failed", exc_info=True) + logger.error( + f"Knowledge base summary generation failed: {e}", exc_info=True) return StreamingResponse( - f"data: {{\"status\": \"error\", \"message\": \"Knowledge base summary generation failed due to an internal error.\"}}\n\n", + "data: {{\"status\": \"error\", \"message\": \"Knowledge base summary generation failed due to an internal error.\"}}\n\n", media_type="text/event-stream", status_code=500 ) diff --git a/backend/apps/prompt_app.py b/backend/apps/prompt_app.py index 0512bd0e7..7c0b799dc 100644 --- a/backend/apps/prompt_app.py +++ b/backend/apps/prompt_app.py @@ -27,7 +27,9 @@ async def generate_and_save_system_prompt_api( task_description=prompt_request.task_description, user_id=user_id, tenant_id=tenant_id, - language=language + language=language, + tool_ids=prompt_request.tool_ids, + sub_agent_ids=prompt_request.sub_agent_ids ), media_type="text/event-stream") except Exception as e: logger.exception(f"Error occurred while generating system prompt: {e}") diff --git a/backend/apps/runtime_app.py b/backend/apps/runtime_app.py new file mode 100644 index 000000000..6db480ab6 --- /dev/null +++ b/backend/apps/runtime_app.py @@ -0,0 +1,58 @@ +import logging + +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse + +from apps.agent_app import agent_runtime_router as agent_router +from apps.voice_app import voice_runtime_router as voice_router +from apps.conversation_management_app import router as conversation_management_router +from apps.memory_config_app import router as memory_config_router +from apps.file_management_app import file_management_runtime_router as file_management_router + +# Import monitoring utilities +from utils.monitoring import monitoring_manager + +# Create logger instance +logger = logging.getLogger("runtime_app") +app = FastAPI(root_path="/api") + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +app.include_router(agent_router) +app.include_router(conversation_management_router) +app.include_router(memory_config_router) +app.include_router(file_management_router) +app.include_router(voice_router) + +# Initialize monitoring for the application +monitoring_manager.setup_fastapi_app(app) + + +# Global exception handler for HTTP exceptions +@app.exception_handler(HTTPException) +async def http_exception_handler(request, exc): + logger.error(f"HTTPException: {exc.detail}") + return JSONResponse( + status_code=exc.status_code, + content={"message": exc.detail}, + ) + + +# Global exception handler for all uncaught exceptions +@app.exception_handler(Exception) +async def generic_exception_handler(request, exc): + logger.error(f"Generic Exception: {exc}") + return JSONResponse( + status_code=500, + content={"message": "Internal server error, please try again later."}, + ) + + diff --git a/backend/apps/elasticsearch_app.py b/backend/apps/vectordatabase_app.py similarity index 74% rename from backend/apps/elasticsearch_app.py rename to backend/apps/vectordatabase_app.py index 5fe9a1a21..b2b410264 100644 --- a/backend/apps/elasticsearch_app.py +++ b/backend/apps/vectordatabase_app.py @@ -3,29 +3,34 @@ from typing import Any, Dict, List, Optional from fastapi import APIRouter, Body, Depends, Header, HTTPException, Path, Query +from fastapi.responses import JSONResponse from consts.model import IndexingResponse -from nexent.vector_database.elasticsearch_core import ElasticSearchCore -from services.elasticsearch_service import ElasticSearchService, get_embedding_model, get_es_core, \ - check_knowledge_base_exist_impl +from nexent.vector_database.base import VectorDatabaseCore +from services.vectordatabase_service import ( + ElasticSearchService, + get_embedding_model, + get_vector_db_core, + check_knowledge_base_exist_impl, +) from services.redis_service import get_redis_service from utils.auth_utils import get_current_user_id router = APIRouter(prefix="/indices") service = ElasticSearchService() -logger = logging.getLogger("elasticsearch_app") +logger = logging.getLogger("vectordatabase_app") @router.get("/check_exist/{index_name}") async def check_knowledge_base_exist( index_name: str = Path(..., description="Name of the index to check"), - es_core: ElasticSearchCore = Depends(get_es_core), + vdb_core: VectorDatabaseCore = Depends(get_vector_db_core), authorization: Optional[str] = Header(None) ): """Check if a knowledge base name exists and in which scope.""" try: user_id, tenant_id = get_current_user_id(authorization) - return check_knowledge_base_exist_impl(index_name=index_name, es_core=es_core, user_id=user_id, tenant_id=tenant_id) + return check_knowledge_base_exist_impl(index_name=index_name, vdb_core=vdb_core, user_id=user_id, tenant_id=tenant_id) except Exception as e: logger.error( f"Error checking knowledge base existence for '{index_name}': {str(e)}", exc_info=True) @@ -38,13 +43,13 @@ def create_new_index( index_name: str = Path(..., description="Name of the index to create"), embedding_dim: Optional[int] = Query( None, description="Dimension of the embedding vectors"), - es_core: ElasticSearchCore = Depends(get_es_core), + vdb_core: VectorDatabaseCore = Depends(get_vector_db_core), authorization: Optional[str] = Header(None) ): """Create a new vector index and store it in the knowledge table""" try: user_id, tenant_id = get_current_user_id(authorization) - return ElasticSearchService.create_index(index_name, embedding_dim, es_core, user_id, tenant_id) + return ElasticSearchService.create_index(index_name, embedding_dim, vdb_core, user_id, tenant_id) except Exception as e: raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=f"Error creating index: {str(e)}") @@ -53,7 +58,7 @@ def create_new_index( @router.delete("/{index_name}") async def delete_index( index_name: str = Path(..., description="Name of the index to delete"), - es_core: ElasticSearchCore = Depends(get_es_core), + vdb_core: VectorDatabaseCore = Depends(get_vector_db_core), authorization: Optional[str] = Header(None) ): """Delete an index and all its related data by calling the centralized service.""" @@ -61,7 +66,7 @@ async def delete_index( try: user_id, tenant_id = get_current_user_id(authorization) # Call the centralized full deletion service - result = await ElasticSearchService.full_delete_knowledge_base(index_name, es_core, user_id) + result = await ElasticSearchService.full_delete_knowledge_base(index_name, vdb_core, user_id) return result except Exception as e: logger.error( @@ -75,13 +80,13 @@ def get_list_indices( pattern: str = Query("*", description="Pattern to match index names"), include_stats: bool = Query( False, description="Whether to include index stats"), - es_core: ElasticSearchCore = Depends(get_es_core), + vdb_core: VectorDatabaseCore = Depends(get_vector_db_core), authorization: Optional[str] = Header(None), ): """List all user indices with optional stats""" try: user_id, tenant_id = get_current_user_id(authorization) - return ElasticSearchService.list_indices(pattern, include_stats, tenant_id, user_id, es_core) + return ElasticSearchService.list_indices(pattern, include_stats, tenant_id, user_id, vdb_core) except Exception as e: raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=f"Error get index: {str(e)}") @@ -93,7 +98,7 @@ def create_index_documents( index_name: str = Path(..., description="Name of the index"), data: List[Dict[str, Any] ] = Body(..., description="Document List to process"), - es_core: ElasticSearchCore = Depends(get_es_core), + vdb_core: VectorDatabaseCore = Depends(get_vector_db_core), authorization: Optional[str] = Header(None) ): """ @@ -103,7 +108,7 @@ def create_index_documents( try: user_id, tenant_id = get_current_user_id(authorization) embedding_model = get_embedding_model(tenant_id) - return ElasticSearchService.index_documents(embedding_model, index_name, data, es_core) + return ElasticSearchService.index_documents(embedding_model, index_name, data, vdb_core) except Exception as e: error_msg = str(e) logger.error(f"Error indexing documents: {error_msg}") @@ -114,11 +119,11 @@ def create_index_documents( @router.get("/{index_name}/files") async def get_index_files( index_name: str = Path(..., description="Name of the index"), - es_core: ElasticSearchCore = Depends(get_es_core) + vdb_core: VectorDatabaseCore = Depends(get_vector_db_core) ): """Get all files from an index, including those that are not yet stored in ES""" try: - result = await ElasticSearchService.list_files(index_name, include_chunks=False, es_core=es_core) + result = await ElasticSearchService.list_files(index_name, include_chunks=False, vdb_core=vdb_core) # Transform result to match frontend expectations return { "status": "success", @@ -136,13 +141,13 @@ def delete_documents( index_name: str = Path(..., description="Name of the index"), path_or_url: str = Query(..., description="Path or URL of documents to delete"), - es_core: ElasticSearchCore = Depends(get_es_core) + vdb_core: VectorDatabaseCore = Depends(get_vector_db_core) ): """Delete documents by path or URL and clean up related Redis records""" try: # First delete the documents using existing service result = ElasticSearchService.delete_documents( - index_name, path_or_url, es_core) + index_name, path_or_url, vdb_core) # Then clean up Redis records related to this specific document try: @@ -184,10 +189,40 @@ def delete_documents( # Health check @router.get("/health") -def health_check(es_core: ElasticSearchCore = Depends(get_es_core)): +def health_check(vdb_core: VectorDatabaseCore = Depends(get_vector_db_core)): """Check API and Elasticsearch health""" try: # Try to list indices as a health check - return ElasticSearchService.health_check(es_core) + return ElasticSearchService.health_check(vdb_core) except Exception as e: raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=f"{str(e)}") + + +@router.post("/{index_name}/chunks") +def get_index_chunks( + index_name: str = Path(..., + description="Name of the index to get chunks from"), + page: int = Query( + None, description="Page number (1-based) for pagination"), + page_size: int = Query( + None, description="Number of records per page for pagination"), + path_or_url: Optional[str] = Query( + None, description="Filter chunks by document path_or_url"), + vdb_core: VectorDatabaseCore = Depends(get_vector_db_core) +): + """Get chunks from the specified index, with optional pagination support""" + try: + result = ElasticSearchService.get_index_chunks( + index_name=index_name, + page=page, + page_size=page_size, + path_or_url=path_or_url, + vdb_core=vdb_core, + ) + return JSONResponse(status_code=HTTPStatus.OK, content=result) + except Exception as e: + error_msg = str(e) + logger.error( + f"Error getting chunks for index '{index_name}': {error_msg}") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=f"Error getting chunks: {error_msg}") diff --git a/backend/apps/voice_app.py b/backend/apps/voice_app.py index 837101442..8f517cd07 100644 --- a/backend/apps/voice_app.py +++ b/backend/apps/voice_app.py @@ -16,10 +16,11 @@ logger = logging.getLogger("voice_app") -router = APIRouter(prefix="/voice") +voice_runtime_router = APIRouter(prefix="/voice") +voice_config_router = APIRouter(prefix="/voice") -@router.websocket("/stt/ws") +@voice_runtime_router.websocket("/stt/ws") async def stt_websocket(websocket: WebSocket): """WebSocket endpoint for real-time audio streaming and STT""" logger.info("STT WebSocket connection attempt...") @@ -39,7 +40,7 @@ async def stt_websocket(websocket: WebSocket): logger.info("STT WebSocket connection closed") -@router.websocket("/tts/ws") +@voice_runtime_router.websocket("/tts/ws") async def tts_websocket(websocket: WebSocket): """WebSocket endpoint for streaming TTS""" logger.info("TTS WebSocket connection attempt...") @@ -73,7 +74,7 @@ async def tts_websocket(websocket: WebSocket): await websocket.close() -@router.post("/connectivity") +@voice_config_router.post("/connectivity") async def check_voice_connectivity(request: VoiceConnectivityRequest): """ Check voice service connectivity diff --git a/backend/main_service.py b/backend/config_service.py similarity index 94% rename from backend/main_service.py rename to backend/config_service.py index edb2a3831..f98c7b155 100644 --- a/backend/main_service.py +++ b/backend/config_service.py @@ -10,13 +10,13 @@ from dotenv import load_dotenv load_dotenv() -from apps.base_app import app +from apps.config_app import app from utils.logging_utils import configure_logging, configure_elasticsearch_logging from services.tool_configuration_service import initialize_tools_on_startup configure_logging(logging.INFO) configure_elasticsearch_logging() -logger = logging.getLogger("main_service") +logger = logging.getLogger("config_service") async def startup_initialization(): diff --git a/backend/consts/const.py b/backend/consts/const.py index bb011d26d..53dac1068 100644 --- a/backend/consts/const.py +++ b/backend/consts/const.py @@ -1,4 +1,5 @@ import os +from enum import Enum from dotenv import load_dotenv # Load environment variables @@ -10,6 +11,11 @@ os.path.dirname(__file__)), 'assets', 'test.wav') +# Vector database providers +class VectorDatabaseType(str, Enum): + ELASTICSEARCH = "elasticsearch" + + # ModelEngine Configuration MODEL_ENGINE_HOST = os.getenv('MODEL_ENGINE_HOST') MODEL_ENGINE_APIKEY = os.getenv('MODEL_ENGINE_APIKEY') @@ -273,7 +279,7 @@ os.getenv("LLM_SLOW_TOKEN_RATE_THRESHOLD", "10.0")) # tokens per second # APP Version -APP_VERSION = "v1.7.5.3" +APP_VERSION = "v1.7.6" DEFAULT_ZH_TITLE = "新对话" DEFAULT_EN_TITLE = "New Conversation" diff --git a/backend/consts/model.py b/backend/consts/model.py index 66c891d8f..2d2208a91 100644 --- a/backend/consts/model.py +++ b/backend/consts/model.py @@ -193,6 +193,8 @@ class GeneratePromptRequest(BaseModel): task_description: str agent_id: int model_id: int + tool_ids: Optional[List[int]] = None # Optional: tool IDs from frontend (takes precedence over database query) + sub_agent_ids: Optional[List[int]] = None # Optional: sub-agent IDs from frontend (takes precedence over database query) class GenerateTitleRequest(BaseModel): diff --git a/backend/data_process/ray_config.py b/backend/data_process/ray_config.py index 179ea6c20..8ae85bd58 100644 --- a/backend/data_process/ray_config.py +++ b/backend/data_process/ray_config.py @@ -212,7 +212,6 @@ def start_local_cluster( def log_configuration(self): """Log current configuration information""" logger.debug("Ray Configuration:") - logger.debug(f" Plasma directory: {self.plasma_directory}") logger.debug(f" ObjectStore memory: {self.object_store_memory_gb} GB") logger.debug(f" Temp directory: {self.temp_dir}") logger.debug(f" Preallocate plasma: {self.preallocate_plasma}") diff --git a/backend/database/agent_db.py b/backend/database/agent_db.py index 6a4816d25..51169d678 100644 --- a/backend/database/agent_db.py +++ b/backend/database/agent_db.py @@ -76,14 +76,15 @@ def create_agent(agent_info, tenant_id: str, user_id: str): :param user_id: :return: Created agent object """ - agent_info.update({ + info_with_metadata = dict(agent_info) + info_with_metadata.setdefault("max_steps", 5) + info_with_metadata.update({ "tenant_id": tenant_id, "created_by": user_id, "updated_by": user_id, - "max_steps": 5 }) with get_db_session() as session: - new_agent = AgentInfo(**filter_property(agent_info, AgentInfo)) + new_agent = AgentInfo(**filter_property(info_with_metadata, AgentInfo)) new_agent.delete_flag = 'N' session.add(new_agent) session.flush() diff --git a/backend/nexent_mcp_service.py b/backend/mcp_service.py similarity index 92% rename from backend/nexent_mcp_service.py rename to backend/mcp_service.py index a04e4f9b1..c36d476ca 100644 --- a/backend/nexent_mcp_service.py +++ b/backend/mcp_service.py @@ -10,7 +10,7 @@ """ configure_logging(logging.INFO) -logger = logging.getLogger("nexent_mcp_service") +logger = logging.getLogger("mcp_service") # initialize main mcp service nexent_mcp = FastMCP(name="nexent_mcp") diff --git a/backend/prompts/managed_system_prompt_template.yaml b/backend/prompts/managed_system_prompt_template.yaml index 4848029a4..b89dcc405 100644 --- a/backend/prompts/managed_system_prompt_template.yaml +++ b/backend/prompts/managed_system_prompt_template.yaml @@ -61,8 +61,8 @@ system_prompt: |- - 用简单的Python编写代码 - 遵循python代码规范和python语法 - 根据格式规范正确调用工具 - - 考虑到代码执行与展示用户代码的区别,使用'代码:\n```\n'开头,并以'```'表达运行代码,使用'代码:\n```\n'开头,并以'```'表达展示代码 - - 注意运行的代码不会被用户看到,所以如果用户需要看到代码,你需要使用'代码:\n```\n'开头,并以'```'表达展示代码。 + - 考虑到代码执行与展示用户代码的区别,使用'代码:\n```\n'开头,并以'```'表达运行代码,使用'代码:\n```\n'开头,并以'```'表达展示代码 + - 注意运行的代码不会被用户看到,所以如果用户需要看到代码,你需要使用'代码:\n```\n'开头,并以'```'表达展示代码。 3. 观察结果: - 查看代码执行结果 @@ -70,13 +70,25 @@ system_prompt: |- 在思考结束后,当你认为可以回答用户问题,那么可以不生成代码,直接生成最终回答给到用户并停止循环。 生成最终回答时,你需要遵循以下规范: - 1. 使用Markdown格式格式化你的输出。 - 2. 若使用了检索工具获取到具体信息并基于这些信息回答问题,则需在回答的对应位置添加引用标记: - - 引用标记的字母和数字需要与检索工具的检索结果一一对应 - - 引用标记格式为'[[对应字母+数字]]',例如:'[[a1]][[b2]][[c3]]' + 1. **Markdown格式要求**: + - 使用标准Markdown语法格式化输出,支持标题、列表、表格、代码块、链接等 + - 展示图片和视频使用链接方式,不需要外套代码块,格式:[链接文本](URL),图片格式:![alt文本](图片URL),视频格式: + - 段落之间使用单个空行分隔,避免多个连续空行 + - 数学公式使用标准Markdown格式:行内公式用 $公式$,块级公式用 $$公式$$ + + 2. **引用标记规范**(仅在使用了检索工具时): + - 引用标记格式必须严格为:`[[字母+数字]]`,例如:`[[a1]]`、`[[b2]]`、`[[c3]]` + - 字母部分必须是单个小写字母(a-e),数字部分必须是整数 + - 引用标记的字母和数字必须与检索工具的检索结果一一对应 - 引用标记应紧跟在相关信息或句子之后,通常放在句末或段落末尾 - - 注意仅添加引用标记,不需要添加链接、参考文献等多余内容 - 3. 若未使用检索工具,则不添加任何引用标记 + - 多个引用标记可以连续使用,例如:`[[a1]][[b2]]` + - **重要**:仅添加引用标记,不要添加链接、参考文献列表等多余内容 + - 如果检索结果中没有匹配的引用,则不显示该引用标记 + + 3. **格式细节要求**: + - 避免在Markdown中使用HTML标签,优先使用Markdown原生语法 + - 代码块中的代码应保持原始格式,不要添加额外的转义字符 + - 若未使用检索工具,则不添加任何引用标记 注意最后生成的回答要语义连贯,信息清晰,可读性高。 @@ -101,7 +113,7 @@ system_prompt: |- {{ constraint }} ### python代码规范 - 1. 如果认为是需要执行的代码,代码内容以'代码:\n```\n'开头,并以'```'标识符结尾。如果是不需要执行仅用于展示的代码,代码内容以'代码:\n```\n'开头,并以'```'标识符结尾,其中语言类型例如python、java、javascript等; + 1. 如果认为是需要执行的代码,代码内容以'代码:\n```\n'开头,并以'```'标识符结尾。如果是不需要执行仅用于展示的代码,代码内容以'代码:\n```\n'开头,并以'```'标识符结尾,其中语言类型例如python、java、javascript等; 2. 只使用已定义的变量,变量将在多次调用之间持续保持; 3. 使用“print()”函数让下一次的模型调用看到对应变量信息; 4. 正确使用工具的入参,使用关键字参数,不要用字典形式; diff --git a/backend/prompts/managed_system_prompt_template_en.yaml b/backend/prompts/managed_system_prompt_template_en.yaml index 82381e461..9c3a2799c 100644 --- a/backend/prompts/managed_system_prompt_template_en.yaml +++ b/backend/prompts/managed_system_prompt_template_en.yaml @@ -61,8 +61,8 @@ system_prompt: |- - Write code in simple Python - Follow Python coding standards and Python syntax - Call tools correctly according to format specifications - - To distinguish between code execution and displaying user code, use 'Code: \n```\n' to start executing code and '```' to indicate its completion. Use 'Code: \n```\n' to start displaying code and '```' to indicate its completion. - - Note that executed code is not visible to users. If users need to see the code, use 'Code: \n```\n' as the start and '```' to denote displayed code. + - To distinguish between code execution and displaying user code, use 'Code: \n```\n' to start executing code and '```' to indicate its completion. Use 'Code: \n```\n' to start displaying code and '```' to indicate its completion. + - Note that executed code is not visible to users. If users need to see the code, use 'Code: \n```\n' as the start and '```' to denote displayed code. 3. Observe Results: - View code execution results @@ -70,13 +70,25 @@ system_prompt: |- After thinking, when you believe you can answer the user's question, you can generate a final answer directly to the user without generating code and stop the loop. When generating the final answer, you need to follow these specifications: - 1. Use Markdown format to format your output. - 2. If you have used retrieval tools to obtain specific information and answer questions based on this information, you need to add reference marks at the corresponding positions in your answer: - - The letters and numbers of the reference marks need to correspond one-to-one with the retrieval results of the retrieval tools - - The reference mark format is '[[corresponding letter+number]]', for example: '[[a1]][[b2]][[c3]]' - - Reference marks should be placed immediately after the relevant information or sentence, usually at the end of the sentence or paragraph - - Note that only reference marks need to be added, no need to add links, references, or other extraneous content - 3. If no retrieval tools are used, do not add any reference marks + 1. **Markdown Format Requirements**: + - Use standard Markdown syntax to format your output, supporting headings, lists, tables, code blocks, and links. + - Display images and videos using links instead of wrapping them in code blocks. Use `[link text](URL)` for links, `![alt text](image URL)` for images, and `` for videos. + - Use a single blank line between paragraphs, avoid multiple consecutive blank lines + - Mathematical formulas use standard Markdown format: inline formulas use $formula$, block formulas use $$formula$$ + + 2. **Reference Mark Specifications** (only when retrieval tools are used): + - Reference mark format must strictly be: `[[letter+number]]`, for example: `[[a1]]`, `[[b2]]`, `[[c3]]` + - The letter part must be a single lowercase letter (a-e), the number part must be an integer + - The letters and numbers of reference marks must correspond one-to-one with the retrieval results of retrieval tools + - Reference marks should be placed immediately after relevant information or sentences, usually at the end of sentences or paragraphs + - Multiple reference marks can be used consecutively, for example: `[[a1]][[b2]]` + - **Important**: Only add reference marks, do not add links, reference lists, or other extraneous content + - If there is no matching reference in the retrieval results, do not display that reference mark + + 3. **Format Detail Requirements**: + - Avoid using HTML tags in Markdown, prioritize native Markdown syntax + - Code in code blocks should maintain original format, do not add extra escape characters + - If no retrieval tools are used, do not add any reference marks Note that the final generated answer should be semantically coherent, with clear information and high readability. @@ -101,7 +113,7 @@ system_prompt: |- {{ constraint }} ### Python Code Specifications - 1. If it is considered to be code that needs to be executed, the code content begins with 'code: \n```\n' and ends with '```'. If the code does not need to be executed for display only, the code content begins with 'code:\n```\n', and ends with '```', where language_type can be python, java, javascript, etc; + 1. If it is considered to be code that needs to be executed, the code content begins with 'code: \n```\n' and ends with '```'. If the code does not need to be executed for display only, the code content begins with 'code:\n```\n', and ends with '```', where language_type can be python, java, javascript, etc; 2. Only use defined variables, variables will persist between multiple calls; 3. Use "print()" function to let the next model call see corresponding variable information; 4. Use tool input parameters correctly, use keyword arguments, not dictionary format; diff --git a/backend/prompts/manager_system_prompt_template.yaml b/backend/prompts/manager_system_prompt_template.yaml index 453e4919b..8effcd54a 100644 --- a/backend/prompts/manager_system_prompt_template.yaml +++ b/backend/prompts/manager_system_prompt_template.yaml @@ -62,8 +62,8 @@ system_prompt: |- - 用简单的Python编写代码 - 遵循python代码规范和python语法 - 正确调用工具或助手解决问题 - - 考虑到代码执行与展示用户代码的区别,使用'代码:\n```\n'开头,并以'```'表达运行代码,使用'代码:\n```\n'开头,并以'```'表达展示代码 - - 注意运行的代码不会被用户看到,所以如果用户需要看到代码,你需要使用'代码:\n```\n'开头,并以'```'表达展示代码。 + - 考虑到代码执行与展示用户代码的区别,使用'代码:\n```\n'开头,并以'```'表达运行代码,使用'代码:\n```\n'开头,并以'```'表达展示代码 + - 注意运行的代码不会被用户看到,所以如果用户需要看到代码,你需要使用'代码:\n```\n'开头,并以'```'表达展示代码。 3. 观察结果: - 查看代码执行结果 @@ -72,13 +72,25 @@ system_prompt: |- 在思考结束后,当你认为可以回答用户问题,那么可以不生成代码,直接生成最终回答给到用户并停止循环。 生成最终回答时,你需要遵循以下规范: - 1. 使用Markdown格式格式化你的输出。 - 2. 若使用了检索工具获取到具体信息并基于这些信息回答问题,则需在回答的对应位置添加引用标记: - - 引用标记的字母和数字需要与检索工具的检索结果一一对应 - - 引用标记格式为'[[对应字母+数字]]',例如:'[[a1]][[b2]][[c3]]' + 1. Markdown格式要求: + - 使用标准Markdown语法格式化输出,支持标题、列表、表格、代码块、链接等 + - 展示图片和视频使用链接方式,不需要外套代码块,格式:[链接文本](URL),图片格式:![alt文本](图片URL),视频格式: + - 段落之间使用单个空行分隔,避免多个连续空行 + - 数学公式使用标准Markdown格式:行内公式用 $公式$,块级公式用 $$公式$$ + + 2. 引用标记规范(仅在使用了检索工具时): + - 引用标记格式必须严格为:`[[字母+数字]]`,例如:`[[a1]]`、`[[b2]]`、`[[c3]]` + - 字母部分必须是单个小写字母(a-e),数字部分必须是整数 + - 引用标记的字母和数字必须与检索工具的检索结果一一对应 - 引用标记应紧跟在相关信息或句子之后,通常放在句末或段落末尾 - - 注意仅添加引用标记,不需要添加链接、参考文献等多余内容 - 3. 若未使用检索工具,则不添加任何引用标记 + - 多个引用标记可以连续使用,例如:`[[a1]][[b2]]` + - **重要**:仅添加引用标记,不要添加链接、参考文献列表等多余内容 + - 如果检索结果中没有匹配的引用,则不显示该引用标记 + + 3. 格式细节要求: + - 避免在Markdown中使用HTML标签,优先使用Markdown原生语法 + - 代码块中的代码应保持原始格式,不要添加额外的转义字符 + - 若未使用检索工具,则不添加任何引用标记 ### 可用资源 你只能使用以下资源,不得使用任何其他工具或助手: @@ -129,7 +141,7 @@ system_prompt: |- {{ constraint }} ### python代码规范 - 1. 如果认为是需要执行的代码,代码内容以'代码:\n```\n'开头,并以'```'标识符结尾。如果是不需要执行仅用于展示的代码,代码内容以'代码:\n```\n'开头,并以'```'标识符结尾,其中语言类型例如python、java、javascript等; + 1. 如果认为是需要执行的代码,代码内容以'代码:\n```\n'开头,并以'```'标识符结尾。如果是不需要执行仅用于展示的代码,代码内容以'代码:\n```\n'开头,并以'```'标识符结尾,其中语言类型例如python、java、javascript等; 2. 只使用已定义的变量,变量将在多次调用之间持续保持; 3. 使用“print()”函数让下一次的模型调用看到对应变量信息; 4. 正确使用工具/助手的入参,使用关键字参数,不要用字典形式; diff --git a/backend/prompts/manager_system_prompt_template_en.yaml b/backend/prompts/manager_system_prompt_template_en.yaml index 3df4f2665..8da048bfe 100644 --- a/backend/prompts/manager_system_prompt_template_en.yaml +++ b/backend/prompts/manager_system_prompt_template_en.yaml @@ -62,8 +62,8 @@ system_prompt: |- - Write code in simple Python - Follow Python coding standards and Python syntax - Correctly call tools or agents to solve problems - - To distinguish between code execution and displaying user code, use 'Code: \n```\n' to start executing code and '```' to indicate its completion. Use 'Code: \n```\n' to start displaying code and '```' to indicate its completion. - - Note that executed code is not visible to users. If users need to see the code, use 'Code: \n```\n' as the start and '```' to denote displayed code. + - To distinguish between code execution and displaying user code, use 'Code: \n```\n' to start executing code and '```' to indicate its completion. Use 'Code: \n```\n' to start displaying code and '```' to indicate its completion. + - Note that executed code is not visible to users. If users need to see the code, use 'Code: \n```\n' as the start and '```' to denote displayed code. 3. Observe Results: - View code execution results @@ -72,13 +72,25 @@ system_prompt: |- After thinking, when you believe you can answer the user's question, you can generate a final answer directly to the user without generating code and stop the loop. When generating the final answer, you need to follow these specifications: - 1. Use Markdown format to format your output. - 2. If you have used retrieval tools to obtain specific information and answer questions based on this information, you need to add reference marks at the corresponding positions in your answer: - - The letters and numbers of the reference marks need to correspond one-to-one with the retrieval results of the retrieval tools - - The reference mark format is '[[corresponding letter+number]]', for example: '[[a1]][[b2]][[c3]]' - - Reference marks should be placed immediately after the relevant information or sentence, usually at the end of the sentence or paragraph - - Note that only reference marks need to be added, no need to add links, references, or other extraneous content - 3. If no retrieval tools are used, do not add any reference marks + 1. **Markdown Format Requirements**: + - Use standard Markdown syntax to format your output, supporting headings, lists, tables, code blocks, and links. + - Display images and videos using links instead of wrapping them in code blocks. Use `[link text](URL)` for links, `![alt text](image URL)` for images, and `` for videos. + - Use a single blank line between paragraphs, avoid multiple consecutive blank lines + - Mathematical formulas use standard Markdown format: inline formulas use $formula$, block formulas use $$formula$$ + + 2. **Reference Mark Specifications** (only when retrieval tools are used): + - Reference mark format must strictly be: `[[letter+number]]`, for example: `[[a1]]`, `[[b2]]`, `[[c3]]` + - The letter part must be a single lowercase letter (a-e), the number part must be an integer + - The letters and numbers of reference marks must correspond one-to-one with the retrieval results of retrieval tools + - Reference marks should be placed immediately after relevant information or sentences, usually at the end of sentences or paragraphs + - Multiple reference marks can be used consecutively, for example: `[[a1]][[b2]]` + - **Important**: Only add reference marks, do not add links, reference lists, or other extraneous content + - If there is no matching reference in the retrieval results, do not display that reference mark + + 3. **Format Detail Requirements**: + - Avoid using HTML tags in Markdown, prioritize native Markdown syntax + - Code in code blocks should maintain original format, do not add extra escape characters + - If no retrieval tools are used, do not add any reference marks ### Available Resources You can only use the following resources, and may not use any other tools or agents: @@ -129,7 +141,7 @@ system_prompt: |- {{ constraint }} ### Python Code Specifications - 1. If it is considered to be code that needs to be executed, the code content begins with 'code: \n```\n' and ends with '```'. If the code does not need to be executed for display only, the code content begins with 'code: \n```\n', and ends with '```', where language_type can be python, java, javascript, etc; + 1. If it is considered to be code that needs to be executed, the code content begins with 'code: \n```\n' and ends with '```'. If the code does not need to be executed for display only, the code content begins with 'code: \n```\n', and ends with '```', where language_type can be python, java, javascript, etc; 2. Only use defined variables, variables will persist between multiple calls; 3. Use "print()" function to let the next model call see corresponding variable information; 4. Use tool/agent input parameters correctly, use keyword arguments, not dictionary format; diff --git a/backend/prompts/utils/prompt_generate.yaml b/backend/prompts/utils/prompt_generate.yaml index fc102788b..9832ab671 100644 --- a/backend/prompts/utils/prompt_generate.yaml +++ b/backend/prompts/utils/prompt_generate.yaml @@ -52,8 +52,8 @@ FEW_SHOTS_SYSTEM_PROMPT: |- - 用简单的Python编写代码 - 遵循python代码规范和python语法 - 根据格式规范正确调用工具/助手 - - 考虑到代码执行与展示用户代码的区别,使用'代码:\n```\n'开头,并以'```'表达运行代码,使用'代码:\n```\n'开头,并以'```'表达展示代码 - - 注意运行的代码不会被用户看到,所以如果用户需要看到代码,你需要使用'代码:\n```\n'开头,并以'```'表达展示代码。 + - 考虑到代码执行与展示用户代码的区别,使用'代码:\n```\n'开头,并以'```'表达运行代码,使用'代码:\n```\n'开头,并以'```'表达展示代码 + - 注意运行的代码不会被用户看到,所以如果用户需要看到代码,你需要使用'代码:\n```\n'开头,并以'```'表达展示代码。 3. 观察结果: - 查看代码执行结果 @@ -61,7 +61,7 @@ FEW_SHOTS_SYSTEM_PROMPT: |- 在思考结束后,当Agent认为可以回答用户问题,那么可以不生成代码,直接生成最终回答给到用户并停止循环。 ### python代码规范 - 1. 如果认为是需要执行的代码,代码内容以'代码:\n```\n'开头,并以'```'标识符结尾。如果是不需要执行仅用于展示的代码,代码内容以'代码:\n```\n'开头,并以'```'标识符结尾,其中语言类型例如python、java、javascript等; + 1. 如果认为是需要执行的代码,代码内容以'代码:\n```\n'开头,并以'```'标识符结尾。如果是不需要执行仅用于展示的代码,代码内容以'代码:\n```\n'开头,并以'```'标识符结尾,其中语言类型例如python、java、javascript等; 2. 只使用已定义的变量,变量将在多次调用之间持续保持; 3. 使用“print()”函数让下一次的模型调用看到对应变量信息; 4. 正确使用工具/助手的入参,使用关键字参数,不要用字典形式; @@ -160,11 +160,12 @@ FEW_SHOTS_SYSTEM_PROMPT: |- middle = [x for x in arr if x == pivot] right = [x for x in arr if x > pivot] return quick_sort(left) + middle + quick_sort(right) - ``` + ``` 观察结果:快速排序的python代码。 思考:我已经获得了快速排序的python代码,现在我将生成最终回答。 快速排序的python代码如下: + 代码: ``` def quick_sort(arr): if len(arr) <= 1: @@ -174,7 +175,7 @@ FEW_SHOTS_SYSTEM_PROMPT: |- middle = [x for x in arr if x == pivot] right = [x for x in arr if x > pivot] return quick_sort(left) + middle + quick_sort(right) - ``` + ``` --- diff --git a/backend/prompts/utils/prompt_generate_en.yaml b/backend/prompts/utils/prompt_generate_en.yaml index 8ddf8899c..41db11957 100644 --- a/backend/prompts/utils/prompt_generate_en.yaml +++ b/backend/prompts/utils/prompt_generate_en.yaml @@ -53,8 +53,8 @@ FEW_SHOTS_SYSTEM_PROMPT: |- - Write code in simple Python - Follow Python coding standards and Python syntax - Call tools/assistants correctly according to format specifications - - To distinguish between code execution and displaying user code, use 'Code: \n```\n' to start executing code and '```' to indicate its completion. Use 'Code: \n```\n' to start displaying code and '```' to indicate its completion. - - Note that executed code is not visible to users. If users need to see the code, use 'Code: \n```\n' as the start and '```' to denote displayed code. + - To distinguish between code execution and displaying user code, use 'Code: \n```\n' to start executing code and '```' to indicate its completion. Use 'Code: \n```\n' to start displaying code and '```' to indicate its completion. + - Note that executed code is not visible to users. If users need to see the code, use 'Code: \n```\n' as the start and '```' to denote displayed code. 3. Observe Results: - View code execution results @@ -62,7 +62,7 @@ FEW_SHOTS_SYSTEM_PROMPT: |- After thinking, when you believe you can answer the user's question, you can generate a final answer directly to the user without generating code and stop the loop. ### Python Code Specifications - 1. If it is considered to be code that needs to be executed, the code content begins with 'Code:\n```\n' and ends with '```'. If the code does not need to be executed for display only, the code content begins with 'Code:\n```\n', and ends with '```', where language_type can be python, java, javascript, etc.; + 1. If it is considered to be code that needs to be executed, the code content begins with 'Code:\n```\n' and ends with '```'. If the code does not need to be executed for display only, the code content begins with 'Code:\n```\n', and ends with '```', where language_type can be python, java, javascript, etc.; 2. Only use defined variables, variables will persist between multiple calls; 3. Use "print()" function to let the next model call see corresponding variable information; 4. Use tool/assistant input parameters correctly, use keyword arguments, not dictionary format; @@ -158,7 +158,7 @@ FEW_SHOTS_SYSTEM_PROMPT: |- middle = [x for x in arr if x == pivot] right = [x for x in arr if x > pivot] return quick_sort(left) + middle + quick_sort(right) - ``` + ``` Observe Results: The Python quick sort code. Think: I have obtained the Python quick sort code, now I will generate the final answer. @@ -172,7 +172,7 @@ FEW_SHOTS_SYSTEM_PROMPT: |- middle = [x for x in arr if x == pivot] right = [x for x in arr if x > pivot] return quick_sort(left) + middle + quick_sort(right) - ``` + ``` --- diff --git a/backend/runtime_service.py b/backend/runtime_service.py new file mode 100644 index 000000000..faa3d2981 --- /dev/null +++ b/backend/runtime_service.py @@ -0,0 +1,42 @@ +import uvicorn +import logging +import warnings +import asyncio + +from consts.const import APP_VERSION + +warnings.filterwarnings("ignore", category=UserWarning) + +from dotenv import load_dotenv +load_dotenv() + +from apps.runtime_app import app +from utils.logging_utils import configure_logging, configure_elasticsearch_logging +from services.tool_configuration_service import initialize_tools_on_startup + +configure_logging(logging.INFO) +configure_elasticsearch_logging() +logger = logging.getLogger("runtime_service") + + +async def startup_initialization(): + """ + Perform initialization tasks during server startup + """ + logger.info("Starting server initialization...") + logger.info(f"APP version is: {APP_VERSION}") + try: + # Initialize tools on startup - service layer handles detailed logging + await initialize_tools_on_startup() + logger.info("Server initialization completed successfully!") + except Exception as e: + logger.error(f"Server initialization failed: {str(e)}") + # Don't raise the exception to allow server to start even if initialization fails + logger.warning("Server will continue to start despite initialization issues") + + +if __name__ == "__main__": + asyncio.run(startup_initialization()) + uvicorn.run(app, host="0.0.0.0", port=5014, log_level="info") + + diff --git a/backend/services/agent_service.py b/backend/services/agent_service.py index b1e9ff813..30bb3d453 100644 --- a/backend/services/agent_service.py +++ b/backend/services/agent_service.py @@ -4,7 +4,7 @@ import os import uuid from collections import deque -from typing import Optional +from typing import Optional, Dict from fastapi import Header, Request from fastapi.responses import JSONResponse, StreamingResponse @@ -23,7 +23,7 @@ ExportAndImportDataFormat, MCPInfo, ToolInstanceInfoRequest, - ToolSourceEnum + ToolSourceEnum, ModelConnectStatusEnum ) from database.agent_db import ( create_agent, @@ -801,29 +801,140 @@ async def list_all_agent_info_impl(tenant_id: str) -> list[dict]: try: agent_list = query_all_agent_info_by_tenant_id(tenant_id=tenant_id) - simple_agent_list = [] + model_cache: Dict[int, Optional[dict]] = {} + enriched_agents: list[dict] = [] + for agent in agent_list: - # check agent is available if not agent["enabled"]: continue + + unavailable_reasons: list[str] = [] + tool_info = search_tools_for_sub_agent( agent_id=agent["agent_id"], tenant_id=tenant_id) - tool_id_list = [tool["tool_id"] for tool in tool_info] - is_available = all(check_tool_is_available(tool_id_list)) + tool_id_list = [tool["tool_id"] + for tool in tool_info if tool.get("tool_id") is not None] + if tool_id_list: + tool_statuses = check_tool_is_available(tool_id_list) + if not all(tool_statuses): + unavailable_reasons.append("tool_unavailable") + + model_reasons = _collect_model_availability_reasons( + agent=agent, + tenant_id=tenant_id, + model_cache=model_cache + ) + unavailable_reasons.extend(model_reasons) + + # Preserve the raw data so we can adjust availability for duplicates + enriched_agents.append({ + "raw_agent": agent, + "unavailable_reasons": unavailable_reasons, + }) + + # Handle duplicate name/display_name: keep the earliest created agent available, + # mark later ones as unavailable due to duplication. + _apply_duplicate_name_availability_rules(enriched_agents) + + simple_agent_list: list[dict] = [] + for entry in enriched_agents: + agent = entry["raw_agent"] + unavailable_reasons = list(dict.fromkeys(entry["unavailable_reasons"])) simple_agent_list.append({ "agent_id": agent["agent_id"], "name": agent["name"] if agent["name"] else agent["display_name"], "display_name": agent["display_name"] if agent["display_name"] else agent["name"], "description": agent["description"], - "is_available": is_available + "is_available": len(unavailable_reasons) == 0, + "unavailable_reasons": unavailable_reasons }) + return simple_agent_list except Exception as e: logger.error(f"Failed to query all agent info: {str(e)}") raise ValueError(f"Failed to query all agent info: {str(e)}") +def _apply_duplicate_name_availability_rules(enriched_agents: list[dict]) -> None: + """ + For agents that share the same name or display_name, only the earliest created + agent should remain available (if it has no other unavailable reasons). + All later-created agents in the same group become unavailable due to duplication. + """ + # Group by name and display_name + name_groups: dict[str, list[dict]] = {} + display_name_groups: dict[str, list[dict]] = {} + + for entry in enriched_agents: + agent = entry["raw_agent"] + name = agent.get("name") + if name: + name_groups.setdefault(name, []).append(entry) + + display_name = agent.get("display_name") + if display_name: + display_name_groups.setdefault(display_name, []).append(entry) + + def _mark_duplicates(groups: dict[str, list[dict]], reason_key: str) -> None: + for entries in groups.values(): + if len(entries) <= 1: + continue + + # Sort by create_time ascending so the earliest created agent comes first + sorted_entries = sorted( + entries, + key=lambda e: e["raw_agent"].get("create_time"), + ) + + # The first (earliest) agent keeps its current availability; + # subsequent agents are marked as duplicates. + for duplicate_entry in sorted_entries[1:]: + duplicate_entry["unavailable_reasons"].append(reason_key) + + _mark_duplicates(name_groups, "duplicate_name") + _mark_duplicates(display_name_groups, "duplicate_display_name") + + +def _collect_model_availability_reasons(agent: dict, tenant_id: str, model_cache: Dict[int, Optional[dict]]) -> list[str]: + """ + Build a list of reasons related to model availability issues for a given agent. + """ + reasons: list[str] = [] + reasons.extend(_check_single_model_availability( + model_id=agent.get("model_id"), + tenant_id=tenant_id, + model_cache=model_cache, + reason_key="model_unavailable" + )) + + return reasons + + +def _check_single_model_availability( + model_id: int | None, + tenant_id: str, + model_cache: Dict[int, Optional[dict]], + reason_key: str, +) -> list[str]: + if not model_id: + return [] + + if model_id not in model_cache: + model_cache[model_id] = get_model_by_model_id(model_id, tenant_id) + + model_info = model_cache.get(model_id) + if not model_info: + return [reason_key] + + connect_status = ModelConnectStatusEnum.get_value( + model_info.get("connect_status")) + if connect_status != ModelConnectStatusEnum.AVAILABLE.value: + return [reason_key] + + return [] + + def insert_related_agent_impl(parent_agent_id, child_agent_id, tenant_id): # search the agent by bfs, check if there is a circular call search_list = deque([child_agent_id]) diff --git a/backend/services/file_management_service.py b/backend/services/file_management_service.py index 84b68fe90..1e1175228 100644 --- a/backend/services/file_management_service.py +++ b/backend/services/file_management_service.py @@ -20,7 +20,7 @@ list_files ) from utils.attachment_utils import convert_image_to_text, convert_long_text_to_text -from services.elasticsearch_service import ElasticSearchService, get_es_core +from services.vectordatabase_service import ElasticSearchService, get_vector_db_core from utils.prompt_template_utils import get_file_processing_messages_template from utils.file_management_utils import save_upload_file @@ -79,8 +79,8 @@ async def upload_files_impl(destination: str, file: List[UploadFile], folder: st # Resolve filename conflicts against existing KB documents by renaming (e.g., name -> name_1) if index_name: try: - es_core = get_es_core() - existing = await ElasticSearchService.list_files(index_name, include_chunks=False, es_core=es_core) + vdb_core = get_vector_db_core() + existing = await ElasticSearchService.list_files(index_name, include_chunks=False, vdb_core=vdb_core) existing_files = existing.get( "files", []) if isinstance(existing, dict) else [] # Prefer 'file' field; fall back to 'filename' if present diff --git a/backend/services/model_management_service.py b/backend/services/model_management_service.py index 778ee2b88..84936f393 100644 --- a/backend/services/model_management_service.py +++ b/backend/services/model_management_service.py @@ -1,5 +1,5 @@ import logging -from typing import List, Dict, Any, Optional +from typing import List, Dict, Any from consts.const import LOCALHOST_IP, LOCALHOST_NAME, DOCKER_INTERNAL_HOST from consts.model import ModelConnectStatusEnum @@ -25,7 +25,7 @@ sort_models_by_id, ) from utils.memory_utils import build_memory_config as build_memory_config_for_tenant -from services.elasticsearch_service import get_es_core +from services.vectordatabase_service import get_vector_db_core from nexent.memory.memory_service import clear_model_memories logger = logging.getLogger("model_management_service") @@ -244,12 +244,12 @@ async def delete_model_for_tenant(user_id: str, tenant_id: str, display_name: st # Best-effort memory cleanup using the fetched variants try: - es_core = get_es_core() + vdb_core = get_vector_db_core() base_memory_config = build_memory_config_for_tenant(tenant_id) for t, m in models_by_type.items(): try: await clear_model_memories( - es_core=es_core, + vdb_core=vdb_core, model_repo=m.get("model_repo", ""), model_name=m.get("model_name", ""), embedding_dims=int(m.get("max_tokens") or 0), diff --git a/backend/services/prompt_service.py b/backend/services/prompt_service.py index e0bae7870..5b9d57e0e 100644 --- a/backend/services/prompt_service.py +++ b/backend/services/prompt_service.py @@ -2,17 +2,17 @@ import logging import queue import threading +from typing import Optional, List from jinja2 import StrictUndefined, Template from smolagents import OpenAIServerModel -from consts.const import LANGUAGE, MODEL_CONFIG_MAPPING, MESSAGE_ROLE, THINK_END_PATTERN, THINK_START_PATTERN +from consts.const import LANGUAGE, MESSAGE_ROLE, THINK_END_PATTERN, THINK_START_PATTERN from consts.model import AgentInfoRequest -from database.agent_db import update_agent, query_sub_agents_id_list, search_agent_info_by_agent_id +from database.agent_db import update_agent, search_agent_info_by_agent_id from database.model_management_db import get_model_by_model_id from database.tool_db import query_tools_by_ids -from services.agent_service import get_enable_tool_id_by_agent_id -from utils.config_utils import tenant_config_manager, get_model_name_from_config +from utils.config_utils import get_model_name_from_config from utils.prompt_template_utils import get_prompt_generate_prompt_template # Configure logging @@ -34,7 +34,7 @@ def _process_thinking_tokens(new_token: str, is_thinking: bool, token_join: list """ # Handle thinking mode if is_thinking: - return not (THINK_END_PATTERN in new_token) + return THINK_END_PATTERN not in new_token # Handle start of thinking if THINK_START_PATTERN in new_token: @@ -98,14 +98,16 @@ def call_llm_for_system_prompt(model_id: int, user_prompt: str, system_prompt: s raise e -def gen_system_prompt_streamable(agent_id: int, model_id: int, task_description: str, user_id: str, tenant_id: str, language: str): +def gen_system_prompt_streamable(agent_id: int, model_id: int, task_description: str, user_id: str, tenant_id: str, language: str, tool_ids: Optional[List[int]] = None, sub_agent_ids: Optional[List[int]] = None): for system_prompt in generate_and_save_system_prompt_impl( agent_id=agent_id, model_id=model_id, task_description=task_description, user_id=user_id, tenant_id=tenant_id, - language=language + language=language, + tool_ids=tool_ids, + sub_agent_ids=sub_agent_ids ): # SSE format, each message ends with \n\n yield f"data: {json.dumps({'success': True, 'data': system_prompt}, ensure_ascii=False)}\n\n" @@ -116,17 +118,35 @@ def generate_and_save_system_prompt_impl(agent_id: int, task_description: str, user_id: str, tenant_id: str, - language: str): - # Get description of tool and agent - # In create mode (agent_id=0), return empty lists - if agent_id == 0: + language: str, + tool_ids: Optional[List[int]] = None, + sub_agent_ids: Optional[List[int]] = None): + # Get description of tool and agent from frontend-provided IDs + # Frontend always provides tool_ids and sub_agent_ids (could be empty arrays) + + # Handle tool IDs + if tool_ids and len(tool_ids) > 0: + tool_info_list = query_tools_by_ids(tool_ids) + logger.debug(f"Using frontend-provided tool IDs: {tool_ids}") + else: tool_info_list = [] + logger.debug("No tools selected (empty tool_ids list)") + + # Handle sub-agent IDs + if sub_agent_ids and len(sub_agent_ids) > 0: sub_agent_info_list = [] + for sub_agent_id in sub_agent_ids: + try: + sub_agent_info = search_agent_info_by_agent_id( + agent_id=sub_agent_id, tenant_id=tenant_id) + sub_agent_info_list.append(sub_agent_info) + except Exception as e: + logger.warning( + f"Failed to get sub-agent info for agent_id {sub_agent_id}: {str(e)}") + logger.debug(f"Using frontend-provided sub-agent IDs: {sub_agent_ids}") else: - tool_info_list = get_enabled_tool_description_for_generate_prompt( - tenant_id=tenant_id, agent_id=agent_id) - sub_agent_info_list = get_enabled_sub_agent_description_for_generate_prompt( - tenant_id=tenant_id, agent_id=agent_id) + sub_agent_info_list = [] + logger.debug("No sub-agents selected (empty sub_agent_ids list)") # 1. Real-time streaming push final_results = {"duty": "", "constraint": "", "few_shots": "", "agent_var_name": "", "agent_display_name": "", @@ -292,27 +312,3 @@ def join_info_for_generate_system_prompt(prompt_for_generate, sub_agent_info_lis "assistant_description": assistant_description }) return content - - -def get_enabled_tool_description_for_generate_prompt(agent_id: int, tenant_id: str): - # Get tool information - logger.info("Fetching tool instances") - tool_id_list = get_enable_tool_id_by_agent_id( - agent_id=agent_id, tenant_id=tenant_id) - tool_info_list = query_tools_by_ids(tool_id_list) - return tool_info_list - - -def get_enabled_sub_agent_description_for_generate_prompt(agent_id: int, tenant_id: str): - logger.info("Fetching sub-agents information") - - sub_agent_id_list = query_sub_agents_id_list( - main_agent_id=agent_id, tenant_id=tenant_id) - - sub_agent_info_list = [] - for sub_agent_id in sub_agent_id_list: - sub_agent_info = search_agent_info_by_agent_id( - agent_id=sub_agent_id, tenant_id=tenant_id) - - sub_agent_info_list.append(sub_agent_info) - return sub_agent_info_list diff --git a/backend/services/tool_configuration_service.py b/backend/services/tool_configuration_service.py index ba1ff5628..5aebaf614 100644 --- a/backend/services/tool_configuration_service.py +++ b/backend/services/tool_configuration_service.py @@ -20,10 +20,10 @@ query_all_tools, query_tool_instances_by_id, update_tool_table_from_scan_tool_list, - search_last_tool_instance_by_tool_id + search_last_tool_instance_by_tool_id, ) from database.user_tenant_db import get_all_tenant_ids -from services.elasticsearch_service import get_embedding_model, elastic_core +from services.vectordatabase_service import get_embedding_model, get_vector_db_core from services.tenant_config_service import get_selected_knowledge_list logger = logging.getLogger("tool_configuration_service") @@ -605,11 +605,12 @@ def _validate_local_tool( index_names = [knowledge_info.get("index_name") for knowledge_info in knowledge_info_list] embedding_model = get_embedding_model(tenant_id=tenant_id) + vdb_core = get_vector_db_core() params = { **instantiation_params, 'index_names': index_names, - 'es_core': elastic_core, - 'embedding_model': embedding_model + 'vdb_core': vdb_core, + 'embedding_model': embedding_model, } tool_instance = tool_class(**params) else: diff --git a/backend/services/elasticsearch_service.py b/backend/services/vectordatabase_service.py similarity index 82% rename from backend/services/elasticsearch_service.py rename to backend/services/vectordatabase_service.py index 5193c2e4e..bab6bd284 100644 --- a/backend/services/elasticsearch_service.py +++ b/backend/services/vectordatabase_service.py @@ -20,9 +20,10 @@ from fastapi.responses import StreamingResponse from nexent.core.models.embedding_model import OpenAICompatibleEmbedding, JinaEmbedding, BaseEmbedding from nexent.core.nlp.tokenizer import calculate_term_weights +from nexent.vector_database.base import VectorDatabaseCore from nexent.vector_database.elasticsearch_core import ElasticSearchCore -from consts.const import ES_API_KEY, ES_HOST, LANGUAGE +from consts.const import ES_API_KEY, ES_HOST, LANGUAGE, VectorDatabaseType from database.attachment_db import delete_file from database.knowledge_db import ( create_knowledge_record, @@ -34,32 +35,46 @@ from utils.config_utils import tenant_config_manager, get_model_name_from_config from utils.file_management_utils import get_all_files_status, get_file_size +ALLOWED_CHUNK_FIELDS = {"filename", + "path_or_url", "content", "create_time", "id"} + # Configure logging -logger = logging.getLogger("elasticsearch_service") +logger = logging.getLogger("vectordatabase_service") +def get_vector_db_core( + db_type: VectorDatabaseType = VectorDatabaseType.ELASTICSEARCH, +) -> VectorDatabaseCore: + """ + Return a VectorDatabaseCore implementation based on the requested type. + Args: + db_type: Target vector database provider. Defaults to Elasticsearch. -# Old keyword-based summary method removed - replaced with Map-Reduce approach -# See utils/document_vector_utils.py for new implementation + Returns: + VectorDatabaseCore: Concrete vector database implementation. + Raises: + ValueError: If the requested database type is not supported. + """ + if db_type == VectorDatabaseType.ELASTICSEARCH: + return ElasticSearchCore( + host=ES_HOST, + api_key=ES_API_KEY, + verify_certs=False, + ssl_show_warn=False, + ) -# Initialize ElasticSearchCore instance with HTTPS support -elastic_core = ElasticSearchCore( - host=ES_HOST, - api_key=ES_API_KEY, - verify_certs=False, - ssl_show_warn=False, -) + raise ValueError(f"Unsupported vector database type: {db_type}") -def check_knowledge_base_exist_impl(index_name: str, es_core, user_id: str, tenant_id: str) -> dict: +def check_knowledge_base_exist_impl(index_name: str, vdb_core: VectorDatabaseCore, user_id: str, tenant_id: str) -> dict: """ Check knowledge base existence and handle orphan cases Args: index_name: Name of the index to check - es_core: Elasticsearch core instance + vdb_core: Elasticsearch core instance user_id: Current user ID tenant_id: Current tenant ID @@ -67,7 +82,7 @@ def check_knowledge_base_exist_impl(index_name: str, es_core, user_id: str, tena dict: Status information about the knowledge base """ # 1. Check index existence in ES and corresponding record in PG - es_exists = es_core.client.indices.exists(index=index_name) + es_exists = vdb_core.check_index_exists(index_name) pg_record = get_knowledge_record({"index_name": index_name}) # Case A: Orphan in ES only (exists in ES, missing in PG) @@ -75,7 +90,7 @@ def check_knowledge_base_exist_impl(index_name: str, es_core, user_id: str, tena logger.warning( f"Detected orphan knowledge base '{index_name}' – present in ES, absent in PG. Deleting ES index only.") try: - es_core.delete_index(index_name) + vdb_core.delete_index(index_name) # Clean up Redis records related to this index to avoid stale tasks try: redis_service = get_redis_service() @@ -121,11 +136,6 @@ def check_knowledge_base_exist_impl(index_name: str, es_core, user_id: str, tena return {"status": "exists_in_other_tenant"} -def get_es_core(): - # ensure embedding model is latest - return elastic_core - - def get_embedding_model(tenant_id: str): # Get the tenant config model_config = tenant_config_manager.get_model_config( @@ -144,7 +154,7 @@ def get_embedding_model(tenant_id: str): class ElasticSearchService: @staticmethod - async def full_delete_knowledge_base(index_name: str, es_core: ElasticSearchCore, user_id: str): + async def full_delete_knowledge_base(index_name: str, vdb_core: VectorDatabaseCore, user_id: str): """ Completely delete a knowledge base, including its index, associated files in MinIO, and all related records in Redis and PostgreSQL. @@ -157,7 +167,7 @@ async def full_delete_knowledge_base(index_name: str, es_core: ElasticSearchCore f"Step 1/4: Retrieving file list for index: {index_name}") try: file_list_result = await ElasticSearchService.list_files(index_name, include_chunks=False, - es_core=es_core) + vdb_core=vdb_core) files_to_delete = file_list_result.get("files", []) logger.debug( f"Found {len(files_to_delete)} files to delete from MinIO for index '{index_name}'.") @@ -209,7 +219,7 @@ async def full_delete_knowledge_base(index_name: str, es_core: ElasticSearchCore # 3. Delete Elasticsearch index and its DB record logger.debug( f"Step 3/4: Deleting Elasticsearch index '{index_name}' and its database record.") - delete_index_result = await ElasticSearchService.delete_index(index_name, es_core, user_id) + delete_index_result = await ElasticSearchService.delete_index(index_name, vdb_core, user_id) # 4. Clean up Redis records related to this knowledge base logger.debug( @@ -262,17 +272,17 @@ def create_index( description="Name of the index to create"), embedding_dim: Optional[int] = Query( None, description="Dimension of the embedding vectors"), - es_core: ElasticSearchCore = Depends(get_es_core), + vdb_core: VectorDatabaseCore = Depends(get_vector_db_core), user_id: Optional[str] = Body( None, description="ID of the user creating the knowledge base"), tenant_id: Optional[str] = Body( None, description="ID of the tenant creating the knowledge base"), ): try: - if es_core.client.indices.exists(index=index_name): + if vdb_core.check_index_exists(index_name): raise Exception(f"Index {index_name} already exists") embedding_model = get_embedding_model(tenant_id) - success = es_core.create_vector_index(index_name, embedding_dim=embedding_dim or ( + success = vdb_core.create_index(index_name, embedding_dim=embedding_dim or ( embedding_model.embedding_dim if embedding_model else 1024)) if not success: raise Exception(f"Failed to create index {index_name}") @@ -289,14 +299,14 @@ def create_index( async def delete_index( index_name: str = Path(..., description="Name of the index to delete"), - es_core: ElasticSearchCore = Depends(get_es_core), + vdb_core: VectorDatabaseCore = Depends(get_vector_db_core), user_id: Optional[str] = Body( None, description="ID of the user delete the knowledge base"), ): try: # 1. Get list of files from the index try: - files_to_delete = await ElasticSearchService.list_files(index_name, es_core=es_core) + files_to_delete = await ElasticSearchService.list_files(index_name, vdb_core=vdb_core) if files_to_delete and files_to_delete.get("files"): # 2. Delete files from MinIO storage for file_info in files_to_delete["files"]: @@ -312,7 +322,7 @@ async def delete_index( f"Error deleting associated files from MinIO for index {index_name}: {str(e)}") # 3. Delete the index in Elasticsearch - success = es_core.delete_index(index_name) + success = vdb_core.delete_index(index_name) if not success: # Even if deletion fails, we proceed to database record cleanup logger.warning( @@ -342,7 +352,7 @@ def list_indices( description="ID of the tenant listing the knowledge base"), user_id: str = Body( description="ID of the user listing the knowledge base"), - es_core: ElasticSearchCore = Depends(get_es_core) + vdb_core: VectorDatabaseCore = Depends(get_vector_db_core) ): """ List all indices that the current user has permissions to access. @@ -353,12 +363,12 @@ def list_indices( include_stats: Whether to include index stats tenant_id: ID of the tenant listing the knowledge base user_id: ID of the user listing the knowledge base - es_core: ElasticSearchCore instance + vdb_core: VectorDatabaseCore instance Returns: Dict[str, Any]: A dictionary containing the list of indices and the count. """ - all_indices_list = es_core.get_user_indices(pattern) + all_indices_list = vdb_core.get_user_indices(pattern) db_record = get_knowledge_info_by_tenant_id(tenant_id=tenant_id) @@ -385,7 +395,7 @@ def list_indices( if include_stats: stats_info = [] if filtered_indices_list: - indice_stats = es_core.get_index_stats(filtered_indices_list) + indice_stats = vdb_core.get_indices_detail(filtered_indices_list) for index_name in filtered_indices_list: index_stats = indice_stats.get(index_name, {}) stats_info.append({ @@ -401,80 +411,13 @@ def list_indices( return response - @staticmethod - def get_index_name( - index_name: str = Path(..., description="Name of the index"), - es_core: ElasticSearchCore = Depends(get_es_core) - ): - """ - Get detailed information about the index, including statistics, field mappings, file list, and processing - information - - Args: - index_name: Index name - es_core: ElasticSearchCore instance - - Returns: - Dictionary containing detailed index information - """ - try: - # Get all the info in one combined response - stats = es_core.get_index_stats([index_name]) - mappings = es_core.get_index_mapping([index_name]) - - # Check if stats and mappings are valid - if stats and index_name in stats: - index_stats = stats[index_name] - else: - logger.error(f"404: Index {index_name} not found in stats") - index_stats = {} - - if mappings and index_name in mappings: - fields = mappings[index_name] - else: - logger.error(f"404: Index {index_name} not found in mappings:") - fields = [] - - # Check if base_info exists in stats - search_performance = {} - if index_stats and "base_info" in index_stats: - base_info = index_stats["base_info"] - search_performance = index_stats.get("search_performance", {}) - else: - logger.error(f"404: Index {index_name} may not be created yet") - base_info = { - "doc_count": 0, - "unique_sources_count": 0, - "store_size": "0", - "process_source": "Unknown", - "embedding_model": "Unknown", - } - - return { - "base_info": base_info, - "search_performance": search_performance, - "fields": fields - } - except Exception as e: - error_msg = str(e) - # Check if it's an ElasticSearch connection issue - if "503" in error_msg or "search_phase_execution_exception" in error_msg: - raise Exception( - f"ElasticSearch service unavailable for index {index_name}: {error_msg}") - elif "ApiError" in error_msg: - raise Exception( - f"ElasticSearch API error for index {index_name}: {error_msg}") - else: - raise Exception( - f"Error getting info for index {index_name}: {error_msg}") - @staticmethod def index_documents( embedding_model: BaseEmbedding, index_name: str = Path(..., description="Name of the index"), data: List[Dict[str, Any] ] = Body(..., description="Document List to process"), - es_core: ElasticSearchCore = Depends(get_es_core) + vdb_core: VectorDatabaseCore = Depends(get_vector_db_core) ): """ Index documents and create vector embeddings, create index if it doesn't exist @@ -483,7 +426,7 @@ def index_documents( embedding_model: Optional embedding model to use for generating document vectors index_name: Index name data: List containing document data to be indexed - es_core: ElasticSearchCore instance + vdb_core: VectorDatabaseCore instance Returns: IndexingResponse object containing indexing result information @@ -493,10 +436,10 @@ def index_documents( raise Exception("Index name is required") # Create index if needed (ElasticSearchCore will handle embedding_dim automatically) - if not es_core.client.indices.exists(index=index_name): + if not vdb_core.check_index_exists(index_name): try: ElasticSearchService.create_index( - index_name, es_core=es_core) + index_name, vdb_core=vdb_core) logger.info(f"Created new index {index_name}") except Exception as create_error: raise Exception( @@ -565,7 +508,7 @@ def index_documents( # Index documents (use default batch_size and content_field) try: - total_indexed = es_core.index_documents( + total_indexed = vdb_core.vectorize_documents( index_name=index_name, embedding_model=embedding_model, documents=documents, @@ -592,7 +535,7 @@ async def list_files( index_name: str = Path(..., description="Name of the index"), include_chunks: bool = Query( False, description="Whether to include text chunks for each file"), - es_core: ElasticSearchCore = Depends(get_es_core) + vdb_core: VectorDatabaseCore = Depends(get_vector_db_core) ): """ Get file list for the specified index, including files that are not yet stored in ES @@ -600,7 +543,7 @@ async def list_files( Args: index_name: Name of the index include_chunks: Whether to include text chunks for each file - es_core: ElasticSearchCore instance + vdb_core: VectorDatabaseCore instance Returns: Dictionary containing file list @@ -608,7 +551,7 @@ async def list_files( try: files = [] # Get existing files from ES - existing_files = es_core.get_file_list_with_details(index_name) + existing_files = vdb_core.get_documents_detail(index_name) # Get unique celery files list and the status of each file celery_task_files = await get_all_files_status(index_name) @@ -632,7 +575,8 @@ async def list_files( 'file_size': file_info.get('file_size', 0), 'create_time': int(utc_create_timestamp * 1000), 'status': "COMPLETED", - 'latest_task_id': '' + 'latest_task_id': '', + 'chunk_count': file_info.get('chunk_count', 0) } files.append(file_data) @@ -690,13 +634,13 @@ async def list_files( # Initialize chunks for all files for file_data in files: file_data['chunks'] = [] - file_data['chunk_count'] = 0 + file_data['chunk_count'] = file_data.get('chunk_count', 0) if msearch_body: try: - msearch_responses = es_core.client.msearch( + msearch_responses = vdb_core.multi_search( body=msearch_body, - index=index_name + index_name=index_name ) for i, file_path in enumerate(completed_files_map.keys()): @@ -727,7 +671,7 @@ async def list_files( else: for file_data in files: file_data['chunks'] = [] - file_data['chunk_count'] = 0 + file_data['chunk_count'] = file_data.get('chunk_count', 0) return {"files": files} @@ -740,29 +684,29 @@ def delete_documents( index_name: str = Path(..., description="Name of the index"), path_or_url: str = Query(..., description="Path or URL of documents to delete"), - es_core: ElasticSearchCore = Depends(get_es_core) + vdb_core: VectorDatabaseCore = Depends(get_vector_db_core) ): # 1. Delete ES documents - deleted_count = es_core.delete_documents_by_path_or_url( + deleted_count = vdb_core.delete_documents( index_name, path_or_url) # 2. Delete MinIO file minio_result = delete_file(path_or_url) return {"status": "success", "deleted_es_count": deleted_count, "deleted_minio": minio_result.get("success")} @staticmethod - def health_check(es_core: ElasticSearchCore = Depends(get_es_core)): + def health_check(vdb_core: VectorDatabaseCore = Depends(get_vector_db_core)): """ Check the health status of the API and Elasticsearch Args: - es_core: ElasticSearchCore instance + vdb_core: VectorDatabaseCore instance Returns: Response containing health status information """ try: # Try to list indices as a health check - indices = es_core.get_user_indices() + indices = vdb_core.get_user_indices() return { "status": "healthy", "elasticsearch": "connected", @@ -776,8 +720,7 @@ async def summary_index_name(self, ..., description="Name of the index to get documents from"), batch_size: int = Query( 1000, description="Number of documents to retrieve per batch"), - es_core: ElasticSearchCore = Depends( - get_es_core), + vdb_core: VectorDatabaseCore = Depends(get_vector_db_core), user_id: Optional[str] = Body( None, description="ID of the user delete the knowledge base"), tenant_id: Optional[str] = Body( @@ -797,7 +740,8 @@ async def summary_index_name(self, Args: index_name: Name of the index to summarize batch_size: Number of documents to sample (default: 1000) - es_core: ElasticSearchCore instance + vdb_core: VectorDatabaseCore instance + user_id: ID of the user delete the knowledge base tenant_id: ID of the tenant language: Language of the summary (default: 'zh') model_id: Model ID for LLM summarization @@ -819,32 +763,45 @@ async def summary_index_name(self, # Use new Map-Reduce approach sample_count = min(batch_size // 5, 200) # Sample reasonable number of documents - # Step 1: Get documents and calculate embeddings - document_samples, doc_embeddings = process_documents_for_clustering( - index_name=index_name, - es_core=es_core, - sample_doc_count=sample_count - ) - - if not document_samples: - raise Exception("No documents found in index.") - - # Step 2: Cluster documents - clusters = kmeans_cluster_documents(doc_embeddings, k=None) - - # Step 3: Map-Reduce summarization - cluster_summaries = summarize_clusters_map_reduce( - document_samples=document_samples, - clusters=clusters, - language=language, - doc_max_words=100, - cluster_max_words=150, - model_id=model_id, - tenant_id=tenant_id - ) + # Define a helper function to run all blocking operations in a thread pool + def _generate_summary_sync(): + """Synchronous function that performs all blocking operations""" + # Step 1: Get documents and calculate embeddings + document_samples, doc_embeddings = process_documents_for_clustering( + index_name=index_name, + vdb_core=vdb_core, + sample_doc_count=sample_count + ) + + if not document_samples: + raise Exception("No documents found in index.") + + # Step 2: Cluster documents (CPU-intensive operation) + clusters = kmeans_cluster_documents(doc_embeddings, k=None) + + # Step 3: Map-Reduce summarization (contains blocking LLM calls) + cluster_summaries = summarize_clusters_map_reduce( + document_samples=document_samples, + clusters=clusters, + language=language, + doc_max_words=100, + cluster_max_words=150, + model_id=model_id, + tenant_id=tenant_id + ) + + # Step 4: Merge into final summary + final_summary = merge_cluster_summaries(cluster_summaries) + return final_summary - # Step 4: Merge into final summary - final_summary = merge_cluster_summaries(cluster_summaries) + # Run blocking operations in a thread pool to avoid blocking the event loop + # Use get_running_loop() for better compatibility with modern asyncio + try: + loop = asyncio.get_running_loop() + except RuntimeError: + # Fallback for edge cases + loop = asyncio.get_event_loop() + final_summary = await loop.run_in_executor(None, _generate_summary_sync) # Stream the result async def generate_summary(): @@ -872,7 +829,7 @@ def get_random_documents( description="Name of the index to get documents from"), batch_size: int = Query( 1000, description="Maximum number of documents to retrieve"), - es_core: ElasticSearchCore = Depends(get_es_core) + vdb_core: VectorDatabaseCore = Depends(get_vector_db_core) ): """ Get random sample of documents from the specified index @@ -880,15 +837,14 @@ def get_random_documents( Args: index_name: Name of the index to get documents from batch_size: Maximum number of documents to retrieve, default 1000 - es_core: ElasticSearchCore instance + vdb_core: VectorDatabaseCore instance Returns: Dictionary containing total count and sampled documents """ try: # Get total document count - count_response = es_core.client.count(index=index_name) - total_docs = count_response['count'] + total_docs = vdb_core.count_documents(index_name) # Construct the random sampling query using random_score query = { @@ -906,9 +862,9 @@ def get_random_documents( } # Execute the query - response = es_core.client.search( - index=index_name, - body=query + response = vdb_core.search( + index_name=index_name, + query=query ) # Extract and process the sampled documents @@ -981,3 +937,62 @@ def get_summary(index_name: str = Path(..., description="Name of the index to ge except Exception as e: error_msg = f"Failed to get summary: {str(e)}" raise Exception(error_msg) + + @staticmethod + def get_index_chunks( + index_name: str, + page: Optional[int] = None, + page_size: Optional[int] = None, + path_or_url: Optional[str] = None, + vdb_core: VectorDatabaseCore = Depends(get_vector_db_core), + ): + """ + Retrieve chunk records for the specified index with optional pagination. + + Args: + index_name: Name of the index to query + page: Page number (1-based) when paginating + page_size: Page size when paginating + path_or_url: Optional document filter + vdb_core: VectorDatabaseCore instance + + Returns: + Dictionary containing status, chunk list, total, and pagination metadata + """ + try: + result = vdb_core.get_index_chunks( + index_name, + page=page, + page_size=page_size, + path_or_url=path_or_url, + ) + raw_chunks = result.get("chunks", []) + total = result.get("total", len(raw_chunks)) + result_page = result.get("page", page) + result_page_size = result.get("page_size", page_size) + + filtered_chunks: List[Any] = [] + for chunk in raw_chunks: + if isinstance(chunk, dict): + filtered_chunks.append( + { + field: chunk.get(field) + for field in ALLOWED_CHUNK_FIELDS + if field in chunk + } + ) + else: + filtered_chunks.append(chunk) + + return { + "status": "success", + "message": f"Successfully retrieved {len(filtered_chunks)} chunks from index {index_name}", + "chunks": filtered_chunks, + "total": total, + "page": result_page, + "page_size": result_page_size + } + except Exception as e: + error_msg = f"Error retrieving chunks from index {index_name}: {str(e)}" + logger.error(error_msg) + raise Exception(error_msg) diff --git a/backend/utils/document_vector_utils.py b/backend/utils/document_vector_utils.py index d470071a5..7d8e5b112 100644 --- a/backend/utils/document_vector_utils.py +++ b/backend/utils/document_vector_utils.py @@ -13,24 +13,29 @@ from typing import Dict, List, Optional, Tuple import numpy as np -import yaml from jinja2 import Template, StrictUndefined +from nexent.vector_database.base import VectorDatabaseCore from sklearn.cluster import KMeans from sklearn.metrics import silhouette_score from sklearn.metrics.pairwise import cosine_similarity from consts.const import LANGUAGE +from utils.prompt_template_utils import ( + get_document_summary_prompt_template, + get_cluster_summary_reduce_prompt_template, + get_cluster_summary_agent_prompt_template +) logger = logging.getLogger("document_vector_utils") -def get_documents_from_es(index_name: str, es_core, sample_doc_count: int = 200) -> Dict[str, Dict]: +def get_documents_from_es(index_name: str, vdb_core: VectorDatabaseCore, sample_doc_count: int = 200) -> Dict[str, Dict]: """ Get document samples from Elasticsearch, aggregated by path_or_url Args: index_name: Name of the index to query - es_core: ElasticSearchCore instance + vdb_core: VectorDatabaseCore instance sample_doc_count: Number of documents to sample Returns: @@ -51,7 +56,7 @@ def get_documents_from_es(index_name: str, es_core, sample_doc_count: int = 200) } logger.info(f"Fetching unique documents from index {index_name}") - agg_response = es_core.client.search(index=index_name, body=agg_query) + agg_response = vdb_core.search(index_name=index_name, query=agg_query) all_documents = agg_response['aggregations']['unique_documents']['buckets'] if not all_documents: @@ -89,7 +94,7 @@ def get_documents_from_es(index_name: str, es_core, sample_doc_count: int = 200) ] } - chunks_response = es_core.client.search(index=index_name, body=chunks_query) + chunks_response = vdb_core.search(index_name=index_name, query=chunks_query) chunks = [hit['_source'] for hit in chunks_response['hits']['hits']] # Build document object @@ -444,13 +449,13 @@ def kmeans_cluster_documents(doc_embeddings: Dict[str, np.ndarray], k: Optional[ raise Exception(f"Failed to cluster documents: {str(e)}") -def process_documents_for_clustering(index_name: str, es_core, sample_doc_count: int = 200) -> Tuple[Dict[str, Dict], Dict[str, np.ndarray]]: +def process_documents_for_clustering(index_name: str, vdb_core, sample_doc_count: int = 200) -> Tuple[Dict[str, Dict], Dict[str, np.ndarray]]: """ Complete workflow: Get documents from ES and calculate their embeddings Args: index_name: Name of the index to query - es_core: ElasticSearchCore instance + vdb_core: ElasticSearchCore instance sample_doc_count: Number of documents to sample Returns: @@ -458,7 +463,7 @@ def process_documents_for_clustering(index_name: str, es_core, sample_doc_count: """ try: # Step 1: Get documents from ES - document_samples = get_documents_from_es(index_name, es_core, sample_doc_count) + document_samples = get_documents_from_es(index_name, vdb_core, sample_doc_count) if not document_samples: logger.warning("No documents retrieved from ES") @@ -547,14 +552,8 @@ def summarize_document(document_content: str, filename: str, language: str = LAN Document summary text """ try: - # Select prompt file based on language - if language == LANGUAGE["ZH"]: - prompt_path = 'backend/prompts/document_summary_agent_zh.yaml' - else: - prompt_path = 'backend/prompts/document_summary_agent.yaml' - - with open(prompt_path, 'r', encoding='utf-8') as f: - prompts = yaml.safe_load(f) + # Get prompt template from prompt_template_utils + prompts = get_document_summary_prompt_template(language) system_prompt = prompts.get('system_prompt', '') user_prompt_template = prompts.get('user_prompt', '') @@ -625,14 +624,8 @@ def summarize_cluster(document_summaries: List[str], language: str = LANGUAGE["Z Cluster summary text """ try: - # Select prompt file based on language - if language == LANGUAGE["ZH"]: - prompt_path = 'backend/prompts/cluster_summary_reduce_zh.yaml' - else: - prompt_path = 'backend/prompts/cluster_summary_reduce.yaml' - - with open(prompt_path, 'r', encoding='utf-8') as f: - prompts = yaml.safe_load(f) + # Get prompt template from prompt_template_utils + prompts = get_cluster_summary_reduce_prompt_template(language) system_prompt = prompts.get('system_prompt', '') user_prompt_template = prompts.get('user_prompt', '') @@ -937,9 +930,8 @@ def summarize_cluster_legacy(cluster_content: str, language: str = LANGUAGE["ZH" Cluster summary text """ try: - prompt_path = 'backend/prompts/cluster_summary_agent.yaml' - with open(prompt_path, 'r', encoding='utf-8') as f: - prompts = yaml.safe_load(f) + # Get prompt template from prompt_template_utils + prompts = get_cluster_summary_agent_prompt_template(language) system_prompt = prompts.get('system_prompt', '') user_prompt_template = prompts.get('user_prompt', '') diff --git a/backend/utils/prompt_template_utils.py b/backend/utils/prompt_template_utils.py index 4232ef193..dfbb2c89c 100644 --- a/backend/utils/prompt_template_utils.py +++ b/backend/utils/prompt_template_utils.py @@ -21,6 +21,9 @@ def get_prompt_template(template_type: str, language: str = LANGUAGE["ZH"], **kw - 'analyze_file': File analysis template - 'generate_title': Title generation template - 'file_processing_messages': File processing messages template + - 'document_summary': Document summary template (Map stage) + - 'cluster_summary_reduce': Cluster summary reduce template (Reduce stage) + - 'cluster_summary_agent': Cluster summary agent template (legacy) language: Language code ('zh' or 'en') **kwargs: Additional parameters, for agent type need to pass is_manager parameter @@ -61,6 +64,18 @@ def get_prompt_template(template_type: str, language: str = LANGUAGE["ZH"], **kw 'file_processing_messages': { LANGUAGE["ZH"]: 'backend/prompts/utils/file_processing_messages.yaml', LANGUAGE["EN"]: 'backend/prompts/utils/file_processing_messages_en.yaml' + }, + 'document_summary': { + LANGUAGE["ZH"]: 'backend/prompts/document_summary_agent_zh.yaml', + LANGUAGE["EN"]: 'backend/prompts/document_summary_agent.yaml' + }, + 'cluster_summary_reduce': { + LANGUAGE["ZH"]: 'backend/prompts/cluster_summary_reduce_zh.yaml', + LANGUAGE["EN"]: 'backend/prompts/cluster_summary_reduce.yaml' + }, + 'cluster_summary_agent': { + LANGUAGE["ZH"]: 'backend/prompts/cluster_summary_agent.yaml', + LANGUAGE["EN"]: 'backend/prompts/cluster_summary_agent.yaml' } } @@ -164,3 +179,42 @@ def get_file_processing_messages_template(language: str = 'zh') -> Dict[str, Any dict: Loaded file processing messages configuration """ return get_prompt_template('file_processing_messages', language) + + +def get_document_summary_prompt_template(language: str = LANGUAGE["ZH"]) -> Dict[str, Any]: + """ + Get document summary prompt template (Map stage) + + Args: + language: Language code ('zh' or 'en') + + Returns: + dict: Loaded document summary prompt template configuration + """ + return get_prompt_template('document_summary', language) + + +def get_cluster_summary_reduce_prompt_template(language: str = LANGUAGE["ZH"]) -> Dict[str, Any]: + """ + Get cluster summary reduce prompt template (Reduce stage) + + Args: + language: Language code ('zh' or 'en') + + Returns: + dict: Loaded cluster summary reduce prompt template configuration + """ + return get_prompt_template('cluster_summary_reduce', language) + + +def get_cluster_summary_agent_prompt_template(language: str = LANGUAGE["ZH"]) -> Dict[str, Any]: + """ + Get cluster summary agent prompt template (legacy) + + Args: + language: Language code ('zh' or 'en') + + Returns: + dict: Loaded cluster summary agent prompt template configuration + """ + return get_prompt_template('cluster_summary_agent', language) diff --git a/doc/docs/en/backend/overview.md b/doc/docs/en/backend/overview.md index 9b6219182..3e8620551 100644 --- a/doc/docs/en/backend/overview.md +++ b/doc/docs/en/backend/overview.md @@ -28,7 +28,7 @@ backend/ ├── services/ # Business service layer │ ├── agent_service.py # Agent business logic │ ├── conversation_management_service.py # Conversation management -│ ├── elasticsearch_service.py # Search engine service +│ ├── vectordatabase_service.py # Search engine service │ ├── model_health_service.py # Model health checks │ ├── prompt_service.py # Prompt service │ └── tenant_config_service.py # Tenant configuration service @@ -64,7 +64,8 @@ backend/ │ └── utils/ # Prompt utilities ├── sql/ # SQL scripts ├── assets/ # Backend resource files -├── main_service.py # Main service entry point +├── config_service.py # Config service entry point +├── runtime_service.py # Runtime service entry point ├── data_process_service.py # Data processing service entry point └── requirements.txt # Python dependencies ``` @@ -179,8 +180,9 @@ uv sync && uv pip install -e ../sdk ### Service Startup ```bash python backend/data_process_service.py # Data processing service -python backend/main_service.py # Main service -python backend/nexent_mcp_service.py # MCP service +python backend/config_service.py # Config service +python backend/runtime_service.py # Runtime service +python backend/mcp_service.py # MCP service ``` ## Performance and Scalability diff --git a/doc/docs/en/backend/tools/mcp.md b/doc/docs/en/backend/tools/mcp.md index 7c4a46ba0..67ad79e48 100644 --- a/doc/docs/en/backend/tools/mcp.md +++ b/doc/docs/en/backend/tools/mcp.md @@ -48,15 +48,15 @@ graph TD This system implements a **dual-service proxy architecture** consisting of two independent services: -### 1. Main Service (FastAPI) - Port 5010 +### 1. Main Service (FastAPI) - Port 5010 5014 - **Purpose**: Provides web management interface and RESTful API, serving as the single entry point for frontend - **Features**: User-oriented management with authentication, multi-tenant support, and proxy calls to MCP service -- **Startup File**: `main_service.py` +- **Startup File**: `config_service.py, runtime_service.py` ### 2. MCP Service (FastMCP) - Port 5011 - **Purpose**: Provides MCP protocol services and proxy management (internal service) - **Features**: MCP protocol standard, supports local services and remote proxies, primarily called by main service -- **Startup File**: `nexent_mcp_service.py` +- **Startup File**: `mcp_service.py` **Important Note**: Frontend clients only directly access the main service (5010). All MCP-related operations are completed by the main service proxying calls to the MCP service (5011). @@ -144,14 +144,15 @@ DELETE /remote-proxies?service_name={service_name} **Start Main Service** ```bash cd backend -python main_service.py +python config_service.py +python runtime_service.py ``` Service will start at `http://localhost:5010`. **Start MCP Service** ```bash cd backend -python nexent_mcp_service.py +python mcp_service.py ``` Service will start at `http://localhost:5011`. @@ -189,12 +190,12 @@ curl -X POST http://localhost:5011/add-remote-proxies \ ## Code Structure -### Main Service Components (main_service.py) +### Main Service Components (config_service.py, runtime_service.py) - **FastAPI Application**: Provides Web API and management interface - **Multi-tenant Support**: Multi-tenant management based on authentication - **Router Management**: Contains routers for multiple functional modules -### MCP Service Components (nexent_mcp_service.py) +### MCP Service Components (mcp_service.py) #### RemoteProxyManager Class Responsible for managing the lifecycle of all remote proxies: diff --git a/doc/docs/en/getting-started/development-guide.md b/doc/docs/en/getting-started/development-guide.md index 09c2599d5..d781d4e28 100644 --- a/doc/docs/en/getting-started/development-guide.md +++ b/doc/docs/en/getting-started/development-guide.md @@ -125,9 +125,10 @@ Nexent includes three core backend services that need to be started separately: ```bash # Execute in the project root directory, please follow this order: -source .env && python backend/nexent_mcp_service.py # MCP service +source .env && python backend/mcp_service.py # MCP service source .env && python backend/data_process_service.py # Data processing service -source .env && python backend/main_service.py # Main service +source .env && python backend/config_service.py # Config service +source .env && python backend/runtime_service.py # Runtime service ``` ::: warning Important Notes diff --git a/doc/docs/en/opensource-memorial-wall.md b/doc/docs/en/opensource-memorial-wall.md index 019c1cc2a..299a26c71 100644 --- a/doc/docs/en/opensource-memorial-wall.md +++ b/doc/docs/en/opensource-memorial-wall.md @@ -63,3 +63,7 @@ just dropping by to say nice work 👍 starred the repo ::: info IPM - 2025-10-15 Really impressed by Nexent — smooth interface and powerful agent framework. Great work! ::: + +::: info uu - 2024-01-15 +华为ICT智能体,感谢nexent平台支持! +::: diff --git a/doc/docs/en/sdk/monitoring.md b/doc/docs/en/sdk/monitoring.md index 2f9b86d35..4aa625132 100644 --- a/doc/docs/en/sdk/monitoring.md +++ b/doc/docs/en/sdk/monitoring.md @@ -31,7 +31,8 @@ uv sync --extra performance export ENABLE_TELEMETRY=true # 4. Start backend service -python backend/main_service.py +python backend/config_service.py +python backend/runtime_service.py ``` ## 📊 Access Monitoring Interfaces diff --git a/doc/docs/en/version/version-management.md b/doc/docs/en/version/version-management.md index 7dd22f799..2c416c0b3 100644 --- a/doc/docs/en/version/version-management.md +++ b/doc/docs/en/version/version-management.md @@ -74,7 +74,7 @@ Version is configured directly in `backend/consts/const.py`. Backend startup will print version information in the logs: ```python -# backend/main_service.py +# backend/config_service.py logger.info(f"APP version is: {APP_VERSION}") ``` @@ -92,7 +92,7 @@ APP_VERSION="v1.1.0" ```bash # Start the backend service cd backend - python main_service.py + python config_service.py # Check the version information in the startup logs # Output example: APP version is: v1.1.0 diff --git a/doc/docs/zh/backend/overview.md b/doc/docs/zh/backend/overview.md index 9d1b89427..70a9c87fa 100644 --- a/doc/docs/zh/backend/overview.md +++ b/doc/docs/zh/backend/overview.md @@ -28,7 +28,7 @@ backend/ ├── services/ # 业务服务层 │ ├── agent_service.py # 代理业务逻辑 │ ├── conversation_management_service.py # 对话管理 -│ ├── elasticsearch_service.py # 搜索引擎服务 +│ ├── vectordatabase_service.py # 搜索引擎服务 │ ├── model_health_service.py # 模型健康检查 │ ├── prompt_service.py # 提示词服务 │ └── tenant_config_service.py # 租户配置服务 @@ -64,7 +64,8 @@ backend/ │ └── utils/ # 提示词工具 ├── sql/ # SQL脚本 ├── assets/ # 后端资源文件 -├── main_service.py # 主服务入口 +├── config_service.py # 编辑态服务入口 +├── runtime_service.py # 运行态服务入口 ├── data_process_service.py # 数据处理服务入口 └── requirements.txt # Python依赖 ``` @@ -179,8 +180,9 @@ uv sync && uv pip install -e ../sdk ### 服务启动 ```bash python backend/data_process_service.py # 数据处理服务 -python backend/main_service.py # 主服务 -python backend/nexent_mcp_service.py # MCP服务 +python backend/config_service.py # 编辑态服务 +python backend/runtime_service.py # 运行态服务 +python backend/mcp_service.py # MCP服务 ``` ## 性能和可扩展性 diff --git a/doc/docs/zh/backend/tools/mcp.md b/doc/docs/zh/backend/tools/mcp.md index a76e181a7..4655249b4 100644 --- a/doc/docs/zh/backend/tools/mcp.md +++ b/doc/docs/zh/backend/tools/mcp.md @@ -4,15 +4,15 @@ Nexent采用**本地MCP服务 + 直接远程连接**的架构,通过MCP(Model Context Protocol)协议实现本地服务与远程服务的统一管理。系统包含两个核心服务: -### 1. 主服务 (FastAPI) - 端口 5010 +### 1. 主服务 (FastAPI) - 端口 5010 5014 - **用途**:提供Web管理界面和RESTful API,作为前端唯一入口 - **特点**:面向用户管理,包含认证、多租户支持,管理MCP服务器配置 -- **启动文件**:`main_service.py` +- **启动文件**:`config_service.py, runtime_service.py` ### 2. 本地MCP服务 (FastMCP) - 端口 5011 - **用途**:提供本地MCP协议服务,挂载本地工具 - **特点**:MCP协议标准,仅提供本地服务,不代理远程服务 -- **启动文件**:`nexent_mcp_service.py` +- **启动文件**:`mcp_service.py` ### 3. 远程MCP服务 - **用途**:外部MCP服务,提供远程工具 @@ -54,7 +54,7 @@ graph TD ## 核心功能模块 -### 1. 本地MCP服务管理 (nexent_mcp_service.py) +### 1. 本地MCP服务管理 (mcp_service.py) **本地MCP服务实现**: ```python @@ -323,11 +323,12 @@ CREATE TABLE mcp_servers ( ```bash # 启动主服务 cd backend -python main_service.py +python config_service.py +python runtime_service.py # 启动本地MCP服务 cd backend -python nexent_mcp_service.py +python mcp_service.py ``` ### 2. 添加远程MCP服务器 diff --git a/doc/docs/zh/getting-started/development-guide.md b/doc/docs/zh/getting-started/development-guide.md index ec8fe415f..333f41536 100644 --- a/doc/docs/zh/getting-started/development-guide.md +++ b/doc/docs/zh/getting-started/development-guide.md @@ -125,9 +125,10 @@ Nexent 包含三个核心后端服务,需要分别启动: ```bash # 在项目根目录下执行,请按以下顺序执行: -source .env && python backend/nexent_mcp_service.py # MCP 服务 +source .env && python backend/mcp_service.py # MCP 服务 source .env && python backend/data_process_service.py # 数据处理服务 -source .env && python backend/main_service.py # 主服务 +source .env && python backend/config_service.py # 编辑态服务 +source .env && python backend/runtime_service.py # 运行态服务 ``` ::: warning 重要提示 diff --git a/doc/docs/zh/opensource-memorial-wall.md b/doc/docs/zh/opensource-memorial-wall.md index 077a65c40..6592039ff 100644 --- a/doc/docs/zh/opensource-memorial-wall.md +++ b/doc/docs/zh/opensource-memorial-wall.md @@ -15,6 +15,13 @@ 每条消息应包含您的姓名/昵称和日期。 请保持消息的礼貌和尊重,符合我们的行为准则。 --> +::: china-king-hs - 2025-11-20 +希望能正常使用nexent +::: + +::: info happyzhang - 2025-11-13 +也许我们正见证着未来的“后起之秀”😀 +::: ::: info KevinLeeNJ - 2025-11-13 来参加华为ICT大赛的,nexent很不错,希望后续能有更多功能! @@ -393,3 +400,47 @@ Nexent功能如此之强大,给我很多帮助,感谢开发者!厉害 ::: info chengliuxiang2002 - 2025-11-13 我又来了,通过华为ICT了解到nexent,正在学习中... ::: + +::: info user - 2025-11-14 +我要参加华为ICT +::: + +::: tip Locker - 2025-11-15 +感谢 Nexent 让我踏上了开源之旅!我们将在华为ICT大赛中使用 Nexent 构建智能体,期待能深入学习和贡献。 +::: + +::: info xlp888 - 2025-11-16 +通过华为ICT大赛了解到nexent,希望能在智能体之路上更进一步 +::: + +::: info user - 2025-11-16 +第一次参加,加油 +::: + +::: info user - 2025-11-17 +感谢 Nexent 让我第一次接触到智能体,希望参加ICT比赛过程中可以学到更多知识! +::: + +::: tip kon-do - 2025-11-17 +感谢 Nexent 让我踏上了开源之旅!这个项目的文档真的很棒,帮助我快速上手。 +::: + +::: tip user - 2025-11-19 +感谢 Nexent 让我第一次感受到智能体 希望参加ICT比赛过程中可以学到更多知识 能够对该领域有更多的了解和认识! +::: + +::: info 开源小白 - 2025-11-19 +感谢 Nexent 让我踏上了开源之旅!这个项目的文档真的很棒,帮助我快速上手。 +::: + +::: tip qinjavermo - 2025-11-19 +感谢 Nexent 让我踏上了开源之旅!给我一个机会制作智能体 +::: + +::: info chengyudan - 2025-10-20 +感谢 Nexent 让我踏上了开源之旅! +::: + +::: info user - 2025-11-20 +学习ai - agent非常好的项目,后面会持续输出贡献! +::: diff --git a/doc/docs/zh/sdk/monitoring.md b/doc/docs/zh/sdk/monitoring.md index 0a614e4d2..c592df267 100644 --- a/doc/docs/zh/sdk/monitoring.md +++ b/doc/docs/zh/sdk/monitoring.md @@ -31,7 +31,8 @@ uv sync --extra performance export ENABLE_TELEMETRY=true # 4. 启动后端服务 -python backend/main_service.py +python backend/config_service.py +python backend/runtime_service.py ``` ## 📊 访问监控界面 diff --git a/doc/docs/zh/sdk/vector-database.md b/doc/docs/zh/sdk/vector-database.md index 68c72af45..940af9c33 100644 --- a/doc/docs/zh/sdk/vector-database.md +++ b/doc/docs/zh/sdk/vector-database.md @@ -234,7 +234,7 @@ docker network rm elastic - `elasticsearch_core.py`: 主类,包含所有 Elasticsearch 操作 - `embedding_model.py`: 处理使用 Jina AI 模型生成嵌入向量 - `utils.py`: 数据格式化和显示的工具函数 -- `elasticsearch_service.py`: FastAPI 服务,提供 REST API 接口 +- `vectordatabase_service.py`: FastAPI 服务,提供 REST API 接口 ## 使用示例 @@ -244,10 +244,10 @@ docker network rm elastic from nexent.vector_database.elasticsearch_core import ElasticSearchCore # 使用 .env 文件中的凭据初始化 -es_core = ElasticSearchCore() +vdb_core = ElasticSearchCore() # 或直接指定凭据 -es_core = ElasticSearchCore( +vdb_core = ElasticSearchCore( host="https://localhost:9200", api_key="your_api_key", verify_certs=False, @@ -259,21 +259,21 @@ es_core = ElasticSearchCore( ```python # 创建新的向量索引 -es_core.create_vector_index("my_documents") +vdb_core.create_index("my_documents") # 列出所有用户索引 -indices = es_core.get_user_indices() +indices = vdb_core.get_user_indices() print(indices) # 获取所有索引的统计信息 -all_indices_stats = es_core.get_all_indices_stats() +all_indices_stats = vdb_core.get_all_indices_stats() print(all_indices_stats) # 删除索引 -es_core.delete_index("my_documents") +vdb_core.delete_index("my_documents") # 创建测试知识库 -index_name, doc_count = es_core.create_test_knowledge_base() +index_name, doc_count = vdb_core.create_test_knowledge_base() print(f"创建了测试知识库 {index_name},包含 {doc_count} 个文档") ``` @@ -304,11 +304,11 @@ documents = [ } ] # 支持批量处理,默认批处理大小为3000 -total_indexed = es_core.index_documents("my_documents", documents, batch_size=3000) +total_indexed = vdb_core.vectorize_documents("my_documents", documents, batch_size=3000) print(f"成功索引了 {total_indexed} 个文档") # 通过 URL 或路径删除文档 -deleted_count = es_core.delete_documents_by_path_or_url("my_documents", "https://example.com/doc1") +deleted_count = vdb_core.delete_documents("my_documents", "https://example.com/doc1") print(f"删除了 {deleted_count} 个文档") ``` @@ -316,17 +316,17 @@ print(f"删除了 {deleted_count} 个文档") ```python # 文本精确搜索 -results = es_core.accurate_search("my_documents", "示例查询", top_k=5) +results = vdb_core.accurate_search("my_documents", "示例查询", top_k=5) for result in results: print(f"得分: {result['score']}, 文档: {result['document']['title']}") # 语义向量搜索 -results = es_core.semantic_search("my_documents", "示例查询", top_k=5) +results = vdb_core.semantic_search("my_documents", "示例查询", top_k=5) for result in results: print(f"得分: {result['score']}, 文档: {result['document']['title']}") # 混合搜索 -results = es_core.hybrid_search( +results = vdb_core.hybrid_search( "my_documents", "示例查询", top_k=5, @@ -340,19 +340,19 @@ for result in results: ```python # 获取索引统计信息 -stats = es_core.get_index_stats("my_documents") +stats = vdb_core.get_indices_detail(["my_documents"]) print(stats) # 获取文件列表及详细信息 -file_details = es_core.get_file_list_with_details("my_documents") +file_details = vdb_core.get_documents_detail("my_documents") print(file_details) # 获取嵌入模型信息 -embedding_model = es_core.get_embedding_model_info("my_documents") +embedding_model = vdb_core.get_embedding_model_info("my_documents") print(f"使用的嵌入模型: {embedding_model}") # 打印所有索引信息 -es_core.print_all_indices_info() +vdb_core.print_all_indices_info() ``` ## ElasticSearchCore 主要功能 @@ -368,7 +368,7 @@ ElasticSearchCore 类提供了以下主要功能: ```python # 获取索引的文件列表及详细信息 -files = es_core.get_file_list_with_details("my_documents") +files = vdb_core.get_documents_detail("my_documents") for file in files: print(f"文件路径: {file['path_or_url']}") print(f"文件名: {file['file']}") @@ -377,11 +377,11 @@ for file in files: print("---") # 获取嵌入模型信息 -model_info = es_core.get_embedding_model_info("my_documents") +model_info = vdb_core.get_embedding_model_info("my_documents") print(f"使用的嵌入模型: {model_info}") # 获取所有索引的综合统计信息 -all_stats = es_core.get_all_indices_stats() +all_stats = vdb_core.get_all_indices_stats() for index_name, stats in all_stats.items(): print(f"索引: {index_name}") print(f"文档数: {stats['base_info']['doc_count']}") @@ -392,12 +392,12 @@ for index_name, stats in all_stats.items(): ## API 服务接口 -通过 `elasticsearch_service.py` 提供的 FastAPI 服务,可使用 REST API 访问上述所有功能。 +通过 `vectordatabase_service.py` 提供的 FastAPI 服务,可使用 REST API 访问上述所有功能。 ### 服务启动 ```bash -python -m nexent.service.elasticsearch_service +python -m nexent.service.vectordatabase_service ``` 服务默认在 `http://localhost:8000` 运行。 @@ -836,13 +836,13 @@ print(json.dumps(response.json(), indent=2, ensure_ascii=False)) ```python # 初始化 ElasticSearchCore -es_core = ElasticSearchCore() +vdb_core = ElasticSearchCore() # 获取或创建测试知识库 index_name = "sample_articles" # 列出所有用户索引 -user_indices = es_core.get_user_indices() +user_indices = vdb_core.get_user_indices() for idx in user_indices: print(f" - {idx}") @@ -850,20 +850,20 @@ for idx in user_indices: if index_name in user_indices: # 精确搜索 query = "Doctor" - accurate_results = es_core.accurate_search(index_name, query, top_k=2) + accurate_results = vdb_core.accurate_search(index_name, query, top_k=2) # 语义搜索 query = "medical professionals in London" - semantic_results = es_core.semantic_search(index_name, query, top_k=2) + semantic_results = vdb_core.semantic_search(index_name, query, top_k=2) # 混合搜索 query = "medical professionals in London" - semantic_results = es_core.hybrid_search(index_name, query, top_k=2, weight_accurate=0.5) + semantic_results = vdb_core.hybrid_search(index_name, query, top_k=2, weight_accurate=0.5) # 获取索引统计信息 - stats = es_core.get_index_stats(index_name) - fields = es_core.get_index_mapping(index_name) - unique_sources = es_core.get_unique_sources_count(index_name) + stats = vdb_core.get_indices_detail([index_name]) + fields = vdb_core.get_index_mapping(index_name) + unique_sources = vdb_core.get_unique_sources_count(index_name) ``` ## 许可证 diff --git a/doc/docs/zh/version/version-management.md b/doc/docs/zh/version/version-management.md index fefb1f6a7..318b88253 100644 --- a/doc/docs/zh/version/version-management.md +++ b/doc/docs/zh/version/version-management.md @@ -74,7 +74,7 @@ APP_VERSION = "v1.0.0" 后端启动时会在日志中打印版本信息: ```python -# backend/main_service.py +# backend/config_service.py logger.info(f"APP version is: {APP_VERSION}") ``` @@ -92,7 +92,7 @@ APP_VERSION="v1.1.0" ```bash # 启动后端服务 cd backend - python main_service.py + python config_service.py # 查看启动日志中的版本信息 # 输出示例:APP version is: v1.1.0 diff --git a/docker/.env.example b/docker/.env.example index 41b27f1b0..2e18a6068 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -35,10 +35,21 @@ ES_DISK_WATERMARK_HIGH=90% ES_DISK_WATERMARK_FLOOD_STAGE=95% # Main Services -ELASTICSEARCH_SERVICE=http://nexent:5010/api -NEXENT_MCP_SERVER=http://nexent:5011 +# Config service (port 5010) - Main API service for config operations +CONFIG_SERVICE_URL=http://nexent-config:5010 +ELASTICSEARCH_SERVICE=http://nexent-config:5010/api + +# Runtime service (port 5014) - Runtime execution service for agent operations +RUNTIME_SERVICE_URL=http://nexent-runtime:5014 + +# MCP service (port 5011) - MCP protocol service +NEXENT_MCP_SERVER=http://nexent-mcp:5011 + +# Data process service (port 5012) - Data processing service DATA_PROCESS_SERVICE=http://nexent-data-process:5012/api -NORTHBOUND_API_SERVER=http://nexent:5013/api + +# Northbound service (port 5013) - Northbound API service +NORTHBOUND_API_SERVER=http://nexent-northbound:5013/api # Postgres Config POSTGRES_HOST=nexent-postgresql diff --git a/docker/deploy.sh b/docker/deploy.sh index f2ef8a75f..2b3ea9618 100755 --- a/docker/deploy.sh +++ b/docker/deploy.sh @@ -375,7 +375,7 @@ prepare_directory_and_data() { deploy_core_services() { # Function to deploy core services echo "👀 Starting core services..." - if ! ${docker_compose_command} -p nexent -f "docker-compose${COMPOSE_FILE_SUFFIX}" up -d nexent nexent-web nexent-data-process; then + if ! ${docker_compose_command} -p nexent -f "docker-compose${COMPOSE_FILE_SUFFIX}" up -d nexent-config nexent-runtime nexent-mcp nexent-northbound nexent-web nexent-data-process; then echo " ❌ ERROR Failed to start core services" exit 1 fi @@ -618,7 +618,7 @@ select_terminal_tool() { create_default_admin_user() { echo "🔧 Creating admin user..." - RESPONSE=$(docker exec nexent bash -c "curl -X POST http://kong:8000/auth/v1/signup -H \"apikey: ${SUPABASE_KEY}\" -H \"Authorization: Bearer ${SUPABASE_KEY}\" -H \"Content-Type: application/json\" -d '{\"email\":\"nexent@example.com\",\"password\":\"nexent@4321\",\"email_confirm\":true,\"data\":{\"role\":\"admin\"}}'" 2>/dev/null) + RESPONSE=$(docker exec nexent-config bash -c "curl -X POST http://kong:8000/auth/v1/signup -H \"apikey: ${SUPABASE_KEY}\" -H \"Authorization: Bearer ${SUPABASE_KEY}\" -H \"Content-Type: application/json\" -d '{\"email\":\"nexent@example.com\",\"password\":\"nexent@4321\",\"email_confirm\":true,\"data\":{\"role\":\"admin\"}}'" 2>/dev/null) if [ -z "$RESPONSE" ]; then echo " ❌ No response received from Supabase." diff --git a/docker/docker-compose.prod.yml b/docker/docker-compose.prod.yml index 311a2cd0d..5e3fc1a22 100644 --- a/docker/docker-compose.prod.yml +++ b/docker/docker-compose.prod.yml @@ -69,9 +69,9 @@ services: networks: - nexent - nexent: + nexent-config: image: ${NEXENT_IMAGE} - container_name: nexent + container_name: nexent-config restart: always volumes: - ${NEXENT_USER_DIR:-$HOME/nexent}:/mnt/nexent @@ -93,9 +93,85 @@ services: max-file: "3" # Maximum number of log files to keep networks: - nexent - ports: - - "5013:5013" # Northbound service port - entrypoint: ["/bin/bash", "-c", "python backend/nexent_mcp_service.py & python backend/northbound_service.py & python backend/main_service.py"] + entrypoint: ["/bin/bash", "-c", "python backend/config_service.py"] + + nexent-runtime: + image: ${NEXENT_IMAGE} + container_name: nexent-runtime + restart: always + volumes: + - ${NEXENT_USER_DIR:-$HOME/nexent}:/mnt/nexent + - ${ROOT_DIR}/openssh-server/ssh-keys:/opt/ssh-keys:ro + environment: + <<: [*minio-vars, *es-vars] + skip_proxy: "true" + UMASK: 0022 + env_file: + - .env + user: root + depends_on: + nexent-elasticsearch: + condition: service_healthy + logging: + driver: "json-file" + options: + max-size: "10m" # Maximum size of a single log file + max-file: "3" # Maximum number of log files to keep + networks: + - nexent + entrypoint: ["/bin/bash", "-c", "python backend/runtime_service.py"] + + nexent-mcp: + image: ${NEXENT_IMAGE} + container_name: nexent-mcp + restart: always + volumes: + - ${NEXENT_USER_DIR:-$HOME/nexent}:/mnt/nexent + - ${ROOT_DIR}/openssh-server/ssh-keys:/opt/ssh-keys:ro + environment: + <<: [*minio-vars, *es-vars] + skip_proxy: "true" + UMASK: 0022 + env_file: + - .env + user: root + depends_on: + nexent-elasticsearch: + condition: service_healthy + logging: + driver: "json-file" + options: + max-size: "10m" # Maximum size of a single log file + max-file: "3" # Maximum number of log files to keep + networks: + - nexent + entrypoint: ["/bin/bash", "-c", "python backend/mcp_service.py"] + + nexent-northbound: + image: ${NEXENT_IMAGE} + container_name: nexent-northbound + restart: always + volumes: + - ${NEXENT_USER_DIR:-$HOME/nexent}:/mnt/nexent + - ${ROOT_DIR}/openssh-server/ssh-keys:/opt/ssh-keys:ro + environment: + <<: [*minio-vars, *es-vars] + skip_proxy: "true" + UMASK: 0022 + env_file: + - .env + user: root + depends_on: + nexent-elasticsearch: + condition: service_healthy + logging: + driver: "json-file" + options: + max-size: "10m" # Maximum size of a single log file + max-file: "3" # Maximum number of log files to keep + networks: + - nexent + entrypoint: ["/bin/bash", "-c", "python backend/northbound_service.py"] nexent-web: image: ${NEXENT_WEB_IMAGE} @@ -106,8 +182,9 @@ services: ports: - "3000:3000" environment: - - HTTP_BACKEND=http://nexent:5010 - - WS_BACKEND=ws://nexent:5010 + - HTTP_BACKEND=http://nexent-config:5010 + - WS_BACKEND=ws://nexent-runtime:5014 + - RUNTIME_HTTP_BACKEND=http://nexent-runtime:5014 - MINIO_ENDPOINT=http://nexent-minio:9000 logging: driver: "json-file" diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index fb0981894..8011b8f28 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -74,12 +74,95 @@ services: networks: - nexent - nexent: + nexent-config: + image: ${NEXENT_IMAGE} + container_name: nexent-config + restart: always + ports: + - "5010:5010" # Config service port + volumes: + - ${NEXENT_USER_DIR:-$HOME/nexent}:/mnt/nexent + - ${ROOT_DIR}/openssh-server/ssh-keys:/opt/ssh-keys:ro + environment: + <<: [*minio-vars, *es-vars] + skip_proxy: "true" + UMASK: 0022 + env_file: + - .env + user: root + depends_on: + nexent-elasticsearch: + condition: service_healthy + logging: + driver: "json-file" + options: + max-size: "10m" # Maximum size of a single log file + max-file: "3" # Maximum number of log files to keep + networks: + - nexent + entrypoint: ["/bin/bash", "-c", "python backend/config_service.py"] + + nexent-runtime: + image: ${NEXENT_IMAGE} + container_name: nexent-runtime + restart: always + ports: + - "5014:5014" # Runtime service port + volumes: + - ${NEXENT_USER_DIR:-$HOME/nexent}:/mnt/nexent + - ${ROOT_DIR}/openssh-server/ssh-keys:/opt/ssh-keys:ro + environment: + <<: [*minio-vars, *es-vars] + skip_proxy: "true" + UMASK: 0022 + env_file: + - .env + user: root + depends_on: + nexent-elasticsearch: + condition: service_healthy + logging: + driver: "json-file" + options: + max-size: "10m" # Maximum size of a single log file + max-file: "3" # Maximum number of log files to keep + networks: + - nexent + entrypoint: ["/bin/bash", "-c", "python backend/runtime_service.py"] + + nexent-mcp: + image: ${NEXENT_IMAGE} + container_name: nexent-mcp + restart: always + ports: + - "5011:5011" # MCP service port + volumes: + - ${NEXENT_USER_DIR:-$HOME/nexent}:/mnt/nexent + - ${ROOT_DIR}/openssh-server/ssh-keys:/opt/ssh-keys:ro + environment: + <<: [*minio-vars, *es-vars] + skip_proxy: "true" + UMASK: 0022 + env_file: + - .env + user: root + depends_on: + nexent-elasticsearch: + condition: service_healthy + logging: + driver: "json-file" + options: + max-size: "10m" # Maximum size of a single log file + max-file: "3" # Maximum number of log files to keep + networks: + - nexent + entrypoint: ["/bin/bash", "-c", "python backend/mcp_service.py"] + + nexent-northbound: image: ${NEXENT_IMAGE} - container_name: nexent + container_name: nexent-northbound restart: always ports: - - "5010:5010" # Backend service port - "5013:5013" # Northbound service port volumes: - ${NEXENT_USER_DIR:-$HOME/nexent}:/mnt/nexent @@ -101,7 +184,7 @@ services: max-file: "3" # Maximum number of log files to keep networks: - nexent - entrypoint: ["/bin/bash", "-c", "python backend/nexent_mcp_service.py & python backend/northbound_service.py & python backend/main_service.py"] + entrypoint: ["/bin/bash", "-c", "python backend/northbound_service.py"] nexent-web: image: ${NEXENT_WEB_IMAGE} @@ -112,8 +195,9 @@ services: ports: - "3000:3000" environment: - - HTTP_BACKEND=http://nexent:5010 - - WS_BACKEND=ws://nexent:5010 + - HTTP_BACKEND=http://nexent-config:5010 + - WS_BACKEND=ws://nexent-runtime:5014 + - RUNTIME_HTTP_BACKEND=http://nexent-runtime:5014 - MINIO_ENDPOINT=http://nexent-minio:9000 logging: driver: "json-file" diff --git a/docker/generate_env.sh b/docker/generate_env.sh index 246873f5a..4be94e28b 100755 --- a/docker/generate_env.sh +++ b/docker/generate_env.sh @@ -106,6 +106,23 @@ update_env_file() { echo "ELASTICSEARCH_HOST=http://localhost:9210" >> ../.env fi + # Main Services + # CONFIG_SERVICE_URL + if grep -q "^CONFIG_SERVICE_URL=" ../.env; then + sed -i.bak "s~^CONFIG_SERVICE_URL=.*~CONFIG_SERVICE_URL=http://localhost:5010~" ../.env + else + echo "" >> ../.env + echo "# Main Services" >> ../.env + echo "CONFIG_SERVICE_URL=http://localhost:5010" >> ../.env + fi + + # RUNTIME_SERVICE_URL + if grep -q "^RUNTIME_SERVICE_URL=" ../.env; then + sed -i.bak "s~^RUNTIME_SERVICE_URL=.*~RUNTIME_SERVICE_URL=http://localhost:5014~" ../.env + else + echo "RUNTIME_SERVICE_URL=http://localhost:5014" >> ../.env + fi + # ELASTICSEARCH_SERVICE if grep -q "^ELASTICSEARCH_SERVICE=" ../.env; then sed -i.bak "s~^ELASTICSEARCH_SERVICE=.*~ELASTICSEARCH_SERVICE=http://localhost:5010/api~" ../.env diff --git a/frontend/app/[locale]/setup/agents/config.tsx b/frontend/app/[locale]/agents/AgentConfiguration.tsx similarity index 95% rename from frontend/app/[locale]/setup/agents/config.tsx rename to frontend/app/[locale]/agents/AgentConfiguration.tsx index e6dee2283..ab5411811 100644 --- a/frontend/app/[locale]/setup/agents/config.tsx +++ b/frontend/app/[locale]/agents/AgentConfiguration.tsx @@ -36,7 +36,7 @@ import { configStore } from "@/lib/config"; import AgentSetupOrchestrator from "./components/AgentSetupOrchestrator"; import DebugConfig from "./components/DebugConfig"; -import "../../i18n"; +import "../i18n"; // Layout Height Constant Configuration const LAYOUT_CONFIG: LayoutConfig = AGENT_SETUP_LAYOUT_DEFAULT; @@ -161,12 +161,25 @@ export default forwardRef(function AgentCon const currentAgentName = agentName; const currentAgentDisplayName = agentDisplayName; + // Extract tool IDs from selected tools (convert string IDs to numbers) + // Always pass tool_ids array (empty array means no tools selected, undefined means use database) + // In edit mode, we want to use current selection, so pass the array even if empty + const toolIds = selectedTools.map((tool) => Number(tool.id)); + + // Get sub-agent IDs from enabledAgentIds + // Always pass sub_agent_ids array (empty array means no sub-agents selected, undefined means use database) + // In edit mode, we want to use current selection, so pass the array even if empty + const subAgentIds = [...enabledAgentIds]; + // Call backend API to generate agent prompt + // Pass tool_ids and sub_agent_ids to use frontend selection instead of database query await generatePromptStream( { agent_id: agentIdToUse, task_description: businessLogic, model_id: selectedModel?.id?.toString() || "", + tool_ids: toolIds, + sub_agent_ids: subAgentIds, }, (data) => { // Process streaming response data @@ -525,13 +538,13 @@ export default forwardRef(function AgentCon className="w-full mx-auto" style={{ maxWidth: SETUP_PAGE_CONTAINER.MAX_WIDTH, - height: SETUP_PAGE_CONTAINER.MAIN_CONTENT_HEIGHT, + padding: `0 ${SETUP_PAGE_CONTAINER.HORIZONTAL_PADDING}`, }} >
diff --git a/frontend/app/[locale]/agents/AgentsContent.tsx b/frontend/app/[locale]/agents/AgentsContent.tsx new file mode 100644 index 000000000..af658c0a4 --- /dev/null +++ b/frontend/app/[locale]/agents/AgentsContent.tsx @@ -0,0 +1,110 @@ +"use client"; + +import React, {useState, useEffect, useRef} from "react"; +import {motion} from "framer-motion"; + +import {useSetupFlow} from "@/hooks/useSetupFlow"; +import { + ConnectionStatus, +} from "@/const/modelConfig"; + +import AgentConfig, {AgentConfigHandle} from "./AgentConfiguration"; +import SaveConfirmModal from "./components/SaveConfirmModal"; + +interface AgentsContentProps { + /** Whether currently saving */ + isSaving?: boolean; + /** Connection status */ + connectionStatus?: ConnectionStatus; + /** Is checking connection */ + isCheckingConnection?: boolean; + /** Check connection callback */ + onCheckConnection?: () => void; + /** Callback to expose connection status */ + onConnectionStatusChange?: (status: ConnectionStatus) => void; + /** Callback to expose saving state */ + onSavingStateChange?: (isSaving: boolean) => void; +} + +/** + * AgentsContent - Main component for agent configuration + * Can be used in setup flow or as standalone page + */ +export default function AgentsContent({ + isSaving: externalIsSaving, + connectionStatus: externalConnectionStatus, + isCheckingConnection: externalIsCheckingConnection, + onCheckConnection: externalOnCheckConnection, + onConnectionStatusChange, + onSavingStateChange, +}: AgentsContentProps) { + const agentConfigRef = useRef(null); + const [showSaveConfirm, setShowSaveConfirm] = useState(false); + const pendingNavRef = useRef void)>(null); + + // Use custom hook for common setup flow logic + const { + canAccessProtectedData, + pageVariants, + pageTransition, + } = useSetupFlow({ + requireAdmin: true, + externalConnectionStatus, + externalIsCheckingConnection, + onCheckConnection: externalOnCheckConnection, + onConnectionStatusChange, + nonAdminRedirect: "/setup/knowledges", + }); + + const [internalIsSaving, setInternalIsSaving] = useState(false); + const isSaving = externalIsSaving ?? internalIsSaving; + + // Update external saving state + useEffect(() => { + onSavingStateChange?.(isSaving); + }, [isSaving, onSavingStateChange]); + + return ( + <> + + {canAccessProtectedData ? ( + + ) : null} + + + { + // Reload data from backend to discard changes + await agentConfigRef.current?.reloadCurrentAgentData?.(); + setShowSaveConfirm(false); + const go = pendingNavRef.current; + pendingNavRef.current = null; + if (go) go(); + }} + onSave={async () => { + try { + setInternalIsSaving(true); + await agentConfigRef.current?.saveAllChanges?.(); + setShowSaveConfirm(false); + const go = pendingNavRef.current; + pendingNavRef.current = null; + if (go) go(); + } catch (e) { + // errors are surfaced by underlying save + } finally { + setInternalIsSaving(false); + } + }} + /> + + ); +} + diff --git a/frontend/app/[locale]/setup/agents/components/AgentSetupOrchestrator.tsx b/frontend/app/[locale]/agents/components/AgentSetupOrchestrator.tsx similarity index 97% rename from frontend/app/[locale]/setup/agents/components/AgentSetupOrchestrator.tsx rename to frontend/app/[locale]/agents/components/AgentSetupOrchestrator.tsx index d2c4cb7cd..dd19e3abd 100644 --- a/frontend/app/[locale]/setup/agents/components/AgentSetupOrchestrator.tsx +++ b/frontend/app/[locale]/agents/components/AgentSetupOrchestrator.tsx @@ -4,7 +4,7 @@ import { useState, useEffect, useCallback, useRef, useMemo } from "react"; import { useTranslation } from "react-i18next"; import { TFunction } from "i18next"; -import { App, Modal, Typography, Button } from "antd"; +import { App, Modal, Button } from "antd"; import { WarningFilled } from "@ant-design/icons"; import { TooltipProvider } from "@/components/ui/tooltip"; @@ -113,6 +113,38 @@ export default function AgentSetupOrchestrator({ // Edit agent related status const [isEditingAgent, setIsEditingAgent] = useState(false); const [editingAgent, setEditingAgent] = useState(null); + const activeEditingAgent = editingAgentFromParent || editingAgent; + const isAgentUnavailable = activeEditingAgent?.is_available === false; + const agentUnavailableReasons = + isAgentUnavailable && Array.isArray(activeEditingAgent?.unavailable_reasons) + ? (activeEditingAgent?.unavailable_reasons as string[]) + : []; + const mergeAgentAvailabilityMetadata = useCallback( + (detail: Agent, fallback?: Agent | null): Agent => { + const detailReasons = Array.isArray(detail?.unavailable_reasons) + ? detail.unavailable_reasons + : []; + const fallbackReasons = Array.isArray(fallback?.unavailable_reasons) + ? fallback!.unavailable_reasons! + : []; + const normalizedReasons = + detailReasons.length > 0 ? detailReasons : fallbackReasons; + + const normalizedAvailability = + typeof detail?.is_available === "boolean" + ? detail.is_available + : typeof fallback?.is_available === "boolean" + ? fallback.is_available + : detail?.is_available; + + return { + ...detail, + unavailable_reasons: normalizedReasons, + is_available: normalizedAvailability, + }; + }, + [] + ); // Add a flag to track if it has been initialized to avoid duplicate calls const hasInitialized = useRef(false); @@ -287,7 +319,11 @@ export default function AgentSetupOrchestrator({ return; } - const agentDetail = result.data; + const agentDetail = mergeAgentAvailabilityMetadata( + result.data as Agent, + editingAgent + ); + setEditingAgent(agentDetail); // Reload all agent data to match backend state setAgentName?.(agentDetail.name || ""); @@ -296,7 +332,7 @@ export default function AgentSetupOrchestrator({ // Load Agent data to interface setMainAgentModel(agentDetail.model); - setMainAgentModelId(agentDetail.model_id); + setMainAgentModelId(agentDetail.model_id ?? null); setMainAgentMaxStep(agentDetail.max_step); setBusinessLogic(agentDetail.business_description || ""); setBusinessLogicModel(agentDetail.business_logic_model_name || null); @@ -971,7 +1007,9 @@ export default function AgentSetupOrchestrator({ try { const detail = await searchAgentInfo(newId); if (detail.success && detail.data) { - const agentDetail = detail.data; + const agentDetail = mergeAgentAvailabilityMetadata( + detail.data as Agent + ); setIsEditingAgent(true); setEditingAgent(agentDetail); setMainAgentId(agentDetail.id); @@ -982,7 +1020,7 @@ export default function AgentSetupOrchestrator({ setAgentDisplayName?.(agentDetail.display_name || ""); onEditingStateChange?.(true, agentDetail); setMainAgentModel(agentDetail.model); - setMainAgentModelId(agentDetail.model_id); + setMainAgentModelId(agentDetail.model_id ?? null); setMainAgentMaxStep(agentDetail.max_step); setBusinessLogic(agentDetail.business_description || ""); setBusinessLogicModel( @@ -1146,7 +1184,10 @@ export default function AgentSetupOrchestrator({ return; } - const agentDetail = result.data; + const agentDetail = mergeAgentAvailabilityMetadata( + result.data as Agent, + agent + ); // Set editing state and highlight after successfully getting information setIsEditingAgent(true); @@ -1170,7 +1211,7 @@ export default function AgentSetupOrchestrator({ // Load Agent data to interface setMainAgentModel(agentDetail.model); - setMainAgentModelId(agentDetail.model_id); + setMainAgentModelId(agentDetail.model_id ?? null); setMainAgentMaxStep(agentDetail.max_step); setBusinessLogic(agentDetail.business_description || ""); setBusinessLogicModel(agentDetail.business_logic_model_name || null); @@ -1595,7 +1636,6 @@ export default function AgentSetupOrchestrator({ isGeneratingAgent={isGeneratingAgent} editingAgent={editingAgent} isCreatingNewAgent={isCreatingNewAgent} - editingAgentName={agentName || null} onExportAgent={handleExportAgentFromList} onDeleteAgent={handleDeleteAgentFromList} unsavedAgentId={ @@ -1682,6 +1722,7 @@ export default function AgentSetupOrchestrator({ isEditingMode={isEditingAgent || isCreatingNewAgent} isGeneratingAgent={isGeneratingAgent} isEmbeddingConfigured={isEmbeddingConfigured} + agentUnavailableReasons={agentUnavailableReasons} />
diff --git a/frontend/app/[locale]/setup/agents/components/DebugConfig.tsx b/frontend/app/[locale]/agents/components/DebugConfig.tsx similarity index 100% rename from frontend/app/[locale]/setup/agents/components/DebugConfig.tsx rename to frontend/app/[locale]/agents/components/DebugConfig.tsx diff --git a/frontend/app/[locale]/setup/agents/components/McpConfigModal.tsx b/frontend/app/[locale]/agents/components/McpConfigModal.tsx similarity index 100% rename from frontend/app/[locale]/setup/agents/components/McpConfigModal.tsx rename to frontend/app/[locale]/agents/components/McpConfigModal.tsx diff --git a/frontend/app/[locale]/setup/agents/components/PromptManager.tsx b/frontend/app/[locale]/agents/components/PromptManager.tsx similarity index 100% rename from frontend/app/[locale]/setup/agents/components/PromptManager.tsx rename to frontend/app/[locale]/agents/components/PromptManager.tsx diff --git a/frontend/app/[locale]/setup/agents/components/SaveConfirmModal.tsx b/frontend/app/[locale]/agents/components/SaveConfirmModal.tsx similarity index 100% rename from frontend/app/[locale]/setup/agents/components/SaveConfirmModal.tsx rename to frontend/app/[locale]/agents/components/SaveConfirmModal.tsx diff --git a/frontend/app/[locale]/setup/agents/components/agent/AgentCallRelationshipModal.tsx b/frontend/app/[locale]/agents/components/agent/AgentCallRelationshipModal.tsx similarity index 100% rename from frontend/app/[locale]/setup/agents/components/agent/AgentCallRelationshipModal.tsx rename to frontend/app/[locale]/agents/components/agent/AgentCallRelationshipModal.tsx diff --git a/frontend/app/[locale]/setup/agents/components/agent/AgentConfigModal.tsx b/frontend/app/[locale]/agents/components/agent/AgentConfigModal.tsx similarity index 85% rename from frontend/app/[locale]/setup/agents/components/agent/AgentConfigModal.tsx rename to frontend/app/[locale]/agents/components/agent/AgentConfigModal.tsx index b618b62fc..098cd1b93 100644 --- a/frontend/app/[locale]/setup/agents/components/agent/AgentConfigModal.tsx +++ b/frontend/app/[locale]/agents/components/agent/AgentConfigModal.tsx @@ -10,6 +10,7 @@ import { Bug, Save, Maximize2 } from "lucide-react"; import log from "@/lib/logger"; import { ModelOption } from "@/types/modelConfig"; +import { Agent } from "@/types/agentConfig"; import { modelService } from "@/services/modelService"; import { checkAgentName, @@ -49,7 +50,7 @@ export interface AgentConfigModalProps { onDeleteSuccess?: () => void; // New prop for handling delete success onSaveAgent?: () => void; isCreatingNewAgent?: boolean; - editingAgent?: any; + editingAgent?: Agent | null; canSaveAgent?: boolean; getButtonTitle?: () => string; onViewCallRelationship?: () => void; // New prop for viewing call relationship @@ -84,7 +85,7 @@ export default function AgentConfigModal({ onDeleteSuccess, onSaveAgent, isCreatingNewAgent = false, - editingAgent, + editingAgent = null, canSaveAgent = false, getButtonTitle, }: AgentConfigModalProps) { @@ -126,6 +127,49 @@ export default function AgentConfigModal({ // Local fallback for selected main model display name (used when parent has not yet propagated) const [localMainAgentModel, setLocalMainAgentModel] = useState(""); + const isAgentUnavailable = editingAgent?.is_available === false; + const normalizedUnavailableReasons = + isAgentUnavailable && Array.isArray(editingAgent?.unavailable_reasons) + ? (editingAgent?.unavailable_reasons as string[]) + : []; + const hasDuplicateDisplayNameReason = normalizedUnavailableReasons.some( + (reason) => + ["duplicate_display_name", "duplicate_disaplay_name"].includes(reason) + ); + const hasDuplicateNameReason = + normalizedUnavailableReasons.includes("duplicate_name"); + const hasModelUnavailableReason = + normalizedUnavailableReasons.includes("model_unavailable"); + const currentDisplayName = (agentDisplayName || "").trim(); + const originalDisplayName = (editingAgent?.display_name || "").trim(); + const currentAgentName = (agentName || "").trim(); + const originalAgentName = (editingAgent?.name || "").trim(); + const shouldShowDuplicateDisplayNameReason = + hasDuplicateDisplayNameReason && + !!currentDisplayName && + currentDisplayName === originalDisplayName; + const shouldShowDuplicateNameReason = + hasDuplicateNameReason && + !!currentAgentName && + currentAgentName === originalAgentName; + const originalModelId = + typeof editingAgent?.model_id === "number" + ? editingAgent.model_id + : null; + const selectedModelId = + typeof mainAgentModelId === "number" + ? mainAgentModelId + : originalModelId; + const effectiveModelName = + mainAgentModel || + localMainAgentModel || + editingAgent?.model || + ""; + const shouldShowModelUnavailableReason = + hasModelUnavailableReason && + originalModelId !== null && + selectedModelId === originalModelId; + // Load LLM models on component mount useEffect(() => { const loadLLMModels = async () => { @@ -272,8 +316,18 @@ export default function AgentConfigModal({ [validateAgentDisplayName, onAgentDisplayNameChange] ); - // Check agent name existence - only when user is actively typing + // Check agent name existence - when creating new agent or when editing and name changed useEffect(() => { + // Perform real-time check when: + // 1. Creating new agent, OR + // 2. Editing existing agent and the name has changed + const shouldCheck = isCreatingNewAgent || + (!isCreatingNewAgent && currentAgentName !== originalAgentName); + + if (!shouldCheck) { + return; + } + if (!agentName) { return; } @@ -306,7 +360,7 @@ export default function AgentConfigModal({ return () => { clearTimeout(timer); }; - }, [isEditingMode, agentName, agentNameError, agentId, agentNameStatus, t]); + }, [isCreatingNewAgent, agentName, agentNameError, agentId, agentNameStatus, currentAgentName, originalAgentName, t]); // Reset user typing state after user stops typing useEffect(() => { @@ -328,12 +382,26 @@ export default function AgentConfigModal({ } }, [agentName]); - // Check agent display name existence - only when user is actively typing + // Clear name status when editing and name hasn't changed (should only use backend markers) useEffect(() => { - if ( - (!isEditingMode && !isCreatingNewAgent) || - !agentDisplayName - ) { + if (!isCreatingNewAgent && currentAgentName === originalAgentName) { + setAgentNameStatus(NAME_CHECK_STATUS.AVAILABLE); + } + }, [isCreatingNewAgent, currentAgentName, originalAgentName]); + + // Check agent display name existence - when creating new agent or when editing and display name changed + useEffect(() => { + // Perform real-time check when: + // 1. Creating new agent, OR + // 2. Editing existing agent and the display name has changed + const shouldCheck = isCreatingNewAgent || + (!isCreatingNewAgent && currentDisplayName !== originalDisplayName); + + if (!shouldCheck) { + return; + } + + if (!agentDisplayName) { return; } @@ -365,7 +433,7 @@ export default function AgentConfigModal({ return () => { clearTimeout(timer); }; - }, [isEditingMode, isCreatingNewAgent, agentDisplayName, agentDisplayNameError, agentId, agentDisplayNameStatus, t]); + }, [isCreatingNewAgent, agentDisplayName, agentDisplayNameError, agentId, agentDisplayNameStatus, currentDisplayName, originalDisplayName, t]); // Reset user typing state for display name after user stops typing useEffect(() => { @@ -387,6 +455,13 @@ export default function AgentConfigModal({ } }, [agentDisplayName]); + // Clear display name status when editing and display name hasn't changed (should only use backend markers) + useEffect(() => { + if (!isCreatingNewAgent && currentDisplayName === originalDisplayName) { + setAgentDisplayNameStatus(NAME_CHECK_STATUS.AVAILABLE); + } + }, [isCreatingNewAgent, currentDisplayName, originalDisplayName]); + // Handle delete confirmation const handleDeleteConfirm = useCallback(() => { setIsDeleteModalVisible(false); @@ -461,12 +536,26 @@ export default function AgentConfigModal({ }, [agentDisplayName, isEditingMode, validateAgentDisplayName]); // Calculate whether save buttons should be enabled + // Check real-time status when: + // 1. Creating new agent, OR + // 2. Editing and name/display name has changed + const shouldCheckNameStatus = isCreatingNewAgent || currentAgentName !== originalAgentName; + const shouldCheckDisplayNameStatus = isCreatingNewAgent || currentDisplayName !== originalDisplayName; + + // Disable save if there are any error indicators from backend (unavailable_reasons) + // These errors should block saving even if names haven't changed + const hasBackendErrors = + shouldShowDuplicateNameReason || + shouldShowDuplicateDisplayNameReason || + shouldShowModelUnavailableReason; + const canActuallySave = canSaveAgent && !agentNameError && - agentNameStatus !== NAME_CHECK_STATUS.EXISTS_IN_TENANT && + (shouldCheckNameStatus ? agentNameStatus !== NAME_CHECK_STATUS.EXISTS_IN_TENANT : true) && !agentDisplayNameError && - agentDisplayNameStatus !== NAME_CHECK_STATUS.EXISTS_IN_TENANT; + (shouldCheckDisplayNameStatus ? agentDisplayNameStatus !== NAME_CHECK_STATUS.EXISTS_IN_TENANT : true) && + !hasBackendErrors; // Render individual content sections const renderAgentInfo = () => ( @@ -488,7 +577,9 @@ export default function AgentConfigModal({ disabled={!isEditingMode} status={ agentDisplayNameError || - agentDisplayNameStatus === NAME_CHECK_STATUS.EXISTS_IN_TENANT + ((isCreatingNewAgent || currentDisplayName !== originalDisplayName) && + agentDisplayNameStatus === NAME_CHECK_STATUS.EXISTS_IN_TENANT) || + shouldShowDuplicateDisplayNameReason ? "error" : "" } @@ -497,6 +588,7 @@ export default function AgentConfigModal({

{agentDisplayNameError}

)} {!agentDisplayNameError && + (isCreatingNewAgent || currentDisplayName !== originalDisplayName) && agentDisplayNameStatus === NAME_CHECK_STATUS.EXISTS_IN_TENANT && (

{t("agent.error.displayNameExists", { @@ -504,6 +596,15 @@ export default function AgentConfigModal({ })}

)} + {!agentDisplayNameError && + agentDisplayNameStatus !== NAME_CHECK_STATUS.EXISTS_IN_TENANT && + shouldShowDuplicateDisplayNameReason && ( +

+ {t("agent.error.displayNameExists", { + displayName: agentDisplayName || editingAgent?.display_name || "", + })} +

+ )} {/* Agent Name */} @@ -521,7 +622,9 @@ export default function AgentConfigModal({ disabled={!isEditingMode} status={ agentNameError || - agentNameStatus === NAME_CHECK_STATUS.EXISTS_IN_TENANT + ((isCreatingNewAgent || currentAgentName !== originalAgentName) && + agentNameStatus === NAME_CHECK_STATUS.EXISTS_IN_TENANT) || + shouldShowDuplicateNameReason ? "error" : "" } @@ -530,11 +633,21 @@ export default function AgentConfigModal({

{agentNameError}

)} {!agentNameError && + (isCreatingNewAgent || currentAgentName !== originalAgentName) && agentNameStatus === NAME_CHECK_STATUS.EXISTS_IN_TENANT && (

{t("agent.error.nameExists", { name: agentName })}

)} + {!agentNameError && + agentNameStatus !== NAME_CHECK_STATUS.EXISTS_IN_TENANT && + shouldShowDuplicateNameReason && ( +

+ {t("agent.error.nameExists", { + name: agentName || editingAgent?.name || "", + })} +

+ )} {/* Model Selection */} @@ -543,6 +656,7 @@ export default function AgentConfigModal({ {t("businessLogic.config.model")}: + {shouldShowModelUnavailableReason && ( +

+ {t("agent.error.modelUnavailable", { + modelName: effectiveModelName, + })} +

+ )} {llmModels.length === 0 && (

{t("businessLogic.config.error.noAvailableModels")} diff --git a/frontend/app/[locale]/setup/agents/components/agent/CollaborativeAgentDisplay.tsx b/frontend/app/[locale]/agents/components/agent/CollaborativeAgentDisplay.tsx similarity index 93% rename from frontend/app/[locale]/setup/agents/components/agent/CollaborativeAgentDisplay.tsx rename to frontend/app/[locale]/agents/components/agent/CollaborativeAgentDisplay.tsx index 822c12089..9659ee4a0 100644 --- a/frontend/app/[locale]/setup/agents/components/agent/CollaborativeAgentDisplay.tsx +++ b/frontend/app/[locale]/agents/components/agent/CollaborativeAgentDisplay.tsx @@ -183,23 +183,21 @@ export default function CollaborativeAgentDisplay({ handleRemoveCollaborativeAgent(Number(agent.id))} closeIcon={} style={{ - fontSize: "9px", - padding: "1px 2px", - lineHeight: "1.2", - height: "auto", - minHeight: "16px", - borderRadius: "4px", - display: "inline-flex", - alignItems: "center", - justifyContent: "center", + maxWidth: "200px", }} > - + {agent.display_name || agent.name} @@ -209,3 +207,4 @@ export default function CollaborativeAgentDisplay({ ); } + diff --git a/frontend/app/[locale]/setup/agents/components/agent/SubAgentPool.tsx b/frontend/app/[locale]/agents/components/agent/SubAgentPool.tsx similarity index 83% rename from frontend/app/[locale]/setup/agents/components/agent/SubAgentPool.tsx rename to frontend/app/[locale]/agents/components/agent/SubAgentPool.tsx index 1b0e06ef8..6314a75f4 100644 --- a/frontend/app/[locale]/setup/agents/components/agent/SubAgentPool.tsx +++ b/frontend/app/[locale]/agents/components/agent/SubAgentPool.tsx @@ -1,9 +1,9 @@ "use client"; -import { useState, useMemo } from "react"; +import { useState } from "react"; import { useTranslation } from "react-i18next"; -import { App, Button } from "antd"; +import { Button } from "antd"; import { ExclamationCircleOutlined } from "@ant-design/icons"; import { FileOutput, Network, FileInput, Trash2, Plus, X } from "lucide-react"; @@ -16,7 +16,7 @@ import { } from "@/components/ui/tooltip"; import { Agent, SubAgentPoolProps } from "@/types/agentConfig"; -import AgentCallRelationshipModal from "./AgentCallRelationshipModal"; +import AgentCallRelationshipModal from "@/components/ui/AgentCallRelationshipModal"; /** * Sub Agent Pool Component @@ -37,13 +37,11 @@ export default function SubAgentPool({ isGeneratingAgent = false, editingAgent = null, isCreatingNewAgent = false, - editingAgentName = null, onExportAgent, onDeleteAgent, unsavedAgentId = null, }: ExtendedSubAgentPoolProps) { const { t } = useTranslation("common"); - const { message } = App.useApp(); // Call relationship related state const [callRelationshipModalVisible, setCallRelationshipModalVisible] = @@ -63,53 +61,6 @@ export default function SubAgentPool({ setSelectedAgentForRelationship(null); }; - // Detect duplicate agent names and mark later-added agents as disabled - // For agents with the same name, keep the first one (smallest ID) enabled, disable the rest - const duplicateAgentInfo = useMemo(() => { - // Create a map to track agents by name - const nameToAgents = new Map(); - - subAgentList.forEach((agent) => { - // Use the current editing name if this agent is being edited, otherwise use the original name - const agentName = - editingAgent && - String(editingAgent.id) === String(agent.id) && - editingAgentName - ? editingAgentName - : agent.name; - - if (!nameToAgents.has(agentName)) { - nameToAgents.set(agentName, []); - } - nameToAgents.get(agentName)!.push(agent); - }); - - // For each group of agents with the same name, sort by ID (smallest first) - // Mark all except the first one as disabled - const disabledAgentIds = new Set(); - - nameToAgents.forEach((agents, name) => { - if (agents.length > 1) { - // Sort by ID (treating as number if possible, otherwise as string) - const sortedAgents = [...agents].sort((a, b) => { - const idA = Number(a.id); - const idB = Number(b.id); - if (!isNaN(idA) && !isNaN(idB)) { - return idA - idB; - } - return String(a.id).localeCompare(String(b.id)); - }); - - // Mark all except the first one as disabled - for (let i = 1; i < sortedAgents.length; i++) { - disabledAgentIds.add(String(sortedAgents[i].id)); - } - } - }); - - return { disabledAgentIds, nameToAgents }; - }, [subAgentList, editingAgent, editingAgentName]); - return ( + + ); +} + diff --git a/frontend/components/ui/markdownRenderer.tsx b/frontend/components/ui/markdownRenderer.tsx index 29ec151be..a96fbaebd 100644 --- a/frontend/components/ui/markdownRenderer.tsx +++ b/frontend/components/ui/markdownRenderer.tsx @@ -29,8 +29,31 @@ interface MarkdownRendererProps { searchResults?: SearchResult[]; showDiagramToggle?: boolean; onCitationHover?: () => void; + enableMultimodal?: boolean; } +const VIDEO_EXTENSIONS = [".mp4", ".webm", ".ogg", ".mov", ".m4v"]; + +const extractExtension = (value: string): string => { + const normalized = value.split("?")[0].split("#")[0]; + const match = normalized.toLowerCase().match(/\.[a-z0-9]+$/); + return match?.[0] ?? ""; +}; + +const isVideoUrl = (url?: string): boolean => { + if (!url) { + return false; + } + + const trimmed = url.trim(); + if (!trimmed.startsWith("http://") && !trimmed.startsWith("https://")) { + return false; + } + + const extension = extractExtension(trimmed); + return VIDEO_EXTENSIONS.includes(extension); +}; + // Get background color for different tool signs const getBackgroundColor = (toolSign: string) => { switch (toolSign) { @@ -364,6 +387,7 @@ export const MarkdownRenderer: React.FC = ({ searchResults = [], showDiagramToggle = true, onCitationHover, + enableMultimodal = true, }) => { const { t } = useTranslation("common"); @@ -407,6 +431,71 @@ export const MarkdownRenderer: React.FC = ({ }, }; + const renderCodeFallback = (text: string, key?: React.Key) => ( + + {text} + + ); + + const buildMediaFallbackText = (src?: string | null, alt?: string | null) => { + if (alt) { + return `${t("chatStreamMessage.imageTextFallbackTitle", { + defaultValue: "Media (text view)", + })}: ${alt}${src ? ` - ${src}` : ""}`; + } + return ( + src ?? + t("chatStreamMessage.imageTextFallbackTitle", { + defaultValue: "Media (text view)", + }) + ); + }; + + const renderMediaFallback = (src?: string | null, alt?: string | null) => + renderCodeFallback(buildMediaFallbackText(src, alt)); + + const renderVideoElement = ({ + src, + alt, + props = {}, + }: { + src?: string | null; + alt?: string | null; + props?: React.VideoHTMLAttributes; + }) => { + if (!src) { + return null; + } + + if (!enableMultimodal) { + return renderMediaFallback(src, alt); + } + + return ( +

+ + {alt ? ( +
{alt}
+ ) : null} +
+ ); + }; + // Modified processText function logic const processText = (text: string) => { if (typeof text !== "string") return text; @@ -445,6 +534,9 @@ export const MarkdownRenderer: React.FC = ({ const mmd = part.match(/^:mermaid\[([^\]]+)\]$/); if (mmd) { const code = mmd[1]; + if (!enableMultimodal) { + return renderCodeFallback(code, `mmd-placeholder-${index}`); + } return ; } // Handle line breaks in text content @@ -627,11 +719,13 @@ export const MarkdownRenderer: React.FC = ({ ), // Link - a: ({ href, children, ...props }: any) => ( - - {children} - - ), + a: ({ href, children, ...props }: any) => { + return ( + + {children} + + ); + }, pre: ({ children }: any) => <>{children}, // Code blocks and inline code code({ node, inline, className, children, ...props }: any) { @@ -644,6 +738,9 @@ export const MarkdownRenderer: React.FC = ({ if (match && match[1]) { // Check if it's a Mermaid diagram if (match[1] === "mermaid") { + if (!enableMultimodal) { + return renderCodeFallback(codeContent); + } return ; } if (!inline) { @@ -689,10 +786,40 @@ export const MarkdownRenderer: React.FC = ({ ); }, - // Image - img: ({ src, alt }: any) => ( - {alt} - ), + // Image (also handles video previews emitted as image markdown) + img: ({ src, alt }: any) => { + if (!enableMultimodal) { + return renderMediaFallback(src, alt); + } + + if (isVideoUrl(src)) { + return renderVideoElement({ src, alt }); + } + + return {alt}; + }, + video: ({ children, ...props }: any) => { + const directSrc = props?.src; + const childSource = React.Children.toArray(children) + .map((child) => + React.isValidElement(child) ? child.props?.src : undefined + ) + .find(Boolean); + const videoSrc = directSrc ?? childSource; + const caption = + props?.["aria-label"] ?? + props?.title ?? + props?.["data-caption"] ?? + undefined; + + const element = renderVideoElement({ + src: videoSrc, + alt: caption, + props, + }); + + return element ?? renderMediaFallback(undefined, caption); + }, }} > {processedContent} @@ -702,3 +829,5 @@ export const MarkdownRenderer: React.FC = ({ ); }; + + \ No newline at end of file diff --git a/frontend/components/ui/navbar.tsx b/frontend/components/ui/navbar.tsx new file mode 100644 index 000000000..45a938e56 --- /dev/null +++ b/frontend/components/ui/navbar.tsx @@ -0,0 +1,129 @@ +"use client"; + +import { Button } from "@/components/ui/button"; +import { AvatarDropdown } from "@/components/auth/avatarDropdown"; +import { useTranslation } from "react-i18next"; +import { useAuth } from "@/hooks/useAuth"; +import { Globe } from "lucide-react"; +import { Dropdown } from "antd"; +import { DownOutlined } from "@ant-design/icons"; +import Link from "next/link"; +import { HEADER_CONFIG } from "@/const/layoutConstants"; +import { languageOptions } from "@/const/constants"; +import { useLanguageSwitch } from "@/lib/language"; + +/** + * Main navigation bar component + * Displays logo, navigation links, language switcher, and user authentication status + */ +export function Navbar() { + const { t } = useTranslation("common"); + const { user, isLoading: userLoading, isSpeedMode } = useAuth(); + const { currentLanguage, handleLanguageChange } = useLanguageSwitch(); + + return ( +
+ {/* Left section - Logo */} + +

+ ModelEngine + + {t("assistant.name")} + +

+ + + {/* Right section - Navigation links and user controls */} +
+ {/* GitHub link */} + + + Github + + + {/* ModelEngine link */} + + ModelEngine + + + {/* Language switcher */} + ({ + key: opt.value, + label: opt.label, + })), + onClick: ({ key }) => handleLanguageChange(key as string), + }} + > + + + {languageOptions.find((o) => o.value === currentLanguage)?.label || + currentLanguage} + + + + + {/* User status - only shown in full version */} + {!isSpeedMode && ( + <> + {userLoading ? ( + + {t("common.loading")}... + + ) : user ? ( + + {user.email} + + ) : null} + + + )} +
+ + {/* Mobile hamburger menu button */} + +
+ ); +} + diff --git a/frontend/const/layoutConstants.ts b/frontend/const/layoutConstants.ts index b9d170330..ccc4fa80b 100644 --- a/frontend/const/layoutConstants.ts +++ b/frontend/const/layoutConstants.ts @@ -5,8 +5,11 @@ // Header configuration export const HEADER_CONFIG = { - // Header height (including padding) - HEIGHT: "64px", + // Actual displayed height (including padding) + DISPLAY_HEIGHT: "55px", + + // Space reserved for layout calculation (may be larger than display height) + RESERVED_HEIGHT: "55px", // Vertical padding VERTICAL_PADDING: "16px", // py-4 @@ -17,11 +20,14 @@ export const HEADER_CONFIG = { // Footer configuration export const FOOTER_CONFIG = { - // Footer height (including padding) - HEIGHT: "64px", + // Actual displayed height (including padding) + DISPLAY_HEIGHT: "40px", + + // Space reserved for layout calculation (smaller than header, no extra space) + RESERVED_HEIGHT: "40px", // Vertical padding - VERTICAL_PADDING: "16px", // py-4 + VERTICAL_PADDING: "12px", // py-3 // Horizontal padding HORIZONTAL_PADDING: "16px", // px-4 diff --git a/frontend/hooks/useSetupFlow.ts b/frontend/hooks/useSetupFlow.ts new file mode 100644 index 000000000..39c152e45 --- /dev/null +++ b/frontend/hooks/useSetupFlow.ts @@ -0,0 +1,210 @@ +import {useState, useEffect, useRef} from "react"; +import {useRouter} from "next/navigation"; +import {useTranslation} from "react-i18next"; + +import {useAuth} from "@/hooks/useAuth"; +import modelEngineService from "@/services/modelEngineService"; +import { + USER_ROLES, + CONNECTION_STATUS, + ConnectionStatus, +} from "@/const/modelConfig"; +import {EVENTS} from "@/const/auth"; +import log from "@/lib/logger"; + +interface UseSetupFlowOptions { + /** Whether admin role is required to access this page */ + requireAdmin?: boolean; + /** External connection status (if managed by parent) */ + externalConnectionStatus?: ConnectionStatus; + /** External checking connection state (if managed by parent) */ + externalIsCheckingConnection?: boolean; + /** External check connection handler (if managed by parent) */ + onCheckConnection?: () => void; + /** Callback to expose connection status changes */ + onConnectionStatusChange?: (status: ConnectionStatus) => void; + /** Redirect path for non-admin users */ + nonAdminRedirect?: string; +} + +interface UseSetupFlowReturn { + // Auth related + user: any; + isLoading: boolean; + isSpeedMode: boolean; + canAccessProtectedData: boolean; + + // Connection status + connectionStatus: ConnectionStatus; + isCheckingConnection: boolean; + checkModelEngineConnection: () => Promise; + + // Animation config + pageVariants: { + initial: { opacity: number; x: number }; + in: { opacity: number; x: number }; + out: { opacity: number; x: number }; + }; + pageTransition: { + type: "tween"; + ease: "anticipate"; + duration: number; + }; + + // Utilities + router: ReturnType; + t: ReturnType["t"]; +} + +/** + * useSetupFlow - Custom hook for setup flow pages + * + * Provides common functionality for setup pages including: + * - Authentication and permission checks + * - Connection status management + * - Session expiration handling + * - Page transition animations + * + * @param options - Configuration options + * @returns Setup flow utilities and state + */ +export function useSetupFlow(options: UseSetupFlowOptions = {}): UseSetupFlowReturn { + const { + requireAdmin = false, + externalConnectionStatus, + externalIsCheckingConnection, + onCheckConnection: externalOnCheckConnection, + onConnectionStatusChange, + nonAdminRedirect = "/setup/knowledges", + } = options; + + const router = useRouter(); + const {t} = useTranslation(); + const {user, isLoading: userLoading, isSpeedMode} = useAuth(); + const sessionExpiredTriggeredRef = useRef(false); + + // Calculate if user can access protected data + const canAccessProtectedData = isSpeedMode || (!userLoading && !!user); + + // Internal connection status management (if not provided externally) + const [internalConnectionStatus, setInternalConnectionStatus] = useState( + CONNECTION_STATUS.PROCESSING + ); + const [internalIsCheckingConnection, setInternalIsCheckingConnection] = useState(false); + + // Use external status if provided, otherwise use internal + const connectionStatus = externalConnectionStatus ?? internalConnectionStatus; + const isCheckingConnection = externalIsCheckingConnection ?? internalIsCheckingConnection; + + // Check login status and handle session expiration + useEffect(() => { + if (isSpeedMode) { + sessionExpiredTriggeredRef.current = false; + return; + } + + if (user) { + sessionExpiredTriggeredRef.current = false; + return; + } + + // Trigger session expired event if user is not logged in + if (!userLoading && !sessionExpiredTriggeredRef.current) { + sessionExpiredTriggeredRef.current = true; + window.dispatchEvent( + new CustomEvent(EVENTS.SESSION_EXPIRED, { + detail: {message: "Session expired, please sign in again"}, + }) + ); + } + }, [isSpeedMode, user, userLoading]); + + // Check admin permission if required + useEffect(() => { + if (!requireAdmin) return; + + // Only check after user is loaded + if (userLoading) return; + + // Speed mode always has access + if (isSpeedMode) return; + + // Check if user has admin role + if (user && user.role !== USER_ROLES.ADMIN) { + router.push(nonAdminRedirect); + } + }, [requireAdmin, isSpeedMode, user, userLoading, router, nonAdminRedirect]); + + // Internal check connection function + const checkModelEngineConnectionInternal = async () => { + setInternalIsCheckingConnection(true); + + try { + const result = await modelEngineService.checkConnection(); + setInternalConnectionStatus(result.status); + onConnectionStatusChange?.(result.status); + } catch (error) { + log.error(t("setup.page.error.checkConnection"), error); + setInternalConnectionStatus(CONNECTION_STATUS.ERROR); + onConnectionStatusChange?.(CONNECTION_STATUS.ERROR); + } finally { + setInternalIsCheckingConnection(false); + } + }; + + // Use external handler if provided, otherwise use internal + const checkModelEngineConnection = externalOnCheckConnection + ? async () => { await Promise.resolve(externalOnCheckConnection()); } + : checkModelEngineConnectionInternal; + + // Check connection on mount if not externally managed + useEffect(() => { + if (canAccessProtectedData && !externalOnCheckConnection) { + checkModelEngineConnectionInternal(); + } + }, [canAccessProtectedData, externalOnCheckConnection]); + + // Animation variants for smooth page transitions + const pageVariants = { + initial: { + opacity: 0, + x: 20, + }, + in: { + opacity: 1, + x: 0, + }, + out: { + opacity: 0, + x: -20, + }, + }; + + const pageTransition = { + type: "tween" as const, + ease: "anticipate" as const, + duration: 0.4, + }; + + return { + // Auth + user, + isLoading: userLoading, + isSpeedMode, + canAccessProtectedData, + + // Connection + connectionStatus, + isCheckingConnection, + checkModelEngineConnection, + + // Animation + pageVariants, + pageTransition, + + // Utilities + router, + t, + }; +} + diff --git a/frontend/lib/auth.ts b/frontend/lib/auth.ts index c6bda9652..fae920cd3 100644 --- a/frontend/lib/auth.ts +++ b/frontend/lib/auth.ts @@ -2,12 +2,10 @@ * Authentication utilities */ -import { createAvatar } from '@dicebear/core'; -import * as initialsStyle from '@dicebear/initials'; - import { fetchWithErrorHandling } from "@/services/api"; import { STORAGE_KEYS } from "@/const/auth"; import { Session } from "@/types/auth"; +import { generateAvatarUrl as generateAvatar } from "@/lib/avatar"; import log from "@/lib/logger"; // Get color corresponding to user role @@ -21,17 +19,10 @@ export function getRoleColor(role: string): string { } } -// Generate avatar based on email +// Generate avatar based on email (re-export from avatar.tsx for backward compatibility) export function generateAvatarUrl(email: string): string { - // Use local dicebear package to generate avatar - const avatar = createAvatar(initialsStyle, { - seed: email, - backgroundType: ['gradientLinear'] - }); - - // Return SVG data URI - return avatar.toDataUri(); - } + return generateAvatar(email); +} /** * Request with authorization headers @@ -64,7 +55,7 @@ export const saveSessionToStorage = (session: Session) => { }; /** - * 从本地存储删除会话 + * Remove session from local storage */ export const removeSessionFromStorage = () => { if (typeof window !== "undefined") { @@ -73,7 +64,7 @@ export const removeSessionFromStorage = () => { }; /** - * 从本地存储获取会话 + * Get session from local storage */ export const getSessionFromStorage = (): Session | null => { try { @@ -82,7 +73,7 @@ export const getSessionFromStorage = (): Session | null => { return JSON.parse(storedSession); } catch (error) { - log.error("解析会话信息失败:", error); + log.error("Failed to parse session info:", error); return null; } }; diff --git a/frontend/lib/avatar.tsx b/frontend/lib/avatar.tsx index 59e913fc1..badde189b 100644 --- a/frontend/lib/avatar.tsx +++ b/frontend/lib/avatar.tsx @@ -190,6 +190,53 @@ function hslToRgb(h: number, s: number, l: number): [number, number, number] { return [Math.round(r * 255), Math.round(g * 255), Math.round(b * 255)]; } +/** + * Generate avatar from name (for agent avatars) + * @param name Agent name or display name + * @param size Avatar size (default: 30) + * @param scale Scale percentage (default: 80) + * @returns Generated avatar data URI + */ +export const generateAvatarFromName = (name: string, size: number = 30, scale: number = 80): string => { + // Use name as seed to generate consistent color + const seed = name || "default"; + const random = new SeededRandom(seed); + + // Generate main color from name + const r = random.randomInt(50, 200); + const g = random.randomInt(50, 200); + const b = random.randomInt(50, 200); + const mainColor = ((1 << 24) + (r << 16) + (g << 8) + b).toString(16).slice(1); + const secondaryColor = generateComplementaryColor(mainColor); + + // Select icon based on name + const iconIndex = random.randomInt(0, presetIcons.length - 1); + const selectedIcon = presetIcons[iconIndex]; + + const avatar = createAvatar(iconStyle, { + seed: seed, + backgroundColor: [mainColor, secondaryColor], + backgroundType: ["gradientLinear"], + icon: [selectedIcon.key], + scale: scale, + size: size, + radius: 50 + }); + + return avatar.toDataUri(); +}; + +/** + * Generate avatar from email or identifier (for user avatars) + * @param identifier Email or other identifier + * @param size Avatar size (default: 30) + * @param scale Scale percentage (default: 80) + * @returns Generated avatar data URI + */ +export const generateAvatarUrl = (identifier: string, size: number = 30, scale: number = 80): string => { + return generateAvatarFromName(identifier, size, scale); +}; + /** * Extract main and secondary colors from Dicebear generated Data URI, reserved for app name use * @param dataUri Dicebear generated avatar data URI diff --git a/frontend/lib/viewPersistence.ts b/frontend/lib/viewPersistence.ts new file mode 100644 index 000000000..a08b9f2fe --- /dev/null +++ b/frontend/lib/viewPersistence.ts @@ -0,0 +1,86 @@ +/** + * View persistence utilities for managing current view state across page refreshes + * Uses localStorage to persist the current view selection + */ + +const VIEW_STORAGE_KEY = 'nexent_current_view'; + +type ViewType = + | "home" + | "memory" + | "models" + | "agents" + | "knowledges" + | "space" + | "setup" + | "chat" + | "market" + | "users"; + +const VALID_VIEWS: ViewType[] = [ + "home", + "memory", + "models", + "agents", + "knowledges", + "space", + "setup", + "chat", + "market", + "users", +]; + +/** + * Get the saved view from localStorage + * @returns The saved view or "home" as default + */ +export function getSavedView(): ViewType { + if (typeof window === 'undefined') { + return "home"; + } + + try { + const savedView = localStorage.getItem(VIEW_STORAGE_KEY); + if (savedView && VALID_VIEWS.includes(savedView as ViewType)) { + return savedView as ViewType; + } + } catch (error) { + // localStorage might be disabled or throw errors + console.warn('Failed to read saved view from localStorage:', error); + } + + return "home"; +} + +/** + * Save the current view to localStorage + * @param view The view to save + */ +export function saveView(view: ViewType): void { + if (typeof window === 'undefined') { + return; + } + + try { + localStorage.setItem(VIEW_STORAGE_KEY, view); + } catch (error) { + // localStorage might be disabled or throw errors + console.warn('Failed to save view to localStorage:', error); + } +} + +/** + * Clear the saved view from localStorage + */ +export function clearSavedView(): void { + if (typeof window === 'undefined') { + return; + } + + try { + localStorage.removeItem(VIEW_STORAGE_KEY); + } catch (error) { + console.warn('Failed to clear saved view from localStorage:', error); + } +} + diff --git a/frontend/next.config.mjs b/frontend/next.config.mjs index 524ae5161..a44561b9c 100644 --- a/frontend/next.config.mjs +++ b/frontend/next.config.mjs @@ -21,8 +21,6 @@ const nextConfig = { parallelServerBuildTraces: true, parallelServerCompiles: true, }, - // Optimize resource preloading - optimizeCss: true, compress: true, } diff --git a/frontend/public/locales/en/common.json b/frontend/public/locales/en/common.json index 1677820d5..5171cbd5c 100644 --- a/frontend/public/locales/en/common.json +++ b/frontend/public/locales/en/common.json @@ -118,6 +118,7 @@ "page.description": "No orchestration, no complex drag-and-drop required. Integrate data, models, and tools into one intelligent hub.", "page.startChat": "Start Chatting", "page.quickConfig": "Quick Setup", + "page.agentSpace": "Agent Space", "page.dataProtection": "Free trial environment does not retain data. Data may be lost during updates - please take note.", "page.coreFeatures": "Core Features", "page.features": [ @@ -290,6 +291,7 @@ "agent.action.viewCallRelationship": "View Call Relationship", "agent.error.nameExists": "Agent var name {{name}} already exists, please modify", "agent.error.displayNameExists": "Agent name {{displayName}} already exists, please modify", + "agent.error.modelUnavailable": "LLM {{modelName}} is unavailable, please modify", "agent.debug.placeholder": "Enter test question...", "agent.debug.stop": "Stop", "agent.debug.send": "Send", @@ -335,6 +337,7 @@ "subAgentPool.button.create": "Create Agent", "subAgentPool.button.import": "Import Agent", "subAgentPool.button.importing": "Importing...", + "subAgentPool.message.unavailable": "This Agent is unavailable", "subAgentPool.tooltip.unavailableAgent": "Agent is unavailable", "subAgentPool.tooltip.hasUnavailableTools": "This agent has been disabled because it contains unavailable tools. Please modify the tool configuration before using it", "subAgentPool.section.agentList": "Agent List", @@ -346,7 +349,7 @@ "subAgentPool.tooltip.exitCreateMode": "Click to exit create mode", "subAgentPool.tooltip.exitEditMode": "Click to exit edit mode", "subAgentPool.tooltip.editAgent": "Click to edit", - "subAgentPool.tooltip.duplicateNameDisabled": "This agent is disabled due to duplicate name with other agents. Please change the name to use it", + "subAgentPool.tooltip.duplicateNameDisabled": "Agent name already exists", "subAgentPool.message.duplicateNameDisabled": "This agent is disabled due to duplicate name with other agents. Please change the name to use it", "toolConfig.title.paramConfig": "Parameter Configuration", @@ -381,6 +384,7 @@ "toolPool.tooltip.viewOnlyMode": "View mode, cannot select tools", "toolPool.message.unavailable": "This tool is unavailable", "toolPool.message.viewOnlyMode": "Currently in view mode, cannot select tools", + "toolPool.error.unavailableSelected": "Agent contains unavailable tools, please modify", "toolPool.tag.mcp": "MCP Tool", "toolPool.tag.local": "Local Tool", "toolPool.tag.langchain": "LangChain Tool", @@ -397,6 +401,11 @@ "common.confirm": "Confirm", "common.enabled": "enabled", "common.disabled": "disabled", + "common.yes": "Yes", + "common.no": "No", + "common.none": "None", + "common.required": "Required", + "common.refresh": "Refresh", "common.unknownError": "Unknown error", "common.retryLater": "Please try again later", "common.back": "Back", @@ -413,6 +422,7 @@ "tool.message.unavailable": "This tool is currently unavailable and cannot be selected", "tool.error.noMainAgentId": "Main agent ID is not set, cannot update tool status", "tool.error.configFetchFailed": "Failed to get tool configuration", + "tool.message.statusUpdated": "Tool {{name}} has been {{status}}", "tool.error.updateFailed": "Failed to update tool status", "tool.error.updateRetry": "Failed to update tool status, please try again later", @@ -499,6 +509,8 @@ "document.status.loadingList": "Loading document list...", "document.input.knowledgeBaseName": "Please enter knowledge base name", "document.button.details": "Details", + "document.button.overview": "Overview", + "document.button.detail": "Chunk Details", "document.button.autoSummary": "Auto Summary", "document.title.createNew": "Create New Knowledge Base", "document.hint.uploadToCreate": "Please select files to upload to complete knowledge base creation", @@ -519,6 +531,10 @@ "document.modal.deleteConfirm.content": "Are you sure you want to delete this document? This action cannot be undone.", "document.message.noFiles": "Please select files first", "document.message.uploadError": "Failed to upload files", + "document.chunk.noChunks": "No chunks available", + "document.chunk.characterCount": "{{count}} characters", + "document.chunk.error.loadFailed": "Failed to load chunks", + "document.chunk.error.downloadFailed": "Failed to download chunk", "model.dialog.title": "Add Model", "model.dialog.label.type": "Model Type", @@ -546,6 +562,8 @@ "model.dialog.hint.batchImportDisabled": "Batch add disabled. Only a single model will be added.", "model.provider.silicon": "SiliconFlow", "model.dialog.modelList.title": "Show Models", + "model.dialog.modelList.searchPlaceholder": "Search models by name", + "model.dialog.modelList.noResults": "No models match your search", "model.dialog.connectivity.title": "Connectivity Verification", "model.dialog.connectivity.status.checking": "Detecting", "model.dialog.connectivity.status.available": "Available", @@ -576,6 +594,7 @@ "model.dialog.editTitle": "Edit Model", "model.dialog.editSuccess": "Model updated successfully", "model.dialog.error.editFailed": "Failed to update model", + "model.dialog.error.nameConflict": "Name '{{name}}' is already in use, please choose another display name", "model.dialog.error.serverError": "Server internal error, please try again later", "model.type.llm": "Large Language Model", "model.type.embedding": "Embedding Model", @@ -928,10 +947,10 @@ "memoryService.loadMemoryError": "Failed to load memories, please try again later", "memoryService.tenantSharedGroupTitle": "Tenant shared memories", - "memoryService.agentSharedGroupTitle": "Shared memories of all users in {{agentName}}", + "memoryService.agentSharedGroupTitle": "{{agentName}}", "memoryService.agentSharedPlaceholder": "Agent Shared Memories", "memoryService.userPersonalGroupTitle": "User's personal memories", - "memoryService.userAgentGroupTitle": "User memories in {{agentName}}", + "memoryService.userAgentGroupTitle": "{{agentName}}", "memoryService.userAgentPlaceholder": "User Agent Memories", "memoryManageModal.title": "Memory Management", @@ -968,5 +987,78 @@ "diagram.format.svg": "SVG", "diagram.format.png": "PNG", "diagram.format.selectFormat": "Select Format", - "diagram.error.renderFailed": "Render Failed" + "diagram.error.renderFailed": "Render Failed", + + "space.title": "Agent Space", + "space.description": "Manage and interact with your intelligent agents", + "space.createAgent": "Create New Agent", + "space.noAgents": "No agents yet. Create your first agent to get started!", + "space.noDescription": "No description", + "space.status.available": "Available", + "space.status.unavailable": "Unavailable", + "space.deleteConfirm.title": "Delete Agent", + "space.deleteConfirm.content": "Are you sure you want to delete this agent? This action cannot be undone.", + "space.deleteSuccess": "Agent deleted successfully", + "space.exportSuccess": "Agent exported successfully", + "space.actions.edit": "Edit", + "space.actions.delete": "Delete", + "space.actions.export": "Export", + "space.actions.relationship": "View Relationships", + "space.actions.chat": "Chat", + + "space.detail.title": "Agent Details", + "space.detail.subtitle": "Detailed configuration and information", + "space.detail.tabs.basic": "Basic Info", + "space.detail.tabs.model": "Model Config", + "space.detail.tabs.prompts": "Prompts", + "space.detail.tabs.tools": "Tools", + "space.detail.tabs.subAgents": "Sub Agents", + "space.detail.id": "Agent ID", + "space.detail.name": "Name", + "space.detail.displayName": "Display Name", + "space.detail.description": "Description", + "space.detail.status": "Status", + "space.detail.enabled": "Enabled", + "space.detail.model": "Model Name", + "space.detail.modelId": "Model ID", + "space.detail.maxStep": "Max Steps", + "space.detail.businessLogicModel": "Business Logic Model", + "space.detail.businessLogicModelId": "Business Logic Model ID", + "space.detail.provideRunSummary": "Provide Run Summary", + "space.detail.dutyPrompt": "Duty Prompt", + "space.detail.constraintPrompt": "Constraint Prompt", + "space.detail.fewShotsPrompt": "Few-Shots Prompt", + "space.detail.businessDescription": "Business Description", + "space.detail.noTools": "No tools configured", + "space.detail.noSubAgents": "No sub agents configured", + "space.detail.subAgentId": "Sub Agent ID", + "space.detail.source": "Source", + "space.detail.category": "Category", + "space.detail.usage": "Usage", + "space.detail.parameters": "Parameters", + + "sidebar.homePage": "Home Page", + "sidebar.startChat": "Start Chat", + "sidebar.quickConfig": "Quick Setup", + "sidebar.agentSpace": "Agent Space", + "sidebar.agentMarket": "Agent Market", + "sidebar.agentDev": "Agent Development", + "sidebar.knowledgeBase": "Knowledge Base", + "sidebar.modelManagement": "Model Management", + "sidebar.memoryManagement": "Memory Management", + "sidebar.userManagement": "User Management", + + "market.comingSoon.title": "Agent Market Coming Soon", + "market.comingSoon.description": "Discover and install pre-built AI agents from our marketplace. Save time by leveraging community-created solutions.", + "market.comingSoon.feature1": "Browse curated agent templates", + "market.comingSoon.feature2": "One-click installation and deployment", + "market.comingSoon.feature3": "Share your own agents with the community", + "market.comingSoon.badge": "Coming Soon", + + "users.comingSoon.title": "User Management Coming Soon", + "users.comingSoon.description": "Comprehensive user management system for administrators. Control access, roles, and permissions across your organization.", + "users.comingSoon.feature1": "Manage user accounts and roles", + "users.comingSoon.feature2": "Configure fine-grained permissions", + "users.comingSoon.feature3": "Monitor user activity and usage", + "users.comingSoon.badge": "Coming Soon" } diff --git a/frontend/public/locales/zh/common.json b/frontend/public/locales/zh/common.json index 2aad5d809..656910f8b 100644 --- a/frontend/public/locales/zh/common.json +++ b/frontend/public/locales/zh/common.json @@ -118,6 +118,7 @@ "page.description": "无需编排,无需复杂拖拉拽,将数据、模型和工具整合到一个智能中心中。", "page.startChat": "开始问答", "page.quickConfig": "快速配置", + "page.agentSpace": "智能体空间", "page.dataProtection": "免费试用环境不做数据留存,数据可能随更新丢失,请注意", "page.coreFeatures": "核心功能", "page.features": [ @@ -146,7 +147,7 @@ "description": "基于多模态知识库、数据处理能力,提供多模态的智能体服务,支持文本、图像、音频等多种数据类型的输入输出。" } ], - "page.copyright": "Nexent 智能体 © {{year}}", + "page.copyright": "Nexent © {{year}}", "page.termsOfUse": "使用条款", "page.loginPrompt.title": "登录账号", "page.loginPrompt.register": "注册", @@ -291,6 +292,7 @@ "agent.action.viewCallRelationship": "查看调用关系", "agent.error.nameExists": "Agent变量名{{name}}已存在,请修改", "agent.error.displayNameExists": "Agent名称{{displayName}}已存在,请修改", + "agent.error.modelUnavailable": "大语言模型{{modelName}}不可用,请修改", "agent.debug.placeholder": "输入测试问题...", "agent.debug.stop": "停止", "agent.debug.send": "发送", @@ -336,6 +338,7 @@ "subAgentPool.button.create": "创建Agent", "subAgentPool.button.import": "导入Agent", "subAgentPool.button.importing": "导入中...", + "subAgentPool.message.unavailable": "该Agent不可用", "subAgentPool.tooltip.unavailableAgent": "Agent不可用", "subAgentPool.tooltip.hasUnavailableTools": "该智能体因包含不可用工具而被禁用,请修改工具配置后使用", "subAgentPool.section.agentList": "智能体列表", @@ -382,6 +385,7 @@ "toolPool.tooltip.viewOnlyMode": "查看模式,无法选择工具", "toolPool.message.unavailable": "该工具不可用", "toolPool.message.viewOnlyMode": "当前为查看模式,无法选择工具", + "toolPool.error.unavailableSelected": "Agent存在不可用工具,请修改", "toolPool.tag.mcp": "MCP工具", "toolPool.tag.local": "本地工具", "toolPool.tag.langchain": "LangChain工具", @@ -398,6 +402,11 @@ "common.confirm": "确定", "common.enabled": "已启用", "common.disabled": "已禁用", + "common.yes": "是", + "common.no": "否", + "common.none": "无", + "common.required": "必填", + "common.refresh": "刷新", "common.unknownError": "未知错误", "common.retryLater": "请稍后重试", "common.back": "返回", @@ -414,6 +423,7 @@ "tool.message.unavailable": "该工具当前不可用,无法选择", "tool.error.noMainAgentId": "主代理ID未设置,无法更新工具状态", "tool.error.configFetchFailed": "获取工具配置失败", + "tool.message.statusUpdated": "工具{{name}}{{status}}", "tool.error.updateFailed": "更新工具状态失败", "tool.error.updateRetry": "更新工具状态失败,请稍后重试", @@ -500,6 +510,8 @@ "document.status.loadingList": "正在加载文档列表...", "document.input.knowledgeBaseName": "请输入知识库名称", "document.button.details": "详细内容", + "document.button.overview": "概览", + "document.button.detail": "分片详情", "document.button.autoSummary": "自动总结", "document.title.createNew": "创建新知识库", "document.hint.uploadToCreate": "请选择文件上传以完成知识库创建", @@ -520,6 +532,10 @@ "document.modal.deleteConfirm.content": "确定要删除这个文档吗?删除后无法恢复。", "document.message.noFiles": "请先选择文件", "document.message.uploadError": "文件上传失败", + "document.chunk.noChunks": "暂无分片数据", + "document.chunk.characterCount": "{{count}} 字符", + "document.chunk.error.loadFailed": "加载分片失败", + "document.chunk.error.downloadFailed": "下载分片失败", "model.dialog.title": "添加模型", "model.dialog.label.type": "模型类型", @@ -547,6 +563,8 @@ "model.dialog.hint.batchImportDisabled": "批量添加模式已关闭,仅添加单个模型", "model.provider.silicon": "硅基流动", "model.dialog.modelList.title": "显示模型", + "model.dialog.modelList.searchPlaceholder": "按名称搜索模型", + "model.dialog.modelList.noResults": "没有匹配的模型", "model.dialog.connectivity.title": "连通性验证", "model.dialog.connectivity.status.checking": "检测中", "model.dialog.connectivity.status.available": "可用", @@ -575,6 +593,7 @@ "model.dialog.editTitle": "编辑模型", "model.dialog.editSuccess": "模型更新成功", "model.dialog.error.editFailed": "更新模型失败", + "model.dialog.error.nameConflict": "名称 '{{name}}' 已被使用,请选择其他显示名称", "model.dialog.error.serverError": "服务器内部错误,请稍后重试", "model.type.llm": "大语言模型", "model.type.embedding": "向量模型", @@ -928,10 +947,10 @@ "memoryService.loadMemoryError": "加载记忆失败,请稍后重试", "memoryService.tenantSharedGroupTitle": "租户所有用户的共享记忆", - "memoryService.agentSharedGroupTitle": "{{agentName}} 所有用户的共享记忆", + "memoryService.agentSharedGroupTitle": "{{agentName}}", "memoryService.agentSharedPlaceholder": "Agent 共享记忆", "memoryService.userPersonalGroupTitle": "用户的个性化记忆", - "memoryService.userAgentGroupTitle": "{{agentName}} 中的用户记忆", + "memoryService.userAgentGroupTitle": "{{agentName}}", "memoryService.userAgentPlaceholder": "用户 Agent 记忆", "memoryManageModal.title": "记忆管理", @@ -968,5 +987,78 @@ "diagram.format.svg": "SVG", "diagram.format.png": "PNG", "diagram.format.selectFormat": "选择格式", - "diagram.error.renderFailed": "渲染失败" + "diagram.error.renderFailed": "渲染失败", + + "space.title": "智能体空间", + "space.description": "管理和使用您的智能体", + "space.createAgent": "创建新智能体", + "space.noAgents": "暂无智能体,创建您的第一个智能体吧!", + "space.noDescription": "暂无描述", + "space.status.available": "可用", + "space.status.unavailable": "不可用", + "space.deleteConfirm.title": "删除智能体", + "space.deleteConfirm.content": "确定要删除此智能体吗?此操作无法撤销。", + "space.deleteSuccess": "智能体删除成功", + "space.exportSuccess": "智能体导出成功", + "space.actions.edit": "编辑", + "space.actions.delete": "删除", + "space.actions.export": "导出", + "space.actions.relationship": "查看关系", + "space.actions.chat": "聊天", + + "space.detail.title": "智能体详情", + "space.detail.subtitle": "详细配置和信息", + "space.detail.tabs.basic": "基础信息", + "space.detail.tabs.model": "模型配置", + "space.detail.tabs.prompts": "提示词", + "space.detail.tabs.tools": "工具", + "space.detail.tabs.subAgents": "子智能体", + "space.detail.id": "智能体 ID", + "space.detail.name": "名称", + "space.detail.displayName": "显示名称", + "space.detail.description": "描述", + "space.detail.status": "状态", + "space.detail.enabled": "已启用", + "space.detail.model": "模型名称", + "space.detail.modelId": "模型 ID", + "space.detail.maxStep": "最大步数", + "space.detail.businessLogicModel": "业务逻辑模型", + "space.detail.businessLogicModelId": "业务逻辑模型 ID", + "space.detail.provideRunSummary": "提供运行摘要", + "space.detail.dutyPrompt": "职责提示词", + "space.detail.constraintPrompt": "约束提示词", + "space.detail.fewShotsPrompt": "少样本提示词", + "space.detail.businessDescription": "业务描述", + "space.detail.noTools": "暂无配置工具", + "space.detail.noSubAgents": "暂无配置子智能体", + "space.detail.subAgentId": "子智能体 ID", + "space.detail.source": "来源", + "space.detail.category": "分类", + "space.detail.usage": "用途", + "space.detail.parameters": "参数", + + "sidebar.homePage": "首页", + "sidebar.startChat": "开始问答", + "sidebar.quickConfig": "快速配置", + "sidebar.agentSpace": "智能体空间", + "sidebar.agentMarket": "智能体市场", + "sidebar.agentDev": "智能体开发", + "sidebar.knowledgeBase": "知识库", + "sidebar.modelManagement": "模型管理", + "sidebar.memoryManagement": "记忆管理", + "sidebar.userManagement": "用户管理", + + "market.comingSoon.title": "智能体市场即将推出", + "market.comingSoon.description": "从我们的市场中发现并安装预构建的AI智能体。通过使用社区创建的解决方案节省时间。", + "market.comingSoon.feature1": "浏览精选的智能体模板", + "market.comingSoon.feature2": "一键安装和部署", + "market.comingSoon.feature3": "与社区分享您自己的智能体", + "market.comingSoon.badge": "即将推出", + + "users.comingSoon.title": "用户管理即将推出", + "users.comingSoon.description": "为管理员提供全面的用户管理系统。控制组织内的访问权限、角色和权限。", + "users.comingSoon.feature1": "管理用户账户和角色", + "users.comingSoon.feature2": "配置精细化权限", + "users.comingSoon.feature3": "监控用户活动和使用情况", + "users.comingSoon.badge": "即将推出" } diff --git a/frontend/server.js b/frontend/server.js index 11ef45658..0d7a03d05 100644 --- a/frontend/server.js +++ b/frontend/server.js @@ -10,8 +10,9 @@ const app = next({ const handle = app.getRequestHandler(); // Backend addresses -const HTTP_BACKEND = process.env.HTTP_BACKEND || 'http://localhost:5010'; -const WS_BACKEND = process.env.WS_BACKEND || 'ws://localhost:5010'; +const HTTP_BACKEND = process.env.HTTP_BACKEND || 'http://localhost:5010'; // config +const WS_BACKEND = process.env.WS_BACKEND || 'ws://localhost:5014'; // runtime +const RUNTIME_HTTP_BACKEND = process.env.RUNTIME_HTTP_BACKEND || 'http://localhost:5014'; // runtime const MINIO_BACKEND = process.env.MINIO_ENDPOINT || 'http://localhost:9000'; const PORT = 3000; @@ -26,8 +27,16 @@ app.prepare().then(() => { if (pathname.includes('/attachments/') && !pathname.startsWith('/api/')) { proxy.web(req, res, { target: MINIO_BACKEND }); } else if (pathname.startsWith('/api/')) { - // All /api/ requests (including the initial handshake for WebSockets) go to the backend - proxy.web(req, res, { target: HTTP_BACKEND, changeOrigin: true }); + // Route runtime endpoints to runtime backend, others to config backend + const isRuntime = + pathname.startsWith('/api/agent/run') || + pathname.startsWith('/api/agent/stop') || + pathname.startsWith('/api/conversation/') || + pathname.startsWith('/api/memory/') || + pathname.startsWith('/api/file/storage') || + pathname.startsWith('/api/file/preprocess'); + const target = isRuntime ? RUNTIME_HTTP_BACKEND : HTTP_BACKEND; + proxy.web(req, res, { target, changeOrigin: true }); } else { // Let Next.js handle all other requests handle(req, res, parsedUrl); diff --git a/frontend/services/agentConfigService.ts b/frontend/services/agentConfigService.ts index 5736a419b..44e3b8766 100644 --- a/frontend/services/agentConfigService.ts +++ b/frontend/services/agentConfigService.ts @@ -117,6 +117,7 @@ export const fetchAgentList = async () => { display_name: agent.display_name || agent.name, description: agent.description, is_available: agent.is_available, + unavailable_reasons: agent.unavailable_reasons || [], })); return { @@ -517,6 +518,7 @@ export const searchAgentInfo = async (agentId: number) => { provide_run_summary: data.provide_run_summary, enabled: data.enabled, is_available: data.is_available, + unavailable_reasons: data.unavailable_reasons || [], sub_agent_id_list: data.sub_agent_id_list || [], // Add sub_agent_id_list tools: data.tools ? data.tools.map((tool: any) => { diff --git a/frontend/services/api.ts b/frontend/services/api.ts index e82fde3b7..69ec7bc55 100644 --- a/frontend/services/api.ts +++ b/frontend/services/api.ts @@ -104,6 +104,8 @@ export const API_ENDPOINTS = { listFiles: (indexName: string) => `${API_BASE_URL}/indices/${indexName}/files`, indexDetail: (indexName: string) => `${API_BASE_URL}/indices/${indexName}`, + chunks: (indexName: string) => + `${API_BASE_URL}/indices/${indexName}/chunks`, summary: (indexName: string) => `${API_BASE_URL}/summary/${indexName}/auto_summary`, changeSummary: (indexName: string) => diff --git a/frontend/services/knowledgeBaseService.ts b/frontend/services/knowledgeBaseService.ts index e747283b8..93be15a6e 100644 --- a/frontend/services/knowledgeBaseService.ts +++ b/frontend/services/knowledgeBaseService.ts @@ -272,7 +272,7 @@ class KnowledgeBaseService { type: this.getFileTypeFromName(file.file || file.path_or_url), size: file.file_size, create_time: file.create_time, - chunk_num: file.chunk_count || 0, + chunk_num: file.chunk_count ?? 0, token_num: 0, status: file.status || "UNKNOWN", latest_task_id: file.latest_task_id || "", @@ -570,6 +570,104 @@ class KnowledgeBaseService { throw new Error("Failed to get summary"); } } + + // Preview chunks from a knowledge base + async previewChunks( + indexName: string, + batchSize: number = 1000 + ): Promise { + try { + const url = new URL( + API_ENDPOINTS.knowledgeBase.chunks(indexName), + window.location.origin + ); + url.searchParams.set("batch_size", batchSize.toString()); + + const response = await fetch(url.toString(), { + method: "POST", + headers: getAuthHeaders(), + }); + + const data = await response.json(); + + if (!response.ok) { + throw new Error( + data.detail || + data.message || + `HTTP error! status: ${response.status}` + ); + } + + if (data.status !== "success") { + throw new Error(data.message || "Failed to get chunks"); + } + + return data.chunks || []; + } catch (error) { + log.error("Error getting chunks:", error); + if (error instanceof Error) { + throw error; + } + throw new Error("Failed to get chunks"); + } + } + + // Preview chunks from a knowledge base with pagination + async previewChunksPaginated( + indexName: string, + page: number = 1, + pageSize: number = 10, + pathOrUrl?: string + ): Promise<{ + chunks: any[]; + total: number; + page: number; + pageSize: number; + }> { + try { + const url = new URL( + API_ENDPOINTS.knowledgeBase.chunks(indexName), + window.location.origin + ); + url.searchParams.set("page", page.toString()); + url.searchParams.set("page_size", pageSize.toString()); + if (pathOrUrl) { + url.searchParams.set("path_or_url", pathOrUrl); + } + + const response = await fetch(url.toString(), { + method: "POST", + headers: getAuthHeaders(), + }); + + const data = await response.json(); + + if (!response.ok) { + throw new Error( + data.detail || + data.message || + `HTTP error! status: ${response.status}` + ); + } + + if (data.status !== "success") { + throw new Error(data.message || "Failed to get chunks"); + } + + return { + chunks: data.chunks || [], + total: data.total || 0, + page: data.page || page, + pageSize: data.page_size || pageSize, + }; + } catch (error) { + log.error("Error getting chunks with pagination:", error); + if (error instanceof Error) { + throw error; + } + throw new Error("Failed to get chunks"); + } + } } // Export a singleton instance diff --git a/frontend/styles/globals.css b/frontend/styles/globals.css index 3da4b79ce..64b90fded 100644 --- a/frontend/styles/globals.css +++ b/frontend/styles/globals.css @@ -223,6 +223,30 @@ line-height: 1.2 !important; } +/* Document chunks tabs layout */ +.document-chunk-tabs { + height: 100%; +} + +.document-chunk-tabs .ant-tabs { + height: 100%; +} + +.document-chunk-tabs .ant-tabs-content-holder, +.document-chunk-tabs .ant-tabs-content, +.document-chunk-tabs .ant-tabs-tabpane { + height: 100%; +} + +.document-chunk-tabs .ant-tabs-nav { + height: 100%; +} + +.document-chunk-tabs .ant-tabs-nav-list { + height: 100%; + overflow-y: auto; +} + /* Styles for Embedding warning modal */ .kb-embedding-warning .ant-modal-wrap { position: absolute; diff --git a/frontend/styles/react-markdown.css b/frontend/styles/react-markdown.css index 08091ecd6..24a0e4f33 100644 --- a/frontend/styles/react-markdown.css +++ b/frontend/styles/react-markdown.css @@ -249,6 +249,34 @@ display: block; } +.markdown-video-wrapper { + margin: 1rem 0; + display: flex; + flex-direction: column; + gap: 0.5rem; + width: 100%; + align-items: stretch; +} + +.markdown-video { + width: 100%; + max-height: 70vh; + border-radius: 0.75rem; + background-color: #000; + box-shadow: 0 6px 12px -3px rgba(0, 0, 0, 0.2); +} + +.markdown-video:focus { + outline: 2px solid var(--color-nord10); + outline-offset: 2px; +} + +.markdown-video-caption { + font-size: 0.875rem; + color: var(--color-nord2); + line-height: 1.4; +} + /* Global markdown container */ .task-message-content { color: hsl(var(--foreground)) !important; diff --git a/frontend/types/agentConfig.ts b/frontend/types/agentConfig.ts index 03e4f65b1..f73549dce 100644 --- a/frontend/types/agentConfig.ts +++ b/frontend/types/agentConfig.ts @@ -10,6 +10,7 @@ export interface Agent { name: string; display_name?: string; description: string; + unavailable_reasons?: string[]; model: string; model_id?: number; max_step: number; @@ -153,7 +154,6 @@ export interface SubAgentPoolProps { isGeneratingAgent?: boolean; editingAgent?: Agent | null; isCreatingNewAgent?: boolean; - editingAgentName?: string | null; onExportAgent?: (agent: Agent) => void; onDeleteAgent?: (agent: Agent) => void; } @@ -170,6 +170,7 @@ export interface ToolPoolProps { isEditingMode?: boolean; isGeneratingAgent?: boolean; isEmbeddingConfigured?: boolean; + agentUnavailableReasons?: string[]; } // Simple prompt editor props interface @@ -322,6 +323,8 @@ export interface GeneratePromptParams { agent_id: number; task_description: string; model_id: string; + tool_ids?: number[]; // Optional: tool IDs selected in frontend (takes precedence over database query) + sub_agent_ids?: number[]; // Optional: sub-agent IDs selected in frontend (takes precedence over database query) } /** diff --git a/sdk/nexent/core/agents/core_agent.py b/sdk/nexent/core/agents/core_agent.py index ff7f5fcda..2e3fc05e7 100644 --- a/sdk/nexent/core/agents/core_agent.py +++ b/sdk/nexent/core/agents/core_agent.py @@ -92,6 +92,7 @@ def convert_code_format(text): # Restore if it was affected by the above replacement text = text.replace("```", "```") + text = text.replace("```", "```") # Clean up any remaining ```< patterns text = text.replace("```<", "```") diff --git a/sdk/nexent/core/agents/nexent_agent.py b/sdk/nexent/core/agents/nexent_agent.py index 85b41cc12..1a3d44a6f 100644 --- a/sdk/nexent/core/agents/nexent_agent.py +++ b/sdk/nexent/core/agents/nexent_agent.py @@ -66,11 +66,23 @@ def create_local_tool(self, tool_config: ToolConfig): raise ValueError(f"{class_name} not found in local") else: if class_name == "KnowledgeBaseSearchTool": - tools_obj = tool_class(index_names=tool_config.metadata.get("index_names", []), - observer=self.observer, - es_core=tool_config.metadata.get("es_core", []), - embedding_model=tool_config.metadata.get("embedding_model", []), - **params) + # Filter out conflicting parameters from params to avoid conflicts + # These parameters have exclude=True and cannot be passed to __init__ + # due to smolagents.tools.Tool wrapper restrictions + filtered_params = {k: v for k, v in params.items() + if k not in ["index_names", "vdb_core", "embedding_model", "observer"]} + # Create instance with only non-excluded parameters + tools_obj = tool_class(**filtered_params) + # Set excluded parameters directly as attributes after instantiation + # This bypasses smolagents wrapper restrictions + tools_obj.observer = self.observer + index_names = tool_config.metadata.get( + "index_names", None) if tool_config.metadata else None + tools_obj.index_names = [] if index_names is None else index_names + tools_obj.vdb_core = tool_config.metadata.get( + "vdb_core", None) if tool_config.metadata else None + tools_obj.embedding_model = tool_config.metadata.get( + "embedding_model", None) if tool_config.metadata else None else: tools_obj = tool_class(**params) if hasattr(tools_obj, 'observer'): diff --git a/sdk/nexent/core/tools/knowledge_base_search_tool.py b/sdk/nexent/core/tools/knowledge_base_search_tool.py index 5ef97468f..636162da1 100644 --- a/sdk/nexent/core/tools/knowledge_base_search_tool.py +++ b/sdk/nexent/core/tools/knowledge_base_search_tool.py @@ -2,14 +2,14 @@ import logging from typing import List -import requests +from pydantic import Field from smolagents.tools import Tool -from ..utils.observer import MessageObserver, ProcessType -from ..utils.tools_common_message import SearchResultTextMessage, ToolSign, ToolCategory -from pydantic import Field -from ...vector_database.elasticsearch_core import ElasticSearchCore +from ...vector_database.base import VectorDatabaseCore from ..models.embedding_model import BaseEmbedding +from ..utils.observer import MessageObserver, ProcessType +from ..utils.tools_common_message import SearchResultTextMessage, ToolCategory, ToolSign + # Get logger instance logger = logging.getLogger("knowledge_base_search_tool") @@ -17,40 +17,56 @@ class KnowledgeBaseSearchTool(Tool): """Knowledge base search tool""" + name = "knowledge_base_search" - description = "Performs a local knowledge base search based on your query then returns the top search results. " \ - "A tool for retrieving domain-specific knowledge, documents, and information stored in the local knowledge base. " \ - "Use this tool when users ask questions related to specialized knowledge, technical documentation, " \ - "domain expertise, personal notes, or any information that has been indexed in the knowledge base. " \ - "Suitable for queries requiring access to stored knowledge that may not be publicly available." - inputs = {"query": {"type": "string", "description": "The search query to perform."}, - "search_mode": {"type": "string", "description": "the search mode, optional values: hybrid, combining accurate matching and semantic search results across multiple indices.; accurate, Search for documents using fuzzy text matching across multiple indices; semantic, Search for similar documents using vector similarity across multiple indices.", - "default": "hybrid", "nullable": True}, - "index_names": {"type": "array", "description": "The list of knowledge base index names to search. If not provided, will search all available knowledge bases.", "nullable": True}} + description = ( + "Performs a local knowledge base search based on your query then returns the top search results. " + "A tool for retrieving domain-specific knowledge, documents, and information stored in the local knowledge base. " + "Use this tool when users ask questions related to specialized knowledge, technical documentation, " + "domain expertise, personal notes, or any information that has been indexed in the knowledge base. " + "Suitable for queries requiring access to stored knowledge that may not be publicly available." + ) + inputs = { + "query": {"type": "string", "description": "The search query to perform."}, + "search_mode": { + "type": "string", + "description": "the search mode, optional values: hybrid, combining accurate matching and semantic search results across multiple indices.; accurate, Search for documents using fuzzy text matching across multiple indices; semantic, Search for similar documents using vector similarity across multiple indices.", + "default": "hybrid", + "nullable": True, + }, + "index_names": { + "type": "array", + "description": "The list of knowledge base index names to search. If not provided, will search all available knowledge bases.", + "nullable": True, + }, + } output_type = "string" category = ToolCategory.SEARCH.value - tool_sign = ToolSign.KNOWLEDGE_BASE.value # Used to distinguish different index sources for summaries - - def __init__(self, top_k: int = Field(description="Maximum number of search results", default=5), - index_names: List[str] = Field(description="The list of index names to search", default=None, exclude=True) , - observer: MessageObserver = Field(description="Message observer", default=None, exclude=True), - embedding_model: BaseEmbedding = Field(description="The embedding model to use", default=None, exclude=True), - es_core: ElasticSearchCore = Field(description="Elasticsearch client", default=None, exclude=True) - ): + # Used to distinguish different index sources for summaries + tool_sign = ToolSign.KNOWLEDGE_BASE.value + + def __init__( + self, + top_k: int = Field(description="Maximum number of search results", default=5), + index_names: List[str] = Field(description="The list of index names to search", default=None, exclude=True), + observer: MessageObserver = Field(description="Message observer", default=None, exclude=True), + embedding_model: BaseEmbedding = Field(description="The embedding model to use", default=None, exclude=True), + vdb_core: VectorDatabaseCore = Field(description="Vector database client", default=None, exclude=True), + ): """Initialize the KBSearchTool. - + Args: top_k (int, optional): Number of results to return. Defaults to 5. observer (MessageObserver, optional): Message observer instance. Defaults to None. - + Raises: ValueError: If language is not supported """ super().__init__() self.top_k = top_k self.observer = observer - self.es_core = es_core + self.vdb_core = vdb_core self.index_names = [] if index_names is None else index_names self.embedding_model = embedding_model @@ -68,19 +84,21 @@ def forward(self, query: str, search_mode: str = "hybrid", index_names: List[str # Use provided index_names if available, otherwise use default search_index_names = index_names if index_names is not None else self.index_names - + # Log the index_names being used for this search - logger.info(f"KnowledgeBaseSearchTool called with query: '{query}', search_mode: '{search_mode}', index_names: {search_index_names}") - + logger.info( + f"KnowledgeBaseSearchTool called with query: '{query}', search_mode: '{search_mode}', index_names: {search_index_names}" + ) + if len(search_index_names) == 0: return json.dumps("No knowledge base selected. No relevant information found.", ensure_ascii=False) - if search_mode=="hybrid": - kb_search_data = self.es_search_hybrid(query=query, index_names=search_index_names) - elif search_mode=="accurate": - kb_search_data = self.es_search_accurate(query=query, index_names=search_index_names) - elif search_mode=="semantic": - kb_search_data = self.es_search_semantic(query=query, index_names=search_index_names) + if search_mode == "hybrid": + kb_search_data = self.search_hybrid(query=query, index_names=search_index_names) + elif search_mode == "accurate": + kb_search_data = self.search_accurate(query=query, index_names=search_index_names) + elif search_mode == "semantic": + kb_search_data = self.search_semantic(query=query, index_names=search_index_names) else: raise Exception(f"Invalid search mode: {search_mode}, only support: hybrid, accurate, semantic") @@ -98,12 +116,19 @@ def forward(self, query: str, search_mode: str = "hybrid", index_names: List[str title = single_search_result.get("title") if not title: title = single_search_result.get("filename", "") - search_result_message = SearchResultTextMessage(title=title, - text=single_search_result.get("content", ""), source_type=source_type, - url=single_search_result.get("path_or_url", ""), filename=single_search_result.get("filename", ""), - published_date=single_search_result.get("create_time", ""), score=single_search_result.get("score", 0), - score_details=single_search_result.get("score_details", {}), cite_index=self.record_ops + index, - search_type=self.name, tool_sign=self.tool_sign) + search_result_message = SearchResultTextMessage( + title=title, + text=single_search_result.get("content", ""), + source_type=source_type, + url=single_search_result.get("path_or_url", ""), + filename=single_search_result.get("filename", ""), + published_date=single_search_result.get("create_time", ""), + score=single_search_result.get("score", 0), + score_details=single_search_result.get("score_details", {}), + cite_index=self.record_ops + index, + search_type=self.name, + tool_sign=self.tool_sign, + ) search_results_json.append(search_result_message.to_dict()) search_results_return.append(search_result_message.to_model_dict()) @@ -116,20 +141,19 @@ def forward(self, query: str, search_mode: str = "hybrid", index_names: List[str self.observer.add_message("", ProcessType.SEARCH_CONTENT, search_results_data) return json.dumps(search_results_return, ensure_ascii=False) - - def es_search_hybrid(self, query, index_names): + def search_hybrid(self, query, index_names): try: - results = self.es_core.hybrid_search(index_names=index_names, - query_text=query, - embedding_model=self.embedding_model, - top_k=self.top_k) + results = self.vdb_core.hybrid_search( + index_names=index_names, query_text=query, embedding_model=self.embedding_model, top_k=self.top_k + ) # Format results formatted_results = [] for result in results: doc = result["document"] doc["score"] = result["score"] - doc["index"] = result["index"] # Include source index in results + # Include source index in results + doc["index"] = result["index"] formatted_results.append(doc) return { @@ -139,18 +163,17 @@ def es_search_hybrid(self, query, index_names): except Exception as e: raise Exception(f"Error during semantic search: {str(e)}") - def es_search_accurate(self, query, index_names): + def search_accurate(self, query, index_names): try: - results = self.es_core.accurate_search(index_names=index_names, - query_text=query, - top_k=self.top_k) + results = self.vdb_core.accurate_search(index_names=index_names, query_text=query, top_k=self.top_k) # Format results formatted_results = [] for result in results: doc = result["document"] doc["score"] = result["score"] - doc["index"] = result["index"] # Include source index in results + # Include source index in results + doc["index"] = result["index"] formatted_results.append(doc) return { @@ -160,19 +183,19 @@ def es_search_accurate(self, query, index_names): except Exception as e: raise Exception(detail=f"Error during accurate search: {str(e)}") - def es_search_semantic(self, query, index_names): + def search_semantic(self, query, index_names): try: - results = self.es_core.semantic_search(index_names=index_names, - query_text=query, - embedding_model=self.embedding_model, - top_k=self.top_k) + results = self.vdb_core.semantic_search( + index_names=index_names, query_text=query, embedding_model=self.embedding_model, top_k=self.top_k + ) # Format results formatted_results = [] for result in results: doc = result["document"] doc["score"] = result["score"] - doc["index"] = result["index"] # Include source index in results + # Include source index in results + doc["index"] = result["index"] formatted_results.append(doc) return { @@ -180,4 +203,4 @@ def es_search_semantic(self, query, index_names): "total": len(formatted_results), } except Exception as e: - raise Exception(detail=f"Error during semantic search: {str(e)}") \ No newline at end of file + raise Exception(detail=f"Error during semantic search: {str(e)}") diff --git a/sdk/nexent/data_process/README.md b/sdk/nexent/data_process/README.md deleted file mode 100644 index ae7736a2e..000000000 --- a/sdk/nexent/data_process/README.md +++ /dev/null @@ -1,237 +0,0 @@ -# DataProcessCore 使用指南 - -## 概述 - -`DataProcessCore` 是一个统一的文件处理核心类,支持多种文件格式的自动检测和处理,提供灵活的分块策略和多种输入源支持。 - -## 主要功能 - -### 1. 核心处理方法:`file_process()` - -**函数签名:** -```python -def file_process(self, - file_path_or_url: Optional[str] = None, - file_data: Optional[bytes] = None, - chunking_strategy: str = "basic", - destination: str = "local", - filename: Optional[str] = None, - **params) -> List[Dict] -``` - -**参数说明:** - -| 参数名 | 类型 | 必需 | 描述 | 可选值 | -|--------|------|------|------|--------| -| `file_path_or_url` | `str` | 否* | 本地文件路径或远程URL | 任何有效的文件路径或URL | -| `file_data` | `bytes` | 否* | 文件的字节数据(用于内存处理) | 任何有效的字节数据 | -| `chunking_strategy` | `str` | 否 | 分块策略 | `"basic"`, `"by_title"`, `"none"` | -| `destination` | `str` | 否 | 目标类型,指示文件来源 | `"local"`, `"minio"`, `"url"` | -| `filename` | `str` | 否** | 文件名 | 任何有效的文件名 | -| `**params` | `dict` | 否 | 额外的处理参数 | 见下方参数详情 | - -*注:`file_path_or_url` 和 `file_data` 必须提供其中一个 -**注:使用 `file_data` 时,`filename` 为必需参数 - -**分块策略 (`chunking_strategy`) 详解:** - -| 策略值 | 描述 | 适用场景 | 输出特点 | -|--------|------|----------|----------| -| `"basic"` | 基础分块策略 | 大多数文档处理场景 | 根据内容长度自动分块 | -| `"by_title"` | 按标题分块 | 结构化文档(如技术文档、报告) | 以标题为界限进行分块 | -| `"none"` | 不分块 | 短文档或需要完整内容的场景 | 返回单个包含全部内容的块 | - -**目标类型 (`destination`) 详解:** - -| 目标值 | 描述 | 使用场景 | 要求 | -|--------|------|----------|------| -| `"local"` | 本地文件 | 处理本地存储的文件 | 提供有效的本地文件路径 | -| `"minio"` | MinIO存储 | 处理云存储中的文件 | 需要数据库依赖 | -| `"url"` | 远程URL | 处理网络资源 | 需要数据库依赖 | - -**额外参数 (`**params`) 详解:** - -| 参数名 | 类型 | 默认值 | 描述 | 适用处理器 | -|--------|------|--------|------|-----------| -| `max_characters` | `int` | `1500` | 每个块的最大字符数 | Generic | -| `new_after_n_chars` | `int` | `1200` | 达到此字符数后开始新块 | Generic | -| `strategy` | `str` | `"fast"` | 处理策略 | Generic | -| `skip_infer_table_types` | `list` | `[]` | 跳过推断的表格类型 | Generic | -| `task_id` | `str` | `""` | 任务标识符 | Generic | - -**返回值格式:** - -返回 `List[Dict]`,每个字典包含以下字段: - -**通用字段:** -| 字段名 | 类型 | 描述 | 示例 | -|--------|------|------|------| -| `content` | `str` | 文本内容 | `"这是文档的第一段..."` | -| `path_or_url` | `str` | 文件路径或URL | `"/path/to/file.pdf"` | -| `filename` | `str` | 文件名 | `"document.pdf"` | - -**Excel文件额外字段:** -| 字段名 | 类型 | 描述 | 示例 | -|--------|------|------|------| -| `metadata` | `dict` | 元数据信息 | `{"chunk_index": 0, "file_type": "xlsx"}` | - -**Generic文件额外字段:** -| 字段名 | 类型 | 描述 | 示例 | -|--------|------|------|------| -| `language` | `str` | 语言标识(可选) | `"en"` | -| `metadata` | `dict` | 元数据信息(可选) | `{"chunk_index": 0}` | - -## 支持的文件类型 - -### Excel文件 -- `.xlsx` - Excel 2007及更高版本 -- `.xls` - Excel 97-2003版本 - -### 通用文件 -- `.txt` - 纯文本文件 -- `.pdf` - PDF文档 -- `.docx` - Word 2007及更高版本 -- `.doc` - Word 97-2003版本 -- `.html`, `.htm` - HTML文档 -- `.md` - Markdown文件 -- `.rtf` - 富文本格式 -- `.odt` - OpenDocument文本 -- `.pptx` - PowerPoint 2007及更高版本 -- `.ppt` - PowerPoint 97-2003版本 - -## 使用示例 - -### 示例1:处理本地文本文件 -```python -from nexent.data_process import DataProcessCore - -core = DataProcessCore() - -# 基础处理 -result = core.file_process( - file_path_or_url="/path/to/document.txt", - destination="local", - chunking_strategy="basic" -) - -print(f"处理得到 {len(result)} 个块") -for i, chunk in enumerate(result): - print(f"块 {i}: {chunk['content'][:100]}...") -``` - -### 示例2:处理Excel文件 -```python -# 处理Excel文件 -result = core.file_process( - file_path_or_url="/path/to/spreadsheet.xlsx", - destination="local", - chunking_strategy="none" # Excel通常不需要分块 -) - -for chunk in result: - print(f"文件: {chunk['filename']}") - print(f"内容: {chunk['content']}") - print(f"元数据: {chunk['metadata']}") -``` - -### 示例3:处理内存中的文件 -```python -# 读取文件到内存 -with open("/path/to/document.pdf", "rb") as f: - file_bytes = f.read() - -# 处理内存中的文件 -result = core.file_process( - file_data=file_bytes, - filename="document.pdf", - chunking_strategy="by_title", - max_characters=2000 # 自定义参数 -) -``` - -### 示例4:处理远程文件(需要数据库依赖) -```python -# 处理MinIO中的文件 -result = core.file_process( - file_path_or_url="minio://bucket/path/to/file.docx", - destination="minio", - filename="file.docx", - chunking_strategy="basic" -) -``` - -## 辅助方法 - -### 1. 获取支持的文件类型 -```python -core = DataProcessCore() -supported_types = core.get_supported_file_types() -print("Excel文件:", supported_types["excel"]) -print("通用文件:", supported_types["generic"]) -``` - -### 2. 验证文件类型 -```python -is_supported = core.validate_file_type("document.pdf") -print(f"PDF文件是否支持: {is_supported}") -``` - -### 3. 获取处理器信息 -```python -info = core.get_processor_info("spreadsheet.xlsx") -print(f"处理器类型: {info['processor_type']}") -print(f"文件扩展名: {info['file_extension']}") -print(f"是否支持: {info['is_supported']}") -``` - -### 4. 获取支持的策略和目标类型 -```python -strategies = core.get_supported_strategies() -destinations = core.get_supported_destinations() -print(f"支持的分块策略: {strategies}") -print(f"支持的目标类型: {destinations}") -``` - -## 错误处理 - -### 常见异常 - -| 异常类型 | 触发条件 | 解决方案 | -|----------|----------|----------| -| `ValueError` | 参数无效(如同时提供file_path_or_url和file_data) | 检查参数组合 | -| `FileNotFoundError` | 本地文件不存在或远程文件无法获取 | 验证文件路径 | -| `ImportError` | 处理远程文件时缺少数据库依赖 | 安装相关依赖 | - -### 错误处理示例 -```python -try: - result = core.file_process( - file_path_or_url="/nonexistent/file.txt", - destination="local" - ) -except FileNotFoundError as e: - print(f"文件未找到: {e}") -except ValueError as e: - print(f"参数错误: {e}") -except Exception as e: - print(f"处理失败: {e}") -``` - -## 性能优化建议 - -1. **选择合适的分块策略**: - - 小文件使用 `"none"` - - 大文件使用 `"basic"` - - 结构化文档使用 `"by_title"` - -2. **调整分块参数**: - - 根据下游处理需求调整 `max_characters` - - 平衡处理速度和内存使用 - -3. **文件类型优化**: - - Excel文件通常不需要分块 - - PDF文件建议使用较大的 `max_characters` - -4. **批量处理**: - - 复用 `DataProcessCore` 实例 - - 避免重复初始化 \ No newline at end of file diff --git a/sdk/nexent/memory/memory_service.py b/sdk/nexent/memory/memory_service.py index 09a7be208..45a6fc72d 100644 --- a/sdk/nexent/memory/memory_service.py +++ b/sdk/nexent/memory/memory_service.py @@ -293,7 +293,7 @@ async def reset_all_memory(memory_config: Dict[str, Any]) -> bool: async def clear_model_memories( - es_core: Any, + vdb_core: Any, model_repo: str, model_name: str, embedding_dims: int, @@ -305,7 +305,7 @@ async def clear_model_memories( memory utilities, while remaining SDK-only and transport-agnostic. Args: - es_core: An initialized Elasticsearch core instance (must expose ``client.indices`` and ``delete_index``). + vdb_core: An initialized Elasticsearch core instance (must expose ``client.indices`` and ``delete_index``). model_repo: Optional repository/namespace of the embedding model (e.g., "jina-ai"). Empty if none. model_name: The embedding model name (e.g., "jina-embeddings-v2-base-en"). embedding_dims: The embedding vector dimension for this model configuration. @@ -329,7 +329,7 @@ async def clear_model_memories( # 1) If index does not exist in ES, nothing to do try: - es_exists = es_core.client.indices.exists(index=index_name) + es_exists = vdb_core.client.indices.exists(index=index_name) except Exception: # If existence check fails, proceed defensively to attempt cleanup via mem0 then ES delete es_exists = True @@ -371,7 +371,7 @@ async def clear_model_memories( # 4) Drop ES index try: - es_core.delete_index(index_name) + vdb_core.delete_index(index_name) except Exception: # Swallow delete errors and report as best-effort pass diff --git a/sdk/nexent/multi_modal/__init__.py b/sdk/nexent/multi_modal/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/sdk/nexent/multi_modal/load_save_object.py b/sdk/nexent/multi_modal/load_save_object.py new file mode 100644 index 000000000..9e85f4880 --- /dev/null +++ b/sdk/nexent/multi_modal/load_save_object.py @@ -0,0 +1,296 @@ +import functools +import inspect +import logging +from io import BytesIO +from typing import Any, Callable, List, Optional, Tuple +import requests + +from .utils import ( + UrlType, + is_url, + generate_object_name, + detect_content_type_from_bytes, + guess_extension_from_content_type, + parse_s3_url +) + +logger = logging.getLogger("multi_modal") + + +class LoadSaveObjectManager: + """ + Provide load/save decorators that operate on a specific storage client. + + The manager can be instantiated with a storage client and exposes decorator + factories for `load_object` and `save_object`. A default module-level manager + is also provided for backwards compatibility with existing helper functions. + """ + + def __init__(self, storage_client: Any): + self._storage_client = storage_client + + def _get_client(self) -> Any: + """ + Return a ready-to-use storage client, ensuring initialization first. + """ + if self._storage_client is None: + raise ValueError("Storage client is not initialized.") + return self._storage_client + + def download_file_from_url( + self, + url: str, + url_type: UrlType, + timeout: int = 30 + ) -> Optional[bytes]: + """ + Download file content from S3 URL or HTTP/HTTPS URL as bytes. + """ + if not url: + return None + + if not url_type: + raise ValueError("url_type must be provided for download_file_from_url") + + try: + if url_type in ("http", "https"): + response = requests.get(url, timeout=timeout) + response.raise_for_status() + return response.content + + if url_type == "s3": + client = self._get_client() + bucket, object_name = parse_s3_url(url) + + if not hasattr(client, 'get_file_stream'): + raise ValueError("Storage client does not have get_file_stream method") + + success, stream = client.get_file_stream(object_name, bucket) + if not success: + raise ValueError(f"Failed to get file stream from storage: {stream}") + + try: + bytes_data = stream.read() + if hasattr(stream, 'close'): + stream.close() + return bytes_data + except Exception as exc: + raise ValueError(f"Failed to read stream content: {exc}") from exc + + raise ValueError(f"Unsupported URL type: {url_type}") + + except Exception as exc: + logger.error(f"Failed to download file from URL: {exc}") + return None + + def _upload_bytes_to_minio( + self, + bytes_data: bytes, + object_name: Optional[str] = None, + bucket: str = "multi-modal", + content_type: str = "application/octet-stream", + ) -> str: + """ + Upload bytes to MinIO and return the resulting file URL. + """ + client = self._get_client() + + if not hasattr(client, 'upload_fileobj'): + raise ValueError("Storage client must have upload_fileobj method") + + if object_name is None: + file_ext = guess_extension_from_content_type(content_type) + object_name = generate_object_name(file_ext) + + file_obj = BytesIO(bytes_data) + success, result = client.upload_fileobj(file_obj, object_name, bucket) + + if not success: + raise ValueError(f"Failed to upload file to MinIO: {result}") + + return result + + def load_object( + self, + input_names: List[str], + input_data_transformer: Optional[List[Callable[[bytes], Any]]] = None, + ): + """ + Decorator factory that downloads inputs before invoking the wrapped callable. + """ + + def decorator(func: Callable): + @functools.wraps(func) + def wrapper(*args, **kwargs): + def _transform_single_value(param_name: str, value: Any, + transformer: Optional[Callable[[bytes], Any]]) -> Any: + if isinstance(value, str): + url_type = is_url(value) + if url_type: + bytes_data = self.download_file_from_url(value, url_type=url_type) + + if bytes_data is None: + raise ValueError(f"Failed to download file from URL: {value}") + + if transformer: + transformed_data = transformer(bytes_data) + logger.info( + f"Downloaded {param_name} from URL and transformed " + f"using {transformer.__name__}" + ) + return transformed_data + + logger.info(f"Downloaded {param_name} from URL as bytes (binary stream)") + return bytes_data + + raise ValueError( + f"Parameter '{param_name}' is not a URL string. " + f"load_object decorator expects S3 or HTTP/HTTPS URLs. " + f"Got: {type(value).__name__}" + ) + + def _process_value(param_name: str, value: Any, + transformer: Optional[Callable[[bytes], Any]]) -> Any: + if value is None: + return None + + if isinstance(value, (list, tuple)): + processed_items = [ + _process_value(param_name, item, transformer) + for item in value + ] + return type(value)(processed_items) + + return _transform_single_value(param_name, value, transformer) + + sig = inspect.signature(func) + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + + for i, param_name in enumerate(input_names): + if param_name not in bound_args.arguments: + continue + + original_data = bound_args.arguments[param_name] + if original_data is None: + continue + + transformer_func = ( + input_data_transformer[i] + if input_data_transformer and i < len(input_data_transformer) + else None + ) + + transformed_data = _process_value(param_name, original_data, transformer_func) + bound_args.arguments[param_name] = transformed_data + + return func(*bound_args.args, **bound_args.kwargs) + + return wrapper + + return decorator + + def save_object( + self, + output_names: List[str], + output_transformers: Optional[List[Callable[[Any], bytes]]] = None, + bucket: str = "multi-modal", + ): + """ + Decorator factory that uploads outputs to storage after function execution. + """ + + def decorator(func: Callable) -> Callable: + def _handle_results(results: Any): + if not isinstance(results, tuple): + results_tuple = (results,) + else: + results_tuple = results + + if len(results_tuple) != len(output_names): + raise ValueError( + f"Function returned {len(results_tuple)} values, " + f"but expected {len(output_names)} outputs" + ) + + def _upload_single_output( + name: str, + value: Any, + transformer: Optional[Callable[[Any], bytes]] + ) -> str: + if transformer: + bytes_data = transformer(value) + if not isinstance(bytes_data, bytes): + raise ValueError( + f"Transformer {transformer.__name__} for {name} must return bytes, " + f"got {type(bytes_data).__name__}" + ) + logger.info(f"Transformed {name} using {transformer.__name__} to bytes") + else: + if not isinstance(value, bytes): + raise ValueError( + f"Return value for {name} must be bytes when no transformer is provided, " + f"got {type(value).__name__}" + ) + bytes_data = value + logger.info(f"Using {name} as bytes directly") + + content_type = detect_content_type_from_bytes(bytes_data) + logger.info(f"Detected content type for {name}: {content_type}") + + file_url = self._upload_bytes_to_minio( + bytes_data, + object_name=None, + content_type=content_type, + bucket=bucket, + ) + logger.info(f"Uploaded {name} to MinIO: {file_url}") + return "s3:/" + file_url + + def _process_output_value( + name: str, + value: Any, + transformer: Optional[Callable[[Any], bytes]] + ) -> Any: + if value is None: + return None + + if isinstance(value, (list, tuple)): + processed_items = [ + _process_output_value(name, item, transformer) + for item in value + ] + return type(value)(processed_items) + + return _upload_single_output(name, value, transformer) + + uploaded_urls = [] + for i, (result, name) in enumerate(zip(results_tuple, output_names)): + transformer_func = ( + output_transformers[i] + if output_transformers and i < len(output_transformers) + else None + ) + processed_result = _process_output_value(name, result, transformer_func) + uploaded_urls.append(processed_result) + + if len(uploaded_urls) == 1: + return uploaded_urls[0] + return tuple(uploaded_urls) + + if inspect.iscoroutinefunction(func): + @functools.wraps(func) + async def async_wrapper(*args, **kwargs): + results = await func(*args, **kwargs) + return _handle_results(results) + + return async_wrapper + + @functools.wraps(func) + def wrapper(*args, **kwargs): + results = func(*args, **kwargs) + return _handle_results(results) + + return wrapper + + return decorator \ No newline at end of file diff --git a/sdk/nexent/multi_modal/utils.py b/sdk/nexent/multi_modal/utils.py new file mode 100644 index 000000000..e118f6940 --- /dev/null +++ b/sdk/nexent/multi_modal/utils.py @@ -0,0 +1,354 @@ +import base64 +import logging +from datetime import datetime +import uuid +from typing import Literal, Optional, Tuple +import mimetypes +from pathlib import PurePosixPath + + +logger = logging.getLogger("multi_modal") + +UrlType = Literal["http", "https", "s3"] + + +def is_url(url: str) -> Optional[UrlType]: + """ + Classify a string URL as HTTP(S) or S3. + + Args: + url: URL candidate + + Returns: + 'http', 'https', or 's3' when the input matches the respective + scheme. Returns None when the input is not a supported URL. + """ + if not url or not isinstance(url, str): + return None + + url = url.strip() + + if url.startswith("http://"): + return "http" + + if url.startswith("https://"): + return "https" + + if url.startswith("s3://"): + bucket_path = url.replace("s3://", "", 1) + bucket_object = bucket_path.split("/", 1) + if len(bucket_object) == 2 and all(bucket_object): + return "s3" + return None + + if url.startswith("/"): + stripped = url.lstrip("/") + parts = stripped.split("/", 1) + if len(parts) == 2 and all(parts): + return "s3" + return None + + return None + + +def bytes_to_base64(bytes_data: bytes, content_type: str = "application/octet-stream") -> str: + """ + Convert bytes to base64 data URI string + + Args: + bytes_data: File content as bytes + content_type: MIME type (e.g., 'image/png', 'video/mp4', 'application/pdf') + + Returns: + Base64 data URI string (e.g., "data:image/png;base64,...") + """ + if not bytes_data: + raise ValueError("bytes_data cannot be empty") + + b64_bytes = base64.b64encode(bytes_data) + b64_string = b64_bytes.decode("utf-8") + return f"data:{content_type};base64,{b64_string}" + + +def guess_content_type_from_url(url: str) -> str: + """ + Guess content type from URL file extension + + Args: + url: URL string + + Returns: + MIME type string + """ + # Extract file extension + url_without_params = url.split("?")[0] # Remove query params + file_ext = PurePosixPath(url_without_params).suffix.lower() + + # Try mimetypes first + content_type, _ = mimetypes.guess_type(url_without_params) + if content_type: + return content_type + + # Fallback to common types + common_types = { + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".png": "image/png", + ".gif": "image/gif", + ".webp": "image/webp", + ".bmp": "image/bmp", + ".mp4": "video/mp4", + ".avi": "video/x-msvideo", + ".mov": "video/quicktime", + ".webm": "video/webm", + ".mp3": "audio/mpeg", + ".wav": "audio/wav", + ".ogg": "audio/ogg", + ".flac": "audio/flac", + ".pdf": "application/pdf", + ".txt": "text/plain", + ".json": "application/json", + } + + return common_types.get(file_ext, "application/octet-stream") + + +def base64_to_bytes(base64_data: str) -> Tuple[bytes, str]: + """ + Convert base64 data URI to bytes and extract content type + + Args: + base64_data: Base64 data URI string (e.g., "data:image/png;base64,...") + + Returns: + Tuple[bytes, content_type]: File content as bytes and MIME type + + Raises: + ValueError: If base64_data format is invalid + """ + if not base64_data or not isinstance(base64_data, str): + raise ValueError("base64_data must be a non-empty string") + + # Check if it is a data URI + if base64_data.startswith("data:"): + # Parse data URI: data:content/type;base64, + parts = base64_data.split(",", 1) + if len(parts) != 2: + raise ValueError(f"Invalid data URI format: {base64_data[:50]}...") + + header = parts[0] + data = parts[1] + + # Extract content type + if ";base64" in header: + content_type = header.replace("data:", "").replace(";base64", "") + else: + content_type = header.replace("data:", "") + + if not content_type: + content_type = "application/octet-stream" + + # Decode base64 + try: + bytes_data = base64.b64decode(data) + return bytes_data, content_type + except Exception as e: + raise ValueError(f"Failed to decode base64 data: {e}") + else: + # Assume it is raw base64 string without data URI prefix + try: + bytes_data = base64.b64decode(base64_data) + return bytes_data, "application/octet-stream" + except Exception as e: + raise ValueError(f"Failed to decode base64 string: {e}") + + +def generate_object_name(file_extension: str = "") -> str: + """ + Generate unique object name for MinIO upload + + Args: + file_extension: File extension (e.g., '.png', '.jpg') + + Returns: + Unique object name string + """ + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + unique_id = str(uuid.uuid4())[:8] + + if file_extension and not file_extension.startswith("."): + file_extension = "." + file_extension + + return f"{timestamp}_{unique_id}{file_extension}" + + +def detect_content_type_from_bytes(bytes_data: bytes) -> str: + """ + Detect content type from binary data using magic bytes (file signatures) + + Args: + bytes_data: Binary data to analyze + + Returns: + MIME type string (e.g., 'image/png', 'video/mp4') + """ + if not bytes_data or len(bytes_data) < 4: + return "application/octet-stream" + + # Get first bytes for magic number detection + header = bytes_data[:12] + + # PNG: 89 50 4E 47 0D 0A 1A 0A + if len(bytes_data) >= 8 and header[:8] == b"\x89PNG\r\n\x1a\n": + return "image/png" + + # JPEG: FF D8 FF + if len(bytes_data) >= 3 and header[:3] == b"\xff\xd8\xff": + return "image/jpeg" + + # GIF: 47 49 46 38 (GIF8) + if len(bytes_data) >= 6 and header[:6] in (b"GIF87a", b"GIF89a"): + return "image/gif" + + # WebP: 52 49 46 46 ... 57 45 42 50 (RIFF....WEBP) + if len(bytes_data) >= 12 and header[:4] == b"RIFF" and header[8:12] == b"WEBP": + return "image/webp" + + # BMP: 42 4D (BM) + if len(bytes_data) >= 2 and header[:2] == b"BM": + return "image/bmp" + + # PDF: 25 50 44 46 (%PDF) + if len(bytes_data) >= 4 and header[:4] == b"%PDF": + return "application/pdf" + + # MP4: 00 00 00 ?? 66 74 79 70 (ftyp) + if len(bytes_data) >= 8: + # Check for ftyp at offset 4 + if header[4:8] == b"ftyp": + return "video/mp4" + # Also check for quicktime/mov format + if header[4:8] == b"qt ": + return "video/quicktime" + + # MP3: Check for ID3 tag or MPEG frame sync + if len(bytes_data) >= 3: + # ID3 tag: 49 44 33 (ID3) + if header[:3] == b"ID3": + return "audio/mpeg" + # MPEG frame sync: FF FB or FF F3 + if header[:2] == b"\xff\xfb" or header[:2] == b"\xff\xf3": + return "audio/mpeg" + + # WAV: 52 49 46 46 ... 57 41 56 45 (RIFF....WAVE) + if len(bytes_data) >= 12 and header[:4] == b"RIFF" and header[8:12] == b"WAVE": + return "audio/wav" + + # OGG: 4F 67 67 53 (OggS) + if len(bytes_data) >= 4 and header[:4] == b"OggS": + return "audio/ogg" + + # FLAC: 66 4C 61 43 (fLaC) + if len(bytes_data) >= 4 and header[:4] == b"fLaC": + return "audio/flac" + + # WebM: 1A 45 DF A3 (EBML header) + if len(bytes_data) >= 4 and header[:4] == b"\x1a\x45\xdf\xa3": + return "video/webm" + + # AVI: 52 49 46 46 ... 41 56 49 20 (RIFF....AVI ) + if len(bytes_data) >= 12 and header[:4] == b"RIFF" and header[8:12] == b"AVI ": + return "video/x-msvideo" + + # JSON: Check if it starts with { or [ + try: + if bytes_data[:1] in (b"{", b"["): + # Try to decode as UTF-8 and parse as JSON + text = bytes_data[:100].decode("utf-8", errors="ignore").strip() + if text.startswith(("{", "[")): + return "application/json" + except Exception: + pass + + # Text: Check if it is valid UTF-8 text + try: + text = bytes_data[:100].decode("utf-8", errors="strict") + # If it is mostly printable ASCII, consider it text + if all(32 <= ord(c) <= 126 or c in "\n\r\t" for c in text[:50]): + return "text/plain" + except Exception: + pass + + # Default: unknown binary + return "application/octet-stream" + + +def guess_extension_from_content_type(content_type: str) -> str: + """ + Guess file extension from content type + + Args: + content_type: MIME type (e.g., 'image/png', 'video/mp4') + + Returns: + File extension (e.g., '.png', '.mp4') + """ + content_type_to_ext = { + "image/jpeg": ".jpg", + "image/png": ".png", + "image/gif": ".gif", + "image/webp": ".webp", + "image/bmp": ".bmp", + "video/mp4": ".mp4", + "video/x-msvideo": ".avi", + "video/quicktime": ".mov", + "video/webm": ".webm", + "audio/mpeg": ".mp3", + "audio/wav": ".wav", + "audio/ogg": ".ogg", + "audio/flac": ".flac", + "application/pdf": ".pdf", + "text/plain": ".txt", + "application/json": ".json", + } + + return content_type_to_ext.get(content_type, "") + + +def parse_s3_url(s3_url: str) -> Tuple[str, str]: + """ + Parse S3 URL to extract bucket and object_name + + Supports formats: + - s3://bucket/key + - /bucket/key (MinIO path format) + + Args: + s3_url: S3 URL string + + Returns: + Tuple[bucket, object_name] + + Raises: + ValueError: If URL format is not recognized + """ + if not s3_url: + raise ValueError("S3 URL cannot be empty") + + if s3_url.startswith('s3://'): + parts = s3_url.replace('s3://', '').split('/', 1) + if len(parts) == 2: + bucket, object_name = parts + if not bucket or not object_name: + raise ValueError(f"Invalid s3:// URL format: {s3_url}") + return bucket, object_name + raise ValueError(f"Invalid s3:// URL format: {s3_url}") + + if s3_url.startswith('/'): + parts = s3_url.lstrip('/').split('/', 1) + if len(parts) == 2: + bucket, object_name = parts + return bucket, object_name + raise ValueError(f"Invalid path format: {s3_url}") + + raise ValueError(f"Unrecognized S3 URL format: {s3_url[:50]}...") \ No newline at end of file diff --git a/sdk/nexent/vector_database/README.md b/sdk/nexent/vector_database/README.md deleted file mode 100644 index 68c72af45..000000000 --- a/sdk/nexent/vector_database/README.md +++ /dev/null @@ -1,871 +0,0 @@ -# Elasticsearch 向量数据库 - -一个用于 Elasticsearch 的向量搜索和文档管理服务,支持 Jina 嵌入模型集成。 - -## 环境设置 - -1. 创建一个包含凭据的 `.env` 文件: - -``` -ELASTICSEARCH_HOST=https://localhost:9200 -ELASTICSEARCH_API_KEY=your_api_key_here -JINA_API_URL=https://api.jina.ai/v1/embeddings -JINA_MODEL=jian_model_name -JINA_API_KEY=your_jina_api_key_here -``` - -2. 安装依赖: - -```bash -pip install elasticsearch python-dotenv requests fastapi uvicorn -``` - -## Docker 部署指南 - -### 前置条件 - -1. 安装Docker - - 访问 [Get Docker](https://www.docker.com/products/docker-desktop) 安装Docker - - 如果使用Docker Desktop,请确保分配至少4GB内存 - - 可以在Docker Desktop的 **Settings > Resources** 中调整内存使用 - -2. 创建Docker网络 - ```bash - docker network create elastic - ``` - -### Elasticsearch部署 - -1. 拉取Elasticsearch镜像 - ```bash - docker pull docker.elastic.co/elasticsearch/elasticsearch:8.17.4 - ``` - -2. 启动Elasticsearch容器 (静默模式,等待3-5分钟) - ```bash - docker run -d --name es01 --net elastic -p 9200:9200 -m 6GB -e "xpack.ml.use_auto_machine_memory_percent=true" docker.elastic.co/elasticsearch/elasticsearch:8.17.4 - ``` - -3. 查看Elasticsearch日志 - ```bash - docker logs -f es01 - ``` - -4. 重置密码(确认Yes) - ```bash - docker exec -it es01 /usr/share/elasticsearch/bin/elasticsearch-reset-password -u elastic - ``` - -5. 保存重要信息 - - 容器启动时会显示 `elastic` 用户密码和Kibana的注册令牌 - - 建议将密码保存为环境变量: - ```bash - export ELASTIC_PASSWORD="your_password" - ``` - -6. 复制SSL证书 - ```bash - docker cp es01:/usr/share/elasticsearch/config/certs/http_ca.crt . - ``` - -7. 验证部署 - ```bash - curl --cacert http_ca.crt -u elastic:$ELASTIC_PASSWORD https://localhost:9200 -k - ``` - -8. 获取api_key - ```bash - curl --cacert http_ca.crt \ - -u elastic:$ELASTIC_PASSWORD \ - --request POST \ - --url https://localhost:9200/_security/api_key \ - --header 'Content-Type: application/json' \ - --data '{ - "name": "取个名字" - }' - ``` - -9. 检验key有效 - ```bash - curl --request GET \ - --url https://XXX.XX.XXX.XX:9200/_cluster/health \ - --header 'Authorization: ApiKey API-KEY' - ``` - -### Kibana部署 (可选) - -1. 拉取Kibana镜像 - ```bash - docker pull docker.elastic.co/kibana/kibana:8.17.4 - ``` - -2. 启动Kibana容器 - ```bash - docker run -d --name kib01 --net elastic -p 5601:5601 docker.elastic.co/kibana/kibana:8.17.4 - ``` - -3. 查看Kibana日志 - ```bash - docker logs -f kib01 - ``` - -4. 配置Kibana - - 生成令牌,运行: - ```bash - docker exec -it es01 /usr/share/elasticsearch/bin/elasticsearch-create-enrollment-token -s kibana - ``` - - 在浏览器中,访问http://localhost:5601输入生成的注册令牌 - - 可能需要`docker logs -f kib01`查看验证码 - -5. 使用elastic用户和之前生成的密码登录Kibana - -### 常用管理命令 - -```bash -# 停止容器 -docker stop es01 -docker stop kib01 - -# 删除容器 -docker rm es01 -docker rm kib01 - -# 删除网络 -docker network rm elastic -``` - -### 生产环境注意事项 - -1. 数据持久化 - - 必须绑定数据卷到 `/usr/share/elasticsearch/data` - - 启动命令示例: - ```bash - docker run -d --name es01 --net elastic -p 9200:9200 -m 6GB -v es_data:/usr/share/elasticsearch/data docker.elastic.co/elasticsearch/elasticsearch:8.17.4 - ``` - -2. 内存配置 - - 根据实际需求调整容器内存限制 - - 建议至少分配6GB内存 - -3. 故障排除 - - 内存不足: 检查Docker Desktop的内存设置 - - 端口冲突: 确保9200端口未被占用 - - 证书问题: 确保正确复制了SSL证书 - - 昇腾服务器vm.max_map_count问题: - ```bash - # 错误信息 - # node validation exception: bootstrap checks failed - # max virtual memory areas vm.max_map_count [65530] is too low, increase to at least [262144] - - # 解决方案(在宿主机执行): - sudo sysctl -w vm.max_map_count=262144 - - # 永久生效,编辑 /etc/sysctl.conf 添加: - vm.max_map_count=262144 - - # 然后执行: - sudo sysctl -p - ``` - -### 远程部署调试指南 - -当Elasticsearch部署在远程服务器上时,可能会遇到一些网络访问的问题。以下是常见问题和解决方案: - -1. 远程访问被拒绝 - - 症状:curl请求返回 "Connection reset by peer" - - 解决方案: - ```bash - # 使用SSH隧道进行端口转发 - ssh -L 9200:localhost:9200 user@remote_server - - # 在新终端中通过本地端口访问 - curl -H "Authorization: ApiKey your_api_key" https://localhost:9200/_cluster/health\?pretty -k - ``` - -2. 网络配置检查清单 - - 确保远程服务器的防火墙允许9201端口访问 - ```bash - # 对于使用iptables的系统 - sudo iptables -A INPUT -p tcp --dport 9200 -j ACCEPT - sudo service iptables save - ``` - - - 检查Elasticsearch网络配置 - ```yaml - # elasticsearch.yml 配置示例 - network.host: 0.0.0.0 - http.cors.enabled: true - http.cors.allow-origin: "*" - ``` - -3. 安全配置建议 - - 在生产环境中,建议: - - 限制CORS的 `allow-origin` 为特定域名 - - 使用反向代理(如Nginx)管理SSL终端 - - 配置适当的网络安全组规则 - - 使用SSL证书而不是自签名证书 - -4. 使用环境变量 - - 在 `.env` 文件中配置远程连接: - ``` - ELASTICSEARCH_HOST=https://remote_server:9200 - ELASTICSEARCH_API_KEY=your_api_key - ``` - - - 如果使用SSH隧道,可以保持使用localhost: - ``` - ELASTICSEARCH_HOST=https://localhost:9200 - ``` - -5. 故障排除命令 - ```bash - # 检查端口监听状态 - netstat -tulpn | grep 9200 - - # 检查ES日志 - docker logs es01 - - # 测试SSL连接 - openssl s_client -connect remote_server:9200 - ``` - -## 核心组件 - -- `elasticsearch_core.py`: 主类,包含所有 Elasticsearch 操作 -- `embedding_model.py`: 处理使用 Jina AI 模型生成嵌入向量 -- `utils.py`: 数据格式化和显示的工具函数 -- `elasticsearch_service.py`: FastAPI 服务,提供 REST API 接口 - -## 使用示例 - -### 基本初始化 - -```python -from nexent.vector_database.elasticsearch_core import ElasticSearchCore - -# 使用 .env 文件中的凭据初始化 -es_core = ElasticSearchCore() - -# 或直接指定凭据 -es_core = ElasticSearchCore( - host="https://localhost:9200", - api_key="your_api_key", - verify_certs=False, - ssl_show_warn=False, -) -``` - -### 索引管理 - -```python -# 创建新的向量索引 -es_core.create_vector_index("my_documents") - -# 列出所有用户索引 -indices = es_core.get_user_indices() -print(indices) - -# 获取所有索引的统计信息 -all_indices_stats = es_core.get_all_indices_stats() -print(all_indices_stats) - -# 删除索引 -es_core.delete_index("my_documents") - -# 创建测试知识库 -index_name, doc_count = es_core.create_test_knowledge_base() -print(f"创建了测试知识库 {index_name},包含 {doc_count} 个文档") -``` - -### 文档操作 - -```python -# 索引文档(自动生成嵌入向量) -documents = [ - { - "id": "doc1", - "title": "文档 1", - "file": "文件1.txt", - "path_or_url": "https://example.com/doc1", - "content": "这是文档 1 的内容", - "process_source": "Web", - "embedding_model_name": "jina-embeddings-v2-base-en", # 指定嵌入模型 - "file_size": 1024, # 文件大小(字节) - "create_time": "2023-06-01T10:30:00" # 文件创建时间 - }, - { - "id": "doc2", - "title": "文档 2", - "file": "文件2.txt", - "path_or_url": "https://example.com/doc2", - "content": "这是文档 2 的内容", - "process_source": "Web" - # 如果未提供其他字段,将使用默认值 - } -] -# 支持批量处理,默认批处理大小为3000 -total_indexed = es_core.index_documents("my_documents", documents, batch_size=3000) -print(f"成功索引了 {total_indexed} 个文档") - -# 通过 URL 或路径删除文档 -deleted_count = es_core.delete_documents_by_path_or_url("my_documents", "https://example.com/doc1") -print(f"删除了 {deleted_count} 个文档") -``` - -### 搜索功能 - -```python -# 文本精确搜索 -results = es_core.accurate_search("my_documents", "示例查询", top_k=5) -for result in results: - print(f"得分: {result['score']}, 文档: {result['document']['title']}") - -# 语义向量搜索 -results = es_core.semantic_search("my_documents", "示例查询", top_k=5) -for result in results: - print(f"得分: {result['score']}, 文档: {result['document']['title']}") - -# 混合搜索 -results = es_core.hybrid_search( - "my_documents", - "示例查询", - top_k=5, - weight_accurate=0.3 # 精确搜索权重为0.3,向量搜索权重为0.7 -) -for result in results: - print(f"得分: {result['score']}, 文档: {result['document']['title']}") -``` - -### 统计和监控 - -```python -# 获取索引统计信息 -stats = es_core.get_index_stats("my_documents") -print(stats) - -# 获取文件列表及详细信息 -file_details = es_core.get_file_list_with_details("my_documents") -print(file_details) - -# 获取嵌入模型信息 -embedding_model = es_core.get_embedding_model_info("my_documents") -print(f"使用的嵌入模型: {embedding_model}") - -# 打印所有索引信息 -es_core.print_all_indices_info() -``` - -## ElasticSearchCore 主要功能 - -ElasticSearchCore 类提供了以下主要功能: - -- **索引管理**: 创建和删除索引,获取用户索引列表和统计信息 -- **文档操作**: 批量索引带有嵌入向量的文档,删除指定文档 -- **搜索操作**: 提供精确文本搜索、语义向量搜索、以及混合搜索 -- **统计和监控**: 获取索引统计数据,查看数据源、创建时间和文件列表等信息 - -### 新增高级功能 - -```python -# 获取索引的文件列表及详细信息 -files = es_core.get_file_list_with_details("my_documents") -for file in files: - print(f"文件路径: {file['path_or_url']}") - print(f"文件名: {file['file']}") - print(f"文件大小: {file['file_size']} 字节") - print(f"创建时间: {file['create_time']}") - print("---") - -# 获取嵌入模型信息 -model_info = es_core.get_embedding_model_info("my_documents") -print(f"使用的嵌入模型: {model_info}") - -# 获取所有索引的综合统计信息 -all_stats = es_core.get_all_indices_stats() -for index_name, stats in all_stats.items(): - print(f"索引: {index_name}") - print(f"文档数: {stats['base_info']['doc_count']}") - print(f"唯一源数量: {stats['base_info']['unique_sources_count']}") - print(f"使用的嵌入模型: {stats['base_info']['embedding_model']}") - print("---") -``` - -## API 服务接口 - -通过 `elasticsearch_service.py` 提供的 FastAPI 服务,可使用 REST API 访问上述所有功能。 - -### 服务启动 - -```bash -python -m nexent.service.elasticsearch_service -``` - -服务默认在 `http://localhost:8000` 运行。 - -### API 接口文档 - -#### 健康检查 - -- **GET** `/health`: 检查 API 和 Elasticsearch 连接状态 - - 返回示例: `{"status": "healthy", "elasticsearch": "connected", "indices_count": 5}` - -#### 索引管理 -- **POST** `/indices/{index_name}`: 创建索引 - - 参数: - - `index_name`: 索引名称 (路径参数) - - `embedding_dim`: 向量化维度 (查询参数,可选) - - 返回示例: `{"status": "success", "message": "Index my_documents created successfully"}` - -- **DELETE** `/indices/{index_name}`: 删除索引 - - 参数: `index_name`: 索引名称 (路径参数) - - 返回示例: `{"status": "success", "message": "Index my_documents deleted successfully"}` - -- **GET** `/indices`: 列出所有索引,可选包含详细统计信息 - - 参数: - - `pattern`: 索引名称匹配模式 (查询参数,默认为 "*") - - `include_stats`: 是否包含索引统计信息 (查询参数,默认为 false) - - 基本返回示例: `{"indices": ["index1", "index2"], "count": 2}` - - 包含统计信息的返回示例: - ```json - { - "indices": ["index1", "index2"], - "count": 2, - "indices_info": [ - { - "name": "index1", - "stats": { - "base_info": { - "doc_count": 100, - "unique_sources_count": 10, - "store_size": "1.2 MB", - "process_source": "Web", - "embedding_model": "jina-embeddings-v2-base-en", - "creation_date": "2023-06-01 12:00:00", - "update_date": "2023-06-02 15:30:00" - }, - "search_performance": { - "total_search_count": 150, - "hit_count": 120 - } - } - }, - { - "name": "index2", - "stats": { "..." } - } - ] - } - ``` - -- **GET** `/indices/{index_name}/info`: 获取索引的综合信息 - - 参数: - - `index_name`: 索引名称 (路径参数) - - `include_files`: 是否包含文件列表信息 (查询参数,默认为 true) - - `include_chunks`: 是否包含文本块信息 (查询参数,默认为 false) - - 返回综合信息,包括基本信息、搜索性能、字段列表、文件列表和文本块列表 - - 返回示例: - ```json - { - "base_info": { - "doc_count": 100, - "unique_sources_count": 10, - "store_size": "1.2 MB", - "process_source": "Web", - "embedding_model": "jina-embeddings-v2-base-en", - "embedding_dim": 1024, - "creation_date": "2023-06-01 12:00:00", - "update_date": "2023-06-02 15:30:00" - }, - "search_performance": { - "total_search_count": 150, - "hit_count": 120 - }, - "fields": ["id", "title", "content", "embedding", "embedding_model_name", "file_size", "create_time", "..."], - "files": [ - { - "path_or_url": "https://example.com/doc1", - "file": "文件1.txt", - "file_size": 1024, - "create_time": "2023-06-01T10:30:00", - "chunks_count": 6, - "status": "PROCESSING", - "chunks": [] - }, - { - "path_or_url": "https://example.com/doc2", - "file": "文件2.txt", - "file_size": 2048, - "create_time": "2023-06-01T11:45:00", - "chunks_count": 10, - "status": "WAITING", - "chunks": [] - }, - { - "path_or_url": "https://example.com/doc3", - "file": "文件3.txt", - "file_size": 0, - "create_time": "2023-06-01T12:00:00", - "chunks_count": 0, - "status": "COMPLETED", - "chunks": [ - { - "id": "task-0", - "title": "title-0", - "content": "content-0", - "create_time": "2023-06-01T12:30:00" - }, - { - "id": "task-1", - "title": "title-1", - "content": "content-1", - "create_time": "2023-06-01T12:30:00" - } - ], - } - ] - } - ``` - - 文件状态说明: - - `WAITING`: 文件正在等待处理 - - `PROCESSING`: 文件正在被处理 - - `FORWARDING`: 文件正在被转发到向量知识库服务 - - `COMPLETED`: 文件已完成处理并成功入库 - - `FAILED`: 文件处理失败 - - 文件列表包含: - - 已存在于ES中的文件(状态为 COMPLETED 或活跃任务中的状态) - - 正在数据清洗服务中处理但尚未进入ES的文件(状态为 WAITING/PROCESSING/FORWARDING/FAILED) - -#### 文档操作 - -- **POST** `/indices/{index_name}/documents`: 索引文档 - - 参数: - - `index_name`: 索引名称 (路径参数) - - `data`: 包含任务ID和文档的请求体 (IndexingRequest) - - `embedding_model_name`: 指定要使用的嵌入模型名称 (查询参数,可选) - - IndexingRequest 格式示例: - ```json - { - "task_id": "task-123", - "index_name": "my_documents", - "results": [ - { - "metadata": { - "title": "文档标题", - "filename": "文件名.txt", - "languages": ["zh"], - "author": "作者", - "file_size": 1024, - "creation_date": "2023-06-01T10:30:00" - }, - "source": "https://example.com/doc1", - "source_type": "url", - "text": "文档内容" - } - ], - "embedding_dim": 1024 - } - ``` - - 返回示例: - ```json - { - "success": true, - "message": "Successfully indexed 1 documents", - "total_indexed": 1, - "total_submitted": 1 - } - ``` - -- **DELETE** `/indices/{index_name}/documents`: 删除文档 - - 参数: - - `index_name`: 索引名称 (路径参数) - - `path_or_url`: 文档路径或URL (查询参数) - - 返回示例: `{"status": "success", "deleted_count": 1}` - -#### 搜索操作 - -- **POST** `/indices/search/accurate`: 精确文本搜索 - - 请求体 (SearchRequest): - ```json - { - "index_names": ["index1", "index2"], - "query": "搜索关键词", - "top_k": 5 - } - ``` - - 返回格式: - ```json - { - "results": [ - { - "id": "doc1", - "title": "文档标题", - "file": "文件名.txt", - "path_or_url": "https://example.com/doc1", - "content": "文档内容", - "process_source": "Web", - "embedding_model_name": "jina-embeddings-v2-base-en", - "file_size": 1024, - "create_time": "2023-06-01T10:30:00", - "score": 0.95, - "index": "index1" - }, - { - "id": "doc2", - "title": "文档标题", - "file": "文件名.txt", - "path_or_url": "https://example.com/doc2", - "content": "文档内容", - "process_source": "Web", - "embedding_model_name": "jina-embeddings-v2-base-en", - "file_size": 1024, - "create_time": "2023-06-01T10:30:00", - "score": 0.85, - "index": "index2" - } - ], - "total": 2, - "query_time_ms": 25.4 - } - ``` - -- **POST** `/indices/search/semantic`: 语义向量搜索 - - 请求体格式与精确搜索相同 (SearchRequest) - - 返回格式与精确搜索相同,但基于语义相似度评分 - -- **POST** `/indices/search/hybrid`: 混合搜索 - - 请求体 (HybridSearchRequest): - ```json - { - "index_names": ["index1", "index2"], - "query": "搜索关键词", - "top_k": 5, - "weight_accurate": 0.3 - } - ``` - - 返回格式与精确搜索相同,但包含详细的得分信息: - ```json - { - "results": [ - { - "id": "doc1", - "title": "文档标题", - "file": "文件名.txt", - "path_or_url": "https://example.com/doc1", - "content": "文档内容", - "process_source": "Web", - "embedding_model_name": "jina-embeddings-v2-base-en", - "file_size": 1024, - "create_time": "2023-06-01T10:30:00", - "score": 0.798, - "index": "index1", - "score_details": { - "accurate": 0.80, - "semantic": 0.90 - } - }, - { - "id": "doc2", - "title": "文档标题", - "file": "文件名.txt", - "path_or_url": "https://example.com/doc2", - "content": "文档内容", - "process_source": "Web", - "embedding_model_name": "jina-embeddings-v2-base-en", - "file_size": 1024, - "create_time": "2023-06-01T10:30:00", - "score": 0.756, - "index": "index1", - "score_details": { - "accurate": 0.60, - "semantic": 0.90 - } - } - ], - "total": 2, - "query_time_ms": 35.2 - } - ``` - -### API 使用示例 - -#### 使用 curl 请求示例 - -```bash -# 健康检查 -curl -X GET "http://localhost:8000/health" - -# 列出所有索引(包含统计信息) -curl -X GET "http://localhost:8000/indices?include_stats=true" - -# 获取索引详细信息(包含文本块列表) -curl -X GET "http://localhost:8000/indices/my_documents/info?include_chunks=true" - -# 精确搜索(支持多索引搜索) -curl -X POST "http://localhost:8000/indices/search/accurate" \ - -H "Content-Type: application/json" \ - -d '{ - "index_names": ["my_documents", "other_index"], - "query": "示例查询", - "top_k": 3 - }' - -# 语义搜索(支持多索引搜索) -curl -X POST "http://localhost:8000/indices/search/semantic" \ - -H "Content-Type: application/json" \ - -d '{ - "index_names": ["my_documents", "other_index"], - "query": "相似含义查询", - "top_k": 3 - }' - -# 混合搜索(支持多索引搜索) -curl -X POST "http://localhost:8000/indices/search/hybrid" \ - -H "Content-Type: application/json" \ - -d '{ - "index_names": ["my_documents", "other_index"], - "query": "示例查询", - "top_k": 3, - "weight_accurate": 0.3 - }' - -# 删除文档 -curl -X DELETE "http://localhost:8000/indices/my_documents/documents?path_or_url=https://example.com/doc1" - -# 创建索引 -curl -X POST "http://localhost:8000/indices/my_documents" - -# 删除索引 -curl -X DELETE "http://localhost:8000/indices/my_documents" -``` - -#### 使用 Python requests 示例 - -```python -import requests -import json -import time - -BASE_URL = "http://localhost:8000" - -# 当前时间,ISO格式 -current_time = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()) - -# 准备 IndexingRequest -indexing_request = { - "task_id": f"task-{int(time.time())}", - "index_name": "my_documents", - "results": [ - { - "metadata": { - "title": "示例文档", - "filename": "example.txt", - "language": "zh", - "author": "作者", - "file_size": 1024, - "creation_date": current_time - }, - "source": "https://example.com/doc1", - "text": "这是一个示例文档" - } - ], - "embedding_dim": 1024 -} - -# 索引文档 -response = requests.post( - f"{BASE_URL}/indices/my_documents/documents", - json=indexing_request, - params={ - "embedding_model_name": "jina-embeddings-v2-base-en" # 可选参数:指定嵌入模型 - } -) -print(response.json()) - -# 获取索引信息,包含文件列表 -response = requests.get( - f"{BASE_URL}/indices/my_documents/info", - params={"include_files": True} -) -print(json.dumps(response.json(), indent=2, ensure_ascii=False)) - -# 获取所有索引信息,包含统计 -response = requests.get( - f"{BASE_URL}/indices", - params={"include_stats": True} -) -print(json.dumps(response.json(), indent=2, ensure_ascii=False)) - -# 精确搜索 -response = requests.post( - f"{BASE_URL}/indices/search/accurate", - json={ - "index_names": ["my_documents", "other_index"], - "query": "示例内容", - "top_k": 3 - } -) -print(json.dumps(response.json(), indent=2, ensure_ascii=False)) - -# 语义搜索 -response = requests.post( - f"{BASE_URL}/indices/search/semantic", - json={ - "index_names": ["my_documents", "other_index"], - "query": "示例内容", - "top_k": 3 - } -) -print(json.dumps(response.json(), indent=2, ensure_ascii=False)) - -# 混合搜索 -response = requests.post( - f"{BASE_URL}/indices/search/hybrid", - json={ - "index_names": ["my_documents", "other_index"], - "query": "示例内容", - "top_k": 3, - "weight_accurate": 0.3 - } -) -print(json.dumps(response.json(), indent=2, ensure_ascii=False)) -``` - -## 完整示例 - -查看 ElasticSearchCore 类的 main 函数,了解完整功能演示: - -```python -# 初始化 ElasticSearchCore -es_core = ElasticSearchCore() - -# 获取或创建测试知识库 -index_name = "sample_articles" - -# 列出所有用户索引 -user_indices = es_core.get_user_indices() -for idx in user_indices: - print(f" - {idx}") - -# 执行搜索 -if index_name in user_indices: - # 精确搜索 - query = "Doctor" - accurate_results = es_core.accurate_search(index_name, query, top_k=2) - - # 语义搜索 - query = "medical professionals in London" - semantic_results = es_core.semantic_search(index_name, query, top_k=2) - - # 混合搜索 - query = "medical professionals in London" - semantic_results = es_core.hybrid_search(index_name, query, top_k=2, weight_accurate=0.5) - - # 获取索引统计信息 - stats = es_core.get_index_stats(index_name) - fields = es_core.get_index_mapping(index_name) - unique_sources = es_core.get_unique_sources_count(index_name) -``` - -## 许可证 - -该项目根据 MIT 许可证授权 - 详情请参阅 LICENSE 文件。 diff --git a/sdk/nexent/vector_database/base.py b/sdk/nexent/vector_database/base.py new file mode 100644 index 000000000..be53aab99 --- /dev/null +++ b/sdk/nexent/vector_database/base.py @@ -0,0 +1,272 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional + +from ..core.models.embedding_model import BaseEmbedding + + +class VectorDatabaseCore(ABC): + """ + Abstract base class for vector database operations. + + All vector database implementations must inherit from this class and implement + all abstract methods. This abstraction enables support for multiple vector + database backends (e.g., Elasticsearch, Milvus) while maintaining a consistent + interface for the service layer. + """ + + # ---- INDEX MANAGEMENT ---- + + @abstractmethod + def create_index(self, index_name: str, embedding_dim: Optional[int] = None) -> bool: + """ + Create a new vector search index with appropriate mappings. + + Args: + index_name: Name of the index to create + embedding_dim: Dimension of the embedding vectors (optional, will use model's dim if not provided) + + Returns: + bool: True if creation was successful + """ + pass + + @abstractmethod + def delete_index(self, index_name: str) -> bool: + """ + Delete an entire index. + + Args: + index_name: Name of the index to delete + + Returns: + bool: True if deletion was successful + """ + pass + + @abstractmethod + def get_user_indices(self, index_pattern: str = "*") -> List[str]: + """ + Get list of user created indices (excluding system indices). + + Args: + index_pattern: Pattern to match index names + + Returns: + List of index names + """ + pass + + @abstractmethod + def check_index_exists(self, index_name: str) -> bool: + """ + Check if an index exists. + + Args: + index_name: Name of the index to check + + Returns: + bool: True if index exists, False otherwise + """ + pass + + # ---- DOCUMENT OPERATIONS ---- + + @abstractmethod + def vectorize_documents( + self, + index_name: str, + embedding_model: BaseEmbedding, + documents: List[Dict[str, Any]], + batch_size: int = 64, + content_field: str = "content", + ) -> int: + """ + Index documents with embeddings. + + Args: + index_name: Name of the index to add documents to + embedding_model: Model used to generate embeddings for documents + documents: List of document dictionaries + batch_size: Number of documents to process at once + content_field: Field to use for generating embeddings + + Returns: + int: Number of documents successfully indexed + """ + pass + + @abstractmethod + def delete_documents(self, index_name: str, path_or_url: str) -> int: + """ + Delete documents based on their path_or_url field. + + Args: + index_name: Name of the index to delete documents from + path_or_url: The URL or path of the documents to delete + + Returns: + int: Number of documents deleted + """ + pass + + @abstractmethod + def get_index_chunks( + self, + index_name: str, + page: Optional[int] = None, + page_size: Optional[int] = None, + path_or_url: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Retrieve chunk records for the specified index with optional pagination. + + Args: + index_name: Name of the index to query + page: Page number to return (1-based). If None, all chunks are returned. + page_size: Page size for pagination. Must be provided together with page. + path_or_url: Optional filter for a specific document path or URL. + + Returns: + Dict containing chunks, total count, and pagination metadata + """ + pass + + @abstractmethod + def count_documents(self, index_name: str) -> int: + """ + Count the total number of documents in an index. + + Args: + index_name: Name of the index to count documents in + + Returns: + int: Total number of documents + """ + pass + + @abstractmethod + def search(self, index_name: str, query: Dict[str, Any]) -> Dict[str, Any]: + """ + Execute a search query on an index. + + Args: + index_name: Name of the index to search + query: Search query dictionary + + Returns: + Dict containing search results + """ + pass + + @abstractmethod + def multi_search(self, body: List[Dict[str, Any]], index_name: str) -> Dict[str, Any]: + """ + Execute multiple search queries in a single request. + + Args: + body: List of search queries (alternating index and query) + index_name: Name of the index to search + + Returns: + Dict containing responses for all queries + """ + pass + + # ---- SEARCH OPERATIONS ---- + + @abstractmethod + def accurate_search(self, index_names: List[str], query_text: str, top_k: int = 5) -> List[Dict[str, Any]]: + """ + Search for documents using fuzzy text matching across multiple indices. + + Args: + index_names: List of index names to search in + query_text: The text query to search for + top_k: Number of results to return + + Returns: + List of search results with scores and document content + """ + pass + + @abstractmethod + def semantic_search( + self, index_names: List[str], query_text: str, embedding_model: BaseEmbedding, top_k: int = 5 + ) -> List[Dict[str, Any]]: + """ + Search for similar documents using vector similarity across multiple indices. + + Args: + index_names: List of index names to search in + query_text: The text query to search for + embedding_model: The embedding model to use + top_k: Number of results to return + + Returns: + List of search results with scores and document content + """ + pass + + @abstractmethod + def hybrid_search( + self, + index_names: List[str], + query_text: str, + embedding_model: BaseEmbedding, + top_k: int = 5, + weight_accurate: float = 0.3, + ) -> List[Dict[str, Any]]: + """ + Hybrid search method, combining accurate matching and semantic search results across multiple indices. + + Args: + index_names: List of index names to search in + query_text: The text query to search for + embedding_model: The embedding model to use + top_k: Number of results to return + weight_accurate: The weight of the accurate matching score (0-1), + the semantic search weight is 1-weight_accurate + + Returns: + List of search results sorted by combined score + """ + pass + + # ---- STATISTICS AND MONITORING ---- + + @abstractmethod + def get_documents_detail(self, index_name: str) -> List[Dict[str, Any]]: + """ + Get a list of unique source files with metadata. + + Args: + index_name: Name of the index to query + + Returns: + List of dictionaries, each containing: + - path_or_url: Source identifier + - filename: Optional display name + - file_size: Size in bytes + - create_time: ISO timestamp string + """ + pass + + @abstractmethod + def get_indices_detail( + self, index_names: List[str], embedding_dim: Optional[int] = None + ) -> Dict[str, Dict[str, Dict[str, Any]]]: + """ + Get formatted statistics for multiple indices. + + Args: + index_names: List of index names to get stats for + embedding_dim: Optional embedding dimension (for display purposes) + + Returns: + Dict mapping each index name to: + - base_info: Dict with doc_count, chunk_count, store_size, + process_source, embedding_model, embedding_dim, + creation_date, update_date + - search_performance: Dict with total_search_count, hit_count + """ + pass diff --git a/sdk/nexent/vector_database/elasticsearch_core.py b/sdk/nexent/vector_database/elasticsearch_core.py index 5cd4b27f7..2b23a8be1 100644 --- a/sdk/nexent/vector_database/elasticsearch_core.py +++ b/sdk/nexent/vector_database/elasticsearch_core.py @@ -1,27 +1,37 @@ -import time import logging import threading -from typing import List, Dict, Any, Optional +import time from contextlib import contextmanager from dataclasses import dataclass from datetime import datetime, timedelta -from ..core.models.embedding_model import BaseEmbedding -from .utils import format_size, format_timestamp, build_weighted_query +from typing import Any, Dict, List, Optional + from elasticsearch import Elasticsearch, exceptions +from ..core.models.embedding_model import BaseEmbedding from ..core.nlp.tokenizer import calculate_term_weights +from .base import VectorDatabaseCore +from .utils import build_weighted_query, format_size + logger = logging.getLogger("elasticsearch_core") + @dataclass class BulkOperation: """Bulk operation status tracking""" + index_name: str operation_id: str start_time: datetime expected_duration: timedelta -class ElasticSearchCore: + +SCROLL_TTL = "2m" +DEFAULT_SCROLL_SIZE = 1000 + + +class ElasticSearchCore(VectorDatabaseCore): """ Core class for Elasticsearch operations including: - Index management @@ -32,9 +42,9 @@ class ElasticSearchCore: - Hybrid search - Index statistics """ - + def __init__( - self, + self, host: Optional[str], api_key: Optional[str], verify_certs: bool = False, @@ -42,7 +52,7 @@ def __init__( ): """ Initialize ElasticSearchCore with Elasticsearch client and JinaEmbedding model. - + Args: host: Elasticsearch host URL (defaults to env variable) api_key: Elasticsearch API key (defaults to env variable) @@ -52,7 +62,7 @@ def __init__( # Get credentials from environment if not provided self.host = host self.api_key = api_key - + # Initialize Elasticsearch client with HTTPS support self.client = Elasticsearch( self.host, @@ -64,7 +74,7 @@ def __init__( retry_on_timeout=True, retry_on_status=[502, 503, 504], # Retry on these status codes, ) - + # Initialize embedding model self._bulk_operations: Dict[str, List[BulkOperation]] = {} self._settings_lock = threading.Lock() @@ -77,22 +87,22 @@ def __init__( self.max_retries = 3 # Number of retries for failed embedding batches # ---- INDEX MANAGEMENT ---- - - def create_vector_index(self, index_name: str, embedding_dim: Optional[int] = None) -> bool: + + def create_index(self, index_name: str, embedding_dim: Optional[int] = None) -> bool: """ Create a new vector search index with appropriate mappings in a celery-friendly way. - + Args: index_name: Name of the index to create embedding_dim: Dimension of the embedding vectors (optional, will use model's dim if not provided) - + Returns: bool: True if creation was successful """ try: # Use provided embedding_dim or get from model actual_embedding_dim = embedding_dim or 1024 - + # Use balanced fixed settings to avoid dynamic adjustment settings = { "number_of_shards": 1, @@ -100,29 +110,20 @@ def create_vector_index(self, index_name: str, embedding_dim: Optional[int] = No "refresh_interval": "5s", "index": { "max_result_window": 50000, - "translog": { - "durability": "async", - "sync_interval": "5s" - }, - "write": { - "wait_for_active_shards": "1" - }, + "translog": {"durability": "async", "sync_interval": "5s"}, + "write": {"wait_for_active_shards": "1"}, # Memory optimization for bulk operations - "merge": { - "policy": { - "max_merge_at_once": 5, - "segments_per_tier": 5 - } - } - } + "merge": {"policy": {"max_merge_at_once": 5, "segments_per_tier": 5}}, + }, } # Check if index already exists if self.client.indices.exists(index=index_name): - logger.info(f"Index {index_name} already exists, skipping creation") + logger.info( + f"Index {index_name} already exists, skipping creation") self._ensure_index_ready(index_name) return True - + # Define the mapping with vector field mappings = { "properties": { @@ -146,13 +147,10 @@ def create_vector_index(self, index_name: str, embedding_dim: Optional[int] = No }, } } - + # Create the index with the defined mappings self.client.indices.create( - index=index_name, - mappings=mappings, - settings=settings, - wait_for_active_shards="1" + index=index_name, mappings=mappings, settings=settings, wait_for_active_shards="1" ) # Force refresh to ensure visibility @@ -161,11 +159,12 @@ def create_vector_index(self, index_name: str, embedding_dim: Optional[int] = No logger.info(f"Successfully created index: {index_name}") return True - + except exceptions.RequestError as e: # Handle the case where index already exists (error 400) if "resource_already_exists_exception" in str(e): - logger.info(f"Index {index_name} already exists, skipping creation") + logger.info( + f"Index {index_name} already exists, skipping creation") self._ensure_index_ready(index_name) return True logger.error(f"Error creating index: {str(e)}") @@ -200,23 +199,19 @@ def _ensure_index_ready(self, index_name: str, timeout: int = 10) -> bool: try: # Check cluster health health = self.client.cluster.health( - index=index_name, - wait_for_status="yellow", - timeout="1s" - ) + index=index_name, wait_for_status="yellow", timeout="1s") if health["status"] in ["green", "yellow"]: # Double check: try simple query - self.client.search( - index=index_name, - body={"query": {"match_all": {}}, "size": 0} - ) + self.client.search(index=index_name, body={ + "query": {"match_all": {}}, "size": 0}) return True - except Exception as e: + except Exception: time.sleep(0.1) - logger.warning(f"Index {index_name} may not be fully ready after {timeout}s") + logger.warning( + f"Index {index_name} may not be fully ready after {timeout}s") return False @contextmanager @@ -231,7 +226,7 @@ def bulk_operation_context(self, index_name: str, estimated_duration: int = 60): index_name=index_name, operation_id=operation_id, start_time=datetime.now(), - expected_duration=timedelta(seconds=estimated_duration) + expected_duration=timedelta(seconds=estimated_duration), ) with self._settings_lock: @@ -250,8 +245,7 @@ def bulk_operation_context(self, index_name: str, estimated_duration: int = 60): with self._settings_lock: # Remove operation record self._bulk_operations[index_name] = [ - op for op in self._bulk_operations[index_name] - if op.operation_id != operation_id + op for op in self._bulk_operations[index_name] if op.operation_id != operation_id ] # If there are no other bulk operations, restore settings @@ -264,11 +258,8 @@ def _apply_bulk_settings(self, index_name: str): try: self.client.indices.put_settings( index=index_name, - body={ - "refresh_interval": "30s", - "translog.durability": "async", - "translog.sync_interval": "10s" - } + body={"refresh_interval": "30s", "translog.durability": "async", + "translog.sync_interval": "10s"}, ) logger.debug(f"Applied bulk settings to {index_name}") except Exception as e: @@ -278,11 +269,8 @@ def _restore_normal_settings(self, index_name: str): """Restore normal settings""" try: self.client.indices.put_settings( - index=index_name, - body={ - "refresh_interval": "5s", - "translog.durability": "request" - } + index=index_name, body={ + "refresh_interval": "5s", "translog.durability": "request"} ) # Refresh after restoration self._force_refresh_with_retry(index_name) @@ -293,10 +281,10 @@ def _restore_normal_settings(self, index_name: str): def delete_index(self, index_name: str) -> bool: """ Delete an entire index - + Args: index_name: Name of the index to delete - + Returns: bool: True if deletion was successful """ @@ -310,45 +298,57 @@ def delete_index(self, index_name: str) -> bool: except Exception as e: logger.error(f"Error deleting index: {str(e)}") return False - + def get_user_indices(self, index_pattern: str = "*") -> List[str]: """ Get list of user created indices (excluding system indices) - + Args: index_pattern: Pattern to match index names - + Returns: List of index names """ try: indices = self.client.indices.get_alias(index=index_pattern) # Filter out system indices (starting with '.') - return [index_name for index_name in indices.keys() if not index_name.startswith('.')] + return [index_name for index_name in indices.keys() if not index_name.startswith(".")] except Exception as e: logger.error(f"Error getting user indices: {str(e)}") return [] - + + def check_index_exists(self, index_name: str) -> bool: + """ + Check if an index exists. + + Args: + index_name: Name of the index to check + + Returns: + bool: True if index exists, False otherwise + """ + return self.client.indices.exists(index=index_name) + # ---- DOCUMENT OPERATIONS ---- - - def index_documents( - self, + + def vectorize_documents( + self, index_name: str, embedding_model: BaseEmbedding, documents: List[Dict[str, Any]], batch_size: int = 64, - content_field: str = "content" + content_field: str = "content", ) -> int: """ Smart batch insertion - automatically selecting strategy based on data size - + Args: index_name: Name of the index to add documents to embedding_model: Model used to generate embeddings for documents documents: List of document dictionaries batch_size: Number of documents to process at once content_field: Field to use for generating embeddings - + Returns: int: Number of documents successfully indexed """ @@ -369,12 +369,15 @@ def index_documents( with self.bulk_operation_context(index_name, estimated_duration): return self._large_batch_insert(index_name, documents, batch_size, content_field, embedding_model) - def _small_batch_insert(self, index_name: str, documents: List[Dict[str, Any]], content_field: str, embedding_model:BaseEmbedding) -> int: + def _small_batch_insert( + self, index_name: str, documents: List[Dict[str, Any]], content_field: str, embedding_model: BaseEmbedding + ) -> int: """Small batch insertion: real-time""" try: # Preprocess documents - processed_docs = self._preprocess_documents(documents, content_field) - + processed_docs = self._preprocess_documents( + documents, content_field) + # Get embeddings inputs = [doc[content_field] for doc in processed_docs] embeddings = embedding_model.get_embeddings(inputs) @@ -390,38 +393,45 @@ def _small_batch_insert(self, index_name: str, documents: List[Dict[str, Any]], # Execute bulk insertion, wait for refresh to complete response = self.client.bulk( - index=index_name, - operations=operations, - refresh='wait_for' - ) + index=index_name, operations=operations, refresh="wait_for") # Handle errors self._handle_bulk_errors(response) - logger.info(f"Small batch insert completed: {len(documents)} chunks indexed.") + logger.info( + f"Small batch insert completed: {len(documents)} chunks indexed.") return len(documents) - + except Exception as e: logger.error(f"Small batch insert failed: {e}") return 0 - def _large_batch_insert(self, index_name: str, documents: List[Dict[str, Any]], batch_size: int, content_field: str, embedding_model: BaseEmbedding) -> int: + def _large_batch_insert( + self, + index_name: str, + documents: List[Dict[str, Any]], + batch_size: int, + content_field: str, + embedding_model: BaseEmbedding, + ) -> int: """ Large batch insertion with sub-batching for embedding API. Splits large document batches into smaller chunks to respect embedding API limits before bulk inserting into Elasticsearch. """ try: - processed_docs = self._preprocess_documents(documents, content_field) + processed_docs = self._preprocess_documents( + documents, content_field) total_indexed = 0 total_docs = len(processed_docs) es_total_batches = (total_docs + batch_size - 1) // batch_size start_time = time.time() logger.info( - f"=== [INDEXING START] Total chunks: {total_docs}, ES batch size: {batch_size}, Total ES batches: {es_total_batches} ===") + f"=== [INDEXING START] Total chunks: {total_docs}, ES batch size: {batch_size}, Total ES batches: {es_total_batches} ===" + ) for i in range(0, total_docs, batch_size): - es_batch = processed_docs[i:i + batch_size] + es_batch = processed_docs[i: i + batch_size] es_batch_num = i // batch_size + 1 es_batch_start_time = time.time() @@ -431,7 +441,7 @@ def _large_batch_insert(self, index_name: str, documents: List[Dict[str, Any]], # Sub-batch for embedding API embedding_batch_size = 64 for j in range(0, len(es_batch), embedding_batch_size): - embedding_sub_batch = es_batch[j:j + embedding_batch_size] + embedding_sub_batch = es_batch[j: j + embedding_batch_size] # Retry logic for embedding API call (3 retries, 1s delay) # Note: embedding_model.get_embeddings() already has built-in retries with exponential backoff # This outer retry handles additional failures @@ -454,19 +464,22 @@ def _large_batch_insert(self, index_name: str, documents: List[Dict[str, Any]], except Exception as e: if retry_attempt < max_retries - 1: logger.warning( - f"Embedding API error (attempt {retry_attempt + 1}/{max_retries}): {e}, ES batch num: {es_batch_num}, sub-batch start: {j}, size: {len(embedding_sub_batch)}. Retrying in {retry_delay}s...") + f"Embedding API error (attempt {retry_attempt + 1}/{max_retries}): {e}, ES batch num: {es_batch_num}, sub-batch start: {j}, size: {len(embedding_sub_batch)}. Retrying in {retry_delay}s..." + ) time.sleep(retry_delay) else: logger.error( - f"Embedding API error after {max_retries} attempts: {e}, ES batch num: {es_batch_num}, sub-batch start: {j}, size: {len(embedding_sub_batch)}") + f"Embedding API error after {max_retries} attempts: {e}, ES batch num: {es_batch_num}, sub-batch start: {j}, size: {len(embedding_sub_batch)}" + ) if not success: # Skip this sub-batch after all retries failed continue - + # Perform a single bulk insert for the entire Elasticsearch batch if not doc_embedding_pairs: - logger.warning(f"No documents with embeddings to index for ES batch {es_batch_num}") + logger.warning( + f"No documents with embeddings to index for ES batch {es_batch_num}") continue operations = [] @@ -474,32 +487,33 @@ def _large_batch_insert(self, index_name: str, documents: List[Dict[str, Any]], operations.append({"index": {"_index": index_name}}) doc["embedding"] = embedding if "embedding_model_name" not in doc: - doc["embedding_model_name"] = getattr(embedding_model, 'embedding_model_name', 'unknown') + doc["embedding_model_name"] = getattr( + embedding_model, "embedding_model_name", "unknown") operations.append(doc) try: response = self.client.bulk( - index=index_name, - operations=operations, - refresh=False - ) + index=index_name, operations=operations, refresh=False) self._handle_bulk_errors(response) total_indexed += len(doc_embedding_pairs) es_batch_elapsed = time.time() - es_batch_start_time logger.info( - f"[ES BATCH {es_batch_num}/{es_total_batches}] Indexed {len(doc_embedding_pairs)} documents in {es_batch_elapsed:.2f}s. Total progress: {total_indexed}/{total_docs}") + f"[ES BATCH {es_batch_num}/{es_total_batches}] Indexed {len(doc_embedding_pairs)} documents in {es_batch_elapsed:.2f}s. Total progress: {total_indexed}/{total_docs}" + ) except Exception as e: - logger.error(f"Bulk insert error: {e}, ES batch num: {es_batch_num}") + logger.error( + f"Bulk insert error: {e}, ES batch num: {es_batch_num}") continue - + # Add 0.1s delay between batches to avoid overloading embedding API time.sleep(0.1) self._force_refresh_with_retry(index_name) total_elapsed = time.time() - start_time logger.info( - f"=== [INDEXING COMPLETE] Successfully indexed {total_indexed}/{total_docs} chunks in {total_elapsed:.2f}s (avg: {total_elapsed/es_total_batches:.2f}s/batch) ===") + f"=== [INDEXING COMPLETE] Successfully indexed {total_indexed}/{total_docs} chunks in {total_elapsed:.2f}s (avg: {total_elapsed / es_total_batches:.2f}s/batch) ===" + ) return total_indexed except Exception as e: logger.error(f"Large batch insert failed: {e}") @@ -508,7 +522,7 @@ def _large_batch_insert(self, index_name: str, documents: List[Dict[str, Any]], def _preprocess_documents(self, documents: List[Dict[str, Any]], content_field: str) -> List[Dict[str, Any]]: """Ensure all documents have the required fields and set default values""" current_time = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()) - current_date = time.strftime('%Y-%m-%d', time.gmtime()) + current_date = time.strftime("%Y-%m-%d", time.gmtime()) processed_docs = [] for doc in documents: @@ -533,7 +547,8 @@ def _preprocess_documents(self, documents: List[Dict[str, Any]], content_field: # Ensure all documents have an ID if not doc_copy.get("id"): - doc_copy["id"] = f"{int(time.time())}_{hash(doc_copy[content_field])}"[:20] + doc_copy["id"] = f"{int(time.time())}_{hash(doc_copy[content_field])}"[ + :20] processed_docs.append(doc_copy) @@ -541,52 +556,214 @@ def _preprocess_documents(self, documents: List[Dict[str, Any]], content_field: def _handle_bulk_errors(self, response: Dict[str, Any]) -> None: """Handle bulk operation errors""" - if response.get('errors'): - for item in response['items']: - if 'error' in item.get('index', {}): - error_info = item['index']['error'] - error_type = error_info.get('type') - error_reason = error_info.get('reason') - error_cause = error_info.get('caused_by', {}) - - if error_type == 'version_conflict_engine_exception': + if response.get("errors"): + for item in response["items"]: + if "error" in item.get("index", {}): + error_info = item["index"]["error"] + error_type = error_info.get("type") + error_reason = error_info.get("reason") + error_cause = error_info.get("caused_by", {}) + + if error_type == "version_conflict_engine_exception": # ignore version conflict continue else: - logger.error(f"FATAL ERROR {error_type}: {error_reason}") + logger.error( + f"FATAL ERROR {error_type}: {error_reason}") if error_cause: - logger.error(f"Caused By: {error_cause.get('type')}: {error_cause.get('reason')}") - - def delete_documents_by_path_or_url(self, index_name: str, path_or_url: str) -> int: + logger.error( + f"Caused By: {error_cause.get('type')}: {error_cause.get('reason')}") + + def delete_documents(self, index_name: str, path_or_url: str) -> int: """ Delete documents based on their path_or_url field - + Args: index_name: Name of the index to delete documents from path_or_url: The URL or path of the documents to delete - + Returns: int: Number of documents deleted """ try: result = self.client.delete_by_query( - index=index_name, - body={ - "query": { - "term": { - "path_or_url": path_or_url - } - } - } + index=index_name, body={ + "query": {"term": {"path_or_url": path_or_url}}} + ) + logger.info( + f"Successfully deleted {result['deleted']} documents with path_or_url: {path_or_url} from index: {index_name}" ) - logger.info(f"Successfully deleted {result['deleted']} documents with path_or_url: {path_or_url} from index: {index_name}") - return result['deleted'] + return result["deleted"] except Exception as e: logger.error(f"Error deleting documents: {str(e)}") return 0 - + + def count_documents(self, index_name: str) -> int: + """ + Count the total number of documents in an index. + + Args: + index_name: Name of the index to count documents in + + Returns: + int: Total number of documents + """ + try: + count_response = self.client.count(index=index_name) + return count_response["count"] + except Exception as e: + logger.error(f"Error counting documents: {str(e)}") + return 0 + + def get_index_chunks( + self, + index_name: str, + page: Optional[int] = None, + page_size: Optional[int] = None, + path_or_url: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Retrieve chunk records for the specified index with optional pagination. + + Args: + index_name: Name of the index to query + page: Page number (1-based). Provide together with page_size. + page_size: Number of records per page. Provide together with page. + path_or_url: Optional path_or_url filter. + + Returns: + Dictionary containing chunks, total count, page, and page_size + """ + chunks: List[Dict[str, Any]] = [] + total = 0 + scroll_id: Optional[str] = None + paginate = page is not None and page_size is not None + result_page = page if paginate else None + result_page_size = page_size if paginate else None + + try: + query: Dict[str, Any] = {"match_all": {}} + if path_or_url: + query = {"term": {"path_or_url": path_or_url}} + + count_response = self.client.count( + index=index_name, + body={"query": query}, + ) + total = count_response.get("count", 0) + + if total == 0: + return { + "chunks": [], + "total": 0, + "page": result_page, + "page_size": result_page_size, + } + + source_filter = {"_source": {"excludes": ["embedding"]}} + + if paginate: + safe_page = max(page, 1) + safe_page_size = max(page_size, 1) + from_index = (safe_page - 1) * safe_page_size + response = self.client.search( + index=index_name, + body={ + "query": query, + **source_filter, + }, + from_=from_index, + size=safe_page_size, + ) + hits = response.get("hits", {}).get("hits", []) + for hit in hits: + chunk = hit.get("_source", {}).copy() + if "id" not in chunk: + chunk["id"] = hit.get("_id") + chunks.append(chunk) + else: + response = self.client.search( + index=index_name, + body={ + "query": query, + **source_filter, + }, + size=DEFAULT_SCROLL_SIZE, + scroll=SCROLL_TTL, + ) + scroll_id = response.get("_scroll_id") + + while True: + hits = response.get("hits", {}).get("hits", []) + if not hits: + break + + for hit in hits: + chunk = hit.get("_source", {}).copy() + if "id" not in chunk: + chunk["id"] = hit.get("_id") + chunks.append(chunk) + + if not scroll_id: + break + + response = self.client.scroll( + scroll_id=scroll_id, + scroll=SCROLL_TTL, + ) + scroll_id = response.get("_scroll_id") + + except exceptions.NotFoundError: + logger.info(f"Index {index_name} not found when fetching chunks") + chunks = [] + total = 0 + except Exception as e: + logger.error(f"Error fetching chunks for index {index_name}: {e}") + raise + finally: + if scroll_id: + try: + self.client.clear_scroll(scroll_id=scroll_id) + except Exception as cleanup_error: + logger.warning( + f"Failed to clear scroll context for index {index_name}: {cleanup_error}" + ) + + return { + "chunks": chunks, + "total": total, + "page": result_page, + "page_size": result_page_size, + } + + def search(self, index_name: str, query: Dict[str, Any]) -> Dict[str, Any]: + """ + Execute a search query on an index. + + Args: + index_name: Name of the index to search + query: Search query dictionary + + Returns: + Dict containing search results + """ + return self.client.search(index=index_name, body=query) + + def multi_search(self, body: List[Dict[str, Any]], index_name: str) -> Dict[str, Any]: + """ + Execute multiple search queries in a single request. + + Args: + body: List of search queries (alternating index and query) + index_name: Name of the index to search + + Returns: + Dict containing responses for all queries + """ + return self.client.msearch(body=body, index=index_name) + # ---- SEARCH OPERATIONS ---- - + def accurate_search(self, index_names: List[str], query_text: str, top_k: int = 5) -> List[Dict[str, Any]]: """ Search for documents using fuzzy text matching across multiple indices. @@ -595,7 +772,7 @@ def accurate_search(self, index_names: List[str], query_text: str, top_k: int = index_names: Name of the index to search in query_text: The text query to search for top_k: Number of results to return - + Returns: List of search results with scores and document content """ @@ -607,39 +784,38 @@ def accurate_search(self, index_names: List[str], query_text: str, top_k: int = # Prepare the search query using match query for fuzzy matching search_query = build_weighted_query(query_text, weights) | { "size": top_k, - "_source": { - "excludes": ["embedding"] - } + "_source": {"excludes": ["embedding"]}, } # Execute the search across multiple indices return self.exec_query(index_pattern, search_query) def exec_query(self, index_pattern, search_query): - response = self.client.search( - index=index_pattern, - body=search_query - ) + response = self.client.search(index=index_pattern, body=search_query) # Process and return results results = [] for hit in response["hits"]["hits"]: - results.append({ - "score": hit["_score"], - "document": hit["_source"], - "index": hit["_index"] # Include source index in results - }) + results.append( + { + "score": hit["_score"], + "document": hit["_source"], + "index": hit["_index"], # Include source index in results + } + ) return results - def semantic_search(self, index_names: List[str], query_text: str, embedding_model: BaseEmbedding, top_k: int = 5) -> List[Dict[str, Any]]: + def semantic_search( + self, index_names: List[str], query_text: str, embedding_model: BaseEmbedding, top_k: int = 5 + ) -> List[Dict[str, Any]]: """ Search for similar documents using vector similarity across multiple indices. - + Args: index_names: List of index names to search in query_text: The text query to search for embedding_model: The embedding model to use top_k: Number of results to return - + Returns: List of search results with scores and document content """ @@ -648,7 +824,7 @@ def semantic_search(self, index_names: List[str], query_text: str, embedding_mod # Get query embedding query_embedding = embedding_model.get_embeddings(query_text)[0] - + # Prepare the search query search_query = { "knn": { @@ -658,11 +834,9 @@ def semantic_search(self, index_names: List[str], query_text: str, embedding_mod "num_candidates": top_k * 2, }, "size": top_k, - "_source": { - "excludes": ["embedding"] - } + "_source": {"excludes": ["embedding"]}, } - + # Execute the search across multiple indices return self.exec_query(index_pattern, search_query) @@ -672,11 +846,11 @@ def hybrid_search( query_text: str, embedding_model: BaseEmbedding, top_k: int = 5, - weight_accurate: float = 0.3 + weight_accurate: float = 0.3, ) -> List[Dict[str, Any]]: """ Hybrid search method, combining accurate matching and semantic search results across multiple indices. - + Args: index_names: List of index names to search in query_text: The text query to search for @@ -688,8 +862,10 @@ def hybrid_search( List of search results sorted by combined score """ # Get results from both searches - accurate_results = self.accurate_search(index_names, query_text, top_k=top_k) - semantic_results = self.semantic_search(index_names, query_text, embedding_model=embedding_model, top_k=top_k) + accurate_results = self.accurate_search( + index_names, query_text, top_k=top_k) + semantic_results = self.semantic_search( + index_names, query_text, embedding_model=embedding_model, top_k=top_k) # Create a mapping from document ID to results combined_results = {} @@ -697,79 +873,85 @@ def hybrid_search( # Process accurate matching results for result in accurate_results: try: - doc_id = result['document']['id'] + doc_id = result["document"]["id"] combined_results[doc_id] = { - 'document': result['document'], - 'accurate_score': result.get('score', 0), - 'semantic_score': 0, - 'index': result['index'] # Keep track of source index + "document": result["document"], + "accurate_score": result.get("score", 0), + "semantic_score": 0, + "index": result["index"], # Keep track of source index } except KeyError as e: - logger.warning(f"Warning: Missing required field in accurate result: {e}") + logger.warning( + f"Warning: Missing required field in accurate result: {e}") continue # Process semantic search results for result in semantic_results: try: - doc_id = result['document']['id'] + doc_id = result["document"]["id"] if doc_id in combined_results: - combined_results[doc_id]['semantic_score'] = result.get('score', 0) + combined_results[doc_id]["semantic_score"] = result.get( + "score", 0) else: combined_results[doc_id] = { - 'document': result['document'], - 'accurate_score': 0, - 'semantic_score': result.get('score', 0), - 'index': result['index'] # Keep track of source index + "document": result["document"], + "accurate_score": 0, + "semantic_score": result.get("score", 0), + "index": result["index"], # Keep track of source index } except KeyError as e: - logger.warning(f"Warning: Missing required field in semantic result: {e}") + logger.warning( + f"Warning: Missing required field in semantic result: {e}") continue # Calculate maximum scores - max_accurate = max([r.get('score', 0) for r in accurate_results]) if accurate_results else 1 - max_semantic = max([r.get('score', 0) for r in semantic_results]) if semantic_results else 1 + max_accurate = max([r.get("score", 0) + for r in accurate_results]) if accurate_results else 1 + max_semantic = max([r.get("score", 0) + for r in semantic_results]) if semantic_results else 1 # Calculate combined scores and sort results = [] for doc_id, result in combined_results.items(): try: # Get scores safely - accurate_score = result.get('accurate_score', 0) - semantic_score = result.get('semantic_score', 0) + accurate_score = result.get("accurate_score", 0) + semantic_score = result.get("semantic_score", 0) # Normalize scores normalized_accurate = accurate_score / max_accurate if max_accurate > 0 else 0 normalized_semantic = semantic_score / max_semantic if max_semantic > 0 else 0 # Calculate weighted combined score - combined_score = (weight_accurate * normalized_accurate + - (1 - weight_accurate) * normalized_semantic) - - results.append({ - 'score': combined_score, - 'document': result['document'], - 'index': result['index'], # Include source index in results - 'scores': { - 'accurate': normalized_accurate, - 'semantic': normalized_semantic + combined_score = weight_accurate * normalized_accurate + \ + (1 - weight_accurate) * normalized_semantic + + results.append( + { + "score": combined_score, + "document": result["document"], + # Include source index in results + "index": result["index"], + "scores": {"accurate": normalized_accurate, "semantic": normalized_semantic}, } - }) + ) except KeyError as e: - logger.warning(f"Warning: Error processing result for doc_id {doc_id}: {e}") + logger.warning( + f"Warning: Error processing result for doc_id {doc_id}: {e}") continue # Sort by combined score and return top k results - results.sort(key=lambda x: x['score'], reverse=True) + results.sort(key=lambda x: x["score"], reverse=True) return results[:top_k] # ---- STATISTICS AND MONITORING ---- - def get_file_list_with_details(self, index_name: str) -> List[Dict[str, Any]]: + def get_documents_detail(self, index_name: str) -> List[Dict[str, Any]]: """ Get a list of unique path_or_url values with their file_size and create_time - + Args: index_name: Name of the index to query - + Returns: List of dictionaries with path_or_url, file_size, and create_time """ @@ -779,58 +961,40 @@ def get_file_list_with_details(self, index_name: str) -> List[Dict[str, Any]]: "unique_sources": { "terms": { "field": "path_or_url", - "size": 1000 # Limit to 1000 files for performance + "size": 1000, # Limit to 1000 files for performance }, "aggs": { "file_sample": { - "top_hits": { - "size": 1, - "_source": ["path_or_url", "file_size", "create_time", "filename"] - } + "top_hits": {"size": 1, "_source": ["path_or_url", "file_size", "create_time", "filename"]} } - } + }, } - } + }, } - + try: - result = self.client.search( - index=index_name, - body=agg_query - ) - + result = self.client.search(index=index_name, body=agg_query) + file_list = [] - for bucket in result['aggregations']['unique_sources']['buckets']: - source = bucket['file_sample']['hits']['hits'][0]['_source'] + for bucket in result["aggregations"]["unique_sources"]["buckets"]: + source = bucket["file_sample"]["hits"]["hits"][0]["_source"] file_info = { "path_or_url": source["path_or_url"], "filename": source.get("filename", ""), "file_size": source.get("file_size", 0), - "create_time": source.get("create_time", None) + "create_time": source.get("create_time", None), + "chunk_count": bucket.get("doc_count", 0), } file_list.append(file_info) - + return file_list except Exception as e: logger.error(f"Error getting file list: {str(e)}") return [] - - def get_index_mapping(self, index_names: List[str]) -> Dict[str, List[str]]: - """Get field mappings for multiple indices""" - mappings = {} - for index_name in index_names: - try: - mapping = self.client.indices.get_mapping(index=index_name) - if mapping[index_name].get('mappings') and mapping[index_name]['mappings'].get('properties'): - mappings[index_name] = list(mapping[index_name]['mappings']['properties'].keys()) - else: - mappings[index_name] = [] - except Exception as e: - logger.error(f"Error getting mapping for index {index_name}: {str(e)}") - mappings[index_name] = [] - return mappings - - def get_index_stats(self, index_names: List[str], embedding_dim: Optional[int] = None) -> Dict[str, Dict[str, Dict[str, Any]]]: + + def get_indices_detail( + self, index_names: List[str], embedding_dim: Optional[int] = None + ) -> Dict[str, Dict[str, Dict[str, Any]]]: """Get formatted statistics for multiple indices""" all_stats = {} for index_name in index_names: @@ -842,40 +1006,33 @@ def get_index_stats(self, index_names: List[str], embedding_dim: Optional[int] = agg_query = { "size": 0, "aggs": { - "unique_path_or_url_count": { - "cardinality": { - "field": "path_or_url" - } - }, - "process_sources": { - "terms": { - "field": "process_source", - "size": 10 - } - }, - "embedding_models": { - "terms": { - "field": "embedding_model_name", - "size": 10 - } - } - } + "unique_path_or_url_count": {"cardinality": {"field": "path_or_url"}}, + "process_sources": {"terms": {"field": "process_source", "size": 10}}, + "embedding_models": {"terms": {"field": "embedding_model_name", "size": 10}}, + }, } # Execute query agg_result = self.client.search( - index=index_name, - body=agg_query - ) + index=index_name, body=agg_query) - unique_sources_count = agg_result['aggregations']['unique_path_or_url_count']['value'] - process_source = agg_result['aggregations']['process_sources']['buckets'][0]['key'] if agg_result['aggregations']['process_sources']['buckets'] else "" - embedding_model = agg_result['aggregations']['embedding_models']['buckets'][0]['key'] if agg_result['aggregations']['embedding_models']['buckets'] else "" + unique_sources_count = agg_result["aggregations"]["unique_path_or_url_count"]["value"] + process_source = ( + agg_result["aggregations"]["process_sources"]["buckets"][0]["key"] + if agg_result["aggregations"]["process_sources"]["buckets"] + else "" + ) + embedding_model = ( + agg_result["aggregations"]["embedding_models"]["buckets"][0]["key"] + if agg_result["aggregations"]["embedding_models"]["buckets"] + else "" + ) index_stats = stats["indices"][index_name]["primaries"] # Get creation and update timestamps from settings - creation_date = int(settings[index_name]['settings']['index']['creation_date']) + creation_date = int( + settings[index_name]["settings"]["index"]["creation_date"]) # Update time defaults to creation time if not modified update_time = creation_date @@ -888,16 +1045,16 @@ def get_index_stats(self, index_names: List[str], embedding_dim: Optional[int] = "embedding_model": embedding_model, "embedding_dim": embedding_dim or 1024, "creation_date": creation_date, - "update_date": update_time + "update_date": update_time, }, "search_performance": { "total_search_count": index_stats["search"]["query_total"], "hit_count": index_stats["request_cache"]["hit_count"], - } + }, } except Exception as e: - logger.error(f"Error getting stats for index {index_name}: {str(e)}") + logger.error( + f"Error getting stats for index {index_name}: {str(e)}") all_stats[index_name] = {"error": str(e)} return all_stats - \ No newline at end of file diff --git a/test/backend/agents/test_create_agent_info.py b/test/backend/agents/test_create_agent_info.py index 9ef6882d8..e4e66edec 100644 --- a/test/backend/agents/test_create_agent_info.py +++ b/test/backend/agents/test_create_agent_info.py @@ -59,7 +59,7 @@ sys.modules['database.agent_db'] = MagicMock() sys.modules['database.tool_db'] = MagicMock() sys.modules['database.model_management_db'] = MagicMock() -sys.modules['services.elasticsearch_service'] = MagicMock() +sys.modules['services.vectordatabase_service'] = MagicMock() sys.modules['services.tenant_config_service'] = MagicMock() sys.modules['utils.prompt_template_utils'] = MagicMock() sys.modules['utils.config_utils'] = MagicMock() @@ -250,7 +250,7 @@ async def test_create_tool_config_list_with_knowledge_base_tool(self): with patch('backend.agents.create_agent_info.discover_langchain_tools') as mock_discover, \ patch('backend.agents.create_agent_info.search_tools_for_sub_agent') as mock_search_tools, \ patch('backend.agents.create_agent_info.get_selected_knowledge_list') as mock_knowledge, \ - patch('backend.agents.create_agent_info.elastic_core') as mock_elastic, \ + patch('backend.agents.create_agent_info.get_vector_db_core') as mock_get_vector_db_core, \ patch('backend.agents.create_agent_info.get_embedding_model') as mock_embedding: mock_discover.return_value = [] @@ -268,9 +268,10 @@ async def test_create_tool_config_list_with_knowledge_base_tool(self): ] mock_knowledge.return_value = [ {"index_name": "knowledge_1"}, - {"index_name": "knowledge_2"} + {"index_name": "knowledge_2"}, ] - mock_elastic.return_value = "mock_elastic_core" + mock_vdb_core = "mock_elastic_core" + mock_get_vector_db_core.return_value = mock_vdb_core mock_embedding.return_value = "mock_embedding_model" result = await create_tool_config_list("agent_1", "tenant_1", "user_1") diff --git a/test/backend/app/test_agent_app.py b/test/backend/app/test_agent_app.py index 7bff04f76..3eeaf6650 100644 --- a/test/backend/app/test_agent_app.py +++ b/test/backend/app/test_agent_app.py @@ -1,5 +1,5 @@ import atexit -from unittest.mock import patch, Mock, MagicMock +from unittest.mock import patch, Mock, MagicMock, ANY import os import sys import types @@ -35,7 +35,7 @@ patch('backend.database.client.minio_client', minio_mock).start() patch('elasticsearch.Elasticsearch', return_value=MagicMock()).start() -# Apply patches before importing any app modules (similar to test_base_app.py) +# Apply patches before importing any app modules (similar to test_config_app.py) patches = [ # Mock database sessions patch('backend.database.client.get_db_session', return_value=Mock()) @@ -45,7 +45,7 @@ p.start() # Import target endpoints with all external dependencies patched -from apps.agent_app import router +from apps.agent_app import agent_config_router, agent_runtime_router # Mock external dependencies before importing the modules that use them # Stub nexent.core.agents.agent_model.ToolConfig to satisfy type imports in consts.model @@ -138,10 +138,14 @@ def stop_patches(): atexit.register(stop_patches) -# Create FastAPI app for testing -app = FastAPI() -app.include_router(router) -client = TestClient(app) +# Create FastAPI apps for runtime and config routers +runtime_app = FastAPI() +runtime_app.include_router(agent_runtime_router) +runtime_client = TestClient(runtime_app) + +config_app = FastAPI() +config_app.include_router(agent_config_router) +config_client = TestClient(config_app) @pytest.fixture @@ -168,7 +172,7 @@ async def mock_stream(): mock_run_agent_stream.return_value = StreamingResponse( mock_stream(), media_type="text/event-stream") - response = client.post( + response = runtime_client.post( "/agent/run", json={ "agent_id": 1, @@ -200,7 +204,7 @@ def test_agent_stop_api_success(mocker, mock_conversation_id): mock_stop_tasks = mocker.patch("apps.agent_app.stop_agent_tasks") mock_stop_tasks.return_value = {"status": "success"} - response = client.get( + response = runtime_client.get( f"/agent/stop/{mock_conversation_id}", headers={"Authorization": "Bearer test_token"} ) @@ -221,7 +225,7 @@ def test_agent_stop_api_not_found(mocker, mock_conversation_id): mock_stop_tasks = mocker.patch("apps.agent_app.stop_agent_tasks") mock_stop_tasks.return_value = {"status": "error"} # Simulate not found - response = client.get( + response = runtime_client.get( f"/agent/stop/{mock_conversation_id}", headers={"Authorization": "Bearer test_token"} ) @@ -244,7 +248,7 @@ def test_search_agent_info_api_success(mocker, mock_auth_header): mock_get_agent_info.return_value = {"agent_id": 123, "name": "Test Agent"} # Test the endpoint - response = client.post( + response = config_client.post( "/agent/search_info", json=123, # agent_id as body parameter headers=mock_auth_header @@ -266,7 +270,7 @@ def test_search_agent_info_api_exception(mocker, mock_auth_header): mock_get_agent_info.side_effect = Exception("Test error") # Test the endpoint - response = client.post( + response = config_client.post( "/agent/search_info", json=123, headers=mock_auth_header @@ -284,7 +288,7 @@ def test_get_creating_sub_agent_info_api_success(mocker, mock_auth_header): mock_get_creating_agent.return_value = {"agent_id": 456} # Test the endpoint - this is a GET request - response = client.get( + response = config_client.get( "/agent/get_creating_sub_agent_id", headers=mock_auth_header ) @@ -303,7 +307,7 @@ def test_get_creating_sub_agent_info_api_exception(mocker, mock_auth_header): mock_get_creating_agent.side_effect = Exception("Test error") # Test the endpoint - this is a GET request - response = client.get( + response = config_client.get( "/agent/get_creating_sub_agent_id", headers=mock_auth_header ) @@ -320,7 +324,7 @@ def test_update_agent_info_api_success(mocker, mock_auth_header): mock_update_agent.return_value = None # Test the endpoint - response = client.post( + response = config_client.post( "/agent/update", json={"agent_id": 123, "name": "Updated Agent", "display_name": "Updated Display Name"}, @@ -340,7 +344,7 @@ def test_update_agent_info_api_exception(mocker, mock_auth_header): mock_update_agent.side_effect = Exception("Test error") # Test the endpoint - response = client.post( + response = config_client.post( "/agent/update", json={"agent_id": 123, "name": "Updated Agent", "display_name": "Updated Display Name"}, @@ -359,7 +363,7 @@ def test_delete_agent_api_success(mocker, mock_auth_header): mock_delete_agent.return_value = None # Test the endpoint - response = client.request( + response = config_client.request( "DELETE", "/agent", json={"agent_id": 123}, @@ -380,7 +384,7 @@ def test_delete_agent_api_exception(mocker, mock_auth_header): mock_delete_agent.side_effect = Exception("Test error") # Test the endpoint - response = client.request( + response = config_client.request( "DELETE", "/agent", json={"agent_id": 123}, @@ -400,7 +404,7 @@ async def test_export_agent_api_success(mocker, mock_auth_header): mock_export_agent.return_value = '{"agent_id": 123, "name": "Test Agent"}' # Test the endpoint - response = client.post( + response = config_client.post( "/agent/export", json={"agent_id": 123}, headers=mock_auth_header @@ -422,7 +426,7 @@ async def test_export_agent_api_exception(mocker, mock_auth_header): mock_export_agent.side_effect = Exception("Test error") # Test the endpoint - response = client.post( + response = config_client.post( "/agent/export", json={"agent_id": 123}, headers=mock_auth_header @@ -440,7 +444,7 @@ def test_import_agent_api_success(mocker, mock_auth_header): mock_import_agent.return_value = None # Test the endpoint - following the ExportAndImportDataFormat structure - response = client.post( + response = config_client.post( "/agent/import", json={ "agent_info": { @@ -484,7 +488,7 @@ def test_import_agent_api_exception(mocker, mock_auth_header): mock_import_agent.side_effect = Exception("Test error") # Test the endpoint - following the ExportAndImportDataFormat structure - response = client.post( + response = config_client.post( "/agent/import", json={ "agent_info": { @@ -530,14 +534,14 @@ def test_list_all_agent_info_api_success(mocker, mock_auth_header): ] # Test the endpoint - response = client.get( + response = config_client.get( "/agent/list", headers=mock_auth_header ) # Assertions assert response.status_code == 200 - mock_get_user_info.assert_called_once() + mock_get_user_info.assert_called_once_with(mock_auth_header["Authorization"], ANY) mock_list_all_agent.assert_called_once_with(tenant_id="test_tenant") assert len(response.json()) == 2 assert response.json()[0]["agent_id"] == 1 @@ -556,14 +560,14 @@ def test_list_all_agent_info_api_exception(mocker, mock_auth_header): mock_list_all_agent.side_effect = Exception("Test error") # Test the endpoint - response = client.get( + response = config_client.get( "/agent/list", headers=mock_auth_header ) # Assertions assert response.status_code == 500 - mock_get_user_info.assert_called_once() + mock_get_user_info.assert_called_once_with(mock_auth_header["Authorization"], ANY) mock_list_all_agent.assert_called_once_with(tenant_id="test_tenant") assert "Agent list error" in response.json()["detail"] @@ -587,7 +591,7 @@ async def test_export_agent_api_detailed(mocker, mock_auth_header): mock_export_agent.return_value = agent_data # Test with complex data - response = client.post( + response = config_client.post( "/agent/export", json={"agent_id": 456}, headers=mock_auth_header @@ -616,7 +620,7 @@ async def test_export_agent_api_empty_response(mocker, mock_auth_header): mock_export_agent.return_value = {} # Send request - response = client.post( + response = config_client.post( "/agent/export", json={"agent_id": 789}, headers=mock_auth_header @@ -636,34 +640,34 @@ async def test_export_agent_api_empty_response(mocker, mock_auth_header): def _alias_services_for_tests(): """ - 兼容路由里使用的 'services.agent_service' 动态导入路径。 - 将 backend.services.* 映射为 services.* 以便 mocker.patch 能找到。 + Provide fallback aliases for dynamic `services.agent_service` imports used by the routers. + Map `backend.services.*` modules to `services.*` so mocker.patch can locate them. """ import sys try: import backend.services as b_services import backend.services.agent_service as b_agent_service - # 父包与子模块都映射一份 + # Map both the package and submodule for compatibility sys.modules['services'] = b_services sys.modules['services.agent_service'] = b_agent_service except Exception: - # 如果你的工程本来就能直接 import services.*,这里兜底不做处理 + # If the project already supports direct imports, ignore the failure pass def test_get_agent_call_relationship_api_success(mocker, mock_auth_header): - # patch 鉴权 + # Patch authentication helper mock_get_user_id = mocker.patch("apps.agent_app.get_current_user_id") mock_get_user_id.return_value = ("user_id_x", "tenant_abc") - # 现在改 patch 这里:指向 apps.agent_app 命名空间的顶层符号 + # Patch the implementation referenced from the apps.agent_app namespace mock_impl = mocker.patch("apps.agent_app.get_agent_call_relationship_impl") mock_impl.return_value = { "agent_id": 1, "tree": {"tools": [], "sub_agents": []} } - resp = client.get("/agent/call_relationship/1", headers=mock_auth_header) + resp = config_client.get("/agent/call_relationship/1", headers=mock_auth_header) assert resp.status_code == 200 mock_get_user_id.assert_called_once_with(mock_auth_header["Authorization"]) @@ -677,11 +681,11 @@ def test_get_agent_call_relationship_api_exception(mocker, mock_auth_header): mock_get_user_id = mocker.patch("apps.agent_app.get_current_user_id") mock_get_user_id.return_value = ("user_id_x", "tenant_abc") - # 同样改这里 + # Patch the same implementation for the error path mock_impl = mocker.patch("apps.agent_app.get_agent_call_relationship_impl") mock_impl.side_effect = Exception("boom") - resp = client.get("/agent/call_relationship/999", headers=mock_auth_header) + resp = config_client.get("/agent/call_relationship/999", headers=mock_auth_header) assert resp.status_code == 500 assert "Failed to get agent call relationship" in resp.json()["detail"] diff --git a/test/backend/app/test_base_app.py b/test/backend/app/test_config_app.py similarity index 99% rename from test/backend/app/test_base_app.py rename to test/backend/app/test_config_app.py index c8c7d5758..fcfcbc92e 100644 --- a/test/backend/app/test_base_app.py +++ b/test/backend/app/test_config_app.py @@ -67,7 +67,7 @@ # Now safe to import app modules from fastapi import HTTPException from fastapi.testclient import TestClient -from apps.base_app import app +from apps.config_app import app # Stop all patches at the end of the module diff --git a/test/backend/app/test_knowledge_summary_app.py b/test/backend/app/test_knowledge_summary_app.py index c3b1a8ab0..4c8cc5080 100644 --- a/test/backend/app/test_knowledge_summary_app.py +++ b/test/backend/app/test_knowledge_summary_app.py @@ -1,11 +1,16 @@ import pytest -import json import sys import os +import types from unittest.mock import patch, MagicMock, AsyncMock # Add path for correct imports -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../backend")) +CURRENT_DIR = os.path.dirname(__file__) +PROJECT_ROOT = os.path.abspath(os.path.join(CURRENT_DIR, "../../..")) +BACKEND_DIR = os.path.join(PROJECT_ROOT, "backend") +for path in (PROJECT_ROOT, BACKEND_DIR): + if path not in sys.path: + sys.path.insert(0, path) # Patch environment variables before any imports that might use them os.environ.setdefault('MINIO_ENDPOINT', 'http://localhost:9000') @@ -29,7 +34,20 @@ sys.modules['nexent.core.models.tts_model'] = MagicMock() sys.modules['nexent.core.nlp'] = MagicMock() sys.modules['nexent.core.nlp.tokenizer'] = MagicMock() -sys.modules['nexent.vector_database'] = MagicMock() +vector_db_module = types.ModuleType("nexent.vector_database") +vector_db_base_module = types.ModuleType("nexent.vector_database.base") + + +class MockVectorDatabaseCore: + def __init__(self, *args, **kwargs): + pass + + +vector_db_base_module.VectorDatabaseCore = MockVectorDatabaseCore +vector_db_module.base = vector_db_base_module + +sys.modules['nexent.vector_database'] = vector_db_module +sys.modules['nexent.vector_database.base'] = vector_db_base_module sys.modules['nexent.vector_database.elasticsearch_core'] = MagicMock() # Mock specific classes that are imported @@ -80,7 +98,7 @@ def __init__(self, *args, **kwargs): pass from fastapi.testclient import TestClient from fastapi import FastAPI from pydantic import BaseModel - from apps.knowledge_summary_app import router + from backend.apps.knowledge_summary_app import router # Define test models class ChangeSummaryRequest(BaseModel): @@ -107,7 +125,7 @@ def test_data(): def test_auto_summary_success(test_data): """Test successful auto summary generation""" # Setup mock responses - mock_es_core_instance = MagicMock() + mock_vdb_core_instance = MagicMock() mock_user_info = ("test_user_id", "test_tenant_id", "en") # Setup service mock @@ -117,9 +135,9 @@ def test_auto_summary_success(test_data): mock_service_instance.summary_index_name.return_value = stream_response # Patch all necessary components directly in the app module - with patch('apps.knowledge_summary_app.ElasticSearchService', return_value=mock_service_instance), \ - patch('apps.knowledge_summary_app.get_es_core', return_value=mock_es_core_instance), \ - patch('apps.knowledge_summary_app.get_current_user_info', return_value=mock_user_info): + with patch('backend.apps.knowledge_summary_app.ElasticSearchService', return_value=mock_service_instance), \ + patch('backend.apps.knowledge_summary_app.get_vector_db_core', return_value=mock_vdb_core_instance), \ + patch('backend.apps.knowledge_summary_app.get_current_user_info', return_value=mock_user_info): # Execute test with model_id parameter response = client.post( @@ -127,6 +145,8 @@ def test_auto_summary_success(test_data): headers=test_data["auth_header"] ) + assert response.status_code == 200 + # Assertions - verify the function was called exactly once assert mock_service_instance.summary_index_name.call_count == 1 @@ -141,7 +161,7 @@ def test_auto_summary_success(test_data): def test_auto_summary_without_model_id(test_data): """Test successful auto summary generation without model_id parameter""" # Setup mock responses - mock_es_core_instance = MagicMock() + mock_vdb_core_instance = MagicMock() mock_user_info = ("test_user_id", "test_tenant_id", "en") # Setup service mock @@ -151,9 +171,9 @@ def test_auto_summary_without_model_id(test_data): mock_service_instance.summary_index_name.return_value = stream_response # Patch all necessary components directly in the app module - with patch('apps.knowledge_summary_app.ElasticSearchService', return_value=mock_service_instance), \ - patch('apps.knowledge_summary_app.get_es_core', return_value=mock_es_core_instance), \ - patch('apps.knowledge_summary_app.get_current_user_info', return_value=mock_user_info): + with patch('backend.apps.knowledge_summary_app.ElasticSearchService', return_value=mock_service_instance), \ + patch('backend.apps.knowledge_summary_app.get_vector_db_core', return_value=mock_vdb_core_instance), \ + patch('backend.apps.knowledge_summary_app.get_current_user_info', return_value=mock_user_info): # Execute test without model_id parameter response = client.post( @@ -161,6 +181,8 @@ def test_auto_summary_without_model_id(test_data): headers=test_data["auth_header"] ) + assert response.status_code == 200 + # Assertions - verify the function was called exactly once assert mock_service_instance.summary_index_name.call_count == 1 @@ -175,7 +197,7 @@ def test_auto_summary_without_model_id(test_data): def test_auto_summary_exception(test_data): """Test auto summary generation with exception""" # Setup mock to raise exception - mock_es_core_instance = MagicMock() + mock_vdb_core_instance = MagicMock() mock_user_info = ("test_user_id", "test_tenant_id", "en") # Setup service mock to raise exception @@ -184,10 +206,10 @@ def test_auto_summary_exception(test_data): side_effect=Exception("Error generating summary") ) - # Patch both the ElasticSearchService and get_es_core in the route handler - with patch('apps.knowledge_summary_app.ElasticSearchService', return_value=mock_service_instance), \ - patch('apps.knowledge_summary_app.get_es_core', return_value=mock_es_core_instance), \ - patch('apps.knowledge_summary_app.get_current_user_info', return_value=mock_user_info): + # Patch both the ElasticSearchService and get_vector_db_core in the route handler + with patch('backend.apps.knowledge_summary_app.ElasticSearchService', return_value=mock_service_instance), \ + patch('backend.apps.knowledge_summary_app.get_vector_db_core', return_value=mock_vdb_core_instance), \ + patch('backend.apps.knowledge_summary_app.get_current_user_info', return_value=mock_user_info): # Execute test response = client.post( @@ -219,8 +241,8 @@ def test_change_summary_success(test_data): mock_service_instance.change_summary.return_value = expected_response # Execute test with direct patching of route handler function - with patch('apps.knowledge_summary_app.ElasticSearchService', return_value=mock_service_instance), \ - patch('apps.knowledge_summary_app.get_current_user_id', return_value=test_data["user_id"]): + with patch('backend.apps.knowledge_summary_app.ElasticSearchService', return_value=mock_service_instance), \ + patch('backend.apps.knowledge_summary_app.get_current_user_id', return_value=test_data["user_id"]): response = client.post( f"/summary/{test_data['index_name']}/summary", @@ -254,8 +276,8 @@ def test_change_summary_exception(test_data): mock_service_instance.change_summary.side_effect = Exception("Error updating summary") # Execute test - with patch('apps.knowledge_summary_app.ElasticSearchService', return_value=mock_service_instance), \ - patch('apps.knowledge_summary_app.get_current_user_id', return_value=test_data["user_id"]): + with patch('backend.apps.knowledge_summary_app.ElasticSearchService', return_value=mock_service_instance), \ + patch('backend.apps.knowledge_summary_app.get_current_user_id', return_value=test_data["user_id"]): response = client.post( f"/summary/{test_data['index_name']}/summary", @@ -280,7 +302,7 @@ def test_get_summary_success(test_data): mock_service_instance = MagicMock() mock_service_instance.get_summary.return_value = expected_response - with patch('apps.knowledge_summary_app.ElasticSearchService', return_value=mock_service_instance): + with patch('backend.apps.knowledge_summary_app.ElasticSearchService', return_value=mock_service_instance): # Execute test response = client.get(f"/summary/{test_data['index_name']}/summary") @@ -299,7 +321,7 @@ def test_get_summary_exception(test_data): mock_service_instance = MagicMock() mock_service_instance.get_summary.side_effect = Exception("Error getting summary") - with patch('apps.knowledge_summary_app.ElasticSearchService', return_value=mock_service_instance): + with patch('backend.apps.knowledge_summary_app.ElasticSearchService', return_value=mock_service_instance): # Execute test response = client.get(f"/summary/{test_data['index_name']}/summary") diff --git a/test/backend/app/test_model_managment_app.py b/test/backend/app/test_model_managment_app.py index bdc57e3a5..18a41b54f 100644 --- a/test/backend/app/test_model_managment_app.py +++ b/test/backend/app/test_model_managment_app.py @@ -32,17 +32,17 @@ def client(mocker): mocker.patch('boto3.client') # Patch MinioClient at both possible import paths mocker.patch('backend.database.client.MinioClient') - # Stub services.elasticsearch_service to avoid real ES initialization + # Stub services.vectordatabase_service to avoid real VDB initialization import types import sys as _sys - if "services.elasticsearch_service" not in _sys.modules: - services_es_mod = types.ModuleType("services.elasticsearch_service") + if "services.vectordatabase_service" not in _sys.modules: + services_vdb_mod = types.ModuleType("services.vectordatabase_service") - def _get_es_core(): # minimal stub + def _get_vector_db_core(): # minimal stub return object() - services_es_mod.get_es_core = _get_es_core - _sys.modules["services.elasticsearch_service"] = services_es_mod + services_vdb_mod.get_vector_db_core = _get_vector_db_core + _sys.modules["services.vectordatabase_service"] = services_vdb_mod # Import after mocking (only backend path is required by app imports) from apps.model_managment_app import router diff --git a/test/backend/app/test_elasticsearch_app.py b/test/backend/app/test_vectordatabase_app.py similarity index 70% rename from test/backend/app/test_elasticsearch_app.py rename to test/backend/app/test_vectordatabase_app.py index ff552201b..5455a8293 100644 --- a/test/backend/app/test_elasticsearch_app.py +++ b/test/backend/app/test_vectordatabase_app.py @@ -78,7 +78,7 @@ class IndexingResponse(BaseModel): RedisService = MagicMock() # Import routes and services -from backend.apps.elasticsearch_app import router +from backend.apps.vectordatabase_app import router from nexent.vector_database.elasticsearch_core import ElasticSearchCore # Create test client @@ -96,15 +96,10 @@ class IndexingResponse(BaseModel): @pytest.fixture -def es_core_mock(): +def vdb_core_mock(): return MagicMock(spec=ElasticSearchCore) -@pytest.fixture -def es_service_mock(): - return MagicMock(spec=ElasticSearchService) - - @pytest.fixture def redis_service_mock(): mock = MagicMock() @@ -126,15 +121,15 @@ def auth_data(): @pytest.mark.asyncio -async def test_create_new_index_success(es_core_mock, auth_data): +async def test_create_new_index_success(vdb_core_mock, auth_data): """ Test creating a new index successfully. Verifies that the endpoint returns the expected response when index creation succeeds. """ # Setup mocks - with patch("backend.apps.elasticsearch_app.get_es_core", return_value=es_core_mock), \ - patch("backend.apps.elasticsearch_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ - patch("backend.apps.elasticsearch_app.ElasticSearchService.create_index") as mock_create: + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.create_index") as mock_create: expected_response = {"status": "success", "index_name": auth_data["index_name"]} @@ -151,15 +146,15 @@ async def test_create_new_index_success(es_core_mock, auth_data): @pytest.mark.asyncio -async def test_create_new_index_error(es_core_mock, auth_data): +async def test_create_new_index_error(vdb_core_mock, auth_data): """ Test creating a new index with error. Verifies that the endpoint returns an appropriate error response when index creation fails. """ # Setup mocks - with patch("backend.apps.elasticsearch_app.get_es_core", return_value=es_core_mock), \ - patch("backend.apps.elasticsearch_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ - patch("backend.apps.elasticsearch_app.ElasticSearchService.create_index") as mock_create: + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.create_index") as mock_create: mock_create.side_effect = Exception("Test error") @@ -174,18 +169,18 @@ async def test_create_new_index_error(es_core_mock, auth_data): @pytest.mark.asyncio -async def test_delete_index_success(es_core_mock, redis_service_mock, auth_data): +async def test_delete_index_success(vdb_core_mock, redis_service_mock, auth_data): """ Test deleting an index successfully. Verifies that the endpoint returns the expected response and performs Redis cleanup. """ # Setup mocks - with patch("backend.apps.elasticsearch_app.get_es_core", return_value=es_core_mock), \ - patch("backend.apps.elasticsearch_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ - patch("backend.apps.elasticsearch_app.get_redis_service", return_value=redis_service_mock), \ - patch("backend.apps.elasticsearch_app.ElasticSearchService.list_files") as mock_list_files, \ - patch("backend.apps.elasticsearch_app.ElasticSearchService.delete_index") as mock_delete, \ - patch("backend.apps.elasticsearch_app.ElasticSearchService.full_delete_knowledge_base") as mock_full_delete: + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.get_redis_service", return_value=redis_service_mock), \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.list_files") as mock_list_files, \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.delete_index") as mock_delete, \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.full_delete_knowledge_base") as mock_full_delete: # Properly setup the async mock for list_files mock_list_files.return_value = {"files": []} @@ -237,27 +232,27 @@ async def test_delete_index_success(es_core_mock, redis_service_mock, auth_data) assert "minio_cleanup" in actual_response # Verify full_delete_knowledge_base was called with the correct parameters - # Use ANY for the es_core parameter because the actual object may differ + # Use ANY for the vdb_core parameter because the actual object may differ mock_full_delete.assert_called_once_with( auth_data["index_name"], - ANY, # Use ANY instead of es_core_mock to ignore object identity + ANY, # Use ANY instead of vdb_core_mock to ignore object identity auth_data["user_id"] ) @pytest.mark.asyncio -async def test_delete_index_redis_error(es_core_mock, redis_service_mock, auth_data): +async def test_delete_index_redis_error(vdb_core_mock, redis_service_mock, auth_data): """ Test deleting an index with Redis error. Verifies that the endpoint still succeeds with ES but reports Redis cleanup error. """ # Setup mocks - with patch("backend.apps.elasticsearch_app.get_es_core", return_value=es_core_mock), \ - patch("backend.apps.elasticsearch_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ - patch("backend.apps.elasticsearch_app.get_redis_service", return_value=redis_service_mock), \ - patch("backend.apps.elasticsearch_app.ElasticSearchService.list_files") as mock_list_files, \ - patch("backend.apps.elasticsearch_app.ElasticSearchService.delete_index") as mock_delete, \ - patch("backend.apps.elasticsearch_app.ElasticSearchService.full_delete_knowledge_base") as mock_full_delete: + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.get_redis_service", return_value=redis_service_mock), \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.list_files") as mock_list_files, \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.delete_index") as mock_delete, \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.full_delete_knowledge_base") as mock_full_delete: # Properly setup the async mock for list_files mock_list_files.return_value = {"files": []} @@ -310,23 +305,23 @@ async def test_delete_index_redis_error(es_core_mock, redis_service_mock, auth_d ) or "error" in str(actual_response).lower() # Verify full_delete_knowledge_base was called with the correct parameters - # Use ANY for the es_core parameter because the actual object may differ + # Use ANY for the vdb_core parameter because the actual object may differ mock_full_delete.assert_called_once_with( auth_data["index_name"], - ANY, # Use ANY instead of es_core_mock to ignore object identity + ANY, # Use ANY instead of vdb_core_mock to ignore object identity auth_data["user_id"] ) @pytest.mark.asyncio -async def test_get_list_indices_success(es_core_mock): +async def test_get_list_indices_success(vdb_core_mock): """ Test listing indices successfully. Verifies that the endpoint returns the expected list of indices. """ # Setup mocks - with patch("backend.apps.elasticsearch_app.get_es_core", return_value=es_core_mock), \ - patch("backend.apps.elasticsearch_app.ElasticSearchService.list_indices") as mock_list: + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.list_indices") as mock_list: expected_response = {"indices": ["index1", "index2"]} mock_list.return_value = expected_response @@ -342,14 +337,14 @@ async def test_get_list_indices_success(es_core_mock): @pytest.mark.asyncio -async def test_get_list_indices_error(es_core_mock): +async def test_get_list_indices_error(vdb_core_mock): """ Test listing indices with error. Verifies that the endpoint returns an appropriate error response when listing fails. """ # Setup mocks - with patch("backend.apps.elasticsearch_app.get_es_core", return_value=es_core_mock), \ - patch("backend.apps.elasticsearch_app.ElasticSearchService.list_indices") as mock_list: + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.list_indices") as mock_list: mock_list.side_effect = Exception("Test error") @@ -362,16 +357,16 @@ async def test_get_list_indices_error(es_core_mock): @pytest.mark.asyncio -async def test_create_index_documents_success(es_core_mock, auth_data): +async def test_create_index_documents_success(vdb_core_mock, auth_data): """ Test indexing documents successfully. Verifies that the endpoint returns the expected response after documents are indexed. """ # Setup mocks - with patch("backend.apps.elasticsearch_app.get_es_core", return_value=es_core_mock), \ - patch("backend.apps.elasticsearch_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ - patch("backend.apps.elasticsearch_app.ElasticSearchService.index_documents") as mock_index, \ - patch("backend.apps.elasticsearch_app.get_embedding_model", return_value=MagicMock()): + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.index_documents") as mock_index, \ + patch("backend.apps.vectordatabase_app.get_embedding_model", return_value=MagicMock()): index_name = "test_index" documents = [{"id": 1, "text": "test doc"}] @@ -397,16 +392,16 @@ async def test_create_index_documents_success(es_core_mock, auth_data): @pytest.mark.asyncio -async def test_create_index_documents_exception(es_core_mock, auth_data): +async def test_create_index_documents_exception(vdb_core_mock, auth_data): """ Test indexing documents with exception. Verifies that the endpoint returns an appropriate error response when an exception occurs during indexing. """ # Setup mocks - with patch("backend.apps.elasticsearch_app.get_es_core", return_value=es_core_mock), \ - patch("backend.apps.elasticsearch_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ - patch("backend.apps.elasticsearch_app.ElasticSearchService.index_documents") as mock_index, \ - patch("backend.apps.elasticsearch_app.get_embedding_model", return_value=MagicMock()): + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.index_documents") as mock_index, \ + patch("backend.apps.vectordatabase_app.get_embedding_model", return_value=MagicMock()): index_name = "test_index" documents = [{"id": 1, "text": "test doc"}] @@ -430,15 +425,15 @@ async def test_create_index_documents_exception(es_core_mock, auth_data): @pytest.mark.asyncio -async def test_create_index_documents_auth_exception(es_core_mock, auth_data): +async def test_create_index_documents_auth_exception(vdb_core_mock, auth_data): """ Test indexing documents with authentication exception. Verifies that the endpoint returns an appropriate error response when authentication fails. """ # Setup mocks - with patch("backend.apps.elasticsearch_app.get_es_core", return_value=es_core_mock), \ - patch("backend.apps.elasticsearch_app.get_current_user_id") as mock_get_user, \ - patch("backend.apps.elasticsearch_app.get_embedding_model", return_value=MagicMock()): + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.get_current_user_id") as mock_get_user, \ + patch("backend.apps.vectordatabase_app.get_embedding_model", return_value=MagicMock()): index_name = "test_index" documents = [{"id": 1, "text": "test doc"}] @@ -462,15 +457,15 @@ async def test_create_index_documents_auth_exception(es_core_mock, auth_data): @pytest.mark.asyncio -async def test_create_index_documents_embedding_model_exception(es_core_mock, auth_data): +async def test_create_index_documents_embedding_model_exception(vdb_core_mock, auth_data): """ Test indexing documents with embedding model exception. Verifies that the endpoint returns an appropriate error response when embedding model fails. """ # Setup mocks - with patch("backend.apps.elasticsearch_app.get_es_core", return_value=es_core_mock), \ - patch("backend.apps.elasticsearch_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ - patch("backend.apps.elasticsearch_app.get_embedding_model") as mock_get_embedding: + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.get_embedding_model") as mock_get_embedding: index_name = "test_index" documents = [{"id": 1, "text": "test doc"}] @@ -495,16 +490,16 @@ async def test_create_index_documents_embedding_model_exception(es_core_mock, au @pytest.mark.asyncio -async def test_create_index_documents_validation_exception(es_core_mock, auth_data): +async def test_create_index_documents_validation_exception(vdb_core_mock, auth_data): """ Test indexing documents with validation exception. Verifies that the endpoint returns an appropriate error response when document validation fails. """ # Setup mocks - with patch("backend.apps.elasticsearch_app.get_es_core", return_value=es_core_mock), \ - patch("backend.apps.elasticsearch_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ - patch("backend.apps.elasticsearch_app.ElasticSearchService.index_documents") as mock_index, \ - patch("backend.apps.elasticsearch_app.get_embedding_model", return_value=MagicMock()): + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.index_documents") as mock_index, \ + patch("backend.apps.vectordatabase_app.get_embedding_model", return_value=MagicMock()): index_name = "test_index" documents = [{"id": 1, "text": "test doc"}] @@ -528,14 +523,14 @@ async def test_create_index_documents_validation_exception(es_core_mock, auth_da @pytest.mark.asyncio -async def test_get_index_files_success(es_core_mock): +async def test_get_index_files_success(vdb_core_mock): """ Test listing index files successfully. Using pytest-asyncio to properly handle async operations. """ # Setup mocks - with patch("backend.apps.elasticsearch_app.get_es_core", return_value=es_core_mock), \ - patch("backend.apps.elasticsearch_app.ElasticSearchService.list_files") as mock_list_files: + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.list_files") as mock_list_files: index_name = "test_index" expected_files = { @@ -560,14 +555,14 @@ async def test_get_index_files_success(es_core_mock): @pytest.mark.asyncio -async def test_get_index_files_exception(es_core_mock): +async def test_get_index_files_exception(vdb_core_mock): """ Test listing index files with exception. Verifies that the endpoint returns an appropriate error response when an exception occurs during file listing. """ # Setup mocks - with patch("backend.apps.elasticsearch_app.get_es_core", return_value=es_core_mock), \ - patch("backend.apps.elasticsearch_app.ElasticSearchService.list_files") as mock_list_files: + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.list_files") as mock_list_files: index_name = "test_index" @@ -586,20 +581,20 @@ async def test_get_index_files_exception(es_core_mock): assert response.json() == {"detail": expected_error_detail} # Verify list_files was called with correct parameters - # Use ANY for the es_core parameter because the actual object may differ + # Use ANY for the vdb_core parameter because the actual object may differ mock_list_files.assert_called_once_with( - index_name, include_chunks=False, es_core=ANY) + index_name, include_chunks=False, vdb_core=ANY) @pytest.mark.asyncio -async def test_get_index_files_validation_exception(es_core_mock): +async def test_get_index_files_validation_exception(vdb_core_mock): """ Test listing index files with validation exception. Verifies that the endpoint returns an appropriate error response when index validation fails. """ # Setup mocks - with patch("backend.apps.elasticsearch_app.get_es_core", return_value=es_core_mock), \ - patch("backend.apps.elasticsearch_app.ElasticSearchService.list_files") as mock_list_files: + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.list_files") as mock_list_files: index_name = "test_index" @@ -621,14 +616,14 @@ async def test_get_index_files_validation_exception(es_core_mock): @pytest.mark.asyncio -async def test_get_index_files_timeout_exception(es_core_mock): +async def test_get_index_files_timeout_exception(vdb_core_mock): """ Test listing index files with timeout exception. Verifies that the endpoint returns an appropriate error response when operation times out. """ # Setup mocks - with patch("backend.apps.elasticsearch_app.get_es_core", return_value=es_core_mock), \ - patch("backend.apps.elasticsearch_app.ElasticSearchService.list_files") as mock_list_files: + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.list_files") as mock_list_files: index_name = "test_index" @@ -650,14 +645,14 @@ async def test_get_index_files_timeout_exception(es_core_mock): @pytest.mark.asyncio -async def test_get_index_files_permission_exception(es_core_mock): +async def test_get_index_files_permission_exception(vdb_core_mock): """ Test listing index files with permission exception. Verifies that the endpoint returns an appropriate error response when permission is denied. """ # Setup mocks - with patch("backend.apps.elasticsearch_app.get_es_core", return_value=es_core_mock), \ - patch("backend.apps.elasticsearch_app.ElasticSearchService.list_files") as mock_list_files: + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.list_files") as mock_list_files: index_name = "test_index" @@ -679,14 +674,75 @@ async def test_get_index_files_permission_exception(es_core_mock): @pytest.mark.asyncio -async def test_health_check_success(es_core_mock): +async def test_get_index_chunks_success(vdb_core_mock): + """ + Test retrieving index chunks successfully. + Verifies that the endpoint forwards query params and returns the service payload. + """ + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.get_index_chunks") as mock_get_chunks: + + index_name = "test_index" + expected_response = { + "status": "success", + "message": "ok", + "chunks": [{"id": "1"}], + "total": 1, + "page": 2, + "page_size": 50, + } + mock_get_chunks.return_value = expected_response + + response = client.post( + f"/indices/{index_name}/chunks", + params={"page": 2, "page_size": 50, "path_or_url": "/foo"} + ) + + assert response.status_code == 200 + assert response.json() == expected_response + mock_get_chunks.assert_called_once_with( + index_name=index_name, + page=2, + page_size=50, + path_or_url="/foo", + vdb_core=ANY, + ) + + +@pytest.mark.asyncio +async def test_get_index_chunks_error(vdb_core_mock): + """ + Test retrieving index chunks with service error. + Ensures the endpoint maps the exception to HTTP 500. + """ + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.get_index_chunks") as mock_get_chunks: + + index_name = "test_index" + mock_get_chunks.side_effect = Exception("Chunk failure") + + response = client.post(f"/indices/{index_name}/chunks") + + assert response.status_code == 500 + assert response.json() == {"detail": "Error getting chunks: Chunk failure"} + mock_get_chunks.assert_called_once_with( + index_name=index_name, + page=None, + page_size=None, + path_or_url=None, + vdb_core=ANY, + ) + + +@pytest.mark.asyncio +async def test_health_check_success(vdb_core_mock): """ Test health check endpoint successfully. Using pytest-asyncio to properly handle async operations. """ # Setup mocks - with patch("backend.apps.elasticsearch_app.get_es_core", return_value=es_core_mock), \ - patch("backend.apps.elasticsearch_app.ElasticSearchService.health_check") as mock_health: + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.health_check") as mock_health: expected_response = {"status": "ok", "elasticsearch": "connected"} mock_health.return_value = expected_response @@ -700,13 +756,13 @@ async def test_health_check_success(es_core_mock): @pytest.mark.asyncio -async def test_check_knowledge_base_exist_success(es_core_mock, auth_data): +async def test_check_knowledge_base_exist_success(vdb_core_mock, auth_data): """ Test check knowledge base exist endpoint success. """ - with patch("backend.apps.elasticsearch_app.get_es_core", return_value=es_core_mock), \ - patch("backend.apps.elasticsearch_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ - patch("backend.apps.elasticsearch_app.check_knowledge_base_exist_impl") as mock_impl: + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.check_knowledge_base_exist_impl") as mock_impl: expected_response = {"exist": True, "scope": "tenant"} mock_impl.return_value = expected_response @@ -719,13 +775,13 @@ async def test_check_knowledge_base_exist_success(es_core_mock, auth_data): @pytest.mark.asyncio -async def test_check_knowledge_base_exist_error(es_core_mock, auth_data): +async def test_check_knowledge_base_exist_error(vdb_core_mock, auth_data): """ Test check knowledge base exist endpoint error path. """ - with patch("backend.apps.elasticsearch_app.get_es_core", return_value=es_core_mock), \ - patch("backend.apps.elasticsearch_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ - patch("backend.apps.elasticsearch_app.check_knowledge_base_exist_impl") as mock_impl: + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.check_knowledge_base_exist_impl") as mock_impl: mock_impl.side_effect = Exception("Test error") @@ -738,15 +794,15 @@ async def test_check_knowledge_base_exist_error(es_core_mock, auth_data): @pytest.mark.asyncio -async def test_delete_index_exception(es_core_mock, auth_data): +async def test_delete_index_exception(vdb_core_mock, auth_data): """ Test deleting an index with exception. Verifies that the endpoint returns an appropriate error response when an exception occurs during deletion. """ # Setup mocks - with patch("backend.apps.elasticsearch_app.get_es_core", return_value=es_core_mock), \ - patch("backend.apps.elasticsearch_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ - patch("backend.apps.elasticsearch_app.ElasticSearchService.full_delete_knowledge_base") as mock_full_delete: + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.full_delete_knowledge_base") as mock_full_delete: # Setup the mock to raise an exception mock_full_delete.side_effect = Exception("Database connection failed") @@ -765,20 +821,20 @@ async def test_delete_index_exception(es_core_mock, auth_data): # Verify full_delete_knowledge_base was called with the correct parameters mock_full_delete.assert_called_once_with( auth_data["index_name"], - ANY, # Use ANY instead of es_core_mock to ignore object identity + ANY, # Use ANY instead of vdb_core_mock to ignore object identity auth_data["user_id"] ) @pytest.mark.asyncio -async def test_delete_index_auth_exception(es_core_mock, auth_data): +async def test_delete_index_auth_exception(vdb_core_mock, auth_data): """ Test deleting an index with authentication exception. Verifies that the endpoint returns an appropriate error response when authentication fails. """ # Setup mocks - with patch("backend.apps.elasticsearch_app.get_es_core", return_value=es_core_mock), \ - patch("backend.apps.elasticsearch_app.get_current_user_id") as mock_get_user: + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.get_current_user_id") as mock_get_user: # Setup the mock to raise an authentication exception mock_get_user.side_effect = Exception("Invalid authorization token") @@ -799,15 +855,15 @@ async def test_delete_index_auth_exception(es_core_mock, auth_data): @pytest.mark.asyncio -async def test_delete_documents_success(es_core_mock, redis_service_mock): +async def test_delete_documents_success(vdb_core_mock, redis_service_mock): """ Test deleting documents successfully. Verifies that the endpoint returns the expected response and performs Redis cleanup. """ # Setup mocks - with patch("backend.apps.elasticsearch_app.get_es_core", return_value=es_core_mock), \ - patch("backend.apps.elasticsearch_app.get_redis_service", return_value=redis_service_mock), \ - patch("backend.apps.elasticsearch_app.ElasticSearchService.delete_documents") as mock_delete_docs: + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.get_redis_service", return_value=redis_service_mock), \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.delete_documents") as mock_delete_docs: index_name = "test_index" path_or_url = "test_document.pdf" @@ -852,22 +908,22 @@ async def test_delete_documents_success(es_core_mock, redis_service_mock): assert actual_response["redis_cleanup"] == redis_result # Verify delete_documents was called with the correct parameters - # Use ANY for the es_core parameter because the actual object may differ + # Use ANY for the vdb_core parameter because the actual object may differ mock_delete_docs.assert_called_once_with(index_name, path_or_url, ANY) redis_service_mock.delete_document_records.assert_called_once_with( index_name, path_or_url) @pytest.mark.asyncio -async def test_delete_documents_redis_error(es_core_mock, redis_service_mock): +async def test_delete_documents_redis_error(vdb_core_mock, redis_service_mock): """ Test deleting documents with Redis error. Verifies that the endpoint still succeeds with ES but reports Redis cleanup error. """ # Setup mocks - with patch("backend.apps.elasticsearch_app.get_es_core", return_value=es_core_mock), \ - patch("backend.apps.elasticsearch_app.get_redis_service", return_value=redis_service_mock), \ - patch("backend.apps.elasticsearch_app.ElasticSearchService.delete_documents") as mock_delete_docs: + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.get_redis_service", return_value=redis_service_mock), \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.delete_documents") as mock_delete_docs: index_name = "test_index" path_or_url = "test_document.pdf" @@ -906,21 +962,21 @@ async def test_delete_documents_redis_error(es_core_mock, redis_service_mock): assert actual_response["redis_cleanup_error"] == redis_error_message # Verify delete_documents was called - # Use ANY for the es_core parameter because the actual object may differ + # Use ANY for the vdb_core parameter because the actual object may differ mock_delete_docs.assert_called_once_with(index_name, path_or_url, ANY) redis_service_mock.delete_document_records.assert_called_once_with( index_name, path_or_url) @pytest.mark.asyncio -async def test_delete_documents_es_exception(es_core_mock): +async def test_delete_documents_es_exception(vdb_core_mock): """ Test deleting documents with Elasticsearch exception. Verifies that the endpoint returns an appropriate error response when ES deletion fails. """ # Setup mocks - with patch("backend.apps.elasticsearch_app.get_es_core", return_value=es_core_mock), \ - patch("backend.apps.elasticsearch_app.ElasticSearchService.delete_documents") as mock_delete_docs: + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.delete_documents") as mock_delete_docs: index_name = "test_index" path_or_url = "test_document.pdf" @@ -941,20 +997,20 @@ async def test_delete_documents_es_exception(es_core_mock): assert response.json() == {"detail": expected_error_detail} # Verify delete_documents was called - # Use ANY for the es_core parameter because the actual object may differ + # Use ANY for the vdb_core parameter because the actual object may differ mock_delete_docs.assert_called_once_with(index_name, path_or_url, ANY) @pytest.mark.asyncio -async def test_delete_documents_redis_warnings(es_core_mock, redis_service_mock): +async def test_delete_documents_redis_warnings(vdb_core_mock, redis_service_mock): """ Test deleting documents with Redis warnings. Verifies that the endpoint handles Redis warnings properly. """ # Setup mocks - with patch("backend.apps.elasticsearch_app.get_es_core", return_value=es_core_mock), \ - patch("backend.apps.elasticsearch_app.get_redis_service", return_value=redis_service_mock), \ - patch("backend.apps.elasticsearch_app.ElasticSearchService.delete_documents") as mock_delete_docs: + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.get_redis_service", return_value=redis_service_mock), \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.delete_documents") as mock_delete_docs: index_name = "test_index" path_or_url = "test_document.pdf" @@ -1000,21 +1056,21 @@ async def test_delete_documents_redis_warnings(es_core_mock, redis_service_mock) "Some cache keys could not be deleted"] # Verify delete_documents was called - # Use ANY for the es_core parameter because the actual object may differ + # Use ANY for the vdb_core parameter because the actual object may differ mock_delete_docs.assert_called_once_with(index_name, path_or_url, ANY) redis_service_mock.delete_document_records.assert_called_once_with( index_name, path_or_url) @pytest.mark.asyncio -async def test_delete_documents_validation_exception(es_core_mock): +async def test_delete_documents_validation_exception(vdb_core_mock): """ Test deleting documents with validation exception. Verifies that the endpoint returns an appropriate error response when validation fails. """ # Setup mocks - with patch("backend.apps.elasticsearch_app.get_es_core", return_value=es_core_mock), \ - patch("backend.apps.elasticsearch_app.ElasticSearchService.delete_documents") as mock_delete_docs: + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.delete_documents") as mock_delete_docs: index_name = "test_index" path_or_url = "test_document.pdf" @@ -1035,19 +1091,19 @@ async def test_delete_documents_validation_exception(es_core_mock): assert response.json() == {"detail": expected_error_detail} # Verify delete_documents was called - # Use ANY for the es_core parameter because the actual object may differ + # Use ANY for the vdb_core parameter because the actual object may differ mock_delete_docs.assert_called_once_with(index_name, path_or_url, ANY) @pytest.mark.asyncio -async def test_health_check_exception(es_core_mock): +async def test_health_check_exception(vdb_core_mock): """ Test health check endpoint with exception. Verifies that the endpoint returns an appropriate error response when an exception occurs during health check. """ # Setup mocks - with patch("backend.apps.elasticsearch_app.get_es_core", return_value=es_core_mock), \ - patch("backend.apps.elasticsearch_app.ElasticSearchService.health_check") as mock_health: + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.health_check") as mock_health: # Setup the mock to raise an exception mock_health.side_effect = Exception("Elasticsearch connection failed") @@ -1062,19 +1118,19 @@ async def test_health_check_exception(es_core_mock): assert response.json() == {"detail": expected_error_detail} # Verify health_check was called - # Use ANY for the es_core parameter because the actual object may differ + # Use ANY for the vdb_core parameter because the actual object may differ mock_health.assert_called_once_with(ANY) @pytest.mark.asyncio -async def test_health_check_timeout_exception(es_core_mock): +async def test_health_check_timeout_exception(vdb_core_mock): """ Test health check endpoint with timeout exception. Verifies that the endpoint returns an appropriate error response when operation times out. """ # Setup mocks - with patch("backend.apps.elasticsearch_app.get_es_core", return_value=es_core_mock), \ - patch("backend.apps.elasticsearch_app.ElasticSearchService.health_check") as mock_health: + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.health_check") as mock_health: # Setup the mock to raise a timeout exception mock_health.side_effect = TimeoutError("Health check timed out") @@ -1094,14 +1150,14 @@ async def test_health_check_timeout_exception(es_core_mock): @pytest.mark.asyncio -async def test_health_check_connection_exception(es_core_mock): +async def test_health_check_connection_exception(vdb_core_mock): """ Test health check endpoint with connection exception. Verifies that the endpoint returns an appropriate error response when connection fails. """ # Setup mocks - with patch("backend.apps.elasticsearch_app.get_es_core", return_value=es_core_mock), \ - patch("backend.apps.elasticsearch_app.ElasticSearchService.health_check") as mock_health: + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.health_check") as mock_health: # Setup the mock to raise a connection exception mock_health.side_effect = ConnectionError( @@ -1122,14 +1178,14 @@ async def test_health_check_connection_exception(es_core_mock): @pytest.mark.asyncio -async def test_health_check_permission_exception(es_core_mock): +async def test_health_check_permission_exception(vdb_core_mock): """ Test health check endpoint with permission exception. Verifies that the endpoint returns an appropriate error response when permission is denied. """ # Setup mocks - with patch("backend.apps.elasticsearch_app.get_es_core", return_value=es_core_mock), \ - patch("backend.apps.elasticsearch_app.ElasticSearchService.health_check") as mock_health: + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.health_check") as mock_health: # Setup the mock to raise a permission exception mock_health.side_effect = PermissionError( @@ -1150,14 +1206,14 @@ async def test_health_check_permission_exception(es_core_mock): @pytest.mark.asyncio -async def test_health_check_validation_exception(es_core_mock): +async def test_health_check_validation_exception(vdb_core_mock): """ Test health check endpoint with validation exception. Verifies that the endpoint returns an appropriate error response when validation fails. """ # Setup mocks - with patch("backend.apps.elasticsearch_app.get_es_core", return_value=es_core_mock), \ - patch("backend.apps.elasticsearch_app.ElasticSearchService.health_check") as mock_health: + with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \ + patch("backend.apps.vectordatabase_app.ElasticSearchService.health_check") as mock_health: # Setup the mock to raise a validation exception mock_health.side_effect = ValueError( diff --git a/test/backend/app/test_voice_app.py b/test/backend/app/test_voice_app.py index af28b7c26..87d17f294 100644 --- a/test/backend/app/test_voice_app.py +++ b/test/backend/app/test_voice_app.py @@ -25,7 +25,7 @@ def __init__(self): # Now import the app under test -from apps.voice_app import router +from apps.voice_app import voice_runtime_router, voice_config_router class TestVoiceApp: @@ -34,7 +34,8 @@ class TestVoiceApp: def setup_method(self): """Set up test fixtures""" self.app = FastAPI() - self.app.include_router(router) + self.app.include_router(voice_runtime_router) + self.app.include_router(voice_config_router) self.client = TestClient(self.app) def test_stt_websocket_success(self): @@ -276,7 +277,8 @@ class TestVoiceAppIntegration: def setup_method(self): """Set up test fixtures""" self.app = FastAPI() - self.app.include_router(router) + self.app.include_router(voice_runtime_router) + self.app.include_router(voice_config_router) self.client = TestClient(self.app) def test_voice_connectivity_real_logic_stt(self): diff --git a/test/backend/data_process/test_tasks.py b/test/backend/data_process/test_tasks.py index da2a71aee..42a086347 100644 --- a/test/backend/data_process/test_tasks.py +++ b/test/backend/data_process/test_tasks.py @@ -595,7 +595,7 @@ def test_forward_skips_empty_chunk_without_preprocess(monkeypatch): # We asserted path executed; exact stored count depends on implementation but should not error -def test_forward_index_documents_client_connector_error(monkeypatch): +def test_forward_vectorize_documents_client_connector_error(monkeypatch): tasks, _ = import_tasks_with_fake_ray(monkeypatch) monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") # Speed up retries @@ -649,7 +649,7 @@ def __init__(self, status=None): json.loads(str(ei.value)) -def test_forward_index_documents_client_response_503(monkeypatch): +def test_forward_vectorize_documents_client_response_503(monkeypatch): tasks, _ = import_tasks_with_fake_ray(monkeypatch) monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") @@ -731,7 +731,7 @@ def test_forward_api_returns_error_and_unexpected_format(monkeypatch): json.loads(str(ei2.value)) -def test_forward_index_documents_timeout_error(monkeypatch): +def test_forward_vectorize_documents_timeout_error(monkeypatch): tasks, _ = import_tasks_with_fake_ray(monkeypatch) monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") @@ -789,7 +789,7 @@ class DummyClientConnectorError(Exception): json.loads(str(ei.value)) -def test_forward_index_documents_unexpected_error(monkeypatch): +def test_forward_vectorize_documents_unexpected_error(monkeypatch): tasks, _ = import_tasks_with_fake_ray(monkeypatch) monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") diff --git a/test/backend/services/test_agent_service.py b/test/backend/services/test_agent_service.py index 995f9aad2..69f0a1c56 100644 --- a/test/backend/services/test_agent_service.py +++ b/test/backend/services/test_agent_service.py @@ -82,6 +82,8 @@ def _mock_context(): run_agent_stream, stop_agent_tasks, _resolve_user_tenant_language, + _apply_duplicate_name_availability_rules, + _check_single_model_availability, ) from consts.model import ExportAndImportAgentInfo, ExportAndImportDataFormat, MCPInfo, AgentRequest @@ -1554,10 +1556,12 @@ async def test_list_all_agent_info_impl_success(): assert result[0]["name"] == "Agent 1" assert result[0]["display_name"] == "Display Agent 1" assert result[0]["is_available"] == True + assert result[0]["unavailable_reasons"] == [] assert result[1]["agent_id"] == 2 assert result[1]["name"] == "Agent 2" assert result[1]["display_name"] == "Display Agent 2" assert result[1]["is_available"] == True + assert result[1]["unavailable_reasons"] == [] # Verify mock calls mock_query_agents.assert_called_once_with(tenant_id="test_tenant") @@ -1619,7 +1623,9 @@ async def test_list_all_agent_info_impl_with_unavailable_tools(): # Assert assert len(result) == 2 assert result[0]["is_available"] == True + assert result[0]["unavailable_reasons"] == [] assert result[1]["is_available"] == False + assert result[1]["unavailable_reasons"] == ["tool_unavailable"] # Verify mock calls mock_query_agents.assert_called_once_with(tenant_id="test_tenant") @@ -1647,6 +1653,74 @@ async def test_list_all_agent_info_impl_query_error(): mock_query_agents.assert_called_once_with(tenant_id="test_tenant") +async def test_list_all_agent_info_impl_model_unavailable(): + mock_agents = [ + { + "agent_id": 1, + "name": "Agent 1", + "display_name": "Display Agent 1", + "description": "Agent with unavailable model", + "enabled": True, + "model_id": 101 + } + ] + + with patch('backend.services.agent_service.query_all_agent_info_by_tenant_id') as mock_query_agents, \ + patch('backend.services.agent_service.search_tools_for_sub_agent') as mock_search_tools, \ + patch('backend.services.agent_service.get_model_by_model_id') as mock_get_model: + mock_query_agents.return_value = mock_agents + mock_search_tools.return_value = [] + mock_get_model.return_value = { + "connect_status": agent_service.ModelConnectStatusEnum.UNAVAILABLE.value + } + + result = await list_all_agent_info_impl(tenant_id="test_tenant") + + assert len(result) == 1 + assert result[0]["is_available"] is False + assert result[0]["unavailable_reasons"] == ["model_unavailable"] + + +async def test_list_all_agent_info_impl_duplicate_names(): + mock_agents = [ + { + "agent_id": 1, + "name": "Duplicated", + "create_time": "2024-01-01T00:00:00", + "display_name": "Agent Display 1", + "description": "First agent", + "enabled": True + }, + { + "agent_id": 2, + "name": "Duplicated", + "create_time": "2024-02-01T00:00:00", + "display_name": "Agent Display 2", + "description": "Second agent", + "enabled": True + } + ] + + with patch('backend.services.agent_service.query_all_agent_info_by_tenant_id') as mock_query_agents, \ + patch('backend.services.agent_service.search_tools_for_sub_agent') as mock_search_tools: + mock_query_agents.return_value = mock_agents + mock_search_tools.return_value = [] + + result = await list_all_agent_info_impl(tenant_id="test_tenant") + + assert len(result) == 2 + + # The earliest created agent (agent_id=1) should remain available + agent1 = next(a for a in result if a["agent_id"] == 1) + assert agent1["is_available"] is True + assert "duplicate_name" not in agent1["unavailable_reasons"] + + # The later created agent (agent_id=2) should be unavailable due to duplication + agent2 = next(a for a in result if a["agent_id"] == 2) + assert agent2["is_available"] is False + assert "duplicate_name" in agent2["unavailable_reasons"] + + @patch('backend.services.agent_service.query_sub_agents_id_list') @patch('backend.services.agent_service.create_tool_config_list', new_callable=AsyncMock) @patch('backend.services.agent_service.search_agent_info_by_agent_id') @@ -3896,6 +3970,67 @@ async def test_list_all_agent_info_impl_all_disabled_agents(): mock_check_tools.assert_not_called() +def test_apply_duplicate_name_availability_rules_handles_missing_fields(): + """ + Ensure duplicate detection gracefully handles agents without name/display_name. + """ + enriched_agents = [ + { + "raw_agent": { + "agent_id": 1, + "name": None, + "display_name": None, + "create_time": "2024-01-01T00:00:00", + }, + "unavailable_reasons": [], + }, + { + "raw_agent": { + "agent_id": 2, + "name": "dup", + "display_name": None, + "create_time": "2024-01-01T00:00:00", + }, + "unavailable_reasons": [], + }, + { + "raw_agent": { + "agent_id": 3, + "name": "dup", + "display_name": None, + "create_time": "2024-02-01T00:00:00", + }, + "unavailable_reasons": [], + }, + { + "raw_agent": { + "agent_id": 4, + "name": None, + "display_name": "display-dup", + "create_time": "2024-01-01T00:00:00", + }, + "unavailable_reasons": [], + }, + { + "raw_agent": { + "agent_id": 5, + "name": None, + "display_name": "display-dup", + "create_time": "2024-02-01T00:00:00", + }, + "unavailable_reasons": [], + }, + ] + + _apply_duplicate_name_availability_rules(enriched_agents) + + assert enriched_agents[0]["unavailable_reasons"] == [] + assert "duplicate_name" not in enriched_agents[1]["unavailable_reasons"] + assert "duplicate_name" in enriched_agents[2]["unavailable_reasons"] + assert "duplicate_display_name" not in enriched_agents[3]["unavailable_reasons"] + assert "duplicate_display_name" in enriched_agents[4]["unavailable_reasons"] + + # ============================================================================ # Tests for Agent Export/Import Integration with model_name fields # ============================================================================ @@ -5348,3 +5483,60 @@ async def test_resolve_model_quick_config_exception( model_label="Model", tenant_id="tenant_011" ) + + +def test_check_single_model_availability_no_model_id(): + reasons = _check_single_model_availability( + model_id=None, + tenant_id="tenant", + model_cache={}, + reason_key="model_unavailable", + ) + assert reasons == [] + + +@patch("backend.services.agent_service.get_model_by_model_id") +def test_check_single_model_availability_fetches_and_handles_missing_model(mock_get_model): + model_cache = {} + mock_get_model.return_value = None + + reasons = _check_single_model_availability( + model_id=123, + tenant_id="tenant", + model_cache=model_cache, + reason_key="model_unavailable", + ) + + assert reasons == ["model_unavailable"] + assert 123 in model_cache + mock_get_model.assert_called_once_with(123, "tenant") + + +def test_check_single_model_availability_uses_cached_unavailable_model(): + model_cache = { + 456: {"connect_status": agent_service.ModelConnectStatusEnum.UNAVAILABLE.value} + } + + reasons = _check_single_model_availability( + model_id=456, + tenant_id="tenant", + model_cache=model_cache, + reason_key="model_unavailable", + ) + + assert reasons == ["model_unavailable"] + + +def test_check_single_model_availability_returns_empty_for_available_model(): + model_cache = { + 789: {"connect_status": agent_service.ModelConnectStatusEnum.AVAILABLE.value} + } + + reasons = _check_single_model_availability( + model_id=789, + tenant_id="tenant", + model_cache=model_cache, + reason_key="model_unavailable", + ) + + assert reasons == [] diff --git a/test/backend/services/test_file_management_service.py b/test/backend/services/test_file_management_service.py index c3c907e1e..4eabb4cae 100644 --- a/test/backend/services/test_file_management_service.py +++ b/test/backend/services/test_file_management_service.py @@ -45,23 +45,23 @@ services_stub.__path__ = [] # Mark as package sys.modules.setdefault('services', services_stub) -es_stub = types.ModuleType('services.elasticsearch_service') +vdb_stub = types.ModuleType('services.vectordatabase_service') class _StubElasticSearchService: @staticmethod - async def list_files(index_name, include_chunks=False, es_core=None): + async def list_files(index_name, include_chunks=False, vdb_core=None): return {"files": []} -def _stub_get_es_core(): +def _stub_get_vector_db_core(): return None -es_stub.ElasticSearchService = _StubElasticSearchService -es_stub.get_es_core = _stub_get_es_core -sys.modules['services.elasticsearch_service'] = es_stub -setattr(services_stub, 'elasticsearch_service', es_stub) +vdb_stub.ElasticSearchService = _StubElasticSearchService +vdb_stub.get_vector_db_core = _stub_get_vector_db_core +sys.modules['services.vectordatabase_service'] = vdb_stub +setattr(services_stub, 'vectordatabase_service', vdb_stub) # Import the service module after mocking external dependencies file_management_service = importlib.import_module( @@ -311,7 +311,7 @@ async def test_upload_files_impl_minio_conflict_resolution(self): } with patch('backend.services.file_management_service.upload_to_minio', AsyncMock(return_value=minio_return)) as mock_upload, \ - patch('backend.services.file_management_service.get_es_core', MagicMock()) as mock_es_core, \ + patch('backend.services.file_management_service.get_vector_db_core', MagicMock()) as mock_vdb_core, \ patch('backend.services.file_management_service.ElasticSearchService.list_files', AsyncMock(return_value=existing)) as mock_list: errors, uploaded_paths, uploaded_names = await upload_files_impl( @@ -342,7 +342,7 @@ async def test_upload_files_impl_minio_conflict_resolution_case_insensitive_dupl existing = {"files": [{"file": "doc.pdf"}]} with patch('backend.services.file_management_service.upload_to_minio', AsyncMock(return_value=minio_return)), \ - patch('backend.services.file_management_service.get_es_core', MagicMock()), \ + patch('backend.services.file_management_service.get_vector_db_core', MagicMock()), \ patch('backend.services.file_management_service.ElasticSearchService.list_files', AsyncMock(return_value=existing)): errors, uploaded_paths, uploaded_names = await upload_files_impl( @@ -364,7 +364,7 @@ async def test_upload_files_impl_minio_conflict_resolution_es_exception(self): ] with patch('backend.services.file_management_service.upload_to_minio', AsyncMock(return_value=minio_return)), \ - patch('backend.services.file_management_service.get_es_core', MagicMock()), \ + patch('backend.services.file_management_service.get_vector_db_core', MagicMock()), \ patch('backend.services.file_management_service.ElasticSearchService.list_files', AsyncMock(side_effect=Exception("boom"))), \ patch('backend.services.file_management_service.logger') as mock_logger: diff --git a/test/backend/services/test_model_management_service.py b/test/backend/services/test_model_management_service.py index 9a6dc73d1..8d5cdcd4d 100644 --- a/test/backend/services/test_model_management_service.py +++ b/test/backend/services/test_model_management_service.py @@ -93,7 +93,7 @@ def get_value(status): consts_const_mod.LOCALHOST_IP = "127.0.0.1" consts_const_mod.LOCALHOST_NAME = "localhost" consts_const_mod.DOCKER_INTERNAL_HOST = "host.docker.internal" -# Fields required by utils.memory_utils and services.elasticsearch_service +# Fields required by utils.memory_utils and services.vectordatabase_service consts_const_mod.MODEL_CONFIG_MAPPING = { "llm": "LLM_ID", "embedding": "EMBEDDING_ID"} consts_const_mod.ES_HOST = "http://localhost:9200" @@ -255,16 +255,16 @@ def _update_config_by_tenant_config_id_and_data(*args, **kwargs): db_tenant_cfg_mod.update_config_by_tenant_config_id_and_data = _update_config_by_tenant_config_id_and_data sys.modules["database.tenant_config_db"] = db_tenant_cfg_mod -# Stub services.elasticsearch_service to avoid heavy imports -services_es_mod = types.ModuleType("services.elasticsearch_service") +# Stub services.vectordatabase_service to avoid heavy imports +services_vdb_mod = types.ModuleType("services.vectordatabase_service") -def _get_es_core(): +def _get_vector_db_core(): return object() -services_es_mod.get_es_core = _get_es_core -sys.modules["services.elasticsearch_service"] = services_es_mod +services_vdb_mod.get_vector_db_core = _get_vector_db_core +sys.modules["services.vectordatabase_service"] = services_vdb_mod # Stub nexent.memory.memory_service.clear_model_memories nexent_memory_mod = types.ModuleType("nexent.memory.memory_service") @@ -583,13 +583,13 @@ async def test_delete_model_for_tenant_embedding_deletes_both(): ] with mock.patch.object(svc, "get_model_by_display_name", side_effect=side_effect) as mock_get, \ mock.patch.object(svc, "delete_model_record") as mock_delete, \ - mock.patch.object(svc, "get_es_core", return_value=object()) as mock_get_es, \ + mock.patch.object(svc, "get_vector_db_core", return_value=object()) as mock_get_vdb, \ mock.patch.object(svc, "build_memory_config_for_tenant", return_value={}) as mock_build_cfg, \ mock.patch.object(svc, "clear_model_memories", new=mock.AsyncMock()) as mock_clear: await svc.delete_model_for_tenant("u1", "t1", "name") assert mock_delete.call_count == 2 mock_get.assert_called() - mock_get_es.assert_called_once() + mock_get_vdb.assert_called_once() mock_build_cfg.assert_called_once_with("t1") # Best-effort cleanup may call once or twice depending on state assert mock_clear.await_count >= 1 @@ -606,7 +606,7 @@ async def test_delete_model_for_tenant_cleanup_inner_exception(caplog): ] with mock.patch.object(svc, "get_model_by_display_name", side_effect=side_effect), \ mock.patch.object(svc, "delete_model_record") as mock_delete, \ - mock.patch.object(svc, "get_es_core", return_value=object()), \ + mock.patch.object(svc, "get_vector_db_core", return_value=object()), \ mock.patch.object(svc, "build_memory_config_for_tenant", return_value={}), \ mock.patch.object(svc, "clear_model_memories", new=mock.AsyncMock(side_effect=Exception("boom"))): @@ -629,7 +629,7 @@ async def test_delete_model_for_tenant_cleanup_outer_exception(caplog): ] with mock.patch.object(svc, "get_model_by_display_name", side_effect=side_effect), \ mock.patch.object(svc, "delete_model_record") as mock_delete, \ - mock.patch.object(svc, "get_es_core", side_effect=Exception("es_down")), \ + mock.patch.object(svc, "get_vector_db_core", side_effect=Exception("vdb_down")), \ mock.patch.object(svc, "build_memory_config_for_tenant", return_value={}): with caplog.at_level(logging.WARNING): diff --git a/test/backend/services/test_prompt_service.py b/test/backend/services/test_prompt_service.py index 0646b1435..b1345f720 100644 --- a/test/backend/services/test_prompt_service.py +++ b/test/backend/services/test_prompt_service.py @@ -36,8 +36,6 @@ call_llm_for_system_prompt, generate_and_save_system_prompt_impl, gen_system_prompt_streamable, - get_enabled_tool_description_for_generate_prompt, - get_enabled_sub_agent_description_for_generate_prompt, generate_system_prompt, join_info_for_generate_system_prompt, _process_thinking_tokens @@ -91,21 +89,21 @@ def test_call_llm_for_system_prompt(self, mock_get_model_name, mock_openai, mock ) @patch('backend.services.prompt_service.generate_system_prompt') - @patch('backend.services.prompt_service.get_enabled_sub_agent_description_for_generate_prompt') - @patch('backend.services.prompt_service.get_enabled_tool_description_for_generate_prompt') + @patch('backend.services.prompt_service.query_tools_by_ids') + @patch('backend.services.prompt_service.search_agent_info_by_agent_id') @patch('backend.services.prompt_service.update_agent') - def test_generate_and_save_system_prompt_impl(self, mock_update_agent, mock_get_tool_desc, - mock_get_agent_desc, mock_generate_system_prompt): + def test_generate_and_save_system_prompt_impl(self, mock_update_agent, mock_search_agent_info, + mock_query_tools, mock_generate_system_prompt): # Setup mock_tool1 = {"name": "tool1", "description": "Tool 1 desc", "inputs": "input1", "output_type": "output1"} mock_tool2 = {"name": "tool2", "description": "Tool 2 desc", "inputs": "input2", "output_type": "output2"} - mock_get_tool_desc.return_value = [mock_tool1, mock_tool2] + mock_query_tools.return_value = [mock_tool1, mock_tool2] mock_agent1 = {"name": "agent1", "description": "Agent 1 desc"} mock_agent2 = {"name": "agent2", "description": "Agent 2 desc"} - mock_get_agent_desc.return_value = [mock_agent1, mock_agent2] + mock_search_agent_info.side_effect = [mock_agent1, mock_agent2] # Mock the generator to return the expected data structure def mock_generator(*args, **kwargs): @@ -121,33 +119,34 @@ def mock_generator(*args, **kwargs): mock_generate_system_prompt.side_effect = mock_generator - # Execute - test as a generator + # Execute - test as a generator with frontend-provided IDs result_gen = generate_and_save_system_prompt_impl( agent_id=123, model_id=self.test_model_id, task_description="Test task", user_id="user123", tenant_id="tenant456", - language="zh" + language="zh", + tool_ids=[1, 2], + sub_agent_ids=[10, 20] ) result = list(result_gen) # Convert generator to list for assertion # Assert self.assertGreater(len(result), 0) - mock_get_tool_desc.assert_called_once_with( - agent_id=123, tenant_id="tenant456") - mock_get_agent_desc.assert_called_once_with( - agent_id=123, tenant_id="tenant456") + # Verify tools and agents were queried using frontend-provided IDs + mock_query_tools.assert_called_once_with([1, 2]) + self.assertEqual(mock_search_agent_info.call_count, 2) + mock_search_agent_info.assert_any_call(agent_id=10, tenant_id="tenant456") + mock_search_agent_info.assert_any_call(agent_id=20, tenant_id="tenant456") - mock_generate_system_prompt.assert_called_once_with( - mock_get_agent_desc.return_value, - "Test task", - mock_get_tool_desc.return_value, - "tenant456", - self.test_model_id, - "zh" - ) + # Verify generate_system_prompt was called with correct parameters + mock_generate_system_prompt.assert_called_once() + call_args = mock_generate_system_prompt.call_args + self.assertEqual(call_args[0][0], [mock_agent1, mock_agent2]) # sub_agent_info_list + self.assertEqual(call_args[0][1], "Test task") # task_description + self.assertEqual(call_args[0][2], [mock_tool1, mock_tool2]) # tool_info_list # Verify update_agent was called with the correct parameters mock_update_agent.assert_called_once() @@ -162,11 +161,8 @@ def mock_generator(*args, **kwargs): self.assertEqual(agent_info.business_description, "Test task") @patch('backend.services.prompt_service.generate_system_prompt') - @patch('backend.services.prompt_service.get_enabled_sub_agent_description_for_generate_prompt') - @patch('backend.services.prompt_service.get_enabled_tool_description_for_generate_prompt') @patch('backend.services.prompt_service.update_agent') - def test_generate_and_save_system_prompt_impl_create_mode(self, mock_update_agent, mock_get_tool_desc, - mock_get_agent_desc, mock_generate_system_prompt): + def test_generate_and_save_system_prompt_impl_create_mode(self, mock_update_agent, mock_generate_system_prompt): """Test generate_and_save_system_prompt_impl in create mode (agent_id=0)""" # Setup - Mock the generator to return the expected data structure def mock_generator(*args, **kwargs): @@ -182,25 +178,22 @@ def mock_generator(*args, **kwargs): mock_generate_system_prompt.side_effect = mock_generator - # Execute - test as a generator with agent_id=0 (create mode) + # Execute - test as a generator with agent_id=0 (create mode) and empty tool/sub-agent IDs result_gen = generate_and_save_system_prompt_impl( agent_id=0, model_id=self.test_model_id, task_description="Test task", user_id="user123", tenant_id="tenant456", - language="zh" + language="zh", + tool_ids=[], + sub_agent_ids=[] ) result = list(result_gen) # Convert generator to list for assertion # Assert self.assertGreater(len(result), 0) - # In create mode, should NOT call get_enabled_tool_description_for_generate_prompt - # and get_enabled_sub_agent_description_for_generate_prompt - mock_get_tool_desc.assert_not_called() - mock_get_agent_desc.assert_not_called() - # Should call generate_system_prompt with empty lists for tools and sub-agents mock_generate_system_prompt.assert_called_once_with( [], # Empty sub_agent_info_list @@ -246,7 +239,9 @@ def test_gen_system_prompt_streamable(self, mock_generate_impl): task_description="Test task", user_id="user123", tenant_id="tenant456", - language="zh" + language="zh", + tool_ids=None, + sub_agent_ids=None, ) # Verify output format - should be SSE format @@ -473,68 +468,6 @@ def test_join_info_for_generate_system_prompt(self, mock_template): self.assertEqual( template_vars["task_description"], mock_task_description) - @patch('backend.services.prompt_service.get_enable_tool_id_by_agent_id') - @patch('backend.services.prompt_service.query_tools_by_ids') - def test_get_enabled_tool_description_for_generate_prompt(self, mock_query_tools, mock_get_tool_ids): - # Setup - mock_get_tool_ids.return_value = [1, 2, 3] - mock_tools = [{"id": 1, "name": "tool1"}, { - "id": 2, "name": "tool2"}, {"id": 3, "name": "tool3"}] - mock_query_tools.return_value = mock_tools - - # Execute - result = get_enabled_tool_description_for_generate_prompt( - agent_id=123, - tenant_id="tenant456" - ) - - # Assert - self.assertEqual(result, mock_tools) - mock_get_tool_ids.assert_called_once_with( - agent_id=123, tenant_id="tenant456") - mock_query_tools.assert_called_once_with([1, 2, 3]) - - @patch('backend.services.prompt_service.search_agent_info_by_agent_id') - @patch('backend.services.prompt_service.query_sub_agents_id_list') - def test_get_enabled_sub_agent_description_for_generate_prompt(self, mock_query_sub_agents_id_list, mock_search_agent_info): - # Setup - mock_query_sub_agents_id_list.return_value = [1, 2, 3] - - # Mock search_agent_info_by_agent_id to return different agent info for each ID - def mock_search_agent_info_side_effect(agent_id, tenant_id): - agent_info_map = { - 1: {"id": 1, "name": "agent1", "enabled": True}, - 2: {"id": 2, "name": "agent2", "enabled": False}, - 3: {"id": 3, "name": "agent3", "enabled": True} - } - return agent_info_map.get(agent_id, {}) - - mock_search_agent_info.side_effect = mock_search_agent_info_side_effect - - # Execute - result = get_enabled_sub_agent_description_for_generate_prompt( - agent_id=123, - tenant_id="tenant456" - ) - - # Assert - expected_result = [ - {"id": 1, "name": "agent1", "enabled": True}, - {"id": 2, "name": "agent2", "enabled": False}, - {"id": 3, "name": "agent3", "enabled": True} - ] - self.assertEqual(result, expected_result) - mock_query_sub_agents_id_list.assert_called_once_with( - main_agent_id=123, tenant_id="tenant456") - - # Verify search_agent_info_by_agent_id was called for each sub agent ID - self.assertEqual(mock_search_agent_info.call_count, 3) - mock_search_agent_info.assert_any_call( - agent_id=1, tenant_id="tenant456") - mock_search_agent_info.assert_any_call( - agent_id=2, tenant_id="tenant456") - mock_search_agent_info.assert_any_call( - agent_id=3, tenant_id="tenant456") @patch('backend.services.prompt_service.get_model_by_model_id') @patch('backend.services.prompt_service.OpenAIServerModel') diff --git a/test/backend/services/test_tool_configuration_service.py b/test/backend/services/test_tool_configuration_service.py index 9c3383c27..75c1ce7a2 100644 --- a/test/backend/services/test_tool_configuration_service.py +++ b/test/backend/services/test_tool_configuration_service.py @@ -1603,8 +1603,8 @@ class TestValidateLocalToolKnowledgeBaseSearch: @patch('backend.services.tool_configuration_service.inspect.signature') @patch('backend.services.tool_configuration_service.get_selected_knowledge_list') @patch('backend.services.tool_configuration_service.get_embedding_model') - @patch('backend.services.tool_configuration_service.elastic_core') - def test_validate_local_tool_knowledge_base_search_success(self, mock_elastic_core, mock_get_embedding_model, + @patch('backend.services.tool_configuration_service.get_vector_db_core') + def test_validate_local_tool_knowledge_base_search_success(self, mock_get_vector_db_core, mock_get_embedding_model, mock_get_knowledge_list, mock_signature, mock_get_class): """Test successful knowledge_base_search tool validation with proper dependencies""" # Mock tool class @@ -1620,7 +1620,7 @@ def test_validate_local_tool_knowledge_base_search_success(self, mock_elastic_co mock_sig.parameters = { 'self': Mock(), 'index_names': Mock(), - 'es_core': Mock(), + 'vdb_core': Mock(), 'embedding_model': Mock() } mock_signature.return_value = mock_sig @@ -1632,7 +1632,8 @@ def test_validate_local_tool_knowledge_base_search_success(self, mock_elastic_co ] mock_get_knowledge_list.return_value = mock_knowledge_list mock_get_embedding_model.return_value = "mock_embedding_model" - # elastic_core is already a mock object, we don't need to set return_value + mock_vdb_core = Mock() + mock_get_vector_db_core.return_value = mock_vdb_core from backend.services.tool_configuration_service import _validate_local_tool @@ -1651,8 +1652,8 @@ def test_validate_local_tool_knowledge_base_search_success(self, mock_elastic_co expected_params = { "param": "config", "index_names": ["index1", "index2"], - "es_core": mock_elastic_core, # Use the mock object directly - "embedding_model": "mock_embedding_model" + "vdb_core": mock_vdb_core, + "embedding_model": "mock_embedding_model", } mock_tool_class.assert_called_once_with(**expected_params) mock_tool_instance.forward.assert_called_once_with(query="test query") @@ -1720,8 +1721,8 @@ def test_validate_local_tool_knowledge_base_search_missing_both_ids(self, mock_g @patch('backend.services.tool_configuration_service.inspect.signature') @patch('backend.services.tool_configuration_service.get_selected_knowledge_list') @patch('backend.services.tool_configuration_service.get_embedding_model') - @patch('backend.services.tool_configuration_service.elastic_core') - def test_validate_local_tool_knowledge_base_search_empty_knowledge_list(self, mock_elastic_core, + @patch('backend.services.tool_configuration_service.get_vector_db_core') + def test_validate_local_tool_knowledge_base_search_empty_knowledge_list(self, mock_get_vector_db_core, mock_get_embedding_model, mock_get_knowledge_list, mock_signature, @@ -1740,7 +1741,7 @@ def test_validate_local_tool_knowledge_base_search_empty_knowledge_list(self, mo mock_sig.parameters = { 'self': Mock(), 'index_names': Mock(), - 'es_core': Mock(), + 'vdb_core': Mock(), 'embedding_model': Mock() } mock_signature.return_value = mock_sig @@ -1748,7 +1749,8 @@ def test_validate_local_tool_knowledge_base_search_empty_knowledge_list(self, mo # Mock empty knowledge list mock_get_knowledge_list.return_value = [] mock_get_embedding_model.return_value = "mock_embedding_model" - # elastic_core is already a mock object, we don't need to set return_value + mock_vdb_core = Mock() + mock_get_vector_db_core.return_value = mock_vdb_core from backend.services.tool_configuration_service import _validate_local_tool @@ -1766,8 +1768,8 @@ def test_validate_local_tool_knowledge_base_search_empty_knowledge_list(self, mo expected_params = { "param": "config", "index_names": [], - "es_core": mock_elastic_core, # Use the mock object directly - "embedding_model": "mock_embedding_model" + "vdb_core": mock_vdb_core, + "embedding_model": "mock_embedding_model", } mock_tool_class.assert_called_once_with(**expected_params) mock_tool_instance.forward.assert_called_once_with(query="test query") @@ -1776,8 +1778,8 @@ def test_validate_local_tool_knowledge_base_search_empty_knowledge_list(self, mo @patch('backend.services.tool_configuration_service.inspect.signature') @patch('backend.services.tool_configuration_service.get_selected_knowledge_list') @patch('backend.services.tool_configuration_service.get_embedding_model') - @patch('backend.services.tool_configuration_service.elastic_core') - def test_validate_local_tool_knowledge_base_search_execution_error(self, mock_elastic_core, + @patch('backend.services.tool_configuration_service.get_vector_db_core') + def test_validate_local_tool_knowledge_base_search_execution_error(self, mock_get_vector_db_core, mock_get_embedding_model, mock_get_knowledge_list, mock_signature, @@ -1797,7 +1799,7 @@ def test_validate_local_tool_knowledge_base_search_execution_error(self, mock_el mock_sig.parameters = { 'self': Mock(), 'index_names': Mock(), - 'es_core': Mock(), + 'vdb_core': Mock(), 'embedding_model': Mock() } mock_signature.return_value = mock_sig @@ -1806,7 +1808,8 @@ def test_validate_local_tool_knowledge_base_search_execution_error(self, mock_el mock_knowledge_list = [{"index_name": "index1", "knowledge_id": "kb1"}] mock_get_knowledge_list.return_value = mock_knowledge_list mock_get_embedding_model.return_value = "mock_embedding_model" - # elastic_core is already a mock object, we don't need to set return_value + mock_vdb_core = Mock() + mock_get_vector_db_core.return_value = mock_vdb_core from backend.services.tool_configuration_service import _validate_local_tool diff --git a/test/backend/services/test_elasticsearch_service.py b/test/backend/services/test_vectordatabase_service.py similarity index 63% rename from test/backend/services/test_elasticsearch_service.py rename to test/backend/services/test_vectordatabase_service.py index 87f2b2481..d345b0c45 100644 --- a/test/backend/services/test_elasticsearch_service.py +++ b/test/backend/services/test_vectordatabase_service.py @@ -7,6 +7,7 @@ # Mock MinioClient before importing modules that use it from unittest.mock import patch import numpy as np +from types import ModuleType, SimpleNamespace from fastapi.responses import StreamingResponse @@ -22,20 +23,37 @@ sys.modules['boto3'] = boto3_mock # Mock nexent modules before importing modules that use them -nexent_mock = MagicMock() +def _create_package_mock(name: str) -> MagicMock: + pkg = MagicMock() + pkg.__path__ = [] # Mark as package for importlib + pkg.__spec__ = SimpleNamespace(name=name, submodule_search_locations=[]) + return pkg + + +nexent_mock = _create_package_mock('nexent') sys.modules['nexent'] = nexent_mock -sys.modules['nexent.core'] = MagicMock() -sys.modules['nexent.core.agents'] = MagicMock() +sys.modules['nexent.core'] = _create_package_mock('nexent.core') +sys.modules['nexent.core.agents'] = _create_package_mock('nexent.core.agents') sys.modules['nexent.core.agents.agent_model'] = MagicMock() -sys.modules['nexent.core.models'] = MagicMock() +sys.modules['nexent.core.models'] = _create_package_mock('nexent.core.models') sys.modules['nexent.core.models.embedding_model'] = MagicMock() sys.modules['nexent.core.models.stt_model'] = MagicMock() -sys.modules['nexent.core.nlp'] = MagicMock() +sys.modules['nexent.core.nlp'] = _create_package_mock('nexent.core.nlp') sys.modules['nexent.core.nlp.tokenizer'] = MagicMock() -sys.modules['nexent.vector_database'] = MagicMock() +sys.modules['nexent.vector_database'] = _create_package_mock('nexent.vector_database') +vector_db_base_module = ModuleType('nexent.vector_database.base') + + +class _VectorDatabaseCore: + """Lightweight stand-in for the real VectorDatabaseCore for import-time typing.""" + pass + + +vector_db_base_module.VectorDatabaseCore = _VectorDatabaseCore +sys.modules['nexent.vector_database.base'] = vector_db_base_module sys.modules['nexent.vector_database.elasticsearch_core'] = MagicMock() # Mock nexent.storage module and its submodules before any imports -sys.modules['nexent.storage'] = MagicMock() +sys.modules['nexent.storage'] = _create_package_mock('nexent.storage') storage_factory_module = MagicMock() storage_config_module = MagicMock() # Create mock classes/functions that will be imported @@ -78,17 +96,17 @@ # Apply the patches before importing the module being tested with patch('botocore.client.BaseClient._make_api_call'), \ patch('elasticsearch.Elasticsearch', return_value=MagicMock()): - from backend.services.elasticsearch_service import ElasticSearchService, check_knowledge_base_exist_impl + from backend.services.vectordatabase_service import ElasticSearchService, check_knowledge_base_exist_impl -def _accurate_search_impl(request, es_core): +def _accurate_search_impl(request, vdb_core): start_time = time.time() if not request.query or not request.query.strip(): raise Exception("Search query cannot be empty") if not request.index_names: raise Exception("At least one index name is required") - results = es_core.accurate_search( + results = vdb_core.accurate_search( index_names=request.index_names, query=request.query, top_k=request.top_k @@ -103,9 +121,9 @@ def _accurate_search_impl(request, es_core): } -def _semantic_search_impl(request, es_core): +def _semantic_search_impl(request, vdb_core): start_time = time.time() - results = es_core.semantic_search( + results = vdb_core.semantic_search( index_names=request.index_names, query=request.query, top_k=request.top_k @@ -120,9 +138,9 @@ def _semantic_search_impl(request, es_core): } -def _hybrid_search_impl(request, es_core): +def _hybrid_search_impl(request, vdb_core): start_time = time.time() - results = es_core.hybrid_search( + results = vdb_core.hybrid_search( index_names=request.index_names, query=request.query, top_k=request.top_k, @@ -152,13 +170,13 @@ def setUp(self): that will be used across test cases. """ self.es_service = ElasticSearchService() - self.mock_es_core = MagicMock() - self.mock_es_core.embedding_model = MagicMock() - self.mock_es_core.embedding_dim = 768 + self.mock_vdb_core = MagicMock() + self.mock_vdb_core.embedding_model = MagicMock() + self.mock_vdb_core.embedding_dim = 768 # Patch get_embedding_model for all tests self.get_embedding_model_patcher = patch( - 'backend.services.elasticsearch_service.get_embedding_model') + 'backend.services.vectordatabase_service.get_embedding_model') self.mock_get_embedding = self.get_embedding_model_patcher.start() self.mock_embedding = MagicMock() self.mock_embedding.embedding_dim = 768 @@ -178,7 +196,7 @@ def tearDown(self): del ElasticSearchService.semantic_search del ElasticSearchService.hybrid_search - @patch('backend.services.elasticsearch_service.create_knowledge_record') + @patch('backend.services.vectordatabase_service.create_knowledge_record') def test_create_index_success(self, mock_create_knowledge): """ Test successful index creation. @@ -190,28 +208,28 @@ def test_create_index_success(self, mock_create_knowledge): 4. The method returns a success status """ # Setup - self.mock_es_core.client.indices.exists.return_value = False - self.mock_es_core.create_vector_index.return_value = True + self.mock_vdb_core.check_index_exists.return_value = False + self.mock_vdb_core.create_index.return_value = True mock_create_knowledge.return_value = True # Execute result = ElasticSearchService.create_index( index_name="test_index", embedding_dim=768, - es_core=self.mock_es_core, + vdb_core=self.mock_vdb_core, user_id="test_user", tenant_id="test_tenant" # Added explicit tenant_id ) # Assert self.assertEqual(result["status"], "success") - self.mock_es_core.client.indices.exists.assert_called_once_with( - index="test_index") - self.mock_es_core.create_vector_index.assert_called_once_with( + self.mock_vdb_core.check_index_exists.assert_called_once_with( + "test_index") + self.mock_vdb_core.create_index.assert_called_once_with( "test_index", embedding_dim=768) mock_create_knowledge.assert_called_once() - @patch('backend.services.elasticsearch_service.create_knowledge_record') + @patch('backend.services.vectordatabase_service.create_knowledge_record') def test_create_index_already_exists(self, mock_create_knowledge): """ Test index creation when the index already exists. @@ -222,14 +240,14 @@ def test_create_index_already_exists(self, mock_create_knowledge): 3. No knowledge record is created """ # Setup - self.mock_es_core.client.indices.exists.return_value = True + self.mock_vdb_core.check_index_exists.return_value = True # Execute and Assert with self.assertRaises(Exception) as context: ElasticSearchService.create_index( index_name="test_index", embedding_dim=768, - es_core=self.mock_es_core, + vdb_core=self.mock_vdb_core, user_id="test_user" ) @@ -237,7 +255,7 @@ def test_create_index_already_exists(self, mock_create_knowledge): self.assertIn("already exists", str(context.exception)) mock_create_knowledge.assert_not_called() - @patch('backend.services.elasticsearch_service.create_knowledge_record') + @patch('backend.services.vectordatabase_service.create_knowledge_record') def test_create_index_failure(self, mock_create_knowledge): """ Test index creation failure. @@ -248,15 +266,15 @@ def test_create_index_failure(self, mock_create_knowledge): 3. No knowledge record is created """ # Setup - self.mock_es_core.client.indices.exists.return_value = False - self.mock_es_core.create_vector_index.return_value = False + self.mock_vdb_core.check_index_exists.return_value = False + self.mock_vdb_core.create_index.return_value = False # Execute and Assert with self.assertRaises(Exception) as context: ElasticSearchService.create_index( index_name="test_index", embedding_dim=768, - es_core=self.mock_es_core, + vdb_core=self.mock_vdb_core, user_id="test_user", tenant_id="test_tenant" # Added explicit tenant_id ) @@ -264,7 +282,7 @@ def test_create_index_failure(self, mock_create_knowledge): self.assertIn("Failed to create index", str(context.exception)) mock_create_knowledge.assert_not_called() - @patch('backend.services.elasticsearch_service.delete_knowledge_record') + @patch('backend.services.vectordatabase_service.delete_knowledge_record') def test_delete_index_success(self, mock_delete_knowledge): """ Test successful index deletion. @@ -275,26 +293,26 @@ def test_delete_index_success(self, mock_delete_knowledge): 3. The method returns a success status """ # Setup - self.mock_es_core.delete_index.return_value = True + self.mock_vdb_core.delete_index.return_value = True mock_delete_knowledge.return_value = True # Execute async def run_test(): result = await ElasticSearchService.delete_index( index_name="test_index", - es_core=self.mock_es_core, + vdb_core=self.mock_vdb_core, user_id="test_user" ) # Assert self.assertEqual(result["status"], "success") - self.mock_es_core.delete_index.assert_called_once_with( + self.mock_vdb_core.delete_index.assert_called_once_with( "test_index") mock_delete_knowledge.assert_called_once() asyncio.run(run_test()) - @patch('backend.services.elasticsearch_service.delete_knowledge_record') + @patch('backend.services.vectordatabase_service.delete_knowledge_record') def test_delete_index_failure(self, mock_delete_knowledge): """ Test index deletion failure. @@ -304,26 +322,26 @@ def test_delete_index_failure(self, mock_delete_knowledge): 2. The method returns success status if knowledge record deletion succeeds """ # Setup - self.mock_es_core.delete_index.return_value = False + self.mock_vdb_core.delete_index.return_value = False mock_delete_knowledge.return_value = True # Execute async def run_test(): result = await ElasticSearchService.delete_index( index_name="test_index", - es_core=self.mock_es_core, + vdb_core=self.mock_vdb_core, user_id="test_user" ) # Assert self.assertEqual(result["status"], "success") - self.mock_es_core.delete_index.assert_called_once_with( + self.mock_vdb_core.delete_index.assert_called_once_with( "test_index") mock_delete_knowledge.assert_called_once() asyncio.run(run_test()) - @patch('backend.services.elasticsearch_service.delete_knowledge_record') + @patch('backend.services.vectordatabase_service.delete_knowledge_record') def test_delete_index_knowledge_record_failure(self, mock_delete_knowledge): """ Test deletion when the index is deleted but knowledge record deletion fails. @@ -334,7 +352,7 @@ def test_delete_index_knowledge_record_failure(self, mock_delete_knowledge): 3. The exception message contains "Error deleting knowledge record" """ # Setup - self.mock_es_core.delete_index.return_value = True + self.mock_vdb_core.delete_index.return_value = True mock_delete_knowledge.return_value = False # Execute and Assert @@ -342,7 +360,7 @@ async def run_test(): with self.assertRaises(Exception) as context: await ElasticSearchService.delete_index( index_name="test_index", - es_core=self.mock_es_core, + vdb_core=self.mock_vdb_core, user_id="test_user" ) @@ -351,7 +369,7 @@ async def run_test(): asyncio.run(run_test()) - @patch('backend.services.elasticsearch_service.get_knowledge_info_by_tenant_id') + @patch('backend.services.vectordatabase_service.get_knowledge_info_by_tenant_id') def test_list_indices_without_stats(self, mock_get_knowledge): """ Test listing indices without including statistics. @@ -362,7 +380,7 @@ def test_list_indices_without_stats(self, mock_get_knowledge): 3. No statistics are requested when include_stats is False """ # Setup - self.mock_es_core.get_user_indices.return_value = ["index1", "index2"] + self.mock_vdb_core.get_user_indices.return_value = ["index1", "index2"] mock_get_knowledge.return_value = [ {"index_name": "index1", "embedding_model_name": "test-model"}, {"index_name": "index2", "embedding_model_name": "test-model"} @@ -374,16 +392,16 @@ def test_list_indices_without_stats(self, mock_get_knowledge): include_stats=False, tenant_id="test_tenant", # Now required parameter user_id="test_user", # New required parameter - es_core=self.mock_es_core + vdb_core=self.mock_vdb_core ) # Assert self.assertEqual(len(result["indices"]), 2) self.assertEqual(result["count"], 2) - self.mock_es_core.get_user_indices.assert_called_once_with("*") + self.mock_vdb_core.get_user_indices.assert_called_once_with("*") mock_get_knowledge.assert_called_once_with(tenant_id="test_tenant") - @patch('backend.services.elasticsearch_service.get_knowledge_info_by_tenant_id') + @patch('backend.services.vectordatabase_service.get_knowledge_info_by_tenant_id') def test_list_indices_with_stats(self, mock_get_knowledge): """ Test listing indices with statistics included. @@ -394,8 +412,8 @@ def test_list_indices_with_stats(self, mock_get_knowledge): 3. Both indices and their stats are included in the response """ # Setup - self.mock_es_core.get_user_indices.return_value = ["index1", "index2"] - self.mock_es_core.get_index_stats.return_value = { + self.mock_vdb_core.get_user_indices.return_value = ["index1", "index2"] + self.mock_vdb_core.get_indices_detail.return_value = { "index1": {"base_info": {"doc_count": 10, "embedding_model": "test-model"}}, "index2": {"base_info": {"doc_count": 20, "embedding_model": "test-model"}} } @@ -410,282 +428,151 @@ def test_list_indices_with_stats(self, mock_get_knowledge): include_stats=True, tenant_id="test_tenant", # Now required parameter user_id="test_user", # New required parameter - es_core=self.mock_es_core + vdb_core=self.mock_vdb_core ) # Assert self.assertEqual(len(result["indices"]), 2) self.assertEqual(result["count"], 2) self.assertEqual(len(result["indices_info"]), 2) - self.mock_es_core.get_user_indices.assert_called_once_with("*") - self.mock_es_core.get_index_stats.assert_called_once_with( + self.mock_vdb_core.get_user_indices.assert_called_once_with("*") + self.mock_vdb_core.get_indices_detail.assert_called_once_with( ["index1", "index2"]) mock_get_knowledge.assert_called_once_with(tenant_id="test_tenant") - def test_get_index_name_success(self): + @patch('backend.services.vectordatabase_service.get_knowledge_info_by_tenant_id') + @patch('backend.services.vectordatabase_service.delete_knowledge_record') + def test_list_indices_removes_stale_pg_records(self, mock_delete_knowledge, mock_get_info): """ - Test successful retrieval of index information. - - This test verifies that: - 1. Index statistics are correctly retrieved - 2. Index mapping details are correctly retrieved - 3. The response contains both base information and field details - 4. All expected information is present in the response + Test that list_indices deletes PostgreSQL records whose indices are missing in Elasticsearch. """ - # Setup - self.mock_es_core.get_index_stats.return_value = { - "test_index": { - "base_info": { - "doc_count": 10, - "unique_sources_count": 5, - "store_size": "1MB", - "process_source": "Test", - "embedding_model": "Test" - }, - "search_performance": {"avg_time": 10} - } - } - self.mock_es_core.get_index_mapping.return_value = { - "test_index": ["field1", "field2"] - } + self.mock_vdb_core.get_user_indices.return_value = ["es_index"] + mock_get_info.return_value = [ + {"index_name": "dangling_index", "embedding_model_name": "model-A"} + ] - # Execute - result = ElasticSearchService.get_index_name( - index_name="test_index", - es_core=self.mock_es_core + result = ElasticSearchService.list_indices( + pattern="*", + include_stats=False, + tenant_id="tenant-1", + user_id="user-1", + vdb_core=self.mock_vdb_core ) - # Assert - self.assertEqual(result["base_info"]["doc_count"], 10) - self.assertEqual(len(result["fields"]), 2) - self.mock_es_core.get_index_stats.assert_called_once_with([ - "test_index"]) - self.mock_es_core.get_index_mapping.assert_called_once_with([ - "test_index"]) - - def test_get_index_name_stats_not_found(self): - """ - Test get_index_name when index stats are not found. - - This test verifies that: - 1. When index stats are not found, appropriate error logging occurs - 2. The method continues execution with empty stats - 3. The response contains default values for missing stats - """ - # Setup - self.mock_es_core.get_index_stats.return_value = {} - self.mock_es_core.get_index_mapping.return_value = { - "test_index": ["field1", "field2"] - } - - # Execute - result = ElasticSearchService.get_index_name( - index_name="test_index", - es_core=self.mock_es_core + mock_delete_knowledge.assert_called_once_with( + {"index_name": "dangling_index", "user_id": "user-1"} ) + self.assertEqual(result["indices"], []) + self.assertEqual(result["count"], 0) - # Assert - self.assertEqual(result["base_info"]["doc_count"], 0) - self.assertEqual(result["base_info"]["process_source"], "Unknown") - self.assertEqual(result["base_info"]["embedding_model"], "Unknown") - self.assertEqual(len(result["fields"]), 2) - - def test_get_index_name_mappings_not_found(self): + @patch('backend.services.vectordatabase_service.get_knowledge_info_by_tenant_id') + def test_list_indices_stats_defaults_when_missing(self, mock_get_info): """ - Test get_index_name when index mappings are not found. - - This test verifies that: - 1. When index mappings are not found, appropriate error logging occurs - 2. The method continues execution with empty fields - 3. The response contains empty fields list + Test list_indices include_stats path when Elasticsearch returns no stats for an index. """ - # Setup - self.mock_es_core.get_index_stats.return_value = { - "test_index": { - "base_info": { - "doc_count": 10, - "unique_sources_count": 5, - "store_size": "1MB", - "process_source": "Test", - "embedding_model": "Test" - } - } - } - self.mock_es_core.get_index_mapping.return_value = {} + self.mock_vdb_core.get_user_indices.return_value = ["index1"] + mock_get_info.return_value = [ + {"index_name": "index1", "embedding_model_name": "model-A"} + ] + self.mock_vdb_core.get_indices_detail.return_value = {} - # Execute - result = ElasticSearchService.get_index_name( - index_name="test_index", - es_core=self.mock_es_core + result = ElasticSearchService.list_indices( + pattern="*", + include_stats=True, + tenant_id="tenant-1", + user_id="user-1", + vdb_core=self.mock_vdb_core ) - # Assert - self.assertEqual(result["base_info"]["doc_count"], 10) - self.assertEqual(result["fields"], []) + self.assertEqual(result["indices"], ["index1"]) + self.assertEqual(result["indices_info"][0]["name"], "index1") + self.assertEqual(result["indices_info"][0]["stats"], {}) - def test_get_index_name_no_base_info(self): + @patch('backend.services.vectordatabase_service.update_model_name_by_index_name') + @patch('backend.services.vectordatabase_service.get_knowledge_info_by_tenant_id') + def test_list_indices_backfills_missing_model_names(self, mock_get_info, mock_update_model): """ - Test get_index_name when base_info is missing from stats. - - This test verifies that: - 1. When base_info is missing, appropriate error logging occurs - 2. The method provides default values for missing base_info - 3. The response contains reasonable defaults + Test that list_indices updates database records when embedding_model_name is missing. """ - # Setup - self.mock_es_core.get_index_stats.return_value = { - "test_index": { - "search_performance": {"avg_time": 10} - } - } - self.mock_es_core.get_index_mapping.return_value = { - "test_index": ["field1"] + self.mock_vdb_core.get_user_indices.return_value = ["index1"] + mock_get_info.return_value = [ + {"index_name": "index1", "embedding_model_name": None} + ] + self.mock_vdb_core.get_indices_detail.return_value = { + "index1": {"base_info": {"embedding_model": "text-embedding-ada-002"}} } - # Execute - result = ElasticSearchService.get_index_name( - index_name="test_index", - es_core=self.mock_es_core + result = ElasticSearchService.list_indices( + pattern="*", + include_stats=True, + tenant_id="tenant-1", + user_id="user-1", + vdb_core=self.mock_vdb_core ) - # Assert - self.assertEqual(result["base_info"]["doc_count"], 0) - self.assertEqual(result["base_info"]["process_source"], "Unknown") - self.assertEqual(result["base_info"]["embedding_model"], "Unknown") - # Fix: search_performance should still be preserved even when base_info is missing - self.assertEqual(result["search_performance"], {}) - - def test_get_index_name_elasticsearch_connection_error(self): - """ - Test get_index_name when Elasticsearch connection fails. - - This test verifies that: - 1. When Elasticsearch connection fails (503 error), appropriate exception is raised - 2. The exception message contains "ElasticSearch service unavailable" - 3. The error is properly categorized as a connection issue - """ - # Setup - self.mock_es_core.get_index_stats.side_effect = Exception( - "503 Service Unavailable") - - # Execute and Assert - with self.assertRaises(Exception) as context: - ElasticSearchService.get_index_name( - index_name="test_index", - es_core=self.mock_es_core - ) - - self.assertIn("ElasticSearch service unavailable", - str(context.exception)) - - def test_get_index_name_api_error(self): - """ - Test get_index_name when Elasticsearch API returns an error. - - This test verifies that: - 1. When Elasticsearch API error occurs, appropriate exception is raised - 2. The exception message contains "ElasticSearch API error" - 3. The error is properly categorized as an API issue - """ - # Setup - self.mock_es_core.get_index_stats.side_effect = Exception( - "ApiError: Invalid request") - - # Execute and Assert - with self.assertRaises(Exception) as context: - ElasticSearchService.get_index_name( - index_name="test_index", - es_core=self.mock_es_core - ) - - self.assertIn("ElasticSearch API error", str(context.exception)) - - def test_get_index_name_generic_error(self): - """ - Test get_index_name when a generic error occurs. - - This test verifies that: - 1. When a generic error occurs, appropriate exception is raised - 2. The exception message contains "Error getting info for index" - 3. The error is properly categorized as a generic issue - """ - # Setup - self.mock_es_core.get_index_stats.side_effect = Exception( - "Generic error message") - - # Execute and Assert - with self.assertRaises(Exception) as context: - ElasticSearchService.get_index_name( - index_name="test_index", - es_core=self.mock_es_core - ) - - self.assertIn("Error getting info for index", str(context.exception)) + mock_update_model.assert_called_once_with( + "index1", "text-embedding-ada-002", "tenant-1", "user-1" + ) + self.assertEqual(result["count"], 1) + self.assertEqual(result["indices"][0], "index1") - def test_get_index_name_search_phase_execution_exception(self): + @patch('backend.services.vectordatabase_service.get_knowledge_info_by_tenant_id') + def test_list_indices_stats_surfaces_elasticsearch_errors(self, mock_get_info): """ - Test get_index_name when search_phase_execution_exception occurs. - - This test verifies that: - 1. When search_phase_execution_exception occurs, appropriate exception is raised - 2. The exception message contains "ElasticSearch service unavailable" - 3. The error is properly categorized as a connection issue + Test that list_indices propagates Elasticsearch errors while fetching stats. """ - # Setup - self.mock_es_core.get_index_stats.side_effect = Exception( - "search_phase_execution_exception: No shard available") + self.mock_vdb_core.get_user_indices.return_value = ["index1"] + mock_get_info.return_value = [ + {"index_name": "index1", "embedding_model_name": "model-A"} + ] + self.mock_vdb_core.get_indices_detail.side_effect = Exception( + "503 Service Unavailable" + ) - # Execute and Assert with self.assertRaises(Exception) as context: - ElasticSearchService.get_index_name( - index_name="test_index", - es_core=self.mock_es_core + ElasticSearchService.list_indices( + pattern="*", + include_stats=True, + tenant_id="tenant-1", + user_id="user-1", + vdb_core=self.mock_vdb_core ) - self.assertIn("ElasticSearch service unavailable", - str(context.exception)) + self.assertIn("503 Service Unavailable", str(context.exception)) - def test_get_index_name_success_status_200(self): + @patch('backend.services.vectordatabase_service.get_knowledge_info_by_tenant_id') + def test_list_indices_stats_keeps_non_stat_fields(self, mock_get_info): """ - Test get_index_name method returns status code 200 on success. - - This test verifies that: - 1. The get_index_name method successfully retrieves index information - 2. The response contains the expected data structure - 3. The method completes without raising exceptions, implying a 200 status code + Test that list_indices preserves all stats fields returned by ElasticSearchCore. """ - # Setup - self.mock_es_core.get_index_stats.return_value = { - "test_index": { + self.mock_vdb_core.get_user_indices.return_value = ["index1"] + mock_get_info.return_value = [ + {"index_name": "index1", "embedding_model_name": "model-A"} + ] + detailed_stats = { + "index1": { "base_info": { - "doc_count": 15, - "unique_sources_count": 8, - "store_size": "2MB", + "doc_count": 42, "process_source": "Unstructured", - "embedding_model": "text-embedding-ada-002" + "embedding_model": "text-embedding-3-large" }, - "search_performance": {"avg_query_time": 25.5} + "search_performance": {"avg_time": 12.3} } } - self.mock_es_core.get_index_mapping.return_value = { - "test_index": ["title", "content", "path_or_url", "create_time"] - } + self.mock_vdb_core.get_indices_detail.return_value = detailed_stats - # Execute - result = ElasticSearchService.get_index_name( - index_name="test_index", - es_core=self.mock_es_core + result = ElasticSearchService.list_indices( + pattern="*", + include_stats=True, + tenant_id="tenant-1", + user_id="user-1", + vdb_core=self.mock_vdb_core ) - # Assert - self.assertIsInstance(result, dict) # Success response is a dictionary - self.assertIn("base_info", result) - self.assertIn("search_performance", result) - self.assertIn("fields", result) - self.assertEqual(result["base_info"]["doc_count"], 15) - self.assertEqual(len(result["fields"]), 4) + self.assertEqual(len(result["indices_info"]), 1) + self.assertEqual(result["indices_info"][0]["stats"], detailed_stats["index1"]) - def test_index_documents_success(self): + def test_vectorize_documents_success(self): """ Test successful document indexing. @@ -696,8 +583,8 @@ def test_index_documents_success(self): 4. Documents with various metadata fields are handled correctly """ # Setup - self.mock_es_core.client.indices.exists.return_value = True - self.mock_es_core.index_documents.return_value = 2 + self.mock_vdb_core.check_index_exists.return_value = True + self.mock_vdb_core.vectorize_documents.return_value = 2 mock_embedding_model = MagicMock() mock_embedding_model.model = "test-model" @@ -729,7 +616,7 @@ def test_index_documents_success(self): result = ElasticSearchService.index_documents( index_name="test_index", data=test_data, - es_core=self.mock_es_core, + vdb_core=self.mock_vdb_core, embedding_model=mock_embedding_model ) @@ -737,9 +624,9 @@ def test_index_documents_success(self): self.assertTrue(result["success"]) self.assertEqual(result["total_indexed"], 2) self.assertEqual(result["total_submitted"], 2) - self.mock_es_core.index_documents.assert_called_once() + self.mock_vdb_core.vectorize_documents.assert_called_once() - def test_index_documents_empty_data(self): + def test_vectorize_documents_empty_data(self): """ Test document indexing with empty data. @@ -756,7 +643,7 @@ def test_index_documents_empty_data(self): result = ElasticSearchService.index_documents( index_name="test_index", data=test_data, - es_core=self.mock_es_core, + vdb_core=self.mock_vdb_core, embedding_model=mock_embedding_model ) @@ -764,9 +651,9 @@ def test_index_documents_empty_data(self): self.assertTrue(result["success"]) self.assertEqual(result["total_indexed"], 0) self.assertEqual(result["total_submitted"], 0) - self.mock_es_core.index_documents.assert_not_called() + self.mock_vdb_core.vectorize_documents.assert_not_called() - def test_index_documents_create_index(self): + def test_vectorize_documents_create_index(self): """ Test document indexing when the index doesn't exist. @@ -776,9 +663,9 @@ def test_index_documents_create_index(self): 3. The response contains the correct status and document counts """ # Setup - self.mock_es_core.client.indices.exists.return_value = False - self.mock_es_core.create_vector_index.return_value = True - self.mock_es_core.index_documents.return_value = 1 + self.mock_vdb_core.check_index_exists.return_value = False + self.mock_vdb_core.create_index.return_value = True + self.mock_vdb_core.vectorize_documents.return_value = 1 mock_embedding_model = MagicMock() test_data = [ { @@ -789,12 +676,12 @@ def test_index_documents_create_index(self): ] # Execute - with patch('backend.services.elasticsearch_service.ElasticSearchService.create_index') as mock_create_index: + with patch('backend.services.vectordatabase_service.ElasticSearchService.create_index') as mock_create_index: mock_create_index.return_value = {"status": "success"} result = ElasticSearchService.index_documents( index_name="test_index", data=test_data, - es_core=self.mock_es_core, + vdb_core=self.mock_vdb_core, embedding_model=mock_embedding_model ) @@ -803,7 +690,7 @@ def test_index_documents_create_index(self): self.assertEqual(result["total_indexed"], 1) mock_create_index.assert_called_once() - def test_index_documents_indexing_error(self): + def test_vectorize_documents_indexing_error(self): """ Test document indexing when an error occurs during indexing. @@ -813,8 +700,8 @@ def test_index_documents_indexing_error(self): 3. The exception message contains "Error during indexing" """ # Setup - self.mock_es_core.client.indices.exists.return_value = True - self.mock_es_core.index_documents.side_effect = Exception( + self.mock_vdb_core.check_index_exists.return_value = True + self.mock_vdb_core.vectorize_documents.side_effect = Exception( "Indexing error") mock_embedding_model = MagicMock() test_data = [ @@ -830,13 +717,13 @@ def test_index_documents_indexing_error(self): ElasticSearchService.index_documents( index_name="test_index", data=test_data, - es_core=self.mock_es_core, + vdb_core=self.mock_vdb_core, embedding_model=mock_embedding_model ) self.assertIn("Error during indexing", str(context.exception)) - @patch('backend.services.elasticsearch_service.get_all_files_status') + @patch('backend.services.vectordatabase_service.get_all_files_status') def test_list_files_without_chunks(self, mock_get_files_status): """ Test listing files without including document chunks. @@ -848,7 +735,7 @@ def test_list_files_without_chunks(self, mock_get_files_status): 4. The status of each file is correctly set (COMPLETED or PROCESSING) """ # Setup - self.mock_es_core.get_file_list_with_details.return_value = [ + self.mock_vdb_core.get_documents_detail.return_value = [ { "path_or_url": "file1", "filename": "file1.txt", @@ -864,7 +751,7 @@ async def run_test(): return await ElasticSearchService.list_files( index_name="test_index", include_chunks=False, - es_core=self.mock_es_core + vdb_core=self.mock_vdb_core ) result = asyncio.run(run_test()) @@ -873,10 +760,10 @@ async def run_test(): self.assertEqual(len(result["files"]), 2) self.assertEqual(result["files"][0]["status"], "COMPLETED") self.assertEqual(result["files"][1]["status"], "PROCESSING") - self.mock_es_core.get_file_list_with_details.assert_called_once_with( + self.mock_vdb_core.get_documents_detail.assert_called_once_with( "test_index") - @patch('backend.services.elasticsearch_service.get_all_files_status') + @patch('backend.services.vectordatabase_service.get_all_files_status') def test_list_files_with_chunks(self, mock_get_files_status): """ Test listing files with document chunks included. @@ -888,7 +775,7 @@ def test_list_files_with_chunks(self, mock_get_files_status): 4. The chunk count is correctly calculated """ # Setup - self.mock_es_core.get_file_list_with_details.return_value = [ + self.mock_vdb_core.get_documents_detail.return_value = [ { "path_or_url": "file1", "filename": "file1.txt", @@ -898,7 +785,7 @@ def test_list_files_with_chunks(self, mock_get_files_status): ] mock_get_files_status.return_value = {} - # Mock msearch response + # Mock multi_search response msearch_response = { 'responses': [ { @@ -917,14 +804,14 @@ def test_list_files_with_chunks(self, mock_get_files_status): } ] } - self.mock_es_core.client.msearch.return_value = msearch_response + self.mock_vdb_core.multi_search.return_value = msearch_response # Execute async def run_test(): return await ElasticSearchService.list_files( index_name="test_index", include_chunks=True, - es_core=self.mock_es_core + vdb_core=self.mock_vdb_core ) result = asyncio.run(run_test()) @@ -933,9 +820,9 @@ async def run_test(): self.assertEqual(len(result["files"]), 1) self.assertEqual(len(result["files"][0]["chunks"]), 1) self.assertEqual(result["files"][0]["chunk_count"], 1) - self.mock_es_core.client.msearch.assert_called_once() + self.mock_vdb_core.multi_search.assert_called_once() - @patch('backend.services.elasticsearch_service.get_all_files_status') + @patch('backend.services.vectordatabase_service.get_all_files_status') def test_list_files_msearch_error(self, mock_get_files_status): """ Test listing files when msearch encounters an error. @@ -947,7 +834,7 @@ def test_list_files_msearch_error(self, mock_get_files_status): 4. The overall operation doesn't fail due to msearch errors """ # Setup - self.mock_es_core.get_file_list_with_details.return_value = [ + self.mock_vdb_core.get_documents_detail.return_value = [ { "path_or_url": "file1", "filename": "file1.txt", @@ -958,7 +845,7 @@ def test_list_files_msearch_error(self, mock_get_files_status): mock_get_files_status.return_value = {} # Mock msearch error - self.mock_es_core.client.msearch.side_effect = Exception( + self.mock_vdb_core.client.msearch.side_effect = Exception( "MSSearch Error") # Execute @@ -966,7 +853,7 @@ async def run_test(): return await ElasticSearchService.list_files( index_name="test_index", include_chunks=True, - es_core=self.mock_es_core + vdb_core=self.mock_vdb_core ) result = asyncio.run(run_test()) @@ -976,7 +863,7 @@ async def run_test(): self.assertEqual(len(result["files"][0]["chunks"]), 0) self.assertEqual(result["files"][0]["chunk_count"], 0) - @patch('backend.services.elasticsearch_service.delete_file') + @patch('backend.services.vectordatabase_service.delete_file') def test_delete_documents(self, mock_delete_file): """ Test document deletion by path or URL. @@ -986,7 +873,7 @@ def test_delete_documents(self, mock_delete_file): 2. The response contains a success status """ # Setup - self.mock_es_core.delete_documents_by_path_or_url.return_value = 5 + self.mock_vdb_core.delete_documents.return_value = 5 # Configure delete_file to return a success response mock_delete_file.return_value = {"success": True, "object_name": "test_path"} @@ -994,14 +881,14 @@ def test_delete_documents(self, mock_delete_file): result = ElasticSearchService.delete_documents( index_name="test_index", path_or_url="test_path", - es_core=self.mock_es_core + vdb_core=self.mock_vdb_core ) # Assert self.assertEqual(result["status"], "success") self.assertEqual(result["deleted_minio"], True) - # Verify that delete_documents_by_path_or_url was called with correct parameters - self.mock_es_core.delete_documents_by_path_or_url.assert_called_once_with( + # Verify that delete_documents was called with correct parameters + self.mock_vdb_core.delete_documents.assert_called_once_with( "test_index", "test_path") # Verify that delete_file was called with the correct path mock_delete_file.assert_called_once_with("test_path") @@ -1022,7 +909,7 @@ def test_accurate_search(self): search_request.query = "test query" search_request.top_k = 10 - self.mock_es_core.accurate_search.return_value = [ + self.mock_vdb_core.accurate_search.return_value = [ { "document": {"title": "Doc1", "content": "Content1"}, "score": 0.95, @@ -1033,14 +920,14 @@ def test_accurate_search(self): # Execute result = ElasticSearchService.accurate_search( request=search_request, - es_core=self.mock_es_core + vdb_core=self.mock_vdb_core ) # Assert self.assertEqual(len(result["results"]), 1) self.assertEqual(result["total"], 1) self.assertTrue("query_time_ms" in result) - self.mock_es_core.accurate_search.assert_called_once_with( + self.mock_vdb_core.accurate_search.assert_called_once_with( index_names=["test_index"], query="test query", top_k=10 ) @@ -1063,7 +950,7 @@ def test_accurate_search_empty_query(self): with self.assertRaises(Exception) as context: ElasticSearchService.accurate_search( request=search_request, - es_core=self.mock_es_core + vdb_core=self.mock_vdb_core ) self.assertIn("Search query cannot be empty", str(context.exception)) @@ -1087,7 +974,7 @@ def test_accurate_search_no_indices(self): with self.assertRaises(Exception) as context: ElasticSearchService.accurate_search( request=search_request, - es_core=self.mock_es_core + vdb_core=self.mock_vdb_core ) self.assertIn("At least one index name is required", @@ -1109,8 +996,8 @@ def test_semantic_search(self): search_request.query = "test query" search_request.top_k = 10 - # Create a mock response directly on the es_core instance - self.mock_es_core.semantic_search.return_value = [ + # Create a mock response directly on the vdb_core instance + self.mock_vdb_core.semantic_search.return_value = [ { "document": {"title": "Doc1", "content": "Content1"}, "score": 0.85, @@ -1121,14 +1008,14 @@ def test_semantic_search(self): # Execute result = ElasticSearchService.semantic_search( request=search_request, - es_core=self.mock_es_core + vdb_core=self.mock_vdb_core ) # Assert self.assertEqual(len(result["results"]), 1) self.assertEqual(result["total"], 1) self.assertTrue("query_time_ms" in result) - self.mock_es_core.semantic_search.assert_called_once_with( + self.mock_vdb_core.semantic_search.assert_called_once_with( index_names=["test_index"], query="test query", top_k=10 ) @@ -1149,8 +1036,8 @@ def test_hybrid_search(self): search_request.top_k = 10 search_request.weight_accurate = 0.5 - # Create a mock response directly on the es_core instance - self.mock_es_core.hybrid_search.return_value = [ + # Create a mock response directly on the vdb_core instance + self.mock_vdb_core.hybrid_search.return_value = [ { "document": {"title": "Doc1", "content": "Content1"}, "score": 0.90, @@ -1162,7 +1049,7 @@ def test_hybrid_search(self): # Execute result = ElasticSearchService.hybrid_search( request=search_request, - es_core=self.mock_es_core + vdb_core=self.mock_vdb_core ) # Assert @@ -1173,7 +1060,7 @@ def test_hybrid_search(self): ["score_details"]["accurate"], 0.85) self.assertEqual(result["results"][0] ["score_details"]["semantic"], 0.95) - self.mock_es_core.hybrid_search.assert_called_once_with( + self.mock_vdb_core.hybrid_search.assert_called_once_with( index_names=["test_index"], query="test query", top_k=10, weight_accurate=0.5 ) @@ -1187,10 +1074,10 @@ def test_health_check_healthy(self): 3. The health_check method returns without raising exceptions """ # Setup - self.mock_es_core.get_user_indices.return_value = ["index1", "index2"] + self.mock_vdb_core.get_user_indices.return_value = ["index1", "index2"] # Execute - result = ElasticSearchService.health_check(es_core=self.mock_es_core) + result = ElasticSearchService.health_check(vdb_core=self.mock_vdb_core) # Assert self.assertEqual(result["status"], "healthy") @@ -1207,17 +1094,17 @@ def test_health_check_unhealthy(self): 3. The exception message contains "Health check failed" """ # Setup - self.mock_es_core.get_user_indices.side_effect = Exception( + self.mock_vdb_core.get_user_indices.side_effect = Exception( "Connection error") # Execute and Assert with self.assertRaises(Exception) as context: - ElasticSearchService.health_check(es_core=self.mock_es_core) + ElasticSearchService.health_check(vdb_core=self.mock_vdb_core) self.assertIn("Health check failed", str(context.exception)) - @patch('backend.services.elasticsearch_service.calculate_term_weights') + @patch('backend.services.vectordatabase_service.calculate_term_weights') @patch('database.model_management_db.get_model_by_model_id') def test_summary_index_name(self, mock_get_model_by_model_id, mock_calculate_weights): """ @@ -1265,7 +1152,7 @@ async def run_test(): result = await self.es_service.summary_index_name( index_name="test_index", batch_size=1000, - es_core=self.mock_es_core, + vdb_core=self.mock_vdb_core, language='en', model_id=1, tenant_id="test_tenant" @@ -1289,6 +1176,277 @@ async def run_test(): # Basic functionality test - just verify the response is correct type # The detailed function calls are tested in their own unit tests + def test_summary_index_name_no_tenant_id(self): + """ + Test summary_index_name raises exception when tenant_id is missing. + + This test verifies that: + 1. An exception is raised when tenant_id is None + 2. The exception message contains "Tenant ID is required" + """ + # Execute and Assert + async def run_test(): + with self.assertRaises(Exception) as context: + await self.es_service.summary_index_name( + index_name="test_index", + batch_size=1000, + vdb_core=self.mock_vdb_core, + language='en', + model_id=1, + tenant_id=None # Missing tenant_id + ) + self.assertIn("Tenant ID is required", str(context.exception)) + + asyncio.run(run_test()) + + def test_summary_index_name_no_documents(self): + """ + Test summary_index_name when no documents are found in index. + + This test verifies that: + 1. An exception is raised when document_samples is empty + 2. The exception message contains "No documents found in index" + """ + # Mock the new Map-Reduce functions + with patch('utils.document_vector_utils.process_documents_for_clustering') as mock_process_docs, \ + patch('utils.document_vector_utils.kmeans_cluster_documents') as mock_cluster, \ + patch('utils.document_vector_utils.summarize_clusters_map_reduce') as mock_summarize, \ + patch('utils.document_vector_utils.merge_cluster_summaries') as mock_merge: + + # Mock return empty document_samples + mock_process_docs.return_value = ( + {}, # Empty document_samples + {} # Empty doc_embeddings + ) + + # Execute + async def run_test(): + with self.assertRaises(Exception) as context: + result = await self.es_service.summary_index_name( + index_name="test_index", + batch_size=1000, + vdb_core=self.mock_vdb_core, + language='en', + model_id=1, + tenant_id="test_tenant" + ) + # Consume the stream to trigger execution + generator = result.body_iterator + async for item in generator: + break + + self.assertIn("No documents found in index", str(context.exception)) + + asyncio.run(run_test()) + + def test_summary_index_name_runtime_error_fallback(self): + """ + Test summary_index_name fallback when get_running_loop raises RuntimeError. + + This test verifies that: + 1. When get_running_loop() raises RuntimeError, get_event_loop() is used as fallback + 2. The summary generation still works correctly + """ + # Mock the new Map-Reduce functions + with patch('utils.document_vector_utils.process_documents_for_clustering') as mock_process_docs, \ + patch('utils.document_vector_utils.kmeans_cluster_documents') as mock_cluster, \ + patch('utils.document_vector_utils.summarize_clusters_map_reduce') as mock_summarize, \ + patch('utils.document_vector_utils.merge_cluster_summaries') as mock_merge: + + # Mock return values + mock_process_docs.return_value = ( + {"doc1": {"chunks": [{"content": "test content"}]}}, # document_samples + {"doc1": np.array([0.1, 0.2, 0.3])} # doc_embeddings + ) + mock_cluster.return_value = {"doc1": 0} # clusters + mock_summarize.return_value = {0: "Test cluster summary"} # cluster_summaries + mock_merge.return_value = "Final merged summary" # final_summary + + # Create a mock loop with run_in_executor that returns a coroutine + mock_loop = MagicMock() + async def mock_run_in_executor(executor, func, *args): + # Execute the function synchronously and return its result + return func() + mock_loop.run_in_executor = mock_run_in_executor + + # Patch asyncio functions to trigger RuntimeError fallback + with patch('backend.services.vectordatabase_service.asyncio.get_running_loop', side_effect=RuntimeError("No running event loop")), \ + patch('backend.services.vectordatabase_service.asyncio.get_event_loop', return_value=mock_loop) as mock_get_event_loop: + + # Execute + async def run_test(): + result = await self.es_service.summary_index_name( + index_name="test_index", + batch_size=1000, + vdb_core=self.mock_vdb_core, + language='en', + model_id=1, + tenant_id="test_tenant" + ) + + # Consume part of the stream to trigger execution + generator = result.body_iterator + try: + async for item in generator: + break + except StopAsyncIteration: + pass + + return result + + result = asyncio.run(run_test()) + + # Assert + self.assertIsInstance(result, StreamingResponse) + # Verify fallback was used + mock_get_event_loop.assert_called() + + def test_summary_index_name_generator_exception(self): + """ + Test summary_index_name handles exceptions in the generator function. + + This test verifies that: + 1. Exceptions in the generator are caught and streamed as error messages + 2. The error status is properly formatted + """ + # Mock the new Map-Reduce functions + with patch('utils.document_vector_utils.process_documents_for_clustering') as mock_process_docs, \ + patch('utils.document_vector_utils.kmeans_cluster_documents') as mock_cluster, \ + patch('utils.document_vector_utils.summarize_clusters_map_reduce') as mock_summarize, \ + patch('utils.document_vector_utils.merge_cluster_summaries') as mock_merge: + + # Mock return values + mock_process_docs.return_value = ( + {"doc1": {"chunks": [{"content": "test content"}]}}, # document_samples + {"doc1": np.array([0.1, 0.2, 0.3])} # doc_embeddings + ) + mock_cluster.return_value = {"doc1": 0} # clusters + mock_summarize.return_value = {0: "Test cluster summary"} # cluster_summaries + mock_merge.return_value = "Final merged summary" # final_summary + + # Execute + async def run_test(): + result = await self.es_service.summary_index_name( + index_name="test_index", + batch_size=1000, + vdb_core=self.mock_vdb_core, + language='en', + model_id=1, + tenant_id="test_tenant" + ) + + # Consume the stream completely + generator = result.body_iterator + items = [] + try: + async for item in generator: + items.append(item) + except Exception: + pass + + return result, items + + result, items = asyncio.run(run_test()) + + # Assert + self.assertIsInstance(result, StreamingResponse) + # Verify that items were generated (at least the completed message) + self.assertGreater(len(items), 0) + + def test_summary_index_name_sample_count_calculation(self): + """ + Test summary_index_name correctly calculates sample_count from batch_size. + + This test verifies that: + 1. sample_count is calculated as min(batch_size // 5, 200) + 2. The sample_doc_count parameter is passed correctly to process_documents_for_clustering + """ + # Test with batch_size=1000 -> sample_count should be min(200, 200) = 200 + with patch('utils.document_vector_utils.process_documents_for_clustering') as mock_process_docs, \ + patch('utils.document_vector_utils.kmeans_cluster_documents') as mock_cluster, \ + patch('utils.document_vector_utils.summarize_clusters_map_reduce') as mock_summarize, \ + patch('utils.document_vector_utils.merge_cluster_summaries') as mock_merge: + + # Mock return values + mock_process_docs.return_value = ( + {"doc1": {"chunks": [{"content": "test content"}]}}, # document_samples + {"doc1": np.array([0.1, 0.2, 0.3])} # doc_embeddings + ) + mock_cluster.return_value = {"doc1": 0} # clusters + mock_summarize.return_value = {0: "Test cluster summary"} # cluster_summaries + mock_merge.return_value = "Final merged summary" # final_summary + + # Execute with batch_size=1000 + async def run_test(): + result = await self.es_service.summary_index_name( + index_name="test_index", + batch_size=1000, + vdb_core=self.mock_vdb_core, + language='en', + model_id=1, + tenant_id="test_tenant" + ) + + # Consume part of the stream to trigger execution + generator = result.body_iterator + try: + async for item in generator: + break + except StopAsyncIteration: + pass + + return result + + asyncio.run(run_test()) + + # Verify sample_doc_count was called with 200 (min(1000 // 5, 200) = 200) + self.assertTrue(mock_process_docs.called) + call_args = mock_process_docs.call_args + self.assertEqual(call_args.kwargs['sample_doc_count'], 200) + + # Test with batch_size=50 -> sample_count should be min(10, 200) = 10 + with patch('utils.document_vector_utils.process_documents_for_clustering') as mock_process_docs, \ + patch('utils.document_vector_utils.kmeans_cluster_documents') as mock_cluster, \ + patch('utils.document_vector_utils.summarize_clusters_map_reduce') as mock_summarize, \ + patch('utils.document_vector_utils.merge_cluster_summaries') as mock_merge: + + # Mock return values + mock_process_docs.return_value = ( + {"doc1": {"chunks": [{"content": "test content"}]}}, + {"doc1": np.array([0.1, 0.2, 0.3])} + ) + mock_cluster.return_value = {"doc1": 0} + mock_summarize.return_value = {0: "Test cluster summary"} + mock_merge.return_value = "Final merged summary" + + # Execute with batch_size=50 + async def run_test_small(): + result = await self.es_service.summary_index_name( + index_name="test_index", + batch_size=50, + vdb_core=self.mock_vdb_core, + language='en', + model_id=1, + tenant_id="test_tenant" + ) + + # Consume part of the stream to trigger execution + generator = result.body_iterator + try: + async for item in generator: + break + except StopAsyncIteration: + pass + + return result + + asyncio.run(run_test_small()) + + # Verify sample_doc_count was called with 10 (min(50 // 5, 200) = 10) + self.assertTrue(mock_process_docs.called) + call_args = mock_process_docs.call_args + self.assertEqual(call_args.kwargs['sample_doc_count'], 10) + def test_get_random_documents(self): """ Test retrieving random documents from an index. @@ -1299,8 +1457,7 @@ def test_get_random_documents(self): 3. The response contains both the total count and the sampled documents """ # Setup - count_response = {'count': 100} - self.mock_es_core.client.count.return_value = count_response + self.mock_vdb_core.count_documents.return_value = 100 search_response = { 'hits': { @@ -1316,23 +1473,22 @@ def test_get_random_documents(self): ] } } - self.mock_es_core.client.search.return_value = search_response + self.mock_vdb_core.search.return_value = search_response # Execute result = ElasticSearchService.get_random_documents( index_name="test_index", batch_size=10, - es_core=self.mock_es_core + vdb_core=self.mock_vdb_core ) # Assert self.assertEqual(result["total"], 100) self.assertEqual(len(result["documents"]), 2) - self.mock_es_core.client.count.assert_called_once_with( - index="test_index") - self.mock_es_core.client.search.assert_called_once() + self.mock_vdb_core.count_documents.assert_called_once_with("test_index") + self.mock_vdb_core.search.assert_called_once() - @patch('backend.services.elasticsearch_service.update_knowledge_record') + @patch('backend.services.vectordatabase_service.update_knowledge_record') def test_change_summary(self, mock_update_record): """ Test changing the summary of a knowledge base. @@ -1357,7 +1513,7 @@ def test_change_summary(self, mock_update_record): self.assertEqual(result["summary"], "Test summary") mock_update_record.assert_called_once() - @patch('backend.services.elasticsearch_service.get_knowledge_record') + @patch('backend.services.vectordatabase_service.get_knowledge_record') def test_get_summary(self, mock_get_record): """ Test retrieving the summary of a knowledge base. @@ -1380,7 +1536,7 @@ def test_get_summary(self, mock_get_record): self.assertEqual(result["summary"], "Test summary") mock_get_record.assert_called_once_with({'index_name': 'test_index'}) - @patch('backend.services.elasticsearch_service.get_knowledge_record') + @patch('backend.services.vectordatabase_service.get_knowledge_record') def test_get_summary_not_found(self, mock_get_record): """ Test retrieving a summary when the knowledge record doesn't exist. @@ -1399,7 +1555,70 @@ def test_get_summary_not_found(self, mock_get_record): self.assertIn("Unable to get summary", str(context.exception)) - @patch('backend.services.elasticsearch_service.get_knowledge_info_by_tenant_id') + def test_get_index_chunks_filters_fields(self): + """ + Test chunk retrieval filters unsupported fields and reports totals. + """ + self.mock_vdb_core.get_index_chunks.return_value = { + "chunks": [ + {"id": "1", "content": "A", "path_or_url": "/a", "extra": "ignore"}, + {"content": "B", "create_time": "2024-01-01T00:00:00"} + ], + "total": 2, + "page": None, + "page_size": None, + } + + result = ElasticSearchService.get_index_chunks( + index_name="kb-index", + vdb_core=self.mock_vdb_core + ) + + self.assertEqual(result["status"], "success") + self.assertEqual(result["total"], 2) + self.assertEqual(result["chunks"][0], {"id": "1", "content": "A", "path_or_url": "/a"}) + self.assertEqual(result["chunks"][1], {"content": "B", "create_time": "2024-01-01T00:00:00"}) + self.mock_vdb_core.get_index_chunks.assert_called_once_with( + "kb-index", + page=None, + page_size=None, + path_or_url=None, + ) + + def test_get_index_chunks_keeps_non_dict_entries(self): + """ + Test chunk retrieval keeps non-dict entries unchanged. + """ + self.mock_vdb_core.get_index_chunks.return_value = { + "chunks": ["raw_chunk"], + "total": 1, + "page": 1, + "page_size": 1, + } + + result = ElasticSearchService.get_index_chunks( + index_name="kb-index", + vdb_core=self.mock_vdb_core + ) + + self.assertEqual(result["chunks"], ["raw_chunk"]) + self.assertEqual(result["total"], 1) + + def test_get_index_chunks_error(self): + """ + Test chunk retrieval error handling. + """ + self.mock_vdb_core.get_index_chunks.side_effect = Exception("boom") + + with self.assertRaises(Exception) as exc: + ElasticSearchService.get_index_chunks( + index_name="kb-index", + vdb_core=self.mock_vdb_core + ) + + self.assertIn("Error retrieving chunks from index kb-index: boom", str(exc.exception)) + + @patch('backend.services.vectordatabase_service.get_knowledge_info_by_tenant_id') @patch('fastapi.Response') def test_list_indices_success_status_200(self, mock_response, mock_get_knowledge): """ @@ -1411,7 +1630,7 @@ def test_list_indices_success_status_200(self, mock_response, mock_get_knowledge 3. The method completes without raising exceptions, implying a 200 status code """ # Setup - self.mock_es_core.get_user_indices.return_value = ["index1", "index2"] + self.mock_vdb_core.get_user_indices.return_value = ["index1", "index2"] mock_response.status_code = 200 mock_get_knowledge.return_value = [ {"index_name": "index1", "embedding_model_name": "test-model"}, @@ -1424,7 +1643,7 @@ def test_list_indices_success_status_200(self, mock_response, mock_get_knowledge include_stats=False, tenant_id="test_tenant", # Now required parameter user_id="test_user", # New required parameter - es_core=self.mock_es_core + vdb_core=self.mock_vdb_core ) # Assert @@ -1432,7 +1651,7 @@ def test_list_indices_success_status_200(self, mock_response, mock_get_knowledge self.assertEqual(result["count"], 2) # Verify no exception is raised, implying 200 status code self.assertIsInstance(result, dict) # Success response is a dictionary - self.mock_es_core.get_user_indices.assert_called_once_with("*") + self.mock_vdb_core.get_user_indices.assert_called_once_with("*") mock_get_knowledge.assert_called_once_with(tenant_id="test_tenant") def test_health_check_success_status_200(self): @@ -1445,10 +1664,10 @@ def test_health_check_success_status_200(self): 3. The method completes without raising exceptions, implying a 200 status code """ # Setup - self.mock_es_core.get_user_indices.return_value = ["index1", "index2"] + self.mock_vdb_core.get_user_indices.return_value = ["index1", "index2"] # Execute - result = ElasticSearchService.health_check(es_core=self.mock_es_core) + result = ElasticSearchService.health_check(vdb_core=self.mock_vdb_core) # Assert self.assertEqual(result["status"], "healthy") @@ -1466,8 +1685,7 @@ def test_get_random_documents_success_status_200(self): 3. The method completes without raising exceptions, implying a 200 status code """ # Setup - count_response = {'count': 100} - self.mock_es_core.client.count.return_value = count_response + self.mock_vdb_core.count_documents.return_value = 100 search_response = { 'hits': { @@ -1479,13 +1697,13 @@ def test_get_random_documents_success_status_200(self): ] } } - self.mock_es_core.client.search.return_value = search_response + self.mock_vdb_core.search.return_value = search_response # Execute result = ElasticSearchService.get_random_documents( index_name="test_index", batch_size=10, - es_core=self.mock_es_core + vdb_core=self.mock_vdb_core ) # Assert @@ -1511,7 +1729,7 @@ def test_semantic_search_success_status_200(self): search_request.query = "valid query" search_request.top_k = 10 - self.mock_es_core.semantic_search.return_value = [ + self.mock_vdb_core.semantic_search.return_value = [ { "document": {"title": "Doc1", "content": "Content1"}, "score": 0.85, @@ -1522,7 +1740,7 @@ def test_semantic_search_success_status_200(self): # Execute result = ElasticSearchService.semantic_search( request=search_request, - es_core=self.mock_es_core + vdb_core=self.mock_vdb_core ) # Assert @@ -1532,22 +1750,22 @@ def test_semantic_search_success_status_200(self): self.assertIn("results", result) self.assertIn("total", result) self.assertIn("query_time_ms", result) - self.mock_es_core.semantic_search.assert_called_once_with( + self.mock_vdb_core.semantic_search.assert_called_once_with( index_names=["test_index"], query="valid query", top_k=10 ) - def test_index_documents_success_status_200(self): + def test_vectorize_documents_success_status_200(self): """ - Test index_documents method returns status code 200 on success. + Test vectorize_documents method returns status code 200 on success. This test verifies that: - 1. The index_documents method successfully indexes multiple documents + 1. The vectorize_documents method successfully indexes multiple documents 2. The response indicates success and correct document counts 3. The method completes without raising exceptions, implying a 200 status code """ # Setup - self.mock_es_core.client.indices.exists.return_value = True - self.mock_es_core.index_documents.return_value = 3 + self.mock_vdb_core.check_index_exists.return_value = True + self.mock_vdb_core.vectorize_documents.return_value = 3 mock_embedding_model = MagicMock() mock_embedding_model.model = "test-model" @@ -1573,7 +1791,7 @@ def test_index_documents_success_status_200(self): result = ElasticSearchService.index_documents( index_name="test_index", data=test_data, - es_core=self.mock_es_core, + vdb_core=self.mock_vdb_core, embedding_model=mock_embedding_model ) @@ -1586,7 +1804,7 @@ def test_index_documents_success_status_200(self): self.assertIn("success", result) self.assertTrue(result["success"]) - @patch('backend.services.elasticsearch_service.delete_file') + @patch('backend.services.vectordatabase_service.delete_file') def test_delete_documents_success_status_200(self, mock_delete_file): """ Test delete_documents method returns status code 200 on success. @@ -1597,7 +1815,7 @@ def test_delete_documents_success_status_200(self, mock_delete_file): 3. The method completes without raising exceptions, implying a 200 status code """ # Setup - self.mock_es_core.delete_documents_by_path_or_url.return_value = 5 + self.mock_vdb_core.delete_documents.return_value = 5 # Configure delete_file to return a success response mock_delete_file.return_value = {"success": True, "object_name": "test_path"} @@ -1605,7 +1823,7 @@ def test_delete_documents_success_status_200(self, mock_delete_file): result = ElasticSearchService.delete_documents( index_name="test_index", path_or_url="test_path", - es_core=self.mock_es_core + vdb_core=self.mock_vdb_core ) # Assert @@ -1613,13 +1831,13 @@ def test_delete_documents_success_status_200(self, mock_delete_file): self.assertIsInstance(result, dict) self.assertEqual(result["status"], "success") self.assertEqual(result["deleted_minio"], True) - # Verify that delete_documents_by_path_or_url was called with correct parameters - self.mock_es_core.delete_documents_by_path_or_url.assert_called_once_with( + # Verify that delete_documents was called with correct parameters + self.mock_vdb_core.delete_documents.assert_called_once_with( "test_index", "test_path") # Verify that delete_file was called with the correct path mock_delete_file.assert_called_once_with("test_path") - @patch('backend.services.elasticsearch_service.get_knowledge_record') + @patch('backend.services.vectordatabase_service.get_knowledge_record') def test_get_summary_success_status_200(self, mock_get_record): """ Test get_summary method returns status code 200 on success. @@ -1646,12 +1864,12 @@ def test_get_summary_success_status_200(self, mock_get_record): self.assertEqual(result["status"], "success") mock_get_record.assert_called_once_with({'index_name': 'test_index'}) - @patch('backend.services.elasticsearch_service.get_redis_service') - @patch('backend.services.elasticsearch_service.get_knowledge_record') + @patch('backend.services.vectordatabase_service.get_redis_service') + @patch('backend.services.vectordatabase_service.get_knowledge_record') def test_check_kb_exist_orphan_in_es(self, mock_get_knowledge, mock_get_redis_service): """Test handling of orphaned knowledge base existing only in Elasticsearch.""" # Setup: ES index exists, PG record missing - self.mock_es_core.client.indices.exists.return_value = True + self.mock_vdb_core.check_index_exists.return_value = True mock_get_knowledge.return_value = None # Mock Redis service @@ -1663,24 +1881,24 @@ def test_check_kb_exist_orphan_in_es(self, mock_get_knowledge, mock_get_redis_se # Execute result = check_knowledge_base_exist_impl( index_name="test_index", - es_core=self.mock_es_core, + vdb_core=self.mock_vdb_core, user_id="test_user", tenant_id="tenant1" ) # Assert - self.mock_es_core.delete_index.assert_called_once_with("test_index") + self.mock_vdb_core.delete_index.assert_called_once_with("test_index") mock_redis_service.delete_knowledgebase_records.assert_called_once_with( "test_index") self.assertEqual(result["status"], "error_cleaning_orphans") self.assertEqual(result["action"], "cleaned_es") - @patch('backend.services.elasticsearch_service.delete_knowledge_record') - @patch('backend.services.elasticsearch_service.get_knowledge_record') + @patch('backend.services.vectordatabase_service.delete_knowledge_record') + @patch('backend.services.vectordatabase_service.get_knowledge_record') def test_check_kb_exist_orphan_in_pg(self, mock_get_knowledge, mock_delete_record): """Test handling of orphaned knowledge base existing only in PostgreSQL.""" # Setup: ES index missing, PG record exists - self.mock_es_core.client.indices.exists.return_value = False + self.mock_vdb_core.check_index_exists.return_value = False mock_get_knowledge.return_value = { "index_name": "test_index", "tenant_id": "tenant1"} mock_delete_record.return_value = True @@ -1688,7 +1906,7 @@ def test_check_kb_exist_orphan_in_pg(self, mock_get_knowledge, mock_delete_recor # Execute result = check_knowledge_base_exist_impl( index_name="test_index", - es_core=self.mock_es_core, + vdb_core=self.mock_vdb_core, user_id="test_user", tenant_id="tenant1" ) @@ -1698,17 +1916,17 @@ def test_check_kb_exist_orphan_in_pg(self, mock_get_knowledge, mock_delete_recor self.assertEqual(result["status"], "error_cleaning_orphans") self.assertEqual(result["action"], "cleaned_pg") - @patch('backend.services.elasticsearch_service.get_knowledge_record') + @patch('backend.services.vectordatabase_service.get_knowledge_record') def test_check_kb_exist_available(self, mock_get_knowledge): """Test knowledge base name availability when neither ES nor PG has the record.""" # Setup: ES index missing, PG record missing - self.mock_es_core.client.indices.exists.return_value = False + self.mock_vdb_core.check_index_exists.return_value = False mock_get_knowledge.return_value = None # Execute result = check_knowledge_base_exist_impl( index_name="test_index", - es_core=self.mock_es_core, + vdb_core=self.mock_vdb_core, user_id="test_user", tenant_id="tenant1" ) @@ -1716,18 +1934,18 @@ def test_check_kb_exist_available(self, mock_get_knowledge): # Assert self.assertEqual(result["status"], "available") - @patch('backend.services.elasticsearch_service.get_knowledge_record') + @patch('backend.services.vectordatabase_service.get_knowledge_record') def test_check_kb_exist_exists_in_tenant(self, mock_get_knowledge): """Test detection when knowledge base exists within the same tenant.""" # Setup: ES index exists, PG record exists with same tenant_id - self.mock_es_core.client.indices.exists.return_value = True + self.mock_vdb_core.check_index_exists.return_value = True mock_get_knowledge.return_value = { "index_name": "test_index", "tenant_id": "tenant1"} # Execute result = check_knowledge_base_exist_impl( index_name="test_index", - es_core=self.mock_es_core, + vdb_core=self.mock_vdb_core, user_id="test_user", tenant_id="tenant1" ) @@ -1735,18 +1953,18 @@ def test_check_kb_exist_exists_in_tenant(self, mock_get_knowledge): # Assert self.assertEqual(result["status"], "exists_in_tenant") - @patch('backend.services.elasticsearch_service.get_knowledge_record') + @patch('backend.services.vectordatabase_service.get_knowledge_record') def test_check_kb_exist_exists_in_other_tenant(self, mock_get_knowledge): """Test detection when knowledge base exists in a different tenant.""" # Setup: ES index exists, PG record exists with different tenant_id - self.mock_es_core.client.indices.exists.return_value = True + self.mock_vdb_core.check_index_exists.return_value = True mock_get_knowledge.return_value = { "index_name": "test_index", "tenant_id": "other_tenant"} # Execute result = check_knowledge_base_exist_impl( index_name="test_index", - es_core=self.mock_es_core, + vdb_core=self.mock_vdb_core, user_id="test_user", tenant_id="tenant1" ) @@ -1754,12 +1972,12 @@ def test_check_kb_exist_exists_in_other_tenant(self, mock_get_knowledge): # Assert self.assertEqual(result["status"], "exists_in_other_tenant") - @patch('backend.services.elasticsearch_service.get_redis_service') - @patch('backend.services.elasticsearch_service.get_knowledge_record') + @patch('backend.services.vectordatabase_service.get_redis_service') + @patch('backend.services.vectordatabase_service.get_knowledge_record') def test_check_kb_exist_orphan_in_es_redis_failure(self, mock_get_knowledge, mock_get_redis_service): """Test orphan ES case when Redis cleanup raises an exception.""" # Setup: ES index exists, PG record missing - self.mock_es_core.client.indices.exists.return_value = True + self.mock_vdb_core.check_index_exists.return_value = True mock_get_knowledge.return_value = None # Mock Redis service that raises an exception @@ -1771,46 +1989,46 @@ def test_check_kb_exist_orphan_in_es_redis_failure(self, mock_get_knowledge, moc # Execute result = check_knowledge_base_exist_impl( index_name="test_index", - es_core=self.mock_es_core, + vdb_core=self.mock_vdb_core, user_id="test_user", tenant_id="tenant1" ) # Assert: ES index deletion attempted, Redis cleanup attempted and exception handled - self.mock_es_core.delete_index.assert_called_once_with("test_index") + self.mock_vdb_core.delete_index.assert_called_once_with("test_index") mock_redis_service.delete_knowledgebase_records.assert_called_once_with( "test_index") self.assertEqual(result["status"], "error_cleaning_orphans") self.assertEqual(result["action"], "cleaned_es") - @patch('backend.services.elasticsearch_service.get_knowledge_record') + @patch('backend.services.vectordatabase_service.get_knowledge_record') def test_check_kb_exist_orphan_in_es_delete_failure(self, mock_get_knowledge): """Test failure when deleting orphan ES index raises an exception.""" # Setup: ES index exists, PG record missing, delete_index raises - self.mock_es_core.client.indices.exists.return_value = True + self.mock_vdb_core.check_index_exists.return_value = True mock_get_knowledge.return_value = None - self.mock_es_core.delete_index.side_effect = Exception( + self.mock_vdb_core.delete_index.side_effect = Exception( "Delete index failed") # Execute result = check_knowledge_base_exist_impl( index_name="test_index", - es_core=self.mock_es_core, + vdb_core=self.mock_vdb_core, user_id="test_user", tenant_id="tenant1" ) # Assert - self.mock_es_core.delete_index.assert_called_once_with("test_index") + self.mock_vdb_core.delete_index.assert_called_once_with("test_index") self.assertEqual(result["status"], "error_cleaning_orphans") self.assertTrue(result.get("error")) - @patch('backend.services.elasticsearch_service.delete_knowledge_record') - @patch('backend.services.elasticsearch_service.get_knowledge_record') + @patch('backend.services.vectordatabase_service.delete_knowledge_record') + @patch('backend.services.vectordatabase_service.get_knowledge_record') def test_check_kb_exist_orphan_in_pg_delete_failure(self, mock_get_knowledge, mock_delete_record): """Test failure when deleting orphan PG record raises an exception.""" # Setup: ES index missing, PG record exists, deletion raises - self.mock_es_core.client.indices.exists.return_value = False + self.mock_vdb_core.check_index_exists.return_value = False mock_get_knowledge.return_value = { "index_name": "test_index", "tenant_id": "tenant1"} mock_delete_record.side_effect = Exception("Delete PG record failed") @@ -1818,7 +2036,7 @@ def test_check_kb_exist_orphan_in_pg_delete_failure(self, mock_get_knowledge, mo # Execute result = check_knowledge_base_exist_impl( index_name="test_index", - es_core=self.mock_es_core, + vdb_core=self.mock_vdb_core, user_id="test_user", tenant_id="tenant1" ) @@ -1831,25 +2049,25 @@ def test_check_kb_exist_orphan_in_pg_delete_failure(self, mock_get_knowledge, mo # Note: generate_knowledge_summary_stream function has been removed # These tests are no longer relevant as the function was replaced with summary_index_name - def test_get_es_core(self): + def test_get_vdb_core(self): """ - Test get_es_core function returns the elastic_core instance. + Test get_vdb_core function returns the elastic_core instance. This test verifies that: - 1. The get_es_core function returns the correct elastic_core instance + 1. The get_vdb_core function returns the correct elastic_core instance 2. The function is properly imported and accessible """ - from backend.services.elasticsearch_service import get_es_core + from backend.services.vectordatabase_service import get_vector_db_core # Execute - result = get_es_core() + result = get_vector_db_core() # Assert self.assertIsNotNone(result) # The result should be the elastic_core instance self.assertTrue(hasattr(result, 'client')) - @patch('backend.services.elasticsearch_service.tenant_config_manager') + @patch('backend.services.vectordatabase_service.tenant_config_manager') def test_get_embedding_model_embedding_type(self, mock_tenant_config_manager): """ Test get_embedding_model with embedding model type. @@ -1872,32 +2090,31 @@ def test_get_embedding_model_embedding_type(self, mock_tenant_config_manager): self.get_embedding_model_patcher.stop() try: - with patch('backend.services.elasticsearch_service.OpenAICompatibleEmbedding') as mock_embedding_class: + with patch('backend.services.vectordatabase_service.OpenAICompatibleEmbedding') as mock_embedding_class, \ + patch('backend.services.vectordatabase_service.get_model_name_from_config') as mock_get_model_name: mock_embedding_instance = MagicMock() mock_embedding_class.return_value = mock_embedding_instance - - with patch('backend.services.elasticsearch_service.get_model_name_from_config') as mock_get_model_name: - mock_get_model_name.return_value = "test-model" - - # Execute - now we can call the real function - from backend.services.elasticsearch_service import get_embedding_model - result = get_embedding_model("test_tenant") - - # Assert - self.assertEqual(result, mock_embedding_instance) - mock_tenant_config_manager.get_model_config.assert_called_once_with( - key="EMBEDDING_ID", tenant_id="test_tenant") - mock_embedding_class.assert_called_once_with( - api_key="test_api_key", - base_url="https://test.api.com", - model_name="test-model", - embedding_dim=1024 - ) + mock_get_model_name.return_value = "test-model" + + # Execute - now we can call the real function + from backend.services.vectordatabase_service import get_embedding_model + result = get_embedding_model("test_tenant") + + # Assert + self.assertEqual(result, mock_embedding_instance) + mock_tenant_config_manager.get_model_config.assert_called_once_with( + key="EMBEDDING_ID", tenant_id="test_tenant") + mock_embedding_class.assert_called_once_with( + api_key="test_api_key", + base_url="https://test.api.com", + model_name="test-model", + embedding_dim=1024 + ) finally: # Restart the mock for other tests self.get_embedding_model_patcher.start() - @patch('backend.services.elasticsearch_service.tenant_config_manager') + @patch('backend.services.vectordatabase_service.tenant_config_manager') def test_get_embedding_model_multi_embedding_type(self, mock_tenant_config_manager): """ Test get_embedding_model with multi_embedding model type. @@ -1920,32 +2137,31 @@ def test_get_embedding_model_multi_embedding_type(self, mock_tenant_config_manag self.get_embedding_model_patcher.stop() try: - with patch('backend.services.elasticsearch_service.JinaEmbedding') as mock_embedding_class: + with patch('backend.services.vectordatabase_service.JinaEmbedding') as mock_embedding_class, \ + patch('backend.services.vectordatabase_service.get_model_name_from_config') as mock_get_model_name: mock_embedding_instance = MagicMock() mock_embedding_class.return_value = mock_embedding_instance - - with patch('backend.services.elasticsearch_service.get_model_name_from_config') as mock_get_model_name: - mock_get_model_name.return_value = "test-model" - - # Execute - now we can call the real function - from backend.services.elasticsearch_service import get_embedding_model - result = get_embedding_model("test_tenant") - - # Assert - self.assertEqual(result, mock_embedding_instance) - mock_tenant_config_manager.get_model_config.assert_called_once_with( - key="EMBEDDING_ID", tenant_id="test_tenant") - mock_embedding_class.assert_called_once_with( - api_key="test_api_key", - base_url="https://test.api.com", - model_name="test-model", - embedding_dim=2048 - ) + mock_get_model_name.return_value = "test-model" + + # Execute - now we can call the real function + from backend.services.vectordatabase_service import get_embedding_model + result = get_embedding_model("test_tenant") + + # Assert + self.assertEqual(result, mock_embedding_instance) + mock_tenant_config_manager.get_model_config.assert_called_once_with( + key="EMBEDDING_ID", tenant_id="test_tenant") + mock_embedding_class.assert_called_once_with( + api_key="test_api_key", + base_url="https://test.api.com", + model_name="test-model", + embedding_dim=2048 + ) finally: # Restart the mock for other tests self.get_embedding_model_patcher.start() - @patch('backend.services.elasticsearch_service.tenant_config_manager') + @patch('backend.services.vectordatabase_service.tenant_config_manager') def test_get_embedding_model_unknown_type(self, mock_tenant_config_manager): """ Test get_embedding_model with unknown model type. @@ -1969,7 +2185,7 @@ def test_get_embedding_model_unknown_type(self, mock_tenant_config_manager): try: # Execute - now we can call the real function - from backend.services.elasticsearch_service import get_embedding_model + from backend.services.vectordatabase_service import get_embedding_model result = get_embedding_model("test_tenant") # Assert @@ -1980,7 +2196,7 @@ def test_get_embedding_model_unknown_type(self, mock_tenant_config_manager): # Restart the mock for other tests self.get_embedding_model_patcher.start() - @patch('backend.services.elasticsearch_service.tenant_config_manager') + @patch('backend.services.vectordatabase_service.tenant_config_manager') def test_get_embedding_model_empty_type(self, mock_tenant_config_manager): """ Test get_embedding_model with empty model type. @@ -2004,7 +2220,7 @@ def test_get_embedding_model_empty_type(self, mock_tenant_config_manager): try: # Execute - now we can call the real function - from backend.services.elasticsearch_service import get_embedding_model + from backend.services.vectordatabase_service import get_embedding_model result = get_embedding_model("test_tenant") # Assert @@ -2015,7 +2231,7 @@ def test_get_embedding_model_empty_type(self, mock_tenant_config_manager): # Restart the mock for other tests self.get_embedding_model_patcher.start() - @patch('backend.services.elasticsearch_service.tenant_config_manager') + @patch('backend.services.vectordatabase_service.tenant_config_manager') def test_get_embedding_model_missing_type(self, mock_tenant_config_manager): """ Test get_embedding_model with missing model type. @@ -2038,7 +2254,7 @@ def test_get_embedding_model_missing_type(self, mock_tenant_config_manager): try: # Execute - now we can call the real function - from backend.services.elasticsearch_service import get_embedding_model + from backend.services.vectordatabase_service import get_embedding_model result = get_embedding_model("test_tenant") # Assert @@ -2052,3 +2268,4 @@ def test_get_embedding_model_missing_type(self, mock_tenant_config_manager): if __name__ == '__main__': unittest.main() + diff --git a/test/backend/test_config_service.py b/test/backend/test_config_service.py new file mode 100644 index 000000000..77f318b10 --- /dev/null +++ b/test/backend/test_config_service.py @@ -0,0 +1,399 @@ +import os +import sys +import asyncio +import types +from unittest.mock import patch, MagicMock, AsyncMock + +import pytest + +# Dynamically determine the backend path - MUST BE FIRST +current_dir = os.path.dirname(os.path.abspath(__file__)) +backend_dir = os.path.abspath(os.path.join(current_dir, "../../backend")) +sys.path.insert(0, backend_dir) + +# Patch environment variables before any imports that might use them +os.environ.setdefault('MINIO_ENDPOINT', 'http://localhost:9000') +os.environ.setdefault('MINIO_ACCESS_KEY', 'minioadmin') +os.environ.setdefault('MINIO_SECRET_KEY', 'minioadmin') +os.environ.setdefault('MINIO_REGION', 'us-east-1') +os.environ.setdefault('MINIO_DEFAULT_BUCKET', 'test-bucket') + +# Mock boto3 before importing the module under test +boto3_mock = MagicMock() +minio_client_mock = MagicMock() +sys.modules['boto3'] = boto3_mock + +# Mock nexent modules before importing modules that use them +nexent_mock = MagicMock() +sys.modules['nexent'] = nexent_mock +sys.modules['nexent.core'] = MagicMock() +sys.modules['nexent.core.models'] = MagicMock() +sys.modules['nexent.core.models.embedding_model'] = MagicMock() +sys.modules['nexent.core.nlp'] = MagicMock() +sys.modules['nexent.core.nlp.tokenizer'] = MagicMock() +sys.modules['nexent.vector_database'] = MagicMock() +sys.modules['nexent.vector_database.elasticsearch_core'] = MagicMock() +sys.modules['nexent.core.agents'] = MagicMock() +sys.modules['nexent.core.agents.agent_model'] = MagicMock() +sys.modules['nexent.storage.storage_client_factory'] = MagicMock() + +# Patch storage factory and MinIO config validation to avoid errors during initialization +# These patches must be started before any imports that use MinioClient +storage_client_mock = MagicMock() +patch('nexent.storage.storage_client_factory.create_storage_client_from_config', return_value=storage_client_mock).start() +patch('nexent.storage.minio_config.MinIOStorageConfig.validate', lambda self: None).start() +patch('backend.database.client.MinioClient', return_value=minio_client_mock).start() + +# Create stub vector database modules to satisfy imports +vector_db_module = types.ModuleType("nexent.vector_database") +vector_db_module.__path__ = [] # Mark as package +vector_db_base_module = types.ModuleType("nexent.vector_database.base") + +class MockVectorDatabaseCore: + def __init__(self, *args, **kwargs): + pass + +vector_db_base_module.VectorDatabaseCore = MockVectorDatabaseCore + +vector_db_es_module = types.ModuleType("nexent.vector_database.elasticsearch_core") + +class MockElasticSearchCore: + def __init__(self, *args, **kwargs): + pass + +vector_db_es_module.ElasticSearchCore = MockElasticSearchCore + +sys.modules['nexent.vector_database'] = vector_db_module +sys.modules['nexent.vector_database.base'] = vector_db_base_module +sys.modules['nexent.vector_database.elasticsearch_core'] = vector_db_es_module +setattr(vector_db_module, "base", vector_db_base_module) +setattr(vector_db_module, "elasticsearch_core", vector_db_es_module) + +# Pre-inject a stubbed base_app to avoid import side effects +backend_pkg = types.ModuleType("backend") +apps_pkg = types.ModuleType("backend.apps") +base_app_mod = types.ModuleType("backend.apps.base_app") +base_app_mod.app = MagicMock() + +# Install stubs into sys.modules +sys.modules.setdefault("backend", backend_pkg) +sys.modules["backend.apps"] = apps_pkg +sys.modules["backend.apps.base_app"] = base_app_mod + +# Also stub non-namespaced imports used by the application +apps_pkg_flat = types.ModuleType("apps") +base_app_mod_flat = types.ModuleType("apps.config_app") +base_app_mod_flat.app = MagicMock() +sys.modules["apps"] = apps_pkg_flat +sys.modules["apps.config_app"] = base_app_mod_flat +setattr(apps_pkg_flat, "config_app", base_app_mod_flat) + +# Wire package attributes +setattr(backend_pkg, "apps", apps_pkg) +setattr(apps_pkg, "config_app", base_app_mod) + +# Mock external dependencies before importing backend modules +with patch('backend.database.client.MinioClient', return_value=minio_client_mock), \ + patch('elasticsearch.Elasticsearch', return_value=MagicMock()), \ + patch('nexent.vector_database.elasticsearch_core.ElasticSearchCore', return_value=MagicMock()): + # Mock dotenv before importing config_service + with patch('dotenv.load_dotenv'): + # Mock logging configuration + with patch('utils.logging_utils.configure_logging'), \ + patch('utils.logging_utils.configure_elasticsearch_logging'): + from config_service import startup_initialization + + +class TestMainService: + """Test cases for config_service module""" + + @pytest.mark.asyncio + @patch('config_service.initialize_tools_on_startup', new_callable=AsyncMock) + @patch('config_service.logger') + async def test_startup_initialization_success(self, mock_logger, mock_initialize_tools): + """ + Test successful startup initialization. + + This test verifies that: + 1. The function logs the start of initialization + 2. It logs the APP version + 3. It calls initialize_tools_on_startup + 4. It logs successful completion + """ + # Setup + mock_initialize_tools.return_value = None + + # Execute + await startup_initialization() + + # Assert + # Check that appropriate log messages were called + mock_logger.info.assert_any_call("Starting server initialization...") + mock_logger.info.assert_any_call( + "Server initialization completed successfully!") + + # Verify initialize_tools_on_startup was called + mock_initialize_tools.assert_called_once() + + @pytest.mark.asyncio + @patch('config_service.initialize_tools_on_startup', new_callable=AsyncMock) + @patch('config_service.logger') + async def test_startup_initialization_with_version_log(self, mock_logger, mock_initialize_tools): + """ + Test that startup initialization logs the APP version. + + This test verifies that: + 1. The function logs the APP version from consts.const + """ + # Setup + mock_initialize_tools.return_value = None + + # Execute + await startup_initialization() + + # Assert + # Check that version logging was called (should contain "APP version is:") + version_logged = any( + call for call in mock_logger.info.call_args_list + if len(call.args) > 0 and "APP version is:" in str(call.args[0]) + ) + assert version_logged, "APP version should be logged during initialization" + + @pytest.mark.asyncio + @patch('config_service.initialize_tools_on_startup', new_callable=AsyncMock) + @patch('config_service.logger') + async def test_startup_initialization_tool_initialization_failure(self, mock_logger, mock_initialize_tools): + """ + Test startup initialization when tool initialization fails. + + This test verifies that: + 1. When initialize_tools_on_startup raises an exception + 2. The function catches the exception and logs an error + 3. The function logs a warning about continuing despite issues + 4. The function does not re-raise the exception + """ + # Setup + mock_initialize_tools.side_effect = Exception( + "Tool initialization failed") + + # Execute - should not raise exception + await startup_initialization() + + # Assert + mock_logger.error.assert_called_once() + error_call = mock_logger.error.call_args[0][0] + assert "Server initialization failed:" in error_call + assert "Tool initialization failed" in error_call + + mock_logger.warning.assert_called_once_with( + "Server will continue to start despite initialization issues" + ) + + # Verify initialize_tools_on_startup was called + mock_initialize_tools.assert_called_once() + + @pytest.mark.asyncio + @patch('config_service.initialize_tools_on_startup', new_callable=AsyncMock) + @patch('config_service.logger') + async def test_startup_initialization_database_error(self, mock_logger, mock_initialize_tools): + """ + Test startup initialization when database connection fails. + + This test verifies that: + 1. Database-related exceptions are handled gracefully + 2. Appropriate error messages are logged + 3. The server startup is not blocked + """ + # Setup + mock_initialize_tools.side_effect = ConnectionError( + "Database connection failed") + + # Execute - should not raise exception + await startup_initialization() + + # Assert + mock_logger.error.assert_called_once() + error_message = mock_logger.error.call_args[0][0] + assert "Server initialization failed:" in error_message + assert "Database connection failed" in error_message + + mock_logger.warning.assert_called_once_with( + "Server will continue to start despite initialization issues" + ) + + @pytest.mark.asyncio + @patch('config_service.initialize_tools_on_startup', new_callable=AsyncMock) + @patch('config_service.logger') + async def test_startup_initialization_timeout_error(self, mock_logger, mock_initialize_tools): + """ + Test startup initialization when tool initialization times out. + + This test verifies that: + 1. Timeout exceptions are handled gracefully + 2. Appropriate error messages are logged + 3. The function continues execution + """ + # Setup + mock_initialize_tools.side_effect = asyncio.TimeoutError( + "Tool initialization timed out") + + # Execute - should not raise exception + await startup_initialization() + + # Assert + mock_logger.error.assert_called_once() + error_message = mock_logger.error.call_args[0][0] + assert "Server initialization failed:" in error_message + + mock_logger.warning.assert_called_once_with( + "Server will continue to start despite initialization issues" + ) + + @pytest.mark.asyncio + @patch('config_service.initialize_tools_on_startup', new_callable=AsyncMock) + @patch('config_service.logger') + async def test_startup_initialization_multiple_calls_safe(self, mock_logger, mock_initialize_tools): + """ + Test that multiple calls to startup_initialization are safe. + + This test verifies that: + 1. The function can be called multiple times without issues + 2. Each call properly executes the initialization sequence + """ + # Setup + mock_initialize_tools.return_value = None + + # Execute multiple times + await startup_initialization() + await startup_initialization() + + # Assert + assert mock_initialize_tools.call_count == 2 + # At least 2 calls * 2 info messages per call + assert mock_logger.info.call_count >= 4 + + @pytest.mark.asyncio + @patch('config_service.initialize_tools_on_startup', new_callable=AsyncMock) + @patch('config_service.logger') + async def test_startup_initialization_logging_order(self, mock_logger, mock_initialize_tools): + """ + Test that logging occurs in the correct order during initialization. + + This test verifies that: + 1. Start message is logged first + 2. Version message is logged second + 3. Success message is logged last (when successful) + """ + # Setup + mock_initialize_tools.return_value = None + + # Execute + await startup_initialization() + + # Assert + info_calls = [call.args[0] for call in mock_logger.info.call_args_list] + + # Check order of log messages + assert len(info_calls) >= 3 + assert "Starting server initialization..." in info_calls[0] + assert "APP version is:" in info_calls[1] + assert "Server initialization completed successfully!" in info_calls[-1] + + @pytest.mark.asyncio + @patch('config_service.initialize_tools_on_startup', new_callable=AsyncMock) + @patch('config_service.logger') + async def test_startup_initialization_exception_details_logged(self, mock_logger, mock_initialize_tools): + """ + Test that exception details are properly logged. + + This test verifies that: + 1. The specific exception message is included in error logs + 2. Both error and warning messages are logged on failure + """ + # Setup + specific_error_message = "Specific tool configuration error occurred" + mock_initialize_tools.side_effect = ValueError(specific_error_message) + + # Execute + await startup_initialization() + + # Assert + mock_logger.error.assert_called_once() + error_call_args = mock_logger.error.call_args[0][0] + assert specific_error_message in error_call_args + assert "Server initialization failed:" in error_call_args + + mock_logger.warning.assert_called_once() + + @pytest.mark.asyncio + @patch('config_service.initialize_tools_on_startup', new_callable=AsyncMock) + @patch('config_service.logger') + async def test_startup_initialization_no_exception_propagation(self, mock_logger, mock_initialize_tools): + """ + Test that exceptions during initialization do not propagate. + + This test verifies that: + 1. Even when initialize_tools_on_startup fails, no exception is raised + 2. This allows the server to continue starting up + """ + # Setup + mock_initialize_tools.side_effect = RuntimeError( + "Critical initialization error") + + # Execute and Assert - should not raise any exception + try: + await startup_initialization() + except Exception as e: + pytest.fail( + f"startup_initialization should not raise exceptions, but raised: {e}") + + # Verify that error handling occurred + mock_logger.error.assert_called_once() + mock_logger.warning.assert_called_once() + + +class TestMainServiceModuleIntegration: + """Integration tests for config_service module dependencies""" + + @patch('config_service.configure_logging') + @patch('config_service.configure_elasticsearch_logging') + def test_logging_configuration_called_on_import(self, mock_configure_es, mock_configure_logging): + """ + Test that logging configuration functions are called when module is imported. + + This test verifies that: + 1. configure_logging is called with logging.INFO + 2. configure_elasticsearch_logging is called + """ + # Note: This test checks that logging configuration happens during module import + # The mocks should have been called when the module was imported + # In a real scenario, you might need to reload the module to test this properly + pass # The actual verification would depend on how the test runner handles imports + + @patch('config_service.APP_VERSION', 'test_version_1.2.3') + @patch('config_service.initialize_tools_on_startup', new_callable=AsyncMock) + @patch('config_service.logger') + async def test_startup_initialization_with_custom_version(self, mock_logger, mock_initialize_tools): + """ + Test startup initialization with a custom APP_VERSION. + + This test verifies that: + 1. The custom version is properly logged + """ + # Setup + mock_initialize_tools.return_value = None + + # Execute + await startup_initialization() + + # Assert + version_logged = any( + "test_version_1.2.3" in str(call.args[0]) + for call in mock_logger.info.call_args_list + if len(call.args) > 0 + ) + assert version_logged, "Custom APP version should be logged" + + +if __name__ == '__main__': + pytest.main() diff --git a/test/backend/test_document_vector_utils.py b/test/backend/test_document_vector_utils.py index 8f6d0c5ba..f03ed3346 100644 --- a/test/backend/test_document_vector_utils.py +++ b/test/backend/test_document_vector_utils.py @@ -339,9 +339,9 @@ class TestGetDocumentsFromEs: """Test ES document retrieval""" def test_get_documents_from_es_mock(self): - """Test ES document retrieval with mocked client""" - mock_es_core = MagicMock() - mock_es_core.client.search.return_value = { + """Test ES document retrieval with mocked VectorDatabaseCore search""" + mock_vdb_core = MagicMock() + mock_vdb_core.search.return_value = { 'hits': { 'hits': [ { @@ -367,17 +367,17 @@ def test_get_documents_from_es_mock(self): } } - result = get_documents_from_es('test_index', mock_es_core, sample_doc_count=10) + result = get_documents_from_es( + 'test_index', mock_vdb_core, sample_doc_count=10) assert isinstance(result, dict) - # The function returns a dict with document IDs as keys, not 'documents' key assert len(result) > 0 # Check that we have document data first_doc = list(result.values())[0] assert 'chunks' in first_doc # Verify that sort parameter is included in the query - call_args = mock_es_core.client.search.call_args + call_args = mock_vdb_core.search.call_args if call_args: query_body = call_args[1].get('body') or call_args[0][1] if len(call_args[0]) > 1 else None if query_body and 'sort' in query_body: @@ -392,8 +392,8 @@ class TestProcessDocumentsForClustering: def test_process_documents_for_clustering_mock(self): """Test document processing with mocked functions""" - mock_es_core = MagicMock() - mock_es_core.client.search.return_value = { + mock_vdb_core = MagicMock() + mock_vdb_core.client.search.return_value = { 'hits': { 'hits': [ { @@ -422,7 +422,7 @@ def test_process_documents_for_clustering_mock(self): mock_calc_embedding.return_value = np.array([1.0, 2.0, 3.0]) documents, embeddings = process_documents_for_clustering( - 'test_index', mock_es_core, sample_doc_count=10 + 'test_index', mock_vdb_core, sample_doc_count=10 ) assert isinstance(documents, dict) diff --git a/test/backend/test_document_vector_utils_coverage.py b/test/backend/test_document_vector_utils_coverage.py index eabcfe697..b442e47e4 100644 --- a/test/backend/test_document_vector_utils_coverage.py +++ b/test/backend/test_document_vector_utils_coverage.py @@ -39,8 +39,8 @@ class TestGetDocumentsFromES: def test_get_documents_from_es_success(self): """Test successful document retrieval from ES""" - mock_es_core = MagicMock() - mock_es_core.client.search.return_value = { + mock_vdb_core = MagicMock() + mock_vdb_core.search.return_value = { 'aggregations': { 'unique_documents': { 'buckets': [ @@ -63,14 +63,14 @@ def test_get_documents_from_es_success(self): } } - result = get_documents_from_es('test_index', mock_es_core, sample_doc_count=10) + result = get_documents_from_es('test_index', mock_vdb_core, sample_doc_count=10) assert isinstance(result, dict) - assert mock_es_core.client.search.called + assert mock_vdb_core.search.called def test_get_documents_from_es_empty(self): """Test ES retrieval with no documents""" - mock_es_core = MagicMock() - mock_es_core.client.search.return_value = { + mock_vdb_core = MagicMock() + mock_vdb_core.search.return_value = { 'aggregations': { 'unique_documents': { 'buckets': [] @@ -78,16 +78,16 @@ def test_get_documents_from_es_empty(self): } } - result = get_documents_from_es('test_index', mock_es_core) + result = get_documents_from_es('test_index', mock_vdb_core) assert result == {} def test_get_documents_from_es_error(self): """Test ES retrieval error handling""" - mock_es_core = MagicMock() - mock_es_core.client.search.side_effect = Exception("ES error") + mock_vdb_core = MagicMock() + mock_vdb_core.search.side_effect = Exception("ES error") with pytest.raises(Exception, match="Failed to retrieve documents from Elasticsearch"): - get_documents_from_es('test_index', mock_es_core) + get_documents_from_es('test_index', mock_vdb_core) class TestProcessDocumentsForClustering: @@ -105,8 +105,8 @@ def test_process_documents_success(self, mock_calc_emb, mock_get_docs): } mock_calc_emb.return_value = np.array([0.1, 0.2, 0.3]) - mock_es_core = MagicMock() - docs, embeddings = process_documents_for_clustering('test_index', mock_es_core) + mock_vdb_core = MagicMock() + docs, embeddings = process_documents_for_clustering('test_index', mock_vdb_core) assert isinstance(docs, dict) assert isinstance(embeddings, dict) @@ -118,8 +118,8 @@ def test_process_documents_empty(self, mock_get_docs): """Test processing with no documents""" mock_get_docs.return_value = {} - mock_es_core = MagicMock() - docs, embeddings = process_documents_for_clustering('test_index', mock_es_core) + mock_vdb_core = MagicMock() + docs, embeddings = process_documents_for_clustering('test_index', mock_vdb_core) assert docs == {} assert embeddings == {} @@ -318,10 +318,10 @@ class TestAdditionalCoverage: def test_get_documents_from_es_non_list_documents(self): """Test ES retrieval when all_documents is not a list""" - mock_es_core = MagicMock() + mock_vdb_core = MagicMock() # Mock the first search call to return a tuple instead of list - mock_es_core.client.search.side_effect = [ + mock_vdb_core.client.search.side_effect = [ { 'aggregations': { 'unique_documents': { @@ -347,13 +347,13 @@ def test_get_documents_from_es_non_list_documents(self): } ] - result = get_documents_from_es('test_index', mock_es_core) + result = get_documents_from_es('test_index', mock_vdb_core) assert isinstance(result, dict) def test_get_documents_from_es_no_chunks(self): """Test ES retrieval when document has no chunks""" - mock_es_core = MagicMock() - mock_es_core.client.search.side_effect = [ + mock_vdb_core = MagicMock() + mock_vdb_core.client.search.side_effect = [ { 'aggregations': { 'unique_documents': { @@ -370,7 +370,7 @@ def test_get_documents_from_es_no_chunks(self): } ] - result = get_documents_from_es('test_index', mock_es_core) + result = get_documents_from_es('test_index', mock_vdb_core) assert result == {} # Should return empty dict when no chunks def test_calculate_document_embedding_exception(self): @@ -428,16 +428,16 @@ def test_kmeans_cluster_documents_exception(self): def test_process_documents_for_clustering_exception(self): """Test process_documents_for_clustering with exception""" - mock_es_core = MagicMock() - mock_es_core.client.search.side_effect = Exception("ES error") + mock_vdb_core = MagicMock() + mock_vdb_core.search.side_effect = Exception("ES error") with pytest.raises(Exception, match="Failed to process documents"): - process_documents_for_clustering('test_index', mock_es_core) + process_documents_for_clustering('test_index', mock_vdb_core) def test_process_documents_for_clustering_no_embeddings(self): """Test process_documents_for_clustering when some documents fail embedding calculation""" - mock_es_core = MagicMock() - mock_es_core.client.search.return_value = { + mock_vdb_core = MagicMock() + mock_vdb_core.search.return_value = { 'aggregations': { 'unique_documents': { 'buckets': [ @@ -463,7 +463,7 @@ def test_process_documents_for_clustering_no_embeddings(self): with patch('backend.utils.document_vector_utils.calculate_document_embedding') as mock_calc: mock_calc.return_value = None - docs, embeddings = process_documents_for_clustering('test_index', mock_es_core) + docs, embeddings = process_documents_for_clustering('test_index', mock_vdb_core) assert isinstance(docs, dict) assert isinstance(embeddings, dict) assert len(embeddings) == 0 # No successful embeddings diff --git a/test/backend/test_main_service.py b/test/backend/test_runtime_service.py similarity index 82% rename from test/backend/test_main_service.py rename to test/backend/test_runtime_service.py index 1939b967e..cfe77604d 100644 --- a/test/backend/test_main_service.py +++ b/test/backend/test_runtime_service.py @@ -31,8 +31,32 @@ sys.modules['nexent.core.models.embedding_model'] = MagicMock() sys.modules['nexent.core.nlp'] = MagicMock() sys.modules['nexent.core.nlp.tokenizer'] = MagicMock() -sys.modules['nexent.vector_database'] = MagicMock() -sys.modules['nexent.vector_database.elasticsearch_core'] = MagicMock() + +# Create stub vector database modules to satisfy imports +vector_db_module = types.ModuleType("nexent.vector_database") +vector_db_module.__path__ = [] # Mark as package +vector_db_base_module = types.ModuleType("nexent.vector_database.base") + +class MockVectorDatabaseCore: + def __init__(self, *args, **kwargs): + pass + +vector_db_base_module.VectorDatabaseCore = MockVectorDatabaseCore + +vector_db_es_module = types.ModuleType("nexent.vector_database.elasticsearch_core") + +class MockElasticSearchCore: + def __init__(self, *args, **kwargs): + pass + +vector_db_es_module.ElasticSearchCore = MockElasticSearchCore + +sys.modules['nexent.vector_database'] = vector_db_module +sys.modules['nexent.vector_database.base'] = vector_db_base_module +sys.modules['nexent.vector_database.elasticsearch_core'] = vector_db_es_module +setattr(vector_db_module, "base", vector_db_base_module) +setattr(vector_db_module, "elasticsearch_core", vector_db_es_module) + sys.modules['nexent.core.agents'] = MagicMock() sys.modules['nexent.core.agents.agent_model'] = MagicMock() sys.modules['nexent.storage.storage_client_factory'] = MagicMock() @@ -57,33 +81,33 @@ # Also stub non-namespaced imports used by the application apps_pkg_flat = types.ModuleType("apps") -base_app_mod_flat = types.ModuleType("apps.base_app") +base_app_mod_flat = types.ModuleType("apps.runtime_app") base_app_mod_flat.app = MagicMock() sys.modules["apps"] = apps_pkg_flat -sys.modules["apps.base_app"] = base_app_mod_flat -setattr(apps_pkg_flat, "base_app", base_app_mod_flat) +sys.modules["apps.runtime_app"] = base_app_mod_flat +setattr(apps_pkg_flat, "runtime_app", base_app_mod_flat) # Wire package attributes setattr(backend_pkg, "apps", apps_pkg) -setattr(apps_pkg, "base_app", base_app_mod) +setattr(apps_pkg, "runtime_app", base_app_mod) # Mock external dependencies before importing backend modules with patch('elasticsearch.Elasticsearch', return_value=MagicMock()), \ patch('nexent.vector_database.elasticsearch_core.ElasticSearchCore', return_value=MagicMock()): - # Mock dotenv before importing main_service + # Mock dotenv before importing runtime_service with patch('dotenv.load_dotenv'): # Mock logging configuration with patch('utils.logging_utils.configure_logging'), \ patch('utils.logging_utils.configure_elasticsearch_logging'): - from main_service import startup_initialization + from runtime_service import startup_initialization class TestMainService: - """Test cases for main_service module""" + """Test cases for runtime_service module""" @pytest.mark.asyncio - @patch('main_service.initialize_tools_on_startup', new_callable=AsyncMock) - @patch('main_service.logger') + @patch('runtime_service.initialize_tools_on_startup', new_callable=AsyncMock) + @patch('runtime_service.logger') async def test_startup_initialization_success(self, mock_logger, mock_initialize_tools): """ Test successful startup initialization. @@ -110,8 +134,8 @@ async def test_startup_initialization_success(self, mock_logger, mock_initialize mock_initialize_tools.assert_called_once() @pytest.mark.asyncio - @patch('main_service.initialize_tools_on_startup', new_callable=AsyncMock) - @patch('main_service.logger') + @patch('runtime_service.initialize_tools_on_startup', new_callable=AsyncMock) + @patch('runtime_service.logger') async def test_startup_initialization_with_version_log(self, mock_logger, mock_initialize_tools): """ Test that startup initialization logs the APP version. @@ -134,8 +158,8 @@ async def test_startup_initialization_with_version_log(self, mock_logger, mock_i assert version_logged, "APP version should be logged during initialization" @pytest.mark.asyncio - @patch('main_service.initialize_tools_on_startup', new_callable=AsyncMock) - @patch('main_service.logger') + @patch('runtime_service.initialize_tools_on_startup', new_callable=AsyncMock) + @patch('runtime_service.logger') async def test_startup_initialization_tool_initialization_failure(self, mock_logger, mock_initialize_tools): """ Test startup initialization when tool initialization fails. @@ -167,8 +191,8 @@ async def test_startup_initialization_tool_initialization_failure(self, mock_log mock_initialize_tools.assert_called_once() @pytest.mark.asyncio - @patch('main_service.initialize_tools_on_startup', new_callable=AsyncMock) - @patch('main_service.logger') + @patch('runtime_service.initialize_tools_on_startup', new_callable=AsyncMock) + @patch('runtime_service.logger') async def test_startup_initialization_database_error(self, mock_logger, mock_initialize_tools): """ Test startup initialization when database connection fails. @@ -196,8 +220,8 @@ async def test_startup_initialization_database_error(self, mock_logger, mock_ini ) @pytest.mark.asyncio - @patch('main_service.initialize_tools_on_startup', new_callable=AsyncMock) - @patch('main_service.logger') + @patch('runtime_service.initialize_tools_on_startup', new_callable=AsyncMock) + @patch('runtime_service.logger') async def test_startup_initialization_timeout_error(self, mock_logger, mock_initialize_tools): """ Test startup initialization when tool initialization times out. @@ -224,8 +248,8 @@ async def test_startup_initialization_timeout_error(self, mock_logger, mock_init ) @pytest.mark.asyncio - @patch('main_service.initialize_tools_on_startup', new_callable=AsyncMock) - @patch('main_service.logger') + @patch('runtime_service.initialize_tools_on_startup', new_callable=AsyncMock) + @patch('runtime_service.logger') async def test_startup_initialization_multiple_calls_safe(self, mock_logger, mock_initialize_tools): """ Test that multiple calls to startup_initialization are safe. @@ -247,8 +271,8 @@ async def test_startup_initialization_multiple_calls_safe(self, mock_logger, moc assert mock_logger.info.call_count >= 4 @pytest.mark.asyncio - @patch('main_service.initialize_tools_on_startup', new_callable=AsyncMock) - @patch('main_service.logger') + @patch('runtime_service.initialize_tools_on_startup', new_callable=AsyncMock) + @patch('runtime_service.logger') async def test_startup_initialization_logging_order(self, mock_logger, mock_initialize_tools): """ Test that logging occurs in the correct order during initialization. @@ -274,8 +298,8 @@ async def test_startup_initialization_logging_order(self, mock_logger, mock_init assert "Server initialization completed successfully!" in info_calls[-1] @pytest.mark.asyncio - @patch('main_service.initialize_tools_on_startup', new_callable=AsyncMock) - @patch('main_service.logger') + @patch('runtime_service.initialize_tools_on_startup', new_callable=AsyncMock) + @patch('runtime_service.logger') async def test_startup_initialization_exception_details_logged(self, mock_logger, mock_initialize_tools): """ Test that exception details are properly logged. @@ -300,8 +324,8 @@ async def test_startup_initialization_exception_details_logged(self, mock_logger mock_logger.warning.assert_called_once() @pytest.mark.asyncio - @patch('main_service.initialize_tools_on_startup', new_callable=AsyncMock) - @patch('main_service.logger') + @patch('runtime_service.initialize_tools_on_startup', new_callable=AsyncMock) + @patch('runtime_service.logger') async def test_startup_initialization_no_exception_propagation(self, mock_logger, mock_initialize_tools): """ Test that exceptions during initialization do not propagate. @@ -327,10 +351,10 @@ async def test_startup_initialization_no_exception_propagation(self, mock_logger class TestMainServiceModuleIntegration: - """Integration tests for main_service module dependencies""" + """Integration tests for runtime_service module dependencies""" - @patch('main_service.configure_logging') - @patch('main_service.configure_elasticsearch_logging') + @patch('runtime_service.configure_logging') + @patch('runtime_service.configure_elasticsearch_logging') def test_logging_configuration_called_on_import(self, mock_configure_es, mock_configure_logging): """ Test that logging configuration functions are called when module is imported. @@ -344,9 +368,9 @@ def test_logging_configuration_called_on_import(self, mock_configure_es, mock_co # In a real scenario, you might need to reload the module to test this properly pass # The actual verification would depend on how the test runner handles imports - @patch('main_service.APP_VERSION', 'test_version_1.2.3') - @patch('main_service.initialize_tools_on_startup', new_callable=AsyncMock) - @patch('main_service.logger') + @patch('runtime_service.APP_VERSION', 'test_version_1.2.3') + @patch('runtime_service.initialize_tools_on_startup', new_callable=AsyncMock) + @patch('runtime_service.logger') async def test_startup_initialization_with_custom_version(self, mock_logger, mock_initialize_tools): """ Test startup initialization with a custom APP_VERSION. diff --git a/test/sdk/core/agents/test_core_agent.py b/test/sdk/core/agents/test_core_agent.py index 3027da3c1..cb6240893 100644 --- a/test/sdk/core/agents/test_core_agent.py +++ b/test/sdk/core/agents/test_core_agent.py @@ -231,15 +231,15 @@ def test_run_with_final_answer_error_and_model_output(core_agent_instance): # Create a mock action step with model_output mock_action_step = MagicMock() - mock_action_step.model_output = "```\nprint('hello')\n```" + mock_action_step.model_output = "```\nprint('hello')\n```" # Mock _execute_step to set model_output and then raise FinalAnswerError def mock_execute_step(action_step): - action_step.model_output = "```\nprint('hello')\n```" + action_step.model_output = "```\nprint('hello')\n```" raise core_agent_module.FinalAnswerError() with patch.object(core_agent_instance, '_execute_step', side_effect=mock_execute_step), \ - patch.object(core_agent_module, 'convert_code_format', return_value="```python\nprint('hello')\n```") as mock_convert, \ + patch.object(core_agent_module, 'convert_code_format', return_value="```python\nprint('hello')\n```") as mock_convert, \ patch.object(core_agent_instance, '_finalize_step'): # Execute result = list(core_agent_instance._run_stream(task, max_steps)) @@ -250,7 +250,7 @@ def mock_execute_step(action_step): assert isinstance(result[1], MagicMock) # Final answer step # Verify convert_code_format was called mock_convert.assert_called_once_with( - "```\nprint('hello')\n```") + "```\nprint('hello')\n```") def test_run_with_agent_error_updated(core_agent_instance): @@ -282,11 +282,11 @@ def test_run_with_agent_parse_error_branch_updated(core_agent_instance): # Mock _execute_step to set model_output and then raise FinalAnswerError def mock_execute_step(action_step): - action_step.model_output = "```\nprint('hello')\n```" + action_step.model_output = "```\nprint('hello')\n```" raise core_agent_module.FinalAnswerError() with patch.object(core_agent_instance, '_execute_step', side_effect=mock_execute_step), \ - patch.object(core_agent_module, 'convert_code_format', return_value="```python\nprint('hello')\n```") as mock_convert, \ + patch.object(core_agent_module, 'convert_code_format', return_value="```python\nprint('hello')\n```") as mock_convert, \ patch.object(core_agent_instance, '_finalize_step'): results = list(core_agent_instance._run_stream(task, max_steps)) @@ -296,7 +296,7 @@ def mock_execute_step(action_step): assert isinstance(results[1], MagicMock) # Final answer step # Verify convert_code_format was called mock_convert.assert_called_once_with( - "```\nprint('hello')\n```") + "```\nprint('hello')\n```") def test_convert_code_format_display_replacements(): @@ -305,7 +305,7 @@ def test_convert_code_format_display_replacements(): original_text = """Here is code: ``` print('hello') -``` +``` And some more text.""" expected_text = """Here is code: @@ -392,19 +392,19 @@ def test_parse_code_blobs_python_match(): assert result == expected -def test_parse_code_blobs_display_format_ignored(): - """Test parse_code_blobs ignores ```\ncontent\n``` pattern.""" +def test_parse_code_blobs_display_format_raises_value_error(): + """Test parse_code_blobs raises ValueError when only DISPLAY code blocks are present.""" text = """Here is some code: ``` def hello(): return "Hello" -``` +``` And some more text.""" - # This should raise ValueError because parse_code_blobs only handles format + # This should raise ValueError when only DISPLAY code blocks are found (no executable code) with pytest.raises(ValueError) as exc_info: core_agent_module.parse_code_blobs(text) - + assert "executable code block pattern" in str(exc_info.value) @@ -599,6 +599,32 @@ def test_step_stream_parse_success(core_agent_instance): assert hasattr(mock_memory_step.tool_calls[0], 'arguments') +def test_step_stream_skips_execution_for_display_only(core_agent_instance): + """Test that _step_stream raises FinalAnswerError when only DISPLAY code blocks are present.""" + # Setup + mock_memory_step = MagicMock() + mock_chat_message = MagicMock() + mock_chat_message.content = "```\nprint('hello')\n```" + + # Set all required attributes on the instance + core_agent_instance.agent_name = "test_agent" + core_agent_instance.step_number = 1 + core_agent_instance.grammar = None + core_agent_instance.logger = MagicMock() + core_agent_instance.memory = MagicMock() + core_agent_instance.memory.steps = [] + + # Mock parse_code_blobs to raise ValueError (no executable code found) + with patch.object(core_agent_module, 'parse_code_blobs', side_effect=ValueError("No executable code found")): + # Mock the methods directly on the instance + core_agent_instance.write_memory_to_messages = MagicMock(return_value=[]) + core_agent_instance.model = MagicMock(return_value=mock_chat_message) + + # Execute and assert that FinalAnswerError is raised + with pytest.raises(core_agent_module.FinalAnswerError): + list(core_agent_instance._step_stream(mock_memory_step)) + + def test_step_stream_parse_failure_raises_final_answer_error(core_agent_instance): """Test _step_stream method when parsing fails and raises FinalAnswerError.""" # Setup diff --git a/test/sdk/core/agents/test_nexent_agent.py b/test/sdk/core/agents/test_nexent_agent.py index 5f747f89f..a186d80bb 100644 --- a/test/sdk/core/agents/test_nexent_agent.py +++ b/test/sdk/core/agents/test_nexent_agent.py @@ -86,6 +86,15 @@ class _TestCoreAgent: "openai.types.chat.chat_completion_message_param": MagicMock(), # Mock exa_py to avoid importing the real package when sdk.nexent.core.tools imports it "exa_py": MagicMock(Exa=MagicMock()), + # Mock paramiko and cryptography to avoid PyO3 import issues in tests + "paramiko": MagicMock(), + "cryptography": MagicMock(), + "cryptography.hazmat": MagicMock(), + "cryptography.hazmat.primitives": MagicMock(), + "cryptography.hazmat.primitives.ciphers": MagicMock(), + "cryptography.hazmat.primitives.ciphers.base": MagicMock(), + "cryptography.hazmat.bindings": MagicMock(), + "cryptography.hazmat.bindings._rust": MagicMock(), # Mock the OpenAIModel import "sdk.nexent.core.models.openai_llm": MagicMock(OpenAIModel=mock_openai_model_class), # Mock CoreAgent import @@ -100,6 +109,7 @@ class _TestCoreAgent: # --------------------------------------------------------------------------- with patch.dict("sys.modules", module_mocks): from sdk.nexent.core.utils.observer import MessageObserver, ProcessType + from sdk.nexent.core.agents import nexent_agent from sdk.nexent.core.agents.nexent_agent import NexentAgent, ActionStep, TaskStep from sdk.nexent.core.agents.agent_model import ToolConfig, ModelConfig, AgentConfig @@ -447,6 +457,236 @@ def test_create_tool_with_local_source(nexent_agent_instance): assert result == "local_tool" +def test_create_local_tool_success(nexent_agent_instance): + """Test successful creation of a local tool.""" + mock_tool_class = MagicMock() + mock_tool_instance = MagicMock() + mock_tool_class.return_value = mock_tool_instance + + tool_config = ToolConfig( + class_name="DummyTool", + name="dummy", + description="desc", + inputs="{}", + output_type="string", + params={"param1": "value1", "param2": 42}, + source="local", + metadata={}, + ) + + # Patch the module's globals to include our mock tool class + original_value = nexent_agent.__dict__.get("DummyTool") + nexent_agent.__dict__["DummyTool"] = mock_tool_class + + try: + result = nexent_agent_instance.create_local_tool(tool_config) + finally: + # Restore original value + if original_value is not None: + nexent_agent.__dict__["DummyTool"] = original_value + elif "DummyTool" in nexent_agent.__dict__: + del nexent_agent.__dict__["DummyTool"] + + mock_tool_class.assert_called_once_with(param1="value1", param2=42) + assert result == mock_tool_instance + + +def test_create_local_tool_class_not_found(nexent_agent_instance): + """Test create_local_tool raises ValueError when class is not found.""" + tool_config = ToolConfig( + class_name="NonExistentTool", + name="dummy", + description="desc", + inputs="{}", + output_type="string", + params={}, + source="local", + metadata={}, + ) + + with pytest.raises(ValueError, match="NonExistentTool not found in local"): + nexent_agent_instance.create_local_tool(tool_config) + + +def test_create_local_tool_knowledge_base_search_tool_success(nexent_agent_instance): + """Test successful creation of KnowledgeBaseSearchTool with metadata.""" + mock_kb_tool_class = MagicMock() + mock_kb_tool_instance = MagicMock() + mock_kb_tool_class.return_value = mock_kb_tool_instance + + mock_vdb_core = MagicMock() + mock_embedding_model = MagicMock() + + tool_config = ToolConfig( + class_name="KnowledgeBaseSearchTool", + name="knowledge_base_search", + description="desc", + inputs="{}", + output_type="string", + params={"top_k": 10}, + source="local", + metadata={ + "index_names": ["index1", "index2"], + "vdb_core": mock_vdb_core, + "embedding_model": mock_embedding_model, + }, + ) + + original_value = nexent_agent.__dict__.get("KnowledgeBaseSearchTool") + nexent_agent.__dict__["KnowledgeBaseSearchTool"] = mock_kb_tool_class + + try: + result = nexent_agent_instance.create_local_tool(tool_config) + finally: + # Restore original value + if original_value is not None: + nexent_agent.__dict__["KnowledgeBaseSearchTool"] = original_value + elif "KnowledgeBaseSearchTool" in nexent_agent.__dict__: + del nexent_agent.__dict__["KnowledgeBaseSearchTool"] + + # Verify only non-excluded params are passed to __init__ + mock_kb_tool_class.assert_called_once_with( + top_k=10, # Only non-excluded params passed to __init__ + ) + # Verify excluded parameters were set directly as attributes after instantiation + assert result == mock_kb_tool_instance + assert mock_kb_tool_instance.observer == nexent_agent_instance.observer + assert mock_kb_tool_instance.index_names == ["index1", "index2"] + assert mock_kb_tool_instance.vdb_core == mock_vdb_core + assert mock_kb_tool_instance.embedding_model == mock_embedding_model + + +def test_create_local_tool_knowledge_base_search_tool_with_conflicting_params(nexent_agent_instance): + """Test KnowledgeBaseSearchTool creation filters out conflicting params from params dict.""" + mock_kb_tool_class = MagicMock() + mock_kb_tool_instance = MagicMock() + mock_kb_tool_class.return_value = mock_kb_tool_instance + + mock_vdb_core = MagicMock() + mock_embedding_model = MagicMock() + + tool_config = ToolConfig( + class_name="KnowledgeBaseSearchTool", + name="knowledge_base_search", + description="desc", + inputs="{}", + output_type="string", + params={ + "top_k": 10, + "index_names": ["conflicting_index"], # This should be filtered out + "vdb_core": "conflicting_vdb", # This should be filtered out + "embedding_model": "conflicting_model", # This should be filtered out + "observer": "conflicting_observer", # This should be filtered out + }, + source="local", + metadata={ + "index_names": ["index1", "index2"], # These should be used instead + "vdb_core": mock_vdb_core, + "embedding_model": mock_embedding_model, + }, + ) + + original_value = nexent_agent.__dict__.get("KnowledgeBaseSearchTool") + nexent_agent.__dict__["KnowledgeBaseSearchTool"] = mock_kb_tool_class + + try: + result = nexent_agent_instance.create_local_tool(tool_config) + finally: + # Restore original value + if original_value is not None: + nexent_agent.__dict__["KnowledgeBaseSearchTool"] = original_value + elif "KnowledgeBaseSearchTool" in nexent_agent.__dict__: + del nexent_agent.__dict__["KnowledgeBaseSearchTool"] + + # Verify conflicting params were filtered out from __init__ call + # Only non-excluded params should be passed to __init__ due to smolagents wrapper restrictions + mock_kb_tool_class.assert_called_once_with( + top_k=10, # From filtered_params (not in conflict list) + ) + # Verify excluded parameters were set directly as attributes after instantiation + assert result == mock_kb_tool_instance + assert mock_kb_tool_instance.observer == nexent_agent_instance.observer + assert mock_kb_tool_instance.index_names == ["index1", "index2"] # From metadata, not params + assert mock_kb_tool_instance.vdb_core == mock_vdb_core # From metadata, not params + assert mock_kb_tool_instance.embedding_model == mock_embedding_model # From metadata, not params + + +def test_create_local_tool_knowledge_base_search_tool_with_none_defaults(nexent_agent_instance): + """Test KnowledgeBaseSearchTool creation with None defaults when metadata is missing.""" + mock_kb_tool_class = MagicMock() + mock_kb_tool_instance = MagicMock() + mock_kb_tool_class.return_value = mock_kb_tool_instance + + tool_config = ToolConfig( + class_name="KnowledgeBaseSearchTool", + name="knowledge_base_search", + description="desc", + inputs="{}", + output_type="string", + params={"top_k": 5}, + source="local", + metadata={}, # No metadata provided + ) + + original_value = nexent_agent.__dict__.get("KnowledgeBaseSearchTool") + nexent_agent.__dict__["KnowledgeBaseSearchTool"] = mock_kb_tool_class + + try: + result = nexent_agent_instance.create_local_tool(tool_config) + finally: + # Restore original value + if original_value is not None: + nexent_agent.__dict__["KnowledgeBaseSearchTool"] = original_value + elif "KnowledgeBaseSearchTool" in nexent_agent.__dict__: + del nexent_agent.__dict__["KnowledgeBaseSearchTool"] + + # Verify only non-excluded params are passed to __init__ + mock_kb_tool_class.assert_called_once_with( + top_k=5, + ) + # Verify excluded parameters were set directly as attributes with None defaults when metadata is missing + assert result == mock_kb_tool_instance + assert mock_kb_tool_instance.observer == nexent_agent_instance.observer + assert mock_kb_tool_instance.index_names == [] # Empty list when None + assert mock_kb_tool_instance.vdb_core is None + assert mock_kb_tool_instance.embedding_model is None + assert result == mock_kb_tool_instance + + +def test_create_local_tool_with_observer_attribute(nexent_agent_instance): + """Test create_local_tool sets observer attribute on tool if it exists.""" + mock_tool_class = MagicMock() + mock_tool_instance = MagicMock() + mock_tool_instance.observer = None # Initially no observer + mock_tool_class.return_value = mock_tool_instance + + tool_config = ToolConfig( + class_name="ToolWithObserver", + name="tool", + description="desc", + inputs="{}", + output_type="string", + params={}, + source="local", + metadata={}, + ) + + original_value = nexent_agent.__dict__.get("ToolWithObserver") + nexent_agent.__dict__["ToolWithObserver"] = mock_tool_class + + try: + result = nexent_agent_instance.create_local_tool(tool_config) + finally: + # Restore original value + if original_value is not None: + nexent_agent.__dict__["ToolWithObserver"] = original_value + elif "ToolWithObserver" in nexent_agent.__dict__: + del nexent_agent.__dict__["ToolWithObserver"] + + # Verify observer was set on the tool instance + assert result.observer == nexent_agent_instance.observer + + def test_create_tool_with_mcp_source(nexent_agent_instance): """Ensure create_tool dispatches to create_mcp_tool for mcp source.""" tool_config = ToolConfig( diff --git a/test/sdk/core/tools/test_knowledge_base_search_tool.py b/test/sdk/core/tools/test_knowledge_base_search_tool.py index 3bbfaa0e9..535af6b35 100644 --- a/test/sdk/core/tools/test_knowledge_base_search_tool.py +++ b/test/sdk/core/tools/test_knowledge_base_search_tool.py @@ -16,10 +16,10 @@ def mock_observer(): @pytest.fixture -def mock_es_core(): +def mock_vdb_core(): """Create a mock ElasticSearchCore for testing""" - es_core = MagicMock() - return es_core + vdb_core = MagicMock() + return vdb_core @pytest.fixture @@ -30,27 +30,27 @@ def mock_embedding_model(): @pytest.fixture -def knowledge_base_search_tool(mock_observer, mock_es_core, mock_embedding_model): +def knowledge_base_search_tool(mock_observer, mock_vdb_core, mock_embedding_model): """Create KnowledgeBaseSearchTool instance for testing""" tool = KnowledgeBaseSearchTool( top_k=5, index_names=["test_index1", "test_index2"], observer=mock_observer, embedding_model=mock_embedding_model, - es_core=mock_es_core + vdb_core=mock_vdb_core ) return tool @pytest.fixture -def knowledge_base_search_tool_no_observer(mock_es_core, mock_embedding_model): +def knowledge_base_search_tool_no_observer(mock_vdb_core, mock_embedding_model): """Create KnowledgeBaseSearchTool instance without observer for testing""" tool = KnowledgeBaseSearchTool( top_k=3, index_names=["test_index"], observer=None, embedding_model=mock_embedding_model, - es_core=mock_es_core + vdb_core=mock_vdb_core ) return tool @@ -78,41 +78,41 @@ def create_mock_search_result(count=3): class TestKnowledgeBaseSearchTool: """Test KnowledgeBaseSearchTool functionality""" - def test_init_with_custom_values(self, mock_observer, mock_es_core, mock_embedding_model): + def test_init_with_custom_values(self, mock_observer, mock_vdb_core, mock_embedding_model): """Test initialization with custom values""" tool = KnowledgeBaseSearchTool( top_k=10, index_names=["index1", "index2", "index3"], observer=mock_observer, embedding_model=mock_embedding_model, - es_core=mock_es_core + vdb_core=mock_vdb_core ) assert tool.top_k == 10 assert tool.index_names == ["index1", "index2", "index3"] assert tool.observer == mock_observer assert tool.embedding_model == mock_embedding_model - assert tool.es_core == mock_es_core + assert tool.vdb_core == mock_vdb_core - def test_init_with_none_index_names(self, mock_es_core, mock_embedding_model): + def test_init_with_none_index_names(self, mock_vdb_core, mock_embedding_model): """Test initialization with None index_names""" tool = KnowledgeBaseSearchTool( top_k=5, index_names=None, observer=None, embedding_model=mock_embedding_model, - es_core=mock_es_core + vdb_core=mock_vdb_core ) assert tool.index_names == [] - def test_es_search_hybrid_success(self, knowledge_base_search_tool): + def test_search_hybrid_success(self, knowledge_base_search_tool): """Test successful hybrid search""" # Mock search results mock_results = create_mock_search_result(3) - knowledge_base_search_tool.es_core.hybrid_search.return_value = mock_results + knowledge_base_search_tool.vdb_core.hybrid_search.return_value = mock_results - result = knowledge_base_search_tool.es_search_hybrid("test query", ["test_index1"]) + result = knowledge_base_search_tool.search_hybrid("test query", ["test_index1"]) # Verify result structure assert result["total"] == 3 @@ -126,59 +126,59 @@ def test_es_search_hybrid_success(self, knowledge_base_search_tool): assert "index" in doc assert doc["title"] == f"Test Document {i}" - # Verify es_core was called correctly - knowledge_base_search_tool.es_core.hybrid_search.assert_called_once_with( + # Verify vdb_core was called correctly + knowledge_base_search_tool.vdb_core.hybrid_search.assert_called_once_with( index_names=["test_index1"], query_text="test query", embedding_model=knowledge_base_search_tool.embedding_model, top_k=5 ) - def test_es_search_accurate_success(self, knowledge_base_search_tool): + def test_search_accurate_success(self, knowledge_base_search_tool): """Test successful accurate search""" # Mock search results mock_results = create_mock_search_result(2) - knowledge_base_search_tool.es_core.accurate_search.return_value = mock_results + knowledge_base_search_tool.vdb_core.accurate_search.return_value = mock_results - result = knowledge_base_search_tool.es_search_accurate("test query", ["test_index1"]) + result = knowledge_base_search_tool.search_accurate("test query", ["test_index1"]) # Verify result structure assert result["total"] == 2 assert len(result["results"]) == 2 - # Verify es_core was called correctly - knowledge_base_search_tool.es_core.accurate_search.assert_called_once_with( + # Verify vdb_core was called correctly + knowledge_base_search_tool.vdb_core.accurate_search.assert_called_once_with( index_names=["test_index1"], query_text="test query", top_k=5 ) - def test_es_search_semantic_success(self, knowledge_base_search_tool): + def test_search_semantic_success(self, knowledge_base_search_tool): """Test successful semantic search""" # Mock search results mock_results = create_mock_search_result(4) - knowledge_base_search_tool.es_core.semantic_search.return_value = mock_results + knowledge_base_search_tool.vdb_core.semantic_search.return_value = mock_results - result = knowledge_base_search_tool.es_search_semantic("test query", ["test_index1"]) + result = knowledge_base_search_tool.search_semantic("test query", ["test_index1"]) # Verify result structure assert result["total"] == 4 assert len(result["results"]) == 4 - # Verify es_core was called correctly - knowledge_base_search_tool.es_core.semantic_search.assert_called_once_with( + # Verify vdb_core was called correctly + knowledge_base_search_tool.vdb_core.semantic_search.assert_called_once_with( index_names=["test_index1"], query_text="test query", embedding_model=knowledge_base_search_tool.embedding_model, top_k=5 ) - def test_es_search_hybrid_error(self, knowledge_base_search_tool): + def test_search_hybrid_error(self, knowledge_base_search_tool): """Test hybrid search with error""" - knowledge_base_search_tool.es_core.hybrid_search.side_effect = Exception("Search error") + knowledge_base_search_tool.vdb_core.hybrid_search.side_effect = Exception("Search error") with pytest.raises(Exception) as excinfo: - knowledge_base_search_tool.es_search_hybrid("test query", ["test_index1"]) + knowledge_base_search_tool.search_hybrid("test query", ["test_index1"]) assert "Error during semantic search" in str(excinfo.value) @@ -186,7 +186,7 @@ def test_forward_accurate_mode_success(self, knowledge_base_search_tool): """Test forward method with accurate search mode""" # Mock search results mock_results = create_mock_search_result(2) - knowledge_base_search_tool.es_core.accurate_search.return_value = mock_results + knowledge_base_search_tool.vdb_core.accurate_search.return_value = mock_results result = knowledge_base_search_tool.forward("test query", search_mode="accurate") @@ -200,7 +200,7 @@ def test_forward_semantic_mode_success(self, knowledge_base_search_tool): """Test forward method with semantic search mode""" # Mock search results mock_results = create_mock_search_result(4) - knowledge_base_search_tool.es_core.semantic_search.return_value = mock_results + knowledge_base_search_tool.vdb_core.semantic_search.return_value = mock_results result = knowledge_base_search_tool.forward("test query", search_mode="semantic") @@ -231,7 +231,7 @@ def test_forward_no_index_names(self, knowledge_base_search_tool): def test_forward_no_results(self, knowledge_base_search_tool): """Test forward method with no search results""" # Mock empty search results - knowledge_base_search_tool.es_core.hybrid_search.return_value = [] + knowledge_base_search_tool.vdb_core.hybrid_search.return_value = [] with pytest.raises(Exception) as excinfo: knowledge_base_search_tool.forward("test query") @@ -242,7 +242,7 @@ def test_forward_with_custom_index_names(self, knowledge_base_search_tool): """Test forward method with custom index names""" # Mock search results mock_results = create_mock_search_result(2) - knowledge_base_search_tool.es_core.hybrid_search.return_value = mock_results + knowledge_base_search_tool.vdb_core.hybrid_search.return_value = mock_results result = knowledge_base_search_tool.forward( "test query", @@ -250,8 +250,8 @@ def test_forward_with_custom_index_names(self, knowledge_base_search_tool): index_names=["custom_index1", "custom_index2"] ) - # Verify es_core was called with custom index names - knowledge_base_search_tool.es_core.hybrid_search.assert_called_once_with( + # Verify vdb_core was called with custom index names + knowledge_base_search_tool.vdb_core.hybrid_search.assert_called_once_with( index_names=["custom_index1", "custom_index2"], query_text="test query", embedding_model=knowledge_base_search_tool.embedding_model, @@ -265,7 +265,7 @@ def test_forward_chinese_language_observer(self, knowledge_base_search_tool): # Mock search results mock_results = create_mock_search_result(2) - knowledge_base_search_tool.es_core.hybrid_search.return_value = mock_results + knowledge_base_search_tool.vdb_core.hybrid_search.return_value = mock_results result = knowledge_base_search_tool.forward("test query") @@ -291,7 +291,7 @@ def test_forward_title_fallback(self, knowledge_base_search_tool): "index": "test_index" } ] - knowledge_base_search_tool.es_core.hybrid_search.return_value = mock_results + knowledge_base_search_tool.vdb_core.hybrid_search.return_value = mock_results result = knowledge_base_search_tool.forward("test query") diff --git a/test/sdk/memory/test_memory_service.py b/test/sdk/memory/test_memory_service.py index ade09f93f..5894d2d9c 100644 --- a/test/sdk/memory/test_memory_service.py +++ b/test/sdk/memory/test_memory_service.py @@ -562,7 +562,7 @@ async def _reset(cfg): # noqa: ANN001 monkeypatch.setattr(memory_service, "reset_all_memory", _reset) ok = await memory_service.clear_model_memories( - es_core=es, + vdb_core=es, model_repo="jina-ai", model_name="jina-embeddings-v2-base-en", embedding_dims=768, @@ -586,7 +586,7 @@ async def _reset(cfg): # noqa: ANN001 monkeypatch.setattr(memory_service, "reset_all_memory", _reset) ok = await memory_service.clear_model_memories( - es_core=es, + vdb_core=es, model_repo="jina-ai", model_name="jina-embeddings-v2-base-en", embedding_dims=1024, @@ -622,7 +622,7 @@ async def _reset(cfg): # noqa: ANN001 monkeypatch.setattr(memory_service, "reset_all_memory", _reset) ok = await memory_service.clear_model_memories( - es_core=es, + vdb_core=es, model_repo="", model_name="m", embedding_dims=128, @@ -644,7 +644,7 @@ async def _reset(_: Dict[str, Any]): monkeypatch.setattr(memory_service, "reset_all_memory", _reset) ok = await memory_service.clear_model_memories( - es_core=es, + vdb_core=es, model_repo=None, model_name="Model", embedding_dims=256, @@ -660,7 +660,7 @@ async def _reset(_: Dict[str, Any]): async def test_clear_model_memories_invalid_model_name(): es = _DummyESCore(exists_behavior=lambda index: True) ok = await memory_service.clear_model_memories( - es_core=es, + vdb_core=es, model_repo="any", model_name="", embedding_dims=512, diff --git a/test/sdk/multi_modal/test_load_save_object.py b/test/sdk/multi_modal/test_load_save_object.py new file mode 100644 index 000000000..0a05d8b1f --- /dev/null +++ b/test/sdk/multi_modal/test_load_save_object.py @@ -0,0 +1,443 @@ +import io +from typing import Any, Tuple +from unittest.mock import MagicMock + +import pytest + +from sdk.nexent.multi_modal import load_save_object as lso + + +def make_manager(client: Any = None) -> lso.LoadSaveObjectManager: + if client is None: + client = object() + return lso.LoadSaveObjectManager(storage_client=client) + + +def test_get_client_returns_configured_storage(): + sentinel = object() + manager = make_manager(sentinel) + assert manager._get_client() is sentinel + + +def test_get_client_requires_initialized_storage(): + manager = lso.LoadSaveObjectManager(storage_client=None) + + with pytest.raises(ValueError): + manager._get_client() + + +def test_download_file_from_http(monkeypatch): + manager = make_manager() + + class _Response: + def __init__(self): + self.content = b"binary" + + def raise_for_status(self): + return None + + monkeypatch.setattr(lso.requests, "get", lambda url, timeout: _Response()) + data = manager.download_file_from_url( + "https://example.com/file.png", + url_type="https", + ) + assert data == b"binary" + + +def test_download_file_from_s3(monkeypatch): + class _FakeClient: + def get_file_stream(self, object_name: str, bucket: str) -> Tuple[bool, Any]: + assert object_name == "path/to/object" + assert bucket == "bucket" + return True, io.BytesIO(b"payload") + + manager = make_manager(_FakeClient()) + data = manager.download_file_from_url("s3://bucket/path/to/object", url_type="s3") + assert data == b"payload" + + +def test_download_file_from_s3_failure_returns_none(): + class _FailingClient: + def get_file_stream(self, object_name: str, bucket: str): + return False, "boom" + + manager = make_manager(_FailingClient()) + assert manager.download_file_from_url("s3://bucket/object", url_type="s3") is None + + +def test_download_file_from_s3_missing_method_returns_none(): + class _InvalidClient: + pass + + manager = make_manager(_InvalidClient()) + assert manager.download_file_from_url("s3://bucket/object", url_type="s3") is None + + +def test_download_file_requires_url_type(): + manager = make_manager() + with pytest.raises(ValueError): + manager.download_file_from_url("https://example.com/file.png", url_type=None) # type: ignore[arg-type] + + +def test_download_file_empty_url_returns_none(): + manager = make_manager() + assert manager.download_file_from_url("", url_type="https") is None + + +def test_download_file_stream_read_failure(monkeypatch): + class _FailingStream: + def read(self): + raise RuntimeError("cannot read") + + def close(self): + pass + + class _Client: + def get_file_stream(self, object_name: str, bucket: str): + return True, _FailingStream() + + manager = make_manager(_Client()) + assert manager.download_file_from_url("s3://bucket/object", url_type="s3") is None + + +def test_upload_bytes_to_minio_generates_object_name(monkeypatch): + captured = {} + + class _UploadClient: + def upload_fileobj(self, file_obj, object_name, bucket): + captured["data"] = file_obj.read() + captured["object_name"] = object_name + captured["bucket"] = bucket + return True, "/bucket/generated.bin" + + manager = make_manager(_UploadClient()) + monkeypatch.setattr(lso, "guess_extension_from_content_type", lambda c: ".bin") + monkeypatch.setattr(lso, "generate_object_name", lambda ext: f"generated{ext}") + + result = manager._upload_bytes_to_minio(b"payload", content_type="application/octet-stream") + + assert result == "/bucket/generated.bin" + assert captured["data"] == b"payload" + assert captured["object_name"] == "generated.bin" + assert captured["bucket"] == "multi-modal" + + +def test_upload_bytes_to_minio_generates_name_without_extension(monkeypatch): + captured = {} + + class _UploadClient: + def upload_fileobj(self, file_obj, object_name, bucket): + captured["object_name"] = object_name + return True, "/bucket/generated" + + manager = make_manager(_UploadClient()) + + monkeypatch.setattr(lso, "guess_extension_from_content_type", lambda _: "") + + def _generate(ext: str): + captured["ext"] = ext + return "generated" + + monkeypatch.setattr(lso, "generate_object_name", _generate) + + path = manager._upload_bytes_to_minio(b"bytes", content_type="application/octet-stream") + assert path == "/bucket/generated" + assert captured["ext"] == "" + assert captured["object_name"] == "generated" + + +def test_upload_bytes_to_minio_requires_upload_method(): + class _InvalidClient: + pass + + manager = make_manager(_InvalidClient()) + + with pytest.raises(ValueError): + manager._upload_bytes_to_minio(b"bytes") + + +def test_upload_bytes_to_minio_failure_propagates_error(): + class _UploadClient: + def upload_fileobj(self, file_obj, object_name, bucket): + return False, "failed" + + manager = make_manager(_UploadClient()) + + with pytest.raises(ValueError): + manager._upload_bytes_to_minio(b"bytes") + + +def test_upload_bytes_to_minio_with_explicit_object_name(): + captured = {} + + class _UploadClient: + def upload_fileobj(self, file_obj, object_name, bucket): + captured["name"] = object_name + captured["bucket"] = bucket + return True, "/bucket/custom.bin" + + manager = make_manager(_UploadClient()) + result = manager._upload_bytes_to_minio( + b"payload", + object_name="provided.bin", + bucket="custom-bucket" + ) + + assert result == "/bucket/custom.bin" + assert captured["name"] == "provided.bin" + assert captured["bucket"] == "custom-bucket" + + +def test_load_object_transforms_single_argument(monkeypatch): + manager = make_manager() + download_mock = MagicMock(return_value=b"file-bytes") + monkeypatch.setattr(manager, "download_file_from_url", download_mock) + + @manager.load_object(input_names=["image"]) + def handler(image): + return image + + result = handler("https://example.com/img.png") + + download_mock.assert_called_once_with("https://example.com/img.png", url_type="https") + assert result == b"file-bytes" + + +def test_load_object_transforms_iterable_with_transformer(monkeypatch): + manager = make_manager() + + def transformer(data: bytes) -> str: + return data.decode("utf-8") + + download_mock = MagicMock(side_effect=[b"first", b"second"]) + monkeypatch.setattr(manager, "download_file_from_url", download_mock) + + @manager.load_object(input_names=["images"], input_data_transformer=[transformer]) + def handler(images): + return images + + result = handler(["https://a", "https://b"]) + + assert result == ["first", "second"] + + +def test_load_object_preserves_tuple_type(monkeypatch): + manager = make_manager() + download_mock = MagicMock(side_effect=[b"alpha", b"beta"]) + monkeypatch.setattr(manager, "download_file_from_url", download_mock) + + @manager.load_object(input_names=["images"]) + def handler(images): + return images + + result = handler(("https://a", "https://b")) + + assert isinstance(result, tuple) + assert result == (b"alpha", b"beta") + + +def test_load_object_skips_missing_arguments(monkeypatch): + manager = make_manager() + download_mock = MagicMock(return_value=b"bytes") + monkeypatch.setattr(manager, "download_file_from_url", download_mock) + + @manager.load_object(input_names=["image", "mask"]) + def handler(image, other=None): + return image, other + + result = handler("https://example.com/a.png") + download_mock.assert_called_once_with("https://example.com/a.png", url_type="https") + assert result == (b"bytes", None) + + +def test_load_object_raises_for_non_url(): + manager = make_manager() + + @manager.load_object(input_names=["image"]) + def handler(image): + return image + + with pytest.raises(ValueError): + handler(123) + + +def test_load_object_allows_none_input(): + manager = make_manager() + + @manager.load_object(input_names=["image"]) + def handler(image): + return image + + assert handler(None) is None + + +def test_load_object_transformer_error_propagates(monkeypatch): + def transformer(_data: bytes): + raise RuntimeError("boom") + + manager = make_manager() + monkeypatch.setattr(manager, "download_file_from_url", MagicMock(return_value=b"bytes")) + + @manager.load_object(input_names=["image"], input_data_transformer=[transformer]) + def handler(image): + return image + + with pytest.raises(RuntimeError): + handler("https://example.com/test.png") + + +def test_load_object_transformer_list_shorter_than_inputs(monkeypatch): + manager = make_manager() + download_mock = MagicMock(side_effect=[b"first", b"second"]) + monkeypatch.setattr(manager, "download_file_from_url", download_mock) + + def decode(data: bytes) -> str: + return data.decode("utf-8") + + @manager.load_object( + input_names=["primary", "secondary"], + input_data_transformer=[decode], + ) + def handler(primary, secondary): + return primary, secondary + + result = handler("https://a", "https://b") + assert result == ("first", b"second") + assert download_mock.call_count == 2 + + +def test_save_object_uploads_bytes(monkeypatch): + manager = make_manager() + upload_mock = MagicMock(return_value="/bucket/object") + monkeypatch.setattr(manager, "_upload_bytes_to_minio", upload_mock) + monkeypatch.setattr( + lso, "detect_content_type_from_bytes", lambda data: "image/png" + ) + + @manager.save_object(output_names=["image"]) + def handler(): + return b"\x89PNG\r\n\x1a\n" + + result = handler() + upload_mock.assert_called_once() + assert result == "s3://bucket/object" + + +def test_save_object_with_transformer_and_nested(monkeypatch): + manager = make_manager() + upload_mock = MagicMock(side_effect=["/bucket/a", "/bucket/b"]) + monkeypatch.setattr(manager, "_upload_bytes_to_minio", upload_mock) + monkeypatch.setattr( + lso, "detect_content_type_from_bytes", lambda data: "application/octet-stream" + ) + + def to_bytes(value: str) -> bytes: + return value.encode("utf-8") + + @manager.save_object(output_names=["images"], output_transformers=[to_bytes]) + def handler(): + return ["one", "two"] + + result = handler() + assert result == ["s3://bucket/a", "s3://bucket/b"] + assert upload_mock.call_count == 2 + + +def test_save_object_validates_return_value_count(): + manager = make_manager() + + @manager.save_object(output_names=["first", "second"]) + def handler(): + return b"only-one" + + with pytest.raises(ValueError): + handler() + + +def test_save_object_transformer_must_return_bytes(): + def identity(value): + return value # not bytes + + manager = make_manager() + + @manager.save_object(output_names=["payload"], output_transformers=[identity]) + def handler(): + return "text" + + with pytest.raises(ValueError): + handler() + + +def test_save_object_requires_bytes_without_transformer(): + manager = make_manager() + + @manager.save_object(output_names=["image"]) + def handler(): + return "text" + + with pytest.raises(ValueError): + handler() + + +def test_save_object_handles_none_output(monkeypatch): + manager = make_manager() + upload_mock = MagicMock() + monkeypatch.setattr(manager, "_upload_bytes_to_minio", upload_mock) + + @manager.save_object(output_names=["image"]) + def handler(): + return None + + assert handler() is None + upload_mock.assert_not_called() + + +def test_save_object_returns_tuple_for_multiple_outputs(monkeypatch): + manager = make_manager() + monkeypatch.setattr( + lso, "detect_content_type_from_bytes", lambda data: "application/octet-stream" + ) + upload_mock = MagicMock(side_effect=["/bucket/a", "/bucket/b"]) + monkeypatch.setattr(manager, "_upload_bytes_to_minio", upload_mock) + + @manager.save_object(output_names=["first", "second"]) + def handler(): + return b"a", b"b" + + assert handler() == ("s3://bucket/a", "s3://bucket/b") + assert upload_mock.call_count == 2 + + +def test_save_object_nested_none_structure(monkeypatch): + manager = make_manager() + monkeypatch.setattr( + lso, "detect_content_type_from_bytes", lambda data: "application/octet-stream" + ) + upload_mock = MagicMock(return_value="/bucket/value") + monkeypatch.setattr(manager, "_upload_bytes_to_minio", upload_mock) + + @manager.save_object(output_names=["images"]) + def handler_nested(): + return [None, b"bytes"] + + result = handler_nested() + assert result == [None, "s3://bucket/value"] + upload_mock.assert_called_once() + + +@pytest.mark.asyncio +async def test_save_object_supports_async_functions(monkeypatch): + manager = make_manager() + upload_mock = MagicMock(return_value="/bucket/object") + monkeypatch.setattr(manager, "_upload_bytes_to_minio", upload_mock) + monkeypatch.setattr( + lso, "detect_content_type_from_bytes", lambda data: "image/png" + ) + + @manager.save_object(output_names=["image"]) + async def handler(): + return b"\x89PNG\r\n\x1a\n" + + result = await handler() + assert result == "s3://bucket/object" + upload_mock.assert_called_once() diff --git a/test/sdk/multi_modal/test_utils.py b/test/sdk/multi_modal/test_utils.py new file mode 100644 index 000000000..f72ce9d30 --- /dev/null +++ b/test/sdk/multi_modal/test_utils.py @@ -0,0 +1,205 @@ +import base64 + +import pytest + +from sdk.nexent.multi_modal import utils + + +def test_is_url_variants(): + assert utils.is_url("https://example.com/image.png") == "https" + assert utils.is_url("http://example.com/image.png") == "http" + assert utils.is_url("s3://bucket/key") == "s3" + assert utils.is_url("/bucket/key") == "s3" + assert utils.is_url("not-a-url") is None + assert utils.is_url(123) is None # type: ignore[arg-type] + + +def test_is_url_requires_bucket_and_key(): + assert utils.is_url("/bucket") is None + assert utils.is_url("s3://bucket/") is None + assert utils.is_url("") is None + + +def test_bytes_to_base64_and_back(): + data = b"sample" + encoded = utils.bytes_to_base64(data, content_type="text/plain") + assert encoded.startswith("data:text/plain;base64,") + decoded, content_type = utils.base64_to_bytes(encoded) + assert decoded == data + assert content_type == "text/plain" + + +def test_bytes_to_base64_requires_data(): + with pytest.raises(ValueError): + utils.bytes_to_base64(b"") + + +def test_base64_to_bytes_without_prefix(): + payload = base64.b64encode(b"raw-data").decode("utf-8") + decoded, content_type = utils.base64_to_bytes(payload) + assert decoded == b"raw-data" + assert content_type == "application/octet-stream" + + +def test_base64_to_bytes_invalid_input(): + with pytest.raises(ValueError): + utils.base64_to_bytes("data:image/png;base64,invalid!!") + + +def test_base64_to_bytes_requires_string(): + with pytest.raises(ValueError): + utils.base64_to_bytes(b"not-a-string") # type: ignore[arg-type] + + +def test_base64_to_bytes_invalid_header_format(): + with pytest.raises(ValueError): + utils.base64_to_bytes("data:image/png;base64") # missing comma + + +def test_generate_object_name_appends_extension(monkeypatch: pytest.MonkeyPatch): + class _FixedDateTime: + @staticmethod + def now(): + class _Value: + def strftime(self, fmt: str) -> str: + return "20240102_030405" + + return _Value() + + class _FixedUUID: + @staticmethod + def uuid4(): + return "12345678-abcdef" + + monkeypatch.setattr(utils, "datetime", _FixedDateTime()) + monkeypatch.setattr(utils, "uuid", _FixedUUID()) + + name = utils.generate_object_name("png") + assert name == "20240102_030405_12345678.png" + + +def test_generate_object_name_accepts_dot_prefix(monkeypatch: pytest.MonkeyPatch): + class _FixedDateTime: + @staticmethod + def now(): + class _Value: + def strftime(self, fmt: str) -> str: + return "20240102_030405" + + return _Value() + + class _FixedUUID: + @staticmethod + def uuid4(): + return "12345678-abcdef" + + monkeypatch.setattr(utils, "datetime", _FixedDateTime()) + monkeypatch.setattr(utils, "uuid", _FixedUUID()) + + name = utils.generate_object_name(".gif") + assert name.endswith(".gif") + + +def test_detect_content_type_known_signatures(): + png_bytes = b"\x89PNG\r\n\x1a\n" + b"\x00" * 10 + assert utils.detect_content_type_from_bytes(png_bytes) == "image/png" + + pdf_bytes = b"%PDF" + b"\x00" * 10 + assert utils.detect_content_type_from_bytes(pdf_bytes) == "application/pdf" + + jpeg_bytes = b"\xff\xd8\xff" + b"\x00" * 5 + assert utils.detect_content_type_from_bytes(jpeg_bytes) == "image/jpeg" + + gif_bytes = b"GIF89a" + b"\x00" * 6 + assert utils.detect_content_type_from_bytes(gif_bytes) == "image/gif" + + webp_bytes = b"RIFF" + b"\x00" * 4 + b"WEBP" + assert utils.detect_content_type_from_bytes(webp_bytes) == "image/webp" + + wav_bytes = b"RIFF" + b"\x00" * 4 + b"WAVE" + assert utils.detect_content_type_from_bytes(wav_bytes) == "audio/wav" + + +def test_detect_content_type_audio_video_variants(): + mp4_bytes = b"\x00\x00\x00\x20ftyp" + b"\x00" * 10 + assert utils.detect_content_type_from_bytes(mp4_bytes) == "video/mp4" + + mp3_bytes = b"ID3" + b"\x00" * 5 + assert utils.detect_content_type_from_bytes(mp3_bytes) == "audio/mpeg" + + +def test_detect_content_type_text_and_default(): + text_bytes = b"Hello world" + assert utils.detect_content_type_from_bytes(text_bytes) == "text/plain" + assert utils.detect_content_type_from_bytes(b"\x00\x01\x02") == "application/octet-stream" + json_bytes = b'{"key": "value"}' + assert utils.detect_content_type_from_bytes(json_bytes) == "application/json" + + +def test_guess_content_type_from_url(): + assert utils.guess_content_type_from_url("http://example.com/file.webp") == "image/webp" + assert utils.guess_content_type_from_url("http://example.com/file.unknown") == "application/octet-stream" + assert utils.guess_content_type_from_url("http://example.com/file.jpg?token=1") == "image/jpeg" + + +def test_guess_content_type_from_url_uses_case_insensitive_suffix(): + assert utils.guess_content_type_from_url("http://example.com/VIDEO.MP4") == "video/mp4" + + +def test_guess_extension_from_content_type(): + assert utils.guess_extension_from_content_type("image/png") == ".png" + assert utils.guess_extension_from_content_type("unknown/type") == "" + + +def test_parse_s3_url_variants(): + assert utils.parse_s3_url("s3://bucket/key") == ("bucket", "key") + assert utils.parse_s3_url("/bucket/key") == ("bucket", "key") + + +def test_parse_s3_url_invalid(): + with pytest.raises(ValueError): + utils.parse_s3_url("invalid") + + +def test_parse_s3_url_requires_object_name(): + with pytest.raises(ValueError): + utils.parse_s3_url("s3://bucket/") + + with pytest.raises(ValueError): + utils.parse_s3_url("/bucket") + + +def test_base64_to_bytes_header_without_base64_flag(): + payload = base64.b64encode(b"json-bytes").decode("utf-8") + decoded, content_type = utils.base64_to_bytes( + f"data:application/json,{payload}" + ) + assert decoded == b"json-bytes" + assert content_type == "application/json" + + +@pytest.mark.parametrize( + ("payload", "expected"), + [ + (b"\x00\x00\x00 qt " + b"\x00" * 6, "video/quicktime"), + (b"OggS" + b"\x00" * 8, "audio/ogg"), + (b"fLaC" + b"\x00" * 8, "audio/flac"), + (b"\x1a\x45\xdf\xa3" + b"\x00" * 8, "video/webm"), + (b"RIFF" + b"\x00" * 4 + b"AVI ", "video/x-msvideo"), + ], +) +def test_detect_content_type_expanded_signatures(payload: bytes, expected: str): + assert utils.detect_content_type_from_bytes(payload) == expected + + +def test_detect_content_type_mp3_frame_sync(): + payload = b"\xff\xfb" + b"\x00" * 4 + assert utils.detect_content_type_from_bytes(payload) == "audio/mpeg" + + +@pytest.mark.parametrize("value", ["", None]) +def test_parse_s3_url_rejects_empty(value): + with pytest.raises(ValueError): + utils.parse_s3_url(value) # type: ignore[arg-type] + + diff --git a/test/sdk/vector_database/test_elasticsearch_core.py b/test/sdk/vector_database/test_elasticsearch_core.py index 0d495c4f4..c8b576051 100644 --- a/test/sdk/vector_database/test_elasticsearch_core.py +++ b/test/sdk/vector_database/test_elasticsearch_core.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch import time from typing import List, Dict, Any +from elasticsearch import exceptions # Import the class under test from sdk.nexent.vector_database.elasticsearch_core import ElasticSearchCore @@ -394,7 +395,7 @@ def test_preprocess_documents_maintains_order(elasticsearch_core_instance): # Tests for index management methods # ---------------------------------------------------------------------------- -def test_create_vector_index_success(elasticsearch_core_instance): +def test_create_index_success(elasticsearch_core_instance): """Test creating a new vector index successfully.""" with patch.object(elasticsearch_core_instance.client.indices, 'exists') as mock_exists, \ patch.object(elasticsearch_core_instance.client.indices, 'create') as mock_create, \ @@ -406,7 +407,7 @@ def test_create_vector_index_success(elasticsearch_core_instance): mock_refresh.return_value = True mock_ready.return_value = True - result = elasticsearch_core_instance.create_vector_index( + result = elasticsearch_core_instance.create_index( "test_index", embedding_dim=1024) assert result is True @@ -416,7 +417,7 @@ def test_create_vector_index_success(elasticsearch_core_instance): mock_ready.assert_called_once_with("test_index") -def test_create_vector_index_already_exists(elasticsearch_core_instance): +def test_create_index_already_exists(elasticsearch_core_instance): """Test creating an index that already exists.""" with patch.object(elasticsearch_core_instance.client.indices, 'exists') as mock_exists, \ patch.object(elasticsearch_core_instance, '_ensure_index_ready') as mock_ready: @@ -424,7 +425,7 @@ def test_create_vector_index_already_exists(elasticsearch_core_instance): mock_exists.return_value = True mock_ready.return_value = True - result = elasticsearch_core_instance.create_vector_index( + result = elasticsearch_core_instance.create_index( "existing_index") assert result is True @@ -445,8 +446,6 @@ def test_delete_index_success(elasticsearch_core_instance): def test_delete_index_not_found(elasticsearch_core_instance): """Test deleting an index that doesn't exist.""" - from elasticsearch import exceptions - with patch.object(elasticsearch_core_instance.client.indices, 'delete') as mock_delete: mock_delete.side_effect = exceptions.NotFoundError( "Index not found", {}, {}) @@ -478,11 +477,11 @@ def test_get_user_indices_success(elasticsearch_core_instance): # Tests for document operations # ---------------------------------------------------------------------------- -def test_index_documents_empty_list(elasticsearch_core_instance): +def test_vectorize_documents_empty_list(elasticsearch_core_instance): """Test indexing an empty list of documents.""" mock_embedding_model = MagicMock() - result = elasticsearch_core_instance.index_documents( + result = elasticsearch_core_instance.vectorize_documents( "test_index", mock_embedding_model, [], @@ -492,7 +491,7 @@ def test_index_documents_empty_list(elasticsearch_core_instance): assert result == 0 -def test_index_documents_small_batch(elasticsearch_core_instance): +def test_vectorize_documents_small_batch(elasticsearch_core_instance): """Test indexing a small batch of documents (< 64).""" mock_embedding_model = MagicMock() mock_embedding_model.get_embeddings.return_value = [[0.1] * 1024] * 3 @@ -512,7 +511,7 @@ def test_index_documents_small_batch(elasticsearch_core_instance): mock_time.return_value = 1642234567 mock_bulk.return_value = {"errors": False, "items": []} - result = elasticsearch_core_instance.index_documents( + result = elasticsearch_core_instance.vectorize_documents( "test_index", mock_embedding_model, documents, @@ -524,7 +523,7 @@ def test_index_documents_small_batch(elasticsearch_core_instance): mock_bulk.assert_called_once() -def test_index_documents_large_batch(elasticsearch_core_instance): +def test_vectorize_documents_large_batch(elasticsearch_core_instance): """Test indexing a large batch of documents (>= 64).""" mock_embedding_model = MagicMock() mock_embedding_model.get_embeddings.return_value = [[0.1] * 1024] * 64 @@ -546,7 +545,7 @@ def test_index_documents_large_batch(elasticsearch_core_instance): mock_bulk.return_value = {"errors": False, "items": []} mock_refresh.return_value = True - result = elasticsearch_core_instance.index_documents( + result = elasticsearch_core_instance.vectorize_documents( "test_index", mock_embedding_model, documents, @@ -560,12 +559,12 @@ def test_index_documents_large_batch(elasticsearch_core_instance): mock_refresh.assert_called_once_with("test_index") -def test_delete_documents_by_path_or_url_success(elasticsearch_core_instance): +def test_delete_documents_success(elasticsearch_core_instance): """Test deleting documents by path_or_url successfully.""" with patch.object(elasticsearch_core_instance.client, 'delete_by_query') as mock_delete: mock_delete.return_value = {"deleted": 5} - result = elasticsearch_core_instance.delete_documents_by_path_or_url( + result = elasticsearch_core_instance.delete_documents( "test_index", "/path/to/file.pdf" ) @@ -574,6 +573,97 @@ def test_delete_documents_by_path_or_url_success(elasticsearch_core_instance): mock_delete.assert_called_once() +def test_get_index_chunks_success(elasticsearch_core_instance): + """Test fetching chunks via scroll API.""" + elasticsearch_core_instance.client = MagicMock() + elasticsearch_core_instance.client.count.return_value = {"count": 2} + elasticsearch_core_instance.client.search.return_value = { + "_scroll_id": "scroll123", + "hits": { + "hits": [ + {"_id": "doc-1", "_source": {"id": "chunk-1", "content": "A"}}, + {"_id": "doc-2", "_source": {"content": "B"}} + ] + } + } + elasticsearch_core_instance.client.scroll.return_value = { + "_scroll_id": "scroll123", + "hits": {"hits": []} + } + + result = elasticsearch_core_instance.get_index_chunks("kb-index") + + assert result["chunks"] == [ + {"id": "chunk-1", "content": "A"}, + {"content": "B", "id": "doc-2"} + ] + assert result["total"] == 2 + elasticsearch_core_instance.client.search.assert_called_once() + elasticsearch_core_instance.client.scroll.assert_called_once_with(scroll_id="scroll123", scroll="2m") + elasticsearch_core_instance.client.clear_scroll.assert_called_once_with(scroll_id="scroll123") + + +def test_get_index_chunks_paginated(elasticsearch_core_instance): + """Test fetching chunks with pagination parameters.""" + elasticsearch_core_instance.client = MagicMock() + elasticsearch_core_instance.client.count.return_value = {"count": 5} + elasticsearch_core_instance.client.search.return_value = { + "hits": { + "hits": [ + {"_id": "doc-2", "_source": {"content": "B"}}, + ] + } + } + + result = elasticsearch_core_instance.get_index_chunks( + "kb-index", page=2, page_size=1) + + assert result["chunks"] == [{"content": "B", "id": "doc-2"}] + assert result["page"] == 2 + assert result["page_size"] == 1 + assert result["total"] == 5 + elasticsearch_core_instance.client.scroll.assert_not_called() + elasticsearch_core_instance.client.clear_scroll.assert_not_called() + + +def test_get_index_chunks_not_found(elasticsearch_core_instance): + """Test fetching chunks when index does not exist.""" + elasticsearch_core_instance.client = MagicMock() + elasticsearch_core_instance.client.count.side_effect = exceptions.NotFoundError( + 404, "not found", {}) + + chunks = elasticsearch_core_instance.get_index_chunks("missing-index") + + assert chunks == {"chunks": [], "total": 0, + "page": None, "page_size": None} + elasticsearch_core_instance.client.clear_scroll.assert_not_called() + + +def test_get_index_chunks_cleanup_failure(elasticsearch_core_instance): + """Test cleanup warning path when clear_scroll raises.""" + elasticsearch_core_instance.client = MagicMock() + elasticsearch_core_instance.client.count.return_value = {"count": 1} + elasticsearch_core_instance.client.search.return_value = { + "_scroll_id": "scroll123", + "hits": { + "hits": [ + {"_id": "doc-1", "_source": {"content": "A"}} + ] + } + } + elasticsearch_core_instance.client.scroll.return_value = { + "_scroll_id": "scroll123", + "hits": {"hits": []} + } + elasticsearch_core_instance.client.clear_scroll.side_effect = Exception("cleanup error") + + chunks = elasticsearch_core_instance.get_index_chunks("kb-index") + + assert len(chunks["chunks"]) == 1 + assert chunks["chunks"][0]["id"] == "doc-1" + elasticsearch_core_instance.client.clear_scroll.assert_called_once_with(scroll_id="scroll123") + + # ---------------------------------------------------------------------------- # Tests for search operations # ---------------------------------------------------------------------------- @@ -608,6 +698,32 @@ def test_accurate_search_success(elasticsearch_core_instance): mock_exec.assert_called_once() +def test_accurate_search_builds_multi_index_query(elasticsearch_core_instance): + """Ensure accurate_search joins indices and applies top_k sizing.""" + with patch.object(elasticsearch_core_instance, 'exec_query') as mock_exec, \ + patch('sdk.nexent.vector_database.elasticsearch_core.calculate_term_weights') as mock_weights, \ + patch('sdk.nexent.vector_database.elasticsearch_core.build_weighted_query') as mock_build: + + mock_weights.return_value = {"test": 0.5} + mock_build.return_value = {"query": {"match_all": {}}} + mock_exec.return_value = [] + + elasticsearch_core_instance.accurate_search( + ["index_a", "index_b"], + "multi query", + top_k=7, + ) + + mock_weights.assert_called_once_with("multi query") + mock_build.assert_called_once_with("multi query", {"test": 0.5}) + mock_exec.assert_called_once() + + index_pattern, search_query = mock_exec.call_args[0] + assert index_pattern == "index_a,index_b" + assert search_query["size"] == 7 + assert search_query["_source"]["excludes"] == ["embedding"] + + def test_semantic_search_success(elasticsearch_core_instance): """Test semantic search with vector similarity.""" mock_embedding_model = MagicMock() @@ -636,6 +752,32 @@ def test_semantic_search_success(elasticsearch_core_instance): mock_exec.assert_called_once() +def test_semantic_search_sets_knn_parameters(elasticsearch_core_instance): + """Ensure semantic_search sets k and num_candidates based on top_k.""" + mock_embedding_model = MagicMock() + mock_embedding_model.get_embeddings.return_value = [[0.2] * 8] + + with patch.object(elasticsearch_core_instance, 'exec_query') as mock_exec: + mock_exec.return_value = [] + + elasticsearch_core_instance.semantic_search( + ["index_x"], + "query terms", + mock_embedding_model, + top_k=4, + ) + + mock_embedding_model.get_embeddings.assert_called_once_with( + "query terms") + mock_exec.assert_called_once() + + _, search_query = mock_exec.call_args[0] + assert search_query["knn"]["k"] == 4 + assert search_query["knn"]["num_candidates"] == 8 + assert search_query["size"] == 4 + assert search_query["_source"]["excludes"] == ["embedding"] + + def test_hybrid_search_success(elasticsearch_core_instance): """Test hybrid search combining accurate and semantic results.""" mock_embedding_model = MagicMock() @@ -683,7 +825,7 @@ def test_hybrid_search_success(elasticsearch_core_instance): # Tests for statistics and monitoring # ---------------------------------------------------------------------------- -def test_get_file_list_with_details_success(elasticsearch_core_instance): +def test_get_documents_detail_success(elasticsearch_core_instance): """Test getting file list with details.""" with patch.object(elasticsearch_core_instance.client, 'search') as mock_search: mock_search.return_value = { @@ -691,6 +833,7 @@ def test_get_file_list_with_details_success(elasticsearch_core_instance): "unique_sources": { "buckets": [ { + "doc_count": 3, "file_sample": { "hits": { "hits": [ @@ -711,39 +854,18 @@ def test_get_file_list_with_details_success(elasticsearch_core_instance): } } - result = elasticsearch_core_instance.get_file_list_with_details( + result = elasticsearch_core_instance.get_documents_detail( "test_index") assert len(result) == 1 assert result[0]["path_or_url"] == "/path/to/file1.pdf" assert result[0]["filename"] == "file1.pdf" assert result[0]["file_size"] == 1024 + assert result[0]["chunk_count"] == 3 mock_search.assert_called_once() -def test_get_index_mapping_success(elasticsearch_core_instance): - """Test getting index mapping.""" - with patch.object(elasticsearch_core_instance.client.indices, 'get_mapping') as mock_get_mapping: - mock_get_mapping.return_value = { - "test_index": { - "mappings": { - "properties": { - "content": {"type": "text"}, - "embedding": {"type": "dense_vector"} - } - } - } - } - - result = elasticsearch_core_instance.get_index_mapping(["test_index"]) - - assert "test_index" in result - assert "content" in result["test_index"] - assert "embedding" in result["test_index"] - mock_get_mapping.assert_called_once() - - -def test_get_index_stats_success(elasticsearch_core_instance): +def test_get_indices_detail_success(elasticsearch_core_instance): """Test getting index statistics.""" with patch.object(elasticsearch_core_instance.client.indices, 'stats') as mock_stats, \ patch.object(elasticsearch_core_instance.client.indices, 'get_settings') as mock_settings, \ @@ -780,7 +902,7 @@ def test_get_index_stats_success(elasticsearch_core_instance): } } - result = elasticsearch_core_instance.get_index_stats( + result = elasticsearch_core_instance.get_indices_detail( ["test_index"], embedding_dim=1024) assert "test_index" in result diff --git a/test/sdk/vector_database/test_elasticsearch_core_coverage.py b/test/sdk/vector_database/test_elasticsearch_core_coverage.py index 5bc28bc44..f307c9d84 100644 --- a/test/sdk/vector_database/test_elasticsearch_core_coverage.py +++ b/test/sdk/vector_database/test_elasticsearch_core_coverage.py @@ -9,6 +9,7 @@ import os import sys from typing import List, Dict, Any +from datetime import datetime, timedelta # Add the project root to the path current_dir = os.path.dirname(os.path.abspath(__file__)) @@ -16,7 +17,7 @@ sys.path.insert(0, project_root) # Import the class under test -from sdk.nexent.vector_database.elasticsearch_core import ElasticSearchCore +from sdk.nexent.vector_database.elasticsearch_core import ElasticSearchCore, BulkOperation from elasticsearch import exceptions @@ -24,7 +25,7 @@ class TestElasticSearchCoreCoverage: """Test class for improving elasticsearch_core coverage""" @pytest.fixture - def es_core(self): + def vdb_core(self): """Create an ElasticSearchCore instance for testing.""" return ElasticSearchCore( host="http://localhost:9200", @@ -33,147 +34,147 @@ def es_core(self): ssl_show_warn=False ) - def test_force_refresh_with_retry_success(self, es_core): + def test_force_refresh_with_retry_success(self, vdb_core): """Test _force_refresh_with_retry successful refresh""" - es_core.client = MagicMock() - es_core.client.indices.refresh.return_value = {"_shards": {"total": 1, "successful": 1}} + vdb_core.client = MagicMock() + vdb_core.client.indices.refresh.return_value = {"_shards": {"total": 1, "successful": 1}} - result = es_core._force_refresh_with_retry("test_index") + result = vdb_core._force_refresh_with_retry("test_index") assert result is True - es_core.client.indices.refresh.assert_called_once_with(index="test_index") + vdb_core.client.indices.refresh.assert_called_once_with(index="test_index") - def test_force_refresh_with_retry_failure_retry(self, es_core): + def test_force_refresh_with_retry_failure_retry(self, vdb_core): """Test _force_refresh_with_retry with retries""" - es_core.client = MagicMock() - es_core.client.indices.refresh.side_effect = [ + vdb_core.client = MagicMock() + vdb_core.client.indices.refresh.side_effect = [ Exception("Connection error"), Exception("Still failing"), {"_shards": {"total": 1, "successful": 1}} ] with patch('time.sleep'): # Mock sleep to speed up test - result = es_core._force_refresh_with_retry("test_index", max_retries=3) + result = vdb_core._force_refresh_with_retry("test_index", max_retries=3) assert result is True - assert es_core.client.indices.refresh.call_count == 3 + assert vdb_core.client.indices.refresh.call_count == 3 - def test_force_refresh_with_retry_max_retries_exceeded(self, es_core): + def test_force_refresh_with_retry_max_retries_exceeded(self, vdb_core): """Test _force_refresh_with_retry when max retries exceeded""" - es_core.client = MagicMock() - es_core.client.indices.refresh.side_effect = Exception("Persistent error") + vdb_core.client = MagicMock() + vdb_core.client.indices.refresh.side_effect = Exception("Persistent error") with patch('time.sleep'): # Mock sleep to speed up test - result = es_core._force_refresh_with_retry("test_index", max_retries=2) + result = vdb_core._force_refresh_with_retry("test_index", max_retries=2) assert result is False - assert es_core.client.indices.refresh.call_count == 2 + assert vdb_core.client.indices.refresh.call_count == 2 - def test_ensure_index_ready_success(self, es_core): + def test_ensure_index_ready_success(self, vdb_core): """Test _ensure_index_ready successful case""" - es_core.client = MagicMock() - es_core.client.cluster.health.return_value = {"status": "green"} - es_core.client.search.return_value = {"hits": {"total": {"value": 0}}} + vdb_core.client = MagicMock() + vdb_core.client.cluster.health.return_value = {"status": "green"} + vdb_core.client.search.return_value = {"hits": {"total": {"value": 0}}} - result = es_core._ensure_index_ready("test_index") + result = vdb_core._ensure_index_ready("test_index") assert result is True - def test_ensure_index_ready_yellow_status(self, es_core): + def test_ensure_index_ready_yellow_status(self, vdb_core): """Test _ensure_index_ready with yellow status""" - es_core.client = MagicMock() - es_core.client.cluster.health.return_value = {"status": "yellow"} - es_core.client.search.return_value = {"hits": {"total": {"value": 0}}} + vdb_core.client = MagicMock() + vdb_core.client.cluster.health.return_value = {"status": "yellow"} + vdb_core.client.search.return_value = {"hits": {"total": {"value": 0}}} - result = es_core._ensure_index_ready("test_index") + result = vdb_core._ensure_index_ready("test_index") assert result is True - def test_ensure_index_ready_timeout(self, es_core): + def test_ensure_index_ready_timeout(self, vdb_core): """Test _ensure_index_ready timeout scenario""" - es_core.client = MagicMock() - es_core.client.cluster.health.return_value = {"status": "red"} + vdb_core.client = MagicMock() + vdb_core.client.cluster.health.return_value = {"status": "red"} with patch('time.sleep'): # Mock sleep to speed up test - result = es_core._ensure_index_ready("test_index", timeout=1) + result = vdb_core._ensure_index_ready("test_index", timeout=1) assert result is False - def test_ensure_index_ready_exception(self, es_core): + def test_ensure_index_ready_exception(self, vdb_core): """Test _ensure_index_ready with exception""" - es_core.client = MagicMock() - es_core.client.cluster.health.side_effect = Exception("Connection error") + vdb_core.client = MagicMock() + vdb_core.client.cluster.health.side_effect = Exception("Connection error") with patch('time.sleep'): # Mock sleep to speed up test - result = es_core._ensure_index_ready("test_index", timeout=1) + result = vdb_core._ensure_index_ready("test_index", timeout=1) assert result is False - def test_apply_bulk_settings_success(self, es_core): + def test_apply_bulk_settings_success(self, vdb_core): """Test _apply_bulk_settings successful case""" - es_core.client = MagicMock() - es_core.client.indices.put_settings.return_value = {"acknowledged": True} + vdb_core.client = MagicMock() + vdb_core.client.indices.put_settings.return_value = {"acknowledged": True} - es_core._apply_bulk_settings("test_index") - es_core.client.indices.put_settings.assert_called_once() + vdb_core._apply_bulk_settings("test_index") + vdb_core.client.indices.put_settings.assert_called_once() - def test_apply_bulk_settings_failure(self, es_core): + def test_apply_bulk_settings_failure(self, vdb_core): """Test _apply_bulk_settings with exception""" - es_core.client = MagicMock() - es_core.client.indices.put_settings.side_effect = Exception("Settings error") + vdb_core.client = MagicMock() + vdb_core.client.indices.put_settings.side_effect = Exception("Settings error") # Should not raise exception, just log warning - es_core._apply_bulk_settings("test_index") - es_core.client.indices.put_settings.assert_called_once() + vdb_core._apply_bulk_settings("test_index") + vdb_core.client.indices.put_settings.assert_called_once() - def test_restore_normal_settings_success(self, es_core): + def test_restore_normal_settings_success(self, vdb_core): """Test _restore_normal_settings successful case""" - es_core.client = MagicMock() - es_core.client.indices.put_settings.return_value = {"acknowledged": True} - es_core._force_refresh_with_retry = MagicMock(return_value=True) + vdb_core.client = MagicMock() + vdb_core.client.indices.put_settings.return_value = {"acknowledged": True} + vdb_core._force_refresh_with_retry = MagicMock(return_value=True) - es_core._restore_normal_settings("test_index") - es_core.client.indices.put_settings.assert_called_once() - es_core._force_refresh_with_retry.assert_called_once_with("test_index") + vdb_core._restore_normal_settings("test_index") + vdb_core.client.indices.put_settings.assert_called_once() + vdb_core._force_refresh_with_retry.assert_called_once_with("test_index") - def test_restore_normal_settings_failure(self, es_core): + def test_restore_normal_settings_failure(self, vdb_core): """Test _restore_normal_settings with exception""" - es_core.client = MagicMock() - es_core.client.indices.put_settings.side_effect = Exception("Settings error") + vdb_core.client = MagicMock() + vdb_core.client.indices.put_settings.side_effect = Exception("Settings error") # Should not raise exception, just log warning - es_core._restore_normal_settings("test_index") - es_core.client.indices.put_settings.assert_called_once() + vdb_core._restore_normal_settings("test_index") + vdb_core.client.indices.put_settings.assert_called_once() - def test_delete_index_success(self, es_core): + def test_delete_index_success(self, vdb_core): """Test delete_index successful case""" - es_core.client = MagicMock() - es_core.client.indices.delete.return_value = {"acknowledged": True} + vdb_core.client = MagicMock() + vdb_core.client.indices.delete.return_value = {"acknowledged": True} - result = es_core.delete_index("test_index") + result = vdb_core.delete_index("test_index") assert result is True - es_core.client.indices.delete.assert_called_once_with(index="test_index") + vdb_core.client.indices.delete.assert_called_once_with(index="test_index") - def test_delete_index_not_found(self, es_core): + def test_delete_index_not_found(self, vdb_core): """Test delete_index when index not found""" - es_core.client = MagicMock() + vdb_core.client = MagicMock() # Create a proper NotFoundError with required parameters not_found_error = exceptions.NotFoundError(404, "Index not found", {"error": {"type": "index_not_found_exception"}}) - es_core.client.indices.delete.side_effect = not_found_error + vdb_core.client.indices.delete.side_effect = not_found_error - result = es_core.delete_index("test_index") + result = vdb_core.delete_index("test_index") assert result is False - es_core.client.indices.delete.assert_called_once_with(index="test_index") + vdb_core.client.indices.delete.assert_called_once_with(index="test_index") - def test_delete_index_general_exception(self, es_core): + def test_delete_index_general_exception(self, vdb_core): """Test delete_index with general exception""" - es_core.client = MagicMock() - es_core.client.indices.delete.side_effect = Exception("General error") + vdb_core.client = MagicMock() + vdb_core.client.indices.delete.side_effect = Exception("General error") - result = es_core.delete_index("test_index") + result = vdb_core.delete_index("test_index") assert result is False - es_core.client.indices.delete.assert_called_once_with(index="test_index") + vdb_core.client.indices.delete.assert_called_once_with(index="test_index") - def test_handle_bulk_errors_no_errors(self, es_core): + def test_handle_bulk_errors_no_errors(self, vdb_core): """Test _handle_bulk_errors when no errors in response""" response = {"errors": False, "items": []} - es_core._handle_bulk_errors(response) + vdb_core._handle_bulk_errors(response) # Should not raise any exceptions - def test_handle_bulk_errors_with_version_conflict(self, es_core): + def test_handle_bulk_errors_with_version_conflict(self, vdb_core): """Test _handle_bulk_errors with version conflict (should be ignored)""" response = { "errors": True, @@ -192,10 +193,10 @@ def test_handle_bulk_errors_with_version_conflict(self, es_core): } ] } - es_core._handle_bulk_errors(response) + vdb_core._handle_bulk_errors(response) # Should not raise any exceptions for version conflicts - def test_handle_bulk_errors_with_fatal_error(self, es_core): + def test_handle_bulk_errors_with_fatal_error(self, vdb_core): """Test _handle_bulk_errors with fatal error""" response = { "errors": True, @@ -214,10 +215,10 @@ def test_handle_bulk_errors_with_fatal_error(self, es_core): } ] } - es_core._handle_bulk_errors(response) + vdb_core._handle_bulk_errors(response) # Should log error but not raise exception - def test_handle_bulk_errors_with_caused_by(self, es_core): + def test_handle_bulk_errors_with_caused_by(self, vdb_core): """Test _handle_bulk_errors with caused_by information""" response = { "errors": True, @@ -236,60 +237,289 @@ def test_handle_bulk_errors_with_caused_by(self, es_core): } ] } - es_core._handle_bulk_errors(response) + vdb_core._handle_bulk_errors(response) # Should log both main error and caused_by error - def test_delete_documents_by_path_or_url_success(self, es_core): - """Test delete_documents_by_path_or_url successful case""" - es_core.client = MagicMock() - es_core.client.delete_by_query.return_value = {"deleted": 5} + def test_delete_documents_success(self, vdb_core): + """Test delete_documents successful case""" + vdb_core.client = MagicMock() + vdb_core.client.delete_by_query.return_value = {"deleted": 5} - result = es_core.delete_documents_by_path_or_url("test_index", "/path/to/file.pdf") + result = vdb_core.delete_documents("test_index", "/path/to/file.pdf") assert result == 5 - es_core.client.delete_by_query.assert_called_once() + vdb_core.client.delete_by_query.assert_called_once() - def test_delete_documents_by_path_or_url_exception(self, es_core): - """Test delete_documents_by_path_or_url with exception""" - es_core.client = MagicMock() - es_core.client.delete_by_query.side_effect = Exception("Delete error") + def test_delete_documents_exception(self, vdb_core): + """Test delete_documents with exception""" + vdb_core.client = MagicMock() + vdb_core.client.delete_by_query.side_effect = Exception("Delete error") - result = es_core.delete_documents_by_path_or_url("test_index", "/path/to/file.pdf") + result = vdb_core.delete_documents("test_index", "/path/to/file.pdf") assert result == 0 - es_core.client.delete_by_query.assert_called_once() - - def test_get_index_mapping_success(self, es_core): - """Test get_index_mapping successful case""" - es_core.client = MagicMock() - es_core.client.indices.get_mapping.return_value = { - "test_index": { - "mappings": { - "properties": { - "title": {"type": "text"}, - "content": {"type": "text"} + vdb_core.client.delete_by_query.assert_called_once() + + def test_get_index_chunks_not_found(self, vdb_core): + """Ensure get_index_chunks handles missing index gracefully.""" + vdb_core.client = MagicMock() + vdb_core.client.count.side_effect = exceptions.NotFoundError( + 404, "missing", {}) + + result = vdb_core.get_index_chunks("missing-index") + + assert result == {"chunks": [], "total": 0, + "page": None, "page_size": None} + vdb_core.client.clear_scroll.assert_not_called() + + def test_get_index_chunks_cleanup_warning(self, vdb_core): + """Ensure clear_scroll errors are swallowed.""" + vdb_core.client = MagicMock() + vdb_core.client.count.return_value = {"count": 1} + vdb_core.client.search.return_value = { + "_scroll_id": "scroll123", + "hits": {"hits": [{"_id": "doc-1", "_source": {"content": "A"}}]} + } + vdb_core.client.scroll.return_value = { + "_scroll_id": "scroll123", + "hits": {"hits": []} + } + vdb_core.client.clear_scroll.side_effect = Exception("cleanup-failed") + + result = vdb_core.get_index_chunks("kb-index") + + assert len(result["chunks"]) == 1 + assert result["chunks"][0]["id"] == "doc-1" + vdb_core.client.clear_scroll.assert_called_once_with( + scroll_id="scroll123") + + def test_create_index_request_error_existing(self, vdb_core): + """Ensure RequestError with resource already exists still succeeds.""" + vdb_core.client = MagicMock() + vdb_core.client.indices.exists.return_value = False + meta = MagicMock(status=400) + vdb_core.client.indices.create.side_effect = exceptions.RequestError( + "resource_already_exists_exception", meta, {"error": {"reason": "exists"}} + ) + vdb_core._ensure_index_ready = MagicMock(return_value=True) + + assert vdb_core.create_index("test_index") is True + vdb_core._ensure_index_ready.assert_called_once_with("test_index") + + def test_create_index_request_error_failure(self, vdb_core): + """Ensure create_index returns False for non recoverable RequestError.""" + vdb_core.client = MagicMock() + vdb_core.client.indices.exists.return_value = False + meta = MagicMock(status=400) + vdb_core.client.indices.create.side_effect = exceptions.RequestError( + "validation_exception", meta, {"error": {"reason": "bad"}} + ) + + assert vdb_core.create_index("test_index") is False + + def test_create_index_general_exception(self, vdb_core): + """Ensure unexpected exception from create_index returns False.""" + vdb_core.client = MagicMock() + vdb_core.client.indices.exists.return_value = False + vdb_core.client.indices.create.side_effect = Exception("boom") + + assert vdb_core.create_index("test_index") is False + + def test_force_refresh_with_retry_zero_attempts(self, vdb_core): + """Ensure guard clause without attempts returns False.""" + vdb_core.client = MagicMock() + result = vdb_core._force_refresh_with_retry("idx", max_retries=0) + assert result is False + + def test_bulk_operation_context_preexisting_operation(self, vdb_core): + """Ensure context skips apply/restore when operations remain.""" + existing = BulkOperation( + index_name="test_index", + operation_id="existing", + start_time=datetime.utcnow(), + expected_duration=timedelta(seconds=30), + ) + vdb_core._bulk_operations = {"test_index": [existing]} + + with patch.object(vdb_core, "_apply_bulk_settings") as mock_apply, \ + patch.object(vdb_core, "_restore_normal_settings") as mock_restore: + + with vdb_core.bulk_operation_context("test_index") as op_id: + assert op_id != existing.operation_id + + mock_apply.assert_not_called() + mock_restore.assert_not_called() + assert vdb_core._bulk_operations["test_index"] == [existing] + + def test_get_user_indices_exception(self, vdb_core): + """Ensure get_user_indices returns empty list on failure.""" + vdb_core.client = MagicMock() + vdb_core.client.indices.get_alias.side_effect = Exception("failure") + + assert vdb_core.get_user_indices() == [] + + def test_check_index_exists(self, vdb_core): + """Ensure check_index_exists delegates to client.""" + vdb_core.client = MagicMock() + vdb_core.client.indices.exists.return_value = True + + assert vdb_core.check_index_exists("idx") is True + vdb_core.client.indices.exists.assert_called_once_with(index="idx") + + def test_small_batch_insert_sets_embedding_model_name(self, vdb_core): + """_small_batch_insert should attach embedding model name.""" + vdb_core.client = MagicMock() + vdb_core.client.bulk.return_value = {"errors": False, "items": []} + vdb_core._preprocess_documents = MagicMock(return_value=[{"content": "body"}]) + vdb_core._handle_bulk_errors = MagicMock() + + mock_embedding_model = MagicMock() + mock_embedding_model.get_embeddings.return_value = [[0.1, 0.2]] + mock_embedding_model.embedding_model_name = "demo-model" + + vdb_core._small_batch_insert("idx", [{"content": "body"}], "content", mock_embedding_model) + operations = vdb_core.client.bulk.call_args.kwargs["operations"] + inserted_doc = operations[1] + assert inserted_doc["embedding_model_name"] == "demo-model" + + def test_large_batch_insert_sets_default_embedding_model_name(self, vdb_core): + """_large_batch_insert should fall back to 'unknown' when attr missing.""" + vdb_core.client = MagicMock() + vdb_core.client.bulk.return_value = {"errors": False, "items": []} + vdb_core._preprocess_documents = MagicMock(return_value=[{"content": "body"}]) + vdb_core._handle_bulk_errors = MagicMock() + + class SimpleEmbedding: + def get_embeddings(self, texts): + return [[0.1 for _ in texts]] + + embedding_model = SimpleEmbedding() + + vdb_core._large_batch_insert("idx", [{"content": "body"}], 10, "content", embedding_model) + operations = vdb_core.client.bulk.call_args.kwargs["operations"] + inserted_doc = operations[1] + assert inserted_doc["embedding_model_name"] == "unknown" + + def test_large_batch_insert_bulk_exception(self, vdb_core): + """Ensure bulk exceptions are handled and indexing continues.""" + vdb_core.client = MagicMock() + vdb_core.client.bulk.side_effect = Exception("bulk error") + vdb_core._preprocess_documents = MagicMock(return_value=[{"content": "body"}]) + + mock_embedding_model = MagicMock() + mock_embedding_model.get_embeddings.return_value = [[0.1]] + + result = vdb_core._large_batch_insert("idx", [{"content": "body"}], 1, "content", mock_embedding_model) + assert result == 0 + + def test_large_batch_insert_preprocess_exception(self, vdb_core): + """Ensure outer exception handler returns zero on preprocess failure.""" + vdb_core._preprocess_documents = MagicMock(side_effect=Exception("fail")) + + mock_embedding_model = MagicMock() + result = vdb_core._large_batch_insert("idx", [{"content": "body"}], 10, "content", mock_embedding_model) + assert result == 0 + + def test_count_documents_success(self, vdb_core): + """Ensure count_documents returns ES count.""" + vdb_core.client = MagicMock() + vdb_core.client.count.return_value = {"count": 42} + + assert vdb_core.count_documents("idx") == 42 + + def test_count_documents_exception(self, vdb_core): + """Ensure count_documents returns zero on error.""" + vdb_core.client = MagicMock() + vdb_core.client.count.side_effect = Exception("fail") + + assert vdb_core.count_documents("idx") == 0 + + def test_search_and_multi_search_passthrough(self, vdb_core): + """Ensure search helpers delegate to the client.""" + vdb_core.client = MagicMock() + vdb_core.client.search.return_value = {"hits": {}} + vdb_core.client.msearch.return_value = {"responses": []} + + assert vdb_core.search("idx", {"query": {"match_all": {}}}) == {"hits": {}} + assert vdb_core.multi_search([{"query": {"match_all": {}}}], "idx") == {"responses": []} + + def test_exec_query_formats_results(self, vdb_core): + """Ensure exec_query strips metadata and exposes scores.""" + vdb_core.client = MagicMock() + vdb_core.client.search.return_value = { + "hits": { + "hits": [ + { + "_score": 1.23, + "_index": "idx", + "_source": {"id": "doc1", "content": "body"}, } - } + ] } } - - result = es_core.get_index_mapping(["test_index"]) - assert "test_index" in result - assert "title" in result["test_index"] - assert "content" in result["test_index"] - - def test_get_index_mapping_exception(self, es_core): - """Test get_index_mapping with exception""" - es_core.client = MagicMock() - es_core.client.indices.get_mapping.side_effect = Exception("Mapping error") - - result = es_core.get_index_mapping(["test_index"]) - # The function returns empty list for failed indices, not empty dict - assert "test_index" in result - assert result["test_index"] == [] - - def test_get_index_stats_success(self, es_core): - """Test get_index_stats successful case""" - es_core.client = MagicMock() - es_core.client.indices.stats.return_value = { + + results = vdb_core.exec_query("idx", {"query": {}}) + assert results == [ + {"score": 1.23, "document": {"id": "doc1", "content": "body"}, "index": "idx"} + ] + + def test_hybrid_search_missing_fields_logged_for_accurate(self, vdb_core): + """Ensure hybrid_search tolerates missing accurate fields.""" + mock_embedding_model = MagicMock() + with patch.object(vdb_core, "accurate_search", return_value=[{"score": 1.0}]), \ + patch.object(vdb_core, "semantic_search", return_value=[]): + assert vdb_core.hybrid_search(["idx"], "query", mock_embedding_model) == [] + + def test_hybrid_search_missing_fields_logged_for_semantic(self, vdb_core): + """Ensure hybrid_search tolerates missing semantic fields.""" + mock_embedding_model = MagicMock() + with patch.object(vdb_core, "accurate_search", return_value=[]), \ + patch.object(vdb_core, "semantic_search", return_value=[{"score": 0.5}]): + assert vdb_core.hybrid_search(["idx"], "query", mock_embedding_model) == [] + + def test_hybrid_search_faulty_combined_results(self, vdb_core): + """Inject faulty combined result to hit KeyError handling in final loop.""" + mock_embedding_model = MagicMock() + accurate_payload = [ + {"score": 1.0, "document": {"id": "doc1"}, "index": "idx"} + ] + + with patch.object(vdb_core, "accurate_search", return_value=accurate_payload), \ + patch.object(vdb_core, "semantic_search", return_value=[]): + + injected = {"done": False} + + def tracer(frame, event, arg): + if ( + frame.f_code.co_name == "hybrid_search" + and event == "line" + and frame.f_lineno == 788 + and not injected["done"] + ): + frame.f_locals["combined_results"]["faulty"] = { + "accurate_score": 0, + "semantic_score": 0, + } + injected["done"] = True + return tracer + + sys.settrace(tracer) + try: + results = vdb_core.hybrid_search(["idx"], "query", mock_embedding_model) + finally: + sys.settrace(None) + + assert len(results) == 1 + + def test_get_documents_detail_exception(self, vdb_core): + """Ensure get_documents_detail returns empty list on failure.""" + vdb_core.client = MagicMock() + vdb_core.client.search.side_effect = Exception("fail") + + assert vdb_core.get_documents_detail("idx") == [] + + def test_get_indices_detail_success(self, vdb_core): + """Test get_indices_detail successful case""" + vdb_core.client = MagicMock() + vdb_core.client.indices.stats.return_value = { "indices": { "test_index": { "primaries": { @@ -301,7 +531,7 @@ def test_get_index_stats_success(self, es_core): } } } - es_core.client.indices.get_settings.return_value = { + vdb_core.client.indices.get_settings.return_value = { "test_index": { "settings": { "index": { @@ -312,7 +542,7 @@ def test_get_index_stats_success(self, es_core): } } } - es_core.client.search.return_value = { + vdb_core.client.search.return_value = { "aggregations": { "unique_path_or_url_count": {"value": 10}, "process_sources": {"buckets": [{"key": "test_source"}]}, @@ -320,25 +550,25 @@ def test_get_index_stats_success(self, es_core): } } - result = es_core.get_index_stats(["test_index"]) + result = vdb_core.get_indices_detail(["test_index"]) assert "test_index" in result assert "base_info" in result["test_index"] assert "search_performance" in result["test_index"] - def test_get_index_stats_exception(self, es_core): - """Test get_index_stats with exception""" - es_core.client = MagicMock() - es_core.client.indices.stats.side_effect = Exception("Stats error") + def test_get_indices_detail_exception(self, vdb_core): + """Test get_indices_detail with exception""" + vdb_core.client = MagicMock() + vdb_core.client.indices.stats.side_effect = Exception("Stats error") - result = es_core.get_index_stats(["test_index"]) + result = vdb_core.get_indices_detail(["test_index"]) # The function returns error info for failed indices, not empty dict assert "test_index" in result assert "error" in result["test_index"] - def test_get_index_stats_with_embedding_dim(self, es_core): - """Test get_index_stats with embedding dimension""" - es_core.client = MagicMock() - es_core.client.indices.stats.return_value = { + def test_get_indices_detail_with_embedding_dim(self, vdb_core): + """Test get_indices_detail with embedding dimension""" + vdb_core.client = MagicMock() + vdb_core.client.indices.stats.return_value = { "indices": { "test_index": { "primaries": { @@ -350,7 +580,7 @@ def test_get_index_stats_with_embedding_dim(self, es_core): } } } - es_core.client.indices.get_settings.return_value = { + vdb_core.client.indices.get_settings.return_value = { "test_index": { "settings": { "index": { @@ -361,7 +591,7 @@ def test_get_index_stats_with_embedding_dim(self, es_core): } } } - es_core.client.search.return_value = { + vdb_core.client.search.return_value = { "aggregations": { "unique_path_or_url_count": {"value": 10}, "process_sources": {"buckets": [{"key": "test_source"}]}, @@ -369,61 +599,61 @@ def test_get_index_stats_with_embedding_dim(self, es_core): } } - result = es_core.get_index_stats(["test_index"], embedding_dim=512) + result = vdb_core.get_indices_detail(["test_index"], embedding_dim=512) assert "test_index" in result assert "base_info" in result["test_index"] assert "search_performance" in result["test_index"] assert result["test_index"]["base_info"]["embedding_dim"] == 512 - def test_bulk_operation_context_success(self, es_core): + def test_bulk_operation_context_success(self, vdb_core): """Test bulk_operation_context successful case""" - es_core._bulk_operations = {} - es_core._operation_counter = 0 - es_core._settings_lock = MagicMock() - es_core._apply_bulk_settings = MagicMock() - es_core._restore_normal_settings = MagicMock() + vdb_core._bulk_operations = {} + vdb_core._operation_counter = 0 + vdb_core._settings_lock = MagicMock() + vdb_core._apply_bulk_settings = MagicMock() + vdb_core._restore_normal_settings = MagicMock() - with es_core.bulk_operation_context("test_index") as operation_id: + with vdb_core.bulk_operation_context("test_index") as operation_id: assert operation_id is not None - assert "test_index" in es_core._bulk_operations - es_core._apply_bulk_settings.assert_called_once_with("test_index") + assert "test_index" in vdb_core._bulk_operations + vdb_core._apply_bulk_settings.assert_called_once_with("test_index") # After context exit, should restore settings - es_core._restore_normal_settings.assert_called_once_with("test_index") + vdb_core._restore_normal_settings.assert_called_once_with("test_index") - def test_bulk_operation_context_multiple_operations(self, es_core): + def test_bulk_operation_context_multiple_operations(self, vdb_core): """Test bulk_operation_context with multiple operations""" - es_core._bulk_operations = {} - es_core._operation_counter = 0 - es_core._settings_lock = MagicMock() - es_core._apply_bulk_settings = MagicMock() - es_core._restore_normal_settings = MagicMock() + vdb_core._bulk_operations = {} + vdb_core._operation_counter = 0 + vdb_core._settings_lock = MagicMock() + vdb_core._apply_bulk_settings = MagicMock() + vdb_core._restore_normal_settings = MagicMock() # First operation - with es_core.bulk_operation_context("test_index") as op1: + with vdb_core.bulk_operation_context("test_index") as op1: assert op1 is not None - es_core._apply_bulk_settings.assert_called_once() + vdb_core._apply_bulk_settings.assert_called_once() # After first operation exits, settings should be restored - es_core._restore_normal_settings.assert_called_once_with("test_index") + vdb_core._restore_normal_settings.assert_called_once_with("test_index") # Second operation - will apply settings again since first operation is done - with es_core.bulk_operation_context("test_index") as op2: + with vdb_core.bulk_operation_context("test_index") as op2: assert op2 is not None # Should call apply_bulk_settings again since first operation is done - assert es_core._apply_bulk_settings.call_count == 2 + assert vdb_core._apply_bulk_settings.call_count == 2 # After second operation exits, should restore settings again - assert es_core._restore_normal_settings.call_count == 2 + assert vdb_core._restore_normal_settings.call_count == 2 - def test_small_batch_insert_success(self, es_core): + def test_small_batch_insert_success(self, vdb_core): """Test _small_batch_insert successful case""" - es_core.client = MagicMock() - es_core.client.bulk.return_value = {"items": [], "errors": False} - es_core._preprocess_documents = MagicMock(return_value=[ + vdb_core.client = MagicMock() + vdb_core.client.bulk.return_value = {"items": [], "errors": False} + vdb_core._preprocess_documents = MagicMock(return_value=[ {"content": "test content", "title": "test"} ]) - es_core._handle_bulk_errors = MagicMock() + vdb_core._handle_bulk_errors = MagicMock() mock_embedding_model = MagicMock() mock_embedding_model.get_embeddings.return_value = [[0.1, 0.2, 0.3]] @@ -431,28 +661,28 @@ def test_small_batch_insert_success(self, es_core): documents = [{"content": "test content", "title": "test"}] - result = es_core._small_batch_insert("test_index", documents, "content", mock_embedding_model) + result = vdb_core._small_batch_insert("test_index", documents, "content", mock_embedding_model) assert result == 1 - es_core.client.bulk.assert_called_once() + vdb_core.client.bulk.assert_called_once() - def test_small_batch_insert_exception(self, es_core): + def test_small_batch_insert_exception(self, vdb_core): """Test _small_batch_insert with exception""" - es_core._preprocess_documents = MagicMock(side_effect=Exception("Preprocess error")) + vdb_core._preprocess_documents = MagicMock(side_effect=Exception("Preprocess error")) mock_embedding_model = MagicMock() documents = [{"content": "test content", "title": "test"}] - result = es_core._small_batch_insert("test_index", documents, "content", mock_embedding_model) + result = vdb_core._small_batch_insert("test_index", documents, "content", mock_embedding_model) assert result == 0 - def test_large_batch_insert_success(self, es_core): + def test_large_batch_insert_success(self, vdb_core): """Test _large_batch_insert successful case""" - es_core.client = MagicMock() - es_core.client.bulk.return_value = {"items": [], "errors": False} - es_core._preprocess_documents = MagicMock(return_value=[ + vdb_core.client = MagicMock() + vdb_core.client.bulk.return_value = {"items": [], "errors": False} + vdb_core._preprocess_documents = MagicMock(return_value=[ {"content": "test content", "title": "test"} ]) - es_core._handle_bulk_errors = MagicMock() + vdb_core._handle_bulk_errors = MagicMock() mock_embedding_model = MagicMock() mock_embedding_model.get_embeddings.return_value = [[0.1, 0.2, 0.3]] @@ -460,14 +690,14 @@ def test_large_batch_insert_success(self, es_core): documents = [{"content": "test content", "title": "test"}] - result = es_core._large_batch_insert("test_index", documents, 10, "content", mock_embedding_model) + result = vdb_core._large_batch_insert("test_index", documents, 10, "content", mock_embedding_model) assert result == 1 - es_core.client.bulk.assert_called_once() + vdb_core.client.bulk.assert_called_once() - def test_large_batch_insert_embedding_error(self, es_core): + def test_large_batch_insert_embedding_error(self, vdb_core): """Test _large_batch_insert with embedding API error""" - es_core.client = MagicMock() - es_core._preprocess_documents = MagicMock(return_value=[ + vdb_core.client = MagicMock() + vdb_core._preprocess_documents = MagicMock(return_value=[ {"content": "test content", "title": "test"} ]) @@ -476,13 +706,13 @@ def test_large_batch_insert_embedding_error(self, es_core): documents = [{"content": "test content", "title": "test"}] - result = es_core._large_batch_insert("test_index", documents, 10, "content", mock_embedding_model) + result = vdb_core._large_batch_insert("test_index", documents, 10, "content", mock_embedding_model) assert result == 0 # No documents indexed due to embedding error - def test_large_batch_insert_no_embeddings(self, es_core): + def test_large_batch_insert_no_embeddings(self, vdb_core): """Test _large_batch_insert with no successful embeddings""" - es_core.client = MagicMock() - es_core._preprocess_documents = MagicMock(return_value=[ + vdb_core.client = MagicMock() + vdb_core._preprocess_documents = MagicMock(return_value=[ {"content": "test content", "title": "test"} ]) @@ -491,5 +721,5 @@ def test_large_batch_insert_no_embeddings(self, es_core): documents = [{"content": "test content", "title": "test"}] - result = es_core._large_batch_insert("test_index", documents, 10, "content", mock_embedding_model) + result = vdb_core._large_batch_insert("test_index", documents, 10, "content", mock_embedding_model) assert result == 0 # No documents indexed