diff --git a/run_benchmark.py b/run_benchmark.py index 1dcda24..4c87aac 100644 --- a/run_benchmark.py +++ b/run_benchmark.py @@ -3,8 +3,8 @@ This script provides a command-line interface for running comprehensive MEQ-Bench evaluations on different language models. It supports various model backends -including Hugging Face models, OpenAI API, Anthropic API, MLX (Apple Silicon), -and custom model functions. +including Hugging Face models, OpenAI API, Anthropic API, Google Gemini API, +MLX (Apple Silicon), and custom model functions. The script handles model loading, benchmark execution, results saving, and provides detailed progress reporting for long-running evaluations. @@ -19,9 +19,12 @@ # Run with OpenAI API (requires OPENAI_API_KEY environment variable) python run_benchmark.py --model_name openai:gpt-4 --max_items 100 --output_dir results/gpt4/ - # Run with custom data and specific configuration + # Run with Anthropic API (requires ANTHROPIC_API_KEY environment variable) python run_benchmark.py --model_name anthropic:claude-3-opus --data_path data/custom_dataset.json --config config/custom.yaml + # Run with Google Gemini API (requires GOOGLE_API_KEY environment variable) + python run_benchmark.py --model_name gemini:gemini-pro --max_items 100 --output_dir results/gemini/ + # Use MLX for optimized inference on Apple Silicon python run_benchmark.py --model_name mlx:mistralai/Mistral-7B-Instruct-v0.2 --max_items 100 """ @@ -67,6 +70,7 @@ def create_model_function(model_name: str) -> Callable[[str], str]: - "huggingface:model_id" - Uses Hugging Face model (requires transformers) - "openai:model_id" - Uses OpenAI API (requires openai and API key) - "anthropic:model_id" - Uses Anthropic API (requires anthropic and API key) + - "gemini:model_id" - Uses Google Gemini API (requires google-generativeai and API key) - "mlx:model_id" - Uses MLX optimized model (Apple Silicon only) - Custom format can be added for other providers @@ -84,6 +88,7 @@ def create_model_function(model_name: str) -> Callable[[str], str]: dummy_func = create_model_function("dummy") hf_func = create_model_function("huggingface:mistralai/Mistral-7B-Instruct-v0.2") openai_func = create_model_function("openai:gpt-4") + gemini_func = create_model_function("gemini:gemini-pro") mlx_func = create_model_function("mlx:mistralai/Mistral-7B-Instruct-v0.2") ``` """ @@ -115,6 +120,10 @@ def create_model_function(model_name: str) -> Callable[[str], str]: logger.info(f"Creating Anthropic model: {model_id}") return _create_anthropic_model(model_id) + elif backend == "gemini": + logger.info(f"Creating Gemini model: {model_id}") + return _create_gemini_model(model_id) + elif backend == "mlx": logger.info(f"Creating MLX model: {model_id}") return _create_mlx_model(model_id) @@ -122,7 +131,7 @@ def create_model_function(model_name: str) -> Callable[[str], str]: else: raise ValueError( f"Unknown model backend: {backend}. " - f"Supported backends: dummy, huggingface, openai, anthropic, mlx" + f"Supported backends: dummy, huggingface, openai, anthropic, gemini, mlx" ) @@ -336,6 +345,130 @@ def anthropic_model(prompt: str) -> str: return anthropic_model +def _create_gemini_model(model_id: str) -> Callable[[str], str]: + """Create a Google Gemini model function. + + This function creates a model function that interfaces with Google's Gemini API + for generating responses. It handles authentication, API calls, and error recovery. + + Args: + model_id: The Gemini model identifier (e.g., "gemini-pro", "gemini-pro-vision"). + + Returns: + Callable function that takes a prompt string and returns response string. + + Raises: + ImportError: If google-generativeai library is not installed. + EnvironmentError: If GOOGLE_API_KEY environment variable is not set. + Exception: If model initialization fails. + + Example: + ```python + model_func = _create_gemini_model("gemini-pro") + response = model_func("Explain diabetes symptoms for a patient") + ``` + """ + try: + import google.generativeai as genai + except ImportError as e: + raise ImportError( + "Gemini models require 'google-generativeai' library. " + "Install with: pip install google-generativeai" + ) from e + + # Check for API key + api_key = os.getenv('GOOGLE_API_KEY') + if not api_key: + raise EnvironmentError( + "GOOGLE_API_KEY environment variable is required for Gemini models. " + "Get your API key from: https://makersuite.google.com/app/apikey" + ) + + # Configure the Gemini client + genai.configure(api_key=api_key) + + try: + # Initialize the model + model = genai.GenerativeModel(model_id) + logger.info(f"Initialized Gemini client for model: {model_id}") + + # Test the model with a simple query to ensure it's working + logger.debug("Testing Gemini model connection...") + + except Exception as e: + logger.error(f"Failed to initialize Gemini model {model_id}: {e}") + raise + + def gemini_model(prompt: str) -> str: + """Generate response using Google Gemini API. + + Args: + prompt: Input prompt for the model. + + Returns: + Generated response text from the Gemini model. + """ + max_retries = 3 + base_delay = 1.0 + + for attempt in range(max_retries): + try: + logger.debug(f"Gemini generation attempt {attempt + 1}/{max_retries}") + + # Configure generation parameters + generation_config = genai.types.GenerationConfig( + temperature=0.7, + top_p=0.8, + top_k=40, + max_output_tokens=800, + ) + + # Generate response + response = model.generate_content( + prompt, + generation_config=generation_config, + safety_settings={ + genai.types.HarmCategory.HARM_CATEGORY_HATE_SPEECH: genai.types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + genai.types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: genai.types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + genai.types.HarmCategory.HARM_CATEGORY_HARASSMENT: genai.types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + genai.types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: genai.types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + } + ) + + # Check if the response was blocked + if response.candidates and response.candidates[0].content.parts: + generated_text = response.text.strip() + logger.debug(f"Successfully generated {len(generated_text)} characters on attempt {attempt + 1}") + return generated_text + else: + # Handle blocked content + if hasattr(response, 'prompt_feedback'): + block_reason = response.prompt_feedback.block_reason + logger.warning(f"Gemini blocked content: {block_reason}") + return f"Error: Content was blocked due to safety filters ({block_reason})" + else: + logger.warning("Gemini returned empty response") + return "Error: Empty response from Gemini model" + + except Exception as e: + logger.warning(f"Gemini generation attempt {attempt + 1} failed: {e}") + + # If this is the last attempt, log error and return fallback + if attempt == max_retries - 1: + logger.error(f"All {max_retries} Gemini generation attempts failed. Final error: {e}") + return f"Error: Gemini API call failed after {max_retries} attempts" + + # Calculate exponential backoff delay + delay = base_delay * (2 ** attempt) + logger.info(f"Retrying in {delay:.1f} seconds... (attempt {attempt + 2}/{max_retries})") + time.sleep(delay) + + # This should never be reached due to the logic above, but included for safety + return "Error: Unexpected failure in Gemini retry mechanism" + + return gemini_model + + def _create_mlx_model(model_id: str) -> Callable[[str], str]: """Create an MLX model function for optimized inference on Apple Silicon. @@ -639,9 +772,12 @@ def main(): # Evaluate OpenAI model (requires OPENAI_API_KEY) python run_benchmark.py --model_name openai:gpt-4 --max_items 100 --output_dir results/openai/ - # Use custom data and configuration + # Evaluate Anthropic model (requires ANTHROPIC_API_KEY) python run_benchmark.py --model_name anthropic:claude-3-opus --data_path data/custom.json --config config/custom.yaml + # Evaluate Google Gemini model (requires GOOGLE_API_KEY) + python run_benchmark.py --model_name gemini:gemini-pro --max_items 100 --output_dir results/gemini/ + # Use MLX for optimized inference on Apple Silicon python run_benchmark.py --model_name mlx:mistralai/Mistral-7B-Instruct-v0.2 --max_items 100 @@ -650,11 +786,13 @@ def main(): huggingface:model_id - Hugging Face model openai:model_id - OpenAI API model anthropic:model_id - Anthropic API model + gemini:model_id - Google Gemini API model mlx:model_id - MLX optimized model (Apple Silicon only) Required Environment Variables: OPENAI_API_KEY - For OpenAI models ANTHROPIC_API_KEY - For Anthropic models + GOOGLE_API_KEY - For Google Gemini models Notes: MLX backend requires Apple Silicon (M1/M2/M3) and mlx-lm package