diff --git a/.gitignore b/.gitignore index c7c3e5f65..6228da445 100644 --- a/.gitignore +++ b/.gitignore @@ -89,3 +89,7 @@ CLAUDE.md # Build cache .cache/ # Includes conda_unpack_wheels/ for Windows packaging workaround +/assets/ +/console/ +src/copaw/console/ +src/copaw/assets/ diff --git a/console/src/api/modules/provider.ts b/console/src/api/modules/provider.ts index ee289dba3..060a96d60 100644 --- a/console/src/api/modules/provider.ts +++ b/console/src/api/modules/provider.ts @@ -10,6 +10,10 @@ import type { TestProviderRequest, TestModelRequest, DiscoverModelsResponse, + SeriesResponse, + DiscoverExtendedResponse, + FilterModelsRequest, + FilterModelsResponse, } from "../types"; export const providerApi = { @@ -87,4 +91,21 @@ export const providerApi = { body: body ? JSON.stringify(body) : undefined, }, ), + + /* ---- OpenRouter specific endpoints ---- */ + + getOpenRouterSeries: () => + request("/models/openrouter/series"), + + discoverOpenRouterExtended: (body?: TestProviderRequest) => + request("/models/openrouter/discover-extended", { + method: "POST", + body: body ? JSON.stringify(body) : undefined, + }), + + filterOpenRouterModels: (body: FilterModelsRequest) => + request("/models/openrouter/models/filter", { + method: "POST", + body: JSON.stringify(body), + }), }; diff --git a/console/src/api/types/provider.ts b/console/src/api/types/provider.ts index 1b31a5840..309b61290 100644 --- a/console/src/api/types/provider.ts +++ b/console/src/api/types/provider.ts @@ -126,6 +126,7 @@ export interface TestProviderRequest { base_url?: string; chat_model?: string; generate_kwargs?: Record; + include_extended?: boolean; } export interface TestModelRequest { @@ -138,3 +139,38 @@ export interface DiscoverModelsResponse { models: ModelInfo[]; added_count: number; } + +/* ---- OpenRouter extended model types ---- */ + +export interface ExtendedModelInfo { + id: string; + name: string; + provider: string; + input_modalities: string[]; + output_modalities: string[]; + pricing: Record; +} + +export interface FilterModelsRequest { + providers?: string[]; + input_modalities?: string[]; + output_modalities?: string[]; + max_prompt_price?: number; +} + +export interface SeriesResponse { + series: string[]; +} + +export interface DiscoverExtendedResponse { + success: boolean; + models: ExtendedModelInfo[]; + providers: string[]; + total_count: number; +} + +export interface FilterModelsResponse { + success: boolean; + models: ExtendedModelInfo[]; + total_count: number; +} diff --git a/console/src/locales/en.json b/console/src/locales/en.json index e64708475..6e8a05473 100644 --- a/console/src/locales/en.json +++ b/console/src/locales/en.json @@ -485,6 +485,19 @@ "testConnectionError": "An error occurred while testing connection", "discoverModels": "Discover Models", "discoverModelsFailed": "Failed to discover models", + "filterModels": "Filter Models", + "filterByProvider": "Provider:", + "filterByModality": "Input Modality:", + "getModels": "Get Models", + "discovered": "Available Models:", + "add": "Add", + "filteredModelsLoaded": "Filtered models loaded: {count}", + "filterFailed": "Failed to filter models", + "clearAll": "Clear All", + "clearAllModels": "Clear All Models", + "clearAllModelsConfirm": "Are you sure you want to remove all {count} added models? This action cannot be undone.", + "allModelsCleared": "Removed {count} models", + "modelList": "Models", "modelTestFailed": "Model validation failed, please check if the model ID is correct", "modelTestFailedConfirm": "Model connection test failed: {{message}}. Do you still want to add this model?", "autoDiscoveredAndAdded": "Auto-discovered {{count}} models and added {{added}} new model(s)", diff --git a/console/src/locales/ja.json b/console/src/locales/ja.json index 6a09faeed..7fb156495 100644 --- a/console/src/locales/ja.json +++ b/console/src/locales/ja.json @@ -479,6 +479,19 @@ "testConnectionError": "接続テスト中にエラーが発生しました", "discoverModels": "モデルを検出", "discoverModelsFailed": "モデルの検出に失敗しました", + "filterModels": "モデルフィルター", + "filterByProvider": "プロバイダー:", + "filterByModality": "入力モダリティ:", + "getModels": "モデルを取得", + "discovered": "利用可能なモデル:", + "add": "追加", + "filteredModelsLoaded": "フィルタリングされたモデル: {count}", + "filterFailed": "モデルのフィルタリングに失敗しました", + "clearAll": "すべてクリア", + "clearAllModels": "すべてのモデルを消去", + "clearAllModelsConfirm": "追加された{count}個のモデルをすべて削除してもよろしいですか?この操作は取り消せません。", + "allModelsCleared": "{count}個のモデルを削除しました", + "modelList": "モデル", "modelTestFailed": "モデルの検証に失敗しました。モデルIDが正しいか確認してください", "modelTestFailedConfirm": "モデルの接続テストに失敗しました: {{message}}。このモデルを追加しますか?", "autoDiscoveredAndAdded": "{{count}} 件のモデルを自動検出し、{{added}} 件の新規モデルを追加しました", diff --git a/console/src/locales/ru.json b/console/src/locales/ru.json index 7fcd25559..3abfb766a 100644 --- a/console/src/locales/ru.json +++ b/console/src/locales/ru.json @@ -484,6 +484,19 @@ "testConnectionError": "Произошла ошибка при проверке подключения", "discoverModels": "Обнаружить модели", "discoverModelsFailed": "Не удалось обнаружить модели", + "filterModels": "Фильтр моделей", + "filterByProvider": "Провайдер:", + "filterByModality": "Тип ввода:", + "getModels": "Получить модели", + "discovered": "Доступные модели:", + "add": "Добавить", + "filteredModelsLoaded": "Загружено моделей: {count}", + "filterFailed": "Ошибка фильтрации моделей", + "clearAll": "Очистить все", + "clearAllModels": "Очистить все модели", + "clearAllModelsConfirm": "Вы уверены, что хотите удалить все добавленные модели ({count})? Это действие нельзя отменить.", + "allModelsCleared": "Удалено моделей: {count}", + "modelList": "Модели", "modelTestFailed": "Проверка модели не пройдена, проверьте корректность ID модели", "modelTestFailedConfirm": "Тест подключения модели не пройден: {{message}}. Всё равно добавить эту модель?", "autoDiscoveredAndAdded": "Автоматически обнаружено {{count}} моделей и добавлено {{added}} новых", diff --git a/console/src/locales/zh.json b/console/src/locales/zh.json index 24995172f..d7e5f84df 100644 --- a/console/src/locales/zh.json +++ b/console/src/locales/zh.json @@ -485,6 +485,19 @@ "testConnectionError": "测试连接时发生错误", "discoverModels": "自动获取模型", "discoverModelsFailed": "自动获取模型失败", + "filterModels": "模型过滤器", + "filterByProvider": "提供商:", + "filterByModality": "输入模式:", + "getModels": "获取模型", + "discovered": "可用模型:", + "add": "添加", + "filteredModelsLoaded": "已加载模型: {count}", + "filterFailed": "模型过滤失败", + "clearAll": "清空全部", + "clearAllModels": "清空所有模型", + "clearAllModelsConfirm": "确定要删除所有 {count} 个已添加的模型吗?此操作无法撤销。", + "allModelsCleared": "已删除 {count} 个模型", + "modelList": "模型列表", "modelTestFailed": "模型验证失败,请检查模型ID是否正确", "modelTestFailedConfirm": "模型连接测试失败:{{message}}。是否仍要添加此模型?", "autoDiscoveredAndAdded": "已自动获取 {{count}} 个模型,并新增 {{added}} 个到可选列表", diff --git a/console/src/pages/Settings/Models/components/modals/RemoteModelManageModal.tsx b/console/src/pages/Settings/Models/components/modals/RemoteModelManageModal.tsx index e4cdb1196..08f342c50 100644 --- a/console/src/pages/Settings/Models/components/modals/RemoteModelManageModal.tsx +++ b/console/src/pages/Settings/Models/components/modals/RemoteModelManageModal.tsx @@ -6,14 +6,26 @@ import { Modal, Tag, message, + Checkbox, + Collapse, } from "@agentscope-ai/design"; import { DeleteOutlined, PlusOutlined, ApiOutlined, SyncOutlined, + FilterOutlined, + ClearOutlined, } from "@ant-design/icons"; -import type { ProviderInfo } from "../../../../../api/types"; +import { + SparkTextLine, + SparkImageuploadLine, + SparkAudiouploadLine, + SparkVideouploadLine, + SparkFilePdfLine, + SparkTextImageLine, +} from "@agentscope-ai/icons"; +import type { ProviderInfo, SeriesResponse } from "../../../../../api/types"; import api from "../../../../../api"; import { useTranslation } from "react-i18next"; import styles from "../../index.module.less"; @@ -37,7 +49,83 @@ export function RemoteModelManageModal({ const [discovering, setDiscovering] = useState(false); const [testingModelId, setTestingModelId] = useState(null); const [form] = Form.useForm(); - const canDiscover = provider.support_model_discovery; + + // OpenRouter filter state + const isOpenRouter = provider.id === "openrouter"; + const [showFilters, setShowFilters] = useState(false); + const [availableSeries, setAvailableSeries] = useState([]); + const [discoveredModels, setDiscoveredModels] = useState([]); + const [selectedSeries, setSelectedSeries] = useState([]); + const [selectedInputModality, setSelectedInputModality] = useState< + string | null + >(null); + const [loadingFilters, setLoadingFilters] = useState(false); + + // Enable discover for providers that support it + // For local providers (ollama, llama.cpp, mlx) - check base_url + // For built-in providers with frozen URL - check api_key + const canDiscover = provider.is_local + ? !!provider.base_url + : !!provider.api_key; + + // Load available series for OpenRouter + useEffect(() => { + if (isOpenRouter && canDiscover) { + api + .getOpenRouterSeries() + .then((res: SeriesResponse) => { + setAvailableSeries(res.series || []); + }) + .catch(() => { + setAvailableSeries([]); + }); + } + }, [isOpenRouter, canDiscover]); + + // Fetch models with current filters + const handleFetchModels = async () => { + if (!isOpenRouter) return; + + setLoadingFilters(true); + try { + const filterBody: Record = {}; + if (selectedSeries.length > 0) { + filterBody.providers = selectedSeries; + } + if (selectedInputModality) { + filterBody.input_modalities = [selectedInputModality]; + } + + const result = await api.filterOpenRouterModels(filterBody); + if (result.success) { + setDiscoveredModels(result.models || []); + message.success( + t("models.filteredModelsLoaded", { count: result.total_count }), + ); + } else { + message.error(t("models.filterFailed")); + } + } catch { + message.error(t("models.filterFailed")); + } finally { + setLoadingFilters(false); + } + }; + + const handleAddFilteredModel = async (model: any) => { + setSaving(true); + try { + await api.addModel(provider.id, { id: model.id, name: model.name }); + message.success(t("models.modelAdded", { name: model.name })); + await onSaved(); + // Remove from discovered list + setDiscoveredModels((prev) => prev.filter((m) => m.id !== model.id)); + } catch { + message.error(t("models.modelAddFailed")); + } finally { + setSaving(false); + } + }; // For custom providers ALL models are deletable. // For built-in providers only extra_models are deletable. @@ -151,6 +239,41 @@ export function RemoteModelManageModal({ }); }; + const handleClearAllModels = () => { + const extraModels = provider.extra_models || []; + if (extraModels.length === 0) return; + + Modal.confirm({ + title: t("models.clearAllModels"), + content: t("models.clearAllModelsConfirm", { + count: extraModels.length, + }), + okText: t("common.delete"), + okButtonProps: { danger: true }, + cancelText: t("models.cancel"), + onOk: async () => { + setSaving(true); + try { + for (const model of extraModels) { + await api.removeModel(provider.id, model.id); + } + message.success( + t("models.allModelsCleared", { count: extraModels.length }), + ); + await onSaved(); + } catch (error) { + const errMsg = + error instanceof Error + ? error.message + : t("models.modelRemoveFailed"); + message.error(errMsg); + } finally { + setSaving(false); + } + }, + }); + }; + const handleClose = () => { setAdding(false); form.resetFields(); @@ -218,71 +341,379 @@ export function RemoteModelManageModal({ width={560} destroyOnHidden > - {/* Model list */} -
- {all_models.length === 0 ? ( -
{t("models.noModels")}
- ) : ( - all_models.map((m) => { - const isDeletable = extraModelIds.has(m.id); - return ( -
-
- {m.name} - {m.id} + {/* Model list - collapsible */} + + {t("models.modelList")} ({all_models.length}) + + ), + extra: + (provider.extra_models?.length ?? 0) > 0 ? ( + + ) : null, + children: ( + <> + {all_models.length === 0 ? ( +
+ {t("models.noModels")} +
+ ) : ( + all_models.map((m) => { + const isDeletable = extraModelIds.has(m.id); + // Check if it's an extended model (has input_modalities) + const hasExtendedInfo = (m as any).input_modalities; + return ( +
+
+ + {m.name} + + + {m.id} + {/* Show modalities and price for extended models */} + {hasExtendedInfo && ( + + {(m as any).input_modalities?.includes( + "text", + ) && } + {(m as any).input_modalities?.includes( + "image", + ) && ( + + )} + {(m as any).input_modalities?.includes( + "audio", + ) && ( + + )} + {(m as any).input_modalities?.includes( + "video", + ) && ( + + )} + {(m as any).input_modalities?.includes( + "file", + ) && ( + + )} + {(m as any).output_modalities?.includes( + "image", + ) && ( + + )} + {(m as any).pricing?.prompt && ( + + $ + {( + (m as any).pricing.prompt * 1_000_000 + ).toFixed(2)} + /1M in + {(m as any).pricing.completion && ( + + {" "} + · $ + {( + (m as any).pricing.completion * + 1_000_000 + ).toFixed(2)} + /1M out + + )} + + )} + + )} + +
+
+ {isDeletable ? ( + <> + + {t("models.userAdded")} + + + + + )} +
+
+ ); + }) + )} + + ), + }, + ]} + /> + + {/* OpenRouter Filter Section */} + {isOpenRouter && ( +
+ + + {showFilters && ( +
+ {/* Provider/Series Filter */} +
+
+ {t("models.filterByProvider") || "Provider:"}
-
- {isDeletable ? ( - <> - - {t("models.userAdded")} - - -
+ + {/* Input Modality Filter */} +
+
+ {t("models.filterByModality") || "Input Modality:"} +
+ + Vision (image) + + ), + value: "image", + }, + { + label: ( + <> + Audio + + ), + value: "audio", + }, + { + label: ( + <> + Video + + ), + value: "video", + }, + { + label: ( + <> + File + + ), + value: "file", + }, + { + label: ( + <> + Text only + + ), + value: "text", + }, + ]} + value={selectedInputModality ? [selectedInputModality] : []} + onChange={(vals) => + setSelectedInputModality( + vals.length > 0 ? (vals[0] as string) : null, + ) + } + style={{ display: "flex", flexWrap: "wrap", gap: 8 }} + /> +
+ + {/* Fetch Button */} + + + {/* Discovered Models List */} + {discoveredModels.length > 0 && ( +
+
+ {t("models.discovered") || "Available Models:"} +
+ {discoveredModels.map((model: any) => ( +
+
+
{model.name}
+
+ {model.provider} + {/* Input modalities icons */} + {model.input_modalities?.includes("text") && ( + + )} + {model.input_modalities?.includes("image") && ( + + )} + {model.input_modalities?.includes("audio") && ( + + )} + {model.input_modalities?.includes("video") && ( + + )} + {model.input_modalities?.includes("file") && ( + + )} + {model.output_modalities?.includes("image") && ( + + )} + {/* Price */} + {model.pricing?.prompt && ( + + $ + {( + parseFloat(model.pricing.prompt) * 1_000_000 + ).toFixed(2)} + /1M in + {model.pricing?.completion && ( + + {" "} + · $ + {( + parseFloat(model.pricing.completion) * + 1_000_000 + ).toFixed(2)} + /1M out + + )} + + )} +
+
- - )} +
+ ))}
-
- ); - }) - )} -
+ )} +
+ )} +
+ )} {/* Add model section */} {adding ? ( diff --git a/plans/openrouter_enhancement_plan.md b/plans/openrouter_enhancement_plan.md new file mode 100644 index 000000000..cfab26569 --- /dev/null +++ b/plans/openrouter_enhancement_plan.md @@ -0,0 +1,180 @@ +# OpenRouter Provider Enhancement Plan + +## Overview + +This plan outlines the improvements needed for the OpenRouter provider to support: +1. Loading models from OpenRouter API with proper metadata +2. Filtering models by provider/series (e.g., Google, Anthropic, OpenAI) +3. Extracting model name after the slash (e.g., "gpt-4o" from "openai/gpt-4o") +4. Supporting modality filtering (input/output) +5. Adding UI filters in frontend for selecting series/providers + +## Current Implementation Analysis + +### Files to Modify + +| File | Purpose | +|------|---------| +| [`src/copaw/providers/provider.py`](src/copaw/providers/provider.py) | Base `ModelInfo` class - needs extended metadata fields | +| [`src/copaw/providers/openrouter_provider.py`](src/copaw/providers/openrouter_provider.py) | Main OpenRouter provider - needs enhanced model fetching | +| [`src/copaw/providers/provider_manager.py`](src/copaw/providers/provider_manager.py) | Provider manager - needs new methods for filtering | +| [`src/copaw/app/routers/providers.py`](src/copaw/app/routers/providers.py) | API routes - needs new endpoints | + +--- + +## Implementation Steps + +### Step 1: Extend ModelInfo with Extended Metadata + +Create an `ExtendedModelInfo` class that includes: +- `id`: Model ID (e.g., "openai/gpt-4o") +- `name`: Human-readable name (e.g., "gpt-4o") +- `provider`: Provider/series (e.g., "openai", "google", "anthropic") +- `input_modalities`: List of supported input types (text, image, audio, video, file) +- `output_modalities`: List of supported output types (text, image, audio) +- `pricing`: Pricing information (prompt/completion costs) + +```python +# New class in provider.py or models.py +class ExtendedModelInfo(ModelInfo): + provider: str = "" # e.g., "openai", "google" + input_modalities: List[str] = Field(default_factory=list) + output_modalities: List[str] = Field(default_factory=list) + pricing: Dict[str, str] = Field(default_factory=dict) # {"prompt": "0.000005", "completion": "0.000015"} +``` + +### Step 2: Enhance OpenRouterProvider + +Modify [`openrouter_provider.py`](src/copaw/providers/openrouter_provider.py): + +1. **Update `_normalize_models_payload`** to: + - Extract provider from model ID (part before `/`) + - Extract model name from model ID (part after `/`) + - Store input/output modalities + - Store pricing information + +2. **Add new methods**: + - `fetch_extended_models()`: Fetch all models with full metadata + - `get_available_providers()`: Get list of unique providers from loaded models + +```python +@staticmethod +def _extract_provider(model_id: str) -> str: + """Extract provider from model ID (e.g., 'openai' from 'openai/gpt-4o')""" + return model_id.split("/")[0] if "/" in model_id else "" + +@staticmethod +def _extract_model_name(model_id: str) -> str: + """Extract model name from model ID (e.g., 'gpt-4o' from 'openai/gpt-4o')""" + return model_id.split("/")[-1] if "/" in model_id else model_id +``` + +### Step 3: Enhance ProviderManager + +Modify [`provider_manager.py`](src/copaw/providers/provider_manager.py): + +Add new methods: +- `fetch_provider_models_extended()`: Fetch models with extended metadata +- `filter_models_by_provider()`: Filter models by selected providers +- `filter_models_by_modalities()`: Filter models by input/output modalities +- `get_available_series()`: Get unique providers/series from OpenRouter models + +### Step 4: Add New API Endpoints + +Modify [`providers.py`](src/copaw/app/routers/providers.py): + +Add new endpoints: + +``` +GET /providers/openrouter/series - Get available provider series +GET /providers/openrouter/models/filter - Get filtered models +POST /providers/openrouter/discover-extended - Discover models with full metadata +``` + +Request/Response models: +```python +class FilterModelsRequest(BaseModel): + providers: List[str] = [] # e.g., ["openai", "google", "anthropic"] + input_modalities: List[str] = [] # e.g., ["image"] + output_modalities: List[str] = [] # e.g., ["text"] + min_price: Optional[float] = None + max_price: Optional[float] = None + +class SeriesResponse(BaseModel): + series: List[str] = Field(default_factory=list) # e.g., ["openai", "google", "anthropic"] +``` + +### Step 5: Frontend Integration (Summary) + +The frontend needs to be updated to: +1. Add checkboxes for selecting providers/series +2. Add dropdown/checkbox for modality filtering +3. Add "Fetch Models" button to load filtered models + +Since the frontend is built separately, the backend API changes above will enable this functionality. + +--- + +## API Flow Diagram + +```mermaid +sequenceDiagram + participant User + participant Frontend + participant API + participant OpenRouterProvider + participant OpenRouterAPI + + User->>Frontend: Selects providers (Google, Anthropic) + Frontend->>API: GET /providers/openrouter/series + API->>OpenRouterProvider: get_available_providers() + OpenRouterProvider-->>API: ["openai", "google", "anthropic", ...] + API-->>Frontend: List of available series + + User->>Frontend: Clicks "Get Models" + Frontend->>API: POST /providers/openrouter/models/filter + API->>OpenRouterProvider: fetch_extended_models() + OpenRouterProvider->>OpenRouterAPI: GET /api/v1/models + OpenRouterAPI-->>OpenRouterProvider: Full model list + OpenRouterProvider-->>OpenRouterProvider: Filter by selected providers + OpenRouterProvider-->>API: Filtered models + API-->>Frontend: Filtered model list + + User->>Frontend: Selects model + Frontend->>API: POST /providers/{id}/models + API->>ProviderManager: add_model_to_provider() + ProviderManager-->>Frontend: Success +``` + +--- + +## Key Implementation Details + +### Model ID Parsing + +| Model ID | Provider | Model Name | +|----------|----------|------------| +| openai/gpt-4o | openai | gpt-4o | +| anthropic/claude-3.5-sonnet | anthropic | claude-3.5-sonnet | +| google/gemini-2.5-flash | google | gemini-2.5-flash | + +### Modality Types + +- **input_modalities**: text, image, audio, video, file +- **output_modalities**: text, image, audio + +### Price Filtering + +Price is stored as string in API response (e.g., "0.000005" per token). +For filtering, convert to float and compare against threshold (e.g., $1/1M tokens = 0.000001). + +--- + +## Testing Considerations + +1. Test model fetching with valid API key +2. Test filtering by single provider +3. Test filtering by multiple providers +4. Test modality filtering (e.g., vision models with image input) +5. Test price filtering +6. Test edge cases (empty results, invalid provider names) diff --git a/src/copaw/app/routers/providers.py b/src/copaw/app/routers/providers.py index 1fba69b5b..f0ef3877c 100644 --- a/src/copaw/app/routers/providers.py +++ b/src/copaw/app/routers/providers.py @@ -11,6 +11,7 @@ from ...providers.provider import ProviderInfo, ModelInfo from ...providers.provider_manager import ActiveModelsInfo, ProviderManager +from ...providers.openrouter_provider import OpenRouterProvider router = APIRouter(prefix="/models", tags=["models"]) @@ -174,6 +175,10 @@ class DiscoverModelsRequest(BaseModel): default=None, description="Optional chat model class to use for discovery", ) + include_extended: bool = Field( + default=False, + description="Include extended metadata for OpenRouter", + ) class DiscoverModelsResponse(BaseModel): @@ -248,10 +253,36 @@ async def discover_models( detail=f"Provider '{provider_id}' not found", ) try: - result = await manager.fetch_provider_models( - provider_id, - ) - success = True + # Check if we should fetch extended info (for OpenRouter) + include_extended = body.include_extended if body else False + + # For OpenRouter, use fetch_extended_models if requested + if provider_id == "openrouter" and include_extended: + provider = manager.get_provider(provider_id) + if provider and isinstance(provider, OpenRouterProvider): + result = await provider.fetch_extended_models() + # Add models to provider + for model in result: + await provider.add_model( + model, + target="extra_models", + ) + # pylint: disable=protected-access + manager._save_provider( + provider, + is_builtin=True, + ) + success = True + else: + result = await manager.fetch_provider_models( + provider_id, + ) + success = True + else: + result = await manager.fetch_provider_models( + provider_id, + ) + success = True except Exception: result = [] success = False @@ -378,3 +409,243 @@ async def set_active_model( # Invalid model, unreachable provider, or other configuration error raise HTTPException(status_code=400, detail=message) from exc return ActiveModelsInfo(active_llm=manager.get_active_model()) + + +# ============================================================================= +# OpenRouter-specific endpoints for model discovery with filtering +# ============================================================================= + + +class FilterModelsRequest(BaseModel): + """Request model for filtering OpenRouter models.""" + + providers: List[str] = Field( + default_factory=list, + description="Filter by provider/series (e.g., ['openai', 'google'])", + ) + input_modalities: List[str] = Field( + default_factory=list, + description="Required input modalities (e.g., ['image'])", + ) + output_modalities: List[str] = Field( + default_factory=list, + description="Required output modalities (e.g., ['text'])", + ) + max_prompt_price: Optional[float] = Field( + default=None, + description="Maximum prompt price per 1M tokens (e.g., 0.000001)", + ) + + +class SeriesResponse(BaseModel): + """Response model for available series/providers.""" + + series: List[str] = Field( + default_factory=list, + description="Provider series (e.g., ['openai', 'google'])", + ) + + +class DiscoverExtendedResponse(BaseModel): + """Response model for extended model discovery.""" + + success: bool = Field(..., description="Whether discovery succeeded") + models: List[dict] = Field( + default_factory=list, + description="Discovered models with extended metadata", + ) + providers: List[str] = Field( + default_factory=list, + description="Available provider series", + ) + total_count: int = Field( + default=0, + description="Total number of models discovered", + ) + + +class FilterModelsResponse(BaseModel): + """Response model for filtered models.""" + + success: bool = Field(..., description="Whether filtering succeeded") + models: List[dict] = Field( + default_factory=list, + description="Filtered models with extended metadata", + ) + total_count: int = Field( + default=0, + description="Total number of models matching filters", + ) + + +@router.get( + "/openrouter/series", + response_model=SeriesResponse, + summary="Get available OpenRouter provider series", +) +async def get_openrouter_series( + manager: ProviderManager = Depends(get_provider_manager), +) -> SeriesResponse: + """Get list of available provider/series from OpenRouter. + + This endpoint fetches all available models from OpenRouter and returns + the unique provider/series names (e.g., 'openai', 'google', 'anthropic'). + """ + provider = manager.get_provider("openrouter") + if provider is None: + raise HTTPException( + status_code=404, + detail="OpenRouter provider not found", + ) + + if not isinstance(provider, OpenRouterProvider): + raise HTTPException( + status_code=400, + detail="Provider is not an OpenRouter provider", + ) + + try: + series = await provider.get_available_providers() + return SeriesResponse(series=series) + except Exception as exc: + raise HTTPException( + status_code=500, + detail=f"Failed to fetch series: {str(exc)}", + ) from exc + + +@router.post( + "/openrouter/discover-extended", + response_model=DiscoverExtendedResponse, + summary="Discover OpenRouter models with extended metadata", +) +async def discover_openrouter_extended( + manager: ProviderManager = Depends(get_provider_manager), + body: Optional[DiscoverModelsRequest] = Body(default=None), +) -> DiscoverExtendedResponse: + """Discover available models from OpenRouter with full metadata. + + This endpoint fetches all available models with extended information + including provider, modalities, and pricing. + """ + provider = manager.get_provider("openrouter") + if provider is None: + raise HTTPException( + status_code=404, + detail="OpenRouter provider not found", + ) + + if not isinstance(provider, OpenRouterProvider): + raise HTTPException( + status_code=400, + detail="Provider is not an OpenRouter provider", + ) + + # Update provider config if API key provided + if body and body.api_key: + manager.update_provider("openrouter", {"api_key": body.api_key}) + + try: + models = await provider.fetch_extended_models() + + # Get available providers + providers = await provider.get_available_providers() + + # Convert to dict for JSON response + models_dict = [ + { + "id": m.id, + "name": m.name, + "provider": m.provider, + "input_modalities": m.input_modalities, + "output_modalities": m.output_modalities, + "pricing": m.pricing, + } + for m in models + ] + + return DiscoverExtendedResponse( + success=True, + models=models_dict, + providers=providers, + total_count=len(models_dict), + ) + except Exception: + return DiscoverExtendedResponse( + success=False, + models=[], + providers=[], + total_count=0, + ) + + +@router.post( + "/openrouter/models/filter", + response_model=FilterModelsResponse, + summary="Filter OpenRouter models by criteria", +) +async def filter_openrouter_models( + manager: ProviderManager = Depends(get_provider_manager), + body: FilterModelsRequest = Body(...), +) -> FilterModelsResponse: + """Filter OpenRouter models by provider, modalities, and price. + + This endpoint fetches models and applies the specified filters: + - providers: Filter by provider/series (e.g., ['openai', 'google']) + - input_modalities: Required input types (e.g., ['image']) + - output_modalities: Required output types (e.g., ['text']) + - max_prompt_price: Maximum price per 1M input tokens + """ + provider = manager.get_provider("openrouter") + if provider is None: + raise HTTPException( + status_code=404, + detail="OpenRouter provider not found", + ) + + if not isinstance(provider, OpenRouterProvider): + raise HTTPException( + status_code=400, + detail="Provider is not an OpenRouter provider", + ) + + try: + # Fetch all extended models + models = await provider.fetch_extended_models() + + # Apply filters + filtered_models = provider.filter_models( + models=models, + providers=body.providers if body.providers else None, + input_modalities=( + body.input_modalities if body.input_modalities else None + ), + output_modalities=( + body.output_modalities if body.output_modalities else None + ), + max_prompt_price=body.max_prompt_price, + ) + + # Convert to dict for JSON response + models_dict = [ + { + "id": m.id, + "name": m.name, + "provider": m.provider, + "input_modalities": m.input_modalities, + "output_modalities": m.output_modalities, + "pricing": m.pricing, + } + for m in filtered_models + ] + + return FilterModelsResponse( + success=True, + models=models_dict, + total_count=len(models_dict), + ) + except Exception as exc: + raise HTTPException( + status_code=500, + detail=f"Failed to filter models: {str(exc)}", + ) from exc diff --git a/src/copaw/providers/openrouter_provider.py b/src/copaw/providers/openrouter_provider.py new file mode 100644 index 000000000..4e0a65006 --- /dev/null +++ b/src/copaw/providers/openrouter_provider.py @@ -0,0 +1,297 @@ +# -*- coding: utf-8 -*- +"""An OpenRouter provider implementation.""" + +from __future__ import annotations + +from typing import Any, List, Optional + +from agentscope.model import ChatModelBase +from openai import APIError, AsyncOpenAI + +from copaw.providers.provider import ( + ExtendedModelInfo, + ModelInfo, + Provider, +) + + +class OpenRouterProvider(Provider): + """OpenRouter provider with required HTTP-Referer and X-Title headers.""" + + _DEFAULT_HEADERS = { + "HTTP-Referer": "https://copaw.ai", + "X-Title": "CoPaw", + } + + def _client(self, timeout: float = 30) -> AsyncOpenAI: + return AsyncOpenAI( + base_url=self.base_url, + api_key=self.api_key, + timeout=timeout, + default_headers=self._DEFAULT_HEADERS, + ) + + @staticmethod + def _extract_provider(model_id: str) -> str: + """Extract provider from model ID. + + Examples: + 'openai/gpt-4o' -> 'openai' + 'anthropic/claude-3.5-sonnet' -> 'anthropic' + 'google/gemini-2.5-flash' -> 'google' + 'gpt-4o' -> 'gpt-4o' (no provider prefix) + """ + if "/" in model_id: + return model_id.split("/")[0] + return "" + + @staticmethod + def _extract_model_name(model_id: str) -> str: + """Extract model name from model ID (part after the slash). + + Examples: + 'openai/gpt-4o' -> 'gpt-4o' + 'anthropic/claude-3.5-sonnet' -> 'claude-3.5-sonnet' + 'google/gemini-2.5-flash' -> 'gemini-2.5-flash' + 'gpt-4o' -> 'gpt-4o' (no change if no slash) + """ + if "/" in model_id: + return model_id.split("/")[-1] + return model_id + + @staticmethod + def _normalize_models_payload( + payload: Any, + include_extended: bool = False, + ) -> List[ModelInfo] | List[ExtendedModelInfo]: + """Normalize the models payload from OpenRouter API. + + Args: + payload: The raw API response payload + include_extended: If True, return ExtendedModelInfo with metadata + + Returns: + List of ModelInfo or ExtendedModelInfo objects + """ + models: dict[str, ModelInfo | ExtendedModelInfo] = {} + # payload is an OpenAI AsyncPage object with .data attribute + rows = getattr(payload, "data", []) or [] + for row in rows: + # row is an OpenAI Model object, use getattr for attributes + model_id = str(getattr(row, "id", "") or "").strip() + if not model_id: + continue + + # Extract provider from model ID + provider = OpenRouterProvider._extract_provider(model_id) + + # Extract model name (part after slash, or full ID if no slash) + model_name = OpenRouterProvider._extract_model_name(model_id) + + # Use name attr if no slash in model_id + attr_name = str(getattr(row, "name", "") or "").strip() + if attr_name and "/" not in model_id: + model_name = attr_name + + # Deduplication: keep first occurrence by model_id + if model_id not in models: + if include_extended: + # Get architecture and pricing from the API response + # These are dict attributes of the Model object + architecture = getattr(row, "architecture", None) or {} + pricing = getattr(row, "pricing", None) or {} + + # Extract modalities from architecture dict + arch_input = architecture.get("input_modalities", []) + arch_output = architecture.get("output_modalities", []) + input_modalities = list(arch_input) if arch_input else [] + output_modalities = ( + list(arch_output) if arch_output else [] + ) + + # Convert pricing to dict + pricing_dict = {} + if pricing: + if isinstance(pricing, dict): + pricing_dict = { + k: str(v) if v is not None else "0" + for k, v in pricing.items() + } + + models[model_id] = ExtendedModelInfo( + id=model_id, + name=model_name, + provider=provider, + input_modalities=input_modalities, + output_modalities=output_modalities, + pricing=pricing_dict, + ) + else: + models[model_id] = ModelInfo(id=model_id, name=model_name) + + return list(models.values()) + + async def check_connection(self, timeout: float = 30) -> bool: + """Check if OpenRouter provider is reachable.""" + client = self._client() + try: + await client.models.list(timeout=timeout) + return True + except APIError: + return False + + async def fetch_models( + self, + timeout: float = 30, + include_extended: bool = False, + ) -> List[ModelInfo]: + """Fetch available models. + + Args: + timeout: Request timeout in seconds + include_extended: If True, fetch extended model info with + modalities and pricing + + Returns: + List of ModelInfo (or ExtendedModelInfo if include_extended=True) + """ + try: + client = self._client(timeout=timeout) + payload = await client.models.list(timeout=timeout) + models = self._normalize_models_payload( + payload, + include_extended=include_extended, + ) + return models + except APIError: + return [] + + async def fetch_extended_models( + self, + timeout: float = 30, + ) -> List[ExtendedModelInfo]: + """Fetch available models with extended metadata. + + This method fetches models with full information including + provider, modalities, and pricing. + + Args: + timeout: Request timeout in seconds + + Returns: + List of ExtendedModelInfo objects + """ + return await self.fetch_models( + timeout=timeout, + include_extended=True, + ) # type: ignore + + def filter_models( + self, + models: List[ExtendedModelInfo], + providers: Optional[List[str]] = None, + input_modalities: Optional[List[str]] = None, + output_modalities: Optional[List[str]] = None, + max_prompt_price: Optional[float] = None, + ) -> List[ExtendedModelInfo]: + """Filter models by given criteria. + + Args: + models: List of models to filter + providers: Filter by provider/series (e.g., ["openai", "google"]) + input_modalities: Required input modalities (e.g., ["image"]) + output_modalities: Required output modalities (e.g., ["text"]) + max_prompt_price: Maximum prompt price per 1M tokens + + Returns: + Filtered list of models + """ + result = models + + # Filter by providers + if providers: + providers_lower = [p.lower() for p in providers] + result = [ + m for m in result if m.provider.lower() in providers_lower + ] + + # Filter by input modalities + if input_modalities: + result = [ + m + for m in result + if any(mod in m.input_modalities for mod in input_modalities) + ] + + # Filter by output modalities + if output_modalities: + result = [ + m + for m in result + if any(mod in m.output_modalities for mod in output_modalities) + ] + + # Filter by max prompt price + if max_prompt_price is not None: + result = [ + m + for m in result + if m.pricing.get("prompt") + and float(m.pricing.get("prompt", "0")) <= max_prompt_price + ] + + return result + + async def get_available_providers( + self, + timeout: float = 30, + ) -> List[str]: + """Get list of available providers/series from OpenRouter. + + Args: + timeout: Request timeout in seconds + + Returns: + List of unique provider names (e.g., ['openai', 'google']) + """ + models = await self.fetch_extended_models(timeout=timeout) + providers_set = set() + for model in models: + if model.provider: + providers_set.add(model.provider) + return sorted(list(providers_set)) + + async def check_model_connection( + self, + model_id: str, + timeout: float = 30, + ) -> bool: + """Check if a specific model is reachable/usable""" + try: + client = self._client(timeout=timeout) + res = await client.chat.completions.create( + model=model_id, + messages=[{"role": "user", "content": "ping"}], + timeout=timeout, + max_tokens=1, + stream=True, + ) + # consume the stream to ensure the model is actually responsive + async for _ in res: + break + return True + except APIError: + return False + + def get_chat_model_instance(self, model_id: str) -> ChatModelBase: + from .openai_chat_model_compat import OpenAIChatModelCompat + + return OpenAIChatModelCompat( + model_name=model_id, + stream=True, + api_key=self.api_key, + client_kwargs={ + "base_url": self.base_url, + "default_headers": self._DEFAULT_HEADERS, + }, + ) diff --git a/src/copaw/providers/provider.py b/src/copaw/providers/provider.py index d8d449e74..73e3dd894 100644 --- a/src/copaw/providers/provider.py +++ b/src/copaw/providers/provider.py @@ -13,6 +13,27 @@ class ModelInfo(BaseModel): name: str = Field(..., description="Human-readable model name") +class ExtendedModelInfo(ModelInfo): + """Extended model info with additional metadata for providers.""" + + provider: str = Field( + default="", + description="Provider/series (e.g., 'openai', 'google')", + ) + input_modalities: List[str] = Field( + default_factory=list, + description="Supported input modalities", + ) + output_modalities: List[str] = Field( + default_factory=list, + description="Supported output modalities", + ) + pricing: Dict[str, str] = Field( + default_factory=dict, + description="Pricing info (prompt/completion)", + ) + + class ProviderInfo(BaseModel): id: str = Field(..., description="Provider identifier") name: str = Field(..., description="Human-readable provider name") diff --git a/src/copaw/providers/provider_manager.py b/src/copaw/providers/provider_manager.py index 0d137735f..da57183dc 100644 --- a/src/copaw/providers/provider_manager.py +++ b/src/copaw/providers/provider_manager.py @@ -22,6 +22,7 @@ from copaw.providers.openai_provider import OpenAIProvider from copaw.providers.anthropic_provider import AnthropicProvider from copaw.providers.ollama_provider import OllamaProvider +from copaw.providers.openrouter_provider import OpenRouterProvider from copaw.constant import SECRET_DIR from copaw.local_models import create_local_chat_model @@ -92,6 +93,9 @@ ANTHROPIC_MODELS: List[ModelInfo] = [] +# OpenRouter models are fetched dynamically - no hardcoded list +OPENROUTER_MODELS: List[ModelInfo] = [] + PROVIDER_MODELSCOPE = OpenAIProvider( id="modelscope", name="ModelScope", @@ -169,6 +173,15 @@ freeze_url=True, ) +PROVIDER_OPENROUTER = OpenRouterProvider( + id="openrouter", + name="OpenRouter", + base_url="https://openrouter.ai/api/v1", + api_key_prefix="sk-or-v1-", + models=OPENROUTER_MODELS, + freeze_url=True, +) + PROVIDER_OLLAMA = OllamaProvider( id="ollama", name="Ollama", @@ -244,6 +257,7 @@ def _init_builtins(self): self._add_builtin(PROVIDER_AZURE_OPENAI) self._add_builtin(PROVIDER_MINIMAX) self._add_builtin(PROVIDER_ANTHROPIC) + self._add_builtin(PROVIDER_OPENROUTER) self._add_builtin(PROVIDER_OLLAMA) self._add_builtin(PROVIDER_LMSTUDIO) self._add_builtin(PROVIDER_LLAMACPP) @@ -452,6 +466,8 @@ def _provider_from_data(self, data: Dict) -> Provider: provider_id = str(data.get("id", "")) chat_model = str(data.get("chat_model", "")) + if provider_id == "openrouter": + return OpenRouterProvider.model_validate(data) if provider_id == "anthropic" or chat_model == "AnthropicChatModel": return AnthropicProvider.model_validate(data) if provider_id == "ollama":