diff --git a/.github/scripts/model-discovery.py b/.github/scripts/model-discovery.py new file mode 100755 index 000000000..d18a32a4b --- /dev/null +++ b/.github/scripts/model-discovery.py @@ -0,0 +1,638 @@ +#!/usr/bin/env python3 +"""Automated Vertex AI model discovery. + +Discovers models from Vertex AI publishers via the Model Garden list API, +filters by configured prefix patterns, resolves versions, probes each to +confirm availability, and updates the model manifest. Never removes models +— only adds new ones or updates the ``available`` / ``vertexId`` fields. + +New models matching a prefix are auto-discovered without code changes. +For example, if Anthropic releases ``claude-opus-4-7``, it will be picked +up automatically because it matches the ``claude-`` prefix under the +``anthropic`` publisher. + +Required env vars: + GCP_REGION - GCP region (e.g. us-east5) + GCP_PROJECT - GCP project ID + +Optional env vars: + GOOGLE_APPLICATION_CREDENTIALS - Path to SA key (uses ADC otherwise) + MANIFEST_PATH - Override default manifest location +""" + +import json +import os +import re +import subprocess +import sys +import time +import urllib.error +import urllib.parse +import urllib.request +from collections import defaultdict +from typing import NotRequired, TypedDict +from pathlib import Path + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + +DEFAULT_MANIFEST = ( + Path(__file__).resolve().parent.parent.parent + / "components" + / "manifests" + / "base" + / "models.json" +) + +# Keep only the N most recent versions per model family. +# e.g. claude-opus-4-6 and claude-opus-4-5 are kept, claude-opus-4-1 is dropped. +MAX_VERSIONS_PER_FAMILY = 2 + +# Model Garden list API pagination settings. +LIST_PAGE_SIZE = 100 +MAX_LIST_PAGES = 20 + + +# Publisher discovery configuration. +# prefixes: only models whose ID starts with one of these are included. +# exclude: model IDs matching these regex patterns are skipped (embeddings, +# image models, legacy versions, etc.). +class PublisherConfig(TypedDict): + publisher: str + provider: str + prefixes: list[str] + exclude: list[str] + version_cutoff: NotRequired[ + tuple[int, ...] + ] # models with version <= this are excluded + + +PUBLISHERS: list[PublisherConfig] = [ + { + "publisher": "anthropic", + "provider": "anthropic", + "prefixes": ["claude-"], + "exclude": [ + r"^claude-[a-z]+-\d+$", # base aliases without minor version (claude-opus-4) + ], + }, + { + "publisher": "google", + "provider": "google", + "prefixes": ["gemini-"], + "exclude": [ + r"-\d{3}$", # pinned versions like gemini-2.5-flash-001 + r"exp", # experimental models + r"embedding", + r"imagen", + r"veo", + r"chirp", + r"codey", + r"medlm", + ], + "version_cutoff": (2, 0), # exclude gemini 2.0 and older + }, +] + +# Fallback seed list used when the list API is unavailable. +# Once the list API works, this is only used for models it might miss. +SEED_MODELS: list[tuple[str, str, str]] = [ + ("claude-sonnet-4-6", "anthropic", "anthropic"), + ("claude-sonnet-4-5", "anthropic", "anthropic"), + ("claude-opus-4-6", "anthropic", "anthropic"), + ("claude-opus-4-5", "anthropic", "anthropic"), + ("claude-haiku-4-5", "anthropic", "anthropic"), + ("gemini-2.5-flash", "google", "google"), + ("gemini-2.5-pro", "google", "google"), +] + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def get_access_token() -> str: + """Get a GCP access token via gcloud.""" + try: + result = subprocess.run( + ["gcloud", "auth", "print-access-token"], + capture_output=True, + text=True, + check=True, + timeout=30, + ) + except subprocess.TimeoutExpired as err: + raise RuntimeError("Timed out getting GCP access token via gcloud") from err + except subprocess.CalledProcessError as err: + raise RuntimeError("Failed to get GCP access token via gcloud") from err + return result.stdout.strip() + + +def list_publisher_models(publisher: str, token: str) -> list[tuple[str, str | None]]: + """List models from the Model Garden for a publisher. + + Uses the v1beta1 API: GET /publishers/{publisher}/models + Returns a list of (model_id, version_id) tuples. version_id is the + versionId from the API response (e.g. "20250929") or None if absent. + Returns an empty list on failure (caller falls back to seed list). + """ + base_url = "https://aiplatform.googleapis.com/v1beta1" + all_models: list[tuple[str, str | None]] = [] + page_token = "" + + for _ in range(MAX_LIST_PAGES): + params = {"pageSize": str(LIST_PAGE_SIZE)} + if page_token: + params["pageToken"] = page_token + + url = ( + f"{base_url}/publishers/{urllib.parse.quote(publisher, safe='')}" + f"/models?{urllib.parse.urlencode(params)}" + ) + + data = None + last_err: Exception | None = None + for attempt in range(3): + req = urllib.request.Request( + url, + headers={"Authorization": f"Bearer {token}"}, + method="GET", + ) + try: + with urllib.request.urlopen(req, timeout=30) as resp: + data = json.loads(resp.read().decode()) + break + except urllib.error.HTTPError as e: + # Auth failures are fatal — don't fall back to seeds with bad credentials + if e.code in (401, 403): + raise RuntimeError( + f"list models for {publisher} failed (HTTP {e.code}): " + f"check GCP credentials and IAM permissions" + ) from e + # Not found — retrying won't help + if e.code == 404: + print( + f" WARNING: list models for {publisher} returned 404", + file=sys.stderr, + ) + return [] + last_err = e + except Exception as e: + last_err = e + + if attempt < 2: + time.sleep(2**attempt) + + if data is None: + print( + f" WARNING: list models for {publisher} failed after 3 attempts ({last_err})", + file=sys.stderr, + ) + return [] + + for model in data.get("publisherModels", []): + # name is like "publishers/google/models/gemini-2.5-flash" + name = model.get("name", "") + model_id = name.rsplit("/", 1)[-1] if "/" in name else name + version_id = model.get("versionId") + if model_id: + all_models.append((model_id, version_id)) + + page_token = data.get("nextPageToken", "") + if not page_token: + break + + return all_models + + +def discover_models( + token: str, manifest: dict[str, object] +) -> list[tuple[str, str, str, str | None]]: + """Discover models from all configured publishers. + + Queries the Model Garden list API for each publisher, filters by + prefix patterns, and excludes unwanted model types. Falls back to + the SEED_MODELS list for any publisher where the API fails. + + Provider default models (defaultModel + providerDefaults values) are + exempt from version limiting and always kept. + + Returns a deduplicated list of (model_id, publisher, provider, version_id) + tuples. version_id comes from the list API response and may be None for + seed models or when the API doesn't provide it. + """ + seen: set[str] = set() + result: list[tuple[str, str, str, str | None]] = [] + + # Collect per-publisher: (model_id, reason) for the summary table + publisher_log: list[tuple[str, list[tuple[str, str]]]] = [] + + for pub in PUBLISHERS: + publisher = pub["publisher"] + provider = pub["provider"] + prefixes = pub["prefixes"] + excludes = [re.compile(p) for p in pub["exclude"]] + min_ver = pub.get("version_cutoff") + + api_models = list_publisher_models(publisher, token) + log_entries: list[tuple[str, str]] = [] + + if api_models: + for model_id, version_id in sorted(api_models, key=lambda x: x[0]): + if not any(model_id.startswith(p) for p in prefixes): + log_entries.append((model_id, "SKIP (prefix)")) + continue + if any(pat.search(model_id) for pat in excludes): + log_entries.append((model_id, "EXCLUDE")) + continue + if min_ver: + _, parsed_ver = parse_model_family(model_id) + if parsed_ver and parsed_ver <= min_ver: + log_entries.append((model_id, "EXCLUDE (version)")) + continue + log_entries.append((model_id, "KEEP")) + if model_id not in seen: + seen.add(model_id) + result.append((model_id, publisher, provider, version_id)) + else: + print( + f" {publisher}: API unavailable, using seed list", + file=sys.stderr, + ) + + publisher_log.append((publisher, log_entries)) + + # Merge in seed models that weren't discovered by the API. + # Apply version_cutoff so seed models respect the same filtering as API models. + pub_by_name = {p["publisher"]: p for p in PUBLISHERS} + for model_id, publisher, provider in SEED_MODELS: + if model_id not in seen: + cutoff = pub_by_name.get(publisher, {}).get("version_cutoff") + if cutoff: + _, parsed_ver = parse_model_family(model_id) + if parsed_ver and parsed_ver <= cutoff: + continue + seen.add(model_id) + result.append((model_id, publisher, provider, None)) + + # Build the set of protected model IDs (defaults are never dropped) + protected: set[str] = set() + default_model = manifest.get("defaultModel", "") + if default_model: + protected.add(default_model) + for model_id in manifest.get("providerDefaults", {}).values(): + if model_id: + protected.add(model_id) + + # Keep only the N most recent versions per model family + result = keep_latest_versions(result, MAX_VERSIONS_PER_FAMILY, protected) + kept_ids = {entry[0] for entry in result} + + # Print the summary table with accurate final disposition + for publisher, log_entries in publisher_log: + if not log_entries: + continue + print(f" {publisher}: {len(log_entries)} model(s) from API") + for model_id, reason in log_entries: + if reason == "KEEP" and model_id in protected: + reason = "KEEP (default)" + elif reason == "KEEP" and model_id not in kept_ids: + reason = "SKIP (version limit)" + print(f" {model_id:<50s} {reason}") + + return sorted(result, key=lambda x: x[0]) + + +def model_id_to_label(model_id: str) -> str: + """Convert a model ID like 'claude-opus-4-6' to 'Claude Opus 4.6'.""" + parts = model_id.split("-") + result = [] + for part in parts: + if part and part[0].isdigit(): + if result and result[-1][-1].isdigit(): + result[-1] += f".{part}" + else: + result.append(part) + elif part: + result.append(part.capitalize()) + return " ".join(result) + + +# Temporal qualifiers stripped from model names before determining family. +# These are release stages and date stamps, not part of the model identity. +# Applied to individual dash-segments after splitting. +_QUALIFIER_PATTERNS = [ + re.compile(r"^preview$"), + re.compile(r"^exp$"), + re.compile(r"^\d{2}$"), # date segments like 04, 17 (from stamps like 04-17) +] + + +def parse_model_family(model_id: str) -> tuple[str, tuple[int, ...]]: + """Split a model ID into (family, version_tuple). + + Handles two naming conventions: + + 1. Semver segment (e.g. "2.5" in "gemini-2.5-flash"): + The first segment matching ``\\d+\\.\\d+`` is extracted as the version + and removed from the family name. Temporal qualifiers (preview, exp, + date stamps) are also stripped so that preview variants group with + their stable counterpart. + "gemini-2.5-flash" -> ("gemini-flash", (2, 5)) + "gemini-2.5-flash-lite" -> ("gemini-flash-lite", (2, 5)) + "gemini-2.5-flash-preview-04-17" -> ("gemini-flash", (2, 5)) + "gemini-2.0-flash-preview-image-generation" + -> ("gemini-flash-image-generation", (2, 0)) + "gemini-3.1-flash-image-preview" -> ("gemini-flash-image", (3, 1)) + + 2. Trailing digits (e.g. "claude-opus-4-6"): + Trailing numeric dash-segments form the version. + "claude-opus-4-6" -> ("claude-opus", (4, 6)) + "claude-haiku-4-5" -> ("claude-haiku", (4, 5)) + """ + parts = model_id.split("-") + + # Check for a semver segment (e.g. "2.5", "3.1") + for i, part in enumerate(parts): + if re.fullmatch(r"\d+\.\d+", part): + version = tuple(int(x) for x in part.split(".")) + family_parts = parts[:i] + parts[i + 1 :] + # Strip temporal qualifiers from family name + family_parts = [ + p + for p in family_parts + if not any(q.match(p) for q in _QUALIFIER_PATTERNS) + ] + return "-".join(family_parts), version + + # Fall back to trailing numeric segments + version_parts: list[int] = [] + while parts and parts[-1].isdigit(): + version_parts.insert(0, int(parts.pop())) + family = "-".join(parts) if parts else model_id + return family, tuple(version_parts) + + +def keep_latest_versions( + models: list[tuple[str, str, str, str | None]], + max_versions: int, + protected: set[str] | None = None, +) -> list[tuple[str, str, str, str | None]]: + """Keep only the N most recent versions per model family. + + Models without a parseable version (no semver or trailing digits) are always kept. + Provider default models (from providerDefaults in the manifest) are exempt + from version limiting and always kept. + """ + protected = protected or set() + + # Group by family + families: dict[ + str, list[tuple[tuple[int, ...], tuple[str, str, str, str | None]]] + ] = defaultdict(list) + no_version: list[tuple[str, str, str, str | None]] = [] + + for entry in models: + model_id = entry[0] + if model_id in protected: + no_version.append(entry) + continue + family, version = parse_model_family(model_id) + if version: + families[family].append((version, entry)) + else: + no_version.append(entry) + + result: list[tuple[str, str, str, str | None]] = list(no_version) + for family, versioned in sorted(families.items()): + # Sort by version descending, keep top N + versioned.sort(key=lambda x: x[0], reverse=True) + kept = [entry for _, entry in versioned[:max_versions]] + dropped = [entry[0] for _, entry in versioned[max_versions:]] + if dropped: + print(f" {family}: keeping {max_versions} latest, dropping {dropped}") + result.extend(kept) + + return sorted(result, key=lambda x: x[0]) + + +def _build_probe_request( + region: str, project_id: str, vertex_id: str, publisher: str, token: str +) -> urllib.request.Request: + """Build the probe HTTP request for a given publisher.""" + safe_vid = urllib.parse.quote(vertex_id, safe="@") + if publisher == "google": + url = ( + f"https://{region}-aiplatform.googleapis.com/v1/" + f"projects/{project_id}/locations/{region}/" + f"publishers/google/models/{safe_vid}:generateContent" + ) + body = json.dumps( + { + "contents": [{"parts": [{"text": "hi"}]}], + "generationConfig": {"maxOutputTokens": 1}, + } + ).encode() + elif publisher == "anthropic": + url = ( + f"https://{region}-aiplatform.googleapis.com/v1/" + f"projects/{project_id}/locations/{region}/" + f"publishers/anthropic/models/{safe_vid}:rawPredict" + ) + body = json.dumps( + { + "anthropic_version": "vertex-2023-10-16", + "max_tokens": 1, + "messages": [{"role": "user", "content": "hi"}], + } + ).encode() + else: + raise ValueError(f"Unknown publisher: {publisher!r}") + + return urllib.request.Request( + url, + data=body, + headers={ + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + }, + method="POST", + ) + + +def probe_model( + region: str, project_id: str, vertex_id: str, publisher: str, token: str +) -> str: + """Probe a Vertex AI model endpoint. + + Returns: + "available" - 200 or 400 (model exists, endpoint responds) + "unavailable" - 404 (model not found) + "unknown" - any other status (transient error, leave unchanged) + """ + if publisher not in ("anthropic", "google"): + print( + f" {vertex_id}: unsupported publisher {publisher!r}", + file=sys.stderr, + ) + return "unknown" + + last_err = None + for attempt in range(3): + req = _build_probe_request(region, project_id, vertex_id, publisher, token) + + try: + with urllib.request.urlopen(req, timeout=30): + return "available" + except urllib.error.HTTPError as e: + if e.code == 400: + return "available" + if e.code == 404: + return "unavailable" + if e.code in (429, 500, 502, 503, 504): + last_err = e + else: + print( + f" WARNING: unexpected HTTP {e.code} for {vertex_id}", + file=sys.stderr, + ) + return "unknown" + except Exception as e: + last_err = e + + if attempt < 2: + time.sleep(2**attempt) + + print( + f" WARNING: probe failed after 3 attempts for {vertex_id} ({last_err})", + file=sys.stderr, + ) + return "unknown" + + +def load_manifest(path: Path) -> dict: + """Load the model manifest JSON, or return a blank manifest if missing. + + Raises on malformed JSON to prevent overwriting a corrupt file. + Returns a blank manifest only when the file does not exist yet. + """ + if not path.exists(): + return {"version": 1, "defaultModel": "claude-sonnet-4-5", "models": []} + + with open(path) as f: + data = json.load(f) + + if not isinstance(data, dict) or "models" not in data: + raise ValueError( + f"manifest at {path} is missing required 'models' key — " + f"fix the file manually or delete it to start fresh" + ) + + return data + + +def save_manifest(path: Path, manifest: dict) -> None: + """Save the model manifest JSON with consistent formatting.""" + with open(path, "w") as f: + json.dump(manifest, f, indent=2) + f.write("\n") + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main() -> int: + region = os.environ.get("GCP_REGION", "").strip() + project_id = os.environ.get("GCP_PROJECT", "").strip() + + if not region or not project_id: + print( + "ERROR: GCP_REGION and GCP_PROJECT must be set", + file=sys.stderr, + ) + return 1 + + manifest_path = Path(os.environ.get("MANIFEST_PATH", str(DEFAULT_MANIFEST))) + manifest = load_manifest(manifest_path) + token = get_access_token() + + # Discover models from the Model Garden API + seed list fallback + print("Discovering models from Vertex AI Model Garden...") + models_to_process = discover_models(token, manifest) + print(f"Processing {len(models_to_process)} model(s) in {region}/{project_id}...") + + changes = [] + + for model_id, publisher, provider, version_id in models_to_process: + # Find existing entry in manifest + existing = next((m for m in manifest["models"] if m["id"] == model_id), None) + + # Determine the vertex ID to probe. + # version_id comes from the list API; fall back to existing manifest + # entry or @default if neither is available. + if version_id: + vertex_id = f"{model_id}@{version_id}" + elif existing and existing.get("vertexId"): + vertex_id = existing["vertexId"] + else: + vertex_id = f"{model_id}@default" + + # Probe availability + status = probe_model(region, project_id, vertex_id, publisher, token) + is_available = status == "available" + + if existing: + # Update vertexId if version resolution found a newer one + if existing.get("vertexId") != vertex_id and version_id: + old_vid = existing.get("vertexId", "") + existing["vertexId"] = vertex_id + changes.append( + f" {model_id}: vertexId updated {old_vid} -> {vertex_id}" + ) + print(f" {model_id}: vertexId updated -> {vertex_id}") + + if status == "unknown": + print( + f" {model_id}: probe inconclusive, " + f"leaving available={existing['available']}" + ) + continue + if existing["available"] != is_available: + existing["available"] = is_available + changes.append(f" {model_id}: available changed to {is_available}") + print(f" {model_id}: available -> {is_available}") + else: + print(f" {model_id}: unchanged (available={is_available})") + else: + if status == "unknown": + print(f" {model_id}: new model but probe inconclusive, skipping") + continue + new_entry = { + "id": model_id, + "label": model_id_to_label(model_id), + "vertexId": vertex_id, + "provider": provider, + "available": is_available, + "featureGated": True, # New models require explicit opt-in via feature flag + } + manifest["models"].append(new_entry) + changes.append(f" {model_id}: added (available={is_available})") + print(f" {model_id}: NEW model added (available={is_available})") + + if changes: + save_manifest(manifest_path, manifest) + print(f"\n{len(changes)} change(s) written to {manifest_path}:") + for c in changes: + print(c) + else: + print("\nNo changes detected.") + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/.github/workflows/model-discovery.yml b/.github/workflows/model-discovery.yml index 2b0ed7da4..e244059da 100644 --- a/.github/workflows/model-discovery.yml +++ b/.github/workflows/model-discovery.yml @@ -40,7 +40,7 @@ jobs: env: GCP_REGION: ${{ secrets.GCP_REGION }} GCP_PROJECT: ${{ secrets.GCP_PROJECT }} - run: python scripts/model-discovery.py + run: python .github/scripts/model-discovery.py - name: Check for changes id: diff diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 14e9bd304..937a712f1 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -10,6 +10,8 @@ on: - 'components/ambient-cli/**' - 'components/ambient-sdk/go-sdk/**' - 'components/frontend/**' + - '.github/scripts/**' + - 'tests/**' - '.github/workflows/unit-tests.yml' - '!**/*.md' @@ -21,6 +23,8 @@ on: - 'components/ambient-cli/**' - 'components/ambient-sdk/go-sdk/**' - 'components/frontend/**' + - '.github/scripts/**' + - 'tests/**' - '.github/workflows/unit-tests.yml' - '!**/*.md' @@ -50,6 +54,7 @@ jobs: runner: ${{ steps.filter.outputs.runner }} cli: ${{ steps.filter.outputs.cli }} frontend: ${{ steps.filter.outputs.frontend }} + scripts: ${{ steps.filter.outputs.scripts }} steps: - name: Checkout code uses: actions/checkout@v6 @@ -70,6 +75,9 @@ jobs: - 'components/ambient-sdk/go-sdk/**' frontend: - 'components/frontend/**' + scripts: + - '.github/scripts/**' + - 'tests/test_model_discovery.py' backend: runs-on: ubuntu-latest @@ -300,9 +308,26 @@ jobs: - name: Run unit tests with coverage run: npx vitest run --coverage + scripts: + runs-on: ubuntu-latest + needs: detect-changes + if: needs.detect-changes.outputs.scripts == 'true' || github.event_name == 'workflow_dispatch' + name: Script Tests (model-discovery) + steps: + - name: Checkout code + uses: actions/checkout@v6 + + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: '3.11' + + - name: Run tests + run: python tests/test_model_discovery.py -v + summary: runs-on: ubuntu-latest - needs: [detect-changes, backend, api-server, runner, cli, frontend] + needs: [detect-changes, backend, api-server, runner, cli, frontend, scripts] if: always() steps: - name: Check overall status @@ -313,7 +338,8 @@ jobs: "${{ needs.api-server.result }}" \ "${{ needs.runner.result }}" \ "${{ needs.cli.result }}" \ - "${{ needs.frontend.result }}"; do + "${{ needs.frontend.result }}" \ + "${{ needs.scripts.result }}"; do if [ "$result" == "failure" ] || [ "$result" == "cancelled" ]; then failed=true fi @@ -325,6 +351,7 @@ jobs: echo " runner: ${{ needs.runner.result }}" echo " cli: ${{ needs.cli.result }}" echo " frontend: ${{ needs.frontend.result }}" + echo " scripts: ${{ needs.scripts.result }}" exit 1 fi echo "All unit tests passed!" diff --git a/scripts/model-discovery.py b/scripts/model-discovery.py deleted file mode 100755 index c24f564cb..000000000 --- a/scripts/model-discovery.py +++ /dev/null @@ -1,322 +0,0 @@ -#!/usr/bin/env python3 -"""Automated Vertex AI model discovery. - -Maintains a curated list of Anthropic model base names, resolves their -latest Vertex AI version via the Model Garden API, probes each to confirm -availability, and updates the model manifest. Never removes models — only -adds new ones or updates the ``available`` / ``vertexId`` fields. - -Required env vars: - GCP_REGION - GCP region (e.g. us-east5) - GCP_PROJECT - GCP project ID - -Optional env vars: - GOOGLE_APPLICATION_CREDENTIALS - Path to SA key (uses ADC otherwise) - MANIFEST_PATH - Override default manifest location -""" - -import json -import os -import subprocess -import sys -import time -import urllib.error -import urllib.request -from pathlib import Path - -# --------------------------------------------------------------------------- -# Configuration -# --------------------------------------------------------------------------- - -DEFAULT_MANIFEST = ( - Path(__file__).resolve().parent.parent - / "components" - / "manifests" - / "base" - / "models.json" -) - -# Known Anthropic model base names. Add new models here as they are released. -# Version resolution and availability probing are automatic. -KNOWN_MODELS = [ - "claude-sonnet-4-6", - "claude-sonnet-4-5", - "claude-opus-4-6", - "claude-opus-4-5", - "claude-haiku-4-5", -] - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def get_access_token() -> str: - """Get a GCP access token via gcloud.""" - try: - result = subprocess.run( - ["gcloud", "auth", "print-access-token"], - capture_output=True, - text=True, - check=True, - timeout=30, - ) - except subprocess.TimeoutExpired: - raise RuntimeError("Timed out getting GCP access token via gcloud") - except subprocess.CalledProcessError: - raise RuntimeError("Failed to get GCP access token via gcloud") - return result.stdout.strip() - - -def resolve_version(region: str, model_id: str, token: str) -> str | None: - """Resolve the latest version for a model via the Model Garden API. - - Returns the version string (e.g. "20250929") or None if the API call - fails (permissions, model not found, etc.). - - Note: requires ``roles/serviceusage.serviceUsageConsumer`` on the GCP - project. Works in CI via the Workload Identity service account; may - return None locally if the user lacks this role. - """ - url = ( - f"https://{region}-aiplatform.googleapis.com/v1/" - f"publishers/anthropic/models/{model_id}" - ) - - last_err = None - for attempt in range(3): - req = urllib.request.Request( - url, - headers={"Authorization": f"Bearer {token}"}, - method="GET", - ) - try: - with urllib.request.urlopen(req, timeout=30) as resp: - data = json.loads(resp.read().decode()) - - name = data.get("name", "") - if "@" in name: - return name.split("@", 1)[1] - return data.get("versionId") - - except urllib.error.HTTPError as e: - if e.code in (403, 404): - # Permission denied or not found — retrying won't help - print( - f" {model_id}: version resolution unavailable (HTTP {e.code})", - file=sys.stderr, - ) - return None - last_err = e - except Exception as e: - last_err = e - - if attempt < 2: - time.sleep(2**attempt) # 1s, 2s backoff - - print( - f" {model_id}: version resolution failed after 3 attempts ({last_err})", - file=sys.stderr, - ) - return None - - -def model_id_to_label(model_id: str) -> str: - """Convert a model ID like 'claude-opus-4-6' to 'Claude Opus 4.6'.""" - parts = model_id.split("-") - result = [] - for part in parts: - if part and part[0].isdigit(): - if result and result[-1][-1].isdigit(): - result[-1] += f".{part}" - else: - result.append(part) - elif part: - result.append(part.capitalize()) - return " ".join(result) - - -def probe_model(region: str, project_id: str, vertex_id: str, token: str) -> str: - """Probe a Vertex AI model endpoint. - - Returns: - "available" - 200 or 400 (model exists, endpoint responds) - "unavailable" - 404 (model not found) - "unknown" - any other status (transient error, leave unchanged) - """ - url = ( - f"https://{region}-aiplatform.googleapis.com/v1/" - f"projects/{project_id}/locations/{region}/" - f"publishers/anthropic/models/{vertex_id}:rawPredict" - ) - - body = json.dumps( - { - "anthropic_version": "vertex-2023-10-16", - "max_tokens": 1, - "messages": [{"role": "user", "content": "hi"}], - } - ).encode() - - last_err = None - for attempt in range(3): - req = urllib.request.Request( - url, - data=body, - headers={ - "Authorization": f"Bearer {token}", - "Content-Type": "application/json", - }, - method="POST", - ) - - try: - with urllib.request.urlopen(req, timeout=30): - return "available" - except urllib.error.HTTPError as e: - if e.code == 400: - return "available" - if e.code == 404: - return "unavailable" - if e.code in (429, 500, 502, 503, 504): - last_err = e - else: - print( - f" WARNING: unexpected HTTP {e.code} for {vertex_id}", - file=sys.stderr, - ) - return "unknown" - except Exception as e: - last_err = e - - if attempt < 2: - time.sleep(2**attempt) - - print( - f" WARNING: probe failed after 3 attempts for {vertex_id} ({last_err})", - file=sys.stderr, - ) - return "unknown" - - -def load_manifest(path: Path) -> dict: - """Load the model manifest JSON, or return a blank manifest if missing/empty.""" - blank = {"version": 1, "defaultModel": "claude-sonnet-4-5", "models": []} - if not path.exists(): - return blank - try: - with open(path) as f: - data = json.load(f) - if not isinstance(data, dict) or "models" not in data: - return blank - return data - except (json.JSONDecodeError, ValueError) as e: - print( - f"WARNING: malformed manifest at {path}, starting fresh ({e})", - file=sys.stderr, - ) - return blank - - -def save_manifest(path: Path, manifest: dict) -> None: - """Save the model manifest JSON with consistent formatting.""" - with open(path, "w") as f: - json.dump(manifest, f, indent=2) - f.write("\n") - - -# --------------------------------------------------------------------------- -# Main -# --------------------------------------------------------------------------- - - -def main() -> int: - region = os.environ.get("GCP_REGION", "").strip() - project_id = os.environ.get("GCP_PROJECT", "").strip() - - if not region or not project_id: - print( - "ERROR: GCP_REGION and GCP_PROJECT must be set", - file=sys.stderr, - ) - return 1 - - manifest_path = Path(os.environ.get("MANIFEST_PATH", str(DEFAULT_MANIFEST))) - manifest = load_manifest(manifest_path) - token = get_access_token() - - print(f"Processing {len(KNOWN_MODELS)} known model(s) in {region}/{project_id}...") - - changes = [] - - for model_id in KNOWN_MODELS: - # Try to resolve the latest version via Model Garden API - resolved_version = resolve_version(region, model_id, token) - - # Find existing entry in manifest - existing = next((m for m in manifest["models"] if m["id"] == model_id), None) - - # Determine the vertex ID to probe - if resolved_version: - vertex_id = f"{model_id}@{resolved_version}" - elif existing and existing.get("vertexId"): - vertex_id = existing["vertexId"] - else: - vertex_id = f"{model_id}@default" - - # Probe availability - status = probe_model(region, project_id, vertex_id, token) - is_available = status == "available" - - if existing: - # Update vertexId if version resolution found a newer one - if existing.get("vertexId") != vertex_id and resolved_version: - old_vid = existing.get("vertexId", "") - existing["vertexId"] = vertex_id - changes.append( - f" {model_id}: vertexId updated {old_vid} -> {vertex_id}" - ) - print(f" {model_id}: vertexId updated -> {vertex_id}") - - if status == "unknown": - print( - f" {model_id}: probe inconclusive, " - f"leaving available={existing['available']}" - ) - continue - if existing["available"] != is_available: - existing["available"] = is_available - changes.append(f" {model_id}: available changed to {is_available}") - print(f" {model_id}: available -> {is_available}") - else: - print(f" {model_id}: unchanged (available={is_available})") - else: - if status == "unknown": - print(f" {model_id}: new model but probe inconclusive, skipping") - continue - new_entry = { - "id": model_id, - "label": model_id_to_label(model_id), - "vertexId": vertex_id, - "provider": "anthropic", - "available": is_available, - "featureGated": True, # New models require explicit opt-in via feature flag - } - manifest["models"].append(new_entry) - changes.append(f" {model_id}: added (available={is_available})") - print(f" {model_id}: NEW model added (available={is_available})") - - if changes: - save_manifest(manifest_path, manifest) - print(f"\n{len(changes)} change(s) written to {manifest_path}:") - for c in changes: - print(c) - else: - print("\nNo changes detected.") - - return 0 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/tests/test_model_discovery.py b/tests/test_model_discovery.py new file mode 100644 index 000000000..dd4e2e246 --- /dev/null +++ b/tests/test_model_discovery.py @@ -0,0 +1,250 @@ +"""Unit tests for model-discovery.py pure functions.""" + +import importlib.util +import sys +import unittest +from pathlib import Path +from unittest.mock import patch + +# Import model-discovery.py as a module (it has a hyphen in the name) +_spec = importlib.util.spec_from_file_location( + "model_discovery", + Path(__file__).resolve().parent.parent / ".github" / "scripts" / "model-discovery.py", +) +_mod = importlib.util.module_from_spec(_spec) +sys.modules["model_discovery"] = _mod +_spec.loader.exec_module(_mod) + +parse_model_family = _mod.parse_model_family +model_id_to_label = _mod.model_id_to_label +keep_latest_versions = _mod.keep_latest_versions +discover_models = _mod.discover_models + + +class TestParseModelFamily(unittest.TestCase): + """Test parse_model_family with both naming conventions.""" + + # -- Claude: trailing numeric segments -- + + def test_claude_opus(self): + self.assertEqual(parse_model_family("claude-opus-4-6"), ("claude-opus", (4, 6))) + + def test_claude_sonnet(self): + self.assertEqual( + parse_model_family("claude-sonnet-4-5"), ("claude-sonnet", (4, 5)) + ) + + def test_claude_haiku(self): + self.assertEqual( + parse_model_family("claude-haiku-4-5"), ("claude-haiku", (4, 5)) + ) + + # -- Gemini: semver segment -- + + def test_gemini_flash(self): + self.assertEqual( + parse_model_family("gemini-2.5-flash"), ("gemini-flash", (2, 5)) + ) + + def test_gemini_flash_lite(self): + self.assertEqual( + parse_model_family("gemini-2.5-flash-lite"), ("gemini-flash-lite", (2, 5)) + ) + + def test_gemini_pro(self): + self.assertEqual(parse_model_family("gemini-2.5-pro"), ("gemini-pro", (2, 5))) + + # -- Qualifier stripping -- + + def test_strips_preview(self): + self.assertEqual( + parse_model_family("gemini-2.5-flash-preview-04-17"), + ("gemini-flash", (2, 5)), + ) + + def test_strips_exp_and_date(self): + self.assertEqual( + parse_model_family("gemini-2.5-pro-exp-03-25"), ("gemini-pro", (2, 5)) + ) + + def test_strips_preview_from_image_model(self): + self.assertEqual( + parse_model_family("gemini-3.1-flash-image-preview"), + ("gemini-flash-image", (3, 1)), + ) + + # -- No version -- + + def test_no_version_segments(self): + self.assertEqual(parse_model_family("some-model"), ("some-model", ())) + + +class TestModelIdToLabel(unittest.TestCase): + def test_claude_opus(self): + self.assertEqual(model_id_to_label("claude-opus-4-6"), "Claude Opus 4.6") + + def test_claude_sonnet(self): + self.assertEqual(model_id_to_label("claude-sonnet-4-5"), "Claude Sonnet 4.5") + + def test_gemini_flash(self): + self.assertEqual(model_id_to_label("gemini-2.5-flash"), "Gemini 2.5 Flash") + + def test_gemini_flash_lite(self): + self.assertEqual( + model_id_to_label("gemini-2.5-flash-lite"), "Gemini 2.5 Flash Lite" + ) + + +class TestKeepLatestVersions(unittest.TestCase): + def test_keeps_latest_two(self): + models = [ + ("claude-opus-4-1", "anthropic", "anthropic", None), + ("claude-opus-4-5", "anthropic", "anthropic", None), + ("claude-opus-4-6", "anthropic", "anthropic", None), + ] + result = keep_latest_versions(models, 2) + ids = [r[0] for r in result] + self.assertIn("claude-opus-4-6", ids) + self.assertIn("claude-opus-4-5", ids) + self.assertNotIn("claude-opus-4-1", ids) + + def test_versionless_always_kept(self): + models = [ + ("gemini-2.5-flash", "google", "google", None), + ("some-model", "x", "x", None), + ] + result = keep_latest_versions(models, 1) + ids = [r[0] for r in result] + # versionless "some-model" has no trailing digits or semver + self.assertIn("some-model", ids) + # single version in its family — must be kept as the latest + self.assertIn("gemini-2.5-flash", ids) + + def test_protected_models_exempt(self): + models = [ + ("claude-opus-4-1", "anthropic", "anthropic", None), + ("claude-opus-4-5", "anthropic", "anthropic", None), + ("claude-opus-4-6", "anthropic", "anthropic", None), + ] + result = keep_latest_versions(models, 1, protected={"claude-opus-4-1"}) + ids = [r[0] for r in result] + # 4-1 is protected so kept despite version limit of 1 + self.assertIn("claude-opus-4-1", ids) + self.assertIn("claude-opus-4-6", ids) + self.assertEqual(len(ids), 2) + + def test_gemini_semver_grouping(self): + models = [ + ("gemini-2.0-flash", "google", "google", None), + ("gemini-2.5-flash", "google", "google", None), + ("gemini-3.0-flash", "google", "google", None), + ] + result = keep_latest_versions(models, 2) + ids = [r[0] for r in result] + self.assertIn("gemini-3.0-flash", ids) + self.assertIn("gemini-2.5-flash", ids) + self.assertNotIn("gemini-2.0-flash", ids) + + def test_empty_input(self): + self.assertEqual(keep_latest_versions([], 2), []) + + +class TestDiscoverModels(unittest.TestCase): + """Test discover_models with API discovery, seed fallback, and filtering.""" + + _default_manifest = { + "defaultModel": "claude-sonnet-4-5", + "providerDefaults": {"google": "gemini-2.5-flash"}, + } + + @patch("model_discovery.list_publisher_models", return_value=[]) + @patch( + "model_discovery.SEED_MODELS", + _mod.SEED_MODELS + [("gemini-2.0-flash", "google", "google")], + ) + def test_seed_models_respect_version_cutoff(self, _mock_list): + """Seed models older than version_cutoff should be excluded.""" + result = discover_models("fake-token", self._default_manifest) + ids = [r[0] for r in result] + self.assertNotIn("gemini-2.0-flash", ids) + self.assertIn("gemini-2.5-flash", ids) + + @patch("model_discovery.list_publisher_models") + def test_api_discovered_models_included(self, mock_list): + """Models returned by the list API should appear in results.""" + def fake_list(publisher, token): + if publisher == "anthropic": + return [("claude-sonnet-4-5", "20250929"), ("claude-opus-4-6", None)] + if publisher == "google": + return [("gemini-2.5-flash", None), ("gemini-2.5-pro", None)] + return [] + + mock_list.side_effect = fake_list + result = discover_models("fake-token", self._default_manifest) + ids = [r[0] for r in result] + self.assertIn("claude-sonnet-4-5", ids) + self.assertIn("claude-opus-4-6", ids) + self.assertIn("gemini-2.5-flash", ids) + self.assertIn("gemini-2.5-pro", ids) + + @patch("model_discovery.list_publisher_models") + def test_protected_models_exempt_from_pruning(self, mock_list): + """Default model and provider defaults are never pruned by version limiting.""" + def fake_list(publisher, token): + if publisher == "anthropic": + return [ + ("claude-sonnet-4-5", None), # defaultModel + ("claude-sonnet-4-6", None), + ("claude-opus-4-6", None), + ("claude-opus-4-5", None), + ("claude-haiku-4-5", None), + ] + if publisher == "google": + return [("gemini-2.5-flash", None)] # providerDefault + return [] + + mock_list.side_effect = fake_list + result = discover_models("fake-token", self._default_manifest) + ids = [r[0] for r in result] + # Protected models must always be present + self.assertIn("claude-sonnet-4-5", ids) # defaultModel + self.assertIn("gemini-2.5-flash", ids) # providerDefault for google + + @patch("model_discovery.list_publisher_models") + def test_prefix_and_exclude_filters(self, mock_list): + """Prefix filtering keeps matching models; exclude patterns remove unwanted ones.""" + def fake_list(publisher, token): + if publisher == "anthropic": + return [ + ("claude-sonnet-4-5", None), # matches prefix, no exclude + ("claude-opus-4", None), # matches exclude: base alias without minor + ("not-claude-model", None), # doesn't match prefix + ] + if publisher == "google": + return [ + ("gemini-2.5-flash", None), + ("gemini-2.5-flash-001", None), # matches exclude: pinned version + ("gemini-2.0-flash-preview-image-generation", None), # excluded by version_cutoff + ] + return [] + + mock_list.side_effect = fake_list + result = discover_models("fake-token", self._default_manifest) + ids = [r[0] for r in result] + self.assertIn("claude-sonnet-4-5", ids) + self.assertNotIn("claude-opus-4", ids) # excluded: base alias + self.assertNotIn("not-claude-model", ids) # excluded: wrong prefix + self.assertIn("gemini-2.5-flash", ids) + self.assertNotIn("gemini-2.5-flash-001", ids) # excluded: pinned version + self.assertNotIn("gemini-2.0-flash-preview-image-generation", ids) # excluded: version_cutoff + + @patch("model_discovery.list_publisher_models") + def test_auth_error_propagates(self, mock_list): + """Auth errors from list_publisher_models should propagate, not fall back to seeds.""" + mock_list.side_effect = RuntimeError("HTTP 401: check GCP credentials") + with self.assertRaises(RuntimeError): + discover_models("fake-token", self._default_manifest) + + +if __name__ == "__main__": + unittest.main()