Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 143 additions & 5 deletions run_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

This script provides a command-line interface for running comprehensive MEQ-Bench
evaluations on different language models. It supports various model backends
including Hugging Face models, OpenAI API, Anthropic API, MLX (Apple Silicon),
and custom model functions.
including Hugging Face models, OpenAI API, Anthropic API, Google Gemini API,
MLX (Apple Silicon), and custom model functions.

The script handles model loading, benchmark execution, results saving, and provides
detailed progress reporting for long-running evaluations.
Expand All @@ -19,9 +19,12 @@
# Run with OpenAI API (requires OPENAI_API_KEY environment variable)
python run_benchmark.py --model_name openai:gpt-4 --max_items 100 --output_dir results/gpt4/

# Run with custom data and specific configuration
# Run with Anthropic API (requires ANTHROPIC_API_KEY environment variable)
python run_benchmark.py --model_name anthropic:claude-3-opus --data_path data/custom_dataset.json --config config/custom.yaml

# Run with Google Gemini API (requires GOOGLE_API_KEY environment variable)
python run_benchmark.py --model_name gemini:gemini-pro --max_items 100 --output_dir results/gemini/

# Use MLX for optimized inference on Apple Silicon
python run_benchmark.py --model_name mlx:mistralai/Mistral-7B-Instruct-v0.2 --max_items 100
"""
Expand Down Expand Up @@ -67,6 +70,7 @@ def create_model_function(model_name: str) -> Callable[[str], str]:
- "huggingface:model_id" - Uses Hugging Face model (requires transformers)
- "openai:model_id" - Uses OpenAI API (requires openai and API key)
- "anthropic:model_id" - Uses Anthropic API (requires anthropic and API key)
- "gemini:model_id" - Uses Google Gemini API (requires google-generativeai and API key)
- "mlx:model_id" - Uses MLX optimized model (Apple Silicon only)
- Custom format can be added for other providers

Expand All @@ -84,6 +88,7 @@ def create_model_function(model_name: str) -> Callable[[str], str]:
dummy_func = create_model_function("dummy")
hf_func = create_model_function("huggingface:mistralai/Mistral-7B-Instruct-v0.2")
openai_func = create_model_function("openai:gpt-4")
gemini_func = create_model_function("gemini:gemini-pro")
mlx_func = create_model_function("mlx:mistralai/Mistral-7B-Instruct-v0.2")
```
"""
Expand Down Expand Up @@ -115,14 +120,18 @@ def create_model_function(model_name: str) -> Callable[[str], str]:
logger.info(f"Creating Anthropic model: {model_id}")
return _create_anthropic_model(model_id)

elif backend == "gemini":
logger.info(f"Creating Gemini model: {model_id}")
return _create_gemini_model(model_id)

elif backend == "mlx":
logger.info(f"Creating MLX model: {model_id}")
return _create_mlx_model(model_id)

else:
raise ValueError(
f"Unknown model backend: {backend}. "
f"Supported backends: dummy, huggingface, openai, anthropic, mlx"
f"Supported backends: dummy, huggingface, openai, anthropic, gemini, mlx"
)


Expand Down Expand Up @@ -336,6 +345,130 @@ def anthropic_model(prompt: str) -> str:
return anthropic_model


def _create_gemini_model(model_id: str) -> Callable[[str], str]:
"""Create a Google Gemini model function.

This function creates a model function that interfaces with Google's Gemini API
for generating responses. It handles authentication, API calls, and error recovery.

Args:
model_id: The Gemini model identifier (e.g., "gemini-pro", "gemini-pro-vision").

Returns:
Callable function that takes a prompt string and returns response string.

Raises:
ImportError: If google-generativeai library is not installed.
EnvironmentError: If GOOGLE_API_KEY environment variable is not set.
Exception: If model initialization fails.

Example:
```python
model_func = _create_gemini_model("gemini-pro")
response = model_func("Explain diabetes symptoms for a patient")
```
"""
try:
import google.generativeai as genai
except ImportError as e:
raise ImportError(
"Gemini models require 'google-generativeai' library. "
"Install with: pip install google-generativeai"
) from e

