Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ AI-SecBench evaluates AI models on security-adjacent reasoning tasks including c
- 🌍 **Multi-language support** - English and Norwegian (extensible)
- 🔄 **Reproducible** - Seeded generation + versioned official sets
- 💰 **Cost tracking** - Estimate API costs per run
- 🔌 **Multi-provider** - Anthropic, OpenAI, HuggingFace, xAI (Grok)
- 🔌 **Multi-provider** - Anthropic, OpenAI, HuggingFace, xAI (Grok), Google Gemini
- 📊 **LLM-as-Judge** - Rubric-based reasoning evaluation

## Installation
Expand All @@ -35,6 +35,7 @@ pip install git+https://github.com/kelkalot/ai-secbench.git[anthropic]
pip install git+https://github.com/kelkalot/ai-secbench.git[openai]
pip install git+https://github.com/kelkalot/ai-secbench.git[huggingface]
pip install git+https://github.com/kelkalot/ai-secbench.git[xai]
pip install git+https://github.com/kelkalot/ai-secbench.git[google]

# All providers
pip install git+https://github.com/kelkalot/ai-secbench.git[all]
Expand Down Expand Up @@ -85,6 +86,8 @@ results = runner.run_sync()

```bash
ai-secbench --provider xai --model grok-2-1212 --n-per-type 2 --language english --output results.json
# Google AI Studio (Gemini)
ai-secbench --provider google --model gemini-1.5-pro-latest --n-per-type 2 --language english --output results.json
```

