diff --git a/openviking/models/embedder/base.py b/openviking/models/embedder/base.py index 9a9e5df3e..62370c37a 100644 --- a/openviking/models/embedder/base.py +++ b/openviking/models/embedder/base.py @@ -74,6 +74,7 @@ def __init__(self, model_name: str, config: Optional[Dict[str, Any]] = None): """ self.model_name = model_name self.config = config or {} + self.max_retries = self.config.get("max_retries", 3) if self.config else 3 @abstractmethod def embed(self, text: str, is_query: bool = False) -> EmbedResult: @@ -255,7 +256,7 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes return [ EmbedResult(dense_vector=d.dense_vector, sparse_vector=s.sparse_vector) - for d, s in zip(dense_results, sparse_results) + for d, s in zip(dense_results, sparse_results, strict=False) ] def get_dimension(self) -> int: diff --git a/openviking/models/embedder/gemini_embedders.py b/openviking/models/embedder/gemini_embedders.py index c11d73e47..6bc1e1ab6 100644 --- a/openviking/models/embedder/gemini_embedders.py +++ b/openviking/models/embedder/gemini_embedders.py @@ -29,6 +29,7 @@ EmbedResult, truncate_and_normalize, ) +from openviking.models.retry import transient_retry logger = logging.getLogger("gemini_embedders") @@ -146,15 +147,13 @@ def __init__( ) if dimension is not None and not (1 <= dimension <= 3072): raise ValueError(f"dimension must be between 1 and 3072, got {dimension}") + # Disable SDK-level retry; we use transient_retry for unified retry logic if _HTTP_RETRY_AVAILABLE: self.client = genai.Client( api_key=api_key, http_options=HttpOptions( retry_options=HttpRetryOptions( - attempts=3, - initial_delay=1.0, - max_delay=30.0, - exp_base=2.0, + attempts=1, ) ), ) @@ -209,11 +208,16 @@ def embed( task_type = self.document_param # SDK accepts plain str; converts to REST Parts format internally. try: - result = self.client.models.embed_content( - model=self.model_name, - contents=text, - config=self._build_config(task_type=task_type, title=title), - ) + embed_config = self._build_config(task_type=task_type, title=title) + + def _call(): + return self.client.models.embed_content( + model=self.model_name, + contents=text, + config=embed_config, + ) + + result = transient_retry(_call, max_retries=self.max_retries) vector = truncate_and_normalize(list(result.embeddings[0].values), self._dimension) return EmbedResult(dense_vector=vector) except (APIError, ClientError) as e: @@ -233,7 +237,7 @@ def embed_batch( if titles is not None: return [ self.embed(text, is_query=is_query, task_type=task_type, title=title) - for text, title in zip(texts, titles) + for text, title in zip(texts, titles, strict=False) ] # Resolve effective task_type from is_query when no explicit override if task_type is None: @@ -254,13 +258,17 @@ def embed_batch( non_empty_texts = [batch[j] for j in non_empty_indices] try: - response = self.client.models.embed_content( - model=self.model_name, - contents=non_empty_texts, - config=config, - ) + + def _batch_call(texts=non_empty_texts, cfg=config): + return self.client.models.embed_content( + model=self.model_name, + contents=texts, + config=cfg, + ) + + response = transient_retry(_batch_call, max_retries=self.max_retries) batch_results = [None] * len(batch) - for j, emb in zip(non_empty_indices, response.embeddings): + for j, emb in zip(non_empty_indices, response.embeddings, strict=False): batch_results[j] = EmbedResult( dense_vector=truncate_and_normalize(list(emb.values), self._dimension) ) diff --git a/openviking/models/embedder/jina_embedders.py b/openviking/models/embedder/jina_embedders.py index f94650fcc..25159f2df 100644 --- a/openviking/models/embedder/jina_embedders.py +++ b/openviking/models/embedder/jina_embedders.py @@ -10,6 +10,7 @@ DenseEmbedderBase, EmbedResult, ) +from openviking.models.retry import transient_retry # Default dimensions for Jina embedding models JINA_MODEL_DIMENSIONS = { @@ -113,9 +114,11 @@ def __init__( raise ValueError("api_key is required") # Initialize OpenAI-compatible client with Jina base URL + # Disable SDK retry; we use transient_retry for unified retry logic self.client = openai.OpenAI( api_key=self.api_key, base_url=self.api_base, + max_retries=0, ) # Determine dimension @@ -174,7 +177,10 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult: if extra_body: kwargs["extra_body"] = extra_body - response = self.client.embeddings.create(**kwargs) + def _call(): + return self.client.embeddings.create(**kwargs) + + response = transient_retry(_call, max_retries=self.max_retries) vector = response.data[0].embedding return EmbedResult(dense_vector=vector) @@ -209,7 +215,10 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes if extra_body: kwargs["extra_body"] = extra_body - response = self.client.embeddings.create(**kwargs) + def _call(): + return self.client.embeddings.create(**kwargs) + + response = transient_retry(_call, max_retries=self.max_retries) return [EmbedResult(dense_vector=item.embedding) for item in response.data] except openai.APIError as e: diff --git a/openviking/models/embedder/litellm_embedders.py b/openviking/models/embedder/litellm_embedders.py index ea24c8141..903b2fc67 100644 --- a/openviking/models/embedder/litellm_embedders.py +++ b/openviking/models/embedder/litellm_embedders.py @@ -13,6 +13,7 @@ import litellm from openviking.models.embedder.base import DenseEmbedderBase, EmbedResult +from openviking.models.retry import transient_retry from openviking.telemetry import get_current_telemetry logger = logging.getLogger(__name__) @@ -157,7 +158,11 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult: try: kwargs = self._build_kwargs(is_query=is_query) kwargs["input"] = [text] - response = litellm.embedding(**kwargs) + + def _call(): + return litellm.embedding(**kwargs) + + response = transient_retry(_call, max_retries=self.max_retries) self._update_telemetry_token_usage(response) vector = response.data[0]["embedding"] return EmbedResult(dense_vector=vector) @@ -183,7 +188,11 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes try: kwargs = self._build_kwargs(is_query=is_query) kwargs["input"] = texts - response = litellm.embedding(**kwargs) + + def _call(): + return litellm.embedding(**kwargs) + + response = transient_retry(_call, max_retries=self.max_retries) self._update_telemetry_token_usage(response) return [EmbedResult(dense_vector=item["embedding"]) for item in response.data] except Exception as e: diff --git a/openviking/models/embedder/minimax_embedders.py b/openviking/models/embedder/minimax_embedders.py index aba462968..1547b13af 100644 --- a/openviking/models/embedder/minimax_embedders.py +++ b/openviking/models/embedder/minimax_embedders.py @@ -9,6 +9,7 @@ from urllib3.util.retry import Retry from openviking.models.embedder.base import DenseEmbedderBase, EmbedResult +from openviking.models.retry import transient_retry from openviking_cli.utils.logger import default_logger as logger @@ -89,12 +90,8 @@ def __init__( def _create_session(self) -> requests.Session: """Create a requests session with retry logic""" session = requests.Session() - retry_strategy = Retry( - total=6, - backoff_factor=1, # 1s, 2s, 4s, 8s, 16s, 32s - status_forcelist=[429, 500, 502, 503, 504], - allowed_methods=["POST"], - ) + # Disable transport-level retry; we use transient_retry for unified retry logic + retry_strategy = Retry(total=0) adapter = HTTPAdapter(max_retries=retry_strategy) session.mount("https://", adapter) session.mount("http://", adapter) @@ -163,7 +160,10 @@ def _call_api(self, texts: List[str], is_query: bool = False) -> List[List[float def embed(self, text: str, is_query: bool = False) -> EmbedResult: """Perform dense embedding on text""" - vectors = self._call_api([text], is_query=is_query) + vectors = transient_retry( + lambda: self._call_api([text], is_query=is_query), + max_retries=self.max_retries, + ) return EmbedResult(dense_vector=vectors[0]) def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]: @@ -171,9 +171,10 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes if not texts: return [] - # MiniMax might have batch size limits, but let's assume the caller handles batching or use safe defaults - # For now, we pass through. If needed, we can implement internal chunking. - vectors = self._call_api(texts, is_query=is_query) + vectors = transient_retry( + lambda: self._call_api(texts, is_query=is_query), + max_retries=self.max_retries, + ) return [EmbedResult(dense_vector=v) for v in vectors] def get_dimension(self) -> int: diff --git a/openviking/models/embedder/openai_embedders.py b/openviking/models/embedder/openai_embedders.py index 0ebabbeee..0477270a6 100644 --- a/openviking/models/embedder/openai_embedders.py +++ b/openviking/models/embedder/openai_embedders.py @@ -12,6 +12,7 @@ HybridEmbedderBase, SparseEmbedderBase, ) +from openviking.models.retry import transient_retry from openviking.models.vlm.registry import DEFAULT_AZURE_API_VERSION from openviking.telemetry import get_current_telemetry @@ -118,7 +119,10 @@ def __init__( if not self.api_key and not self.api_base: raise ValueError("api_key is required") - client_kwargs: Dict[str, Any] = {"api_key": self.api_key or "no-key"} + client_kwargs: Dict[str, Any] = { + "api_key": self.api_key or "no-key", + "max_retries": 0, # Disable SDK retry; we use transient_retry + } if self._provider == "azure": if not self.api_base: raise ValueError("api_base (Azure endpoint) is required for Azure provider") @@ -242,7 +246,10 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult: if extra_body: kwargs["extra_body"] = extra_body - response = self.client.embeddings.create(**kwargs) + def _call(): + return self.client.embeddings.create(**kwargs) + + response = transient_retry(_call, max_retries=self.max_retries) self._update_telemetry_token_usage(response) vector = response.data[0].embedding @@ -277,7 +284,10 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes if extra_body: kwargs["extra_body"] = extra_body - response = self.client.embeddings.create(**kwargs) + def _call(): + return self.client.embeddings.create(**kwargs) + + response = transient_retry(_call, max_retries=self.max_retries) self._update_telemetry_token_usage(response) return [EmbedResult(dense_vector=item.embedding) for item in response.data] diff --git a/openviking/models/embedder/vikingdb_embedders.py b/openviking/models/embedder/vikingdb_embedders.py index 0253af9dc..d9c5cf49b 100644 --- a/openviking/models/embedder/vikingdb_embedders.py +++ b/openviking/models/embedder/vikingdb_embedders.py @@ -10,6 +10,7 @@ HybridEmbedderBase, SparseEmbedderBase, ) +from openviking.models.retry import transient_retry from openviking.storage.vectordb.collection.volcengine_clients import ClientForDataApi from openviking_cli.utils.logger import default_logger as logger @@ -124,7 +125,10 @@ def __init__( self.dense_model = {"name": model_name, "version": model_version, "dim": dimension} def embed(self, text: str, is_query: bool = False) -> EmbedResult: - results = self._call_api([text], dense_model=self.dense_model) + results = transient_retry( + lambda: self._call_api([text], dense_model=self.dense_model), + max_retries=self.max_retries, + ) if not results: return EmbedResult(dense_vector=[]) @@ -138,7 +142,10 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult: def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]: if not texts: return [] - raw_results = self._call_api(texts, dense_model=self.dense_model) + raw_results = transient_retry( + lambda: self._call_api(texts, dense_model=self.dense_model), + max_retries=self.max_retries, + ) return [ EmbedResult( dense_vector=self._truncate_and_normalize( @@ -174,7 +181,10 @@ def __init__( } def embed(self, text: str, is_query: bool = False) -> EmbedResult: - results = self._call_api([text], sparse_model=self.sparse_model) + results = transient_retry( + lambda: self._call_api([text], sparse_model=self.sparse_model), + max_retries=self.max_retries, + ) if not results: return EmbedResult(sparse_vector={}) @@ -188,7 +198,10 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult: def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]: if not texts: return [] - raw_results = self._call_api(texts, sparse_model=self.sparse_model) + raw_results = transient_retry( + lambda: self._call_api(texts, sparse_model=self.sparse_model), + max_retries=self.max_retries, + ) return [ EmbedResult( sparse_vector=self._process_sparse_embedding(item.get("sparse_embedding", {})) @@ -224,8 +237,11 @@ def __init__( } def embed(self, text: str, is_query: bool = False) -> EmbedResult: - results = self._call_api( - [text], dense_model=self.dense_model, sparse_model=self.sparse_model + results = transient_retry( + lambda: self._call_api( + [text], dense_model=self.dense_model, sparse_model=self.sparse_model + ), + max_retries=self.max_retries, ) if not results: return EmbedResult(dense_vector=[], sparse_vector={}) @@ -244,8 +260,11 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult: def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]: if not texts: return [] - raw_results = self._call_api( - texts, dense_model=self.dense_model, sparse_model=self.sparse_model + raw_results = transient_retry( + lambda: self._call_api( + texts, dense_model=self.dense_model, sparse_model=self.sparse_model + ), + max_retries=self.max_retries, ) results = [] for item in raw_results: diff --git a/openviking/models/embedder/volcengine_embedders.py b/openviking/models/embedder/volcengine_embedders.py index 7a3ef6f4e..15ea42cca 100644 --- a/openviking/models/embedder/volcengine_embedders.py +++ b/openviking/models/embedder/volcengine_embedders.py @@ -11,29 +11,13 @@ EmbedResult, HybridEmbedderBase, SparseEmbedderBase, - exponential_backoff_retry, truncate_and_normalize, ) +from openviking.models.retry import transient_retry from openviking.telemetry import get_current_telemetry from openviking_cli.utils.logger import default_logger as logger -def is_429_error(exception: Exception) -> bool: - """ - 判断异常是否为 429 限流错误 - - Args: - exception: 要检查的异常 - - Returns: - 如果是 429 错误则返回 True,否则返回 False - """ - exception_str = str(exception) - return ( - "429" in exception_str or "TooManyRequests" in exception_str or "RateLimit" in exception_str - ) - - def process_sparse_embedding(sparse_data: Any) -> Dict[str, float]: """Process sparse embedding data from SDK response""" if not sparse_data: @@ -177,15 +161,7 @@ def _embed_call(): return EmbedResult(dense_vector=vector) try: - return exponential_backoff_retry( - _embed_call, - max_wait=10.0, - base_delay=0.5, - max_delay=2.0, - jitter=True, - is_retryable=is_429_error, - logger=logger, - ) + return transient_retry(_embed_call, max_retries=self.max_retries) except Exception as e: raise RuntimeError(f"Volcengine embedding failed: {str(e)}") from e @@ -205,7 +181,7 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes if not texts: return [] - try: + def _batch_call(): if self.input_type == "multimodal": multimodal_inputs = [{"type": "text", "text": text} for text in texts] response = self.client.multimodal_embeddings.create( @@ -222,6 +198,9 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes EmbedResult(dense_vector=truncate_and_normalize(item.embedding, self.dimension)) for item in data ] + + try: + return transient_retry(_batch_call, max_retries=self.max_retries) except Exception as e: logger.error( f"Volcengine batch embedding failed, texts length: {len(texts)}, input_type: {self.input_type}, model_name: {self.model_name}" @@ -295,15 +274,7 @@ def _embed_call(): return EmbedResult(sparse_vector=process_sparse_embedding(sparse_vector)) try: - return exponential_backoff_retry( - _embed_call, - max_wait=10.0, - base_delay=0.5, - max_delay=2.0, - jitter=True, - is_retryable=is_429_error, - logger=logger, - ) + return transient_retry(_embed_call, max_retries=self.max_retries) except Exception as e: raise RuntimeError(f"Volcengine sparse embedding failed: {str(e)}") from e @@ -400,15 +371,7 @@ def _embed_call(): ) try: - return exponential_backoff_retry( - _embed_call, - max_wait=10.0, - base_delay=0.5, - max_delay=2.0, - jitter=True, - is_retryable=is_429_error, - logger=logger, - ) + return transient_retry(_embed_call, max_retries=self.max_retries) except Exception as e: raise RuntimeError(f"Volcengine hybrid embedding failed: {str(e)}") from e diff --git a/openviking/models/embedder/voyage_embedders.py b/openviking/models/embedder/voyage_embedders.py index ed8d49f04..db8b85b3a 100644 --- a/openviking/models/embedder/voyage_embedders.py +++ b/openviking/models/embedder/voyage_embedders.py @@ -7,6 +7,7 @@ import openai from openviking.models.embedder.base import DenseEmbedderBase, EmbedResult +from openviking.models.retry import transient_retry VOYAGE_MODEL_DIMENSIONS = { "voyage-3": 1024, @@ -74,9 +75,11 @@ def __init__( f"Supported dimensions: {supported}." ) + # Disable SDK retry; we use transient_retry for unified retry logic self.client = openai.OpenAI( api_key=self.api_key, base_url=self.api_base, + max_retries=0, ) self._dimension = dimension or get_voyage_model_default_dimension(normalized_model_name) @@ -88,7 +91,10 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult: if self.dimension is not None: kwargs["extra_body"] = {"output_dimension": self.dimension} - response = self.client.embeddings.create(**kwargs) + def _call(): + return self.client.embeddings.create(**kwargs) + + response = transient_retry(_call, max_retries=self.max_retries) vector = response.data[0].embedding return EmbedResult(dense_vector=vector) except openai.APIError as e: @@ -106,7 +112,10 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes if self.dimension is not None: kwargs["extra_body"] = {"output_dimension": self.dimension} - response = self.client.embeddings.create(**kwargs) + def _call(): + return self.client.embeddings.create(**kwargs) + + response = transient_retry(_call, max_retries=self.max_retries) return [EmbedResult(dense_vector=item.embedding) for item in response.data] except openai.APIError as e: raise RuntimeError(f"Voyage API error: {e.message}") from e diff --git a/openviking/models/retry.py b/openviking/models/retry.py new file mode 100644 index 000000000..bf4b4e697 --- /dev/null +++ b/openviking/models/retry.py @@ -0,0 +1,287 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 + +"""Unified retry logic for VLM backends and embedding providers. + +Provides three public helpers: + +- ``is_transient_error`` — classifies an exception as transient (retryable) + or permanent (should propagate immediately). +- ``transient_retry`` — synchronous retry loop with exponential backoff. +- ``transient_retry_async`` — asynchronous counterpart using ``asyncio.sleep``. + +Transient errors are those that may resolve on their own (rate-limits, temporary +server errors, network resets). Permanent errors indicate a caller mistake +(bad auth, invalid input) and should never be retried. + +Usage example:: + + result = transient_retry(lambda: client.chat(...), max_retries=3) + result = await transient_retry_async(lambda: client.chat_async(...), max_retries=3) +""" + +from __future__ import annotations + +import asyncio +import logging +import random +import time +from collections.abc import Callable +from typing import Optional, TypeVar + +logger = logging.getLogger("openviking.models.retry") + +T = TypeVar("T") + +# --------------------------------------------------------------------------- +# Status code helpers +# --------------------------------------------------------------------------- + +_TRANSIENT_STATUS_CODES: frozenset[int] = frozenset({429, 500, 502, 503, 504}) +_PERMANENT_STATUS_CODES: frozenset[int] = frozenset({400, 401, 403, 404, 422}) + +# String patterns — permanent check runs first (more specific) +_PERMANENT_STR_PATTERNS: tuple[str, ...] = ( + "InvalidRequestError", + "AuthenticationError", +) +_TRANSIENT_STR_PATTERNS: tuple[str, ...] = ( + "TooManyRequests", + "RateLimit", + "RequestBurstTooFast", + "timed out", + "timeout", +) + + +def _extract_status_code(exc: Exception) -> int | None: + """Return numeric HTTP status from common status-bearing attributes. + + Checks ``.status_code``, ``.code``, and ``.http_status`` in that order. + Returns ``None`` if none of the attributes exist or hold an integer. + """ + for attr in ("status_code", "code", "http_status"): + value = getattr(exc, attr, None) + if isinstance(value, int): + return value + return None + + +# --------------------------------------------------------------------------- +# is_transient_error +# --------------------------------------------------------------------------- + + +def is_transient_error(exc: Exception) -> bool: + """Classify an exception as transient (retryable) or permanent. + + Evaluation order: + 1. Extract numeric status code from the exception attributes; check + permanent codes first, then transient codes. + 2. Check the exception type directly (built-in connection / timeout types). + 3. Scan ``str(exc)`` for known permanent string patterns, then transient + ones. + 4. Attempt to import ``openai`` and check against its error hierarchy. + 5. Default to ``False`` (conservative — unknown errors are not retried). + + Args: + exc: The exception to classify. + + Returns: + ``True`` if the error is likely transient and worth retrying. + ``False`` for permanent errors or any unrecognised exception. + """ + # ── 1. Numeric status code ──────────────────────────────────────────── + status = _extract_status_code(exc) + if status is not None: + if status in _PERMANENT_STATUS_CODES: + return False + if status in _TRANSIENT_STATUS_CODES: + return True + + # ── 2. Exception type ───────────────────────────────────────────────── + # asyncio.TimeoutError is a subclass of TimeoutError on 3.11+, but treat + # both explicitly for clarity on 3.10. + if isinstance(exc, (ConnectionError, ConnectionResetError, ConnectionRefusedError)): + return True + if isinstance(exc, (TimeoutError, asyncio.TimeoutError)): + return True + + # ── 3. String patterns ──────────────────────────────────────────────── + message = str(exc) + + for pattern in _PERMANENT_STR_PATTERNS: + if pattern in message: + return False + + for pattern in _TRANSIENT_STR_PATTERNS: + if pattern in message: + return True + + # ── 4. openai error types (optional dependency) ─────────────────────── + try: + import openai # type: ignore[import-untyped] + + # Permanent openai errors — check before transient + if isinstance(exc, openai.AuthenticationError): + return False + + # Transient openai errors + if isinstance( + exc, (openai.RateLimitError, openai.APITimeoutError, openai.APIConnectionError) + ): + return True + except ImportError: + pass + + # ── 5. Default: do not retry unknown errors ─────────────────────────── + return False + + +# --------------------------------------------------------------------------- +# transient_retry (sync) +# --------------------------------------------------------------------------- + + +def transient_retry( + func: Callable[[], T], + max_retries: int = 3, + base_delay: float = 0.5, + max_delay: float = 8.0, + jitter: bool = True, + is_retryable: Optional[Callable[[Exception], bool]] = None, +) -> T: + """Call *func* and retry on transient failures with exponential backoff. + + The delay between attempts follows the formula:: + + delay = min(base_delay * 2^attempt, max_delay) + + When ``jitter=True`` the delay is multiplied by a random factor in + ``[0.5, 1.5)`` to spread concurrent retries. + + Args: + func: Zero-argument callable to invoke. + max_retries: Maximum number of *additional* attempts after the first + failure. ``0`` disables retrying entirely. + base_delay: Initial delay in seconds before the first retry. + max_delay: Upper bound on the computed delay (seconds). + jitter: Whether to apply random jitter to the delay. + is_retryable: Optional predicate that decides whether an exception + should be retried. Defaults to ``is_transient_error``. + + Returns: + The return value of *func* on success. + + Raises: + Exception: The last exception raised by *func* after all retries are + exhausted, or immediately if the error is not retryable. + """ + _check = is_retryable if is_retryable is not None else is_transient_error + + last_exc: Exception + for attempt in range(max_retries + 1): + try: + return func() + except Exception as exc: + last_exc = exc + + if not _check(exc): + # Permanent — propagate immediately + raise + + if attempt >= max_retries: + # Retries exhausted + logger.warning( + "transient_retry: all %d retries exhausted; last error: %s", + max_retries, + exc, + ) + raise + + delay = min(base_delay * (2**attempt), max_delay) + if jitter: + delay *= 0.5 + random.random() # [0.5, 1.5) + + logger.info( + "transient_retry: attempt %d/%d failed (%s); retrying in %.2fs", + attempt + 1, + max_retries, + exc, + delay, + ) + time.sleep(delay) + + # Unreachable, but satisfies the type checker + raise last_exc # type: ignore[possibly-undefined] + + +# --------------------------------------------------------------------------- +# transient_retry_async +# --------------------------------------------------------------------------- + + +async def transient_retry_async( + coro_func: Callable[[], "asyncio.Coroutine[object, object, T]"], + max_retries: int = 3, + base_delay: float = 0.5, + max_delay: float = 8.0, + jitter: bool = True, + is_retryable: Optional[Callable[[Exception], bool]] = None, +) -> T: + """Async version of :func:`transient_retry`. + + Identical semantics to the sync variant but uses ``asyncio.sleep`` + so it does not block the event loop during backoff. + + Args: + coro_func: Zero-argument async callable (coroutine factory) to invoke. + max_retries: Maximum number of *additional* attempts after the first + failure. ``0`` disables retrying entirely. + base_delay: Initial delay in seconds before the first retry. + max_delay: Upper bound on the computed delay (seconds). + jitter: Whether to apply random jitter to the delay. + is_retryable: Optional predicate that decides whether an exception + should be retried. Defaults to ``is_transient_error``. + + Returns: + The return value of *coro_func()* on success. + + Raises: + Exception: The last exception raised by *coro_func* after all retries + are exhausted, or immediately if the error is not retryable. + """ + _check = is_retryable if is_retryable is not None else is_transient_error + + last_exc: Exception + for attempt in range(max_retries + 1): + try: + return await coro_func() + except Exception as exc: + last_exc = exc + + if not _check(exc): + raise + + if attempt >= max_retries: + logger.warning( + "transient_retry_async: all %d retries exhausted; last error: %s", + max_retries, + exc, + ) + raise + + delay = min(base_delay * (2**attempt), max_delay) + if jitter: + delay *= 0.5 + random.random() + + logger.info( + "transient_retry_async: attempt %d/%d failed (%s); retrying in %.2fs", + attempt + 1, + max_retries, + exc, + delay, + ) + await asyncio.sleep(delay) + + raise last_exc # type: ignore[possibly-undefined] diff --git a/openviking/models/vlm/backends/litellm_vlm.py b/openviking/models/vlm/backends/litellm_vlm.py index ca4a36aa7..36a2bf2d9 100644 --- a/openviking/models/vlm/backends/litellm_vlm.py +++ b/openviking/models/vlm/backends/litellm_vlm.py @@ -8,7 +8,6 @@ os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" -import asyncio import base64 import time from pathlib import Path @@ -17,6 +16,8 @@ import litellm from litellm import acompletion, completion +from openviking.models.retry import transient_retry, transient_retry_async + from ..base import ToolCall, VLMBase, VLMResponse logger = logging.getLogger(__name__) @@ -294,8 +295,11 @@ def get_completion( kwargs = self._build_kwargs(model, kwargs_messages, tools, tool_choice, thinking=thinking) + def _call(): + return completion(**kwargs) + t0 = time.perf_counter() - response = completion(**kwargs) + response = transient_retry(_call, max_retries=self.max_retries) elapsed = time.perf_counter() - t0 self._update_token_usage_from_response(response, duration_seconds=elapsed) return self._build_vlm_response(response, has_tools=bool(tools)) @@ -304,7 +308,6 @@ async def get_completion_async( self, prompt: str = "", thinking: bool = False, - max_retries: int = 0, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[str] = None, messages: Optional[List[Dict[str, Any]]] = None, @@ -318,25 +321,17 @@ async def get_completion_async( kwargs = self._build_kwargs(model, kwargs_messages, tools, tool_choice, thinking=thinking) - last_error = None - for attempt in range(max_retries + 1): - try: - t0 = time.perf_counter() - response = await acompletion(**kwargs) - elapsed = time.perf_counter() - t0 - self._update_token_usage_from_response( - response, - duration_seconds=elapsed, - ) - return self._build_vlm_response(response, has_tools=bool(tools)) - except Exception as e: - last_error = e - if attempt < max_retries: - await asyncio.sleep(2**attempt) - - if last_error: - raise last_error - raise RuntimeError("Unknown error in async completion") + async def _call(): + return await acompletion(**kwargs) + + t0 = time.perf_counter() + response = await transient_retry_async(_call, max_retries=self.max_retries) + elapsed = time.perf_counter() - t0 + self._update_token_usage_from_response( + response, + duration_seconds=elapsed, + ) + return self._build_vlm_response(response, has_tools=bool(tools)) def get_vision_completion( self, @@ -362,8 +357,11 @@ def get_vision_completion( kwargs = self._build_kwargs(model, kwargs_messages, tools, thinking=thinking) + def _call(): + return completion(**kwargs) + t0 = time.perf_counter() - response = completion(**kwargs) + response = transient_retry(_call, max_retries=self.max_retries) elapsed = time.perf_counter() - t0 self._update_token_usage_from_response(response, duration_seconds=elapsed) return self._build_vlm_response(response, has_tools=bool(tools)) @@ -392,8 +390,11 @@ async def get_vision_completion_async( kwargs = self._build_kwargs(model, kwargs_messages, tools, thinking=thinking) + async def _call(): + return await acompletion(**kwargs) + t0 = time.perf_counter() - response = await acompletion(**kwargs) + response = await transient_retry_async(_call, max_retries=self.max_retries) elapsed = time.perf_counter() - t0 self._update_token_usage_from_response(response, duration_seconds=elapsed) return self._build_vlm_response(response, has_tools=bool(tools)) diff --git a/openviking/models/vlm/backends/openai_vlm.py b/openviking/models/vlm/backends/openai_vlm.py index 05c28c768..abcc35ba0 100644 --- a/openviking/models/vlm/backends/openai_vlm.py +++ b/openviking/models/vlm/backends/openai_vlm.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: AGPL-3.0 """OpenAI VLM backend implementation""" -import asyncio import base64 import json import logging @@ -11,6 +10,8 @@ from typing import Any, Dict, List, Optional, Union from urllib.parse import urlparse +from openviking.models.retry import transient_retry, transient_retry_async + from ..base import ToolCall, VLMBase, VLMResponse from ..registry import DEFAULT_AZURE_API_VERSION @@ -69,6 +70,7 @@ def get_client(self): self.api_version, self.extra_headers, ) + kwargs["max_retries"] = 0 # Disable SDK retry; we use transient_retry if self.provider == "azure": self._sync_client = openai.AzureOpenAI(**kwargs) else: @@ -89,6 +91,7 @@ def get_async_client(self): self.api_version, self.extra_headers, ) + kwargs["max_retries"] = 0 # Disable SDK retry; we use transient_retry_async if self.provider == "azure": self._async_client = openai.AsyncAzureOpenAI(**kwargs) else: @@ -289,8 +292,11 @@ def get_completion( kwargs["tools"] = tools kwargs["tool_choice"] = tool_choice or "auto" + def _call(): + return client.chat.completions.create(**kwargs) + t0 = time.perf_counter() - response = client.chat.completions.create(**kwargs) + response = transient_retry(_call, max_retries=self.max_retries) elapsed = time.perf_counter() - t0 if tools: @@ -309,7 +315,6 @@ async def get_completion_async( self, prompt: str = "", thinking: bool = False, - max_retries: int = 0, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[str] = None, messages: Optional[List[Dict[str, Any]]] = None, @@ -335,36 +340,27 @@ async def get_completion_async( kwargs["tools"] = tools kwargs["tool_choice"] = tool_choice or "auto" - last_error = None - for attempt in range(max_retries + 1): - try: - t0 = time.perf_counter() - response = await client.chat.completions.create(**kwargs) - elapsed = time.perf_counter() - t0 - - if tools: - self._update_token_usage_from_response(response) - return self._build_vlm_response(response, has_tools=bool(tools)) - - if self.stream: - content = await self._process_streaming_response_async(response) - else: - self._update_token_usage_from_response( - response, - duration_seconds=elapsed, - ) - content = self._extract_content_from_response(response) - - return self._clean_response(content) - except Exception as e: - last_error = e - if attempt < max_retries: - await asyncio.sleep(2**attempt) - - if last_error: - raise last_error + async def _call(): + return await client.chat.completions.create(**kwargs) + + t0 = time.perf_counter() + response = await transient_retry_async(_call, max_retries=self.max_retries) + elapsed = time.perf_counter() - t0 + + if tools: + self._update_token_usage_from_response(response) + return self._build_vlm_response(response, has_tools=bool(tools)) + + if self.stream: + content = await self._process_streaming_response_async(response) else: - raise RuntimeError("Unknown error in async completion") + self._update_token_usage_from_response( + response, + duration_seconds=elapsed, + ) + content = self._extract_content_from_response(response) + + return self._clean_response(content) def _detect_image_format(self, data: bytes) -> str: """Detect image format from magic bytes. @@ -454,8 +450,11 @@ def get_vision_completion( kwargs["tools"] = tools kwargs["tool_choice"] = "auto" + def _call(): + return client.chat.completions.create(**kwargs) + t0 = time.perf_counter() - response = client.chat.completions.create(**kwargs) + response = transient_retry(_call, max_retries=self.max_retries) elapsed = time.perf_counter() - t0 if tools: @@ -506,8 +505,11 @@ async def get_vision_completion_async( kwargs["tools"] = tools kwargs["tool_choice"] = "auto" + async def _call(): + return await client.chat.completions.create(**kwargs) + t0 = time.perf_counter() - response = await client.chat.completions.create(**kwargs) + response = await transient_retry_async(_call, max_retries=self.max_retries) elapsed = time.perf_counter() - t0 if tools: diff --git a/openviking/models/vlm/backends/volcengine_vlm.py b/openviking/models/vlm/backends/volcengine_vlm.py index 06616def1..50912c600 100644 --- a/openviking/models/vlm/backends/volcengine_vlm.py +++ b/openviking/models/vlm/backends/volcengine_vlm.py @@ -9,7 +9,7 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union -# Import run_async for sync-to-async calls +from openviking.models.retry import transient_retry_async from openviking_cli.utils import run_async from ..base import ToolCall, VLMResponse @@ -266,6 +266,20 @@ def _update_token_usage_from_response( ) return + def _parse_tool_calls(self, message) -> List[ToolCall]: + """Parse tool calls from VolcEngine response message.""" + tool_calls = [] + if hasattr(message, "tool_calls") and message.tool_calls: + for tc in message.tool_calls: + args = tc.function.arguments + if isinstance(args, str): + try: + args = json.loads(args) + except json.JSONDecodeError: + args = {"raw": args} + tool_calls.append(ToolCall(id=tc.id, name=tc.function.name, arguments=args)) + return tool_calls + def _build_vlm_response(self, response, has_tools: bool) -> Union[str, VLMResponse]: """Build response from VolcEngine Responses API response. @@ -279,10 +293,10 @@ def _build_vlm_response(self, response, has_tools: bool) -> Union[str, VLMRespon # logger.info(f"[VolcEngineVLM] Full response: {response}") if hasattr(response, "output"): # logger.debug(f"[VolcEngineVLM] Output items: {len(response.output)}") - for i, item in enumerate(response.output): - # logger.debug(f"[VolcEngineVLM] Item {i}: type={getattr(item, 'type', 'unknown')}") + for _i, _item in enumerate(response.output): + # logger.debug(f"[VolcEngineVLM] Item {_i}: type={getattr(_item, 'type', 'unknown')}") # Print full item for debugging - # logger.info(f"[VolcEngineVLM] Item {i} full: {item}") + # logger.info(f"[VolcEngineVLM] Item {_i} full: {_item}") pass # Extract content from Responses API format @@ -436,7 +450,7 @@ def _convert_messages_to_input(self, messages: List[Dict[str, Any]]) -> List[Dic url = image_url.get("url", "") if url: image_urls.append(url) - has_images = True + has_images = True # noqa: F841 # Handle other block types else: # Try to extract text from any dict block @@ -539,7 +553,6 @@ async def get_completion_async( self, prompt: str = "", thinking: bool = False, - max_retries: int = 0, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[str] = None, messages: Optional[List[Dict[str, Any]]] = None, @@ -561,32 +574,19 @@ async def get_completion_async( # If we have static segments, try prefix cache response_format = None # Can be extended for structured output - try: - # Use prefix cache with multiple segments - response = await self.responseapi_prefixcache_completion( + async def _call(): + return await self.responseapi_prefixcache_completion( static_segments=static_segments, dynamic_messages=dynamic_messages, response_format=response_format, tools=tools, tool_choice=tool_choice, ) - elapsed = 0 # Timing handled in responseapi methods - self._update_token_usage_from_response(response, duration_seconds=elapsed) - return self._build_vlm_response(response, has_tools=bool(tools)) - except Exception as e: - last_error = e - # Log token info from error response if available - error_response = getattr(e, "response", None) - if error_response and hasattr(error_response, "usage"): - u = error_response.usage - prompt_tokens = getattr(u, "input_tokens", 0) or 0 - completion_tokens = getattr(u, "output_tokens", 0) or 0 - logger.info( - f"[VolcEngineVLM] Error response - Input tokens: {prompt_tokens}, Output tokens: {completion_tokens}" - ) - logger.warning(f"[VolcEngineVLM] Request failed: {e}") - raise last_error + response = await transient_retry_async(_call, max_retries=self.max_retries) + elapsed = 0 # Timing handled in responseapi methods + self._update_token_usage_from_response(response, duration_seconds=elapsed) + return self._build_vlm_response(response, has_tools=bool(tools)) def _detect_image_format(self, data: bytes) -> str: """Detect image format from magic bytes. diff --git a/openviking/models/vlm/base.py b/openviking/models/vlm/base.py index c54285ad4..270444283 100644 --- a/openviking/models/vlm/base.py +++ b/openviking/models/vlm/base.py @@ -58,7 +58,7 @@ def __init__(self, config: Dict[str, Any]): self.api_key = config.get("api_key") self.api_base = config.get("api_base") self.temperature = config.get("temperature", 0.0) - self.max_retries = config.get("max_retries", 2) + self.max_retries = config.get("max_retries", 3) self.max_tokens = config.get("max_tokens") self.extra_headers = config.get("extra_headers") self.stream = config.get("stream", False) @@ -94,7 +94,6 @@ async def get_completion_async( self, prompt: str = "", thinking: bool = False, - max_retries: int = 0, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[str] = None, messages: Optional[List[Dict[str, Any]]] = None, @@ -104,7 +103,6 @@ async def get_completion_async( Args: prompt: Text prompt (used if messages not provided) thinking: Whether to enable thinking mode - max_retries: Maximum number of retries tools: Optional list of tool definitions in OpenAI function format tool_choice: Optional tool choice mode ("auto", "none", or specific tool name) messages: Optional list of message dicts (takes precedence over prompt) diff --git a/openviking/models/vlm/llm.py b/openviking/models/vlm/llm.py index e1cde6ccf..d28266a3a 100644 --- a/openviking/models/vlm/llm.py +++ b/openviking/models/vlm/llm.py @@ -183,7 +183,12 @@ def complete_json( if schema and not messages: prompt = f"{prompt}\n\n{get_json_schema_prompt(schema)}" - response = self._get_vlm().get_completion(prompt, thinking, tools, messages) + response = self._get_vlm().get_completion( + prompt=prompt, + thinking=thinking, + tools=tools, + messages=messages, + ) return parse_json_from_response(response) async def complete_json_async( @@ -191,7 +196,6 @@ async def complete_json_async( prompt: str = "", schema: Optional[Dict[str, Any]] = None, thinking: bool = False, - max_retries: int = 0, tools: Optional[List[Dict[str, Any]]] = None, messages: Optional[List[Dict[str, Any]]] = None, ) -> Optional[Dict[str, Any]]: @@ -200,7 +204,10 @@ async def complete_json_async( prompt = f"{prompt}\n\n{get_json_schema_prompt(schema)}" response = await self._get_vlm().get_completion_async( - prompt, thinking, max_retries, tools, messages + prompt=prompt, + thinking=thinking, + tools=tools, + messages=messages, ) return parse_json_from_response(response) @@ -227,12 +234,13 @@ async def complete_model_async( prompt: str, model_class: Type[T], thinking: bool = False, - max_retries: int = 0, ) -> Optional[T]: """Async version of complete_model.""" schema = model_class.model_json_schema() response = await self.complete_json_async( - prompt, schema=schema, thinking=thinking, max_retries=max_retries + prompt=prompt, + schema=schema, + thinking=thinking, ) if response is None: return None @@ -252,7 +260,13 @@ def get_vision_completion( messages: Optional[List[Dict[str, Any]]] = None, ) -> Union[str, Any]: """Get vision completion.""" - return self._get_vlm().get_vision_completion(prompt, images, thinking, tools, messages) + return self._get_vlm().get_vision_completion( + prompt=prompt, + images=images, + thinking=thinking, + tools=tools, + messages=messages, + ) async def get_vision_completion_async( self, @@ -264,5 +278,9 @@ async def get_vision_completion_async( ) -> Union[str, Any]: """Async vision completion.""" return await self._get_vlm().get_vision_completion_async( - prompt, images, thinking, tools, messages + prompt=prompt, + images=images, + thinking=thinking, + tools=tools, + messages=messages, ) diff --git a/openviking_cli/utils/config/embedding_config.py b/openviking_cli/utils/config/embedding_config.py index 8caad66a9..1377d5f4a 100644 --- a/openviking_cli/utils/config/embedding_config.py +++ b/openviking_cli/utils/config/embedding_config.py @@ -262,6 +262,7 @@ class EmbeddingConfig(BaseModel): sparse: Optional[EmbeddingModelConfig] = Field(default=None) hybrid: Optional[EmbeddingModelConfig] = Field(default=None) + max_retries: int = Field(default=3, description="Maximum retry attempts for transient errors") max_concurrent: int = Field( default=10, description="Maximum number of concurrent embedding requests" ) @@ -509,6 +510,13 @@ def _create_embedder( embedder_class, param_builder = factory_registry[key] params = param_builder(config) + + # Inject max_retries into the config dict so embedders pick it up + existing_config = params.get("config") or {} + if isinstance(existing_config, dict): + existing_config["max_retries"] = self.max_retries + params["config"] = existing_config + return embedder_class(**params) def get_embedder(self): diff --git a/openviking_cli/utils/config/vlm_config.py b/openviking_cli/utils/config/vlm_config.py index dd67e7c2c..5bff0a8e4 100644 --- a/openviking_cli/utils/config/vlm_config.py +++ b/openviking_cli/utils/config/vlm_config.py @@ -12,7 +12,7 @@ class VLMConfig(BaseModel): api_key: Optional[str] = Field(default=None, description="API key") api_base: Optional[str] = Field(default=None, description="API base URL") temperature: float = Field(default=0.0, description="Generation temperature") - max_retries: int = Field(default=2, description="Maximum retry attempts") + max_retries: int = Field(default=3, description="Maximum retry attempts") provider: Optional[str] = Field(default=None, description="Provider type") backend: Optional[str] = Field( @@ -181,19 +181,26 @@ def get_completion( messages: Optional[List[Dict[str, Any]]] = None, ) -> Union[str, Any]: """Get LLM completion.""" - return self.get_vlm_instance().get_completion(prompt, thinking, tools, messages) + return self.get_vlm_instance().get_completion( + prompt=prompt, + thinking=thinking, + tools=tools, + messages=messages, + ) async def get_completion_async( self, prompt: str = "", thinking: bool = False, - max_retries: int = 0, tools: Optional[List[Dict[str, Any]]] = None, messages: Optional[List[Dict[str, Any]]] = None, ) -> Union[str, Any]: - """Get LLM completion asynchronously, max_retries=0 means no retry.""" + """Get LLM completion asynchronously.""" return await self.get_vlm_instance().get_completion_async( - prompt, thinking, max_retries, tools, messages + prompt=prompt, + thinking=thinking, + tools=tools, + messages=messages, ) def is_available(self) -> bool: @@ -210,7 +217,11 @@ def get_vision_completion( ) -> Union[str, Any]: """Get LLM completion with images.""" return self.get_vlm_instance().get_vision_completion( - prompt, images, thinking, tools, messages + prompt=prompt, + images=images, + thinking=thinking, + tools=tools, + messages=messages, ) async def get_vision_completion_async( @@ -223,5 +234,9 @@ async def get_vision_completion_async( ) -> Union[str, Any]: """Get LLM completion with images asynchronously.""" return await self.get_vlm_instance().get_vision_completion_async( - prompt, images, thinking, tools, messages + prompt=prompt, + images=images, + thinking=thinking, + tools=tools, + messages=messages, ) diff --git a/tests/models/test_vlm_strip_think_tags.py b/tests/models/test_vlm_strip_think_tags.py index c7a996cfb..95e45b682 100644 --- a/tests/models/test_vlm_strip_think_tags.py +++ b/tests/models/test_vlm_strip_think_tags.py @@ -18,7 +18,7 @@ class _Stub(VLMBase): def get_completion(self, prompt, thinking=False): return "" - async def get_completion_async(self, prompt, thinking=False, max_retries=0): + async def get_completion_async(self, prompt, thinking=False): return "" def get_vision_completion(self, prompt, images, thinking=False): diff --git a/tests/unit/test_backward_compat.py b/tests/unit/test_backward_compat.py new file mode 100644 index 000000000..ba4d8a762 --- /dev/null +++ b/tests/unit/test_backward_compat.py @@ -0,0 +1,167 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 + +"""Backward compatibility tests for the retry migration. + +Verifies that: +- exponential_backoff_retry is still importable from the old location (base.py) +- exponential_backoff_retry signature is unchanged +- exponential_backoff_retry behaviour still works (time-based) +- transient_retry is count-based (different semantics) +""" + +from __future__ import annotations + +import inspect +from unittest.mock import patch + +import pytest + + +class _HttpError(Exception): + """Fake HTTP error carrying a numeric status code.""" + + def __init__(self, status_code: int, message: str = ""): + super().__init__(message or f"HTTP {status_code}") + self.status_code = status_code + + +class TestExponentialBackoffRetryImportable: + def test_importable_from_old_location(self): + """exponential_backoff_retry should still be importable from base.py.""" + from openviking.models.embedder.base import exponential_backoff_retry + + assert callable(exponential_backoff_retry) + + +class TestExponentialBackoffRetrySignature: + def test_signature_unchanged(self): + """exponential_backoff_retry should retain its original signature.""" + from openviking.models.embedder.base import exponential_backoff_retry + + sig = inspect.signature(exponential_backoff_retry) + param_names = list(sig.parameters.keys()) + + expected_params = [ + "func", + "max_wait", + "base_delay", + "max_delay", + "jitter", + "is_retryable", + "logger", + ] + + assert param_names == expected_params, ( + f"exponential_backoff_retry signature changed.\n" + f"Expected: {expected_params}\n" + f"Got: {param_names}" + ) + + def test_defaults_unchanged(self): + """Default parameter values should be preserved.""" + from openviking.models.embedder.base import exponential_backoff_retry + + sig = inspect.signature(exponential_backoff_retry) + params = sig.parameters + + assert params["max_wait"].default == 10.0 + assert params["base_delay"].default == 0.5 + assert params["max_delay"].default == 2.0 + assert params["jitter"].default is True + assert params["is_retryable"].default is None + assert params["logger"].default is None + + +class TestExponentialBackoffRetryBehavior: + def test_success_first_try(self): + """Function succeeds on first attempt.""" + from openviking.models.embedder.base import exponential_backoff_retry + + call_count = 0 + + def fn(): + nonlocal call_count + call_count += 1 + return "ok" + + result = exponential_backoff_retry(fn) + assert result == "ok" + assert call_count == 1 + + def test_retries_on_failure(self): + """Function retries on failure until success.""" + from openviking.models.embedder.base import exponential_backoff_retry + + errors = [Exception("fail"), Exception("fail")] + call_count = 0 + + def fn(): + nonlocal call_count + call_count += 1 + if errors: + raise errors.pop(0) + return "ok" + + with patch("time.sleep"): + result = exponential_backoff_retry(fn, max_wait=10.0) + + assert result == "ok" + assert call_count == 3 + + def test_is_time_based(self): + """exponential_backoff_retry should be time-based (uses max_wait, not count).""" + from openviking.models.embedder.base import exponential_backoff_retry + + sig = inspect.signature(exponential_backoff_retry) + param_names = list(sig.parameters.keys()) + + # Time-based: has max_wait, no max_retries + assert "max_wait" in param_names + assert "max_retries" not in param_names + + def test_respects_is_retryable(self): + """exponential_backoff_retry should respect is_retryable callback.""" + from openviking.models.embedder.base import exponential_backoff_retry + + call_count = 0 + + def fn(): + nonlocal call_count + call_count += 1 + raise ValueError("permanent") + + # is_retryable returns False => no retry + with patch("time.sleep"): + with pytest.raises(ValueError): + exponential_backoff_retry(fn, is_retryable=lambda e: False) + + assert call_count == 1 + + +class TestTransientRetryIsCountBased: + def test_is_count_based(self): + """transient_retry should be count-based (uses max_retries, not max_wait).""" + from openviking.models.retry import transient_retry + + sig = inspect.signature(transient_retry) + param_names = list(sig.parameters.keys()) + + # Count-based: has max_retries, no max_wait + assert "max_retries" in param_names + assert "max_wait" not in param_names + + def test_different_from_backoff_retry(self): + """transient_retry and exponential_backoff_retry should have different signatures.""" + from openviking.models.embedder.base import exponential_backoff_retry + from openviking.models.retry import transient_retry + + backoff_params = set(inspect.signature(exponential_backoff_retry).parameters.keys()) + retry_params = set(inspect.signature(transient_retry).parameters.keys()) + + # They share 'func', 'base_delay', 'max_delay', 'jitter', 'is_retryable' + # but differ on time vs count control params + assert "max_wait" in backoff_params + assert "max_wait" not in retry_params + assert "max_retries" in retry_params + assert "max_retries" not in backoff_params diff --git a/tests/unit/test_embedding_retry_integration.py b/tests/unit/test_embedding_retry_integration.py new file mode 100644 index 000000000..02011b4cb --- /dev/null +++ b/tests/unit/test_embedding_retry_integration.py @@ -0,0 +1,230 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 + +"""Integration tests for embedding providers with unified retry logic. + +Tests cover (using OpenAI and VikingDB as representatives): +- embed retries on transient error (mock API client) +- embed does NOT retry on permanent error +- uses config max_retries +- VikingDB now has retry +""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class _HttpError(Exception): + """Fake HTTP error carrying a numeric status code.""" + + def __init__(self, status_code: int, message: str = ""): + super().__init__(message or f"HTTP {status_code}") + self.status_code = status_code + + +def _make_fake_embedding_response(vector=None): + """Build a minimal fake OpenAI embeddings response.""" + if vector is None: + vector = [0.1] * 10 + item = SimpleNamespace(embedding=vector) + usage = SimpleNamespace(prompt_tokens=5, total_tokens=5) + return SimpleNamespace(data=[item], usage=usage) + + +# --------------------------------------------------------------------------- +# OpenAI Embedder Tests +# --------------------------------------------------------------------------- + + +class TestOpenAIEmbedderRetry: + @pytest.fixture() + def openai_embedder(self): + """Create an OpenAIDenseEmbedder with mocked client.""" + from openviking.models.embedder.openai_embedders import OpenAIDenseEmbedder + + embedder = OpenAIDenseEmbedder( + model_name="text-embedding-3-small", + api_key="sk-test", + dimension=10, + config={"max_retries": 2}, + ) + embedder.client = MagicMock() + return embedder + + def test_embed_retries_on_transient_error(self, openai_embedder): + """embed() should retry on 429 (transient) and succeed.""" + errors = [_HttpError(429)] + call_count = 0 + + def fake_create(**kwargs): + nonlocal call_count + call_count += 1 + if errors: + raise errors.pop(0) + return _make_fake_embedding_response() + + openai_embedder.client.embeddings.create = fake_create + + with patch("time.sleep"): + result = openai_embedder.embed("test text") + + assert result.dense_vector == [0.1] * 10 + assert call_count == 2 # 1 failure + 1 success + + def test_embed_no_retry_on_permanent_error(self, openai_embedder): + """embed() should NOT retry on 401 (permanent).""" + call_count = 0 + + def fake_create(**kwargs): + nonlocal call_count + call_count += 1 + raise _HttpError(401, "Unauthorized") + + openai_embedder.client.embeddings.create = fake_create + + with patch("time.sleep"): + # 401 is permanent, transient_retry won't retry it. + # It will propagate and be caught by the except block, re-raised as RuntimeError. + with pytest.raises((RuntimeError, _HttpError)): + openai_embedder.embed("test text") + + assert call_count == 1 # no retries + + def test_uses_config_max_retries(self): + """Embedder should use self.max_retries from config.""" + from openviking.models.embedder.openai_embedders import OpenAIDenseEmbedder + + embedder = OpenAIDenseEmbedder( + model_name="text-embedding-3-small", + api_key="sk-test", + dimension=10, + config={"max_retries": 5}, + ) + assert embedder.max_retries == 5 + + # Default + embedder2 = OpenAIDenseEmbedder( + model_name="text-embedding-3-small", + api_key="sk-test", + dimension=10, + ) + assert embedder2.max_retries == 3 + + def test_openai_sdk_retry_disabled(self): + """OpenAI client should be created with max_retries=0.""" + from openviking.models.embedder.openai_embedders import OpenAIDenseEmbedder + + with patch("openai.OpenAI") as mock_openai: + mock_openai.return_value = MagicMock() + OpenAIDenseEmbedder( + model_name="text-embedding-3-small", + api_key="sk-test", + dimension=10, + ) + call_kwargs = mock_openai.call_args + assert call_kwargs.kwargs.get("max_retries") == 0 + + +# --------------------------------------------------------------------------- +# VikingDB Embedder Tests +# --------------------------------------------------------------------------- + + +class TestVikingDBEmbedderRetry: + @pytest.fixture() + def vikingdb_embedder(self): + """Create a VikingDBDenseEmbedder with mocked client.""" + from openviking.models.embedder.vikingdb_embedders import VikingDBDenseEmbedder + + with patch("openviking.storage.vectordb.collection.volcengine_clients.ClientForDataApi"): + embedder = VikingDBDenseEmbedder( + model_name="test-model", + model_version="1.0", + ak="test-ak", + sk="test-sk", + region="cn-beijing", + dimension=10, + config={"max_retries": 2}, + ) + return embedder + + def test_embed_retries_on_transient_error(self, vikingdb_embedder): + """embed() should retry on transient error and succeed.""" + errors = [_HttpError(503)] + call_count = 0 + + def fake_call_api(*args, **kwargs): + nonlocal call_count + call_count += 1 + if errors: + raise errors.pop(0) + return [{"dense_embedding": [0.1] * 10}] + + vikingdb_embedder._call_api = fake_call_api + + with patch("time.sleep"): + result = vikingdb_embedder.embed("test text") + + assert result.dense_vector == [0.1] * 10 + assert call_count == 2 # 1 failure + 1 success + + def test_embed_no_retry_on_permanent_error(self, vikingdb_embedder): + """embed() should NOT retry on 401 (permanent).""" + call_count = 0 + + def fake_call_api(*args, **kwargs): + nonlocal call_count + call_count += 1 + raise _HttpError(401, "Unauthorized") + + vikingdb_embedder._call_api = fake_call_api + + with patch("time.sleep"): + with pytest.raises(_HttpError): + vikingdb_embedder.embed("test text") + + assert call_count == 1 # no retries + + def test_uses_config_max_retries(self): + """VikingDB embedder should use self.max_retries from config.""" + from openviking.models.embedder.vikingdb_embedders import VikingDBDenseEmbedder + + with patch("openviking.storage.vectordb.collection.volcengine_clients.ClientForDataApi"): + embedder = VikingDBDenseEmbedder( + model_name="test-model", + model_version="1.0", + ak="test-ak", + sk="test-sk", + region="cn-beijing", + dimension=10, + config={"max_retries": 7}, + ) + assert embedder.max_retries == 7 + + def test_vikingdb_now_has_retry(self, vikingdb_embedder): + """VikingDB embed() should retry on 429 (was zero retry before unified retry).""" + errors = [_HttpError(429), _HttpError(429)] + call_count = 0 + + def fake_call_api(*args, **kwargs): + nonlocal call_count + call_count += 1 + if errors: + raise errors.pop(0) + return [{"dense_embedding": [0.2] * 10}] + + vikingdb_embedder._call_api = fake_call_api + + with patch("time.sleep"): + result = vikingdb_embedder.embed("test text") + + assert result.dense_vector == [0.2] * 10 + assert call_count == 3 # 2 failures + 1 success diff --git a/tests/unit/test_extra_headers_vlm.py b/tests/unit/test_extra_headers_vlm.py index 97f6bea00..c087c2370 100644 --- a/tests/unit/test_extra_headers_vlm.py +++ b/tests/unit/test_extra_headers_vlm.py @@ -210,7 +210,7 @@ class StubVLM(OpenAIVLM): def get_completion(self, prompt, thinking=False): return "" - async def get_completion_async(self, prompt, thinking=False, max_retries=0): + async def get_completion_async(self, prompt, thinking=False): return "" def get_vision_completion(self, prompt, images, thinking=False): @@ -236,7 +236,7 @@ class StubVLM(OpenAIVLM): def get_completion(self, prompt, thinking=False): return "" - async def get_completion_async(self, prompt, thinking=False, max_retries=0): + async def get_completion_async(self, prompt, thinking=False): return "" def get_vision_completion(self, prompt, images, thinking=False): diff --git a/tests/unit/test_retry.py b/tests/unit/test_retry.py new file mode 100644 index 000000000..176412d48 --- /dev/null +++ b/tests/unit/test_retry.py @@ -0,0 +1,441 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 + +"""Comprehensive tests for the core retry module (openviking.models.retry). + +Tests cover: +- is_transient_error: ~28 parametrized cases (14 transient, 14 permanent) +- transient_retry (sync): 8 behavioral tests +- transient_retry_async (async): 8 mirrored behavioral tests +""" + +from __future__ import annotations + +import asyncio +from unittest.mock import AsyncMock, patch + +import pytest + +from openviking.models.retry import is_transient_error, transient_retry, transient_retry_async + +# --------------------------------------------------------------------------- +# Helper fake HTTP error with status_code attribute +# --------------------------------------------------------------------------- + + +class _HttpError(Exception): + """Fake HTTP error carrying a numeric status code for testing.""" + + def __init__(self, status_code: int, message: str = ""): + super().__init__(message or f"HTTP {status_code}") + self.status_code = status_code + + +# --------------------------------------------------------------------------- +# is_transient_error — parametrized cases +# --------------------------------------------------------------------------- + +_TRANSIENT_CASES = [ + # HTTP status codes via _HttpError.status_code + pytest.param(_HttpError(429), True, id="http_429"), + pytest.param(_HttpError(500), True, id="http_500"), + pytest.param(_HttpError(502), True, id="http_502"), + pytest.param(_HttpError(503), True, id="http_503"), + pytest.param(_HttpError(504), True, id="http_504"), + # Built-in connection exceptions + pytest.param(ConnectionError("connection failed"), True, id="ConnectionError"), + pytest.param(ConnectionResetError("reset"), True, id="ConnectionResetError"), + pytest.param(ConnectionRefusedError("refused"), True, id="ConnectionRefusedError"), + pytest.param(TimeoutError("timed out"), True, id="TimeoutError"), + pytest.param(asyncio.TimeoutError(), True, id="asyncio_TimeoutError"), + # String-pattern transient errors + pytest.param(Exception("TooManyRequests from server"), True, id="str_TooManyRequests"), + pytest.param(Exception("RateLimit exceeded"), True, id="str_RateLimit"), + pytest.param(Exception("RequestBurstTooFast"), True, id="str_RequestBurstTooFast"), + pytest.param(Exception("request timed out after 30s"), True, id="str_timed_out"), +] + +_PERMANENT_CASES = [ + # HTTP status codes via _HttpError.status_code + pytest.param(_HttpError(400), False, id="http_400"), + pytest.param(_HttpError(401), False, id="http_401"), + pytest.param(_HttpError(403), False, id="http_403"), + pytest.param(_HttpError(404), False, id="http_404"), + pytest.param(_HttpError(422), False, id="http_422"), + # Built-in value/type errors + pytest.param(ValueError("bad value"), False, id="ValueError"), + pytest.param(TypeError("wrong type"), False, id="TypeError"), + # String-pattern permanent errors + pytest.param( + Exception("InvalidRequestError: field missing"), False, id="str_InvalidRequestError" + ), + pytest.param( + Exception("AuthenticationError: invalid key"), False, id="str_AuthenticationError" + ), + # Unknown errors — conservative default False + pytest.param(Exception("some unknown error"), False, id="unknown_generic"), + pytest.param(RuntimeError("unexpected state"), False, id="RuntimeError_unknown"), + pytest.param(KeyError("missing key"), False, id="KeyError"), + pytest.param(AttributeError("no attr"), False, id="AttributeError"), + pytest.param( + Exception("config_value_out_of_range"), False, id="str_unknown_no_transient_keyword" + ), +] + + +@pytest.mark.parametrize("exc,expected", _TRANSIENT_CASES) +def test_is_transient_error_transient(exc, expected): + """Transient errors should be classified as retryable (True).""" + assert is_transient_error(exc) is expected + + +@pytest.mark.parametrize("exc,expected", _PERMANENT_CASES) +def test_is_transient_error_permanent(exc, expected): + """Permanent / unknown errors should not be retried (False).""" + assert is_transient_error(exc) is expected + + +# --------------------------------------------------------------------------- +# transient_retry (sync) +# --------------------------------------------------------------------------- + + +class TestTransientRetrySync: + """Sync retry behaviour tests.""" + + def test_success_first_try(self): + """Function succeeds on first attempt — call_count == 1.""" + call_count = 0 + + def fn(): + nonlocal call_count + call_count += 1 + return "ok" + + result = transient_retry(fn, max_retries=3) + assert result == "ok" + assert call_count == 1 + + def test_retry_then_success(self): + """Two transient failures then success — call_count == 3.""" + errors = [_HttpError(429), _HttpError(503)] + call_count = 0 + + def fn(): + nonlocal call_count + call_count += 1 + if errors: + raise errors.pop(0) + return "ok" + + with patch("time.sleep"): + result = transient_retry(fn, max_retries=3) + + assert result == "ok" + assert call_count == 3 + + def test_permanent_error_no_retry(self): + """Permanent error (401) should not be retried — call_count == 1 and raises.""" + call_count = 0 + + def fn(): + nonlocal call_count + call_count += 1 + raise _HttpError(401) + + with patch("time.sleep"): + with pytest.raises(_HttpError): + transient_retry(fn, max_retries=3) + + assert call_count == 1 + + def test_max_retries_exhausted(self): + """4 consecutive 429 errors with max_retries=3 → raises after 4 calls.""" + call_count = 0 + + def fn(): + nonlocal call_count + call_count += 1 + raise _HttpError(429) + + with patch("time.sleep"): + with pytest.raises(_HttpError): + transient_retry(fn, max_retries=3) + + assert call_count == 4 # 1 initial + 3 retries + + def test_max_retries_zero_raises_immediately(self): + """max_retries=0 disables retrying — call_count == 1.""" + call_count = 0 + + def fn(): + nonlocal call_count + call_count += 1 + raise _HttpError(429) + + with patch("time.sleep"): + with pytest.raises(_HttpError): + transient_retry(fn, max_retries=0) + + assert call_count == 1 + + def test_max_retries_one(self): + """max_retries=1: one failure then success → call_count == 2.""" + call_count = 0 + + def fn(): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise _HttpError(429) + return "done" + + with patch("time.sleep"): + result = transient_retry(fn, max_retries=1) + + assert result == "done" + assert call_count == 2 + + def test_backoff_delays_exponential(self): + """Verify exponential backoff: base_delay=1.0, jitter=False → 1.0, 2.0, 4.0.""" + call_count = 0 + sleep_calls = [] + + def fn(): + nonlocal call_count + call_count += 1 + raise _HttpError(429) + + with patch("time.sleep", side_effect=lambda d: sleep_calls.append(d)): + with pytest.raises(_HttpError): + transient_retry(fn, max_retries=3, base_delay=1.0, max_delay=100.0, jitter=False) + + assert len(sleep_calls) == 3 + assert sleep_calls[0] == pytest.approx(1.0) + assert sleep_calls[1] == pytest.approx(2.0) + assert sleep_calls[2] == pytest.approx(4.0) + + def test_delay_capped_at_max_delay(self): + """Delays must not exceed max_delay even with many retries.""" + sleep_calls = [] + call_count = 0 + + def fn(): + nonlocal call_count + call_count += 1 + raise _HttpError(503) + + with patch("time.sleep", side_effect=lambda d: sleep_calls.append(d)): + with pytest.raises(_HttpError): + transient_retry(fn, max_retries=10, base_delay=1.0, max_delay=8.0, jitter=False) + + assert all(d <= 8.0 for d in sleep_calls), f"Some delays exceed max_delay: {sleep_calls}" + + +# --------------------------------------------------------------------------- +# transient_retry_async (async) +# --------------------------------------------------------------------------- + + +class TestTransientRetryAsync: + """Async retry behaviour tests — mirrors sync suite.""" + + async def test_success_first_try(self): + """Async function succeeds on first attempt — call_count == 1.""" + call_count = 0 + + async def coro(): + nonlocal call_count + call_count += 1 + return "ok" + + result = await transient_retry_async(coro, max_retries=3) + assert result == "ok" + assert call_count == 1 + + async def test_retry_then_success(self): + """Two transient failures then success — call_count == 3.""" + errors = [_HttpError(429), _HttpError(503)] + call_count = 0 + + async def coro(): + nonlocal call_count + call_count += 1 + if errors: + raise errors.pop(0) + return "ok" + + with patch("asyncio.sleep", new_callable=AsyncMock): + result = await transient_retry_async(coro, max_retries=3) + + assert result == "ok" + assert call_count == 3 + + async def test_permanent_error_no_retry(self): + """Permanent error (401) should not be retried — call_count == 1 and raises.""" + call_count = 0 + + async def coro(): + nonlocal call_count + call_count += 1 + raise _HttpError(401) + + with patch("asyncio.sleep", new_callable=AsyncMock): + with pytest.raises(_HttpError): + await transient_retry_async(coro, max_retries=3) + + assert call_count == 1 + + async def test_max_retries_exhausted(self): + """4 consecutive 429 errors with max_retries=3 → raises after 4 calls.""" + call_count = 0 + + async def coro(): + nonlocal call_count + call_count += 1 + raise _HttpError(429) + + with patch("asyncio.sleep", new_callable=AsyncMock): + with pytest.raises(_HttpError): + await transient_retry_async(coro, max_retries=3) + + assert call_count == 4 + + async def test_max_retries_zero_raises_immediately(self): + """max_retries=0 disables retrying — call_count == 1.""" + call_count = 0 + + async def coro(): + nonlocal call_count + call_count += 1 + raise _HttpError(429) + + with patch("asyncio.sleep", new_callable=AsyncMock): + with pytest.raises(_HttpError): + await transient_retry_async(coro, max_retries=0) + + assert call_count == 1 + + async def test_max_retries_one(self): + """max_retries=1: one failure then success → call_count == 2.""" + call_count = 0 + + async def coro(): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise _HttpError(429) + return "done" + + with patch("asyncio.sleep", new_callable=AsyncMock): + result = await transient_retry_async(coro, max_retries=1) + + assert result == "done" + assert call_count == 2 + + async def test_backoff_delays_exponential(self): + """Verify exponential backoff: base_delay=1.0, jitter=False → 1.0, 2.0, 4.0.""" + call_count = 0 + sleep_calls = [] + + async def fake_sleep(d): + sleep_calls.append(d) + + async def coro(): + nonlocal call_count + call_count += 1 + raise _HttpError(429) + + with patch("asyncio.sleep", side_effect=fake_sleep): + with pytest.raises(_HttpError): + await transient_retry_async( + coro, max_retries=3, base_delay=1.0, max_delay=100.0, jitter=False + ) + + assert len(sleep_calls) == 3 + assert sleep_calls[0] == pytest.approx(1.0) + assert sleep_calls[1] == pytest.approx(2.0) + assert sleep_calls[2] == pytest.approx(4.0) + + async def test_delay_capped_at_max_delay(self): + """Async delays must not exceed max_delay even with many retries.""" + sleep_calls = [] + call_count = 0 + + async def fake_sleep(d): + sleep_calls.append(d) + + async def coro(): + nonlocal call_count + call_count += 1 + raise _HttpError(503) + + with patch("asyncio.sleep", side_effect=fake_sleep): + with pytest.raises(_HttpError): + await transient_retry_async( + coro, max_retries=10, base_delay=1.0, max_delay=8.0, jitter=False + ) + + assert all(d <= 8.0 for d in sleep_calls), f"Some delays exceed max_delay: {sleep_calls}" + + +# --------------------------------------------------------------------------- +# Additional edge-case tests +# --------------------------------------------------------------------------- + + +class TestIsTransientErrorEdgeCases: + """Edge cases for is_transient_error.""" + + def test_timeout_substring_in_message(self): + """'timeout' substring in message → transient.""" + err = Exception("connection timeout after 10s") + assert is_transient_error(err) is True + + def test_status_code_attribute_takes_priority(self): + """status_code=503 → transient, even if message says 'bad request'.""" + err = _HttpError(503, "bad request") + assert is_transient_error(err) is True + + def test_status_code_401_permanent_priority(self): + """status_code=401 → permanent, even if message contains 'timeout'.""" + err = _HttpError(401, "timeout auth failure") + assert is_transient_error(err) is False + + def test_custom_is_retryable_overrides(self): + """Custom is_retryable callback overrides default classification.""" + # 429 is normally transient but we pass a custom fn that returns False + call_count = 0 + + def fn(): + nonlocal call_count + call_count += 1 + raise _HttpError(429) + + with patch("time.sleep"): + with pytest.raises(_HttpError): + transient_retry(fn, max_retries=3, is_retryable=lambda e: False) + + assert call_count == 1 # no retries because custom fn says not retryable + + def test_http_status_attribute_variant(self): + """Objects with .http_status should be checked for transient status.""" + + class AltHttpError(Exception): + def __init__(self, http_status: int): + super().__init__(f"HTTP {http_status}") + self.http_status = http_status + + assert is_transient_error(AltHttpError(503)) is True + assert is_transient_error(AltHttpError(401)) is False + + def test_code_attribute_variant(self): + """Objects with .code should be checked for transient status.""" + + class CodeError(Exception): + def __init__(self, code: int): + super().__init__(f"Error code {code}") + self.code = code + + assert is_transient_error(CodeError(429)) is True + assert is_transient_error(CodeError(403)) is False diff --git a/tests/unit/test_retry_config.py b/tests/unit/test_retry_config.py new file mode 100644 index 000000000..24f28fb5e --- /dev/null +++ b/tests/unit/test_retry_config.py @@ -0,0 +1,81 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for retry configuration fields on VLMConfig and EmbeddingConfig. + +Verifies that: +- VLMConfig default max_retries = 3 +- EmbeddingConfig has max_retries field, default = 3 +- EmbeddingConfig accepts custom max_retries +""" + +from __future__ import annotations + + +class TestVLMConfigMaxRetries: + def test_default_max_retries(self): + """VLMConfig should default max_retries to 3.""" + from openviking_cli.utils.config.vlm_config import VLMConfig + + cfg = VLMConfig( + model="gpt-4o-mini", + api_key="sk-test", + provider="openai", + ) + assert cfg.max_retries == 3 + + def test_custom_max_retries(self): + """VLMConfig should accept custom max_retries.""" + from openviking_cli.utils.config.vlm_config import VLMConfig + + cfg = VLMConfig( + model="gpt-4o-mini", + api_key="sk-test", + provider="openai", + max_retries=10, + ) + assert cfg.max_retries == 10 + + +class TestEmbeddingConfigMaxRetries: + def test_has_max_retries_field(self): + """EmbeddingConfig should have a max_retries field.""" + from openviking_cli.utils.config.embedding_config import EmbeddingConfig + + fields = EmbeddingConfig.model_fields + assert "max_retries" in fields, ( + f"EmbeddingConfig is missing 'max_retries' field. Fields: {list(fields.keys())}" + ) + + def test_default_max_retries(self): + """EmbeddingConfig should default max_retries to 3.""" + from openviking_cli.utils.config.embedding_config import ( + EmbeddingConfig, + EmbeddingModelConfig, + ) + + cfg = EmbeddingConfig( + dense=EmbeddingModelConfig( + model="text-embedding-3-small", + api_key="sk-test", + provider="openai", + ), + ) + assert cfg.max_retries == 3 + + def test_custom_max_retries(self): + """EmbeddingConfig should accept custom max_retries.""" + from openviking_cli.utils.config.embedding_config import ( + EmbeddingConfig, + EmbeddingModelConfig, + ) + + cfg = EmbeddingConfig( + dense=EmbeddingModelConfig( + model="text-embedding-3-small", + api_key="sk-test", + provider="openai", + ), + max_retries=7, + ) + assert cfg.max_retries == 7 diff --git a/tests/unit/test_stream_config_vlm.py b/tests/unit/test_stream_config_vlm.py index 64b2f81c2..dea3e285b 100644 --- a/tests/unit/test_stream_config_vlm.py +++ b/tests/unit/test_stream_config_vlm.py @@ -253,7 +253,7 @@ class StubVLM(OpenAIVLM): def get_completion(self, prompt, thinking=False): return "" - async def get_completion_async(self, prompt, thinking=False, max_retries=0): + async def get_completion_async(self, prompt, thinking=False): return "" def get_vision_completion(self, prompt, images, thinking=False): @@ -277,7 +277,7 @@ class StubVLM(OpenAIVLM): def get_completion(self, prompt, thinking=False): return "" - async def get_completion_async(self, prompt, thinking=False, max_retries=0): + async def get_completion_async(self, prompt, thinking=False): return "" def get_vision_completion(self, prompt, images, thinking=False): diff --git a/tests/unit/test_vlm_retry_integration.py b/tests/unit/test_vlm_retry_integration.py new file mode 100644 index 000000000..e65f0ef37 --- /dev/null +++ b/tests/unit/test_vlm_retry_integration.py @@ -0,0 +1,314 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 + +"""Integration tests for VLM backends with unified retry logic. + +Tests cover (using OpenAI backend as representative): +- completion retries on 429 (transient) +- completion does NOT retry on 401 (permanent) +- vision completion now retries (was zero before) +- uses config max_retries +- max_retries parameter removed from get_completion_async signature +""" + +from __future__ import annotations + +import inspect +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class _HttpError(Exception): + """Fake HTTP error carrying a numeric status code.""" + + def __init__(self, status_code: int, message: str = ""): + super().__init__(message or f"HTTP {status_code}") + self.status_code = status_code + + +def _make_fake_response(content: str = "ok") -> SimpleNamespace: + """Build a minimal fake OpenAI ChatCompletion response.""" + message = SimpleNamespace(content=content, tool_calls=None) + choice = SimpleNamespace(message=message, finish_reason="stop") + usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15) + return SimpleNamespace(choices=[choice], usage=usage) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def openai_vlm(): + """Create an OpenAIVLM instance with mocked clients.""" + from openviking.models.vlm.backends.openai_vlm import OpenAIVLM + + vlm = OpenAIVLM( + { + "api_key": "sk-test", + "model": "gpt-4o-mini", + "provider": "openai", + "max_retries": 2, + } + ) + + # Mock sync client + mock_sync = MagicMock() + vlm._sync_client = mock_sync + + # Mock async client + mock_async = MagicMock() + vlm._async_client = mock_async + + return vlm + + +# --------------------------------------------------------------------------- +# Tests: get_completion_async retries on 429 +# --------------------------------------------------------------------------- + + +class TestCompletionAsyncRetries: + async def test_retries_on_429(self, openai_vlm): + """get_completion_async should retry on 429 (transient) and succeed.""" + errors = [_HttpError(429), _HttpError(429)] + call_count = 0 + + async def fake_create(**kwargs): + nonlocal call_count + call_count += 1 + if errors: + raise errors.pop(0) + return _make_fake_response("success") + + openai_vlm._async_client.chat.completions.create = fake_create + + with patch("asyncio.sleep", new_callable=AsyncMock): + result = await openai_vlm.get_completion_async(prompt="test") + + assert result == "success" + assert call_count == 3 # 2 failures + 1 success + + async def test_no_retry_on_401(self, openai_vlm): + """get_completion_async should NOT retry on 401 (permanent).""" + call_count = 0 + + async def fake_create(**kwargs): + nonlocal call_count + call_count += 1 + raise _HttpError(401, "Unauthorized") + + openai_vlm._async_client.chat.completions.create = fake_create + + with patch("asyncio.sleep", new_callable=AsyncMock): + with pytest.raises(_HttpError): + await openai_vlm.get_completion_async(prompt="test") + + assert call_count == 1 # no retries + + async def test_uses_config_max_retries(self): + """Backend should use self.max_retries from config, not a param.""" + from openviking.models.vlm.backends.openai_vlm import OpenAIVLM + + vlm = OpenAIVLM( + { + "api_key": "sk-test", + "model": "gpt-4o-mini", + "provider": "openai", + "max_retries": 5, + } + ) + assert vlm.max_retries == 5 + + # Config default is now 3 + vlm2 = OpenAIVLM( + { + "api_key": "sk-test", + "model": "gpt-4o-mini", + "provider": "openai", + } + ) + assert vlm2.max_retries == 3 + + +# --------------------------------------------------------------------------- +# Tests: get_vision_completion_async now retries +# --------------------------------------------------------------------------- + + +class TestVisionCompletionAsyncRetries: + async def test_vision_retries_on_429(self, openai_vlm): + """get_vision_completion_async should retry on 429 (was zero retry before).""" + errors = [_HttpError(429)] + call_count = 0 + + async def fake_create(**kwargs): + nonlocal call_count + call_count += 1 + if errors: + raise errors.pop(0) + return _make_fake_response("vision ok") + + openai_vlm._async_client.chat.completions.create = fake_create + + with patch("asyncio.sleep", new_callable=AsyncMock): + result = await openai_vlm.get_vision_completion_async( + prompt="describe", + images=["http://example.com/img.png"], + ) + + assert result == "vision ok" + assert call_count == 2 # 1 failure + 1 success + + +# --------------------------------------------------------------------------- +# Tests: sync completion retries +# --------------------------------------------------------------------------- + + +class TestCompletionSyncRetries: + def test_sync_retries_on_429(self, openai_vlm): + """get_completion should retry on 429.""" + errors = [_HttpError(429)] + call_count = 0 + + def fake_create(**kwargs): + nonlocal call_count + call_count += 1 + if errors: + raise errors.pop(0) + return _make_fake_response("sync ok") + + openai_vlm._sync_client.chat.completions.create = fake_create + + with patch("time.sleep"): + result = openai_vlm.get_completion(prompt="test") + + assert result == "sync ok" + assert call_count == 2 + + def test_sync_vision_retries_on_503(self, openai_vlm): + """get_vision_completion should retry on 503.""" + errors = [_HttpError(503)] + call_count = 0 + + def fake_create(**kwargs): + nonlocal call_count + call_count += 1 + if errors: + raise errors.pop(0) + return _make_fake_response("vision sync ok") + + openai_vlm._sync_client.chat.completions.create = fake_create + + with patch("time.sleep"): + result = openai_vlm.get_vision_completion( + prompt="describe", + images=["http://example.com/img.png"], + ) + + assert result == "vision sync ok" + assert call_count == 2 + + +# --------------------------------------------------------------------------- +# Tests: signature change verification +# --------------------------------------------------------------------------- + + +class TestSignatureChange: + def test_no_max_retries_in_get_completion_async(self): + """get_completion_async should no longer accept max_retries parameter.""" + from openviking.models.vlm.backends.openai_vlm import OpenAIVLM + + sig = inspect.signature(OpenAIVLM.get_completion_async) + param_names = list(sig.parameters.keys()) + + assert "max_retries" not in param_names, ( + f"max_retries should be removed from get_completion_async, got params: {param_names}" + ) + + def test_no_max_retries_in_base_get_completion_async(self): + """VLMBase.get_completion_async should no longer accept max_retries parameter.""" + from openviking.models.vlm.base import VLMBase + + sig = inspect.signature(VLMBase.get_completion_async) + param_names = list(sig.parameters.keys()) + + assert "max_retries" not in param_names, ( + f"max_retries should be removed from VLMBase.get_completion_async, got params: {param_names}" + ) + + def test_no_max_retries_in_litellm_get_completion_async(self): + """LiteLLMVLMProvider.get_completion_async should no longer accept max_retries.""" + from openviking.models.vlm.backends.litellm_vlm import LiteLLMVLMProvider + + sig = inspect.signature(LiteLLMVLMProvider.get_completion_async) + param_names = list(sig.parameters.keys()) + + assert "max_retries" not in param_names + + def test_no_max_retries_in_volcengine_get_completion_async(self): + """VolcEngineVLM.get_completion_async should no longer accept max_retries.""" + from openviking.models.vlm.backends.volcengine_vlm import VolcEngineVLM + + sig = inspect.signature(VolcEngineVLM.get_completion_async) + param_names = list(sig.parameters.keys()) + + assert "max_retries" not in param_names + + +# --------------------------------------------------------------------------- +# Tests: OpenAI SDK retry disabled +# --------------------------------------------------------------------------- + + +class TestOpenAISDKRetryDisabled: + def test_sync_client_max_retries_zero(self): + """OpenAI sync client should be created with max_retries=0.""" + from openviking.models.vlm.backends.openai_vlm import OpenAIVLM + + vlm = OpenAIVLM( + { + "api_key": "sk-test", + "model": "gpt-4o-mini", + "provider": "openai", + } + ) + + with patch("openai.OpenAI") as mock_openai: + mock_openai.return_value = MagicMock() + vlm._sync_client = None # force re-creation + vlm.get_client() + call_kwargs = mock_openai.call_args + assert call_kwargs[1].get("max_retries") == 0 or ( + len(call_kwargs[0]) == 0 and call_kwargs.kwargs.get("max_retries") == 0 + ) + + def test_async_client_max_retries_zero(self): + """OpenAI async client should be created with max_retries=0.""" + from openviking.models.vlm.backends.openai_vlm import OpenAIVLM + + vlm = OpenAIVLM( + { + "api_key": "sk-test", + "model": "gpt-4o-mini", + "provider": "openai", + } + ) + + with patch("openai.AsyncOpenAI") as mock_async_openai: + mock_async_openai.return_value = MagicMock() + vlm._async_client = None # force re-creation + vlm.get_async_client() + call_kwargs = mock_async_openai.call_args + assert call_kwargs[1].get("max_retries") == 0 or ( + len(call_kwargs[0]) == 0 and call_kwargs.kwargs.get("max_retries") == 0 + )