diff --git a/openviking/models/embedder/vikingdb_embedders.py b/openviking/models/embedder/vikingdb_embedders.py index d9c5cf49b..de59640e5 100644 --- a/openviking/models/embedder/vikingdb_embedders.py +++ b/openviking/models/embedder/vikingdb_embedders.py @@ -40,11 +40,15 @@ def _call_api( texts: List[str], dense_model: Dict[str, Any] = None, sparse_model: Optional[Dict[str, Any]] = None, + input_type: Optional[str] = None, ) -> List[Dict[str, Any]]: """Call VikingDB Embedding API""" path = "/api/vikingdb/embedding" data_items = [{"text": text} for text in texts] + if input_type is not None: + for item in data_items: + item["input_type"] = input_type req_body = {"data": data_items} if dense_model: @@ -116,6 +120,8 @@ def __init__( dimension: Optional[int] = None, embedding_type: str = "text", config: Optional[Dict[str, Any]] = None, + query_param: Optional[str] = None, + document_param: Optional[str] = None, ): DenseEmbedderBase.__init__(self, model_name, config) self._init_vikingdb_client(ak, sk, region, host) @@ -123,10 +129,21 @@ def __init__( self.dimension = dimension self.embedding_type = embedding_type self.dense_model = {"name": model_name, "version": model_version, "dim": dimension} + self.query_param = query_param + self.document_param = document_param + + def _resolve_input_type(self, is_query: bool) -> Optional[str]: + """Return the input_type value for query or document side, or None for symmetric mode.""" + if is_query and self.query_param is not None: + return self.query_param + if not is_query and self.document_param is not None: + return self.document_param + return None def embed(self, text: str, is_query: bool = False) -> EmbedResult: + input_type = self._resolve_input_type(is_query) results = transient_retry( - lambda: self._call_api([text], dense_model=self.dense_model), + lambda: self._call_api([text], dense_model=self.dense_model, input_type=input_type), max_retries=self.max_retries, ) if not results: @@ -142,8 +159,9 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult: def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]: if not texts: return [] + input_type = self._resolve_input_type(is_query) raw_results = transient_retry( - lambda: self._call_api(texts, dense_model=self.dense_model), + lambda: self._call_api(texts, dense_model=self.dense_model, input_type=input_type), max_retries=self.max_retries, ) return [ @@ -224,6 +242,8 @@ def __init__( dimension: Optional[int] = None, embedding_type: str = "text", config: Optional[Dict[str, Any]] = None, + query_param: Optional[str] = None, + document_param: Optional[str] = None, ): HybridEmbedderBase.__init__(self, model_name, config) self._init_vikingdb_client(ak, sk, region, host) @@ -235,11 +255,25 @@ def __init__( "name": model_name, "version": model_version, } + self.query_param = query_param + self.document_param = document_param + + def _resolve_input_type(self, is_query: bool) -> Optional[str]: + """Return the input_type value for query or document side, or None for symmetric mode.""" + if is_query and self.query_param is not None: + return self.query_param + if not is_query and self.document_param is not None: + return self.document_param + return None def embed(self, text: str, is_query: bool = False) -> EmbedResult: + input_type = self._resolve_input_type(is_query) results = transient_retry( lambda: self._call_api( - [text], dense_model=self.dense_model, sparse_model=self.sparse_model + [text], + dense_model=self.dense_model, + sparse_model=self.sparse_model, + input_type=input_type, ), max_retries=self.max_retries, ) @@ -260,9 +294,13 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult: def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]: if not texts: return [] + input_type = self._resolve_input_type(is_query) raw_results = transient_retry( lambda: self._call_api( - texts, dense_model=self.dense_model, sparse_model=self.sparse_model + texts, + dense_model=self.dense_model, + sparse_model=self.sparse_model, + input_type=input_type, ), max_retries=self.max_retries, ) diff --git a/openviking_cli/utils/config/embedding_config.py b/openviking_cli/utils/config/embedding_config.py index 1377d5f4a..0536e34e0 100644 --- a/openviking_cli/utils/config/embedding_config.py +++ b/openviking_cli/utils/config/embedding_config.py @@ -399,6 +399,8 @@ def _create_embedder( "host": cfg.host, "dimension": cfg.dimension, "input_type": cfg.input, + **({"query_param": cfg.query_param} if cfg.query_param else {}), + **({"document_param": cfg.document_param} if cfg.document_param else {}), }, ), ("vikingdb", "sparse"): ( @@ -423,6 +425,8 @@ def _create_embedder( "host": cfg.host, "dimension": cfg.dimension, "input_type": cfg.input, + **({"query_param": cfg.query_param} if cfg.query_param else {}), + **({"document_param": cfg.document_param} if cfg.document_param else {}), }, ), ("jina", "dense"): ( diff --git a/tests/unit/embedder/test_vikingdb_nonsymmetric.py b/tests/unit/embedder/test_vikingdb_nonsymmetric.py new file mode 100644 index 000000000..3ce630b41 --- /dev/null +++ b/tests/unit/embedder/test_vikingdb_nonsymmetric.py @@ -0,0 +1,60 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: AGPL-3.0 +"""Tests for VikingDB non-symmetric embedding support.""" + +from unittest.mock import patch + +import pytest + +from openviking.models.embedder.vikingdb_embedders import ( + VikingDBDenseEmbedder, + VikingDBHybridEmbedder, +) + + +@pytest.fixture +def mock_vikingdb_client(): + """Patch VikingDB client initialization.""" + with patch.object( + VikingDBDenseEmbedder, "_init_vikingdb_client", return_value=None + ) as mock_init: + mock_init.side_effect = lambda *args, **kwargs: None + yield mock_init + + +def test_dense_resolve_input_type_symmetric(): + """When no query_param/document_param, input_type is None (symmetric).""" + embedder = VikingDBDenseEmbedder.__new__(VikingDBDenseEmbedder) + embedder.query_param = None + embedder.document_param = None + assert embedder._resolve_input_type(is_query=True) is None + assert embedder._resolve_input_type(is_query=False) is None + + +def test_dense_resolve_input_type_nonsymmetric(): + """When query_param/document_param set, return correct value for is_query.""" + embedder = VikingDBDenseEmbedder.__new__(VikingDBDenseEmbedder) + embedder.query_param = "query" + embedder.document_param = "passage" + assert embedder._resolve_input_type(is_query=True) == "query" + assert embedder._resolve_input_type(is_query=False) == "passage" + + +def test_hybrid_resolve_input_type_nonsymmetric(): + """Hybrid embedder also resolves input_type correctly.""" + embedder = VikingDBHybridEmbedder.__new__(VikingDBHybridEmbedder) + embedder.query_param = "search_query" + embedder.document_param = "search_document" + assert embedder._resolve_input_type(is_query=True) == "search_query" + assert embedder._resolve_input_type(is_query=False) == "search_document" + + +def test_dense_backward_compat_no_params(): + """VikingDBDenseEmbedder without query_param/document_param works.""" + embedder = VikingDBDenseEmbedder.__new__(VikingDBDenseEmbedder) + embedder.query_param = None + embedder.document_param = None + embedder.model_name = "test" + embedder.dimension = 1024 + # Should not raise + assert embedder._resolve_input_type(is_query=True) is None