diff --git a/docs/en/guides/01-configuration.md b/docs/en/guides/01-configuration.md index 136ec1277..4a71fc4d7 100644 --- a/docs/en/guides/01-configuration.md +++ b/docs/en/guides/01-configuration.md @@ -99,6 +99,7 @@ Embedding model configuration for vector search, supporting dense, sparse, and h { "embedding": { "max_concurrent": 10, + "max_retries": 3, "dense": { "provider": "volcengine", "api_key": "your-api-key", @@ -115,6 +116,7 @@ Embedding model configuration for vector search, supporting dense, sparse, and h | Parameter | Type | Description | |-----------|------|-------------| | `max_concurrent` | int | Maximum concurrent embedding requests (`embedding.max_concurrent`, default: `10`) | +| `max_retries` | int | Maximum retry attempts for transient embedding provider errors (`embedding.max_retries`, default: `3`; `0` disables retry) | | `provider` | str | `"volcengine"`, `"openai"`, `"vikingdb"`, `"jina"`, `"voyage"`, or `"gemini"` | | `api_key` | str | API key | | `model` | str | Model name | @@ -122,6 +124,8 @@ Embedding model configuration for vector search, supporting dense, sparse, and h | `input` | str | Input type: `"text"` or `"multimodal"` | | `batch_size` | int | Batch size for embedding requests | +`embedding.max_retries` only applies to transient errors such as `429`, `5xx`, timeouts, and connection failures. Permanent errors such as `400`, `401`, `403`, and `AccountOverdue` are not retried automatically. The backoff strategy is exponential backoff with jitter, starting at `0.5s` and capped at `8s`. + **Available Models** | Model | Dimension | Input Type | Notes | @@ -355,7 +359,8 @@ Vision Language Model for semantic extraction (L0/L1 generation). "vlm": { "api_key": "your-api-key", "model": "doubao-seed-2-0-pro-260215", - "api_base": "https://ark.cn-beijing.volces.com/api/v3" + "api_base": "https://ark.cn-beijing.volces.com/api/v3", + "max_retries": 3 } } ``` @@ -369,9 +374,12 @@ Vision Language Model for semantic extraction (L0/L1 generation). | `api_base` | str | API endpoint (optional) | | `thinking` | bool | Enable thinking mode for VolcEngine models (default: `false`) | | `max_concurrent` | int | Maximum concurrent semantic LLM calls (default: `100`) | +| `max_retries` | int | Maximum retry attempts for transient VLM provider errors (default: `3`; `0` disables retry) | | `extra_headers` | object | Custom HTTP headers (for OpenAI-compatible providers, optional) | | `stream` | bool | Enable streaming mode (for OpenAI-compatible providers, default: `false`) | +`vlm.max_retries` only applies to transient errors such as `429`, `5xx`, timeouts, and connection failures. Permanent authentication, authorization, and billing errors are not retried automatically. The backoff strategy is exponential backoff with jitter, starting at `0.5s` and capped at `8s`. + **Available Models** | Model | Notes | @@ -956,6 +964,7 @@ For detailed encryption explanations, see [Data Encryption](../concepts/10-encry { "embedding": { "max_concurrent": 10, + "max_retries": 3, "dense": { "provider": "volcengine", "api_key": "string", @@ -971,6 +980,7 @@ For detailed encryption explanations, see [Data Encryption](../concepts/10-encry "api_base": "string", "thinking": false, "max_concurrent": 100, + "max_retries": 3, "extra_headers": {}, "stream": false }, @@ -1058,7 +1068,9 @@ Error: VLM request timeout - Check network connectivity - Increase timeout in config +- For intermittent timeouts, increase `vlm.max_retries` moderately - Try a smaller model +- For bulk ingestion, consider lowering `vlm.max_concurrent` ### Rate Limiting @@ -1067,6 +1079,8 @@ Error: Rate limit exceeded ``` Volcengine has rate limits. Consider batch processing with delays or upgrading your plan. +- Lower `embedding.max_concurrent` / `vlm.max_concurrent` first +- Keep a small `max_retries` value for occasional `429`s; set it to `0` if you prefer fail-fast behavior ## Related Documentation diff --git a/docs/zh/guides/01-configuration.md b/docs/zh/guides/01-configuration.md index 5c2396883..65b6040ba 100644 --- a/docs/zh/guides/01-configuration.md +++ b/docs/zh/guides/01-configuration.md @@ -104,6 +104,7 @@ OpenViking 使用 JSON 配置文件(`ov.conf`)进行设置。配置文件支 { "embedding": { "max_concurrent": 10, + "max_retries": 3, "dense": { "provider": "volcengine", "api_key": "your-api-key", @@ -121,6 +122,7 @@ OpenViking 使用 JSON 配置文件(`ov.conf`)进行设置。配置文件支 | 参数 | 类型 | 说明 | |------|------|------| | `max_concurrent` | int | 最大并发 Embedding 请求数(`embedding.max_concurrent`,默认:`10`) | +| `max_retries` | int | Embedding provider 瞬时错误的最大重试次数(`embedding.max_retries`,默认:`3`;`0` 表示禁用重试) | | `provider` | str | `"volcengine"`、`"openai"`、`"vikingdb"`、`"jina"`、`"voyage"`、`"minimax"` 或 `"gemini"` | | `api_key` | str | API Key | | `model` | str | 模型名称 | @@ -128,6 +130,8 @@ OpenViking 使用 JSON 配置文件(`ov.conf`)进行设置。配置文件支 | `input` | str | 输入类型:`"text"` 或 `"multimodal"` | | `batch_size` | int | 批量请求大小 | +`embedding.max_retries` 仅对瞬时错误生效,例如 `429`、`5xx`、超时和连接错误;`400`、`401`、`403`、`AccountOverdue` 这类永久错误不会自动重试。退避策略为指数退避,初始延迟 `0.5s`,上限 `8s`,并带随机抖动。 + **可用模型** | 模型 | 维度 | 输入类型 | 说明 | @@ -330,7 +334,8 @@ OpenViking 使用 JSON 配置文件(`ov.conf`)进行设置。配置文件支 "provider": "volcengine", "api_key": "your-api-key", "model": "doubao-seed-2-0-pro-260215", - "api_base": "https://ark.cn-beijing.volces.com/api/v3" + "api_base": "https://ark.cn-beijing.volces.com/api/v3", + "max_retries": 3 } } ``` @@ -344,9 +349,12 @@ OpenViking 使用 JSON 配置文件(`ov.conf`)进行设置。配置文件支 | `api_base` | str | API 端点(可选) | | `thinking` | bool | 启用思考模式(仅对部分火山模型生效,默认:`false`) | | `max_concurrent` | int | 语义处理阶段 LLM 最大并发调用数(默认:`100`) | +| `max_retries` | int | VLM provider 瞬时错误的最大重试次数(默认:`3`;`0` 表示禁用重试) | | `extra_headers` | object | 自定义 HTTP 请求头(OpenAI 兼容 provider 可用,可选) | | `stream` | bool | 启用流式模式(OpenAI 兼容 provider 可用,默认:`false`) | +`vlm.max_retries` 仅对瞬时错误生效,例如 `429`、`5xx`、超时和连接错误;认证、鉴权、欠费等永久错误不会自动重试。退避策略为指数退避,初始延迟 `0.5s`,上限 `8s`,并带随机抖动。 + **可用模型** | 模型 | 说明 | @@ -933,6 +941,7 @@ openviking --account acme --user alice --agent-id assistant-2 ls viking:// { "embedding": { "max_concurrent": 10, + "max_retries": 3, "dense": { "provider": "volcengine", "api_key": "string", @@ -948,6 +957,7 @@ openviking --account acme --user alice --agent-id assistant-2 ls viking:// "api_base": "string", "thinking": false, "max_concurrent": 100, + "max_retries": 3, "extra_headers": {}, "stream": false }, @@ -1035,7 +1045,9 @@ Error: VLM request timeout - 检查网络连接 - 增加配置中的超时时间 +- 对偶发超时,适当增大 `vlm.max_retries` - 尝试更小的模型 +- 如为批量导入场景,结合降低 `vlm.max_concurrent` ### 速率限制 @@ -1044,6 +1056,8 @@ Error: Rate limit exceeded ``` 火山引擎有速率限制。考虑批量处理时添加延迟或升级套餐。 +- 优先降低 `embedding.max_concurrent` / `vlm.max_concurrent` +- 对偶发 `429` 可保留少量 `max_retries`;若希望快速失败,可将其设为 `0` ## 相关文档 diff --git a/openviking/models/embedder/base.py b/openviking/models/embedder/base.py index 9a9e5df3e..4aa101c81 100644 --- a/openviking/models/embedder/base.py +++ b/openviking/models/embedder/base.py @@ -6,6 +6,8 @@ from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, TypeVar +from openviking.utils.model_retry import retry_sync + T = TypeVar("T") @@ -74,6 +76,7 @@ def __init__(self, model_name: str, config: Optional[Dict[str, Any]] = None): """ self.model_name = model_name self.config = config or {} + self.max_retries = int(self.config.get("max_retries", 3)) @abstractmethod def embed(self, text: str, is_query: bool = False) -> EmbedResult: @@ -104,6 +107,14 @@ def close(self): """Release resources, subclasses can override as needed""" pass + def _run_with_retry(self, func: Callable[[], T], *, logger=None, operation_name: str) -> T: + return retry_sync( + func, + max_retries=self.max_retries, + logger=logger, + operation_name=operation_name, + ) + @property def is_dense(self) -> bool: """Check if result contains dense vector""" @@ -255,7 +266,7 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes return [ EmbedResult(dense_vector=d.dense_vector, sparse_vector=s.sparse_vector) - for d, s in zip(dense_results, sparse_results) + for d, s in zip(dense_results, sparse_results, strict=True) ] def get_dimension(self) -> int: diff --git a/openviking/models/embedder/gemini_embedders.py b/openviking/models/embedder/gemini_embedders.py index c11d73e47..c233fc98e 100644 --- a/openviking/models/embedder/gemini_embedders.py +++ b/openviking/models/embedder/gemini_embedders.py @@ -151,9 +151,9 @@ def __init__( api_key=api_key, http_options=HttpOptions( retry_options=HttpRetryOptions( - attempts=3, - initial_delay=1.0, - max_delay=30.0, + attempts=max(self.max_retries + 1, 1), + initial_delay=0.5, + max_delay=8.0, exp_base=2.0, ) ), @@ -207,8 +207,9 @@ def embed( task_type = self.query_param elif not is_query and self.document_param: task_type = self.document_param + # SDK accepts plain str; converts to REST Parts format internally. - try: + def _call() -> EmbedResult: result = self.client.models.embed_content( model=self.model_name, contents=text, @@ -216,6 +217,15 @@ def embed( ) vector = truncate_and_normalize(list(result.embeddings[0].values), self._dimension) return EmbedResult(dense_vector=vector) + + try: + if _HTTP_RETRY_AVAILABLE: + return _call() + return self._run_with_retry( + _call, + logger=logger, + operation_name="Gemini embedding", + ) except (APIError, ClientError) as e: _raise_api_error(e, self.model_name) @@ -233,7 +243,7 @@ def embed_batch( if titles is not None: return [ self.embed(text, is_query=is_query, task_type=task_type, title=title) - for text, title in zip(texts, titles) + for text, title in zip(texts, titles, strict=True) ] # Resolve effective task_type from is_query when no explicit override if task_type is None: @@ -253,14 +263,29 @@ def embed_batch( continue non_empty_texts = [batch[j] for j in non_empty_indices] - try: + + def _call_batch( + non_empty_texts: List[str] = non_empty_texts, + config: types.EmbedContentConfig = config, + ) -> Any: response = self.client.models.embed_content( model=self.model_name, contents=non_empty_texts, config=config, ) + return response + + try: + if _HTTP_RETRY_AVAILABLE: + response = _call_batch() + else: + response = self._run_with_retry( + _call_batch, + logger=logger, + operation_name="Gemini batch embedding", + ) batch_results = [None] * len(batch) - for j, emb in zip(non_empty_indices, response.embeddings): + for j, emb in zip(non_empty_indices, response.embeddings, strict=True): batch_results[j] = EmbedResult( dense_vector=truncate_and_normalize(list(emb.values), self._dimension) ) diff --git a/openviking/models/embedder/jina_embedders.py b/openviking/models/embedder/jina_embedders.py index f94650fcc..d17691d9c 100644 --- a/openviking/models/embedder/jina_embedders.py +++ b/openviking/models/embedder/jina_embedders.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: AGPL-3.0 """Jina AI Embedder Implementation""" +import logging from typing import Any, Dict, List, Optional import openai @@ -11,6 +12,8 @@ EmbedResult, ) +logger = logging.getLogger(__name__) + # Default dimensions for Jina embedding models JINA_MODEL_DIMENSIONS = { "jina-embeddings-v5-text-small": 1024, # 677M params, max seq 32768 @@ -165,7 +168,8 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult: Raises: RuntimeError: When API call fails """ - try: + + def _call() -> EmbedResult: kwargs: Dict[str, Any] = {"input": text, "model": self.model_name} if self.dimension: kwargs["dimensions"] = self.dimension @@ -178,6 +182,13 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult: vector = response.data[0].embedding return EmbedResult(dense_vector=vector) + + try: + return self._run_with_retry( + _call, + logger=logger, + operation_name="Jina embedding", + ) except openai.APIError as e: self._raise_task_error(e) raise RuntimeError(f"Jina API error: {e.message}") from e @@ -200,7 +211,7 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes if not texts: return [] - try: + def _call() -> List[EmbedResult]: kwargs: Dict[str, Any] = {"input": texts, "model": self.model_name} if self.dimension: kwargs["dimensions"] = self.dimension @@ -212,6 +223,13 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes response = self.client.embeddings.create(**kwargs) return [EmbedResult(dense_vector=item.embedding) for item in response.data] + + try: + return self._run_with_retry( + _call, + logger=logger, + operation_name="Jina batch embedding", + ) except openai.APIError as e: self._raise_task_error(e) raise RuntimeError(f"Jina API error: {e.message}") from e diff --git a/openviking/models/embedder/litellm_embedders.py b/openviking/models/embedder/litellm_embedders.py index ea24c8141..9fd76f7fe 100644 --- a/openviking/models/embedder/litellm_embedders.py +++ b/openviking/models/embedder/litellm_embedders.py @@ -154,13 +154,21 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult: Raises: RuntimeError: When embedding call fails """ - try: + + def _call() -> EmbedResult: kwargs = self._build_kwargs(is_query=is_query) kwargs["input"] = [text] response = litellm.embedding(**kwargs) self._update_telemetry_token_usage(response) vector = response.data[0]["embedding"] return EmbedResult(dense_vector=vector) + + try: + return self._run_with_retry( + _call, + logger=logger, + operation_name="LiteLLM embedding", + ) except Exception as e: raise RuntimeError(f"LiteLLM embedding failed: {e}") from e @@ -180,12 +188,19 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes if not texts: return [] - try: + def _call() -> List[EmbedResult]: kwargs = self._build_kwargs(is_query=is_query) kwargs["input"] = texts response = litellm.embedding(**kwargs) self._update_telemetry_token_usage(response) return [EmbedResult(dense_vector=item["embedding"]) for item in response.data] + + try: + return self._run_with_retry( + _call, + logger=logger, + operation_name="LiteLLM batch embedding", + ) except Exception as e: raise RuntimeError(f"LiteLLM batch embedding failed: {e}") from e diff --git a/openviking/models/embedder/minimax_embedders.py b/openviking/models/embedder/minimax_embedders.py index aba462968..db143fe47 100644 --- a/openviking/models/embedder/minimax_embedders.py +++ b/openviking/models/embedder/minimax_embedders.py @@ -90,8 +90,8 @@ def _create_session(self) -> requests.Session: """Create a requests session with retry logic""" session = requests.Session() retry_strategy = Retry( - total=6, - backoff_factor=1, # 1s, 2s, 4s, 8s, 16s, 32s + total=self.max_retries, + backoff_factor=0.5, status_forcelist=[429, 500, 502, 503, 504], allowed_methods=["POST"], ) diff --git a/openviking/models/embedder/openai_embedders.py b/openviking/models/embedder/openai_embedders.py index 0ebabbeee..09d0d855a 100644 --- a/openviking/models/embedder/openai_embedders.py +++ b/openviking/models/embedder/openai_embedders.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: AGPL-3.0 """OpenAI Embedder Implementation""" +import logging from typing import Any, Dict, List, Optional import openai @@ -15,6 +16,8 @@ from openviking.models.vlm.registry import DEFAULT_AZURE_API_VERSION from openviking.telemetry import get_current_telemetry +logger = logging.getLogger(__name__) + class OpenAIDenseEmbedder(DenseEmbedderBase): """OpenAI-Compatible Dense Embedder Implementation @@ -235,7 +238,8 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult: Raises: RuntimeError: When API call fails """ - try: + + def _call() -> EmbedResult: kwargs: Dict[str, Any] = {"input": text, "model": self.model_name} extra_body = self._build_extra_body(is_query=is_query) @@ -247,6 +251,13 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult: vector = response.data[0].embedding return EmbedResult(dense_vector=vector) + + try: + return self._run_with_retry( + _call, + logger=logger, + operation_name="OpenAI embedding", + ) except openai.APIError as e: raise RuntimeError(f"OpenAI API error: {e.message}") from e except Exception as e: @@ -268,7 +279,7 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes if not texts: return [] - try: + def _call() -> List[EmbedResult]: kwargs: Dict[str, Any] = {"input": texts, "model": self.model_name} if self.dimension: kwargs["dimensions"] = self.dimension @@ -281,6 +292,13 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes self._update_telemetry_token_usage(response) return [EmbedResult(dense_vector=item.embedding) for item in response.data] + + try: + return self._run_with_retry( + _call, + logger=logger, + operation_name="OpenAI batch embedding", + ) except openai.APIError as e: raise RuntimeError(f"OpenAI API error: {e.message}") from e except Exception as e: diff --git a/openviking/models/embedder/vikingdb_embedders.py b/openviking/models/embedder/vikingdb_embedders.py index 0253af9dc..8d31669d5 100644 --- a/openviking/models/embedder/vikingdb_embedders.py +++ b/openviking/models/embedder/vikingdb_embedders.py @@ -124,29 +124,44 @@ def __init__( self.dense_model = {"name": model_name, "version": model_version, "dim": dimension} def embed(self, text: str, is_query: bool = False) -> EmbedResult: - results = self._call_api([text], dense_model=self.dense_model) - if not results: - return EmbedResult(dense_vector=[]) - - item = results[0] - dense_vector = [] - if "dense_embedding" in item: - dense_vector = self._truncate_and_normalize(item["dense_embedding"], self.dimension) - - return EmbedResult(dense_vector=dense_vector) + def _call() -> EmbedResult: + results = self._call_api([text], dense_model=self.dense_model) + if not results: + return EmbedResult(dense_vector=[]) + + item = results[0] + dense_vector = [] + if "dense_embedding" in item: + dense_vector = self._truncate_and_normalize(item["dense_embedding"], self.dimension) + + return EmbedResult(dense_vector=dense_vector) + + return self._run_with_retry( + _call, + logger=logger, + operation_name="VikingDB embedding", + ) def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]: if not texts: return [] - raw_results = self._call_api(texts, dense_model=self.dense_model) - return [ - EmbedResult( - dense_vector=self._truncate_and_normalize( - item.get("dense_embedding", []), self.dimension + + def _call() -> List[EmbedResult]: + raw_results = self._call_api(texts, dense_model=self.dense_model) + return [ + EmbedResult( + dense_vector=self._truncate_and_normalize( + item.get("dense_embedding", []), self.dimension + ) ) - ) - for item in raw_results - ] + for item in raw_results + ] + + return self._run_with_retry( + _call, + logger=logger, + operation_name="VikingDB batch embedding", + ) def get_dimension(self) -> int: return self.dimension if self.dimension else 2048 @@ -174,27 +189,42 @@ def __init__( } def embed(self, text: str, is_query: bool = False) -> EmbedResult: - results = self._call_api([text], sparse_model=self.sparse_model) - if not results: - return EmbedResult(sparse_vector={}) + def _call() -> EmbedResult: + results = self._call_api([text], sparse_model=self.sparse_model) + if not results: + return EmbedResult(sparse_vector={}) - item = results[0] - sparse_vector = {} - if "sparse" in item: - sparse_vector = item["sparse"] + item = results[0] + sparse_vector = {} + if "sparse" in item: + sparse_vector = item["sparse"] + + return EmbedResult(sparse_vector=sparse_vector) - return EmbedResult(sparse_vector=sparse_vector) + return self._run_with_retry( + _call, + logger=logger, + operation_name="VikingDB sparse embedding", + ) def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]: if not texts: return [] - raw_results = self._call_api(texts, sparse_model=self.sparse_model) - return [ - EmbedResult( - sparse_vector=self._process_sparse_embedding(item.get("sparse_embedding", {})) - ) - for item in raw_results - ] + + def _call() -> List[EmbedResult]: + raw_results = self._call_api(texts, sparse_model=self.sparse_model) + return [ + EmbedResult( + sparse_vector=self._process_sparse_embedding(item.get("sparse_embedding", {})) + ) + for item in raw_results + ] + + return self._run_with_retry( + _call, + logger=logger, + operation_name="VikingDB sparse batch embedding", + ) class VikingDBHybridEmbedder(HybridEmbedderBase, VikingDBClientMixin): @@ -224,37 +254,54 @@ def __init__( } def embed(self, text: str, is_query: bool = False) -> EmbedResult: - results = self._call_api( - [text], dense_model=self.dense_model, sparse_model=self.sparse_model - ) - if not results: - return EmbedResult(dense_vector=[], sparse_vector={}) + def _call() -> EmbedResult: + results = self._call_api( + [text], dense_model=self.dense_model, sparse_model=self.sparse_model + ) + if not results: + return EmbedResult(dense_vector=[], sparse_vector={}) - item = results[0] - dense_vector = [] - sparse_vector = {} + item = results[0] + dense_vector = [] + sparse_vector = {} - if "dense" in item: - dense_vector = self._truncate_and_normalize(item["dense"], self.dimension) - if "sparse" in item: - sparse_vector = item["sparse"] + if "dense" in item: + dense_vector = self._truncate_and_normalize(item["dense"], self.dimension) + if "sparse" in item: + sparse_vector = item["sparse"] - return EmbedResult(dense_vector=dense_vector, sparse_vector=sparse_vector) + return EmbedResult(dense_vector=dense_vector, sparse_vector=sparse_vector) + + return self._run_with_retry( + _call, + logger=logger, + operation_name="VikingDB hybrid embedding", + ) def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]: if not texts: return [] - raw_results = self._call_api( - texts, dense_model=self.dense_model, sparse_model=self.sparse_model + + def _call() -> List[EmbedResult]: + raw_results = self._call_api( + texts, dense_model=self.dense_model, sparse_model=self.sparse_model + ) + results = [] + for item in raw_results: + dense_vector = [] + sparse_vector = {} + if "dense" in item: + dense_vector = self._truncate_and_normalize(item["dense"], self.dimension) + if "sparse" in item: + sparse_vector = item["sparse"] + results.append(EmbedResult(dense_vector=dense_vector, sparse_vector=sparse_vector)) + return results + + return self._run_with_retry( + _call, + logger=logger, + operation_name="VikingDB hybrid batch embedding", ) - results = [] - for item in raw_results: - if "dense" in item: - dense_vector = self._truncate_and_normalize(item["dense"], self.dimension) - if "sparse" in item: - sparse_vector = item["sparse"] - results.append(EmbedResult(dense_vector=dense_vector, sparse_vector=sparse_vector)) - return results def get_dimension(self) -> int: return self.dimension if self.dimension else 2048 diff --git a/openviking/models/embedder/volcengine_embedders.py b/openviking/models/embedder/volcengine_embedders.py index 7a3ef6f4e..2fbba2d61 100644 --- a/openviking/models/embedder/volcengine_embedders.py +++ b/openviking/models/embedder/volcengine_embedders.py @@ -11,29 +11,12 @@ EmbedResult, HybridEmbedderBase, SparseEmbedderBase, - exponential_backoff_retry, truncate_and_normalize, ) from openviking.telemetry import get_current_telemetry from openviking_cli.utils.logger import default_logger as logger -def is_429_error(exception: Exception) -> bool: - """ - 判断异常是否为 429 限流错误 - - Args: - exception: 要检查的异常 - - Returns: - 如果是 429 错误则返回 True,否则返回 False - """ - exception_str = str(exception) - return ( - "429" in exception_str or "TooManyRequests" in exception_str or "RateLimit" in exception_str - ) - - def process_sparse_embedding(sparse_data: Any) -> Dict[str, float]: """Process sparse embedding data from SDK response""" if not sparse_data: @@ -177,14 +160,10 @@ def _embed_call(): return EmbedResult(dense_vector=vector) try: - return exponential_backoff_retry( + return self._run_with_retry( _embed_call, - max_wait=10.0, - base_delay=0.5, - max_delay=2.0, - jitter=True, - is_retryable=is_429_error, logger=logger, + operation_name="Volcengine embedding", ) except Exception as e: raise RuntimeError(f"Volcengine embedding failed: {str(e)}") from e @@ -205,7 +184,7 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes if not texts: return [] - try: + def _call() -> List[EmbedResult]: if self.input_type == "multimodal": multimodal_inputs = [{"type": "text", "text": text} for text in texts] response = self.client.multimodal_embeddings.create( @@ -222,6 +201,13 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes EmbedResult(dense_vector=truncate_and_normalize(item.embedding, self.dimension)) for item in data ] + + try: + return self._run_with_retry( + _call, + logger=logger, + operation_name="Volcengine batch embedding", + ) except Exception as e: logger.error( f"Volcengine batch embedding failed, texts length: {len(texts)}, input_type: {self.input_type}, model_name: {self.model_name}" @@ -295,14 +281,10 @@ def _embed_call(): return EmbedResult(sparse_vector=process_sparse_embedding(sparse_vector)) try: - return exponential_backoff_retry( + return self._run_with_retry( _embed_call, - max_wait=10.0, - base_delay=0.5, - max_delay=2.0, - jitter=True, - is_retryable=is_429_error, logger=logger, + operation_name="Volcengine sparse embedding", ) except Exception as e: raise RuntimeError(f"Volcengine sparse embedding failed: {str(e)}") from e @@ -400,14 +382,10 @@ def _embed_call(): ) try: - return exponential_backoff_retry( + return self._run_with_retry( _embed_call, - max_wait=10.0, - base_delay=0.5, - max_delay=2.0, - jitter=True, - is_retryable=is_429_error, logger=logger, + operation_name="Volcengine hybrid embedding", ) except Exception as e: raise RuntimeError(f"Volcengine hybrid embedding failed: {str(e)}") from e diff --git a/openviking/models/embedder/voyage_embedders.py b/openviking/models/embedder/voyage_embedders.py index ed8d49f04..ff0c65a08 100644 --- a/openviking/models/embedder/voyage_embedders.py +++ b/openviking/models/embedder/voyage_embedders.py @@ -2,12 +2,15 @@ # SPDX-License-Identifier: AGPL-3.0 """Voyage AI dense embedder implementation.""" +import logging from typing import Any, Dict, List, Optional import openai from openviking.models.embedder.base import DenseEmbedderBase, EmbedResult +logger = logging.getLogger(__name__) + VOYAGE_MODEL_DIMENSIONS = { "voyage-3": 1024, "voyage-3-large": 1024, @@ -83,7 +86,8 @@ def __init__( def embed(self, text: str, is_query: bool = False) -> EmbedResult: """Perform dense embedding on text.""" - try: + + def _call() -> EmbedResult: kwargs: Dict[str, Any] = {"input": text, "model": self.model_name} if self.dimension is not None: kwargs["extra_body"] = {"output_dimension": self.dimension} @@ -91,6 +95,13 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult: response = self.client.embeddings.create(**kwargs) vector = response.data[0].embedding return EmbedResult(dense_vector=vector) + + try: + return self._run_with_retry( + _call, + logger=logger, + operation_name="Voyage embedding", + ) except openai.APIError as e: raise RuntimeError(f"Voyage API error: {e.message}") from e except Exception as e: @@ -101,13 +112,20 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes if not texts: return [] - try: + def _call() -> List[EmbedResult]: kwargs: Dict[str, Any] = {"input": texts, "model": self.model_name} if self.dimension is not None: kwargs["extra_body"] = {"output_dimension": self.dimension} response = self.client.embeddings.create(**kwargs) return [EmbedResult(dense_vector=item.embedding) for item in response.data] + + try: + return self._run_with_retry( + _call, + logger=logger, + operation_name="Voyage batch embedding", + ) except openai.APIError as e: raise RuntimeError(f"Voyage API error: {e.message}") from e except Exception as e: diff --git a/openviking/models/vlm/backends/litellm_vlm.py b/openviking/models/vlm/backends/litellm_vlm.py index ca4a36aa7..72a746238 100644 --- a/openviking/models/vlm/backends/litellm_vlm.py +++ b/openviking/models/vlm/backends/litellm_vlm.py @@ -2,21 +2,21 @@ # SPDX-License-Identifier: AGPL-3.0 """LiteLLM VLM Provider implementation with multi-provider support.""" +import base64 import json import logging import os - -os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - -import asyncio -import base64 import time from pathlib import Path from typing import Any, Dict, List, Optional, Union +os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" + import litellm from litellm import acompletion, completion +from openviking.utils.model_retry import retry_async, retry_sync + from ..base import ToolCall, VLMBase, VLMResponse logger = logging.getLogger(__name__) @@ -107,7 +107,6 @@ def __init__(self, config: Dict[str, Any]): if self.api_key: self._setup_env(self.api_key, self.model) - # Configure LiteLLM behavior (these are global but safe to re-set) litellm.suppress_debug_info = True litellm.drop_params = True @@ -124,7 +123,6 @@ def _setup_env(self, api_key: str, model: str | None) -> None: os.environ[env_key] = api_key self._detected_provider = provider else: - # Fallback to OpenAI if provider is unknown or literal litellm os.environ["OPENAI_API_KEY"] = api_key def _resolve_model(self, model: str) -> str: @@ -133,7 +131,6 @@ def _resolve_model(self, model: str) -> str: if provider and provider in PROVIDER_CONFIGS: prefix = PROVIDER_CONFIGS[provider]["litellm_prefix"] - # LiteLLM uses the `zai/` prefix for Zhipu GLM; do not prepend `zhipu/` (see #784). is_zhipu_zai_model = provider == "zhipu" and model.startswith("zai/") if prefix and not model.startswith(f"{prefix}/") and not is_zhipu_zai_model: return f"{prefix}/{model}" @@ -155,11 +152,11 @@ def _detect_image_format(self, data: bytes) -> str: if data[:8] == b"\x89PNG\r\n\x1a\n": return "image/png" - elif data[:2] == b"\xff\xd8": + if data[:2] == b"\xff\xd8": return "image/jpeg" - elif data[:6] in (b"GIF87a", b"GIF89a"): + if data[:6] in (b"GIF87a", b"GIF89a"): return "image/gif" - elif data[:4] == b"RIFF" and len(data) >= 12 and data[8:12] == b"WEBP": + if data[:4] == b"RIFF" and len(data) >= 12 and data[8:12] == b"WEBP": return "image/webp" logger.warning(f"[LiteLLMVLM] Unknown image format, magic bytes: {data[:8].hex()}") @@ -174,7 +171,7 @@ def _prepare_image(self, image: Union[str, Path, bytes]) -> Dict[str, Any]: "type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{b64}"}, } - elif isinstance(image, Path) or ( + if isinstance(image, Path) or ( isinstance(image, str) and not image.startswith(("http://", "https://")) ): path = Path(image) @@ -193,8 +190,7 @@ def _prepare_image(self, image: Union[str, Path, bytes]) -> Dict[str, Any]: "type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{b64}"}, } - else: - return {"type": "image_url", "image_url": {"url": image}} + return {"type": "image_url", "image_url": {"url": image}} def _build_kwargs( self, @@ -216,9 +212,6 @@ def _build_kwargs( if self.api_key: kwargs["api_key"] = self.api_key if self.api_base: - # For Gemini, LiteLLM constructs the URL itself. If user provides a full Google endpoint - # as api_base, it might break the URL construction in LiteLLM. - # We only pass api_base if it doesn't look like a standard Google endpoint versioned URL. is_google_endpoint = "generativelanguage.googleapis.com" in self.api_base and ( "/v1" in self.api_base or "/v1beta" in self.api_base ) @@ -274,69 +267,94 @@ def _build_vlm_response(self, response, has_tools: bool) -> Union[str, VLMRespon finish_reason=choice.finish_reason or "stop", usage=usage, ) - else: - return message.content or "" + return message.content or "" - def get_completion( + def _build_text_kwargs( self, prompt: str = "", thinking: bool = False, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[str] = None, messages: Optional[List[Dict[str, Any]]] = None, - ) -> Union[str, VLMResponse]: - """Get text completion synchronously.""" + ) -> dict[str, Any]: + model = self._resolve_model(self.model or "gpt-4o-mini") + kwargs_messages = messages or [{"role": "user", "content": prompt}] + return self._build_kwargs(model, kwargs_messages, tools, tool_choice, thinking=thinking) + + def _build_vision_kwargs( + self, + prompt: str = "", + images: Optional[List[Union[str, Path, bytes]]] = None, + thinking: bool = False, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[str] = None, + messages: Optional[List[Dict[str, Any]]] = None, + ) -> dict[str, Any]: model = self._resolve_model(self.model or "gpt-4o-mini") if messages: kwargs_messages = messages else: - kwargs_messages = [{"role": "user", "content": prompt}] - - kwargs = self._build_kwargs(model, kwargs_messages, tools, tool_choice, thinking=thinking) + content = [] + if images: + content.extend(self._prepare_image(img) for img in images) + if prompt: + content.append({"type": "text", "text": prompt}) + kwargs_messages = [{"role": "user", "content": content}] + return self._build_kwargs(model, kwargs_messages, tools, tool_choice, thinking=thinking) - t0 = time.perf_counter() - response = completion(**kwargs) - elapsed = time.perf_counter() - t0 - self._update_token_usage_from_response(response, duration_seconds=elapsed) - return self._build_vlm_response(response, has_tools=bool(tools)) + def get_completion( + self, + prompt: str = "", + thinking: bool = False, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[str] = None, + messages: Optional[List[Dict[str, Any]]] = None, + ) -> Union[str, VLMResponse]: + """Get text completion synchronously.""" + kwargs = self._build_text_kwargs(prompt, thinking, tools, tool_choice, messages) + + def _call() -> Union[str, VLMResponse]: + t0 = time.perf_counter() + response = completion(**kwargs) + elapsed = time.perf_counter() - t0 + self._update_token_usage_from_response(response, duration_seconds=elapsed) + if tools: + return self._build_vlm_response(response, has_tools=True) + return self._clean_response(self._extract_content_from_response(response)) + + return retry_sync( + _call, + max_retries=self.max_retries, + logger=logger, + operation_name="LiteLLM VLM completion", + ) async def get_completion_async( self, prompt: str = "", thinking: bool = False, - max_retries: int = 0, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[str] = None, messages: Optional[List[Dict[str, Any]]] = None, ) -> Union[str, VLMResponse]: """Get text completion asynchronously.""" - model = self._resolve_model(self.model or "gpt-4o-mini") - if messages: - kwargs_messages = messages - else: - kwargs_messages = [{"role": "user", "content": prompt}] - - kwargs = self._build_kwargs(model, kwargs_messages, tools, tool_choice, thinking=thinking) - - last_error = None - for attempt in range(max_retries + 1): - try: - t0 = time.perf_counter() - response = await acompletion(**kwargs) - elapsed = time.perf_counter() - t0 - self._update_token_usage_from_response( - response, - duration_seconds=elapsed, - ) - return self._build_vlm_response(response, has_tools=bool(tools)) - except Exception as e: - last_error = e - if attempt < max_retries: - await asyncio.sleep(2**attempt) - - if last_error: - raise last_error - raise RuntimeError("Unknown error in async completion") + kwargs = self._build_text_kwargs(prompt, thinking, tools, tool_choice, messages) + + async def _call() -> Union[str, VLMResponse]: + t0 = time.perf_counter() + response = await acompletion(**kwargs) + elapsed = time.perf_counter() - t0 + self._update_token_usage_from_response(response, duration_seconds=elapsed) + if tools: + return self._build_vlm_response(response, has_tools=True) + return self._clean_response(self._extract_content_from_response(response)) + + return await retry_async( + _call, + max_retries=self.max_retries, + logger=logger, + operation_name="LiteLLM VLM async completion", + ) def get_vision_completion( self, @@ -347,26 +365,23 @@ def get_vision_completion( messages: Optional[List[Dict[str, Any]]] = None, ) -> Union[str, VLMResponse]: """Get vision completion synchronously.""" - model = self._resolve_model(self.model or "gpt-4o-mini") - - if messages: - kwargs_messages = messages - else: - content = [] - if images: - for img in images: - content.append(self._prepare_image(img)) - if prompt: - content.append({"type": "text", "text": prompt}) - kwargs_messages = [{"role": "user", "content": content}] - - kwargs = self._build_kwargs(model, kwargs_messages, tools, thinking=thinking) - - t0 = time.perf_counter() - response = completion(**kwargs) - elapsed = time.perf_counter() - t0 - self._update_token_usage_from_response(response, duration_seconds=elapsed) - return self._build_vlm_response(response, has_tools=bool(tools)) + kwargs = self._build_vision_kwargs(prompt, images, thinking, tools, None, messages) + + def _call() -> Union[str, VLMResponse]: + t0 = time.perf_counter() + response = completion(**kwargs) + elapsed = time.perf_counter() - t0 + self._update_token_usage_from_response(response, duration_seconds=elapsed) + if tools: + return self._build_vlm_response(response, has_tools=True) + return self._clean_response(self._extract_content_from_response(response)) + + return retry_sync( + _call, + max_retries=self.max_retries, + logger=logger, + operation_name="LiteLLM VLM vision completion", + ) async def get_vision_completion_async( self, @@ -377,26 +392,23 @@ async def get_vision_completion_async( messages: Optional[List[Dict[str, Any]]] = None, ) -> Union[str, VLMResponse]: """Get vision completion asynchronously.""" - model = self._resolve_model(self.model or "gpt-4o-mini") - - if messages: - kwargs_messages = messages - else: - content = [] - if images: - for img in images: - content.append(self._prepare_image(img)) - if prompt: - content.append({"type": "text", "text": prompt}) - kwargs_messages = [{"role": "user", "content": content}] - - kwargs = self._build_kwargs(model, kwargs_messages, tools, thinking=thinking) - - t0 = time.perf_counter() - response = await acompletion(**kwargs) - elapsed = time.perf_counter() - t0 - self._update_token_usage_from_response(response, duration_seconds=elapsed) - return self._build_vlm_response(response, has_tools=bool(tools)) + kwargs = self._build_vision_kwargs(prompt, images, thinking, tools, None, messages) + + async def _call() -> Union[str, VLMResponse]: + t0 = time.perf_counter() + response = await acompletion(**kwargs) + elapsed = time.perf_counter() - t0 + self._update_token_usage_from_response(response, duration_seconds=elapsed) + if tools: + return self._build_vlm_response(response, has_tools=True) + return self._clean_response(self._extract_content_from_response(response)) + + return await retry_async( + _call, + max_retries=self.max_retries, + logger=logger, + operation_name="LiteLLM VLM async vision completion", + ) def _update_token_usage_from_response( self, diff --git a/openviking/models/vlm/backends/openai_vlm.py b/openviking/models/vlm/backends/openai_vlm.py index 05c28c768..8934bcfb3 100644 --- a/openviking/models/vlm/backends/openai_vlm.py +++ b/openviking/models/vlm/backends/openai_vlm.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: AGPL-3.0 """OpenAI VLM backend implementation""" -import asyncio import base64 import json import logging @@ -11,6 +10,13 @@ from typing import Any, Dict, List, Optional, Union from urllib.parse import urlparse +try: + import openai +except ImportError: + openai = None + +from openviking.utils.model_retry import retry_async, retry_sync + from ..base import ToolCall, VLMBase, VLMResponse from ..registry import DEFAULT_AZURE_API_VERSION @@ -58,9 +64,7 @@ def __init__(self, config: Dict[str, Any]): def get_client(self): """Get sync client""" if self._sync_client is None: - try: - import openai - except ImportError: + if openai is None: raise ImportError("Please install openai: pip install openai") kwargs = _build_openai_client_kwargs( self.provider, @@ -78,9 +82,7 @@ def get_client(self): def get_async_client(self): """Get async client""" if self._async_client is None: - try: - import openai - except ImportError: + if openai is None: raise ImportError("Please install openai: pip install openai") kwargs = _build_openai_client_kwargs( self.provider, @@ -170,8 +172,7 @@ def _build_vlm_response(self, response, has_tools: bool) -> Union[str, VLMRespon finish_reason=choice.finish_reason or "stop", usage=usage, ) - else: - return message.content or "" + return message.content or "" def _extract_from_chunk(self, chunk): """Extract content and usage from a single chunk. @@ -183,11 +184,9 @@ def _extract_from_chunk(self, chunk): prompt_tokens = 0 completion_tokens = 0 - # Extract content from delta if chunk.choices and chunk.choices[0].delta: content = getattr(chunk.choices[0].delta, "content", None) - # Extract usage from chunk if available if hasattr(chunk, "usage") and chunk.usage: prompt_tokens = chunk.usage.prompt_tokens or 0 completion_tokens = chunk.usage.completion_tokens or 0 @@ -195,14 +194,7 @@ def _extract_from_chunk(self, chunk): return content, prompt_tokens, completion_tokens def _process_streaming_response(self, response): - """Process streaming response and extract content and token usage. - - Args: - response: Streaming response iterator from OpenAI client - - Returns: - str: Extracted content - """ + """Process streaming response and extract content and token usage.""" content_parts = [] prompt_tokens = 0 completion_tokens = 0 @@ -216,7 +208,6 @@ def _process_streaming_response(self, response): if ct > 0: completion_tokens = ct - # Update token usage if we got it from streaming chunks if prompt_tokens > 0 or completion_tokens > 0: self.update_token_usage( model_name=self.model or "gpt-4o-mini", @@ -228,14 +219,7 @@ def _process_streaming_response(self, response): return "".join(content_parts) async def _process_streaming_response_async(self, response): - """Process async streaming response and extract content and token usage. - - Args: - response: Async streaming response iterator from OpenAI client - - Returns: - str: Extracted content - """ + """Process async streaming response and extract content and token usage.""" content_parts = [] prompt_tokens = 0 completion_tokens = 0 @@ -249,7 +233,6 @@ async def _process_streaming_response_async(self, response): if ct > 0: completion_tokens = ct - # Update token usage if we got it from streaming chunks if prompt_tokens > 0 or completion_tokens > 0: self.update_token_usage( model_name=self.model or "gpt-4o-mini", @@ -260,21 +243,15 @@ async def _process_streaming_response_async(self, response): return "".join(content_parts) - def get_completion( + def _build_text_kwargs( self, prompt: str = "", - thinking: bool = False, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[str] = None, messages: Optional[List[Dict[str, Any]]] = None, - ) -> Union[str, VLMResponse]: - """Get text completion""" - client = self.get_client() - if messages: - kwargs_messages = messages - else: - kwargs_messages = [{"role": "user", "content": prompt}] - + thinking: bool = False, + ) -> Dict[str, Any]: + kwargs_messages = messages or [{"role": "user", "content": prompt}] kwargs = { "model": self.model or "gpt-4o-mini", "messages": kwargs_messages, @@ -284,42 +261,29 @@ def get_completion( self._apply_provider_specific_extra_body(kwargs, thinking) if self.max_tokens is not None: kwargs["max_tokens"] = self.max_tokens - if tools: kwargs["tools"] = tools kwargs["tool_choice"] = tool_choice or "auto" + return kwargs - t0 = time.perf_counter() - response = client.chat.completions.create(**kwargs) - elapsed = time.perf_counter() - t0 - - if tools: - self._update_token_usage_from_response(response) - return self._build_vlm_response(response, has_tools=bool(tools)) - - if self.stream: - content = self._process_streaming_response(response) - else: - self._update_token_usage_from_response(response, duration_seconds=elapsed) - content = self._extract_content_from_response(response) - - return self._clean_response(content) - - async def get_completion_async( + def _build_vision_kwargs( self, prompt: str = "", - thinking: bool = False, - max_retries: int = 0, + images: Optional[List[Union[str, Path, bytes]]] = None, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[str] = None, messages: Optional[List[Dict[str, Any]]] = None, - ) -> Union[str, VLMResponse]: - """Get text completion asynchronously""" - client = self.get_async_client() + thinking: bool = False, + ) -> Dict[str, Any]: if messages: kwargs_messages = messages else: - kwargs_messages = [{"role": "user", "content": prompt}] + content = [] + if images: + content.extend(self._prepare_image(img) for img in images) + if prompt: + content.append({"type": "text", "text": prompt}) + kwargs_messages = [{"role": "user", "content": content}] kwargs = { "model": self.model or "gpt-4o-mini", @@ -330,41 +294,82 @@ async def get_completion_async( self._apply_provider_specific_extra_body(kwargs, thinking) if self.max_tokens is not None: kwargs["max_tokens"] = self.max_tokens - if tools: kwargs["tools"] = tools kwargs["tool_choice"] = tool_choice or "auto" + return kwargs - last_error = None - for attempt in range(max_retries + 1): - try: - t0 = time.perf_counter() - response = await client.chat.completions.create(**kwargs) - elapsed = time.perf_counter() - t0 - - if tools: - self._update_token_usage_from_response(response) - return self._build_vlm_response(response, has_tools=bool(tools)) - - if self.stream: - content = await self._process_streaming_response_async(response) - else: - self._update_token_usage_from_response( - response, - duration_seconds=elapsed, - ) - content = self._extract_content_from_response(response) - - return self._clean_response(content) - except Exception as e: - last_error = e - if attempt < max_retries: - await asyncio.sleep(2**attempt) - - if last_error: - raise last_error + def _extract_completion_content(self, response, elapsed: float) -> str: + if self.stream: + content = self._process_streaming_response(response) else: - raise RuntimeError("Unknown error in async completion") + self._update_token_usage_from_response(response, duration_seconds=elapsed) + content = self._extract_content_from_response(response) + return self._clean_response(content) + + async def _extract_completion_content_async(self, response, elapsed: float) -> str: + if self.stream: + content = await self._process_streaming_response_async(response) + else: + self._update_token_usage_from_response(response, duration_seconds=elapsed) + content = self._extract_content_from_response(response) + return self._clean_response(content) + + def get_completion( + self, + prompt: str = "", + thinking: bool = False, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[str] = None, + messages: Optional[List[Dict[str, Any]]] = None, + ) -> Union[str, VLMResponse]: + """Get text completion""" + client = self.get_client() + kwargs = self._build_text_kwargs(prompt, tools, tool_choice, messages, thinking) + + def _call() -> Union[str, VLMResponse]: + t0 = time.perf_counter() + response = client.chat.completions.create(**kwargs) + elapsed = time.perf_counter() - t0 + if tools: + self._update_token_usage_from_response(response, duration_seconds=elapsed) + return self._build_vlm_response(response, has_tools=True) + return self._extract_completion_content(response, elapsed) + + return retry_sync( + _call, + max_retries=self.max_retries, + logger=logger, + operation_name="OpenAI VLM completion", + ) + + async def get_completion_async( + self, + prompt: str = "", + thinking: bool = False, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[str] = None, + messages: Optional[List[Dict[str, Any]]] = None, + ) -> Union[str, VLMResponse]: + """Get text completion asynchronously""" + client = self.get_async_client() + kwargs = self._build_text_kwargs(prompt, tools, tool_choice, messages, thinking) + + async def _call() -> Union[str, VLMResponse]: + t0 = time.perf_counter() + response = await client.chat.completions.create(**kwargs) + elapsed = time.perf_counter() - t0 + if tools: + self._update_token_usage_from_response(response, duration_seconds=elapsed) + return self._build_vlm_response(response, has_tools=True) + return await self._extract_completion_content_async(response, elapsed) + + return await retry_async( + _call, + max_retries=self.max_retries, + logger=logger, + operation_name="OpenAI VLM async completion", + ) def _detect_image_format(self, data: bytes) -> str: """Detect image format from magic bytes. @@ -377,11 +382,11 @@ def _detect_image_format(self, data: bytes) -> str: if data[:8] == b"\x89PNG\r\n\x1a\n": return "image/png" - elif data[:2] == b"\xff\xd8": + if data[:2] == b"\xff\xd8": return "image/jpeg" - elif data[:6] in (b"GIF87a", b"GIF89a"): + if data[:6] in (b"GIF87a", b"GIF89a"): return "image/gif" - elif data[:4] == b"RIFF" and len(data) >= 12 and data[8:12] == b"WEBP": + if data[:4] == b"RIFF" and len(data) >= 12 and data[8:12] == b"WEBP": return "image/webp" logger.warning(f"[OpenAIVLM] Unknown image format, magic bytes: {data[:8].hex()}") @@ -396,7 +401,7 @@ def _prepare_image(self, image: Union[str, Path, bytes]) -> Dict[str, Any]: "type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{b64}"}, } - elif isinstance(image, Path) or ( + if isinstance(image, Path) or ( isinstance(image, str) and not image.startswith(("http://", "https://")) ): path = Path(image) @@ -415,8 +420,7 @@ def _prepare_image(self, image: Union[str, Path, bytes]) -> Dict[str, Any]: "type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{b64}"}, } - else: - return {"type": "image_url", "image_url": {"url": image}} + return {"type": "image_url", "image_url": {"url": image}} def get_vision_completion( self, @@ -428,47 +432,23 @@ def get_vision_completion( ) -> Union[str, VLMResponse]: """Get vision completion""" client = self.get_client() - - if messages: - kwargs_messages = messages - else: - content = [] - if images: - for img in images: - content.append(self._prepare_image(img)) - if prompt: - content.append({"type": "text", "text": prompt}) - kwargs_messages = [{"role": "user", "content": content}] - - kwargs = { - "model": self.model or "gpt-4o-mini", - "messages": kwargs_messages, - "temperature": self.temperature, - "stream": self.stream, - } - self._apply_provider_specific_extra_body(kwargs, thinking) - if self.max_tokens is not None: - kwargs["max_tokens"] = self.max_tokens - - if tools: - kwargs["tools"] = tools - kwargs["tool_choice"] = "auto" - - t0 = time.perf_counter() - response = client.chat.completions.create(**kwargs) - elapsed = time.perf_counter() - t0 - - if tools: - self._update_token_usage_from_response(response) - return self._build_vlm_response(response, has_tools=bool(tools)) - - if self.stream: - content = self._process_streaming_response(response) - else: - self._update_token_usage_from_response(response, duration_seconds=elapsed) - content = self._extract_content_from_response(response) - - return self._clean_response(content) + kwargs = self._build_vision_kwargs(prompt, images, tools, None, messages, thinking) + + def _call() -> Union[str, VLMResponse]: + t0 = time.perf_counter() + response = client.chat.completions.create(**kwargs) + elapsed = time.perf_counter() - t0 + if tools: + self._update_token_usage_from_response(response, duration_seconds=elapsed) + return self._build_vlm_response(response, has_tools=True) + return self._extract_completion_content(response, elapsed) + + return retry_sync( + _call, + max_retries=self.max_retries, + logger=logger, + operation_name="OpenAI VLM vision completion", + ) async def get_vision_completion_async( self, @@ -480,44 +460,20 @@ async def get_vision_completion_async( ) -> Union[str, VLMResponse]: """Get vision completion asynchronously""" client = self.get_async_client() - - if messages: - kwargs_messages = messages - else: - content = [] - if images: - for img in images: - content.append(self._prepare_image(img)) - if prompt: - content.append({"type": "text", "text": prompt}) - kwargs_messages = [{"role": "user", "content": content}] - - kwargs = { - "model": self.model or "gpt-4o-mini", - "messages": kwargs_messages, - "temperature": self.temperature, - "stream": self.stream, - } - self._apply_provider_specific_extra_body(kwargs, thinking) - if self.max_tokens is not None: - kwargs["max_tokens"] = self.max_tokens - - if tools: - kwargs["tools"] = tools - kwargs["tool_choice"] = "auto" - - t0 = time.perf_counter() - response = await client.chat.completions.create(**kwargs) - elapsed = time.perf_counter() - t0 - - if tools: - self._update_token_usage_from_response(response) - return self._build_vlm_response(response, has_tools=bool(tools)) - - if self.stream: - content = await self._process_streaming_response_async(response) - else: - self._update_token_usage_from_response(response, duration_seconds=elapsed) - content = self._extract_content_from_response(response) - - return self._clean_response(content) + kwargs = self._build_vision_kwargs(prompt, images, tools, None, messages, thinking) + + async def _call() -> Union[str, VLMResponse]: + t0 = time.perf_counter() + response = await client.chat.completions.create(**kwargs) + elapsed = time.perf_counter() - t0 + if tools: + self._update_token_usage_from_response(response, duration_seconds=elapsed) + return self._build_vlm_response(response, has_tools=True) + return await self._extract_completion_content_async(response, elapsed) + + return await retry_async( + _call, + max_retries=self.max_retries, + logger=logger, + operation_name="OpenAI VLM async vision completion", + ) diff --git a/openviking/models/vlm/backends/volcengine_vlm.py b/openviking/models/vlm/backends/volcengine_vlm.py index 06616def1..d8ecc7010 100644 --- a/openviking/models/vlm/backends/volcengine_vlm.py +++ b/openviking/models/vlm/backends/volcengine_vlm.py @@ -1,15 +1,16 @@ # Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. # SPDX-License-Identifier: AGPL-3.0 -"""VolcEngine VLM backend implementation""" +"""VolcEngine VLM backend implementation.""" import base64 import json import logging +import time from collections import OrderedDict from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union -# Import run_async for sync-to-async calls +from openviking.utils.model_retry import retry_async, retry_sync from openviking_cli.utils import run_async from ..base import ToolCall, VLMResponse @@ -49,21 +50,16 @@ def __init__(self, config: Dict[str, Any]): super().__init__(config) self._sync_client = None self._async_client = None - # Ensure provider type is correct self.provider = "volcengine" - - # Prompt caching: message content -> response_id self._response_cache = LRUCache(maxsize=100) - # VolcEngine-specific defaults if not self.api_base: self.api_base = "https://ark.cn-beijing.volces.com/api/v3" if not self.model: self.model = "doubao-seed-2-0-pro-260215" def _get_response_id_cache_key(self, messages: List[Dict[str, Any]]) -> str: - """Generate cache key for response_id using simple JSON serialization.""" - # Filter out cache_control from messages for cache key + """Generate cache key for response_id using JSON serialization.""" key_messages = [] for msg in messages: filtered = {k: v for k, v in msg.items() if k != "cache_control"} @@ -73,60 +69,36 @@ def _get_response_id_cache_key(self, messages: List[Dict[str, Any]]) -> str: def _parse_messages_with_breakpoints( self, messages: List[Dict[str, Any]] ) -> Tuple[List[List[Dict[str, Any]]], List[Dict[str, Any]]]: - """Parse messages into static segment and dynamic messages. - - Only the content BEFORE the first cache_control becomes the static segment. - All messages after (including the one with cache_control) become dynamic. - """ - # 找到第一个 cache_control 的位置 + """Split messages into cacheable prefixes and dynamic suffix.""" first_breakpoint_idx = -1 for i, msg in enumerate(messages): if msg.get("cache_control"): first_breakpoint_idx = i - # print(f'cache_control={msg}') break if first_breakpoint_idx > 0: - # 有 cache_control,取其之前的内容作为 static segment static_segment = messages[: first_breakpoint_idx + 1] dynamic_messages = messages[first_breakpoint_idx + 1 :] - static_segments = [static_segment] - print(f"static_segment={len(static_segment)}") - print(f"dynamic_messages={len(dynamic_messages)}") - else: - # 没有 cache_control 或在第一个位置,全部作为 dynamic - static_segments = [] - dynamic_messages = messages + return [static_segment], dynamic_messages - return static_segments, dynamic_messages + return [], messages async def _get_or_create_from_segments( self, segments: List[List[Dict[str, Any]]], end_idx: int ) -> Optional[str]: - """递归获取或创建 cache,从长到短尝试。 - - Args: - segments: static 消息分段,每段以 cache_control 结尾 - end_idx: 尝试的前缀长度(包含的 segment 数量) - - Returns: - response_id for the prefix - """ + """Recursively get or create cached prefixes.""" if end_idx <= 0: return None - def segments_to_messages(segs): - # 拼接前 end_idx 个 segments - msgs = [] + def segments_to_messages(segs: List[List[Dict[str, Any]]]) -> List[Dict[str, Any]]: + msgs: List[Dict[str, Any]] = [] for seg in segs: msgs.extend(seg) return msgs prefix = segments_to_messages(segments[:end_idx]) - if end_idx == 1: - response_id = await self._get_or_create_from_messages(prefix) - return response_id + return await self._get_or_create_from_messages(prefix) previous_response_id = await self._get_or_create_from_segments(segments, end_idx - 1) return await self._get_or_create_from_messages( @@ -135,11 +107,9 @@ def segments_to_messages(segs): ) async def _get_or_create_from_messages( - self, messages: List[Dict[str, Any]], previous_response_id=None + self, messages: List[Dict[str, Any]], previous_response_id: Optional[str] = None ) -> Optional[str]: - """从头创建新 cache。""" - - # Check cache first + """Create a cached prefix and return its response id.""" cache_key = self._get_response_id_cache_key(messages) cached_id = self._response_cache.get(cache_key) if cached_id is not None: @@ -159,65 +129,48 @@ async def _get_or_create_from_messages( self._response_cache.set(cache_key, cached_id) return cached_id except Exception as e: - logger.warning(f"[VolcEngineVLM] Failed to create new cache: {e}") + logger.warning("[VolcEngineVLM] Failed to create cached prefix: %s", e) return None async def responseapi_prefixcache_completion( self, static_segments: List[List[Dict[str, Any]]], dynamic_messages: List[Dict[str, Any]], - response_format: Optional[Dict] = None, + response_format: Optional[Dict[str, Any]] = None, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[str] = None, + thinking: bool = False, ) -> Any: - """Use cached response_id for completion with dynamic messages. - - Args: - static_segments: Multiple static segments, each ending with cache_control - dynamic_messages: New messages for this request - response_format: Response format for structured output - tools: Tool definitions - tool_choice: Tool choice setting - """ - # 使用多段缓存获取 response_id + """Call VolcEngine Responses API with optional prefix caching.""" if static_segments: response_id = await self._get_or_create_from_segments( static_segments, len(static_segments) ) else: response_id = None - client = self.get_async_client() - input_data = self._convert_messages_to_input(dynamic_messages) - kwargs = { + client = self.get_async_client() + kwargs: Dict[str, Any] = { "model": self.model, - "input": input_data, + "input": self._convert_messages_to_input(dynamic_messages), "temperature": self.temperature, - "thinking": {"type": "disabled"}, + "thinking": {"type": "enabled" if thinking else "disabled"}, + "caching": {"type": "enabled"}, } if self.max_tokens is not None: kwargs["max_tokens"] = self.max_tokens if response_format: kwargs["text"] = {"format": response_format} - if response_id: kwargs["previous_response_id"] = response_id - kwargs["caching"] = {"type": "enabled"} - elif tools: - # First call with tools: enable caching - converted_tools = self._convert_tools(tools) - kwargs["tools"] = converted_tools + if tools: + kwargs["tools"] = self._convert_tools(tools) kwargs["tool_choice"] = tool_choice or "auto" - kwargs["caching"] = {"type": "enabled"} - else: - # Enable caching by default - kwargs["caching"] = {"type": "enabled"} - response = await client.responses.create(**kwargs) - return response + return await client.responses.create(**kwargs) def get_client(self): - """Get sync client""" + """Get sync client.""" if self._sync_client is None: try: import volcenginesdkarkruntime @@ -232,7 +185,7 @@ def get_client(self): return self._sync_client def get_async_client(self): - """Get async client""" + """Get async client.""" if self._async_client is None: try: import volcenginesdkarkruntime @@ -251,51 +204,35 @@ def _update_token_usage_from_response( response, duration_seconds: float = 0.0, ) -> None: - """Update token usage from VolcEngine Responses API response.""" + """Update token usage from either Responses API or chat completions.""" if hasattr(response, "usage") and response.usage: - u = response.usage - # Responses API uses input_tokens/output_tokens instead of prompt_tokens/completion_tokens - prompt_tokens = getattr(u, "input_tokens", 0) or 0 - completion_tokens = getattr(u, "output_tokens", 0) or 0 - self.update_token_usage( - model_name=self.model or "unknown", - provider=self.provider, - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - duration_seconds=duration_seconds, - ) - return + usage = response.usage + if hasattr(usage, "input_tokens") or hasattr(usage, "output_tokens"): + prompt_tokens = getattr(usage, "input_tokens", 0) or 0 + completion_tokens = getattr(usage, "output_tokens", 0) or 0 + self.update_token_usage( + model_name=self.model or "unknown", + provider=self.provider, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + duration_seconds=duration_seconds, + ) + return + super()._update_token_usage_from_response(response, duration_seconds=duration_seconds) def _build_vlm_response(self, response, has_tools: bool) -> Union[str, VLMResponse]: - """Build response from VolcEngine Responses API response. - - Responses API returns: - - response.output: list of output items - - response.id: response ID - - response.usage: token usage - """ - # Debug: print response structure - # logger.debug(f"[VolcEngineVLM] Response type: {type(response)}") - # logger.info(f"[VolcEngineVLM] Full response: {response}") - if hasattr(response, "output"): - # logger.debug(f"[VolcEngineVLM] Output items: {len(response.output)}") - for i, item in enumerate(response.output): - # logger.debug(f"[VolcEngineVLM] Item {i}: type={getattr(item, 'type', 'unknown')}") - # Print full item for debugging - # logger.info(f"[VolcEngineVLM] Item {i} full: {item}") - pass - - # Extract content from Responses API format + """Build a VLM response from Responses API or chat completions payloads.""" + if hasattr(response, "choices"): + return super()._build_vlm_response(response, has_tools) + content = "" - tool_calls = [] + tool_calls: List[ToolCall] = [] finish_reason = "stop" if hasattr(response, "output") and response.output: for item in response.output: item_type = getattr(item, "type", None) - # Check if it's a function_call item (Responses API format) if item_type == "function_call": - # logger.debug(f"[VolcEngineVLM] Found function_call tool call") args = item.arguments if isinstance(args, str): try: @@ -306,45 +243,43 @@ def _build_vlm_response(self, response, has_tools: bool) -> Union[str, VLMRespon ToolCall(id=item.call_id or "", name=item.name or "", arguments=args) ) finish_reason = "tool_calls" - # Check if it's a message item (Chat API compatibility) - elif item_type == "message": - message = item - if hasattr(message, "content"): - # Content can be a list or string - if isinstance(message.content, list): - for block in message.content: - if hasattr(block, "type") and block.type == "output_text": - content = block.text or "" - elif hasattr(block, "text"): - content = block.text or "" - else: - content = message.content or "" - - # Parse tool calls from message - if hasattr(message, "tool_calls") and message.tool_calls: - # logger.debug(f"[VolcEngineVLM] Found {len(message.tool_calls)} tool calls in message") - for tc in message.tool_calls: - args = tc.arguments - if isinstance(args, str): - try: - args = json.loads(args) - except json.JSONDecodeError: - args = {"raw": args} - # Handle both tc.name and tc.function.name (Responses API vs Chat API) + continue + + if item_type != "message": + continue + + if hasattr(item, "content"): + if isinstance(item.content, list): + text_parts = [] + for block in item.content: + if getattr(block, "type", None) == "output_text": + text_parts.append(block.text or "") + elif hasattr(block, "text"): + text_parts.append(block.text or "") + content = "".join(text_parts) + else: + content = item.content or "" + + if hasattr(item, "tool_calls") and item.tool_calls: + for tc in item.tool_calls: + args = tc.arguments + if isinstance(args, str): try: - tool_name = tc.name - if not tool_name: - tool_name = tc.function.name - except AttributeError: - tool_name = tc.function.name if hasattr(tc, "function") else "" - tool_calls.append( - ToolCall(id=tc.id or "", name=tool_name or "", arguments=args) + args = json.loads(args) + except json.JSONDecodeError: + args = {"raw": args} + tool_name = getattr(tc, "name", None) + if not tool_name and hasattr(tc, "function"): + tool_name = tc.function.name + tool_calls.append( + ToolCall( + id=getattr(tc, "id", "") or "", name=tool_name or "", arguments=args ) + ) - finish_reason = getattr(message, "finish_reason", "stop") or "stop" + finish_reason = getattr(item, "finish_reason", "stop") or "stop" - # Extract usage - usage = {} + usage: Dict[str, Any] = {} if hasattr(response, "usage") and response.usage: u = response.usage usage = { @@ -352,7 +287,6 @@ def _build_vlm_response(self, response, has_tools: bool) -> Union[str, VLMRespon "completion_tokens": getattr(u, "output_tokens", 0), "total_tokens": getattr(u, "total_tokens", 0), } - # Handle cached tokens input_details = getattr(u, "input_tokens_details", None) if input_details: usage["prompt_tokens_details"] = { @@ -366,8 +300,7 @@ def _build_vlm_response(self, response, has_tools: bool) -> Union[str, VLMRespon finish_reason=finish_reason, usage=usage, ) - else: - return content + return content def get_completion( self, @@ -377,11 +310,7 @@ def get_completion( tool_choice: Optional[str] = None, messages: Optional[List[Dict[str, Any]]] = None, ) -> Union[str, VLMResponse]: - """Get text completion with prompt caching support. - - Uses VolcEngine Responses API with prefix cache. - Delegates to async implementation. - """ + """Get text completion via the async Responses API implementation.""" return run_async( self.get_completion_async( prompt=prompt, @@ -393,63 +322,41 @@ def get_completion( ) def _convert_messages_to_input(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """Convert OpenAI-style messages to VolcEngine Responses API input format. - - VolcEngine Responses API format (no "type" field needed): - [ - {"role": "system", "content": "..."}, - {"role": "user", "content": "..."}, - ] - - Note: Responses API doesn't support 'tool' role, so we convert tool results - to user messages with a prefix indicating it's a tool result. - """ + """Convert OpenAI-style messages to VolcEngine Responses API input format.""" input_messages = [] for msg in messages: role = msg.get("role", "user") content = msg.get("content", "") - # Handle tool_call role with content as dict {name, args, result} if role == "tool_call" and isinstance(content, dict): - import json - content_str = json.dumps(content, ensure_ascii=False) - role = "user" # Convert tool_call to user + role = "user" else: - # Handle content - check if it contains images - has_images = False if isinstance(content, list): text_parts = [] image_urls = [] for block in content: - if isinstance(block, dict): - block_type = block.get("type", "") - # Handle text blocks - if block_type == "text" or "text" in block: - text = block.get("text", "") - if text: - text_parts.append(text) - # Handle image_url blocks - elif block_type == "image_url" or "image_url" in block: - image_url = block.get("image_url", {}) - if isinstance(image_url, dict): - url = image_url.get("url", "") - if url: - image_urls.append(url) - has_images = True - # Handle other block types - else: - # Try to extract text from any dict block - text = block.get("text", "") - if text: - text_parts.append(text) + if not isinstance(block, dict): + continue + block_type = block.get("type", "") + if block_type == "text" or "text" in block: + text = block.get("text", "") + if text: + text_parts.append(text) + elif block_type == "image_url" or "image_url" in block: + image_url = block.get("image_url", {}) + if isinstance(image_url, dict): + url = image_url.get("url", "") + if url: + image_urls.append(url) + else: + text = block.get("text", "") + if text: + text_parts.append(text) content = " ".join(text_parts) - # If there were images, include them as base64 data URLs in content if image_urls: - # Filter out non-data URLs (keep only data: URLs) data_urls = [u for u in image_urls if u.startswith("data:")] if data_urls: - # Append image references to content content = ( content + "\n[Images: " @@ -457,57 +364,37 @@ def _convert_messages_to_input(self, messages: List[Dict[str, Any]]) -> List[Dic + "]" ) - # Ensure content is a string, use placeholder if empty content_str = str(content) if content else "[empty]" - # Skip messages with empty content (API requirement) if not content_str or content_str == "[empty]": continue - # Handle role conversion - # Responses API supports: system, user, assistant - # Convert 'tool' role to user with prefix (preserve the tool result context) if role == "tool": - # Prefix with tool result indicator content_str = f"[Tool Result]\n{content_str}" role = "user" - # Simple format: role + content (no type field) - input_messages.append( - { - "role": role, - "content": content_str, - } - ) + input_messages.append({"role": role, "content": content_str}) return input_messages def _convert_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """Convert OpenAI-style tool format to VolcEngine Responses API format. - - OpenAI format: {"type": "function", "function": {"name": ..., "parameters": ...}} - VolcEngine format: {"type": "function", "name": ..., "description": ..., "parameters": ...} - - Note: VolcEngine Responses API requires "type": "function" and name at top level. - """ + """Convert OpenAI-style tool format to VolcEngine Responses API format.""" converted = [] for tool in tools: if not isinstance(tool, dict): converted.append(tool) continue - # Check if it's OpenAI format: {"type": "function", "function": {...}} if tool.get("type") == "function" and "function" in tool: func = tool["function"] converted.append( { - "type": "function", # Keep the type field + "type": "function", "name": func.get("name", ""), "description": func.get("description", ""), "parameters": func.get("parameters", {}), } ) elif "function" in tool: - # Has function but no type func = tool["function"] converted.append( { @@ -517,21 +404,17 @@ def _convert_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: "parameters": func.get("parameters", {}), } ) + elif tool.get("type") != "function": + converted.append( + { + "type": "function", + "name": tool.get("name", ""), + "description": tool.get("description", ""), + "parameters": tool.get("parameters", {}), + } + ) else: - # Already in correct format or other format - # Ensure it has type: function - if tool.get("type") != "function": - converted.append( - { - "type": "function", - "name": tool.get("name", ""), - "description": tool.get("description", ""), - "parameters": tool.get("parameters", {}), - } - ) - else: - # Keep as is - converted.append(tool) + converted.append(tool) return converted @@ -539,134 +422,117 @@ async def get_completion_async( self, prompt: str = "", thinking: bool = False, - max_retries: int = 0, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[str] = None, messages: Optional[List[Dict[str, Any]]] = None, ) -> Union[str, VLMResponse]: - """Get text completion with prompt caching support. - - Uses VolcEngine Responses API with prefix cache. - Separates messages into static (cached) and dynamic parts. - """ - if messages: - kwargs_messages = messages - else: - kwargs_messages = [{"role": "user", "content": prompt}] - - # Parse messages into multiple static segments and dynamic messages - # Each segment ends with cache_control, dynamic is the rest + """Get text completion with prompt caching support.""" + kwargs_messages = messages or [{"role": "user", "content": prompt}] static_segments, dynamic_messages = self._parse_messages_with_breakpoints(kwargs_messages) - # If we have static segments, try prefix cache - response_format = None # Can be extended for structured output - - try: - # Use prefix cache with multiple segments + async def _call() -> Union[str, VLMResponse]: response = await self.responseapi_prefixcache_completion( static_segments=static_segments, dynamic_messages=dynamic_messages, - response_format=response_format, + response_format=None, tools=tools, tool_choice=tool_choice, + thinking=thinking, ) - elapsed = 0 # Timing handled in responseapi methods - self._update_token_usage_from_response(response, duration_seconds=elapsed) - return self._build_vlm_response(response, has_tools=bool(tools)) - - except Exception as e: - last_error = e - # Log token info from error response if available - error_response = getattr(e, "response", None) - if error_response and hasattr(error_response, "usage"): - u = error_response.usage - prompt_tokens = getattr(u, "input_tokens", 0) or 0 - completion_tokens = getattr(u, "output_tokens", 0) or 0 - logger.info( - f"[VolcEngineVLM] Error response - Input tokens: {prompt_tokens}, Output tokens: {completion_tokens}" - ) - logger.warning(f"[VolcEngineVLM] Request failed: {e}") - raise last_error + self._update_token_usage_from_response(response, duration_seconds=0.0) + result = self._build_vlm_response(response, has_tools=bool(tools)) + if tools: + return result + return self._clean_response(str(result)) + + return await retry_async( + _call, + max_retries=self.max_retries, + logger=logger, + operation_name="VolcEngine VLM async completion", + ) - def _detect_image_format(self, data: bytes) -> str: - """Detect image format from magic bytes. + def _build_vision_kwargs( + self, + prompt: str = "", + images: Optional[List[Union[str, Path, bytes]]] = None, + thinking: bool = False, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[str] = None, + messages: Optional[List[Dict[str, Any]]] = None, + ) -> Dict[str, Any]: + if messages: + kwargs_messages = messages + else: + content = [] + if images: + content.extend(self._prepare_image(img) for img in images) + if prompt: + content.append({"type": "text", "text": prompt}) + kwargs_messages = [{"role": "user", "content": content}] - Returns the MIME type, or raises ValueError for unsupported formats like SVG. + kwargs = { + "model": self.model or "doubao-seed-2-0-pro-260215", + "messages": kwargs_messages, + "temperature": self.temperature, + "thinking": {"type": "disabled" if not thinking else "enabled"}, + } + if self.max_tokens is not None: + kwargs["max_tokens"] = self.max_tokens + if tools: + kwargs["tools"] = tools + kwargs["tool_choice"] = tool_choice or "auto" + return kwargs - Supported formats per VolcEngine docs: - https://www.volcengine.com/docs/82379/1362931 - - JPEG, PNG, GIF, WEBP, BMP, TIFF, ICO, DIB, ICNS, SGI, JPEG2000, HEIC, HEIF - """ + def _detect_image_format(self, data: bytes) -> str: + """Detect image format from magic bytes.""" if len(data) < 12: - # logger.warning(f"[VolcEngineVLM] Image data too small: {len(data)} bytes") return "image/png" - # PNG: 89 50 4E 47 0D 0A 1A 0A if data[:8] == b"\x89PNG\r\n\x1a\n": return "image/png" - # JPEG: FF D8 - elif data[:2] == b"\xff\xd8": + if data[:2] == b"\xff\xd8": return "image/jpeg" - # GIF: GIF87a or GIF89a - elif data[:6] in (b"GIF87a", b"GIF89a"): + if data[:6] in (b"GIF87a", b"GIF89a"): return "image/gif" - # WEBP: RIFF....WEBP - elif data[:4] == b"RIFF" and len(data) >= 12 and data[8:12] == b"WEBP": + if data[:4] == b"RIFF" and len(data) >= 12 and data[8:12] == b"WEBP": return "image/webp" - # BMP: BM - elif data[:2] == b"BM": + if data[:2] == b"BM": return "image/bmp" - # TIFF (little-endian): 49 49 2A 00 - # TIFF (big-endian): 4D 4D 00 2A - elif data[:4] == b"II*\x00" or data[:4] == b"MM\x00*": + if data[:4] == b"II*\x00" or data[:4] == b"MM\x00*": return "image/tiff" - # ICO: 00 00 01 00 - elif data[:4] == b"\x00\x00\x01\x00": + if data[:4] == b"\x00\x00\x01\x00": return "image/ico" - # ICNS: 69 63 6E 73 ("icns") - elif data[:4] == b"icns": + if data[:4] == b"icns": return "image/icns" - # SGI: 01 DA - elif data[:2] == b"\x01\xda": + if data[:2] == b"\x01\xda": return "image/sgi" - # JPEG2000: 00 00 00 0C 6A 50 20 20 (JP2 signature) - elif data[:8] == b"\x00\x00\x00\x0cjP " or data[:4] == b"\xff\x4f\xff\x51": + if data[:8] == b"\x00\x00\x00\x0cjP " or data[:4] == b"\xff\x4f\xff\x51": return "image/jp2" - # HEIC/HEIF: ftyp box with heic/heif brand - # 00 00 00 XX 66 74 79 70 68 65 69 63 (heic) - # 00 00 00 XX 66 74 79 70 68 65 69 66 (heif) - elif len(data) >= 12 and data[4:8] == b"ftyp": + if len(data) >= 12 and data[4:8] == b"ftyp": brand = data[8:12] if brand == b"heic": return "image/heic" - elif brand == b"heif": + if brand == b"heif" or brand[:3] == b"mif": return "image/heif" - elif brand[:3] == b"mif": - return "image/heif" - # SVG (not supported) - elif data[:4] == b" Dict[str, Any]: - """Prepare image data""" + """Prepare image data for vision completion.""" if isinstance(image, bytes): b64 = base64.b64encode(image).decode("utf-8") mime_type = self._detect_image_format(image) - # logger.info( - # f"[VolcEngineVLM] Preparing image from bytes, size={len(image)}, detected mime={mime_type}" - # ) return { "type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{b64}"}, } - elif isinstance(image, Path) or ( + if isinstance(image, Path) or ( isinstance(image, str) and not image.startswith(("http://", "https://")) ): path = Path(image) @@ -699,8 +565,7 @@ def _prepare_image(self, image: Union[str, Path, bytes]) -> Dict[str, Any]: "type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{b64}"}, } - else: - return {"type": "image_url", "image_url": {"url": image}} + return {"type": "image_url", "image_url": {"url": image}} def get_vision_completion( self, @@ -710,19 +575,25 @@ def get_vision_completion( tools: Optional[List[Dict[str, Any]]] = None, messages: Optional[List[Dict[str, Any]]] = None, ) -> Union[str, VLMResponse]: - """Get vision completion with prompt caching support. - - Uses VolcEngine Responses API with prefix cache. - Delegates to async implementation. - """ - return run_async( - self.get_vision_completion_async( - prompt=prompt, - images=images, - thinking=thinking, - tools=tools, - messages=messages, - ) + """Get vision completion through chat completions.""" + client = self.get_client() + kwargs = self._build_vision_kwargs(prompt, images, thinking, tools, None, messages) + + def _call() -> Union[str, VLMResponse]: + t0 = time.perf_counter() + response = client.chat.completions.create(**kwargs) + elapsed = time.perf_counter() - t0 + self._update_token_usage_from_response(response, duration_seconds=elapsed) + result = self._build_vlm_response(response, has_tools=bool(tools)) + if tools: + return result + return self._clean_response(str(result)) + + return retry_sync( + _call, + max_retries=self.max_retries, + logger=logger, + operation_name="VolcEngine VLM vision completion", ) async def get_vision_completion_async( @@ -733,25 +604,23 @@ async def get_vision_completion_async( tools: Optional[List[Dict[str, Any]]] = None, messages: Optional[List[Dict[str, Any]]] = None, ) -> Union[str, VLMResponse]: - """Get vision completion with prompt caching support. - - Uses VolcEngine Responses API with prefix cache. - """ - if messages: - kwargs_messages = messages - else: - content = [] - if images: - for img in images: - content.append(self._prepare_image(img)) - if prompt: - content.append({"type": "text", "text": prompt}) - kwargs_messages = [{"role": "user", "content": content}] + """Get vision completion asynchronously through chat completions.""" + client = self.get_async_client() + kwargs = self._build_vision_kwargs(prompt, images, thinking, tools, None, messages) - # 复用 get_completion_async 的逻辑 - return await self.get_completion_async( - prompt=prompt, - thinking=thinking, - tools=tools, - messages=kwargs_messages, + async def _call() -> Union[str, VLMResponse]: + t0 = time.perf_counter() + response = await client.chat.completions.create(**kwargs) + elapsed = time.perf_counter() - t0 + self._update_token_usage_from_response(response, duration_seconds=elapsed) + result = self._build_vlm_response(response, has_tools=bool(tools)) + if tools: + return result + return self._clean_response(str(result)) + + return await retry_async( + _call, + max_retries=self.max_retries, + logger=logger, + operation_name="VolcEngine VLM async vision completion", ) diff --git a/openviking/models/vlm/base.py b/openviking/models/vlm/base.py index c54285ad4..270444283 100644 --- a/openviking/models/vlm/base.py +++ b/openviking/models/vlm/base.py @@ -58,7 +58,7 @@ def __init__(self, config: Dict[str, Any]): self.api_key = config.get("api_key") self.api_base = config.get("api_base") self.temperature = config.get("temperature", 0.0) - self.max_retries = config.get("max_retries", 2) + self.max_retries = config.get("max_retries", 3) self.max_tokens = config.get("max_tokens") self.extra_headers = config.get("extra_headers") self.stream = config.get("stream", False) @@ -94,7 +94,6 @@ async def get_completion_async( self, prompt: str = "", thinking: bool = False, - max_retries: int = 0, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[str] = None, messages: Optional[List[Dict[str, Any]]] = None, @@ -104,7 +103,6 @@ async def get_completion_async( Args: prompt: Text prompt (used if messages not provided) thinking: Whether to enable thinking mode - max_retries: Maximum number of retries tools: Optional list of tool definitions in OpenAI function format tool_choice: Optional tool choice mode ("auto", "none", or specific tool name) messages: Optional list of message dicts (takes precedence over prompt) diff --git a/openviking/models/vlm/llm.py b/openviking/models/vlm/llm.py index e1cde6ccf..422bbe6bb 100644 --- a/openviking/models/vlm/llm.py +++ b/openviking/models/vlm/llm.py @@ -183,7 +183,12 @@ def complete_json( if schema and not messages: prompt = f"{prompt}\n\n{get_json_schema_prompt(schema)}" - response = self._get_vlm().get_completion(prompt, thinking, tools, messages) + response = self._get_vlm().get_completion( + prompt=prompt, + thinking=thinking, + tools=tools, + messages=messages, + ) return parse_json_from_response(response) async def complete_json_async( @@ -191,7 +196,6 @@ async def complete_json_async( prompt: str = "", schema: Optional[Dict[str, Any]] = None, thinking: bool = False, - max_retries: int = 0, tools: Optional[List[Dict[str, Any]]] = None, messages: Optional[List[Dict[str, Any]]] = None, ) -> Optional[Dict[str, Any]]: @@ -200,7 +204,10 @@ async def complete_json_async( prompt = f"{prompt}\n\n{get_json_schema_prompt(schema)}" response = await self._get_vlm().get_completion_async( - prompt, thinking, max_retries, tools, messages + prompt=prompt, + thinking=thinking, + tools=tools, + messages=messages, ) return parse_json_from_response(response) @@ -227,13 +234,10 @@ async def complete_model_async( prompt: str, model_class: Type[T], thinking: bool = False, - max_retries: int = 0, ) -> Optional[T]: """Async version of complete_model.""" schema = model_class.model_json_schema() - response = await self.complete_json_async( - prompt, schema=schema, thinking=thinking, max_retries=max_retries - ) + response = await self.complete_json_async(prompt, schema=schema, thinking=thinking) if response is None: return None @@ -252,7 +256,13 @@ def get_vision_completion( messages: Optional[List[Dict[str, Any]]] = None, ) -> Union[str, Any]: """Get vision completion.""" - return self._get_vlm().get_vision_completion(prompt, images, thinking, tools, messages) + return self._get_vlm().get_vision_completion( + prompt=prompt, + images=images, + thinking=thinking, + tools=tools, + messages=messages, + ) async def get_vision_completion_async( self, @@ -264,5 +274,9 @@ async def get_vision_completion_async( ) -> Union[str, Any]: """Async vision completion.""" return await self._get_vlm().get_vision_completion_async( - prompt, images, thinking, tools, messages + prompt=prompt, + images=images, + thinking=thinking, + tools=tools, + messages=messages, ) diff --git a/openviking/utils/circuit_breaker.py b/openviking/utils/circuit_breaker.py index af511aebf..cddd2a7fe 100644 --- a/openviking/utils/circuit_breaker.py +++ b/openviking/utils/circuit_breaker.py @@ -7,55 +7,11 @@ import threading import time +from openviking.utils.model_retry import classify_api_error from openviking_cli.utils.logger import get_logger logger = get_logger(__name__) -# --- Error classification --- - -_PERMANENT_PATTERNS = ("403", "401", "Forbidden", "Unauthorized", "AccountOverdue") -_TRANSIENT_PATTERNS = ( - "429", - "500", - "502", - "503", - "504", - "TooManyRequests", - "RateLimit", - "timeout", - "Timeout", - "ConnectionError", - "Connection refused", - "Connection reset", -) - - -def classify_api_error(error: Exception) -> str: - """Classify an API error as permanent, transient, or unknown. - - Checks both str(error) and str(error.__cause__) for known patterns. - - Returns: - "permanent" — 403/401, never retry. - "transient" — 429/5xx/timeout, safe to retry. - "unknown" — unrecognized, treated as transient by callers. - """ - texts = [str(error)] - if error.__cause__ is not None: - texts.append(str(error.__cause__)) - - for text in texts: - for pattern in _PERMANENT_PATTERNS: - if pattern in text: - return "permanent" - - for text in texts: - for pattern in _TRANSIENT_PATTERNS: - if pattern in text: - return "transient" - - return "unknown" - # --- Circuit breaker --- diff --git a/openviking/utils/model_retry.py b/openviking/utils/model_retry.py new file mode 100644 index 000000000..2e7cfa957 --- /dev/null +++ b/openviking/utils/model_retry.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +import asyncio +import random +import time +from typing import Awaitable, Callable, TypeVar + +T = TypeVar("T") + +PERMANENT_API_ERROR_PATTERNS = ( + "400", + "401", + "403", + "Forbidden", + "Unauthorized", + "AccountOverdue", +) + +TRANSIENT_API_ERROR_PATTERNS = ( + "429", + "500", + "502", + "503", + "504", + "TooManyRequests", + "RateLimit", + "RequestBurstTooFast", + "timeout", + "Timeout", + "ConnectionError", + "Connection refused", + "Connection reset", +) + + +def classify_api_error(error: Exception) -> str: + """Classify an API error as permanent, transient, or unknown.""" + texts = [str(error)] + if error.__cause__ is not None: + texts.append(str(error.__cause__)) + + for text in texts: + for pattern in PERMANENT_API_ERROR_PATTERNS: + if pattern in text: + return "permanent" + + for text in texts: + for pattern in TRANSIENT_API_ERROR_PATTERNS: + if pattern in text: + return "transient" + + return "unknown" + + +def is_retryable_api_error(error: Exception) -> bool: + """Return True if the error should be retried.""" + return classify_api_error(error) == "transient" + + +def _compute_delay( + attempt: int, + *, + base_delay: float, + max_delay: float, + jitter: bool, +) -> float: + delay = min(base_delay * (2**attempt), max_delay) + if jitter: + delay += random.uniform(0.0, min(base_delay, delay)) + return delay + + +def retry_sync( + func: Callable[[], T], + *, + max_retries: int, + base_delay: float = 0.5, + max_delay: float = 8.0, + jitter: bool = True, + is_retryable: Callable[[Exception], bool] = is_retryable_api_error, + logger=None, + operation_name: str = "operation", +) -> T: + """Retry a sync function on known transient errors.""" + attempt = 0 + + while True: + try: + return func() + except Exception as e: + if max_retries <= 0 or attempt >= max_retries or not is_retryable(e): + raise + + delay = _compute_delay( + attempt, + base_delay=base_delay, + max_delay=max_delay, + jitter=jitter, + ) + if logger: + logger.warning( + "%s failed with retryable error (retry %d/%d): %s; retrying in %.2fs", + operation_name, + attempt + 1, + max_retries, + e, + delay, + ) + time.sleep(delay) + attempt += 1 + + +async def retry_async( + func: Callable[[], Awaitable[T]], + *, + max_retries: int, + base_delay: float = 0.5, + max_delay: float = 8.0, + jitter: bool = True, + is_retryable: Callable[[Exception], bool] = is_retryable_api_error, + logger=None, + operation_name: str = "operation", +) -> T: + """Retry an async function on known transient errors.""" + attempt = 0 + + while True: + try: + return await func() + except Exception as e: + if max_retries <= 0 or attempt >= max_retries or not is_retryable(e): + raise + + delay = _compute_delay( + attempt, + base_delay=base_delay, + max_delay=max_delay, + jitter=jitter, + ) + if logger: + logger.warning( + "%s failed with retryable error (retry %d/%d): %s; retrying in %.2fs", + operation_name, + attempt + 1, + max_retries, + e, + delay, + ) + await asyncio.sleep(delay) + attempt += 1 diff --git a/openviking_cli/utils/config/embedding_config.py b/openviking_cli/utils/config/embedding_config.py index 8caad66a9..2392198c9 100644 --- a/openviking_cli/utils/config/embedding_config.py +++ b/openviking_cli/utils/config/embedding_config.py @@ -265,6 +265,10 @@ class EmbeddingConfig(BaseModel): max_concurrent: int = Field( default=10, description="Maximum number of concurrent embedding requests" ) + max_retries: int = Field( + default=3, + description="Maximum retry attempts for embedding provider calls (0 disables retry)", + ) text_source: str = Field( default="summary_first", description="Text source for file vectorization: summary_first|summary_only|content_only", @@ -340,6 +344,7 @@ def _create_embedder( "api_version": cfg.api_version, "dimension": cfg.dimension, "provider": "openai", + "config": {"max_retries": self.max_retries}, **({"query_param": cfg.query_param} if cfg.query_param else {}), **({"document_param": cfg.document_param} if cfg.document_param else {}), **({"extra_headers": cfg.extra_headers} if cfg.extra_headers else {}), @@ -354,6 +359,7 @@ def _create_embedder( "api_version": cfg.api_version, "dimension": cfg.dimension, "provider": "azure", + "config": {"max_retries": self.max_retries}, **({"query_param": cfg.query_param} if cfg.query_param else {}), **({"document_param": cfg.document_param} if cfg.document_param else {}), **({"extra_headers": cfg.extra_headers} if cfg.extra_headers else {}), @@ -367,6 +373,7 @@ def _create_embedder( "api_base": cfg.api_base, "dimension": cfg.dimension, "input_type": cfg.input, + "config": {"max_retries": self.max_retries}, }, ), ("volcengine", "sparse"): ( @@ -375,6 +382,7 @@ def _create_embedder( "model_name": cfg.model, "api_key": cfg.api_key, "api_base": cfg.api_base, + "config": {"max_retries": self.max_retries}, }, ), ("volcengine", "hybrid"): ( @@ -385,6 +393,7 @@ def _create_embedder( "api_base": cfg.api_base, "dimension": cfg.dimension, "input_type": cfg.input, + "config": {"max_retries": self.max_retries}, }, ), ("vikingdb", "dense"): ( @@ -398,6 +407,7 @@ def _create_embedder( "host": cfg.host, "dimension": cfg.dimension, "input_type": cfg.input, + "config": {"max_retries": self.max_retries}, }, ), ("vikingdb", "sparse"): ( @@ -409,6 +419,7 @@ def _create_embedder( "sk": cfg.sk, "region": cfg.region, "host": cfg.host, + "config": {"max_retries": self.max_retries}, }, ), ("vikingdb", "hybrid"): ( @@ -422,6 +433,7 @@ def _create_embedder( "host": cfg.host, "dimension": cfg.dimension, "input_type": cfg.input, + "config": {"max_retries": self.max_retries}, }, ), ("jina", "dense"): ( @@ -431,6 +443,7 @@ def _create_embedder( "api_key": cfg.api_key, "api_base": cfg.api_base, "dimension": cfg.dimension, + "config": {"max_retries": self.max_retries}, **({"query_param": cfg.query_param} if cfg.query_param else {}), **({"document_param": cfg.document_param} if cfg.document_param else {}), }, @@ -441,6 +454,7 @@ def _create_embedder( "model_name": cfg.model, "api_key": cfg.api_key, "dimension": cfg.dimension, + "config": {"max_retries": self.max_retries}, **({"query_param": cfg.query_param} if cfg.query_param else {}), **({"document_param": cfg.document_param} if cfg.document_param else {}), }, @@ -454,6 +468,7 @@ def _create_embedder( or "no-key", # Ollama ignores the key, but client requires non-empty "api_base": cfg.api_base or "http://localhost:11434/v1", "dimension": cfg.dimension, + "config": {"max_retries": self.max_retries}, }, ), ("voyage", "dense"): ( @@ -463,6 +478,7 @@ def _create_embedder( "api_key": cfg.api_key, "api_base": cfg.api_base, "dimension": cfg.dimension, + "config": {"max_retries": self.max_retries}, }, ), ("minimax", "dense"): ( @@ -472,6 +488,7 @@ def _create_embedder( "api_key": cfg.api_key, "api_base": cfg.api_base, "dimension": cfg.dimension, + "config": {"max_retries": self.max_retries}, **({"query_param": cfg.query_param} if cfg.query_param else {}), **({"document_param": cfg.document_param} if cfg.document_param else {}), **({"extra_headers": cfg.extra_headers} if cfg.extra_headers else {}), @@ -493,6 +510,7 @@ def _create_embedder( "api_key": cfg.api_key, "api_base": cfg.api_base, "dimension": cfg.dimension, + "config": {"max_retries": self.max_retries}, **({"query_param": cfg.query_param} if cfg.query_param else {}), **({"document_param": cfg.document_param} if cfg.document_param else {}), **({"extra_headers": cfg.extra_headers} if cfg.extra_headers else {}), diff --git a/openviking_cli/utils/config/vlm_config.py b/openviking_cli/utils/config/vlm_config.py index dd67e7c2c..5bff0a8e4 100644 --- a/openviking_cli/utils/config/vlm_config.py +++ b/openviking_cli/utils/config/vlm_config.py @@ -12,7 +12,7 @@ class VLMConfig(BaseModel): api_key: Optional[str] = Field(default=None, description="API key") api_base: Optional[str] = Field(default=None, description="API base URL") temperature: float = Field(default=0.0, description="Generation temperature") - max_retries: int = Field(default=2, description="Maximum retry attempts") + max_retries: int = Field(default=3, description="Maximum retry attempts") provider: Optional[str] = Field(default=None, description="Provider type") backend: Optional[str] = Field( @@ -181,19 +181,26 @@ def get_completion( messages: Optional[List[Dict[str, Any]]] = None, ) -> Union[str, Any]: """Get LLM completion.""" - return self.get_vlm_instance().get_completion(prompt, thinking, tools, messages) + return self.get_vlm_instance().get_completion( + prompt=prompt, + thinking=thinking, + tools=tools, + messages=messages, + ) async def get_completion_async( self, prompt: str = "", thinking: bool = False, - max_retries: int = 0, tools: Optional[List[Dict[str, Any]]] = None, messages: Optional[List[Dict[str, Any]]] = None, ) -> Union[str, Any]: - """Get LLM completion asynchronously, max_retries=0 means no retry.""" + """Get LLM completion asynchronously.""" return await self.get_vlm_instance().get_completion_async( - prompt, thinking, max_retries, tools, messages + prompt=prompt, + thinking=thinking, + tools=tools, + messages=messages, ) def is_available(self) -> bool: @@ -210,7 +217,11 @@ def get_vision_completion( ) -> Union[str, Any]: """Get LLM completion with images.""" return self.get_vlm_instance().get_vision_completion( - prompt, images, thinking, tools, messages + prompt=prompt, + images=images, + thinking=thinking, + tools=tools, + messages=messages, ) async def get_vision_completion_async( @@ -223,5 +234,9 @@ async def get_vision_completion_async( ) -> Union[str, Any]: """Get LLM completion with images asynchronously.""" return await self.get_vlm_instance().get_vision_completion_async( - prompt, images, thinking, tools, messages + prompt=prompt, + images=images, + thinking=thinking, + tools=tools, + messages=messages, ) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 19f1d9488..87572a28d 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -7,6 +7,7 @@ AsyncHTTPClient integration tests can run without a manually started server process. """ +import copy import math import os import shutil @@ -17,12 +18,15 @@ import httpx import pytest +import pytest_asyncio import uvicorn +from openviking import AsyncOpenViking from openviking.server.app import create_app from openviking.server.config import ServerConfig from openviking.service.core import OpenVikingService from openviking_cli.session.user_id import UserIdentifier +from openviking_cli.utils.config.open_viking_config import OpenVikingConfigSingleton PROJECT_ROOT = Path(__file__).parent.parent.parent TEST_TMP_DIR = PROJECT_ROOT / "test_data" / "tmp_integration" @@ -52,6 +56,20 @@ ] +def _local_engine_available() -> bool: + try: + from openviking.storage.vectordb.engine import ENGINE_VARIANT + except Exception: + return False + return ENGINE_VARIANT != "unavailable" + + +requires_engine = pytest.mark.skipif( + not _local_engine_available(), + reason="local vectordb engine unavailable", +) + + def l2_norm(vec: list[float]) -> float: """Compute L2 norm of a vector.""" return math.sqrt(sum(v * v for v in vec)) @@ -69,6 +87,85 @@ def gemini_embedder(): return GeminiDenseEmbedder("gemini-embedding-2-preview", api_key=GOOGLE_API_KEY, dimension=768) +def gemini_config_dict( + model: str, + dim: int, + query_param: str | None = None, + doc_param: str | None = None, +) -> dict: + """Build a minimal embedded-mode config for Gemini-backed integration tests.""" + return { + "storage": { + "workspace": str(TEST_TMP_DIR / "gemini"), + "agfs": {"backend": "local"}, + "vectordb": {"name": "test", "backend": "local", "project": "default"}, + }, + "embedding": { + "dense": { + "provider": "gemini", + "api_key": GOOGLE_API_KEY, + "model": model, + "dimension": dim, + **({"query_param": query_param} if query_param else {}), + **({"document_param": doc_param} if doc_param else {}), + } + }, + } + + +async def teardown_ov_client() -> None: + """Reset singleton client/config state used by embedded integration tests.""" + await AsyncOpenViking.reset() + OpenVikingConfigSingleton.reset_instance() + + +async def make_ov_client(config_dict: dict, data_path: str) -> AsyncOpenViking: + """Create an AsyncOpenViking client from an explicit config dict.""" + if not GOOGLE_API_KEY: + pytest.skip("GOOGLE_API_KEY not set") + try: + from openviking.models.embedder.gemini_embedders import GeminiDenseEmbedder # noqa: F401 + except (ImportError, ModuleNotFoundError, AttributeError): + pytest.skip("google-genai not installed") + + await teardown_ov_client() + + workspace = Path(data_path) + shutil.rmtree(workspace, ignore_errors=True) + workspace.mkdir(parents=True, exist_ok=True) + + effective_config = copy.deepcopy(config_dict) + storage = effective_config.setdefault("storage", {}) + storage["workspace"] = str(workspace) + storage.setdefault("agfs", {"backend": "local"}) + storage.setdefault("vectordb", {"name": "test", "backend": "local", "project": "default"}) + + OpenVikingConfigSingleton.initialize(config_dict=effective_config) + + client = AsyncOpenViking(path=str(workspace)) + await client.initialize() + return client + + +def sample_markdown(base_dir: Path, slug: str, content: str) -> Path: + """Write a markdown file for an integration test case.""" + path = base_dir / f"{slug}.md" + path.write_text(content, encoding="utf-8") + return path + + +@pytest_asyncio.fixture(scope="function") +async def gemini_ov_client(tmp_path): + """Provide a Gemini-backed OpenViking client and its model metadata.""" + model = "gemini-embedding-2-preview" + dim = 768 + client = await make_ov_client(gemini_config_dict(model, dim), str(tmp_path / "ov_gemini")) + try: + yield client, model, dim + finally: + await teardown_ov_client() + + @pytest.fixture(scope="session") def temp_dir(): """Create temp directory for the whole test session.""" diff --git a/tests/models/test_vlm_strip_think_tags.py b/tests/models/test_vlm_strip_think_tags.py index c7a996cfb..95e45b682 100644 --- a/tests/models/test_vlm_strip_think_tags.py +++ b/tests/models/test_vlm_strip_think_tags.py @@ -18,7 +18,7 @@ class _Stub(VLMBase): def get_completion(self, prompt, thinking=False): return "" - async def get_completion_async(self, prompt, thinking=False, max_retries=0): + async def get_completion_async(self, prompt, thinking=False): return "" def get_vision_completion(self, prompt, images, thinking=False): diff --git a/tests/server/conftest.py b/tests/server/conftest.py index 875ddb502..577420d56 100644 --- a/tests/server/conftest.py +++ b/tests/server/conftest.py @@ -73,7 +73,7 @@ def get_dimension(self) -> int: def _install_fake_vlm(monkeypatch): """Use a fake VLM so server tests never hit external LLM APIs.""" - async def _fake_get_completion(self, prompt, thinking=False, max_retries=0): + async def _fake_get_completion(self, prompt, thinking=False): return "# Test Summary\n\nFake summary for testing.\n\n## Details\nTest content." async def _fake_get_vision_completion(self, prompt, images, thinking=False): diff --git a/tests/unit/test_extra_headers_embedding.py b/tests/unit/test_extra_headers_embedding.py index 91e2b7a90..d345c8c84 100644 --- a/tests/unit/test_extra_headers_embedding.py +++ b/tests/unit/test_extra_headers_embedding.py @@ -98,6 +98,23 @@ def test_factory_omits_extra_headers_when_none(self, mock_openai_class): call_kwargs = mock_openai_class.call_args[1] assert "default_headers" not in call_kwargs + @patch("openai.OpenAI") + def test_factory_injects_embedding_max_retries(self, mock_openai_class): + """Factory should inject top-level embedding.max_retries into embedder config.""" + mock_openai_class.return_value = _make_mock_client() + + cfg = EmbeddingModelConfig( + provider="openai", + model="text-embedding-3-small", + api_key="sk-test", + dimension=8, + ) + embedder = EmbeddingConfig(dense=cfg, max_retries=0)._create_embedder( + "openai", "dense", cfg + ) + + assert embedder.max_retries == 0 + class TestEmbeddingModelConfigExtraHeaders: """Test that EmbeddingModelConfig accepts and stores the extra_headers field.""" diff --git a/tests/unit/test_extra_headers_vlm.py b/tests/unit/test_extra_headers_vlm.py index 97f6bea00..c087c2370 100644 --- a/tests/unit/test_extra_headers_vlm.py +++ b/tests/unit/test_extra_headers_vlm.py @@ -210,7 +210,7 @@ class StubVLM(OpenAIVLM): def get_completion(self, prompt, thinking=False): return "" - async def get_completion_async(self, prompt, thinking=False, max_retries=0): + async def get_completion_async(self, prompt, thinking=False): return "" def get_vision_completion(self, prompt, images, thinking=False): @@ -236,7 +236,7 @@ class StubVLM(OpenAIVLM): def get_completion(self, prompt, thinking=False): return "" - async def get_completion_async(self, prompt, thinking=False, max_retries=0): + async def get_completion_async(self, prompt, thinking=False): return "" def get_vision_completion(self, prompt, images, thinking=False): diff --git a/tests/unit/test_model_retry.py b/tests/unit/test_model_retry.py new file mode 100644 index 000000000..7f0360805 --- /dev/null +++ b/tests/unit/test_model_retry.py @@ -0,0 +1,38 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for shared model retry helpers.""" + +import pytest + +from openviking.utils.model_retry import classify_api_error, retry_async, retry_sync + + +def test_classify_api_error_recognizes_request_burst_too_fast(): + assert classify_api_error(RuntimeError("RequestBurstTooFast")) == "transient" + + +def test_retry_sync_retries_transient_error_until_success(): + attempts = {"count": 0} + + def _call(): + attempts["count"] += 1 + if attempts["count"] < 3: + raise RuntimeError("429 TooManyRequests") + return "ok" + + assert retry_sync(_call, max_retries=3) == "ok" + assert attempts["count"] == 3 + + +@pytest.mark.asyncio +async def test_retry_async_does_not_retry_unknown_error(): + attempts = {"count": 0} + + async def _call(): + attempts["count"] += 1 + raise RuntimeError("some unexpected validation failure") + + with pytest.raises(RuntimeError): + await retry_async(_call, max_retries=3) + + assert attempts["count"] == 1 diff --git a/tests/unit/test_stream_config_vlm.py b/tests/unit/test_stream_config_vlm.py index 64b2f81c2..ece01b97e 100644 --- a/tests/unit/test_stream_config_vlm.py +++ b/tests/unit/test_stream_config_vlm.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: AGPL-3.0 """Tests for VLM stream configuration support.""" -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -131,10 +131,7 @@ async def async_generator(): for chunk in chunks: yield chunk - async def mock_create(*args, **kwargs): - return async_generator() - - mock_client.chat.completions.create = mock_create + mock_client.chat.completions.create = AsyncMock(return_value=async_generator()) mock_async_openai_class.return_value = mock_client vlm = OpenAIVLM( @@ -220,10 +217,7 @@ async def async_generator(): for chunk in chunks: yield chunk - async def mock_create(*args, **kwargs): - return async_generator() - - mock_client.chat.completions.create = mock_create + mock_client.chat.completions.create = AsyncMock(return_value=async_generator()) mock_async_openai_class.return_value = mock_client vlm = OpenAIVLM( @@ -253,7 +247,7 @@ class StubVLM(OpenAIVLM): def get_completion(self, prompt, thinking=False): return "" - async def get_completion_async(self, prompt, thinking=False, max_retries=0): + async def get_completion_async(self, prompt, thinking=False): return "" def get_vision_completion(self, prompt, images, thinking=False): @@ -277,7 +271,7 @@ class StubVLM(OpenAIVLM): def get_completion(self, prompt, thinking=False): return "" - async def get_completion_async(self, prompt, thinking=False, max_retries=0): + async def get_completion_async(self, prompt, thinking=False): return "" def get_vision_completion(self, prompt, images, thinking=False): @@ -389,6 +383,23 @@ def test_vlm_config_stream_in_providers_takes_precedence(self): result = config._build_vlm_config_dict() assert result["stream"] is True + def test_vlm_config_max_retries_defaults_to_three(self): + """VLMConfig should default max_retries to 3.""" + from openviking_cli.utils.config.vlm_config import VLMConfig + + config = VLMConfig( + model="gpt-4o", + provider="openai", + providers={ + "openai": { + "api_key": "sk-test", + } + }, + ) + + assert config.max_retries == 3 + assert config._build_vlm_config_dict()["max_retries"] == 3 + class TestStreamingResponseProcessing: """Test streaming response processing logic.""" diff --git a/tests/unit/test_vlm_response_formats.py b/tests/unit/test_vlm_response_formats.py index 007b24b86..3d0eaa360 100644 --- a/tests/unit/test_vlm_response_formats.py +++ b/tests/unit/test_vlm_response_formats.py @@ -18,9 +18,7 @@ class ConcreteVLM(VLMBase): def get_completion(self, prompt: str, thinking: bool = False) -> str: pass - async def get_completion_async( - self, prompt: str, thinking: bool = False, max_retries: int = 0 - ) -> str: + async def get_completion_async(self, prompt: str, thinking: bool = False) -> str: pass def get_vision_completion(