# Check for API key
api_key = os.getenv('GOOGLE_API_KEY')
if not api_key:
raise EnvironmentError(
"GOOGLE_API_KEY environment variable is required for Gemini models. "
"Get your API key from: https://makersuite.google.com/app/apikey"
)

# Configure the Gemini client
genai.configure(api_key=api_key)

try:
# Initialize the model
model = genai.GenerativeModel(model_id)
logger.info(f"Initialized Gemini client for model: {model_id}")

# Test the model with a simple query to ensure it's working
logger.debug("Testing Gemini model connection...")

except Exception as e:
logger.error(f"Failed to initialize Gemini model {model_id}: {e}")
raise

def gemini_model(prompt: str) -> str:
"""Generate response using Google Gemini API.

Args:
prompt: Input prompt for the model.

Returns:
Generated response text from the Gemini model.
"""
max_retries = 3
base_delay = 1.0

for attempt in range(max_retries):
try:
logger.debug(f"Gemini generation attempt {attempt + 1}/{max_retries}")

# Configure generation parameters
generation_config = genai.types.GenerationConfig(
temperature=0.7,
top_p=0.8,
top_k=40,
max_output_tokens=800,
)

# Generate response
response = model.generate_content(
prompt,
generation_config=generation_config,
safety_settings={
genai.types.HarmCategory.HARM_CATEGORY_HATE_SPEECH: genai.types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
genai.types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: genai.types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
genai.types.HarmCategory.HARM_CATEGORY_HARASSMENT: genai.types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
genai.types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: genai.types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
}
)

# Check if the response was blocked
if response.candidates and response.candidates[0].content.parts:
generated_text = response.text.strip()
logger.debug(f"Successfully generated {len(generated_text)} characters on attempt {attempt + 1}")
return generated_text
else:
# Handle blocked content
if hasattr(response, 'prompt_feedback'):
block_reason = response.prompt_feedback.block_reason
logger.warning(f"Gemini blocked content: {block_reason}")
return f"Error: Content was blocked due to safety filters ({block_reason})"
else:
logger.warning("Gemini returned empty response")
return "Error: Empty response from Gemini model"

except Exception as e:
logger.warning(f"Gemini generation attempt {attempt + 1} failed: {e}")

# If this is the last attempt, log error and return fallback
if attempt == max_retries - 1:
logger.error(f"All {max_retries} Gemini generation attempts failed. Final error: {e}")
return f"Error: Gemini API call failed after {max_retries} attempts"

# Calculate exponential backoff delay
delay = base_delay * (2 ** attempt)
logger.info(f"Retrying in {delay:.1f} seconds... (attempt {attempt + 2}/{max_retries})")
time.sleep(delay)

# This should never be reached due to the logic above, but included for safety
return "Error: Unexpected failure in Gemini retry mechanism"

return gemini_model


def _create_mlx_model(model_id: str) -> Callable[[str], str]:
"""Create an MLX model function for optimized inference on Apple Silicon.

Expand Down Expand Up @@ -639,9 +772,12 @@ def main():
# Evaluate OpenAI model (requires OPENAI_API_KEY)
python run_benchmark.py --model_name openai:gpt-4 --max_items 100 --output_dir results/openai/

# Use custom data and configuration
# Evaluate Anthropic model (requires ANTHROPIC_API_KEY)
python run_benchmark.py --model_name anthropic:claude-3-opus --data_path data/custom.json --config config/custom.yaml

# Evaluate Google Gemini model (requires GOOGLE_API_KEY)
python run_benchmark.py --model_name gemini:gemini-pro --max_items 100 --output_dir results/gemini/

# Use MLX for optimized inference on Apple Silicon
python run_benchmark.py --model_name mlx:mistralai/Mistral-7B-Instruct-v0.2 --max_items 100

Expand All @@ -650,11 +786,13 @@ def main():
huggingface:model_id - Hugging Face model
openai:model_id - OpenAI API model
anthropic:model_id - Anthropic API model
gemini:model_id - Google Gemini API model
mlx:model_id - MLX optimized model (Apple Silicon only)

Required Environment Variables:
OPENAI_API_KEY - For OpenAI models
ANTHROPIC_API_KEY - For Anthropic models
GOOGLE_API_KEY - For Google Gemini models

Notes:
MLX backend requires Apple Silicon (M1/M2/M3) and mlx-lm package
Expand Down
Loading