diff --git a/openviking/models/vlm/backends/litellm_vlm.py b/openviking/models/vlm/backends/litellm_vlm.py index 7a951d732..3b946c244 100644 --- a/openviking/models/vlm/backends/litellm_vlm.py +++ b/openviking/models/vlm/backends/litellm_vlm.py @@ -199,6 +199,8 @@ def _build_kwargs(self, model: str, messages: list) -> dict[str, Any]: "messages": messages, "temperature": self.temperature, } + if self.max_tokens is not None: + kwargs["max_tokens"] = self.max_tokens if self.api_key: kwargs["api_key"] = self.api_key diff --git a/openviking/models/vlm/backends/openai_vlm.py b/openviking/models/vlm/backends/openai_vlm.py index de59b8dc7..643a57c4b 100644 --- a/openviking/models/vlm/backends/openai_vlm.py +++ b/openviking/models/vlm/backends/openai_vlm.py @@ -62,6 +62,8 @@ def get_completion(self, prompt: str, thinking: bool = False) -> str: "messages": [{"role": "user", "content": prompt}], "temperature": self.temperature, } + if self.max_tokens is not None: + kwargs["max_tokens"] = self.max_tokens response = client.chat.completions.create(**kwargs) self._update_token_usage_from_response(response) @@ -77,6 +79,8 @@ async def get_completion_async( "messages": [{"role": "user", "content": prompt}], "temperature": self.temperature, } + if self.max_tokens is not None: + kwargs["max_tokens"] = self.max_tokens last_error = None for attempt in range(max_retries + 1): @@ -165,6 +169,8 @@ def get_vision_completion( "messages": [{"role": "user", "content": content}], "temperature": self.temperature, } + if self.max_tokens is not None: + kwargs["max_tokens"] = self.max_tokens response = client.chat.completions.create(**kwargs) self._update_token_usage_from_response(response) @@ -189,6 +195,8 @@ async def get_vision_completion_async( "messages": [{"role": "user", "content": content}], "temperature": self.temperature, } + if self.max_tokens is not None: + kwargs["max_tokens"] = self.max_tokens response = await client.chat.completions.create(**kwargs) self._update_token_usage_from_response(response) diff --git a/openviking/models/vlm/backends/volcengine_vlm.py b/openviking/models/vlm/backends/volcengine_vlm.py index 985025cac..604d51bb7 100644 --- a/openviking/models/vlm/backends/volcengine_vlm.py +++ b/openviking/models/vlm/backends/volcengine_vlm.py @@ -68,6 +68,8 @@ def get_completion(self, prompt: str, thinking: bool = False) -> str: "temperature": self.temperature, "thinking": {"type": "disabled" if not thinking else "enabled"}, } + if self.max_tokens is not None: + kwargs["max_tokens"] = self.max_tokens response = client.chat.completions.create(**kwargs) self._update_token_usage_from_response(response) @@ -84,6 +86,8 @@ async def get_completion_async( "temperature": self.temperature, "thinking": {"type": "disabled" if not thinking else "enabled"}, } + if self.max_tokens is not None: + kwargs["max_tokens"] = self.max_tokens last_error = None for attempt in range(max_retries + 1): @@ -235,6 +239,8 @@ def get_vision_completion( "temperature": self.temperature, "thinking": {"type": "disabled" if not thinking else "enabled"}, } + if self.max_tokens is not None: + kwargs["max_tokens"] = self.max_tokens response = client.chat.completions.create(**kwargs) self._update_token_usage_from_response(response) @@ -260,6 +266,8 @@ async def get_vision_completion_async( "temperature": self.temperature, "thinking": {"type": "disabled" if not thinking else "enabled"}, } + if self.max_tokens is not None: + kwargs["max_tokens"] = self.max_tokens response = await client.chat.completions.create(**kwargs) self._update_token_usage_from_response(response) diff --git a/openviking/models/vlm/base.py b/openviking/models/vlm/base.py index 66b859ba1..0bb6c4188 100644 --- a/openviking/models/vlm/base.py +++ b/openviking/models/vlm/base.py @@ -22,6 +22,7 @@ def __init__(self, config: Dict[str, Any]): self.api_base = config.get("api_base") self.temperature = config.get("temperature", 0.0) self.max_retries = config.get("max_retries", 2) + self.max_tokens = config.get("max_tokens") # Token usage tracking self._token_tracker = TokenUsageTracker() diff --git a/openviking_cli/utils/config/vlm_config.py b/openviking_cli/utils/config/vlm_config.py index eea7682bf..6c42d6436 100644 --- a/openviking_cli/utils/config/vlm_config.py +++ b/openviking_cli/utils/config/vlm_config.py @@ -26,6 +26,10 @@ class VLMConfig(BaseModel): default_provider: Optional[str] = Field(default=None, description="Default provider name") + max_tokens: Optional[int] = Field( + default=None, description="Maximum tokens for VLM completion output (None = provider default)" + ) + thinking: bool = Field(default=False, description="Enable thinking mode for VolcEngine models") max_concurrent: int = Field( @@ -134,6 +138,7 @@ def _build_vlm_config_dict(self) -> Dict[str, Any]: "max_retries": self.max_retries, "provider": name, "thinking": self.thinking, + "max_tokens": self.max_tokens, } if config: