Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions openviking/models/embedder/vikingdb_embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,15 @@ def _call_api(
texts: List[str],
dense_model: Dict[str, Any] = None,
sparse_model: Optional[Dict[str, Any]] = None,
input_type: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""Call VikingDB Embedding API"""
path = "/api/vikingdb/embedding"

data_items = [{"text": text} for text in texts]
if input_type is not None:
for item in data_items:
item["input_type"] = input_type

req_body = {"data": data_items}
if dense_model:
Expand Down Expand Up @@ -116,17 +120,30 @@ def __init__(
dimension: Optional[int] = None,
embedding_type: str = "text",
config: Optional[Dict[str, Any]] = None,
query_param: Optional[str] = None,
document_param: Optional[str] = None,
):
DenseEmbedderBase.__init__(self, model_name, config)
self._init_vikingdb_client(ak, sk, region, host)
self.model_version = model_version
self.dimension = dimension
self.embedding_type = embedding_type
self.dense_model = {"name": model_name, "version": model_version, "dim": dimension}
self.query_param = query_param
self.document_param = document_param

def _resolve_input_type(self, is_query: bool) -> Optional[str]:
"""Return the input_type value for query or document side, or None for symmetric mode."""
if is_query and self.query_param is not None:
return self.query_param
if not is_query and self.document_param is not None:
return self.document_param
return None

def embed(self, text: str, is_query: bool = False) -> EmbedResult:
input_type = self._resolve_input_type(is_query)
results = transient_retry(
lambda: self._call_api([text], dense_model=self.dense_model),
lambda: self._call_api([text], dense_model=self.dense_model, input_type=input_type),
max_retries=self.max_retries,
)
if not results:
Expand All @@ -142,8 +159,9 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult:
def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]:
if not texts:
return []
input_type = self._resolve_input_type(is_query)
raw_results = transient_retry(
lambda: self._call_api(texts, dense_model=self.dense_model),
lambda: self._call_api(texts, dense_model=self.dense_model, input_type=input_type),
max_retries=self.max_retries,
)
return [
Expand Down Expand Up @@ -224,6 +242,8 @@ def __init__(
dimension: Optional[int] = None,
embedding_type: str = "text",
config: Optional[Dict[str, Any]] = None,
query_param: Optional[str] = None,
document_param: Optional[str] = None,
):
HybridEmbedderBase.__init__(self, model_name, config)
self._init_vikingdb_client(ak, sk, region, host)
Expand All @@ -235,11 +255,25 @@ def __init__(
"name": model_name,
"version": model_version,
}
self.query_param = query_param
self.document_param = document_param

def _resolve_input_type(self, is_query: bool) -> Optional[str]:
"""Return the input_type value for query or document side, or None for symmetric mode."""
if is_query and self.query_param is not None:
return self.query_param
if not is_query and self.document_param is not None:
return self.document_param
return None

def embed(self, text: str, is_query: bool = False) -> EmbedResult:
input_type = self._resolve_input_type(is_query)
results = transient_retry(
lambda: self._call_api(
[text], dense_model=self.dense_model, sparse_model=self.sparse_model
[text],
dense_model=self.dense_model,
sparse_model=self.sparse_model,
input_type=input_type,
),
max_retries=self.max_retries,
)
Expand All @@ -260,9 +294,13 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult:
def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]:
if not texts:
return []
input_type = self._resolve_input_type(is_query)
raw_results = transient_retry(
lambda: self._call_api(
texts, dense_model=self.dense_model, sparse_model=self.sparse_model
texts,
dense_model=self.dense_model,
sparse_model=self.sparse_model,
input_type=input_type,
),
max_retries=self.max_retries,
)
Expand Down
4 changes: 4 additions & 0 deletions openviking_cli/utils/config/embedding_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,8 @@ def _create_embedder(
"host": cfg.host,
"dimension": cfg.dimension,
"input_type": cfg.input,
**({"query_param": cfg.query_param} if cfg.query_param else {}),
**({"document_param": cfg.document_param} if cfg.document_param else {}),
},
),
("vikingdb", "sparse"): (
Expand All @@ -423,6 +425,8 @@ def _create_embedder(
"host": cfg.host,
"dimension": cfg.dimension,
"input_type": cfg.input,
**({"query_param": cfg.query_param} if cfg.query_param else {}),
**({"document_param": cfg.document_param} if cfg.document_param else {}),
},
),
("jina", "dense"): (
Expand Down
60 changes: 60 additions & 0 deletions tests/unit/embedder/test_vikingdb_nonsymmetric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd.
# SPDX-License-Identifier: AGPL-3.0
"""Tests for VikingDB non-symmetric embedding support."""

from unittest.mock import patch

import pytest

from openviking.models.embedder.vikingdb_embedders import (
VikingDBDenseEmbedder,
VikingDBHybridEmbedder,
)


@pytest.fixture
def mock_vikingdb_client():
"""Patch VikingDB client initialization."""
with patch.object(
VikingDBDenseEmbedder, "_init_vikingdb_client", return_value=None
) as mock_init:
mock_init.side_effect = lambda *args, **kwargs: None
yield mock_init


def test_dense_resolve_input_type_symmetric():
"""When no query_param/document_param, input_type is None (symmetric)."""
embedder = VikingDBDenseEmbedder.__new__(VikingDBDenseEmbedder)
embedder.query_param = None
embedder.document_param = None
assert embedder._resolve_input_type(is_query=True) is None
assert embedder._resolve_input_type(is_query=False) is None


def test_dense_resolve_input_type_nonsymmetric():
"""When query_param/document_param set, return correct value for is_query."""
embedder = VikingDBDenseEmbedder.__new__(VikingDBDenseEmbedder)
embedder.query_param = "query"
embedder.document_param = "passage"
assert embedder._resolve_input_type(is_query=True) == "query"
assert embedder._resolve_input_type(is_query=False) == "passage"


def test_hybrid_resolve_input_type_nonsymmetric():
"""Hybrid embedder also resolves input_type correctly."""
embedder = VikingDBHybridEmbedder.__new__(VikingDBHybridEmbedder)
embedder.query_param = "search_query"
embedder.document_param = "search_document"
assert embedder._resolve_input_type(is_query=True) == "search_query"
assert embedder._resolve_input_type(is_query=False) == "search_document"


def test_dense_backward_compat_no_params():
"""VikingDBDenseEmbedder without query_param/document_param works."""
embedder = VikingDBDenseEmbedder.__new__(VikingDBDenseEmbedder)
embedder.query_param = None
embedder.document_param = None
embedder.model_name = "test"
embedder.dimension = 1024
# Should not raise
assert embedder._resolve_input_type(is_query=True) is None