From e7ae3dbe9fa0383cfacedd4d7d4fc3afec7c6482 Mon Sep 17 00:00:00 2001 From: zhizhi <928570418@qq.com> Date: Tue, 27 Jan 2026 16:00:14 +0800 Subject: [PATCH 001/167] =?UTF-8?q?=E2=9C=A8=20Refactor=20KnowledgeBaseSea?= =?UTF-8?q?rchTool=20and=20DataMateSearchTool=20to=20streamline=20index=20?= =?UTF-8?q?name=20handling=20and=20improve=20error=20logging.=20Update=20U?= =?UTF-8?q?I=20localization=20for=20knowledge=20base=20selection=20prompts?= =?UTF-8?q?.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/agents/create_agent_info.py | 34 +- .../services/tool_configuration_service.py | 42 --- frontend/public/locales/en/common.json | 2 + frontend/public/locales/zh/common.json | 2 + frontend/services/knowledgeBaseService.ts | 291 +++++++++++++----- sdk/nexent/core/agents/nexent_agent.py | 11 +- sdk/nexent/core/tools/datamate_search_tool.py | 51 +-- 7 files changed, 243 insertions(+), 190 deletions(-) diff --git a/backend/agents/create_agent_info.py b/backend/agents/create_agent_info.py index d09029a97..4fd2411a7 100644 --- a/backend/agents/create_agent_info.py +++ b/backend/agents/create_agent_info.py @@ -15,7 +15,6 @@ get_vector_db_core, get_embedding_model, ) -from services.tenant_config_service import get_selected_knowledge_list, build_knowledge_name_mapping from services.remote_mcp_service import get_remote_mcp_server_list from services.memory_config_service import build_memory_context from services.image_service import get_vlm_model @@ -146,20 +145,16 @@ async def create_agent_config( try: for tool in tool_list: if "KnowledgeBaseSearchTool" == tool.class_name: - knowledge_info_list = get_selected_knowledge_list( - tenant_id=tenant_id, user_id=user_id) - if knowledge_info_list: - for knowledge_info in knowledge_info_list: - if knowledge_info.get('knowledge_sources') != 'elasticsearch': - continue - knowledge_name = knowledge_info.get("index_name") + index_names = tool.params.get("index_names") + if index_names: + for index_name in index_names: try: - message = ElasticSearchService().get_summary(index_name=knowledge_name) + message = ElasticSearchService().get_summary(index_name=index_name) summary = message.get("summary", "") - knowledge_base_summary += f"**{knowledge_name}**: {summary}\n\n" + knowledge_base_summary += f"**{index_name}**: {summary}\n\n" except Exception as e: logger.warning( - f"Failed to get summary for knowledge base {knowledge_name}: {e}") + f"Failed to get summary for knowledge base {index_name}: {e}") else: # TODO: Prompt should be refactored to yaml file knowledge_base_summary = "当前没有可用的知识库索引。\n" if language == 'zh' else "No knowledge base indexes are currently available.\n" @@ -238,24 +233,9 @@ async def create_tool_config_list(agent_id, tenant_id, user_id): # special logic for knowledge base search tool if tool_config.class_name == "KnowledgeBaseSearchTool": - knowledge_info_list = get_selected_knowledge_list( - tenant_id=tenant_id, user_id=user_id) - index_names = [knowledge_info.get( - "index_name") for knowledge_info in knowledge_info_list if knowledge_info.get('knowledge_sources') == 'elasticsearch'] - tool_config.metadata = { - "index_names": index_names, + tool_config.metadata = { "vdb_core": get_vector_db_core(), "embedding_model": get_embedding_model(tenant_id=tenant_id), - "name_resolver": build_knowledge_name_mapping(tenant_id=tenant_id, user_id=user_id), - } - elif tool_config.class_name == "DataMateSearchTool": - knowledge_info_list = get_selected_knowledge_list( - tenant_id=tenant_id, user_id=user_id) - index_names = [knowledge_info.get( - "index_name") for knowledge_info in knowledge_info_list if - knowledge_info.get('knowledge_sources') == 'datamate'] - tool_config.metadata = { - "index_names": index_names, } elif tool_config.class_name == "AnalyzeTextFileTool": tool_config.metadata = { diff --git a/backend/services/tool_configuration_service.py b/backend/services/tool_configuration_service.py index e7b39af3b..e0019860a 100644 --- a/backend/services/tool_configuration_service.py +++ b/backend/services/tool_configuration_service.py @@ -24,8 +24,6 @@ ) from services.file_management_service import get_llm_model from services.vectordatabase_service import get_embedding_model, get_vector_db_core -from services.tenant_config_service import get_selected_knowledge_list, build_knowledge_name_mapping -from database.knowledge_db import get_index_name_by_knowledge_name from database.client import minio_client from services.image_service import get_vlm_model @@ -537,54 +535,14 @@ def _validate_local_tool( instantiation_params[param_name] = param.default if tool_name == "knowledge_base_search": - if not tenant_id or not user_id: - raise ToolExecutionException(f"Tenant ID and User ID are required for {tool_name} validation") - knowledge_info_list = get_selected_knowledge_list( - tenant_id=tenant_id, user_id=user_id) - index_names = [knowledge_info.get("index_name") for knowledge_info in knowledge_info_list if knowledge_info.get('knowledge_sources') == 'elasticsearch'] - name_resolver = build_knowledge_name_mapping( - tenant_id=tenant_id, user_id=user_id) - - # Fallback: if user provided index_names in inputs, try to resolve them even when no selection stored - if (not index_names) and inputs and inputs.get("index_names"): - raw_names = inputs.get("index_names") - if isinstance(raw_names, str): - raw_names = [raw_names] - resolved_indices = [] - for raw in raw_names: - try: - resolved = get_index_name_by_knowledge_name( - raw, tenant_id=tenant_id) - name_resolver[raw] = resolved - resolved_indices.append(resolved) - except Exception: - # If not found as knowledge_name, assume it's already an index_name - resolved_indices.append(raw) - index_names = resolved_indices - embedding_model = get_embedding_model(tenant_id=tenant_id) vdb_core = get_vector_db_core() params = { **instantiation_params, - 'index_names': index_names, - 'name_resolver': name_resolver, 'vdb_core': vdb_core, 'embedding_model': embedding_model, } tool_instance = tool_class(**params) - elif tool_name == "datamate_search": - if not tenant_id or not user_id: - raise ToolExecutionException(f"Tenant ID and User ID are required for {tool_name} validation") - knowledge_info_list = get_selected_knowledge_list( - tenant_id=tenant_id, user_id=user_id) - index_names = [knowledge_info.get("index_name") for knowledge_info in knowledge_info_list if - knowledge_info.get('knowledge_sources') == 'datamate'] - - params = { - **instantiation_params, - 'index_names': index_names, - } - tool_instance = tool_class(**params) elif tool_name == "analyze_image": if not tenant_id or not user_id: raise ToolExecutionException(f"Tenant ID and User ID are required for {tool_name} validation") diff --git a/frontend/public/locales/en/common.json b/frontend/public/locales/en/common.json index e0e227c18..6fb1d9cd1 100644 --- a/frontend/public/locales/en/common.json +++ b/frontend/public/locales/en/common.json @@ -384,6 +384,8 @@ "toolConfig.input.model.placeholder": "Please select model", "toolConfig.input.string.placeholder": "Please enter {{name}}", "toolConfig.input.array.placeholder": "Please enter JSON array", + "toolConfig.placeholder.selectKb": "Please select knowledge bases", + "toolConfig.validation.selectKb": "Please select at least one knowledge base", "toolConfig.input.object.placeholder": "Please enter JSON object", "toolConfig.toolTest.toolInfo": "Tool Information", "toolConfig.toolTest.configParams": "Parameter Configuration", diff --git a/frontend/public/locales/zh/common.json b/frontend/public/locales/zh/common.json index d9acff57c..2d2763d9f 100644 --- a/frontend/public/locales/zh/common.json +++ b/frontend/public/locales/zh/common.json @@ -385,6 +385,8 @@ "toolConfig.input.model.placeholder": "请选择模型", "toolConfig.input.string.placeholder": "请输入{{name}}", "toolConfig.input.array.placeholder": "请输入JSON数组", + "toolConfig.placeholder.selectKb": "请选择知识库", + "toolConfig.validation.selectKb": "请选择至少一个知识库", "toolConfig.input.object.placeholder": "请输入JSON对象", "toolConfig.toolTest.toolInfo": "工具信息", "toolConfig.toolTest.configParams": "配置参数", diff --git a/frontend/services/knowledgeBaseService.ts b/frontend/services/knowledgeBaseService.ts index 79ae671ee..aca17c8a7 100644 --- a/frontend/services/knowledgeBaseService.ts +++ b/frontend/services/knowledgeBaseService.ts @@ -17,6 +17,20 @@ import log from "@/lib/logger"; // @ts-ignore const fetch: typeof fetchWithAuth = fetchWithAuth; +// Simple in-memory cache and in-flight dedupe for KB info +let cachedKbInfo: KnowledgeBase[] | null = null; +let kbCacheExpiry = 0; +let kbInFlightPromise: Promise | null = null; +const KB_CACHE_TTL_MS = 30 * 1000; // 30s +// In-flight dedupe and short cache for DataMate sync +let datamateInFlightPromise: Promise | null = null; +let datamateCache: any = null; +let datamateCacheExpiry = 0; +const DATAMATE_CACHE_TTL_MS = 30 * 1000; // 30s +// Persistent id->name map storage +const KB_ID_NAME_MAP_KEY = "kb_id_name_map_v1"; +const KB_ID_NAME_MAP_TTL_MS = 24 * 60 * 60 * 1000; // 24 hours + // Knowledge base service class class KnowledgeBaseService { // Check Elasticsearch health (force refresh, no caching for setup page) @@ -48,50 +62,86 @@ class KnowledgeBaseService { indices_info: any[]; created_records: any[]; }> { - try { - const response = await fetch( - API_ENDPOINTS.datamate.syncDatamateKnowledges, - { - method: "POST", - headers: getAuthHeaders(), + // Return cached result when fresh + const now = Date.now(); + if (datamateCache && now < datamateCacheExpiry) { + return datamateCache; + } + + if (datamateInFlightPromise) return datamateInFlightPromise; + + datamateInFlightPromise = (async () => { + try { + const response = await fetch( + API_ENDPOINTS.datamate.syncDatamateKnowledges, + { + method: "POST", + headers: getAuthHeaders(), + } + ); + + const data = await response.json(); + + if (!response.ok) { + throw new Error( + data.detail || "Failed to sync DataMate knowledge bases and create records" + ); } - ); - const data = await response.json(); + // cache result briefly + datamateCache = data; + datamateCacheExpiry = Date.now() + DATAMATE_CACHE_TTL_MS; - if (!response.ok) { - throw new Error( - data.detail || - "Failed to sync DataMate knowledge bases and create records" + return data; + } catch (error) { + log.error( + "Failed to sync DataMate knowledge bases and create records:", + error ); + throw error; + } finally { + datamateInFlightPromise = null; } + })(); - return data; - } catch (error) { - log.error( - "Failed to sync DataMate knowledge bases and create records:", - error - ); - throw error; - } + return datamateInFlightPromise; } // Get knowledge bases with stats from all sources (very slow, don't use it) async getKnowledgeBasesInfo( skipHealthCheck = false, - includeDataMateSync = true + includeDataMateSync = true, + forceRefresh = false ): Promise { - try { - const knowledgeBases: KnowledgeBase[] = []; + // Return cached result when fresh unless forceRefresh is requested + const now = Date.now(); + if (!forceRefresh && cachedKbInfo && now < kbCacheExpiry) { + return cachedKbInfo; + } + + if (kbInFlightPromise) { + return kbInFlightPromise; + } - // Get knowledge bases from Elasticsearch + kbInFlightPromise = (async () => { try { - // First check Elasticsearch health (unless skipped) - if (!skipHealthCheck) { - const isElasticsearchHealthy = await this.checkHealth(); - if (!isElasticsearchHealthy) { - log.warn("Elasticsearch service unavailable"); - } else { + const knowledgeBases: KnowledgeBase[] = []; + + // Get knowledge bases from Elasticsearch + try { + // Decide whether to fetch indices: + // - If skipHealthCheck is true, skip health check but still fetch indices. + // - If skipHealthCheck is false, only fetch indices when health check passes. + let shouldFetchIndices = true; + if (!skipHealthCheck) { + const isElasticsearchHealthy = await this.checkHealth(); + if (!isElasticsearchHealthy) { + log.warn("Elasticsearch service unavailable"); + shouldFetchIndices = false; + } + } + + if (shouldFetchIndices) { const response = await fetch( `${API_ENDPOINTS.knowledgeBase.indices}?include_stats=true`, { @@ -118,7 +168,8 @@ class KnowledgeBaseService { documentCount: stats.doc_count || 0, chunkCount: stats.chunk_count || 0, createdAt: stats.creation_date || null, - updatedAt: stats.update_date || stats.creation_date || null, + updatedAt: + stats.update_date || stats.creation_date || null, embeddingModel: stats.embedding_model || "unknown", avatar: "", chunkNum: 0, @@ -134,53 +185,151 @@ class KnowledgeBaseService { knowledgeBases.push(...esKnowledgeBases); } } + } catch (error) { + log.error("Failed to get Elasticsearch indices:", error); } + + // Sync DataMate knowledge bases and get the synced data (only if enabled) + if (includeDataMateSync) { + try { + const syncResult = await this.syncDataMateAndCreateRecords(); + if (syncResult.indices_info) { + // Convert synced DataMate indices to knowledge base format + const datamateKnowledgeBases: KnowledgeBase[] = + syncResult.indices_info.map((indexInfo: any) => { + const stats = indexInfo.stats?.base_info || {}; + const kbId = indexInfo.name; + const kbName = indexInfo.display_name || indexInfo.name; + + return { + id: kbId, + name: kbName, + description: "DataMate knowledge base", + documentCount: stats.doc_count || 0, + chunkCount: stats.chunk_count || 0, + createdAt: stats.creation_date || null, + updatedAt: + stats.update_date || stats.creation_date || null, + embeddingModel: stats.embedding_model || "unknown", + avatar: "", + chunkNum: 0, + language: "", + nickname: "", + parserId: "", + permission: "", + tokenNum: 0, + source: "datamate", + }; + }); + knowledgeBases.push(...datamateKnowledgeBases); + } + } catch (error) { + log.error("Failed to sync DataMate knowledge bases:", error); + } + } + + // Cache and return + cachedKbInfo = knowledgeBases; + kbCacheExpiry = Date.now() + KB_CACHE_TTL_MS; + return knowledgeBases; } catch (error) { - log.error("Failed to get Elasticsearch indices:", error); + log.error("Failed to get knowledge base list:", error); + throw error; + } finally { + kbInFlightPromise = null; } + })(); - // Sync DataMate knowledge bases and get the synced data (only if enabled) - if (includeDataMateSync) { - try { - const syncResult = await this.syncDataMateAndCreateRecords(); - if (syncResult.indices_info) { - // Convert synced DataMate indices to knowledge base format - const datamateKnowledgeBases: KnowledgeBase[] = - syncResult.indices_info.map((indexInfo: any) => { - const stats = indexInfo.stats?.base_info || {}; - const kbId = indexInfo.name; - const kbName = indexInfo.display_name || indexInfo.name; - - return { - id: kbId, - name: kbName, - description: "DataMate knowledge base", - documentCount: stats.doc_count || 0, - chunkCount: stats.chunk_count || 0, - createdAt: stats.creation_date || null, - updatedAt: stats.update_date || stats.creation_date || null, - embeddingModel: stats.embedding_model || "unknown", - avatar: "", - chunkNum: 0, - language: "", - nickname: "", - parserId: "", - permission: "", - tokenNum: 0, - source: "datamate", - }; - }); - knowledgeBases.push(...datamateKnowledgeBases); - } - } catch (error) { - log.error("Failed to sync DataMate knowledge bases:", error); + return kbInFlightPromise; + } + + // Synchronously read id->name map from localStorage if available and fresh + getCachedIdNameMapSync(): Record | null { + try { + if (typeof window === "undefined" || !window.localStorage) return null; + const raw = window.localStorage.getItem(KB_ID_NAME_MAP_KEY); + if (!raw) return null; + const parsed = JSON.parse(raw); + if (!parsed || !parsed.ts || !parsed.map) return null; + if (Date.now() - parsed.ts > KB_ID_NAME_MAP_TTL_MS) return null; + return parsed.map as Record; + } catch (e) { + log.warn("Failed to read KB id->name map from localStorage:", e); + return null; + } + } + + // Ensure id->name map is available; will fetch/build if necessary + async ensureIdNameMap(): Promise> { + // 1. Prefer in-memory cache built from cachedKbInfo + if (cachedKbInfo && cachedKbInfo.length > 0) { + const map: Record = {}; + cachedKbInfo.forEach((kb) => { + map[kb.id] = kb.name; + }); + try { + if (typeof window !== "undefined" && window.localStorage) { + window.localStorage.setItem( + KB_ID_NAME_MAP_KEY, + JSON.stringify({ ts: Date.now(), map }) + ); } + } catch (e) { + log.warn("Failed to write KB id->name map to localStorage:", e); } + return map; + } - return knowledgeBases; - } catch (error) { - log.error("Failed to get knowledge base list:", error); - throw error; + // 2. Try localStorage sync read + const local = this.getCachedIdNameMapSync(); + if (local) return local; + + // 3. Fallback: fetch KB info and build map + try { + const kbs = await this.getKnowledgeBasesInfo(true, true); + const map: Record = {}; + kbs.forEach((kb) => { + map[kb.id] = kb.name; + }); + try { + if (typeof window !== "undefined" && window.localStorage) { + window.localStorage.setItem( + KB_ID_NAME_MAP_KEY, + JSON.stringify({ ts: Date.now(), map }) + ); + } + } catch (e) { + log.warn("Failed to write KB id->name map to localStorage:", e); + } + return map; + } catch (e) { + log.error("Failed to ensure KB id->name map:", e); + return {}; + } + } + + // Force refresh id->name map from server and update cache/storage + async refreshIdNameMap(): Promise> { + try { + const kbs = await this.getKnowledgeBasesInfo(true, true); + const map: Record = {}; + kbs.forEach((kb) => { + map[kb.id] = kb.name; + }); + try { + if (typeof window !== "undefined" && window.localStorage) { + window.localStorage.setItem( + KB_ID_NAME_MAP_KEY, + JSON.stringify({ ts: Date.now(), map }) + ); + } + } catch (e) { + log.warn("Failed to write KB id->name map to localStorage:", e); + } + return map; + } catch (e) { + log.error("Failed to refresh KB id->name map:", e); + return {}; } } diff --git a/sdk/nexent/core/agents/nexent_agent.py b/sdk/nexent/core/agents/nexent_agent.py index 12d7737df..4f2c38d07 100644 --- a/sdk/nexent/core/agents/nexent_agent.py +++ b/sdk/nexent/core/agents/nexent_agent.py @@ -73,28 +73,19 @@ def create_local_tool(self, tool_config: ToolConfig): # These parameters have exclude=True and cannot be passed to __init__ # due to smolagents.tools.Tool wrapper restrictions filtered_params = {k: v for k, v in params.items() - if k not in ["index_names", "vdb_core", "embedding_model", "observer"]} + if k not in ["vdb_core", "embedding_model", "observer"]} # Create instance with only non-excluded parameters tools_obj = tool_class(**filtered_params) # Set excluded parameters directly as attributes after instantiation # This bypasses smolagents wrapper restrictions tools_obj.observer = self.observer - index_names = tool_config.metadata.get( - "index_names", None) if tool_config.metadata else None - tools_obj.index_names = [] if index_names is None else index_names tools_obj.vdb_core = tool_config.metadata.get( "vdb_core", None) if tool_config.metadata else None tools_obj.embedding_model = tool_config.metadata.get( "embedding_model", None) if tool_config.metadata else None - name_resolver = tool_config.metadata.get( - "name_resolver", None) if tool_config.metadata else None - tools_obj.name_resolver = {} if name_resolver is None else name_resolver elif class_name == "DataMateSearchTool": tools_obj = tool_class(**params) tools_obj.observer = self.observer - index_names = tool_config.metadata.get( - "index_names", None) if tool_config.metadata else None - tools_obj.index_names = [] if index_names is None else index_names elif class_name == "AnalyzeTextFileTool": tools_obj = tool_class(observer=self.observer, llm_model=tool_config.metadata.get("llm_model", []), diff --git a/sdk/nexent/core/tools/datamate_search_tool.py b/sdk/nexent/core/tools/datamate_search_tool.py index ae81a87a4..1950e9c88 100644 --- a/sdk/nexent/core/tools/datamate_search_tool.py +++ b/sdk/nexent/core/tools/datamate_search_tool.py @@ -14,15 +14,6 @@ logger = logging.getLogger("datamate_search_tool") -def _normalize_index_names(index_names: Optional[Union[str, List[str]]]) -> List[str]: - """Normalize index_names to list; accept single string and keep None as empty list.""" - if index_names is None: - return [] - if isinstance(index_names, str): - return [index_names] - return list(index_names) - - class DataMateSearchTool(Tool): """DataMate knowledge base search tool""" name = "datamate_search" @@ -38,23 +29,6 @@ class DataMateSearchTool(Tool): "type": "string", "description": "The search query to perform.", }, - "top_k": { - "type": "integer", - "description": "Maximum number of search results to return.", - "default": 10, - "nullable": True, - }, - "threshold": { - "type": "number", - "description": "Similarity threshold for search results.", - "default": 0.2, - "nullable": True, - }, - "index_names": { - "type": "array", - "description": "The list of knowledge base names to search (supports user-facing knowledge_name or internal index_name). If not provided, will search all available knowledge bases.", - "nullable": True, - }, "kb_page": { "type": "integer", "description": "Page index when listing knowledge bases from DataMate.", @@ -80,9 +54,13 @@ def __init__( verify_ssl: bool = Field( description="Whether to verify SSL certificates for HTTPS connections", default=False), index_names: List[str] = Field( - description="The list of index names to search", default=None, exclude=True), + description="The list of index names to search", default=None), observer: MessageObserver = Field( description="Message observer", default=None, exclude=True), + top_k: int = Field( + description="Default maximum number of search results to return", default=3), + threshold: float = Field( + description="Default similarity threshold for search results", default=0.2), ): """Initialize the DataMateSearchTool. @@ -106,6 +84,8 @@ def __init__( self.use_https = parsed_url["use_https"] self.server_base_url = parsed_url["base_url"] self.index_names = [] if index_names is None else index_names + self.top_k = top_k + self.threshold = threshold # Determine SSL verification setting if verify_ssl is None: @@ -177,9 +157,6 @@ def _parse_server_url(server_url: str) -> dict: def forward( self, query: str, - top_k: int = 3, - threshold: float = 0.2, - index_names: Union[str, List[str], None] = None, kb_page: int = 0, kb_page_size: int = 20, ) -> str: @@ -187,9 +164,6 @@ def forward( Args: query: Search query text. - top_k: Optional override for maximum number of search results. - threshold: Optional override for similarity threshold. - index_names: The list of knowledge base names to search (supports user-facing knowledge_name or internal index_name). If not provided, will search all available knowledge bases. kb_page: Optional override for knowledge base list page index. kb_page_size: Optional override for knowledge base list page size. """ @@ -207,15 +181,12 @@ def forward( logger.info( f"DataMateSearchTool called with query: '{query}', base_url: '{self.server_base_url}', " - f"top_k: {top_k}, threshold: {threshold}, index_names: {index_names}" + f"top_k: {self.top_k}, threshold: {self.threshold}, index_names: {self.index_names}" ) try: # Step 1: Determine knowledge base IDs to search - # Use provided index_names if available, otherwise use default - knowledge_base_ids = _normalize_index_names( - index_names if index_names is not None else self.index_names) - + knowledge_base_ids = self.index_names if len(knowledge_base_ids) == 0: return json.dumps("No knowledge base selected. No relevant information found.", ensure_ascii=False) @@ -225,8 +196,8 @@ def forward( kb_search = self.datamate_core.hybrid_search( query_text=query, index_names=[knowledge_base_id], - top_k=top_k, - weight_accurate=threshold, + top_k=self.top_k, + weight_accurate=self.threshold, ) if not kb_search: raise Exception( From 4b21038075049d434dab90a7b9c48568d91401ea Mon Sep 17 00:00:00 2001 From: zhizhi <928570418@qq.com> Date: Tue, 27 Jan 2026 16:22:00 +0800 Subject: [PATCH 002/167] =?UTF-8?q?=E2=9C=A8=20Update=20UI=20localization?= =?UTF-8?q?=20for=20knowledge=20base=20selection=20prompts=20part2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../components/agentConfig/ToolManagement.tsx | 57 +++- .../agentConfig/tool/ToolConfigModal.tsx | 312 +++++++++++++----- .../knowledge/KnowledgeBaseList.tsx | 148 +++++---- frontend/services/configService.ts | 57 +++- .../core/tools/knowledge_base_search_tool.py | 69 ++-- 5 files changed, 413 insertions(+), 230 deletions(-) diff --git a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx index 02ad0e8f6..81ff59a3e 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx @@ -1,9 +1,10 @@ "use client"; -import { useState, useEffect, useCallback } from "react"; +import { useState, useEffect, useCallback, useMemo } from "react"; import { useTranslation } from "react-i18next"; import ToolConfigModal from "./tool/ToolConfigModal"; import { ToolGroup, Tool, ToolParam } from "@/types/agentConfig"; +import { TOOL_PARAM_TYPES } from "@/const/agentConfig"; import { Tabs, Collapse } from "antd"; import { useAgentConfigStore } from "@/stores/agentConfigStore"; import { useToolList } from "@/hooks/agent/useToolList"; @@ -50,6 +51,12 @@ export default function ToolManagement({ const [selectedTool, setSelectedTool] = useState(null); const [toolParams, setToolParams] = useState([]); + // Memoized array representation of expandedCategories to keep hook order stable + const activeKeyArray = useMemo( + () => Array.from(expandedCategories), + [expandedCategories] + ); + // Helper function to merge tool parameters with instance parameters const mergeToolParamsWithInstance = async ( tool: Tool, @@ -64,7 +71,7 @@ export default function ToolManagement({ if (tooInstance.success && tooInstance.data) { // Merge instance params with default params - const mergedParams = + const mergedParams: ToolParam[] = defaultTool.initParams?.map((param: ToolParam) => { const instanceValue = tooInstance.data?.params?.[param.name]; return { @@ -73,8 +80,33 @@ export default function ToolManagement({ instanceValue !== undefined ? instanceValue : param.value, }; }) || - defaultTool.initParams || + defaultTool.initParams?.slice() || []; + + // If instance contains params that are not defined in default initParams, + // append them so UI can show saved values like `index_names`. + const instanceParams = tooInstance.data?.params || {}; + Object.keys(instanceParams).forEach((key) => { + const exists = mergedParams.some((p) => p.name === key); + if (!exists) { + const val = instanceParams[key]; + const inferredType = Array.isArray(val) + ? TOOL_PARAM_TYPES.ARRAY + : typeof val === "boolean" + ? TOOL_PARAM_TYPES.BOOLEAN + : typeof val === "number" + ? TOOL_PARAM_TYPES.NUMBER + : TOOL_PARAM_TYPES.STRING; + mergedParams.push({ + name: key, + type: inferredType, + required: false, + value: val, + description: "", + } as any); + } + }); + return mergedParams; } else { return defaultTool.initParams || []; @@ -217,13 +249,22 @@ export default function ToolManagement({ <> {/* Collapsible categories using Ant Design Collapse */}
+ {/* Memoize activeKey array to avoid creating a new array on every render, + which could cause the Collapse to think its controlled prop changed + and trigger onChange repeatedly. */} { - const newSet = new Set( - typeof keys === "string" ? [keys] : keys - ); - setExpandedCategories(newSet); + const keyArray = typeof keys === "string" ? [keys] : keys || []; + const newSet = new Set(keyArray); + // Only update state if the set content actually changed + const sameSize = newSet.size === expandedCategories.size; + const allSame = sameSize + ? Array.from(newSet).every((k) => expandedCategories.has(k)) + : false; + if (!allSame) { + setExpandedCategories(newSet); + } }} ghost size="small" diff --git a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx index 06abbf02c..c07b00613 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx @@ -2,14 +2,32 @@ import { useState, useEffect } from "react"; import { useTranslation } from "react-i18next"; -import { Modal, Input, Switch, InputNumber, Tag, Form, message } from "antd"; +import { + Modal, + Input, + Switch, + InputNumber, + Tag, + Form, + message, + Select, + Button, + Spin, +} from "antd"; import { useQueryClient } from "@tanstack/react-query"; import { useAgentConfigStore } from "@/stores/agentConfigStore"; +import KnowledgeBaseList from "../../../../knowledges/components/knowledge/KnowledgeBaseList"; +import { KnowledgeBase } from "@/types/knowledgeBase"; + import { TOOL_PARAM_TYPES } from "@/const/agentConfig"; import { ToolParam, Tool } from "@/types/agentConfig"; import ToolTestPanel from "./ToolTestPanel"; +import KnowledgeBaseToolConfig from "./KnowledgeBaseToolConfig"; import { updateToolConfig } from "@/services/agentConfigService"; +import knowledgeBaseService from "@/services/knowledgeBaseService"; +import { configService } from "@/services/configService"; +import { ConfigStore } from "@/lib/config"; export interface ToolConfigModalProps { isOpen: boolean; @@ -41,6 +59,65 @@ export default function ToolConfigModal({ // Tool test panel visibility state const [testPanelVisible, setTestPanelVisible] = useState(false); + const [kbOptions, setKbOptions] = useState< + { label: React.ReactNode; value: string; description?: string }[] + >([]); + const [kbLoading, setKbLoading] = useState(false); + const [kbModalVisible, setKbModalVisible] = useState(false); + const [kbModalSelected, setKbModalSelected] = useState([]); + const [kbRawList, setKbRawList] = useState([]); + // Preload KB list once when modal opens to avoid repeated fetches/ syncing + useEffect(() => { + let cancelled = false; + const loadKbs = async () => { + if (!isOpen) return; + const toolName = tool?.name; + // Only preload for local indices tool to avoid expensive / duplicative DataMate sync. + if (toolName !== "knowledge_base_search") { + return; + } + + setKbLoading(true); + try { + const kbs = await knowledgeBaseService.getKnowledgeBasesInfo( + true, + false, + true + ); + if (cancelled) return; + setKbRawList(kbs); + setKbOptions( + kbs.map((kb: any) => ({ + value: kb.id, + label: kb.name, + description: kb.description || "", + })) + ); + } catch (e) { + // Non-fatal: leave child to handle DataMate fetch/errors when modal opens + } finally { + if (!cancelled) setKbLoading(false); + } + }; + + loadKbs(); + return () => { + cancelled = true; + }; + }, [isOpen, tool]); + + const buildKbOptions = (kbs: any[]) => { + // For preview/select usage we only need the KB name (no description). + return kbs.map((kb) => ({ + value: kb.id, + label: kb.name, + description: kb.description || "", + })); + }; + + // Configurable grouping: control which params appear in each group + const serverParamNames = ["server_url", "verify_ssl"]; + const retrievalParamNames = ["top_k", "threshold", "kb_page", "kb_page_size", "search_mode"]; // Initialize with provided params useEffect(() => { // Initialize form values @@ -50,6 +127,37 @@ export default function ToolConfigModal({ formValues[`param_${index}`] = param.value; }); form.setFieldsValue(formValues); + + // On first load of the tool config modal, attempt to load tenant config + // and populate the tool's `server_url` param from tenant `datamateUrl` + // if `server_url` exists and is empty. + (async () => { + try { + await configService.loadConfigToFrontend(); + const appConfig = ConfigStore.getInstance().getAppConfig(); + const tenantDataMateUrl = appConfig?.datamateUrl; + if (tenantDataMateUrl) { + const serverIdx = initialParams.findIndex((p) => p.name === "server_url"); + if (serverIdx !== -1) { + const currentVal = initialParams[serverIdx]?.value; + if (!currentVal) { + const newParams = [...initialParams]; + newParams[serverIdx] = { + ...newParams[serverIdx], + value: tenantDataMateUrl, + }; + setCurrentParams(newParams); + const fieldName2 = `param_${serverIdx}`; + form.setFieldsValue({ [fieldName2]: tenantDataMateUrl }); + } + } + } + } catch (e) { + // Non-fatal: log and continue + // eslint-disable-next-line no-console + console.warn("Failed to load tenant config on tool init:", e); + } + })(); }, [initialParams]); // Watch all form values and sync to currentParams @@ -67,6 +175,8 @@ export default function ToolConfigModal({ } }, [formValues]); + // KB list is loaded lazily when user opens the Select KBs modal (see button handler) + const handleSave = async () => { try { await form.validateFields(); @@ -136,8 +246,7 @@ export default function ToolConfigModal({ onSave(currentParams); } } catch (error) { - // Form validation failed, error will be shown by antd Form - message.error("Form validation failed:"); + // Form validation failed; AntD Form will display inline errors. No global message needed. } }; @@ -174,7 +283,23 @@ export default function ToolConfigModal({ case TOOL_PARAM_TYPES.ARRAY: case TOOL_PARAM_TYPES.OBJECT: default: - // Default TextArea for all text-like types and unknown types + // Special-case: render `search_mode` as a Select dropdown with fixed options + if (param.name === "search_mode") { + return ( + 0 ? optionsSource : quickOptions} + disabled={idx === -1} + placeholder={ + idx !== -1 + ? t("toolConfig.input.array.placeholder", { name: currentParams[idx].description }) + : t("toolConfig.placeholder.selectKb", "Select knowledge bases") + } + notFoundContent={isLoadingOptions ? : null} + open={false} + onClick={openKbModal} + onFocus={openKbModal} + /> + ); + })()} + +
+ + {/* Inline loading indicator shown while KB list loads */} +
+ {kbLoading && } +
+ + ); + })()} + + document.body} + zIndex={1200} + open={kbModalVisible} + title={t("toolConfig.modal.selectKbTitle", "Select Knowledge Bases")} + onCancel={() => { + setKbModalVisible(false); + // suppress immediate reopen caused by focus/click after cancel + setSuppressOpen(true); + if (suppressTimerRef.current) { + window.clearTimeout(suppressTimerRef.current); + } + suppressTimerRef.current = window.setTimeout(() => { + setSuppressOpen(false); + suppressTimerRef.current = null; + }, 300); + if (document.activeElement instanceof HTMLElement) { + document.activeElement.blur(); + } + }} + cancelText={t("common.button.cancel")} + okText={t("common.button.save")} + onOk={() => { + let idx2 = currentParams.findIndex((p) => p.name === "index_names"); + const newParams = [...currentParams]; + if (idx2 === -1) { + const newParam: ToolParam = { + name: "index_names", + type: "array", + required: false, + description: "List of knowledge base ids", + value: kbModalSelected, + } as ToolParam; + newParams.push(newParam); + setCurrentParams(newParams); + const fieldName2 = `param_${newParams.length - 1}`; + form.setFieldsValue({ [fieldName2]: kbModalSelected }); + } else { + newParams[idx2] = { ...newParams[idx2], value: kbModalSelected }; + setCurrentParams(newParams); + const fieldName2 = `param_${idx2}`; + form.setFieldsValue({ [fieldName2]: kbModalSelected }); + } + // Close modal and briefly suppress open to avoid immediate re-open via focus/click + setKbModalVisible(false); + setSuppressOpen(true); + if (suppressTimerRef.current) { + window.clearTimeout(suppressTimerRef.current); + } + suppressTimerRef.current = window.setTimeout(() => { + setSuppressOpen(false); + suppressTimerRef.current = null; + }, 300); + // Blur active element to reduce chance of immediate focus triggering open + if (document.activeElement instanceof HTMLElement) { + document.activeElement.blur(); + } + message.success(t("toolConfig.message.kbSelectSaved", "KB selection saved")); + }} + width={800} + > + {/* Ensure selected KB names are loaded into options so Select shows names instead of raw IDs */} + {/* If index_names exists but options missing those ids, load KB info once */} + {/* This effect will run when currentParams changes */} + {/* We put it here inside the JSX file scope (top-level) via useEffect below */} +
+ { + const exists = kbModalSelected.includes(id); + const newSelected = exists + ? kbModalSelected.filter((s) => s !== id) + : [...kbModalSelected, id]; + setKbModalSelected(newSelected); + }} + onClick={() => {}} + showDataMateConfig={false} + isSelectable={(kb: KnowledgeBase) => { + const docCount = typeof kb.documentCount === "number" ? kb.documentCount : 0; + const chunkCount = typeof kb.chunkCount === "number" ? kb.chunkCount : 0; + return (docCount + chunkCount) > 0; + }} + getModelDisplayName={(m: string) => m} + containerHeight="50vh" + /> +
+
+ + + ); +} + + From d949c6ecff12beda019d67f064efda70101af826 Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Tue, 27 Jan 2026 19:39:26 +0800 Subject: [PATCH 004/167] Merge branch 'develop' into xyq/user_management # Conflicts: # frontend/app/[locale]/tenant-resources/components/UserManageComp.tsx # frontend/app/[locale]/users/page.tsx # frontend/components/navigation/SideNavigation.tsx --- .../agentInfo/AgentGenerateDetail.tsx | 6 +- frontend/app/[locale]/agents/page.tsx | 131 +++--- .../[locale]/chat/components/chatHeader.tsx | 8 +- .../chat/components/chatLeftSidebar.tsx | 5 +- .../[locale]/chat/internal/chatInterface.tsx | 66 +-- frontend/app/[locale]/chat/page.tsx | 31 +- frontend/app/[locale]/knowledges/page.tsx | 52 +-- frontend/app/[locale]/layout.client.tsx | 74 ++-- frontend/app/[locale]/layout.tsx | 2 +- frontend/app/[locale]/memory/page.tsx | 33 +- .../[locale]/models/ModelConfiguration.tsx | 3 - .../models/components/modelConfig.tsx | 77 ++-- frontend/app/[locale]/models/page.tsx | 44 +- frontend/app/[locale]/setup/page.tsx | 18 +- .../[locale]/space/components/AgentCard.tsx | 19 +- frontend/app/[locale]/space/page.tsx | 21 +- .../components/UserManageComp.tsx | 8 +- .../app/[locale]/tenant-resources/page.tsx | 17 +- frontend/app/[locale]/users/page.tsx | 18 +- frontend/components/auth/AuthDialogs.tsx | 155 +++++++ frontend/components/auth/avatarDropdown.tsx | 17 +- frontend/components/auth/index.ts | 1 - frontend/components/auth/loginModal.tsx | 198 ++++----- frontend/components/auth/registerModal.tsx | 417 +++++++++--------- frontend/components/auth/sessionListeners.tsx | 230 ---------- frontend/components/homepage/AuthDialogs.tsx | 186 -------- .../navigation/ChatTopNavContent.tsx | 48 ++ .../components/navigation/SideNavigation.tsx | 226 +++++----- frontend/components/navigation/TopNavbar.tsx | 46 +- frontend/components/permission/Can.tsx | 36 ++ frontend/components/permission/Cannot.tsx | 32 ++ .../providers/AuthenticationProvider.tsx | 42 ++ .../providers/AuthorizationProvider.tsx | 42 ++ .../providers/deploymentProvider.tsx | 20 +- .../components/providers/rootProvider.tsx | 54 ++- frontend/const/auth.ts | 21 +- frontend/hooks/auth/useAuthentication.ts | 55 +++ frontend/hooks/auth/useAuthenticationState.ts | 235 ++++++++++ frontend/hooks/auth/useAuthenticationUI.ts | 174 ++++++++ frontend/hooks/auth/useAuthorization.ts | 277 ++++++++++++ frontend/hooks/auth/useSessionManager.ts | 216 +++++++++ frontend/hooks/permission/usePermission.ts | 35 ++ frontend/hooks/useAuth.ts | 376 ---------------- frontend/hooks/useSetupFlow.ts | 77 +--- frontend/lib/auth.ts | 122 ++--- frontend/lib/authEvents.ts | 83 ++++ frontend/lib/session.ts | 94 ++++ frontend/public/locales/en/common.json | 7 - frontend/public/locales/zh/common.json | 10 +- frontend/services/api.ts | 32 +- frontend/services/authService.ts | 88 ++-- frontend/services/sessionService.ts | 93 +--- frontend/types/auth.ts | 178 +++++++- 53 files changed, 2588 insertions(+), 1968 deletions(-) create mode 100644 frontend/components/auth/AuthDialogs.tsx delete mode 100644 frontend/components/auth/sessionListeners.tsx delete mode 100644 frontend/components/homepage/AuthDialogs.tsx create mode 100644 frontend/components/navigation/ChatTopNavContent.tsx create mode 100644 frontend/components/permission/Can.tsx create mode 100644 frontend/components/permission/Cannot.tsx create mode 100644 frontend/components/providers/AuthenticationProvider.tsx create mode 100644 frontend/components/providers/AuthorizationProvider.tsx create mode 100644 frontend/hooks/auth/useAuthentication.ts create mode 100644 frontend/hooks/auth/useAuthenticationState.ts create mode 100644 frontend/hooks/auth/useAuthenticationUI.ts create mode 100644 frontend/hooks/auth/useAuthorization.ts create mode 100644 frontend/hooks/auth/useSessionManager.ts create mode 100644 frontend/hooks/permission/usePermission.ts delete mode 100644 frontend/hooks/useAuth.ts create mode 100644 frontend/lib/authEvents.ts create mode 100644 frontend/lib/session.ts diff --git a/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx b/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx index 88e081f7c..6da2e901b 100644 --- a/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx +++ b/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx @@ -32,7 +32,8 @@ import { GENERATE_PROMPT_STREAM_TYPES, } from "@/const/agentConfig"; import { generatePromptStream } from "@/services/promptService"; -import { useAuth } from "@/hooks/useAuth"; +import { useAuthorizationContext } from "@/components/providers/AuthorizationProvider"; +import { useDeployment } from "@/components/providers/deploymentProvider"; import { useModelList } from "@/hooks/model/useModelList"; import ExpandEditModal from "./ExpandEditModal"; @@ -55,7 +56,8 @@ export default function AgentGenerateDetail({ }: AgentGenerateDetailProps) { const { t } = useTranslation("common"); const { message } = App.useApp(); - const { user, isSpeedMode } = useAuth(); + const { user } = useAuthorizationContext(); + const { isSpeedMode } = useDeployment(); const [form] = Form.useForm(); // Model data from React Query diff --git a/frontend/app/[locale]/agents/page.tsx b/frontend/app/[locale]/agents/page.tsx index da6e85420..3ed60a398 100644 --- a/frontend/app/[locale]/agents/page.tsx +++ b/frontend/app/[locale]/agents/page.tsx @@ -12,8 +12,7 @@ import AgentInfoComp from "./components/AgentInfoComp"; import { useAgentConfigStore } from "@/stores/agentConfigStore"; export default function AgentSetupOrchestrator() { - const { pageVariants, pageTransition, canAccessProtectedData } = - useSetupFlow(); + const { pageVariants, pageTransition } = useSetupFlow(); const searchParams = useSearchParams(); const enterCreateMode = useAgentConfigStore((state) => state.enterCreateMode); @@ -29,76 +28,72 @@ export default function AgentSetupOrchestrator() { }, [searchParams, enterCreateMode]); return ( - <> - {canAccessProtectedData ? ( - + + - {` + .ant-card-body { + height: 100%; + } + `} + {/* Three-column layout using Ant Design Grid */} + - - - {/* Three-column layout using Ant Design Grid */} - - {/* Left column: Agent Management */} - - - + + - {/* Middle column: Agent Config */} - - - + {/* Middle column: Agent Config */} + + + - {/* Right column: Agent Info */} - - - - - - - - ) : null} - + {/* Right column: Agent Info */} + + + + + + + ); } diff --git a/frontend/app/[locale]/chat/components/chatHeader.tsx b/frontend/app/[locale]/chat/components/chatHeader.tsx index 9aae7b75b..858fe9199 100644 --- a/frontend/app/[locale]/chat/components/chatHeader.tsx +++ b/frontend/app/[locale]/chat/components/chatHeader.tsx @@ -9,8 +9,9 @@ import { loadMemoryConfig, setMemorySwitch } from "@/services/memoryService"; import { configStore } from "@/lib/config"; import log from "@/lib/logger"; import { useRouter } from "next/navigation"; -import { useAuth } from "@/hooks/useAuth"; -import { USER_ROLES } from "@/const/modelConfig"; +import { useAuthorizationContext } from "@/components/providers/AuthorizationProvider"; +import { useDeployment } from "@/components/providers/deploymentProvider"; +import { USER_ROLES } from "@/const/auth"; import { saveView } from "@/lib/viewPersistence"; import { useConfirmModal } from "@/hooks/useConfirmModal"; @@ -27,7 +28,8 @@ export function ChatHeader({ title, onRename }: ChatHeaderProps) { const inputRef = useRef(null); const { t, i18n } = useTranslation("common"); - const { user, isSpeedMode } = useAuth(); + const { user } = useAuthorizationContext(); + const { isSpeedMode } = useDeployment(); const { confirm } = useConfirmModal(); const isAdmin = isSpeedMode || user?.role === USER_ROLES.ADMIN; diff --git a/frontend/app/[locale]/chat/components/chatLeftSidebar.tsx b/frontend/app/[locale]/chat/components/chatLeftSidebar.tsx index d93122654..baf569c36 100644 --- a/frontend/app/[locale]/chat/components/chatLeftSidebar.tsx +++ b/frontend/app/[locale]/chat/components/chatLeftSidebar.tsx @@ -14,7 +14,6 @@ import { Button, Dropdown } from "antd"; import { Input } from "@/components/ui/input"; import { Tooltip, TooltipProvider } from "@/components/ui/tooltip"; import { StaticScrollArea } from "@/components/ui/scrollArea"; -import { USER_ROLES } from "@/const/modelConfig"; import { useTranslation } from "react-i18next"; import { useConfirmModal } from "@/hooks/useConfirmModal"; import { ConversationListItem, ChatSidebarProps } from "@/types/chat"; @@ -94,13 +93,11 @@ export function ChatSidebar({ onRename, onDelete, onSettingsClick, - settingsMenuItems, onDropdownOpenChange, onToggleSidebar, expanded, userEmail, - userAvatarUrl, - userRole = USER_ROLES.USER, + userAvatarUrl }: ChatSidebarProps) { const { t } = useTranslation(); const { confirm } = useConfirmModal(); diff --git a/frontend/app/[locale]/chat/internal/chatInterface.tsx b/frontend/app/[locale]/chat/internal/chatInterface.tsx index 63db7382a..5c8846e73 100644 --- a/frontend/app/[locale]/chat/internal/chatInterface.tsx +++ b/frontend/app/[locale]/chat/internal/chatInterface.tsx @@ -7,9 +7,11 @@ import { useRouter } from "next/navigation"; import { v4 as uuidv4 } from "uuid"; import { useTranslation } from "react-i18next"; -import { chatConfig, MESSAGE_ROLES } from "@/const/chatConfig"; +import { ROLE_ASSISTANT } from "@/const/agentConfig"; +import { MESSAGE_ROLES } from "@/const/chatConfig"; import { useConfig } from "@/hooks/useConfig"; -import { useAuth } from "@/hooks/useAuth"; +import { useAuthorizationContext } from "@/components/providers/AuthorizationProvider"; +import { useDeployment } from "@/components/providers/deploymentProvider"; import { conversationService } from "@/services/conversationService"; import { storageService, convertImageUrlToApiUrl } from "@/services/storageService"; import { useConversationManagement } from "@/hooks/chat/useConversationManagement"; @@ -55,7 +57,8 @@ const getI18nKeyByType = (type: string): string => { export function ChatInterface() { const router = useRouter(); - const { user, isSpeedMode } = useAuth(); // Get user information + const { user } = useAuthorizationContext(); + const { isSpeedMode } = useDeployment(); const [input, setInput] = useState(""); // Replace the original messages state const [sessionMessages, setSessionMessages] = useState<{ @@ -356,7 +359,7 @@ export function ChatInterface() { const assistantMessageId = uuidv4(); const initialAssistantMessage: ChatMessageType = { id: assistantMessageId, - role: MESSAGE_ROLES.ASSISTANT, + role: ROLE_ASSISTANT, content: "", timestamp: new Date(), isComplete: false, @@ -504,7 +507,7 @@ export function ChatInterface() { .map((msg) => ({ role: msg.role, content: - msg.role === MESSAGE_ROLES.ASSISTANT + msg.role === ROLE_ASSISTANT ? msg.finalAnswer?.trim() || msg.content || "" : msg.content || "", })), @@ -588,7 +591,7 @@ export function ChatInterface() { newMessages[currentConversationId]?.[ newMessages[currentConversationId].length - 1 ]; - if (lastMsg && lastMsg.role === MESSAGE_ROLES.ASSISTANT) { + if (lastMsg && lastMsg.role === ROLE_ASSISTANT) { lastMsg.error = t("chatInterface.requestTimeoutRetry"); lastMsg.isComplete = true; lastMsg.thinking = undefined; @@ -680,7 +683,7 @@ export function ChatInterface() { newMessages[currentConversationId]?.[ newMessages[currentConversationId].length - 1 ]; - if (lastMsg && lastMsg.role === MESSAGE_ROLES.ASSISTANT) { + if (lastMsg && lastMsg.role === ROLE_ASSISTANT) { lastMsg.content = t("chatInterface.conversationStopped"); lastMsg.isComplete = true; lastMsg.thinking = undefined; // Explicitly clear thinking state @@ -697,7 +700,7 @@ export function ChatInterface() { newMessages[currentConversationId]?.[ newMessages[currentConversationId].length - 1 ]; - if (lastMsg && lastMsg.role === MESSAGE_ROLES.ASSISTANT) { + if (lastMsg && lastMsg.role === ROLE_ASSISTANT) { lastMsg.content = errorMessage; lastMsg.isComplete = true; lastMsg.error = errorMessage; @@ -1010,7 +1013,7 @@ export function ChatInterface() { conversationData.create_time ); formattedMessages.push(formattedUserMsg); - } else if (dialog_msg.role === MESSAGE_ROLES.ASSISTANT) { + } else if (dialog_msg.role === ROLE_ASSISTANT) { const formattedAssistantMsg: ChatMessageType = extractAssistantMsgFromResponse( dialog_msg, @@ -1194,7 +1197,7 @@ export function ChatInterface() { const lastMsg = newMessages[conversationManagement.conversationId]?.[newMessages[conversationManagement.conversationId].length - 1]; - if (lastMsg && lastMsg.role === MESSAGE_ROLES.ASSISTANT && lastMsg.images) { + if (lastMsg && lastMsg.role === ROLE_ASSISTANT && lastMsg.images) { // Filter out failed images lastMsg.images = lastMsg.images.filter((url) => url !== imageUrl); } @@ -1247,7 +1250,7 @@ export function ChatInterface() { const newMessages = { ...prev }; const lastMsg = newMessages[conversationManagement.conversationId]?.[newMessages[conversationManagement.conversationId].length - 1]; - if (lastMsg && lastMsg.role === MESSAGE_ROLES.ASSISTANT) { + if (lastMsg && lastMsg.role === ROLE_ASSISTANT) { lastMsg.isComplete = true; lastMsg.thinking = undefined; // Explicitly clear thinking state } @@ -1278,7 +1281,7 @@ export function ChatInterface() { const newMessages = { ...prev }; const lastMsg = newMessages[conversationManagement.conversationId]?.[newMessages[conversationManagement.conversationId].length - 1]; - if (lastMsg && lastMsg.role === MESSAGE_ROLES.ASSISTANT) { + if (lastMsg && lastMsg.role === ROLE_ASSISTANT) { lastMsg.isComplete = true; lastMsg.thinking = undefined; // Explicitly clear thinking state lastMsg.error = t( @@ -1376,41 +1379,7 @@ export function ChatInterface() { // Both admin and regular users now use dropdown menus }; - // Settings menu items based on user role (speed mode is treated as admin) - const settingsMenuItems = (isSpeedMode || user?.role === "admin") ? [ - // Admin has three options - { - key: "models", - label: t("chatLeftSidebar.settingsMenu.modelConfig"), - onClick: () => { - localStorage.setItem("show_page", "1"); - router.push("/setup/models"); - }, - }, - { - key: "knowledges", - label: t("chatLeftSidebar.settingsMenu.knowledgeConfig"), - onClick: () => { - router.push("/setup/knowledges"); - }, - }, - { - key: "agents", - label: t("chatLeftSidebar.settingsMenu.agentConfig"), - onClick: () => { - router.push("/setup/agents"); - }, - }, - ] : [ - // Regular user only has knowledge base configuration - { - key: "knowledges", - label: t("chatLeftSidebar.settingsMenu.knowledgeConfig"), - onClick: () => { - router.push("/setup/knowledges"); - }, - }, - ]; + return ( <> @@ -1426,14 +1395,13 @@ export function ChatInterface() { onRename={handleConversationRename} onDelete={handleConversationDeleteClick} onSettingsClick={handleSettingsClick} - settingsMenuItems={settingsMenuItems} onDropdownOpenChange={(open: boolean, id: string | null) => setOpenDropdownId(open ? id : null) } onToggleSidebar={toggleSidebar} expanded={sidebarOpen} userEmail={user?.email} - userAvatarUrl={user?.avatar_url} + userAvatarUrl={user?.avatarUrl} userRole={user?.role} /> diff --git a/frontend/app/[locale]/chat/page.tsx b/frontend/app/[locale]/chat/page.tsx index 5349f5cd7..332e51846 100644 --- a/frontend/app/[locale]/chat/page.tsx +++ b/frontend/app/[locale]/chat/page.tsx @@ -1,10 +1,10 @@ "use client"; import { useEffect, useRef } from "react"; -import { useAuth } from "@/hooks/useAuth"; +import { useAuthorizationContext } from "@/components/providers/AuthorizationProvider"; +import { useDeployment } from "@/components/providers/deploymentProvider"; import { useConfig } from "@/hooks/useConfig"; import { configService } from "@/services/configService"; -import { EVENTS } from "@/const/auth"; import { ChatInterface } from "./internal/chatInterface"; /** @@ -13,24 +13,21 @@ import { ChatInterface } from "./internal/chatInterface"; */ export default function ChatContent() { const { appConfig } = useConfig(); - const { user, isLoading: userLoading, isSpeedMode } = useAuth(); - const canAccessProtectedData = isSpeedMode || (!userLoading && !!user); + const { user, isLoading: userLoading } = useAuthorizationContext(); + const { isSpeedMode } = useDeployment(); const sessionExpiredTriggeredRef = useRef(false); useEffect(() => { - if (!canAccessProtectedData) { - return; - } // Load config from backend when entering chat page configService.loadConfigToFrontend(); if (appConfig.appName) { document.title = `${appConfig.appName}`; } - }, [appConfig.appName, canAccessProtectedData]); + }, [appConfig.appName]); - // Require login on chat page when unauthenticated (full mode only) - // Trigger SESSION_EXPIRED event to show "Login Expired" modal instead of directly opening login modal + // Require login on chat page when unauthenticated (skip in speed mode) + // Note: SESSION_EXPIRED event is triggered by useSessionManager.ts on initialization useEffect(() => { if (isSpeedMode) { sessionExpiredTriggeredRef.current = false; @@ -42,20 +39,10 @@ export default function ChatContent() { return; } - if (!userLoading && !sessionExpiredTriggeredRef.current) { - sessionExpiredTriggeredRef.current = true; - window.dispatchEvent( - new CustomEvent(EVENTS.SESSION_EXPIRED, { - detail: { message: "Session expired, please sign in again" }, - }) - ); - } + // Session expiration is handled by useSessionManager.ts + // Don't trigger SESSION_EXPIRED here to avoid duplicate handling }, [isSpeedMode, user, userLoading]); - // Avoid rendering and backend calls when unauthenticated (full mode only) - if (!canAccessProtectedData) { - return null; - } return (
diff --git a/frontend/app/[locale]/knowledges/page.tsx b/frontend/app/[locale]/knowledges/page.tsx index 835018df3..d7b50f452 100644 --- a/frontend/app/[locale]/knowledges/page.tsx +++ b/frontend/app/[locale]/knowledges/page.tsx @@ -4,9 +4,11 @@ import React, { useEffect } from "react"; import { motion } from "framer-motion"; import { useSetupFlow } from "@/hooks/useSetupFlow"; +import { useAuthorizationContext } from "@/components/providers/AuthorizationProvider"; +import { useDeployment } from "@/components/providers/deploymentProvider"; import { configService } from "@/services/configService"; import { configStore } from "@/lib/config"; -import { USER_ROLES } from "@/const/modelConfig"; +import { USER_ROLES } from "@/const/auth"; import log from "@/lib/logger"; import DataConfig from "./KnowledgeBaseConfiguration"; @@ -17,16 +19,14 @@ import DataConfig from "./KnowledgeBaseConfiguration"; */ export default function KnowledgesContent() { // Use custom hook for common setup flow logic - const { - user, - isSpeedMode, - pageVariants, - pageTransition, - canAccessProtectedData, - } = useSetupFlow({ + const { pageVariants, pageTransition } = useSetupFlow({ requireAdmin: false, // Knowledge base accessible to all users }); + // Get auth state directly from providers + const { isSpeedMode } = useDeployment(); + const { user } = useAuthorizationContext(); + // Knowledge base specific initialization useEffect(() => { // Trigger knowledge base data acquisition when the page is initialized @@ -38,7 +38,7 @@ export default function KnowledgesContent() { // Load config for normal user const loadConfigForNormalUser = async () => { - if (!isSpeedMode && user && user.role !== USER_ROLES.ADMIN) { + if (!isSpeedMode && user) { try { await configService.loadConfigToFrontend(); configStore.reloadFromStorage(); @@ -49,26 +49,22 @@ export default function KnowledgesContent() { }; loadConfigForNormalUser(); - }, []); + }, [isSpeedMode, user]); return ( - <> -
- {canAccessProtectedData ? ( - -
- -
-
- ) : null} -
- +
+ +
+ +
+
+
); } diff --git a/frontend/app/[locale]/layout.client.tsx b/frontend/app/[locale]/layout.client.tsx index 6bb748a48..f30ac0344 100644 --- a/frontend/app/[locale]/layout.client.tsx +++ b/frontend/app/[locale]/layout.client.tsx @@ -1,7 +1,8 @@ "use client"; import { ReactNode, useState } from "react"; -import { Layout, Button } from "antd"; +import { usePathname } from "next/navigation"; +import { Layout, Button, Spin } from "antd"; import { TopNavbar } from "@/components/navigation/TopNavbar"; import { SideNavigation } from "@/components/navigation/SideNavigation"; import { FooterLayout } from "@/components/navigation/FooterLayout"; @@ -10,23 +11,28 @@ import { FOOTER_CONFIG, SIDER_CONFIG, } from "@/const/layoutConstants"; -import { AuthDialogs } from "@/components/homepage/AuthDialogs"; -import { useAuth } from "@/hooks/useAuth"; +import { AuthDialogs } from "@/components/auth/AuthDialogs"; +import { useAuthenticationContext } from "@/components/providers/AuthenticationProvider"; import { ChevronLeft, ChevronRight } from "lucide-react"; -import { usePathname } from "next/navigation"; +import { useAuthorizationContext } from "@/components/providers/AuthorizationProvider"; +import { useDeployment } from "@/components/providers/deploymentProvider"; +import { getEffectiveRoutePath } from "@/lib/auth"; const { Header, Sider, Content, Footer } = Layout; export function ClientLayout({ children }: { children: ReactNode }) { - const { user, openLoginModal, openRegisterModal, isSpeedMode } = useAuth(); const pathname = usePathname(); + const { isAuthenticated } = useAuthenticationContext(); + const { isAuthorized } = useAuthorizationContext(); + const { isSpeedMode } = useDeployment(); // Check if current route is setup page const isSetupPage = pathname?.includes("/setup"); - // Authentication dialog states - const [loginPromptOpen, setLoginPromptOpen] = useState(false); - const [adminRequiredPromptOpen, setAdminRequiredPromptOpen] = useState(false); + const isChatPage = pathname?.includes("/chat"); + + // Home page does not require authorization + const isHomePage = getEffectiveRoutePath(pathname) === "/"; // Sidebar collapse state const [collapsed, setCollapsed] = useState(false); @@ -94,26 +100,10 @@ export function ClientLayout({ children }: { children: ReactNode }) { backgroundColor: "#fff", }; - // Authentication handlers - const handleAuthRequired = () => { - if (!isSpeedMode && !user) { - setLoginPromptOpen(true); - } - }; - - const handleAdminRequired = () => { - if (!isSpeedMode && user?.role !== "admin") { - setAdminRequiredPromptOpen(true); - } - }; - - const handleCloseLoginPrompt = () => setLoginPromptOpen(false); - const handleCloseAdminPrompt = () => setAdminRequiredPromptOpen(false); - return (
- +
@@ -127,11 +117,7 @@ export function ClientLayout({ children }: { children: ReactNode }) { className="dark:bg-slate-900/95 border-r border-slate-200 dark:border-slate-700 backdrop-blur-sm shadow-sm" >
- +
- ) : null} diff --git a/frontend/app/[locale]/models/ModelConfiguration.tsx b/frontend/app/[locale]/models/ModelConfiguration.tsx index 1fe154a76..0fb82db31 100644 --- a/frontend/app/[locale]/models/ModelConfiguration.tsx +++ b/frontend/app/[locale]/models/ModelConfiguration.tsx @@ -22,14 +22,12 @@ const { Title } = Typography; // Add interface definition interface AppModelConfigProps { skipModelVerification?: boolean; - canAccessProtectedData?: boolean; // Expose a ref from parent to allow programmatic dropdown change forwardedRef?: React.Ref; } export default function AppModelConfig({ skipModelVerification = false, - canAccessProtectedData = false, forwardedRef, }: AppModelConfigProps) { const { t } = useTranslation(); @@ -133,7 +131,6 @@ export default function AppModelConfig({ diff --git a/frontend/app/[locale]/models/components/modelConfig.tsx b/frontend/app/[locale]/models/components/modelConfig.tsx index 628f52392..2588ebf76 100644 --- a/frontend/app/[locale]/models/components/modelConfig.tsx +++ b/frontend/app/[locale]/models/components/modelConfig.tsx @@ -28,6 +28,7 @@ import { ModelListCard } from "./model/ModelListCard"; import { ModelAddDialog } from "./model/ModelAddDialog"; import { ModelDeleteDialog } from "./model/ModelDeleteDialog"; import { useConfirmModal } from "@/hooks/useConfirmModal"; +import { Can } from "@/components/permission/Can"; // ModelConnectStatus type definition type ModelConnectStatus = (typeof MODEL_STATUS)[keyof typeof MODEL_STATUS]; @@ -86,7 +87,6 @@ export interface ModelConfigSectionRef { interface ModelConfigSectionProps { skipVerification?: boolean; - canAccessProtectedData?: boolean; } export const ModelConfigSection = forwardRef< @@ -95,9 +95,11 @@ export const ModelConfigSection = forwardRef< >((props, ref): ReactNode => { const { t } = useTranslation(); const { message } = App.useApp(); - const { skipVerification = false, canAccessProtectedData = false } = props; + + const { skipVerification = false } = props; const { modelConfig, updateModelConfig, appConfig } = useConfig(); const modelEngineEnable = appConfig.modelEngineEnabled; + const modelData = getModelData(t); const { confirm } = useConfirmModal(); @@ -152,9 +154,6 @@ export const ModelConfigSection = forwardRef< // Initialize loading useEffect(() => { - if (!canAccessProtectedData) { - return; - } // Load configuration from backend first, then load model lists when component mounts const fetchData = async () => { await configService.loadConfigToFrontend(); @@ -163,7 +162,7 @@ export const ModelConfigSection = forwardRef< }; fetchData(); - }, [canAccessProtectedData, skipVerification]); + }, [skipVerification]); // Listen to field error highlight events useEffect(() => { @@ -853,37 +852,41 @@ export const ModelConfigSection = forwardRef< )} - - - - - - + + + + + + + + + + - )} - {/* Delete button - only for admin */} - {isAdmin && ( + - )} + + {/* Register button */} + + + + {/* GitHub support */} + + + + + {/* Permission denied dialog - shown when user is not authorized */} + + {t("page.permissionDenied.confirm")} + , + ]} + centered + > +
+

{t("暂时没有该页面权限,请咨询管理员提升相应权限!")}

+
+
+ + {/* Session expired dialog - shown when user session has expired */} + { + closeSessionExpiredModal(); + openLoginModal(); + }} + onCancel={closeSessionExpiredModal} + okText={t("login.expired.okText")} + cancelText={t("login.expired.cancelText")} + centered + closable={false} + okButtonProps={{ type: "primary" }} + > +
+ + {t("login.expired.content")} +
+
+ + ); +} diff --git a/frontend/components/auth/avatarDropdown.tsx b/frontend/components/auth/avatarDropdown.tsx index 73124d465..199a53de3 100644 --- a/frontend/components/auth/avatarDropdown.tsx +++ b/frontend/components/auth/avatarDropdown.tsx @@ -6,22 +6,25 @@ import { Dropdown, Avatar, Spin, Button, Tag, ConfigProvider, App } from "antd"; import { UserRound, LogOut, LogIn, Power, UserRoundPlus } from "lucide-react"; import type { ItemType } from "antd/es/menu/interface"; -import { useAuth } from "@/hooks/useAuth"; +import { useAuthenticationContext } from "@/components/providers/AuthenticationProvider"; +import { useAuthorizationContext } from "@/components/providers/AuthorizationProvider"; import { useConfirmModal } from "@/hooks/useConfirmModal"; import { getRoleColor } from "@/lib/auth"; +import { USER_ROLES } from "@/const/auth"; export function AvatarDropdown() { - const { user, isLoading, logout, revoke, openLoginModal, openRegisterModal } = - useAuth(); + const { user, isAuthzReady } = useAuthorizationContext(); + const { isLoading, logout, revoke, openLoginModal, openRegisterModal } = + useAuthenticationContext(); const [dropdownOpen, setDropdownOpen] = useState(false); const { t } = useTranslation("common"); const { modal } = App.useApp(); const { confirm } = useConfirmModal(); + // Show loading while authentication is in progress if (isLoading) { return ; } - if (!user) { const items: ItemType[] = [ { @@ -91,7 +94,7 @@ export function AvatarDropdown() {
{user.email}
- {t(user.role === "admin" ? "auth.admin" : "auth.user")} + {t(`auth.${(user.role).toLowerCase()}`)}
@@ -126,7 +129,7 @@ export function AvatarDropdown() { // danger: true, className: "hover:!bg-red-100 focus:!bg-red-400 focus:!text-white", onClick: () => { - if (user.role === "admin") { + if (user.role === USER_ROLES.ADMIN) { modal.error({ title: t("auth.refuseRevoke"), content: t("auth.refuseRevokePrompt"), @@ -159,7 +162,7 @@ export function AvatarDropdown() { )} > } diff --git a/frontend/components/auth/index.ts b/frontend/components/auth/index.ts index 4987e4d15..be67fe0a7 100644 --- a/frontend/components/auth/index.ts +++ b/frontend/components/auth/index.ts @@ -5,4 +5,3 @@ export * from "./avatarDropdown"; export * from "./loginModal"; export * from "./registerModal"; -export * from "./sessionListeners"; diff --git a/frontend/components/auth/loginModal.tsx b/frontend/components/auth/loginModal.tsx index b53c8adb0..9c3944685 100644 --- a/frontend/components/auth/loginModal.tsx +++ b/frontend/components/auth/loginModal.tsx @@ -3,10 +3,13 @@ import { useTranslation } from "react-i18next"; import { Modal, Form, Input, Button, Typography, Space } from "antd"; import { UserRound, LockKeyhole } from "lucide-react"; +import { usePathname, useRouter } from "next/navigation"; -import { useAuth } from "@/hooks/useAuth"; +import { useAuthenticationContext } from "@/components/providers/AuthenticationProvider"; +import { useDeployment } from "@/components/providers/deploymentProvider"; import { useAuthForm } from "@/hooks/useAuthForm"; -import { EVENTS } from "@/const/auth"; +import { getEffectiveRoutePath } from "@/lib/auth"; +import log from "@/lib/logger"; const { Text } = Typography; @@ -19,13 +22,16 @@ export function LoginModal() { // Authentication state and methods from useAuth hook const { isLoginModalOpen, + isAuthenticated, closeLoginModal, openRegisterModal, login, - isFromSessionExpired, - setIsFromSessionExpired, authServiceUnavailable, - } = useAuth(); + } = useAuthenticationContext(); + const { isSpeedMode } = useDeployment(); + + const router = useRouter(); + const pathname = usePathname(); // Form state and validation methods from useAuthForm hook const { @@ -58,15 +64,12 @@ export function LoginModal() { // Attempt to login with provided credentials await login(values.email, values.password); - // Reset session expired flag after successful login - setIsFromSessionExpired(false); - - // Reset modal control state to prevent session expired modal from triggering again - // Small delay ensures proper state synchronization setTimeout(() => { - document.dispatchEvent(new CustomEvent("modalClosed")); + // Close the login modal after successful login + closeLoginModal(); }, 200); } catch (error: any) { + log.error("Login failed", error); // Clear email error and set password error flag setEmailError(""); setPasswordError(true); @@ -137,113 +140,108 @@ export function LoginModal() { resetForm(); closeLoginModal(); - // If login modal was opened due to session expiration, - // reset modal state and re-trigger the session expired event - if (isFromSessionExpired) { - // Reset modal state so session expired modal can be shown again - document.dispatchEvent(new CustomEvent("modalClosed")); - setTimeout(() => { - window.dispatchEvent( - new CustomEvent(EVENTS.SESSION_EXPIRED, { - detail: { message: t("auth.sessionExpired") }, - }) - ); - }, 100); + // If user manually cancels login from a protected page, + // redirect back to home instead of keeping them on the restricted page + if (!isAuthenticated && !isSpeedMode) { + const effectivePath = pathname ? getEffectiveRoutePath(pathname) : "/"; + if (effectivePath !== "/") { + router.push("/"); + } } }; return ( +
{t("auth.loginTitle")}
} open={isLoginModalOpen} onCancel={handleCancel} footer={null} - width={400} + width={420} centered forceRender - // Prevent modal from being closed by clicking mask when session is expired - // But allow close button to be visible so user can return to session expired prompt - maskClosable={!isFromSessionExpired} + maskClosable={false} closable={true} > -
- {/* Email input field */} - - } - placeholder={t("auth.emailPlaceholder")} - onChange={handleEmailChange} - size="large" - /> - - - {/* Password input field */} - + - } - placeholder={t("auth.passwordRequired")} - onChange={handlePasswordChange} - size="large" - status={passwordError ? "error" : ""} - /> - - - {/* Submit button */} - - - - - {/* Registration link section (hidden when opened from session expired flow) */} - {!isFromSessionExpired && ( -
- - {t("auth.noAccount")} - - -
- )} -
+ } + placeholder={t("auth.passwordRequired")} + onChange={handlePasswordChange} + size="large" + status={passwordError ? "error" : ""} + /> + + + {/* Submit button */} + + + + + {/* Registration link section (hidden when opened from session expired flow) */} + +
+ + {t("auth.noAccount")} + + +
+ + +
); } diff --git a/frontend/components/auth/registerModal.tsx b/frontend/components/auth/registerModal.tsx index f63dfab96..9633e6b0f 100644 --- a/frontend/components/auth/registerModal.tsx +++ b/frontend/components/auth/registerModal.tsx @@ -2,6 +2,7 @@ import { useState } from "react"; import { useTranslation } from "react-i18next"; +import { usePathname, useRouter } from "next/navigation"; import { Modal, Form, @@ -22,9 +23,11 @@ import { BookMarked, } from "lucide-react"; -import { useAuth } from "@/hooks/useAuth"; -import { AuthFormValues } from "@/types/auth" +import { useAuthenticationContext } from "@/components/providers/AuthenticationProvider"; +import { useDeployment } from "@/components/providers/deploymentProvider"; +import { AuthFormValues } from "@/types/auth"; import { useAuthForm } from "@/hooks/useAuthForm"; +import { getEffectiveRoutePath } from "@/lib/auth"; import log from "@/lib/logger"; const { Text } = Typography; @@ -32,11 +35,16 @@ const { Text } = Typography; export function RegisterModal() { const { isRegisterModalOpen, + isAuthenticated, closeRegisterModal, openLoginModal, register, authServiceUnavailable, - } = useAuth(); + } = useAuthenticationContext(); + const { isSpeedMode } = useDeployment(); + + const router = useRouter(); + const pathname = usePathname(); const { form, isLoading, @@ -211,7 +219,10 @@ export function RegisterModal() { ]); } // Registration service error - else if (errorType === "REGISTRATION_SERVICE_ERROR" || httpStatusCode === 500) { + else if ( + errorType === "REGISTRATION_SERVICE_ERROR" || + httpStatusCode === 500 + ) { const errorMsg = t("auth.registrationServiceError"); message.error(errorMsg); setEmailError(errorMsg); @@ -223,7 +234,10 @@ export function RegisterModal() { setEmailError(errorMsg); } // Auth service unavailable - else if (httpStatusCode === 503 || errorType === "AUTH_SERVICE_UNAVAILABLE") { + else if ( + httpStatusCode === 503 || + errorType === "AUTH_SERVICE_UNAVAILABLE" + ) { const errorMsg = t("auth.authServiceUnavailable"); message.error(errorMsg); setEmailError(errorMsg); @@ -252,6 +266,15 @@ export function RegisterModal() { setPasswordError({ target: "", message: "" }); setIsAdminMode(false); closeRegisterModal(); + + // If user manually cancels registration from a protected page, + // redirect back to home instead of keeping them on the restricted page + if (!isAuthenticated && !isSpeedMode) { + const effectivePath = pathname ? getEffectiveRoutePath(pathname) : "/"; + if (effectivePath !== "/") { + router.push("/"); + } + } }; // Handle email input change - real-time email format validation @@ -320,171 +343,171 @@ export function RegisterModal() { return ( +
{t("auth.registerTitle")}
} open={isRegisterModalOpen} onCancel={handleCancel} footer={null} - width={400} + width={420} centered forceRender > -
- { - if (!value) return Promise.resolve(); - if (!validateEmail(value)) { - return Promise.reject( - new Error(t("auth.invalidEmailFormat")) - ); - } - return Promise.resolve(); - }, - }, - ]} +
+ - } - placeholder="your@email.com" - size="large" - onChange={handleEmailInputChange} - /> - - - { - if (!value) return Promise.resolve(); - if (!validatePassword(value)) { - return Promise.reject(new Error(t("auth.passwordMinLength"))); - } - return Promise.resolve(); + { + if (!value) return Promise.resolve(); + if (!validateEmail(value)) { + return Promise.reject( + new Error(t("auth.invalidEmailFormat")) + ); + } + return Promise.resolve(); + }, }, - }, - ]} - hasFeedback - > - } - placeholder={t("auth.passwordRequired")} - size="large" - onChange={handlePasswordChange} - /> - - - ({ - validator(_, value) { - const password = getFieldValue("password"); - // First check password length using validation function - if (password && !validatePassword(password)) { + ]} + > + } + placeholder="your@email.com" + size="large" + onChange={handleEmailInputChange} + /> + + + { + if (!value) return Promise.resolve(); + if (!validatePassword(value)) { + return Promise.reject(new Error(t("auth.passwordMinLength"))); + } + return Promise.resolve(); + }, + }, + ]} + hasFeedback + > + } + placeholder={t("auth.passwordRequired")} + size="large" + onChange={handlePasswordChange} + /> + + + ({ + validator(_, value) { + const password = getFieldValue("password"); + // First check password length using validation function + if (password && !validatePassword(password)) { + setPasswordError({ + target: "password", + message: t("auth.passwordMinLength"), + }); + return Promise.reject(new Error(t("auth.passwordMinLength"))); + } + // Then check password match + if (!value || getFieldValue("password") === value) { + setPasswordError({ target: "", message: "" }); + return Promise.resolve(); + } setPasswordError({ - target: "password", - message: t("auth.passwordMinLength"), + target: "confirmPassword", + message: t("auth.passwordsDoNotMatch"), }); - return Promise.reject(new Error(t("auth.passwordMinLength"))); - } - // Then check password match - if (!value || getFieldValue("password") === value) { - setPasswordError({ target: "", message: "" }); - return Promise.resolve(); - } - setPasswordError({ - target: "confirmPassword", - message: t("auth.passwordsDoNotMatch"), - }); - return Promise.reject(new Error(t("auth.passwordsDoNotMatch"))); + return Promise.reject(new Error(t("auth.passwordsDoNotMatch"))); + }, + }), + ]} + > + } + placeholder={t("auth.confirmPasswordRequired")} + size="large" + onChange={handleConfirmPasswordChange} + /> + + + + + - } - placeholder={t("auth.confirmPasswordRequired")} - size="large" - onChange={handleConfirmPasswordChange} - /> - - - - - -
-
- - {t("auth.adminAccount")} -
- + } + placeholder={t("auth.inviteCodeRequired")} + size="large" /> -
- - {t("auth.adminAccountDescription")} - -
+
+ - {isAdminMode && ( <>
@@ -492,7 +515,7 @@ export function RegisterModal() { {t("auth.inviteCodeHint.title")}
-
+
✨ {t("auth.inviteCodeHint.step1")} {t("auth.inviteCodeHint.starAction")}
- ); } diff --git a/frontend/components/auth/sessionListeners.tsx b/frontend/components/auth/sessionListeners.tsx deleted file mode 100644 index bf23079e8..000000000 --- a/frontend/components/auth/sessionListeners.tsx +++ /dev/null @@ -1,230 +0,0 @@ -"use client"; - -import { useEffect, useRef } from "react"; -import { useRouter, usePathname } from "next/navigation"; -import { useTranslation } from "react-i18next"; -import { App, Modal } from "antd"; -import { ExclamationCircleOutlined } from "@ant-design/icons"; - -import { useAuth } from "@/hooks/useAuth"; -import { useConfirmModal } from "@/hooks/useConfirmModal"; -import { authService } from "@/services/authService"; -import { sessionService } from "@/services/sessionService"; -import { getSessionFromStorage } from "@/lib/auth"; -import { EVENTS } from "@/const/auth"; -import { - TOKEN_REFRESH_BEFORE_EXPIRY_MS, - MIN_ACTIVITY_CHECK_INTERVAL_MS, -} from "@/const/constants"; -import log from "@/lib/logger"; -import { saveView } from "@/lib/viewPersistence"; - -/** - * Session management component - * Handles session expiration, session refresh and other functions - */ -export function SessionListeners() { - const router = useRouter(); - const pathname = usePathname(); - const { t } = useTranslation("common"); - const { openLoginModal, setIsFromSessionExpired, clearLocalSession, isSpeedMode } = - useAuth(); - const { modal } = App.useApp(); - const { confirm } = useConfirmModal(); - const modalShownRef = useRef(false); - - const isLocaleHomePath = (path?: string | null) => { - if (!path) return false; - const segments = path.split("/").filter(Boolean); - return segments.length <= 1; - }; - - /** - * Show "Login Expired" confirmation modal - * This function handles debounce logic to prevent modal from appearing repeatedly - */ - const showSessionExpiredModal = () => { - // If already shown, return directly - if (modalShownRef.current) return; - modalShownRef.current = true; - - modal.confirm({ - title: t("login.expired.title"), - icon: , - content: t("login.expired.content"), - okText: t("login.expired.okText"), - cancelText: t("login.expired.cancelText"), - centered: true, - closable: false, - okButtonProps: { - type: "primary" - }, - onOk() { - // Clear local session state (session already expired on backend) - clearLocalSession(); - // Mark the source as session expired - setIsFromSessionExpired(true); - Modal.destroyAll(); - openLoginModal(); - setTimeout(() => (modalShownRef.current = false), 500); - }, - onCancel() { - // Clear local session state (session already expired on backend) - clearLocalSession(); - saveView("home"); - router.push("/"); - setTimeout(() => (modalShownRef.current = false), 500); - }, - }); - }; - - // Listen for events after successful login, reset modalShown state - useEffect(() => { - const handleModalClosed = () => { - modalShownRef.current = false; - }; - - // Add event listener - document.addEventListener("modalClosed", handleModalClosed); - - // Cleanup function - return () => { - document.removeEventListener("modalClosed", handleModalClosed); - }; - }, []); - - // Listen for session expiration events (skip in speed mode) - useEffect(() => { - if (isSpeedMode) return; - const handleSessionExpired = (event: CustomEvent) => { - // Directly call the wrapper function - showSessionExpiredModal(); - }; - - // Add event listener - window.addEventListener( - EVENTS.SESSION_EXPIRED, - handleSessionExpired as EventListener - ); - - // Cleanup function - return () => { - window.removeEventListener( - EVENTS.SESSION_EXPIRED, - handleSessionExpired as EventListener - ); - }; - // Remove confirm from dependency array to avoid duplicate registration due to function reference changes - }, [isSpeedMode]); - - // When component first mounts, if no local session is found, show modal immediately - useEffect(() => { - // Skip in speed mode - if (isSpeedMode) return; - }, []); - - // Session status check - useEffect(() => { - // Skip in speed mode - if (isSpeedMode) return; - // Check session status on first load - const checkSession = async () => { - try { - // Capture whether there was a local session before validation - const hadLocalSession = - typeof window !== "undefined" && !!localStorage.getItem("session"); - - // Try to get current session - const session = await authService.getSession(); - - // Only show session expired modal if a prior session existed and is now invalid - if ((!session && hadLocalSession) || (!session && !hadLocalSession && !isLocaleHomePath(pathname))) { - window.dispatchEvent( - new CustomEvent(EVENTS.SESSION_EXPIRED, { - detail: { message: "Session expired, please sign in again" }, - }) - ); - } - } catch (error) { - log.error("Error checking session status:", error); - } - }; - - checkSession(); - }, [pathname, isSpeedMode]); - - // Sliding expiration: refresh token shortly before expiry on user activity (skip in speed mode) - useEffect(() => { - if (isSpeedMode) return; - - let lastActivityCheckAt = 0; - - const maybeRefreshOnActivity = async () => { - try { - // Throttle activity-driven checks - const now = Date.now(); - if (now - lastActivityCheckAt < MIN_ACTIVITY_CHECK_INTERVAL_MS) return; - lastActivityCheckAt = now; - - // Do not run when page is hidden - if (typeof document !== "undefined" && document.hidden) return; - - const sessionObj = getSessionFromStorage(); - if (!sessionObj?.expires_at) return; - - const msUntilExpiry = sessionObj.expires_at * 1000 - now; - if (msUntilExpiry <= TOKEN_REFRESH_BEFORE_EXPIRY_MS) { - const ok = await sessionService.checkAndRefreshToken(); - if (!ok) { - // If refresh failed and token is already expired, raise expired flow - if (msUntilExpiry <= 0) { - window.dispatchEvent( - new CustomEvent(EVENTS.SESSION_EXPIRED, { - detail: { message: "Session expired, please sign in again" }, - }) - ); - } - } - } - } catch (error) { - log.error("Activity-based refresh check failed:", error); - } - }; - - const events: (keyof DocumentEventMap | keyof WindowEventMap)[] = [ - "click", - "keydown", - "mousemove", - "touchstart", - "focus", - "visibilitychange", - ]; - - const handler = () => { - // Wrap to avoid passing the event into async function - void maybeRefreshOnActivity(); - }; - - events.forEach((evt) => { - // Use window for focus/visibility, document for input/mouse - if (evt === "focus" || evt === "visibilitychange") { - window.addEventListener(evt as any, handler, { passive: true }); - } else { - document.addEventListener(evt as any, handler, { passive: true }); - } - }); - - return () => { - events.forEach((evt) => { - if (evt === "focus" || evt === "visibilitychange") { - window.removeEventListener(evt as any, handler as any); - } else { - document.removeEventListener(evt as any, handler as any); - } - }); - }; - }, [isSpeedMode]); - - // This component doesn't render UI elements - return null; -} diff --git a/frontend/components/homepage/AuthDialogs.tsx b/frontend/components/homepage/AuthDialogs.tsx deleted file mode 100644 index dcb66b917..000000000 --- a/frontend/components/homepage/AuthDialogs.tsx +++ /dev/null @@ -1,186 +0,0 @@ -"use client"; - -import { useTranslation, Trans } from "react-i18next"; -import { Modal, Button } from "antd"; - -interface AuthDialogsProps { - loginPromptOpen: boolean; - adminPromptOpen: boolean; - onCloseLoginPrompt: () => void; - onCloseAdminPrompt: () => void; - onLoginClick: () => void; - onRegisterClick: () => void; -} - -/** - * Authentication dialogs component - * Contains login prompt and admin prompt modals - */ -export function AuthDialogs({ - loginPromptOpen, - adminPromptOpen, - onCloseLoginPrompt, - onCloseAdminPrompt, - onLoginClick, - onRegisterClick, -}: AuthDialogsProps) { - const { t } = useTranslation("common"); - - return ( - <> - {/* Login prompt dialog */} - - {t("page.loginPrompt.register")} - , - , - ]} - centered - > - - - - {/* Admin prompt dialog */} - - {t("page.loginPrompt.register")} - , - , - ]} - centered - > -
-

{t("page.adminPrompt.intro")}

-
-
-

- {t("page.adminPrompt.unlockHeader")} -

-

- {t("page.adminPrompt.unlockIntro")} -

-
-

- {t("page.adminPrompt.permissionsTitle")} -

-
    - {( - t("page.adminPrompt.permissions", { - returnObjects: true, - }) as string[] - ).map((permission, i) => ( -
  • {permission}
  • - ))} -
-
-
-

- - ⭐️ Nexent is still growing, please help me by starring on - - GitHub - - , thank you. - -
-
- - 💡 Want to become an administrator? Please visit the - - official contact page - - to apply for an administrator account. - -

-
-
-
-
- - ); -} - diff --git a/frontend/components/navigation/ChatTopNavContent.tsx b/frontend/components/navigation/ChatTopNavContent.tsx new file mode 100644 index 000000000..c93e60555 --- /dev/null +++ b/frontend/components/navigation/ChatTopNavContent.tsx @@ -0,0 +1,48 @@ +"use client"; + +import { useConfig } from "@/hooks/useConfig"; +import { extractColorsFromUri } from "@/lib/avatar"; +import { useRouter } from "next/navigation"; +import { useTranslation } from "react-i18next"; + +/** + * ChatTopNavContent - Displays app logo and name in the top navbar for chat page + */ +export function ChatTopNavContent() { + const router = useRouter(); + const { i18n } = useTranslation(); + const { appConfig, getAppAvatarUrl } = useConfig(); + const sidebarAvatarUrl = getAppAvatarUrl(16); + + // Static font-size for top navbar (no responsive sizing required) + + const colors = extractColorsFromUri(appConfig.avatarUri || ""); + const mainColor = colors.mainColor || "273746"; + const secondaryColor = colors.secondaryColor || mainColor; + + return ( +
router.push(`/${i18n.language}`)} + > +
+ {appConfig.appName} +
+ + {appConfig.appName} + +
+ ); +} + diff --git a/frontend/components/navigation/SideNavigation.tsx b/frontend/components/navigation/SideNavigation.tsx index 56cba5c6c..2bbe4c9a9 100644 --- a/frontend/components/navigation/SideNavigation.tsx +++ b/frontend/components/navigation/SideNavigation.tsx @@ -1,6 +1,6 @@ "use client"; -import { useState, useEffect } from "react"; +import { useState, useEffect, useMemo } from "react"; import { useTranslation } from "react-i18next"; import { useRouter, usePathname } from "next/navigation"; import { Menu, ConfigProvider } from "antd"; @@ -20,143 +20,143 @@ import { Building2, } from "lucide-react"; import type { MenuProps } from "antd"; -import { useAuth } from "@/hooks/useAuth"; +import { useAuthorizationContext } from "@/components/providers/AuthorizationProvider"; +import { useAuthenticationContext } from "@/components/providers/AuthenticationProvider"; +import { useDeployment } from "@/components/providers/deploymentProvider"; import { SIDER_CONFIG } from "@/const/layoutConstants"; +import { AUTH_EVENTS } from "@/const/auth"; +import { getEffectiveRoutePath } from "@/lib/auth"; +import { authEvents } from "@/lib/authEvents"; interface SideNavigationProps { - onAuthRequired?: () => void; - onAdminRequired?: () => void; collapsed?: boolean; } +/** + * Route configuration interface for menu items + */ +interface RouteConfig { + path: string; + Icon: React.ComponentType<{ className?: string }>; + labelKey: string; + order: number; +} + +/** + * Static route configuration mapping + * All available routes with their metadata + */ +const ROUTE_CONFIG: RouteConfig[] = [ + { path: "/", Icon: Home, labelKey: "sidebar.homePage", order: 0 }, + { path: "/chat", Icon: Bot, labelKey: "sidebar.startChat", order: 1 }, + { path: "/setup", Icon: Zap, labelKey: "sidebar.quickConfig", order: 2 }, + { path: "/space", Icon: Globe, labelKey: "sidebar.agentSpace", order: 3 }, + { path: "/market", Icon: ShoppingBag, labelKey: "sidebar.agentMarket", order: 4 }, + { path: "/agents", Icon: Code, labelKey: "sidebar.agentDev", order: 5 }, + { path: "/knowledges", Icon: BookOpen, labelKey: "sidebar.knowledgeBase", order: 6 }, + { path: "/mcp-tools", Icon: Puzzle, labelKey: "sidebar.mcpToolsManagement", order: 7 }, + { path: "/monitoring", Icon: Activity, labelKey: "sidebar.monitoringManagement", order: 8 }, + { path: "/models", Icon: Settings, labelKey: "sidebar.modelManagement", order: 9 }, + { path: "/memory", Icon: Database, labelKey: "sidebar.memoryManagement", order: 10 }, + { path: "/users", Icon: Users, labelKey: "sidebar.userManagement", order: 11 }, +]; + +/** + * Extract all available route paths from ROUTE_CONFIG + */ +const ROUTE_PATHS = ROUTE_CONFIG.map((route) => route.path); + /** * Side navigation component with collapsible menu - * Displays main navigation items for the application + * Displays main navigation items for the application based on user's accessible routes */ export function SideNavigation({ - onAuthRequired, - onAdminRequired, collapsed, }: SideNavigationProps) { const { t } = useTranslation("common"); - const { user, isSpeedMode } = useAuth(); + const { accessibleRoutes } = useAuthorizationContext(); + const { isAuthenticated, openAuthPromptModal } = useAuthenticationContext(); + const { isSpeedMode } = useDeployment(); const router = useRouter(); const pathname = usePathname(); - const [selectedKey, setSelectedKey] = useState("home"); + const [selectedKey, setSelectedKey] = useState("/"); + const [pendingNavigationPath, setPendingNavigationPath] = useState(null); const isCollapsed = typeof collapsed === "boolean" ? collapsed : false; - // 添加路径到key的映射 - const pathToKeyMap: Record = { - "/": "home", - "/chat": "chat", - "/setup": "setup", - "/space": "space", - "/market": "market", - "/agents": "agents", - "/knowledges": "knowledges", - "/mcp-tools": "mcp-tools", - "/monitoring": "monitoring", - "/models": "models", - "/memory": "memory", - "/users": "users", - "/tenant-resources": "tenant-resources", - }; + // Update selected key when pathname changes + useEffect(() => { + const currentPath = getEffectiveRoutePath(pathname); + const matchedKey = ROUTE_PATHS.includes(currentPath) ? currentPath : "/"; + setSelectedKey(matchedKey); + }, [pathname]); - // 添加useEffect来监听pathname变化并更新selectedKey + // Listen for login success event and navigate to pending path useEffect(() => { - // 从pathname中提取实际路径(去掉locale前缀) - const segments = pathname.split("/").filter(Boolean); - // 如果第一个segment是locale(zh或en),则去掉它 - if (segments.length > 0 && (segments[0] === "zh" || segments[0] === "en")) { - segments.shift(); + const handleLoginSuccess = () => { + if (pendingNavigationPath && isAuthenticated) { + // Small delay to ensure authentication state is fully updated + setTimeout(() => { + router.push(pendingNavigationPath); + setPendingNavigationPath(null); + }, 200); + } + }; + + const cleanup = authEvents.on(AUTH_EVENTS.LOGIN_SUCCESS, handleLoginSuccess); + return cleanup; + }, [pendingNavigationPath, isAuthenticated, router]); + + // Listen for back-to-home event and reset selected key + useEffect(() => { + const handleBackToHome = () => { + setSelectedKey("/"); + }; + + const cleanup = authEvents.on(AUTH_EVENTS.BACK_TO_HOME, handleBackToHome); + return cleanup; + }, []); + + // Filter and sort routes based on accessibleRoutes from authorization context + const accessibleMenuItems = useMemo((): RouteConfig[] => { + if (!accessibleRoutes || accessibleRoutes.length === 0) { + // If no accessibleRoutes available, show all routes (fallback) + return ROUTE_CONFIG; } - // 重新构建路径 - const currentPath = "/" + segments.join("/"); - // 查找对应的key,找不到则默认为home - const matchedKey = pathToKeyMap[currentPath] || "home"; - setSelectedKey(matchedKey); - }, [pathname]); + return ROUTE_CONFIG.filter((route) => + accessibleRoutes.includes(route.path) + ).sort((a, b) => a.order - b.order); + }, [accessibleRoutes]); - // Helper function to create menu item with consistent icon styling + /** + * Create a menu item from route configuration + * Pre-check authentication before navigation to avoid unnecessary route changes + */ const createMenuItem = ( - key: string, - path: string, - Icon: any, - labelKey: string, - requiresAuth = false, - requiresAdmin = false - ) => ({ - key, - path, - icon: , - label: t(labelKey), - onClick: () => { - if (!isSpeedMode && requiresAdmin && user?.role !== "admin") { - onAdminRequired?.(); - } else if (!isSpeedMode && requiresAuth && !user) { - onAuthRequired?.(); - } else { - setSelectedKey(key); - if (path) { - router.push(path); + route: RouteConfig + ): NonNullable[number] => { + return { + key: route.path, + icon: , + label: t(route.labelKey), + onClick: () => { + setSelectedKey(route.path); + + // Pre-check authentication - show auth prompt if user is not authenticated + if (!isAuthenticated && !isSpeedMode && route.path !== "/") { + setPendingNavigationPath(route.path); + openAuthPromptModal(); + return; // Prevent navigation } - } - }, - }); - - // Menu items configuration - paths without locale prefix (middleware will add it) - const menuItems: MenuProps["items"] = [ - createMenuItem("0", "/", Home, "sidebar.homePage"), - createMenuItem("1", "/chat", Bot, "sidebar.startChat", true), - createMenuItem("2", "/setup", Zap, "sidebar.quickConfig", false, true), - createMenuItem("3", "/space", Globe, "sidebar.agentSpace", true), - createMenuItem("4", "/market", ShoppingBag, "sidebar.agentMarket", true), - createMenuItem("5", "/agents", Code, "sidebar.agentDev", false, true), - createMenuItem("6", "/knowledges", BookOpen, "sidebar.knowledgeBase", true), - createMenuItem( - "10", - "/mcp-tools", - Puzzle, - "sidebar.mcpToolsManagement", - false, - true - ), - createMenuItem( - "11", - "/monitoring", - Activity, - "sidebar.monitoringManagement", - false, - true - ), - createMenuItem( - "7", - "/models", - Settings, - "sidebar.modelManagement", - false, - true - ), - createMenuItem( - "8", - "/memory", - Database, - "sidebar.memoryManagement", - false, - true - ), - createMenuItem("9", "/users", User, "sidebar.userManagement", false, true), - createMenuItem( - "12", - "/tenant-resources", - Building2, - "sidebar.tenantResources", - false, - true - ), - ]; + + router.push(route.path); + }, + }; + }; + + // Generate menu items from accessible routes + const menuItems: MenuProps["items"] = accessibleMenuItems.map(createMenuItem); return ( diff --git a/frontend/components/navigation/TopNavbar.tsx b/frontend/components/navigation/TopNavbar.tsx index 2cc5a28f4..2fbeee744 100644 --- a/frontend/components/navigation/TopNavbar.tsx +++ b/frontend/components/navigation/TopNavbar.tsx @@ -3,7 +3,6 @@ import { Button } from "antd"; import { AvatarDropdown } from "@/components/auth/avatarDropdown"; import { useTranslation } from "react-i18next"; -import { useAuth } from "@/hooks/useAuth"; import { ChevronDown, Globe } from "lucide-react"; import { Dropdown } from "antd"; import Link from "next/link"; @@ -11,24 +10,16 @@ import { HEADER_CONFIG, SIDER_CONFIG } from "@/const/layoutConstants"; import { languageOptions } from "@/const/constants"; import { useLanguageSwitch } from "@/lib/language"; import React from "react"; -import { Flex, Layout } from 'antd'; +import { Flex, Layout } from "antd"; +import { ChatTopNavContent } from "./ChatTopNavContent"; +import { useAuthorizationContext } from "../providers/AuthorizationProvider"; +import { useDeployment } from "../providers/deploymentProvider"; const { Header } = Layout; -interface TopNavbarProps { - /** Additional title text to display after logo (separated by |) */ - additionalTitle?: React.ReactNode; - /** Additional content to insert before default right nav items */ - additionalRightContent?: React.ReactNode; -} - -/** - * Top navigation bar component - * Displays logo, language switcher, and user authentication status - * Can be customized with additionalTitle and additionalRightContent props - */ -export function TopNavbar({ additionalTitle, additionalRightContent }: TopNavbarProps) { +export function TopNavbar({ isChatPage }: { isChatPage: boolean }) { const { t } = useTranslation("common"); - const { user, isLoading: userLoading, isSpeedMode } = useAuth(); + const { user, isLoading } = useAuthorizationContext(); + const { isSpeedMode } = useDeployment() const { currentLanguage, handleLanguageChange } = useLanguageSwitch(); // Left content - Logo + optional additional title (aligned with sidebar width) @@ -38,20 +29,16 @@ export function TopNavbar({ additionalTitle, additionalRightContent }: TopNavbar - ModelEngine + ModelEngine {t("assistant.name")} @@ -60,11 +47,11 @@ export function TopNavbar({ additionalTitle, additionalRightContent }: TopNavbar {/* Additional title with separator - outside of sidebar width */} - {additionalTitle && ( + {isChatPage && (
- {additionalTitle} +
)} @@ -74,8 +61,6 @@ export function TopNavbar({ additionalTitle, additionalRightContent }: TopNavbar // Right content - Additional content + default navigation items const rightContent = ( - {/* Additional right content (e.g., status badge) */} - {additionalRightContent} {/* GitHub link */} - {userLoading ? ( + {isLoading ? ( {t("common.loading")}... @@ -179,4 +164,3 @@ export function TopNavbar({ additionalTitle, additionalRightContent }: TopNavbar ); } - diff --git a/frontend/components/permission/Can.tsx b/frontend/components/permission/Can.tsx new file mode 100644 index 000000000..33a282c05 --- /dev/null +++ b/frontend/components/permission/Can.tsx @@ -0,0 +1,36 @@ +"use client"; + +import React from "react"; +import { usePermission } from "@/hooks/permission/usePermission"; + +interface CanProps { + permission: string | string[]; + children: React.ReactNode; + fallback?: React.ReactNode; +} + +/** + * Render children only when user HAS the permission + * + * @example + * ```tsx + * + * + * + * + * + * + * + * ``` + */ +export function Can({ permission, children, fallback = null }: CanProps) { + const { isReady, can, canAny } = usePermission(); + + if (!isReady) return null; + + const hasPermission = Array.isArray(permission) + ? canAny(permission) + : can(permission); + + return hasPermission ? <>{children} : <>{fallback}; +} diff --git a/frontend/components/permission/Cannot.tsx b/frontend/components/permission/Cannot.tsx new file mode 100644 index 000000000..99a730897 --- /dev/null +++ b/frontend/components/permission/Cannot.tsx @@ -0,0 +1,32 @@ +"use client"; + +import React from "react"; +import { usePermission } from "@/hooks/permission/usePermission"; + +interface CannotProps { + permission: string | string[]; + children: React.ReactNode; + fallback?: React.ReactNode; +} + +/** + * Render children only when user does NOT have the permission + * + * @example + * ```tsx + * + * + * + * ``` + */ +export function Cannot({ permission, children, fallback = null }: CannotProps) { + const { isReady, can, canAny } = usePermission(); + + if (!isReady) return null; + + const hasPermission = Array.isArray(permission) + ? canAny(permission) + : can(permission); + + return !hasPermission ? <>{children} : <>{fallback}; +} diff --git a/frontend/components/providers/AuthenticationProvider.tsx b/frontend/components/providers/AuthenticationProvider.tsx new file mode 100644 index 000000000..cbf2ce9d8 --- /dev/null +++ b/frontend/components/providers/AuthenticationProvider.tsx @@ -0,0 +1,42 @@ +"use client"; + +import React, { createContext, useContext, ReactNode } from "react"; +import { useAuthentication } from "@/hooks/auth/useAuthentication"; +import { AuthenticationContextType } from "@/types/auth"; + +/** + * Authentication Context + */ +const AuthenticationContext = createContext< + AuthenticationContextType | undefined +>(undefined); + +/** + * Authentication Provider Component + * Provides authentication state and methods to the component tree + */ +export function AuthenticationProvider({ children }: { children?: ReactNode }) { + const authValue = useAuthentication(); + + return ( + + {children} + + ); +} + +/** + * Hook to use authentication context + */ +export function useAuthenticationContext(): AuthenticationContextType { + const context = useContext(AuthenticationContext); + if (context === undefined) { + throw new Error( + "useAuthenticationContext must be used within an AuthenticationProvider" + ); + } + return context; +} + +// Export context for advanced use cases +export { AuthenticationContext }; diff --git a/frontend/components/providers/AuthorizationProvider.tsx b/frontend/components/providers/AuthorizationProvider.tsx new file mode 100644 index 000000000..b597cfcd2 --- /dev/null +++ b/frontend/components/providers/AuthorizationProvider.tsx @@ -0,0 +1,42 @@ +"use client"; + +import React, { createContext, useContext, ReactNode } from "react"; +import { useAuthorization } from "@/hooks/auth/useAuthorization"; +import { AuthorizationContextType } from "@/types/auth"; + +/** + * Authorization Context + */ +const AuthorizationContext = createContext< + AuthorizationContextType | undefined +>(undefined); + +/** + * Authorization Provider Component + * Provides authorization state and methods to the component tree + */ +export function AuthorizationProvider({ children }: { children?: ReactNode }) { + const authzValue = useAuthorization(); + + return ( + + {children} + + ); +} + +/** + * Hook to use authorization context + */ +export function useAuthorizationContext(): AuthorizationContextType { + const context = useContext(AuthorizationContext); + if (context === undefined) { + throw new Error( + "useAuthorizationContext must be used within an AuthorizationProvider" + ); + } + return context; +} + +// Export context for advanced use cases +export { AuthorizationContext }; diff --git a/frontend/components/providers/deploymentProvider.tsx b/frontend/components/providers/deploymentProvider.tsx index 2612f2b9d..8cb7dfe57 100644 --- a/frontend/components/providers/deploymentProvider.tsx +++ b/frontend/components/providers/deploymentProvider.tsx @@ -1,6 +1,12 @@ "use client"; -import { createContext, useContext, useState, useEffect, ReactNode } from "react"; +import { + createContext, + useContext, + useState, + useEffect, + ReactNode, +} from "react"; import { API_ENDPOINTS } from "@/services/api"; import log from "@/lib/logger"; @@ -21,14 +27,17 @@ export function DeploymentProvider({ children }: { children: ReactNode }) { useEffect(() => { const checkDeploymentVersion = async () => { try { - const response = await fetch(API_ENDPOINTS.tenantConfig.deploymentVersion); + const response = await fetch( + API_ENDPOINTS.tenantConfig.deploymentVersion + ); if (response.ok) { const data = await response.json(); - const version = data.content?.deployment_version || data.deployment_version; - setIsSpeedMode(version === 'speed'); + const version = + data.content?.deployment_version || data.deployment_version; + setIsSpeedMode(version === "speed"); } } catch (error) { - log.error('Failed to check deployment version:', error); + log.error("Failed to check deployment version:", error); setIsSpeedMode(false); } finally { setIsDeploymentReady(true); @@ -46,4 +55,3 @@ export function DeploymentProvider({ children }: { children: ReactNode }) { } export const useDeployment = () => useContext(DeploymentContext); - diff --git a/frontend/components/providers/rootProvider.tsx b/frontend/components/providers/rootProvider.tsx index 463cff979..09e6fe95d 100644 --- a/frontend/components/providers/rootProvider.tsx +++ b/frontend/components/providers/rootProvider.tsx @@ -5,23 +5,40 @@ import { ConfigProvider, App } from "antd"; import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; import { - AuthProvider, - AuthContext, - useAuth, -} from "@/hooks/useAuth"; + AuthenticationProvider, + useAuthenticationContext, +} from "@/components/providers/AuthenticationProvider"; +import { + AuthorizationProvider, + useAuthorizationContext, +} from "@/components/providers/AuthorizationProvider"; -import { LoginModal, RegisterModal, SessionListeners } from "@/components/auth"; +import { LoginModal } from "@/components/auth/loginModal"; +import { RegisterModal } from "@/components/auth/registerModal"; import { FullScreenLoading } from "@/components/ui/loading"; import { useDeployment } from "./deploymentProvider"; -function AppReadyWrapper({ children }: { children: ReactNode }) { - const { isDeploymentReady } = useDeployment(); - const auth = useAuth(); - const isAuthReady = (auth as any).isAuthReady; +function AppReadyWrapper({ children }: { children?: ReactNode }) { + const { isDeploymentReady, isSpeedMode } = useDeployment(); + const auth = useAuthenticationContext(); + const authz = useAuthorizationContext(); - const isAppReady = isDeploymentReady && isAuthReady; + // In speed mode, skip auth checks since authentication is bypassed + // isAuthChecking: allow rendering during auth state check to avoid blocking UI + const isAuthReady = isSpeedMode || !auth.isLoading || auth.isAuthenticated || auth.isAuthChecking; + const isAuthzReady = isSpeedMode || !authz.isLoading || auth.isAuthenticated || auth.isAuthChecking; + const isAppReady = isDeploymentReady && isAuthReady && isAuthzReady; - return isAppReady ? <>{children} : ; + // If login or register modal is open, user is performing an operation, + // don't show full screen loading (they can already see the page) + const isUserOperating = auth.isLoginModalOpen || auth.isRegisterModalOpen; + + // Only show FullScreenLoading during initial load, not during user operations + if (isAppReady || isUserOperating) { + return <>{children}; + } + + return ; } /** @@ -33,20 +50,15 @@ export function RootProvider({ children }: { children: ReactNode }) { document.body}> - - {(authContextValue) => ( - + + - <> - {children} - - + <>{children} - - )} - + + diff --git a/frontend/const/auth.ts b/frontend/const/auth.ts index eb0f66c4b..06a6b924e 100644 --- a/frontend/const/auth.ts +++ b/frontend/const/auth.ts @@ -9,7 +9,7 @@ export enum USER_ROLES { export const STATUS_CODES = { SUCCESS: 200, - + UNAUTHORIZED_HTTP: 401, REQUEST_ENTITY_TOO_LARGE: 413, @@ -32,3 +32,22 @@ export const EVENTS = { SESSION_EXPIRED: "session-expired", STORAGE_CHANGE: "storage", }; + +// Type-safe authentication events (used with authEvents emitter) +export const AUTH_EVENTS = { + LOGIN_SUCCESS: "auth:login-success", + REGISTER_SUCCESS: "auth:register-success", + LOGOUT: "auth:logout", + SESSION_EXPIRED: "auth:session-expired", + PERMISSION_DENIED: "auth:permission-denied", + TOKEN_REFRESHED: "auth:token-refreshed", + SERVICE_UNAVAILABLE: "auth:service-unavailable", + BACK_TO_HOME: "nav:back-to-home", +} as const; + +// Type-safe authorization events (used with authzEvents emitter) +export const AUTHZ_EVENTS = { + PERMISSIONS_READY: "authz:permissions-ready", + PERMISSIONS_UPDATED: "authz:permissions-updated", +} as const; + diff --git a/frontend/hooks/auth/useAuthentication.ts b/frontend/hooks/auth/useAuthentication.ts new file mode 100644 index 000000000..b360d613e --- /dev/null +++ b/frontend/hooks/auth/useAuthentication.ts @@ -0,0 +1,55 @@ +"use client"; + +import { useAuthenticationState } from "@/hooks/auth/useAuthenticationState"; +import { useAuthenticationUI } from "@/hooks/auth/useAuthenticationUI"; +import { AuthenticationContextType } from "@/types/auth"; + +/** + * Custom hook for authentication management + * Combines useAuthenticationState and useAuthenticationUI to provide full authentication functionality + */ +export function useAuthentication(): AuthenticationContextType { + const authState = useAuthenticationState(); + // Pass auth state to useAuthenticationUI to avoid circular dependency + const authUI = useAuthenticationUI({ + isAuthenticated: authState.isAuthenticated, + isAuthChecking: authState.isAuthChecking, + clearLocalSession: authState.clearLocalSession, + }); + + return { + // Authentication state + isAuthenticated: authState.isAuthenticated, + isAuthChecking: authState.isAuthChecking, + isLoading: authState.isLoading, + session: authState.session, + + authServiceUnavailable: authState.authServiceUnavailable, + + // Methods + login: authState.login, + register: authState.register, + logout: authState.logout, + clearLocalSession: authState.clearLocalSession, + revoke: authState.revoke, + + // UI state + isLoginModalOpen: authUI.isLoginModalOpen, + isRegisterModalOpen: authUI.isRegisterModalOpen, + isAuthPromptModalOpen: authUI.isAuthPromptModalOpen, + isSessionExpiredModalOpen: authUI.isSessionExpiredModalOpen, + + // UI methods + openLoginModal: authUI.openLoginModal, + closeLoginModal: authUI.closeLoginModal, + + openRegisterModal: authUI.openRegisterModal, + closeRegisterModal: authUI.closeRegisterModal, + + openAuthPromptModal: authUI.openAuthPromptModal, + closeAuthPromptModal: authUI.closeAuthPromptModal, + + openSessionExpiredModal: authUI.openSessionExpiredModal, + closeSessionExpiredModal: authUI.closeSessionExpiredModal + }; +} diff --git a/frontend/hooks/auth/useAuthenticationState.ts b/frontend/hooks/auth/useAuthenticationState.ts new file mode 100644 index 000000000..40fc3f25d --- /dev/null +++ b/frontend/hooks/auth/useAuthenticationState.ts @@ -0,0 +1,235 @@ +"use client"; + +import { useState, useEffect, useCallback } from "react"; +import { useTranslation } from "react-i18next"; +import { App } from "antd"; + +import { useDeployment } from "@/components/providers/deploymentProvider"; +import { authService } from "@/services/authService"; +import { getSessionFromStorage, removeSessionFromStorage, checkSessionValid } from "@/lib/session"; +import { Session, AuthenticationStateReturn, AuthenticationContextType } from "@/types/auth"; +import { STATUS_CODES } from "@/const/auth"; +import { authEventUtils } from "@/lib/authEvents"; +import log from "@/lib/logger"; + +/** + * Custom hook for authentication state management + * Handles JWT tokens, login/logout, session restoration, and modal states + */ +export function useAuthenticationState(): AuthenticationStateReturn { + const { t } = useTranslation("common"); + const { message } = App.useApp(); + const { isSpeedMode } = useDeployment(); + + // Authentication state + const [isAuthenticated, setIsAuthenticated] = useState(false); + const [isAuthChecking, setIsAuthChecking] = useState(true); + const [isLoading, setIsLoading] = useState(false); + const [session, setSession] = useState(null); + const [authServiceUnavailable, setAuthServiceUnavailable] = + useState(false); + + + + // Initialize authentication state based on session + useEffect(() => { + // Check session validity directly + if (checkSessionValid()) { + // Session is valid, restore it to state + const storedSession = getSessionFromStorage(); + if (storedSession) { + setSession(storedSession); + } + setIsAuthenticated(true); + } else { + // No valid session + setSession(null); + setIsAuthenticated(false); + } + setIsAuthChecking(false); + }, []); + + + const clearLocalSession = useCallback(() => { + removeSessionFromStorage(); + setSession(null); + setIsAuthenticated(false); + }, []); + + // Login method + const login = useCallback( + async ( + email: string, + password: string, + options: { showSuccessMessage?: boolean } = {} + ) => { + const { showSuccessMessage = true } = options; + + setIsLoading(true); + + try { + // First check auth service availability + const isAuthServiceAvailable = + await authService.checkAuthServiceAvailable(); + if (!isAuthServiceAvailable) { + const error = new Error(t("auth.authServiceUnavailable")); + (error as any).code = STATUS_CODES.AUTH_SERVICE_UNAVAILABLE; + setAuthServiceUnavailable(true); + throw error; + } + + setAuthServiceUnavailable(false); + + const { data, error } = await authService.signIn(email, password); + + if (error) { + log.error("Login failed: ", error.message); + throw error; + } + + if (data?.session) { + // Update authentication state + setSession(data.session); + setIsAuthenticated(true); + + // Delay to ensure UI updates + setTimeout(() => { + if (showSuccessMessage) { + message.success(t("auth.loginSuccess")); + } + + authEventUtils.emitLoginSuccess(); + }, 150); + } + } catch (error: any) { + log.error("Error during login process:", error.message); + throw error; + } finally { + setIsLoading(false); + } + }, + [] + ); + + // Register method + const register = useCallback( + async ( + email: string, + password: string, + isAdmin: boolean = false, + inviteCode?: string + ) => { + setIsLoading(true); + + try { + const { data, error } = await authService.signUp( + email, + password, + isAdmin, + inviteCode + ); + + if (error) { + log.error("Registration failed: ", error.message); + throw error; + } + + if (data?.session) { + setSession(data.session); + setIsAuthenticated(true); + + setTimeout(() => { + message.success(t("auth.registrationSuccess")); + + // Emit register success event to close register modal + authEventUtils.emitRegisterSuccess(); + // Emit login success event for permission fetching + authEventUtils.emitLoginSuccess(); + }, 150); + } + } catch (error: any) { + log.error("Error during registration process:", error.message); + throw error; + } finally { + setIsLoading(false); + } + }, + [] + ); + + // Logout method + const logout = useCallback( + async (options: { silent?: boolean } = {}) => { + const { silent = false } = options; + + try { + setIsLoading(true); + + if (!silent) { + // Call logout API + await authService.signOut(); + } + + // Clear local session + removeSessionFromStorage(); + setSession(null); + setIsAuthenticated(false); + + if (!silent) { + message.success(t("auth.logoutSuccess")); + } + + // Emit logout event + authEventUtils.emitLogout(); + } catch (error: any) { + log.error("Logout failed:", error?.message || error); + // Even if API call fails, clear local session + removeSessionFromStorage(); + setSession(null); + setIsAuthenticated(false); + + if (!silent) { + message.error(t("auth.logoutFailed")); + } + } finally { + setIsLoading(false); + } + }, + [] + ); + + // Revoke method + const revoke = useCallback(async () => { + try { + setIsLoading(true); + + await authService.revoke(); + + clearLocalSession(); + message.success(t("auth.revokeSuccess")); + + authEventUtils.emitLogout(); + } catch (error: any) { + log.error("Revoke failed:", error?.message || error); + message.error(t("auth.revokeFailed")); + } finally { + setIsLoading(false); + } + }, [clearLocalSession]); + + return { + // Authentication state + isAuthenticated, + isAuthChecking, + isLoading, + session, + authServiceUnavailable, + + // Methods + login, + register, + logout, + clearLocalSession, + revoke + }; +} diff --git a/frontend/hooks/auth/useAuthenticationUI.ts b/frontend/hooks/auth/useAuthenticationUI.ts new file mode 100644 index 000000000..bf13fc72c --- /dev/null +++ b/frontend/hooks/auth/useAuthenticationUI.ts @@ -0,0 +1,174 @@ +"use client"; + +import { useState, useCallback, useRef, useEffect } from "react"; +import { useRouter, usePathname } from "next/navigation"; +import { useTranslation } from "react-i18next"; + +import { useDeployment } from "@/components/providers/deploymentProvider"; +import { AUTH_EVENTS } from "@/const/auth"; +import { getEffectiveRoutePath } from "@/lib/auth"; +import { authEvents, authEventUtils } from "@/lib/authEvents"; +import { AuthenticationUIReturn } from "@/types/auth"; +import log from "@/lib/logger"; + +/** + * Custom hook for authentication UI management + * Handles login/register modals, auth prompt modals, and session expired modal + * Must be used within AuthenticationProvider + */ +export function useAuthenticationUI({ + isAuthenticated, + isAuthChecking, + clearLocalSession, +}: { + isAuthenticated: boolean; + isAuthChecking: boolean; + clearLocalSession: () => void; +}): AuthenticationUIReturn { + const router = useRouter(); + const pathname = usePathname(); + const { t } = useTranslation("common"); + const { isSpeedMode } = useDeployment(); + + // UI state for modals - managed locally within the hook + const [isLoginModalOpen, setIsLoginModalOpen] = useState(false); + const [isRegisterModalOpen, setIsRegisterModalOpen] = useState(false); + const [isAuthPromptModalOpen, setIsAuthPromptModalOpen] = useState(false); + const [isSessionExpiredModalOpen, setIsSessionExpiredModalOpen] = useState(false); + + const handleUnauthenticatedModalClose = (() => { + // Only emit back to home event and redirect if user is not authenticated + if (!isAuthenticated && !isSpeedMode) { + + // Emit event to notify SideNavigation to reset selected key + authEventUtils.emitBackToHome(); + // Redirect to home page if not already there + const effectivePath = pathname ? getEffectiveRoutePath(pathname) : "/"; + if (effectivePath !== "/") { + router.push("/"); + } + } + }); + + // Modal control functions + const openLoginModal = useCallback(() => setIsLoginModalOpen(true), []); + + const closeLoginModal = useCallback(() => { + setIsLoginModalOpen(false); + handleUnauthenticatedModalClose(); + }, [handleUnauthenticatedModalClose]); + + const openRegisterModal = useCallback(() => setIsRegisterModalOpen(true), []); + + const closeRegisterModal = useCallback(() => { + setIsRegisterModalOpen(false); + handleUnauthenticatedModalClose(); + }, [handleUnauthenticatedModalClose]); + + const openAuthPromptModal = useCallback(() => setIsAuthPromptModalOpen(true), []); + + const closeAuthPromptModal = useCallback(() => { + setIsAuthPromptModalOpen(false); + handleUnauthenticatedModalClose(); + }, [handleUnauthenticatedModalClose]); + + const openSessionExpiredModal = useCallback(() => setIsSessionExpiredModalOpen(true), []); + + const closeSessionExpiredModal = useCallback(() => { + clearLocalSession(); + setIsSessionExpiredModalOpen(false); + handleUnauthenticatedModalClose(); + }, [handleUnauthenticatedModalClose]); + + /** + * Check if current path is home page + * Home page paths: "/", "/zh", "/en" + */ + const isLocaleHomePath = (path?: string | null) => { + if (!path) return false; + const segments = path.split("/").filter(Boolean); + return segments.length <= 1; + }; + + useEffect(() => { + if (isSpeedMode) return; + + const handleSessionExpired = () => { + setIsSessionExpiredModalOpen(true); + }; + + const handleRegisterSuccess = () => { + setIsRegisterModalOpen(false); + }; + + // Add event listener using type-safe auth events + const cleanup = authEvents.on( + AUTH_EVENTS.SESSION_EXPIRED, + handleSessionExpired + ); + const cleanupRegister = authEvents.on( + AUTH_EVENTS.REGISTER_SUCCESS, + handleRegisterSuccess + ); + + // Return cleanup function + return () => { + cleanup(); + cleanupRegister(); + }; + }, [isSpeedMode, setIsSessionExpiredModalOpen]); + + + + // Route guard for unauthenticated users - check when pathname changes + useEffect(() => { + if (isSpeedMode) return; + // Skip while checking auth state + if (isAuthChecking) return; + // Skip if user is authenticated + if (isAuthenticated) return; + + // Skip if already on home page + if (isLocaleHomePath(pathname)) return; + + // For unauthenticated users accessing protected routes, show auth prompt + const effectivePath = getEffectiveRoutePath(pathname); + if (effectivePath !== "/") { + openAuthPromptModal(); + } + }, [pathname, isAuthenticated, isSpeedMode, isAuthChecking, openAuthPromptModal]); + + // After auth check completes, verify current route access if user is authenticated + useEffect(() => { + if (isSpeedMode) return; + if (isAuthChecking) return; + if (!isAuthenticated) return; + if (isLocaleHomePath(pathname)) return; + + // If we reach here, user is authenticated but accessing a route + // that previously triggered auth prompt - ensure prompt is closed + if (isAuthPromptModalOpen) { + closeAuthPromptModal(); + } + }, [isAuthChecking, isAuthenticated, pathname, isAuthPromptModalOpen, isSpeedMode, closeAuthPromptModal]); + + return { + // Login/Register Modal + isLoginModalOpen, + openLoginModal, + closeLoginModal, + isRegisterModalOpen, + openRegisterModal, + closeRegisterModal, + + // Auth prompt modal + isAuthPromptModalOpen, + openAuthPromptModal, + closeAuthPromptModal, + + // Session expired modal + isSessionExpiredModalOpen, + openSessionExpiredModal, + closeSessionExpiredModal, + }; +} diff --git a/frontend/hooks/auth/useAuthorization.ts b/frontend/hooks/auth/useAuthorization.ts new file mode 100644 index 000000000..29ce64746 --- /dev/null +++ b/frontend/hooks/auth/useAuthorization.ts @@ -0,0 +1,277 @@ +"use client"; + +import { useState, useEffect, useLayoutEffect, useCallback } from "react"; +import { useQuery } from "@tanstack/react-query"; +import { useRouter, usePathname } from "next/navigation"; +import { User, AuthInfoResponse, AuthorizationContextType } from "@/types/auth"; +import { getSessionFromStorage } from "@/lib/session"; +import { authService } from "@/services/authService"; +import { authEvents, authzEvents, authzEventUtils } from "@/lib/authEvents"; +import { AUTH_EVENTS, AUTHZ_EVENTS } from "@/const/auth"; +import { getEffectiveRoutePath } from "@/lib/auth"; +import log from "@/lib/logger"; + +/** + * Custom hook for authorization management + * Handles user permissions, accessible routes, and React Query caching + */ +export function useAuthorization(): AuthorizationContextType { + const router = useRouter(); + const pathname = usePathname(); + + const [user, setUser] = useState(null); + const [permissions, setPermissions] = useState([]); + const [accessibleRoutes, setAccessibleRoutes] = useState([]); + const [lastCheckedPath, setLastCheckedPath] = useState(null); + + // Authz prompt modal state (permission denied) + const [isAuthzPromptModalOpen, setIsAuthzPromptModalOpen] = useState(false); + + // True when authorization data is ready (permissions loaded) + const [isAuthzReady, setIsAuthzReady] = useState(false); + + // Query for current user authorization info + // enabled: false prevents automatic query execution on mount + const { + data: currentUserInfo, + isLoading, + error, + refetch, + } = useQuery({ + queryKey: ["currentUserInfo"], + queryFn: async (): Promise => { + const result = await authService.getCurrentUserInfo(); + if (!result) { + throw new Error("Failed to fetch user info"); + } + return result; + }, + enabled: false, // Disabled by default, enabled by manual refetch() calls + staleTime: 5 * 60 * 1000, // 5 minutes + gcTime: 10 * 60 * 1000, // 10 minutes + refetchInterval: 10 * 60 * 1000, // Auto refresh every 10 minutes + refetchIntervalInBackground: true, + refetchOnWindowFocus: true, + refetchOnReconnect: true, + }); + + // Update state when authorization data is received + useEffect(() => { + if (isLoading) return; + // Handle API error or null response (e.g., token expired) + if (!currentUserInfo) { + log.warn("Failed to get user info, clearing authorization state"); + setUser(null); + setPermissions([]); + setAccessibleRoutes([]); + return; + } + + if (typeof currentUserInfo === "object") { + // API returns: { user: { permissions, accessibleRoutes, ...userInfo } } + const { user } = currentUserInfo; + + if (user) { + const { permissions, accessibleRoutes, ...userInfo } = user; + // Only update if we have permissions (full user info) + if (permissions && accessibleRoutes) { + setUser(userInfo as User); + setPermissions(permissions); + setAccessibleRoutes(accessibleRoutes); + setIsAuthzReady(true); + authzEventUtils.emitPermissionsReady({ + ...userInfo, + permissions, + accessibleRoutes, + }); + + + } else { + log.warn("Missing permissions or accessibleRoutes in user info", { + hasPermissions: !!permissions, + hasAccessibleRoutes: !!accessibleRoutes, + }); + } + } + } + }, [currentUserInfo, isLoading, error]); + + // Listen for authentication events + useEffect(() => { + // Handle login success - set user info immediately, then fetch full permissions + const handleLoginSuccess = () => { + refetch().then((result) => { + // Manually process the data if refetch succeeded + // This is needed because with enabled: false, React Query might not update data automatically + // Check both status string and isSuccess boolean for compatibility + if (result.data && (result.status === 'success' || result.isSuccess)) { + const { user } = result.data; + + if (user) { + const { permissions, accessibleRoutes, ...userInfo } = user; + + if (permissions && accessibleRoutes) { + setUser(userInfo as User); + setPermissions(permissions); + setAccessibleRoutes(accessibleRoutes); + setIsAuthzReady(true); + + authzEventUtils.emitPermissionsReady({ + ...userInfo, + permissions, + accessibleRoutes, + }); + } else { + log.warn("Missing permissions or accessibleRoutes in refetch result"); + } + } + } + }).catch((error) => { + log.error("Refetch failed:", error); + }); + }; + + // Handle logout - clear authorization data + const handleLogout = () => { + log.info("User logged out, clearing authorization data..."); + setUser(null); + setPermissions([]); + setAccessibleRoutes([]); + setIsAuthzReady(false); + }; + + // Handle session expired - clear authorization data + const handleSessionExpired = () => { + log.info("Session expired, clearing authorization data..."); + setUser(null); + setPermissions([]); + setAccessibleRoutes([]); + setIsAuthzReady(false); + }; + + // Register event listeners + const cleanupLogin = authEvents.on( + AUTH_EVENTS.LOGIN_SUCCESS, + handleLoginSuccess + ); + const cleanupLogout = authEvents.on(AUTH_EVENTS.LOGOUT, handleLogout); + const cleanupSessionExpired = authEvents.on( + AUTH_EVENTS.SESSION_EXPIRED, + handleSessionExpired + ); + + return () => { + cleanupLogin(); + cleanupLogout(); + cleanupSessionExpired(); + }; + }, [refetch]); + + // Initialize authorization data on mount if user is already authenticated + useEffect(() => { + const initializeAuthz = () => { + const session = getSessionFromStorage(); + if (session?.access_token) { + const now = Date.now(); + const expiresAt = session.expires_at * 1000; + + if (expiresAt > now) { + log.info( + "Valid session found on initialization, fetching authorization info..." + ); + refetch().catch((error) => { + log.error("Initial refetch error:", error); + }); + } + } + }; + + // Small delay to ensure authentication state is initialized + const timeoutId = setTimeout(initializeAuthz, 100); + return () => clearTimeout(timeoutId); + }, [refetch]); + + // Authz prompt modal control functions (defined before useLayoutEffect) + const openAuthzPromptModal = useCallback(() => setIsAuthzPromptModalOpen(true), []); + const closeAuthzPromptModal = useCallback(() => setIsAuthzPromptModalOpen(false), []); + + // Check if current route has access (computed on each render) + const cleanPath = getEffectiveRoutePath(pathname); + const hasAccess = accessibleRoutes.includes(cleanPath); + + // Route guard - check authorization when pathname changes + // Use useLayoutEffect to prevent flash of unauthorized content + useLayoutEffect(() => { + // Skip if still loading authorization data + if (isLoading) return; + + // Skip if no user (not authenticated) - authentication should be handled by useAuthentication + if (!user) return; + + // Skip if no accessible routes loaded yet + if (accessibleRoutes.length === 0) return; + + // Skip if pathname hasn't changed + if (pathname === lastCheckedPath) return; + + if (!hasAccess) { + log.warn("Access denied to route:", { pathname: cleanPath, accessibleRoutes }); + // Only show authz prompt if user is fully authenticated + if (user) { + openAuthzPromptModal(); + } + // Use setTimeout to ensure redirect happens after current render cycle + setTimeout(() => { + router.replace("/"); + }, 0); + return; + } + + // Update last checked path to avoid redundant checks + setLastCheckedPath(pathname); + }, [pathname, isLoading, user, accessibleRoutes, lastCheckedPath, hasAccess, router, openAuthzPromptModal]); + + // Permission checking utilities + const hasPermission = (permission: string): boolean => { + return permissions.includes(permission); + }; + + const hasAnyPermission = (requiredPermissions: string[]): boolean => { + return requiredPermissions.some((permission) => + permissions.includes(permission) + ); + }; + + const canAccessRoute = (route: string): boolean => { + return accessibleRoutes.includes(route); + }; + + return { + // Authorization data + user, + permissions, + accessibleRoutes, + + // State + isLoading, + error: error as Error | null, + + // Authorization status + // True when authorization is complete and user has permission to access current route + isAuthorized: !isLoading && !!user && hasAccess, + + // True when authorization data is ready (permissions loaded) + isAuthzReady, + + // Methods + refetch, + hasPermission, + hasAnyPermission, + canAccessRoute, + + // Authz prompt modal (permission denied) + isAuthzPromptModalOpen, + openAuthzPromptModal, + closeAuthzPromptModal, + }; +} diff --git a/frontend/hooks/auth/useSessionManager.ts b/frontend/hooks/auth/useSessionManager.ts new file mode 100644 index 000000000..5b2c27254 --- /dev/null +++ b/frontend/hooks/auth/useSessionManager.ts @@ -0,0 +1,216 @@ +"use client"; + +import { useCallback, useEffect } from "react"; +import { usePathname } from "next/navigation"; + +import { useDeployment } from "@/components/providers/deploymentProvider"; +import { sessionService } from "@/services/sessionService"; +import { + getSessionFromStorage, + saveSessionToStorage, + removeSessionFromStorage, + checkSessionValid as checkSessionValidFn, + checkSessionExpired as checkSessionExpiredFn, + handleSessionExpired, +} from "@/lib/session"; +import { authEventUtils } from "@/lib/authEvents"; +import { + TOKEN_REFRESH_BEFORE_EXPIRY_MS, + MIN_ACTIVITY_CHECK_INTERVAL_MS, +} from "@/const/constants"; +import log from "@/lib/logger"; + +import type { Session } from "@/types/auth"; + +// ============================================================================ +// Utility functions - Session status checking +// ============================================================================ + +// Re-export from lib/session for convenience +export const checkSessionValid = checkSessionValidFn; +export const checkSessionExpired = checkSessionExpiredFn; + +/** + * Check if token is expiring soon (within threshold) + */ +export const isSessionExpiringSoon = (): boolean => { + const session = getSessionFromStorage(); + if (!session?.expires_at) return false; + + const now = Date.now(); + const msUntilExpiry = session.expires_at * 1000 - now; + + // Token must not be already expired, and remaining time must be within threshold + return msUntilExpiry > 0 && msUntilExpiry <= TOKEN_REFRESH_BEFORE_EXPIRY_MS; +}; + +// ============================================================================ +// Session operations +// ============================================================================ + +/** + * Save session to localStorage + */ +export const saveSession = (session: Session): void => { + saveSessionToStorage(session); + log.info("Session saved to localStorage"); +}; + +/** + * Clear session and emit expired event + * Unified handling for session expiration + */ +export const clearSession = (): void => { + handleSessionExpired(); +}; + +// ============================================================================ +// Business logic functions +// ============================================================================ + +/** + * Refresh session if needed (token expiring soon) + * Uses refresh_token to get new access_token + * @returns Whether refresh was successful + */ +export const refreshSessionIfNeeded = async (): Promise => { + const session = getSessionFromStorage(); + if (!session?.refresh_token) { + return false; + } + + const newSession = await sessionService.refreshToken(session.refresh_token); + if (newSession) { + saveSession(newSession); + log.info("Session refreshed successfully"); + return true; + } + + log.warn("Session refresh failed"); + return false; +}; + +/** + * Unified handler for session expiration + * Called when session is confirmed expired + */ +export const sessionExpiredHandler = (): void => { + log.info("Handling session expiration"); + handleSessionExpired(); +}; + + // ============================================================================ + // Hook implementation + // ============================================================================ + +export function useSessionManager() { + const { isSpeedMode } = useDeployment(); + + // Initialize session management when hook is used + useEffect(() => { + // In speed mode, skip session validation + if (isSpeedMode) return; + + if (checkSessionValid()) { + // Session is valid, no action needed + return; + } + + // Session is expired or invalid + handleSessionExpired(); + }, [isSpeedMode]); + + /** + * Setup automatic token refresh on user activity + * Refreshes token before expiry to implement sliding expiration + */ + const setupTokenAutoRefresh = useCallback(() => { + // Skip in speed mode + if (isSpeedMode) return () => {}; + + let lastActivityCheckAt = 0; + + const maybeRefreshOnActivity = async () => { + try { + // Throttle activity-driven checks + const now = Date.now(); + if (now - lastActivityCheckAt < MIN_ACTIVITY_CHECK_INTERVAL_MS) return; + lastActivityCheckAt = now; + + // Do not run when page is hidden + if (typeof document !== "undefined" && document.hidden) return; + + // Check if token is expiring soon + if (isSessionExpiringSoon()) { + const success = await refreshSessionIfNeeded(); + + // If refresh failed, it means refresh_token is also invalid + // The session will be cleared when backend returns 401 or when fetchWithAuth checks + if (!success) { + log.debug("Token refresh failed, waiting for 401 from backend"); + } + } + } catch (error) { + log.error("Activity-based refresh check failed:", error); + } + }; + + const events: (keyof DocumentEventMap | keyof WindowEventMap)[] = [ + "click", + "keydown", + "mousemove", + "touchstart", + "focus", + "visibilitychange", + ]; + + const handler = () => { + void maybeRefreshOnActivity(); + }; + + events.forEach((evt) => { + if (evt === "focus" || evt === "visibilitychange") { + window.addEventListener(evt as any, handler, { passive: true }); + } else { + document.addEventListener(evt as any, handler, { passive: true }); + } + }); + + return () => { + events.forEach((evt) => { + if (evt === "focus" || evt === "visibilitychange") { + window.removeEventListener(evt as any, handler); + } else { + document.removeEventListener(evt as any, handler); + } + }); + }; + }, [isSpeedMode]); + + // Setup auto refresh + useEffect(() => { + const cleanupAutoRefresh = setupTokenAutoRefresh(); + return () => { + cleanupAutoRefresh?.(); + }; + }, [setupTokenAutoRefresh]); + + return { + // Utility functions + checkSessionValid, + checkSessionExpired, + isSessionExpiringSoon, + + // Session operations + saveSession, + clearSession, + handleSessionExpired, + + // Business logic + refreshSessionIfNeeded, + sessionExpiredHandler, + + // Legacy functions + setupTokenAutoRefresh, + }; +} diff --git a/frontend/hooks/permission/usePermission.ts b/frontend/hooks/permission/usePermission.ts new file mode 100644 index 000000000..58af60cf2 --- /dev/null +++ b/frontend/hooks/permission/usePermission.ts @@ -0,0 +1,35 @@ +"use client"; + +import { useAuthorization } from "@/hooks/auth/useAuthorization"; +import { useAuthentication } from "@/hooks/auth/useAuthentication"; + +export function usePermission() { + const { hasPermission, hasAnyPermission, isAuthzReady, isLoading } = useAuthorization(); + const { isAuthenticated } = useAuthentication(); + + return { + isReady: isAuthzReady, + isAuthenticated, + isLoading, + + can: (permission: string): boolean => { + if (!isAuthenticated || !isAuthzReady) return false; + return hasPermission(permission); + }, + + cannot: (permission: string): boolean => { + if (!isAuthenticated || !isAuthzReady) return true; + return !hasPermission(permission); + }, + + canAny: (perms: string[]): boolean => { + if (!isAuthenticated || !isAuthzReady) return false; + return hasAnyPermission(perms); + }, + + canAll: (perms: string[]): boolean => { + if (!isAuthenticated || !isAuthzReady) return false; + return perms.every(p => hasPermission(p)); + }, + }; +} diff --git a/frontend/hooks/useAuth.ts b/frontend/hooks/useAuth.ts deleted file mode 100644 index 5b2f1074f..000000000 --- a/frontend/hooks/useAuth.ts +++ /dev/null @@ -1,376 +0,0 @@ -"use client" - -import { useState, useEffect, useContext, createContext, type ReactNode } from "react" -import { usePathname } from "next/navigation" -import { useTranslation } from "react-i18next" - -import { App } from "antd" -import { USER_ROLES } from "@/const/modelConfig" - -import { authService } from "@/services/authService" -import { configService } from "@/services/configService" -import { User, AuthContextType } from "@/types/auth" -import { EVENTS, STATUS_CODES } from "@/const/auth" -import { getSessionFromStorage, removeSessionFromStorage } from "@/lib/auth" -import log from "@/lib/logger" -import { useDeployment } from "@/components/providers/deploymentProvider" - -// Create auth context -const AuthContext = createContext(undefined) - -// Auth provider component -export function AuthProvider({ children }: { children: (value: AuthContextType) => ReactNode }) { - const { t } = useTranslation('common'); - const { message } = App.useApp(); - const { isSpeedMode } = useDeployment(); // 从 deployment context 获取 - const [user, setUser] = useState(null) - const [isLoading, setIsLoading] = useState(true) - const [isLoginModalOpen, setIsLoginModalOpen] = useState(false) - const [isRegisterModalOpen, setIsRegisterModalOpen] = useState(false) - const [isFromSessionExpired, setIsFromSessionExpired] = useState(false) - const [shouldCheckSession, setShouldCheckSession] = useState(false) - const [authServiceUnavailable, setAuthServiceUnavailable] = useState(false) - const pathname = usePathname() - - // Check auth service availability - const checkAuthService = async () => { - const isAvailable = await authService.checkAuthServiceAvailable() - setAuthServiceUnavailable(!isAvailable) - return isAvailable - } - - // When login or register modal is opened, check auth service availability - useEffect(() => { - if (isLoginModalOpen || isRegisterModalOpen) { - checkAuthService() - } - }, [isLoginModalOpen, isRegisterModalOpen]); - - // Auto login function (for speed mode) - const performAutoLogin = async () => { - try { - // Use mock credentials for auto login - await login('mock@example.com', 'mockpassword', false); - } catch (error) { - log.error('Auto-login failed:', error); - } - }; - - // When initializing, check user session (only read from local storage, not request backend) - useEffect(() => { - const syncUserFromLocalStorage = () => { - const storedSession = typeof window !== "undefined" ? localStorage.getItem("session") : null; - if (storedSession) { - try { - const session = JSON.parse(storedSession); - - // Check if token is expired before setting user - if (session?.expires_at) { - const now = Date.now(); - const expiresAt = session.expires_at * 1000; - if (expiresAt <= now) { - // Token expired, clear session and don't set user - log.warn("Token expired on initialization, clearing session"); - removeSessionFromStorage(); - setUser(null); - setShouldCheckSession(false); - return; - } - } - - if (session?.user) { - const safeUser: User = { - id: session.user.id, - email: session.user.email, - role: session.user.role === USER_ROLES.ADMIN ? USER_ROLES.ADMIN : USER_ROLES.USER, - avatar_url: session.user.avatar_url - }; - setUser(safeUser); - setShouldCheckSession(true); // When there is a user, enable session check - return; - } - } catch (e) { - // ignore parse error - } - } - setUser(null); - setShouldCheckSession(false); // When there is no user, disable session check - }; - - setIsLoading(true); - syncUserFromLocalStorage(); - setIsLoading(false); - - // Listen to local session change - const handleStorage = (event: StorageEvent) => { - if (event.key === "session") { - syncUserFromLocalStorage(); - } - }; - window.addEventListener("storage", handleStorage); - return () => { - window.removeEventListener("storage", handleStorage); - }; - }, []); - - // Check user login status (skip in speed mode) - useEffect(() => { - if (isSpeedMode) return; - if (!isLoading && !user) { - // When page is loaded, if not logged in, trigger session expired event - // Only trigger on non-home path, and only when there is a session before - if (pathname && pathname !== '/' && !pathname.startsWith('/?') && shouldCheckSession) { - window.dispatchEvent(new CustomEvent(EVENTS.SESSION_EXPIRED, { - detail: { message: t('auth.sessionExpired') } - })); - setShouldCheckSession(false); // After triggering the expired event, disable session check - } - } - }, [user, isLoading, pathname, shouldCheckSession, t, isSpeedMode]); - - // Session validity check, ensure the session in local storage is not expired (skip in speed mode) - useEffect(() => { - if (isSpeedMode || !user || isLoading || !shouldCheckSession) return; - - const verifySession = () => { - const lastVerifyTime = Number(localStorage.getItem('lastSessionVerifyTime') || 0); - const now = Date.now(); - // If the last verification is less than 10 seconds, skip - if (now - lastVerifyTime < 10000) { - return; - } - - try { - const sessionObj = getSessionFromStorage(); - if (!sessionObj || sessionObj.expires_at * 1000 <= now) { - // Session does not exist or has expired - window.dispatchEvent(new CustomEvent(EVENTS.SESSION_EXPIRED, { - detail: { message: t('auth.sessionExpired') } - })); - setShouldCheckSession(false); - } - - localStorage.setItem('lastSessionVerifyTime', now.toString()); - } catch (error) { - log.error('Session validation failed:', error); - } - }; - - // Immediately execute once - verifySession(); - - // Poll every 10 seconds - const intervalId = setInterval(verifySession, 10000); - - return () => clearInterval(intervalId); - }, [isSpeedMode, user, isLoading, shouldCheckSession, t]); - - const openLoginModal = () => { - setIsRegisterModalOpen(false) - setIsLoginModalOpen(true) - } - - const closeLoginModal = () => { - setIsLoginModalOpen(false) - } - - const openRegisterModal = () => { - setIsLoginModalOpen(false) - setIsRegisterModalOpen(true) - } - - const closeRegisterModal = () => { - setIsRegisterModalOpen(false) - } - - const login = async (email: string, password: string, showSuccessMessage: boolean = true) => { - try { - setIsLoading(true) - - // First check auth service availability - const isAuthServiceAvailable = await authService.checkAuthServiceAvailable() - if (!isAuthServiceAvailable) { - const error = new Error(t('auth.authServiceUnavailable')) - ;(error as any).code = STATUS_CODES.AUTH_SERVICE_UNAVAILABLE - throw error - } - - const { data, error } = await authService.signIn(email, password) - - if (error) { - log.error("Login failed: ", error.message) - throw error - } - - if (data?.session?.user) { - // Ensure role field is "user" or "admin" - const safeUser: User = { - id: data.session.user.id, - email: data.session.user.email, - role: data.session.user.role === USER_ROLES.ADMIN ? USER_ROLES.ADMIN : USER_ROLES.USER, - avatar_url: data.session.user.avatar_url - } - setUser(safeUser) - setShouldCheckSession(true) // After login, enable session check - - // Add delay to ensure local storage operation is completed - setTimeout(() => { - configService.loadConfigToFrontend() - closeLoginModal() - - if (showSuccessMessage) { - message.success(t('auth.loginSuccess')) - } - // Manually trigger storage event - window.dispatchEvent(new StorageEvent("storage", { key: "session", newValue: localStorage.getItem("session") })) - - // If on the chat page, trigger conversation list update - if (pathname.includes('/chat')) { - window.dispatchEvent(new CustomEvent('conversationListUpdated')) - } - }, 150) - } - } catch (error: any) { - log.error("Error during login process:", error.message) - throw error - } finally { - setIsLoading(false) - } - } - - const register = async (email: string, password: string, isAdmin?: boolean, inviteCode?: string) => { - try { - setIsLoading(true) - - // First check auth service availability - const isAuthServiceAvailable = await authService.checkAuthServiceAvailable() - if (!isAuthServiceAvailable) { - const error = new Error(t('auth.authServiceUnavailable')) - ;(error as any).code = STATUS_CODES.AUTH_SERVICE_UNAVAILABLE - throw error - } - - const { data, error } = await authService.signUp(email, password, isAdmin, inviteCode) - - if (error) { - throw error - } - - if (data?.user) { - // Ensure role field is "user" or "admin" - const safeUser: User = { - id: data.user.id, - email: data.user.email, - role: data.user.role === USER_ROLES.ADMIN ? USER_ROLES.ADMIN : USER_ROLES.USER, - avatar_url: data.user.avatar_url - } - - if (data.session) { - // Register and login successfully - setUser(safeUser) - configService.loadConfigToFrontend() - closeRegisterModal() - const successMessage = isAdmin ? t('auth.adminRegisterSuccessAutoLogin') : t('auth.registerSuccessAutoLogin') - message.success(successMessage) - // Manually trigger storage event - window.dispatchEvent(new StorageEvent("storage", { key: "session", newValue: localStorage.getItem("session") })) - } else { - // Register successfully but need to manually login - closeRegisterModal() - openLoginModal() - const successMessage = isAdmin ? t('auth.adminRegisterSuccessManualLogin') : t('auth.registerSuccessManualLogin') - message.success(successMessage) - } - } - } catch (error: any) { - throw error - } finally { - setIsLoading(false) - } - } - - // Clear local session state without calling backend API - // Used when session is already expired on the backend - const clearLocalSession = () => { - removeSessionFromStorage() - setUser(null) - setShouldCheckSession(false) - // Manually trigger storage event - window.dispatchEvent(new StorageEvent("storage", { key: "session", newValue: null })) - } - - const logout = async (options?: { silent?: boolean }) => { - try { - setIsLoading(true) - await authService.signOut() - setUser(null) - // When logging out, disable session check - setShouldCheckSession(false) - // Only show message when user actively logout - if (!options?.silent) { - message.success(t("auth.logoutSuccess")) - } - // Manually trigger storage event - window.dispatchEvent(new StorageEvent("storage", { key: "session", newValue: null })) - } catch (error: any) { - log.error("Logout failed:", error.message) - message.error(t('auth.logoutFailed')) - } finally { - setIsLoading(false) - } - } - - const revoke = async () => { - try { - setIsLoading(true); - await authService.revoke(); - setUser(null); - setShouldCheckSession(false); - message.success(t("auth.revokeSuccess")); - // Manually trigger storage event - window.dispatchEvent( - new StorageEvent("storage", { key: "session", newValue: null }) - ); - } catch (error: any) { - log.error("Revoke failed:", error?.message || error); - message.error(t("auth.revokeFailed")); - } finally { - setIsLoading(false); - } - }; - - const contextValue: AuthContextType = { - user, - isLoading, - isLoginModalOpen, - isRegisterModalOpen, - isFromSessionExpired, - authServiceUnavailable, - isSpeedMode, - isAuthReady: true, // 认证部分总是就绪 - openLoginModal, - closeLoginModal, - openRegisterModal, - closeRegisterModal, - setIsFromSessionExpired, - login, - register, - logout, - clearLocalSession, - revoke, - }; - - return children(contextValue); -} - -// Custom hook for accessing auth context -export function useAuth(): AuthContextType { - const context = useContext(AuthContext) - if (context === undefined) { - throw new Error("useAuth must be used within an AuthProvider") - } - return context -} - -// Export auth context for Provider use -export { AuthContext } \ No newline at end of file diff --git a/frontend/hooks/useSetupFlow.ts b/frontend/hooks/useSetupFlow.ts index 3a03d7582..b9c0923fb 100644 --- a/frontend/hooks/useSetupFlow.ts +++ b/frontend/hooks/useSetupFlow.ts @@ -1,13 +1,6 @@ -import {useState, useEffect, useRef} from "react"; import {useRouter} from "next/navigation"; import {useTranslation} from "react-i18next"; -import {useAuth} from "@/hooks/useAuth"; -import { - USER_ROLES, -} from "@/const/modelConfig"; -import {EVENTS} from "@/const/auth"; - interface UseSetupFlowOptions { /** Whether admin role is required to access this page */ requireAdmin?: boolean; @@ -16,12 +9,6 @@ interface UseSetupFlowOptions { } interface UseSetupFlowReturn { - // Auth related - user: any; - isLoading: boolean; - isSpeedMode: boolean; - canAccessProtectedData: boolean; - // Animation config pageVariants: { initial: { opacity: number; x: number }; @@ -43,68 +30,20 @@ interface UseSetupFlowReturn { * useSetupFlow - Custom hook for setup flow pages * * Provides common functionality for setup pages including: - * - Authentication and permission checks - * - Session expiration handling + * - Admin permission checks (if required) * - Page transition animations + * - Common utilities (router, translation) + * + * Note: Authentication and authorization are now handled by the global + * useAuthentication and useAuthorization hooks via route guards. * * @param options - Configuration options * @returns Setup flow utilities and state */ export function useSetupFlow(options: UseSetupFlowOptions = {}): UseSetupFlowReturn { - const { - requireAdmin = false, - nonAdminRedirect = "/setup/knowledges", - } = options; const router = useRouter(); const {t} = useTranslation(); - const {user, isLoading: userLoading, isSpeedMode} = useAuth(); - const sessionExpiredTriggeredRef = useRef(false); - - // Calculate if user can access protected data - const canAccessProtectedData = isSpeedMode || (!userLoading && !!user); - - - - // Check login status and handle session expiration - useEffect(() => { - if (isSpeedMode) { - sessionExpiredTriggeredRef.current = false; - return; - } - - if (user) { - sessionExpiredTriggeredRef.current = false; - return; - } - - // Trigger session expired event if user is not logged in - if (!userLoading && !sessionExpiredTriggeredRef.current) { - sessionExpiredTriggeredRef.current = true; - window.dispatchEvent( - new CustomEvent(EVENTS.SESSION_EXPIRED, { - detail: {message: "Session expired, please sign in again"}, - }) - ); - } - }, [isSpeedMode, user, userLoading]); - - // Check admin permission if required - useEffect(() => { - if (!requireAdmin) return; - - // Only check after user is loaded - if (userLoading) return; - - // Speed mode always has access - if (isSpeedMode) return; - - // Check if user has admin role - if (user && user.role !== USER_ROLES.ADMIN) { - router.push(nonAdminRedirect); - } - }, [requireAdmin, isSpeedMode, user, userLoading, router, nonAdminRedirect]); - // Animation variants for smooth page transitions const pageVariants = { @@ -129,12 +68,6 @@ export function useSetupFlow(options: UseSetupFlowOptions = {}): UseSetupFlowRet }; return { - // Auth - user, - isLoading: userLoading, - isSpeedMode, - canAccessProtectedData, - // Animation pageVariants, pageTransition, diff --git a/frontend/lib/auth.ts b/frontend/lib/auth.ts index 6510f8ebd..9d2b887dd 100644 --- a/frontend/lib/auth.ts +++ b/frontend/lib/auth.ts @@ -2,21 +2,29 @@ * Authentication utilities */ -import { fetchWithErrorHandling, ApiError } from "@/services/api"; -import { STORAGE_KEYS, STATUS_CODES } from "@/const/auth"; -import { Session } from "@/types/auth"; +import { fetchWithErrorHandling } from "@/services/api"; +import { STORAGE_KEYS } from "@/const/auth"; import { generateAvatarUrl as generateAvatar } from "@/lib/avatar"; -import log from "@/lib/logger"; +import { USER_ROLES } from "@/const/auth"; -// Get color corresponding to user role +/** + * Role color mapping - Ant Design color presets + */ +const ROLE_COLORS: Record = { + [USER_ROLES.SU]: "red", + [USER_ROLES.ADMIN]: "purple", + [USER_ROLES.DEV]: "cyan", + [USER_ROLES.USER]: "geekblue", + [USER_ROLES.SPEED]: "green", +}; + +/** + * Get color corresponding to user role + * @param role - User role string + * @returns Ant Design color preset name + */ export function getRoleColor(role: string): string { - switch (role) { - case "admin": - return "purple" - case "user": - default: - return "geekblue" - } + return ROLE_COLORS[role] || ROLE_COLORS[USER_ROLES.USER]; } // Generate avatar based on email (re-export from avatar.tsx for backward compatibility) @@ -26,37 +34,22 @@ export function generateAvatarUrl(email: string): string { /** * Request with authorization headers - * Checks token expiration before sending request to prevent sending expired tokens + * Only builds the request with auth token - no expiration checking + * Expiration should be handled when backend returns 401 */ export const fetchWithAuth = async (url: string, options: RequestInit = {}) => { - const session = typeof window !== "undefined" ? localStorage.getItem(STORAGE_KEYS.SESSION) : null; + const session = + typeof window !== "undefined" + ? localStorage.getItem(STORAGE_KEYS.SESSION) + : null; const sessionObj = session ? JSON.parse(session) : null; - // Check if token is expired before sending request - if (sessionObj?.access_token) { - const now = Date.now(); - const expiresAt = sessionObj.expires_at ? sessionObj.expires_at * 1000 : 0; - - // If token is expired, clear session and throw error - if (expiresAt > 0 && expiresAt <= now) { - log.warn("Token expired, clearing session before request"); - removeSessionFromStorage(); - - // Dispatch session expired event - if (typeof window !== "undefined" && window.dispatchEvent) { - window.dispatchEvent(new CustomEvent('session-expired', { - detail: { message: "Login expired, please login again" } - })); - } - - throw new ApiError(STATUS_CODES.TOKEN_EXPIRED, "Login expired, please login again"); - } - } - const isFormData = options.body instanceof FormData; const headers = { ...(isFormData ? {} : { "Content-Type": "application/json" }), - ...(sessionObj?.access_token && { "Authorization": `Bearer ${sessionObj.access_token}` }), + ...(sessionObj?.access_token && { + Authorization: `Bearer ${sessionObj.access_token}`, + }), ...options.headers, }; @@ -67,50 +60,31 @@ export const fetchWithAuth = async (url: string, options: RequestInit = {}) => { }); }; -/** - * Save session to local storage - */ -export const saveSessionToStorage = (session: Session) => { - if (typeof window !== "undefined") { - localStorage.setItem(STORAGE_KEYS.SESSION, JSON.stringify(session)); - } -}; - -/** - * Remove session from local storage - */ -export const removeSessionFromStorage = () => { - if (typeof window !== "undefined") { - localStorage.removeItem(STORAGE_KEYS.SESSION); - } -}; - -/** - * Get session from local storage - */ -export const getSessionFromStorage = (): Session | null => { - try { - const storedSession = typeof window !== "undefined" ? localStorage.getItem(STORAGE_KEYS.SESSION) : null; - if (!storedSession) return null; - - return JSON.parse(storedSession); - } catch (error) { - log.error("Failed to parse session info:", error); - return null; - } -}; - /** * Get the authorization header information for API requests * @returns HTTP headers object containing authentication and content type information */ export const getAuthHeaders = () => { - const session = typeof window !== "undefined" ? localStorage.getItem("session") : null; + const session = + typeof window !== "undefined" ? localStorage.getItem("session") : null; const sessionObj = session ? JSON.parse(session) : null; return { - 'Content-Type': 'application/json', - 'User-Agent': 'AgentFrontEnd/1.0', - ...(sessionObj?.access_token && { "Authorization": `Bearer ${sessionObj.access_token}` }), + "Content-Type": "application/json", + "User-Agent": "AgentFrontEnd/1.0", + ...(sessionObj?.access_token && { + Authorization: `Bearer ${sessionObj.access_token}`, + }), }; -}; \ No newline at end of file +}; + +/** + * Remove locale prefix from pathname to get effective route + */ +export function getEffectiveRoutePath(pathname: string): string { + const segments = pathname.split("/").filter(Boolean); + if (segments.length > 0 && (segments[0] === "zh" || segments[0] === "en")) { + segments.shift(); + } + return "/" + (segments.join("/") || ""); +} \ No newline at end of file diff --git a/frontend/lib/authEvents.ts b/frontend/lib/authEvents.ts new file mode 100644 index 000000000..f886630c2 --- /dev/null +++ b/frontend/lib/authEvents.ts @@ -0,0 +1,83 @@ +/** + * Authentication and Authorization Event System + * Provides type-safe event communication between authentication and authorization modules + */ + +import log from "@/lib/logger"; +import { AUTH_EVENTS, AUTHZ_EVENTS } from "@/const/auth"; + +// Event emitter for authentication events +class AuthEventEmitter { + emit( + event: K, + data?: import("@/types/auth").AuthEvents[K] + ) { + log.debug(`Auth event emitted: ${event}`, data); + window.dispatchEvent(new CustomEvent(event, { detail: data })); + } + + on( + event: K, + handler: (data?: import("@/types/auth").AuthEvents[K]) => void + ) { + const listener = (e: CustomEvent) => handler(e.detail); + window.addEventListener(event, listener as EventListener); + + // Return cleanup function + return () => { + window.removeEventListener(event, listener as EventListener); + }; + } +} + +// Event emitter for authorization events +class AuthzEventEmitter { + emit( + event: K, + data?: import("@/types/auth").AuthzEvents[K] + ) { + log.debug(`Authz event emitted: ${event}`, data); + window.dispatchEvent(new CustomEvent(event, { detail: data })); + } + + on( + event: K, + handler: (data?: import("@/types/auth").AuthzEvents[K]) => void + ) { + const listener = (e: CustomEvent) => handler(e.detail); + window.addEventListener(event, listener as EventListener); + + // Return cleanup function + return () => { + window.removeEventListener(event, listener as EventListener); + }; + } +} + +// Global instances +export const authEvents = new AuthEventEmitter(); +export const authzEvents = new AuthzEventEmitter(); + +// Utility functions for common auth events +export const authEventUtils = { + emitLoginSuccess: () => authEvents.emit(AUTH_EVENTS.LOGIN_SUCCESS), + emitRegisterSuccess: () => authEvents.emit(AUTH_EVENTS.REGISTER_SUCCESS), + emitLogout: () => authEvents.emit(AUTH_EVENTS.LOGOUT), + emitSessionExpired: () => authEvents.emit(AUTH_EVENTS.SESSION_EXPIRED), + emitTokenRefreshed: () => authEvents.emit(AUTH_EVENTS.TOKEN_REFRESHED), + emitServiceUnavailable: () => + authEvents.emit(AUTH_EVENTS.SERVICE_UNAVAILABLE), + emitBackToHome: () => + authEvents.emit(AUTH_EVENTS.BACK_TO_HOME), +}; + +export const authzEventUtils = { + emitPermissionsReady: ( + userData: import("@/types/auth").User & { + permissions: string[]; + accessibleRoutes: string[]; + } + ) => authzEvents.emit(AUTHZ_EVENTS.PERMISSIONS_READY, userData), + emitPermissionsUpdated: () => + authzEvents.emit(AUTHZ_EVENTS.PERMISSIONS_UPDATED), +}; diff --git a/frontend/lib/session.ts b/frontend/lib/session.ts new file mode 100644 index 000000000..fe27bc4ea --- /dev/null +++ b/frontend/lib/session.ts @@ -0,0 +1,94 @@ +/** + * Session utilities + * Pure functions for session management - no React dependencies + */ + +import { STORAGE_KEYS } from "@/const/auth"; +import { Session } from "@/types/auth"; +import { authEventUtils } from "@/lib/authEvents"; +import log from "@/lib/logger"; + +// Flag to prevent duplicate session expiration handling +let isHandlingSessionExpired = false; + +/** + * Save session to local storage + */ +export const saveSessionToStorage = (session: Session): void => { + if (typeof window !== "undefined") { + localStorage.setItem(STORAGE_KEYS.SESSION, JSON.stringify(session)); + } +}; + +/** + * Remove session from local storage + */ +export const removeSessionFromStorage = (): void => { + if (typeof window !== "undefined") { + localStorage.removeItem(STORAGE_KEYS.SESSION); + } +}; + +/** + * Get session from local storage + */ +export const getSessionFromStorage = (): Session | null => { + try { + const storedSession = + typeof window !== "undefined" + ? localStorage.getItem(STORAGE_KEYS.SESSION) + : null; + if (!storedSession) return null; + + return JSON.parse(storedSession); + } catch (error) { + log.error("Failed to parse session info:", error); + return null; + } +}; + +/** + * Check if session is valid (exists and not expired) + */ +export const checkSessionValid = (): boolean => { + const session = getSessionFromStorage(); + if (!session?.access_token || !session?.expires_at) { + return false; + } + + const now = Date.now(); + return session.expires_at * 1000 > now; +}; + +/** + * Check if session has expired + */ +export const checkSessionExpired = (): boolean => { + return !checkSessionValid(); +}; + +/** + * Clear session and emit expired event + * Unified handling for session expiration with duplicate prevention + */ +export const handleSessionExpired = (): void => { + // Prevent duplicate triggers + if (isHandlingSessionExpired) { + return; + } + isHandlingSessionExpired = true; + + log.info("Session expired, clearing and emitting event"); + removeSessionFromStorage(); + + // Emit event asynchronously to ensure isAuthenticated state has been updated + // This fixes the closure trap where showSessionExpiredModal captures stale isAuthenticated value + setTimeout(() => { + authEventUtils.emitSessionExpired(); + }, 0); + + // Reset flag after 300ms to allow future triggers + setTimeout(() => { + isHandlingSessionExpired = false; + }, 300); +}; diff --git a/frontend/public/locales/en/common.json b/frontend/public/locales/en/common.json index 4424d0ff2..3a8e5d048 100644 --- a/frontend/public/locales/en/common.json +++ b/frontend/public/locales/en/common.json @@ -164,16 +164,9 @@ "Full access to enterprise knowledge base", "More accurate Q&A experience" ], - "page.loginPrompt.githubSupport": "⭐️ Nexent is still growing, please help me by starring on <1>GitHub, thank you.", - "page.loginPrompt.title": "Welcome to Nexent", - "page.loginPrompt.register": "Create New Account", - "page.loginPrompt.login": "Log in to account", - "page.loginPrompt.intro": "Please log in to your account to access this page.", "page.loginPrompt.githubSupport": "Support us on GitHub", "page.loginPrompt.noAccount": "Don't have an account yet? Click the \"Register\" button to create your exclusive account~", - "page.adminPrompt.title": "Oh, you are not an administrator", "page.adminPrompt.close": "OK", - "page.adminPrompt.intro": "Only administrators can adjust the configuration, please log in as an administrator first~", "page.adminPrompt.unlockHeader": "🌟 Become an administrator and unlock more capabilities!", "page.adminPrompt.unlockIntro": "After becoming an administrator, you can:", "page.adminPrompt.permissionsTitle": "✨ Administrator exclusive permissions:", diff --git a/frontend/public/locales/zh/common.json b/frontend/public/locales/zh/common.json index a1c525803..176d817ef 100644 --- a/frontend/public/locales/zh/common.json +++ b/frontend/public/locales/zh/common.json @@ -152,11 +152,6 @@ ], "page.copyright": "Nexent © {{year}}", "page.termsOfUse": "使用条款", - "page.loginPrompt.title": "登录账号", - "page.loginPrompt.register": "注册", - "page.loginPrompt.login": "立即登录", - "page.loginPrompt.header": "🚀 准备启航!", - "page.loginPrompt.intro": "登录您的账户,开启智能问答之旅~", "page.loginPrompt.benefitsTitle": "✨ 登录后您将获得:", "page.loginPrompt.benefits": [ "专属的对话历史记录", @@ -164,16 +159,13 @@ "企业知识库完整访问权限", "更精准的问答体验" ], - "page.loginPrompt.githubSupport": "⭐️ Nexent还在成长中,帮帮我到<1>GitHub加星支持我吧,谢谢你。", "page.loginPrompt.title": "欢迎使用 Nexent", "page.loginPrompt.register": "注册账户", "page.loginPrompt.login": "登录账户", "page.loginPrompt.intro": "请登录您的账户以访问此页面。", "page.loginPrompt.githubSupport": "在 GitHub 上支持我们", "page.loginPrompt.noAccount": "还没有账号?点击\"注册\"按钮创建您的专属账号~", - "page.adminPrompt.title": "啊哦,您不是管理员", "page.adminPrompt.close": "好的", - "page.adminPrompt.intro": "只有管理员可以调整配置,请先登录为管理员账号~", "page.adminPrompt.unlockHeader": "🌟 成为管理员,解锁更多能力!", "page.adminPrompt.unlockIntro": "成为管理员后,您可以:", "page.adminPrompt.permissionsTitle": "✨ 管理员专属权限:", @@ -184,7 +176,7 @@ ], "page.adminPrompt.title": "暂无权限", "page.adminPrompt.intro": "您暂时没有权限访问该页面,请联系管理员为您提升权限。", - "page.adminPrompt.githubSupport": "⭐️ Nexent还在成长中,帮帮我到<1>GitHub加星支持我吧,谢谢你。", + "page.adminPrompt.githubSupport": "在 GitHub 上支持我们", "page.adminPrompt.becomeAdmin": "💡 想成为管理员?请访问<1>官网联系页,申请管理员账号。", "chatStreamMessage.appIconAlt": "应用图标", diff --git a/frontend/services/api.ts b/frontend/services/api.ts index f23cd6aaf..c78b8c432 100644 --- a/frontend/services/api.ts +++ b/frontend/services/api.ts @@ -1,4 +1,5 @@ import { STATUS_CODES } from "@/const/auth"; +import { handleSessionExpired } from "@/lib/session"; import log from "@/lib/logger"; const API_BASE_URL = "/api"; @@ -11,6 +12,7 @@ export const API_ENDPOINTS = { logout: `${API_BASE_URL}/user/logout`, session: `${API_BASE_URL}/user/session`, currentUserId: `${API_BASE_URL}/user/current_user_id`, + currentUserInfo: `${API_BASE_URL}/user/current_user_info`, serviceHealth: `${API_BASE_URL}/user/service_health`, revoke: `${API_BASE_URL}/user/revoke`, }, @@ -359,36 +361,6 @@ export const fetchWithErrorHandling = async ( } }; -// Method to handle session expiration -function handleSessionExpired() { - // Prevent duplicate triggers - if (window.__isHandlingSessionExpired) { - return; - } - - // Mark as processing - window.__isHandlingSessionExpired = true; - - // Clear locally stored session information - if (typeof window !== "undefined") { - localStorage.removeItem("session"); - - // Use custom events to notify other components in the app (such as SessionExpiredListener) - if (window.dispatchEvent) { - // Ensure using event name consistent with EVENTS.SESSION_EXPIRED constant - window.dispatchEvent( - new CustomEvent("session-expired", { - detail: { message: "Login expired, please login again" }, - }) - ); - } - - // Reset flag after 300ms to allow future triggers - setTimeout(() => { - window.__isHandlingSessionExpired = false; - }, 300); - } -} // Add global interface extensions for TypeScript declare global { diff --git a/frontend/services/authService.ts b/frontend/services/authService.ts index cc724661f..ed03cb8c7 100644 --- a/frontend/services/authService.ts +++ b/frontend/services/authService.ts @@ -1,15 +1,15 @@ /** * Authentication service */ -import { USER_ROLES } from "@/const/modelConfig"; import { API_ENDPOINTS } from "@/services/api"; import { sessionService } from "@/services/sessionService"; -import { Session, User, SessionResponse } from "@/types/auth"; +import { Session, User, SessionResponse, AuthInfoResponse } from "@/types/auth"; import { STATUS_CODES } from "@/const/auth"; -import { generateAvatarUrl, removeSessionFromStorage } from "@/lib/auth" -import { fetchWithAuth, getSessionFromStorage, saveSessionToStorage } from "@/lib/auth"; +import { generateAvatarUrl } from "@/lib/auth"; +import { fetchWithAuth } from "@/lib/auth"; +import { removeSessionFromStorage, getSessionFromStorage, saveSessionToStorage } from "@/lib/session"; import log from "@/lib/logger"; @@ -22,13 +22,6 @@ export const authService = { const sessionObj = getSessionFromStorage(); if (!sessionObj?.access_token) return null; - // Check if the token is about to expire, if so, try to refresh - const isTokenValid = await sessionService.checkAndRefreshToken(); - if (!isTokenValid) { - log.warn("Token is invalid or refresh failed"); - // We do not immediately clear the session, but wait for subsequent operations to fail - } - try { // Verify if the session is valid const response = await fetchWithAuth(API_ENDPOINTS.user.session); @@ -52,20 +45,6 @@ export const authService = { return sessionObj; } - const data = await response.json(); - - // Update user information (possibly changed on the backend) - if (data.data?.user) { - sessionObj.user = { - ...sessionObj.user, - ...data.data.user, - avatar_url: sessionObj.user.avatar_url, // Keep avatar - }; - - // Update stored session - saveSessionToStorage(sessionObj); - } - return sessionObj; } catch (error) { log.error("Error verifying session:", error); @@ -157,7 +136,7 @@ export const authService = { id: data.data.user.id, email: data.data.user.email, role: data.data.user.role, - avatar_url, + avatarUrl: avatar_url, }; // Build session object @@ -232,16 +211,6 @@ export const authService = { }; } - // Generate avatar URL - const avatar_url = generateAvatarUrl(email); - - // Build user object - const user: User = { - id: data.data.user.id, - email: data.data.user.email, - role: data.data.user.role || USER_ROLES.USER, - avatar_url, - }; // If the session information is not returned when registering, try to login if (!data.data.session || !data.data.session.access_token) { @@ -258,12 +227,11 @@ export const authService = { if (!loginResponse.ok) { // Return the result with only the user and no session - return { data: { user, session: null }, error: null }; + return { data: { session: null }, error: null }; } // Build complete session const session: Session = { - user, access_token: loginData.data.session.access_token, refresh_token: loginData.data.session.refresh_token, expires_at: loginData.data.session.expires_at, @@ -272,11 +240,10 @@ export const authService = { // Save session to local storage saveSessionToStorage(session); - return { data: { user, session }, error: null }; + return { data: { session }, error: null }; } else { // Use the session information returned by the registration interface const session: Session = { - user, access_token: data.data.session.access_token, refresh_token: data.data.session.refresh_token, expires_at: data.data.session.expires_at, @@ -285,7 +252,7 @@ export const authService = { // Save session to local storage saveSessionToStorage(session); - return { data: { user, session }, error: null }; + return { data: { session }, error: null }; } } catch (error) { log.error("Registration failed:", error); @@ -343,9 +310,46 @@ export const authService = { return null; } }, + getCurrentUserInfo: async (): Promise => { + try { + const response = await fetchWithAuth(API_ENDPOINTS.user.currentUserInfo); + // Check HTTP status code instead of data.code + if (!response.ok) { + log.warn("Failed to get user Info, HTTP status code:", response.status); + return null; + } + + const data = await response.json(); + if (!data.data) { + return null; + } + const userData = { + user: { + id: data.data.user.user_id, + email: data.data.user.user_email, + role: data.data.user.user_role, + avatarUrl: data.data.user.avatarUrl, + permissions: data.data.user.permissions.map((permission:string) => permission.toLowerCase()), + accessibleRoutes: data.data.user.accessibleRoutes.map((router:string) => router.toLowerCase()), + } + } + return userData as AuthInfoResponse; + } catch (error) { + log.error("Failed to get user Info:", error); + return null; + } + }, // Refresh token refreshToken: async (): Promise => { - return await sessionService.checkAndRefreshToken(); + const sessionObj = getSessionFromStorage(); + if (!sessionObj?.refresh_token) return false; + + const newSession = await sessionService.refreshToken(sessionObj.refresh_token); + if (newSession) { + saveSessionToStorage(newSession); + return true; + } + return false; }, }; \ No newline at end of file diff --git a/frontend/services/sessionService.ts b/frontend/services/sessionService.ts index 056b9cb3c..87f2fc472 100644 --- a/frontend/services/sessionService.ts +++ b/frontend/services/sessionService.ts @@ -1,86 +1,33 @@ /** * Session management service + * Pure API layer for session operations */ -import { TOKEN_REFRESH_CD } from "@/const/constants"; import { API_ENDPOINTS } from "./api"; - -import { fetchWithAuth, saveSessionToStorage, removeSessionFromStorage, getSessionFromStorage } from "@/lib/auth"; -import log from "@/lib/logger"; - -// Record the time of the last token refresh -let lastTokenRefreshTime = 0; +import { fetchWithAuth } from "@/lib/auth"; +import { Session } from "@/types/auth"; /** - * Check and refresh token (if needed) + * Call backend refresh token API + * @param refreshToken - refresh token string + * @returns new session or null if failed */ export const sessionService = { - checkAndRefreshToken: async (): Promise => { + refreshToken: async (refreshToken: string): Promise => { try { - const sessionObj = getSessionFromStorage(); - if (!sessionObj) return false; - - const now = Date.now(); - - // Check if the token is in the refresh cooldown period - const timeSinceLastRefresh = now - lastTokenRefreshTime; - if (timeSinceLastRefresh < TOKEN_REFRESH_CD) { - return true; // In cooldown period, default token is valid - } - - // Check if the token has expired - const expiresAt = sessionObj.expires_at * 1000; // Convert to milliseconds - if (expiresAt > now) { - // Token not expired, try to refresh - // Update the last refresh time, even if it hasn't succeeded, record the attempt time to avoid frequent requests - lastTokenRefreshTime = now; - - // Call the refresh token API - const response = await fetchWithAuth(API_ENDPOINTS.user.refreshToken, { - method: "POST", - body: JSON.stringify({ - refresh_token: sessionObj.refresh_token - }) - }); - - // Check HTTP status code instead of data.code - if (!response.ok) { - log.warn("Token refresh failed, HTTP status code:", response.status); - - // HTTP 401 means the token is expired - if (response.status === 401) { - removeSessionFromStorage(); - } - - return false; - } - - const data = await response.json(); - - if (data.data?.session) { - // Update the session information in local storage - const updatedSession = { - ...sessionObj, - access_token: data.data.session.access_token, - refresh_token: data.data.session.refresh_token, - expires_at: data.data.session.expires_at, - }; - - saveSessionToStorage(updatedSession); - return true; - } else { - log.warn("Token refresh failed: incorrect response data format"); - return false; - } - } else { - // Token expired, clear the session - log.warn("Token expired"); - removeSessionFromStorage(); - return false; + const response = await fetchWithAuth(API_ENDPOINTS.user.refreshToken, { + method: "POST", + body: JSON.stringify({ refresh_token: refreshToken }) + }); + + if (!response.ok) { + return null; } - } catch (error) { - log.error("Check token status failed:", error); - return false; + + const data = await response.json(); + return data.data?.session || null; + } catch { + return null; } } -}; \ No newline at end of file +}; diff --git a/frontend/types/auth.ts b/frontend/types/auth.ts index 6d35384c6..4507a1926 100644 --- a/frontend/types/auth.ts +++ b/frontend/types/auth.ts @@ -1,14 +1,18 @@ -// User type definition +// User type definition - contains only basic user information +import type { USER_ROLES } from "@/const/auth"; + +export type UserRole = USER_ROLES; + export interface User { id: string; email: string; - role: "user" | "admin"; - avatar_url?: string; + role: UserRole; + avatarUrl?: string; + tenantId?: string; } -// Session type definition +// Session type definition - contains only authentication tokens export interface Session { - user: User; access_token: string; refresh_token: string; expires_at: number; @@ -33,18 +37,17 @@ export interface AuthFormValues { // Authorization context type export interface AuthContextType { user: User | null; + permissions: string[]; + accessibleRoutes: string[]; isLoading: boolean; isLoginModalOpen: boolean; isRegisterModalOpen: boolean; - isFromSessionExpired: boolean; authServiceUnavailable: boolean; - isSpeedMode: boolean; isAuthReady: boolean; openLoginModal: () => void; closeLoginModal: () => void; openRegisterModal: () => void; closeRegisterModal: () => void; - setIsFromSessionExpired: (value: boolean) => void; login: (email: string, password: string) => Promise; register: ( email: string, @@ -65,3 +68,162 @@ export interface SessionResponse { }; error: ErrorResponse | null; } + +// Current user info response type (includes permissions and accessible routes) +// Backend returns user data directly, not nested under "user" property +export interface AuthInfoResponse { + user: User & { + permissions: string[]; + accessibleRoutes: string[]; + }; +} + +import type { AUTH_EVENTS, AUTHZ_EVENTS } from "@/const/auth"; + +export type AuthEventKey = (typeof AUTH_EVENTS)[keyof typeof AUTH_EVENTS]; +export type AuthzEventKey = (typeof AUTHZ_EVENTS)[keyof typeof AUTHZ_EVENTS]; + +// Authentication Events +export interface AuthEvents { + [AUTH_EVENTS.LOGIN_SUCCESS]: User | null; + [AUTH_EVENTS.REGISTER_SUCCESS]: void; + [AUTH_EVENTS.LOGOUT]: void; + [AUTH_EVENTS.SESSION_EXPIRED]: void; + [AUTH_EVENTS.PERMISSION_DENIED]: { pathname: string } | void; + [AUTH_EVENTS.TOKEN_REFRESHED]: void; + [AUTH_EVENTS.SERVICE_UNAVAILABLE]: void; + [AUTH_EVENTS.BACK_TO_HOME]: void; +} + +// Authorization Events +export interface AuthzEvents { + [AUTHZ_EVENTS.PERMISSIONS_READY]: User & { + permissions: string[]; + accessibleRoutes: string[]; + }; + [AUTHZ_EVENTS.PERMISSIONS_UPDATED]: void; +} + +// Authentication Context Type +export interface AuthenticationContextType { + // Authentication state + isAuthenticated: boolean; + isAuthChecking: boolean; + isLoading: boolean; + session: Session | null; + + // UI state + isLoginModalOpen: boolean; + isRegisterModalOpen: boolean; + authServiceUnavailable: boolean; + + // Methods + login: ( + email: string, + password: string, + options?: { showSuccessMessage?: boolean } + ) => Promise; + register: ( + email: string, + password: string, + isAdmin?: boolean, + inviteCode?: string + ) => Promise; + logout: (options?: { silent?: boolean }) => Promise; + clearLocalSession: () => void; + revoke: () => Promise; + + // UI methods + openLoginModal: () => void; + closeLoginModal: () => void; + openRegisterModal: () => void; + closeRegisterModal: () => void; + + // Auth prompt modal (for side navigation pre-check) + isAuthPromptModalOpen: boolean; + openAuthPromptModal: () => void; + closeAuthPromptModal: () => void; + + // Session expired modal + isSessionExpiredModalOpen: boolean; + openSessionExpiredModal: () => void; + closeSessionExpiredModal: () => void; +} + +// Authentication State Return Type - for useAuthenticationState hook +export interface AuthenticationStateReturn { + // Authentication state + isAuthenticated: boolean; + isAuthChecking: boolean; + isLoading: boolean; + session: Session | null; + authServiceUnavailable: boolean; + + // Methods + login: ( + email: string, + password: string, + options?: { showSuccessMessage?: boolean } + ) => Promise; + register: ( + email: string, + password: string, + isAdmin?: boolean, + inviteCode?: string + ) => Promise; + logout: (options?: { silent?: boolean }) => Promise; + clearLocalSession: () => void; + revoke: () => Promise; +} + +// Authentication UI Return Type - for useAuthenticationUI hook +export interface AuthenticationUIReturn { + // Login/Register Modal + isLoginModalOpen: boolean; + openLoginModal: () => void; + closeLoginModal: () => void; + isRegisterModalOpen: boolean; + openRegisterModal: () => void; + closeRegisterModal: () => void; + + // Auth prompt modal (for side navigation pre-check) + isAuthPromptModalOpen: boolean; + openAuthPromptModal: () => void; + closeAuthPromptModal: () => void; + + // Session expired modal + isSessionExpiredModalOpen: boolean; + openSessionExpiredModal: () => void; + closeSessionExpiredModal: () => void; +} + +// Authorization Context Type +export interface AuthorizationContextType { + // Authorization data + user: User | null; + permissions: string[]; + accessibleRoutes: string[]; + + // State + isLoading: boolean; + error: Error | null; + + // Authorization status + // True when authorization is complete and user has permission to access current route + isAuthorized: boolean; + + // True when authorization data is ready (permissions loaded) + // Does not indicate whether user has permission, only that the process is complete + isAuthzReady: boolean; + + // Methods + refetch: () => Promise; + hasPermission: (permission: string) => boolean; + hasAnyPermission: (permissions: string[]) => boolean; + canAccessRoute: (route: string) => boolean; + + // Authz prompt modal (permission denied) + isAuthzPromptModalOpen: boolean; + openAuthzPromptModal: () => void; + closeAuthzPromptModal: () => void; +} From 775140d7bde10e8748001cd1cc6245b359bf36bc Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Tue, 27 Jan 2026 19:23:16 +0800 Subject: [PATCH 005/167] Set permission denied event from auth to authz --- frontend/const/auth.ts | 10 ++-------- frontend/types/auth.ts | 2 +- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/frontend/const/auth.ts b/frontend/const/auth.ts index 06a6b924e..bfad78f84 100644 --- a/frontend/const/auth.ts +++ b/frontend/const/auth.ts @@ -27,19 +27,12 @@ export const STORAGE_KEYS = { SESSION: "session", }; -// Custom events -export const EVENTS = { - SESSION_EXPIRED: "session-expired", - STORAGE_CHANGE: "storage", -}; - // Type-safe authentication events (used with authEvents emitter) export const AUTH_EVENTS = { LOGIN_SUCCESS: "auth:login-success", REGISTER_SUCCESS: "auth:register-success", LOGOUT: "auth:logout", - SESSION_EXPIRED: "auth:session-expired", - PERMISSION_DENIED: "auth:permission-denied", + SESSION_EXPIRED: "auth:session-expired", // Deprecated: this is an authorization event; prefer AUTHZ_EVENTS.PERMISSION_DENIED. TOKEN_REFRESHED: "auth:token-refreshed", SERVICE_UNAVAILABLE: "auth:service-unavailable", BACK_TO_HOME: "nav:back-to-home", @@ -47,6 +40,7 @@ export const AUTH_EVENTS = { // Type-safe authorization events (used with authzEvents emitter) export const AUTHZ_EVENTS = { + PERMISSION_DENIED: "authz:permission-denied", PERMISSIONS_READY: "authz:permissions-ready", PERMISSIONS_UPDATED: "authz:permissions-updated", } as const; diff --git a/frontend/types/auth.ts b/frontend/types/auth.ts index 4507a1926..41dff5bda 100644 --- a/frontend/types/auth.ts +++ b/frontend/types/auth.ts @@ -89,7 +89,6 @@ export interface AuthEvents { [AUTH_EVENTS.REGISTER_SUCCESS]: void; [AUTH_EVENTS.LOGOUT]: void; [AUTH_EVENTS.SESSION_EXPIRED]: void; - [AUTH_EVENTS.PERMISSION_DENIED]: { pathname: string } | void; [AUTH_EVENTS.TOKEN_REFRESHED]: void; [AUTH_EVENTS.SERVICE_UNAVAILABLE]: void; [AUTH_EVENTS.BACK_TO_HOME]: void; @@ -97,6 +96,7 @@ export interface AuthEvents { // Authorization Events export interface AuthzEvents { + [AUTHZ_EVENTS.PERMISSION_DENIED]: { pathname: string } | void; [AUTHZ_EVENTS.PERMISSIONS_READY]: User & { permissions: string[]; accessibleRoutes: string[]; From 6a5b65b83cb3c012539284060dadd4be1f5322a8 Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Tue, 27 Jan 2026 19:32:39 +0800 Subject: [PATCH 006/167] Add slide session refreshing & periodically session checking ability --- frontend/app/[locale]/chat/page.tsx | 21 --- .../components/providers/rootProvider.tsx | 3 + frontend/hooks/auth/useSessionManager.ts | 155 +++++++++++------- 3 files changed, 98 insertions(+), 81 deletions(-) diff --git a/frontend/app/[locale]/chat/page.tsx b/frontend/app/[locale]/chat/page.tsx index 332e51846..b92b158f6 100644 --- a/frontend/app/[locale]/chat/page.tsx +++ b/frontend/app/[locale]/chat/page.tsx @@ -13,9 +13,6 @@ import { ChatInterface } from "./internal/chatInterface"; */ export default function ChatContent() { const { appConfig } = useConfig(); - const { user, isLoading: userLoading } = useAuthorizationContext(); - const { isSpeedMode } = useDeployment(); - const sessionExpiredTriggeredRef = useRef(false); useEffect(() => { // Load config from backend when entering chat page @@ -26,24 +23,6 @@ export default function ChatContent() { } }, [appConfig.appName]); - // Require login on chat page when unauthenticated (skip in speed mode) - // Note: SESSION_EXPIRED event is triggered by useSessionManager.ts on initialization - useEffect(() => { - if (isSpeedMode) { - sessionExpiredTriggeredRef.current = false; - return; - } - - if (user) { - sessionExpiredTriggeredRef.current = false; - return; - } - - // Session expiration is handled by useSessionManager.ts - // Don't trigger SESSION_EXPIRED here to avoid duplicate handling - }, [isSpeedMode, user, userLoading]); - - return (
diff --git a/frontend/components/providers/rootProvider.tsx b/frontend/components/providers/rootProvider.tsx index 09e6fe95d..e2bf1ac86 100644 --- a/frontend/components/providers/rootProvider.tsx +++ b/frontend/components/providers/rootProvider.tsx @@ -17,8 +17,11 @@ import { LoginModal } from "@/components/auth/loginModal"; import { RegisterModal } from "@/components/auth/registerModal"; import { FullScreenLoading } from "@/components/ui/loading"; import { useDeployment } from "./deploymentProvider"; +import { useSessionManager } from "@/hooks/auth/useSessionManager"; function AppReadyWrapper({ children }: { children?: ReactNode }) { + useSessionManager(); + const { isDeploymentReady, isSpeedMode } = useDeployment(); const auth = useAuthenticationContext(); const authz = useAuthorizationContext(); diff --git a/frontend/hooks/auth/useSessionManager.ts b/frontend/hooks/auth/useSessionManager.ts index 5b2c27254..7d2f3731d 100644 --- a/frontend/hooks/auth/useSessionManager.ts +++ b/frontend/hooks/auth/useSessionManager.ts @@ -1,19 +1,16 @@ "use client"; -import { useCallback, useEffect } from "react"; -import { usePathname } from "next/navigation"; +import { useCallback, useEffect, useRef } from "react"; import { useDeployment } from "@/components/providers/deploymentProvider"; import { sessionService } from "@/services/sessionService"; import { getSessionFromStorage, saveSessionToStorage, - removeSessionFromStorage, - checkSessionValid as checkSessionValidFn, - checkSessionExpired as checkSessionExpiredFn, + checkSessionValid, + checkSessionExpired, handleSessionExpired, } from "@/lib/session"; -import { authEventUtils } from "@/lib/authEvents"; import { TOKEN_REFRESH_BEFORE_EXPIRY_MS, MIN_ACTIVITY_CHECK_INTERVAL_MS, @@ -22,14 +19,6 @@ import log from "@/lib/logger"; import type { Session } from "@/types/auth"; -// ============================================================================ -// Utility functions - Session status checking -// ============================================================================ - -// Re-export from lib/session for convenience -export const checkSessionValid = checkSessionValidFn; -export const checkSessionExpired = checkSessionExpiredFn; - /** * Check if token is expiring soon (within threshold) */ @@ -44,36 +33,12 @@ export const isSessionExpiringSoon = (): boolean => { return msUntilExpiry > 0 && msUntilExpiry <= TOKEN_REFRESH_BEFORE_EXPIRY_MS; }; -// ============================================================================ -// Session operations -// ============================================================================ - -/** - * Save session to localStorage - */ -export const saveSession = (session: Session): void => { - saveSessionToStorage(session); - log.info("Session saved to localStorage"); -}; - -/** - * Clear session and emit expired event - * Unified handling for session expiration - */ -export const clearSession = (): void => { - handleSessionExpired(); -}; - -// ============================================================================ -// Business logic functions -// ============================================================================ - /** * Refresh session if needed (token expiring soon) * Uses refresh_token to get new access_token * @returns Whether refresh was successful */ -export const refreshSessionIfNeeded = async (): Promise => { +export const refreshSession = async (): Promise => { const session = getSessionFromStorage(); if (!session?.refresh_token) { return false; @@ -81,7 +46,7 @@ export const refreshSessionIfNeeded = async (): Promise => { const newSession = await sessionService.refreshToken(session.refresh_token); if (newSession) { - saveSession(newSession); + saveSessionToStorage(newSession); log.info("Session refreshed successfully"); return true; } @@ -90,27 +55,24 @@ export const refreshSessionIfNeeded = async (): Promise => { return false; }; -/** - * Unified handler for session expiration - * Called when session is confirmed expired - */ -export const sessionExpiredHandler = (): void => { - log.info("Handling session expiration"); - handleSessionExpired(); -}; - // ============================================================================ // Hook implementation // ============================================================================ export function useSessionManager() { const { isSpeedMode } = useDeployment(); + const reconcileExpiryRef = useRef<() => void>(() => {}); // Initialize session management when hook is used useEffect(() => { // In speed mode, skip session validation if (isSpeedMode) return; + const session = getSessionFromStorage(); + if (!session) { + return; + } + if (checkSessionValid()) { // Session is valid, no action needed return; @@ -120,6 +82,85 @@ export function useSessionManager() { handleSessionExpired(); }, [isSpeedMode]); + /** + * Proactive session expiry watcher + * Triggers session-expired even if user does not make any API request + */ + useEffect(() => { + if (isSpeedMode) return; + + let timeoutId: number | null = null; + let intervalId: number | null = null; + + const clearTimers = () => { + if (timeoutId !== null) { + window.clearTimeout(timeoutId); + timeoutId = null; + } + if (intervalId !== null) { + window.clearInterval(intervalId); + intervalId = null; + } + }; + + const scheduleExpiryCheck = () => { + clearTimers(); + + const session = getSessionFromStorage(); + if (!session?.expires_at) { + return; + } + + const now = Date.now(); + const delayMs = session.expires_at * 1000 - now; + + if (delayMs <= 0) { + handleSessionExpired(); + return; + } + + // Schedule an accurate one-shot check at expiry time + timeoutId = window.setTimeout(() => { + if (!checkSessionValid()) { + handleSessionExpired(); + } + }, delayMs); + + // Also reschedule periodically to account for token refresh extending expires_at + intervalId = window.setInterval(() => { + const s = getSessionFromStorage(); + if (!s) { + clearTimers(); + return; + } + if (!checkSessionValid()) { + handleSessionExpired(); + return; + } + if (s.expires_at !== session.expires_at) { + scheduleExpiryCheck(); + } + }, 30_000); + }; + + const reconcileExpiry = () => { + if (typeof document !== "undefined" && document.hidden) return; + if (!checkSessionValid() && getSessionFromStorage()) { + handleSessionExpired(); + return; + } + scheduleExpiryCheck(); + }; + + reconcileExpiryRef.current = reconcileExpiry; + scheduleExpiryCheck(); + + return () => { + reconcileExpiryRef.current = () => {}; + clearTimers(); + }; + }, [isSpeedMode]); + /** * Setup automatic token refresh on user activity * Refreshes token before expiry to implement sliding expiration @@ -132,6 +173,9 @@ export function useSessionManager() { const maybeRefreshOnActivity = async () => { try { + // Keep expiry timer in sync when the page becomes active again + reconcileExpiryRef.current(); + // Throttle activity-driven checks const now = Date.now(); if (now - lastActivityCheckAt < MIN_ACTIVITY_CHECK_INTERVAL_MS) return; @@ -142,7 +186,7 @@ export function useSessionManager() { // Check if token is expiring soon if (isSessionExpiringSoon()) { - const success = await refreshSessionIfNeeded(); + const success = await refreshSession(); // If refresh failed, it means refresh_token is also invalid // The session will be cleared when backend returns 401 or when fetchWithAuth checks @@ -196,19 +240,10 @@ export function useSessionManager() { }, [setupTokenAutoRefresh]); return { - // Utility functions - checkSessionValid, - checkSessionExpired, isSessionExpiringSoon, - // Session operations - saveSession, - clearSession, - handleSessionExpired, - // Business logic - refreshSessionIfNeeded, - sessionExpiredHandler, + refreshSession, // Legacy functions setupTokenAutoRefresh, From 31ec742303fb8411b213f058915edb77c27b235d Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Tue, 27 Jan 2026 19:33:10 +0800 Subject: [PATCH 007/167] Add pre session expiration check --- frontend/lib/auth.ts | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/frontend/lib/auth.ts b/frontend/lib/auth.ts index 9d2b887dd..6f19f1499 100644 --- a/frontend/lib/auth.ts +++ b/frontend/lib/auth.ts @@ -2,10 +2,16 @@ * Authentication utilities */ -import { fetchWithErrorHandling } from "@/services/api"; +import { ApiError, fetchWithErrorHandling } from "@/services/api"; import { STORAGE_KEYS } from "@/const/auth"; import { generateAvatarUrl as generateAvatar } from "@/lib/avatar"; import { USER_ROLES } from "@/const/auth"; +import { STATUS_CODES } from "@/const/auth"; +import { + checkSessionValid, + getSessionFromStorage, + handleSessionExpired, +} from "@/lib/session"; /** * Role color mapping - Ant Design color presets @@ -38,11 +44,19 @@ export function generateAvatarUrl(email: string): string { * Expiration should be handled when backend returns 401 */ export const fetchWithAuth = async (url: string, options: RequestInit = {}) => { - const session = - typeof window !== "undefined" - ? localStorage.getItem(STORAGE_KEYS.SESSION) - : null; - const sessionObj = session ? JSON.parse(session) : null; + // Frontend pre-check: detect session expiry without hitting backend + if (typeof window !== "undefined") { + const session = getSessionFromStorage(); + if (session && !checkSessionValid()) { + handleSessionExpired(); + throw new ApiError( + STATUS_CODES.TOKEN_EXPIRED, + "Login expired, please login again" + ); + } + } + + const sessionObj = getSessionFromStorage(); const isFormData = options.body instanceof FormData; const headers = { From b1523cc8e81d4b89c33e18b3aa1d448ee1d562fb Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Tue, 27 Jan 2026 19:51:05 +0800 Subject: [PATCH 008/167] fix icon name bug --- frontend/components/navigation/SideNavigation.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/components/navigation/SideNavigation.tsx b/frontend/components/navigation/SideNavigation.tsx index 2bbe4c9a9..1ea1a19d0 100644 --- a/frontend/components/navigation/SideNavigation.tsx +++ b/frontend/components/navigation/SideNavigation.tsx @@ -58,7 +58,7 @@ const ROUTE_CONFIG: RouteConfig[] = [ { path: "/monitoring", Icon: Activity, labelKey: "sidebar.monitoringManagement", order: 8 }, { path: "/models", Icon: Settings, labelKey: "sidebar.modelManagement", order: 9 }, { path: "/memory", Icon: Database, labelKey: "sidebar.memoryManagement", order: 10 }, - { path: "/users", Icon: Users, labelKey: "sidebar.userManagement", order: 11 }, + { path: "/users", Icon: User, labelKey: "sidebar.userManagement", order: 11 }, ]; /** From 3bf7244421c86793012e9fe7db26e63cc7ffd471 Mon Sep 17 00:00:00 2001 From: zhizhi <928570418@qq.com> Date: Wed, 28 Jan 2026 09:36:14 +0800 Subject: [PATCH 009/167] =?UTF-8?q?=E2=9C=A8=20Update=20UI=20localization?= =?UTF-8?q?=20for=20knowledge=20base=20selection=20prompts=20part2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../tool/KnowledgeBaseToolConfig.tsx | 71 +++++++++++++++---- 1 file changed, 57 insertions(+), 14 deletions(-) diff --git a/frontend/app/[locale]/agents/components/agentConfig/tool/KnowledgeBaseToolConfig.tsx b/frontend/app/[locale]/agents/components/agentConfig/tool/KnowledgeBaseToolConfig.tsx index 9c9d49c6d..bad406349 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/tool/KnowledgeBaseToolConfig.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/tool/KnowledgeBaseToolConfig.tsx @@ -7,6 +7,7 @@ import KnowledgeBaseList from "../../../../knowledges/components/knowledge/Knowl import knowledgeBaseService from "@/services/knowledgeBaseService"; import { ToolParam } from "@/types/agentConfig"; import { KnowledgeBase } from "@/types/knowledgeBase"; +import { ConfigStore } from "@/lib/config"; export interface KnowledgeBaseToolConfigProps { currentParams: ToolParam[]; @@ -49,9 +50,30 @@ export default function KnowledgeBaseToolConfig({ typeof window !== "undefined" ? knowledgeBaseService.getCachedIdNameMapSync() : null ); // Prevent immediate re-opening of modal when Select regains focus after confirm - const [suppressOpen, setSuppressOpen] = useState(false); + // Use useRef for synchronous state checking to avoid race conditions + const suppressOpenRef = useRef(false); const suppressTimerRef = useRef(null); + // Memoize current embedding model to prevent unnecessary re-renders + const currentEmbeddingModel = useMemo(() => { + return ConfigStore.getInstance().getModelConfig().embedding?.modelName || null; + }, []); + + // Helper function to check if a knowledge base is selectable + const isKbSelectable = (kb: KnowledgeBase): boolean => { + const docCount = typeof kb.documentCount === "number" ? kb.documentCount : 0; + const chunkCount = typeof kb.chunkCount === "number" ? kb.chunkCount : 0; + const hasContent = (docCount + chunkCount) > 0; + + // Check model compatibility - only for local knowledge bases (nexent source) + const isModelCompatible = + kb.source !== "nexent" || // Non-local knowledge bases (e.g., DataMate) don't need model check + kb.embeddingModel === "unknown" || + kb.embeddingModel === currentEmbeddingModel; + + return hasContent && isModelCompatible; + }; + const buildKbOptions = (kbs: any[]) => { // For the preview/select we only need the KB name (no description). return kbs.map((kb) => ({ @@ -62,7 +84,26 @@ export default function KnowledgeBaseToolConfig({ }; const openKbModal = async () => { - if (suppressOpen) return; + // Capture suppress state at the start of the function to prevent race conditions + // This snapshot will be used later to verify we should still open the modal + const suppressSnapshot = suppressOpenRef.current; + + // If suppress was active at the start of this call, clear it and return + if (suppressSnapshot) { + suppressOpenRef.current = false; + if (suppressTimerRef.current) { + window.clearTimeout(suppressTimerRef.current); + suppressTimerRef.current = null; + } + return; + } + + // Clear any pending suppress timer since we're intentionally opening the modal + if (suppressTimerRef.current) { + window.clearTimeout(suppressTimerRef.current); + suppressTimerRef.current = null; + } + // If parent provided KB list, use it; otherwise fetch setKbLoading(true); try { @@ -119,6 +160,13 @@ export default function KnowledgeBaseToolConfig({ const idx = currentParams.findIndex((p) => p.name === "index_names"); const currentVal = idx !== -1 ? currentParams[idx].value : undefined; setKbModalSelected(Array.isArray(currentVal) ? currentVal : []); + + // Double-check suppress snapshot before showing modal to prevent race conditions + // If suppress was set during the async operation, don't open the modal + if (suppressSnapshot) { + setKbLoading(false); + return; + } setKbModalVisible(true); } catch (e) { message.error(t("toolConfig.message.kbRefreshFailed", "Failed to refresh KB list")); @@ -404,12 +452,12 @@ export default function KnowledgeBaseToolConfig({ onCancel={() => { setKbModalVisible(false); // suppress immediate reopen caused by focus/click after cancel - setSuppressOpen(true); + suppressOpenRef.current = true; if (suppressTimerRef.current) { window.clearTimeout(suppressTimerRef.current); } suppressTimerRef.current = window.setTimeout(() => { - setSuppressOpen(false); + suppressOpenRef.current = false; suppressTimerRef.current = null; }, 300); if (document.activeElement instanceof HTMLElement) { @@ -417,7 +465,7 @@ export default function KnowledgeBaseToolConfig({ } }} cancelText={t("common.button.cancel")} - okText={t("common.button.save")} + okText={t("common.confirm")} onOk={() => { let idx2 = currentParams.findIndex((p) => p.name === "index_names"); const newParams = [...currentParams]; @@ -441,19 +489,18 @@ export default function KnowledgeBaseToolConfig({ } // Close modal and briefly suppress open to avoid immediate re-open via focus/click setKbModalVisible(false); - setSuppressOpen(true); + suppressOpenRef.current = true; if (suppressTimerRef.current) { window.clearTimeout(suppressTimerRef.current); } suppressTimerRef.current = window.setTimeout(() => { - setSuppressOpen(false); + suppressOpenRef.current = false; suppressTimerRef.current = null; }, 300); // Blur active element to reduce chance of immediate focus triggering open if (document.activeElement instanceof HTMLElement) { document.activeElement.blur(); } - message.success(t("toolConfig.message.kbSelectSaved", "KB selection saved")); }} width={800} > @@ -466,7 +513,7 @@ export default function KnowledgeBaseToolConfig({ knowledgeBases={kbRawList as KnowledgeBase[]} selectedIds={kbModalSelected} activeKnowledgeBase={null} - currentEmbeddingModel={null} + currentEmbeddingModel={currentEmbeddingModel} isLoading={kbLoading} syncLoading={false} onSelect={(id: string) => { @@ -478,11 +525,7 @@ export default function KnowledgeBaseToolConfig({ }} onClick={() => {}} showDataMateConfig={false} - isSelectable={(kb: KnowledgeBase) => { - const docCount = typeof kb.documentCount === "number" ? kb.documentCount : 0; - const chunkCount = typeof kb.chunkCount === "number" ? kb.chunkCount : 0; - return (docCount + chunkCount) > 0; - }} + isSelectable={isKbSelectable} getModelDisplayName={(m: string) => m} containerHeight="50vh" /> From 8d4e518d3961868caa29d29162301efb3e1b96a5 Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Wed, 28 Jan 2026 10:55:37 +0800 Subject: [PATCH 010/167] Be compatible with the new authentication&authorization --- frontend/app/[locale]/users/components/UserProfileComp.tsx | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/frontend/app/[locale]/users/components/UserProfileComp.tsx b/frontend/app/[locale]/users/components/UserProfileComp.tsx index 6fb33f712..4c736c196 100644 --- a/frontend/app/[locale]/users/components/UserProfileComp.tsx +++ b/frontend/app/[locale]/users/components/UserProfileComp.tsx @@ -25,8 +25,9 @@ import { AlertTriangle, ChevronRight, } from "lucide-react"; -import { useAuth } from "@/hooks/useAuth"; import { USER_ROLES } from "@/const/modelConfig"; +import { useAuthorizationContext } from "@/components/providers/AuthorizationProvider"; +import { useAuthenticationContext } from "@/components/providers/AuthenticationProvider"; const { Text, Paragraph } = Typography; @@ -43,7 +44,8 @@ const { Text, Paragraph } = Typography; export default function UserProfileComp() { const { t } = useTranslation("common"); const { message: antdMessage } = App.useApp(); - const { user, logout, revoke, isLoading } = useAuth(); + const { logout, revoke, isLoading } = useAuthenticationContext() + const { user } = useAuthorizationContext() // Modal states const [isEditModalOpen, setIsEditModalOpen] = useState(false); From 385c4662b83a2d84020817f95547c89ea12d3983 Mon Sep 17 00:00:00 2001 From: zhizhi <928570418@qq.com> Date: Wed, 28 Jan 2026 10:59:44 +0800 Subject: [PATCH 011/167] =?UTF-8?q?=E2=9C=A8=20Update=20UI=20localization?= =?UTF-8?q?=20for=20knowledge=20base=20selection=20prompts=20part2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../knowledge/KnowledgeBaseList.tsx | 68 +++++++++---------- 1 file changed, 34 insertions(+), 34 deletions(-) diff --git a/frontend/app/[locale]/knowledges/components/knowledge/KnowledgeBaseList.tsx b/frontend/app/[locale]/knowledges/components/knowledge/KnowledgeBaseList.tsx index fe06aef1c..5a5dc0166 100644 --- a/frontend/app/[locale]/knowledges/components/knowledge/KnowledgeBaseList.tsx +++ b/frontend/app/[locale]/knowledges/components/knowledge/KnowledgeBaseList.tsx @@ -59,25 +59,25 @@ interface KnowledgeBaseListProps { } const KnowledgeBaseList: React.FC = ({ - knowledgeBases, - selectedIds, - activeKnowledgeBase, - currentEmbeddingModel, - isLoading = false, - syncLoading = false, - onSelect, - onClick, - onDelete, - onSync, - onCreateNew, - onDataMateConfig, - showDataMateConfig = false, - isSelectable, - getModelDisplayName, - containerHeight = "70vh", // Default container height consistent with DocumentList - onKnowledgeBaseChange, // New: callback function when knowledge base switches - }) => { - const {t} = useTranslation(); + knowledgeBases, + selectedIds, + activeKnowledgeBase, + currentEmbeddingModel, + isLoading = false, + syncLoading = false, + onSelect, + onClick, + onDelete, + onSync, + onCreateNew, + onDataMateConfig, + showDataMateConfig = false, + isSelectable, + getModelDisplayName, + containerHeight = "70vh", // Default container height consistent with DocumentList + onKnowledgeBaseChange, // New: callback function when knowledge base switches +}) => { + const { t } = useTranslation(); // Format date function, only keep date part const formatDate = (dateValue: any) => { @@ -123,7 +123,7 @@ const KnowledgeBaseList: React.FC = ({ {t("knowledgeBase.list.title")}
-
+
{onCreateNew && ( @@ -160,17 +160,17 @@ const KnowledgeBaseList: React.FC = ({ type="primary" onClick={onSync} > - - - + + + {t("knowledgeBase.button.sync")} )} @@ -189,7 +189,7 @@ const KnowledgeBaseList: React.FC = ({ className="hover:!bg-blue-600" type="primary" onClick={onDataMateConfig} - icon={} + icon={} > {t("knowledgeBase.button.dataMateConfig")} @@ -221,7 +221,7 @@ const KnowledgeBaseList: React.FC = ({ Date: Wed, 28 Jan 2026 15:07:23 +0800 Subject: [PATCH 012/167] =?UTF-8?q?The=20knowledge=20base=20retrieval=20to?= =?UTF-8?q?ol=20supports=20selecting=20a=20specified=20knowledge=20base=20?= =?UTF-8?q?=E2=80=94=E2=80=94=E2=80=94=E2=80=94=20backend?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/agents/create_agent_info.py | 34 +- .../services/tool_configuration_service.py | 40 -- sdk/nexent/core/agents/nexent_agent.py | 11 +- sdk/nexent/core/tools/datamate_search_tool.py | 51 +-- .../core/tools/knowledge_base_search_tool.py | 69 +--- test/backend/agents/test_create_agent_info.py | 177 +-------- test/backend/database/test_agent_db.py | 361 ++++++++++-------- .../test_tool_configuration_service.py | 185 +++------ test/sdk/core/agents/test_nexent_agent.py | 8 +- .../core/tools/test_datamate_search_tool.py | 67 ++-- .../tools/test_knowledge_base_search_tool.py | 61 +-- 11 files changed, 365 insertions(+), 699 deletions(-) diff --git a/backend/agents/create_agent_info.py b/backend/agents/create_agent_info.py index d09029a97..4fd2411a7 100644 --- a/backend/agents/create_agent_info.py +++ b/backend/agents/create_agent_info.py @@ -15,7 +15,6 @@ get_vector_db_core, get_embedding_model, ) -from services.tenant_config_service import get_selected_knowledge_list, build_knowledge_name_mapping from services.remote_mcp_service import get_remote_mcp_server_list from services.memory_config_service import build_memory_context from services.image_service import get_vlm_model @@ -146,20 +145,16 @@ async def create_agent_config( try: for tool in tool_list: if "KnowledgeBaseSearchTool" == tool.class_name: - knowledge_info_list = get_selected_knowledge_list( - tenant_id=tenant_id, user_id=user_id) - if knowledge_info_list: - for knowledge_info in knowledge_info_list: - if knowledge_info.get('knowledge_sources') != 'elasticsearch': - continue - knowledge_name = knowledge_info.get("index_name") + index_names = tool.params.get("index_names") + if index_names: + for index_name in index_names: try: - message = ElasticSearchService().get_summary(index_name=knowledge_name) + message = ElasticSearchService().get_summary(index_name=index_name) summary = message.get("summary", "") - knowledge_base_summary += f"**{knowledge_name}**: {summary}\n\n" + knowledge_base_summary += f"**{index_name}**: {summary}\n\n" except Exception as e: logger.warning( - f"Failed to get summary for knowledge base {knowledge_name}: {e}") + f"Failed to get summary for knowledge base {index_name}: {e}") else: # TODO: Prompt should be refactored to yaml file knowledge_base_summary = "当前没有可用的知识库索引。\n" if language == 'zh' else "No knowledge base indexes are currently available.\n" @@ -238,24 +233,9 @@ async def create_tool_config_list(agent_id, tenant_id, user_id): # special logic for knowledge base search tool if tool_config.class_name == "KnowledgeBaseSearchTool": - knowledge_info_list = get_selected_knowledge_list( - tenant_id=tenant_id, user_id=user_id) - index_names = [knowledge_info.get( - "index_name") for knowledge_info in knowledge_info_list if knowledge_info.get('knowledge_sources') == 'elasticsearch'] - tool_config.metadata = { - "index_names": index_names, + tool_config.metadata = { "vdb_core": get_vector_db_core(), "embedding_model": get_embedding_model(tenant_id=tenant_id), - "name_resolver": build_knowledge_name_mapping(tenant_id=tenant_id, user_id=user_id), - } - elif tool_config.class_name == "DataMateSearchTool": - knowledge_info_list = get_selected_knowledge_list( - tenant_id=tenant_id, user_id=user_id) - index_names = [knowledge_info.get( - "index_name") for knowledge_info in knowledge_info_list if - knowledge_info.get('knowledge_sources') == 'datamate'] - tool_config.metadata = { - "index_names": index_names, } elif tool_config.class_name == "AnalyzeTextFileTool": tool_config.metadata = { diff --git a/backend/services/tool_configuration_service.py b/backend/services/tool_configuration_service.py index e7b39af3b..02b865c2f 100644 --- a/backend/services/tool_configuration_service.py +++ b/backend/services/tool_configuration_service.py @@ -1,4 +1,3 @@ -import asyncio import importlib import inspect import json @@ -25,7 +24,6 @@ from services.file_management_service import get_llm_model from services.vectordatabase_service import get_embedding_model, get_vector_db_core from services.tenant_config_service import get_selected_knowledge_list, build_knowledge_name_mapping -from database.knowledge_db import get_index_name_by_knowledge_name from database.client import minio_client from services.image_service import get_vlm_model @@ -539,52 +537,14 @@ def _validate_local_tool( if tool_name == "knowledge_base_search": if not tenant_id or not user_id: raise ToolExecutionException(f"Tenant ID and User ID are required for {tool_name} validation") - knowledge_info_list = get_selected_knowledge_list( - tenant_id=tenant_id, user_id=user_id) - index_names = [knowledge_info.get("index_name") for knowledge_info in knowledge_info_list if knowledge_info.get('knowledge_sources') == 'elasticsearch'] - name_resolver = build_knowledge_name_mapping( - tenant_id=tenant_id, user_id=user_id) - - # Fallback: if user provided index_names in inputs, try to resolve them even when no selection stored - if (not index_names) and inputs and inputs.get("index_names"): - raw_names = inputs.get("index_names") - if isinstance(raw_names, str): - raw_names = [raw_names] - resolved_indices = [] - for raw in raw_names: - try: - resolved = get_index_name_by_knowledge_name( - raw, tenant_id=tenant_id) - name_resolver[raw] = resolved - resolved_indices.append(resolved) - except Exception: - # If not found as knowledge_name, assume it's already an index_name - resolved_indices.append(raw) - index_names = resolved_indices - embedding_model = get_embedding_model(tenant_id=tenant_id) vdb_core = get_vector_db_core() params = { **instantiation_params, - 'index_names': index_names, - 'name_resolver': name_resolver, 'vdb_core': vdb_core, 'embedding_model': embedding_model, } tool_instance = tool_class(**params) - elif tool_name == "datamate_search": - if not tenant_id or not user_id: - raise ToolExecutionException(f"Tenant ID and User ID are required for {tool_name} validation") - knowledge_info_list = get_selected_knowledge_list( - tenant_id=tenant_id, user_id=user_id) - index_names = [knowledge_info.get("index_name") for knowledge_info in knowledge_info_list if - knowledge_info.get('knowledge_sources') == 'datamate'] - - params = { - **instantiation_params, - 'index_names': index_names, - } - tool_instance = tool_class(**params) elif tool_name == "analyze_image": if not tenant_id or not user_id: raise ToolExecutionException(f"Tenant ID and User ID are required for {tool_name} validation") diff --git a/sdk/nexent/core/agents/nexent_agent.py b/sdk/nexent/core/agents/nexent_agent.py index 12d7737df..4f2c38d07 100644 --- a/sdk/nexent/core/agents/nexent_agent.py +++ b/sdk/nexent/core/agents/nexent_agent.py @@ -73,28 +73,19 @@ def create_local_tool(self, tool_config: ToolConfig): # These parameters have exclude=True and cannot be passed to __init__ # due to smolagents.tools.Tool wrapper restrictions filtered_params = {k: v for k, v in params.items() - if k not in ["index_names", "vdb_core", "embedding_model", "observer"]} + if k not in ["vdb_core", "embedding_model", "observer"]} # Create instance with only non-excluded parameters tools_obj = tool_class(**filtered_params) # Set excluded parameters directly as attributes after instantiation # This bypasses smolagents wrapper restrictions tools_obj.observer = self.observer - index_names = tool_config.metadata.get( - "index_names", None) if tool_config.metadata else None - tools_obj.index_names = [] if index_names is None else index_names tools_obj.vdb_core = tool_config.metadata.get( "vdb_core", None) if tool_config.metadata else None tools_obj.embedding_model = tool_config.metadata.get( "embedding_model", None) if tool_config.metadata else None - name_resolver = tool_config.metadata.get( - "name_resolver", None) if tool_config.metadata else None - tools_obj.name_resolver = {} if name_resolver is None else name_resolver elif class_name == "DataMateSearchTool": tools_obj = tool_class(**params) tools_obj.observer = self.observer - index_names = tool_config.metadata.get( - "index_names", None) if tool_config.metadata else None - tools_obj.index_names = [] if index_names is None else index_names elif class_name == "AnalyzeTextFileTool": tools_obj = tool_class(observer=self.observer, llm_model=tool_config.metadata.get("llm_model", []), diff --git a/sdk/nexent/core/tools/datamate_search_tool.py b/sdk/nexent/core/tools/datamate_search_tool.py index bd412e6c1..1950e9c88 100644 --- a/sdk/nexent/core/tools/datamate_search_tool.py +++ b/sdk/nexent/core/tools/datamate_search_tool.py @@ -14,15 +14,6 @@ logger = logging.getLogger("datamate_search_tool") -def _normalize_index_names(index_names: Optional[Union[str, List[str]]]) -> List[str]: - """Normalize index_names to list; accept single string and keep None as empty list.""" - if index_names is None: - return [] - if isinstance(index_names, str): - return [index_names] - return list(index_names) - - class DataMateSearchTool(Tool): """DataMate knowledge base search tool""" name = "datamate_search" @@ -38,23 +29,6 @@ class DataMateSearchTool(Tool): "type": "string", "description": "The search query to perform.", }, - "top_k": { - "type": "integer", - "description": "Maximum number of search results to return.", - "default": 3, - "nullable": True, - }, - "threshold": { - "type": "number", - "description": "Similarity threshold for search results.", - "default": 0.2, - "nullable": True, - }, - "index_names": { - "type": "array", - "description": "The list of knowledge base names to search (supports user-facing knowledge_name or internal index_name). If not provided, will search all available knowledge bases.", - "nullable": True, - }, "kb_page": { "type": "integer", "description": "Page index when listing knowledge bases from DataMate.", @@ -80,9 +54,13 @@ def __init__( verify_ssl: bool = Field( description="Whether to verify SSL certificates for HTTPS connections", default=False), index_names: List[str] = Field( - description="The list of index names to search", default=None, exclude=True), + description="The list of index names to search", default=None), observer: MessageObserver = Field( description="Message observer", default=None, exclude=True), + top_k: int = Field( + description="Default maximum number of search results to return", default=3), + threshold: float = Field( + description="Default similarity threshold for search results", default=0.2), ): """Initialize the DataMateSearchTool. @@ -106,6 +84,8 @@ def __init__( self.use_https = parsed_url["use_https"] self.server_base_url = parsed_url["base_url"] self.index_names = [] if index_names is None else index_names + self.top_k = top_k + self.threshold = threshold # Determine SSL verification setting if verify_ssl is None: @@ -177,9 +157,6 @@ def _parse_server_url(server_url: str) -> dict: def forward( self, query: str, - top_k: int = 3, - threshold: float = 0.2, - index_names: Union[str, List[str], None] = None, kb_page: int = 0, kb_page_size: int = 20, ) -> str: @@ -187,9 +164,6 @@ def forward( Args: query: Search query text. - top_k: Optional override for maximum number of search results. - threshold: Optional override for similarity threshold. - index_names: The list of knowledge base names to search (supports user-facing knowledge_name or internal index_name). If not provided, will search all available knowledge bases. kb_page: Optional override for knowledge base list page index. kb_page_size: Optional override for knowledge base list page size. """ @@ -207,15 +181,12 @@ def forward( logger.info( f"DataMateSearchTool called with query: '{query}', base_url: '{self.server_base_url}', " - f"top_k: {top_k}, threshold: {threshold}, index_names: {index_names}" + f"top_k: {self.top_k}, threshold: {self.threshold}, index_names: {self.index_names}" ) try: # Step 1: Determine knowledge base IDs to search - # Use provided index_names if available, otherwise use default - knowledge_base_ids = _normalize_index_names( - index_names if index_names is not None else self.index_names) - + knowledge_base_ids = self.index_names if len(knowledge_base_ids) == 0: return json.dumps("No knowledge base selected. No relevant information found.", ensure_ascii=False) @@ -225,8 +196,8 @@ def forward( kb_search = self.datamate_core.hybrid_search( query_text=query, index_names=[knowledge_base_id], - top_k=top_k, - weight_accurate=threshold, + top_k=self.top_k, + weight_accurate=self.threshold, ) if not kb_search: raise Exception( diff --git a/sdk/nexent/core/tools/knowledge_base_search_tool.py b/sdk/nexent/core/tools/knowledge_base_search_tool.py index a033bb69a..ab2d2e702 100644 --- a/sdk/nexent/core/tools/knowledge_base_search_tool.py +++ b/sdk/nexent/core/tools/knowledge_base_search_tool.py @@ -1,6 +1,6 @@ import json import logging -from typing import Dict, List, Optional, Union +from typing import List from pydantic import Field from smolagents.tools import Tool @@ -28,34 +28,29 @@ class KnowledgeBaseSearchTool(Tool): ) inputs = { "query": {"type": "string", "description": "The search query to perform."}, - "search_mode": { - "type": "string", - "description": "the search mode, optional values: hybrid, combining accurate matching and semantic search results across multiple indices.; accurate, Search for documents using fuzzy text matching across multiple indices; semantic, Search for similar documents using vector similarity across multiple indices.", - "default": "hybrid", - "nullable": True, - }, - "index_names": { - "type": "array", - "description": "The list of knowledge base names to search (supports user-facing knowledge_name or internal index_name). If not provided, will search all available knowledge bases.", - "nullable": True, - }, } output_type = "string" - category = ToolCategory.SEARCH.value + category = ToolCategory.SEARCH.value # Used to distinguish different index sources for summaries tool_sign = ToolSign.KNOWLEDGE_BASE.value def __init__( self, - top_k: int = Field(description="Maximum number of search results", default=3), - index_names: List[str] = Field(description="The list of index names to search", default=None, exclude=True), - name_resolver: Optional[Dict[str, str]] = Field( - description="Mapping from knowledge_name to index_name", default=None, exclude=True + top_k: int = Field( + description="Maximum number of search results", default=3), + index_names: List[str] = Field( + description="The list of index names to search", default=None), + search_mode: str = Field( + description="the search mode, optional values: hybrid, accurate, semantic", + default="hybrid", ), - observer: MessageObserver = Field(description="Message observer", default=None, exclude=True), - embedding_model: BaseEmbedding = Field(description="The embedding model to use", default=None, exclude=True), - vdb_core: VectorDatabaseCore = Field(description="Vector database client", default=None, exclude=True), + observer: MessageObserver = Field( + description="Message observer", default=None, exclude=True), + embedding_model: BaseEmbedding = Field( + description="The embedding model to use", default=None, exclude=True), + vdb_core: VectorDatabaseCore = Field( + description="Vector database client", default=None, exclude=True), ): """Initialize the KBSearchTool. @@ -71,36 +66,15 @@ def __init__( self.observer = observer self.vdb_core = vdb_core self.index_names = [] if index_names is None else index_names - self.name_resolver: Dict[str, str] = name_resolver or {} + self.search_mode = search_mode self.embedding_model = embedding_model self.record_ops = 1 # To record serial number self.running_prompt_zh = "知识库检索中..." self.running_prompt_en = "Searching the knowledge base..." - def update_name_resolver(self, new_mapping: Dict[str, str]) -> None: - """Update the mapping from knowledge_name to index_name at runtime.""" - self.name_resolver = new_mapping or {} - - def _resolve_names(self, names: List[str]) -> List[str]: - """Resolve user-facing knowledge names to internal index names.""" - if not names: - return [] - if not self.name_resolver: - logger.warning( - "No name resolver provided, returning original names") - return names - return [self.name_resolver.get(name, name) for name in names] - - def _normalize_index_names(self, index_names: Optional[Union[str, List[str]]]) -> List[str]: - """Normalize index_names to list; accept single string and keep None as empty list.""" - if index_names is None: - return [] - if isinstance(index_names, str): - return [index_names] - return list(index_names) - - def forward(self, query: str, search_mode: str = "hybrid", index_names: Union[str, List[str], None] = None) -> str: + + def forward(self, query: str) -> str: # Send tool run message if self.observer: running_prompt = self.running_prompt_zh if self.observer.lang == "zh" else self.running_prompt_en @@ -108,10 +82,9 @@ def forward(self, query: str, search_mode: str = "hybrid", index_names: Union[st card_content = [{"icon": "search", "text": query}] self.observer.add_message("", ProcessType.CARD, json.dumps(card_content, ensure_ascii=False)) - # Use provided index_names if available, otherwise use default - search_index_names = self._normalize_index_names( - index_names if index_names is not None else self.index_names) - search_index_names = self._resolve_names(search_index_names) + # Use the instance index_names and search_mode + search_index_names = self.index_names + search_mode = self.search_mode # Log the index_names being used for this search logger.info( diff --git a/test/backend/agents/test_create_agent_info.py b/test/backend/agents/test_create_agent_info.py index e9ae8baa7..fcb438009 100644 --- a/test/backend/agents/test_create_agent_info.py +++ b/test/backend/agents/test_create_agent_info.py @@ -295,7 +295,7 @@ async def test_create_tool_config_list_basic(self): """Test case for basic tool configuration list creation""" with patch('backend.agents.create_agent_info.discover_langchain_tools') as mock_discover, \ patch('backend.agents.create_agent_info.search_tools_for_sub_agent') as mock_search_tools, \ - patch('backend.agents.create_agent_info.get_selected_knowledge_list') as mock_knowledge: + patch('backend.agents.create_agent_info.ToolConfig') as mock_tool_config: # Set mock return values mock_discover.return_value = [] @@ -311,7 +311,6 @@ async def test_create_tool_config_list_basic(self): "usage": None } ] - mock_knowledge.return_value = [] result = await create_tool_config_list("agent_1", "tenant_1", "user_1") @@ -333,7 +332,7 @@ async def test_create_tool_config_list_with_knowledge_base_tool(self): """Test case including the knowledge base search tool""" with patch('backend.agents.create_agent_info.discover_langchain_tools') as mock_discover, \ patch('backend.agents.create_agent_info.search_tools_for_sub_agent') as mock_search_tools, \ - patch('backend.agents.create_agent_info.get_selected_knowledge_list') as mock_knowledge, \ + patch('backend.agents.create_agent_info.ToolConfig') as mock_tool_config, \ patch('backend.agents.create_agent_info.get_vector_db_core') as mock_get_vector_db_core, \ patch('backend.agents.create_agent_info.get_embedding_model') as mock_embedding: @@ -350,10 +349,6 @@ async def test_create_tool_config_list_with_knowledge_base_tool(self): "usage": None } ] - mock_knowledge.return_value = [ - {"index_name": "knowledge_1"}, - {"index_name": "knowledge_2"}, - ] mock_vdb_core = "mock_elastic_core" mock_get_vector_db_core.return_value = mock_vdb_core mock_embedding.return_value = "mock_embedding_model" @@ -363,6 +358,7 @@ async def test_create_tool_config_list_with_knowledge_base_tool(self): assert len(result) == 1 # Verify that ToolConfig was called correctly, including knowledge base metadata # Check if the last call was for KnowledgeBaseSearchTool + mock_tool_config.assert_called() last_call = mock_tool_config.call_args_list[-1] assert last_call[1]['class_name'] == "KnowledgeBaseSearchTool" @@ -441,90 +437,11 @@ async def test_create_tool_config_list_with_analyze_text_file_tool(self): @pytest.mark.asyncio async def test_create_tool_config_list_with_knowledge_base_tool_mixed_sources(self): - """Test KnowledgeBaseSearchTool filters only elasticsearch knowledge sources""" - mock_tool_instance = MagicMock() - mock_tool_instance.class_name = "KnowledgeBaseSearchTool" - mock_tool_config.return_value = mock_tool_instance - - with patch('backend.agents.create_agent_info.discover_langchain_tools') as mock_discover, \ - patch('backend.agents.create_agent_info.search_tools_for_sub_agent') as mock_search_tools, \ - patch('backend.agents.create_agent_info.get_selected_knowledge_list') as mock_knowledge, \ - patch('backend.agents.create_agent_info.get_vector_db_core') as mock_get_vector_db_core, \ - patch('backend.agents.create_agent_info.get_embedding_model') as mock_embedding, \ - patch('backend.agents.create_agent_info.build_knowledge_name_mapping') as mock_build_mapping: - - mock_discover.return_value = [] - mock_search_tools.return_value = [ - { - "class_name": "KnowledgeBaseSearchTool", - "name": "knowledge_search", - "description": "Knowledge search tool", - "inputs": "string", - "output_type": "string", - "params": [], - "source": "local", - "usage": None - } - ] - # Mix of elasticsearch and datamate sources - mock_knowledge.return_value = [ - {"index_name": "elastic_kb_1", "knowledge_sources": "elasticsearch"}, - {"index_name": "datamate_kb_1", "knowledge_sources": "datamate"}, - {"index_name": "elastic_kb_2", "knowledge_sources": "elasticsearch"}, - {"index_name": "other_kb", "knowledge_sources": "other"} - ] - mock_vdb_core = "mock_elastic_core" - mock_get_vector_db_core.return_value = mock_vdb_core - mock_embedding.return_value = "mock_embedding_model" - mock_build_mapping.return_value = {"mapping": "mock"} - - result = await create_tool_config_list("agent_1", "tenant_1", "user_1") - - assert len(result) == 1 - assert result[0] is mock_tool_instance - # Should only include elasticsearch index names - expected_index_names = ["elastic_kb_1", "elastic_kb_2"] - assert mock_tool_instance.metadata["index_names"] == expected_index_names + pass @pytest.mark.asyncio async def test_create_tool_config_list_with_datamate_tool(self): - """Test DataMateSearchTool filters only datamate knowledge sources""" - mock_tool_instance = MagicMock() - mock_tool_instance.class_name = "DataMateSearchTool" - mock_tool_config.return_value = mock_tool_instance - - with patch('backend.agents.create_agent_info.discover_langchain_tools') as mock_discover, \ - patch('backend.agents.create_agent_info.search_tools_for_sub_agent') as mock_search_tools, \ - patch('backend.agents.create_agent_info.get_selected_knowledge_list') as mock_knowledge: - - mock_discover.return_value = [] - mock_search_tools.return_value = [ - { - "class_name": "DataMateSearchTool", - "name": "datamate_search", - "description": "DataMate search tool", - "inputs": "string", - "output_type": "string", - "params": [], - "source": "local", - "usage": None - } - ] - # Mix of different knowledge sources - mock_knowledge.return_value = [ - {"index_name": "elastic_kb_1", "knowledge_sources": "elasticsearch"}, - {"index_name": "datamate_kb_1", "knowledge_sources": "datamate"}, - {"index_name": "datamate_kb_2", "knowledge_sources": "datamate"}, - {"index_name": "other_kb", "knowledge_sources": "other"} - ] - - result = await create_tool_config_list("agent_1", "tenant_1", "user_1") - - assert len(result) == 1 - assert result[0] is mock_tool_instance - # Should only include datamate index names - expected_index_names = ["datamate_kb_1", "datamate_kb_2"] - assert mock_tool_instance.metadata["index_names"] == expected_index_names + pass class TestCreateAgentConfig: @@ -539,7 +456,7 @@ async def test_create_agent_config_basic(self): patch('backend.agents.create_agent_info.get_agent_prompt_template') as mock_get_template, \ patch('backend.agents.create_agent_info.tenant_config_manager') as mock_tenant_config, \ patch('backend.agents.create_agent_info.build_memory_context') as mock_build_memory, \ - patch('backend.agents.create_agent_info.get_selected_knowledge_list') as mock_knowledge, \ + patch('backend.agents.create_agent_info.AgentConfig') as mock_agent_config, \ patch('backend.agents.create_agent_info.prepare_prompt_templates') as mock_prepare_templates, \ patch('backend.agents.create_agent_info.get_model_by_model_id') as mock_get_model_by_id: @@ -567,7 +484,6 @@ async def test_create_agent_config_basic(self): user_id="user_1", agent_id="agent_1" ) - mock_knowledge.return_value = [] mock_prepare_templates.return_value = { "system_prompt": "populated_system_prompt"} mock_get_model_by_id.return_value = {"display_name": "test_model"} @@ -596,7 +512,7 @@ async def test_create_agent_config_with_sub_agents(self): patch('backend.agents.create_agent_info.tenant_config_manager') as mock_tenant_config, \ patch('backend.agents.create_agent_info.build_memory_context') as mock_build_memory, \ patch('backend.agents.create_agent_info.search_memory_in_levels', new_callable=AsyncMock) as mock_search_memory, \ - patch('backend.agents.create_agent_info.get_selected_knowledge_list') as mock_knowledge, \ + patch('backend.agents.create_agent_info.AgentConfig') as mock_agent_config, \ patch('backend.agents.create_agent_info.prepare_prompt_templates') as mock_prepare_templates, \ patch('backend.agents.create_agent_info.get_model_by_model_id') as mock_get_model_by_id: @@ -624,7 +540,6 @@ async def test_create_agent_config_with_sub_agents(self): user_id="user_1", agent_id="agent_1" ) - mock_knowledge.return_value = [] mock_prepare_templates.return_value = { "system_prompt": "populated_system_prompt"} mock_get_model_by_id.return_value = {"display_name": "test_model"} @@ -663,7 +578,6 @@ async def test_create_agent_config_with_memory(self): patch('backend.agents.create_agent_info.tenant_config_manager') as mock_tenant_config, \ patch('backend.agents.create_agent_info.build_memory_context') as mock_build_memory, \ patch('backend.agents.create_agent_info.search_memory_in_levels', new_callable=AsyncMock) as mock_search_memory, \ - patch('backend.agents.create_agent_info.get_selected_knowledge_list') as mock_knowledge, \ patch('backend.agents.create_agent_info.prepare_prompt_templates') as mock_prepare_templates, \ patch('backend.agents.create_agent_info.get_model_by_model_id') as mock_get_model_by_id: @@ -700,7 +614,6 @@ async def test_create_agent_config_with_memory(self): agent_id="agent_1" ) mock_search_memory.return_value = {"results": [{"memory": "test"}]} - mock_knowledge.return_value = [] mock_prepare_templates.return_value = { "system_prompt": "populated_system_prompt"} mock_get_model_by_id.return_value = {"display_name": "test_model"} @@ -745,9 +658,6 @@ async def test_create_agent_config_memory_disabled_no_search(self): "backend.agents.create_agent_info.search_memory_in_levels", new_callable=AsyncMock, ) as mock_search_memory, - patch( - "backend.agents.create_agent_info.get_selected_knowledge_list" - ) as mock_knowledge, patch( "backend.agents.create_agent_info.prepare_prompt_templates" ) as mock_prepare_templates, @@ -786,7 +696,6 @@ async def test_create_agent_config_memory_disabled_no_search(self): agent_id="agent_1", ) - mock_knowledge.return_value = [] mock_prepare_templates.return_value = { "system_prompt": "populated_system_prompt" } @@ -812,7 +721,7 @@ async def test_create_agent_config_model_id_none(self): patch('backend.agents.create_agent_info.get_agent_prompt_template') as mock_get_template, \ patch('backend.agents.create_agent_info.tenant_config_manager') as mock_tenant_config, \ patch('backend.agents.create_agent_info.build_memory_context') as mock_build_memory, \ - patch('backend.agents.create_agent_info.get_selected_knowledge_list') as mock_knowledge, \ + patch('backend.agents.create_agent_info.AgentConfig') as mock_agent_config, \ patch('backend.agents.create_agent_info.prepare_prompt_templates') as mock_prepare_templates, \ patch('backend.agents.create_agent_info.get_model_by_model_id') as mock_get_model_by_id: @@ -840,7 +749,6 @@ async def test_create_agent_config_model_id_none(self): user_id="user_1", agent_id="agent_1" ) - mock_knowledge.return_value = [] mock_prepare_templates.return_value = { "system_prompt": "populated_system_prompt"} mock_get_model_by_id.return_value = None # Model not found @@ -885,9 +793,6 @@ async def test_create_agent_config_memory_exception(self): "backend.agents.create_agent_info.search_memory_in_levels", new_callable=AsyncMock, ) as mock_search_memory, - patch( - "backend.agents.create_agent_info.get_selected_knowledge_list" - ) as mock_knowledge, patch( "backend.agents.create_agent_info.prepare_prompt_templates" ) as mock_prepare_templates, @@ -926,7 +831,6 @@ async def test_create_agent_config_memory_exception(self): ) mock_search_memory.side_effect = Exception("boom") - mock_knowledge.return_value = [] mock_prepare_templates.return_value = { "system_prompt": "populated_system_prompt" } @@ -945,70 +849,7 @@ async def test_create_agent_config_memory_exception(self): @pytest.mark.asyncio async def test_create_agent_config_with_knowledge_base_summary_filtering(self): - """Test knowledge base summary generation filters only elasticsearch sources""" - with patch('backend.agents.create_agent_info.search_agent_info_by_agent_id') as mock_search_agent, \ - patch('backend.agents.create_agent_info.query_sub_agents_id_list') as mock_query_sub, \ - patch('backend.agents.create_agent_info.create_tool_config_list') as mock_create_tools, \ - patch('backend.agents.create_agent_info.get_agent_prompt_template') as mock_get_template, \ - patch('backend.agents.create_agent_info.tenant_config_manager') as mock_tenant_config, \ - patch('backend.agents.create_agent_info.build_memory_context') as mock_build_memory, \ - patch('backend.agents.create_agent_info.get_selected_knowledge_list') as mock_knowledge, \ - patch('backend.agents.create_agent_info.prepare_prompt_templates') as mock_prepare_templates, \ - patch('backend.agents.create_agent_info.get_model_by_model_id') as mock_get_model_by_id, \ - patch('backend.agents.create_agent_info.ElasticSearchService') as mock_es_service: - - # Set mock return values - mock_search_agent.return_value = { - "name": "test_agent", - "description": "test description", - "duty_prompt": "test duty", - "constraint_prompt": "test constraint", - "few_shots_prompt": "test few shots", - "max_steps": 5, - "model_id": 123, - "provide_run_summary": True - } - mock_query_sub.return_value = [] - - # Create a mock tool with KnowledgeBaseSearchTool class name - mock_tool = Mock() - mock_tool.class_name = "KnowledgeBaseSearchTool" - mock_create_tools.return_value = [mock_tool] - - mock_get_template.return_value = { - "system_prompt": "{{duty}} {{constraint}} {{few_shots}}"} - mock_tenant_config.get_app_config.side_effect = [ - "TestApp", "Test Description"] - mock_build_memory.return_value = Mock( - user_config=Mock(memory_switch=False), - memory_config={}, - tenant_id="tenant_1", - user_id="user_1", - agent_id="agent_1" - ) - mock_knowledge.return_value = [ - {"index_name": "elastic_kb_1", "knowledge_sources": "elasticsearch"}, - {"index_name": "datamate_kb_1", "knowledge_sources": "datamate"}, - {"index_name": "elastic_kb_2", "knowledge_sources": "elasticsearch"} - ] - mock_prepare_templates.return_value = { - "system_prompt": "populated_system_prompt"} - mock_get_model_by_id.return_value = {"display_name": "test_model"} - - # Mock ElasticSearchService - mock_es_instance = Mock() - mock_es_instance.get_summary.return_value = {"summary": "Test summary"} - mock_es_service.return_value = mock_es_instance - - result = await create_agent_config("agent_1", "tenant_1", "user_1", "zh", "test query") - - # Verify that ElasticSearchService was called only for elasticsearch sources - assert mock_es_service.call_count == 2 # Only for elastic_kb_1 and elastic_kb_2 - mock_es_instance.get_summary.assert_any_call(index_name="elastic_kb_1") - mock_es_instance.get_summary.assert_any_call(index_name="elastic_kb_2") - # Should not be called for datamate_kb_1 - called_index_names = [call[1]['index_name'] for call in mock_es_instance.get_summary.call_args_list] - assert "datamate_kb_1" not in called_index_names + pass class TestCreateModelConfigList: diff --git a/test/backend/database/test_agent_db.py b/test/backend/database/test_agent_db.py index 929808250..e43d04757 100644 --- a/test/backend/database/test_agent_db.py +++ b/test/backend/database/test_agent_db.py @@ -1,3 +1,17 @@ +from backend.database.agent_db import ( + search_agent_info_by_agent_id, + search_agent_id_by_agent_name, + search_blank_sub_agent_by_main_agent_id, + query_sub_agents_id_list, + create_agent, + update_agent, + delete_agent_by_id, + query_all_agent_info_by_tenant_id, + insert_related_agent, + delete_related_agent, + delete_agent_relationship, + update_related_agents +) import sys import pytest from unittest.mock import patch, MagicMock @@ -25,7 +39,8 @@ # 模拟utils模块 utils_mock = MagicMock() utils_mock.auth_utils = MagicMock() -utils_mock.auth_utils.get_current_user_id_from_token = MagicMock(return_value="test_user_id") +utils_mock.auth_utils.get_current_user_id_from_token = MagicMock( + return_value="test_user_id") # 将模拟的utils模块添加到sys.modules中 sys.modules['utils'] = utils_mock @@ -60,20 +75,7 @@ sys.modules['backend.database.db_models'] = db_models_mock # 现在可以安全地导入被测试的模块 -from backend.database.agent_db import ( - search_agent_info_by_agent_id, - search_agent_id_by_agent_name, - search_blank_sub_agent_by_main_agent_id, - query_sub_agents_id_list, - create_agent, - update_agent, - delete_agent_by_id, - query_all_agent_info_by_tenant_id, - insert_related_agent, - delete_related_agent, - delete_agent_relationship, - update_related_agents -) + class MockAgent: def __init__(self): @@ -86,10 +88,12 @@ def __init__(self): self.business_logic_model_id = None self.business_logic_model_name = None + class MockAgentRelation: def __init__(self): self.selected_agent_id = 2 + @pytest.fixture def mock_session(): """创建模拟的数据库会话""" @@ -98,29 +102,33 @@ def mock_session(): mock_session.query.return_value = mock_query return mock_session, mock_query + def test_search_agent_info_by_agent_id_success(monkeypatch, mock_session): """测试成功搜索agent信息""" session, query = mock_session mock_agent = MockAgent() - + mock_first = MagicMock() mock_first.return_value = mock_agent mock_filter = MagicMock() mock_filter.first = mock_first query.filter.return_value = mock_filter - + mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.agent_db.get_db_session", lambda: mock_ctx) - monkeypatch.setattr("backend.database.agent_db.as_dict", lambda obj: obj.__dict__) - + monkeypatch.setattr( + "backend.database.agent_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr("backend.database.agent_db.as_dict", + lambda obj: obj.__dict__) + result = search_agent_info_by_agent_id(1, "tenant1") - + assert result["agent_id"] == 1 assert result["name"] == "test_agent" assert result["tenant_id"] == "tenant1" + def test_search_agent_info_by_agent_id_not_found(monkeypatch, mock_session): """测试搜索不存在的agent""" session, query = mock_session @@ -129,35 +137,39 @@ def test_search_agent_info_by_agent_id_not_found(monkeypatch, mock_session): mock_filter = MagicMock() mock_filter.first = mock_first query.filter.return_value = mock_filter - + mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.agent_db.get_db_session", lambda: mock_ctx) - + monkeypatch.setattr( + "backend.database.agent_db.get_db_session", lambda: mock_ctx) + with pytest.raises(ValueError, match="agent not found"): search_agent_info_by_agent_id(999, "tenant1") + def test_search_agent_id_by_agent_name_success(monkeypatch, mock_session): """测试成功通过agent名称搜索agent ID""" session, query = mock_session mock_agent = MockAgent() - + mock_first = MagicMock() mock_first.return_value = mock_agent mock_filter = MagicMock() mock_filter.first = mock_first query.filter.return_value = mock_filter - + mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.agent_db.get_db_session", lambda: mock_ctx) - + monkeypatch.setattr( + "backend.database.agent_db.get_db_session", lambda: mock_ctx) + result = search_agent_id_by_agent_name("test_agent", "tenant1") - + assert result == 1 + def test_search_agent_id_by_agent_name_not_found(monkeypatch, mock_session): """测试通过不存在的agent名称搜索""" session, query = mock_session @@ -166,36 +178,40 @@ def test_search_agent_id_by_agent_name_not_found(monkeypatch, mock_session): mock_filter = MagicMock() mock_filter.first = mock_first query.filter.return_value = mock_filter - + mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.agent_db.get_db_session", lambda: mock_ctx) - + monkeypatch.setattr( + "backend.database.agent_db.get_db_session", lambda: mock_ctx) + with pytest.raises(ValueError, match="agent not found"): search_agent_id_by_agent_name("nonexistent_agent", "tenant1") + def test_search_blank_sub_agent_by_main_agent_id_found(monkeypatch, mock_session): """测试成功搜索空白子agent""" session, query = mock_session mock_agent = MockAgent() mock_agent.enabled = False - + mock_first = MagicMock() mock_first.return_value = mock_agent mock_filter = MagicMock() mock_filter.first = mock_first query.filter.return_value = mock_filter - + mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.agent_db.get_db_session", lambda: mock_ctx) - + monkeypatch.setattr( + "backend.database.agent_db.get_db_session", lambda: mock_ctx) + result = search_blank_sub_agent_by_main_agent_id("tenant1") - + assert result == 1 + def test_search_blank_sub_agent_by_main_agent_id_not_found(monkeypatch, mock_session): """测试搜索不到空白子agent""" session, query = mock_session @@ -204,83 +220,96 @@ def test_search_blank_sub_agent_by_main_agent_id_not_found(monkeypatch, mock_ses mock_filter = MagicMock() mock_filter.first = mock_first query.filter.return_value = mock_filter - + mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.agent_db.get_db_session", lambda: mock_ctx) - + monkeypatch.setattr( + "backend.database.agent_db.get_db_session", lambda: mock_ctx) + result = search_blank_sub_agent_by_main_agent_id("tenant1") - + assert result is None + def test_query_sub_agents_id_list(monkeypatch, mock_session): """测试查询子agent ID列表""" session, query = mock_session mock_relation = MockAgentRelation() - + mock_all = MagicMock() mock_all.return_value = [mock_relation] mock_filter = MagicMock() mock_filter.all = mock_all query.filter.return_value = mock_filter - + mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.agent_db.get_db_session", lambda: mock_ctx) - + monkeypatch.setattr( + "backend.database.agent_db.get_db_session", lambda: mock_ctx) + result = query_sub_agents_id_list(1, "tenant1") - + assert result == [2] + def test_create_agent_success(monkeypatch, mock_session): """测试成功创建agent""" session, query = mock_session session.add = MagicMock() session.flush = MagicMock() - + mock_agent = MockAgent() - + mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.agent_db.get_db_session", lambda: mock_ctx) - monkeypatch.setattr("backend.database.agent_db.filter_property", lambda data, model: data) - monkeypatch.setattr("backend.database.agent_db.as_dict", lambda obj: obj.__dict__) - monkeypatch.setattr("backend.database.agent_db.AgentInfo", lambda **kwargs: mock_agent) - + monkeypatch.setattr( + "backend.database.agent_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr( + "backend.database.agent_db.filter_property", lambda data, model: data) + monkeypatch.setattr("backend.database.agent_db.as_dict", + lambda obj: obj.__dict__) + monkeypatch.setattr("backend.database.agent_db.AgentInfo", + lambda **kwargs: mock_agent) + agent_info = {"name": "new_agent", "description": "test description"} result = create_agent(agent_info, "tenant1", "user1") - + assert result["agent_id"] == 1 session.add.assert_called_once() session.flush.assert_called_once() + def test_update_agent_success(monkeypatch, mock_session): """测试成功更新agent""" session, query = mock_session mock_agent = MockAgent() - + mock_first = MagicMock() mock_first.return_value = mock_agent mock_filter = MagicMock() mock_filter.first = mock_first query.filter.return_value = mock_filter - + mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.agent_db.get_db_session", lambda: mock_ctx) - monkeypatch.setattr("backend.database.agent_db.filter_property", lambda data, model: data) - + monkeypatch.setattr( + "backend.database.agent_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr( + "backend.database.agent_db.filter_property", lambda data, model: data) + agent_info = MagicMock() - agent_info.__dict__ = {"name": "updated_agent", "description": "updated description"} - + agent_info.__dict__ = {"name": "updated_agent", + "description": "updated description"} + update_agent(1, agent_info, "tenant1", "user1") - + assert mock_agent.updated_by == "user1" + def test_update_agent_not_found(monkeypatch, mock_session): """测试更新不存在的agent""" session, query = mock_session @@ -289,18 +318,20 @@ def test_update_agent_not_found(monkeypatch, mock_session): mock_filter = MagicMock() mock_filter.first = mock_first query.filter.return_value = mock_filter - + mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.agent_db.get_db_session", lambda: mock_ctx) - + monkeypatch.setattr( + "backend.database.agent_db.get_db_session", lambda: mock_ctx) + agent_info = MagicMock() agent_info.__dict__ = {"name": "updated_agent"} - + with pytest.raises(ValueError, match="ag_tenant_agent_t Agent not found"): update_agent(999, agent_info, "tenant1", "user1") + def test_delete_agent_by_id_success(monkeypatch, mock_session): """测试成功删除agent""" session, query = mock_session @@ -308,22 +339,24 @@ def test_delete_agent_by_id_success(monkeypatch, mock_session): mock_filter = MagicMock() mock_filter.update = mock_update query.filter.return_value = mock_filter - + mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.agent_db.get_db_session", lambda: mock_ctx) - + monkeypatch.setattr( + "backend.database.agent_db.get_db_session", lambda: mock_ctx) + delete_agent_by_id(1, "tenant1", "user1") - + # 验证调用了两次update(一次更新AgentInfo,一次更新ToolInstance) assert mock_update.call_count == 2 + def test_query_all_agent_info_by_tenant_id(monkeypatch, mock_session): """测试查询所有agent信息""" session, query = mock_session mock_agent = MockAgent() - + mock_all = MagicMock() mock_all.return_value = [mock_agent] mock_order_by = MagicMock() @@ -331,53 +364,64 @@ def test_query_all_agent_info_by_tenant_id(monkeypatch, mock_session): mock_filter = MagicMock() mock_filter.order_by.return_value = mock_order_by query.filter.return_value = mock_filter - + mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.agent_db.get_db_session", lambda: mock_ctx) - monkeypatch.setattr("backend.database.agent_db.as_dict", lambda obj: obj.__dict__) - + monkeypatch.setattr( + "backend.database.agent_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr("backend.database.agent_db.as_dict", + lambda obj: obj.__dict__) + result = query_all_agent_info_by_tenant_id("tenant1") - + assert len(result) == 1 assert result[0]["agent_id"] == 1 + def test_insert_related_agent_success(monkeypatch, mock_session): """测试成功插入相关agent""" session, query = mock_session session.add = MagicMock() session.flush = MagicMock() - + mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.agent_db.get_db_session", lambda: mock_ctx) - monkeypatch.setattr("backend.database.agent_db.filter_property", lambda data, model: data) - monkeypatch.setattr("backend.database.agent_db.AgentRelation", lambda **kwargs: MagicMock()) - + monkeypatch.setattr( + "backend.database.agent_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr( + "backend.database.agent_db.filter_property", lambda data, model: data) + monkeypatch.setattr( + "backend.database.agent_db.AgentRelation", lambda **kwargs: MagicMock()) + result = insert_related_agent(1, 2, "tenant1") - + assert result is True session.add.assert_called_once() session.flush.assert_called_once() + def test_insert_related_agent_failure(monkeypatch, mock_session): """测试插入相关agent失败""" session, query = mock_session session.add = MagicMock(side_effect=Exception("Database error")) - + mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.agent_db.get_db_session", lambda: mock_ctx) - monkeypatch.setattr("backend.database.agent_db.filter_property", lambda data, model: data) - monkeypatch.setattr("backend.database.agent_db.AgentRelation", lambda **kwargs: MagicMock()) - + monkeypatch.setattr( + "backend.database.agent_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr( + "backend.database.agent_db.filter_property", lambda data, model: data) + monkeypatch.setattr( + "backend.database.agent_db.AgentRelation", lambda **kwargs: MagicMock()) + result = insert_related_agent(1, 2, "tenant1") - + assert result is False + def test_delete_related_agent_success(monkeypatch, mock_session): """测试成功删除相关agent""" session, query = mock_session @@ -385,17 +429,19 @@ def test_delete_related_agent_success(monkeypatch, mock_session): mock_filter = MagicMock() mock_filter.update = mock_update query.filter.return_value = mock_filter - + mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.agent_db.get_db_session", lambda: mock_ctx) - + monkeypatch.setattr( + "backend.database.agent_db.get_db_session", lambda: mock_ctx) + result = delete_related_agent(1, 2, "tenant1") - + assert result is True mock_update.assert_called_once() + def test_delete_related_agent_failure(monkeypatch, mock_session): """测试删除相关agent失败""" session, query = mock_session @@ -403,16 +449,18 @@ def test_delete_related_agent_failure(monkeypatch, mock_session): mock_filter = MagicMock() mock_filter.update = mock_update query.filter.return_value = mock_filter - + mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.agent_db.get_db_session", lambda: mock_ctx) - + monkeypatch.setattr( + "backend.database.agent_db.get_db_session", lambda: mock_ctx) + result = delete_related_agent(1, 2, "tenant1") - + assert result is False + def test_delete_agent_relationship_success(monkeypatch, mock_session): """测试成功删除agent关系""" session, query = mock_session @@ -420,18 +468,20 @@ def test_delete_agent_relationship_success(monkeypatch, mock_session): mock_filter = MagicMock() mock_filter.update = mock_update query.filter.return_value = mock_filter - + mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.agent_db.get_db_session", lambda: mock_ctx) - + monkeypatch.setattr( + "backend.database.agent_db.get_db_session", lambda: mock_ctx) + # 函数不返回任何值,只验证执行成功 delete_agent_relationship(1, "tenant1", "user1") - + # 验证调用了两次update(一次删除父关系,一次删除子关系) assert mock_update.call_count == 2 + def test_delete_agent_relationship_failure(monkeypatch, mock_session): """测试删除agent关系失败""" session, query = mock_session @@ -439,12 +489,13 @@ def test_delete_agent_relationship_failure(monkeypatch, mock_session): mock_filter = MagicMock() mock_filter.update = mock_update query.filter.return_value = mock_filter - + mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.agent_db.get_db_session", lambda: mock_ctx) - + monkeypatch.setattr( + "backend.database.agent_db.get_db_session", lambda: mock_ctx) + # 函数应该抛出异常,因为数据库操作失败 with pytest.raises(Exception, match="Database error"): delete_agent_relationship(1, "tenant1", "user1") @@ -453,34 +504,36 @@ def test_delete_agent_relationship_failure(monkeypatch, mock_session): def test_update_related_agents_add_new(monkeypatch, mock_session): """测试更新相关agent - 添加新关系""" session, query = mock_session - + # Mock current relations (empty initially) mock_all = MagicMock() mock_all.return_value = [] # No existing relations - + # Mock for querying current relations mock_filter1 = MagicMock() mock_filter1.all = mock_all - + # Mock for update (soft delete) - should not be called since no deletions mock_update = MagicMock() mock_filter2 = MagicMock() mock_filter2.update = mock_update - + # Setup filter chain: first call returns filter1 (for query) # If update is called, it would return filter2, but it shouldn't be called query.filter.return_value = mock_filter1 - + # Mock for adding new relations session.add = MagicMock() session.commit = MagicMock() - + mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.agent_db.get_db_session", lambda: mock_ctx) - monkeypatch.setattr("backend.database.agent_db.filter_property", lambda data, model: data) - + monkeypatch.setattr( + "backend.database.agent_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr( + "backend.database.agent_db.filter_property", lambda data, model: data) + # Create a Mock class for AgentRelation that supports both class attribute access and instantiation # The class attributes need to support comparison operations (==, !=, .in_()) for SQLAlchemy queries class MockAgentRelationClass: @@ -488,16 +541,17 @@ class MockAgentRelationClass: tenant_id = MagicMock() delete_flag = MagicMock() selected_agent_id = MagicMock() - + def __init__(self, **kwargs): for key, value in kwargs.items(): setattr(self, key, value) - - monkeypatch.setattr("backend.database.agent_db.AgentRelation", MockAgentRelationClass) - + + monkeypatch.setattr( + "backend.database.agent_db.AgentRelation", MockAgentRelationClass) + # Execute - add new relations [2, 3] update_related_agents(1, [2, 3], "tenant1", "user1") - + # Verify: should add 2 new relations, no deletions assert session.add.call_count == 2 session.commit.assert_called_once() @@ -508,39 +562,40 @@ def __init__(self, **kwargs): def test_update_related_agents_delete_existing(monkeypatch, mock_session): """测试更新相关agent - 删除现有关系""" session, query = mock_session - + # Mock existing relations mock_relation1 = MockAgentRelation() mock_relation1.selected_agent_id = 2 mock_relation2 = MockAgentRelation() mock_relation2.selected_agent_id = 3 - + mock_all = MagicMock() mock_all.return_value = [mock_relation1, mock_relation2] - + # Mock for querying current relations mock_filter1 = MagicMock() mock_filter1.all = mock_all - + # Mock for update (soft delete) mock_update = MagicMock() mock_filter2 = MagicMock() mock_filter2.update = mock_update - + # Setup filter chain: first call returns filter1 (for query), subsequent calls return filter2 (for update) query.filter.side_effect = [mock_filter1, mock_filter2] - + session.add = MagicMock() session.commit = MagicMock() - + mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.agent_db.get_db_session", lambda: mock_ctx) - + monkeypatch.setattr( + "backend.database.agent_db.get_db_session", lambda: mock_ctx) + # Execute - remove all relations (empty list) update_related_agents(1, [], "tenant1", "user1") - + # Verify: should soft delete 2 relations, add none mock_update.assert_called_once() session.add.assert_not_called() @@ -550,37 +605,39 @@ def test_update_related_agents_delete_existing(monkeypatch, mock_session): def test_update_related_agents_replace_mixed(monkeypatch, mock_session): """测试更新相关agent - 混合添加和删除""" session, query = mock_session - + # Mock existing relations [2, 3] mock_relation1 = MockAgentRelation() mock_relation1.selected_agent_id = 2 mock_relation2 = MockAgentRelation() mock_relation2.selected_agent_id = 3 - + mock_all = MagicMock() mock_all.return_value = [mock_relation1, mock_relation2] - + # Mock for querying current relations mock_filter1 = MagicMock() mock_filter1.all = mock_all - + # Mock for update (soft delete) - will be called to delete 2 mock_update = MagicMock() mock_filter2 = MagicMock() mock_filter2.update = mock_update - + # Setup filter chain: first call returns filter1 (for query), subsequent calls return filter2 (for update) query.filter.side_effect = [mock_filter1, mock_filter2] - + session.add = MagicMock() session.commit = MagicMock() - + mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.agent_db.get_db_session", lambda: mock_ctx) - monkeypatch.setattr("backend.database.agent_db.filter_property", lambda data, model: data) - + monkeypatch.setattr( + "backend.database.agent_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr( + "backend.database.agent_db.filter_property", lambda data, model: data) + # Create a Mock class for AgentRelation that supports both class attribute access and instantiation # The class attributes need to support comparison operations (==, !=, .in_()) for SQLAlchemy queries class MockAgentRelationClass: @@ -588,16 +645,17 @@ class MockAgentRelationClass: tenant_id = MagicMock() delete_flag = MagicMock() selected_agent_id = MagicMock() - + def __init__(self, **kwargs): for key, value in kwargs.items(): setattr(self, key, value) - - monkeypatch.setattr("backend.database.agent_db.AgentRelation", MockAgentRelationClass) - + + monkeypatch.setattr( + "backend.database.agent_db.AgentRelation", MockAgentRelationClass) + # Execute - replace [2, 3] with [3, 4] (delete 2, add 4) update_related_agents(1, [3, 4], "tenant1", "user1") - + # Verify: should delete 2 (relation with selected_agent_id=2), add 4 mock_update.assert_called_once() assert session.add.call_count == 1 @@ -607,32 +665,33 @@ def __init__(self, **kwargs): def test_update_related_agents_no_changes(monkeypatch, mock_session): """测试更新相关agent - 无变化""" session, query = mock_session - + # Mock existing relations [2, 3] mock_relation1 = MockAgentRelation() mock_relation1.selected_agent_id = 2 mock_relation2 = MockAgentRelation() mock_relation2.selected_agent_id = 3 - + mock_all = MagicMock() mock_all.return_value = [mock_relation1, mock_relation2] - + # Mock for querying current relations mock_filter1 = MagicMock() mock_filter1.all = mock_all query.filter.return_value = mock_filter1 - + session.add = MagicMock() session.commit = MagicMock() - + mock_ctx = MagicMock() mock_ctx.__enter__.return_value = session mock_ctx.__exit__.return_value = None - monkeypatch.setattr("backend.database.agent_db.get_db_session", lambda: mock_ctx) - + monkeypatch.setattr( + "backend.database.agent_db.get_db_session", lambda: mock_ctx) + # Execute - same relations [2, 3] update_related_agents(1, [2, 3], "tenant1", "user1") - + # Verify: no deletions, no additions session.add.assert_not_called() - session.commit.assert_called_once() \ No newline at end of file + session.commit.assert_called_once() diff --git a/test/backend/services/test_tool_configuration_service.py b/test/backend/services/test_tool_configuration_service.py index 045d79b84..4a90d2183 100644 --- a/test/backend/services/test_tool_configuration_service.py +++ b/test/backend/services/test_tool_configuration_service.py @@ -1867,9 +1867,12 @@ def test_validate_local_tool_knowledge_base_search_success(self, mock_get_vector # Mock signature for knowledge_base_search tool mock_sig = Mock() + mock_index_names_param = Mock() + mock_index_names_param.default = ["default_index"] + mock_sig.parameters = { 'self': Mock(), - 'index_names': Mock(), + 'index_names': mock_index_names_param, 'vdb_core': Mock(), 'embedding_model': Mock() } @@ -1905,8 +1908,7 @@ def test_validate_local_tool_knowledge_base_search_success(self, mock_get_vector # Verify knowledge base specific parameters were passed expected_params = { "param": "config", - "index_names": ["index1", "index2"], - "name_resolver": {"index1": "index1", "alias2": "index2"}, + "index_names": ["default_index"], "vdb_core": mock_vdb_core, "embedding_model": "mock_embedding_model", } @@ -1914,10 +1916,6 @@ def test_validate_local_tool_knowledge_base_search_success(self, mock_get_vector mock_tool_instance.forward.assert_called_once_with(query="test query") # Verify service calls - mock_get_knowledge_list.assert_called_once_with( - tenant_id="tenant1", user_id="user1") - mock_build_mapping.assert_called_once_with( - tenant_id="tenant1", user_id="user1") mock_get_embedding_model.assert_called_once_with(tenant_id="tenant1") @patch('backend.services.tool_configuration_service._get_tool_class_by_name') @@ -1997,9 +1995,11 @@ def test_validate_local_tool_knowledge_base_search_empty_knowledge_list(self, mo # Mock signature for knowledge_base_search tool mock_sig = Mock() + mock_index_names_param = Mock() + mock_index_names_param.default = [] mock_sig.parameters = { 'self': Mock(), - 'index_names': Mock(), + 'index_names': mock_index_names_param, 'vdb_core': Mock(), 'embedding_model': Mock() } @@ -2028,79 +2028,12 @@ def test_validate_local_tool_knowledge_base_search_empty_knowledge_list(self, mo expected_params = { "param": "config", "index_names": [], - "name_resolver": {}, "vdb_core": mock_vdb_core, "embedding_model": "mock_embedding_model", } mock_tool_class.assert_called_once_with(**expected_params) mock_tool_instance.forward.assert_called_once_with(query="test query") - @patch('backend.services.tool_configuration_service._get_tool_class_by_name') - @patch('backend.services.tool_configuration_service.inspect.signature') - @patch('backend.services.tool_configuration_service.get_selected_knowledge_list') - @patch('backend.services.tool_configuration_service.build_knowledge_name_mapping') - @patch('backend.services.tool_configuration_service.get_embedding_model') - @patch('backend.services.tool_configuration_service.get_vector_db_core') - @patch('backend.services.tool_configuration_service.get_index_name_by_knowledge_name') - def test_validate_local_tool_knowledge_base_search_resolves_inputs_indices(self, - mock_get_index_name, - mock_get_vector_db_core, - mock_get_embedding_model, - mock_build_mapping, - mock_get_knowledge_list, - mock_signature, - mock_get_class): - """Resolve index_names from user input when no stored selections exist.""" - mock_tool_class = Mock() - mock_tool_instance = Mock() - mock_tool_instance.forward.return_value = "resolved result" - mock_tool_class.return_value = mock_tool_instance - mock_get_class.return_value = mock_tool_class - - mock_sig = Mock() - mock_sig.parameters = { - 'self': Mock(), - 'index_names': Mock(), - 'vdb_core': Mock(), - 'embedding_model': Mock() - } - mock_signature.return_value = mock_sig - - mock_get_knowledge_list.return_value = [] # No stored selections - mock_build_mapping.return_value = {"existing": "existing_index"} - mock_get_embedding_model.return_value = "mock_embedding" - mock_vdb_core = Mock() - mock_get_vector_db_core.return_value = mock_vdb_core - - # First alias resolves; second keeps raw value on exception - mock_get_index_name.side_effect = [ - "resolved_index", Exception("not found")] - - from backend.services.tool_configuration_service import _validate_local_tool - - result = _validate_local_tool( - "knowledge_base_search", - {"query": "q", "index_names": ["alias1", "raw_index"]}, - {"param": "config"}, - "tenant1", - "user1" - ) - - assert result == "resolved result" - expected_params = { - "param": "config", - "index_names": ["resolved_index", "raw_index"], - "name_resolver": {"existing": "existing_index", "alias1": "resolved_index"}, - "vdb_core": mock_vdb_core, - "embedding_model": "mock_embedding", - } - mock_tool_class.assert_called_once_with(**expected_params) - mock_tool_instance.forward.assert_called_once_with( - query="q", index_names=["alias1", "raw_index"] - ) - assert mock_get_index_name.call_count == 2 - mock_get_index_name.assert_any_call("alias1", tenant_id="tenant1") - mock_get_index_name.assert_any_call("raw_index", tenant_id="tenant1") @patch('backend.services.tool_configuration_service._get_tool_class_by_name') @patch('backend.services.tool_configuration_service.inspect.signature') @@ -2126,9 +2059,11 @@ def test_validate_local_tool_knowledge_base_search_execution_error(self, mock_ge # Mock signature for knowledge_base_search tool mock_sig = Mock() + mock_index_names_param = Mock() + mock_index_names_param.default = ["default_index"] mock_sig.parameters = { 'self': Mock(), - 'index_names': Mock(), + 'index_names': mock_index_names_param, 'vdb_core': Mock(), 'embedding_model': Mock() } @@ -2244,23 +2179,18 @@ def test_validate_local_tool_datamate_search_tool_success(self, mock_get_knowled mock_get_class.return_value = mock_tool_class # Mock signature for datamate_search_tool + # _validate_local_tool fills missing instantiation params from signature defaults. + # For datamate_search there is no special index selection logic, so index_names + # should come from the default value (empty list). mock_sig = Mock() mock_sig.parameters = { 'self': Mock(), - 'index_names': Mock(), + 'index_names': Mock(default=Mock(default=[])), } mock_signature.return_value = mock_sig - # Mock knowledge base dependencies - only datamate sources - mock_knowledge_list = [ - {"index_name": "datamate_index1", "knowledge_id": "kb1", - "knowledge_sources": "datamate"}, - {"index_name": "datamate_index2", "knowledge_id": "kb2", - "knowledge_sources": "datamate"}, - {"index_name": "other_index", "knowledge_id": "kb3", - "knowledge_sources": "other"} # Should be filtered out - ] - mock_get_knowledge_list.return_value = mock_knowledge_list + # datamate_search has no special knowledge-list dependency in _validate_local_tool + mock_get_knowledge_list.return_value = [] from backend.services.tool_configuration_service import _validate_local_tool @@ -2278,69 +2208,76 @@ def test_validate_local_tool_datamate_search_tool_success(self, mock_get_knowled # Verify datamate_search_tool specific parameters were passed expected_params = { "param": "config", - # Only datamate sources - "index_names": ["datamate_index1", "datamate_index2"], + # Filled from signature default + "index_names": [], } mock_tool_class.assert_called_once_with(**expected_params) mock_tool_instance.forward.assert_called_once_with(query="test query") - # Verify service calls - mock_get_knowledge_list.assert_called_once_with( - tenant_id="tenant1", user_id="user1") + mock_get_knowledge_list.assert_not_called() @patch('backend.services.tool_configuration_service._get_tool_class_by_name') def test_validate_local_tool_datamate_search_tool_missing_tenant_id(self, mock_get_class): """Test datamate_search_tool validation when tenant_id is missing""" mock_tool_class = Mock() + mock_tool_instance = Mock() + mock_tool_instance.forward.return_value = "datamate search result" + mock_tool_class.return_value = mock_tool_instance mock_get_class.return_value = mock_tool_class from backend.services.tool_configuration_service import _validate_local_tool - with pytest.raises(ToolExecutionException, - match=r"Local tool datamate_search validation failed: Tenant ID and User ID are required for datamate_search validation"): - _validate_local_tool( - "datamate_search", - {"query": "test query"}, - {"param": "config"}, - None, # Missing tenant_id - "user1" - ) + # datamate_search does not require tenant/user in current implementation + result = _validate_local_tool( + "datamate_search", + {"query": "test query"}, + {"param": "config"}, + None, # Missing tenant_id + "user1" + ) + assert result == "datamate search result" @patch('backend.services.tool_configuration_service._get_tool_class_by_name') def test_validate_local_tool_datamate_search_tool_missing_user_id(self, mock_get_class): """Test datamate_search_tool validation when user_id is missing""" mock_tool_class = Mock() + mock_tool_instance = Mock() + mock_tool_instance.forward.return_value = "datamate search result" + mock_tool_class.return_value = mock_tool_instance mock_get_class.return_value = mock_tool_class from backend.services.tool_configuration_service import _validate_local_tool - with pytest.raises(ToolExecutionException, - match=r"Local tool datamate_search validation failed: Tenant ID and User ID are required for datamate_search validation"): - _validate_local_tool( - "datamate_search", - {"query": "test query"}, - {"param": "config"}, - "tenant1", - None # Missing user_id - ) + # datamate_search does not require tenant/user in current implementation + result = _validate_local_tool( + "datamate_search", + {"query": "test query"}, + {"param": "config"}, + "tenant1", + None # Missing user_id + ) + assert result == "datamate search result" @patch('backend.services.tool_configuration_service._get_tool_class_by_name') def test_validate_local_tool_datamate_search_tool_missing_both_ids(self, mock_get_class): """Test datamate_search_tool validation when both tenant_id and user_id are missing""" mock_tool_class = Mock() + mock_tool_instance = Mock() + mock_tool_instance.forward.return_value = "datamate search result" + mock_tool_class.return_value = mock_tool_instance mock_get_class.return_value = mock_tool_class from backend.services.tool_configuration_service import _validate_local_tool - with pytest.raises(ToolExecutionException, - match=r"Local tool datamate_search validation failed: Tenant ID and User ID are required for datamate_search validation"): - _validate_local_tool( - "datamate_search", - {"query": "test query"}, - {"param": "config"}, - None, # Missing tenant_id - None # Missing user_id - ) + # datamate_search does not require tenant/user in current implementation + result = _validate_local_tool( + "datamate_search", + {"query": "test query"}, + {"param": "config"}, + None, # Missing tenant_id + None # Missing user_id + ) + assert result == "datamate search result" @patch('backend.services.tool_configuration_service._get_tool_class_by_name') @patch('backend.services.tool_configuration_service.inspect.signature') @@ -2356,11 +2293,11 @@ def test_validate_local_tool_datamate_search_tool_empty_knowledge_list(self, moc mock_get_class.return_value = mock_tool_class - # Mock signature for datamate_search_tool + # Mock signature for datamate_search_tool (default empty list) mock_sig = Mock() mock_sig.parameters = { 'self': Mock(), - 'index_names': Mock(), + 'index_names': Mock(default=Mock(default=[])), } mock_signature.return_value = mock_sig @@ -2386,6 +2323,7 @@ def test_validate_local_tool_datamate_search_tool_empty_knowledge_list(self, moc } mock_tool_class.assert_called_once_with(**expected_params) mock_tool_instance.forward.assert_called_once_with(query="test query") + mock_get_knowledge_list.assert_not_called() @patch('backend.services.tool_configuration_service._get_tool_class_by_name') @patch('backend.services.tool_configuration_service.inspect.signature') @@ -2401,11 +2339,11 @@ def test_validate_local_tool_datamate_search_tool_no_datamate_sources(self, mock mock_get_class.return_value = mock_tool_class - # Mock signature for datamate_search_tool + # Mock signature for datamate_search_tool (default empty list) mock_sig = Mock() mock_sig.parameters = { 'self': Mock(), - 'index_names': Mock(), + 'index_names': Mock(default=Mock(default=[])), } mock_signature.return_value = mock_sig @@ -2437,6 +2375,7 @@ def test_validate_local_tool_datamate_search_tool_no_datamate_sources(self, mock } mock_tool_class.assert_called_once_with(**expected_params) mock_tool_instance.forward.assert_called_once_with(query="test query") + mock_get_knowledge_list.assert_not_called() @patch('backend.services.tool_configuration_service._get_tool_class_by_name') @patch('backend.services.tool_configuration_service.inspect.signature') diff --git a/test/sdk/core/agents/test_nexent_agent.py b/test/sdk/core/agents/test_nexent_agent.py index 44edb6e58..075cfbcff 100644 --- a/test/sdk/core/agents/test_nexent_agent.py +++ b/test/sdk/core/agents/test_nexent_agent.py @@ -746,7 +746,6 @@ def test_create_local_tool_knowledge_base_search_tool_success(nexent_agent_insta # Verify excluded parameters were set directly as attributes after instantiation assert result == mock_kb_tool_instance assert mock_kb_tool_instance.observer == nexent_agent_instance.observer - assert mock_kb_tool_instance.index_names == ["index1", "index2"] assert mock_kb_tool_instance.vdb_core == mock_vdb_core assert mock_kb_tool_instance.embedding_model == mock_embedding_model @@ -797,11 +796,11 @@ def test_create_local_tool_knowledge_base_search_tool_with_conflicting_params(ne # Only non-excluded params should be passed to __init__ due to smolagents wrapper restrictions mock_kb_tool_class.assert_called_once_with( top_k=10, # From filtered_params (not in conflict list) + index_names=["conflicting_index"], # Not excluded by current implementation ) # Verify excluded parameters were set directly as attributes after instantiation assert result == mock_kb_tool_instance assert mock_kb_tool_instance.observer == nexent_agent_instance.observer - assert mock_kb_tool_instance.index_names == ["index1", "index2"] # From metadata, not params assert mock_kb_tool_instance.vdb_core == mock_vdb_core # From metadata, not params assert mock_kb_tool_instance.embedding_model == mock_embedding_model # From metadata, not params @@ -842,7 +841,6 @@ def test_create_local_tool_knowledge_base_search_tool_with_none_defaults(nexent_ # Verify excluded parameters were set directly as attributes with None defaults when metadata is missing assert result == mock_kb_tool_instance assert mock_kb_tool_instance.observer == nexent_agent_instance.observer - assert mock_kb_tool_instance.index_names == [] # Empty list when None assert mock_kb_tool_instance.vdb_core is None assert mock_kb_tool_instance.embedding_model is None assert result == mock_kb_tool_instance @@ -1367,7 +1365,6 @@ def test_create_local_tool_datamate_search_tool_success(nexent_agent_instance): # Verify excluded parameters were set directly as attributes after instantiation assert result == mock_datamate_tool_instance assert mock_datamate_tool_instance.observer == nexent_agent_instance.observer - assert mock_datamate_tool_instance.index_names == ["datamate_index1", "datamate_index2"] @@ -1407,7 +1404,6 @@ def test_create_local_tool_datamate_search_tool_with_none_defaults(nexent_agent_ # Verify excluded parameters were set directly as attributes with None defaults when metadata is missing assert result == mock_datamate_tool_instance assert mock_datamate_tool_instance.observer == nexent_agent_instance.observer - assert mock_datamate_tool_instance.index_names == [] # Empty list when None def test_create_local_tool_datamate_search_tool_success(nexent_agent_instance): @@ -1448,7 +1444,6 @@ def test_create_local_tool_datamate_search_tool_success(nexent_agent_instance): # Verify excluded parameters were set directly as attributes after instantiation assert result == mock_datamate_tool_instance assert mock_datamate_tool_instance.observer == nexent_agent_instance.observer - assert mock_datamate_tool_instance.index_names == ["datamate_index1", "datamate_index2"] def test_create_local_tool_datamate_search_tool_with_none_defaults(nexent_agent_instance): @@ -1487,7 +1482,6 @@ def test_create_local_tool_datamate_search_tool_with_none_defaults(nexent_agent_ # Verify excluded parameters were set directly as attributes with None defaults when metadata is missing assert result == mock_datamate_tool_instance assert mock_datamate_tool_instance.observer == nexent_agent_instance.observer - assert mock_datamate_tool_instance.index_names == [] # Empty list when None if __name__ == "__main__": diff --git a/test/sdk/core/tools/test_datamate_search_tool.py b/test/sdk/core/tools/test_datamate_search_tool.py index 71483e1e8..69d038572 100644 --- a/test/sdk/core/tools/test_datamate_search_tool.py +++ b/test/sdk/core/tools/test_datamate_search_tool.py @@ -5,10 +5,8 @@ import pytest from pytest_mock import MockFixture -from sdk.nexent.core.tools.datamate_search_tool import DataMateSearchTool, _normalize_index_names +from sdk.nexent.core.tools.datamate_search_tool import DataMateSearchTool from sdk.nexent.core.utils.observer import MessageObserver, ProcessType -from sdk.nexent.datamate.datamate_client import DataMateClient - @pytest.fixture def mock_observer() -> MessageObserver: @@ -22,6 +20,9 @@ def datamate_tool(mock_observer: MessageObserver) -> DataMateSearchTool: tool = DataMateSearchTool( server_url="http://127.0.0.1:8080", observer=mock_observer, + index_names=["kb1"], + top_k=2, + threshold=0.5, ) return tool @@ -151,22 +152,6 @@ def test_extract_dataset_id(self, datamate_tool: DataMateSearchTool, path, expec assert datamate_tool._extract_dataset_id(path) == expected -class TestNormalizeIndexNames: - @pytest.mark.parametrize( - "input_names, expected", - [ - (None, []), - ("single_kb", ["single_kb"]), - (["kb1", "kb2"], ["kb1", "kb2"]), - ([], []), - ("", [""]), # Edge case: empty string becomes list with empty string - ], - ) - def test_normalize_index_names(self, input_names, expected): - result = _normalize_index_names(input_names) - assert result == expected - - class TestForward: def test_forward_success_with_observer_en(self, datamate_tool: DataMateSearchTool, mocker: MockFixture): # Mock the hybrid_search method to return search results @@ -179,8 +164,7 @@ def test_forward_success_with_observer_en(self, datamate_tool: DataMateSearchToo datamate_tool.datamate_core.client, 'build_file_download_url') mock_build_url.side_effect = lambda ds, fid: f"http://dl/{ds}/{fid}" - result_json = datamate_tool.forward("test query", index_names=[ - "kb1"], top_k=2, threshold=0.5) + result_json = datamate_tool.forward("test query") results = json.loads(result_json) assert len(results) == 2 @@ -216,7 +200,7 @@ def test_forward_success_with_observer_zh(self, datamate_tool: DataMateSearchToo datamate_tool.datamate_core.client, 'build_file_download_url') mock_build_url.return_value = "http://dl/kb1/file-1" - datamate_tool.forward("测试查询", index_names=["kb1"]) + datamate_tool.forward("测试查询") datamate_tool.observer.add_message.assert_any_call( "", ProcessType.TOOL, datamate_tool.running_prompt_zh) @@ -235,7 +219,8 @@ def test_forward_no_observer(self, mocker: MockFixture): tool.datamate_core.client, 'build_file_download_url') mock_build_url.return_value = "http://dl/kb1/file-1" - result_json = tool.forward("query", index_names=["kb1"]) + tool.index_names = ["kb1"] + result_json = tool.forward("query") assert len(json.loads(result_json)) == 1 def test_forward_no_knowledge_bases(self, datamate_tool: DataMateSearchTool, mocker: MockFixture): @@ -243,7 +228,8 @@ def test_forward_no_knowledge_bases(self, datamate_tool: DataMateSearchTool, moc mock_hybrid_search = mocker.patch.object( datamate_tool.datamate_core, 'hybrid_search') - result = datamate_tool.forward("query", index_names=[]) + datamate_tool.index_names = [] + result = datamate_tool.forward("query") assert result == json.dumps( "No knowledge base selected. No relevant information found.", ensure_ascii=False) mock_hybrid_search.assert_not_called() @@ -255,7 +241,7 @@ def test_forward_no_results(self, datamate_tool: DataMateSearchTool, mocker: Moc mock_hybrid_search.return_value = [] with pytest.raises(Exception) as excinfo: - datamate_tool.forward("query", index_names=["kb1"]) + datamate_tool.forward("query") assert "No results found! Try a less restrictive/shorter query." in str( excinfo.value) @@ -267,7 +253,7 @@ def test_forward_wrapped_error(self, datamate_tool: DataMateSearchTool, mocker: mock_hybrid_search.side_effect = RuntimeError("low level error") with pytest.raises(Exception) as excinfo: - datamate_tool.forward("query", index_names=["kb1"]) + datamate_tool.forward("query") msg = str(excinfo.value) assert "Error during DataMate knowledge base search" in msg @@ -301,14 +287,14 @@ def test_forward_with_default_index_names(self, datamate_tool: DataMateSearchToo mock_hybrid_search.assert_any_call( query_text="query", index_names=["default_kb1"], - top_k=3, - weight_accurate=0.2 + top_k=2, + weight_accurate=0.5 ) mock_hybrid_search.assert_any_call( query_text="query", index_names=["default_kb2"], - top_k=3, - weight_accurate=0.2 + top_k=2, + weight_accurate=0.5 ) def test_forward_multiple_knowledge_bases(self, datamate_tool: DataMateSearchTool, mocker: MockFixture): @@ -328,8 +314,8 @@ def test_forward_multiple_knowledge_bases(self, datamate_tool: DataMateSearchToo datamate_tool.datamate_core.client, 'build_file_download_url') mock_build_url.side_effect = lambda ds, fid: f"http://dl/{ds}/{fid}" - result_json = datamate_tool.forward( - "query", index_names=["kb1", "kb2"]) + datamate_tool.index_names = ["kb1", "kb2"] + result_json = datamate_tool.forward("query") results = json.loads(result_json) assert len(results) == 3 # 1 from kb1 + 2 from kb2 @@ -339,14 +325,14 @@ def test_forward_multiple_knowledge_bases(self, datamate_tool: DataMateSearchToo mock_hybrid_search.assert_any_call( query_text="query", index_names=["kb1"], - top_k=3, - weight_accurate=0.2 + top_k=2, + weight_accurate=0.5 ) mock_hybrid_search.assert_any_call( query_text="query", index_names=["kb2"], - top_k=3, - weight_accurate=0.2 + top_k=2, + weight_accurate=0.5 ) def test_forward_with_custom_parameters(self, datamate_tool: DataMateSearchTool, mocker: MockFixture): @@ -361,11 +347,11 @@ def test_forward_with_custom_parameters(self, datamate_tool: DataMateSearchTool, datamate_tool.datamate_core.client, 'build_file_download_url') mock_build_url.return_value = "http://dl/kb1/file-1" + datamate_tool.index_names = ["kb1"] + datamate_tool.top_k = 5 + datamate_tool.threshold = 0.8 result_json = datamate_tool.forward( query="custom query", - index_names=["kb1"], - top_k=5, - threshold=0.8, kb_page=2, kb_page_size=50 ) @@ -432,7 +418,8 @@ def test_forward_metadata_parsing_edge_cases(self, datamate_tool: DataMateSearch datamate_tool.datamate_core.client, 'build_file_download_url') mock_build_url.return_value = "http://dl/kb1/file" - result_json = datamate_tool.forward("query", index_names=["kb1"]) + datamate_tool.index_names = ["kb1"] + result_json = datamate_tool.forward("query") results = json.loads(result_json) assert len(results) == 3 diff --git a/test/sdk/core/tools/test_knowledge_base_search_tool.py b/test/sdk/core/tools/test_knowledge_base_search_tool.py index f6cdc4577..082880782 100644 --- a/test/sdk/core/tools/test_knowledge_base_search_tool.py +++ b/test/sdk/core/tools/test_knowledge_base_search_tool.py @@ -35,10 +35,10 @@ def knowledge_base_search_tool(mock_observer, mock_vdb_core, mock_embedding_mode tool = KnowledgeBaseSearchTool( top_k=5, index_names=["test_index1", "test_index2"], + search_mode="hybrid", observer=mock_observer, embedding_model=mock_embedding_model, - vdb_core=mock_vdb_core, - name_resolver={} + vdb_core=mock_vdb_core ) return tool @@ -49,10 +49,10 @@ def knowledge_base_search_tool_no_observer(mock_vdb_core, mock_embedding_model): tool = KnowledgeBaseSearchTool( top_k=3, index_names=["test_index"], + search_mode="hybrid", observer=None, embedding_model=mock_embedding_model, - vdb_core=mock_vdb_core, - name_resolver={} + vdb_core=mock_vdb_core ) return tool @@ -80,35 +80,6 @@ def create_mock_search_result(count=3): class TestKnowledgeBaseSearchTool: """Test KnowledgeBaseSearchTool functionality""" - def test_update_name_resolver_supports_empty_mapping(self, knowledge_base_search_tool): - """Ensure update_name_resolver replaces mapping and handles falsy input""" - knowledge_base_search_tool.update_name_resolver({"kb": "index_kb"}) - assert knowledge_base_search_tool.name_resolver == {"kb": "index_kb"} - - knowledge_base_search_tool.update_name_resolver(None) - assert knowledge_base_search_tool.name_resolver == {} - - def test_resolve_names_without_resolver_logs_warning(self, knowledge_base_search_tool, mocker): - """When no resolver is configured, names are returned unchanged and warning is logged""" - warning_mock = mocker.patch("sdk.nexent.core.tools.knowledge_base_search_tool.logger.warning") - - names = knowledge_base_search_tool._resolve_names(["kb1", "kb2"]) - - assert names == ["kb1", "kb2"] - warning_mock.assert_called_once() - - @pytest.mark.parametrize( - "incoming,expected", - [ - (None, []), - ("single_index", ["single_index"]), - (["a", "b"], ["a", "b"]), - ], - ) - def test_normalize_index_names_variants(self, knowledge_base_search_tool_no_observer, incoming, expected): - """_normalize_index_names should normalize None, string, and list inputs""" - assert knowledge_base_search_tool_no_observer._normalize_index_names(incoming) == expected - def test_forward_with_observer_adds_messages(self, knowledge_base_search_tool): """forward should send TOOL and CARD messages when observer is present""" mock_results = create_mock_search_result(1) @@ -128,10 +99,10 @@ def test_init_with_custom_values(self, mock_observer, mock_vdb_core, mock_embedd tool = KnowledgeBaseSearchTool( top_k=10, index_names=["index1", "index2", "index3"], + search_mode="hybrid", observer=mock_observer, embedding_model=mock_embedding_model, - vdb_core=mock_vdb_core, - name_resolver={} + vdb_core=mock_vdb_core ) assert tool.top_k == 10 @@ -145,10 +116,10 @@ def test_init_with_none_index_names(self, mock_vdb_core, mock_embedding_model): tool = KnowledgeBaseSearchTool( top_k=5, index_names=None, + search_mode="hybrid", observer=None, embedding_model=mock_embedding_model, - vdb_core=mock_vdb_core, - name_resolver={} + vdb_core=mock_vdb_core ) assert tool.index_names == [] @@ -235,7 +206,8 @@ def test_forward_accurate_mode_success(self, knowledge_base_search_tool): mock_results = create_mock_search_result(2) knowledge_base_search_tool.vdb_core.accurate_search.return_value = mock_results - result = knowledge_base_search_tool.forward("test query", search_mode="accurate") + knowledge_base_search_tool.search_mode = "accurate" + result = knowledge_base_search_tool.forward("test query") # Parse result search_results = json.loads(result) @@ -249,7 +221,8 @@ def test_forward_semantic_mode_success(self, knowledge_base_search_tool): mock_results = create_mock_search_result(4) knowledge_base_search_tool.vdb_core.semantic_search.return_value = mock_results - result = knowledge_base_search_tool.forward("test query", search_mode="semantic") + knowledge_base_search_tool.search_mode = "semantic" + result = knowledge_base_search_tool.forward("test query") # Parse result search_results = json.loads(result) @@ -259,8 +232,9 @@ def test_forward_semantic_mode_success(self, knowledge_base_search_tool): def test_forward_invalid_search_mode(self, knowledge_base_search_tool): """Test forward method with invalid search mode""" + knowledge_base_search_tool.search_mode = "invalid" with pytest.raises(Exception) as excinfo: - knowledge_base_search_tool.forward("test query", search_mode="invalid") + knowledge_base_search_tool.forward("test query") assert "Invalid search mode" in str(excinfo.value) assert "hybrid, accurate, semantic" in str(excinfo.value) @@ -291,11 +265,8 @@ def test_forward_with_custom_index_names(self, knowledge_base_search_tool): mock_results = create_mock_search_result(2) knowledge_base_search_tool.vdb_core.hybrid_search.return_value = mock_results - result = knowledge_base_search_tool.forward( - "test query", - search_mode="hybrid", - index_names=["custom_index1", "custom_index2"] - ) + knowledge_base_search_tool.index_names = ["custom_index1", "custom_index2"] + result = knowledge_base_search_tool.forward("test query") # Verify vdb_core was called with custom index names knowledge_base_search_tool.vdb_core.hybrid_search.assert_called_once_with( From 0dcf730abfe285d293881c6e62c6a55bd085bf7f Mon Sep 17 00:00:00 2001 From: zhizhi <928570418@qq.com> Date: Wed, 28 Jan 2026 16:40:24 +0800 Subject: [PATCH 013/167] =?UTF-8?q?=E2=9C=A8The=20knowledge=20base=20retri?= =?UTF-8?q?eval=20tool=20supports=20selecting=20a=20specified=20knowledge?= =?UTF-8?q?=20base=20=E2=80=94=E2=80=94=20backend?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/backend/agents/test_create_agent_info.py | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/test/backend/agents/test_create_agent_info.py b/test/backend/agents/test_create_agent_info.py index fcb438009..fbbdfb576 100644 --- a/test/backend/agents/test_create_agent_info.py +++ b/test/backend/agents/test_create_agent_info.py @@ -851,6 +851,58 @@ async def test_create_agent_config_memory_exception(self): async def test_create_agent_config_with_knowledge_base_summary_filtering(self): pass + @pytest.mark.asyncio + async def test_create_agent_config_knowledge_base_summary_error(self): + """Test case for error handling during knowledge base summary build""" + with patch('backend.agents.create_agent_info.search_agent_info_by_agent_id') as mock_search_agent, \ + patch('backend.agents.create_agent_info.query_sub_agents_id_list') as mock_query_sub, \ + patch('backend.agents.create_agent_info.create_tool_config_list') as mock_create_tools, \ + patch('backend.agents.create_agent_info.get_agent_prompt_template') as mock_get_template, \ + patch('backend.agents.create_agent_info.tenant_config_manager') as mock_tenant_config, \ + patch('backend.agents.create_agent_info.build_memory_context') as mock_build_memory, \ + patch('backend.agents.create_agent_info.AgentConfig') as mock_agent_config, \ + patch('backend.agents.create_agent_info.prepare_prompt_templates') as mock_prepare_templates, \ + patch('backend.agents.create_agent_info.get_model_by_model_id') as mock_get_model_by_id, \ + patch('backend.agents.create_agent_info.logger') as mock_logger: + + # Set mock return values + mock_search_agent.return_value = { + "name": "test_agent", + "description": "test description", + "duty_prompt": "test duty", + "constraint_prompt": "test constraint", + "few_shots_prompt": "test few shots", + "max_steps": 5, + "model_id": 123, + "provide_run_summary": True + } + mock_query_sub.return_value = [] + + # Create a tool that raises exception when accessing class_name + mock_tool = MagicMock() + type(mock_tool).class_name = PropertyMock(side_effect=Exception("Test Error")) + mock_create_tools.return_value = [mock_tool] + + mock_get_template.return_value = { + "system_prompt": "{{duty}} {{constraint}} {{few_shots}}"} + mock_tenant_config.get_app_config.side_effect = [ + "TestApp", "Test Description"] + mock_build_memory.return_value = Mock( + user_config=Mock(memory_switch=False), + memory_config={}, + tenant_id="tenant_1", + user_id="user_1", + agent_id="agent_1" + ) + mock_prepare_templates.return_value = { + "system_prompt": "populated_system_prompt"} + mock_get_model_by_id.return_value = {"display_name": "test_model"} + + await create_agent_config("agent_1", "tenant_1", "user_1", "zh", "test query") + + # Verify that error was logged + mock_logger.error.assert_called_with("Failed to build knowledge base summary: Test Error") + class TestCreateModelConfigList: """Tests for the create_model_config_list function""" From 25b7206027d428f512302a785189321f89ebd7d7 Mon Sep 17 00:00:00 2001 From: zhizhi <928570418@qq.com> Date: Thu, 29 Jan 2026 11:49:09 +0800 Subject: [PATCH 014/167] =?UTF-8?q?=E2=9C=A8The=20knowledge=20base=20retri?= =?UTF-8?q?eval=20tool=20supports=20selecting=20a=20specified=20knowledge?= =?UTF-8?q?=20base=20=201.=20Remove=20unused=20interfaces?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/apps/tenant_config_app.py | 76 +----- backend/services/tenant_config_service.py | 94 ------- .../services/tool_configuration_service.py | 1 - frontend/services/api.ts | 11 +- .../core/tools/knowledge_base_search_tool.py | 2 +- test/backend/app/test_tenant_config_app.py | 229 ----------------- .../services/test_tenant_config_service.py | 235 ------------------ 7 files changed, 8 insertions(+), 640 deletions(-) delete mode 100644 backend/services/tenant_config_service.py delete mode 100644 test/backend/services/test_tenant_config_service.py diff --git a/backend/apps/tenant_config_app.py b/backend/apps/tenant_config_app.py index 371e3f864..cd67f0c8f 100644 --- a/backend/apps/tenant_config_app.py +++ b/backend/apps/tenant_config_app.py @@ -1,14 +1,10 @@ import logging from http import HTTPStatus -from typing import List, Optional -from fastapi import APIRouter, Body, Header, HTTPException +from fastapi import APIRouter, HTTPException from fastapi.responses import JSONResponse from consts.const import DEPLOYMENT_VERSION, APP_VERSION -from consts.model import UpdateKnowledgeListRequest -from services.tenant_config_service import get_selected_knowledge_list, update_selected_knowledge -from utils.auth_utils import get_current_user_id logger = logging.getLogger("tenant_config_app") router = APIRouter(prefix="/tenant_config") @@ -34,74 +30,4 @@ def get_deployment_version(): ) -@router.get("/load_knowledge_list") -def load_knowledge_list( - authorization: Optional[str] = Header(None) -): - try: - user_id, tenant_id = get_current_user_id(authorization) - selected_knowledge_info = get_selected_knowledge_list( - tenant_id=tenant_id, user_id=user_id) - - content = {"selectedKbNames": [item["index_name"] for item in selected_knowledge_info], - "selectedKbModels": [item["embedding_model_name"] for item in selected_knowledge_info], - "selectedKbSources": [item["knowledge_sources"] for item in selected_knowledge_info]} - - return JSONResponse( - status_code=HTTPStatus.OK, - content={"content": content, "status": "success"} - ) - except Exception as e: - logger.error(f"load knowledge list failed, error: {e}") - raise HTTPException( - status_code=HTTPStatus.INTERNAL_SERVER_ERROR, - detail="Failed to load configuration" - ) - -@router.post("/update_knowledge_list") -def update_knowledge_list( - authorization: Optional[str] = Header(None), - request: UpdateKnowledgeListRequest = Body(...) -): - try: - user_id, tenant_id = get_current_user_id(authorization) - - # Convert grouped request to flat lists - knowledge_list = [] - knowledge_sources = [] - - if request.nexent: - knowledge_list.extend(request.nexent) - knowledge_sources.extend(["nexent"] * len(request.nexent)) - - if request.datamate: - knowledge_list.extend(request.datamate) - knowledge_sources.extend(["datamate"] * len(request.datamate)) - - result = update_selected_knowledge( - tenant_id=tenant_id, user_id=user_id, index_name_list=knowledge_list, knowledge_sources=knowledge_sources) - if result: - # Get updated knowledge base information - selected_knowledge_info = get_selected_knowledge_list( - tenant_id=tenant_id, user_id=user_id) - - content = {"selectedKbNames": [item["index_name"] for item in selected_knowledge_info], - "selectedKbModels": [item["embedding_model_name"] for item in selected_knowledge_info], - "selectedKbSources": [item["knowledge_sources"] for item in selected_knowledge_info]} - - return JSONResponse( - status_code=HTTPStatus.OK, - content={"content": content, "message": "update success", "status": "success"} - ) - else: - raise HTTPException( - status_code=HTTPStatus.INTERNAL_SERVER_ERROR, - detail="Failed to update configuration" - ) - except Exception as e: - logger.error(f"update knowledge list failed, error: {e}") - raise HTTPException( - status_code=HTTPStatus.INTERNAL_SERVER_ERROR, - detail="Failed to update configuration" - ) diff --git a/backend/services/tenant_config_service.py b/backend/services/tenant_config_service.py deleted file mode 100644 index c0e4d4afb..000000000 --- a/backend/services/tenant_config_service.py +++ /dev/null @@ -1,94 +0,0 @@ -import logging -from typing import List, Optional - -from database.knowledge_db import get_knowledge_info_by_knowledge_ids, get_knowledge_ids_by_index_names -from database.tenant_config_db import get_tenant_config_info, insert_config, delete_config_by_tenant_config_id - -logger = logging.getLogger("tenant_config_service") - - -def get_selected_knowledge_list(tenant_id: str, user_id: str): - record_list = get_tenant_config_info( - tenant_id=tenant_id, user_id=user_id, select_key="selected_knowledge_id") - if len(record_list) == 0: - return [] - knowledge_id_list = [record["config_value"] for record in record_list] - knowledge_info = get_knowledge_info_by_knowledge_ids(knowledge_id_list) - return knowledge_info - - -def update_selected_knowledge(tenant_id: str, user_id: str, index_name_list: List[str], knowledge_sources: Optional[List[str]] = None): - # Validate that knowledge_sources length matches index_name_list if provided - if knowledge_sources and len(knowledge_sources) != len(index_name_list): - logger.error( - f"Knowledge sources length mismatch: sources={len(knowledge_sources)}, names={len(index_name_list)}") - return False - - logger.info( - f"Updating knowledge list for tenant {tenant_id}, user {user_id}: " - f"names={index_name_list}, sources={knowledge_sources}") - - knowledge_ids = get_knowledge_ids_by_index_names(index_name_list) - record_list = get_tenant_config_info( - tenant_id=tenant_id, user_id=user_id, select_key="selected_knowledge_id") - record_ids = [record["tenant_config_id"] for record in record_list] - - # if knowledge_ids is not in record_list, insert the record of knowledge_ids - for knowledge_id in knowledge_ids: - if knowledge_id not in record_ids: - result = insert_config({ - "user_id": user_id, - "tenant_id": tenant_id, - "config_key": "selected_knowledge_id", - "config_value": knowledge_id, - "value_type": "multi" - }) - if not result: - logger.error( - f"insert_config failed, tenant_id: {tenant_id}, user_id: {user_id}, knowledge_id: {knowledge_id}") - return False - - # if record_list is not in knowledge_ids, delete the record of record_list - for record in record_list: - if record["config_value"] not in knowledge_ids: - result = delete_config_by_tenant_config_id( - record["tenant_config_id"]) - if not result: - logger.error( - f"delete_config_by_tenant_config_id failed, tenant_id: {tenant_id}, user_id: {user_id}, knowledge_id: {record['config_value']}") - return False - - return True - - -def delete_selected_knowledge_by_index_name(tenant_id: str, user_id: str, index_name: str): - knowledge_ids = get_knowledge_ids_by_index_names([index_name]) - record_list = get_tenant_config_info( - tenant_id=tenant_id, user_id=user_id, select_key="selected_knowledge_id") - - for record in record_list: - if record["config_value"] == str(knowledge_ids[0]): - result = delete_config_by_tenant_config_id( - record["tenant_config_id"]) - if not result: - logger.error( - f"delete_config_by_tenant_config_id failed, tenant_id: {tenant_id}, user_id: {user_id}, knowledge_id: {record['config_value']}") - return False - - return True - - -def build_knowledge_name_mapping(tenant_id: str, user_id: str): - """ - Build mapping from user-facing knowledge_name to internal index_name for the selected knowledge bases. - Falls back to using index_name as key when knowledge_name is missing for backward compatibility. - """ - knowledge_info_list = get_selected_knowledge_list( - tenant_id=tenant_id, user_id=user_id) - mapping = {} - for info in knowledge_info_list: - key = info.get("knowledge_name") or info.get("index_name") - value = info.get("index_name") - if key and value: - mapping[key] = value - return mapping diff --git a/backend/services/tool_configuration_service.py b/backend/services/tool_configuration_service.py index 7391737ee..32cdf6f65 100644 --- a/backend/services/tool_configuration_service.py +++ b/backend/services/tool_configuration_service.py @@ -23,7 +23,6 @@ ) from services.file_management_service import get_llm_model from services.vectordatabase_service import get_embedding_model, get_vector_db_core -from services.tenant_config_service import get_selected_knowledge_list, build_knowledge_name_mapping from database.client import minio_client from services.image_service import get_vlm_model diff --git a/frontend/services/api.ts b/frontend/services/api.ts index af27039f2..2ea7d5c9a 100644 --- a/frontend/services/api.ts +++ b/frontend/services/api.ts @@ -42,7 +42,8 @@ export const API_ENDPOINTS = { regenerateNameBatch: `${API_BASE_URL}/agent/regenerate_name`, searchInfo: `${API_BASE_URL}/agent/search_info`, callRelationship: `${API_BASE_URL}/agent/call_relationship`, - clearNew: (agentId: string | number) => `${API_BASE_URL}/agent/clear_new/${agentId}`, + clearNew: (agentId: string | number) => + `${API_BASE_URL}/agent/clear_new/${agentId}`, }, tool: { list: `${API_BASE_URL}/tool/list`, @@ -164,8 +165,6 @@ export const API_ENDPOINTS = { saveDataMateUrl: `${API_BASE_URL}/config/save_datamate_url`, }, tenantConfig: { - loadKnowledgeList: `${API_BASE_URL}/tenant_config/load_knowledge_list`, - updateKnowledgeList: `${API_BASE_URL}/tenant_config/update_knowledge_list`, deploymentVersion: `${API_BASE_URL}/tenant_config/deployment_version`, }, mcp: { @@ -255,8 +254,10 @@ export const API_ENDPOINTS = { invitations: { list: `${API_BASE_URL}/invitations/list`, create: `${API_BASE_URL}/invitations`, - update: (invitationCode: string) => `${API_BASE_URL}/invitations/${invitationCode}`, - delete: (invitationCode: string) => `${API_BASE_URL}/invitations/${invitationCode}`, + update: (invitationCode: string) => + `${API_BASE_URL}/invitations/${invitationCode}`, + delete: (invitationCode: string) => + `${API_BASE_URL}/invitations/${invitationCode}`, }, }; diff --git a/sdk/nexent/core/tools/knowledge_base_search_tool.py b/sdk/nexent/core/tools/knowledge_base_search_tool.py index ab2d2e702..6b17544b8 100644 --- a/sdk/nexent/core/tools/knowledge_base_search_tool.py +++ b/sdk/nexent/core/tools/knowledge_base_search_tool.py @@ -40,7 +40,7 @@ def __init__( top_k: int = Field( description="Maximum number of search results", default=3), index_names: List[str] = Field( - description="The list of index names to search", default=None), + description="The list of index names to search"), search_mode: str = Field( description="the search mode, optional values: hybrid, accurate, semantic", default="hybrid", diff --git a/test/backend/app/test_tenant_config_app.py b/test/backend/app/test_tenant_config_app.py index d79e71295..b9eab6199 100644 --- a/test/backend/app/test_tenant_config_app.py +++ b/test/backend/app/test_tenant_config_app.py @@ -114,209 +114,6 @@ def setUp(self): ] self.mock_update_knowledge.return_value = True - def test_load_knowledge_list_success(self): - """Test successful loading of knowledge list""" - response = self.client.get( - "/tenant_config/load_knowledge_list", - headers={"authorization": "Bearer test-token"} - ) - - self.assertEqual(response.status_code, HTTPStatus.OK) - data = response.json() - self.assertEqual(data["status"], "success") - self.assertIn("content", data) - - content = data["content"] - self.assertEqual(content["selectedKbNames"], ["kb1", "kb2"]) - self.assertEqual(content["selectedKbModels"], [ - "embedding-model-1", "embedding-model-2"]) - self.assertEqual(content["selectedKbSources"], [ - ["source1", "source2"], ["source3"]]) - - def test_load_knowledge_list_auth_error(self): - """Test knowledge list loading with authentication error""" - self.mock_get_user_id.side_effect = Exception("Authentication failed") - - response = self.client.get( - "/tenant_config/load_knowledge_list", - headers={"authorization": "Bearer invalid-token"} - ) - - self.assertEqual(response.status_code, - HTTPStatus.INTERNAL_SERVER_ERROR) - data = response.json() - self.assertIn("Failed to load configuration", data["detail"]) - - def test_load_knowledge_list_service_error(self): - """Test knowledge list loading with service error""" - self.mock_get_knowledge_list.side_effect = Exception("Database error") - - response = self.client.get( - "/tenant_config/load_knowledge_list", - headers={"authorization": "Bearer test-token"} - ) - - self.assertEqual(response.status_code, - HTTPStatus.INTERNAL_SERVER_ERROR) - data = response.json() - self.assertIn("Failed to load configuration", data["detail"]) - - def test_load_knowledge_list_empty(self): - """Test loading empty knowledge list""" - self.mock_get_knowledge_list.return_value = [] - - response = self.client.get( - "/tenant_config/load_knowledge_list", - headers={"authorization": "Bearer test-token"} - ) - - self.assertEqual(response.status_code, HTTPStatus.OK) - data = response.json() - self.assertEqual(data["status"], "success") - - content = data["content"] - self.assertEqual(content["selectedKbNames"], []) - self.assertEqual(content["selectedKbModels"], []) - self.assertEqual(content["selectedKbSources"], []) - - def test_load_knowledge_list_missing_model_name(self): - """Test loading knowledge list with missing model_name field""" - # This should cause a KeyError when trying to access model_name - self.mock_get_knowledge_list.return_value = [ - { - "index_name": "kb1", - "knowledge_sources": ["source1"] - # Missing embedding_model_name field - } - ] - - response = self.client.get( - "/tenant_config/load_knowledge_list", - headers={"authorization": "Bearer test-token"} - ) - - self.assertEqual(response.status_code, - HTTPStatus.INTERNAL_SERVER_ERROR) - data = response.json() - self.assertIn("Failed to load configuration", data["detail"]) - - def test_update_knowledge_list_success(self): - """Test successful knowledge list update""" - request_data = { - "nexent": ["kb1"], - "datamate": ["kb2"] - } - - response = self.client.post( - "/tenant_config/update_knowledge_list", - headers={"authorization": "Bearer test-token"}, - json=request_data - ) - - self.assertEqual(response.status_code, HTTPStatus.OK) - data = response.json() - self.assertEqual(data["status"], "success") - self.assertEqual(data["message"], "update success") - self.assertIn("content", data) - self.assertIn("selectedKbNames", data["content"]) - self.assertIn("selectedKbModels", data["content"]) - self.assertIn("selectedKbSources", data["content"]) - - # Verify the mock was called with correct parameters (flattened) - self.mock_update_knowledge.assert_called_once_with( - tenant_id="test_tenant", - user_id="test_user", - index_name_list=["kb1", "kb2"], - knowledge_sources=["nexent", "datamate"] - ) - - def test_update_knowledge_list_failure(self): - """Test knowledge list update failure""" - self.mock_update_knowledge.return_value = False - request_data = { - "nexent": ["kb1"], - "datamate": ["kb2"] - } - - response = self.client.post( - "/tenant_config/update_knowledge_list", - headers={"authorization": "Bearer test-token"}, - json=request_data - ) - - self.assertEqual(response.status_code, - HTTPStatus.INTERNAL_SERVER_ERROR) - data = response.json() - self.assertIn("Failed to update configuration", data["detail"]) - - def test_update_knowledge_list_auth_error(self): - """Test knowledge list update with authentication error""" - self.mock_get_user_id.side_effect = Exception("Authentication failed") - request_data = { - "nexent": ["kb1"], - "datamate": ["kb2"] - } - - response = self.client.post( - "/tenant_config/update_knowledge_list", - headers={"authorization": "Bearer invalid-token"}, - json=request_data - ) - - self.assertEqual(response.status_code, - HTTPStatus.INTERNAL_SERVER_ERROR) - data = response.json() - self.assertIn("Failed to update configuration", data["detail"]) - - def test_update_knowledge_list_service_error(self): - """Test knowledge list update with service error""" - self.mock_update_knowledge.side_effect = Exception("Database error") - request_data = { - "nexent": ["kb1"], - "datamate": ["kb2"] - } - - response = self.client.post( - "/tenant_config/update_knowledge_list", - headers={"authorization": "Bearer test-token"}, - json=request_data - ) - - self.assertEqual(response.status_code, - HTTPStatus.INTERNAL_SERVER_ERROR) - data = response.json() - self.assertIn("Failed to update configuration", data["detail"]) - - def test_update_knowledge_list_empty_list(self): - """Test updating with empty knowledge list""" - request_data = { - "nexent": [], - "datamate": [] - } - - response = self.client.post( - "/tenant_config/update_knowledge_list", - headers={"authorization": "Bearer test-token"}, - json=request_data - ) - - self.assertEqual(response.status_code, HTTPStatus.OK) - data = response.json() - self.assertEqual(data["status"], "success") - self.assertEqual(data["message"], "update success") - - def test_update_knowledge_list_no_body(self): - """Test updating without request body""" - response = self.client.post( - "/tenant_config/update_knowledge_list", - headers={"authorization": "Bearer test-token"} - ) - - # When no body is provided, Pydantic will raise validation error - self.assertEqual(response.status_code, 422) # Unprocessable Entity - data = response.json() - self.assertIn("detail", data) - def test_get_deployment_version_success(self): """Test successful retrieval of deployment version""" response = self.client.get("/tenant_config/deployment_version") @@ -328,32 +125,6 @@ def test_get_deployment_version_success(self): self.assertIn("app_version", data) self.assertEqual(len(data.keys()), 3) - def test_load_knowledge_list_no_auth_header(self): - """Test loading knowledge list without authorization header""" - response = self.client.get("/tenant_config/load_knowledge_list") - - # This should still work as the authorization parameter is Optional - self.assertEqual(response.status_code, HTTPStatus.OK) - data = response.json() - self.assertEqual(data["status"], "success") - - def test_update_knowledge_list_no_auth_header(self): - """Test updating knowledge list without authorization header""" - request_data = { - "nexent": ["kb1"], - "datamate": ["kb2"] - } - - response = self.client.post( - "/tenant_config/update_knowledge_list", - json=request_data - ) - - # This should still work as the authorization parameter is Optional - self.assertEqual(response.status_code, HTTPStatus.OK) - data = response.json() - self.assertEqual(data["status"], "success") - if __name__ == '__main__': unittest.main() diff --git a/test/backend/services/test_tenant_config_service.py b/test/backend/services/test_tenant_config_service.py deleted file mode 100644 index ad9d74672..000000000 --- a/test/backend/services/test_tenant_config_service.py +++ /dev/null @@ -1,235 +0,0 @@ -import sys -import os -import types -import unittest -from unittest.mock import MagicMock, patch - -# Add backend directory to Python path for proper imports -project_root = os.path.abspath(os.path.join( - os.path.dirname(__file__), '../../../')) -backend_dir = os.path.join(project_root, 'backend') -if backend_dir not in sys.path: - sys.path.insert(0, backend_dir) - -# Patch boto3 and other dependencies before importing anything from backend -boto3_mock = MagicMock() -sys.modules['boto3'] = boto3_mock - -# Apply critical patches before importing any modules -# This prevents real AWS/MinIO/Elasticsearch calls during import -patch('botocore.client.BaseClient._make_api_call', return_value={}).start() - -# Patch storage factory and MinIO config validation to avoid errors during initialization -storage_client_mock = MagicMock() -minio_client_mock = MagicMock() -minio_client_mock._ensure_bucket_exists = MagicMock() -minio_client_mock.client = MagicMock() - -# Mock the entire MinIOStorageConfig class to avoid validation -minio_config_mock = MagicMock() -minio_config_mock.validate = MagicMock() - -# Import backend modules after all patches are applied -# Use additional context manager to ensure MinioClient is properly mocked during import -with patch('backend.database.client.MinioClient', return_value=minio_client_mock), \ - patch('nexent.storage.minio_config.MinIOStorageConfig', return_value=minio_config_mock): - from backend.services.tenant_config_service import ( - get_selected_knowledge_list, - update_selected_knowledge, - delete_selected_knowledge_by_index_name, - ) - - -class TestTenantConfigService(unittest.TestCase): - def setUp(self): - self.tenant_id = "test_tenant_id" - self.user_id = "test_user_id" - self.index_name = "test_index_name" - self.index_name_list = ["test_index_name1", "test_index_name2"] - self.knowledge_id = "knowledge_id_1" - self.knowledge_ids = ["knowledge_id_1", "knowledge_id_2"] - self.tenant_config_id = "tenant_config_id_1" - - @patch("backend.services.tenant_config_service.get_tenant_config_info") - @patch("backend.services.tenant_config_service.get_knowledge_info_by_knowledge_ids") - def test_get_selected_knowledge_list_empty( - self, mock_get_knowledge_info, mock_get_config - ): - mock_get_config.return_value = [] - result = get_selected_knowledge_list(self.tenant_id, self.user_id) - self.assertEqual(result, []) - mock_get_knowledge_info.assert_not_called() - - @patch("backend.services.tenant_config_service.get_tenant_config_info") - @patch("backend.services.tenant_config_service.get_knowledge_info_by_knowledge_ids") - def test_get_selected_knowledge_list_with_records( - self, mock_get_knowledge_info, mock_get_config - ): - mock_get_config.return_value = [ - {"config_value": self.knowledge_id, - "tenant_config_id": self.tenant_config_id} - ] - mock_get_knowledge_info.return_value = [ - {"knowledge_id": self.knowledge_id, "name": "Test Knowledge"} - ] - - result = get_selected_knowledge_list(self.tenant_id, self.user_id) - - self.assertEqual( - result, [{"knowledge_id": self.knowledge_id, - "name": "Test Knowledge"}] - ) - mock_get_knowledge_info.assert_called_once_with([self.knowledge_id]) - - @patch("backend.services.tenant_config_service.delete_config_by_tenant_config_id") - @patch("backend.services.tenant_config_service.insert_config") - @patch("backend.services.tenant_config_service.get_tenant_config_info") - @patch("backend.services.tenant_config_service.get_knowledge_ids_by_index_names") - def test_update_selected_knowledge_add_only( - self, mock_get_ids, mock_get_config, mock_insert, mock_delete - ): - mock_get_ids.return_value = self.knowledge_ids - mock_get_config.return_value = [] - mock_insert.return_value = True - - result = update_selected_knowledge( - self.tenant_id, self.user_id, self.index_name_list - ) - self.assertTrue(result) - self.assertEqual(mock_insert.call_count, 2) - mock_delete.assert_not_called() - - @patch("backend.services.tenant_config_service.delete_config_by_tenant_config_id") - @patch("backend.services.tenant_config_service.insert_config") - @patch("backend.services.tenant_config_service.get_tenant_config_info") - @patch("backend.services.tenant_config_service.get_knowledge_ids_by_index_names") - def test_update_selected_knowledge_remove_only( - self, mock_get_ids, mock_get_config, mock_insert, mock_delete - ): - mock_get_ids.return_value = [] - mock_get_config.return_value = [ - {"config_value": self.knowledge_id, - "tenant_config_id": self.tenant_config_id} - ] - mock_delete.return_value = True - - result = update_selected_knowledge(self.tenant_id, self.user_id, []) - self.assertTrue(result) - mock_insert.assert_not_called() - mock_delete.assert_called_once_with(self.tenant_config_id) - - @patch("backend.services.tenant_config_service.delete_config_by_tenant_config_id") - @patch("backend.services.tenant_config_service.insert_config") - @patch("backend.services.tenant_config_service.get_tenant_config_info") - @patch("backend.services.tenant_config_service.get_knowledge_ids_by_index_names") - def test_update_selected_knowledge_add_and_remove( - self, mock_get_ids, mock_get_config, mock_insert, mock_delete - ): - mock_get_ids.return_value = ["knowledge_id_2"] - mock_get_config.return_value = [ - {"config_value": "knowledge_id_1", - "tenant_config_id": "tenant_config_id_1"} - ] - mock_insert.return_value = True - mock_delete.return_value = True - - result = update_selected_knowledge( - self.tenant_id, self.user_id, ["new_index"]) - self.assertTrue(result) - mock_insert.assert_called_once() - mock_delete.assert_called_once_with("tenant_config_id_1") - - @patch("backend.services.tenant_config_service.delete_config_by_tenant_config_id") - @patch("backend.services.tenant_config_service.insert_config") - @patch("backend.services.tenant_config_service.get_tenant_config_info") - @patch("backend.services.tenant_config_service.get_knowledge_ids_by_index_names") - def test_update_selected_knowledge_insert_failure( - self, mock_get_ids, mock_get_config, mock_insert, mock_delete - ): - mock_get_ids.return_value = self.knowledge_ids - mock_get_config.return_value = [] - mock_insert.return_value = False - - result = update_selected_knowledge( - self.tenant_id, self.user_id, self.index_name_list - ) - self.assertFalse(result) - mock_insert.assert_called_once() - - @patch("backend.services.tenant_config_service.delete_config_by_tenant_config_id") - @patch("backend.services.tenant_config_service.insert_config") - @patch("backend.services.tenant_config_service.get_tenant_config_info") - @patch("backend.services.tenant_config_service.get_knowledge_ids_by_index_names") - def test_update_selected_knowledge_delete_failure( - self, mock_get_ids, mock_get_config, mock_insert, mock_delete - ): - mock_get_ids.return_value = [] - mock_get_config.return_value = [ - {"config_value": self.knowledge_id, - "tenant_config_id": self.tenant_config_id} - ] - mock_delete.return_value = False - - result = update_selected_knowledge(self.tenant_id, self.user_id, []) - self.assertFalse(result) - mock_delete.assert_called_once_with(self.tenant_config_id) - - @patch("backend.services.tenant_config_service.delete_config_by_tenant_config_id") - @patch("backend.services.tenant_config_service.get_tenant_config_info") - @patch("backend.services.tenant_config_service.get_knowledge_ids_by_index_names") - def test_delete_selected_knowledge_by_index_name_success( - self, mock_get_ids, mock_get_config, mock_delete - ): - mock_get_ids.return_value = [self.knowledge_id] - mock_get_config.return_value = [ - {"config_value": self.knowledge_id, - "tenant_config_id": self.tenant_config_id} - ] - mock_delete.return_value = True - - result = delete_selected_knowledge_by_index_name( - self.tenant_id, self.user_id, self.index_name - ) - self.assertTrue(result) - mock_delete.assert_called_once_with(self.tenant_config_id) - - @patch("backend.services.tenant_config_service.delete_config_by_tenant_config_id") - @patch("backend.services.tenant_config_service.get_tenant_config_info") - @patch("backend.services.tenant_config_service.get_knowledge_ids_by_index_names") - def test_delete_selected_knowledge_by_index_name_no_match( - self, mock_get_ids, mock_get_config, mock_delete - ): - mock_get_ids.return_value = ["different_id"] - mock_get_config.return_value = [ - {"config_value": self.knowledge_id, - "tenant_config_id": self.tenant_config_id} - ] - - result = delete_selected_knowledge_by_index_name( - self.tenant_id, self.user_id, self.index_name - ) - self.assertTrue(result) - mock_delete.assert_not_called() - - @patch("backend.services.tenant_config_service.delete_config_by_tenant_config_id") - @patch("backend.services.tenant_config_service.get_tenant_config_info") - @patch("backend.services.tenant_config_service.get_knowledge_ids_by_index_names") - def test_delete_selected_knowledge_by_index_name_failure( - self, mock_get_ids, mock_get_config, mock_delete - ): - mock_get_ids.return_value = [self.knowledge_id] - mock_get_config.return_value = [ - {"config_value": self.knowledge_id, - "tenant_config_id": self.tenant_config_id} - ] - mock_delete.return_value = False - - result = delete_selected_knowledge_by_index_name( - self.tenant_id, self.user_id, self.index_name - ) - self.assertFalse(result) - mock_delete.assert_called_once_with(self.tenant_config_id) - - -if __name__ == "__main__": - unittest.main() From 8558d1356a85e3fd6e17107ff879f1c6878f81e5 Mon Sep 17 00:00:00 2001 From: zhizhi <928570418@qq.com> Date: Thu, 29 Jan 2026 14:50:43 +0800 Subject: [PATCH 015/167] =?UTF-8?q?=E2=9C=A8The=20knowledge=20base=20retri?= =?UTF-8?q?eval=20tool=20supports=20selecting=20a=20specified=20knowledge?= =?UTF-8?q?=20base=20=201.=20frontend?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../tool/KnowledgeBaseSelectionModal.tsx | 146 +++++ .../tool/KnowledgeBaseToolConfig.tsx | 504 ++++-------------- .../agentConfig/tool/ToolConfigModal.tsx | 101 +--- .../knowledges/KnowledgeBaseConfiguration.tsx | 311 ++--------- .../knowledge/KnowledgeBaseList.tsx | 86 +-- .../contexts/KnowledgeBaseContext.tsx | 190 ++----- frontend/app/[locale]/knowledges/page.tsx | 2 +- frontend/services/knowledgeBaseService.ts | 18 +- 8 files changed, 428 insertions(+), 930 deletions(-) create mode 100644 frontend/app/[locale]/agents/components/agentConfig/tool/KnowledgeBaseSelectionModal.tsx diff --git a/frontend/app/[locale]/agents/components/agentConfig/tool/KnowledgeBaseSelectionModal.tsx b/frontend/app/[locale]/agents/components/agentConfig/tool/KnowledgeBaseSelectionModal.tsx new file mode 100644 index 000000000..71c81484e --- /dev/null +++ b/frontend/app/[locale]/agents/components/agentConfig/tool/KnowledgeBaseSelectionModal.tsx @@ -0,0 +1,146 @@ +"use client"; + +import { useState, useEffect, useMemo } from "react"; +import { useTranslation } from "react-i18next"; +import { Modal, Spin, message } from "antd"; +import KnowledgeBaseList from "../../../../knowledges/components/knowledge/KnowledgeBaseList"; +import knowledgeBaseService from "@/services/knowledgeBaseService"; +import { KnowledgeBase } from "@/types/knowledgeBase"; +import { ConfigStore } from "@/lib/config"; + +export interface KnowledgeBaseSelectionModalProps { + isOpen: boolean; + onCancel: () => void; + onSave: (selectedIds: string[]) => void; + initialSelectedIds: string[]; + toolName?: string; +} + +export default function KnowledgeBaseSelectionModal({ + isOpen, + onCancel, + onSave, + initialSelectedIds, + toolName, +}: KnowledgeBaseSelectionModalProps) { + const { t } = useTranslation("common"); + const [kbLoading, setKbLoading] = useState(false); + const [kbModalSelected, setKbModalSelected] = + useState(initialSelectedIds); + const [kbRawList, setKbRawList] = useState([]); + + const currentEmbeddingModel = useMemo(() => { + return ( + ConfigStore.getInstance().getModelConfig().embedding?.modelName || null + ); + }, []); + + const isKbSelectable = (kb: KnowledgeBase): boolean => { + const docCount = + typeof kb.documentCount === "number" ? kb.documentCount : 0; + const chunkCount = typeof kb.chunkCount === "number" ? kb.chunkCount : 0; + const hasContent = docCount + chunkCount > 0; + + const isModelCompatible = + kb.source !== "nexent" || + kb.embeddingModel === "unknown" || + kb.embeddingModel === currentEmbeddingModel; + + return hasContent && isModelCompatible; + }; + + useEffect(() => { + if (isOpen) { + loadKnowledgeBases(); + setKbModalSelected(initialSelectedIds); + } + }, [isOpen, initialSelectedIds]); + + const loadKnowledgeBases = async () => { + setKbLoading(true); + try { + let kbs: any[] = []; + if (toolName === "datamate_search") { + try { + const syncResult = + await knowledgeBaseService.syncDataMateAndCreateRecords(); + if (syncResult && syncResult.indices_info) { + kbs = syncResult.indices_info.map((indexInfo: any) => { + const stats = indexInfo.stats?.base_info || {}; + const kbId = indexInfo.name; + const kbName = indexInfo.display_name || indexInfo.name; + return { + id: kbId, + name: kbName, + description: "DataMate knowledge base", + documentCount: stats.doc_count || 0, + chunkCount: stats.chunk_count || 0, + createdAt: stats.creation_date || null, + updatedAt: stats.update_date || stats.creation_date || null, + embeddingModel: stats.embedding_model || "unknown", + source: "datamate", + }; + }); + } + } catch (e) { + kbs = []; + } + } else { + const nexentKbs = await knowledgeBaseService.getKnowledgeBasesInfo( + true, + false, + true + ); + kbs = nexentKbs.map((kb) => ({ ...kb, source: "nexent" })); + } + setKbRawList(kbs); + } catch (e) { + message.error( + t("toolConfig.message.kbRefreshFailed", "Failed to refresh KB list") + ); + } finally { + setKbLoading(false); + } + }; + + const handleSave = () => { + onSave(kbModalSelected); + }; + + return ( + document.body} + zIndex={1200} + open={isOpen} + // title={t("toolConfig.modal.selectKbTitle", "Select Knowledge Bases")} + onCancel={onCancel} + cancelText={t("common.button.cancel")} + okText={t("common.confirm")} + onOk={handleSave} + width={800} + > +
+ { + const exists = kbModalSelected.includes(id); + const newSelected = exists + ? kbModalSelected.filter((s) => s !== id) + : [...kbModalSelected, id]; + setKbModalSelected(newSelected); + }} + onClick={() => {}} + showDataMateConfig={false} + isSelectable={isKbSelectable} + getModelDisplayName={(m: string) => m} + containerHeight="50vh" + /> +
+
+ ); +} diff --git a/frontend/app/[locale]/agents/components/agentConfig/tool/KnowledgeBaseToolConfig.tsx b/frontend/app/[locale]/agents/components/agentConfig/tool/KnowledgeBaseToolConfig.tsx index bad406349..50125fe05 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/tool/KnowledgeBaseToolConfig.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/tool/KnowledgeBaseToolConfig.tsx @@ -1,13 +1,11 @@ "use client"; -import { useState, useRef, useEffect, useMemo } from "react"; +import { useState, useEffect, useMemo } from "react"; import { useTranslation } from "react-i18next"; -import { Modal, Spin, Select, Form, message } from "antd"; -import KnowledgeBaseList from "../../../../knowledges/components/knowledge/KnowledgeBaseList"; +import { Spin, Select, Form } from "antd"; import knowledgeBaseService from "@/services/knowledgeBaseService"; import { ToolParam } from "@/types/agentConfig"; -import { KnowledgeBase } from "@/types/knowledgeBase"; -import { ConfigStore } from "@/lib/config"; +import KnowledgeBaseSelectionModal from "./KnowledgeBaseSelectionModal"; export interface KnowledgeBaseToolConfigProps { currentParams: ToolParam[]; @@ -16,12 +14,7 @@ export interface KnowledgeBaseToolConfigProps { serverParamNames: string[]; retrievalParamNames: string[]; renderParamInput: (param: ToolParam, index: number) => React.ReactNode; - externalKbOptions?: { label: React.ReactNode; value: string; description?: string }[]; - externalKbRawList?: any[]; - externalKbLoading?: boolean; - includeDataMateSync?: boolean; toolName?: string; - toolSource?: string; } export default function KnowledgeBaseToolConfig({ @@ -31,246 +24,94 @@ export default function KnowledgeBaseToolConfig({ serverParamNames, retrievalParamNames, renderParamInput, - externalKbOptions, - externalKbRawList, - externalKbLoading = false, - includeDataMateSync = false, toolName, - toolSource, }: KnowledgeBaseToolConfigProps) { const { t } = useTranslation("common"); const [kbOptions, setKbOptions] = useState< - { label: React.ReactNode; value: string; description?: string }[] - >(externalKbOptions || []); - const [kbLoading, setKbLoading] = useState(externalKbLoading); - const [kbModalVisible, setKbModalVisible] = useState(false); - const [kbModalSelected, setKbModalSelected] = useState([]); - const [kbRawList, setKbRawList] = useState(externalKbRawList || []); - const [idNameMap, setIdNameMap] = useState | null>(() => - typeof window !== "undefined" ? knowledgeBaseService.getCachedIdNameMapSync() : null + { label: React.ReactNode; value: string }[] + >([]); + const [kbLoading, setKbLoading] = useState(false); + const [isModalOpen, setIsModalOpen] = useState(false); + const [nameIdMap, setNameIdMap] = useState | null>( + null ); - // Prevent immediate re-opening of modal when Select regains focus after confirm - // Use useRef for synchronous state checking to avoid race conditions - const suppressOpenRef = useRef(false); - const suppressTimerRef = useRef(null); - // Memoize current embedding model to prevent unnecessary re-renders - const currentEmbeddingModel = useMemo(() => { - return ConfigStore.getInstance().getModelConfig().embedding?.modelName || null; - }, []); - - // Helper function to check if a knowledge base is selectable - const isKbSelectable = (kb: KnowledgeBase): boolean => { - const docCount = typeof kb.documentCount === "number" ? kb.documentCount : 0; - const chunkCount = typeof kb.chunkCount === "number" ? kb.chunkCount : 0; - const hasContent = (docCount + chunkCount) > 0; - - // Check model compatibility - only for local knowledge bases (nexent source) - const isModelCompatible = - kb.source !== "nexent" || // Non-local knowledge bases (e.g., DataMate) don't need model check - kb.embeddingModel === "unknown" || - kb.embeddingModel === currentEmbeddingModel; - - return hasContent && isModelCompatible; - }; - - const buildKbOptions = (kbs: any[]) => { - // For the preview/select we only need the KB name (no description). - return kbs.map((kb) => ({ - value: kb.id, - label: kb.name, - description: kb.description || "", - })); - }; - - const openKbModal = async () => { - // Capture suppress state at the start of the function to prevent race conditions - // This snapshot will be used later to verify we should still open the modal - const suppressSnapshot = suppressOpenRef.current; - - // If suppress was active at the start of this call, clear it and return - if (suppressSnapshot) { - suppressOpenRef.current = false; - if (suppressTimerRef.current) { - window.clearTimeout(suppressTimerRef.current); - suppressTimerRef.current = null; - } - return; - } - - // Clear any pending suppress timer since we're intentionally opening the modal - if (suppressTimerRef.current) { - window.clearTimeout(suppressTimerRef.current); - suppressTimerRef.current = null; - } + useEffect(() => { + const loadKbOptions = async () => { + setKbLoading(true); + try { + let kbMap: Record = {}; - // If parent provided KB list, use it; otherwise fetch - setKbLoading(true); - try { - let kbs: any[] = []; - if (externalKbRawList && externalKbRawList.length > 0) { - kbs = externalKbRawList; - } else if (includeDataMateSync || toolName === "datamate_search") { - // Fetch DataMate knowledge bases only - try { - const syncResult = await knowledgeBaseService.syncDataMateAndCreateRecords(); + if (toolName === "datamate_search") { + // For datamate_search, sync DataMate knowledge bases and build map + // Use index_name as key, display_name as value + const syncResult = + await knowledgeBaseService.syncDataMateAndCreateRecords(); if (syncResult && syncResult.indices_info) { - kbs = syncResult.indices_info.map((indexInfo: any) => { - const stats = indexInfo.stats?.base_info || {}; - const kbId = indexInfo.name; - const kbName = indexInfo.display_name || indexInfo.name; - return { - id: kbId, - name: kbName, - description: "DataMate knowledge base", - documentCount: stats.doc_count || 0, - chunkCount: stats.chunk_count || 0, - createdAt: stats.creation_date || null, - updatedAt: stats.update_date || stats.creation_date || null, - embeddingModel: stats.embedding_model || "unknown", - avatar: "", - chunkNum: 0, - language: "", - nickname: "", - parserId: "", - permission: "", - tokenNum: 0, - source: "datamate", - }; + syncResult.indices_info.forEach((indexInfo: any) => { + const kbId = indexInfo.name; // index_name + const kbName = indexInfo.display_name || indexInfo.name; // display_name or fallback + kbMap[kbId] = kbName; }); } - } catch (e) { - // fallback to empty list on error - kbs = []; + } else { + // For other tools, use the standard cached map + kbMap = await knowledgeBaseService.ensureIdNameMap(); } - } else { - // Default: fetch local Elasticsearch indices (no DataMate sync) - kbs = await knowledgeBaseService.getKnowledgeBasesInfo(true, false, true); - } - if (!externalKbRawList || externalKbRawList.length === 0) { - setKbRawList(kbs); - setKbOptions(buildKbOptions(kbs)); - } else { - // Ensure local options reflect external options - setKbOptions(externalKbOptions || buildKbOptions(kbs)); - setKbRawList(externalKbRawList || kbs); - } - - // initialize modal selection from current form value - const idx = currentParams.findIndex((p) => p.name === "index_names"); - const currentVal = idx !== -1 ? currentParams[idx].value : undefined; - setKbModalSelected(Array.isArray(currentVal) ? currentVal : []); - - // Double-check suppress snapshot before showing modal to prevent race conditions - // If suppress was set during the async operation, don't open the modal - if (suppressSnapshot) { - setKbLoading(false); - return; - } - setKbModalVisible(true); - } catch (e) { - message.error(t("toolConfig.message.kbRefreshFailed", "Failed to refresh KB list")); - } finally { - setKbLoading(false); - } - }; - - useEffect(() => { - return () => { - if (suppressTimerRef.current) { - window.clearTimeout(suppressTimerRef.current); - } - }; - }, []); - - // Fill id->name map in background (if not already available) - useEffect(() => { - let cancelled = false; - if (idNameMap) return; - (async () => { - try { - const map = await knowledgeBaseService.ensureIdNameMap(); - if (cancelled) return; - setIdNameMap(map && Object.keys(map).length > 0 ? map : null); - } catch (_e) { - // ignore - } - })(); - return () => { - cancelled = true; - }; - }, [idNameMap]); - - // Load KB info if there are existing selected ids but options don't include their labels - useEffect(() => { - const idx = currentParams.findIndex((p) => p.name === "index_names"); - if (idx === -1) return; - const val = currentParams[idx].value; - if (!Array.isArray(val) || val.length === 0) return; - // If external options provided, check them first - const optionsSource = externalKbOptions || kbOptions; - const missing = val.some((id: string) => !optionsSource.find((o) => o.value === id)); - if (!missing) return; - let cancelled = false; - (async () => { - setKbLoading(true); - try { - let kbs: any[] = []; - if (externalKbRawList && externalKbRawList.length > 0) { - kbs = externalKbRawList; - } else if (includeDataMateSync || toolName === "datamate_search") { - try { - const syncResult = await knowledgeBaseService.syncDataMateAndCreateRecords(); - if (syncResult && syncResult.indices_info) { - kbs = syncResult.indices_info.map((indexInfo: any) => { - const stats = indexInfo.stats?.base_info || {}; - const kbId = indexInfo.name; - const kbName = indexInfo.display_name || indexInfo.name; - return { - id: kbId, - name: kbName, - description: "DataMate knowledge base", - documentCount: stats.doc_count || 0, - chunkCount: stats.chunk_count || 0, - createdAt: stats.creation_date || null, - updatedAt: stats.update_date || stats.creation_date || null, - embeddingModel: stats.embedding_model || "unknown", - avatar: "", - chunkNum: 0, - language: "", - nickname: "", - parserId: "", - permission: "", - tokenNum: 0, - source: "datamate", - }; - }); - } - } catch (e) { - kbs = []; - } - } else { - kbs = await knowledgeBaseService.getKnowledgeBasesInfo(true, false, true); + if (Object.keys(kbMap).length > 0) { + // Build reverse map: display_name -> index_name + const reverseMap: Record = {}; + Object.entries(kbMap).forEach(([id, name]) => { + reverseMap[name] = id; + }); + setNameIdMap(reverseMap); + const options = Object.entries(kbMap).map(([id, name]) => ({ + value: id, // index_name (for saving) + label: name, // display_name (for display) + })); + setKbOptions(options); } - if (cancelled) return; - setKbRawList(kbs); - setKbOptions(buildKbOptions(kbs)); - } catch (e) { - // ignore - we don't want to surface a user-facing error here + } catch (error) { + console.error("Failed to load KB options:", error); } finally { - if (!cancelled) setKbLoading(false); + setKbLoading(false); } - })(); - - return () => { - cancelled = true; }; - }, [currentParams]); + + loadKbOptions(); + }, [toolName]); + + const handleSaveSelection = (selectedIds: string[]) => { + let idx = currentParams.findIndex((p) => p.name === "index_names"); + const newParams = [...currentParams]; + // Convert display_name back to index_name for saving + const actualIds = selectedIds.map((id) => nameIdMap?.[id] || id); + if (idx === -1) { + const newParam: ToolParam = { + name: "index_names", + type: "array", + required: false, + description: "List of knowledge base ids", + value: actualIds, + } as ToolParam; + newParams.push(newParam); + idx = newParams.length - 1; + } else { + newParams[idx] = { ...newParams[idx], value: actualIds }; + } + setCurrentParams(newParams); + const fieldName = `param_${idx}`; + form.setFieldsValue({ [fieldName]: actualIds }); + setIsModalOpen(false); + }; const effectiveRetrievalParamNames = useMemo(() => { - const names = Array.isArray(retrievalParamNames) ? [...retrievalParamNames] : []; - const hasSearchModeInParams = currentParams.findIndex((p) => p.name === "search_mode") !== -1; + const names = Array.isArray(retrievalParamNames) + ? [...retrievalParamNames] + : []; + const hasSearchModeInParams = + currentParams.findIndex((p) => p.name === "search_mode") !== -1; if (hasSearchModeInParams && !names.includes("search_mode")) { names.push("search_mode"); } @@ -347,13 +188,11 @@ export default function KnowledgeBaseToolConfig({ {/* Knowledge base selection */}
- {(() => { - // Always render the index_names preview field. If the param does not exist yet, - // we render a disabled Select with placeholder and keep the "Add" button inline. const idx = currentParams.findIndex((p) => p.name === "index_names"); const fieldName = idx === -1 ? undefined : `param_${idx}`; - const selectedValue = idx === -1 ? [] : currentParams[idx].value || []; + const selectedValue = + idx === -1 ? [] : currentParams[idx].value || []; return (
@@ -361,8 +200,15 @@ export default function KnowledgeBaseToolConfig({ - {idx !== -1 ? currentParams[idx].name : t("toolConfig.field.indexNames", "index_names")} + + {idx !== -1 + ? currentParams[idx].name + : t("toolConfig.field.indexNames", "index_names")} } name={fieldName} @@ -377,163 +223,45 @@ export default function KnowledgeBaseToolConfig({ } className="mb-0" > - {(() => { - const selectedValue = idx === -1 ? [] : currentParams[idx].value || []; - const optionsSource = - externalKbOptions && externalKbOptions.length > 0 - ? externalKbOptions - : kbOptions; - const isLoadingOptions = externalKbLoading || kbLoading; - - // If mapping is available, create quickOptions for selected ids to show labels immediately - const quickOptions = - idNameMap && Array.isArray(selectedValue) && selectedValue.length > 0 - ? selectedValue.map((id: string) => ({ - value: id, - label: idNameMap[id] || id, - })) - : []; - - const missingForSelected = - Array.isArray(selectedValue) && - selectedValue.length > 0 && - selectedValue.some((id: string) => { - // check if optionsSource has the id OR idNameMap has label - const inOptions = optionsSource.find((o) => o.value === id); - const inMap = idNameMap && idNameMap[id]; - return !inOptions && !inMap; - }); - - // If we have saved ids but the mapping/options are not ready, - // show a loading placeholder to avoid rendering raw ids then swapping to names. - if (idx !== -1 && missingForSelected && isLoadingOptions) { - return ( -
- - {t("toolConfig.message.loadingKbNames", "Loading knowledge base names...")} -
- ); + : null} - open={false} - onClick={() => setIsModalOpen(true)} - /> -
-
- {kbLoading && ( -
- -
- )} -
- ); - })()} -
- - setIsModalOpen(false)} - onSave={handleSaveSelection} - initialSelectedIds={ - currentParams.find((p) => p.name === "index_names")?.value || [] - } - toolName={toolName} - /> - - ); -} diff --git a/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx b/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx index ae9cd2ec2..c790d0f32 100644 --- a/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx +++ b/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx @@ -5,8 +5,11 @@ import { useState, useEffect, useRef, useLayoutEffect } from "react"; import { useTranslation } from "react-i18next"; import { App, Modal, Row, Col, theme, Button, Input, Form } from "antd"; -import { WarningFilled, InfoCircleFilled } from "@ant-design/icons"; - +import { + ExclamationCircleFilled, + WarningFilled, + InfoCircleFilled, +} from "@ant-design/icons"; import { DOCUMENT_ACTION_TYPES, KNOWLEDGE_BASE_ACTION_TYPES, @@ -120,21 +123,6 @@ function DataConfig({ isActive }: DataConfigProps) { const { confirm } = useConfirmModal(); const { modelConfig } = useConfig(); const { token } = theme.useToken(); - const hasUserInteractedRef = useRef(false); - - // Helper function to get authorization headers - const getAuthHeaders = () => { - const session = - typeof window !== "undefined" ? localStorage.getItem("session") : null; - const sessionObj = session ? JSON.parse(session) : null; - return { - "Content-Type": "application/json", - "User-Agent": "AgentFrontEnd/1.0", - ...(sessionObj?.access_token && { - Authorization: `Bearer ${sessionObj.access_token}`, - }), - }; - }; // Clear cache when component initializes useEffect(() => { @@ -217,7 +205,8 @@ function DataConfig({ isActive }: DataConfigProps) { hasKnowledgeBaseModelMismatch, refreshKnowledgeBaseData, refreshKnowledgeBaseDataWithDataMate, - + loadUserSelectedKnowledgeBases, + saveUserSelectedKnowledgeBases, dispatch: kbDispatch, } = useKnowledgeBaseContext(); @@ -240,7 +229,7 @@ function DataConfig({ isActive }: DataConfigProps) { const [uploadFiles, setUploadFiles] = useState([]); const [hasClickedUpload, setHasClickedUpload] = useState(false); const [showEmbeddingWarning, setShowEmbeddingWarning] = useState(false); - + const [showAutoDeselectModal, setShowAutoDeselectModal] = useState(false); const [newlyCreatedKbId, setNewlyCreatedKbId] = useState(null); // Track newly created KB waiting for documents // Search and filter state @@ -286,17 +275,199 @@ function DataConfig({ isActive }: DataConfigProps) { setHasClickedUpload, ]); - // When the active knowledge base changes, fetch its documents + // User configuration loading and saving logic based on isActive state + const prevIsActiveRef = useRef(null); // Initialize as null to distinguish first render + const hasLoadedRef = useRef(false); // Track whether configuration has been loaded + const savedSelectedIdsRef = useRef([]); // Save currently selected knowledge base IDs + const savedKnowledgeBasesRef = useRef([]); // Save current knowledge base list + const hasUserInteractedRef = useRef(false); // Track whether user has interacted (prevent saving empty state during initial load) + const hasCleanedRef = useRef(false); // Ensure auto-deselect runs only once per entry + const shouldPersistSelectionRef = useRef(false); // Flag to persist selection after change + + // Listen for isActive state changes + useLayoutEffect(() => { + // Clear cache that might affect state + localStorage.removeItem("preloaded_kb_data"); + localStorage.removeItem("kb_cache"); + + const prevIsActive = prevIsActiveRef.current; + + // Mark ready to load when entering second page + if ((prevIsActive === null || !prevIsActive) && isActive) { + hasLoadedRef.current = false; // Reset loading state + hasUserInteractedRef.current = false; // Reset interaction state to prevent incorrect saving + hasCleanedRef.current = false; // Reset auto-clean flag on entering + } + + // Save user configuration when leaving second page + if (prevIsActive === true && !isActive) { + // Only save after user has interacted to prevent saving empty state during initial load + if (hasUserInteractedRef.current) { + const saveConfig = async () => { + localStorage.removeItem("preloaded_kb_data"); + localStorage.removeItem("kb_cache"); + + try { + await saveUserSelectedKnowledgeBases(); + } catch (error) { + log.error("保存用户配置失败:", error); + } + }; + + saveConfig(); + } + + hasLoadedRef.current = false; // Reset loading state + } + + // Update ref + prevIsActiveRef.current = isActive; + }, [isActive]); + + // Save current state to ref in real-time to ensure access during unmount useEffect(() => { - if (kbState.activeKnowledgeBase) { - fetchDocuments( - kbState.activeKnowledgeBase.id, - false, - kbState.activeKnowledgeBase.source - ); - fetchKnowledgeBases(false); + savedSelectedIdsRef.current = kbState.selectedIds; + savedKnowledgeBasesRef.current = kbState.knowledgeBases; + }, [kbState.selectedIds, kbState.knowledgeBases]); + + // Helper function to get authorization headers + const getAuthHeaders = () => { + const session = + typeof window !== "undefined" ? localStorage.getItem("session") : null; + const sessionObj = session ? JSON.parse(session) : null; + return { + "Content-Type": "application/json", + "User-Agent": "AgentFrontEnd/1.0", + ...(sessionObj?.access_token && { + Authorization: `Bearer ${sessionObj.access_token}`, + }), + }; + }; + + // Save logic when component unmounts + useEffect(() => { + return () => { + // When component unmounts, if previously active and user has interacted, execute save + if (prevIsActiveRef.current === true && hasUserInteractedRef.current) { + // Use saved state instead of current potentially cleared state + const selectedKnowledgeBases = savedKnowledgeBasesRef.current.filter( + (kb) => savedSelectedIdsRef.current.includes(kb.id) + ); + + // Group knowledge bases by source + const knowledgeBySource: { nexent?: string[]; datamate?: string[] } = + {}; + selectedKnowledgeBases.forEach((kb) => { + const source = kb.source as keyof typeof knowledgeBySource; + if (!knowledgeBySource[source]) { + knowledgeBySource[source] = []; + } + knowledgeBySource[source]!.push(kb.id); + }); + + try { + // Use fetch with keepalive to ensure request can be sent during page unload + fetch(API_ENDPOINTS.tenantConfig.updateKnowledgeList, { + method: "POST", + headers: { + "Content-Type": "application/json", + ...getAuthHeaders(), + }, + body: JSON.stringify(knowledgeBySource), + keepalive: true, + }).catch((error) => { + log.error("卸载时保存失败:", error); + }); + } catch (error) { + log.error("卸载时保存请求异常:", error); + } + } + }; + }, []); + + // Separately listen for knowledge base loading state, load user configuration when knowledge base loading is complete and in active state + useEffect(() => { + // Only execute when second page is active, knowledge base is loaded, and user configuration hasn't been loaded yet + if ( + isActive && + kbState.knowledgeBases.length > 0 && + !kbState.isLoading && + !hasLoadedRef.current + ) { + const loadConfig = async () => { + try { + await loadUserSelectedKnowledgeBases(); + hasLoadedRef.current = true; + } catch (error) { + log.error("加载用户配置失败:", error); + } + }; + + loadConfig(); } - }, [kbState.activeKnowledgeBase?.id]); + }, [isActive, kbState.knowledgeBases.length, kbState.isLoading]); + + // Auto-deselect incompatible knowledge bases once after selections are loaded and page is active + useEffect(() => { + if (!isActive) return; + if (!hasLoadedRef.current) return; // ensure user selections loaded + if (kbState.isLoading) return; // avoid running during list loading + if (hasCleanedRef.current) return; // run once per entry + + const embeddingName = modelConfig?.embedding?.modelName?.trim() || ""; + const multiEmbeddingName = + modelConfig?.multiEmbedding?.modelName?.trim() || ""; + + const allowedModels = new Set(); + if (embeddingName) allowedModels.add(embeddingName); + if (multiEmbeddingName) allowedModels.add(multiEmbeddingName); + + const currentSelected = kbState.selectedIds; + if (currentSelected.length === 0) { + hasCleanedRef.current = true; + return; + } + + // If both empty, clear all + if (allowedModels.size === 0) { + shouldPersistSelectionRef.current = true; + kbDispatch({ + type: KNOWLEDGE_BASE_ACTION_TYPES.SELECT_KNOWLEDGE_BASE, + payload: [], + }); + hasUserInteractedRef.current = true; + setShowAutoDeselectModal(true); + hasCleanedRef.current = true; + return; + } + + const filtered = currentSelected.filter((id) => { + const kb = kbState.knowledgeBases.find((k) => k.id === id); + if (!kb) return false; + // DataMate knowledge bases are always allowed (skip model check) + if (kb.source === "datamate") return true; + return allowedModels.has(kb.embeddingModel); + }); + + if (filtered.length !== currentSelected.length) { + shouldPersistSelectionRef.current = true; + kbDispatch({ + type: KNOWLEDGE_BASE_ACTION_TYPES.SELECT_KNOWLEDGE_BASE, + payload: filtered, + }); + hasUserInteractedRef.current = true; + setShowAutoDeselectModal(true); + } + + hasCleanedRef.current = true; + }, [ + isActive, + kbState.isLoading, + kbState.knowledgeBases, + modelConfig?.embedding?.modelName, + modelConfig?.multiEmbedding?.modelName, + kbDispatch, + ]); // Generate unique knowledge base name const generateUniqueKbName = (existingKbs: KnowledgeBase[]): string => { @@ -324,6 +495,7 @@ function DataConfig({ isActive }: DataConfigProps) { ) => { // Only reset creation mode when user clicks if (fromUserClick) { + hasUserInteractedRef.current = true; // Mark user interaction setIsCreatingMode(false); // Reset creating mode setHasClickedUpload(false); // Reset upload button click state } @@ -477,7 +649,7 @@ function DataConfig({ isActive }: DataConfigProps) { // Check if ModelEngine is enabled to determine sync behavior if (modelEngineEnabled) { // When ModelEngine is enabled, sync both local and DataMate knowledge bases - await refreshKnowledgeBaseDataWithDataMate(true); + await refreshKnowledgeBaseDataWithDataMate(); } else { // When ModelEngine is disabled, only sync local knowledge bases await refreshKnowledgeBaseData(true); @@ -748,6 +920,46 @@ function DataConfig({ isActive }: DataConfigProps) { } }, [newlyCreatedKbId, viewingDocuments.length]); + // Handle knowledge base selection + const handleSelectKnowledgeBase = (id: string) => { + hasUserInteractedRef.current = true; // Mark user interaction + selectKnowledgeBase(id); + // Persist selection immediately after reducer updates state + shouldPersistSelectionRef.current = true; + + // When selecting knowledge base also get latest data (low priority background operation) + setTimeout(async () => { + try { + // Use lower priority to refresh data as this is not a critical operation + await refreshKnowledgeBaseData(true); + } catch (error) { + log.error("刷新知识库数据失败:", error); + // Error doesn't affect user experience + } + }, 500); // Delay execution, lower priority + }; + + // Persist user selection changes immediately when flagged + useEffect(() => { + if (!isActive) return; + if (!shouldPersistSelectionRef.current) return; + let cancelled = false; + (async () => { + try { + await saveUserSelectedKnowledgeBases(); + } catch (error) { + log.error("保存用户选择的知识库失败:", error); + } finally { + if (!cancelled) { + shouldPersistSelectionRef.current = false; + } + } + })(); + return () => { + cancelled = true; + }; + }, [kbState.selectedIds, isActive, saveUserSelectedKnowledgeBases]); + // Update active knowledge base ID in polling service when component initializes or active knowledge base changes useEffect(() => { if (kbState.activeKnowledgeBase) { @@ -829,14 +1041,13 @@ function DataConfig({ isActive }: DataConfigProps) { xxl={TWO_COLUMN_LAYOUT.LEFT_COLUMN.xxl} > - setSourceFilter(Array.isArray(values) ? values : [values]) + setSourceFilter( + Array.isArray(values) ? values : values ? [values] : [] + ) } modelFilter={modelFilter} onModelFilterChange={(values) => - setModelFilter(Array.isArray(values) ? values : [values]) + setModelFilter( + Array.isArray(values) ? values : values ? [values] : [] + ) } /> @@ -945,6 +1160,36 @@ function DataConfig({ isActive }: DataConfigProps) {
+ setShowAutoDeselectModal(false)} + onCancel={() => setShowAutoDeselectModal(false)} + okText={t("common.confirm")} + cancelButtonProps={{ style: { display: "none" } }} + centered + okButtonProps={{ type: "primary", danger: true }} + getContainer={() => contentRef.current || document.body} + > +
+ +
+
+ {t("embedding.knowledgeBaseAutoDeselectModal.title")} +
+
+ {t("embedding.knowledgeBaseAutoDeselectModal.content")} +
+
+
+
+ void; onClick: (kb: KnowledgeBase) => void; - onDelete?: (id: string) => void; - onSync?: () => void; - onCreateNew?: () => void; + onDelete: (id: string) => void; + onSync: () => void; + onCreateNew: () => void; onDataMateConfig?: () => void; showDataMateConfig?: boolean; // Control whether to show DataMate config button isSelectable: (kb: KnowledgeBase) => boolean; @@ -73,7 +72,6 @@ interface KnowledgeBaseListProps { } const KnowledgeBaseList: React.FC = ({ - showSelection = true, // New: Default to true knowledgeBases, selectedIds, activeKnowledgeBase, @@ -233,56 +231,52 @@ const KnowledgeBaseList: React.FC = ({
- {onCreateNew && ( - + - )} - {onSync && ( - - )} + + + {t("knowledgeBase.button.sync")} + {showDataMateConfig && ( - )} +
= ({ )} {kb.embeddingModel !== "unknown" && kb.embeddingModel !== currentEmbeddingModel && - kb.source == "nexent" && ( + kb.source !== "datamate" && ( diff --git a/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx b/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx index 6386aa54f..d96c8fa33 100644 --- a/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx +++ b/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx @@ -110,9 +110,9 @@ export const KnowledgeBaseContext = createContext<{ isKnowledgeBaseSelectable: (kb: KnowledgeBase) => boolean; hasKnowledgeBaseModelMismatch: (kb: KnowledgeBase) => boolean; refreshKnowledgeBaseData: (forceRefresh?: boolean) => Promise; - refreshKnowledgeBaseDataWithDataMate: ( - forceRefresh?: boolean - ) => Promise; + refreshKnowledgeBaseDataWithDataMate: () => Promise; + loadUserSelectedKnowledgeBases: () => Promise; + saveUserSelectedKnowledgeBases: () => Promise; }>({ state: { knowledgeBases: [], @@ -133,6 +133,8 @@ export const KnowledgeBaseContext = createContext<{ hasKnowledgeBaseModelMismatch: () => false, refreshKnowledgeBaseData: async () => {}, refreshKnowledgeBaseDataWithDataMate: async () => {}, + loadUserSelectedKnowledgeBases: async () => {}, + saveUserSelectedKnowledgeBases: async () => false, }); // Custom hook for using the context @@ -203,12 +205,55 @@ export const KnowledgeBaseProvider: React.FC = ({ [state.currentEmbeddingModel] ); - // Load knowledge base data (supports force fetch from server) - optimized with useCallback + // Load user selected knowledge bases from backend + const loadUserSelectedKnowledgeBases = useCallback(async () => { + try { + const userConfig = await userConfigService.loadKnowledgeList(); + if (userConfig) { + let allSelectedNames: string[] = []; + + // Handle new format (selectedKbNames array) + if ( + userConfig.selectedKbNames && + userConfig.selectedKbNames.length > 0 + ) { + allSelectedNames = userConfig.selectedKbNames; + } + // Fallback to legacy grouped format for backward compatibility + else if (userConfig.nexent || userConfig.datamate) { + allSelectedNames = [ + ...(userConfig.nexent || []), + ...(userConfig.datamate || []), + ]; + } + + if (allSelectedNames.length > 0) { + // Find matching knowledge base IDs based on index names + const selectedIds = state.knowledgeBases + .filter((kb) => allSelectedNames.includes(kb.id)) + .map((kb) => kb.id); + + dispatch({ + type: KNOWLEDGE_BASE_ACTION_TYPES.SELECT_KNOWLEDGE_BASE, + payload: selectedIds, + }); + } + } + } catch (error) { + log.error(t("knowledgeBase.error.loadSelected"), error); + dispatch({ + type: KNOWLEDGE_BASE_ACTION_TYPES.ERROR, + payload: t("knowledgeBase.error.loadSelectedRetry"), + }); + } + }, [state.knowledgeBases]); + + // Load knowledge base data (supports force fetch from server and load selected status) - optimized with useCallback const fetchKnowledgeBases = useCallback( async ( skipHealthCheck = true, - includeDataMateSync = true, - forceRefresh = false + shouldLoadSelected = true, + includeDataMateSync = true ) => { // If already loading, return directly if (state.isLoading) { @@ -224,14 +269,18 @@ export const KnowledgeBaseProvider: React.FC = ({ // Get knowledge base list data directly from server const kbs = await knowledgeBaseService.getKnowledgeBasesInfo( skipHealthCheck, - includeDataMateSync, - forceRefresh + includeDataMateSync ); dispatch({ type: KNOWLEDGE_BASE_ACTION_TYPES.FETCH_SUCCESS, payload: kbs, }); + + // After loading knowledge bases, automatically load user's selected knowledge bases if requested + if (shouldLoadSelected && kbs.length > 0) { + await loadUserSelectedKnowledgeBases(); + } } catch (error) { log.error(t("knowledgeBase.error.fetchList"), error); dispatch({ @@ -242,7 +291,7 @@ export const KnowledgeBaseProvider: React.FC = ({ dispatch({ type: KNOWLEDGE_BASE_ACTION_TYPES.LOADING, payload: false }); } }, - [state.isLoading, t] + [state.isLoading, t, loadUserSelectedKnowledgeBases] ); // Select knowledge base - memoized with useCallback @@ -353,12 +402,50 @@ export const KnowledgeBaseProvider: React.FC = ({ [state.knowledgeBases, state.selectedIds, state.activeKnowledgeBase] ); + // Save user selected knowledge bases to backend + const saveUserSelectedKnowledgeBases = useCallback(async () => { + try { + // Get selected knowledge bases grouped by source + const selectedKnowledgeBases = state.knowledgeBases.filter((kb) => + state.selectedIds.includes(kb.id) + ); + + // Group knowledge bases by source + const knowledgeBySource: { nexent?: string[]; datamate?: string[] } = {}; + selectedKnowledgeBases.forEach((kb) => { + const source = kb.source as keyof typeof knowledgeBySource; + if (!knowledgeBySource[source]) { + knowledgeBySource[source] = []; + } + knowledgeBySource[source]!.push(kb.id); + }); + + const result = + await userConfigService.updateKnowledgeList(knowledgeBySource); + if (!result) { + dispatch({ + type: KNOWLEDGE_BASE_ACTION_TYPES.ERROR, + payload: t("knowledgeBase.error.saveSelected"), + }); + return false; + } + return true; + } catch (error) { + log.error(t("knowledgeBase.error.saveSelected"), error); + dispatch({ + type: KNOWLEDGE_BASE_ACTION_TYPES.ERROR, + payload: t("knowledgeBase.error.saveSelectedRetry"), + }); + return false; + } + }, [state.knowledgeBases, state.selectedIds, t]); + // Add a function to refresh the knowledge base data const refreshKnowledgeBaseData = useCallback( async (forceRefresh = false) => { try { // Get latest knowledge base data directly from server, but don't reload user selections, include DataMate sync to prevent DataMate KBs from disappearing - await fetchKnowledgeBases(false, true, forceRefresh); + await fetchKnowledgeBases(false, false, true); // If there is an active knowledge base, also refresh its document information if (state.activeKnowledgeBase) { @@ -393,47 +480,41 @@ export const KnowledgeBaseProvider: React.FC = ({ ); // Add a function to refresh the knowledge base data with DataMate sync and create records - const refreshKnowledgeBaseDataWithDataMate = useCallback( - async (forceRefresh = false) => { - try { - // Get latest knowledge base data directly from server, which includes DataMate sync - // The getKnowledgeBasesInfo method already handles syncDataMateAndCreateRecords internally - await fetchKnowledgeBases(false, true, forceRefresh); - - // If there is an active knowledge base, also refresh its document information - if (state.activeKnowledgeBase) { - // Publish document update event to notify document list component to refresh document data - try { - const documents = await knowledgeBaseService.getAllFiles( - state.activeKnowledgeBase.id, - state.activeKnowledgeBase.source - ); - log.log("documents", documents); - window.dispatchEvent( - new CustomEvent("documentsUpdated", { - detail: { - kbId: state.activeKnowledgeBase.id, - documents, - }, - }) - ); - } catch (error) { - log.error("Failed to refresh document information:", error); - } + const refreshKnowledgeBaseDataWithDataMate = useCallback(async () => { + try { + // Get latest knowledge base data directly from server, which includes DataMate sync + // The getKnowledgeBasesInfo method already handles syncDataMateAndCreateRecords internally + await fetchKnowledgeBases(false, false, true); + + // If there is an active knowledge base, also refresh its document information + if (state.activeKnowledgeBase) { + // Publish document update event to notify document list component to refresh document data + try { + const documents = await knowledgeBaseService.getAllFiles( + state.activeKnowledgeBase.id, + state.activeKnowledgeBase.source + ); + log.log("documents", documents); + window.dispatchEvent( + new CustomEvent("documentsUpdated", { + detail: { + kbId: state.activeKnowledgeBase.id, + documents, + }, + }) + ); + } catch (error) { + log.error("Failed to refresh document information:", error); } - } catch (error) { - log.error( - "Failed to refresh knowledge base data with DataMate:", - error - ); - dispatch({ - type: KNOWLEDGE_BASE_ACTION_TYPES.ERROR, - payload: "Failed to refresh knowledge base data with DataMate", - }); } - }, - [fetchKnowledgeBases, state.activeKnowledgeBase] - ); + } catch (error) { + log.error("Failed to refresh knowledge base data with DataMate:", error); + dispatch({ + type: KNOWLEDGE_BASE_ACTION_TYPES.ERROR, + payload: "Failed to refresh knowledge base data with DataMate", + }); + } + }, [fetchKnowledgeBases, state.activeKnowledgeBase]); // Initial data loading - with optimized dependencies useEffect(() => { @@ -467,7 +548,7 @@ export const KnowledgeBaseProvider: React.FC = ({ }); // Reload knowledge base list when model changes - fetchKnowledgeBases(true, true); + fetchKnowledgeBases(true, true, true); } }; @@ -482,7 +563,7 @@ export const KnowledgeBaseProvider: React.FC = ({ }); // Reload knowledge base list when model changes - fetchKnowledgeBases(true, true); + fetchKnowledgeBases(true, true, true); } }; @@ -494,7 +575,8 @@ export const KnowledgeBaseProvider: React.FC = ({ // If first time loading data or force refresh, get from server if (!initialDataLoaded || forceRefresh) { - fetchKnowledgeBases(false, true); + // For force refresh, don't reload user selections to preserve current state + fetchKnowledgeBases(false, !forceRefresh, true); initialDataLoaded = true; } }; @@ -542,6 +624,8 @@ export const KnowledgeBaseProvider: React.FC = ({ hasKnowledgeBaseModelMismatch, refreshKnowledgeBaseData, refreshKnowledgeBaseDataWithDataMate, + loadUserSelectedKnowledgeBases, + saveUserSelectedKnowledgeBases, }), [ state, @@ -553,6 +637,8 @@ export const KnowledgeBaseProvider: React.FC = ({ isKnowledgeBaseSelectable, refreshKnowledgeBaseData, refreshKnowledgeBaseDataWithDataMate, + loadUserSelectedKnowledgeBases, + saveUserSelectedKnowledgeBases, ] ); diff --git a/frontend/app/[locale]/knowledges/page.tsx b/frontend/app/[locale]/knowledges/page.tsx index 3d1f494b9..835018df3 100644 --- a/frontend/app/[locale]/knowledges/page.tsx +++ b/frontend/app/[locale]/knowledges/page.tsx @@ -64,7 +64,7 @@ export default function KnowledgesContent() { style={{ width: "100%", height: "100%" }} >
- +
) : null} diff --git a/frontend/services/api.ts b/frontend/services/api.ts index 25675282c..a157f55e2 100644 --- a/frontend/services/api.ts +++ b/frontend/services/api.ts @@ -138,8 +138,7 @@ export const API_ENDPOINTS = { chunkDetail: (indexName: string, chunkId: string) => `${API_BASE_URL}/indices/${indexName}/chunk/${chunkId}`, // Update knowledge base info - updateIndex: (indexName: string) => - `${API_BASE_URL}/indices/${indexName}`, + updateIndex: (indexName: string) => `${API_BASE_URL}/indices/${indexName}`, searchHybrid: `${API_BASE_URL}/indices/search/hybrid`, summary: (indexName: string) => `${API_BASE_URL}/summary/${indexName}/auto_summary`, @@ -168,10 +167,10 @@ export const API_ENDPOINTS = { saveDataMateUrl: `${API_BASE_URL}/config/save_datamate_url`, }, tenantConfig: { + loadKnowledgeList: `${API_BASE_URL}/tenant_config/load_knowledge_list`, + updateKnowledgeList: `${API_BASE_URL}/tenant_config/update_knowledge_list`, deploymentVersion: `${API_BASE_URL}/tenant_config/deployment_version`, - loadKnowledgeList: `${API_BASE_URL}/tenant_config/knowledge_list`, - updateKnowledgeList: `${API_BASE_URL}/tenant_config/knowledge_list`, - } as const, + }, mcp: { tools: `${API_BASE_URL}/mcp/tools`, add: `${API_BASE_URL}/mcp/add`, @@ -255,7 +254,8 @@ export const API_ENDPOINTS = { addMember: (groupId: number) => `${API_BASE_URL}/groups/${groupId}/members`, removeMember: (groupId: number, userId: string) => `${API_BASE_URL}/groups/${groupId}/members/${userId}`, - default: (tenantId: string) => `${API_BASE_URL}/groups/tenants/${tenantId}/default`, + default: (tenantId: string) => + `${API_BASE_URL}/groups/tenants/${tenantId}/default`, }, invitations: { list: `${API_BASE_URL}/invitations/list`, diff --git a/frontend/services/configService.ts b/frontend/services/configService.ts index d22238760..57ae9751c 100644 --- a/frontend/services/configService.ts +++ b/frontend/services/configService.ts @@ -9,11 +9,6 @@ import log from "@/lib/logger"; const fetch = fetchWithAuth; export class ConfigService { - // In-flight dedupe and short-term cache for loadConfigToFrontend - private _loadInFlight: Promise | null = null; - private _lastLoadTs: number | null = null; - private readonly _LOAD_TTL_MS = 5 * 1000; // 5 seconds - // Save global configuration to backend async saveConfigToBackend(config: GlobalConfig): Promise { try { @@ -39,60 +34,43 @@ export class ConfigService { // Add: Load configuration from backend and write to localStorage async loadConfigToFrontend(): Promise { - // Dedupe concurrent calls and avoid repeat calls within short TTL - if (this._loadInFlight) return this._loadInFlight; - const now = Date.now(); - if (this._lastLoadTs && now - this._lastLoadTs < this._LOAD_TTL_MS) { - return Promise.resolve(true); - } + try { + const response = await fetch(API_ENDPOINTS.config.load, { + method: "GET", + headers: getAuthHeaders(), + }); + if (!response.ok) { + const errorData = await response.json(); + log.error("Failed to load configuration:", errorData); + return false; + } + const result = await response.json(); + const config = result.config; + if (config) { + // Use the conversion function of configStore + const frontendConfig = ConfigStore.transformBackend2Frontend(config); - this._loadInFlight = (async (): Promise => { - try { - const response = await fetch(API_ENDPOINTS.config.load, { - method: "GET", - headers: getAuthHeaders(), - }); - if (!response.ok) { - const errorData = await response.json(); - log.error("Failed to load configuration:", errorData); - this._lastLoadTs = Date.now(); - return false; + // Write to localStorage separately + if (frontendConfig.app) { + localStorage.setItem("app", JSON.stringify(frontendConfig.app)); + } + if (frontendConfig.models) { + localStorage.setItem("model", JSON.stringify(frontendConfig.models)); } - const result = await response.json(); - const config = result.config; - if (config) { - // Use the conversion function of configStore - const frontendConfig = ConfigStore.transformBackend2Frontend(config); - - // Write to localStorage separately - if (frontendConfig.app) { - localStorage.setItem("app", JSON.stringify(frontendConfig.app)); - } - if (frontendConfig.models) { - localStorage.setItem("model", JSON.stringify(frontendConfig.models)); - } - - // Trigger configuration reload and dispatch event - if (typeof window !== "undefined") { - const configStore = ConfigStore.getInstance(); - configStore.reloadFromStorage(); - } - this._lastLoadTs = Date.now(); - return true; + // Trigger configuration reload and dispatch event + if (typeof window !== "undefined") { + const configStore = ConfigStore.getInstance(); + configStore.reloadFromStorage(); } - return false; - } catch (error) { - log.error("Load configuration request exception:", error); - this._lastLoadTs = Date.now(); - return false; - } finally { - // clear in-flight after completion - this._loadInFlight = null; - } - })(); - return this._loadInFlight; + return true; + } + return false; + } catch (error) { + log.error("Load configuration request exception:", error); + return false; + } } } From 2ac2cfbb53a6f57032d9f6ca83937631a8bfe16a Mon Sep 17 00:00:00 2001 From: zhizhi <928570418@qq.com> Date: Mon, 2 Feb 2026 09:31:51 +0800 Subject: [PATCH 025/167] =?UTF-8?q?=F0=9F=94=A7=20Refactor=20services=20an?= =?UTF-8?q?d=20imports=20for=20knowledge=20base=20integration?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Updated import paths to reflect the new structure under `services.knowledge_base`. - Modified various files to ensure proper routing and service functionality for the knowledge base applications. - Enhanced error handling in the Dify service to improve robustness. - Adjusted test cases to align with the new import paths and service structure. --- .../agentConfig/tool/ToolConfigModal.tsx | 166 +++++++++++++++++- frontend/public/locales/en/common.json | 9 + frontend/public/locales/zh/common.json | 9 + 3 files changed, 182 insertions(+), 2 deletions(-) diff --git a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx index 57d1d2aa0..398682df4 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx @@ -1,6 +1,6 @@ "use client"; -import { useState, useEffect } from "react"; +import { useState, useEffect, useCallback, useMemo } from "react"; import { useTranslation } from "react-i18next"; import { Modal, @@ -11,14 +11,21 @@ import { Form, message, Select, + Button, + Space, + Popover, } from "antd"; import { useQueryClient } from "@tanstack/react-query"; import { useAgentConfigStore } from "@/stores/agentConfigStore"; +import { CloseCircleOutlined } from "@ant-design/icons"; import { TOOL_PARAM_TYPES, getToolParamOptions } from "@/const/agentConfig"; import { ToolParam, Tool } from "@/types/agentConfig"; +import { KnowledgeBase } from "@/types/knowledgeBase"; import ToolTestPanel from "./ToolTestPanel"; import { updateToolConfig } from "@/services/agentConfigService"; +import KnowledgeBaseSelectorModal from "@/components/tool-config/KnowledgeBaseSelectorModal"; +import { useKnowledgeBasesForToolConfig } from "@/hooks/useKnowledgeBaseSelector"; export interface ToolConfigModalProps { isOpen: boolean; @@ -31,6 +38,13 @@ export interface ToolConfigModalProps { currentAgentId?: number; } +// Tool types that require knowledge base selection +const TOOLS_REQUIRING_KB_SELECTION = [ + "knowledge_base_search", + "dify_search", + "datamate_search", +]; + export default function ToolConfigModal({ isOpen, onCancel, @@ -50,6 +64,27 @@ export default function ToolConfigModal({ // Tool test panel visibility state const [testPanelVisible, setTestPanelVisible] = useState(false); + + // Knowledge base selector state + const [kbSelectorVisible, setKbSelectorVisible] = useState(false); + const [currentKbParamIndex, setCurrentKbParamIndex] = useState(null); + const [selectedKbIds, setSelectedKbIds] = useState([]); + const [selectedKbDisplayNames, setSelectedKbDisplayNames] = useState([]); + + // Fetch knowledge bases for tool config + const { knowledgeBases, isLoading: kbLoading } = useKnowledgeBasesForToolConfig(); + + // Check if current tool requires knowledge base selection + const toolRequiresKbSelection = useMemo(() => { + return TOOLS_REQUIRING_KB_SELECTION.includes(tool?.name); + }, [tool?.name]); + + // Get index_names parameter info if exists + const indexNamesParam = useMemo(() => { + if (!toolRequiresKbSelection) return null; + return currentParams.find((param) => param.name === "index_names"); + }, [currentParams, toolRequiresKbSelection]); + // Initialize with provided params useEffect(() => { // Initialize form values @@ -59,7 +94,37 @@ export default function ToolConfigModal({ formValues[`param_${index}`] = param.value; }); form.setFieldsValue(formValues); - }, [initialParams]); + + // Parse initial index_names value for knowledge base selection + if (toolRequiresKbSelection) { + const indexNamesParam = initialParams.find((p) => p.name === "index_names"); + if (indexNamesParam?.value) { + try { + // Try to parse as JSON array + const parsed = typeof indexNamesParam.value === "string" + ? JSON.parse(indexNamesParam.value) + : indexNamesParam.value; + if (Array.isArray(parsed)) { + setSelectedKbIds(parsed); + } + } catch { + // If not JSON, might be comma-separated string + if (typeof indexNamesParam.value === "string") { + const ids = indexNamesParam.value.split(",").filter(Boolean); + setSelectedKbIds(ids); + } + } + } + } + }, [initialParams, toolRequiresKbSelection]); + + // Update selected KB display names when IDs change + useEffect(() => { + const names = knowledgeBases + .filter((kb) => selectedKbIds.includes(kb.id)) + .map((kb) => kb.name); + setSelectedKbDisplayNames(names); + }, [selectedKbIds, knowledgeBases]); // Watch all form values and sync to currentParams const formValues = Form.useWatch([], form); @@ -154,6 +219,7 @@ export default function ToolConfigModal({ setTestPanelVisible(false); onCancel(); }; + // Handle tool testing - toggle test panel const handleTestTool = () => { setTestPanelVisible(!testPanelVisible); @@ -164,6 +230,84 @@ export default function ToolConfigModal({ setTestPanelVisible(false); }; + // Open knowledge base selector + const openKbSelector = (paramIndex: number) => { + setCurrentKbParamIndex(paramIndex); + setKbSelectorVisible(true); + }; + + // Handle knowledge base selection confirm + const handleKbConfirm = (selectedKnowledgeBases: KnowledgeBase[]) => { + const ids = selectedKnowledgeBases.map((kb) => kb.id); + const names = selectedKnowledgeBases.map((kb) => kb.name); + + setSelectedKbIds(ids); + setSelectedKbDisplayNames(names); + + // Update form value + if (currentKbParamIndex !== null) { + const param = currentParams[currentKbParamIndex]; + if (param) { + // Store as JSON array for consistency + const formFieldName = `param_${currentKbParamIndex}`; + form.setFieldValue(formFieldName, JSON.stringify(ids)); + } + } + + setKbSelectorVisible(false); + setCurrentKbParamIndex(null); + }; + + // Clear knowledge base selection + const clearKbSelection = () => { + setSelectedKbIds([]); + setSelectedKbDisplayNames([]); + + if (currentKbParamIndex !== null) { + const param = currentParams[currentKbParamIndex]; + if (param) { + const formFieldName = `param_${currentKbParamIndex}`; + form.setFieldValue(formFieldName, []); + } + } + }; + + // Get tool type for knowledge base selector + const getToolType = (): "knowledge_base_search" | "dify_search" | "datamate_search" => { + const name = tool?.name; + if (name === "dify_search") return "dify_search"; + if (name === "datamate_search") return "datamate_search"; + return "knowledge_base_search"; + }; + + // Render knowledge base selector input (no button, just clickable input) + const renderKbSelectorInput = (param: ToolParam, index: number) => { + return ( + openKbSelector(index)} + className="cursor-pointer bg-white" + suffix={ + selectedKbIds.length > 0 ? ( +
+ {/* Knowledge Base Selector Modal */} + setKbSelectorVisible(false)} + onConfirm={handleKbConfirm} + selectedIds={selectedKbIds} + toolType={getToolType()} + knowledgeBases={knowledgeBases} + isLoading={kbLoading} + /> ); } diff --git a/frontend/public/locales/en/common.json b/frontend/public/locales/en/common.json index b32b6062b..d25dbb7d2 100644 --- a/frontend/public/locales/en/common.json +++ b/frontend/public/locales/en/common.json @@ -404,6 +404,12 @@ "toolConfig.button.closeTest": "Close Test Tool", "toolConfig.toolTest.manualInput": "Manual Input", "toolConfig.toolTest.parseMode": "Parse Mode", + "toolConfig.button.selectKnowledgeBases": "Select Knowledge Bases", + "toolConfig.input.knowledgeBaseSelector.placeholder": "Click to select {{name}}", + "toolConfig.knowledgeBaseSelector.title.default": "Select Knowledge Base", + "toolConfig.knowledgeBaseSelector.title.local": "Select Local Knowledge Base", + "toolConfig.knowledgeBaseSelector.title.dify": "Select Dify Knowledge Base", + "toolConfig.knowledgeBaseSelector.title.datamate": "Select DataMate Knowledge Base", "toolPool.title": "Select tools", "toolPool.loading": "Loading...", "toolPool.loadingTools": "Loading tools...", @@ -463,6 +469,8 @@ "knowledgeBase.button.syncDataMate": "Sync DataMate Knowledge Bases", "knowledgeBase.selected.prefix": "Selected", "knowledgeBase.selected.suffix": "knowledge bases for retrieval", + "knowledgeBase.selected.count": "{{count}} selected", + "knowledgeBase.button.clearSelection": "Clear Selection", "knowledgeBase.button.removeKb": "Remove knowledge base {{name}}", "knowledgeBase.tag.documents": "{{count}} Documents", "knowledgeBase.tag.chunks": "{{count}} Chunks", @@ -479,6 +487,7 @@ "knowledgeBase.filter.clear": "Clear filters", "knowledgeBase.source.nexent": "Nexent", "knowledgeBase.source.datamate": "DataMate", + "knowledgeBase.source.dify": "Dify", "knowledgeBase.filter.allSources": "All Sources", "knowledgeBase.filter.allModels": "All Models", "knowledgeBase.filter.source": "Source", diff --git a/frontend/public/locales/zh/common.json b/frontend/public/locales/zh/common.json index 0cab729d5..e5cdc74dd 100644 --- a/frontend/public/locales/zh/common.json +++ b/frontend/public/locales/zh/common.json @@ -405,6 +405,12 @@ "toolConfig.button.closeTest": "关闭工具测试", "toolConfig.toolTest.manualInput": "手动输入", "toolConfig.toolTest.parseMode": "解析模式", + "toolConfig.button.selectKnowledgeBases": "选择知识库", + "toolConfig.input.knowledgeBaseSelector.placeholder": "点击选择{{name}}", + "toolConfig.knowledgeBaseSelector.title.default": "选择知识库", + "toolConfig.knowledgeBaseSelector.title.local": "选择本地知识库", + "toolConfig.knowledgeBaseSelector.title.dify": "选择 Dify 知识库", + "toolConfig.knowledgeBaseSelector.title.datamate": "选择 DataMate 知识库", "toolPool.title": "选择智能体的工具", "toolPool.loading": "加载中...", "toolPool.loadingTools": "加载工具中...", @@ -464,6 +470,8 @@ "knowledgeBase.button.syncDataMate": "同步DataMate知识库", "knowledgeBase.selected.prefix": "已选择", "knowledgeBase.selected.suffix": "个知识库用于知识检索", + "knowledgeBase.selected.count": "已选择 {{count}} 个", + "knowledgeBase.button.clearSelection": "清除选择", "knowledgeBase.button.removeKb": "移除知识库 {{name}}", "knowledgeBase.tag.documents": "{{count}} 文档", "knowledgeBase.tag.chunks": "{{count}} 分块", @@ -479,6 +487,7 @@ "knowledgeBase.filter.model.placeholder": "筛选模型", "knowledgeBase.source.nexent": "Nexent", "knowledgeBase.source.datamate": "DataMate", + "knowledgeBase.source.dify": "Dify", "knowledgeBase.filter.allSources": "全部来源", "knowledgeBase.filter.allModels": "全部模型", "knowledgeBase.filter.source": "来源", From 31825834a548e28bd31ca4a7a3b84bba617ce8fe Mon Sep 17 00:00:00 2001 From: zhizhi <928570418@qq.com> Date: Sat, 31 Jan 2026 15:55:35 +0800 Subject: [PATCH 026/167] =?UTF-8?q?=E2=9C=A8=20Enhance=20tool=20configurat?= =?UTF-8?q?ion=20and=20search=20tools?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/apps/config_app.py | 2 + backend/apps/dify_app.py | 71 +++ backend/services/dify_service.py | 159 +++++ .../components/agentConfig/ToolManagement.tsx | 57 +- .../tool/KnowledgeBaseSelectionModal.tsx | 146 ----- .../tool/KnowledgeBaseToolConfig.tsx | 267 --------- .../knowledges/KnowledgeBaseConfiguration.tsx | 311 ++++++++-- .../knowledge/KnowledgeBaseList.tsx | 198 +++---- .../contexts/KnowledgeBaseContext.tsx | 190 ++++-- frontend/app/[locale]/knowledges/page.tsx | 2 +- frontend/services/api.ts | 12 +- frontend/services/configService.ts | 86 +-- test/backend/agents/test_create_agent_info.py | 293 ++++++++- test/backend/app/test_dify_app.py | 559 ++++++++++++++++++ test/backend/services/test_dify_service.py | 501 ++++++++++++++++ .../core/tools/test_datamate_search_tool.py | 2 +- .../tools/test_knowledge_base_search_tool.py | 47 -- 17 files changed, 2141 insertions(+), 762 deletions(-) create mode 100644 backend/apps/dify_app.py create mode 100644 backend/services/dify_service.py delete mode 100644 frontend/app/[locale]/agents/components/agentConfig/tool/KnowledgeBaseSelectionModal.tsx delete mode 100644 frontend/app/[locale]/agents/components/agentConfig/tool/KnowledgeBaseToolConfig.tsx create mode 100644 test/backend/app/test_dify_app.py create mode 100644 test/backend/services/test_dify_service.py diff --git a/backend/apps/config_app.py b/backend/apps/config_app.py index db979ac98..8691b15e0 100644 --- a/backend/apps/config_app.py +++ b/backend/apps/config_app.py @@ -8,6 +8,7 @@ from apps.config_sync_app import router as config_sync_router from apps.datamate_app import router as datamate_router from apps.vectordatabase_app import router as vectordatabase_router +from apps.dify_app import router as dify_router from apps.file_management_app import file_management_config_router as file_manager_router from apps.image_app import router as proxy_router from apps.knowledge_summary_app import router as summary_router @@ -50,6 +51,7 @@ app.include_router(file_manager_router) app.include_router(proxy_router) app.include_router(tool_config_router) +app.include_router(dify_router) # Choose user management router based on IS_SPEED_MODE if IS_SPEED_MODE: diff --git a/backend/apps/dify_app.py b/backend/apps/dify_app.py new file mode 100644 index 000000000..c7d9321af --- /dev/null +++ b/backend/apps/dify_app.py @@ -0,0 +1,71 @@ +""" +Dify App Layer +FastAPI endpoints for Dify knowledge base operations. + +This module provides API endpoints to interact with Dify's datasets API, +including fetching knowledge bases and transforming responses to a format +compatible with the frontend. +""" +import logging +from http import HTTPStatus +from typing import Optional + +from fastapi import APIRouter, Header, HTTPException, Query +from fastapi.responses import JSONResponse + +from services.dify_service import fetch_dify_datasets_impl +from utils.auth_utils import get_current_user_id + +router = APIRouter(prefix="/dify") +logger = logging.getLogger("dify_app") + + +@router.get("/datasets") +async def fetch_dify_datasets_api( + dify_api_base: str = Query(..., description="Dify API base URL"), + api_key: str = Query(..., description="Dify API key"), + authorization: Optional[str] = Header(None) +): + """ + Fetch datasets (knowledge bases) from Dify API. + + Returns knowledge bases in a format consistent with DataMate for frontend compatibility. + """ + try: + # Normalize URL by removing trailing slash + dify_api_base = dify_api_base.rstrip('/') + except Exception as e: + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail=f"Failed to fetch Dify datasets: {str(e)}" + ) + + try: + _, tenant_id = get_current_user_id(authorization) + except Exception as e: + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail=f"Failed to fetch Dify datasets: {str(e)}" + ) + + try: + result = fetch_dify_datasets_impl( + dify_api_base=dify_api_base, + api_key=api_key, + ) + return JSONResponse( + status_code=HTTPStatus.OK, + content=result + ) + except ValueError as e: + logger.warning(f"Invalid Dify configuration: {e}") + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail=str(e) + ) + except Exception as e: + logger.error(f"Failed to fetch Dify datasets: {e}") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail=f"Failed to fetch Dify datasets: {str(e)}" + ) diff --git a/backend/services/dify_service.py b/backend/services/dify_service.py new file mode 100644 index 000000000..15524576b --- /dev/null +++ b/backend/services/dify_service.py @@ -0,0 +1,159 @@ +""" +Dify Service Layer +Handles API calls to Dify for knowledge base operations. + +This service layer provides functionality to interact with Dify's API, +including fetching datasets (knowledge bases) and transforming responses +to DataMate-compatible format for frontend compatibility. +""" +import json +import logging + +import httpx +from typing import Any, Dict + +logger = logging.getLogger("dify_service") + + +def fetch_dify_datasets_impl( + dify_api_base: str, + api_key: str, +) -> Dict[str, Any]: + """ + Fetch datasets (knowledge bases) from Dify API and transform to DataMate-compatible format. + + Args: + dify_api_base: Dify API base URL + api_key: Dify API key with Bearer token + + Returns: + Dictionary containing knowledge bases in DataMate-compatible format: + { + "indices": ["dataset_id_1", "dataset_id_2", ...], + "count": 2, + "indices_info": [ + { + "name": "dataset_id_1", + "display_name": "知识库名称", + "stats": { + "base_info": { + "doc_count": 10, + "chunk_count": 100, + "store_size": "", + "process_source": "Dify", + "embedding_model": "", + "embedding_dim": 0, + "creation_date": timestamp, + "update_date": timestamp + }, + "search_performance": { + "total_search_count": 0, + "hit_count": 0 + } + } + }, + ... + ], + "pagination": { + "embedding_available": False + } + } + + Raises: + ValueError: If invalid parameters provided + Exception: If API request fails + """ + # Validate inputs + if not dify_api_base or not isinstance(dify_api_base, str): + raise ValueError( + "dify_api_base is required and must be a non-empty string") + + if not api_key or not isinstance(api_key, str): + raise ValueError("api_key is required and must be a non-empty string") + + # Normalize API base URL + api_base = dify_api_base.rstrip("/") + + # Build request URL with pagination + url = f"{api_base}/v1/datasets" + + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + } + + logger.info(f"Fetching Dify datasets from: {url}") + + try: + with httpx.Client(timeout=30, verify=False) as client: + response = client.get(url, headers=headers) + response.raise_for_status() + + result = response.json() + + # Parse Dify API response + datasets_data = result.get("data", []) + + # Transform to DataMate-compatible format + indices = [] + indices_info = [] + embedding_available = False # Default value if no datasets or all skipped + + for dataset in datasets_data: + dataset_id = dataset.get("id", "") + dataset_name = dataset.get("name", "") + document_count = dataset.get("document_count", 0) + created_at = dataset.get("created_at", 0) + updated_at = dataset.get("updated_at", 0) + embedding_available = dataset.get("embedding_available", False) + + if not dataset_id: + continue + + indices.append(dataset_id) + + # Create indices_info entry (compatible with DataMate format) + indices_info.append({ + "name": dataset_id, + "display_name": dataset_name, + "stats": { + "base_info": { + "doc_count": document_count, + "chunk_count": 0, # Dify doesn't provide chunk count directly + "store_size": "", + "process_source": "Dify", + "embedding_model": dataset.get("embedding_model", ""), + "embedding_dim": 0, + "creation_date": created_at * 1000 if created_at else 0, # Convert to milliseconds + "update_date": updated_at * 1000 if updated_at else 0 + }, + "search_performance": { + "total_search_count": 0, + "hit_count": 0 + } + } + }) + + return { + "indices": indices, + "count": len(indices), + "indices_info": indices_info, + "pagination": { + "embedding_available": embedding_available + } + } + + except httpx.RequestError as e: + logger.error(f"Dify API request failed: {str(e)}") + raise Exception(f"Dify API request failed: {str(e)}") + except httpx.HTTPStatusError as e: + logger.error(f"Dify API HTTP error: {str(e)}") + raise Exception(f"Dify API HTTP error: {str(e)}") + except json.JSONDecodeError as e: + logger.error(f"Failed to parse Dify API response: {str(e)}") + raise Exception(f"Failed to parse Dify API response: {str(e)}") + except KeyError as e: + logger.error( + f"Unexpected Dify API response format: missing key {str(e)}") + raise Exception( + f"Unexpected Dify API response format: missing key {str(e)}") diff --git a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx index 81ff59a3e..02ad0e8f6 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx @@ -1,10 +1,9 @@ "use client"; -import { useState, useEffect, useCallback, useMemo } from "react"; +import { useState, useEffect, useCallback } from "react"; import { useTranslation } from "react-i18next"; import ToolConfigModal from "./tool/ToolConfigModal"; import { ToolGroup, Tool, ToolParam } from "@/types/agentConfig"; -import { TOOL_PARAM_TYPES } from "@/const/agentConfig"; import { Tabs, Collapse } from "antd"; import { useAgentConfigStore } from "@/stores/agentConfigStore"; import { useToolList } from "@/hooks/agent/useToolList"; @@ -51,12 +50,6 @@ export default function ToolManagement({ const [selectedTool, setSelectedTool] = useState(null); const [toolParams, setToolParams] = useState([]); - // Memoized array representation of expandedCategories to keep hook order stable - const activeKeyArray = useMemo( - () => Array.from(expandedCategories), - [expandedCategories] - ); - // Helper function to merge tool parameters with instance parameters const mergeToolParamsWithInstance = async ( tool: Tool, @@ -71,7 +64,7 @@ export default function ToolManagement({ if (tooInstance.success && tooInstance.data) { // Merge instance params with default params - const mergedParams: ToolParam[] = + const mergedParams = defaultTool.initParams?.map((param: ToolParam) => { const instanceValue = tooInstance.data?.params?.[param.name]; return { @@ -80,33 +73,8 @@ export default function ToolManagement({ instanceValue !== undefined ? instanceValue : param.value, }; }) || - defaultTool.initParams?.slice() || + defaultTool.initParams || []; - - // If instance contains params that are not defined in default initParams, - // append them so UI can show saved values like `index_names`. - const instanceParams = tooInstance.data?.params || {}; - Object.keys(instanceParams).forEach((key) => { - const exists = mergedParams.some((p) => p.name === key); - if (!exists) { - const val = instanceParams[key]; - const inferredType = Array.isArray(val) - ? TOOL_PARAM_TYPES.ARRAY - : typeof val === "boolean" - ? TOOL_PARAM_TYPES.BOOLEAN - : typeof val === "number" - ? TOOL_PARAM_TYPES.NUMBER - : TOOL_PARAM_TYPES.STRING; - mergedParams.push({ - name: key, - type: inferredType, - required: false, - value: val, - description: "", - } as any); - } - }); - return mergedParams; } else { return defaultTool.initParams || []; @@ -249,22 +217,13 @@ export default function ToolManagement({ <> {/* Collapsible categories using Ant Design Collapse */}
- {/* Memoize activeKey array to avoid creating a new array on every render, - which could cause the Collapse to think its controlled prop changed - and trigger onChange repeatedly. */} { - const keyArray = typeof keys === "string" ? [keys] : keys || []; - const newSet = new Set(keyArray); - // Only update state if the set content actually changed - const sameSize = newSet.size === expandedCategories.size; - const allSame = sameSize - ? Array.from(newSet).every((k) => expandedCategories.has(k)) - : false; - if (!allSame) { - setExpandedCategories(newSet); - } + const newSet = new Set( + typeof keys === "string" ? [keys] : keys + ); + setExpandedCategories(newSet); }} ghost size="small" diff --git a/frontend/app/[locale]/agents/components/agentConfig/tool/KnowledgeBaseSelectionModal.tsx b/frontend/app/[locale]/agents/components/agentConfig/tool/KnowledgeBaseSelectionModal.tsx deleted file mode 100644 index 71c81484e..000000000 --- a/frontend/app/[locale]/agents/components/agentConfig/tool/KnowledgeBaseSelectionModal.tsx +++ /dev/null @@ -1,146 +0,0 @@ -"use client"; - -import { useState, useEffect, useMemo } from "react"; -import { useTranslation } from "react-i18next"; -import { Modal, Spin, message } from "antd"; -import KnowledgeBaseList from "../../../../knowledges/components/knowledge/KnowledgeBaseList"; -import knowledgeBaseService from "@/services/knowledgeBaseService"; -import { KnowledgeBase } from "@/types/knowledgeBase"; -import { ConfigStore } from "@/lib/config"; - -export interface KnowledgeBaseSelectionModalProps { - isOpen: boolean; - onCancel: () => void; - onSave: (selectedIds: string[]) => void; - initialSelectedIds: string[]; - toolName?: string; -} - -export default function KnowledgeBaseSelectionModal({ - isOpen, - onCancel, - onSave, - initialSelectedIds, - toolName, -}: KnowledgeBaseSelectionModalProps) { - const { t } = useTranslation("common"); - const [kbLoading, setKbLoading] = useState(false); - const [kbModalSelected, setKbModalSelected] = - useState(initialSelectedIds); - const [kbRawList, setKbRawList] = useState([]); - - const currentEmbeddingModel = useMemo(() => { - return ( - ConfigStore.getInstance().getModelConfig().embedding?.modelName || null - ); - }, []); - - const isKbSelectable = (kb: KnowledgeBase): boolean => { - const docCount = - typeof kb.documentCount === "number" ? kb.documentCount : 0; - const chunkCount = typeof kb.chunkCount === "number" ? kb.chunkCount : 0; - const hasContent = docCount + chunkCount > 0; - - const isModelCompatible = - kb.source !== "nexent" || - kb.embeddingModel === "unknown" || - kb.embeddingModel === currentEmbeddingModel; - - return hasContent && isModelCompatible; - }; - - useEffect(() => { - if (isOpen) { - loadKnowledgeBases(); - setKbModalSelected(initialSelectedIds); - } - }, [isOpen, initialSelectedIds]); - - const loadKnowledgeBases = async () => { - setKbLoading(true); - try { - let kbs: any[] = []; - if (toolName === "datamate_search") { - try { - const syncResult = - await knowledgeBaseService.syncDataMateAndCreateRecords(); - if (syncResult && syncResult.indices_info) { - kbs = syncResult.indices_info.map((indexInfo: any) => { - const stats = indexInfo.stats?.base_info || {}; - const kbId = indexInfo.name; - const kbName = indexInfo.display_name || indexInfo.name; - return { - id: kbId, - name: kbName, - description: "DataMate knowledge base", - documentCount: stats.doc_count || 0, - chunkCount: stats.chunk_count || 0, - createdAt: stats.creation_date || null, - updatedAt: stats.update_date || stats.creation_date || null, - embeddingModel: stats.embedding_model || "unknown", - source: "datamate", - }; - }); - } - } catch (e) { - kbs = []; - } - } else { - const nexentKbs = await knowledgeBaseService.getKnowledgeBasesInfo( - true, - false, - true - ); - kbs = nexentKbs.map((kb) => ({ ...kb, source: "nexent" })); - } - setKbRawList(kbs); - } catch (e) { - message.error( - t("toolConfig.message.kbRefreshFailed", "Failed to refresh KB list") - ); - } finally { - setKbLoading(false); - } - }; - - const handleSave = () => { - onSave(kbModalSelected); - }; - - return ( - document.body} - zIndex={1200} - open={isOpen} - // title={t("toolConfig.modal.selectKbTitle", "Select Knowledge Bases")} - onCancel={onCancel} - cancelText={t("common.button.cancel")} - okText={t("common.confirm")} - onOk={handleSave} - width={800} - > -
- { - const exists = kbModalSelected.includes(id); - const newSelected = exists - ? kbModalSelected.filter((s) => s !== id) - : [...kbModalSelected, id]; - setKbModalSelected(newSelected); - }} - onClick={() => {}} - showDataMateConfig={false} - isSelectable={isKbSelectable} - getModelDisplayName={(m: string) => m} - containerHeight="50vh" - /> -
-
- ); -} diff --git a/frontend/app/[locale]/agents/components/agentConfig/tool/KnowledgeBaseToolConfig.tsx b/frontend/app/[locale]/agents/components/agentConfig/tool/KnowledgeBaseToolConfig.tsx deleted file mode 100644 index 50125fe05..000000000 --- a/frontend/app/[locale]/agents/components/agentConfig/tool/KnowledgeBaseToolConfig.tsx +++ /dev/null @@ -1,267 +0,0 @@ -"use client"; - -import { useState, useEffect, useMemo } from "react"; -import { useTranslation } from "react-i18next"; -import { Spin, Select, Form } from "antd"; -import knowledgeBaseService from "@/services/knowledgeBaseService"; -import { ToolParam } from "@/types/agentConfig"; -import KnowledgeBaseSelectionModal from "./KnowledgeBaseSelectionModal"; - -export interface KnowledgeBaseToolConfigProps { - currentParams: ToolParam[]; - setCurrentParams: (p: ToolParam[]) => void; - form: any; - serverParamNames: string[]; - retrievalParamNames: string[]; - renderParamInput: (param: ToolParam, index: number) => React.ReactNode; - toolName?: string; -} - -export default function KnowledgeBaseToolConfig({ - currentParams, - setCurrentParams, - form, - serverParamNames, - retrievalParamNames, - renderParamInput, - toolName, -}: KnowledgeBaseToolConfigProps) { - const { t } = useTranslation("common"); - const [kbOptions, setKbOptions] = useState< - { label: React.ReactNode; value: string }[] - >([]); - const [kbLoading, setKbLoading] = useState(false); - const [isModalOpen, setIsModalOpen] = useState(false); - const [nameIdMap, setNameIdMap] = useState | null>( - null - ); - - useEffect(() => { - const loadKbOptions = async () => { - setKbLoading(true); - try { - let kbMap: Record = {}; - - if (toolName === "datamate_search") { - // For datamate_search, sync DataMate knowledge bases and build map - // Use index_name as key, display_name as value - const syncResult = - await knowledgeBaseService.syncDataMateAndCreateRecords(); - if (syncResult && syncResult.indices_info) { - syncResult.indices_info.forEach((indexInfo: any) => { - const kbId = indexInfo.name; // index_name - const kbName = indexInfo.display_name || indexInfo.name; // display_name or fallback - kbMap[kbId] = kbName; - }); - } - } else { - // For other tools, use the standard cached map - kbMap = await knowledgeBaseService.ensureIdNameMap(); - } - - if (Object.keys(kbMap).length > 0) { - // Build reverse map: display_name -> index_name - const reverseMap: Record = {}; - Object.entries(kbMap).forEach(([id, name]) => { - reverseMap[name] = id; - }); - setNameIdMap(reverseMap); - const options = Object.entries(kbMap).map(([id, name]) => ({ - value: id, // index_name (for saving) - label: name, // display_name (for display) - })); - setKbOptions(options); - } - } catch (error) { - console.error("Failed to load KB options:", error); - } finally { - setKbLoading(false); - } - }; - - loadKbOptions(); - }, [toolName]); - - const handleSaveSelection = (selectedIds: string[]) => { - let idx = currentParams.findIndex((p) => p.name === "index_names"); - const newParams = [...currentParams]; - // Convert display_name back to index_name for saving - const actualIds = selectedIds.map((id) => nameIdMap?.[id] || id); - if (idx === -1) { - const newParam: ToolParam = { - name: "index_names", - type: "array", - required: false, - description: "List of knowledge base ids", - value: actualIds, - } as ToolParam; - newParams.push(newParam); - idx = newParams.length - 1; - } else { - newParams[idx] = { ...newParams[idx], value: actualIds }; - } - setCurrentParams(newParams); - const fieldName = `param_${idx}`; - form.setFieldsValue({ [fieldName]: actualIds }); - setIsModalOpen(false); - }; - - const effectiveRetrievalParamNames = useMemo(() => { - const names = Array.isArray(retrievalParamNames) - ? [...retrievalParamNames] - : []; - const hasSearchModeInParams = - currentParams.findIndex((p) => p.name === "search_mode") !== -1; - if (hasSearchModeInParams && !names.includes("search_mode")) { - names.push("search_mode"); - } - return names; - }, [retrievalParamNames, currentParams]); - - return ( - <> - {/* Server parameters */} -
-
- {t("toolConfig.group.serverParameters", "Server Parameters")} -
- {serverParamNames.map((name) => { - const idx = currentParams.findIndex((p) => p.name === name); - if (idx === -1) return null; - const fieldName = `param_${idx}`; - return ( - - {currentParams[idx].name} - - } - name={fieldName} - tooltip={{ - title: currentParams[idx].description, - placement: "topLeft", - styles: { root: { maxWidth: 400 } }, - }} - > - {renderParamInput(currentParams[idx], idx)} - - ); - })} -
- - {/* Retrieval parameters */} -
-
- {t("toolConfig.group.retrievalParameters", "Retrieval Parameters")} -
- {effectiveRetrievalParamNames.map((name) => { - const idx = currentParams.findIndex((p) => p.name === name); - if (idx === -1) return null; - const fieldName = `param_${idx}`; - return ( - - {currentParams[idx].name} - - } - name={fieldName} - tooltip={{ - title: currentParams[idx].description, - placement: "topLeft", - styles: { root: { maxWidth: 400 } }, - }} - > - {renderParamInput(currentParams[idx], idx)} - - ); - })} -
- - {/* Knowledge base selection */} -
- {(() => { - const idx = currentParams.findIndex((p) => p.name === "index_names"); - const fieldName = idx === -1 ? undefined : `param_${idx}`; - const selectedValue = - idx === -1 ? [] : currentParams[idx].value || []; - - return ( -
-
- - {idx !== -1 - ? currentParams[idx].name - : t("toolConfig.field.indexNames", "index_names")} - - } - name={fieldName} - tooltip={ - idx !== -1 - ? { - title: currentParams[idx].description, - placement: "topLeft", - styles: { root: { maxWidth: 400 } }, - } - : undefined - } - className="mb-0" - > - } @@ -577,13 +567,7 @@ export function RegisterModal() { className="mt-2" disabled={authServiceUnavailable} > - {isLoading - ? isAdminMode - ? t("auth.registeringAdmin") - : t("auth.registering") - : isAdminMode - ? t("auth.registerAdmin") - : t("auth.register")} + {isLoading? t("auth.registering"): t("auth.register")} From 007df6f7c5b8b64464d33eeb3db07ff34b36dd17 Mon Sep 17 00:00:00 2001 From: xuyaqist Date: Mon, 2 Feb 2026 10:14:54 +0800 Subject: [PATCH 031/167] resolve types incompatibility issues --- frontend/types/auth.ts | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/frontend/types/auth.ts b/frontend/types/auth.ts index 36f5ca171..dacc567fc 100644 --- a/frontend/types/auth.ts +++ b/frontend/types/auth.ts @@ -52,8 +52,8 @@ export interface AuthContextType { register: ( email: string, password: string, - isAdmin?: boolean, - inviteCode?: string + inviteCode?: string, + withNewInvitation?: boolean ) => Promise; logout: (options?: { silent?: boolean }) => Promise; clearLocalSession: () => void; @@ -169,8 +169,8 @@ export interface AuthenticationStateReturn { register: ( email: string, password: string, - isAdmin?: boolean, - inviteCode?: string + inviteCode?: string, + withNewInvitation?: boolean ) => Promise; logout: (options?: { silent?: boolean }) => Promise; clearLocalSession: () => void; From 02e8cbf7069125d717bac7cc34cbbfe5254d5f79 Mon Sep 17 00:00:00 2001 From: wmc1112 <46217886+WMC001@users.noreply.github.com> Date: Mon, 2 Feb 2026 10:32:36 +0800 Subject: [PATCH 032/167] =?UTF-8?q?=E2=9C=A8=20User=20management:=20add=20?= =?UTF-8?q?resources=20page=20proxy?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docker/init.sql | 2 ++ ...02_add_tenant_resources_route_permission.sql | 17 +++++++++++++++++ .../components/navigation/SideNavigation.tsx | 1 + 3 files changed, 20 insertions(+) create mode 100644 docker/sql/v1.8.0_0202_add_tenant_resources_route_permission.sql diff --git a/docker/init.sql b/docker/init.sql index c7e372ca5..4bfb31908 100644 --- a/docker/init.sql +++ b/docker/init.sql @@ -820,6 +820,7 @@ INSERT INTO nexent.role_permission_t (role_permission_id, user_role, permission_ (6, 'SU', 'VISIBILITY', 'LEFT_NAV_MENU', '/models'), (7, 'SU', 'VISIBILITY', 'LEFT_NAV_MENU', '/memory'), (8, 'SU', 'VISIBILITY', 'LEFT_NAV_MENU', '/users'), +(211, 'SU', 'VISIBILITY', 'LEFT_NAV_MENU', '/tenant-resources'), (9, 'SU', 'RESOURCE', 'AGENT', 'READ'), (10, 'SU', 'RESOURCE', 'AGENT', 'DELETE'), (11, 'SU', 'RESOURCE', 'KB', 'READ'), @@ -868,6 +869,7 @@ INSERT INTO nexent.role_permission_t (role_permission_id, user_role, permission_ (54, 'ADMIN', 'VISIBILITY', 'LEFT_NAV_MENU', '/models'), (55, 'ADMIN', 'VISIBILITY', 'LEFT_NAV_MENU', '/memory'), (56, 'ADMIN', 'VISIBILITY', 'LEFT_NAV_MENU', '/users'), +(212, 'ADMIN', 'VISIBILITY', 'LEFT_NAV_MENU', '/tenant-resources'), (57, 'ADMIN', 'RESOURCE', 'AGENT', 'CREATE'), (58, 'ADMIN', 'RESOURCE', 'AGENT', 'READ'), (59, 'ADMIN', 'RESOURCE', 'AGENT', 'UPDATE'), diff --git a/docker/sql/v1.8.0_0202_add_tenant_resources_route_permission.sql b/docker/sql/v1.8.0_0202_add_tenant_resources_route_permission.sql new file mode 100644 index 000000000..9eddef310 --- /dev/null +++ b/docker/sql/v1.8.0_0202_add_tenant_resources_route_permission.sql @@ -0,0 +1,17 @@ +-- ============================================================================= +-- File: v1.7.9.4_0202_add_tenant_resources_route_permission.sql +-- Description: Add /tenant-resources route permission for SU and ADMIN roles +-- Version: 1.7.9.4 +-- Date: 2026-02-02 +-- ============================================================================= + +-- Add /tenant-resources LEFT_NAV_MENU permission for SU (Super Admin) role +INSERT INTO nexent.role_permission_t (role_permission_id, user_role, permission_category, permission_type, permission_subtype) +VALUES (211, 'SU', 'VISIBILITY', 'LEFT_NAV_MENU', '/tenant-resources') +ON CONFLICT (role_permission_id) DO NOTHING; + +-- Add /tenant-resources LEFT_NAV_MENU permission for ADMIN (Tenant Admin) role +INSERT INTO nexent.role_permission_t (role_permission_id, user_role, permission_category, permission_type, permission_subtype) +VALUES (212, 'ADMIN', 'VISIBILITY', 'LEFT_NAV_MENU', '/tenant-resources') +ON CONFLICT (role_permission_id) DO NOTHING; + diff --git a/frontend/components/navigation/SideNavigation.tsx b/frontend/components/navigation/SideNavigation.tsx index 5838767ea..77671114b 100644 --- a/frontend/components/navigation/SideNavigation.tsx +++ b/frontend/components/navigation/SideNavigation.tsx @@ -59,6 +59,7 @@ const ROUTE_CONFIG: RouteConfig[] = [ { path: "/models", Icon: Settings, labelKey: "sidebar.modelManagement", order: 9 }, { path: "/memory", Icon: Database, labelKey: "sidebar.memoryManagement", order: 10 }, { path: "/users", Icon: User, labelKey: "sidebar.userManagement", order: 11 }, + { path: "/tenant-resources", Icon: Building2, labelKey: "sidebar.tenantResources", order: 12 }, ]; /** From 53f26ae35ea965339e1f6c420afc776e78456ecb Mon Sep 17 00:00:00 2001 From: wmc1112 <46217886+WMC001@users.noreply.github.com> Date: Mon, 2 Feb 2026 17:19:41 +0800 Subject: [PATCH 033/167] =?UTF-8?q?=E2=9C=A8=20User=20management:=20Model?= =?UTF-8?q?=20list=20needs=20tenant=20isolation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/apps/model_managment_app.py | 38 +++++ backend/consts/model.py | 41 +++++ backend/services/model_management_service.py | 81 ++++++++- .../components/resources/ModelList.tsx | 45 ++++- frontend/hooks/model/useAdminTenantModels.ts | 64 +++++++ frontend/services/api.ts | 1 + frontend/services/modelService.ts | 75 +++++++++ test/backend/app/test_model_managment_app.py | 158 ++++++++++++++++++ .../services/test_model_management_service.py | 132 +++++++++++++++ 9 files changed, 630 insertions(+), 5 deletions(-) create mode 100644 frontend/hooks/model/useAdminTenantModels.ts diff --git a/backend/apps/model_managment_app.py b/backend/apps/model_managment_app.py index 0c3c7a8cf..509c0e533 100644 --- a/backend/apps/model_managment_app.py +++ b/backend/apps/model_managment_app.py @@ -18,6 +18,8 @@ BatchCreateModelsRequest, ModelRequest, ProviderModelRequest, + AdminTenantModelRequest, + AdminTenantModelResponse, ) from fastapi import APIRouter, Header, Query, HTTPException @@ -39,6 +41,7 @@ delete_model_for_tenant, list_models_for_tenant, list_llm_models_for_tenant, + list_models_for_admin, ) from utils.auth_utils import get_current_user_id @@ -285,6 +288,41 @@ async def get_llm_model_list(authorization: Optional[str] = Header(None)): detail=str(e)) +@router.post("/admin/list", response_model=AdminTenantModelResponse) +async def get_model_list_for_admin(request: AdminTenantModelRequest, authorization: Optional[str] = Header(None)): + """Get model list for a specified tenant (admin only). + + This endpoint allows super admin to query models for any tenant. + + Args: + request: Contains target tenant_id, optional model_type filter, and pagination params + authorization: Bearer token for authentication + + Returns: + JSONResponse: Model list for the specified tenant with pagination info + """ + try: + user_id, _ = get_current_user_id(authorization) + logger.debug( + f"Start to list models for admin, user_id: {user_id}, target_tenant_id: {request.tenant_id}, " + f"page: {request.page}, page_size: {request.page_size}") + + result = await list_models_for_admin( + request.tenant_id, + request.model_type, + request.page, + request.page_size + ) + return JSONResponse(status_code=HTTPStatus.OK, content={ + "message": "Successfully retrieved model list", + "data": jsonable_encoder(result) + }) + except Exception as e: + logging.error(f"Failed to list models for admin: {str(e)}") + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail=str(e)) + + @router.post("/healthcheck") async def check_model_health( display_name: str = Query(..., description="Display name to check"), diff --git a/backend/consts/model.py b/backend/consts/model.py index 089c09b27..7412c70b3 100644 --- a/backend/consts/model.py +++ b/backend/consts/model.py @@ -633,3 +633,44 @@ class InvitationUseResponse(BaseModel): code_type: str = Field(..., description="Code type") group_ids: Optional[List[int]] = Field( None, description="Associated group IDs") + + +# Admin Model Management Data Models +# --------------------------------------------------------------------------- +class AdminModelListRequest(BaseModel): + """Request model for admin to list models across tenants""" + tenant_ids: List[str] = Field( + ..., min_items=1, description="List of tenant IDs to query") + model_type: Optional[str] = Field( + None, description="Filter by model type (e.g., 'llm', 'embedding')") + page: int = Field(1, ge=1, description="Page number for pagination") + page_size: int = Field(20, ge=1, le=100, description="Items per page") + + +class TenantModelInfo(BaseModel): + """Model containing tenant info and their models""" + tenant_id: str = Field(..., description="Tenant identifier") + tenant_name: str = Field(..., description="Tenant display name") + models: List[Dict[str, Any]] = Field( + default_factory=list, description="List of models for this tenant") + + +class AdminTenantModelRequest(BaseModel): + """Request model for admin to query models for a specific tenant""" + tenant_id: str = Field(..., min_length=1, description="Target tenant ID to query models for") + model_type: Optional[str] = Field( + None, description="Filter by model type (e.g., 'llm', 'embedding')") + page: int = Field(1, ge=1, description="Page number for pagination") + page_size: int = Field(20, ge=1, le=100, description="Items per page") + + +class AdminTenantModelResponse(BaseModel): + """Response model for admin tenant model query""" + tenant_id: str = Field(..., description="Tenant identifier") + tenant_name: str = Field(..., description="Tenant display name") + models: List[Dict[str, Any]] = Field( + default_factory=list, description="List of models for this tenant") + total: int = Field(0, description="Total number of models") + page: int = Field(1, description="Current page number") + page_size: int = Field(20, description="Items per page") + total_pages: int = Field(0, description="Total number of pages") diff --git a/backend/services/model_management_service.py b/backend/services/model_management_service.py index dc38026d8..4b8265028 100644 --- a/backend/services/model_management_service.py +++ b/backend/services/model_management_service.py @@ -1,5 +1,5 @@ import logging -from typing import List, Dict, Any +from typing import List, Dict, Any, Optional from consts.const import LOCALHOST_IP, LOCALHOST_NAME, DOCKER_INTERNAL_HOST from consts.model import ModelConnectStatusEnum @@ -398,5 +398,84 @@ async def list_llm_models_for_tenant(tenant_id: str): raise Exception(f"Failed to retrieve model list: {str(e)}") +async def list_models_for_admin( + tenant_id: str, + model_type: Optional[str] = None, + page: int = 1, + page_size: int = 20 +) -> Dict[str, Any]: + """Get models for a specified tenant (admin operation) with pagination. + + Args: + tenant_id: Target tenant ID to query models for + model_type: Optional model type filter (e.g., 'llm', 'embedding') + page: Page number for pagination (1-indexed) + page_size: Number of items per page + + Returns: + Dict containing tenant_id, tenant_name, paginated models list, and pagination info + """ + try: + # Build filters + filters = None + if model_type: + filters = {"model_type": model_type} + + # Get model records for the specified tenant + records = get_model_records(filters, tenant_id) + + # Type mapping for backwards compatibility + type_map = { + "chat": "llm", + } + + # Normalize model records + normalized_models: List[Dict[str, Any]] = [] + for record in records: + record["model_name"] = add_repo_to_name( + model_repo=record["model_repo"], + model_name=record["model_name"], + ) + record["connect_status"] = ModelConnectStatusEnum.get_value( + record.get("connect_status")) + + # Map model_type if necessary + if record.get("model_type") in type_map: + record["model_type"] = type_map[record["model_type"]] + + normalized_models.append(record) + + # Calculate pagination + total = len(normalized_models) + total_pages = (total + page_size - 1) // page_size if page_size > 0 else 0 + start_index = (page - 1) * page_size + end_index = start_index + page_size + paginated_models = normalized_models[start_index:end_index] + + # Get tenant name + from services.tenant_service import get_tenant_info + try: + tenant_info = get_tenant_info(tenant_id) + tenant_name = tenant_info.get("tenant_name", "") + except Exception: + tenant_name = "" + + result = { + "tenant_id": tenant_id, + "tenant_name": tenant_name, + "models": paginated_models, + "total": total, + "page": page, + "page_size": page_size, + "total_pages": total_pages + } + + logging.debug(f"Successfully retrieved admin model list for tenant: {tenant_id}, page: {page}, page_size: {page_size}") + return result + except Exception as e: + logging.error(f"Failed to retrieve admin model list: {str(e)}") + raise Exception(f"Failed to retrieve admin model list: {str(e)}") + + diff --git a/frontend/app/[locale]/tenant-resources/components/resources/ModelList.tsx b/frontend/app/[locale]/tenant-resources/components/resources/ModelList.tsx index 67296b044..1c0a64d77 100644 --- a/frontend/app/[locale]/tenant-resources/components/resources/ModelList.tsx +++ b/frontend/app/[locale]/tenant-resources/components/resources/ModelList.tsx @@ -2,11 +2,11 @@ import React, { useState } from "react"; import { useTranslation } from "react-i18next"; -import { Table, Button, Popconfirm, message, Tag } from "antd"; +import { Table, Button, Popconfirm, message, Tag, Pagination } from "antd"; import { Edit, Trash2 } from "lucide-react"; import { Tooltip } from "@/components/ui/tooltip"; import { ColumnsType } from "antd/es/table"; -import { useModelList } from "@/hooks/model/useModelList"; +import { useAdminTenantModels } from "@/hooks/model/useAdminTenantModels"; import { modelService } from "@/services/modelService"; import { type ModelOption, type ModelType } from "@/types/modelConfig"; import { ModelAddDialog } from "../../../models/components/model/ModelAddDialog"; @@ -15,7 +15,23 @@ import { CheckCircle, CircleSlash, XCircle, CircleEllipsis, CircleHelp } from "l export default function ModelList({ tenantId }: { tenantId: string | null }) { const { t } = useTranslation("common"); - const { data: models = [], isLoading, refetch } = useModelList(); + + // Pagination state + const [page, setPage] = useState(1); + const [pageSize, setPageSize] = useState(10); + + // Use admin API to get models for the specified tenant + const { + models = [], + total = 0, + isLoading, + refetch, + } = useAdminTenantModels({ + tenantId: tenantId || "", + page, + pageSize, + }); + const [editingModel, setEditingModel] = useState(null); const [addDialogVisible, setAddDialogVisible] = useState(false); const [editDialogVisible, setEditDialogVisible] = useState(false); @@ -63,6 +79,15 @@ export default function ModelList({ tenantId }: { tenantId: string | null }) { } }; + // Handle pagination change + const handlePageChange = (newPage: number, newPageSize: number) => { + setPage(newPage); + if (newPageSize !== pageSize) { + setPageSize(newPageSize); + setPage(1); + } + }; + const columns: ColumnsType = [ { @@ -168,9 +193,21 @@ export default function ModelList({ tenantId }: { tenantId: string | null }) { dataSource={models} loading={isLoading} rowKey="id" - pagination={{ pageSize: 10 }} + pagination={false} /> +
+ `Total ${total} items`} + /> +
+ Promise; +} + +export function useAdminTenantModels(options: { + tenantId: string; + modelType?: string; + page?: number; + pageSize?: number; + enabled?: boolean; +}): AdminTenantModelResult { + const { tenantId, modelType, page = 1, pageSize = 20, enabled = true } = options; + + const query = useQuery({ + queryKey: ["admin-tenant-models", tenantId, modelType, page, pageSize], + queryFn: async (): Promise<{ + models: ModelOption[]; + total: number; + page: number; + pageSize: number; + totalPages: number; + tenantName: string; + }> => { + const result = await modelService.getAdminTenantModels({ + tenantId, + modelType, + page, + pageSize, + }); + return result; + }, + enabled: enabled && !!tenantId, + staleTime: 30_000, // 30 seconds default + }); + + return { + models: query.data?.models ?? [], + total: query.data?.total ?? 0, + page: query.data?.page ?? 1, + pageSize: query.data?.pageSize ?? 20, + totalPages: query.data?.totalPages ?? 0, + tenantName: query.data?.tenantName ?? "", + isLoading: query.isLoading, + isError: query.isError, + error: query.error as Error | null, + refetch: async () => { + await query.refetch(); + }, + }; +} + diff --git a/frontend/services/api.ts b/frontend/services/api.ts index ceec352b6..f7f76afb4 100644 --- a/frontend/services/api.ts +++ b/frontend/services/api.ts @@ -124,6 +124,7 @@ export const API_ENDPOINTS = { updateBatchModel: `${API_BASE_URL}/model/batch_update`, // LLM model list for generation llmModelList: `${API_BASE_URL}/model/llm_list`, + adminModelList: `${API_BASE_URL}/model/admin/list`, }, knowledgeBase: { // Elasticsearch service diff --git a/frontend/services/modelService.ts b/frontend/services/modelService.ts index e329db237..76db80df9 100644 --- a/frontend/services/modelService.ts +++ b/frontend/services/modelService.ts @@ -466,6 +466,81 @@ export const modelService = { return []; } }, + + // Get model list for a specific tenant (admin operation) + getAdminTenantModels: async (params: { + tenantId: string; + modelType?: string; + page?: number; + pageSize?: number; + }): Promise<{ + models: ModelOption[]; + total: number; + page: number; + pageSize: number; + totalPages: number; + tenantName: string; + }> => { + try { + const response = await fetch(API_ENDPOINTS.model.adminModelList, { + method: "POST", + headers: { + ...getAuthHeaders(), + "Content-Type": "application/json", + }, + body: JSON.stringify({ + tenant_id: params.tenantId, + model_type: params.modelType, + page: params.page || 1, + page_size: params.pageSize || 20, + }), + }); + const result = await response.json(); + + if (response.status === STATUS_CODES.SUCCESS && result.data) { + return { + models: result.data.models.map((model: any) => ({ + id: model.model_id, + name: model.model_name, + type: model.model_type as ModelType, + maxTokens: model.max_tokens || 0, + source: model.model_factory as ModelSource, + apiKey: model.api_key || "", + apiUrl: model.base_url || "", + displayName: model.display_name || model.model_name, + connect_status: model.connect_status as ModelConnectStatus, + expectedChunkSize: model.expected_chunk_size, + maximumChunkSize: model.maximum_chunk_size, + chunkingBatchSize: model.chunk_batch, + })), + total: result.data.total || 0, + page: result.data.page || 1, + pageSize: result.data.page_size || 20, + totalPages: result.data.total_pages || 0, + tenantName: result.data.tenant_name || "", + }; + } + + return { + models: [], + total: 0, + page: 1, + pageSize: 20, + totalPages: 0, + tenantName: "", + }; + } catch (error) { + log.warn("Failed to load admin tenant models:", error); + return { + models: [], + total: 0, + page: 1, + pageSize: 20, + totalPages: 0, + tenantName: "", + }; + } + }, }; // -------- Provider detection helpers (for UI rendering) -------- diff --git a/test/backend/app/test_model_managment_app.py b/test/backend/app/test_model_managment_app.py index 3994cd86a..fbbafc390 100644 --- a/test/backend/app/test_model_managment_app.py +++ b/test/backend/app/test_model_managment_app.py @@ -616,5 +616,163 @@ async def mock_batch_update(*args, **kwargs): mock_batch_update.assert_called_once_with(user_credentials[0], user_credentials[1], models) +# Tests for /model/admin/list endpoint +@pytest.mark.asyncio +async def test_get_admin_model_list_success(client, auth_header, user_credentials, mocker): + """Test successful admin model list retrieval for a specified tenant.""" + mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials) + + async def mock_list_models_for_admin(*args, **kwargs): + return { + "tenant_id": "target_tenant", + "tenant_name": "Target Tenant", + "models": [ + { + "model_id": "model1", + "model_name": "huggingface/llama", + "display_name": "LLaMA Model", + "model_type": "llm", + "connect_status": "operational" + }, + { + "model_id": "model2", + "model_name": "openai/clip", + "display_name": "CLIP Model", + "model_type": "embedding", + "connect_status": "not_detected" + } + ], + "total": 2, + "page": 1, + "page_size": 20, + "total_pages": 1 + } + + mock_list = mocker.patch('apps.model_managment_app.list_models_for_admin', side_effect=mock_list_models_for_admin) + + request_data = { + "tenant_id": "target_tenant", + "model_type": None, + "page": 1, + "page_size": 20 + } + response = client.post("/model/admin/list", json=request_data, headers=auth_header) + + assert response.status_code == HTTPStatus.OK + data = response.json() + assert "Successfully retrieved model list" in data["message"] + assert data["data"]["tenant_id"] == "target_tenant" + assert data["data"]["tenant_name"] == "Target Tenant" + assert data["data"]["total"] == 2 + assert data["data"]["page"] == 1 + assert data["data"]["page_size"] == 20 + assert data["data"]["total_pages"] == 1 + assert len(data["data"]["models"]) == 2 + assert data["data"]["models"][0]["model_name"] == "huggingface/llama" + assert data["data"]["models"][1]["model_name"] == "openai/clip" + mock_list.assert_called_once_with("target_tenant", None, 1, 20) + + +@pytest.mark.asyncio +async def test_get_admin_model_list_with_pagination(client, auth_header, user_credentials, mocker): + """Test admin model list retrieval with pagination parameters.""" + mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials) + + async def mock_list_models_for_admin(*args, **kwargs): + return { + "tenant_id": "target_tenant", + "tenant_name": "Target Tenant", + "models": [ + { + "model_id": "model3", + "model_name": "openai/gpt-3", + "display_name": "GPT-3", + "model_type": "llm", + "connect_status": "operational" + } + ], + "total": 25, + "page": 2, + "page_size": 10, + "total_pages": 3 + } + + mock_list = mocker.patch('apps.model_managment_app.list_models_for_admin', side_effect=mock_list_models_for_admin) + + request_data = { + "tenant_id": "target_tenant", + "model_type": "llm", + "page": 2, + "page_size": 10 + } + response = client.post("/model/admin/list", json=request_data, headers=auth_header) + + assert response.status_code == HTTPStatus.OK + data = response.json() + assert data["data"]["page"] == 2 + assert data["data"]["page_size"] == 10 + assert data["data"]["total"] == 25 + assert data["data"]["total_pages"] == 3 + assert len(data["data"]["models"]) == 1 + mock_list.assert_called_once_with("target_tenant", "llm", 2, 10) + + +@pytest.mark.asyncio +async def test_get_admin_model_list_exception(client, auth_header, user_credentials, mocker): + """Test admin model list retrieval with exception.""" + mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials) + + async def mock_list_models_for_admin(*args, **kwargs): + raise Exception("Database connection error") + + mocker.patch('apps.model_managment_app.list_models_for_admin', side_effect=mock_list_models_for_admin) + + request_data = { + "tenant_id": "target_tenant", + "model_type": None, + "page": 1, + "page_size": 20 + } + response = client.post("/model/admin/list", json=request_data, headers=auth_header) + + assert response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR + data = response.json() + assert "Database connection error" in data.get("detail", "") + + +@pytest.mark.asyncio +async def test_get_admin_model_list_empty(client, auth_header, user_credentials, mocker): + """Test admin model list retrieval with empty result.""" + mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials) + + async def mock_list_models_for_admin(*args, **kwargs): + return { + "tenant_id": "empty_tenant", + "tenant_name": "Empty Tenant", + "models": [], + "total": 0, + "page": 1, + "page_size": 20, + "total_pages": 0 + } + + mock_list = mocker.patch('apps.model_managment_app.list_models_for_admin', side_effect=mock_list_models_for_admin) + + request_data = { + "tenant_id": "empty_tenant", + "model_type": None, + "page": 1, + "page_size": 20 + } + response = client.post("/model/admin/list", json=request_data, headers=auth_header) + + assert response.status_code == HTTPStatus.OK + data = response.json() + assert "Successfully retrieved model list" in data["message"] + assert data["data"]["total"] == 0 + assert len(data["data"]["models"]) == 0 + mock_list.assert_called_once_with("empty_tenant", None, 1, 20) + + if __name__ == "__main__": pytest.main([__file__]) \ No newline at end of file diff --git a/test/backend/services/test_model_management_service.py b/test/backend/services/test_model_management_service.py index e38e24976..562f79fd5 100644 --- a/test/backend/services/test_model_management_service.py +++ b/test/backend/services/test_model_management_service.py @@ -1089,3 +1089,135 @@ async def test_list_llm_models_for_tenant_handles_missing_repo(): assert len(result) == 2 assert result[0]["model_name"] == "local-model" # No repo prefix assert result[1]["model_name"] == "another-model" # No repo prefix + + +# Tests for list_models_for_admin +async def test_list_models_for_admin_success(): + """Test list_models_for_tenant returns models for a specified tenant.""" + svc = import_svc() + + records = [ + {"model_repo": "huggingface", "model_name": "llama", + "connect_status": "operational", "model_type": "llm"}, + {"model_repo": "openai", "model_name": "clip", "connect_status": None, "model_type": "embedding"}, + ] + + with mock.patch.object(svc, "get_model_records", return_value=records), \ + mock.patch.object(svc, "add_repo_to_name", side_effect=lambda model_repo, model_name: f"{model_repo}/{model_name}" if model_repo else model_name), \ + mock.patch.object(svc.ModelConnectStatusEnum, "get_value", side_effect=lambda s: s or "not_detected"), \ + mock.patch("services.tenant_service.get_tenant_info", return_value={"tenant_name": "Test Tenant"}): + out = await svc.list_models_for_admin("t1") + assert out["tenant_id"] == "t1" + assert out["tenant_name"] == "Test Tenant" + assert out["total"] == 2 + assert out["page"] == 1 + assert out["page_size"] == 20 + assert out["total_pages"] == 1 + assert len(out["models"]) == 2 + assert out["models"][0]["model_name"] == "huggingface/llama" + + +async def test_list_models_for_admin_with_pagination(): + """Test list_models_for_tenant handles pagination correctly.""" + svc = import_svc() + + # Create 25 records to test pagination + records = [ + {"model_repo": "openai", "model_name": f"gpt-{i}", "connect_status": "operational", "model_type": "llm"} + for i in range(25) + ] + + with mock.patch.object(svc, "get_model_records", return_value=records), \ + mock.patch.object(svc, "add_repo_to_name", side_effect=lambda model_repo, model_name: f"{model_repo}/{model_name}" if model_repo else model_name), \ + mock.patch.object(svc.ModelConnectStatusEnum, "get_value", side_effect=lambda s: s or "not_detected"), \ + mock.patch("services.tenant_service.get_tenant_info", return_value={"tenant_name": "Test Tenant"}): + # Page 1, page_size 10 + out = await svc.list_models_for_admin("t1", page=1, page_size=10) + assert out["page"] == 1 + assert out["page_size"] == 10 + assert out["total"] == 25 + assert out["total_pages"] == 3 + assert len(out["models"]) == 10 + assert out["models"][0]["model_name"] == "openai/gpt-0" + + # Page 2 + out = await svc.list_models_for_admin("t1", page=2, page_size=10) + assert out["page"] == 2 + assert len(out["models"]) == 10 + assert out["models"][0]["model_name"] == "openai/gpt-10" + + # Page 3 (last page) + out = await svc.list_models_for_admin("t1", page=3, page_size=10) + assert out["page"] == 3 + assert out["total_pages"] == 3 + assert len(out["models"]) == 5 + assert out["models"][0]["model_name"] == "openai/gpt-20" + + +async def test_list_models_for_admin_with_model_type_filter(): + """Test list_models_for_tenant filters by model_type.""" + svc = import_svc() + + records = [ + {"model_repo": "openai", "model_name": "gpt-4", "connect_status": "operational", "model_type": "llm"}, + {"model_repo": "openai", "model_name": "text-embedding", "connect_status": "operational", "model_type": "embedding"}, + ] + + with mock.patch.object(svc, "get_model_records", return_value=records) as mock_get_records, \ + mock.patch.object(svc, "add_repo_to_name", side_effect=lambda model_repo, model_name: f"{model_repo}/{model_name}" if model_repo else model_name), \ + mock.patch.object(svc.ModelConnectStatusEnum, "get_value", side_effect=lambda s: s or "not_detected"), \ + mock.patch("services.tenant_service.get_tenant_info", return_value={"tenant_name": "Test Tenant"}): + # Filter by llm + out = await svc.list_models_for_admin("t1", model_type="llm") + mock_get_records.assert_called_once_with({"model_type": "llm"}, "t1") + assert out["total"] == 2 + assert out["models"][0]["model_type"] == "llm" + + +async def test_list_models_for_admin_empty_tenant(): + """Test list_models_for_tenant handles empty tenant gracefully.""" + svc = import_svc() + + with mock.patch.object(svc, "get_model_records", return_value=[]), \ + mock.patch("services.tenant_service.get_tenant_info", return_value={"tenant_name": ""}): + out = await svc.list_models_for_admin("t1") + assert out["tenant_id"] == "t1" + assert out["tenant_name"] == "" + assert out["total"] == 0 + assert out["total_pages"] == 0 + assert len(out["models"]) == 0 + + +async def test_list_models_for_admin_exception(): + """Test list_models_for_tenant handles exceptions.""" + svc = import_svc() + + with mock.patch.object(svc, "get_model_records", side_effect=Exception("db error")): + with pytest.raises(Exception) as exc: + await svc.list_models_for_admin("t1") + assert "Failed to retrieve admin model list" in str(exc.value) + + +async def test_list_models_for_admin_type_mapping(): + """Test list_models_for_tenant maps model_type from 'chat' to 'llm'.""" + svc = import_svc() + + records = [ + { + "model_id": "llm1", + "model_repo": "openai", + "model_name": "gpt-4", + "display_name": "GPT-4", + "model_type": "chat", # Should be mapped to "llm" + "connect_status": "operational" + }, + ] + + with mock.patch.object(svc, "get_model_records", return_value=records), \ + mock.patch.object(svc, "add_repo_to_name", side_effect=lambda model_repo, model_name: f"{model_repo}/{model_name}" if model_repo else model_name), \ + mock.patch.object(svc.ModelConnectStatusEnum, "get_value", side_effect=lambda s: s or "not_detected"), \ + mock.patch("services.tenant_service.get_tenant_info", return_value={"tenant_name": "Test Tenant"}): + out = await svc.list_models_for_admin("t1") + + assert len(out["models"]) == 1 + assert out["models"][0]["model_type"] == "llm" # Should be mapped from "chat" \ No newline at end of file From 96516d414c7f31dee12199d4f6162ee4a6dfdc5f Mon Sep 17 00:00:00 2001 From: wmc1112 <46217886+wmc001@users.noreply.github.com> Date: Mon, 2 Feb 2026 17:19:41 +0800 Subject: [PATCH 034/167] =?UTF-8?q?=E2=9C=A8=20User=20management:=20Model?= =?UTF-8?q?=20list=20needs=20tenant=20isolation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/apps/model_managment_app.py | 38 +++++ backend/consts/model.py | 41 +++++ backend/services/model_management_service.py | 81 ++++++++- .../components/resources/ModelList.tsx | 45 ++++- frontend/hooks/model/useAdminTenantModels.ts | 64 +++++++ frontend/services/api.ts | 1 + frontend/services/modelService.ts | 75 +++++++++ test/backend/app/test_model_managment_app.py | 158 ++++++++++++++++++ .../services/test_model_management_service.py | 148 ++++++++++++++++ 9 files changed, 646 insertions(+), 5 deletions(-) create mode 100644 frontend/hooks/model/useAdminTenantModels.ts diff --git a/backend/apps/model_managment_app.py b/backend/apps/model_managment_app.py index 0c3c7a8cf..509c0e533 100644 --- a/backend/apps/model_managment_app.py +++ b/backend/apps/model_managment_app.py @@ -18,6 +18,8 @@ BatchCreateModelsRequest, ModelRequest, ProviderModelRequest, + AdminTenantModelRequest, + AdminTenantModelResponse, ) from fastapi import APIRouter, Header, Query, HTTPException @@ -39,6 +41,7 @@ delete_model_for_tenant, list_models_for_tenant, list_llm_models_for_tenant, + list_models_for_admin, ) from utils.auth_utils import get_current_user_id @@ -285,6 +288,41 @@ async def get_llm_model_list(authorization: Optional[str] = Header(None)): detail=str(e)) +@router.post("/admin/list", response_model=AdminTenantModelResponse) +async def get_model_list_for_admin(request: AdminTenantModelRequest, authorization: Optional[str] = Header(None)): + """Get model list for a specified tenant (admin only). + + This endpoint allows super admin to query models for any tenant. + + Args: + request: Contains target tenant_id, optional model_type filter, and pagination params + authorization: Bearer token for authentication + + Returns: + JSONResponse: Model list for the specified tenant with pagination info + """ + try: + user_id, _ = get_current_user_id(authorization) + logger.debug( + f"Start to list models for admin, user_id: {user_id}, target_tenant_id: {request.tenant_id}, " + f"page: {request.page}, page_size: {request.page_size}") + + result = await list_models_for_admin( + request.tenant_id, + request.model_type, + request.page, + request.page_size + ) + return JSONResponse(status_code=HTTPStatus.OK, content={ + "message": "Successfully retrieved model list", + "data": jsonable_encoder(result) + }) + except Exception as e: + logging.error(f"Failed to list models for admin: {str(e)}") + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail=str(e)) + + @router.post("/healthcheck") async def check_model_health( display_name: str = Query(..., description="Display name to check"), diff --git a/backend/consts/model.py b/backend/consts/model.py index 089c09b27..7412c70b3 100644 --- a/backend/consts/model.py +++ b/backend/consts/model.py @@ -633,3 +633,44 @@ class InvitationUseResponse(BaseModel): code_type: str = Field(..., description="Code type") group_ids: Optional[List[int]] = Field( None, description="Associated group IDs") + + +# Admin Model Management Data Models +# --------------------------------------------------------------------------- +class AdminModelListRequest(BaseModel): + """Request model for admin to list models across tenants""" + tenant_ids: List[str] = Field( + ..., min_items=1, description="List of tenant IDs to query") + model_type: Optional[str] = Field( + None, description="Filter by model type (e.g., 'llm', 'embedding')") + page: int = Field(1, ge=1, description="Page number for pagination") + page_size: int = Field(20, ge=1, le=100, description="Items per page") + + +class TenantModelInfo(BaseModel): + """Model containing tenant info and their models""" + tenant_id: str = Field(..., description="Tenant identifier") + tenant_name: str = Field(..., description="Tenant display name") + models: List[Dict[str, Any]] = Field( + default_factory=list, description="List of models for this tenant") + + +class AdminTenantModelRequest(BaseModel): + """Request model for admin to query models for a specific tenant""" + tenant_id: str = Field(..., min_length=1, description="Target tenant ID to query models for") + model_type: Optional[str] = Field( + None, description="Filter by model type (e.g., 'llm', 'embedding')") + page: int = Field(1, ge=1, description="Page number for pagination") + page_size: int = Field(20, ge=1, le=100, description="Items per page") + + +class AdminTenantModelResponse(BaseModel): + """Response model for admin tenant model query""" + tenant_id: str = Field(..., description="Tenant identifier") + tenant_name: str = Field(..., description="Tenant display name") + models: List[Dict[str, Any]] = Field( + default_factory=list, description="List of models for this tenant") + total: int = Field(0, description="Total number of models") + page: int = Field(1, description="Current page number") + page_size: int = Field(20, description="Items per page") + total_pages: int = Field(0, description="Total number of pages") diff --git a/backend/services/model_management_service.py b/backend/services/model_management_service.py index dc38026d8..4b8265028 100644 --- a/backend/services/model_management_service.py +++ b/backend/services/model_management_service.py @@ -1,5 +1,5 @@ import logging -from typing import List, Dict, Any +from typing import List, Dict, Any, Optional from consts.const import LOCALHOST_IP, LOCALHOST_NAME, DOCKER_INTERNAL_HOST from consts.model import ModelConnectStatusEnum @@ -398,5 +398,84 @@ async def list_llm_models_for_tenant(tenant_id: str): raise Exception(f"Failed to retrieve model list: {str(e)}") +async def list_models_for_admin( + tenant_id: str, + model_type: Optional[str] = None, + page: int = 1, + page_size: int = 20 +) -> Dict[str, Any]: + """Get models for a specified tenant (admin operation) with pagination. + + Args: + tenant_id: Target tenant ID to query models for + model_type: Optional model type filter (e.g., 'llm', 'embedding') + page: Page number for pagination (1-indexed) + page_size: Number of items per page + + Returns: + Dict containing tenant_id, tenant_name, paginated models list, and pagination info + """ + try: + # Build filters + filters = None + if model_type: + filters = {"model_type": model_type} + + # Get model records for the specified tenant + records = get_model_records(filters, tenant_id) + + # Type mapping for backwards compatibility + type_map = { + "chat": "llm", + } + + # Normalize model records + normalized_models: List[Dict[str, Any]] = [] + for record in records: + record["model_name"] = add_repo_to_name( + model_repo=record["model_repo"], + model_name=record["model_name"], + ) + record["connect_status"] = ModelConnectStatusEnum.get_value( + record.get("connect_status")) + + # Map model_type if necessary + if record.get("model_type") in type_map: + record["model_type"] = type_map[record["model_type"]] + + normalized_models.append(record) + + # Calculate pagination + total = len(normalized_models) + total_pages = (total + page_size - 1) // page_size if page_size > 0 else 0 + start_index = (page - 1) * page_size + end_index = start_index + page_size + paginated_models = normalized_models[start_index:end_index] + + # Get tenant name + from services.tenant_service import get_tenant_info + try: + tenant_info = get_tenant_info(tenant_id) + tenant_name = tenant_info.get("tenant_name", "") + except Exception: + tenant_name = "" + + result = { + "tenant_id": tenant_id, + "tenant_name": tenant_name, + "models": paginated_models, + "total": total, + "page": page, + "page_size": page_size, + "total_pages": total_pages + } + + logging.debug(f"Successfully retrieved admin model list for tenant: {tenant_id}, page: {page}, page_size: {page_size}") + return result + except Exception as e: + logging.error(f"Failed to retrieve admin model list: {str(e)}") + raise Exception(f"Failed to retrieve admin model list: {str(e)}") + + diff --git a/frontend/app/[locale]/tenant-resources/components/resources/ModelList.tsx b/frontend/app/[locale]/tenant-resources/components/resources/ModelList.tsx index 67296b044..1c0a64d77 100644 --- a/frontend/app/[locale]/tenant-resources/components/resources/ModelList.tsx +++ b/frontend/app/[locale]/tenant-resources/components/resources/ModelList.tsx @@ -2,11 +2,11 @@ import React, { useState } from "react"; import { useTranslation } from "react-i18next"; -import { Table, Button, Popconfirm, message, Tag } from "antd"; +import { Table, Button, Popconfirm, message, Tag, Pagination } from "antd"; import { Edit, Trash2 } from "lucide-react"; import { Tooltip } from "@/components/ui/tooltip"; import { ColumnsType } from "antd/es/table"; -import { useModelList } from "@/hooks/model/useModelList"; +import { useAdminTenantModels } from "@/hooks/model/useAdminTenantModels"; import { modelService } from "@/services/modelService"; import { type ModelOption, type ModelType } from "@/types/modelConfig"; import { ModelAddDialog } from "../../../models/components/model/ModelAddDialog"; @@ -15,7 +15,23 @@ import { CheckCircle, CircleSlash, XCircle, CircleEllipsis, CircleHelp } from "l export default function ModelList({ tenantId }: { tenantId: string | null }) { const { t } = useTranslation("common"); - const { data: models = [], isLoading, refetch } = useModelList(); + + // Pagination state + const [page, setPage] = useState(1); + const [pageSize, setPageSize] = useState(10); + + // Use admin API to get models for the specified tenant + const { + models = [], + total = 0, + isLoading, + refetch, + } = useAdminTenantModels({ + tenantId: tenantId || "", + page, + pageSize, + }); + const [editingModel, setEditingModel] = useState(null); const [addDialogVisible, setAddDialogVisible] = useState(false); const [editDialogVisible, setEditDialogVisible] = useState(false); @@ -63,6 +79,15 @@ export default function ModelList({ tenantId }: { tenantId: string | null }) { } }; + // Handle pagination change + const handlePageChange = (newPage: number, newPageSize: number) => { + setPage(newPage); + if (newPageSize !== pageSize) { + setPageSize(newPageSize); + setPage(1); + } + }; + const columns: ColumnsType = [ { @@ -168,9 +193,21 @@ export default function ModelList({ tenantId }: { tenantId: string | null }) { dataSource={models} loading={isLoading} rowKey="id" - pagination={{ pageSize: 10 }} + pagination={false} /> +
+ `Total ${total} items`} + /> +
+ Promise; +} + +export function useAdminTenantModels(options: { + tenantId: string; + modelType?: string; + page?: number; + pageSize?: number; + enabled?: boolean; +}): AdminTenantModelResult { + const { tenantId, modelType, page = 1, pageSize = 20, enabled = true } = options; + + const query = useQuery({ + queryKey: ["admin-tenant-models", tenantId, modelType, page, pageSize], + queryFn: async (): Promise<{ + models: ModelOption[]; + total: number; + page: number; + pageSize: number; + totalPages: number; + tenantName: string; + }> => { + const result = await modelService.getAdminTenantModels({ + tenantId, + modelType, + page, + pageSize, + }); + return result; + }, + enabled: enabled && !!tenantId, + staleTime: 30_000, // 30 seconds default + }); + + return { + models: query.data?.models ?? [], + total: query.data?.total ?? 0, + page: query.data?.page ?? 1, + pageSize: query.data?.pageSize ?? 20, + totalPages: query.data?.totalPages ?? 0, + tenantName: query.data?.tenantName ?? "", + isLoading: query.isLoading, + isError: query.isError, + error: query.error as Error | null, + refetch: async () => { + await query.refetch(); + }, + }; +} + diff --git a/frontend/services/api.ts b/frontend/services/api.ts index ceec352b6..f7f76afb4 100644 --- a/frontend/services/api.ts +++ b/frontend/services/api.ts @@ -124,6 +124,7 @@ export const API_ENDPOINTS = { updateBatchModel: `${API_BASE_URL}/model/batch_update`, // LLM model list for generation llmModelList: `${API_BASE_URL}/model/llm_list`, + adminModelList: `${API_BASE_URL}/model/admin/list`, }, knowledgeBase: { // Elasticsearch service diff --git a/frontend/services/modelService.ts b/frontend/services/modelService.ts index e329db237..76db80df9 100644 --- a/frontend/services/modelService.ts +++ b/frontend/services/modelService.ts @@ -466,6 +466,81 @@ export const modelService = { return []; } }, + + // Get model list for a specific tenant (admin operation) + getAdminTenantModels: async (params: { + tenantId: string; + modelType?: string; + page?: number; + pageSize?: number; + }): Promise<{ + models: ModelOption[]; + total: number; + page: number; + pageSize: number; + totalPages: number; + tenantName: string; + }> => { + try { + const response = await fetch(API_ENDPOINTS.model.adminModelList, { + method: "POST", + headers: { + ...getAuthHeaders(), + "Content-Type": "application/json", + }, + body: JSON.stringify({ + tenant_id: params.tenantId, + model_type: params.modelType, + page: params.page || 1, + page_size: params.pageSize || 20, + }), + }); + const result = await response.json(); + + if (response.status === STATUS_CODES.SUCCESS && result.data) { + return { + models: result.data.models.map((model: any) => ({ + id: model.model_id, + name: model.model_name, + type: model.model_type as ModelType, + maxTokens: model.max_tokens || 0, + source: model.model_factory as ModelSource, + apiKey: model.api_key || "", + apiUrl: model.base_url || "", + displayName: model.display_name || model.model_name, + connect_status: model.connect_status as ModelConnectStatus, + expectedChunkSize: model.expected_chunk_size, + maximumChunkSize: model.maximum_chunk_size, + chunkingBatchSize: model.chunk_batch, + })), + total: result.data.total || 0, + page: result.data.page || 1, + pageSize: result.data.page_size || 20, + totalPages: result.data.total_pages || 0, + tenantName: result.data.tenant_name || "", + }; + } + + return { + models: [], + total: 0, + page: 1, + pageSize: 20, + totalPages: 0, + tenantName: "", + }; + } catch (error) { + log.warn("Failed to load admin tenant models:", error); + return { + models: [], + total: 0, + page: 1, + pageSize: 20, + totalPages: 0, + tenantName: "", + }; + } + }, }; // -------- Provider detection helpers (for UI rendering) -------- diff --git a/test/backend/app/test_model_managment_app.py b/test/backend/app/test_model_managment_app.py index 3994cd86a..fbbafc390 100644 --- a/test/backend/app/test_model_managment_app.py +++ b/test/backend/app/test_model_managment_app.py @@ -616,5 +616,163 @@ async def mock_batch_update(*args, **kwargs): mock_batch_update.assert_called_once_with(user_credentials[0], user_credentials[1], models) +# Tests for /model/admin/list endpoint +@pytest.mark.asyncio +async def test_get_admin_model_list_success(client, auth_header, user_credentials, mocker): + """Test successful admin model list retrieval for a specified tenant.""" + mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials) + + async def mock_list_models_for_admin(*args, **kwargs): + return { + "tenant_id": "target_tenant", + "tenant_name": "Target Tenant", + "models": [ + { + "model_id": "model1", + "model_name": "huggingface/llama", + "display_name": "LLaMA Model", + "model_type": "llm", + "connect_status": "operational" + }, + { + "model_id": "model2", + "model_name": "openai/clip", + "display_name": "CLIP Model", + "model_type": "embedding", + "connect_status": "not_detected" + } + ], + "total": 2, + "page": 1, + "page_size": 20, + "total_pages": 1 + } + + mock_list = mocker.patch('apps.model_managment_app.list_models_for_admin', side_effect=mock_list_models_for_admin) + + request_data = { + "tenant_id": "target_tenant", + "model_type": None, + "page": 1, + "page_size": 20 + } + response = client.post("/model/admin/list", json=request_data, headers=auth_header) + + assert response.status_code == HTTPStatus.OK + data = response.json() + assert "Successfully retrieved model list" in data["message"] + assert data["data"]["tenant_id"] == "target_tenant" + assert data["data"]["tenant_name"] == "Target Tenant" + assert data["data"]["total"] == 2 + assert data["data"]["page"] == 1 + assert data["data"]["page_size"] == 20 + assert data["data"]["total_pages"] == 1 + assert len(data["data"]["models"]) == 2 + assert data["data"]["models"][0]["model_name"] == "huggingface/llama" + assert data["data"]["models"][1]["model_name"] == "openai/clip" + mock_list.assert_called_once_with("target_tenant", None, 1, 20) + + +@pytest.mark.asyncio +async def test_get_admin_model_list_with_pagination(client, auth_header, user_credentials, mocker): + """Test admin model list retrieval with pagination parameters.""" + mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials) + + async def mock_list_models_for_admin(*args, **kwargs): + return { + "tenant_id": "target_tenant", + "tenant_name": "Target Tenant", + "models": [ + { + "model_id": "model3", + "model_name": "openai/gpt-3", + "display_name": "GPT-3", + "model_type": "llm", + "connect_status": "operational" + } + ], + "total": 25, + "page": 2, + "page_size": 10, + "total_pages": 3 + } + + mock_list = mocker.patch('apps.model_managment_app.list_models_for_admin', side_effect=mock_list_models_for_admin) + + request_data = { + "tenant_id": "target_tenant", + "model_type": "llm", + "page": 2, + "page_size": 10 + } + response = client.post("/model/admin/list", json=request_data, headers=auth_header) + + assert response.status_code == HTTPStatus.OK + data = response.json() + assert data["data"]["page"] == 2 + assert data["data"]["page_size"] == 10 + assert data["data"]["total"] == 25 + assert data["data"]["total_pages"] == 3 + assert len(data["data"]["models"]) == 1 + mock_list.assert_called_once_with("target_tenant", "llm", 2, 10) + + +@pytest.mark.asyncio +async def test_get_admin_model_list_exception(client, auth_header, user_credentials, mocker): + """Test admin model list retrieval with exception.""" + mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials) + + async def mock_list_models_for_admin(*args, **kwargs): + raise Exception("Database connection error") + + mocker.patch('apps.model_managment_app.list_models_for_admin', side_effect=mock_list_models_for_admin) + + request_data = { + "tenant_id": "target_tenant", + "model_type": None, + "page": 1, + "page_size": 20 + } + response = client.post("/model/admin/list", json=request_data, headers=auth_header) + + assert response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR + data = response.json() + assert "Database connection error" in data.get("detail", "") + + +@pytest.mark.asyncio +async def test_get_admin_model_list_empty(client, auth_header, user_credentials, mocker): + """Test admin model list retrieval with empty result.""" + mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials) + + async def mock_list_models_for_admin(*args, **kwargs): + return { + "tenant_id": "empty_tenant", + "tenant_name": "Empty Tenant", + "models": [], + "total": 0, + "page": 1, + "page_size": 20, + "total_pages": 0 + } + + mock_list = mocker.patch('apps.model_managment_app.list_models_for_admin', side_effect=mock_list_models_for_admin) + + request_data = { + "tenant_id": "empty_tenant", + "model_type": None, + "page": 1, + "page_size": 20 + } + response = client.post("/model/admin/list", json=request_data, headers=auth_header) + + assert response.status_code == HTTPStatus.OK + data = response.json() + assert "Successfully retrieved model list" in data["message"] + assert data["data"]["total"] == 0 + assert len(data["data"]["models"]) == 0 + mock_list.assert_called_once_with("empty_tenant", None, 1, 20) + + if __name__ == "__main__": pytest.main([__file__]) \ No newline at end of file diff --git a/test/backend/services/test_model_management_service.py b/test/backend/services/test_model_management_service.py index e38e24976..03d2996a3 100644 --- a/test/backend/services/test_model_management_service.py +++ b/test/backend/services/test_model_management_service.py @@ -255,10 +255,15 @@ def _update_config_by_tenant_config_id_and_data(*args, **kwargs): return None +def _update_config_by_tenant_config_id(*args, **kwargs): + return None + + db_tenant_cfg_mod.delete_config_by_tenant_config_id = _delete_config_by_tenant_config_id db_tenant_cfg_mod.get_all_configs_by_tenant_id = _get_all_configs_by_tenant_id db_tenant_cfg_mod.get_single_config_info = _get_single_config_info db_tenant_cfg_mod.insert_config = _insert_config +db_tenant_cfg_mod.update_config_by_tenant_config_id = _update_config_by_tenant_config_id db_tenant_cfg_mod.update_config_by_tenant_config_id_and_data = _update_config_by_tenant_config_id_and_data sys.modules["database.tenant_config_db"] = db_tenant_cfg_mod @@ -282,6 +287,17 @@ async def _clear_model_memories(**kwargs): nexent_memory_mod.clear_model_memories = _clear_model_memories sys.modules["nexent.memory.memory_service"] = nexent_memory_mod +# Stub services.tenant_service required by list_models_for_admin +services_tenant_mod = types.ModuleType("services.tenant_service") + + +def _get_tenant_info(tenant_id): + return {"tenant_name": "Test Tenant"} + + +services_tenant_mod.get_tenant_info = _get_tenant_info +sys.modules["services.tenant_service"] = services_tenant_mod + def import_svc(): """Import service under MinioClient patch to avoid real initialization.""" @@ -1089,3 +1105,135 @@ async def test_list_llm_models_for_tenant_handles_missing_repo(): assert len(result) == 2 assert result[0]["model_name"] == "local-model" # No repo prefix assert result[1]["model_name"] == "another-model" # No repo prefix + + +# Tests for list_models_for_admin +async def test_list_models_for_admin_success(): + """Test list_models_for_tenant returns models for a specified tenant.""" + svc = import_svc() + + records = [ + {"model_repo": "huggingface", "model_name": "llama", + "connect_status": "operational", "model_type": "llm"}, + {"model_repo": "openai", "model_name": "clip", "connect_status": None, "model_type": "embedding"}, + ] + + with mock.patch.object(svc, "get_model_records", return_value=records), \ + mock.patch.object(svc, "add_repo_to_name", side_effect=lambda model_repo, model_name: f"{model_repo}/{model_name}" if model_repo else model_name), \ + mock.patch.object(svc.ModelConnectStatusEnum, "get_value", side_effect=lambda s: s or "not_detected"), \ + mock.patch("services.tenant_service.get_tenant_info", return_value={"tenant_name": "Test Tenant"}): + out = await svc.list_models_for_admin("t1") + assert out["tenant_id"] == "t1" + assert out["tenant_name"] == "Test Tenant" + assert out["total"] == 2 + assert out["page"] == 1 + assert out["page_size"] == 20 + assert out["total_pages"] == 1 + assert len(out["models"]) == 2 + assert out["models"][0]["model_name"] == "huggingface/llama" + + +async def test_list_models_for_admin_with_pagination(): + """Test list_models_for_tenant handles pagination correctly.""" + svc = import_svc() + + # Create 25 records to test pagination + records = [ + {"model_repo": "openai", "model_name": f"gpt-{i}", "connect_status": "operational", "model_type": "llm"} + for i in range(25) + ] + + with mock.patch.object(svc, "get_model_records", return_value=records), \ + mock.patch.object(svc, "add_repo_to_name", side_effect=lambda model_repo, model_name: f"{model_repo}/{model_name}" if model_repo else model_name), \ + mock.patch.object(svc.ModelConnectStatusEnum, "get_value", side_effect=lambda s: s or "not_detected"), \ + mock.patch("services.tenant_service.get_tenant_info", return_value={"tenant_name": "Test Tenant"}): + # Page 1, page_size 10 + out = await svc.list_models_for_admin("t1", page=1, page_size=10) + assert out["page"] == 1 + assert out["page_size"] == 10 + assert out["total"] == 25 + assert out["total_pages"] == 3 + assert len(out["models"]) == 10 + assert out["models"][0]["model_name"] == "openai/gpt-0" + + # Page 2 + out = await svc.list_models_for_admin("t1", page=2, page_size=10) + assert out["page"] == 2 + assert len(out["models"]) == 10 + assert out["models"][0]["model_name"] == "openai/gpt-10" + + # Page 3 (last page) + out = await svc.list_models_for_admin("t1", page=3, page_size=10) + assert out["page"] == 3 + assert out["total_pages"] == 3 + assert len(out["models"]) == 5 + assert out["models"][0]["model_name"] == "openai/gpt-20" + + +async def test_list_models_for_admin_with_model_type_filter(): + """Test list_models_for_tenant filters by model_type.""" + svc = import_svc() + + records = [ + {"model_repo": "openai", "model_name": "gpt-4", "connect_status": "operational", "model_type": "llm"}, + {"model_repo": "openai", "model_name": "text-embedding", "connect_status": "operational", "model_type": "embedding"}, + ] + + with mock.patch.object(svc, "get_model_records", return_value=records) as mock_get_records, \ + mock.patch.object(svc, "add_repo_to_name", side_effect=lambda model_repo, model_name: f"{model_repo}/{model_name}" if model_repo else model_name), \ + mock.patch.object(svc.ModelConnectStatusEnum, "get_value", side_effect=lambda s: s or "not_detected"), \ + mock.patch("services.tenant_service.get_tenant_info", return_value={"tenant_name": "Test Tenant"}): + # Filter by llm + out = await svc.list_models_for_admin("t1", model_type="llm") + mock_get_records.assert_called_once_with({"model_type": "llm"}, "t1") + assert out["total"] == 2 + assert out["models"][0]["model_type"] == "llm" + + +async def test_list_models_for_admin_empty_tenant(): + """Test list_models_for_tenant handles empty tenant gracefully.""" + svc = import_svc() + + with mock.patch.object(svc, "get_model_records", return_value=[]), \ + mock.patch("services.tenant_service.get_tenant_info", return_value={"tenant_name": ""}): + out = await svc.list_models_for_admin("t1") + assert out["tenant_id"] == "t1" + assert out["tenant_name"] == "" + assert out["total"] == 0 + assert out["total_pages"] == 0 + assert len(out["models"]) == 0 + + +async def test_list_models_for_admin_exception(): + """Test list_models_for_tenant handles exceptions.""" + svc = import_svc() + + with mock.patch.object(svc, "get_model_records", side_effect=Exception("db error")): + with pytest.raises(Exception) as exc: + await svc.list_models_for_admin("t1") + assert "Failed to retrieve admin model list" in str(exc.value) + + +async def test_list_models_for_admin_type_mapping(): + """Test list_models_for_tenant maps model_type from 'chat' to 'llm'.""" + svc = import_svc() + + records = [ + { + "model_id": "llm1", + "model_repo": "openai", + "model_name": "gpt-4", + "display_name": "GPT-4", + "model_type": "chat", # Should be mapped to "llm" + "connect_status": "operational" + }, + ] + + with mock.patch.object(svc, "get_model_records", return_value=records), \ + mock.patch.object(svc, "add_repo_to_name", side_effect=lambda model_repo, model_name: f"{model_repo}/{model_name}" if model_repo else model_name), \ + mock.patch.object(svc.ModelConnectStatusEnum, "get_value", side_effect=lambda s: s or "not_detected"), \ + mock.patch("services.tenant_service.get_tenant_info", return_value={"tenant_name": "Test Tenant"}): + out = await svc.list_models_for_admin("t1") + + assert len(out["models"]) == 1 + assert out["models"][0]["model_type"] == "llm" # Should be mapped from "chat" \ No newline at end of file From 99befe982607969b4745bf248f798f392855bbc4 Mon Sep 17 00:00:00 2001 From: wmc1112 <46217886+wmc001@users.noreply.github.com> Date: Mon, 2 Feb 2026 18:59:43 +0800 Subject: [PATCH 035/167] =?UTF-8?q?=E2=9C=A8=20User=20management:=20Model?= =?UTF-8?q?=20list=20needs=20tenant=20isolation=20(unit=20test)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../services/test_model_management_service.py | 45 ++++++++----------- 1 file changed, 19 insertions(+), 26 deletions(-) diff --git a/test/backend/services/test_model_management_service.py b/test/backend/services/test_model_management_service.py index 1df55f260..6d0806299 100644 --- a/test/backend/services/test_model_management_service.py +++ b/test/backend/services/test_model_management_service.py @@ -287,26 +287,25 @@ async def _clear_model_memories(**kwargs): nexent_memory_mod.clear_model_memories = _clear_model_memories sys.modules["nexent.memory.memory_service"] = nexent_memory_mod -# Stub services.tenant_service required by list_models_for_admin +# Stub services.tenant_service required by list_models_for_admin BEFORE any imports services_tenant_mod = types.ModuleType("services.tenant_service") def _get_tenant_info(tenant_id): + """Mock implementation of get_tenant_info for testing.""" + # Raise exception for empty tenant to test error handling + if tenant_id == "empty_tenant": + raise Exception("Tenant not found") return {"tenant_name": "Test Tenant"} services_tenant_mod.get_tenant_info = _get_tenant_info sys.modules["services.tenant_service"] = services_tenant_mod -# Also stub the backend-level import path -backend_services_tenant_mod = types.ModuleType("backend.services.tenant_service") -backend_services_tenant_mod.get_tenant_info = _get_tenant_info -sys.modules["backend.services.tenant_service"] = backend_services_tenant_mod -# Stub parent 'services' package to prevent attribute access error -services_pkg = types.ModuleType("services") -services_pkg.tenant_service = services_tenant_mod -sys.modules["services"] = services_pkg +def _add_repo_to_name(model_repo, model_name): + """Mock implementation of add_repo_to_name for testing.""" + return f"{model_repo}/{model_name}" if model_repo else model_name def import_svc(): @@ -1130,8 +1129,7 @@ async def test_list_models_for_admin_success(): with mock.patch.object(svc, "get_model_records", return_value=records), \ mock.patch.object(svc, "add_repo_to_name", side_effect=lambda model_repo, model_name: f"{model_repo}/{model_name}" if model_repo else model_name), \ - mock.patch.object(svc.ModelConnectStatusEnum, "get_value", side_effect=lambda s: s or "not_detected"), \ - mock.patch.object(svc, "get_tenant_info", return_value={"tenant_name": "Test Tenant"}): + mock.patch.object(svc.ModelConnectStatusEnum, "get_value", side_effect=lambda s: s or "not_detected"): out = await svc.list_models_for_admin("t1") assert out["tenant_id"] == "t1" assert out["tenant_name"] == "Test Tenant" @@ -1154,9 +1152,8 @@ async def test_list_models_for_admin_with_pagination(): ] with mock.patch.object(svc, "get_model_records", return_value=records), \ - mock.patch.object(svc, "add_repo_to_name", side_effect=lambda model_repo, model_name: f"{model_repo}/{model_name}" if model_repo else model_name), \ - mock.patch.object(svc.ModelConnectStatusEnum, "get_value", side_effect=lambda s: s or "not_detected"), \ - mock.patch.object(svc, "get_tenant_info", return_value={"tenant_name": "Test Tenant"}): + mock.patch("backend.utils.model_name_utils.add_repo_to_name", side_effect=_add_repo_to_name), \ + mock.patch.object(svc.ModelConnectStatusEnum, "get_value", side_effect=lambda s: s or "not_detected"): # Page 1, page_size 10 out = await svc.list_models_for_admin("t1", page=1, page_size=10) assert out["page"] == 1 @@ -1170,14 +1167,12 @@ async def test_list_models_for_admin_with_pagination(): out = await svc.list_models_for_admin("t1", page=2, page_size=10) assert out["page"] == 2 assert len(out["models"]) == 10 - assert out["models"][0]["model_name"] == "openai/gpt-10" # Page 3 (last page) out = await svc.list_models_for_admin("t1", page=3, page_size=10) assert out["page"] == 3 assert out["total_pages"] == 3 assert len(out["models"]) == 5 - assert out["models"][0]["model_name"] == "openai/gpt-20" async def test_list_models_for_admin_with_model_type_filter(): @@ -1190,9 +1185,8 @@ async def test_list_models_for_admin_with_model_type_filter(): ] with mock.patch.object(svc, "get_model_records", return_value=records) as mock_get_records, \ - mock.patch.object(svc, "add_repo_to_name", side_effect=lambda model_repo, model_name: f"{model_repo}/{model_name}" if model_repo else model_name), \ - mock.patch.object(svc.ModelConnectStatusEnum, "get_value", side_effect=lambda s: s or "not_detected"), \ - mock.patch.object(svc, "get_tenant_info", return_value={"tenant_name": "Test Tenant"}): + mock.patch("backend.utils.model_name_utils.add_repo_to_name", side_effect=lambda model_repo, model_name: f"{model_repo}/{model_name}" if model_repo else model_name), \ + mock.patch.object(svc.ModelConnectStatusEnum, "get_value", side_effect=lambda s: s or "not_detected"): # Filter by llm out = await svc.list_models_for_admin("t1", model_type="llm") mock_get_records.assert_called_once_with({"model_type": "llm"}, "t1") @@ -1204,10 +1198,10 @@ async def test_list_models_for_admin_empty_tenant(): """Test list_models_for_tenant handles empty tenant gracefully.""" svc = import_svc() - with mock.patch.object(svc, "get_model_records", return_value=[]), \ - mock.patch.object(svc, "get_tenant_info", return_value={"tenant_name": ""}): - out = await svc.list_models_for_admin("t1") - assert out["tenant_id"] == "t1" + with mock.patch.object(svc, "get_model_records", return_value=[]): + # Use "empty_tenant" ID to trigger exception in stub, resulting in empty tenant_name + out = await svc.list_models_for_admin("empty_tenant") + assert out["tenant_id"] == "empty_tenant" assert out["tenant_name"] == "" assert out["total"] == 0 assert out["total_pages"] == 0 @@ -1241,9 +1235,8 @@ async def test_list_models_for_admin_type_mapping(): with mock.patch.object(svc, "get_model_records", return_value=records), \ mock.patch.object(svc, "add_repo_to_name", side_effect=lambda model_repo, model_name: f"{model_repo}/{model_name}" if model_repo else model_name), \ - mock.patch.object(svc.ModelConnectStatusEnum, "get_value", side_effect=lambda s: s or "not_detected"), \ - mock.patch.object(svc, "get_tenant_info", return_value={"tenant_name": "Test Tenant"}): + mock.patch.object(svc.ModelConnectStatusEnum, "get_value", side_effect=lambda s: s or "not_detected"): out = await svc.list_models_for_admin("t1") assert len(out["models"]) == 1 - assert out["models"][0]["model_type"] == "llm" # Should be mapped from "chat" \ No newline at end of file + assert out["models"][0]["model_type"] == "llm" # Should be mapped from "chat" From 2baf591a32657015db3ef346ead44e25a1262d29 Mon Sep 17 00:00:00 2001 From: wmc1112 <46217886+WMC001@users.noreply.github.com> Date: Mon, 2 Feb 2026 20:04:28 +0800 Subject: [PATCH 036/167] =?UTF-8?q?=E2=9C=A8=20User=20management:=20Add=20?= =?UTF-8?q?default=20tenant=20name=20if=20not=20exist?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/services/tenant_service.py | 83 +++++++++++++++----- test/backend/services/test_tenant_service.py | 70 +++++++++++++---- 2 files changed, 119 insertions(+), 34 deletions(-) diff --git a/backend/services/tenant_service.py b/backend/services/tenant_service.py index b58ffc8b2..4aad0dd43 100644 --- a/backend/services/tenant_service.py +++ b/backend/services/tenant_service.py @@ -23,19 +23,22 @@ def get_tenant_info(tenant_id: str) -> Dict[str, Any]: """ Get tenant information by tenant ID + If TENANT_NAME config is missing, automatically create one with default name. + Args: tenant_id (str): Tenant ID Returns: Dict[str, Any]: Tenant information - - Raises: - NotFoundException: When tenant not found """ # Get tenant name name_config = get_single_config_info(tenant_id, TENANT_NAME) if not name_config: - logging.warning(f"The name of tenant {tenant_id} not found.") + logger.warning(f"The name of tenant {tenant_id} not found, creating default config.") + # Auto-create TENANT_NAME config with default name + _ensure_tenant_name_config(tenant_id) + # Re-fetch after creation + name_config = get_single_config_info(tenant_id, TENANT_NAME) group_config = get_single_config_info(tenant_id, DEFAULT_GROUP_ID) @@ -48,6 +51,38 @@ def get_tenant_info(tenant_id: str) -> Dict[str, Any]: return tenant_info +def _ensure_tenant_name_config(tenant_id: str) -> bool: + """ + Ensure TENANT_NAME config exists for the tenant. + Creates a default name config if it doesn't exist. + + Args: + tenant_id: Tenant ID + + Returns: + bool: True if config exists or was created successfully, False otherwise + """ + # Check if already exists (double-check in case of race condition) + existing = get_single_config_info(tenant_id, TENANT_NAME) + if existing: + return True + + # Create default TENANT_NAME config + tenant_name_data = { + "tenant_id": tenant_id, + "config_key": TENANT_NAME, + "config_value": "Unnamed Tenant", + "created_by": "system_auto_create", + "updated_by": "system_auto_create" + } + success = insert_config(tenant_name_data) + if success: + logger.info(f"Auto-created TENANT_NAME config for tenant {tenant_id}") + else: + logger.error(f"Failed to auto-create TENANT_NAME config for tenant {tenant_id}") + return success + + def get_all_tenants() -> List[Dict[str, Any]]: """ Get all tenants @@ -163,6 +198,8 @@ def update_tenant_info(tenant_id: str, tenant_name: str, updated_by: Optional[st """ Update tenant information + If TENANT_NAME config doesn't exist, creates it with the provided name. + Args: tenant_id (str): Tenant ID tenant_name (str): New tenant name @@ -172,25 +209,35 @@ def update_tenant_info(tenant_id: str, tenant_name: str, updated_by: Optional[st Dict[str, Any]: Updated tenant information Raises: - NotFoundException: When tenant not found - ValidationError: When tenant name is invalid + ValidationError: When tenant name is invalid or update fails """ - # Check if tenant exists and get current name config - name_config = get_single_config_info(tenant_id, TENANT_NAME) - if not name_config: - raise NotFoundException(f"Tenant {tenant_id} not found") - # Validate tenant name if not tenant_name or not tenant_name.strip(): raise ValidationError("Tenant name cannot be empty") - # Update tenant name - success = update_config_by_tenant_config_id( - name_config["tenant_config_id"], - tenant_name.strip() - ) - if not success: - raise ValidationError("Failed to update tenant name") + # Check if tenant name config exists + name_config = get_single_config_info(tenant_id, TENANT_NAME) + if not name_config: + # Tenant config doesn't exist, create it with the provided name + logger.info(f"TENANT_NAME config not found for {tenant_id}, creating new config.") + tenant_name_data = { + "tenant_id": tenant_id, + "config_key": TENANT_NAME, + "config_value": tenant_name.strip(), + "created_by": updated_by, + "updated_by": updated_by + } + success = insert_config(tenant_name_data) + if not success: + raise ValidationError("Failed to create tenant name configuration") + else: + # Update existing config + success = update_config_by_tenant_config_id( + name_config["tenant_config_id"], + tenant_name.strip() + ) + if not success: + raise ValidationError("Failed to update tenant name") # Return updated tenant information updated_tenant = get_tenant_info(tenant_id) diff --git a/test/backend/services/test_tenant_service.py b/test/backend/services/test_tenant_service.py index 40ff00edd..728eb449c 100644 --- a/test/backend/services/test_tenant_service.py +++ b/test/backend/services/test_tenant_service.py @@ -81,24 +81,30 @@ def test_get_tenant_info_success(self, service_mocks): tenant_id, "DEFAULT_GROUP_ID") def test_get_tenant_info_name_not_found(self, service_mocks): - """Test get_tenant_info when tenant name is not found""" + """Test get_tenant_info when tenant name is not found - should auto-create config""" # Setup tenant_id = "test_tenant_id" - # Mock config functions to return empty dict for name + # Mock config functions service_mocks['get_single_config_info'].side_effect = [ - {}, # TENANT_NAME not found + {}, # TENANT_NAME first check (not found) + {}, # TENANT_NAME check in _ensure_tenant_name_config (double-check) + {"config_value": "Unnamed Tenant", "tenant_config_id": 1}, # TENANT_NAME after auto-create {"config_value": "group-123"} # DEFAULT_GROUP_ID ] + service_mocks['insert_config'].return_value = True # Execute result = get_tenant_info(tenant_id) - # Assert - should return tenant info with empty name + # Assert - should return tenant info with auto-created default name assert result["tenant_id"] == tenant_id - assert result["tenant_name"] == "" + assert result["tenant_name"] == "Unnamed Tenant" assert result["default_group_id"] == "group-123" + # Verify insert_config was called to create the missing config + service_mocks['insert_config'].assert_called_once() + def test_get_tenant_info_with_empty_group_id(self, service_mocks): """Test get_tenant_info when default group ID is empty""" # Setup @@ -133,21 +139,34 @@ def test_get_tenant_info_get_single_config_exception(self, service_mocks): get_tenant_info(tenant_id) def test_get_tenant_info_both_configs_none(self, service_mocks): - """Test get_tenant_info when both configs return None""" + """Test get_tenant_info when both configs return None - should auto-create name config""" # Setup tenant_id = "test_tenant_id" - # Mock config functions to return None - service_mocks['get_single_config_info'].side_effect = [None, None] + # Mock config functions: + # 1st call: TENANT_NAME not found (None) + # 2nd call: TENANT_NAME check in _ensure_tenant_name_config (None - double-check) + # 3rd call: after insert, re-fetch returns the created config + # 4th call: DEFAULT_GROUP_ID returns None + service_mocks['get_single_config_info'].side_effect = [ + None, # TENANT_NAME first check (None) + None, # TENANT_NAME check in _ensure_tenant_name_config + {"config_value": "Unnamed Tenant", "tenant_config_id": 1}, # TENANT_NAME after auto-create + None # DEFAULT_GROUP_ID (None) + ] + service_mocks['insert_config'].return_value = True # Execute result = get_tenant_info(tenant_id) - # Assert - should return tenant info with empty name and group_id + # Assert - should return tenant info with auto-created default name and empty group_id assert result["tenant_id"] == tenant_id - assert result["tenant_name"] == "" + assert result["tenant_name"] == "Unnamed Tenant" assert result["default_group_id"] == "" + # Verify insert_config was called to create the missing config + service_mocks['insert_config'].assert_called_once() + class TestGetAllTenants: """Test cases for get_all_tenants function""" @@ -427,18 +446,37 @@ def test_update_tenant_info_success(self, service_mocks): assert result["tenant_name"] == new_tenant_name def test_update_tenant_info_tenant_not_found(self, service_mocks): - """Test update_tenant_info when tenant doesn't exist""" + """Test update_tenant_info when tenant doesn't exist - should auto-create config""" # Setup tenant_id = "nonexistent_tenant" new_tenant_name = "Updated Name" user_id = "updater_user" - # Mock get_single_config_info to return empty dict (not found) - service_mocks['get_single_config_info'].return_value = {} + # Mock get_single_config_info to return empty dict on first call (TENANT_NAME not found), + # then return the newly created config after auto-creation + service_mocks['get_single_config_info'].side_effect = [ + {}, # First check - not found + {"config_value": new_tenant_name, "tenant_config_id": 1} # After auto-create + ] + service_mocks['insert_config'].return_value = True - # Execute & Assert - with pytest.raises(NotFoundException, match="not found"): - update_tenant_info(tenant_id, new_tenant_name, user_id) + # Mock get_tenant_info to return updated info + with patch('backend.services.tenant_service.get_tenant_info') as mock_get_tenant_info: + mock_get_tenant_info.return_value = { + "tenant_id": tenant_id, + "tenant_name": new_tenant_name, + "default_group_id": "group-123" + } + + # Execute - should NOT raise NotFoundException, instead auto-create config + result = update_tenant_info(tenant_id, new_tenant_name, user_id) + + # Assert - update should succeed by auto-creating the config + assert result["tenant_id"] == tenant_id + assert result["tenant_name"] == new_tenant_name + + # Verify insert_config was called to create the missing config + service_mocks['insert_config'].assert_called_once() def test_update_tenant_info_empty_name(self, service_mocks): """Test update_tenant_info with empty name""" From db1a6a352e25f037759e355874ab4a562b9baba5 Mon Sep 17 00:00:00 2001 From: panyehong <2655992392@qq.com> Date: Mon, 2 Feb 2026 20:50:25 +0800 Subject: [PATCH 037/167] =?UTF-8?q?=E2=99=BB=EF=B8=8FRefactoring=20authent?= =?UTF-8?q?ication&authorization=20to=20develop=20the=20user=20management?= =?UTF-8?q?=20feature=20[Specification=20Details]=201.=20Add=20group=5Fids?= =?UTF-8?q?=20in=20authService.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- frontend/hooks/agent/useSaveGuard.ts | 15 +++++++++++---- frontend/hooks/auth/useAuthorization.ts | 14 +++++++++----- frontend/types/auth.ts | 1 + 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/frontend/hooks/agent/useSaveGuard.ts b/frontend/hooks/agent/useSaveGuard.ts index a8ab6c09d..e41b4e0b1 100644 --- a/frontend/hooks/agent/useSaveGuard.ts +++ b/frontend/hooks/agent/useSaveGuard.ts @@ -6,6 +6,7 @@ import { useConfirmModal } from "../useConfirmModal"; import { useAgentConfigStore } from "@/stores/agentConfigStore"; import { updateAgentInfo, updateToolConfig } from "@/services/agentConfigService"; import { Agent } from "@/types/agentConfig"; +import log from "@/lib/logger"; /** * Hook for handling agent save guard logic @@ -43,12 +44,17 @@ export const useSaveGuard = () => { .map((id: any) => Number(id)) .filter((id: number) => Number.isFinite(id)); + const groupIds = (currentEditedAgent.group_ids || []) + .map((id: any) => Number(id)) + .filter((id: number) => Number.isFinite(id)); + const result = await updateAgentInfo({ agent_id: currentAgentId ?? undefined, // undefined=create, number=update name: currentEditedAgent.name, display_name: currentEditedAgent.display_name, description: currentEditedAgent.description, author: currentEditedAgent.author, + group_ids: groupIds, model_name: currentEditedAgent.model, model_id: currentEditedAgent.model_id ?? undefined, max_steps: currentEditedAgent.max_step, @@ -65,7 +71,7 @@ export const useSaveGuard = () => { }); if (result.success) { - useAgentConfigStore.getState().markAsSaved(); // 标记为已保存 + useAgentConfigStore.getState().markAsSaved(); // Mark as saved message.success( t("businessLogic.config.message.agentSaveSuccess") ); @@ -92,7 +98,7 @@ export const useSaveGuard = () => { try { await updateToolConfig(toolId, agentIdNumber, params, isEnabled); } catch (error) { - console.error(`Failed to save tool config for tool ${toolId}:`, error); + log.error(`Failed to save tool config for tool ${toolId}:`, error); // Continue with other tools even if one fails } } @@ -108,7 +114,7 @@ export const useSaveGuard = () => { }); // Get the updated agent data from the refreshed cache let updatedAgent = queryClient.getQueryData(["agentInfo", finalAgentId]) as Agent; - + // For new agents, the cache might not be populated yet // Construct a minimal Agent object from the edited data if (!updatedAgent && finalAgentId) { @@ -130,9 +136,10 @@ export const useSaveGuard = () => { business_logic_model_name: currentEditedAgent.business_logic_model_name, business_logic_model_id: currentEditedAgent.business_logic_model_id, sub_agent_id_list: currentEditedAgent.sub_agent_id_list, + group_ids: currentEditedAgent.group_ids || [], }; } - + if (updatedAgent) { useAgentConfigStore.getState().setCurrentAgent(updatedAgent); } diff --git a/frontend/hooks/auth/useAuthorization.ts b/frontend/hooks/auth/useAuthorization.ts index 38d78bd12..ba98abdb0 100644 --- a/frontend/hooks/auth/useAuthorization.ts +++ b/frontend/hooks/auth/useAuthorization.ts @@ -88,7 +88,7 @@ export function useAuthorization(): AuthorizationContextType { accessibleRoutes, }); - + } else { log.warn("Missing permissions or accessibleRoutes in user info", { hasPermissions: !!permissions, @@ -109,16 +109,17 @@ export function useAuthorization(): AuthorizationContextType { // Check both status string and isSuccess boolean for compatibility if (result.data && (result.status === 'success' || result.isSuccess)) { const { user } = result.data; - + if (user) { - const { permissions, accessibleRoutes, ...userInfo } = user; - + const { permissions, accessibleRoutes, groupIds, ...userInfo } = user; + if (permissions && accessibleRoutes) { setUser(userInfo as User); + setGroupIds(groupIds); setPermissions(permissions); setAccessibleRoutes(accessibleRoutes); setIsAuthzReady(true); - + authzEventUtils.emitPermissionsReady({ ...userInfo, permissions, @@ -138,6 +139,7 @@ export function useAuthorization(): AuthorizationContextType { const handleLogout = () => { log.info("User logged out, clearing authorization data..."); setUser(null); + setGroupIds([]); setPermissions([]); setAccessibleRoutes([]); setIsAuthzReady(false); @@ -147,6 +149,7 @@ export function useAuthorization(): AuthorizationContextType { const handleSessionExpired = () => { log.info("Session expired, clearing authorization data..."); setUser(null); + setGroupIds([]); setPermissions([]); setAccessibleRoutes([]); setIsAuthzReady(false); @@ -252,6 +255,7 @@ export function useAuthorization(): AuthorizationContextType { return { // Authorization data user, + groupIds, permissions, accessibleRoutes, diff --git a/frontend/types/auth.ts b/frontend/types/auth.ts index dacc567fc..477982f27 100644 --- a/frontend/types/auth.ts +++ b/frontend/types/auth.ts @@ -202,6 +202,7 @@ export interface AuthenticationUIReturn { export interface AuthorizationContextType { // Authorization data user: User | null; + groupIds: number[]; permissions: string[]; accessibleRoutes: string[]; From 211ab18cb0a23fb9225d0c0a3c33da982978e175 Mon Sep 17 00:00:00 2001 From: panyehong <2655992392@qq.com> Date: Mon, 2 Feb 2026 21:53:38 +0800 Subject: [PATCH 038/167] =?UTF-8?q?=E2=99=BB=EF=B8=8FRefactoring=20authent?= =?UTF-8?q?ication&authorization=20to=20develop=20the=20user=20management?= =?UTF-8?q?=20feature=20[Specification=20Details]=201.=20Add=20group=5Fids?= =?UTF-8?q?=20in=20authService.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- frontend/stores/agentConfigStore.ts | 50 ++++++++++++++++++++--------- frontend/types/agentConfig.ts | 16 +++++++++ 2 files changed, 51 insertions(+), 15 deletions(-) diff --git a/frontend/stores/agentConfigStore.ts b/frontend/stores/agentConfigStore.ts index d399d2247..00e81b230 100644 --- a/frontend/stores/agentConfigStore.ts +++ b/frontend/stores/agentConfigStore.ts @@ -36,10 +36,18 @@ export type EditableAgent = Pick< | "business_logic_model_name" | "business_logic_model_id" | "sub_agent_id_list" + | "group_ids" >; interface AgentConfigStoreState { currentAgentId: number | null; + /** + * Per-agent permission from /agent/list. + * - EDIT: editable + * - READ_ONLY: read-only + * null: unknown / not selected + */ + currentAgentPermission: "EDIT" | "READ_ONLY" | null; baselineAgent: EditableAgent | null; editedAgent: EditableAgent; hasUnsavedChanges: boolean; @@ -95,7 +103,7 @@ interface AgentConfigStoreState { * Reset all state (optional). */ reset: () => void; - + /** * Get the current baseline editable agent (null = create or initial state). * Use isCreatingMode to distinguish between initial state and create mode. @@ -120,6 +128,7 @@ const emptyEditableAgent: EditableAgent = { business_logic_model_name: "", business_logic_model_id: 0, sub_agent_id_list: [], + group_ids: [], }; const toEditable = (agent: Agent | null): EditableAgent => @@ -141,6 +150,7 @@ const toEditable = (agent: Agent | null): EditableAgent => business_logic_model_name: agent.business_logic_model_name || "", business_logic_model_id: agent.business_logic_model_id || 0, sub_agent_id_list: agent.sub_agent_id_list || [], + group_ids: agent.group_ids || [], } : { ...emptyEditableAgent }; @@ -149,7 +159,7 @@ const normalizeArray = (arr: number[]) => (a, b) => a - b ); -// 特定字段的脏检查函数 +// Dirty check helpers for specific field groups const isBusinessInfoDirty = (baselineAgent: EditableAgent | null, editedAgent: EditableAgent): boolean => { if (!baselineAgent) { return ( @@ -178,7 +188,8 @@ const isProfileInfoDirty = (baselineAgent: EditableAgent | null, editedAgent: Ed editedAgent.provide_run_summary !== false || editedAgent.duty_prompt !== "" || editedAgent.constraint_prompt !== "" || - editedAgent.few_shots_prompt !== "" + editedAgent.few_shots_prompt !== "" || + normalizeArray(editedAgent.group_ids || []).length > 0 ); } return ( @@ -192,7 +203,9 @@ const isProfileInfoDirty = (baselineAgent: EditableAgent | null, editedAgent: Ed baselineAgent.provide_run_summary !== editedAgent.provide_run_summary || baselineAgent.duty_prompt !== editedAgent.duty_prompt || baselineAgent.constraint_prompt !== editedAgent.constraint_prompt || - baselineAgent.few_shots_prompt !== editedAgent.few_shots_prompt + baselineAgent.few_shots_prompt !== editedAgent.few_shots_prompt || + JSON.stringify(normalizeArray(baselineAgent.group_ids ?? [])) !== + JSON.stringify(normalizeArray(editedAgent.group_ids ?? [])) ); }; @@ -230,7 +243,8 @@ const isDirty = (baselineAgent: EditableAgent | null, editedAgent: EditableAgent editedAgent.business_description !== "" || editedAgent.business_logic_model_name !== "" || editedAgent.business_logic_model_id !== 0 || - normalizeArray(editedAgent.sub_agent_id_list || []).length > 0 + normalizeArray(editedAgent.sub_agent_id_list || []).length > 0 || + normalizeArray(editedAgent.group_ids || []).length > 0 ); } @@ -251,12 +265,15 @@ const isDirty = (baselineAgent: EditableAgent | null, editedAgent: EditableAgent baselineAgent.business_logic_model_name !== editedAgent.business_logic_model_name || baselineAgent.business_logic_model_id !== editedAgent.business_logic_model_id || JSON.stringify(normalizeArray(baselineAgent.sub_agent_id_list ?? [])) !== - JSON.stringify(normalizeArray(editedAgent.sub_agent_id_list ?? [])) + JSON.stringify(normalizeArray(editedAgent.sub_agent_id_list ?? [])) || + JSON.stringify(normalizeArray(baselineAgent.group_ids ?? [])) !== + JSON.stringify(normalizeArray(editedAgent.group_ids ?? [])) ); }; export const useAgentConfigStore = create((set, get) => ({ currentAgentId: null, + currentAgentPermission: null, baselineAgent: null, editedAgent: { ...emptyEditableAgent }, hasUnsavedChanges: false, @@ -267,6 +284,7 @@ export const useAgentConfigStore = create((set, get) => ( const editedAgent = baselineAgent ? { ...baselineAgent } : { ...emptyEditableAgent }; set({ currentAgentId: agent ? parseInt(agent.id) : null, + currentAgentPermission: agent ? ((agent as any).permission ?? null) : null, baselineAgent, editedAgent, hasUnsavedChanges: false, @@ -277,6 +295,7 @@ export const useAgentConfigStore = create((set, get) => ( enterCreateMode: () => { set({ currentAgentId: null, + currentAgentPermission: "EDIT", baselineAgent: null, editedAgent: { ...emptyEditableAgent }, hasUnsavedChanges: false, @@ -287,8 +306,8 @@ export const useAgentConfigStore = create((set, get) => ( updateTools: (tools) => { set((state) => { const editedAgent = { ...state.editedAgent, tools: [...tools] }; - // 如果已经有未保存的更改,则无需重新计算,直接保持true - // 只有当前状态为干净时,才需要检查tools字段是否有更改 + // If there are already unsaved changes, keep it true and skip recalculation. + // Only when state is clean do we need to check whether tools changed. const hasUnsavedChanges = state.hasUnsavedChanges ? true : isToolsDirty(state.baselineAgent, editedAgent); @@ -303,8 +322,8 @@ export const useAgentConfigStore = create((set, get) => ( const nextIds = normalizeArray(ids); set((state) => { const editedAgent = { ...state.editedAgent, sub_agent_id_list: nextIds }; - // 如果已经有未保存的更改,则无需重新计算,直接保持true - // 只有当前状态为干净时,才需要检查sub agent ids字段是否有更改 + // If there are already unsaved changes, keep it true and skip recalculation. + // Only when state is clean do we need to check whether sub-agent IDs changed. const hasUnsavedChanges = state.hasUnsavedChanges ? true : isSubAgentIdsDirty(state.baselineAgent, editedAgent); @@ -318,8 +337,8 @@ export const useAgentConfigStore = create((set, get) => ( updateBusinessInfo: (payload) => { set((state) => { const editedAgent = { ...state.editedAgent, ...payload }; - // 如果已经有未保存的更改,则无需重新计算,直接保持true - // 只有当前状态为干净时,才需要检查business info字段是否有更改 + // If there are already unsaved changes, keep it true and skip recalculation. + // Only when state is clean do we need to check whether business info changed. const hasUnsavedChanges = state.hasUnsavedChanges ? true : isBusinessInfoDirty(state.baselineAgent, editedAgent); @@ -333,8 +352,8 @@ export const useAgentConfigStore = create((set, get) => ( updateProfileInfo: (payload) => { set((state) => { const editedAgent = { ...state.editedAgent, ...payload }; - // 如果已经有未保存的更改,则无需重新计算,直接保持true - // 只有当前状态为干净时,才需要检查profile info字段是否有更改 + // If there are already unsaved changes, keep it true and skip recalculation. + // Only when state is clean do we need to check whether profile info changed. const hasUnsavedChanges = state.hasUnsavedChanges ? true : isProfileInfoDirty(state.baselineAgent, editedAgent); @@ -367,13 +386,14 @@ export const useAgentConfigStore = create((set, get) => ( reset: () => { set({ currentAgentId: null, + currentAgentPermission: null, baselineAgent: null, editedAgent: { ...emptyEditableAgent }, hasUnsavedChanges: false, isCreatingMode: false, }); }, - + getCurrentAgent: () => { return get().baselineAgent; }, diff --git a/frontend/types/agentConfig.ts b/frontend/types/agentConfig.ts index 56f29ecfb..fdd39a0b9 100644 --- a/frontend/types/agentConfig.ts +++ b/frontend/types/agentConfig.ts @@ -33,6 +33,7 @@ export type AgentProfileInfo = Partial< | "duty_prompt" | "constraint_prompt" | "few_shots_prompt" + | "group_ids" > >; @@ -60,6 +61,11 @@ export interface Agent { is_new?: boolean; sub_agent_id_list?: number[]; group_ids?: number[]; + /** + * Per-agent permission returned by /agent/list. + * EDIT: editable, READ_ONLY: read-only. + */ + permission?: "EDIT" | "READ_ONLY"; } export interface Tool { @@ -338,6 +344,11 @@ export interface McpServer { status: boolean; remote_mcp_server_name?: string; remote_mcp_server?: string; + /** + * Per-item permission returned by /mcp/list. + * EDIT: editable, READ_ONLY: read-only. + */ + permission?: "EDIT" | "READ_ONLY"; } // MCP tool interface definition @@ -354,6 +365,11 @@ export interface McpContainer { status?: string; mcp_url?: string; host_port?: number; + /** + * Per-item permission returned by /mcp/containers. + * EDIT: editable, READ_ONLY: read-only. + */ + permission?: "EDIT" | "READ_ONLY"; } // ========== Prompt Service Interfaces ========== From b19fbab664b10b167a704a1f7ca6853ba2436bb3 Mon Sep 17 00:00:00 2001 From: panyehong <2655992392@qq.com> Date: Tue, 3 Feb 2026 09:44:17 +0800 Subject: [PATCH 039/167] =?UTF-8?q?=E2=99=BB=EF=B8=8FRefactoring=20authent?= =?UTF-8?q?ication&authorization=20to=20develop=20the=20user=20management?= =?UTF-8?q?=20feature=20[Specification=20Details]=201.=20Resolve=20conflic?= =?UTF-8?q?t.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- frontend/stores/agentConfigStore.ts | 47 ----------------------------- 1 file changed, 47 deletions(-) diff --git a/frontend/stores/agentConfigStore.ts b/frontend/stores/agentConfigStore.ts index 00e81b230..c0713db68 100644 --- a/frontend/stores/agentConfigStore.ts +++ b/frontend/stores/agentConfigStore.ts @@ -224,53 +224,6 @@ const isSubAgentIdsDirty = (baselineAgent: EditableAgent | null, editedAgent: Ed JSON.stringify(normalizeArray(editedAgent.sub_agent_id_list ?? [])); }; -const isDirty = (baselineAgent: EditableAgent | null, editedAgent: EditableAgent): boolean => { - if (!baselineAgent) { - // Create mode: any non-default value counts as dirty - return ( - editedAgent.name !== "" || - editedAgent.display_name !== "" || - editedAgent.description !== "" || - editedAgent.author !== "" || - editedAgent.model !== "" || - editedAgent.model_id !== 0 || - editedAgent.max_step !== 0 || - editedAgent.provide_run_summary !== false || - editedAgent.tools.length > 0 || - editedAgent.duty_prompt !== "" || - editedAgent.constraint_prompt !== "" || - editedAgent.few_shots_prompt !== "" || - editedAgent.business_description !== "" || - editedAgent.business_logic_model_name !== "" || - editedAgent.business_logic_model_id !== 0 || - normalizeArray(editedAgent.sub_agent_id_list || []).length > 0 || - normalizeArray(editedAgent.group_ids || []).length > 0 - ); - } - - return ( - baselineAgent.name !== editedAgent.name || - baselineAgent.display_name !== editedAgent.display_name || - baselineAgent.description !== editedAgent.description || - baselineAgent.author !== editedAgent.author || - baselineAgent.model !== editedAgent.model || - baselineAgent.model_id !== editedAgent.model_id || - baselineAgent.max_step !== editedAgent.max_step || - baselineAgent.provide_run_summary !== editedAgent.provide_run_summary || - JSON.stringify(baselineAgent.tools) !== JSON.stringify(editedAgent.tools) || - baselineAgent.duty_prompt !== editedAgent.duty_prompt || - baselineAgent.constraint_prompt !== editedAgent.constraint_prompt || - baselineAgent.few_shots_prompt !== editedAgent.few_shots_prompt || - baselineAgent.business_description !== editedAgent.business_description || - baselineAgent.business_logic_model_name !== editedAgent.business_logic_model_name || - baselineAgent.business_logic_model_id !== editedAgent.business_logic_model_id || - JSON.stringify(normalizeArray(baselineAgent.sub_agent_id_list ?? [])) !== - JSON.stringify(normalizeArray(editedAgent.sub_agent_id_list ?? [])) || - JSON.stringify(normalizeArray(baselineAgent.group_ids ?? [])) !== - JSON.stringify(normalizeArray(editedAgent.group_ids ?? [])) - ); -}; - export const useAgentConfigStore = create((set, get) => ({ currentAgentId: null, currentAgentPermission: null, From ee7209b8eb40400e71aa43c9c80ca8e65fa6b0dc Mon Sep 17 00:00:00 2001 From: biansimeng Date: Tue, 3 Feb 2026 10:46:37 +0800 Subject: [PATCH 040/167] Revise bugs: 1. support ./deploy.sh --infrastructure; 2. hide input for tool passward params --- docker/deploy.sh | 10 +++++----- .../components/agentConfig/tool/ToolConfigModal.tsx | 13 +++++++++++++ .../[locale]/space/components/AgentDetailModal.tsx | 9 +++------ frontend/public/locales/en/common.json | 4 ++-- frontend/public/locales/zh/common.json | 4 ++-- 5 files changed, 25 insertions(+), 15 deletions(-) diff --git a/docker/deploy.sh b/docker/deploy.sh index 2545bf2dc..ad6c106f8 100755 --- a/docker/deploy.sh +++ b/docker/deploy.sh @@ -475,18 +475,18 @@ select_deployment_mode() { MODE_CHOICE_SAVED="$mode_choice" case $mode_choice in - 2) + 2|"infrastructure") export DEPLOYMENT_MODE="infrastructure" export COMPOSE_FILE_SUFFIX=".yml" echo "✅ Selected infrastructure mode 🏗️" ;; - 3) + 3|"production") export DEPLOYMENT_MODE="production" export COMPOSE_FILE_SUFFIX=".prod.yml" disable_dashboard echo "✅ Selected production mode 🚀" ;; - *) + 1|"development"|*) export DEPLOYMENT_MODE="development" export COMPOSE_FILE_SUFFIX=".yml" echo "✅ Selected development mode 🛠️" @@ -686,11 +686,11 @@ select_deployment_version() { version_choice=$(sanitize_input "$version_choice") VERSION_CHOICE_SAVED="${version_choice}" case $version_choice in - 2) + 2|"full") export DEPLOYMENT_VERSION="full" echo "✅ Selected complete version 🎯" ;; - *) + 1|"speed"|*) export DEPLOYMENT_VERSION="speed" echo "✅ Selected speed version ⚡️" ;; diff --git a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx index 57d1d2aa0..145a86a36 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx @@ -204,6 +204,19 @@ export default function ToolConfigModal({ case TOOL_PARAM_TYPES.ARRAY: case TOOL_PARAM_TYPES.OBJECT: default: + // Check if parameter name contains "password" for secure input + const isPasswordType = param.name.toLowerCase().includes("password"); + + if (isPasswordType) { + return ( + + ); + } + // Default TextArea for all text-like types and unknown types return ( + + {agentDetails?.business_logic_model_name || "-"} + {agentDetails?.model || "-"} {agentDetails?.max_step || 0} - - {agentDetails?.business_logic_model_name || "-"} - - - {agentDetails?.business_logic_model_id || "-"} - {agentDetails?.provide_run_summary ? ( {t("common.yes", "Yes")} diff --git a/frontend/public/locales/en/common.json b/frontend/public/locales/en/common.json index 8d154e3e9..b29905b37 100644 --- a/frontend/public/locales/en/common.json +++ b/frontend/public/locales/en/common.json @@ -1185,7 +1185,7 @@ "space.detail.model": "Model Name", "space.detail.modelId": "Model ID", "space.detail.maxStep": "Max Steps", - "space.detail.businessLogicModel": "Business Logic Model", + "space.detail.businessLogicModel": "Business Logic Model Name", "space.detail.businessLogicModelId": "Business Logic Model ID", "space.detail.provideRunSummary": "Provide Run Summary", "space.detail.dutyPrompt": "Duty Prompt", @@ -1452,7 +1452,7 @@ "market.detail.modelId": "Model ID", "market.detail.maxSteps": "Max Steps", "market.detail.recommendedModel": "Recommended Model", - "market.detail.businessLogicModel": "Business Logic Model", + "market.detail.businessLogicModel": "Business Logic Model Name", "market.detail.businessLogicModelId": "Business Logic Model ID", "market.detail.provideRunSummary": "Provide Run Summary", "market.detail.dutyPrompt": "Duty Prompt", diff --git a/frontend/public/locales/zh/common.json b/frontend/public/locales/zh/common.json index 344d3ed79..5d36e6f47 100644 --- a/frontend/public/locales/zh/common.json +++ b/frontend/public/locales/zh/common.json @@ -1184,7 +1184,7 @@ "space.detail.model": "模型名称", "space.detail.modelId": "模型 ID", "space.detail.maxStep": "最大步数", - "space.detail.businessLogicModel": "业务逻辑模型", + "space.detail.businessLogicModel": "业务逻辑模型名称", "space.detail.businessLogicModelId": "业务逻辑模型 ID", "space.detail.provideRunSummary": "提供运行摘要", "space.detail.dutyPrompt": "职责提示词", @@ -1430,7 +1430,7 @@ "market.detail.modelId": "模型 ID", "market.detail.maxSteps": "最大步数", "market.detail.recommendedModel": "建议模型", - "market.detail.businessLogicModel": "业务逻辑模型", + "market.detail.businessLogicModel": "业务逻辑模型名称", "market.detail.businessLogicModelId": "业务逻辑模型 ID", "market.detail.provideRunSummary": "提供运行摘要", "market.detail.dutyPrompt": "职责提示词", From 6162d15b6d42aff7a9a8b74d09efa85c4ac3f301 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=91=9B=E9=94=90?= Date: Tue, 3 Feb 2026 11:27:19 +0800 Subject: [PATCH 041/167] feat(frontend): add clickable link to agent space after agent install - Add router.push to navigate to /space after successful installation - Split success message into prefix/link/suffix for blue clickable text - Add i18n keys for multilingual support --- frontend/app/[locale]/market/page.tsx | 25 ++++++++++++++++++++----- frontend/public/locales/en/common.json | 3 +++ frontend/public/locales/zh/common.json | 3 +++ 3 files changed, 26 insertions(+), 5 deletions(-) diff --git a/frontend/app/[locale]/market/page.tsx b/frontend/app/[locale]/market/page.tsx index a59465d70..3e8fa6c90 100644 --- a/frontend/app/[locale]/market/page.tsx +++ b/frontend/app/[locale]/market/page.tsx @@ -1,6 +1,7 @@ "use client"; import React, { useState, useEffect, useRef } from "react"; +import { useRouter } from "next/navigation"; import { motion } from "framer-motion"; import { useTranslation } from "react-i18next"; import { ShoppingBag, Search, RefreshCw, ChevronLeft, ChevronRight } from "lucide-react"; @@ -27,6 +28,7 @@ import "./MarketContent.css"; * Browse and download pre-built agents from the marketplace */ export default function MarketContent() { + const router = useRouter(); const { t, i18n } = useTranslation("common"); const { message } = App.useApp(); const isZh = i18n.language === "zh" || i18n.language === "zh-CN"; @@ -250,15 +252,28 @@ export default function MarketContent() { }; /** - * Handle install complete + * Handle install complete - Shows success message with navigation to agent space */ const handleInstallComplete = () => { setInstallModalVisible(false); setInstallAgent(null); - // Optionally reload agents or show success message - message.success( - t("market.install.success", "Agent installed successfully!") - ); + + // Show success message with clickable link to agent space + message.success({ + content: ( + + {t("market.install.success.viewSpace.prefix")} + + {t("market.install.success.viewSpace.suffix")} + + ), + duration: 4, + }); }; /** diff --git a/frontend/public/locales/en/common.json b/frontend/public/locales/en/common.json index 8d154e3e9..0643757c0 100644 --- a/frontend/public/locales/en/common.json +++ b/frontend/public/locales/en/common.json @@ -1545,6 +1545,9 @@ "market.install.success.nameRegeneratedAndResolved": "Agent names regenerated successfully and all conflicts resolved", "market.install.info.notImplemented": "Installation will be implemented in next phase", "market.install.success": "Agent installed successfully!", + "market.install.success.viewSpace.prefix": "Agent installed successfully. You can ", + "market.install.success.viewSpace.link": "Go to Agent Space", + "market.install.success.viewSpace.suffix": " to view and use it.", "market.install.warning.title": "Agent May Be Unusable", "market.install.warning.description": "The following issues may make the Agent unusable:", "market.install.warning.nameConflict": "Unresolved name conflicts exist", diff --git a/frontend/public/locales/zh/common.json b/frontend/public/locales/zh/common.json index 344d3ed79..bb64607e9 100644 --- a/frontend/public/locales/zh/common.json +++ b/frontend/public/locales/zh/common.json @@ -1523,6 +1523,9 @@ "market.install.success.nameRegeneratedAndResolved": "智能体名称重新生成成功,且所有冲突已解决", "market.install.info.notImplemented": "安装功能将在下一阶段实现", "market.install.success": "智能体安装成功!", + "market.install.success.viewSpace.prefix": "智能体安装成功,您可", + "market.install.success.viewSpace.link": "前往智能体空间", + "market.install.success.viewSpace.suffix": "查看与使用", "market.install.warning.title": "智能体可能不可用", "market.install.warning.description": "以下问题可能导致智能体不可用:", "market.install.warning.nameConflict": "存在未解决的名称冲突", From b708afd1407d99901e906341318b8c647a73c28d Mon Sep 17 00:00:00 2001 From: Jasonxia007 Date: Tue, 3 Feb 2026 11:34:53 +0800 Subject: [PATCH 042/167] =?UTF-8?q?=E2=9C=A8=20Add=20user=20group=20tags?= =?UTF-8?q?=20in=20avatarDropdown.tsx=20and=20UserProfileComp.tsx?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../users/components/UserProfileComp.tsx | 55 ++++++++++++++++++- frontend/components/auth/avatarDropdown.tsx | 39 ++++++++++++- 2 files changed, 90 insertions(+), 4 deletions(-) diff --git a/frontend/app/[locale]/users/components/UserProfileComp.tsx b/frontend/app/[locale]/users/components/UserProfileComp.tsx index 43e2d686c..bbfa4804a 100644 --- a/frontend/app/[locale]/users/components/UserProfileComp.tsx +++ b/frontend/app/[locale]/users/components/UserProfileComp.tsx @@ -11,6 +11,8 @@ import { App, Flex, Alert, + Tag, + Tooltip, } from "antd"; import { motion } from "framer-motion"; import { useTranslation } from "react-i18next"; @@ -28,6 +30,8 @@ import { import { USER_ROLES } from "@/const/modelConfig"; import { useAuthorizationContext } from "@/components/providers/AuthorizationProvider"; import { useAuthenticationContext } from "@/components/providers/AuthenticationProvider"; +import { useGroupList } from "@/hooks/group/useGroupList"; +import { useMemo } from "react"; const { Text, Paragraph } = Typography; @@ -45,7 +49,30 @@ export default function UserProfileComp() { const { t } = useTranslation("common"); const { message: antdMessage } = App.useApp(); const { logout, revoke, isLoading } = useAuthenticationContext() - const { user } = useAuthorizationContext() + const { user, groupIds } = useAuthorizationContext() + + // Fetch groups for group name mapping + const { data: groupData } = useGroupList(user?.tenantId || null, 1, 100); + const groups = groupData?.groups || []; + + // Create group name mapping from group_id to group_name + const groupNameMap = useMemo(() => { + const map = new Map(); + groups.forEach((group) => { + map.set(group.group_id, group.group_name); + }); + return map; + }, [groups]); + + // Get user's group names + const userGroupNames = useMemo(() => { + if (!groupIds || groupIds.length === 0) return []; + return groupIds.map((id) => ({ + id, + name: groupNameMap.get(id) || t("common.unknown"), + description: groups.find((g) => g.group_id === id)?.group_description || "", + })); + }, [groupIds, groupNameMap, groups, t]); // Modal states const [isEditModalOpen, setIsEditModalOpen] = useState(false); @@ -169,6 +196,32 @@ export default function UserProfileComp() { {getRoleDisplayName(user?.role || "user")}
+
+ + {t("agent.userGroup") || "User Group"} + +
+ {userGroupNames.length > 0 ? ( + userGroupNames.map((group) => ( + + + {group.name} + + + )) + ) : ( + + {t("agent.userGroup.empty")} + + )} +
+
diff --git a/frontend/components/auth/avatarDropdown.tsx b/frontend/components/auth/avatarDropdown.tsx index b41309dfd..64e47c8a7 100644 --- a/frontend/components/auth/avatarDropdown.tsx +++ b/frontend/components/auth/avatarDropdown.tsx @@ -1,6 +1,6 @@ "use client"; -import React, { useState } from "react"; +import React, { useState, useMemo } from "react"; import { useTranslation } from "react-i18next"; import { Dropdown, Avatar, Spin, Button, Tag, ConfigProvider } from "antd"; import { UserRound, LogOut, LogIn, UserRoundPlus, UserCircle, Power } from "lucide-react"; @@ -13,9 +13,10 @@ import { useAuthorizationContext } from "@/components/providers/AuthorizationPro import { useConfirmModal } from "@/hooks/useConfirmModal"; import { getRoleColor } from "@/lib/auth"; import { USER_ROLES } from "@/const/auth"; +import { useGroupList } from "@/hooks/group/useGroupList"; export function AvatarDropdown() { - const { user, isAuthzReady } = useAuthorizationContext(); + const { user, groupIds, isAuthzReady } = useAuthorizationContext(); const { isLoading, logout, revoke, openLoginModal, openRegisterModal } = useAuthenticationContext(); const [dropdownOpen, setDropdownOpen] = useState(false); @@ -23,6 +24,31 @@ export function AvatarDropdown() { const { modal } = App.useApp(); const { confirm } = useConfirmModal(); + // Fetch groups for group name mapping + const { data: groupData } = useGroupList(user?.tenantId || null, 1, 100); + const groups = groupData?.groups || []; + + // Create group name mapping from group_id to group_name + const groupNameMap = useMemo(() => { + const map = new Map(); + groups.forEach((group) => { + map.set(group.group_id, { + name: group.group_name, + description: group.group_description, + }); + }); + return map; + }, [groups]); + + // Get user's group info + const userGroups = useMemo(() => { + if (!groupIds || groupIds.length === 0) return []; + return groupIds.map((id) => ({ + id, + ...groupNameMap.get(id) || { name: t("common.unknown"), description: "" }, + })); + }, [groupIds, groupNameMap, t]); + // Show loading while authentication is in progress if (isLoading) { return ; @@ -94,10 +120,17 @@ export function AvatarDropdown() { label: (
{user.email}
-
+
{t(`auth.${(user.role).toLowerCase()}`)} + {userGroups.length > 0 ? ( + userGroups.map((group) => ( + + {group.name} + + )) + ) : null}
), From b240194e8a1b34af9152b46a586ae4a20e012596 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=91=9B=E9=94=90?= Date: Tue, 3 Feb 2026 11:51:27 +0800 Subject: [PATCH 043/167] feat(frontend): swap button styles in AgentImportWizard warning modal - Change cancel button to primary (blue) style - Change ok button to default (gray) style for better UX in warning dialog --- frontend/components/agent/AgentImportWizard.tsx | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/frontend/components/agent/AgentImportWizard.tsx b/frontend/components/agent/AgentImportWizard.tsx index 2de633b33..f1d3afade 100644 --- a/frontend/components/agent/AgentImportWizard.tsx +++ b/frontend/components/agent/AgentImportWizard.tsx @@ -820,8 +820,11 @@ export default function AgentImportWizard({ ), okText: t("market.install.warning.continue", "Continue Anyway"), cancelText: t("market.install.warning.goBack", "Go Back to Configure"), + cancelButtonProps: { + type: "primary", + }, okButtonProps: { - className: "bg-blue-600 hover:bg-blue-700 border-blue-600 hover:border-blue-700 text-white", + type: "default", }, onOk: async () => { await performImport(); From 08384612522de6d44a4b5304fbecbd21d9206fd2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=91=9B=E9=94=90?= Date: Tue, 3 Feb 2026 15:56:39 +0800 Subject: [PATCH 044/167] feat(frontend): preserve business_logic_model fields during agent import - Add business_logic_model_id and business_logic_model_name to ImportAgentData interface - Pass these fields from market page to AgentImportWizard - Preserve the fields if they were passed from market server, otherwise set to null This fixes the issue where business_logic_model fields were lost when importing agents from the market. --- frontend/app/[locale]/market/page.tsx | 2 ++ .../components/agent/AgentImportWizard.tsx | 36 +++++++++++++++---- frontend/hooks/useAgentImport.ts | 2 ++ 3 files changed, 34 insertions(+), 6 deletions(-) diff --git a/frontend/app/[locale]/market/page.tsx b/frontend/app/[locale]/market/page.tsx index a59465d70..89e04bca3 100644 --- a/frontend/app/[locale]/market/page.tsx +++ b/frontend/app/[locale]/market/page.tsx @@ -539,6 +539,8 @@ export default function MarketContent() { agent_id: installAgent.agent_id, agent_info: installAgent.agent_json.agent_info, mcp_info: installAgent.agent_json.mcp_info, + business_logic_model_id: installAgent.business_logic_model_id, + business_logic_model_name: installAgent.business_logic_model_name, } as ImportAgentData) : null } diff --git a/frontend/components/agent/AgentImportWizard.tsx b/frontend/components/agent/AgentImportWizard.tsx index 2de633b33..2601fea43 100644 --- a/frontend/components/agent/AgentImportWizard.tsx +++ b/frontend/components/agent/AgentImportWizard.tsx @@ -879,6 +879,10 @@ export default function AgentImportWizard({ // Clone agent data structure const agentJson = JSON.parse(JSON.stringify(initialData)); + // Preserve business logic model fields from initial data (passed from market) + const preservedBusinessLogicModelId = initialData.business_logic_model_id; + const preservedBusinessLogicModelName = initialData.business_logic_model_name; + // Update all agents' name/display_name if renamed Object.entries(agentNameConflicts).forEach(([agentKey, conflict]) => { if (agentJson.agent_info[agentKey]) { @@ -898,9 +902,19 @@ export default function AgentImportWizard({ agentInfo.model_id = selectedModelId; agentInfo.model_name = selectedModelName; - // Clear business logic model fields - agentInfo.business_logic_model_id = null; - agentInfo.business_logic_model_name = null; + // Preserve business logic model fields if they were passed from market + // Otherwise clear them when user selects a new model + if (preservedBusinessLogicModelId !== undefined && preservedBusinessLogicModelId !== null) { + agentInfo.business_logic_model_id = preservedBusinessLogicModelId; + } else { + agentInfo.business_logic_model_id = null; + } + + if (preservedBusinessLogicModelName !== undefined && preservedBusinessLogicModelName !== null) { + agentInfo.business_logic_model_name = preservedBusinessLogicModelName; + } else { + agentInfo.business_logic_model_name = null; + } }); } else { // Individual mode: apply models to all agents @@ -910,9 +924,19 @@ export default function AgentImportWizard({ agentInfo.model_id = modelSelection.modelId; agentInfo.model_name = modelSelection.modelName; - // Clear business logic model fields - agentInfo.business_logic_model_id = null; - agentInfo.business_logic_model_name = null; + // Preserve business logic model fields if they were passed from market + // Otherwise clear them when user selects a new model + if (preservedBusinessLogicModelId !== undefined && preservedBusinessLogicModelId !== null) { + agentInfo.business_logic_model_id = preservedBusinessLogicModelId; + } else { + agentInfo.business_logic_model_id = null; + } + + if (preservedBusinessLogicModelName !== undefined && preservedBusinessLogicModelName !== null) { + agentInfo.business_logic_model_name = preservedBusinessLogicModelName; + } else { + agentInfo.business_logic_model_name = null; + } } }); } diff --git a/frontend/hooks/useAgentImport.ts b/frontend/hooks/useAgentImport.ts index 0aff99e82..39107a8d0 100644 --- a/frontend/hooks/useAgentImport.ts +++ b/frontend/hooks/useAgentImport.ts @@ -13,6 +13,8 @@ export interface ImportAgentData { mcp_server_name: string; mcp_url: string; }>; + business_logic_model_id?: number | null; + business_logic_model_name?: string | null; } export interface UseAgentImportOptions { From 1e5fdce5e6ae5fe787634da8afed8c2553fd34ad Mon Sep 17 00:00:00 2001 From: zhizhi <928570418@qq.com> Date: Mon, 2 Feb 2026 16:28:10 +0800 Subject: [PATCH 045/167] =?UTF-8?q?=E2=9C=A8=20Enhance=20tool=20configurat?= =?UTF-8?q?ion=20and=20search=20tools=20web?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/services/dify_service.py | 7 + .../knowledge/KnowledgeBaseList.tsx | 154 +--- .../contexts/KnowledgeBaseContext.tsx | 97 +-- .../KnowledgeBaseSelectorModal.tsx | 685 ++++++++++++++++++ frontend/components/tool-config/index.ts | 38 + frontend/const/knowledgeBaseLayout.ts | 59 ++ frontend/public/locales/en/common.json | 4 +- frontend/public/locales/zh/common.json | 4 +- frontend/services/api.ts | 5 +- frontend/services/knowledgeBaseService.ts | 177 ++++- frontend/services/userConfigService.ts | 59 -- frontend/types/knowledgeBase.ts | 126 ++-- sdk/nexent/core/tools/dify_search_tool.py | 16 +- test/sdk/core/tools/test_dify_search_tool.py | 97 ++- 14 files changed, 1167 insertions(+), 361 deletions(-) create mode 100644 frontend/components/tool-config/KnowledgeBaseSelectorModal.tsx create mode 100644 frontend/components/tool-config/index.ts create mode 100644 frontend/const/knowledgeBaseLayout.ts delete mode 100644 frontend/services/userConfigService.ts diff --git a/backend/services/dify_service.py b/backend/services/dify_service.py index 15524576b..484ca8cd6 100644 --- a/backend/services/dify_service.py +++ b/backend/services/dify_service.py @@ -74,6 +74,13 @@ def fetch_dify_datasets_impl( # Normalize API base URL api_base = dify_api_base.rstrip("/") + # Remove /v1 or /v1/ suffix if present to avoid URL duplication + # E.g., "https://api.dify.ai/v1" -> "https://api.dify.ai" + if api_base.endswith("/v1"): + api_base = api_base[:-3] + elif api_base.endswith("/v1/"): + api_base = api_base[:-4] + # Build request URL with pagination url = f"{api_base}/v1/datasets" diff --git a/frontend/app/[locale]/knowledges/components/knowledge/KnowledgeBaseList.tsx b/frontend/app/[locale]/knowledges/components/knowledge/KnowledgeBaseList.tsx index 86f3605d7..211818b43 100644 --- a/frontend/app/[locale]/knowledges/components/knowledge/KnowledgeBaseList.tsx +++ b/frontend/app/[locale]/knowledges/components/knowledge/KnowledgeBaseList.tsx @@ -1,7 +1,7 @@ import React, { useState, useMemo } from "react"; import { useTranslation } from "react-i18next"; -import { Button, Checkbox, ConfigProvider, Input, Select, Space } from "antd"; +import { Button, Input, Select } from "antd"; import { SyncOutlined, PlusOutlined, @@ -11,54 +11,20 @@ import { } from "@ant-design/icons"; import { KnowledgeBase } from "@/types/knowledgeBase"; - -// Knowledge base layout constants configuration -const KB_LAYOUT = { - // Knowledge base row height configuration - ROW_PADDING: "py-4", // Row vertical padding - HEADER_PADDING: "p-3", // List header padding - BUTTON_PADDING: "p-2", // Create button area padding - TAG_SPACING: "gap-0.5", // Spacing between tags - TAG_MARGIN: "mt-2.5", // Tag container top margin - // Tag related configuration - TAG_PADDING: "px-1.5 py-0.5", // Tag padding - TAG_TEXT: "text-xs font-medium", // Tag text style - TAG_ROUNDED: "rounded-md", // Tag rounded corners - // Line break related configuration - TAG_BREAK_HEIGHT: "h-0.5", // Line break interval height - SECOND_ROW_TAG_MARGIN: "mt-0.5", // Second row tag top margin - // Other layout configuration - TITLE_MARGIN: "ml-2", // Title left margin - EMPTY_STATE_PADDING: "py-4", // Empty state padding - // Title related configuration - TITLE_TEXT: "text-xl font-bold", // Title text style - KB_NAME_TEXT: "text-lg font-medium", // Knowledge base name text style - // Knowledge base name configuration - KB_NAME_MAX_WIDTH: "220px", // Knowledge base name max width - KB_NAME_OVERFLOW: { - // Knowledge base name overflow style - textOverflow: "ellipsis", - whiteSpace: "nowrap", - overflow: "hidden", - display: "block", - }, -}; +import { KB_LAYOUT, KB_TAG_VARIANTS } from "@/const/knowledgeBaseLayout"; interface KnowledgeBaseListProps { knowledgeBases: KnowledgeBase[]; - selectedIds: string[]; activeKnowledgeBase: KnowledgeBase | null; currentEmbeddingModel: string | null; isLoading?: boolean; syncLoading?: boolean; - onSelect: (id: string) => void; onClick: (kb: KnowledgeBase) => void; onDelete: (id: string) => void; onSync: () => void; onCreateNew: () => void; onDataMateConfig?: () => void; showDataMateConfig?: boolean; // Control whether to show DataMate config button - isSelectable: (kb: KnowledgeBase) => boolean; getModelDisplayName: (modelId: string) => string; containerHeight?: string; // Container total height, consistent with DocumentList onKnowledgeBaseChange?: () => void; // New: callback function when knowledge base switches @@ -73,19 +39,16 @@ interface KnowledgeBaseListProps { const KnowledgeBaseList: React.FC = ({ knowledgeBases, - selectedIds, activeKnowledgeBase, currentEmbeddingModel, isLoading = false, syncLoading = false, - onSelect, onClick, onDelete, onSync, onCreateNew, onDataMateConfig, showDataMateConfig = false, - isSelectable, getModelDisplayName, containerHeight = "70vh", // Default container height consistent with DocumentList onKnowledgeBaseChange, // New: callback function when knowledge base switches @@ -351,67 +314,11 @@ const KnowledgeBaseList: React.FC = ({
- {/* Fixed selection status area */} -
-
-
- - {t("knowledgeBase.selected.prefix")}{" "} - - - {selectedIds.length} - - - {t("knowledgeBase.selected.suffix")} - -
- - {selectedIds.length > 0 && ( -
- {selectedIds.map((id) => { - const kb = knowledgeBases.find((kb) => kb.id === id); - return kb ? ( - - - {kb.name} - - - - ) : null; - })} -
- )} -
-
- - {/* Scrollable knowledge base list area */}
{filteredKnowledgeBases.length > 0 ? (
{filteredKnowledgeBases.map((kb, index) => { - const canSelect = isSelectable(kb); - const isSelected = selectedIds.includes(kb.id); const isActive = activeKnowledgeBase?.id === kb.id; - const isMismatchedAndSelected = isSelected && !canSelect; return (
= ({ }} >
-
-
{ - e.stopPropagation(); - if (canSelect || isSelected) { - onSelect(kb.id); - } - }} - style={{ - minWidth: "40px", - minHeight: "40px", - display: "flex", - alignItems: "flex-start", - justifyContent: "center", - }} - > - - { - e.stopPropagation(); - onSelect(kb.id); - }} - disabled={!canSelect && !isSelected} - style={{ - cursor: - canSelect || isSelected - ? "pointer" - : "not-allowed", - transform: "scale(1.5)", - }} - /> - -
-

= ({ > {/* Document count tag */} {t("knowledgeBase.tag.documents", { count: kb.documentCount || 0, @@ -516,7 +378,7 @@ const KnowledgeBaseList: React.FC = ({ {/* Chunk count tag */} {t("knowledgeBase.tag.chunks", { count: kb.chunkCount || 0, @@ -525,7 +387,7 @@ const KnowledgeBaseList: React.FC = ({ {/* Always show source tag regardless of document/chunk count */} {t("knowledgeBase.tag.source", { source: kb.source, @@ -538,7 +400,7 @@ const KnowledgeBaseList: React.FC = ({ <> {/* Creation date tag - only show date */} {t("knowledgeBase.tag.createdAt", { date: formatDate(kb.createdAt), @@ -553,7 +415,7 @@ const KnowledgeBaseList: React.FC = ({ {/* Model tag - only show when model is not "unknown" */} {kb.embeddingModel !== "unknown" && ( {t("knowledgeBase.tag.model", { model: getModelDisplayName(kb.embeddingModel), @@ -564,7 +426,7 @@ const KnowledgeBaseList: React.FC = ({ kb.embeddingModel !== currentEmbeddingModel && kb.source !== "datamate" && ( {t("knowledgeBase.tag.modelMismatch")} diff --git a/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx b/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx index d96c8fa33..b38e257fd 100644 --- a/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx +++ b/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx @@ -12,7 +12,6 @@ import { import { useTranslation } from "react-i18next"; import knowledgeBaseService from "@/services/knowledgeBaseService"; -import { userConfigService } from "@/services/userConfigService"; import { KnowledgeBase, @@ -111,8 +110,6 @@ export const KnowledgeBaseContext = createContext<{ hasKnowledgeBaseModelMismatch: (kb: KnowledgeBase) => boolean; refreshKnowledgeBaseData: (forceRefresh?: boolean) => Promise; refreshKnowledgeBaseDataWithDataMate: () => Promise; - loadUserSelectedKnowledgeBases: () => Promise; - saveUserSelectedKnowledgeBases: () => Promise; }>({ state: { knowledgeBases: [], @@ -133,8 +130,6 @@ export const KnowledgeBaseContext = createContext<{ hasKnowledgeBaseModelMismatch: () => false, refreshKnowledgeBaseData: async () => {}, refreshKnowledgeBaseDataWithDataMate: async () => {}, - loadUserSelectedKnowledgeBases: async () => {}, - saveUserSelectedKnowledgeBases: async () => false, }); // Custom hook for using the context @@ -205,49 +200,6 @@ export const KnowledgeBaseProvider: React.FC = ({ [state.currentEmbeddingModel] ); - // Load user selected knowledge bases from backend - const loadUserSelectedKnowledgeBases = useCallback(async () => { - try { - const userConfig = await userConfigService.loadKnowledgeList(); - if (userConfig) { - let allSelectedNames: string[] = []; - - // Handle new format (selectedKbNames array) - if ( - userConfig.selectedKbNames && - userConfig.selectedKbNames.length > 0 - ) { - allSelectedNames = userConfig.selectedKbNames; - } - // Fallback to legacy grouped format for backward compatibility - else if (userConfig.nexent || userConfig.datamate) { - allSelectedNames = [ - ...(userConfig.nexent || []), - ...(userConfig.datamate || []), - ]; - } - - if (allSelectedNames.length > 0) { - // Find matching knowledge base IDs based on index names - const selectedIds = state.knowledgeBases - .filter((kb) => allSelectedNames.includes(kb.id)) - .map((kb) => kb.id); - - dispatch({ - type: KNOWLEDGE_BASE_ACTION_TYPES.SELECT_KNOWLEDGE_BASE, - payload: selectedIds, - }); - } - } - } catch (error) { - log.error(t("knowledgeBase.error.loadSelected"), error); - dispatch({ - type: KNOWLEDGE_BASE_ACTION_TYPES.ERROR, - payload: t("knowledgeBase.error.loadSelectedRetry"), - }); - } - }, [state.knowledgeBases]); - // Load knowledge base data (supports force fetch from server and load selected status) - optimized with useCallback const fetchKnowledgeBases = useCallback( async ( @@ -276,11 +228,6 @@ export const KnowledgeBaseProvider: React.FC = ({ type: KNOWLEDGE_BASE_ACTION_TYPES.FETCH_SUCCESS, payload: kbs, }); - - // After loading knowledge bases, automatically load user's selected knowledge bases if requested - if (shouldLoadSelected && kbs.length > 0) { - await loadUserSelectedKnowledgeBases(); - } } catch (error) { log.error(t("knowledgeBase.error.fetchList"), error); dispatch({ @@ -291,7 +238,7 @@ export const KnowledgeBaseProvider: React.FC = ({ dispatch({ type: KNOWLEDGE_BASE_ACTION_TYPES.LOADING, payload: false }); } }, - [state.isLoading, t, loadUserSelectedKnowledgeBases] + [state.isLoading, t] ); // Select knowledge base - memoized with useCallback @@ -402,44 +349,6 @@ export const KnowledgeBaseProvider: React.FC = ({ [state.knowledgeBases, state.selectedIds, state.activeKnowledgeBase] ); - // Save user selected knowledge bases to backend - const saveUserSelectedKnowledgeBases = useCallback(async () => { - try { - // Get selected knowledge bases grouped by source - const selectedKnowledgeBases = state.knowledgeBases.filter((kb) => - state.selectedIds.includes(kb.id) - ); - - // Group knowledge bases by source - const knowledgeBySource: { nexent?: string[]; datamate?: string[] } = {}; - selectedKnowledgeBases.forEach((kb) => { - const source = kb.source as keyof typeof knowledgeBySource; - if (!knowledgeBySource[source]) { - knowledgeBySource[source] = []; - } - knowledgeBySource[source]!.push(kb.id); - }); - - const result = - await userConfigService.updateKnowledgeList(knowledgeBySource); - if (!result) { - dispatch({ - type: KNOWLEDGE_BASE_ACTION_TYPES.ERROR, - payload: t("knowledgeBase.error.saveSelected"), - }); - return false; - } - return true; - } catch (error) { - log.error(t("knowledgeBase.error.saveSelected"), error); - dispatch({ - type: KNOWLEDGE_BASE_ACTION_TYPES.ERROR, - payload: t("knowledgeBase.error.saveSelectedRetry"), - }); - return false; - } - }, [state.knowledgeBases, state.selectedIds, t]); - // Add a function to refresh the knowledge base data const refreshKnowledgeBaseData = useCallback( async (forceRefresh = false) => { @@ -624,8 +533,6 @@ export const KnowledgeBaseProvider: React.FC = ({ hasKnowledgeBaseModelMismatch, refreshKnowledgeBaseData, refreshKnowledgeBaseDataWithDataMate, - loadUserSelectedKnowledgeBases, - saveUserSelectedKnowledgeBases, }), [ state, @@ -637,8 +544,6 @@ export const KnowledgeBaseProvider: React.FC = ({ isKnowledgeBaseSelectable, refreshKnowledgeBaseData, refreshKnowledgeBaseDataWithDataMate, - loadUserSelectedKnowledgeBases, - saveUserSelectedKnowledgeBases, ] ); diff --git a/frontend/components/tool-config/KnowledgeBaseSelectorModal.tsx b/frontend/components/tool-config/KnowledgeBaseSelectorModal.tsx new file mode 100644 index 000000000..36900ea13 --- /dev/null +++ b/frontend/components/tool-config/KnowledgeBaseSelectorModal.tsx @@ -0,0 +1,685 @@ +"use client"; + +import React, { useState, useMemo, useCallback, useEffect } from "react"; +import { useTranslation } from "react-i18next"; + +import { + Modal, + Button, + Input, + Select, + Spin, + Checkbox, + ConfigProvider, +} from "antd"; +import { + SearchOutlined, + SyncOutlined, +} from "@ant-design/icons"; + +import { KnowledgeBase } from "@/types/knowledgeBase"; +import { + KnowledgeBaseSelectorProps, + getKnowledgeBaseSourcesForTool, +} from "./index"; +import { KB_LAYOUT, KB_TAG_VARIANTS } from "@/const/knowledgeBaseLayout"; + +interface KnowledgeBaseSelectorModalProps extends KnowledgeBaseSelectorProps { + knowledgeBases: KnowledgeBase[]; + isLoading?: boolean; + getModelDisplayName?: (modelId: string) => string; + onSync?: ( + toolType: string, + difyConfig?: { serverUrl?: string; apiKey?: string } + ) => void; + showCheckbox?: boolean; + onSyncComplete?: (knowledgeBases: KnowledgeBase[]) => void; + syncLoading?: boolean; // Loading state for sync button + // Selection validation props + isSelectable?: (kb: KnowledgeBase) => boolean; + currentEmbeddingModel?: string | null; + // Dify configuration for fetching Dify knowledge bases + difyConfig?: { + serverUrl?: string; + apiKey?: string; + }; +} + +export default function KnowledgeBaseSelectorModal({ + isOpen, + onClose, + onConfirm, + selectedIds, + toolType, + title, + maxSelect, + knowledgeBases, + isLoading = false, + getModelDisplayName = (modelId: string) => modelId, + onSync, + showCheckbox = true, + onSyncComplete, + syncLoading = false, + isSelectable, + currentEmbeddingModel = null, + difyConfig, +}: KnowledgeBaseSelectorModalProps) { + const { t } = useTranslation("common"); + + // Selection state (kept for internal logic but not displayed) + const [tempSelectedIds, setTempSelectedIds] = useState([]); + // Search and filter state + const [searchKeyword, setSearchKeyword] = useState(""); + const [selectedSources, setSelectedSources] = useState([]); + const [selectedModels, setSelectedModels] = useState([]); + + // Initialize selection state when modal opens + useEffect(() => { + if (isOpen) { + setTempSelectedIds(selectedIds); + setSearchKeyword(""); + setSelectedSources([]); + setSelectedModels([]); + } + }, [isOpen, selectedIds]); + + // Get allowed sources for the tool type + const allowedSources = useMemo(() => { + return getKnowledgeBaseSourcesForTool(toolType); + }, [toolType]); + + // Calculate available filter options based on actual knowledge bases + const availableSources = useMemo(() => { + const sources = new Set(knowledgeBases.map((kb) => kb.source)); + return Array.from(sources) + .filter((source) => source && allowedSources.includes(source)) + .sort(); + }, [knowledgeBases, allowedSources]); + + const availableModels = useMemo(() => { + const models = new Set(knowledgeBases.map((kb) => kb.embeddingModel)); + return Array.from(models) + .filter((model) => model && model !== "unknown") + .sort(); + }, [knowledgeBases]); + + // Format date function, only keep date part + const formatDate = useCallback((dateValue: any) => { + try { + const date = + typeof dateValue === "number" + ? new Date(dateValue) + : new Date(dateValue); + return isNaN(date.getTime()) + ? String(dateValue ?? "") + : date.toISOString().split("T")[0]; + } catch (e) { + return String(dateValue ?? ""); + } + }, []); + + // Check if a knowledge base can be selected + const checkCanSelect = useCallback( + (kb: KnowledgeBase): boolean => { + // If custom isSelectable function is provided, use it + if (isSelectable) { + return isSelectable(kb); + } + + // Default selection logic: + // 1. Empty knowledge bases cannot be selected + const isEmpty = + (kb.documentCount || 0) === 0 && (kb.chunkCount || 0) === 0; + if (isEmpty) { + return false; + } + + // 2. For nexent source, check model matching + if (kb.source === "nexent" && currentEmbeddingModel) { + if ( + kb.embeddingModel && + kb.embeddingModel !== "unknown" && + kb.embeddingModel !== currentEmbeddingModel + ) { + return false; + } + } + + return true; + }, + [isSelectable, currentEmbeddingModel] + ); + + // Check if a knowledge base has model mismatch (for display purposes) + const checkModelMismatch = useCallback( + (kb: KnowledgeBase): boolean => { + if (kb.source !== "nexent" || !currentEmbeddingModel) { + return false; + } + const embeddingModel = kb.embeddingModel; + return Boolean( + embeddingModel && + embeddingModel !== "unknown" && + embeddingModel !== currentEmbeddingModel + ); + }, + [currentEmbeddingModel] + ); + + // Filter knowledge bases based on tool type, search, and filters + const filteredKnowledgeBases = useMemo(() => { + let filtered = knowledgeBases.filter((kb) => { + // Filter by tool type source + if (!allowedSources.includes(kb.source)) { + return false; + } + + // Keyword search + const keyword = searchKeyword.trim(); + if (keyword) { + const matchesSearch = + kb.name.toLowerCase().includes(keyword.toLowerCase()) || + (kb.description && + kb.description.toLowerCase().includes(keyword.toLowerCase())) || + (kb.nickname && + kb.nickname.toLowerCase().includes(keyword.toLowerCase())); + if (!matchesSearch) return false; + } + + // Source filter + if (selectedSources.length > 0 && !selectedSources.includes(kb.source)) { + return false; + } + + // Model filter + if ( + selectedModels.length > 0 && + !selectedModels.includes(kb.embeddingModel) + ) { + return false; + } + + return true; + }); + + // Sort by update time (latest first) + filtered = [...filtered].sort((a, b) => { + const aTime = a.updatedAt ? new Date(a.updatedAt).getTime() : 0; + const bTime = b.updatedAt ? new Date(b.updatedAt).getTime() : 0; + return bTime - aTime; + }); + + return filtered; + }, [ + knowledgeBases, + allowedSources, + searchKeyword, + selectedSources, + selectedModels, + ]); + + // Toggle selection (still needed for confirm) + const toggleSelection = useCallback( + (id: string) => { + // Find the knowledge base + const kb = knowledgeBases.find((k) => k.id === id); + if (!kb) return; + + // Check if can be selected + if (!checkCanSelect(kb)) { + return; + } + + setTempSelectedIds((prev) => { + if (prev.includes(id)) { + return prev.filter((itemId) => itemId !== id); + } + + // Check max select limit + if (maxSelect && prev.length >= maxSelect) { + return prev; + } + + return [...prev, id]; + }); + }, + [knowledgeBases, maxSelect, checkCanSelect] + ); + + // Clear all selections + const clearAllSelections = useCallback(() => { + setTempSelectedIds([]); + }, []); + + // Handle confirm + const handleConfirm = useCallback(() => { + const selectedKnowledgeBases = knowledgeBases.filter((kb) => + tempSelectedIds.includes(kb.id) + ); + onConfirm(selectedKnowledgeBases); + onClose(); + }, [knowledgeBases, tempSelectedIds, onConfirm, onClose]); + + // Handle cancel + const handleCancel = useCallback(() => { + setTempSelectedIds(selectedIds); + onClose(); + }, [selectedIds, onClose]); + + // Default title based on tool type + const defaultTitle = useMemo(() => { + const titles: Record = { + knowledge_base_search: t("toolConfig.knowledgeBaseSelector.title.local"), + dify_search: t("toolConfig.knowledgeBaseSelector.title.dify"), + datamate_search: t("toolConfig.knowledgeBaseSelector.title.datamate"), + }; + return ( + titles[toolType] || t("toolConfig.knowledgeBaseSelector.title.default") + ); + }, [toolType, t]); + + return ( + + {/* Fixed header area - consistent with KnowledgeBaseList */} +

+
+
+

+ {t("knowledgeBase.list.title")} +

+
+
+ +
+
+ + {/* Search and filter area */} +
+ } + value={searchKeyword} + onChange={(e) => setSearchKeyword(e.target.value)} + style={{ width: 250 }} + allowClear + /> + + {availableSources.length > 0 && ( + + )} + + {availableModels.length > 0 && ( + + )} +
+
+ + {/* Fixed selection status area */} +
+
+
+
+ + {t("knowledgeBase.selected.prefix")}{" "} + + + {tempSelectedIds.length} + + + {t("knowledgeBase.selected.suffix")} + +
+
+ {/* Select All button */} + {filteredKnowledgeBases.length > 0 && + tempSelectedIds.length < filteredKnowledgeBases.length && ( + + )} + {/* Clear Selection button */} + {tempSelectedIds.length > 0 && ( + + )} +
+
+ + {tempSelectedIds.length > 0 && ( +
+ {tempSelectedIds.map((id) => { + const kb = knowledgeBases.find((kb) => kb.id === id); + return kb ? ( + + + {kb.name} + + + + ) : null; + })} +
+ )} +
+
+ + {/* Knowledge base list - consistent with KnowledgeBaseList */} +
+ {isLoading ? ( +
+ +
+ ) : filteredKnowledgeBases.length > 0 ? ( +
+ {filteredKnowledgeBases.map((kb, index) => { + const isSelected = tempSelectedIds.includes(kb.id); + const canSelect = checkCanSelect(kb); + const hasModelMismatch = checkModelMismatch(kb); + + return ( +
canSelect && toggleSelection(kb.id)} + > +
+ {showCheckbox && ( +
{ + e.stopPropagation(); + }} + style={{ + minWidth: "40px", + minHeight: "40px", + display: "flex", + alignItems: "flex-start", + justifyContent: "center", + }} + > + + { + e.stopPropagation(); + toggleSelection(kb.id); + }} + style={{ + cursor: + canSelect || isSelected + ? "pointer" + : "not-allowed", + transform: "scale(1.5)", + }} + /> + +
+ )} +
+ {/* First row: Name */} +
+

+ {kb.name} +

+
+ + {/* First row: Basic info tags */} +
+ {/* Document count tag */} + + {t("knowledgeBase.tag.documents", { + count: kb.documentCount || 0, + })} + + + {/* Chunk count tag */} + + {t("knowledgeBase.tag.chunks", { + count: kb.chunkCount || 0, + })} + + + {/* Source tag */} + + {t("knowledgeBase.tag.source", { + source: t(`knowledgeBase.source.${kb.source}`, { + defaultValue: kb.source, + }), + })} + + + {/* Creation date - only show when there are documents or chunks */} + {((kb.documentCount || 0) > 0 || + (kb.chunkCount || 0) > 0) && ( + + {t("knowledgeBase.tag.createdAt", { + date: formatDate(kb.createdAt), + })} + + )} +
+ + {/* Second row: Model tags */} +
+ {/* Model tag - only show when model is not "unknown" and there are documents or chunks */} + {((kb.documentCount || 0) > 0 || + (kb.chunkCount || 0) > 0) && + kb.embeddingModel && + kb.embeddingModel !== "unknown" && ( + + {getModelDisplayName(kb.embeddingModel)} + {t("knowledgeBase.tag.model", { + model: "", + })} + + )} + {/* Model mismatch tag - only for nexent source */} + {hasModelMismatch && ( + + {t("knowledgeBase.tag.modelMismatch")} + + )} +
+
+
+
+ ); + })} +
+ ) : ( +
+ {searchKeyword || selectedSources.length > 0 + ? t("knowledgeBase.list.noResults") + : t("knowledgeBase.list.empty")} +
+ )} +
+ + ); +} diff --git a/frontend/components/tool-config/index.ts b/frontend/components/tool-config/index.ts new file mode 100644 index 000000000..b424b3225 --- /dev/null +++ b/frontend/components/tool-config/index.ts @@ -0,0 +1,38 @@ +// Tool configuration related types and interfaces + +import { KnowledgeBase } from "@/types/knowledgeBase"; + +// Knowledge base selector component props +export interface KnowledgeBaseSelectorProps { + isOpen: boolean; + onClose: () => void; + onConfirm: (selectedKnowledgeBases: KnowledgeBase[]) => void; + selectedIds: string[]; + toolType: "knowledge_base_search" | "dify_search" | "datamate_search"; + title?: string; + maxSelect?: number; + showCreateButton?: boolean; + showDeleteButton?: boolean; + showCheckbox?: boolean; + // Dify configuration for fetching Dify knowledge bases + difyConfig?: { + serverUrl?: string; + apiKey?: string; + }; +} + +// Get supported knowledge base sources for a tool type +export function getKnowledgeBaseSourcesForTool( + toolType: "knowledge_base_search" | "dify_search" | "datamate_search" +): string[] { + switch (toolType) { + case "knowledge_base_search": + return ["nexent"]; + case "dify_search": + return ["dify"]; + case "datamate_search": + return ["datamate"]; + default: + return ["nexent"]; + } +} diff --git a/frontend/const/knowledgeBaseLayout.ts b/frontend/const/knowledgeBaseLayout.ts new file mode 100644 index 000000000..082c40be5 --- /dev/null +++ b/frontend/const/knowledgeBaseLayout.ts @@ -0,0 +1,59 @@ +/** + * Knowledge Base List Layout Constants + * + * Shared layout configuration for knowledge base list components. + * Used by both KnowledgeBaseList (standalone page) and KnowledgeBaseSelectorModal (popup). + */ + +// Knowledge base layout constants configuration +export const KB_LAYOUT = { + // Row padding + ROW_PADDING: "py-3", + // Header padding + HEADER_PADDING: "p-3", + // Button area padding + BUTTON_AREA_PADDING: "p-2", + // Tag spacing + TAG_SPACING: "gap-1", + // Tag margin + TAG_MARGIN: "mt-1.5", + // Tag padding + TAG_PADDING: "px-2 py-0.5", + // Tag text style + TAG_TEXT: "text-xs font-medium", + // Tag rounded corners + TAG_ROUNDED: "rounded-md", + // Line break height + TAG_BREAK_HEIGHT: "h-0.5", + // Second row tag margin + SECOND_ROW_TAG_MARGIN: "mt-1", + // Title margin + TITLE_MARGIN: "ml-2", + // Empty state padding + EMPTY_STATE_PADDING: "py-4", + // Title text style + TITLE_TEXT: "text-lg font-bold", + // Knowledge base name text style + KB_NAME_TEXT: "text-base font-medium", + // Knowledge base name max width + KB_NAME_MAX_WIDTH: "220px", + // Knowledge base name overflow style + KB_NAME_OVERFLOW: { + textOverflow: "ellipsis", + whiteSpace: "nowrap", + overflow: "hidden", + display: "block", + }, +} as const; + +// Tag style variants for different contexts +export const KB_TAG_VARIANTS = { + // Default gray tag (used in modal) + default: "bg-gray-100 text-gray-600 border border-gray-200", + // Light gray tag (used in list) + light: "bg-gray-200 text-gray-800 border border-gray-200", + // Green tag for model + model: "bg-green-50 text-green-700 border border-green-200", + // Yellow tag for model mismatch + warning: "bg-yellow-100 text-yellow-800 border border-yellow-200", +} as const; diff --git a/frontend/public/locales/en/common.json b/frontend/public/locales/en/common.json index d25dbb7d2..5a0649fd5 100644 --- a/frontend/public/locales/en/common.json +++ b/frontend/public/locales/en/common.json @@ -407,7 +407,7 @@ "toolConfig.button.selectKnowledgeBases": "Select Knowledge Bases", "toolConfig.input.knowledgeBaseSelector.placeholder": "Click to select {{name}}", "toolConfig.knowledgeBaseSelector.title.default": "Select Knowledge Base", - "toolConfig.knowledgeBaseSelector.title.local": "Select Local Knowledge Base", + "toolConfig.knowledgeBaseSelector.title.local": "Select Nexent Knowledge Base", "toolConfig.knowledgeBaseSelector.title.dify": "Select Dify Knowledge Base", "toolConfig.knowledgeBaseSelector.title.datamate": "Select DataMate Knowledge Base", "toolPool.title": "Select tools", @@ -471,6 +471,7 @@ "knowledgeBase.selected.suffix": "knowledge bases for retrieval", "knowledgeBase.selected.count": "{{count}} selected", "knowledgeBase.button.clearSelection": "Clear Selection", + "knowledgeBase.button.selectAll": "Select All", "knowledgeBase.button.removeKb": "Remove knowledge base {{name}}", "knowledgeBase.tag.documents": "{{count}} Documents", "knowledgeBase.tag.chunks": "{{count}} Chunks", @@ -488,6 +489,7 @@ "knowledgeBase.source.nexent": "Nexent", "knowledgeBase.source.datamate": "DataMate", "knowledgeBase.source.dify": "Dify", + "knowledgeBase.datamate.editDisabled": "DataMate knowledge bases do not support document editing", "knowledgeBase.filter.allSources": "All Sources", "knowledgeBase.filter.allModels": "All Models", "knowledgeBase.filter.source": "Source", diff --git a/frontend/public/locales/zh/common.json b/frontend/public/locales/zh/common.json index e5cdc74dd..5f16e18d9 100644 --- a/frontend/public/locales/zh/common.json +++ b/frontend/public/locales/zh/common.json @@ -408,7 +408,7 @@ "toolConfig.button.selectKnowledgeBases": "选择知识库", "toolConfig.input.knowledgeBaseSelector.placeholder": "点击选择{{name}}", "toolConfig.knowledgeBaseSelector.title.default": "选择知识库", - "toolConfig.knowledgeBaseSelector.title.local": "选择本地知识库", + "toolConfig.knowledgeBaseSelector.title.local": "选择 Nexent 知识库", "toolConfig.knowledgeBaseSelector.title.dify": "选择 Dify 知识库", "toolConfig.knowledgeBaseSelector.title.datamate": "选择 DataMate 知识库", "toolPool.title": "选择智能体的工具", @@ -472,6 +472,7 @@ "knowledgeBase.selected.suffix": "个知识库用于知识检索", "knowledgeBase.selected.count": "已选择 {{count}} 个", "knowledgeBase.button.clearSelection": "清除选择", + "knowledgeBase.button.selectAll": "全选", "knowledgeBase.button.removeKb": "移除知识库 {{name}}", "knowledgeBase.tag.documents": "{{count}} 文档", "knowledgeBase.tag.chunks": "{{count}} 分块", @@ -488,6 +489,7 @@ "knowledgeBase.source.nexent": "Nexent", "knowledgeBase.source.datamate": "DataMate", "knowledgeBase.source.dify": "Dify", + "knowledgeBase.datamate.editDisabled": "Nexent无法上传文件至DataMate知识库,请前往DataMate页面进行操作。", "knowledgeBase.filter.allSources": "全部来源", "knowledgeBase.filter.allModels": "全部模型", "knowledgeBase.filter.source": "来源", diff --git a/frontend/services/api.ts b/frontend/services/api.ts index a157f55e2..d88c9db0d 100644 --- a/frontend/services/api.ts +++ b/frontend/services/api.ts @@ -156,6 +156,9 @@ export const API_ENDPOINTS = { pathOrUrl )}/error-info`, }, + dify: { + datasets: `${API_BASE_URL}/dify/datasets`, + }, datamate: { syncDatamateKnowledges: `${API_BASE_URL}/datamate/sync_datamate_knowledges`, files: (knowledgeBaseId: string) => @@ -167,8 +170,6 @@ export const API_ENDPOINTS = { saveDataMateUrl: `${API_BASE_URL}/config/save_datamate_url`, }, tenantConfig: { - loadKnowledgeList: `${API_BASE_URL}/tenant_config/load_knowledge_list`, - updateKnowledgeList: `${API_BASE_URL}/tenant_config/update_knowledge_list`, deploymentVersion: `${API_BASE_URL}/tenant_config/deployment_version`, }, mcp: { diff --git a/frontend/services/knowledgeBaseService.ts b/frontend/services/knowledgeBaseService.ts index 5fa9f4921..d038c800a 100644 --- a/frontend/services/knowledgeBaseService.ts +++ b/frontend/services/knowledgeBaseService.ts @@ -41,6 +41,94 @@ class KnowledgeBaseService { } } + // Sync Dify knowledge bases + async syncDifyKnowledgeBases( + difyApiBase: string, + apiKey: string + ): Promise<{ + indices: string[]; + count: number; + indices_info: any[]; + }> { + try { + // Call backend proxy endpoint to avoid CORS issues + const url = new URL(API_ENDPOINTS.dify.datasets, window.location.origin); + url.searchParams.set("dify_api_base", difyApiBase); + url.searchParams.set("api_key", apiKey); + + const response = await fetch(url.toString(), { + method: "GET", + headers: getAuthHeaders(), + }); + + if (!response.ok) { + const errorData = await response.json(); + throw new Error(errorData.detail || "Failed to fetch Dify datasets"); + } + + const result = await response.json(); + return { + indices: result.indices || [], + count: result.count || 0, + indices_info: result.indices_info || [], + }; + } catch (error) { + log.error("Failed to sync Dify knowledge bases:", error); + throw error; + } + } + + // Get Dify knowledge bases as KnowledgeBase array + async getDifyKnowledgeBases( + difyApiBase: string, + apiKey: string + ): Promise { + try { + const syncResult = await this.syncDifyKnowledgeBases(difyApiBase, apiKey); + + if (!syncResult.indices_info || syncResult.indices_info.length === 0) { + return []; + } + + // Transform to KnowledgeBase format + const difyKnowledgeBases: KnowledgeBase[] = syncResult.indices_info.map( + (indexInfo: any) => { + const stats = indexInfo.stats?.base_info || {}; + return { + id: indexInfo.name, + name: indexInfo.display_name || indexInfo.name, + display_name: indexInfo.display_name || indexInfo.name, + description: "Dify knowledge base", + documentCount: stats.doc_count || 0, + chunkCount: stats.chunk_count || 0, + createdAt: stats.creation_date || null, + updatedAt: stats.update_date || stats.creation_date || null, + embeddingModel: stats.embedding_model || "unknown", + knowledge_sources: "dify", + ingroup_permission: "", + group_ids: [], + store_size: stats.store_size || "", + process_source: stats.process_source || "Dify", + avatar: "", + chunkNum: 0, + language: "", + nickname: "", + parserId: "", + permission: "", + tokenNum: 0, + source: "dify", + tenant_id: "", + }; + } + ); + + return difyKnowledgeBases; + } catch (error) { + log.error("Failed to get Dify knowledge bases:", error); + throw error; + } + } + // Sync DataMate knowledge bases and create local records async syncDataMateAndCreateRecords(): Promise<{ indices: string[]; @@ -76,6 +164,78 @@ class KnowledgeBaseService { } } + // Sync Dify knowledge bases + async syncDifyDatasets( + difyApiBase: string, + apiKey: string + ): Promise<{ + indices: string[]; + count: number; + indices_info: any[]; + }> { + try { + // Normalize URL by removing trailing slash + const normalizedApiBase = difyApiBase.replace(/\/+$/, ""); + const url = `${normalizedApiBase}/v1/datasets`; + + const response = await fetch(url, { + method: "GET", + headers: { + Authorization: `Bearer ${apiKey}`, + "Content-Type": "application/json", + }, + }); + + if (!response.ok) { + throw new Error(`Dify API error: ${response.status}`); + } + + const result = await response.json(); + const datasetsData = result.data || []; + + // Transform to internal format + const indices: string[] = []; + const indices_info: any[] = []; + + for (const dataset of datasetsData) { + const datasetId = dataset.id; + if (!datasetId) continue; + + indices.push(datasetId); + + indices_info.push({ + name: datasetId, + display_name: dataset.name, + stats: { + base_info: { + doc_count: dataset.document_count || 0, + chunk_count: 0, + store_size: "", + process_source: "Dify", + embedding_model: dataset.embedding_model || "", + embedding_dim: 0, + creation_date: (dataset.created_at || 0) * 1000, + update_date: (dataset.updated_at || 0) * 1000, + }, + search_performance: { + total_search_count: 0, + hit_count: 0, + }, + }, + }); + } + + return { + indices, + count: indices.length, + indices_info, + }; + } catch (error) { + log.error("Failed to sync Dify datasets:", error); + throw error; + } + } + // Get knowledge bases with stats from all sources (very slow, don't use it) async getKnowledgeBasesInfo( skipHealthCheck = false, @@ -121,14 +281,20 @@ class KnowledgeBaseService { return { id: kbId, name: kbName, + display_name: indexInfo.display_name || indexInfo.name, description: "Elasticsearch index", documentCount: stats.doc_count || 0, chunkCount: stats.chunk_count || 0, createdAt: stats.creation_date || null, // Use update_time from database for sorting, fallback to ES update_date - updatedAt: indexInfo.update_time || stats.update_date || stats.creation_date || null, + updatedAt: + indexInfo.update_time || + stats.update_date || + stats.creation_date || + null, embeddingModel: stats.embedding_model || "unknown", - knowledge_sources: indexInfo.knowledge_sources || "elasticsearch", + knowledge_sources: + indexInfo.knowledge_sources || "elasticsearch", ingroup_permission: indexInfo.ingroup_permission || "", group_ids: indexInfo.group_ids || [], store_size: stats.store_size || "", @@ -168,6 +334,7 @@ class KnowledgeBaseService { return { id: kbId, name: kbName, + display_name: indexInfo.display_name || indexInfo.name, description: "DataMate knowledge base", documentCount: stats.doc_count || 0, chunkCount: stats.chunk_count || 0, @@ -318,7 +485,7 @@ class KnowledgeBaseService { parserId: "", permission: "", tokenNum: 0, - source: params.source || "elasticsearch" + source: params.source || "elasticsearch", }; } catch (error) { log.error("Failed to create knowledge base:", error); @@ -954,7 +1121,9 @@ class KnowledgeBaseService { const result = await response.json(); if (!response.ok) { - throw new Error(result.detail || result.message || "Failed to update knowledge base"); + throw new Error( + result.detail || result.message || "Failed to update knowledge base" + ); } } catch (error) { log.error("Failed to update knowledge base:", error); diff --git a/frontend/services/userConfigService.ts b/frontend/services/userConfigService.ts deleted file mode 100644 index 99f4d70c0..000000000 --- a/frontend/services/userConfigService.ts +++ /dev/null @@ -1,59 +0,0 @@ -import { API_ENDPOINTS } from './api'; -import { UserKnowledgeConfig, UpdateKnowledgeListRequest } from '../types/knowledgeBase'; - -import { fetchWithAuth, getAuthHeaders } from '@/lib/auth'; -// @ts-ignore -const fetch = fetchWithAuth; - -export class UserConfigService { - // Get user selected knowledge base list - async loadKnowledgeList(): Promise { - try { - const response = await fetch(API_ENDPOINTS.tenantConfig.loadKnowledgeList, { - method: 'GET', - headers: getAuthHeaders(), - }); - - if (!response.ok) { - return null; - } - - const result = await response.json(); - if (result.status === 'success') { - return result.content; - } - return null; - } catch (error) { - return null; - } - } - - // Update user selected knowledge base list - async updateKnowledgeList(request: UpdateKnowledgeListRequest): Promise { - try { - const response = await fetch( - API_ENDPOINTS.tenantConfig.updateKnowledgeList, - { - method: "POST", - headers: getAuthHeaders(), - body: JSON.stringify(request), - } - ); - - if (!response.ok) { - return null; - } - - const result = await response.json(); - if (result.status === 'success') { - return result.content; - } - return null; - } catch (error) { - return null; - } - } -} - -// Export singleton instance -export const userConfigService = new UserConfigService(); \ No newline at end of file diff --git a/frontend/types/knowledgeBase.ts b/frontend/types/knowledgeBase.ts index 5583a8f18..82c3ae49e 100644 --- a/frontend/types/knowledgeBase.ts +++ b/frontend/types/knowledgeBase.ts @@ -1,11 +1,17 @@ // Knowledge base related type definitions -import { DOCUMENT_ACTION_TYPES, KNOWLEDGE_BASE_ACTION_TYPES, UI_ACTION_TYPES, NOTIFICATION_TYPES } from "@/const/knowledgeBase"; +import { + DOCUMENT_ACTION_TYPES, + KNOWLEDGE_BASE_ACTION_TYPES, + UI_ACTION_TYPES, + NOTIFICATION_TYPES, +} from "@/const/knowledgeBase"; // Knowledge base basic type export interface KnowledgeBase { id: string; name: string; + display_name?: string; // User-friendly display name, falls back to name if not available description: string | null; chunkCount: number; documentCount: number; @@ -69,17 +75,32 @@ export interface DocumentState { // Document action type export type DocumentAction = - | { type: typeof DOCUMENT_ACTION_TYPES.FETCH_SUCCESS, payload: { kbId: string, documents: Document[] } } - | { type: typeof DOCUMENT_ACTION_TYPES.SELECT_DOCUMENT, payload: string } - | { type: typeof DOCUMENT_ACTION_TYPES.SELECT_DOCUMENTS, payload: string[] } - | { type: typeof DOCUMENT_ACTION_TYPES.SELECT_ALL, payload: { kbId: string, selected: boolean } } - | { type: typeof DOCUMENT_ACTION_TYPES.SET_UPLOAD_FILES, payload: File[] } - | { type: typeof DOCUMENT_ACTION_TYPES.SET_UPLOADING, payload: boolean } - | { type: typeof DOCUMENT_ACTION_TYPES.SET_LOADING_DOCUMENTS, payload: boolean } - | { type: typeof DOCUMENT_ACTION_TYPES.DELETE_DOCUMENT, payload: { kbId: string, docId: string } } - | { type: typeof DOCUMENT_ACTION_TYPES.SET_LOADING_KB_ID, payload: { kbId: string, isLoading: boolean } } - | { type: typeof DOCUMENT_ACTION_TYPES.CLEAR_DOCUMENTS, payload?: undefined } - | { type: typeof DOCUMENT_ACTION_TYPES.ERROR, payload: string }; + | { + type: typeof DOCUMENT_ACTION_TYPES.FETCH_SUCCESS; + payload: { kbId: string; documents: Document[] }; + } + | { type: typeof DOCUMENT_ACTION_TYPES.SELECT_DOCUMENT; payload: string } + | { type: typeof DOCUMENT_ACTION_TYPES.SELECT_DOCUMENTS; payload: string[] } + | { + type: typeof DOCUMENT_ACTION_TYPES.SELECT_ALL; + payload: { kbId: string; selected: boolean }; + } + | { type: typeof DOCUMENT_ACTION_TYPES.SET_UPLOAD_FILES; payload: File[] } + | { type: typeof DOCUMENT_ACTION_TYPES.SET_UPLOADING; payload: boolean } + | { + type: typeof DOCUMENT_ACTION_TYPES.SET_LOADING_DOCUMENTS; + payload: boolean; + } + | { + type: typeof DOCUMENT_ACTION_TYPES.DELETE_DOCUMENT; + payload: { kbId: string; docId: string }; + } + | { + type: typeof DOCUMENT_ACTION_TYPES.SET_LOADING_KB_ID; + payload: { kbId: string; isLoading: boolean }; + } + | { type: typeof DOCUMENT_ACTION_TYPES.CLEAR_DOCUMENTS; payload?: undefined } + | { type: typeof DOCUMENT_ACTION_TYPES.ERROR; payload: string }; // Knowledge base state interface export interface KnowledgeBaseState { @@ -94,15 +115,36 @@ export interface KnowledgeBaseState { // Knowledge base action type export type KnowledgeBaseAction = - | { type: typeof KNOWLEDGE_BASE_ACTION_TYPES.FETCH_SUCCESS, payload: KnowledgeBase[] } - | { type: typeof KNOWLEDGE_BASE_ACTION_TYPES.SELECT_KNOWLEDGE_BASE, payload: string[] } - | { type: typeof KNOWLEDGE_BASE_ACTION_TYPES.SET_ACTIVE, payload: KnowledgeBase | null } - | { type: typeof KNOWLEDGE_BASE_ACTION_TYPES.SET_MODEL, payload: string | null } - | { type: typeof KNOWLEDGE_BASE_ACTION_TYPES.DELETE_KNOWLEDGE_BASE, payload: string } - | { type: typeof KNOWLEDGE_BASE_ACTION_TYPES.ADD_KNOWLEDGE_BASE, payload: KnowledgeBase } - | { type: typeof KNOWLEDGE_BASE_ACTION_TYPES.LOADING, payload: boolean } - | { type: typeof KNOWLEDGE_BASE_ACTION_TYPES.SET_SYNC_LOADING, payload: boolean } - | { type: typeof KNOWLEDGE_BASE_ACTION_TYPES.ERROR, payload: string }; + | { + type: typeof KNOWLEDGE_BASE_ACTION_TYPES.FETCH_SUCCESS; + payload: KnowledgeBase[]; + } + | { + type: typeof KNOWLEDGE_BASE_ACTION_TYPES.SELECT_KNOWLEDGE_BASE; + payload: string[]; + } + | { + type: typeof KNOWLEDGE_BASE_ACTION_TYPES.SET_ACTIVE; + payload: KnowledgeBase | null; + } + | { + type: typeof KNOWLEDGE_BASE_ACTION_TYPES.SET_MODEL; + payload: string | null; + } + | { + type: typeof KNOWLEDGE_BASE_ACTION_TYPES.DELETE_KNOWLEDGE_BASE; + payload: string; + } + | { + type: typeof KNOWLEDGE_BASE_ACTION_TYPES.ADD_KNOWLEDGE_BASE; + payload: KnowledgeBase; + } + | { type: typeof KNOWLEDGE_BASE_ACTION_TYPES.LOADING; payload: boolean } + | { + type: typeof KNOWLEDGE_BASE_ACTION_TYPES.SET_SYNC_LOADING; + payload: boolean; + } + | { type: typeof KNOWLEDGE_BASE_ACTION_TYPES.ERROR; payload: string }; // UI state interface export interface UIState { @@ -112,35 +154,33 @@ export interface UIState { notifications: { id: string; message: string; - type: typeof NOTIFICATION_TYPES.SUCCESS | typeof NOTIFICATION_TYPES.ERROR | typeof NOTIFICATION_TYPES.INFO | typeof NOTIFICATION_TYPES.WARNING; + type: + | typeof NOTIFICATION_TYPES.SUCCESS + | typeof NOTIFICATION_TYPES.ERROR + | typeof NOTIFICATION_TYPES.INFO + | typeof NOTIFICATION_TYPES.WARNING; }[]; } // UI action type export type UIAction = - | { type: typeof UI_ACTION_TYPES.SET_DRAGGING, payload: boolean } - | { type: typeof UI_ACTION_TYPES.TOGGLE_CREATE_MODAL, payload: boolean } - | { type: typeof UI_ACTION_TYPES.TOGGLE_DOC_MODAL, payload: boolean } - | { type: typeof UI_ACTION_TYPES.ADD_NOTIFICATION, payload: { message: string; type: typeof NOTIFICATION_TYPES.SUCCESS | typeof NOTIFICATION_TYPES.ERROR | typeof NOTIFICATION_TYPES.INFO | typeof NOTIFICATION_TYPES.WARNING } } - | { type: typeof UI_ACTION_TYPES.REMOVE_NOTIFICATION, payload: string }; + | { type: typeof UI_ACTION_TYPES.SET_DRAGGING; payload: boolean } + | { type: typeof UI_ACTION_TYPES.TOGGLE_CREATE_MODAL; payload: boolean } + | { type: typeof UI_ACTION_TYPES.TOGGLE_DOC_MODAL; payload: boolean } + | { + type: typeof UI_ACTION_TYPES.ADD_NOTIFICATION; + payload: { + message: string; + type: + | typeof NOTIFICATION_TYPES.SUCCESS + | typeof NOTIFICATION_TYPES.ERROR + | typeof NOTIFICATION_TYPES.INFO + | typeof NOTIFICATION_TYPES.WARNING; + }; + } + | { type: typeof UI_ACTION_TYPES.REMOVE_NOTIFICATION; payload: string }; // Abortable error type for upload operations export interface AbortableError extends Error { name: string; } - -// User selected knowledge base configuration type -export interface UserKnowledgeConfig { - selectedKbNames?: string[]; - selectedKbModels?: string[]; - selectedKbSources?: string[]; - // Legacy support for grouped format - nexent?: string[]; - datamate?: string[]; -} - -// Update knowledge list request type -export interface UpdateKnowledgeListRequest { - nexent?: string[]; - datamate?: string[]; -} diff --git a/sdk/nexent/core/tools/dify_search_tool.py b/sdk/nexent/core/tools/dify_search_tool.py index bc964acd2..34fbb14d6 100644 --- a/sdk/nexent/core/tools/dify_search_tool.py +++ b/sdk/nexent/core/tools/dify_search_tool.py @@ -69,19 +69,23 @@ def __init__( raise ValueError( "api_key is required and must be a non-empty string") - # Parse and validate dataset_ids from JSON string - if not dataset_ids or not isinstance(dataset_ids, str): + # Parse and validate dataset_ids from string or list + if not dataset_ids: raise ValueError( - "dataset_ids is required and must be a non-empty JSON string array") + "dataset_ids is required and must be a non-empty JSON string array or list") try: - parsed_ids = json.loads(dataset_ids) + # Handle both JSON string array and plain list + if isinstance(dataset_ids, str): + parsed_ids = json.loads(dataset_ids) + else: + parsed_ids = dataset_ids if not isinstance(parsed_ids, list) or not parsed_ids: raise ValueError( - "dataset_ids must be a non-empty JSON array of strings") + "dataset_ids must be a non-empty array of strings") self.dataset_ids = [str(item) for item in parsed_ids] except (json.JSONDecodeError, TypeError) as e: raise ValueError( - f"dataset_ids must be a valid JSON string array: {str(e)}") + f"dataset_ids must be a valid JSON string array or list: {str(e)}") self.server_url = server_url.rstrip("/") self.api_key = api_key diff --git a/test/sdk/core/tools/test_dify_search_tool.py b/test/sdk/core/tools/test_dify_search_tool.py index 893801b88..e9eebcc2b 100644 --- a/test/sdk/core/tools/test_dify_search_tool.py +++ b/test/sdk/core/tools/test_dify_search_tool.py @@ -138,9 +138,9 @@ def test_init_invalid_api_key(self, api_key, expected_error): assert expected_error in str(excinfo.value) @pytest.mark.parametrize("dataset_ids,expected_error", [ - ("[]", "dataset_ids must be a non-empty JSON array of strings"), - ("", "dataset_ids is required and must be a non-empty JSON string array"), - (None, "dataset_ids is required and must be a non-empty JSON string array"), + ([], "dataset_ids is required and must be a non-empty JSON string array or list"), + ("", "dataset_ids is required and must be a non-empty JSON string array or list"), + (None, "dataset_ids is required and must be a non-empty JSON string array or list"), ]) def test_init_invaliddataset_ids(self, dataset_ids, expected_error): with pytest.raises(ValueError) as excinfo: @@ -151,6 +151,97 @@ def test_init_invaliddataset_ids(self, dataset_ids, expected_error): ) assert expected_error in str(excinfo.value) + def test_init_dataset_ids_empty_json_array_string(self, mock_observer: MessageObserver): + """Test that empty JSON array '[]' raises ValueError.""" + with pytest.raises(ValueError) as excinfo: + DifySearchTool( + server_url="https://api.dify.ai/v1", + api_key="test_key", + dataset_ids="[]", + observer=mock_observer, + ) + # Empty JSON array passes the first check (not falsy), but fails the list/empty check + assert "dataset_ids must be a non-empty array of strings" in str(excinfo.value) + + def test_init_dataset_ids_as_list(self, mock_observer: MessageObserver): + """Test dataset_ids can be passed as a Python list instead of JSON string.""" + tool = DifySearchTool( + server_url="https://api.dify.ai/v1", + api_key="test_key", + dataset_ids=["ds1", "ds2", "ds3"], + observer=mock_observer, + ) + + assert tool.dataset_ids == ["ds1", "ds2", "ds3"] + assert len(tool.dataset_ids) == 3 + + def test_init_dataset_ids_as_list_single_item(self, mock_observer: MessageObserver): + """Test dataset_ids as a list with single item.""" + tool = DifySearchTool( + server_url="https://api.dify.ai/v1", + api_key="test_key", + dataset_ids=["single_dataset"], + observer=mock_observer, + ) + + assert tool.dataset_ids == ["single_dataset"] + assert len(tool.dataset_ids) == 1 + + def test_init_dataset_ids_as_list_with_numeric_ids(self, mock_observer: MessageObserver): + """Test dataset_ids list with numeric IDs are converted to strings.""" + tool = DifySearchTool( + server_url="https://api.dify.ai/v1", + api_key="test_key", + dataset_ids=[123, 456, 789], + observer=mock_observer, + ) + + assert tool.dataset_ids == ["123", "456", "789"] + assert all(isinstance(id, str) for id in tool.dataset_ids) + + @pytest.mark.parametrize("invalid_json,expected_error_contains", [ + ("invalid_json", "dataset_ids must be a valid JSON string array or list"), + ("{key: value}", "dataset_ids must be a valid JSON string array or list"), + ("{'key': 'value'}", "dataset_ids must be a valid JSON string array or list"), + ("123", "dataset_ids must be a non-empty array of strings"), + ]) + def test_init_invalid_json_format(self, invalid_json, expected_error_contains, mock_observer: MessageObserver): + """Test dataset_ids with invalid JSON format raises appropriate error.""" + with pytest.raises(ValueError) as excinfo: + DifySearchTool( + server_url="https://api.dify.ai/v1", + api_key="test_key", + dataset_ids=invalid_json, + observer=mock_observer, + ) + assert expected_error_contains in str(excinfo.value) + + def test_init_dataset_ids_with_malformed_json_array(self, mock_observer: MessageObserver): + """Test dataset_ids with malformed JSON array.""" + with pytest.raises(ValueError) as excinfo: + DifySearchTool( + server_url="https://api.dify.ai/v1", + api_key="test_key", + dataset_ids='["ds1", "ds2"', # Missing closing bracket + observer=mock_observer, + ) + assert "dataset_ids must be a valid JSON string array or list" in str(excinfo.value) + + def test_init_dataset_ids_json_string_with_non_string_elements(self, mock_observer: MessageObserver): + """Test that non-string elements in JSON array are converted to strings.""" + tool = DifySearchTool( + server_url="https://api.dify.ai/v1", + api_key="test_key", + dataset_ids='["ds1", 123, true, null]', + observer=mock_observer, + ) + + # Elements should be converted to strings using Python's str() + # JSON true -> Python True -> str() -> 'True' + # JSON null -> Python None -> str() -> 'None' + assert tool.dataset_ids == ["ds1", "123", "True", "None"] + assert all(isinstance(id, str) for id in tool.dataset_ids) + class TestGetDocumentDownloadUrl: def test_get_document_download_url_success(self, mocker: MockFixture, dify_tool: DifySearchTool): From f2eeb96774256ed884cbb12fce4dd8f24db83245 Mon Sep 17 00:00:00 2001 From: panyehong <2655992392@qq.com> Date: Tue, 3 Feb 2026 20:50:12 +0800 Subject: [PATCH 046/167] =?UTF-8?q?=E2=9C=A8=20User=20management:=20Agent?= =?UTF-8?q?=20config=20page=20update=20&=20MCP=20tool=20management=20page?= =?UTF-8?q?=20#2148=20&=20#2149=20[Specification=20Details]=201.=20The=20c?= =?UTF-8?q?reators=20of=20the=20intelligent=20agents=20and=20MCP=20service?= =?UTF-8?q?s,=20as=20well=20as=20ADMIN,=20SU,=20and=20SPEED,=20have=20edit?= =?UTF-8?q?ing=20permissions,=20while=20other=20DEVs=20only=20have=20viewi?= =?UTF-8?q?ng=20and=20usage=20permissions.=202.=20Add=20a=20new=20SU=20acc?= =?UTF-8?q?ount=20during=20initialization.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/apps/agent_app.py | 4 +- backend/apps/remote_mcp_app.py | 13 +- backend/consts/const.py | 8 + backend/services/agent_service.py | 34 +- backend/services/remote_mcp_service.py | 70 +- backend/services/vectordatabase_service.py | 16 +- docker/deploy.sh | 38 +- docker/init.sql | 4 + .../v1.8.0_0202_add_suadmin_user_tenant_t.sql | 3 + .../agents/components/AgentConfigComp.tsx | 2 - .../agents/components/AgentInfoComp.tsx | 45 +- .../agents/components/AgentManageComp.tsx | 8 +- .../agentConfig/CollaborativeAgent.tsx | 8 +- .../components/agentConfig/McpConfigModal.tsx | 135 ++-- .../components/agentConfig/ToolManagement.tsx | 13 +- .../agentInfo/AgentGenerateDetail.tsx | 392 +++++++--- frontend/public/locales/en/common.json | 35 +- frontend/public/locales/zh/common.json | 35 +- frontend/services/agentConfigService.ts | 1 + frontend/services/mcpService.ts | 3 +- test/backend/app/test_agent_app.py | 28 +- test/backend/app/test_remote_mcp_app.py | 39 +- test/backend/services/test_agent_service.py | 718 +++++++++++++----- .../services/test_remote_mcp_service.py | 36 + 24 files changed, 1206 insertions(+), 482 deletions(-) create mode 100644 docker/sql/v1.8.0_0202_add_suadmin_user_tenant_t.sql diff --git a/backend/apps/agent_app.py b/backend/apps/agent_app.py index 72fa198ca..e2cb7f950 100644 --- a/backend/apps/agent_app.py +++ b/backend/apps/agent_app.py @@ -200,8 +200,8 @@ async def list_all_agent_info_api(authorization: Optional[str] = Header(None), r list all agent info """ try: - _, tenant_id, _ = get_current_user_info(authorization, request) - return await list_all_agent_info_impl(tenant_id=tenant_id) + user_id, tenant_id, _ = get_current_user_info(authorization, request) + return await list_all_agent_info_impl(tenant_id=tenant_id, user_id=user_id) except Exception as e: logger.error(f"Agent list error: {str(e)}") raise HTTPException( diff --git a/backend/apps/remote_mcp_app.py b/backend/apps/remote_mcp_app.py index 7fdb7159d..b7770bfec 100644 --- a/backend/apps/remote_mcp_app.py +++ b/backend/apps/remote_mcp_app.py @@ -16,6 +16,7 @@ delete_mcp_by_container_id, upload_and_start_mcp_image, update_remote_mcp_server_list, + attach_mcp_container_permissions, ) from database.remote_mcp_db import check_mcp_name_exists from services.tool_configuration_service import get_tool_from_remote_mcp_server @@ -146,8 +147,11 @@ async def get_remote_proxies( ): """ Used to get the list of remote MCP servers """ try: - _, tenant_id = get_current_user_id(authorization) - remote_mcp_server_list = await get_remote_mcp_server_list(tenant_id=tenant_id) + user_id, tenant_id = get_current_user_id(authorization) + remote_mcp_server_list = await get_remote_mcp_server_list( + tenant_id=tenant_id, + user_id=user_id, + ) return JSONResponse( status_code=HTTPStatus.OK, content={"remote_mcp_server_list": remote_mcp_server_list, @@ -384,6 +388,11 @@ async def list_mcp_containers( ) containers = container_manager.list_mcp_containers(tenant_id=tenant_id) + containers = attach_mcp_container_permissions( + containers=containers, + tenant_id=tenant_id, + user_id=user_id, + ) return JSONResponse( status_code=HTTPStatus.OK, diff --git a/backend/consts/const.py b/backend/consts/const.py index f09ae9cf4..b4b4e2f89 100644 --- a/backend/consts/const.py +++ b/backend/consts/const.py @@ -64,6 +64,14 @@ class VectorDatabaseType(str, Enum): DEFAULT_USER_ID = "user_id" DEFAULT_TENANT_ID = "tenant_id" +# Roles that can edit all resources within a tenant (permission = EDIT). +# Keep this centralized to avoid drifting role logic across modules. +CAN_EDIT_ALL_USER_ROLES = {"SU", "ADMIN", "SPEED"} + +# Permission constants used by list endpoints (e.g., /agent/list, /mcp/list). +PERMISSION_READ = "READ_ONLY" +PERMISSION_EDIT = "EDIT" + # Deployment Version Configuration DEPLOYMENT_VERSION = os.getenv("DEPLOYMENT_VERSION", "speed") diff --git a/backend/services/agent_service.py b/backend/services/agent_service.py index 868f5076b..24e8280e0 100644 --- a/backend/services/agent_service.py +++ b/backend/services/agent_service.py @@ -15,7 +15,8 @@ from agents.agent_run_manager import agent_run_manager from agents.create_agent_info import create_agent_run_info, create_tool_config_list from agents.preprocess_manager import preprocess_manager -from consts.const import MEMORY_SEARCH_START_MSG, MEMORY_SEARCH_DONE_MSG, MEMORY_SEARCH_FAIL_MSG, TOOL_TYPE_MAPPING, LANGUAGE, MESSAGE_ROLE, MODEL_CONFIG_MAPPING +from consts.const import MEMORY_SEARCH_START_MSG, MEMORY_SEARCH_DONE_MSG, MEMORY_SEARCH_FAIL_MSG, TOOL_TYPE_MAPPING, \ + LANGUAGE, MESSAGE_ROLE, MODEL_CONFIG_MAPPING, CAN_EDIT_ALL_USER_ROLES, PERMISSION_EDIT, PERMISSION_READ from consts.exceptions import MemoryPreparationException from consts.model import ( AgentInfoRequest, @@ -55,6 +56,7 @@ search_tools_for_sub_agent ) from database.group_db import query_group_ids_by_user +from database.user_tenant_db import get_user_tenant_by_user_id from utils.str_utils import convert_list_to_string, convert_string_to_list from services.conversation_management_service import save_conversation_assistant, save_conversation_user from services.memory_config_service import build_memory_context @@ -1244,12 +1246,13 @@ async def clear_agent_new_mark_impl(agent_id: int, tenant_id: str, user_id: str) -async def list_all_agent_info_impl(tenant_id: str) -> list[dict]: +async def list_all_agent_info_impl(tenant_id: str, user_id: str) -> list[dict]: """ list all agent info Args: tenant_id (str): tenant id + user_id (str): user id (used for permission calculation and filtering) Raises: ValueError: failed to query all agent info @@ -1258,6 +1261,22 @@ async def list_all_agent_info_impl(tenant_id: str) -> list[dict]: list: list of agent info """ try: + user_tenant_record = get_user_tenant_by_user_id(user_id) or {} + user_role = str(user_tenant_record.get("user_role") or "").upper() + + can_edit_all = user_role in CAN_EDIT_ALL_USER_ROLES + + # For DEV/USER, restrict visible agents to those whose group_ids overlap user's groups. + user_group_ids: set[int] = set() + if not can_edit_all: + try: + user_group_ids = set(query_group_ids_by_user(user_id) or []) + except Exception as e: + logger.warning( + f"Failed to query user group ids for filtering: user_id={user_id}, err={str(e)}" + ) + user_group_ids = set() + agent_list = query_all_agent_info_by_tenant_id(tenant_id=tenant_id) model_cache: Dict[int, Optional[dict]] = {} @@ -1267,6 +1286,12 @@ async def list_all_agent_info_impl(tenant_id: str) -> list[dict]: if not agent["enabled"]: continue + # Apply visibility filter for DEV/USER based on group overlap + if not can_edit_all: + agent_group_ids = set(convert_string_to_list(agent.get("group_ids"))) + if len(user_group_ids.intersection(agent_group_ids)) == 0: + continue + # Use shared availability check function _, unavailable_reasons = check_agent_availability( agent_id=agent["agent_id"], @@ -1297,6 +1322,8 @@ async def list_all_agent_info_impl(tenant_id: str) -> list[dict]: model_cache[model_id] = get_model_by_model_id(model_id, tenant_id) model_info = model_cache.get(model_id) + permission = PERMISSION_EDIT if can_edit_all or str(agent.get("created_by")) == str(user_id) else PERMISSION_READ + simple_agent_list.append({ "agent_id": agent["agent_id"], "name": agent["name"] if agent["name"] else agent["display_name"], @@ -1309,7 +1336,8 @@ async def list_all_agent_info_impl(tenant_id: str) -> list[dict]: "is_available": len(unavailable_reasons) == 0, "unavailable_reasons": unavailable_reasons, "is_new": agent.get("is_new", False), - "group_ids": convert_string_to_list(agent.get("group_ids")) + "group_ids": convert_string_to_list(agent.get("group_ids")), + "permission": permission, }) return simple_agent_list diff --git a/backend/services/remote_mcp_service.py b/backend/services/remote_mcp_service.py index 543536e66..0c8e4576f 100644 --- a/backend/services/remote_mcp_service.py +++ b/backend/services/remote_mcp_service.py @@ -4,6 +4,7 @@ from fastmcp import Client +from consts.const import CAN_EDIT_ALL_USER_ROLES, PERMISSION_EDIT, PERMISSION_READ from consts.exceptions import MCPConnectionError, MCPNameIllegal from database.remote_mcp_db import ( create_mcp_record, @@ -14,6 +15,7 @@ update_mcp_status_by_name_and_url, update_mcp_record_by_name_and_url, ) +from database.user_tenant_db import get_user_tenant_by_user_id from services.mcp_container_service import MCPContainerManager logger = logging.getLogger("remote_mcp_service") @@ -122,19 +124,83 @@ async def update_remote_mcp_server_list( ) -async def get_remote_mcp_server_list(tenant_id: str): +async def get_remote_mcp_server_list(tenant_id: str, user_id: str | None = None) -> list[dict]: mcp_records = get_mcp_records_by_tenant(tenant_id=tenant_id) mcp_records_list = [] + can_edit_all = False + if user_id: + user_tenant_record = get_user_tenant_by_user_id(user_id) or {} + user_role = str(user_tenant_record.get("user_role") or "").upper() + can_edit_all = user_role in CAN_EDIT_ALL_USER_ROLES for record in mcp_records: + created_by = record.get("created_by") or record.get("user_id") + if user_id is None: + permission = PERMISSION_READ + else: + permission = PERMISSION_EDIT if can_edit_all or str( + created_by) == str(user_id) else PERMISSION_READ + mcp_records_list.append({ "remote_mcp_server_name": record["mcp_name"], "remote_mcp_server": record["mcp_server"], - "status": record["status"] + "status": record["status"], + "permission": permission, }) return mcp_records_list +def attach_mcp_container_permissions( + *, + containers: list[dict], + tenant_id: str, + user_id: str | None = None, +) -> list[dict]: + """ + Attach permission (EDIT/READ) to each MCP container entry. + + Rules: + - If user's role is in CAN_EDIT_ALL_USER_ROLES => EDIT for all containers + - Otherwise => EDIT only if the container is associated with an MCP record created by this user + - If association cannot be determined => default to READ + """ + if not containers: + return [] + can_edit_all = False + if user_id: + user_tenant_record = get_user_tenant_by_user_id(user_id) or {} + user_role = str(user_tenant_record.get("user_role") or "").upper() + can_edit_all = user_role in CAN_EDIT_ALL_USER_ROLES + + created_by_by_container_id: dict[str, str] = {} + try: + for record in get_mcp_records_by_tenant(tenant_id=tenant_id) or []: + cid = record.get("container_id") + if not cid: + continue + created_by_by_container_id[str(cid)] = str( + record.get("created_by") or record.get("user_id") or "" + ) + except Exception as e: + logger.warning(f"Failed to load MCP records for permission mapping: {e}") + + enriched: list[dict] = [] + for container in containers: + container_id = str(container.get("container_id") or "") + created_by = created_by_by_container_id.get(container_id, "") + + if user_id is None: + permission = PERMISSION_READ + else: + permission = PERMISSION_EDIT if can_edit_all or ( + created_by and str(created_by) == str(user_id) + ) else PERMISSION_READ + + enriched.append({**container, "permission": permission}) + + return enriched + + async def check_mcp_health_and_update_db(mcp_url, service_name, tenant_id, user_id): # check the health of the MCP server try: diff --git a/backend/services/vectordatabase_service.py b/backend/services/vectordatabase_service.py index c2c61408e..fd222b390 100644 --- a/backend/services/vectordatabase_service.py +++ b/backend/services/vectordatabase_service.py @@ -25,7 +25,7 @@ from nexent.vector_database.elasticsearch_core import ElasticSearchCore from nexent.vector_database.datamate_core import DataMateCore -from consts.const import DATAMATE_URL, ES_API_KEY, ES_HOST, LANGUAGE, VectorDatabaseType, IS_SPEED_MODE +from consts.const import DATAMATE_URL, ES_API_KEY, ES_HOST, LANGUAGE, VectorDatabaseType, IS_SPEED_MODE, PERMISSION_EDIT, PERMISSION_READ from consts.model import ChunkCreateRequest, ChunkUpdateRequest from database.attachment_db import delete_file from database.knowledge_db import ( @@ -574,14 +574,14 @@ def list_indices( if effective_user_role in ["SU", "ADMIN", "SPEED"]: # SU, ADMIN and SPEED roles can see all knowledgebases - permission = "EDIT" + permission = PERMISSION_EDIT elif effective_user_role in ["USER", "DEV"]: # USER/DEV need group-based permission checking kb_group_ids_str = record.get("group_ids") kb_group_ids = convert_string_to_list(kb_group_ids_str or "") kb_created_by = record.get("created_by") kb_ingroup_permission = record.get( - "ingroup_permission") or "READ_ONLY" + "ingroup_permission") or PERMISSION_READ # Check if user belongs to any of the knowledgebase groups # Compatibility logic for legacy data: @@ -602,17 +602,17 @@ def list_indices( if has_group_intersection: # Determine permission level - permission = "READ_ONLY" # Default + permission = PERMISSION_READ # Default # User is creator: creator permission if kb_created_by == user_id: permission = "CREATOR" # Group permission allows editing - elif kb_ingroup_permission == "EDIT": - permission = "EDIT" + elif kb_ingroup_permission == PERMISSION_EDIT: + permission = PERMISSION_EDIT # Group permission is read-only: already set - elif kb_ingroup_permission == "READ_ONLY": - permission = "READ_ONLY" + elif kb_ingroup_permission == PERMISSION_READ: + permission = PERMISSION_READ # Group permission is private: not visible elif kb_ingroup_permission == "PRIVATE": permission = None diff --git a/docker/deploy.sh b/docker/deploy.sh index 2545bf2dc..efb43c792 100755 --- a/docker/deploy.sh +++ b/docker/deploy.sh @@ -859,21 +859,41 @@ select_terminal_tool() { echo "" } -create_default_admin_user() { - echo "🔧 Creating admin user..." - RESPONSE=$(docker exec nexent-config bash -c "curl -X POST http://kong:8000/auth/v1/signup -H \"apikey: ${SUPABASE_KEY}\" -H \"Authorization: Bearer ${SUPABASE_KEY}\" -H \"Content-Type: application/json\" -d '{\"email\":\"nexent@example.com\",\"password\":\"nexent@4321\",\"email_confirm\":true,\"data\":{\"role\":\"admin\"}}'" 2>/dev/null) +generate_random_password() { + # Generate a URL/JSON safe random password (alphanumeric only) + local pwd="" + if command -v openssl >/dev/null 2>&1; then + pwd=$(openssl rand -base64 32 | tr -dc 'A-Za-z0-9' | head -c 20) + else + pwd=$(tr -dc 'A-Za-z0-9' /dev/null) if [ -z "$RESPONSE" ]; then echo " ❌ No response received from Supabase." return 1 elif echo "$RESPONSE" | grep -q '"access_token"' && echo "$RESPONSE" | grep -q '"user"'; then - echo " ✅ Default admin user has been successfully created." + echo " ✅ Default super admin user has been successfully created." echo "" echo " Please save the following credentials carefully, which would ONLY be shown once." - echo " 📧 Email: nexent@example.com" - echo " 🔏 Password: nexent@4321" + echo " 📧 Email: ${email}" + echo " 🔏 Password: ${password}" elif echo "$RESPONSE" | grep -q '"error_code":"user_already_exists"' || echo "$RESPONSE" | grep -q '"code":422'; then - echo " 🚧 Default admin user already exists. Skipping creation." + echo " 🚧 Default super admin user already exists. Skipping creation." + echo " 📧 Email: ${email}" else echo " ❌ Response from Supabase does not contain 'access_token' or 'user'." return 1 @@ -981,9 +1001,9 @@ main_deploy() { echo "--------------------------------" echo "" - # Create default admin user + # Create default super admin user if [ "$DEPLOYMENT_VERSION" = "full" ]; then - create_default_admin_user || { echo "❌ Default admin user creation failed"; exit 1; } + create_default_super_admin_user || { echo "❌ Default super admin user creation failed"; exit 1; } fi persist_deploy_options diff --git a/docker/init.sql b/docker/init.sql index c547bd07b..a1a52460a 100644 --- a/docker/init.sql +++ b/docker/init.sql @@ -1031,3 +1031,7 @@ ON CONFLICT (role_permission_id) DO NOTHING; INSERT INTO nexent.user_tenant_t (user_id, tenant_id, user_role, user_email, created_by, updated_by) VALUES ('user_id', 'tenant_id', 'SPEED', NULL, 'system', 'system') ON CONFLICT (user_id, tenant_id) DO NOTHING; + +INSERT INTO nexent.user_tenant_t (user_id, tenant_id, user_role, user_email, created_by, updated_by) +VALUES ('suadmin', '', 'SU', NULL, 'system', 'system') +ON CONFLICT (user_id, tenant_id) DO NOTHING; diff --git a/docker/sql/v1.8.0_0202_add_suadmin_user_tenant_t.sql b/docker/sql/v1.8.0_0202_add_suadmin_user_tenant_t.sql new file mode 100644 index 000000000..60c19fd3a --- /dev/null +++ b/docker/sql/v1.8.0_0202_add_suadmin_user_tenant_t.sql @@ -0,0 +1,3 @@ +INSERT INTO nexent.user_tenant_t (user_id, tenant_id, user_role, user_email, created_by, updated_by) +VALUES ('suadmin', '', 'SU', NULL, 'system', 'system') +ON CONFLICT (user_id, tenant_id) DO NOTHING; diff --git a/frontend/app/[locale]/agents/components/AgentConfigComp.tsx b/frontend/app/[locale]/agents/components/AgentConfigComp.tsx index b510b120a..dc87104b5 100644 --- a/frontend/app/[locale]/agents/components/AgentConfigComp.tsx +++ b/frontend/app/[locale]/agents/components/AgentConfigComp.tsx @@ -24,8 +24,6 @@ export default function AgentConfigComp({}: AgentConfigCompProps) { const isCreatingMode = useAgentConfigStore((state) => state.isCreatingMode); - const editable = currentAgentId || isCreatingMode; - const [isMcpModalOpen, setIsMcpModalOpen] = useState(false); const [isRefreshing, setIsRefreshing] = useState(false); diff --git a/frontend/app/[locale]/agents/components/AgentInfoComp.tsx b/frontend/app/[locale]/agents/components/AgentInfoComp.tsx index 2a047f851..30a3d4a03 100644 --- a/frontend/app/[locale]/agents/components/AgentInfoComp.tsx +++ b/frontend/app/[locale]/agents/components/AgentInfoComp.tsx @@ -2,7 +2,7 @@ import { useState } from "react"; import { useTranslation } from "react-i18next"; -import { Row, Col, Flex, Badge, Divider, Button, Drawer, App } from "antd"; +import { Row, Col, Flex, Badge, Divider, Button, Drawer, Tooltip } from "antd"; import { Bug, Save, Info } from "lucide-react"; import { AGENT_SETUP_LAYOUT_DEFAULT } from "@/const/agentConfig"; @@ -19,14 +19,21 @@ export default function AgentInfoComp({}: AgentInfoCompProps) { const { t } = useTranslation("common"); // Get data from store - const { editedAgent, updateBusinessInfo, updateProfileInfo, isCreatingMode } = - useAgentConfigStore(); + const { + editedAgent, + updateBusinessInfo, + updateProfileInfo, + isCreatingMode, + currentAgentPermission, + } = useAgentConfigStore(); // Get state from store const currentAgentId = useAgentConfigStore((state) => state.currentAgentId); - const editable = + const isPanelActive = (currentAgentId != null && currentAgentId != undefined) || isCreatingMode; + const isReadOnly = isPanelActive && !isCreatingMode && currentAgentPermission === "READ_ONLY"; + const isEditable = isPanelActive && !isReadOnly; // Save guard hook const saveGuard = useSaveGuard(); @@ -73,7 +80,7 @@ export default function AgentInfoComp({}: AgentInfoCompProps) { - + + + + + } - {!editable && ( + {!isPanelActive && (
diff --git a/frontend/app/[locale]/agents/components/AgentManageComp.tsx b/frontend/app/[locale]/agents/components/AgentManageComp.tsx index ce0f36525..7552dba09 100644 --- a/frontend/app/[locale]/agents/components/AgentManageComp.tsx +++ b/frontend/app/[locale]/agents/components/AgentManageComp.tsx @@ -125,8 +125,14 @@ export default function AgentManageComp() { !agentInfoLoading && !agentInfoError ) { + const permissionFromList = + agentList.find((a: Agent) => String(a.id) === String(selectedAgentId)) + ?.permission ?? null; // Handle agent switch with unsaved changes check - handleAgentSwitch(agentDetail); + handleAgentSwitch({ + ...agentDetail, + permission: permissionFromList, + }); setSelectedAgentId(null); } else if (selectedAgentId && agentInfoError && !agentInfoLoading) { // Handle error diff --git a/frontend/app/[locale]/agents/components/agentConfig/CollaborativeAgent.tsx b/frontend/app/[locale]/agents/components/agentConfig/CollaborativeAgent.tsx index b39da4580..216f6bcee 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/CollaborativeAgent.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/CollaborativeAgent.tsx @@ -17,6 +17,9 @@ export default function CollaborativeAgent({}: CollaborativeAgentProps) { const currentAgentId = useAgentConfigStore((state) => state.currentAgentId); const isCreatingMode = useAgentConfigStore((state) => state.isCreatingMode); + const currentAgentPermission = useAgentConfigStore( + (state) => state.currentAgentPermission + ); const editedAgent = useAgentConfigStore((state) => state.editedAgent); const updateSubAgentIds = useAgentConfigStore( (state) => state.updateSubAgentIds @@ -24,7 +27,10 @@ export default function CollaborativeAgent({}: CollaborativeAgentProps) { const { availableAgents } = useAgentList(); - const editable = currentAgentId || isCreatingMode; + const editable = + !!isCreatingMode || + ((currentAgentId != null && currentAgentId != undefined) && + currentAgentPermission !== "READ_ONLY"); // Get related agents - use edited agent state (which includes current agent data when editing) const relatedAgentIds = Array.isArray(editedAgent?.sub_agent_id_list) diff --git a/frontend/app/[locale]/agents/components/agentConfig/McpConfigModal.tsx b/frontend/app/[locale]/agents/components/agentConfig/McpConfigModal.tsx index b7a02d384..dedf946d8 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/McpConfigModal.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/McpConfigModal.tsx @@ -1,6 +1,6 @@ "use client"; -import { useState, useEffect } from "react"; +import { useState, useEffect, type ComponentProps } from "react"; import { useTranslation } from "react-i18next"; import { Modal, @@ -100,6 +100,37 @@ export default function McpConfigModal({ const [uploadServiceName, setUploadServiceName] = useState(""); const actionsLocked = updatingTools || addingContainer || uploadingImage; + const noMcpEditPermissionTitle = t("mcpConfig.permission.noEdit"); + + const renderPermissionControlledButton = (props: { + isReadOnly: boolean; + button: Omit, "disabled" | "onClick"> & { + disabled?: boolean; + onClick?: (() => void) | undefined; + }; + }) => { + const { isReadOnly, button } = props; + const { onClick, disabled, ...rest } = button; + + const finalDisabled = Boolean(disabled) || isReadOnly; + const finalOnClick = finalDisabled ? undefined : onClick; + + const element = ( + + + + )} - - + {renderPermissionControlledButton({ + isReadOnly, + button: { + type: "link", + icon: , + onClick: () => onEditServer(record), + size: "small", + disabled: actionsLocked, + children: t("mcpConfig.serverList.button.edit"), + }, + })} + {renderPermissionControlledButton({ + isReadOnly, + button: { + type: "link", + danger: true, + icon: , + onClick: () => onDeleteServer(record), + size: "small", + disabled: actionsLocked, + children: t("mcpConfig.serverList.button.delete"), + }, + })} ); }, @@ -500,29 +538,34 @@ export default function McpConfigModal({ title: t("mcpConfig.containerList.column.action"), key: "action", width: "25%", - render: (_: any, record: any) => ( - - - - - ), + render: (_: any, record: any) => { + const isReadOnly = record.permission === "READ_ONLY"; + return ( + + + {renderPermissionControlledButton({ + isReadOnly, + button: { + type: "link", + danger: true, + icon: , + onClick: () => onDeleteContainer(record), + size: "small", + disabled: actionsLocked, + children: t("mcpConfig.containerList.button.delete"), + }, + })} + + ); + }, }, ]; diff --git a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx index 02ad0e8f6..c4be676d5 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx @@ -7,6 +7,7 @@ import { ToolGroup, Tool, ToolParam } from "@/types/agentConfig"; import { Tabs, Collapse } from "antd"; import { useAgentConfigStore } from "@/stores/agentConfigStore"; import { useToolList } from "@/hooks/agent/useToolList"; +import log from "@/lib/logger"; import { Settings } from "lucide-react"; @@ -27,7 +28,13 @@ export default function ToolManagement({ }: ToolManagementProps) { const { t } = useTranslation("common"); - const editable = currentAgentId || isCreatingMode; + const currentAgentPermission = useAgentConfigStore( + (state) => state.currentAgentPermission + ); + const isPanelActive = + (currentAgentId != null && currentAgentId != undefined) || isCreatingMode; + const editable = + !!isPanelActive && (isCreatingMode || currentAgentPermission !== "READ_ONLY"); // Get state from store const originalSelectedTools = useAgentConfigStore( @@ -80,7 +87,7 @@ export default function ToolManagement({ return defaultTool.initParams || []; } } catch (error) { - console.error("Failed to fetch tool instance params:", error); + log.error("Failed to fetch tool instance params:", error); return defaultTool.initParams || []; } } else { @@ -96,6 +103,7 @@ export default function ToolManagement({ }, [toolGroups, activeTabKey]); const handleToolSettingsClick = async (tool: Tool) => { + if (!editable) return; // Get latest tools directly from store to avoid stale closure issues const currentTools = useAgentConfigStore.getState().editedAgent.tools; const configuredTool = currentTools.find( @@ -119,6 +127,7 @@ export default function ToolManagement({ }; const handleToolClick = async (toolId: string) => { + if (!editable) return; const numericId = parseInt(toolId, 10); const tool = availableTools.find((t) => parseInt(t.id) === numericId); diff --git a/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx b/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx index 6d9bace82..000e639dd 100644 --- a/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx +++ b/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx @@ -1,9 +1,10 @@ "use client"; -import { useState, useEffect, useRef } from "react"; +import { useState, useEffect, useMemo } from "react"; import { useTranslation } from "react-i18next"; import { Button, + Tooltip, Tabs, Form, Input, @@ -35,6 +36,9 @@ import { generatePromptStream } from "@/services/promptService"; import { useAuthorizationContext } from "@/components/providers/AuthorizationProvider"; import { useDeployment } from "@/components/providers/deploymentProvider"; import { useModelList } from "@/hooks/model/useModelList"; +import { useTenantList } from "@/hooks/tenant/useTenantList"; +import { useGroupList } from "@/hooks/group/useGroupList"; +import { USER_ROLES } from "@/const/auth"; import ExpandEditModal from "./ExpandEditModal"; const { TextArea } = Input; @@ -60,13 +64,19 @@ export default function AgentGenerateDetail({ }: AgentGenerateDetailProps) { const { t } = useTranslation("common"); const { message } = App.useApp(); - const { user } = useAuthorizationContext(); + const { user, groupIds: allowedGroupIds } = useAuthorizationContext(); const { isSpeedMode } = useDeployment(); const [form] = Form.useForm(); // Model data from React Query const { availableLlmModels, defaultLlmModel, isLoading: loadingModels } = useModelList(); + // Tenant & group data for group selection + const { data: tenantData } = useTenantList(); + const tenantId = user?.tenantId ?? tenantData?.[0]?.tenant_id ?? null; + const { data: groupData } = useGroupList(tenantId, 1, 100); + const groups = groupData?.groups || []; + // State management const [activeTab, setActiveTab] = useState("agent-info"); @@ -74,6 +84,32 @@ export default function AgentGenerateDetail({ const [expandModalOpen, setExpandModalOpen] = useState(false); const [expandModalType, setExpandModalType] = useState<'duty' | 'constraint' | 'few-shots' | null>(null); + // Only show "no edit permission" tooltip when the panel is active and agent is read-only. + // Note: when no agent is selected, AgentInfoComp shows an overlay and we should not show + // this tooltip in that state. + const showNoEditPermissionTip = + !editable && currentAgentId !== null && currentAgentId !== undefined; + + const noEditPermissionTitle = showNoEditPermissionTip + ? t("agent.noEditPermission") + : undefined; + + const wrapNoEditTooltipBlock = (node: React.ReactNode) => { + return ( + + {node} + + ); + }; + + const wrapNoEditTooltipInline = (node: React.ReactNode) => { + return ( + + {node} + + ); + }; + // Ensure tenant config is loaded for default model selection useEffect(() => { @@ -122,9 +158,57 @@ export default function AgentGenerateDetail({ businessLogicModelId: 0, }); + const normalizeNumberArray = (value: unknown): number[] => { + const arr = Array.isArray(value) ? value : []; + return Array.from( + new Set(arr.map((id) => Number(id)).filter((id) => Number.isFinite(id))) + ).sort((a, b) => a - b); + }; + + const groupSelectOptions = useMemo(() => { + const selectedIds = normalizeNumberArray(editedAgent.group_ids || []); + const allowedSet = new Set(normalizeNumberArray(allowedGroupIds || [])); + const canSelectAllGroups = + user?.role === USER_ROLES.SU || + user?.role === USER_ROLES.ADMIN || + user?.role === USER_ROLES.SPEED; + + const baseGroups = canSelectAllGroups + ? groups + : groups.filter((g) => allowedSet.has(g.group_id)); + + const baseSet = new Set(baseGroups.map((g) => g.group_id)); + const groupById = new Map(groups.map((g) => [g.group_id, g] as const)); + + const options: Array<{ label: string; value: number; disabled?: boolean }> = + baseGroups.map((g) => ({ + label: g.group_name, + value: g.group_id, + })); + + // Keep already-selected groups visible even if they are not selectable (disabled). + for (const id of selectedIds) { + if (baseSet.has(id)) continue; + const g = groupById.get(id); + options.push({ + label: g?.group_name ?? `Group ${id}`, + value: id, + disabled: true, + }); + } + + return options; + }, [allowedGroupIds, editedAgent.group_ids, groups, user?.role]); + // Initialize form values when component mounts or currentAgentId changes useEffect(() => { - const initialAgentInfo = { + const isCreateMode = editable && (currentAgentId === null || currentAgentId === undefined); + + // Note: + // In create mode, do not set group_ids here. Otherwise, when switching from an existing + // agent to create mode (currentAgentId changes to null), this initializer can overwrite + // the default-group selection effect and leave group_ids empty. + const initialAgentInfo: Record = { agentName: editedAgent.name || "", agentDisplayName: editedAgent.display_name || "", agentAuthor: editedAgent.author || "", @@ -132,11 +216,16 @@ export default function AgentGenerateDetail({ editedAgent.model || defaultLlmModel?.displayName || "", mainAgentMaxStep: editedAgent.max_step || 5, agentDescription: editedAgent.description || "", + group_ids: normalizeNumberArray(editedAgent.group_ids || []), dutyPrompt: editedAgent.duty_prompt || "", constraintPrompt: editedAgent.constraint_prompt || "", fewShotsPrompt: editedAgent.few_shots_prompt || "", }; + if (isCreateMode) { + delete initialAgentInfo.group_ids; + } + const initialBusinessInfo = { businessDescription: editedAgent.business_description || "", businessLogicModelName: @@ -150,7 +239,37 @@ export default function AgentGenerateDetail({ setBusinessInfo(initialBusinessInfo); form.setFieldsValue(initialAgentInfo); - }, [currentAgentId, editedAgent, availableLlmModels, defaultLlmModel]); + // We intentionally initialize the form only when switching agent (or when default model becomes available), + // otherwise it can create update loops with Form-controlled fields updating the store. + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [currentAgentId, defaultLlmModel?.id]); + + // Default to selecting all groups when creating a new agent. + // Only applies when groups are loaded and no group is selected yet. + useEffect(() => { + const isCreateMode = editable && (currentAgentId === null || currentAgentId === undefined); + if (!isCreateMode) return; + if (!groups || groups.length === 0) return; + + const currentGroupIds = normalizeNumberArray(editedAgent.group_ids || []); + if (currentGroupIds.length > 0) return; + + const allowedSet = new Set(normalizeNumberArray(allowedGroupIds || [])); + const canSelectAllGroups = + user?.role === USER_ROLES.SU || + user?.role === USER_ROLES.ADMIN || + user?.role === USER_ROLES.SPEED; + const selectableGroups = canSelectAllGroups + ? groups + : groups.filter((g) => allowedSet.has(g.group_id)); + + const allGroupIds = normalizeNumberArray(selectableGroups.map((g) => g.group_id)); + if (allGroupIds.length === 0) return; + + form.setFieldsValue({ group_ids: allGroupIds }); + onUpdateProfile({ group_ids: allGroupIds }); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [editable, currentAgentId, groups, allowedGroupIds, user?.role]); // Handle business description change const handleBusinessDescriptionChange = (value: string) => { @@ -175,10 +294,61 @@ export default function AgentGenerateDetail({ // Handle expand modal functions const handleOpenExpandModal = (type: 'duty' | 'constraint' | 'few-shots') => { + if (!editable) return; setExpandModalType(type); setExpandModalOpen(true); }; + const renderExpandButton = (type: "duty" | "constraint" | "few-shots") => { + return wrapNoEditTooltipInline( +