diff --git a/pyproject.toml b/pyproject.toml index 4eaec6aa1..6933ca12c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "paho-mqtt>=2.0.0", "wecom-aibot-sdk @ https://agentscope.oss-cn-zhangjiakou.aliyuncs.com/pre_whl/wecom_aibot_sdk-1.0.0-py3-none-any.whl", "matrix-nio>=0.24.0", + "google-genai>=1.67.0", "tzdata>=2024.1", ] diff --git a/src/copaw/agents/model_factory.py b/src/copaw/agents/model_factory.py index 658e5c8f7..4f5baca66 100644 --- a/src/copaw/agents/model_factory.py +++ b/src/copaw/agents/model_factory.py @@ -26,6 +26,13 @@ AnthropicChatFormatter = None AnthropicChatModel = None +try: + from agentscope.formatter import GeminiChatFormatter + from agentscope.model import GeminiChatModel +except ImportError: # pragma: no cover - compatibility fallback + GeminiChatFormatter = None + GeminiChatModel = None + from .utils.tool_message_utils import _sanitize_tool_messages from ..providers import ProviderManager from ..providers.retry_chat_model import RetryChatModel @@ -82,6 +89,8 @@ async def wrapper( } if AnthropicChatModel is not None and AnthropicChatFormatter is not None: _CHAT_MODEL_FORMATTER_MAP[AnthropicChatModel] = AnthropicChatFormatter +if GeminiChatModel is not None and GeminiChatFormatter is not None: + _CHAT_MODEL_FORMATTER_MAP[GeminiChatModel] = GeminiChatFormatter def _get_formatter_for_chat_model( diff --git a/src/copaw/app/routers/providers.py b/src/copaw/app/routers/providers.py index 1fba69b5b..2a912e98d 100644 --- a/src/copaw/app/routers/providers.py +++ b/src/copaw/app/routers/providers.py @@ -14,7 +14,11 @@ router = APIRouter(prefix="/models", tags=["models"]) -ChatModelName = Literal["OpenAIChatModel", "AnthropicChatModel"] +ChatModelName = Literal[ + "OpenAIChatModel", + "AnthropicChatModel", + "GeminiChatModel", +] def get_provider_manager(request: Request) -> ProviderManager: diff --git a/src/copaw/providers/gemini_provider.py b/src/copaw/providers/gemini_provider.py new file mode 100644 index 000000000..9788678ba --- /dev/null +++ b/src/copaw/providers/gemini_provider.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- +"""A Google Gemini provider implementation using AgentScope's native +GeminiChatModel.""" + +from __future__ import annotations + +from typing import Any, List + +from agentscope.model import ChatModelBase +from google import genai +from google.genai import errors as genai_errors +from google.genai import types as genai_types + +from copaw.providers.provider import ModelInfo, Provider + + +class GeminiProvider(Provider): + """Provider implementation for Google Gemini API.""" + + def _client(self, timeout: float = 10) -> Any: + return genai.Client( + api_key=self.api_key, + http_options=genai_types.HttpOptions(timeout=int(timeout * 1000)), + ) + + @staticmethod + def _normalize_models_payload(payload: Any) -> List[ModelInfo]: + models: List[ModelInfo] = [] + for row in payload or []: + model_id = str(getattr(row, "name", "") or "").strip() + + if not model_id: + continue + + # Gemini API returns model names like "models/gemini-2.5-flash" + # Strip the "models/" prefix for cleaner IDs + if model_id.startswith("models/"): + model_id = model_id[len("models/") :] + + display_name = str( + getattr(row, "display_name", "") or model_id, + ).strip() + + if not display_name or display_name.startswith("models/"): + display_name = model_id + + models.append(ModelInfo(id=model_id, name=display_name)) + + deduped: List[ModelInfo] = [] + seen: set[str] = set() + for model in models: + if model.id in seen: + continue + seen.add(model.id) + deduped.append(model) + return deduped + + async def check_connection(self, timeout: float = 10) -> tuple[bool, str]: + """Check if Google Gemini provider is reachable.""" + try: + client = self._client(timeout=timeout) + # Use the async list models endpoint to verify connectivity + async for _ in await client.aio.models.list(): + break + return True, "" + except genai_errors.APIError: + return ( + False, + "Failed to connect to Google Gemini API. " + "Check your API key.", + ) + except Exception: + return ( + False, + "Unknown exception when connecting to Google Gemini API.", + ) + + async def fetch_models(self, timeout: float = 10) -> List[ModelInfo]: + """Fetch available models from Gemini API.""" + try: + client = self._client(timeout=timeout) + payload = [] + async for model in await client.aio.models.list(): + payload.append(model) + models = self._normalize_models_payload(payload) + return models + except genai_errors.APIError: + return [] + except Exception: + return [] + + async def check_model_connection( + self, + model_id: str, + timeout: float = 10, + ) -> tuple[bool, str]: + """Check if a specific Gemini model is reachable/usable.""" + target = (model_id or "").strip() + if not target: + return False, "Empty model ID" + + try: + client = self._client(timeout=timeout) + response = await client.aio.models.generate_content_stream( + model=target, + contents="ping", + ) + async for _ in response: + break + return True, "" + except genai_errors.APIError: + return ( + False, + f"Model '{model_id}' is not reachable or usable", + ) + except Exception: + return ( + False, + f"Unknown exception when connecting to model '{model_id}'", + ) + + def get_chat_model_instance(self, model_id: str) -> ChatModelBase: + from agentscope.model import GeminiChatModel + + return GeminiChatModel( + model_name=model_id, + stream=True, + api_key=self.api_key, + generate_kwargs=self.generate_kwargs, + ) diff --git a/src/copaw/providers/provider_manager.py b/src/copaw/providers/provider_manager.py index 62a329c1b..34fcf81e9 100644 --- a/src/copaw/providers/provider_manager.py +++ b/src/copaw/providers/provider_manager.py @@ -21,6 +21,7 @@ ) from copaw.providers.openai_provider import OpenAIProvider from copaw.providers.anthropic_provider import AnthropicProvider +from copaw.providers.gemini_provider import GeminiProvider from copaw.providers.ollama_provider import OllamaProvider from copaw.constant import SECRET_DIR from copaw.local_models import create_local_chat_model @@ -97,6 +98,19 @@ ANTHROPIC_MODELS: List[ModelInfo] = [] +GEMINI_MODELS: List[ModelInfo] = [ + ModelInfo(id="gemini-3.1-pro-preview", name="Gemini 3.1 Pro Preview"), + ModelInfo(id="gemini-3-flash-preview", name="Gemini 3 Flash Preview"), + ModelInfo( + id="gemini-3.1-flash-lite-preview", + name="Gemini 3.1 Flash Lite Preview", + ), + ModelInfo(id="gemini-2.5-pro", name="Gemini 2.5 Pro"), + ModelInfo(id="gemini-2.5-flash", name="Gemini 2.5 Flash"), + ModelInfo(id="gemini-2.5-flash-lite", name="Gemini 2.5 Flash Lite"), + ModelInfo(id="gemini-2.0-flash", name="Gemini 2.0 Flash"), +] + PROVIDER_MODELSCOPE = OpenAIProvider( id="modelscope", name="ModelScope", @@ -183,6 +197,17 @@ freeze_url=True, ) +PROVIDER_GEMINI = GeminiProvider( + id="gemini", + name="Google Gemini", + base_url="https://generativelanguage.googleapis.com", + api_key_prefix="", + models=GEMINI_MODELS, + chat_model="GeminiChatModel", + freeze_url=True, + support_model_discovery=True, +) + PROVIDER_OLLAMA = OllamaProvider( id="ollama", name="Ollama", @@ -259,6 +284,7 @@ def _init_builtins(self): self._add_builtin(PROVIDER_MINIMAX) self._add_builtin(PROVIDER_DEEPSEEK) self._add_builtin(PROVIDER_ANTHROPIC) + self._add_builtin(PROVIDER_GEMINI) self._add_builtin(PROVIDER_OLLAMA) self._add_builtin(PROVIDER_LMSTUDIO) self._add_builtin(PROVIDER_LLAMACPP) @@ -469,6 +495,8 @@ def _provider_from_data(self, data: Dict) -> Provider: if provider_id == "anthropic" or chat_model == "AnthropicChatModel": return AnthropicProvider.model_validate(data) + if provider_id == "gemini" or chat_model == "GeminiChatModel": + return GeminiProvider.model_validate(data) if provider_id == "ollama": return OllamaProvider.model_validate(data) if data.get("is_local", False): diff --git a/tests/unit/providers/test_gemini_provider.py b/tests/unit/providers/test_gemini_provider.py new file mode 100644 index 000000000..c784b89c6 --- /dev/null +++ b/tests/unit/providers/test_gemini_provider.py @@ -0,0 +1,340 @@ +# -*- coding: utf-8 -*- +# pylint: disable=redefined-outer-name,unused-argument,protected-access +from __future__ import annotations + +from types import SimpleNamespace + +from google.genai import errors as genai_errors + +from copaw.providers.gemini_provider import GeminiProvider + + +def _make_provider() -> GeminiProvider: + return GeminiProvider( + id="gemini", + name="Gemini", + base_url="https://generativelanguage.googleapis.com", + api_key="gem-test", + chat_model="GeminiChatModel", + ) + + +class _AsyncIter: + """Helper that turns a list into an async iterator.""" + + def __init__(self, items): + self._items = iter(items) + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self._items) + except StopIteration as exc: + raise StopAsyncIteration from exc + + +# -- check_connection -------------------------------------------------------- + + +async def test_check_connection_success(monkeypatch) -> None: + provider = _make_provider() + + class FakeModels: + async def list(self): + return _AsyncIter( + [SimpleNamespace(name="models/gemini-2.5-flash")], + ) + + fake_client = SimpleNamespace( + aio=SimpleNamespace(models=FakeModels()), + ) + monkeypatch.setattr(provider, "_client", lambda timeout=5: fake_client) + + ok, msg = await provider.check_connection(timeout=2.0) + + assert ok is True + assert msg == "" + + +async def test_check_connection_api_error_returns_false(monkeypatch) -> None: + provider = _make_provider() + + class FakeModels: + async def list(self): + raise genai_errors.APIError(403, {"error": "forbidden"}) + + fake_client = SimpleNamespace( + aio=SimpleNamespace(models=FakeModels()), + ) + monkeypatch.setattr(provider, "_client", lambda timeout=5: fake_client) + + ok, msg = await provider.check_connection(timeout=1.0) + + assert ok is False + assert "Failed to connect to Google Gemini API" in msg + + +async def test_check_connection_generic_exception_returns_false( + monkeypatch, +) -> None: + provider = _make_provider() + + class FakeModels: + async def list(self): + raise ConnectionError("DNS resolution failed") + + fake_client = SimpleNamespace( + aio=SimpleNamespace(models=FakeModels()), + ) + monkeypatch.setattr(provider, "_client", lambda timeout=5: fake_client) + + ok, msg = await provider.check_connection(timeout=1.0) + + assert ok is False + assert "Unknown exception" in msg + + +# -- fetch_models ------------------------------------------------------------ + + +async def test_fetch_models_normalizes_and_deduplicates(monkeypatch) -> None: + provider = _make_provider() + rows = [ + SimpleNamespace( + name="models/gemini-2.5-flash", + display_name="Gemini 2.5 Flash", + ), + SimpleNamespace( + name="models/gemini-2.5-flash", + display_name="duplicate", + ), + SimpleNamespace( + name="models/gemini-2.5-pro", + display_name="", + ), + SimpleNamespace(name=" ", display_name="invalid"), + ] + + class FakeModels: + async def list(self): + return _AsyncIter(rows) + + fake_client = SimpleNamespace( + aio=SimpleNamespace(models=FakeModels()), + ) + monkeypatch.setattr(provider, "_client", lambda timeout=5: fake_client) + + models = await provider.fetch_models(timeout=3.0) + + assert [m.id for m in models] == ["gemini-2.5-flash", "gemini-2.5-pro"] + assert [m.name for m in models] == ["Gemini 2.5 Flash", "gemini-2.5-pro"] + assert provider.models == [] + + +async def test_fetch_models_api_error_returns_empty(monkeypatch) -> None: + provider = _make_provider() + + class FakeModels: + async def list(self): + raise genai_errors.APIError(500, {"error": "internal"}) + + fake_client = SimpleNamespace( + aio=SimpleNamespace(models=FakeModels()), + ) + monkeypatch.setattr(provider, "_client", lambda timeout=5: fake_client) + + models = await provider.fetch_models(timeout=3.0) + + assert models == [] + + +async def test_fetch_models_generic_exception_returns_empty( + monkeypatch, +) -> None: + provider = _make_provider() + + class FakeModels: + async def list(self): + raise OSError("network unreachable") + + fake_client = SimpleNamespace( + aio=SimpleNamespace(models=FakeModels()), + ) + monkeypatch.setattr(provider, "_client", lambda timeout=5: fake_client) + + models = await provider.fetch_models(timeout=3.0) + + assert models == [] + + +# -- check_model_connection --------------------------------------------------- + + +async def test_check_model_connection_success(monkeypatch) -> None: + provider = _make_provider() + captured: list[dict] = [] + + class FakeModels: + async def generate_content_stream(self, **kwargs): + captured.append(kwargs) + return _AsyncIter([]) + + fake_client = SimpleNamespace( + aio=SimpleNamespace(models=FakeModels()), + ) + monkeypatch.setattr(provider, "_client", lambda timeout=5: fake_client) + + ok, msg = await provider.check_model_connection( + "gemini-2.5-flash", + timeout=4.0, + ) + + assert ok is True + assert msg == "" + assert len(captured) == 1 + assert captured[0]["model"] == "gemini-2.5-flash" + assert captured[0]["contents"] == "ping" + + +async def test_check_model_connection_empty_model_id_returns_false() -> None: + provider = _make_provider() + + ok, msg = await provider.check_model_connection(" ", timeout=4.0) + + assert ok is False + assert msg == "Empty model ID" + + +async def test_check_model_connection_api_error_returns_false( + monkeypatch, +) -> None: + provider = _make_provider() + + class FakeModels: + async def generate_content_stream(self, **kwargs): + raise genai_errors.APIError(404, {"error": "not found"}) + + fake_client = SimpleNamespace( + aio=SimpleNamespace(models=FakeModels()), + ) + monkeypatch.setattr(provider, "_client", lambda timeout=5: fake_client) + + ok, msg = await provider.check_model_connection( + "gemini-2.5-flash", + timeout=4.0, + ) + + assert ok is False + assert "not reachable or usable" in msg + + +async def test_check_model_connection_generic_exception_returns_false( + monkeypatch, +) -> None: + provider = _make_provider() + + class FakeModels: + async def generate_content_stream(self, **kwargs): + raise TimeoutError("connection timed out") + + fake_client = SimpleNamespace( + aio=SimpleNamespace(models=FakeModels()), + ) + monkeypatch.setattr(provider, "_client", lambda timeout=5: fake_client) + + ok, msg = await provider.check_model_connection( + "gemini-2.5-flash", + timeout=4.0, + ) + + assert ok is False + assert "Unknown exception" in msg + + +# -- _normalize_models_payload ------------------------------------------------ + + +def test_normalize_models_strips_prefix_and_deduplicates() -> None: + rows = [ + SimpleNamespace( + name="models/gemini-2.5-flash", + display_name="Gemini 2.5 Flash", + ), + SimpleNamespace( + name="models/gemini-2.5-flash", + display_name="dup", + ), + SimpleNamespace( + name="gemini-2.0-flash", + display_name="No Prefix", + ), + ] + + models = GeminiProvider._normalize_models_payload(rows) + + assert [m.id for m in models] == ["gemini-2.5-flash", "gemini-2.0-flash"] + assert [m.name for m in models] == [ + "Gemini 2.5 Flash", + "No Prefix", + ] + + +def test_normalize_models_empty_and_none() -> None: + assert not GeminiProvider._normalize_models_payload(None) + assert not GeminiProvider._normalize_models_payload([]) + + +def test_normalize_models_display_name_with_models_prefix() -> None: + rows = [ + SimpleNamespace( + name="models/gemini-2.5-pro", + display_name="models/gemini-2.5-pro", + ), + ] + + models = GeminiProvider._normalize_models_payload(rows) + + assert models[0].id == "gemini-2.5-pro" + assert models[0].name == "gemini-2.5-pro" + + +# -- update_config ------------------------------------------------------------ + + +async def test_update_config_updates_non_none_values() -> None: + provider = _make_provider() + + provider.update_config( + { + "name": "Gemini Custom", + "base_url": "https://new.example", + "api_key": "gem-new", + "chat_model": "GeminiChatModel", + "api_key_prefix": "gem-", + "generate_kwargs": {"temperature": 0.5}, + }, + ) + + info = await provider.get_info(mock_secret=False) + + assert provider.name == "Gemini Custom" + assert provider.api_key == "gem-new" + assert provider.generate_kwargs == {"temperature": 0.5} + assert info.name == "Gemini Custom" + assert info.api_key == "gem-new" + + +async def test_update_config_skips_none_values() -> None: + provider = _make_provider() + + provider.update_config( + { + "name": None, + "api_key": None, + }, + ) + + assert provider.name == "Gemini" + assert provider.api_key == "gem-test" diff --git a/website/public/docs/models.en.md b/website/public/docs/models.en.md index 702167f01..3341739c7 100644 --- a/website/public/docs/models.en.md +++ b/website/public/docs/models.en.md @@ -4,13 +4,13 @@ You need to configure a model before chatting with CoPaw. You can do this under ![Console models](https://img.alicdn.com/imgextra/i1/O1CN01zHAE1Z26w6jXl2xbr_!!6000000007725-2-tps-3802-1968.png) -CoPaw supports multiple LLM providers: **cloud providers** (require API Key), **local providers** (llama.cpp / MLX), **Ollama provider**, **LM Studio provider**, and you can add **custom providers**. This page explains how to configure each type. +CoPaw supports multiple LLM providers: **cloud providers** (require API Key, including Google Gemini), **local providers** (llama.cpp / MLX), **Ollama provider**, **LM Studio provider**, and you can add **custom providers**. This page explains how to configure each type. --- ## Configure cloud providers -Cloud providers (including ModelScope, DashScope, Aliyun Coding Plan, OpenAI, and Azure OpenAI) call remote models via API and require an **API Key**. +Cloud providers (including ModelScope, DashScope, Aliyun Coding Plan, OpenAI, Azure OpenAI, Google Gemini, and MiniMax) call remote models via API and require an **API Key**. **In the console:** @@ -35,6 +35,33 @@ Cloud providers (including ModelScope, DashScope, Aliyun Coding Plan, OpenAI, an > > ![cancel](https://img.alicdn.com/imgextra/i2/O1CN01A8j1IR1n8fHGnio0q_!!6000000005045-2-tps-3802-1968.png) +## Google Gemini provider + +The Google Gemini provider uses Google's native Gemini API (via the `google-genai` SDK) to access Gemini models. Pre-configured models include Gemini 3.1 Pro Preview, Gemini 3 Flash Preview, Gemini 3.1 Flash Lite Preview, Gemini 2.5 Pro, Gemini 2.5 Flash, Gemini 2.5 Flash Lite, and Gemini 2.0 Flash. Additional models can be auto-discovered from the API. + +**Prerequisites:** + +- Obtain a Gemini API key from [Google AI Studio](https://aistudio.google.com/apikey). + +**In the console:** + +1. Open the console and go to **Settings → Models**. +2. Find the **Google Gemini** provider card and click **Settings**. Enter your **API key** and click **Save**. +3. After saving, the card status becomes **Available**. The provider supports **model discovery** — click **Models** to auto-discover available Gemini models from the API. +4. In the **LLM Configuration** section at the top, select **Google Gemini** in the **Provider** dropdown and choose a model (e.g. `gemini-2.5-flash`), then click **Save**. + +**Using the CLI:** + +```bash +# Configure the API key +copaw models config-key gemini + +# Set Gemini as the active LLM +copaw models set-llm +``` + +> **Tip:** Gemini models with thinking capabilities (e.g. Gemini 3.1 Pro, Gemini 2.5 Pro, Gemini 2.5 Flash) support extended reasoning. CoPaw automatically handles thinking blocks and thought signatures from these models. + ## Local providers (llama.cpp / MLX) Local providers run models on your machine with **no API Key**; data stays on-device. diff --git a/website/public/docs/models.zh.md b/website/public/docs/models.zh.md index cfce57af6..bfac28411 100644 --- a/website/public/docs/models.zh.md +++ b/website/public/docs/models.zh.md @@ -4,13 +4,13 @@ ![控制台模型](https://img.alicdn.com/imgextra/i4/O1CN01XnOPPQ1c99vox3I88_!!6000000003557-2-tps-3786-1980.png) -CoPaw 支持多种 LLM 提供商:**云提供商**(需 API Key)、**本地提供商**(llama.cpp / MLX)、**Ollama 提供商**、**LM Studio 提供商**,且支持添加自定义 **提供商**。本文介绍这几类提供商的配置方式。 +CoPaw 支持多种 LLM 提供商:**云提供商**(需 API Key,包括 Google Gemini)、**本地提供商**(llama.cpp / MLX)、**Ollama 提供商**、**LM Studio 提供商**,且支持添加自定义 **提供商**。本文介绍这几类提供商的配置方式。 --- ## 配置云提供商 -云提供商(包括 ModelScope、DashScope、Aliyun Coding Plan、OpenAI 和 Azure OpenAI)通过 API 调用远程模型,需要配置 **API Key**。 +云提供商(包括 ModelScope、DashScope、Aliyun Coding Plan、OpenAI、Azure OpenAI、Google Gemini 和 MiniMax)通过 API 调用远程模型,需要配置 **API Key**。 **在控制台中配置:** @@ -35,6 +35,33 @@ CoPaw 支持多种 LLM 提供商:**云提供商**(需 API Key)、**本地 > > ![cancel](https://img.alicdn.com/imgextra/i2/O1CN01LM3rBG1MejNjEeXs1_!!6000000001460-2-tps-3412-1952.png) +## Google Gemini 提供商 + +Google Gemini 提供商通过 Google 原生 Gemini API(使用 `google-genai` SDK)访问 Gemini 模型。内置模型包括 Gemini 3.1 Pro Preview、Gemini 3 Flash Preview、Gemini 3.1 Flash Lite Preview、Gemini 2.5 Pro、Gemini 2.5 Flash、Gemini 2.5 Flash Lite 和 Gemini 2.0 Flash。还可通过 API 自动发现更多模型。 + +**前置条件:** + +- 从 [Google AI Studio](https://aistudio.google.com/apikey) 获取 Gemini API Key。 + +**在控制台中配置:** + +1. 打开控制台,进入 **设置 → 模型**。 +2. 找到 **Google Gemini** 提供商卡片,点击 **设置**。输入你的 **API Key**,点击 **保存**。 +3. 保存后卡片状态变为 **可用**。该提供商支持 **模型发现** — 点击 **模型** 可自动从 API 发现可用的 Gemini 模型。 +4. 在上方的 **LLM 配置** 中,**提供商** 下拉菜单选择 **Google Gemini**,**模型** 下拉菜单选择目标模型(如 `gemini-2.5-flash`),点击 **保存**。 + +**使用 CLI 配置:** + +```bash +# 配置 API Key +copaw models config-key gemini + +# 将 Gemini 设为活跃 LLM +copaw models set-llm +``` + +> **提示:** 具有思考能力的 Gemini 模型(如 Gemini 3.1 Pro、Gemini 2.5 Pro、Gemini 2.5 Flash)支持扩展推理。CoPaw 会自动处理这些模型返回的思考块和思考签名。 + ## 本地提供商(llama.cpp / MLX) 本地提供商在本地运行模型,**无需 API Key**,数据不出本机。