(
onMultimodalChange(!isMultimodal)}
- onClick={() =>
- canToggleMultimodal &&
- onMultimodalChange &&
- onMultimodalChange(!isMultimodal)
- }
+ onClick={() => {
+ if (canToggleMultimodal) {
+ onMultimodalChange?.(!isMultimodal);
+ }
+ }}
+ onKeyDown={(event) => {
+ if (!canToggleMultimodal) {
+ return;
+ }
+ if (event.key === "Enter" || event.key === " ") {
+ event.preventDefault();
+ onMultimodalChange?.(!isMultimodal);
+ }
+ }}
+ role="button"
+ tabIndex={canToggleMultimodal ? 0 : -1}
+ aria-pressed={isMultimodal}
+ aria-disabled={!canToggleMultimodal}
style={{
width: 80, // Keep width aligned with adjacent controls
- display: 'inline-block',
- textAlign: 'center',
- cursor: canToggleMultimodal ? 'pointer' : 'default',
- color: isMultimodal ? '#52c41a' : '#000000', // Success green when enabled, black when disabled
+ display: "inline-block",
+ textAlign: "center",
+ cursor: canToggleMultimodal ? "pointer" : "default",
+ color: isMultimodal ? "#52c41a" : "#000000", // Success green when enabled, black when disabled
fontWeight: isMultimodal ? 500 : 400, // Slightly bolder when enabled
- userSelect: 'none', // Prevent text selection on double-click
- lineHeight: '32px' // Align with Select height
+ userSelect: "none", // Prevent text selection on double-click
+ lineHeight: "32px", // Align with Select height
+ outline: "none",
}}
- title={isMultimodal ? "Multimodal: Enabled" : "Multimodal: Disabled"}
+ title={
+ isMultimodal
+ ? "Multimodal: Enabled"
+ : "Multimodal: Disabled"
+ }
>
Multimodal
diff --git a/frontend/components/tool-config/KnowledgeBaseSelectorModal.tsx b/frontend/components/tool-config/KnowledgeBaseSelectorModal.tsx
index fb8a256a1..6c78c279b 100644
--- a/frontend/components/tool-config/KnowledgeBaseSelectorModal.tsx
+++ b/frontend/components/tool-config/KnowledgeBaseSelectorModal.tsx
@@ -186,6 +186,50 @@ export default function KnowledgeBaseSelectorModal({
}
}, []);
+ const isMultimodalConstraintMismatch = useCallback(
+ (kb: KnowledgeBase) => {
+ return (
+ toolMultimodal !== null &&
+ ((toolMultimodal && !kb.is_multimodal) ||
+ (!toolMultimodal && kb.is_multimodal))
+ );
+ },
+ [toolMultimodal]
+ );
+
+ const isEmbeddingModelCompatible = useCallback(
+ (kb: KnowledgeBase) => {
+ if (kb.is_multimodal) {
+ if (!currentMultiEmbeddingModel) {
+ return false;
+ }
+ if (
+ kb.embeddingModel &&
+ kb.embeddingModel !== "unknown" &&
+ kb.embeddingModel !== currentMultiEmbeddingModel
+ ) {
+ return false;
+ }
+ return true;
+ }
+
+ if (!currentEmbeddingModel) {
+ return true;
+ }
+
+ if (
+ kb.embeddingModel &&
+ kb.embeddingModel !== "unknown" &&
+ kb.embeddingModel !== currentEmbeddingModel
+ ) {
+ return false;
+ }
+
+ return true;
+ },
+ [currentEmbeddingModel, currentMultiEmbeddingModel]
+ );
+
// Check if a knowledge base can be selected
const checkCanSelect = useCallback(
(kb: KnowledgeBase): boolean => {
@@ -204,79 +248,52 @@ export default function KnowledgeBaseSelectorModal({
// 2. For nexent source, check model matching
if (kb.source === "nexent") {
- const hasMultimodalConstraintMismatch =
- toolMultimodal !== null &&
- ((toolMultimodal && !kb.is_multimodal) ||
- (!toolMultimodal && kb.is_multimodal));
-
- if (hasMultimodalConstraintMismatch) {
- return false;
- }
-
- if (kb.is_multimodal) {
- if (!currentMultiEmbeddingModel) {
- return false;
- }
- if (
- kb.embeddingModel &&
- kb.embeddingModel !== "unknown" &&
- kb.embeddingModel !== currentMultiEmbeddingModel
- ) {
- return false;
- }
- } else if (
- kb.embeddingModel &&
- kb.embeddingModel !== "unknown" &&
- kb.embeddingModel !== currentEmbeddingModel
- ) {
+ if (isMultimodalConstraintMismatch(kb)) {
return false;
}
+ return isEmbeddingModelCompatible(kb);
}
return true;
},
[
isSelectable,
- currentEmbeddingModel,
- currentMultiEmbeddingModel,
- toolMultimodal,
+ isEmbeddingModelCompatible,
+ isMultimodalConstraintMismatch,
]
);
// Check if a knowledge base has model mismatch (for display purposes)
- const checkModelMismatch = useCallback(
- (kb: KnowledgeBase): boolean => {
- if (kb.source !== "nexent") {
- return false;
- }
+ const checkModelMismatch = (kb: KnowledgeBase): boolean => {
+ if (kb.source !== "nexent") {
+ return false;
+ }
- const hasMultimodalConstraintMismatch =
- toolMultimodal !== null &&
- ((toolMultimodal && !kb.is_multimodal) ||
- (!toolMultimodal && kb.is_multimodal));
- if (hasMultimodalConstraintMismatch) {
- return true;
- }
+ const hasMultimodalConstraintMismatch =
+ toolMultimodal !== null &&
+ ((toolMultimodal && !kb.is_multimodal) ||
+ (!toolMultimodal && kb.is_multimodal));
+ if (hasMultimodalConstraintMismatch) {
+ return true;
+ }
- const embeddingModel = kb.embeddingModel;
- if (!embeddingModel || embeddingModel === "unknown") {
- return false;
- }
+ const embeddingModel = kb.embeddingModel;
+ if (!embeddingModel || embeddingModel === "unknown") {
+ return false;
+ }
- if (kb.is_multimodal) {
- if (!currentMultiEmbeddingModel) {
- return true;
- }
- return embeddingModel !== currentMultiEmbeddingModel;
+ if (kb.is_multimodal) {
+ if (!currentMultiEmbeddingModel) {
+ return true;
}
+ return embeddingModel !== currentMultiEmbeddingModel;
+ }
- if (!currentEmbeddingModel) {
- return false;
- }
- return embeddingModel !== currentEmbeddingModel;
- },
- [currentEmbeddingModel, currentMultiEmbeddingModel, toolMultimodal]
- );
+ if (!currentEmbeddingModel) {
+ return false;
+ }
+ return embeddingModel !== currentEmbeddingModel;
+ };
// Filter knowledge bases based on tool type, search, and filters
const filteredKnowledgeBases = useMemo(() => {
diff --git a/frontend/services/api.ts b/frontend/services/api.ts
index f12ab92de..2e115044b 100644
--- a/frontend/services/api.ts
+++ b/frontend/services/api.ts
@@ -132,7 +132,7 @@ export const API_ENDPOINTS = {
customModelHealthcheck: (displayName: string, modelType: string) =>
`${API_BASE_URL}/model/healthcheck?display_name=${encodeURIComponent(
displayName
- )}&modelType=${encodeURIComponent(modelType)}`,
+ )}&model_type=${encodeURIComponent(modelType)}`,
verifyModelConfig: `${API_BASE_URL}/model/temporary_healthcheck`,
updateSingleModel: (displayName: string) =>
`${API_BASE_URL}/model/update?display_name=${encodeURIComponent(displayName)}`,
diff --git a/sdk/nexent/core/tools/knowledge_base_search_tool.py b/sdk/nexent/core/tools/knowledge_base_search_tool.py
index ea1277793..48f43c5ef 100644
--- a/sdk/nexent/core/tools/knowledge_base_search_tool.py
+++ b/sdk/nexent/core/tools/knowledge_base_search_tool.py
@@ -126,40 +126,80 @@ def forward(self, query: str, index_names: List[str]) -> str:
f"KnowledgeBaseSearchTool called with query: '{query}', search_mode: '{search_mode}', index_names: {search_index_names}"
)
- if len(search_index_names) == 0:
- return json.dumps("No knowledge base selected. No relevant information found.", ensure_ascii=False)
-
- if search_mode == "hybrid":
- kb_search_data = self.search_hybrid(
- query=query, index_names=search_index_names)
- elif search_mode == "accurate":
- kb_search_data = self.search_accurate(
- query=query, index_names=search_index_names)
- elif search_mode == "semantic":
- kb_search_data = self.search_semantic(
- query=query, index_names=search_index_names)
- else:
- raise Exception(
- f"Invalid search mode: {search_mode}, only support: hybrid, accurate, semantic")
+ if not search_index_names:
+ return json.dumps(
+ "No knowledge base selected. No relevant information found.",
+ ensure_ascii=False,
+ )
+ kb_search_data = self._run_search(
+ query=query, index_names=search_index_names, search_mode=search_mode
+ )
kb_search_results = kb_search_data["results"]
if not kb_search_results:
raise Exception(
"No results found! Try a less restrictive/shorter query.")
- search_results_json = [] # Organize search results into a unified format
- search_results_return = [] # Format for input to the large model
-
+ (
+ search_results_json,
+ search_results_return,
+ images_list_url,
+ ) = self._build_search_results(kb_search_results)
+
+ self.record_ops += len(search_results_return)
+
+ self._record_search_results(
+ search_results_json=search_results_json,
+ images_list_url=images_list_url,
+ query=query,
+ )
+
+ return json.dumps(search_results_return, ensure_ascii=False)
+
+ def _notify_search_start(self, query: str) -> None:
+ if not self.observer:
+ return
+ running_prompt = (
+ self.running_prompt_zh
+ if self.observer.lang == "zh"
+ else self.running_prompt_en
+ )
+ self.observer.add_message("", ProcessType.TOOL, running_prompt)
+ card_content = [{"icon": "search", "text": query}]
+ self.observer.add_message(
+ "", ProcessType.CARD, json.dumps(card_content, ensure_ascii=False)
+ )
+
+ def _run_search(self, query: str, index_names: List[str], search_mode: str):
+ search_handlers = {
+ "hybrid": self.search_hybrid,
+ "accurate": self.search_accurate,
+ "semantic": self.search_semantic,
+ }
+ handler = search_handlers.get(search_mode)
+ if not handler:
+ raise Exception(
+ f"Invalid search mode: {search_mode}, only support: hybrid, accurate, semantic"
+ )
+ return handler(query=query, index_names=index_names)
+
+ @staticmethod
+ def _normalize_source_type(source_type: str) -> str:
+ return "file" if source_type in ["local", "minio"] else source_type
+
+ def _build_search_results(self, kb_search_results):
+ search_results_json = []
+ search_results_return = []
images_list_url = []
+
for index, single_search_result in enumerate(kb_search_results):
- # Temporarily correct the source_type stored in the knowledge base
- source_type = single_search_result.get("source_type", "")
- source_type = "file" if source_type in [
- "local", "minio"] else source_type
- title = single_search_result.get("title")
- if not title:
- title = single_search_result.get("filename", "")
+ source_type = self._normalize_source_type(
+ single_search_result.get("source_type", "")
+ )
+ title = single_search_result.get("title") or single_search_result.get(
+ "filename", ""
+ )
search_result_message = SearchResultTextMessage(
title=title,
text=single_search_result.get("content", ""),
@@ -173,39 +213,52 @@ def forward(self, query: str, index_names: List[str]) -> str:
search_type=self.name,
tool_sign=self.tool_sign,
)
-
- if single_search_result.get('process_source') == 'UniversalImageExtractor':
- try:
- meta_data = json.loads(single_search_result.get('content'))
- except (json.JSONDecodeError, TypeError):
- logger.error("Failed to parse image metadata")
- img_url = meta_data.get("image_url", None)
- if img_url:
- images_list_url.append(img_url)
+
+ image_url = self._extract_image_url(single_search_result)
+ if image_url:
+ images_list_url.append(image_url)
search_results_json.append(search_result_message.to_dict())
search_results_return.append(search_result_message.to_model_dict())
- self.record_ops += len(search_results_return)
+ return search_results_json, search_results_return, images_list_url
- # Record the detailed content of this search
- if self.observer:
- search_results_data = json.dumps(
- search_results_json, ensure_ascii=False)
+ @staticmethod
+ def _extract_image_url(single_search_result):
+ if single_search_result.get("process_source") != "UniversalImageExtractor":
+ return None
+ try:
+ meta_data = json.loads(single_search_result.get("content"))
+ except (json.JSONDecodeError, TypeError):
+ logger.error("Failed to parse image metadata")
+ return None
+ return meta_data.get("image_url", None)
+
+ def _record_search_results(
+ self,
+ search_results_json: List[dict],
+ images_list_url: List[str],
+ query: str,
+ ) -> None:
+ if not self.observer:
+ return
+
+ search_results_data = json.dumps(
+ search_results_json, ensure_ascii=False)
+ self.observer.add_message(
+ "", ProcessType.SEARCH_CONTENT, search_results_data)
+
+ if not images_list_url:
+ return
+
+ print("img list: ", images_list_url)
+ final_filtered_images = self._filter_images(images_list_url, query)
+ print("final_list: ", final_filtered_images)
+ if final_filtered_images:
+ search_images_list_json = json.dumps(
+ {"images_url": images_list_url}, ensure_ascii=False)
self.observer.add_message(
- "", ProcessType.SEARCH_CONTENT, search_results_data)
-
- if len(images_list_url) > 0:
- print("img list: ", images_list_url)
- final_filtered_images = self._filter_images(
- images_list_url, query)
- print("final_list: ", final_filtered_images)
- if len(final_filtered_images) > 0:
- search_images_list_json = json.dumps(
- {"images_url": images_list_url}, ensure_ascii=False)
- self.observer.add_message(
- "", ProcessType.PICTURE_WEB, search_images_list_json)
- return json.dumps(search_results_return, ensure_ascii=False)
+ "", ProcessType.PICTURE_WEB, search_images_list_json)
def search_hybrid(self, query, index_names):
try:
diff --git a/sdk/nexent/data_process/core.py b/sdk/nexent/data_process/core.py
index 6a0f9af7c..84bff7c5a 100644
--- a/sdk/nexent/data_process/core.py
+++ b/sdk/nexent/data_process/core.py
@@ -1,6 +1,6 @@
import logging
import os
-from typing import Dict, List, Optional
+from typing import Dict, List, Optional, Tuple
from .extract_image import UniversalImageExtractor
@@ -55,7 +55,7 @@ def file_process(
chunking_strategy: str = "basic",
processor: Optional[str] = None,
**params,
- ) -> List[Dict]:
+ ) -> Tuple[List[Dict], List[Dict]]:
"""
Facade pattern that automatically detects file type and processes files
@@ -68,11 +68,13 @@ def file_process(
**params: Additional processing parameters
Returns:
- List of processed chunks, each dictionary contains the following fields:
+ Tuple[List[Dict], List[Dict]]: (chunks, images_info)
+ chunks: List of processed chunks, each dictionary contains the following fields:
- content: Text content
- filename: Filename
- metadata: Metadata (optional, includes chunk_index, source_type, etc.)
- language: Language identifier (optional)
+ images_info: List of extracted image metadata dicts (may be empty)
Raises:
ValueError: Invalid parameters
@@ -129,7 +131,7 @@ def _validate_parameters(self, chunking_strategy: str, processor: Optional[str])
logger.debug(
f"Parameter validation passed: chunking_strategy={chunking_strategy}, processor={processor}")
- def _select_processor_by_filename(self, filename: str) -> str:
+ def _select_processor_by_filename(self, filename: str) -> Tuple[str, Optional[str]]:
"""Selects a processor based on the file extension."""
_, file_extension = os.path.splitext(filename)
file_extension = file_extension.lower()
diff --git a/sdk/nexent/data_process/extract_image.py b/sdk/nexent/data_process/extract_image.py
index c764f4be2..0134a2733 100644
--- a/sdk/nexent/data_process/extract_image.py
+++ b/sdk/nexent/data_process/extract_image.py
@@ -198,7 +198,7 @@ def _extract_excel(self, xlsx_path):
if drawing is None:
continue
- rId = drawing.get(
+ rel_id = drawing.get(
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id")
rel_path = sheet_file.replace(
"worksheets", "worksheets/_rels") + ".rels"
@@ -210,7 +210,7 @@ def _extract_excel(self, xlsx_path):
drawing_file = None
for r in rel_xml:
- if r.get("Id") == rId:
+ if r.get("Id") == rel_id:
drawing_file = "xl/" + \
r.get("Target").replace("../", "")
break
@@ -254,12 +254,12 @@ def _extract_excel(self, xlsx_path):
if blip is None:
continue
- rId = blip.get(
+ embed_rel_id = blip.get(
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed")
- if rId not in rel_map:
+ if embed_rel_id not in rel_map:
continue
- img_bytes = z.read(rel_map[rId])
+ img_bytes = z.read(rel_map[embed_rel_id])
h = self._hash(img_bytes)
if h in seen:
@@ -340,24 +340,24 @@ def process_file(self, file_bytes: bytes, chunking_strategy: str, filename: str,
converted_path = None
try:
- if suffix == ".xlsx":
- return self._extract_excel(temp_path)
- if suffix == ".xls":
- converted_path = self._convert_file(temp_path, "xlsx")
- return self._extract_excel(converted_path)
-
- if suffix == ".pptx":
- return self._extract_pptx(temp_path, **params)
- if suffix == ".ppt":
- converted_path = self._convert_file(temp_path, "pptx")
- return self._extract_pptx(converted_path, **params)
-
- if suffix in [".docx", ".doc"]:
- converted_path = self._convert_file(temp_path, "pdf")
- return self._extract_pdf(converted_path, **params)
-
- if suffix == ".pdf":
- return self._extract_pdf(temp_path, **params)
+ direct_extractors = {
+ ".xlsx": lambda: self._extract_excel(temp_path),
+ ".pptx": lambda: self._extract_pptx(temp_path, **params),
+ ".pdf": lambda: self._extract_pdf(temp_path, **params),
+ }
+ if suffix in direct_extractors:
+ return direct_extractors[suffix]()
+
+ conversions = {
+ ".xls": ("xlsx", lambda path: self._extract_excel(path)),
+ ".ppt": ("pptx", lambda path: self._extract_pptx(path, **params)),
+ ".docx": ("pdf", lambda path: self._extract_pdf(path, **params)),
+ ".doc": ("pdf", lambda path: self._extract_pdf(path, **params)),
+ }
+ if suffix in conversions:
+ target_format, extractor = conversions[suffix]
+ converted_path = self._convert_file(temp_path, target_format)
+ return extractor(converted_path)
return []
@@ -377,37 +377,4 @@ def process_file(self, file_bytes: bytes, chunking_strategy: str, filename: str,
try:
os.remove(f_path)
except Exception:
- pass
-
-
-if __name__ == "__main__":
- extractor = UniversalImageExtractor()
-
- input_path = r"C:\Users\pc\Desktop\files\docx.docx"
-
- output_dir = "output_images"
- os.makedirs(output_dir, exist_ok=True)
-
- if not os.path.exists(input_path):
- print(f"Error: Input file not found: {input_path}")
- else:
- with open(input_path, "rb") as f:
- file_bytes = f.read()
-
- images = extractor.process_file(
- file_bytes, os.path.basename(input_path))
- if not images:
- print("No images found.")
- else:
- for i, img_info in enumerate(images, start=1):
- img_data = img_info["image_bytes"]
- img_fmt = img_info.get("image_format", "png")
- img_filename = f"image_{i}.{img_fmt}"
- img_path = os.path.join(output_dir, img_filename)
-
- with open(img_path, "wb") as f:
- f.write(img_data)
-
- print(f"Saved: {img_path} (Format: {img_fmt})")
-
- print(f"\nFinished. Total {len(images)} images extracted.")
+ pass
\ No newline at end of file
diff --git a/sdk/nexent/vector_database/elasticsearch_core.py b/sdk/nexent/vector_database/elasticsearch_core.py
index ceaa6bfe5..190a98554 100644
--- a/sdk/nexent/vector_database/elasticsearch_core.py
+++ b/sdk/nexent/vector_database/elasticsearch_core.py
@@ -401,44 +401,18 @@ def _small_batch_insert(
# Preprocess documents
processed_docs = self._preprocess_documents(
documents, content_field)
-
- if embedding_model.model_type == "multimodal":
- # kerry
- inputs = []
- for doc in processed_docs:
-
- # Get embeddings
- if doc.get("process_source") == "UniversalImageExtractor":
- img_bytes = doc.pop("image_bytes", "")
- if len(img_bytes) > 0:
- image_base64_str = base64.b64encode(
- img_bytes).decode('utf-8')
- data = f"data:image/jpeg;base64,{image_base64_str}"
- inputs.append({"image": data})
- else:
- inputs.append({"text": doc[content_field]})
-
- # Get embeddings
- embeddings = embedding_model.get_multimodal_embeddings(inputs)
- else:
- processed_docs[:] = [doc for doc in processed_docs if doc.get(
- "process_source") != "UniversalImageExtractor"]
- inputs = [doc[content_field] for doc in processed_docs]
- embeddings = embedding_model.get_embeddings(inputs)
+ processed_docs, embeddings = self._prepare_small_batch_embeddings(
+ processed_docs, content_field, embedding_model
+ )
- # # Get embeddings
- # inputs = [doc[content_field] for doc in processed_docs]
- # embeddings = embedding_model.get_embeddings(inputs)
# Prepare bulk operations
- operations = []
- for doc, embedding in zip(processed_docs, embeddings):
- operations.append({"index": {"_index": index_name}})
- doc["multi_embedding" if doc["process_source"]
- == "UniversalImageExtractor"else "embedding"] = embedding
- if "embedding_model_name" not in doc:
- doc["embedding_model_name"] = embedding_model.embedding_model_name
- operations.append(doc)
+ operations = self._build_bulk_operations(
+ index_name=index_name,
+ processed_docs=processed_docs,
+ embeddings=embeddings,
+ embedding_model=embedding_model,
+ )
indexed_count = len(processed_docs)
if indexed_count == 0:
@@ -467,6 +441,57 @@ def _small_batch_insert(
logger.error(f"Small batch insert failed: {e}")
raise
+ def _prepare_small_batch_embeddings(
+ self,
+ processed_docs: List[Dict[str, Any]],
+ content_field: str,
+ embedding_model: BaseEmbedding,
+ ):
+ if embedding_model.model_type == "multimodal":
+ inputs = []
+ for doc in processed_docs:
+ if doc.get("process_source") == "UniversalImageExtractor":
+ img_bytes = doc.pop("image_bytes", "")
+ if len(img_bytes) > 0:
+ image_base64_str = base64.b64encode(
+ img_bytes).decode("utf-8")
+ data = f"data:image/jpeg;base64,{image_base64_str}"
+ inputs.append({"image": data})
+ else:
+ inputs.append({"text": doc[content_field]})
+ embeddings = embedding_model.get_multimodal_embeddings(inputs)
+ return processed_docs, embeddings
+
+ filtered_docs = [
+ doc
+ for doc in processed_docs
+ if doc.get("process_source") != "UniversalImageExtractor"
+ ]
+ inputs = [doc[content_field] for doc in filtered_docs]
+ embeddings = embedding_model.get_embeddings(inputs)
+ return filtered_docs, embeddings
+
+ @staticmethod
+ def _build_bulk_operations(
+ index_name: str,
+ processed_docs: List[Dict[str, Any]],
+ embeddings: List[Any],
+ embedding_model: BaseEmbedding,
+ ) -> List[Dict[str, Any]]:
+ operations = []
+ for doc, embedding in zip(processed_docs, embeddings):
+ operations.append({"index": {"_index": index_name}})
+ embedding_field = (
+ "multi_embedding"
+ if doc.get("process_source") == "UniversalImageExtractor"
+ else "embedding"
+ )
+ doc[embedding_field] = embedding
+ if "embedding_model_name" not in doc:
+ doc["embedding_model_name"] = embedding_model.embedding_model_name
+ operations.append(doc)
+ return operations
+
def _large_batch_insert(
self,
index_name: str,
From d1591fe7bef42fc0d0922a1590127da7eb13c871 Mon Sep 17 00:00:00 2001
From: wyxkerry <1012700194@qq.com>
Date: Tue, 31 Mar 2026 11:08:57 +0800
Subject: [PATCH 03/11] =?UTF-8?q?=E2=9C=A8add=5Fimage=5Fretrieval?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
backend/apps/model_managment_app.py | 6 +-
backend/services/vectordatabase_service.py | 25 ++-
docker/deploy.sh | 19 ++-
.../components/document/DocumentList.tsx | 22 +--
.../KnowledgeBaseSelectorModal.tsx | 55 +++---
.../core/tools/knowledge_base_search_tool.py | 4 +-
sdk/nexent/data_process/extract_image.py | 159 ++++++++++--------
.../vector_database/elasticsearch_core.py | 4 -
.../test_tool_configuration_service.py | 1 +
.../services/test_vectordatabase_service.py | 1 +
test/sdk/data_process/test_core.py | 3 -
11 files changed, 165 insertions(+), 134 deletions(-)
diff --git a/backend/apps/model_managment_app.py b/backend/apps/model_managment_app.py
index fa7759264..c04c577f5 100644
--- a/backend/apps/model_managment_app.py
+++ b/backend/apps/model_managment_app.py
@@ -33,7 +33,7 @@
from fastapi.responses import JSONResponse
from fastapi.encoders import jsonable_encoder
from http import HTTPStatus
-from typing import List, Optional
+from typing import Annotated, List, Optional
from services.model_health_service import (
check_model_connectivity,
verify_model_config_connectivity,
@@ -297,8 +297,8 @@ async def get_llm_model_list(authorization: Optional[str] = Header(None)):
@router.post("/healthcheck")
async def check_model_health(
- display_name: str = Query(..., description="Display name to check"),
- model_type: str = Query(..., description="..."),
+ display_name: Annotated[str, Query(..., description="Display name to check")],
+ model_type: Annotated[str, Query(..., description="...")],
authorization: Optional[str] = Header(None)
):
"""Check and update model connectivity, returning the latest status.
diff --git a/backend/services/vectordatabase_service.py b/backend/services/vectordatabase_service.py
index 3c60aa283..38502afe3 100644
--- a/backend/services/vectordatabase_service.py
+++ b/backend/services/vectordatabase_service.py
@@ -256,6 +256,9 @@ def get_embedding_model(tenant_id: str, is_multimodal: bool = False, model_name:
Returns:
Embedding model instance or None
"""
+ if model_name is None and (isinstance(is_multimodal, str) or is_multimodal is None):
+ model_name = is_multimodal
+ is_multimodal = False
# If model_name is provided, try to find it in the tenant's models
if model_name:
try:
@@ -443,6 +446,7 @@ def create_knowledge_base(
tenant_id: Optional[str],
ingroup_permission: Optional[str] = None,
group_ids: Optional[List[int]] = None,
+ embedding_model_name: Optional[str] = None,
is_multimodal: bool = False,
):
"""
@@ -468,7 +472,20 @@ def create_knowledge_base(
with an explicit index_name.
"""
try:
- embedding_model = get_embedding_model(tenant_id, is_multimodal,)
+ if embedding_model_name is None:
+ if is_multimodal:
+ embedding_model = get_embedding_model(tenant_id, is_multimodal=True)
+ else:
+ embedding_model = get_embedding_model(tenant_id, None)
+ else:
+ if is_multimodal:
+ embedding_model = get_embedding_model(
+ tenant_id,
+ is_multimodal=True,
+ model_name=embedding_model_name,
+ )
+ else:
+ embedding_model = get_embedding_model(tenant_id, embedding_model_name)
# Create knowledge record first to obtain knowledge_id and generated index_name
knowledge_data = {
@@ -476,7 +493,11 @@ def create_knowledge_base(
"knowledge_describe": "",
"user_id": user_id,
"tenant_id": tenant_id,
- "embedding_model_name": embedding_model.model if embedding_model else None,
+ "embedding_model_name": (
+ embedding_model_name
+ if embedding_model_name is not None
+ else (embedding_model.model if embedding_model else None)
+ ),
"is_multimodal": is_multimodal,
}
diff --git a/docker/deploy.sh b/docker/deploy.sh
index 5d41c1cf9..233c14604 100755
--- a/docker/deploy.sh
+++ b/docker/deploy.sh
@@ -631,7 +631,8 @@ download_and_config_models() {
TT_MODEL_DIR_NAME="table-transformer-structure-recognition"
TT_MODEL_DIR_PATH="$MODEL_ROOT/$TT_MODEL_DIR_NAME"
- TT_MODEL_FILE_CHECK="$TT_MODEL_DIR_PATH/model.safetensors"
+ MODEL_SAFETENSORS_FILE="model.safetensors"
+ TT_MODEL_FILE_CHECK="$TT_MODEL_DIR_PATH/$MODEL_SAFETENSORS_FILE"
cd "$MODEL_ROOT" || return 1
@@ -661,29 +662,29 @@ download_and_config_models() {
cd "$TT_MODEL_DIR_NAME" || return 1
echo "INFO: Step 2/2: Download model.safetensors..."
- LARGE_FILE_URL="$HF_ENDPOINT/microsoft/$TT_MODEL_DIR_NAME/resolve/main/model.safetensors"
+ LARGE_FILE_URL="$HF_ENDPOINT/microsoft/$TT_MODEL_DIR_NAME/resolve/main/$MODEL_SAFETENSORS_FILE"
if command -v curl &> /dev/null; then
- curl -L -o "model.safetensors" "$LARGE_FILE_URL" --progress-bar
+ curl -L -o "$MODEL_SAFETENSORS_FILE" "$LARGE_FILE_URL" --progress-bar
elif command -v wget &> /dev/null; then
- wget "$LARGE_FILE_URL" -O "model.safetensors"
+ wget "$LARGE_FILE_URL" -O "$MODEL_SAFETENSORS_FILE"
else
echo "ERROR: curl or wget is required to download model files." >&2
cd "$MODEL_ROOT"; rm -rf "$TT_MODEL_DIR_NAME"; cd "$original_dir"; return 1
fi
- if [[ ! -f "model.safetensors" ]]; then
- echo "ERROR: model.safetensors download failed." >&2
+ if [[ ! -f "$MODEL_SAFETENSORS_FILE" ]]; then
+ echo "ERROR: $MODEL_SAFETENSORS_FILE download failed." >&2
cd "$MODEL_ROOT"; rm -rf "$TT_MODEL_DIR_NAME"; cd "$original_dir"; return 1
fi
- FILE_SIZE=$(stat -c%s "model.safetensors" 2>/dev/null || stat -f%z "model.safetensors" 2>/dev/null)
+ FILE_SIZE=$(stat -c%s "$MODEL_SAFETENSORS_FILE" 2>/dev/null || stat -f%z "$MODEL_SAFETENSORS_FILE" 2>/dev/null)
if [[ "$FILE_SIZE" -lt 1000000 ]]; then
- echo "ERROR: model.safetensors seems too small (size: $FILE_SIZE bytes)." >&2
+ echo "ERROR: $MODEL_SAFETENSORS_FILE seems too small (size: $FILE_SIZE bytes)." >&2
cd "$MODEL_ROOT"; rm -rf "$TT_MODEL_DIR_NAME"; cd "$original_dir"; return 1
fi
- echo "INFO: model.safetensors downloaded (size: $(du -h model.safetensors | cut -f1))"
+ echo "INFO: $MODEL_SAFETENSORS_FILE downloaded (size: $(du -h "$MODEL_SAFETENSORS_FILE" | cut -f1))"
cd "$MODEL_ROOT"
fi
diff --git a/frontend/app/[locale]/knowledges/components/document/DocumentList.tsx b/frontend/app/[locale]/knowledges/components/document/DocumentList.tsx
index 7374f1278..256917d77 100644
--- a/frontend/app/[locale]/knowledges/components/document/DocumentList.tsx
+++ b/frontend/app/[locale]/knowledges/components/document/DocumentList.tsx
@@ -533,26 +533,15 @@ const DocumentListContainer = forwardRef(
/>
- onMultimodalChange(!isMultimodal)}
+
+
diff --git a/frontend/components/tool-config/KnowledgeBaseSelectorModal.tsx b/frontend/components/tool-config/KnowledgeBaseSelectorModal.tsx
index 6c78c279b..cdde3b2b2 100644
--- a/frontend/components/tool-config/KnowledgeBaseSelectorModal.tsx
+++ b/frontend/components/tool-config/KnowledgeBaseSelectorModal.tsx
@@ -263,37 +263,39 @@ export default function KnowledgeBaseSelectorModal({
]
);
- // Check if a knowledge base has model mismatch (for display purposes)
- const checkModelMismatch = (kb: KnowledgeBase): boolean => {
- if (kb.source !== "nexent") {
- return false;
- }
+ const getModelMismatch = useCallback(
+ (kb: KnowledgeBase): boolean => {
+ if (kb.source !== "nexent") {
+ return false;
+ }
- const hasMultimodalConstraintMismatch =
- toolMultimodal !== null &&
- ((toolMultimodal && !kb.is_multimodal) ||
- (!toolMultimodal && kb.is_multimodal));
- if (hasMultimodalConstraintMismatch) {
- return true;
- }
+ const hasMultimodalConstraintMismatch =
+ toolMultimodal !== null &&
+ ((toolMultimodal && !kb.is_multimodal) ||
+ (!toolMultimodal && kb.is_multimodal));
+ if (hasMultimodalConstraintMismatch) {
+ return true;
+ }
- const embeddingModel = kb.embeddingModel;
- if (!embeddingModel || embeddingModel === "unknown") {
- return false;
- }
+ const embeddingModel = kb.embeddingModel;
+ if (!embeddingModel || embeddingModel === "unknown") {
+ return false;
+ }
- if (kb.is_multimodal) {
- if (!currentMultiEmbeddingModel) {
- return true;
+ if (kb.is_multimodal) {
+ if (!currentMultiEmbeddingModel) {
+ return true;
+ }
+ return embeddingModel !== currentMultiEmbeddingModel;
}
- return embeddingModel !== currentMultiEmbeddingModel;
- }
- if (!currentEmbeddingModel) {
- return false;
- }
- return embeddingModel !== currentEmbeddingModel;
- };
+ if (!currentEmbeddingModel) {
+ return false;
+ }
+ return embeddingModel !== currentEmbeddingModel;
+ },
+ [currentEmbeddingModel, currentMultiEmbeddingModel, toolMultimodal]
+ );
// Filter knowledge bases based on tool type, search, and filters
const filteredKnowledgeBases = useMemo(() => {
@@ -714,6 +716,7 @@ export default function KnowledgeBaseSelectorModal({
String(selectedId).trim() === String(kb.id).trim()
);
const canSelect = checkCanSelect(kb);
+ const hasModelMismatch = getModelMismatch(kb);
return (
list:
# kerry
diff --git a/sdk/nexent/data_process/extract_image.py b/sdk/nexent/data_process/extract_image.py
index 0134a2733..8c40e9346 100644
--- a/sdk/nexent/data_process/extract_image.py
+++ b/sdk/nexent/data_process/extract_image.py
@@ -4,7 +4,7 @@
import hashlib
import tempfile
import subprocess
-from typing import List, Dict, Any
+from typing import List, Dict, Any, Optional
import zipfile
from xml.etree import ElementTree
@@ -171,6 +171,76 @@ def _extract_pdf(self, pdf_path: str, **params) -> List[Dict]:
return results
+ def _excel_sheet_files(self, z: zipfile.ZipFile) -> List[str]:
+ return [f for f in z.namelist() if f.startswith("xl/worksheets/sheet")]
+
+
+ def _excel_drawing_file(self, z: zipfile.ZipFile, sheet_file: str) -> Optional[str]:
+ sheet_xml = ElementTree.fromstring(z.read(sheet_file))
+ drawing = sheet_xml.find(
+ ".//{http://schemas.openxmlformats.org/spreadsheetml/2006/main}drawing")
+ if drawing is None:
+ return None
+
+ rel_id = drawing.get(
+ "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id")
+ rel_path = sheet_file.replace("worksheets", "worksheets/_rels") + ".rels"
+ if rel_path not in z.namelist():
+ return None
+
+ rel_xml = ElementTree.fromstring(z.read(rel_path))
+ for rel in rel_xml:
+ if rel.get("Id") == rel_id:
+ return "xl/" + rel.get("Target").replace("../", "")
+
+ return None
+
+
+ def _excel_rel_map(self, z: zipfile.ZipFile, drawing_file: str) -> Optional[Dict[str, str]]:
+ rel_file = drawing_file.replace("drawings/", "drawings/_rels/") + ".rels"
+ if rel_file not in z.namelist():
+ return None
+
+ rel_root = ElementTree.fromstring(z.read(rel_file))
+ return {
+ rel.get("Id"): "xl/" + rel.get("Target").replace("../", "")
+ for rel in rel_root
+ }
+
+
+ def _excel_anchors(self, z: zipfile.ZipFile, drawing_file: str, ns: Dict[str, str]) -> List[Any]:
+ drawing_root = ElementTree.fromstring(z.read(drawing_file))
+ return drawing_root.findall(".//xdr:twoCellAnchor", ns) + \
+ drawing_root.findall(".//xdr:oneCellAnchor", ns)
+
+
+ def _excel_anchor_coords(self, anchor: Any, ns: Dict[str, str]) -> Optional[Dict[str, int]]:
+ from_node = anchor.find("xdr:from", ns)
+ if from_node is None:
+ return None
+
+ row1 = int(from_node.find("xdr:row", ns).text) + 1
+ col1 = int(from_node.find("xdr:col", ns).text) + 1
+
+ to_node = anchor.find("xdr:to", ns)
+ if to_node is not None:
+ row2 = int(to_node.find("xdr:row", ns).text) + 1
+ col2 = int(to_node.find("xdr:col", ns).text) + 1
+ else:
+ row2, col2 = row1, col1
+
+ return {"row1": row1, "col1": col1, "row2": row2, "col2": col2}
+
+
+ def _excel_anchor_embed_id(self, anchor: Any, ns: Dict[str, str]) -> Optional[str]:
+ blip = anchor.find(".//a:blip", ns)
+ if blip is None:
+ return None
+
+ return blip.get(
+ "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed")
+
+
def _extract_excel(self, xlsx_path):
results = []
seen = set()
@@ -182,86 +252,35 @@ def _extract_excel(self, xlsx_path):
"r": "http://schemas.openxmlformats.org/officeDocument/2006/relationships",
}
- workbook = ElementTree.fromstring(z.read("xl/workbook.xml"))
- sheets = {}
- for s in workbook.findall(".//{http://schemas.openxmlformats.org/spreadsheetml/2006/main}sheet"):
- sheets[s.get("r:id")] = s.get("name")
-
- sheet_files = [f for f in z.namelist(
- ) if f.startswith("xl/worksheets/sheet")]
+ sheet_files = self._excel_sheet_files(z)
for sheet_file in sheet_files:
- sheet_xml = ElementTree.fromstring(z.read(sheet_file))
- drawing = sheet_xml.find(
- ".//{http://schemas.openxmlformats.org/spreadsheetml/2006/main}drawing")
-
- if drawing is None:
- continue
-
- rel_id = drawing.get(
- "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id")
- rel_path = sheet_file.replace(
- "worksheets", "worksheets/_rels") + ".rels"
-
- if rel_path not in z.namelist():
- continue
-
- rel_xml = ElementTree.fromstring(z.read(rel_path))
- drawing_file = None
-
- for r in rel_xml:
- if r.get("Id") == rel_id:
- drawing_file = "xl/" + \
- r.get("Target").replace("../", "")
- break
-
+ drawing_file = self._excel_drawing_file(z, sheet_file)
if drawing_file is None:
continue
- sheet_name = os.path.basename(sheet_file)
- drawing_root = ElementTree.fromstring(z.read(drawing_file))
-
- rel_file = drawing_file.replace(
- "drawings/", "drawings/_rels/") + ".rels"
- if rel_file not in z.namelist():
+ rel_map = self._excel_rel_map(z, drawing_file)
+ if not rel_map:
continue
- rel_root = ElementTree.fromstring(z.read(rel_file))
- rel_map = {
- r.get("Id"): "xl/" + r.get("Target").replace("../", "")
- for r in rel_root
- }
-
- anchors = drawing_root.findall(".//xdr:twoCellAnchor", ns) + \
- drawing_root.findall(".//xdr:oneCellAnchor", ns)
+ anchors = self._excel_anchors(z, drawing_file, ns)
+ sheet_name = os.path.basename(sheet_file)
for anchor in anchors:
- from_node = anchor.find("xdr:from", ns)
- if from_node is None:
+ coords = self._excel_anchor_coords(anchor, ns)
+ if coords is None:
continue
- row1 = int(from_node.find("xdr:row", ns).text) + 1
- col1 = int(from_node.find("xdr:col", ns).text) + 1
-
- to_node = anchor.find("xdr:to", ns)
- if to_node is not None:
- row2 = int(to_node.find("xdr:row", ns).text) + 1
- col2 = int(to_node.find("xdr:col", ns).text) + 1
- else:
- row2, col2 = row1, col1
-
- blip = anchor.find(".//a:blip", ns)
- if blip is None:
+ embed_rel_id = self._excel_anchor_embed_id(anchor, ns)
+ if not embed_rel_id:
continue
- embed_rel_id = blip.get(
- "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed")
- if embed_rel_id not in rel_map:
+ target = rel_map.get(embed_rel_id)
+ if not target:
continue
- img_bytes = z.read(rel_map[embed_rel_id])
+ img_bytes = z.read(target)
h = self._hash(img_bytes)
-
if h in seen:
continue
seen.add(h)
@@ -270,10 +289,10 @@ def _extract_excel(self, xlsx_path):
"position": {
"sheet_name": sheet_name,
"coordinates": {
- "x1": col1,
- "x2": col2,
- "y1": row1,
- "y2": row2
+ "x1": coords["col1"],
+ "x2": coords["col2"],
+ "y1": coords["row1"],
+ "y2": coords["row2"]
}
},
"image_format": self.detect_image_format(img_bytes),
@@ -377,4 +396,4 @@ def process_file(self, file_bytes: bytes, chunking_strategy: str, filename: str,
try:
os.remove(f_path)
except Exception:
- pass
\ No newline at end of file
+ pass
diff --git a/sdk/nexent/vector_database/elasticsearch_core.py b/sdk/nexent/vector_database/elasticsearch_core.py
index 190a98554..4f230786c 100644
--- a/sdk/nexent/vector_database/elasticsearch_core.py
+++ b/sdk/nexent/vector_database/elasticsearch_core.py
@@ -1081,10 +1081,6 @@ def semantic_search(
"_source": {"excludes": ["multi_embedding"]},
}
raw_results = self.exec_query(index_pattern, search_text_query) + self.exec_query(index_pattern, search_image_query)
-
- # raw_results = raw_results + raw_results2
- print("raw_results: ", raw_results)
-
else:
search_query = {
"knn": {
diff --git a/test/backend/services/test_tool_configuration_service.py b/test/backend/services/test_tool_configuration_service.py
index 4207c9e62..8f4bad20c 100644
--- a/test/backend/services/test_tool_configuration_service.py
+++ b/test/backend/services/test_tool_configuration_service.py
@@ -26,6 +26,7 @@
memory_pkg.__path__ = []
memory_service_stub = types.ModuleType("nexent.memory.memory_service")
async def _clear_memory_stub(*_args, **_kwargs):
+ await asyncio.sleep(0)
return None
memory_service_stub.clear_memory = _clear_memory_stub
sys.modules["nexent.memory.memory_service"] = memory_service_stub
diff --git a/test/backend/services/test_vectordatabase_service.py b/test/backend/services/test_vectordatabase_service.py
index cbb4904c3..6726a1701 100644
--- a/test/backend/services/test_vectordatabase_service.py
+++ b/test/backend/services/test_vectordatabase_service.py
@@ -29,6 +29,7 @@
memory_pkg.__path__ = []
memory_service_stub = ModuleType("nexent.memory.memory_service")
async def _clear_memory_stub(*_args, **_kwargs):
+ await asyncio.sleep(0)
return None
memory_service_stub.clear_memory = _clear_memory_stub
sys.modules["nexent.memory.memory_service"] = memory_service_stub
diff --git a/test/sdk/data_process/test_core.py b/test/sdk/data_process/test_core.py
index 359204c29..af325b52f 100644
--- a/test/sdk/data_process/test_core.py
+++ b/test/sdk/data_process/test_core.py
@@ -346,9 +346,6 @@ def test_file_process_returns_images_when_extractor_available(self, core, mocker
{"image_bytes": b"img", "image_format": "png", "position": {"page_number": 1}}
]
core.processors["Unstructured"] = mock_processor
- core.processors["UniversalImageExtractor"] = Mock(
- process_file=Mock(return_value=[])
- )
core.processors["UniversalImageExtractor"] = mock_extractor
result = core.file_process(
From 6b427c0a62ed218c60934fb6993cab3bd0bedd0f Mon Sep 17 00:00:00 2001
From: wyxkerry <1012700194@qq.com>
Date: Tue, 31 Mar 2026 11:53:05 +0800
Subject: [PATCH 04/11] =?UTF-8?q?=E2=9C=A8add=5Fimage=5Fretrieval?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
backend/services/vectordatabase_service.py | 42 ++++----
sdk/nexent/data_process/extract_image.py | 113 +++++++++++++--------
2 files changed, 93 insertions(+), 62 deletions(-)
diff --git a/backend/services/vectordatabase_service.py b/backend/services/vectordatabase_service.py
index 38502afe3..c7e9b800a 100644
--- a/backend/services/vectordatabase_service.py
+++ b/backend/services/vectordatabase_service.py
@@ -284,6 +284,20 @@ def get_embedding_model(tenant_id: str, is_multimodal: bool = False, model_name:
return _build_embedding_from_config(model_config)
+def _resolve_embedding_model(
+ tenant_id: str,
+ is_multimodal: bool,
+ embedding_model_name: Optional[str],
+) -> Optional[BaseEmbedding]:
+ if embedding_model_name:
+ return get_embedding_model(
+ tenant_id,
+ is_multimodal=is_multimodal,
+ model_name=embedding_model_name,
+ )
+ return get_embedding_model(tenant_id, is_multimodal=is_multimodal)
+
+
class ElasticSearchService:
@staticmethod
async def full_delete_knowledge_base(index_name: str, vdb_core: VectorDatabaseCore, user_id: str):
@@ -472,20 +486,14 @@ def create_knowledge_base(
with an explicit index_name.
"""
try:
- if embedding_model_name is None:
- if is_multimodal:
- embedding_model = get_embedding_model(tenant_id, is_multimodal=True)
- else:
- embedding_model = get_embedding_model(tenant_id, None)
- else:
- if is_multimodal:
- embedding_model = get_embedding_model(
- tenant_id,
- is_multimodal=True,
- model_name=embedding_model_name,
- )
- else:
- embedding_model = get_embedding_model(tenant_id, embedding_model_name)
+ embedding_model = _resolve_embedding_model(
+ tenant_id=tenant_id,
+ is_multimodal=is_multimodal,
+ embedding_model_name=embedding_model_name,
+ )
+ resolved_embedding_model_name = embedding_model_name
+ if resolved_embedding_model_name is None and embedding_model:
+ resolved_embedding_model_name = getattr(embedding_model, "model", None)
# Create knowledge record first to obtain knowledge_id and generated index_name
knowledge_data = {
@@ -493,11 +501,7 @@ def create_knowledge_base(
"knowledge_describe": "",
"user_id": user_id,
"tenant_id": tenant_id,
- "embedding_model_name": (
- embedding_model_name
- if embedding_model_name is not None
- else (embedding_model.model if embedding_model else None)
- ),
+ "embedding_model_name": resolved_embedding_model_name,
"is_multimodal": is_multimodal,
}
diff --git a/sdk/nexent/data_process/extract_image.py b/sdk/nexent/data_process/extract_image.py
index 8c40e9346..0b34987e8 100644
--- a/sdk/nexent/data_process/extract_image.py
+++ b/sdk/nexent/data_process/extract_image.py
@@ -241,6 +241,73 @@ def _excel_anchor_embed_id(self, anchor: Any, ns: Dict[str, str]) -> Optional[st
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed")
+ def _extract_excel_anchors(
+ self,
+ z: zipfile.ZipFile,
+ anchors: List[Any],
+ rel_map: Dict[str, str],
+ sheet_name: str,
+ ns: Dict[str, str],
+ seen: set,
+ ) -> List[Dict[str, Any]]:
+ results = []
+ for anchor in anchors:
+ coords = self._excel_anchor_coords(anchor, ns)
+ if coords is None:
+ continue
+
+ embed_rel_id = self._excel_anchor_embed_id(anchor, ns)
+ if not embed_rel_id:
+ continue
+
+ target = rel_map.get(embed_rel_id)
+ if not target:
+ continue
+
+ img_bytes = z.read(target)
+ h = self._hash(img_bytes)
+ if h in seen:
+ continue
+ seen.add(h)
+
+ results.append({
+ "position": {
+ "sheet_name": sheet_name,
+ "coordinates": {
+ "x1": coords["col1"],
+ "x2": coords["col2"],
+ "y1": coords["row1"],
+ "y2": coords["row2"]
+ }
+ },
+ "image_format": self.detect_image_format(img_bytes),
+ "image_bytes": img_bytes
+ })
+
+ return results
+
+
+ def _extract_excel_sheet(
+ self,
+ z: zipfile.ZipFile,
+ sheet_file: str,
+ ns: Dict[str, str],
+ seen: set,
+ ) -> List[Dict[str, Any]]:
+ drawing_file = self._excel_drawing_file(z, sheet_file)
+ if drawing_file is None:
+ return []
+
+ rel_map = self._excel_rel_map(z, drawing_file)
+ if not rel_map:
+ return []
+
+ anchors = self._excel_anchors(z, drawing_file, ns)
+ sheet_name = os.path.basename(sheet_file)
+
+ return self._extract_excel_anchors(z, anchors, rel_map, sheet_name, ns, seen)
+
+
def _extract_excel(self, xlsx_path):
results = []
seen = set()
@@ -255,49 +322,9 @@ def _extract_excel(self, xlsx_path):
sheet_files = self._excel_sheet_files(z)
for sheet_file in sheet_files:
- drawing_file = self._excel_drawing_file(z, sheet_file)
- if drawing_file is None:
- continue
-
- rel_map = self._excel_rel_map(z, drawing_file)
- if not rel_map:
- continue
-
- anchors = self._excel_anchors(z, drawing_file, ns)
- sheet_name = os.path.basename(sheet_file)
-
- for anchor in anchors:
- coords = self._excel_anchor_coords(anchor, ns)
- if coords is None:
- continue
-
- embed_rel_id = self._excel_anchor_embed_id(anchor, ns)
- if not embed_rel_id:
- continue
-
- target = rel_map.get(embed_rel_id)
- if not target:
- continue
-
- img_bytes = z.read(target)
- h = self._hash(img_bytes)
- if h in seen:
- continue
- seen.add(h)
-
- results.append({
- "position": {
- "sheet_name": sheet_name,
- "coordinates": {
- "x1": coords["col1"],
- "x2": coords["col2"],
- "y1": coords["row1"],
- "y2": coords["row2"]
- }
- },
- "image_format": self.detect_image_format(img_bytes),
- "image_bytes": img_bytes
- })
+ results.extend(
+ self._extract_excel_sheet(z, sheet_file, ns, seen)
+ )
return results
From 2deeb8ae415963e99037af8d6c953bcae5b957f0 Mon Sep 17 00:00:00 2001
From: wyxkerry <1012700194@qq.com>
Date: Tue, 31 Mar 2026 14:30:41 +0800
Subject: [PATCH 05/11] =?UTF-8?q?=E2=9C=A8add=5Fimage=5Fretrieval?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.../agentConfig/tool/ToolConfigModal.tsx | 42 ++-----
.../KnowledgeBaseSelectorModal.tsx | 46 ++-----
frontend/lib/knowledgeBaseCompatibility.ts | 46 +++++++
frontend/services/knowledgeBaseService.ts | 11 +-
.../core/tools/knowledge_base_search_tool.py | 119 ++++++++----------
sdk/nexent/data_process/extract_image.py | 52 +++++---
test/backend/data_process/test_ray_actors.py | 92 +++++++-------
test/backend/database/test_knowledge_db.py | 63 +++-------
.../tools/test_knowledge_base_search_tool.py | 6 +-
9 files changed, 228 insertions(+), 249 deletions(-)
create mode 100644 frontend/lib/knowledgeBaseCompatibility.ts
diff --git a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx
index 2dc55592c..5f0d969cf 100644
--- a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx
+++ b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx
@@ -30,6 +30,10 @@ import { API_ENDPOINTS } from "@/services/api";
import knowledgeBaseService from "@/services/knowledgeBaseService";
import log from "@/lib/logger";
import { isZhLocale, getLocalizedDescription } from "@/lib/utils";
+import {
+ isEmbeddingModelCompatible as isEmbeddingModelCompatibleBase,
+ isMultimodalConstraintMismatch as isMultimodalConstraintMismatchBase,
+} from "@/lib/knowledgeBaseCompatibility";
export interface ToolConfigModalProps {
isOpen: boolean;
@@ -488,44 +492,18 @@ export default function ToolConfigModal({
const isMultimodalConstraintMismatch = useCallback(
(kb: KnowledgeBase) => {
- return (
- toolMultimodal !== null &&
- ((toolMultimodal && !kb.is_multimodal) ||
- (!toolMultimodal && kb.is_multimodal))
- );
+ return isMultimodalConstraintMismatchBase(kb, toolMultimodal);
},
[toolMultimodal]
);
const isEmbeddingModelCompatible = useCallback(
(kb: KnowledgeBase) => {
- if (kb.is_multimodal) {
- if (!currentMultiEmbeddingModel) {
- return false;
- }
- if (
- kb.embeddingModel &&
- kb.embeddingModel !== "unknown" &&
- kb.embeddingModel !== currentMultiEmbeddingModel
- ) {
- return false;
- }
- return true;
- }
-
- if (!currentEmbeddingModel) {
- return true;
- }
-
- if (
- kb.embeddingModel &&
- kb.embeddingModel !== "unknown" &&
- kb.embeddingModel !== currentEmbeddingModel
- ) {
- return false;
- }
-
- return true;
+ return isEmbeddingModelCompatibleBase(
+ kb,
+ currentEmbeddingModel,
+ currentMultiEmbeddingModel
+ );
},
[currentEmbeddingModel, currentMultiEmbeddingModel]
);
diff --git a/frontend/components/tool-config/KnowledgeBaseSelectorModal.tsx b/frontend/components/tool-config/KnowledgeBaseSelectorModal.tsx
index cdde3b2b2..a3c72e804 100644
--- a/frontend/components/tool-config/KnowledgeBaseSelectorModal.tsx
+++ b/frontend/components/tool-config/KnowledgeBaseSelectorModal.tsx
@@ -20,6 +20,10 @@ import {
import { KnowledgeBase } from "@/types/knowledgeBase";
import { KB_LAYOUT, KB_TAG_VARIANTS } from "@/const/knowledgeBaseLayout";
+import {
+ isEmbeddingModelCompatible as isEmbeddingModelCompatibleBase,
+ isMultimodalConstraintMismatch as isMultimodalConstraintMismatchBase,
+} from "@/lib/knowledgeBaseCompatibility";
interface KnowledgeBaseSelectorProps {
isOpen: boolean;
@@ -188,44 +192,18 @@ export default function KnowledgeBaseSelectorModal({
const isMultimodalConstraintMismatch = useCallback(
(kb: KnowledgeBase) => {
- return (
- toolMultimodal !== null &&
- ((toolMultimodal && !kb.is_multimodal) ||
- (!toolMultimodal && kb.is_multimodal))
- );
+ return isMultimodalConstraintMismatchBase(kb, toolMultimodal);
},
[toolMultimodal]
);
const isEmbeddingModelCompatible = useCallback(
(kb: KnowledgeBase) => {
- if (kb.is_multimodal) {
- if (!currentMultiEmbeddingModel) {
- return false;
- }
- if (
- kb.embeddingModel &&
- kb.embeddingModel !== "unknown" &&
- kb.embeddingModel !== currentMultiEmbeddingModel
- ) {
- return false;
- }
- return true;
- }
-
- if (!currentEmbeddingModel) {
- return true;
- }
-
- if (
- kb.embeddingModel &&
- kb.embeddingModel !== "unknown" &&
- kb.embeddingModel !== currentEmbeddingModel
- ) {
- return false;
- }
-
- return true;
+ return isEmbeddingModelCompatibleBase(
+ kb,
+ currentEmbeddingModel,
+ currentMultiEmbeddingModel
+ );
},
[currentEmbeddingModel, currentMultiEmbeddingModel]
);
@@ -270,9 +248,7 @@ export default function KnowledgeBaseSelectorModal({
}
const hasMultimodalConstraintMismatch =
- toolMultimodal !== null &&
- ((toolMultimodal && !kb.is_multimodal) ||
- (!toolMultimodal && kb.is_multimodal));
+ isMultimodalConstraintMismatchBase(kb, toolMultimodal);
if (hasMultimodalConstraintMismatch) {
return true;
}
diff --git a/frontend/lib/knowledgeBaseCompatibility.ts b/frontend/lib/knowledgeBaseCompatibility.ts
new file mode 100644
index 000000000..37381c048
--- /dev/null
+++ b/frontend/lib/knowledgeBaseCompatibility.ts
@@ -0,0 +1,46 @@
+import { KnowledgeBase } from "@/types/knowledgeBase";
+
+export const isMultimodalConstraintMismatch = (
+ kb: KnowledgeBase,
+ toolMultimodal: boolean | null
+): boolean => {
+ return (
+ toolMultimodal !== null &&
+ ((toolMultimodal && !kb.is_multimodal) ||
+ (!toolMultimodal && kb.is_multimodal))
+ );
+};
+
+export const isEmbeddingModelCompatible = (
+ kb: KnowledgeBase,
+ currentEmbeddingModel: string | null,
+ currentMultiEmbeddingModel: string | null
+): boolean => {
+ if (kb.is_multimodal) {
+ if (!currentMultiEmbeddingModel) {
+ return false;
+ }
+ if (
+ kb.embeddingModel &&
+ kb.embeddingModel !== "unknown" &&
+ kb.embeddingModel !== currentMultiEmbeddingModel
+ ) {
+ return false;
+ }
+ return true;
+ }
+
+ if (!currentEmbeddingModel) {
+ return true;
+ }
+
+ if (
+ kb.embeddingModel &&
+ kb.embeddingModel !== "unknown" &&
+ kb.embeddingModel !== currentEmbeddingModel
+ ) {
+ return false;
+ }
+
+ return true;
+};
diff --git a/frontend/services/knowledgeBaseService.ts b/frontend/services/knowledgeBaseService.ts
index 0c8262e9a..657160fc7 100644
--- a/frontend/services/knowledgeBaseService.ts
+++ b/frontend/services/knowledgeBaseService.ts
@@ -30,6 +30,9 @@ const normalizeIsMultimodal = (value: unknown): boolean => {
return false;
};
+const resolveIsMultimodal = (indexInfo: any, stats: any): boolean =>
+ normalizeIsMultimodal(indexInfo.is_multimodal ?? stats.is_multimodal);
+
// Knowledge base service class
class KnowledgeBaseService {
// Check Elasticsearch health (force refresh, no caching for setup page)
@@ -499,9 +502,7 @@ class KnowledgeBaseService {
stats.creation_date ||
null,
embeddingModel: stats.embedding_model || "unknown",
- is_multimodal: normalizeIsMultimodal(
- indexInfo.is_multimodal ?? stats.is_multimodal
- ),
+ is_multimodal: resolveIsMultimodal(indexInfo, stats),
knowledge_sources:
indexInfo.knowledge_sources || "elasticsearch",
ingroup_permission: indexInfo.ingroup_permission || "",
@@ -569,9 +570,7 @@ class KnowledgeBaseService {
createdAt: stats.creation_date || null,
updatedAt: stats.update_date || stats.creation_date || null,
embeddingModel: stats.embedding_model || "unknown",
- is_multimodal: normalizeIsMultimodal(
- indexInfo.is_multimodal ?? stats.is_multimodal
- ),
+ is_multimodal: resolveIsMultimodal(indexInfo, stats),
knowledge_sources:
indexInfo.knowledge_sources || "datamate",
ingroup_permission: indexInfo.ingroup_permission || "",
diff --git a/sdk/nexent/core/tools/knowledge_base_search_tool.py b/sdk/nexent/core/tools/knowledge_base_search_tool.py
index d47aad063..530d51765 100644
--- a/sdk/nexent/core/tools/knowledge_base_search_tool.py
+++ b/sdk/nexent/core/tools/knowledge_base_search_tool.py
@@ -39,6 +39,7 @@ class KnowledgeBaseSearchTool(Tool):
"index_names": {
"type": "array",
"description": "The list of index names to search",
+ "nullable": True,
"description_zh": "要索引的知识库"
},
}
@@ -106,9 +107,14 @@ def __init__(
self.running_prompt_en = "Searching the knowledge base..."
- def forward(self, query: str, index_names: List[str]) -> str:
- # Parse index_names from string (always required)
- search_index_names = index_names
+ def forward(self, query: str, index_names: List[str] | str | None = None) -> str:
+ # Parse index_names from string (optional)
+ if index_names is None:
+ search_index_names = self.index_names
+ elif isinstance(index_names, str):
+ search_index_names = [name.strip() for name in index_names.split(",") if name.strip()]
+ else:
+ search_index_names = index_names
# Use the instance search_mode
search_mode = self.search_mode
@@ -260,82 +266,59 @@ def _record_search_results(
self.observer.add_message(
"", ProcessType.PICTURE_WEB, search_images_list_json)
- def search_hybrid(self, query, index_names):
+ @staticmethod
+ def _format_vdb_results(results):
+ formatted_results = []
+ for result in results:
+ doc = result["document"]
+ doc["score"] = result["score"]
+ # Include source index in results
+ doc["index"] = result["index"]
+ if "content" in result:
+ doc["content"] = result["content"]
+ if "process_source" in result:
+ doc["process_source"] = result["process_source"]
+ formatted_results.append(doc)
+ return formatted_results
+
+ def _search_with(self, search_fn, query, index_names, label, **kwargs):
try:
- results = self.vdb_core.hybrid_search(
- index_names=index_names, query_text=query, embedding_model=self.embedding_model, top_k=self.top_k
+ results = search_fn(
+ index_names=index_names, query_text=query, top_k=self.top_k, **kwargs
)
-
- # Format results
- formatted_results = []
- for result in results:
- doc = result["document"]
- doc["score"] = result["score"]
- # Include source index in results
- doc["index"] = result["index"]
- if "content" in result:
- doc["content"] = result["content"]
- if "process_source" in result:
- doc["process_source"] = result["process_source"]
- formatted_results.append(doc)
-
+ formatted_results = self._format_vdb_results(results)
return {
"results": formatted_results,
"total": len(formatted_results),
}
except Exception as e:
- raise Exception(detail=f"Error during hybrid search: {str(e)}")
+ raise Exception(f"Error during {label} search: {str(e)}")
- def search_accurate(self, query, index_names):
- try:
- results = self.vdb_core.accurate_search(
- index_names=index_names, query_text=query, top_k=self.top_k)
-
- # Format results
- formatted_results = []
- for result in results:
- doc = result["document"]
- doc["score"] = result["score"]
- # Include source index in results
- doc["index"] = result["index"]
- if "content" in result:
- doc["content"] = result["content"]
- if "process_source" in result:
- doc["process_source"] = result["process_source"]
- formatted_results.append(doc)
+ def search_hybrid(self, query, index_names):
+ return self._search_with(
+ self.vdb_core.hybrid_search,
+ query,
+ index_names,
+ "hybrid",
+ embedding_model=self.embedding_model,
+ )
- return {
- "results": formatted_results,
- "total": len(formatted_results),
- }
- except Exception as e:
- raise Exception(detail=f"Error during accurate search: {str(e)}")
+ def search_accurate(self, query, index_names):
+ return self._search_with(
+ self.vdb_core.accurate_search,
+ query,
+ index_names,
+ "accurate",
+ )
def search_semantic(self, query, index_names):
- try:
- results = self.vdb_core.semantic_search(
- index_names=index_names, query_text=query, embedding_model=self.embedding_model, top_k=self.top_k
- )
-
- # Format results
- formatted_results = []
- for result in results:
- doc = result["document"]
- doc["score"] = result["score"]
- # Include source index in results
- doc["index"] = result["index"]
- if "content" in result:
- doc["content"] = result["content"]
- if "process_source" in result:
- doc["process_source"] = result["process_source"]
- formatted_results.append(doc)
-
- return {
- "results": formatted_results,
- "total": len(formatted_results),
- }
- except Exception as e:
- raise Exception(detail=f"Error during semantic search: {str(e)}")
+ return self._search_with(
+ self.vdb_core.semantic_search,
+ query,
+ index_names,
+ "semantic",
+ embedding_model=self.embedding_model,
+ )
def _filter_images(self, images_list_url, query) -> list:
# kerry
diff --git a/sdk/nexent/data_process/extract_image.py b/sdk/nexent/data_process/extract_image.py
index 0b34987e8..3fbe377f1 100644
--- a/sdk/nexent/data_process/extract_image.py
+++ b/sdk/nexent/data_process/extract_image.py
@@ -43,7 +43,24 @@ class UniversalImageExtractor(FileProcessor):
@staticmethod
def _hash(data: bytes) -> str:
- return hashlib.md5(data).hexdigest()
+ # Use a modern hash for safe, collision-resistant de-duplication.
+ return hashlib.sha256(data).hexdigest()
+
+ @staticmethod
+ def _openxml_namespace_maps() -> List[Dict[str, str]]:
+ # Prefer https URIs, but retain http for compatibility with existing files.
+ return [
+ {
+ "xdr": "https://schemas.openxmlformats.org/drawingml/2006/spreadsheetDrawing",
+ "a": "https://schemas.openxmlformats.org/drawingml/2006/main",
+ "r": "https://schemas.openxmlformats.org/officeDocument/2006/relationships",
+ },
+ {
+ "xdr": "http://schemas.openxmlformats.org/drawingml/2006/spreadsheetDrawing",
+ "a": "http://schemas.openxmlformats.org/drawingml/2006/main",
+ "r": "http://schemas.openxmlformats.org/officeDocument/2006/relationships",
+ },
+ ]
def _write_temp_file(self, data: bytes, suffix: str) -> str:
@@ -178,12 +195,18 @@ def _excel_sheet_files(self, z: zipfile.ZipFile) -> List[str]:
def _excel_drawing_file(self, z: zipfile.ZipFile, sheet_file: str) -> Optional[str]:
sheet_xml = ElementTree.fromstring(z.read(sheet_file))
drawing = sheet_xml.find(
- ".//{http://schemas.openxmlformats.org/spreadsheetml/2006/main}drawing")
+ ".//{https://schemas.openxmlformats.org/spreadsheetml/2006/main}drawing")
+ if drawing is None:
+ drawing = sheet_xml.find(
+ ".//{http://schemas.openxmlformats.org/spreadsheetml/2006/main}drawing")
if drawing is None:
return None
rel_id = drawing.get(
- "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id")
+ "{https://schemas.openxmlformats.org/officeDocument/2006/relationships}id")
+ if rel_id is None:
+ rel_id = drawing.get(
+ "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id")
rel_path = sheet_file.replace("worksheets", "worksheets/_rels") + ".rels"
if rel_path not in z.namelist():
return None
@@ -237,8 +260,12 @@ def _excel_anchor_embed_id(self, anchor: Any, ns: Dict[str, str]) -> Optional[st
if blip is None:
return None
- return blip.get(
- "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed")
+ embed_id = blip.get(
+ "{https://schemas.openxmlformats.org/officeDocument/2006/relationships}embed")
+ if embed_id is None:
+ embed_id = blip.get(
+ "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed")
+ return embed_id
def _extract_excel_anchors(
@@ -313,18 +340,15 @@ def _extract_excel(self, xlsx_path):
seen = set()
with zipfile.ZipFile(xlsx_path) as z:
- ns = {
- "xdr": "http://schemas.openxmlformats.org/drawingml/2006/spreadsheetDrawing",
- "a": "http://schemas.openxmlformats.org/drawingml/2006/main",
- "r": "http://schemas.openxmlformats.org/officeDocument/2006/relationships",
- }
-
sheet_files = self._excel_sheet_files(z)
for sheet_file in sheet_files:
- results.extend(
- self._extract_excel_sheet(z, sheet_file, ns, seen)
- )
+ extracted = []
+ for ns in self._openxml_namespace_maps():
+ extracted = self._extract_excel_sheet(z, sheet_file, ns, seen)
+ if extracted:
+ break
+ results.extend(extracted)
return results
diff --git a/test/backend/data_process/test_ray_actors.py b/test/backend/data_process/test_ray_actors.py
index 41cf69c56..07fe6676d 100644
--- a/test/backend/data_process/test_ray_actors.py
+++ b/test/backend/data_process/test_ray_actors.py
@@ -53,6 +53,27 @@ def expire(self, key, seconds):
self.expirations[key] = seconds
+def make_temp_file(tmp_path, name: str, content: bytes = b"file-bytes") -> str:
+ path = tmp_path / name
+ path.write_bytes(content)
+ return str(path)
+
+
+def stub_consts(monkeypatch):
+ fake_consts_pkg = types.ModuleType("consts")
+ fake_consts_const = types.ModuleType("consts.const")
+ fake_consts_const.RAY_ACTOR_NUM_CPUS = 1
+ fake_consts_const.REDIS_BACKEND_URL = ""
+ # New defaults required by ray_actors import
+ fake_consts_const.DEFAULT_EXPECTED_CHUNK_SIZE = 1024
+ fake_consts_const.DEFAULT_MAXIMUM_CHUNK_SIZE = 1536
+ fake_consts_const.TABLE_TRANSFORMER_MODEL_PATH = "/models/table"
+ fake_consts_const.UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH = "/models/unstructured.json"
+ monkeypatch.setitem(sys.modules, "consts", fake_consts_pkg)
+ monkeypatch.setitem(sys.modules, "consts.const", fake_consts_const)
+ return fake_consts_const
+
+
@pytest.fixture(autouse=True)
def stub_ray_before_import(monkeypatch):
# Ensure that when module under test imports ray, it gets our stub
@@ -138,17 +159,7 @@ class _Redis:
monkeypatch.setitem(sys.modules, "backend.data_process.tasks", fake_dp_tasks)
# Stub consts.const needed by ray_actors imports
- fake_consts_pkg = types.ModuleType("consts")
- fake_consts_const = types.ModuleType("consts.const")
- fake_consts_const.RAY_ACTOR_NUM_CPUS = 1
- fake_consts_const.REDIS_BACKEND_URL = ""
- # New defaults required by ray_actors import
- fake_consts_const.DEFAULT_EXPECTED_CHUNK_SIZE = 1024
- fake_consts_const.DEFAULT_MAXIMUM_CHUNK_SIZE = 1536
- fake_consts_const.TABLE_TRANSFORMER_MODEL_PATH = "/models/table"
- fake_consts_const.UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH = "/models/unstructured.json"
- monkeypatch.setitem(sys.modules, "consts", fake_consts_pkg)
- monkeypatch.setitem(sys.modules, "consts.const", fake_consts_const)
+ stub_consts(monkeypatch)
# Ensure model_management_db is stubbed to avoid importing real DB layer
if "database.model_management_db" not in sys.modules:
@@ -184,12 +195,13 @@ class _Redis:
return ray_actors
-def test_process_file_happy_path(monkeypatch):
+def test_process_file_happy_path(monkeypatch, tmp_path):
ray_actors = import_module(monkeypatch)
actor = ray_actors.DataProcessorRayActor()
+ source_path = make_temp_file(tmp_path, "a.txt")
chunks = actor.process_file(
- source="/tmp/a.txt",
+ source=source_path,
chunking_strategy="basic",
destination="local",
task_id="tid-1",
@@ -201,7 +213,7 @@ def test_process_file_happy_path(monkeypatch):
assert chunks[0]["content"] == "hello world"
-def test_process_file_applies_chunk_sizes_from_model(monkeypatch):
+def test_process_file_applies_chunk_sizes_from_model(monkeypatch, tmp_path):
ray_actors = import_module(monkeypatch)
# Recorder core to capture params
@@ -229,8 +241,9 @@ def file_process(self, file_data, filename, chunking_strategy, **params):
)
actor = ray_actors.DataProcessorRayActor()
+ source_path = make_temp_file(tmp_path, "a.txt")
actor.process_file(
- source="/tmp/a.txt",
+ source=source_path,
chunking_strategy="basic",
destination="local",
model_id=9,
@@ -246,7 +259,7 @@ def file_process(self, file_data, filename, chunking_strategy, **params):
) == "/models/unstructured.json"
-def test_process_file_no_model_omits_chunk_params(monkeypatch):
+def test_process_file_no_model_omits_chunk_params(monkeypatch, tmp_path):
ray_actors = import_module(monkeypatch)
class RecorderCore:
@@ -268,8 +281,9 @@ def file_process(self, file_data, filename, chunking_strategy, **params):
)
actor = ray_actors.DataProcessorRayActor()
+ source_path = make_temp_file(tmp_path, "b.txt")
actor.process_file(
- source="/tmp/b.txt",
+ source=source_path,
chunking_strategy="basic",
destination="local",
model_id=10,
@@ -285,7 +299,7 @@ def file_process(self, file_data, filename, chunking_strategy, **params):
) == "/models/unstructured.json"
-def test_process_file_model_lookup_exception_uses_defaults(monkeypatch):
+def test_process_file_model_lookup_exception_uses_defaults(monkeypatch, tmp_path):
ray_actors = import_module(monkeypatch)
class RecorderCore:
@@ -308,8 +322,9 @@ def file_process(self, file_data, filename, chunking_strategy, **params):
)
actor = ray_actors.DataProcessorRayActor()
+ source_path = make_temp_file(tmp_path, "c.txt")
actor.process_file(
- source="/tmp/c.txt",
+ source=source_path,
chunking_strategy="basic",
destination="local",
model_id=11,
@@ -392,17 +407,7 @@ class _Redis:
fake_dp_tasks.process_sync = lambda *a, **k: None
monkeypatch.setitem(sys.modules, "backend.data_process.tasks", fake_dp_tasks)
# Stub consts.const again for reload path
- fake_consts_pkg = types.ModuleType("consts")
- fake_consts_const = types.ModuleType("consts.const")
- fake_consts_const.RAY_ACTOR_NUM_CPUS = 1
- fake_consts_const.REDIS_BACKEND_URL = ""
- # Provide defaults required by backend.data_process.ray_actors import
- fake_consts_const.DEFAULT_EXPECTED_CHUNK_SIZE = 1024
- fake_consts_const.DEFAULT_MAXIMUM_CHUNK_SIZE = 1536
- fake_consts_const.TABLE_TRANSFORMER_MODEL_PATH = "/models/table"
- fake_consts_const.UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH = "/models/unstructured.json"
- monkeypatch.setitem(sys.modules, "consts", fake_consts_pkg)
- monkeypatch.setitem(sys.modules, "consts.const", fake_consts_const)
+ stub_consts(monkeypatch)
# Stub database.model_management_db and link to parent to avoid real DB import
if "database.model_management_db" not in sys.modules:
@@ -433,7 +438,7 @@ class _Redis:
actor.process_file("url://missing", "basic", destination="minio")
-def test_process_file_core_returns_none_list_variants(monkeypatch):
+def test_process_file_core_returns_none_list_variants(monkeypatch, tmp_path):
class CoreNone(FakeDataProcessCore):
def file_process(self, *a, **k):
return None
@@ -505,17 +510,7 @@ class _Redis:
fake_dp_tasks.process_sync = lambda *a, **k: None
monkeypatch.setitem(sys.modules, "backend.data_process.tasks", fake_dp_tasks)
# Stub consts.const for ray_actors imports
- fake_consts_pkg = types.ModuleType("consts")
- fake_consts_const = types.ModuleType("consts.const")
- fake_consts_const.RAY_ACTOR_NUM_CPUS = 1
- fake_consts_const.REDIS_BACKEND_URL = ""
- # Provide defaults required by backend.data_process.ray_actors import
- fake_consts_const.DEFAULT_EXPECTED_CHUNK_SIZE = 1024
- fake_consts_const.DEFAULT_MAXIMUM_CHUNK_SIZE = 1536
- fake_consts_const.TABLE_TRANSFORMER_MODEL_PATH = "/models/table"
- fake_consts_const.UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH = "/models/unstructured.json"
- monkeypatch.setitem(sys.modules, "consts", fake_consts_pkg)
- monkeypatch.setitem(sys.modules, "consts.const", fake_consts_const)
+ stub_consts(monkeypatch)
# Ensure model_management_db is stubbed to avoid importing real DB layer
if "database.model_management_db" not in sys.modules:
@@ -530,7 +525,8 @@ class _Redis:
import backend.data_process.ray_actors as ray_actors
reload(ray_actors)
actor = ray_actors.DataProcessorRayActor()
- chunks = actor.process_file("/tmp/a.txt", "basic", destination="local")
+ source_path = make_temp_file(tmp_path, f"a_{core_cls.__name__}.txt")
+ chunks = actor.process_file(source_path, "basic", destination="local")
assert chunks == []
@@ -575,7 +571,7 @@ def test_store_chunks_in_redis_no_url_returns_false(monkeypatch):
assert actor.store_chunks_in_redis("k", [{"content": "x"}]) is False
-def test_process_file_appends_image_chunks(monkeypatch):
+def test_process_file_appends_image_chunks(monkeypatch, tmp_path):
ray_actors = import_module(monkeypatch)
class CoreWithImages:
@@ -604,14 +600,15 @@ def file_process(self, *a, **k):
)
actor = ray_actors.DataProcessorRayActor()
- chunks = actor.process_file("/tmp/a.pdf", "basic", destination="local")
+ source_path = make_temp_file(tmp_path, "a.pdf", content=b"%PDF-1.4")
+ chunks = actor.process_file(source_path, "basic", destination="local")
assert len(chunks) == 2
assert chunks[1]["metadata"]["process_source"] == "UniversalImageExtractor"
assert "image_url" in chunks[1]["metadata"]
-def test_process_file_skips_invalid_image_entries(monkeypatch):
+def test_process_file_skips_invalid_image_entries(monkeypatch, tmp_path):
ray_actors = import_module(monkeypatch)
class CoreWithBadImages:
@@ -623,7 +620,8 @@ def file_process(self, *a, **k):
monkeypatch.setattr(ray_actors, "DataProcessCore", CoreWithBadImages)
actor = ray_actors.DataProcessorRayActor()
- chunks = actor.process_file("/tmp/a.pdf", "basic", destination="local")
+ source_path = make_temp_file(tmp_path, "a.pdf", content=b"%PDF-1.4")
+ chunks = actor.process_file(source_path, "basic", destination="local")
assert chunks == [{"content": "text", "metadata": {}}]
diff --git a/test/backend/database/test_knowledge_db.py b/test/backend/database/test_knowledge_db.py
index dc5147503..219372994 100644
--- a/test/backend/database/test_knowledge_db.py
+++ b/test/backend/database/test_knowledge_db.py
@@ -198,27 +198,32 @@ def mock_session():
return mock_session, mock_query
-def test_create_knowledge_record_success(monkeypatch, mock_session):
- """Test successful creation of knowledge record"""
- session, _ = mock_session
-
- # Create mock knowledge record
- mock_record = MockKnowledgeRecord(knowledge_name="test_knowledge")
- mock_record.knowledge_id = 123
- mock_record.index_name = "test_knowledge"
-
- # Mock database session context
+def setup_mock_db_session(monkeypatch, session):
mock_ctx = MagicMock()
mock_ctx.__enter__.return_value = session
- # Mock the context manager to call rollback on exception, like the real get_db_session does
def mock_exit(exc_type, exc_val, exc_tb):
if exc_type is not None:
session.rollback()
return None # Don't suppress the exception
+
mock_ctx.__exit__.side_effect = mock_exit
monkeypatch.setattr(
"backend.database.knowledge_db.get_db_session", lambda: mock_ctx)
+ return mock_ctx
+
+
+def test_create_knowledge_record_success(monkeypatch, mock_session):
+ """Test successful creation of knowledge record"""
+ session, _ = mock_session
+
+ # Create mock knowledge record
+ mock_record = MockKnowledgeRecord(knowledge_name="test_knowledge")
+ mock_record.knowledge_id = 123
+ mock_record.index_name = "test_knowledge"
+
+ # Mock database session context
+ setup_mock_db_session(monkeypatch, session)
# Prepare test data
test_query = {
@@ -305,17 +310,7 @@ def test_create_knowledge_record_sets_multimodal_flag(monkeypatch, mock_session)
mock_record.knowledge_id = 123
mock_record.index_name = "test_knowledge"
- mock_ctx = MagicMock()
- mock_ctx.__enter__.return_value = session
-
- def mock_exit(exc_type, exc_val, exc_tb):
- if exc_type is not None:
- session.rollback()
- return None
-
- mock_ctx.__exit__.side_effect = mock_exit
- monkeypatch.setattr(
- "backend.database.knowledge_db.get_db_session", lambda: mock_ctx)
+ setup_mock_db_session(monkeypatch, session)
test_query = {
"index_name": "test_knowledge",
@@ -339,17 +334,7 @@ def test_create_knowledge_record_exception(monkeypatch, mock_session):
session, _ = mock_session
session.add.side_effect = MockSQLAlchemyError("Database error")
- mock_ctx = MagicMock()
- mock_ctx.__enter__.return_value = session
- # Mock the context manager to call rollback on exception, like the real get_db_session does
-
- def mock_exit(exc_type, exc_val, exc_tb):
- if exc_type is not None:
- session.rollback()
- return None # Don't suppress the exception
- mock_ctx.__exit__.side_effect = mock_exit
- monkeypatch.setattr(
- "backend.database.knowledge_db.get_db_session", lambda: mock_ctx)
+ setup_mock_db_session(monkeypatch, session)
test_query = {
"index_name": "test_knowledge",
@@ -374,17 +359,7 @@ def test_create_knowledge_record_generates_index_name(monkeypatch, mock_session)
mock_record = MockKnowledgeRecord(knowledge_name="kb1")
mock_record.knowledge_id = 7
- mock_ctx = MagicMock()
- mock_ctx.__enter__.return_value = session
- # Mock the context manager to call rollback on exception, like the real get_db_session does
-
- def mock_exit(exc_type, exc_val, exc_tb):
- if exc_type is not None:
- session.rollback()
- return None # Don't suppress the exception
- mock_ctx.__exit__.side_effect = mock_exit
- monkeypatch.setattr(
- "backend.database.knowledge_db.get_db_session", lambda: mock_ctx)
+ setup_mock_db_session(monkeypatch, session)
# Deterministic index name
monkeypatch.setattr(
diff --git a/test/sdk/core/tools/test_knowledge_base_search_tool.py b/test/sdk/core/tools/test_knowledge_base_search_tool.py
index 8f377aa74..478b2687e 100644
--- a/test/sdk/core/tools/test_knowledge_base_search_tool.py
+++ b/test/sdk/core/tools/test_knowledge_base_search_tool.py
@@ -185,7 +185,7 @@ def test_search_hybrid_error(self, knowledge_base_search_tool):
with pytest.raises(Exception) as excinfo:
knowledge_base_search_tool.search_hybrid("test query", ["test_index1"])
- assert "Error during semantic search" in str(excinfo.value)
+ assert "Error during hybrid search" in str(excinfo.value)
def test_forward_accurate_mode_success(self, knowledge_base_search_tool):
"""Test forward method with accurate search mode"""
@@ -305,8 +305,8 @@ def test_forward_title_fallback(self, knowledge_base_search_tool):
def test_forward_adds_picture_web_for_images(self, knowledge_base_search_tool, monkeypatch):
"""Forward should add picture messages when image results are present."""
- monkeypatch.setenv("DATA_PROCESS_SERVICE", "http://data-process")
- knowledge_base_search_tool.data_process_service = "http://data-process"
+ monkeypatch.setenv("DATA_PROCESS_SERVICE", "https://data-process")
+ knowledge_base_search_tool.data_process_service = "https://data-process"
mock_results = [
{
From a959469753dab81dcf344138f77af1ff51b58b9b Mon Sep 17 00:00:00 2001
From: wyxkerry <1012700194@qq.com>
Date: Tue, 31 Mar 2026 14:55:00 +0800
Subject: [PATCH 06/11] =?UTF-8?q?=E2=9C=A8add=5Fimage=5Fretrieval?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
sdk/nexent/data_process/extract_image.py | 28 +++++++-----------------
1 file changed, 8 insertions(+), 20 deletions(-)
diff --git a/sdk/nexent/data_process/extract_image.py b/sdk/nexent/data_process/extract_image.py
index 3fbe377f1..2fd87dc5c 100644
--- a/sdk/nexent/data_process/extract_image.py
+++ b/sdk/nexent/data_process/extract_image.py
@@ -47,20 +47,12 @@ def _hash(data: bytes) -> str:
return hashlib.sha256(data).hexdigest()
@staticmethod
- def _openxml_namespace_maps() -> List[Dict[str, str]]:
- # Prefer https URIs, but retain http for compatibility with existing files.
- return [
- {
- "xdr": "https://schemas.openxmlformats.org/drawingml/2006/spreadsheetDrawing",
- "a": "https://schemas.openxmlformats.org/drawingml/2006/main",
- "r": "https://schemas.openxmlformats.org/officeDocument/2006/relationships",
- },
- {
- "xdr": "http://schemas.openxmlformats.org/drawingml/2006/spreadsheetDrawing",
- "a": "http://schemas.openxmlformats.org/drawingml/2006/main",
- "r": "http://schemas.openxmlformats.org/officeDocument/2006/relationships",
- },
- ]
+ def _openxml_namespace_maps() -> Dict[str, str]:
+ return {
+ "xdr": "https://schemas.openxmlformats.org/drawingml/2006/spreadsheetDrawing",
+ "a": "https://schemas.openxmlformats.org/drawingml/2006/main",
+ "r": "https://schemas.openxmlformats.org/officeDocument/2006/relationships",
+ }
def _write_temp_file(self, data: bytes, suffix: str) -> str:
@@ -342,13 +334,9 @@ def _extract_excel(self, xlsx_path):
with zipfile.ZipFile(xlsx_path) as z:
sheet_files = self._excel_sheet_files(z)
+ ns = self._openxml_namespace_maps()
for sheet_file in sheet_files:
- extracted = []
- for ns in self._openxml_namespace_maps():
- extracted = self._extract_excel_sheet(z, sheet_file, ns, seen)
- if extracted:
- break
- results.extend(extracted)
+ results.extend(self._extract_excel_sheet(z, sheet_file, ns, seen))
return results
From 2abcdb3826be7a20887dcf128ec78398c7e0fed7 Mon Sep 17 00:00:00 2001
From: wyxkerry <1012700194@qq.com>
Date: Tue, 31 Mar 2026 22:58:27 +0800
Subject: [PATCH 07/11] =?UTF-8?q?=E2=9C=A8add=5Fimage=5Fretrieval?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.../[locale]/knowledges/KnowledgeBaseConfiguration.tsx | 10 ----------
sdk/nexent/data_process/extract_image.py | 6 +++---
2 files changed, 3 insertions(+), 13 deletions(-)
diff --git a/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx b/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx
index fc86ada16..c9e79f149 100644
--- a/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx
+++ b/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx
@@ -26,7 +26,6 @@ import knowledgeBaseService from "@/services/knowledgeBaseService";
import knowledgeBasePollingService from "@/services/knowledgeBasePollingService";
import { KnowledgeBase } from "@/types/knowledgeBase";
import { useConfig } from "@/hooks/useConfig";
-import { useModelList } from "@/hooks/model/useModelList";
import {
SETUP_PAGE_CONTAINER,
TWO_COLUMN_LAYOUT,
@@ -128,9 +127,6 @@ function DataConfig({ isActive }: DataConfigProps) {
const { modelConfig, data: configData, invalidateConfig, config, updateConfig, saveConfig } = useConfig();
const { token } = theme.useToken();
- // Get available embedding models for knowledge base creation
- const { availableEmbeddingModels } = useModelList({ enabled: true });
-
// Clear cache when component initializes
useEffect(() => {
localStorage.removeItem("preloaded_kb_data");
@@ -635,12 +631,6 @@ function DataConfig({ isActive }: DataConfigProps) {
setNewKbName(defaultName);
setNewKbIngroupPermission("READ_ONLY");
setNewKbGroupIds([]);
- // Set default embedding model - prioritize config's default model, fall back to first available model
- const configModel = modelConfig?.embedding?.modelName;
- const defaultModel = configModel || (availableEmbeddingModels.length > 0
- ? availableEmbeddingModels[0].displayName
- : "");
- setNewKbEmbeddingModel(defaultModel);
setIsCreatingMode(true);
setHasClickedUpload(false); // Reset upload button click state
setUploadFiles([]); // Reset upload files array, clear all pending upload files
diff --git a/sdk/nexent/data_process/extract_image.py b/sdk/nexent/data_process/extract_image.py
index 2fd87dc5c..6d5051132 100644
--- a/sdk/nexent/data_process/extract_image.py
+++ b/sdk/nexent/data_process/extract_image.py
@@ -49,9 +49,9 @@ def _hash(data: bytes) -> str:
@staticmethod
def _openxml_namespace_maps() -> Dict[str, str]:
return {
- "xdr": "https://schemas.openxmlformats.org/drawingml/2006/spreadsheetDrawing",
- "a": "https://schemas.openxmlformats.org/drawingml/2006/main",
- "r": "https://schemas.openxmlformats.org/officeDocument/2006/relationships",
+ "xdr": "http://schemas.openxmlformats.org/drawingml/2006/spreadsheetDrawing",
+ "a": "http://schemas.openxmlformats.org/drawingml/2006/main",
+ "r": "http://schemas.openxmlformats.org/officeDocument/2006/relationships",
}
From 5ed7e8adc8aacdc35dcc180a26074e7cffa53032 Mon Sep 17 00:00:00 2001
From: wyxkerry <1012700194@qq.com>
Date: Wed, 1 Apr 2026 00:19:13 +0800
Subject: [PATCH 08/11] =?UTF-8?q?=E2=9C=A8add=5Fimage=5Fretrieval?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
sdk/nexent/data_process/extract_image.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/sdk/nexent/data_process/extract_image.py b/sdk/nexent/data_process/extract_image.py
index 6d5051132..38b452d6d 100644
--- a/sdk/nexent/data_process/extract_image.py
+++ b/sdk/nexent/data_process/extract_image.py
@@ -49,9 +49,9 @@ def _hash(data: bytes) -> str:
@staticmethod
def _openxml_namespace_maps() -> Dict[str, str]:
return {
- "xdr": "http://schemas.openxmlformats.org/drawingml/2006/spreadsheetDrawing",
- "a": "http://schemas.openxmlformats.org/drawingml/2006/main",
- "r": "http://schemas.openxmlformats.org/officeDocument/2006/relationships",
+ "xdr": "http://schemas.openxmlformats.org/drawingml/2006/spreadsheetDrawing", # NOSONAR
+ "a": "http://schemas.openxmlformats.org/drawingml/2006/main", # NOSONAR
+ "r": "http://schemas.openxmlformats.org/officeDocument/2006/relationships", # NOSONAR
}
From 39744e86315c58b21d0f5b55e2f67312e1b5f243 Mon Sep 17 00:00:00 2001
From: wyxkerry <1012700194@qq.com>
Date: Sun, 5 Apr 2026 11:58:07 +0800
Subject: [PATCH 09/11] =?UTF-8?q?=E2=9C=A8add=5Fimage=5Fretrieval?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
backend/apps/file_management_app.py | 6 ++++--
backend/consts/model.py | 1 +
backend/data_process/ray_actors.py | 3 +++
backend/services/data_process_service.py | 6 +++++-
backend/utils/file_management_utils.py | 6 +++++-
.../knowledges/KnowledgeBaseConfiguration.tsx | 8 ++++++--
.../knowledges/contexts/DocumentContext.tsx | 8 ++++----
.../contexts/KnowledgeBaseContext.tsx | 2 +-
frontend/services/knowledgeBaseService.ts | 6 ++++--
frontend/tsconfig.json | 2 +-
sdk/nexent/data_process/core.py | 13 ++++++++-----
.../vector_database/elasticsearch_core.py | 18 +++++++++---------
12 files changed, 51 insertions(+), 28 deletions(-)
diff --git a/backend/apps/file_management_app.py b/backend/apps/file_management_app.py
index 5b7c7bc3c..77af77650 100644
--- a/backend/apps/file_management_app.py
+++ b/backend/apps/file_management_app.py
@@ -120,6 +120,7 @@ async def process_files(
chunking_strategy: Optional[str] = Body("basic"),
index_name: str = Body(...),
destination: str = Body(...),
+ is_multimodal: Optional[bool] = Body(False),
authorization: Optional[str] = Header(None)
):
"""
@@ -133,7 +134,8 @@ async def process_files(
chunking_strategy=chunking_strategy,
source_type=destination,
index_name=index_name,
- authorization=authorization
+ authorization=authorization,
+ is_multimodal=is_multimodal
)
process_result = await trigger_data_process(files, process_params)
@@ -639,4 +641,4 @@ async def preview_file(
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail=f"Failed to preview file: {str(e)}"
- )
\ No newline at end of file
+ )
diff --git a/backend/consts/model.py b/backend/consts/model.py
index 2728d95ca..128fe81d4 100644
--- a/backend/consts/model.py
+++ b/backend/consts/model.py
@@ -234,6 +234,7 @@ class ProcessParams(BaseModel):
source_type: str
index_name: str
authorization: Optional[str] = None
+ is_multimodal: Optional[bool] = False
class OpinionRequest(BaseModel):
diff --git a/backend/data_process/ray_actors.py b/backend/data_process/ray_actors.py
index 934be1720..b9fd982ae 100644
--- a/backend/data_process/ray_actors.py
+++ b/backend/data_process/ray_actors.py
@@ -118,9 +118,12 @@ def _apply_model_chunk_sizes(
maximum_chunk_size = model_record.get(
'maximum_chunk_size', DEFAULT_MAXIMUM_CHUNK_SIZE)
model_name = model_record.get('display_name')
+ model_type = model_record.get('model_type')
params['max_characters'] = maximum_chunk_size
params['new_after_n_chars'] = expected_chunk_size
+ if model_type:
+ params['model_type'] = model_type
logger.info(
f"[RayActor] Using chunk sizes from embedding model '{model_name}' (ID: {model_id}): "
diff --git a/backend/services/data_process_service.py b/backend/services/data_process_service.py
index 9eae72407..dcd80424d 100644
--- a/backend/services/data_process_service.py
+++ b/backend/services/data_process_service.py
@@ -474,6 +474,8 @@ async def create_batch_tasks_impl(self, authorization: Optional[str], request: B
chunking_strategy = source_config.get('chunking_strategy')
index_name = source_config.get('index_name')
original_filename = source_config.get('original_filename')
+ embedding_model_id = source_config.get('embedding_model_id')
+ tenant_id = source_config.get('tenant_id')
# Validate required fields
if not source:
@@ -492,7 +494,9 @@ async def create_batch_tasks_impl(self, authorization: Optional[str], request: B
source_type=source_type,
chunking_strategy=chunking_strategy,
index_name=index_name,
- original_filename=original_filename
+ original_filename=original_filename,
+ embedding_model_id=embedding_model_id,
+ tenant_id=tenant_id
).set(queue='process_q'),
forward.s(
index_name=index_name,
diff --git a/backend/utils/file_management_utils.py b/backend/utils/file_management_utils.py
index 57025e350..4103770f3 100644
--- a/backend/utils/file_management_utils.py
+++ b/backend/utils/file_management_utils.py
@@ -40,11 +40,13 @@ async def trigger_data_process(files: List[dict], process_params: ProcessParams)
# Get chunking size according to the embedding model
embedding_model_id = None
tenant_id = None
+ is_multimodal = process_params.is_multimodal
try:
_, tenant_id = get_current_user_id(process_params.authorization)
# Get embedding model ID from tenant config
tenant_config = tenant_config_manager.load_config(tenant_id)
- embedding_model_id_str = tenant_config.get("EMBEDDING_ID") if tenant_config else None
+ embedding_id_key = "MULTI_EMBEDDING_ID" if is_multimodal else "EMBEDDING_ID"
+ embedding_model_id_str = tenant_config.get(embedding_id_key) if tenant_config else None
if embedding_model_id_str:
embedding_model_id = int(embedding_model_id_str)
except Exception as e:
@@ -66,6 +68,7 @@ async def trigger_data_process(files: List[dict], process_params: ProcessParams)
"index_name": process_params.index_name,
"original_filename": file_details.get("filename"),
"embedding_model_id": embedding_model_id,
+ "is_multimodal": is_multimodal,
"tenant_id": tenant_id
}
@@ -97,6 +100,7 @@ async def trigger_data_process(files: List[dict], process_params: ProcessParams)
"index_name": process_params.index_name,
"original_filename": file_details.get("filename"),
"embedding_model_id": embedding_model_id,
+ "is_multimodal": is_multimodal,
"tenant_id": tenant_id
}
sources.append(source)
diff --git a/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx b/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx
index c9e79f149..954c3b82e 100644
--- a/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx
+++ b/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx
@@ -711,7 +711,7 @@ function DataConfig({ isActive }: DataConfigProps) {
setHasClickedUpload(false);
setNewlyCreatedKbId(newKB.id); // Mark this KB as newly created
- await uploadDocuments(newKB.id, filesToUpload);
+ await uploadDocuments(newKB.id, filesToUpload, isMultimodal);
setUploadFiles([]);
knowledgeBasePollingService
@@ -747,7 +747,11 @@ function DataConfig({ isActive }: DataConfigProps) {
}
try {
- await uploadDocuments(kbId, filesToUpload);
+ await uploadDocuments(
+ kbId,
+ filesToUpload,
+ kbState.activeKnowledgeBase?.is_multimodal
+ );
setUploadFiles([]);
knowledgeBasePollingService.triggerKnowledgeBaseListUpdate(true);
diff --git a/frontend/app/[locale]/knowledges/contexts/DocumentContext.tsx b/frontend/app/[locale]/knowledges/contexts/DocumentContext.tsx
index b956dd919..7a2dcfb2e 100644
--- a/frontend/app/[locale]/knowledges/contexts/DocumentContext.tsx
+++ b/frontend/app/[locale]/knowledges/contexts/DocumentContext.tsx
@@ -112,7 +112,7 @@ export const DocumentContext = createContext<{
state: DocumentState;
dispatch: React.Dispatch;
fetchDocuments: (kbId: string, forceRefresh?: boolean, kbSource?: string) => Promise;
- uploadDocuments: (kbId: string, files: File[]) => Promise;
+ uploadDocuments: (kbId: string, files: File[], isMultimodal?: boolean) => Promise;
deleteDocument: (kbId: string, docId: string) => Promise;
}>({
state: {
@@ -202,11 +202,11 @@ export const DocumentProvider: React.FC = ({ children })
}, [state.loadingKbIds, state.documentsMap, t]);
// Upload documents to a knowledge base
- const uploadDocuments = useCallback(async (kbId: string, files: File[]) => {
+ const uploadDocuments = useCallback(async (kbId: string, files: File[], isMultimodal?: boolean) => {
dispatch({ type: DOCUMENT_ACTION_TYPES.SET_UPLOADING, payload: true });
try {
- await knowledgeBaseService.uploadDocuments(kbId, files);
+ await knowledgeBaseService.uploadDocuments(kbId, files, undefined, isMultimodal);
// Set loading state before fetching latest documents
dispatch({ type: DOCUMENT_ACTION_TYPES.SET_LOADING_DOCUMENTS, payload: true });
@@ -265,4 +265,4 @@ export const DocumentProvider: React.FC = ({ children })
{children}
);
-};
\ No newline at end of file
+};
diff --git a/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx b/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx
index 1a087a6a5..0aa397863 100644
--- a/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx
+++ b/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx
@@ -110,7 +110,7 @@ export const KnowledgeBaseContext = createContext<{
source?: string,
ingroup_permission?: string,
group_ids?: number[],
- is_multiimodal?: boolean,
+ is_multimodal?: boolean,
) => Promise;
deleteKnowledgeBase: (id: string) => Promise;
selectKnowledgeBase: (id: string) => void;
diff --git a/frontend/services/knowledgeBaseService.ts b/frontend/services/knowledgeBaseService.ts
index 657160fc7..b2cc4fce3 100644
--- a/frontend/services/knowledgeBaseService.ts
+++ b/frontend/services/knowledgeBaseService.ts
@@ -700,7 +700,7 @@ class KnowledgeBaseService {
} = {
name: params.name,
description: params.description || "",
- embeddingModel: params.embeddingModel || "",
+ embedding_model_name: params.embeddingModel || "",
is_multimodal: params.is_multimodal || false
};
@@ -846,7 +846,8 @@ class KnowledgeBaseService {
async uploadDocuments(
kbId: string,
files: File[],
- chunkingStrategy?: string
+ chunkingStrategy?: string,
+ isMultimodal?: boolean
): Promise {
try {
// Create FormData object
@@ -908,6 +909,7 @@ class KnowledgeBaseService {
files: filesToProcess,
chunking_strategy: chunkingStrategy,
destination: "minio",
+ is_multimodal: isMultimodal ?? false,
}),
});
diff --git a/frontend/tsconfig.json b/frontend/tsconfig.json
index d61634fac..75f792957 100644
--- a/frontend/tsconfig.json
+++ b/frontend/tsconfig.json
@@ -8,7 +8,7 @@
"noEmit": true,
"esModuleInterop": true,
"module": "esnext",
- "moduleResolution": "node",
+ "moduleResolution": "bundler",
"resolveJsonModule": true,
"isolatedModules": true,
"jsx": "preserve",
diff --git a/sdk/nexent/data_process/core.py b/sdk/nexent/data_process/core.py
index 84bff7c5a..b58e6fe03 100644
--- a/sdk/nexent/data_process/core.py
+++ b/sdk/nexent/data_process/core.py
@@ -1,6 +1,6 @@
import logging
import os
-from typing import Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
from .extract_image import UniversalImageExtractor
@@ -86,10 +86,10 @@ def file_process(
# Select appropriate processor
if processor:
processor_name = processor
- _, extractor = self._select_processor_by_filename(filename)
+ _, extractor = self._select_processor_by_filename(filename, params)
else:
processor_name, extractor = self._select_processor_by_filename(
- filename)
+ filename, params)
processor_instance = self.processors.get(processor_name)
extract_image_processor_instance = (
@@ -131,13 +131,16 @@ def _validate_parameters(self, chunking_strategy: str, processor: Optional[str])
logger.debug(
f"Parameter validation passed: chunking_strategy={chunking_strategy}, processor={processor}")
- def _select_processor_by_filename(self, filename: str) -> Tuple[str, Optional[str]]:
+ def _select_processor_by_filename(
+ self, filename: str, params: Optional[Dict[str, Any]] = None
+ ) -> Tuple[str, Optional[str]]:
"""Selects a processor based on the file extension."""
_, file_extension = os.path.splitext(filename)
file_extension = file_extension.lower()
extract_image = None
- if file_extension in self.EXTRACT_IMAGE_EXTENSIONS:
+ model_type = params.get("model_type")
+ if model_type == "multi_embedding" and file_extension in self.EXTRACT_IMAGE_EXTENSIONS:
extract_image = "UniversalImageExtractor"
if file_extension in self.EXCEL_EXTENSIONS:
return "OpenPyxl", extract_image
diff --git a/sdk/nexent/vector_database/elasticsearch_core.py b/sdk/nexent/vector_database/elasticsearch_core.py
index 4f230786c..6d5953f0c 100644
--- a/sdk/nexent/vector_database/elasticsearch_core.py
+++ b/sdk/nexent/vector_database/elasticsearch_core.py
@@ -461,15 +461,15 @@ def _prepare_small_batch_embeddings(
inputs.append({"text": doc[content_field]})
embeddings = embedding_model.get_multimodal_embeddings(inputs)
return processed_docs, embeddings
-
- filtered_docs = [
- doc
- for doc in processed_docs
- if doc.get("process_source") != "UniversalImageExtractor"
- ]
- inputs = [doc[content_field] for doc in filtered_docs]
- embeddings = embedding_model.get_embeddings(inputs)
- return filtered_docs, embeddings
+ else:
+ filtered_docs = [
+ doc
+ for doc in processed_docs
+ if doc.get("process_source") != "UniversalImageExtractor"
+ ]
+ inputs = [doc[content_field] for doc in filtered_docs]
+ embeddings = embedding_model.get_embeddings(inputs)
+ return filtered_docs, embeddings
@staticmethod
def _build_bulk_operations(
From 845a369542d57f60a5cf978618a77e57e8efd8cb Mon Sep 17 00:00:00 2001
From: wyxkerry <1012700194@qq.com>
Date: Sun, 5 Apr 2026 13:30:55 +0800
Subject: [PATCH 10/11] =?UTF-8?q?=E2=9C=A8add=5Fimage=5Fretrieval?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
backend/apps/file_management_app.py | 16 ++++++++--------
test/backend/database/test_attachment_db.py | 2 ++
.../services/test_data_process_service.py | 8 ++++++--
.../services/test_vectordatabase_service.py | 6 ++++--
test/sdk/data_process/test_core.py | 11 ++++++++---
5 files changed, 28 insertions(+), 15 deletions(-)
diff --git a/backend/apps/file_management_app.py b/backend/apps/file_management_app.py
index 77af77650..3c9a95fbe 100644
--- a/backend/apps/file_management_app.py
+++ b/backend/apps/file_management_app.py
@@ -2,7 +2,7 @@
import re
import base64
from http import HTTPStatus
-from typing import List, Optional
+from typing import Annotated, List, Optional
from urllib.parse import urlparse, urlunparse, unquote, quote
import httpx
@@ -115,13 +115,13 @@ async def upload_files(
@file_management_config_router.post("/process")
async def process_files(
- files: List[dict] = Body(
- ..., description="List of file details to process, including path_or_url and filename"),
- chunking_strategy: Optional[str] = Body("basic"),
- index_name: str = Body(...),
- destination: str = Body(...),
- is_multimodal: Optional[bool] = Body(False),
- authorization: Optional[str] = Header(None)
+ files: Annotated[List[dict], Body(
+ ..., description="List of file details to process, including path_or_url and filename")],
+ index_name: Annotated[str, Body(...)],
+ destination: Annotated[str, Body(...)],
+ chunking_strategy: Annotated[Optional[str], Body("basic")],
+ is_multimodal: Annotated[Optional[bool], Body(False)],
+ authorization: Annotated[Optional[str], Header(None)]
):
"""
Trigger data processing for a list of uploaded files.
diff --git a/test/backend/database/test_attachment_db.py b/test/backend/database/test_attachment_db.py
index afb080682..771b90b27 100644
--- a/test/backend/database/test_attachment_db.py
+++ b/test/backend/database/test_attachment_db.py
@@ -17,6 +17,8 @@
# Mock consts module
consts_mock = MagicMock()
consts_mock.const = MagicMock()
+# Ensure constants are real strings to avoid startswith TypeError
+consts_mock.const.S3_URL_PREFIX = "s3://"
# Environment variables are now configured in conftest.py
sys.modules['consts'] = consts_mock
diff --git a/test/backend/services/test_data_process_service.py b/test/backend/services/test_data_process_service.py
index 393bea339..f306c54a5 100644
--- a/test/backend/services/test_data_process_service.py
+++ b/test/backend/services/test_data_process_service.py
@@ -1667,14 +1667,18 @@ async def async_test_create_batch_tasks_impl_success(self, mock_process, mock_fo
'source_type': 'url',
'chunking_strategy': 'semantic',
'index_name': 'test_index_1',
- 'original_filename': 'doc1.pdf'
+ 'original_filename': 'doc1.pdf',
+ 'embedding_model_id': None,
+ 'tenant_id': None
},
{
'source': 'http://example.com/doc2.pdf',
'source_type': 'url',
'chunking_strategy': 'fixed',
'index_name': 'test_index_2',
- 'original_filename': 'doc2.pdf'
+ 'original_filename': 'doc2.pdf',
+ 'embedding_model_id': None,
+ 'tenant_id': None
}
]
actual_process_calls = [kwargs for args,
diff --git a/test/backend/services/test_vectordatabase_service.py b/test/backend/services/test_vectordatabase_service.py
index 66c7e8a7a..b583565bb 100644
--- a/test/backend/services/test_vectordatabase_service.py
+++ b/test/backend/services/test_vectordatabase_service.py
@@ -512,7 +512,9 @@ def test_create_knowledge_base_with_embedding_model_name(self, mock_get_embeddin
self.assertEqual(result["knowledge_id"], 10)
# Verify get_embedding_model was called with the model name
- mock_get_embedding.assert_called_once_with("tenant-1", "text-embedding-3-small")
+ mock_get_embedding.assert_called_once_with(
+ "tenant-1", is_multimodal=False, model_name="text-embedding-3-small"
+ )
# Verify knowledge record was created with the embedding model name
mock_create_knowledge.assert_called_once()
@@ -559,7 +561,7 @@ def test_create_knowledge_base_without_embedding_model_name_uses_default(self, m
self.assertEqual(result["status"], "success")
# Verify get_embedding_model was called with None (no specific model)
- mock_get_embedding.assert_called_once_with("tenant-1", None)
+ mock_get_embedding.assert_called_once_with("tenant-1", is_multimodal=False)
# Verify knowledge record was created with the model's display name
mock_create_knowledge.assert_called_once()
diff --git a/test/sdk/data_process/test_core.py b/test/sdk/data_process/test_core.py
index af325b52f..5dfff546f 100644
--- a/test/sdk/data_process/test_core.py
+++ b/test/sdk/data_process/test_core.py
@@ -207,7 +207,8 @@ def test_validate_parameters_invalid_processor(self, core):
)
def test_select_processor_by_filename(self, core, filename, expected_processor, expected_extractor):
"""Test processor selection based on filename"""
- processor_name, extractor = core._select_processor_by_filename(filename)
+ params = {"model_type": "multi_embedding"} if expected_extractor else {}
+ processor_name, extractor = core._select_processor_by_filename(filename, params)
assert processor_name == expected_processor
assert extractor == expected_extractor
@@ -349,7 +350,7 @@ def test_file_process_returns_images_when_extractor_available(self, core, mocker
core.processors["UniversalImageExtractor"] = mock_extractor
result = core.file_process(
- b"data", "sample.pdf", chunking_strategy="basic"
+ b"data", "sample.pdf", chunking_strategy="basic", model_type="multi_embedding"
)
chunks = _unpack_chunks(result)
@@ -366,7 +367,11 @@ def test_file_process_with_explicit_processor_still_extracts_images(self, core):
)
result = core.file_process(
- b"data", "report.pdf", chunking_strategy="basic", processor="Unstructured"
+ b"data",
+ "report.pdf",
+ chunking_strategy="basic",
+ processor="Unstructured",
+ model_type="multi_embedding",
)
chunks = _unpack_chunks(result)
From 46e71223d1bc5f524abf7064f294d24785a63fc0 Mon Sep 17 00:00:00 2001
From: wyxkerry <1012700194@qq.com>
Date: Sun, 5 Apr 2026 17:53:46 +0800
Subject: [PATCH 11/11] =?UTF-8?q?=E2=9C=A8add=5Fimage=5Fretrieval?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
backend/apps/file_management_app.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/backend/apps/file_management_app.py b/backend/apps/file_management_app.py
index 3c9a95fbe..e0321237d 100644
--- a/backend/apps/file_management_app.py
+++ b/backend/apps/file_management_app.py
@@ -119,9 +119,9 @@ async def process_files(
..., description="List of file details to process, including path_or_url and filename")],
index_name: Annotated[str, Body(...)],
destination: Annotated[str, Body(...)],
- chunking_strategy: Annotated[Optional[str], Body("basic")],
- is_multimodal: Annotated[Optional[bool], Body(False)],
- authorization: Annotated[Optional[str], Header(None)]
+ chunking_strategy: Annotated[Optional[str], Body(...)] = "basic",
+ is_multimodal: Annotated[Optional[bool], Body(...)] = False,
+ authorization: Annotated[Optional[str], Header()] = None
):
"""
Trigger data processing for a list of uploaded files.