From b6220527b07646f0a725dac498a074da8d8b8a10 Mon Sep 17 00:00:00 2001 From: burconsult Date: Thu, 4 Dec 2025 23:05:21 +0100 Subject: [PATCH 1/4] Add Google Gemini provider --- README.md | 6 +- ai_secbench/providers/__init__.py | 8 +- ai_secbench/providers/google.py | 130 ++++++++++++++++++++++++++++++ pyproject.toml | 1 + 4 files changed, 143 insertions(+), 2 deletions(-) create mode 100644 ai_secbench/providers/google.py diff --git a/README.md b/README.md index 0535cac..3e19a29 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ AI-SecBench evaluates AI models on security-adjacent reasoning tasks including c - 🌍 **Multi-language support** - English and Norwegian (extensible) - 🔄 **Reproducible** - Seeded generation + versioned official sets - 💰 **Cost tracking** - Estimate API costs per run -- 🔌 **Multi-provider** - Anthropic, OpenAI, HuggingFace, xAI (Grok) +- 🔌 **Multi-provider** - Anthropic, OpenAI, HuggingFace, xAI (Grok), Google Gemini - 📊 **LLM-as-Judge** - Rubric-based reasoning evaluation ## Installation @@ -35,6 +35,7 @@ pip install git+https://github.com/kelkalot/ai-secbench.git[anthropic] pip install git+https://github.com/kelkalot/ai-secbench.git[openai] pip install git+https://github.com/kelkalot/ai-secbench.git[huggingface] pip install git+https://github.com/kelkalot/ai-secbench.git[xai] +pip install git+https://github.com/kelkalot/ai-secbench.git[google] # All providers pip install git+https://github.com/kelkalot/ai-secbench.git[all] @@ -85,6 +86,8 @@ results = runner.run_sync() ```bash ai-secbench --provider xai --model grok-2-1212 --n-per-type 2 --language english --output results.json +# Google AI Studio (Gemini) +ai-secbench --provider google --model gemini-1.5-pro-latest --n-per-type 2 --language english --output results.json ``` ### Generate Challenges Only @@ -271,6 +274,7 @@ ANTHROPIC_API_KEY=sk-ant-... OPENAI_API_KEY=sk-... HF_TOKEN=hf_... XAI_API_KEY=xaikey_... +GOOGLE_API_KEY=AIza... ``` ## Contributing diff --git a/ai_secbench/providers/__init__.py b/ai_secbench/providers/__init__.py index 30af980..c906b1a 100644 --- a/ai_secbench/providers/__init__.py +++ b/ai_secbench/providers/__init__.py @@ -18,6 +18,7 @@ create_huggingface_provider, ) from ai_secbench.providers.xai import XAIProvider, create_xai_provider +from ai_secbench.providers.google import GoogleAIProvider, create_google_provider # Registry of available providers @@ -30,6 +31,8 @@ "hf": HuggingFaceProvider, # Alias "huggingface_local": LocalHuggingFaceProvider, "xai": XAIProvider, + "google": GoogleAIProvider, + "gemini": GoogleAIProvider, # Alias } # Default models for each provider @@ -38,12 +41,13 @@ "openai": "gpt-4o", "huggingface": "meta-llama/Llama-3.1-8B-Instruct", "xai": "grok-2-1212", + "google": "gemini-1.5-pro-latest", } def list_providers() -> List[str]: """List available provider names.""" - return ["anthropic", "openai", "huggingface", "huggingface_local", "xai"] + return ["anthropic", "openai", "huggingface", "huggingface_local", "xai", "google"] def get_provider( @@ -102,10 +106,12 @@ def get_provider( "HuggingFaceProvider", "LocalHuggingFaceProvider", "XAIProvider", + "GoogleAIProvider", "get_provider", "list_providers", "create_anthropic_provider", "create_openai_provider", "create_huggingface_provider", "create_xai_provider", + "create_google_provider", ] diff --git a/ai_secbench/providers/google.py b/ai_secbench/providers/google.py new file mode 100644 index 0000000..2e530ca --- /dev/null +++ b/ai_secbench/providers/google.py @@ -0,0 +1,130 @@ +""" +Google AI Studio (Gemini) provider implementation. + +Uses the public REST API (v1beta generateContent endpoint). +""" + +import os +from typing import List, Dict, Optional, Tuple, Union + +import json + +from ai_secbench.providers.base import BaseProvider, ProviderConfig + + +class GoogleAIProvider(BaseProvider): + """ + Provider for Google AI Studio / Gemini models. + + Requires: httpx (install with: pip install httpx) + """ + + def __init__(self, config: ProviderConfig): + super().__init__(config) + self._client = None + self._base_url = config.base_url or "https://generativelanguage.googleapis.com" + + @property + def client(self): + """Lazy-load the HTTP client.""" + if self._client is None: + try: + import httpx + except ImportError as exc: + raise ImportError( + "httpx package is required for GoogleAIProvider. Install with: pip install httpx" + ) from exc + + self._client = httpx.AsyncClient( + base_url=self._base_url, + timeout=self.config.timeout, + ) + return self._client + + def _messages_to_contents(self, messages: List[Dict[str, str]]) -> List[Dict[str, object]]: + """Convert OpenAI-style messages to Gemini contents format.""" + contents = [] + for msg in messages: + role = msg.get("role", "user") + # Gemini uses "model" for assistant responses + if role == "assistant": + role = "model" + parts = [{"text": msg.get("content", "")}] + contents.append({"role": role, "parts": parts}) + return contents + + async def complete( + self, + messages: List[Dict[str, str]], + return_usage: bool = False, + **kwargs, + ) -> Union[Tuple[str, Dict[str, int]], str]: + """ + Send completion request to Gemini. + """ + api_key = ( + self.config.api_key + or os.environ.get("GOOGLE_API_KEY") + or os.environ.get("GOOGLE_GENAI_API_KEY") + ) + if not api_key: + raise ValueError("Google AI API key required. Set GOOGLE_API_KEY env var or pass api_key to config.") + + if isinstance(messages, str): + messages = [{"role": "user", "content": messages}] + + contents = self._messages_to_contents(messages) + + payload = { + "contents": contents, + "generationConfig": { + "temperature": kwargs.get("temperature", self.config.temperature), + "maxOutputTokens": kwargs.get("max_tokens", self.config.max_tokens), + }, + } + + path = f"/v1beta/models/{self.config.model}:generateContent?key={api_key}" + response = await self.client.post( + path, + headers={"Content-Type": "application/json"}, + content=json.dumps(payload), + ) + response.raise_for_status() + data = response.json() + + candidates = data.get("candidates") or [] + text = "" + if candidates: + parts = candidates[0].get("content", {}).get("parts", []) + # concatenate text parts + text = "".join(part.get("text", "") for part in parts if isinstance(part, dict)) + + if return_usage: + usage = data.get("usageMetadata", {}) or {} + usage_out = { + "input_tokens": usage.get("promptTokenCount", 0), + "output_tokens": usage.get("candidatesTokenCount", 0), + } + return text, usage_out + + return text + + def count_tokens(self, text: str) -> int: + """Rough token estimate.""" + return len(text) // 4 + + +def create_google_provider( + model: str = "gemini-1.5-pro-latest", + api_key: Optional[str] = None, + base_url: Optional[str] = None, + **kwargs, +) -> GoogleAIProvider: + """Factory to create a Google AI (Gemini) provider.""" + config = ProviderConfig( + model=model, + api_key=api_key, + base_url=base_url, + **kwargs, + ) + return GoogleAIProvider(config) diff --git a/pyproject.toml b/pyproject.toml index 087e13b..58ffd01 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ anthropic = ["anthropic>=0.18.0"] openai = ["openai>=1.0.0"] huggingface = ["huggingface_hub>=0.20.0"] xai = ["httpx>=0.26.0"] +google = ["httpx>=0.26.0"] local = ["transformers>=4.36.0", "torch>=2.0.0"] all = [ "anthropic>=0.18.0", From 19d7db3f3e376ab095695d57ad645d62680dd1cf Mon Sep 17 00:00:00 2001 From: burconsult Date: Thu, 4 Dec 2025 23:22:59 +0100 Subject: [PATCH 2/4] Add provider registry and API key presence tests --- tests/test_providers.py | 74 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) create mode 100644 tests/test_providers.py diff --git a/tests/test_providers.py b/tests/test_providers.py new file mode 100644 index 0000000..ca31c28 --- /dev/null +++ b/tests/test_providers.py @@ -0,0 +1,74 @@ +""" +Provider registry and basic connectivity tests. + +These tests avoid real network calls and instead check: +- Registry wiring (list_providers/get_provider) +- Helpful errors when API keys are missing + +If provider SDKs are not installed for a given provider, the test is skipped. +""" + +import asyncio +import importlib + +import pytest + +from ai_secbench.providers import get_provider, list_providers + + +def _is_installed(module_name: str) -> bool: + try: + importlib.import_module(module_name) + return True + except ImportError: + return False + + +def test_registry_contains_expected_providers(): + providers = list_providers() + for expected in ["anthropic", "openai", "huggingface", "huggingface_local", "xai", "google"]: + assert expected in providers + + +@pytest.mark.parametrize( + ("provider_name", "expected_cls_name"), + [ + ("anthropic", "AnthropicProvider"), + ("openai", "OpenAIProvider"), + ("huggingface", "HuggingFaceProvider"), + ("huggingface_local", "LocalHuggingFaceProvider"), + ("xai", "XAIProvider"), + ("google", "GoogleAIProvider"), + ], +) +def test_get_provider_returns_correct_class(provider_name, expected_cls_name): + provider = get_provider(provider_name, model="dummy-model", api_key="DUMMY") + assert provider.__class__.__name__ == expected_cls_name + # count_tokens should be callable without network + assert isinstance(provider.count_tokens("hello world"), int) + + +@pytest.mark.parametrize( + ("provider_name", "model", "env_vars", "import_check"), + [ + ("anthropic", "claude-3-5-sonnet-20241022", ["ANTHROPIC_API_KEY"], "anthropic"), + ("openai", "gpt-4o", ["OPENAI_API_KEY"], "openai"), + ("xai", "grok-4-1-fast", ["XAI_API_KEY", "XAI_KEY"], "httpx"), + ("google", "gemini-flash-latest", ["GOOGLE_API_KEY", "GOOGLE_GENAI_API_KEY"], "httpx"), + ], +) +def test_complete_raises_without_api_key(monkeypatch, provider_name, model, env_vars, import_check): + if import_check and not _is_installed(import_check): + pytest.skip(f"Required client library '{import_check}' not installed") + + # Ensure env vars are absent + for var in env_vars: + monkeypatch.delenv(var, raising=False) + + provider = get_provider(provider_name, model=model, api_key=None) + + async def _call(): + await provider.complete([{"role": "user", "content": "ping"}]) + + with pytest.raises(ValueError): + asyncio.run(_call()) From 8d675dbe168d740700faabd49778fe619c146692 Mon Sep 17 00:00:00 2001 From: burconsult Date: Thu, 4 Dec 2025 23:24:13 +0100 Subject: [PATCH 3/4] Improve provider tests with clearer assertions --- tests/test_providers.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/test_providers.py b/tests/test_providers.py index ca31c28..4cccec7 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -27,7 +27,7 @@ def _is_installed(module_name: str) -> bool: def test_registry_contains_expected_providers(): providers = list_providers() for expected in ["anthropic", "openai", "huggingface", "huggingface_local", "xai", "google"]: - assert expected in providers + assert expected in providers, f"Provider '{expected}' missing from registry. Seen: {providers}" @pytest.mark.parametrize( @@ -43,7 +43,9 @@ def test_registry_contains_expected_providers(): ) def test_get_provider_returns_correct_class(provider_name, expected_cls_name): provider = get_provider(provider_name, model="dummy-model", api_key="DUMMY") - assert provider.__class__.__name__ == expected_cls_name + assert provider.__class__.__name__ == expected_cls_name, ( + f"Expected {expected_cls_name}, got {provider.__class__.__name__}" + ) # count_tokens should be callable without network assert isinstance(provider.count_tokens("hello world"), int) @@ -70,5 +72,5 @@ def test_complete_raises_without_api_key(monkeypatch, provider_name, model, env_ async def _call(): await provider.complete([{"role": "user", "content": "ping"}]) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="API key"): asyncio.run(_call()) From 7786c0242ff4f4eb1e4079a74c303b49fa1c8c96 Mon Sep 17 00:00:00 2001 From: burconsult Date: Tue, 9 Dec 2025 11:45:22 +0100 Subject: [PATCH 4/4] Update default models for xAI and Google providers --- ai_secbench/providers/google.py | 2 +- ai_secbench/providers/xai.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ai_secbench/providers/google.py b/ai_secbench/providers/google.py index 2e530ca..372c728 100644 --- a/ai_secbench/providers/google.py +++ b/ai_secbench/providers/google.py @@ -115,7 +115,7 @@ def count_tokens(self, text: str) -> int: def create_google_provider( - model: str = "gemini-1.5-pro-latest", + model: str = "gemini-flash-latest", api_key: Optional[str] = None, base_url: Optional[str] = None, **kwargs, diff --git a/ai_secbench/providers/xai.py b/ai_secbench/providers/xai.py index 762860c..00ef647 100644 --- a/ai_secbench/providers/xai.py +++ b/ai_secbench/providers/xai.py @@ -98,7 +98,7 @@ def count_tokens(self, text: str) -> int: def create_xai_provider( - model: str = "grok-2-1212", + model: str = "grok-4-1-fast", api_key: Optional[str] = None, base_url: Optional[str] = None, **kwargs,