### Generate Challenges Only
Expand Down Expand Up @@ -271,6 +274,7 @@ ANTHROPIC_API_KEY=sk-ant-...
OPENAI_API_KEY=sk-...
HF_TOKEN=hf_...
XAI_API_KEY=xaikey_...
GOOGLE_API_KEY=AIza...
```

## Contributing
Expand Down
8 changes: 7 additions & 1 deletion ai_secbench/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
create_huggingface_provider,
)
from ai_secbench.providers.xai import XAIProvider, create_xai_provider
from ai_secbench.providers.google import GoogleAIProvider, create_google_provider


# Registry of available providers
Expand All @@ -30,6 +31,8 @@
"hf": HuggingFaceProvider, # Alias
"huggingface_local": LocalHuggingFaceProvider,
"xai": XAIProvider,
"google": GoogleAIProvider,
"gemini": GoogleAIProvider, # Alias
}

# Default models for each provider
Expand All @@ -38,12 +41,13 @@
"openai": "gpt-4o",
"huggingface": "meta-llama/Llama-3.1-8B-Instruct",
"xai": "grok-2-1212",
"google": "gemini-1.5-pro-latest",
}


def list_providers() -> List[str]:
"""List available provider names."""
return ["anthropic", "openai", "huggingface", "huggingface_local", "xai"]
return ["anthropic", "openai", "huggingface", "huggingface_local", "xai", "google"]


def get_provider(
Expand Down Expand Up @@ -102,10 +106,12 @@ def get_provider(
"HuggingFaceProvider",
"LocalHuggingFaceProvider",
"XAIProvider",
"GoogleAIProvider",
"get_provider",
"list_providers",
"create_anthropic_provider",
"create_openai_provider",
"create_huggingface_provider",
"create_xai_provider",
"create_google_provider",
]
130 changes: 130 additions & 0 deletions ai_secbench/providers/google.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""
Google AI Studio (Gemini) provider implementation.

Uses the public REST API (v1beta generateContent endpoint).
"""

import os
from typing import List, Dict, Optional, Tuple, Union

import json

from ai_secbench.providers.base import BaseProvider, ProviderConfig


class GoogleAIProvider(BaseProvider):
"""
Provider for Google AI Studio / Gemini models.

Requires: httpx (install with: pip install httpx)
"""

def __init__(self, config: ProviderConfig):
super().__init__(config)
self._client = None
self._base_url = config.base_url or "https://generativelanguage.googleapis.com"

@property
def client(self):
"""Lazy-load the HTTP client."""
if self._client is None:
try:
import httpx
except ImportError as exc:
raise ImportError(
"httpx package is required for GoogleAIProvider. Install with: pip install httpx"
) from exc

self._client = httpx.AsyncClient(
base_url=self._base_url,
timeout=self.config.timeout,
)
return self._client

def _messages_to_contents(self, messages: List[Dict[str, str]]) -> List[Dict[str, object]]:
"""Convert OpenAI-style messages to Gemini contents format."""
contents = []
for msg in messages:
role = msg.get("role", "user")
# Gemini uses "model" for assistant responses
if role == "assistant":
role = "model"
parts = [{"text": msg.get("content", "")}]
contents.append({"role": role, "parts": parts})
return contents

async def complete(
self,
messages: List[Dict[str, str]],
return_usage: bool = False,
**kwargs,
) -> Union[Tuple[str, Dict[str, int]], str]:
"""
Send completion request to Gemini.
"""
api_key = (
self.config.api_key
or os.environ.get("GOOGLE_API_KEY")
or os.environ.get("GOOGLE_GENAI_API_KEY")
)
if not api_key:
raise ValueError("Google AI API key required. Set GOOGLE_API_KEY env var or pass api_key to config.")

if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]

contents = self._messages_to_contents(messages)

payload = {
"contents": contents,
"generationConfig": {
"temperature": kwargs.get("temperature", self.config.temperature),
"maxOutputTokens": kwargs.get("max_tokens", self.config.max_tokens),
},
}

path = f"/v1beta/models/{self.config.model}:generateContent?key={api_key}"
response = await self.client.post(
path,
headers={"Content-Type": "application/json"},
content=json.dumps(payload),
)
response.raise_for_status()
data = response.json()

candidates = data.get("candidates") or []
text = ""
if candidates:
parts = candidates[0].get("content", {}).get("parts", [])
# concatenate text parts
text = "".join(part.get("text", "") for part in parts if isinstance(part, dict))

if return_usage:
usage = data.get("usageMetadata", {}) or {}
usage_out = {
"input_tokens": usage.get("promptTokenCount", 0),
"output_tokens": usage.get("candidatesTokenCount", 0),
}
return text, usage_out

return text

def count_tokens(self, text: str) -> int:
"""Rough token estimate."""
return len(text) // 4


def create_google_provider(
model: str = "gemini-flash-latest",
api_key: Optional[str] = None,
base_url: Optional[str] = None,
**kwargs,
) -> GoogleAIProvider:
"""Factory to create a Google AI (Gemini) provider."""
config = ProviderConfig(
model=model,
api_key=api_key,
base_url=base_url,
**kwargs,
)
return GoogleAIProvider(config)
2 changes: 1 addition & 1 deletion ai_secbench/providers/xai.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def count_tokens(self, text: str) -> int:


def create_xai_provider(
model: str = "grok-2-1212",
model: str = "grok-4-1-fast",
api_key: Optional[str] = None,
base_url: Optional[str] = None,
**kwargs,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ anthropic = ["anthropic>=0.18.0"]
openai = ["openai>=1.0.0"]
huggingface = ["huggingface_hub>=0.20.0"]
xai = ["httpx>=0.26.0"]
google = ["httpx>=0.26.0"]
local = ["transformers>=4.36.0", "torch>=2.0.0"]
all = [
"anthropic>=0.18.0",
Expand Down
76 changes: 76 additions & 0 deletions tests/test_providers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""
Provider registry and basic connectivity tests.

These tests avoid real network calls and instead check:
- Registry wiring (list_providers/get_provider)
- Helpful errors when API keys are missing

If provider SDKs are not installed for a given provider, the test is skipped.
"""

import asyncio
import importlib

import pytest

from ai_secbench.providers import get_provider, list_providers


def _is_installed(module_name: str) -> bool:
try:
importlib.import_module(module_name)
return True
except ImportError:
return False


def test_registry_contains_expected_providers():
providers = list_providers()
for expected in ["anthropic", "openai", "huggingface", "huggingface_local", "xai", "google"]:
assert expected in providers, f"Provider '{expected}' missing from registry. Seen: {providers}"


@pytest.mark.parametrize(
("provider_name", "expected_cls_name"),
[
("anthropic", "AnthropicProvider"),
("openai", "OpenAIProvider"),
("huggingface", "HuggingFaceProvider"),
("huggingface_local", "LocalHuggingFaceProvider"),
("xai", "XAIProvider"),
("google", "GoogleAIProvider"),
],
)
def test_get_provider_returns_correct_class(provider_name, expected_cls_name):
provider = get_provider(provider_name, model="dummy-model", api_key="DUMMY")
assert provider.__class__.__name__ == expected_cls_name, (
f"Expected {expected_cls_name}, got {provider.__class__.__name__}"
)
# count_tokens should be callable without network
assert isinstance(provider.count_tokens("hello world"), int)


@pytest.mark.parametrize(
("provider_name", "model", "env_vars", "import_check"),
[
("anthropic", "claude-3-5-sonnet-20241022", ["ANTHROPIC_API_KEY"], "anthropic"),
("openai", "gpt-4o", ["OPENAI_API_KEY"], "openai"),
("xai", "grok-4-1-fast", ["XAI_API_KEY", "XAI_KEY"], "httpx"),
("google", "gemini-flash-latest", ["GOOGLE_API_KEY", "GOOGLE_GENAI_API_KEY"], "httpx"),
],
)
def test_complete_raises_without_api_key(monkeypatch, provider_name, model, env_vars, import_check):
if import_check and not _is_installed(import_check):
pytest.skip(f"Required client library '{import_check}' not installed")

# Ensure env vars are absent
for var in env_vars:
monkeypatch.delenv(var, raising=False)

provider = get_provider(provider_name, model=model, api_key=None)

async def _call():
await provider.complete([{"role": "user", "content": "ping"}])

with pytest.raises(ValueError, match="API key"):
asyncio.run(_call())