diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7d3c9b6..7a2823c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -23,7 +23,7 @@ jobs: python-version: '3.9' - name: Cache pip dependencies - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ~/.cache/pip key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }} @@ -68,7 +68,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Cache pip dependencies - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ~/.cache/pip key: ${{ runner.os }}-${{ matrix.python-version }}-pip-${{ hashFiles('**/requirements.txt') }} @@ -98,7 +98,7 @@ jobs: - name: Upload coverage to Codecov if: matrix.python-version == '3.9' - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 with: file: ./coverage.xml flags: unittests @@ -174,7 +174,7 @@ jobs: make html - name: Upload documentation artifact - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: documentation path: docs/_build/html/ @@ -241,7 +241,7 @@ jobs: twine check dist/* - name: Upload package artifacts - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: dist-packages path: dist/ @@ -259,7 +259,7 @@ jobs: fetch-depth: 0 - name: Download package artifacts - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: dist-packages path: dist/ diff --git a/examples/basic_usage.py b/examples/basic_usage.py index a2dbb9e..ad4ce50 100644 --- a/examples/basic_usage.py +++ b/examples/basic_usage.py @@ -5,34 +5,30 @@ import os import sys import logging -from typing import Optional import torch import warnings -# Suppress some common warnings for cleaner output -warnings.filterwarnings("ignore", category=FutureWarning) -warnings.filterwarnings("ignore", category=UserWarning) - -sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +# Add parent directory to path for imports +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) from src.benchmark import MEQBench from src.evaluator import MEQBenchEvaluator -from src.config import config + +# Suppress some common warnings for cleaner output +warnings.filterwarnings("ignore", category=FutureWarning) +warnings.filterwarnings("ignore", category=UserWarning) # Set up logging -logger = logging.getLogger('meq_bench.examples') +logger = logging.getLogger("meq_bench.examples") -def generate_with_huggingface( - prompt: str, - model_name: str = "mistralai/Mistral-7B-Instruct-v0.2" -) -> str: +def generate_with_huggingface(prompt: str, model_name: str = "mistralai/Mistral-7B-Instruct-v0.2") -> str: """Generate text using a Hugging Face model. - + This function loads a pretrained language model from the Hugging Face Hub and generates a response to the given prompt. It's designed to work with instruction-tuned models that can follow the MEQ-Bench prompt format. - + Args: prompt: The input prompt containing medical content and audience instructions. model_name: Name of the Hugging Face model to use. Popular options include: @@ -40,24 +36,24 @@ def generate_with_huggingface( - "microsoft/DialoGPT-medium" (smaller, faster) - "meta-llama/Llama-2-7b-chat-hf" (requires approval) - "google/flan-t5-large" (encoder-decoder architecture) - + Returns: Generated text response as a string. - + Note: This function requires the transformers library to be installed: pip install transformers torch - + For larger models, ensure you have sufficient GPU memory or use CPU inference (which will be slower). The function automatically detects available devices. - + Example: ```python # Generate explanation using Mistral-7B prompt = "Medical Information: Diabetes is high blood sugar..." response = generate_with_huggingface(prompt, "mistralai/Mistral-7B-Instruct-v0.2") - + # Use a smaller model for faster inference response = generate_with_huggingface(prompt, "microsoft/DialoGPT-medium") ``` @@ -66,48 +62,45 @@ def generate_with_huggingface( # Import transformers here to make it optional from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from transformers import logging as hf_logging - + # Reduce transformers logging verbosity hf_logging.set_verbosity_error() - + except ImportError as e: logger.error("transformers library not installed. Install with: pip install transformers torch") raise ImportError( "transformers library is required for Hugging Face model integration. " "Install with: pip install transformers torch" ) from e - + logger.info(f"Loading Hugging Face model: {model_name}") - + try: # Determine device (GPU if available, otherwise CPU) device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {device}") - + # Load tokenizer logger.info("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(model_name) - + # Set pad token if not present (needed for some models) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - + # Load model with appropriate settings logger.info("Loading model...") model_kwargs = { "torch_dtype": torch.float16 if device == "cuda" else torch.float32, "low_cpu_mem_usage": True, } - + # For CPU inference or limited GPU memory, use smaller precision if device == "cpu": model_kwargs["torch_dtype"] = torch.float32 - - model = AutoModelForCausalLM.from_pretrained( - model_name, - **model_kwargs - ).to(device) - + + model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs).to(device) + # Create text generation pipeline generator = pipeline( "text-generation", @@ -115,20 +108,20 @@ def generate_with_huggingface( tokenizer=tokenizer, device=0 if device == "cuda" else -1, # 0 for first GPU, -1 for CPU return_full_text=False, # Only return generated text, not input prompt - pad_token_id=tokenizer.eos_token_id + pad_token_id=tokenizer.eos_token_id, ) - + # Generation parameters - adjust these for different models/quality trade-offs generation_params = { "max_new_tokens": 800, # Maximum tokens to generate - "temperature": 0.7, # Controls randomness (0.1 = deterministic, 1.0 = creative) - "top_p": 0.9, # Nucleus sampling parameter - "do_sample": True, # Enable sampling for more diverse outputs + "temperature": 0.7, # Controls randomness (0.1 = deterministic, 1.0 = creative) + "top_p": 0.9, # Nucleus sampling parameter + "do_sample": True, # Enable sampling for more diverse outputs "num_return_sequences": 1, "pad_token_id": tokenizer.eos_token_id, "eos_token_id": tokenizer.eos_token_id, } - + # For instruction-tuned models like Mistral, format the prompt appropriately if "instruct" in model_name.lower() or "chat" in model_name.lower(): # Use instruction format for better results @@ -136,41 +129,41 @@ def generate_with_huggingface( else: # Use prompt as-is for base models formatted_prompt = prompt - + logger.info("Generating response...") - + # Generate response result = generator(formatted_prompt, **generation_params) - + # Extract generated text if isinstance(result, list) and len(result) > 0: - generated_text = result[0].get('generated_text', '') + generated_text = result[0].get("generated_text", "") else: generated_text = str(result) - + # Clean up the response generated_text = generated_text.strip() - + # Remove potential instruction formatting artifacts - if generated_text.startswith('[/INST]'): + if generated_text.startswith("[/INST]"): generated_text = generated_text[7:].strip() - + logger.info(f"Generated {len(generated_text)} characters") - + return generated_text - + except Exception as e: logger.error(f"Error during model generation: {e}") - + # Fallback to dummy response to keep the example working logger.warning("Falling back to dummy response due to model loading error") return dummy_model_function(prompt) - + finally: # Clean up GPU memory if used - if 'model' in locals(): + if "model" in locals(): del model - if 'tokenizer' in locals(): + if "tokenizer" in locals(): del tokenizer if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -182,87 +175,110 @@ def dummy_model_function(prompt: str) -> str: In practice, this would call your actual LLM or use the Hugging Face function above. """ return """ - For a Physician: The patient presents with essential hypertension, likely multifactorial etiology including genetic predisposition and lifestyle factors. Recommend ACE inhibitor initiation with monitoring of renal function and electrolytes. Consider cardiovascular risk stratification. + For a Physician: The patient presents with essential hypertension, likely multifactorial etiology including + genetic predisposition and lifestyle factors. Recommend ACE inhibitor initiation with monitoring of renal + function and electrolytes. Consider cardiovascular risk stratification. - For a Nurse: Monitor blood pressure readings twice daily, document trends. Educate patient on medication compliance, dietary sodium restriction, and importance of regular follow-up. Watch for signs of medication side effects. + For a Nurse: Monitor blood pressure readings twice daily, document trends. Educate patient on medication + compliance, dietary sodium restriction, and importance of regular follow-up. Watch for signs of medication + side effects. - For a Patient: Your blood pressure is higher than normal, which means your heart is working harder than it should. We'll start you on medication to help lower it. It's important to take your medicine every day and eat less salt. + For a Patient: Your blood pressure is higher than normal, which means your heart is working harder than it + should. We'll start you on medication to help lower it. It's important to take your medicine every day and + eat less salt. - For a Caregiver: Help ensure they take their blood pressure medication at the same time each day. Monitor for dizziness or fatigue. Encourage a low-salt diet and regular gentle exercise. Call the doctor if blood pressure readings are very high. + For a Caregiver: Help ensure they take their blood pressure medication at the same time each day. Monitor for + dizziness or fatigue. Encourage a low-salt diet and regular gentle exercise. Call the doctor if blood pressure + readings are very high. """ def main(): """Main example function demonstrating both dummy and Hugging Face models""" + # Check if running in CI environment (non-interactive) + is_ci = os.getenv('CI') or os.getenv('GITHUB_ACTIONS') or not os.isatty(0) + print("MEQ-Bench Basic Usage Example with Hugging Face Integration") + if is_ci: + print("Running in CI mode (non-interactive)") print("=" * 60) - + # Initialize benchmark bench = MEQBench() - + # Create sample dataset print("Creating sample dataset...") sample_items = bench.create_sample_dataset() - + # Add items to benchmark for item in sample_items: bench.add_benchmark_item(item) - + # Get benchmark statistics stats = bench.get_benchmark_stats() - print(f"\nBenchmark Statistics:") + print("\nBenchmark Statistics:") print(f"Total items: {stats['total_items']}") print(f"Complexity distribution: {stats['complexity_distribution']}") print(f"Source distribution: {stats['source_distribution']}") - + # Test single explanation generation print("\nGenerating explanations for sample content...") - medical_content = "Diabetes is a condition where blood sugar levels are too high. It requires careful management through diet, exercise, and sometimes medication." - - # Ask user which model to use + medical_content = "Diabetes is a condition where blood sugar levels are too high." + + # Ask user which model to use (or use default in CI) print("\nChoose model for explanation generation:") print("1. Dummy model (fast, for testing)") print("2. Hugging Face model (requires transformers library)") - - choice = input("Enter choice (1 or 2, default=1): ").strip() - + + if is_ci: + choice = "1" # Use dummy model in CI + print("Using dummy model (CI mode)") + else: + choice = input("Enter choice (1 or 2, default=1): ").strip() + if choice == "2": print("\nUsing Hugging Face model...") print("Note: This requires 'transformers' and 'torch' libraries:") print("pip install transformers torch") - + # Option to specify model name - model_choice = input("\nChoose model (or press Enter for default):\n" - "1. mistralai/Mistral-7B-Instruct-v0.2 (default, ~7B params)\n" - "2. microsoft/DialoGPT-medium (smaller, faster)\n" - "3. Custom model name\n" - "Choice: ").strip() - - if model_choice == "2": - model_name = "microsoft/DialoGPT-medium" - elif model_choice == "3": - model_name = input("Enter model name: ").strip() + if is_ci: + model_name = "mistralai/Mistral-7B-Instruct-v0.2" # Use default in CI + print(f"Using default model: {model_name} (CI mode)") else: - model_name = "mistralai/Mistral-7B-Instruct-v0.2" - + model_choice = input( + "\nChoose model (or press Enter for default):\n" + "1. mistralai/Mistral-7B-Instruct-v0.2 (default, ~7B params)\n" + "2. microsoft/DialoGPT-medium (smaller, faster)\n" + "3. Custom model name\n" + "Choice: " + ).strip() + + if model_choice == "2": + model_name = "microsoft/DialoGPT-medium" + elif model_choice == "3": + model_name = input("Enter model name: ").strip() + else: + model_name = "mistralai/Mistral-7B-Instruct-v0.2" + print(f"\nLoading model: {model_name}") print("This may take a few minutes on first run...") - + # Create model function with specified model def hf_model_function(prompt: str) -> str: return generate_with_huggingface(prompt, model_name) - + model_function = hf_model_function model_type = f"Hugging Face ({model_name})" - + else: print("\nUsing dummy model...") model_function = dummy_model_function model_type = "Dummy model" - + print(f"\nGenerating explanations with {model_type}...") explanations = bench.generate_explanations(medical_content, model_function) - + print("\nGenerated Explanations:") for audience, explanation in explanations.items(): print(f"\n{audience.upper()}:") @@ -272,13 +288,13 @@ def hf_model_function(prompt: str) -> str: print(explanation[:max_length] + "...") else: print(explanation) - + # Evaluate explanations print("\nEvaluating explanations...") try: evaluator = MEQBenchEvaluator() results = evaluator.evaluate_all_audiences(medical_content, explanations) - + print("\nEvaluation Results:") for audience, score in results.items(): print(f"\n{audience.upper()}:") @@ -291,31 +307,35 @@ def hf_model_function(prompt: str) -> str: except Exception as e: print(f"Evaluation failed: {e}") print("This might be due to missing dependencies or configuration.") - + # Option to run full benchmark evaluation - run_full = input("\nRun full benchmark evaluation? (y/N): ").strip().lower() - - if run_full == 'y': + if is_ci: + run_full = "n" # Skip full evaluation in CI + print("\nSkipping full benchmark evaluation (CI mode)") + else: + run_full = input("\nRun full benchmark evaluation? (y/N): ").strip().lower() + + if run_full == "y": print("\nRunning full benchmark evaluation...") print("Note: This may take several minutes with real models...") - + try: full_results = bench.evaluate_model(model_function, max_items=2) - + print("\nFull Benchmark Results Summary:") - summary = full_results['summary'] + summary = full_results["summary"] for key, value in summary.items(): if isinstance(value, (int, float)): print(f"{key}: {value:.3f}") - + # Save results output_path = f"sample_results_{model_type.replace(' ', '_').replace('(', '').replace(')', '')}.json" bench.save_results(full_results, output_path) print(f"\nResults saved to: {output_path}") - + except Exception as e: print(f"Full evaluation failed: {e}") - + print("\n" + "=" * 60) print("Example completed!") print("\nTo use Hugging Face models in your own code:") @@ -328,4 +348,4 @@ def hf_model_function(prompt: str) -> str: if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/__init__.py b/src/__init__.py index 4a6449b..927d332 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1,8 +1,8 @@ """MEQ-Bench: A Resource-Efficient Benchmark for Evaluating Medical LLM Explanation Quality. -MEQ-Bench is the first benchmark specifically designed to assess an LLM's ability to generate -audience-adaptive medical explanations for four key stakeholders: physicians, nurses, patients, -and caregivers. This package provides a comprehensive evaluation framework that combines +MEQ-Bench is the first benchmark specifically designed to assess an LLM's ability to generate +audience-adaptive medical explanations for four key stakeholders: physicians, nurses, patients, +and caregivers. This package provides a comprehensive evaluation framework that combines automated metrics with LLM-as-a-judge evaluation. Key Features: @@ -14,11 +14,11 @@ Typical usage example: ```python from meq_bench import MEQBench, MEQBenchEvaluator - + # Initialize benchmark bench = MEQBench() evaluator = MEQBenchEvaluator() - + # Generate and evaluate explanations explanations = bench.generate_explanations(medical_content, model_func) scores = evaluator.evaluate_all_audiences(medical_content, explanations) @@ -31,17 +31,11 @@ # Initialize configuration and logging from .config import config -config.setup_logging() - from .benchmark import MEQBench from .evaluator import MEQBenchEvaluator from .prompt_templates import AudienceAdaptivePrompt from .strategies import StrategyFactory -__all__ = [ - "MEQBench", - "MEQBenchEvaluator", - "AudienceAdaptivePrompt", - "StrategyFactory", - "config" -] \ No newline at end of file +config.setup_logging() + +__all__ = ["MEQBench", "MEQBenchEvaluator", "AudienceAdaptivePrompt", "StrategyFactory", "config"] diff --git a/src/benchmark.py b/src/benchmark.py index b6f0da4..49d9a1b 100644 --- a/src/benchmark.py +++ b/src/benchmark.py @@ -1,7 +1,7 @@ """Main MEQ-Bench benchmark implementation. This module contains the core benchmark classes and functionality for MEQ-Bench. -It provides tools for loading medical datasets, generating audience-adaptive +It provides tools for loading medical datasets, generating audience-adaptive explanations, and running comprehensive evaluations. Key classes: @@ -13,6 +13,7 @@ import os import logging from typing import Dict, List, Any, Optional, Union, Callable + try: from typing_extensions import TypedDict except ImportError: @@ -22,14 +23,15 @@ from .config import config from .prompt_templates import AudienceAdaptivePrompt -from .evaluator import MEQBenchEvaluator, EvaluationScore +from .evaluator import MEQBenchEvaluator -logger = logging.getLogger('meq_bench.benchmark') +logger = logging.getLogger("meq_bench.benchmark") # TypedDict definitions for structured data class EvaluationResultDict(TypedDict): """Type definition for evaluation results.""" + model_name: str total_items: int audience_scores: Dict[str, List[float]] @@ -40,6 +42,7 @@ class EvaluationResultDict(TypedDict): class ItemResultDict(TypedDict): """Type definition for individual item results.""" + item_id: str complexity_level: str source_dataset: str @@ -49,6 +52,7 @@ class ItemResultDict(TypedDict): class BenchmarkStatsDict(TypedDict): """Type definition for benchmark statistics.""" + total_items: int complexity_distribution: Dict[str, int] source_distribution: Dict[str, int] @@ -57,11 +61,11 @@ class BenchmarkStatsDict(TypedDict): @dataclass class MEQBenchItem: """Represents a single benchmark item for evaluation. - + A benchmark item contains medical content to be explained, along with metadata about its complexity level and source dataset. Optionally includes reference explanations for different audiences. - + Attributes: id: Unique identifier for the benchmark item. medical_content: The medical information to be adapted for different audiences. @@ -70,6 +74,7 @@ class MEQBenchItem: reference_explanations: Optional reference explanations for each audience, mapping audience names to explanation text. """ + id: str medical_content: str complexity_level: str # "basic", "intermediate", "advanced" @@ -79,40 +84,40 @@ class MEQBenchItem: class MEQBench: """Main benchmark class for MEQ-Bench evaluation. - + This class provides the core functionality for running MEQ-Bench evaluations, including loading benchmark data, generating audience-adaptive explanations, and evaluating model performance across different audiences and complexity levels. - + The class manages benchmark items, interfaces with evaluation components, and provides comprehensive evaluation results with detailed statistics. - + Attributes: data_path: Path to the benchmark data directory. evaluator: MEQBenchEvaluator instance for scoring explanations. prompt_template: AudienceAdaptivePrompt instance for generating prompts. benchmark_items: List of loaded MEQBenchItem objects. - + Example: ```python # Initialize benchmark bench = MEQBench(data_path="/path/to/data") - + # Generate explanations for a model explanations = bench.generate_explanations(medical_content, model_func) - + # Run full evaluation results = bench.evaluate_model(model_func, max_items=100) ``` """ - + def __init__(self, data_path: Optional[str] = None) -> None: """Initialize MEQ-Bench instance. - + Sets up the benchmark with the specified data directory and initializes the evaluator and prompt template components. Automatically loads benchmark data if the data directory exists. - + Args: data_path: Path to benchmark data directory. If None, uses default 'data' directory relative to the package root. @@ -126,16 +131,16 @@ def __init__(self, data_path: Optional[str] = None) -> None: logger.info("Some evaluation features may be limited due to missing dependencies or configuration") self.prompt_template: AudienceAdaptivePrompt = AudienceAdaptivePrompt() self.benchmark_items: List[MEQBenchItem] = [] - + # Load benchmark data if available self._load_benchmark_data() - + def _resolve_data_path(self, data_path: Optional[str] = None) -> Path: """Resolve data directory path with fallback options. - + Args: data_path: Optional custom data path - + Returns: Resolved Path object for data directory """ @@ -150,35 +155,38 @@ def _resolve_data_path(self, data_path: Optional[str] = None) -> Path: # Current working directory Path.cwd() / "data", # Environment variable if set - Path(os.environ.get('MEQ_BENCH_DATA_PATH', '')) if os.environ.get('MEQ_BENCH_DATA_PATH') else None, + Path(os.environ.get("MEQ_BENCH_DATA_PATH", "")) if os.environ.get("MEQ_BENCH_DATA_PATH") else None, # Config-based path - Path(config.get_data_path()) if hasattr(config, 'get_data_path') else None + Path(config.get_data_path()) if hasattr(config, "get_data_path") else None, ] - + # Find first existing path or use first option as default resolved_path = None for path in possible_paths: if path and path.exists(): resolved_path = path.resolve() break - + if not resolved_path: # Default to package relative path resolved_path = (Path(__file__).parent.parent / "data").resolve() - + # Ensure directory exists - resolved_path.mkdir(parents=True, exist_ok=True) - logger.info(f"Using data directory: {resolved_path}") - return resolved_path - + if resolved_path is not None: + resolved_path.mkdir(parents=True, exist_ok=True) + logger.info(f"Using data directory: {resolved_path}") + return resolved_path + else: + raise ValueError("Could not resolve data directory path") + def _load_benchmark_data(self) -> None: """Load benchmark data from JSON files with error handling. - + Loads benchmark items from benchmark_items.json in the data directory. Each item is converted to a MEQBenchItem object and added to the benchmark_items list. Includes comprehensive error handling for missing files, invalid JSON, and malformed data. - + The JSON file should contain a list of dictionaries with the following keys: - id: Unique identifier for the item - medical_content: Medical information to be explained @@ -188,99 +196,99 @@ def _load_benchmark_data(self) -> None: """ try: benchmark_file = self.data_path / "benchmark_items.json" - + if not benchmark_file.exists(): logger.warning(f"Benchmark data file not found: {benchmark_file}") logger.info("Use create_sample_dataset() to generate sample data") return - - with open(benchmark_file, 'r', encoding='utf-8') as f: + + with open(benchmark_file, "r", encoding="utf-8") as f: data = json.load(f) - + if not isinstance(data, list): raise ValueError("Benchmark data must be a list of items") - + for i, item_data in enumerate(data): try: # Validate required fields - required_fields = ['id', 'medical_content', 'complexity_level', 'source_dataset'] + required_fields = ["id", "medical_content", "complexity_level", "source_dataset"] for field in required_fields: if field not in item_data: raise KeyError(f"Missing required field '{field}' in item {i}") - + item = MEQBenchItem( - id=item_data['id'], - medical_content=item_data['medical_content'], - complexity_level=item_data['complexity_level'], - source_dataset=item_data['source_dataset'], - reference_explanations=item_data.get('reference_explanations') + id=item_data["id"], + medical_content=item_data["medical_content"], + complexity_level=item_data["complexity_level"], + source_dataset=item_data["source_dataset"], + reference_explanations=item_data.get("reference_explanations"), ) self.benchmark_items.append(item) except (KeyError, ValueError) as e: logger.error(f"Error loading benchmark item {i}: {e}") continue - + logger.info(f"Loaded {len(self.benchmark_items)} benchmark items from {benchmark_file}") - + except FileNotFoundError: logger.warning(f"Benchmark data directory not found: {self.data_path}") except json.JSONDecodeError as e: logger.error(f"Invalid JSON in benchmark data file: {e}") except Exception as e: logger.error(f"Unexpected error loading benchmark data: {e}") - + def add_benchmark_item(self, item: MEQBenchItem) -> None: """Add a new benchmark item to the evaluation set. - + Validates the item data and adds it to the benchmark items list. Includes checks for data integrity and duplicate IDs. - + Args: item: MEQBenchItem object to add to the benchmark. - + Raises: TypeError: If item is not an instance of MEQBenchItem. ValueError: If item data is invalid or ID already exists. """ if not isinstance(item, MEQBenchItem): raise TypeError("item must be an instance of MEQBenchItem") - + # Validate item data if not item.id or not isinstance(item.id, str): raise ValueError("item.id must be a non-empty string") - + if not item.medical_content or not isinstance(item.medical_content, str): raise ValueError("item.medical_content must be a non-empty string") - - if item.complexity_level not in ['basic', 'intermediate', 'advanced']: + + if item.complexity_level not in ["basic", "intermediate", "advanced"]: raise ValueError("item.complexity_level must be 'basic', 'intermediate', or 'advanced'") - + # Check for duplicate IDs if any(existing_item.id == item.id for existing_item in self.benchmark_items): raise ValueError(f"Item with ID '{item.id}' already exists") - + self.benchmark_items.append(item) logger.debug(f"Added benchmark item: {item.id}") - + def generate_explanations(self, medical_content: str, model_func: Callable[[str], str]) -> Dict[str, str]: """Generate audience-adaptive explanations using a model. - + Uses the configured prompt template to generate explanations tailored for different healthcare audiences (physicians, nurses, patients, caregivers). - + Args: medical_content: Medical information to be adapted for different audiences. model_func: Callable that takes a prompt string and returns the model's response as a string. - + Returns: Dictionary mapping audience names to their respective explanations. Keys are audience names (e.g., 'physician', 'nurse', 'patient', 'caregiver'). - + Raises: ValueError: If medical_content is empty or invalid TypeError: If model_func is not callable - + Example: ```python def my_model(prompt): @@ -288,9 +296,9 @@ def my_model(prompt): model="gpt-4", messages=[{"role": "user", "content": prompt}] ).choices[0].message.content - + explanations = bench.generate_explanations( - "Hypertension is high blood pressure", + "Hypertension is high blood pressure", my_model ) ``` @@ -298,67 +306,67 @@ def my_model(prompt): # Input validation if not medical_content or not isinstance(medical_content, str): raise ValueError("medical_content must be a non-empty string") - + if medical_content.strip() == "": raise ValueError("medical_content cannot be empty or contain only whitespace") - + if len(medical_content.strip()) < 10: raise ValueError("medical_content must be at least 10 characters long") - + if not callable(model_func): raise TypeError("model_func must be a callable function") - + # Additional content validation if len(medical_content) > 10000: # Reasonable upper limit logger.warning(f"Medical content is very long ({len(medical_content)} chars). Consider splitting.") - + # Sanitize content - remove excessive whitespace - sanitized_content = ' '.join(medical_content.split()) - + sanitized_content = " ".join(medical_content.split()) + try: prompt = self.prompt_template.format_prompt(sanitized_content) logger.debug(f"Generated prompt with {len(prompt)} characters") - + response = model_func(prompt) - + # Validate model response if not response or not isinstance(response, str): raise ValueError("Model function returned empty or invalid response") - + if response.strip() == "": raise ValueError("Model function returned empty response") - + explanations = self.prompt_template.parse_response(response) - + # Validate parsed explanations if not explanations: raise ValueError("Failed to parse any explanations from model response") - + # Log successful generation logger.info(f"Generated explanations for {len(explanations)} audiences") - + return explanations - + except Exception as e: logger.error(f"Error generating explanations: {e}") if isinstance(e, (ValueError, TypeError)): raise else: raise RuntimeError(f"Unexpected error during explanation generation: {e}") from e - + def evaluate_model(self, model_func: Callable[[str], str], max_items: Optional[int] = None) -> EvaluationResultDict: """Evaluate a model on the full benchmark. - + Runs comprehensive evaluation of a model's performance across all benchmark items and audiences. Generates explanations for each item and evaluates them using the configured evaluator. - + Args: model_func: Callable that takes a prompt string and returns the model's response as a string. max_items: Maximum number of benchmark items to evaluate. If None, evaluates all available items. Useful for testing with smaller subsets. - + Returns: Dictionary containing comprehensive evaluation results with the following keys: - model_name: Name of the evaluated model @@ -367,7 +375,7 @@ def evaluate_model(self, model_func: Callable[[str], str], max_items: Optional[i - complexity_scores: Scores grouped by complexity level - detailed_results: Per-item detailed evaluation results - summary: Summary statistics including means, standard deviations, etc. - + Example: ```python results = bench.evaluate_model(my_model_func, max_items=50) @@ -375,77 +383,67 @@ def evaluate_model(self, model_func: Callable[[str], str], max_items: Optional[i ``` """ results: EvaluationResultDict = { - 'model_name': getattr(model_func, '__name__', 'unknown'), - 'total_items': len(self.benchmark_items[:max_items]) if max_items else len(self.benchmark_items), - 'audience_scores': { - 'physician': [], - 'nurse': [], - 'patient': [], - 'caregiver': [] - }, - 'complexity_scores': { - 'basic': [], - 'intermediate': [], - 'advanced': [] - }, - 'detailed_results': [] + "model_name": getattr(model_func, "__name__", "unknown"), + "total_items": len(self.benchmark_items[:max_items]) if max_items else len(self.benchmark_items), + "audience_scores": {"physician": [], "nurse": [], "patient": [], "caregiver": []}, + "complexity_scores": {"basic": [], "intermediate": [], "advanced": []}, + "detailed_results": [], + "summary": {}, } - + items_to_evaluate = self.benchmark_items[:max_items] if max_items else self.benchmark_items - + for item in items_to_evaluate: # Generate explanations explanations = self.generate_explanations(item.medical_content, model_func) - + # Evaluate explanations - evaluation_results = self.evaluator.evaluate_all_audiences( - item.medical_content, explanations - ) - + evaluation_results = self.evaluator.evaluate_all_audiences(item.medical_content, explanations) + # Store detailed results item_result = { - 'item_id': item.id, - 'complexity_level': item.complexity_level, - 'source_dataset': item.source_dataset, - 'explanations': explanations, - 'scores': { + "item_id": item.id, + "complexity_level": item.complexity_level, + "source_dataset": item.source_dataset, + "explanations": explanations, + "scores": { audience: { - 'readability': score.readability, - 'terminology': score.terminology, - 'safety': score.safety, - 'coverage': score.coverage, - 'quality': score.quality, - 'overall': score.overall + "readability": score.readability, + "terminology": score.terminology, + "safety": score.safety, + "coverage": score.coverage, + "quality": score.quality, + "overall": score.overall, } for audience, score in evaluation_results.items() - } + }, } - results['detailed_results'].append(item_result) - + results["detailed_results"].append(item_result) + # Aggregate scores by audience for audience, score in evaluation_results.items(): - if audience in results['audience_scores']: - results['audience_scores'][audience].append(score.overall) - + if audience in results["audience_scores"]: + results["audience_scores"][audience].append(score.overall) + # Aggregate scores by complexity avg_overall = sum(score.overall for score in evaluation_results.values()) / len(evaluation_results) - results['complexity_scores'][item.complexity_level].append(avg_overall) - + results["complexity_scores"][item.complexity_level].append(avg_overall) + # Calculate summary statistics - results['summary'] = self._calculate_summary_stats(results) - + results["summary"] = self._calculate_summary_stats(dict(results)) + return results - + def _calculate_summary_stats(self, results: Dict[str, Any]) -> Dict[str, Any]: """Calculate summary statistics from evaluation results. - + Computes descriptive statistics for audience-level and complexity-level scores, including means, standard deviations, minimums, and maximums. - + Args: results: Dictionary containing evaluation results with audience_scores and complexity_scores keys. - + Returns: Dictionary containing summary statistics with keys like: - {audience}_mean: Mean score for each audience @@ -458,46 +456,48 @@ def _calculate_summary_stats(self, results: Dict[str, Any]) -> Dict[str, Any]: - overall_std: Overall standard deviation """ summary = {} - + # Audience-level statistics - for audience, scores in results['audience_scores'].items(): + for audience, scores in results["audience_scores"].items(): if scores: - summary[f'{audience}_mean'] = sum(scores) / len(scores) - summary[f'{audience}_std'] = (sum((x - summary[f'{audience}_mean']) ** 2 for x in scores) / len(scores)) ** 0.5 - summary[f'{audience}_min'] = min(scores) - summary[f'{audience}_max'] = max(scores) - + summary[f"{audience}_mean"] = sum(scores) / len(scores) + summary[f"{audience}_std"] = (sum((x - summary[f"{audience}_mean"]) ** 2 for x in scores) / len(scores)) ** 0.5 + summary[f"{audience}_min"] = min(scores) + summary[f"{audience}_max"] = max(scores) + # Complexity-level statistics - for complexity, scores in results['complexity_scores'].items(): + for complexity, scores in results["complexity_scores"].items(): if scores: - summary[f'{complexity}_mean'] = sum(scores) / len(scores) - summary[f'{complexity}_std'] = (sum((x - summary[f'{complexity}_mean']) ** 2 for x in scores) / len(scores)) ** 0.5 - + summary[f"{complexity}_mean"] = sum(scores) / len(scores) + summary[f"{complexity}_std"] = ( + sum((x - summary[f"{complexity}_mean"]) ** 2 for x in scores) / len(scores) + ) ** 0.5 + # Overall statistics all_scores = [] - for audience_scores in results['audience_scores'].values(): + for audience_scores in results["audience_scores"].values(): all_scores.extend(audience_scores) - + if all_scores: - summary['overall_mean'] = sum(all_scores) / len(all_scores) - summary['overall_std'] = (sum((x - summary['overall_mean']) ** 2 for x in all_scores) / len(all_scores)) ** 0.5 - + summary["overall_mean"] = sum(all_scores) / len(all_scores) + summary["overall_std"] = (sum((x - summary["overall_mean"]) ** 2 for x in all_scores) / len(all_scores)) ** 0.5 + return summary - + def save_results(self, results: Dict[str, Any], output_path: str) -> None: """Save evaluation results to JSON file. - + Serializes the evaluation results dictionary to a JSON file with proper formatting for readability. Includes proper path handling and error handling. - + Args: results: Dictionary containing evaluation results from evaluate_model(). output_path: File path where results should be saved. - + Raises: Exception: If file writing fails. - + Example: ```python results = bench.evaluate_model(model_func) @@ -505,37 +505,37 @@ def save_results(self, results: Dict[str, Any], output_path: str) -> None: ``` """ output_file = Path(output_path).resolve() - + # Ensure output directory exists output_file.parent.mkdir(parents=True, exist_ok=True) - + try: - with open(output_file, 'w', encoding='utf-8') as f: + with open(output_file, "w", encoding="utf-8") as f: json.dump(results, f, indent=2, ensure_ascii=False) logger.info(f"Results saved to: {output_file}") except Exception as e: logger.error(f"Failed to save results to {output_file}: {e}") raise - + def create_sample_dataset(self, output_path: Optional[str] = None) -> List[MEQBenchItem]: """Create a sample dataset for testing. - + Generates a small set of sample medical content items with different complexity levels for testing and demonstration purposes. - + Args: output_path: Optional path to save the sample dataset as JSON. If provided, saves the dataset to this file. - + Returns: List of MEQBenchItem objects containing sample medical content with basic, intermediate, and advanced complexity levels. - + Example: ```python # Create sample data sample_items = bench.create_sample_dataset() - + # Create and save sample data sample_items = bench.create_sample_dataset("data/sample_dataset.json") ``` @@ -543,63 +543,78 @@ def create_sample_dataset(self, output_path: Optional[str] = None) -> List[MEQBe sample_items = [ MEQBenchItem( id="sample_001", - medical_content="Hypertension, also known as high blood pressure, is a condition where the force of blood against artery walls is consistently too high. It can lead to heart disease, stroke, and kidney problems if left untreated. Treatment typically involves lifestyle changes and medication.", + medical_content=( + "Hypertension, also known as high blood pressure, is a condition where the force of blood " + "against artery walls is consistently too high. It can lead to heart disease, stroke, and " + "kidney problems if left untreated. Treatment typically involves lifestyle changes and medication." + ), complexity_level="basic", - source_dataset="sample" + source_dataset="sample", ), MEQBenchItem( id="sample_002", - medical_content="Myocardial infarction occurs when blood flow to a part of the heart muscle is blocked, usually by a blood clot in a coronary artery. This results in damage or death of heart muscle cells. Immediate treatment with medications to dissolve clots or procedures to restore blood flow is critical.", + medical_content=( + "Myocardial infarction occurs when blood flow to a part of the heart muscle is blocked, usually " + "by a blood clot in a coronary artery. This results in damage or death of heart muscle cells. " + "Immediate treatment with medications to dissolve clots or procedures to restore blood flow is critical." + ), complexity_level="intermediate", - source_dataset="sample" + source_dataset="sample", ), MEQBenchItem( id="sample_003", - medical_content="Diabetic ketoacidosis (DKA) is a serious complication of diabetes mellitus characterized by hyperglycemia, ketosis, and metabolic acidosis. It typically occurs in type 1 diabetes due to absolute insulin deficiency. Treatment involves IV fluids, insulin therapy, and electrolyte replacement while addressing underlying precipitating factors.", + medical_content=( + "Diabetic ketoacidosis (DKA) is a serious complication of diabetes mellitus characterized by " + "hyperglycemia, ketosis, and metabolic acidosis. It typically occurs in type 1 diabetes due to " + "absolute insulin deficiency. Treatment involves IV fluids, insulin therapy, and electrolyte " + "replacement while addressing underlying precipitating factors." + ), complexity_level="advanced", - source_dataset="sample" - ) + source_dataset="sample", + ), ] - + if output_path: # Save sample dataset with proper path handling output_file = Path(output_path).resolve() output_file.parent.mkdir(parents=True, exist_ok=True) - + data_to_save = [] for item in sample_items: - data_to_save.append({ - 'id': item.id, - 'medical_content': item.medical_content, - 'complexity_level': item.complexity_level, - 'source_dataset': item.source_dataset, - 'reference_explanations': item.reference_explanations - }) - + data_to_save.append( + { + "id": item.id, + "medical_content": item.medical_content, + "complexity_level": item.complexity_level, + "source_dataset": item.source_dataset, + "reference_explanations": item.reference_explanations, + } + ) + try: - with open(output_file, 'w', encoding='utf-8') as f: + with open(output_file, "w", encoding="utf-8") as f: json.dump(data_to_save, f, indent=2, ensure_ascii=False) logger.info(f"Sample dataset saved to: {output_file}") except Exception as e: logger.error(f"Failed to save sample dataset to {output_file}: {e}") raise - + return sample_items - + def get_benchmark_stats(self) -> Union[BenchmarkStatsDict, Dict[str, Union[int, str]]]: """Get statistics about the benchmark dataset. - + Provides summary statistics about the loaded benchmark items, including total counts and distributions by complexity level and source dataset. - + Returns: Dictionary containing benchmark statistics with keys: - total_items: Total number of benchmark items - complexity_distribution: Count of items by complexity level - source_distribution: Count of items by source dataset - message: Informational message if no items are loaded - + Example: ```python stats = bench.get_benchmark_stats() @@ -608,83 +623,80 @@ def get_benchmark_stats(self) -> Union[BenchmarkStatsDict, Dict[str, Union[int, ``` """ if not self.benchmark_items: - return {'total_items': 0, 'message': 'No benchmark items loaded'} - + return {"total_items": 0, "message": "No benchmark items loaded"} + stats: BenchmarkStatsDict = { - 'total_items': len(self.benchmark_items), - 'complexity_distribution': {}, - 'source_distribution': {} + "total_items": len(self.benchmark_items), + "complexity_distribution": {}, + "source_distribution": {}, } - + for item in self.benchmark_items: # Count by complexity - if item.complexity_level not in stats['complexity_distribution']: - stats['complexity_distribution'][item.complexity_level] = 0 - stats['complexity_distribution'][item.complexity_level] += 1 - + if item.complexity_level not in stats["complexity_distribution"]: + stats["complexity_distribution"][item.complexity_level] = 0 + stats["complexity_distribution"][item.complexity_level] += 1 + # Count by source - if item.source_dataset not in stats['source_distribution']: - stats['source_distribution'][item.source_dataset] = 0 - stats['source_distribution'][item.source_dataset] += 1 - + if item.source_dataset not in stats["source_distribution"]: + stats["source_distribution"][item.source_dataset] = 0 + stats["source_distribution"][item.source_dataset] += 1 + return stats - + def validate_benchmark(self) -> Dict[str, Any]: """Validate the benchmark dataset and return validation report. - + Returns: Dictionary containing validation results and any issues found """ - validation_report = { - 'valid': True, - 'total_items': len(self.benchmark_items), - 'issues': [], - 'warnings': [], - 'statistics': {} + validation_report: Dict[str, Any] = { + "valid": True, + "total_items": len(self.benchmark_items), + "issues": [], + "warnings": [], + "statistics": {}, } - + if not self.benchmark_items: - validation_report['valid'] = False - validation_report['issues'].append("No benchmark items loaded") + validation_report["valid"] = False + validation_report["issues"].append("No benchmark items loaded") return validation_report - + # Check for duplicate IDs ids = [item.id for item in self.benchmark_items] duplicate_ids = set([id for id in ids if ids.count(id) > 1]) if duplicate_ids: - validation_report['valid'] = False - validation_report['issues'].append(f"Duplicate item IDs found: {duplicate_ids}") - + validation_report["valid"] = False + validation_report["issues"].append(f"Duplicate item IDs found: {duplicate_ids}") + # Validate complexity level distribution - complexity_counts = {} + complexity_counts: Dict[str, int] = {} for item in self.benchmark_items: complexity = item.complexity_level complexity_counts[complexity] = complexity_counts.get(complexity, 0) + 1 - - validation_report['statistics']['complexity_distribution'] = complexity_counts - + + validation_report["statistics"]["complexity_distribution"] = complexity_counts + # Check for balanced distribution if len(complexity_counts) < 3: - validation_report['warnings'].append("Not all complexity levels represented") - + validation_report["warnings"].append("Not all complexity levels represented") + # Validate content length content_lengths = [len(item.medical_content) for item in self.benchmark_items] avg_length = sum(content_lengths) / len(content_lengths) - validation_report['statistics']['average_content_length'] = avg_length - + validation_report["statistics"]["average_content_length"] = avg_length + if avg_length < 50: - validation_report['warnings'].append("Average content length is quite short") + validation_report["warnings"].append("Average content length is quite short") elif avg_length > 2000: - validation_report['warnings'].append("Average content length is quite long") - + validation_report["warnings"].append("Average content length is quite long") + # Check for empty or very short content - short_content_items = [ - item.id for item in self.benchmark_items - if len(item.medical_content.strip()) < 20 - ] + short_content_items = [item.id for item in self.benchmark_items if len(item.medical_content.strip()) < 20] if short_content_items: - validation_report['valid'] = False - validation_report['issues'].append(f"Items with very short content: {short_content_items}") - + validation_report["valid"] = False + validation_report["issues"].append(f"Items with very short content: {short_content_items}") + logger.info(f"Benchmark validation completed. Valid: {validation_report['valid']}") - return validation_report \ No newline at end of file + return validation_report diff --git a/src/config.py b/src/config.py index 4fd5984..9205289 100644 --- a/src/config.py +++ b/src/config.py @@ -12,120 +12,185 @@ - Multiple configuration environments """ -import yaml +import yaml # type: ignore import os import logging from pathlib import Path -from typing import Dict, Any, List, Optional, Union, Type +from typing import Dict, Any, List, Optional class ConfigurationError(Exception): """Raised when there's an error in configuration. - + This exception is raised when: - Configuration files are missing or invalid - Required configuration sections are missing - YAML parsing fails - Required environment variables are not set """ + pass class Config: """Singleton configuration manager for MEQ-Bench. - + This class manages all configuration settings for the MEQ-Bench application. - It loads configuration from YAML files and provides methods to access + It loads configuration from YAML files and provides methods to access configuration values using dot notation. - + The class implements the singleton pattern to ensure that configuration is loaded once and shared across the entire application. - + Attributes: _instance: Singleton instance of the Config class. _config: Dictionary containing the loaded configuration data. - + Example: ```python from config import config - + # Get configuration values model_name = config.get('llm_judge.default_model') audiences = config.get_audiences() - + # Set up logging config.setup_logging() ``` """ - - _instance: Optional['Config'] = None + + _instance: Optional["Config"] = None _config: Optional[Dict[str, Any]] = None - - def __new__(cls) -> 'Config': + + def __new__(cls) -> "Config": if cls._instance is None: cls._instance = super(Config, cls).__new__(cls) return cls._instance - + def __init__(self) -> None: if self._config is None: self.load_config() - + def load_config(self, config_path: Optional[str] = None) -> None: """Load configuration from YAML file. - + Loads configuration settings from a YAML file and validates required sections. If no path is provided, looks for config.yaml in the project root directory. - + Args: config_path: Path to configuration file. If None, uses default config.yaml in the project root directory. - + Raises: ConfigurationError: If configuration file is not found or contains invalid YAML. """ if config_path is None: - # Look for config.yaml in the project root - current_dir = Path(__file__).parent.parent - config_path = str(current_dir / "config.yaml") - + # Search for project root by looking for indicators + config_path = self._find_project_config() + try: - with open(config_path, 'r') as f: + with open(config_path, "r") as f: self._config = yaml.safe_load(f) - + # Validate required sections self._validate_config() - + except FileNotFoundError: raise ConfigurationError(f"Configuration file not found: {config_path}") except yaml.YAMLError as e: raise ConfigurationError(f"Error parsing YAML configuration: {e}") - + + def _find_project_config(self) -> str: + """Find the project configuration file by searching upwards from this file. + + Searches upwards from the current file location until it finds a directory + containing project indicators like .git, pyproject.toml, or setup.py. + + Returns: + Path to config.yaml in the project root. + + Raises: + ConfigurationError: If project root cannot be found. + """ + current_path = Path(__file__).resolve() + + # Search upwards from the current file + for parent in [current_path.parent] + list(current_path.parents): + # Check for project root indicators + indicators = [".git", "pyproject.toml", "setup.py", "setup.cfg", "requirements.txt"] + if any((parent / indicator).exists() for indicator in indicators): + config_file = parent / "config.yaml" + if config_file.exists(): + return str(config_file) + # If config.yaml doesn't exist in project root, create a minimal one + return self._create_minimal_config(parent) + + # Fallback: look in the parent directory of this file + fallback_path = Path(__file__).parent.parent / "config.yaml" + if fallback_path.exists(): + return str(fallback_path) + + # Last resort: create minimal config in current directory + return self._create_minimal_config(Path.cwd()) + + def _create_minimal_config(self, project_root: Path) -> str: + """Create a minimal configuration file for testing/CI environments. + + Args: + project_root: Path to the project root directory. + + Returns: + Path to the created configuration file. + """ + config_path = project_root / "config.yaml" + + minimal_config = { + "app": {"name": "MEQ-Bench", "version": "1.0.0", "data_path": "data/", "output_path": "results/"}, + "audiences": ["physician", "nurse", "patient", "caregiver"], + "complexity_levels": ["basic", "intermediate", "advanced"], + "llm_judge": {"default_model": "gpt-3.5-turbo", "max_tokens": 800, "temperature": 0.3}, + "evaluation": {"metrics": ["readability", "terminology", "safety", "coverage", "quality"], "timeout": 30}, + "scoring": {"weights": {"readability": 0.2, "terminology": 0.2, "safety": 0.3, "coverage": 0.15, "quality": 0.15}}, + "logging": { + "version": 1, + "disable_existing_loggers": False, + "formatters": {"standard": {"format": "%(asctime)s [%(levelname)s] %(name)s: %(message)s"}}, + "handlers": {"default": {"level": "INFO", "formatter": "standard", "class": "logging.StreamHandler"}}, + "loggers": {"": {"handlers": ["default"], "level": "INFO", "propagate": False}}, + }, + } + + try: + with open(config_path, "w") as f: + yaml.dump(minimal_config, f, default_flow_style=False) + return str(config_path) + except Exception: + # If we can't write to disk, return a path that will trigger default values + return str(config_path) + def _validate_config(self) -> None: """Validate that required configuration sections exist""" - required_sections = [ - 'app', 'audiences', 'complexity_levels', 'llm_judge', - 'evaluation', 'scoring', 'logging' - ] - + required_sections = ["app", "audiences", "complexity_levels", "llm_judge", "evaluation", "scoring", "logging"] + for section in required_sections: if self._config is None or section not in self._config: raise ConfigurationError(f"Missing required configuration section: {section}") - + def get(self, key: str, default: Optional[Any] = None) -> Any: """Get configuration value using dot notation. - + Retrieves configuration values using a dot-separated key path. For example, 'llm_judge.default_model' would access config['llm_judge']['default_model']. - + Args: - key: Configuration key using dot notation (e.g., 'app.name', + key: Configuration key using dot notation (e.g., 'app.name', 'llm_judge.default_model'). default: Default value to return if key is not found. - + Returns: The configuration value at the specified key path, or the default value if the key is not found. - + Raises: ConfigurationError: If key is not found and no default value is provided. """ @@ -133,10 +198,10 @@ def get(self, key: str, default: Optional[Any] = None) -> Any: if default is not None: return default raise ConfigurationError(f"Configuration not loaded: {key}") - - keys = key.split('.') + + keys = key.split(".") value = self._config - + try: for k in keys: value = value[k] @@ -145,108 +210,100 @@ def get(self, key: str, default: Optional[Any] = None) -> Any: if default is not None: return default raise ConfigurationError(f"Configuration key not found: {key}") - + def get_audiences(self) -> List[str]: """Get list of target audiences""" - return self.get('audiences') - + return self.get("audiences") + def get_complexity_levels(self) -> List[str]: """Get list of complexity levels""" - return self.get('complexity_levels') - + return self.get("complexity_levels") + def get_llm_config(self) -> Dict[str, Any]: """Get LLM judge configuration""" - return self.get('llm_judge') - + return self.get("llm_judge") + def get_evaluation_config(self) -> Dict[str, Any]: """Get evaluation configuration""" - return self.get('evaluation') - + return self.get("evaluation") + def get_scoring_config(self) -> Dict[str, Any]: """Get scoring configuration""" - return self.get('scoring') - + return self.get("scoring") + def get_api_config(self, provider: str) -> Dict[str, Any]: """ Get API configuration for a specific provider - + Args: provider: API provider name (e.g., 'openai', 'anthropic') """ - return self.get(f'api.{provider}') - + return self.get(f"api.{provider}") + def get_data_path(self) -> str: """Get data directory path""" - return self.get('app.data_path', 'data/') - + return self.get("app.data_path", "data/") + def get_output_path(self) -> str: """Get output directory path""" - return self.get('app.output_path', 'results/') - + return self.get("app.output_path", "results/") + def setup_logging(self) -> None: """Set up logging based on configuration""" import logging.config - + # Create logs directory if it doesn't exist logs_dir = Path("logs") logs_dir.mkdir(exist_ok=True) - + # Get logging configuration - logging_config = self.get('logging') - + logging_config = self.get("logging") + try: logging.config.dictConfig(logging_config) except Exception as e: # Fallback to basic logging if config fails - logging.basicConfig( - level=logging.INFO, - format='%(asctime)s [%(levelname)s] %(name)s: %(message)s' - ) + logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s") logging.error(f"Failed to configure logging from config: {e}") - + def get_environment_variable(self, key: str, default: Optional[str] = None) -> Optional[str]: """ Get environment variable with optional default - + Args: key: Environment variable name default: Default value if not found - + Returns: Environment variable value or default """ return os.getenv(key, default) - + def get_api_key(self, provider: str) -> str: """ Get API key for provider from environment variables - + Args: provider: Provider name (e.g., 'openai', 'anthropic') - + Returns: API key - + Raises: ConfigurationError: If API key not found """ - env_var_map = { - 'openai': 'OPENAI_API_KEY', - 'anthropic': 'ANTHROPIC_API_KEY' - } - + env_var_map = {"openai": "OPENAI_API_KEY", "anthropic": "ANTHROPIC_API_KEY"} + env_var = env_var_map.get(provider.lower()) if not env_var: raise ConfigurationError(f"Unknown API provider: {provider}") - + api_key = self.get_environment_variable(env_var) if not api_key: - raise ConfigurationError( - f"API key not found. Please set {env_var} environment variable." - ) - + raise ConfigurationError(f"API key not found. Please set {env_var} environment variable.") + return api_key - + def reload(self) -> None: """Reload configuration from file""" self._config = None @@ -254,4 +311,4 @@ def reload(self) -> None: # Global config instance -config = Config() \ No newline at end of file +config = Config() diff --git a/src/data_loaders.py b/src/data_loaders.py index 1e17ae2..4427b64 100644 --- a/src/data_loaders.py +++ b/src/data_loaders.py @@ -24,12 +24,12 @@ Example: ```python from data_loaders import load_medqa_usmle, load_icliniq, load_cochrane_reviews - + # Load different datasets with automatic complexity stratification medqa_items = load_medqa_usmle('path/to/medqa.json', max_items=300) icliniq_items = load_icliniq('path/to/icliniq.json', max_items=400) cochrane_items = load_cochrane_reviews('path/to/cochrane.json', max_items=300) - + # Combine and save as benchmark dataset all_items = medqa_items + icliniq_items + cochrane_items save_benchmark_items(all_items, 'data/benchmark_items.json') @@ -40,54 +40,52 @@ import logging import re from pathlib import Path -from typing import List, Dict, Any, Optional, Union +from typing import List, Dict, Optional, Union try: - import textstat + import textstat # type: ignore except ImportError: textstat = None - + from .benchmark import MEQBenchItem -logger = logging.getLogger('meq_bench.data_loaders') +logger = logging.getLogger("meq_bench.data_loaders") def load_medquad( - data_path: Union[str, Path], - max_items: Optional[int] = None, - complexity_level: str = 'basic' + data_path: Union[str, Path], max_items: Optional[int] = None, complexity_level: str = "basic" ) -> List[MEQBenchItem]: """Load MedQuAD dataset and convert to MEQBenchItem objects. - + The MedQuAD (Medical Question Answering Dataset) contains consumer health questions and answers from various medical sources. This function loads the dataset and converts it to the MEQ-Bench format for evaluation. - + Args: data_path: Path to the MedQuAD JSON file. Can be a string or Path object. max_items: Maximum number of items to load. If None, loads all items. complexity_level: Complexity level to assign to all items. Defaults to 'basic' since MedQuAD primarily contains consumer health questions. - + Returns: List of MEQBenchItem objects converted from MedQuAD data. - + Raises: FileNotFoundError: If the data file does not exist. json.JSONDecodeError: If the JSON file is malformed. ValueError: If the data format is invalid. - + Example: ```python # Load all MedQuAD items items = load_medquad('data/medquad.json') - + # Load only first 100 items items = load_medquad('data/medquad.json', max_items=100) - + # Load with different complexity level items = load_medquad('data/medquad.json', complexity_level='intermediate') - + # Add to benchmark bench = MEQBench() for item in items: @@ -95,119 +93,114 @@ def load_medquad( ``` """ data_file = Path(data_path) - + if not data_file.exists(): raise FileNotFoundError(f"MedQuAD file not found: {data_file}") - + logger.info(f"Loading MedQuAD dataset from: {data_file}") - + try: - with open(data_file, 'r', encoding='utf-8') as f: + with open(data_file, "r", encoding="utf-8") as f: data = json.load(f) except json.JSONDecodeError as e: - raise json.JSONDecodeError( - f"Invalid JSON in MedQuAD file: {e}", - e.doc, e.pos - ) - + raise json.JSONDecodeError(f"Invalid JSON in MedQuAD file: {e}", e.doc, e.pos) + if not isinstance(data, list): raise ValueError("MedQuAD data must be a list of items") - + # Validate complexity level - if complexity_level not in ['basic', 'intermediate', 'advanced']: + if complexity_level not in ["basic", "intermediate", "advanced"]: logger.warning(f"Invalid complexity level '{complexity_level}', using 'basic'") - complexity_level = 'basic' - + complexity_level = "basic" + # Convert to MEQBenchItem objects items = [] items_to_process = data[:max_items] if max_items else data - + logger.info(f"Processing {len(items_to_process)} MedQuAD items") - + for i, item_data in enumerate(items_to_process): try: if not isinstance(item_data, dict): logger.warning(f"Skipping invalid MedQuAD item {i}: not a dictionary") continue - + # Extract required fields - MedQuAD typically has 'question' and 'answer' - question = item_data.get('question', '') - answer = item_data.get('answer', '') - item_id = item_data.get('id', f"medquad_{i}") - + question = item_data.get("question", "") + answer = item_data.get("answer", "") + item_id = item_data.get("id", f"medquad_{i}") + if not question.strip() or not answer.strip(): logger.warning(f"Skipping MedQuAD item {i}: empty question or answer") continue - + # Combine question and answer to create medical content medical_content = f"Question: {question.strip()}\\n\\nAnswer: {answer.strip()}" - + # Create MEQBenchItem item = MEQBenchItem( id=str(item_id), medical_content=medical_content, complexity_level=complexity_level, - source_dataset='MedQuAD', - reference_explanations=None # No reference explanations in MedQuAD + source_dataset="MedQuAD", + reference_explanations=None, # No reference explanations in MedQuAD ) - + # Basic validation _validate_benchmark_item(item) - + items.append(item) - + except Exception as e: logger.error(f"Error processing MedQuAD item {i}: {e}") continue - + logger.info(f"Successfully loaded {len(items)} MEQBenchItem objects from MedQuAD") - + if len(items) == 0: logger.warning("No valid items were loaded from MedQuAD dataset") else: # Log some statistics avg_length = sum(len(item.medical_content) for item in items) / len(items) - logger.info(f"MedQuAD dataset statistics:") + logger.info("MedQuAD dataset statistics:") logger.info(f" - Total items: {len(items)}") logger.info(f" - Average content length: {avg_length:.1f} characters") logger.info(f" - Complexity level: {complexity_level}") - + return items def load_healthsearchqa( - data_path: Union[str, Path], - max_items: Optional[int] = None, - complexity_level: str = 'intermediate' + data_path: Union[str, Path], max_items: Optional[int] = None, complexity_level: str = "intermediate" ) -> List[MEQBenchItem]: """Load HealthSearchQA dataset and convert to MEQBenchItem objects. - + The HealthSearchQA dataset contains health-related search queries and answers from various health websites and search engines. This loader converts the dataset into MEQBenchItem objects for use in the benchmark. - + Args: data_path: Path to the HealthSearchQA JSON file. max_items: Maximum number of items to load. If None, loads all items. complexity_level: Complexity level to assign to all items. Defaults to 'intermediate' since HealthSearchQA contains more varied complexity levels. - + Returns: List of MEQBenchItem objects converted from HealthSearchQA data. - + Raises: FileNotFoundError: If the data file does not exist. json.JSONDecodeError: If the JSON file is malformed. ValueError: If the data format is invalid. - + Example: ```python # Load HealthSearchQA items items = load_healthsearchqa('data/healthsearchqa.json') - + # Load with custom complexity level items = load_healthsearchqa('data/healthsearchqa.json', complexity_level='advanced') - + # Add to benchmark bench = MEQBench() for item in items: @@ -215,82 +208,79 @@ def load_healthsearchqa( ``` """ data_file = Path(data_path) - + if not data_file.exists(): raise FileNotFoundError(f"HealthSearchQA file not found: {data_file}") - + logger.info(f"Loading HealthSearchQA dataset from: {data_file}") - + try: - with open(data_file, 'r', encoding='utf-8') as f: + with open(data_file, "r", encoding="utf-8") as f: data = json.load(f) except json.JSONDecodeError as e: - raise json.JSONDecodeError( - f"Invalid JSON in HealthSearchQA file: {e}", - e.doc, e.pos - ) - + raise json.JSONDecodeError(f"Invalid JSON in HealthSearchQA file: {e}", e.doc, e.pos) + if not isinstance(data, list): raise ValueError("HealthSearchQA data must be a list of items") - + # Validate complexity level - if complexity_level not in ['basic', 'intermediate', 'advanced']: + if complexity_level not in ["basic", "intermediate", "advanced"]: logger.warning(f"Invalid complexity level '{complexity_level}', using 'intermediate'") - complexity_level = 'intermediate' - + complexity_level = "intermediate" + items = [] items_to_process = data[:max_items] if max_items else data - + logger.info(f"Processing {len(items_to_process)} HealthSearchQA items") - + for i, item_data in enumerate(items_to_process): try: if not isinstance(item_data, dict): logger.warning(f"Skipping invalid HealthSearchQA item {i}: not a dictionary") continue - + # HealthSearchQA might have different field names - query = item_data.get('query', item_data.get('question', '')) - answer = item_data.get('answer', item_data.get('response', '')) - item_id = item_data.get('id', f"healthsearch_{i}") - + query = item_data.get("query", item_data.get("question", "")) + answer = item_data.get("answer", item_data.get("response", "")) + item_id = item_data.get("id", f"healthsearch_{i}") + if not query.strip() or not answer.strip(): logger.warning(f"Skipping HealthSearchQA item {i}: empty query or answer") continue - + # Create medical content medical_content = f"Search Query: {query.strip()}\\n\\nAnswer: {answer.strip()}" - + # Create MEQBenchItem item = MEQBenchItem( id=str(item_id), medical_content=medical_content, complexity_level=complexity_level, - source_dataset='HealthSearchQA', - reference_explanations=None + source_dataset="HealthSearchQA", + reference_explanations=None, ) - + # Basic validation _validate_benchmark_item(item) - + items.append(item) - + except Exception as e: logger.error(f"Error processing HealthSearchQA item {i}: {e}") continue - + logger.info(f"Successfully loaded {len(items)} MEQBenchItem objects from HealthSearchQA") - + if len(items) == 0: logger.warning("No valid items were loaded from HealthSearchQA dataset") else: # Log some statistics avg_length = sum(len(item.medical_content) for item in items) / len(items) - logger.info(f"HealthSearchQA dataset statistics:") + logger.info("HealthSearchQA dataset statistics:") logger.info(f" - Total items: {len(items)}") logger.info(f" - Average content length: {avg_length:.1f} characters") logger.info(f" - Complexity level: {complexity_level}") - + return items @@ -298,52 +288,47 @@ def load_custom_dataset( data_path: Union[str, Path], field_mapping: Optional[Dict[str, str]] = None, max_items: Optional[int] = None, - complexity_level: str = 'basic' + complexity_level: str = "basic", ) -> List[MEQBenchItem]: """Load custom dataset and convert to MEQBenchItem objects. - + Args: data_path: Path to the JSON file containing the dataset. field_mapping: Dictionary mapping dataset fields to MEQBenchItem fields. Example: {'q': 'question', 'a': 'answer', 'topic': 'medical_content'} max_items: Maximum number of items to load. complexity_level: Complexity level to assign to all items. - + Returns: List of MEQBenchItem objects. """ # Default field mapping if field_mapping is None: - field_mapping = { - 'question': 'question', - 'answer': 'answer', - 'content': 'medical_content', - 'id': 'id' - } - + field_mapping = {"question": "question", "answer": "answer", "content": "medical_content", "id": "id"} + data_file = Path(data_path) if not data_file.exists(): raise FileNotFoundError(f"Custom dataset file not found: {data_file}") - + logger.info(f"Loading custom dataset from: {data_file}") - - with open(data_file, 'r', encoding='utf-8') as f: + + with open(data_file, "r", encoding="utf-8") as f: data = json.load(f) - + if not isinstance(data, list): raise ValueError("Custom dataset must be a list of items") - + items = [] items_to_process = data[:max_items] if max_items else data - + for i, item_data in enumerate(items_to_process): try: # Extract fields based on mapping - question = item_data.get(field_mapping.get('question', 'question'), '') - answer = item_data.get(field_mapping.get('answer', 'answer'), '') - content = item_data.get(field_mapping.get('content', 'content'), '') - item_id = item_data.get(field_mapping.get('id', 'id'), f"custom_{i}") - + question = item_data.get(field_mapping.get("question", "question"), "") + answer = item_data.get(field_mapping.get("answer", "answer"), "") + content = item_data.get(field_mapping.get("content", "content"), "") + item_id = item_data.get(field_mapping.get("id", "id"), f"custom_{i}") + # Create medical content if content: medical_content = content @@ -352,33 +337,29 @@ def load_custom_dataset( else: logger.warning(f"Skipping item {i}: no valid content found") continue - + item = MEQBenchItem( id=str(item_id), medical_content=medical_content, complexity_level=complexity_level, - source_dataset='Custom', - reference_explanations=None + source_dataset="Custom", + reference_explanations=None, ) - + _validate_benchmark_item(item) items.append(item) - + except Exception as e: logger.error(f"Error processing custom dataset item {i}: {e}") continue - + logger.info(f"Successfully loaded {len(items)} items from custom dataset") return items -def save_benchmark_items( - items: List[MEQBenchItem], - output_path: Union[str, Path], - pretty_print: bool = True -) -> None: +def save_benchmark_items(items: List[MEQBenchItem], output_path: Union[str, Path], pretty_print: bool = True) -> None: """Save MEQBenchItem objects to a JSON file. - + Args: items: List of MEQBenchItem objects to save. output_path: Path where to save the JSON file. @@ -386,44 +367,44 @@ def save_benchmark_items( """ output_file = Path(output_path) output_file.parent.mkdir(parents=True, exist_ok=True) - + # Convert items to dictionaries items_data = [] for item in items: item_dict = { - 'id': item.id, - 'medical_content': item.medical_content, - 'complexity_level': item.complexity_level, - 'source_dataset': item.source_dataset, - 'reference_explanations': item.reference_explanations + "id": item.id, + "medical_content": item.medical_content, + "complexity_level": item.complexity_level, + "source_dataset": item.source_dataset, + "reference_explanations": item.reference_explanations, } items_data.append(item_dict) - + # Save to JSON - with open(output_file, 'w', encoding='utf-8') as f: + with open(output_file, "w", encoding="utf-8") as f: if pretty_print: json.dump(items_data, f, indent=2, ensure_ascii=False) else: json.dump(items_data, f, ensure_ascii=False) - + logger.info(f"Saved {len(items)} benchmark items to: {output_file}") def calculate_complexity_level(text: str) -> str: """Calculate complexity level based on Flesch-Kincaid Grade Level. - + Uses textstat library to compute Flesch-Kincaid Grade Level and categorizes the text into basic, intermediate, or advanced complexity levels. - + Args: text: Text content to analyze for complexity. - + Returns: Complexity level string: 'basic', 'intermediate', or 'advanced' - + Raises: ValueError: If text is empty or invalid. - + Example: ```python complexity = calculate_complexity_level("This is simple text.") @@ -432,31 +413,31 @@ def calculate_complexity_level(text: str) -> str: """ if not text or not isinstance(text, str): raise ValueError("Text must be a non-empty string") - + text = text.strip() if not text: raise ValueError("Text cannot be empty or whitespace only") - + # Clean text for analysis - remove extra whitespace and normalize - cleaned_text = ' '.join(text.split()) - + cleaned_text = " ".join(text.split()) + # Fallback if textstat is not available if textstat is None: logger.warning("textstat library not available, using fallback complexity calculation") return _calculate_complexity_fallback(cleaned_text) - + try: # Calculate Flesch-Kincaid Grade Level fk_score = textstat.flesch_kincaid().grade(cleaned_text) - + # Categorize based on grade level if fk_score <= 8: - return 'basic' + return "basic" elif fk_score <= 12: - return 'intermediate' + return "intermediate" else: - return 'advanced' - + return "advanced" + except Exception as e: logger.warning(f"Error calculating Flesch-Kincaid score: {e}, using fallback") return _calculate_complexity_fallback(cleaned_text) @@ -464,235 +445,240 @@ def calculate_complexity_level(text: str) -> str: def _calculate_complexity_fallback(text: str) -> str: """Fallback complexity calculation when textstat is unavailable. - + Uses simple heuristics based on sentence length, word length, and medical terminology density as approximations. - + Args: text: Cleaned text to analyze. - + Returns: Complexity level: 'basic', 'intermediate', or 'advanced' """ # Count sentences (approximate) - sentences = len(re.split(r'[.!?]+', text)) + sentences = len(re.split(r"[.!?]+", text)) if sentences == 0: sentences = 1 - + # Count words words = len(text.split()) if words == 0: - return 'basic' - + return "basic" + # Calculate average words per sentence avg_words_per_sentence = words / sentences - + # Count syllables (rough approximation) syllable_count = 0 for word in text.split(): # Simple syllable counting heuristic - vowels = 'aeiouyAEIOUY' - word = re.sub(r'[^a-zA-Z]', '', word) + word = re.sub(r"[^a-zA-Z]", "", word) if word: - syllables = len(re.findall(r'[aeiouyAEIOUY]+', word)) + syllables = len(re.findall(r"[aeiouyAEIOUY]+", word)) if syllables == 0: syllables = 1 syllable_count += syllables - + avg_syllables_per_word = syllable_count / words if words > 0 else 1 - + # Check for medical terminology (indicator of higher complexity) medical_terms = [ - 'diagnosis', 'treatment', 'therapy', 'syndrome', 'pathology', - 'etiology', 'prognosis', 'medication', 'prescription', 'dosage', - 'contraindication', 'adverse', 'efficacy', 'pharmacology', - 'clinical', 'therapeutic', 'intervention' + "diagnosis", + "treatment", + "therapy", + "syndrome", + "pathology", + "etiology", + "prognosis", + "medication", + "prescription", + "dosage", + "contraindication", + "adverse", + "efficacy", + "pharmacology", + "clinical", + "therapeutic", + "intervention", ] - + medical_term_count = sum(1 for term in medical_terms if term.lower() in text.lower()) medical_density = medical_term_count / words * 100 # percentage - + # Simple scoring algorithm - complexity_score = 0 + complexity_score: float = 0 complexity_score += avg_words_per_sentence * 0.5 complexity_score += avg_syllables_per_word * 3 complexity_score += medical_density * 0.3 - + if complexity_score <= 8: - return 'basic' + return "basic" elif complexity_score <= 15: - return 'intermediate' + return "intermediate" else: - return 'advanced' + return "advanced" def load_medqa_usmle( - data_path: Union[str, Path], - max_items: Optional[int] = None, - auto_complexity: bool = True + data_path: Union[str, Path], max_items: Optional[int] = None, auto_complexity: bool = True ) -> List[MEQBenchItem]: """Load MedQA-USMLE dataset and convert to MEQBenchItem objects. - + The MedQA-USMLE dataset contains medical questions based on USMLE exam format with multiple choice questions and explanations. This loader processes the dataset and optionally applies automatic complexity stratification. - + Args: data_path: Path to the MedQA-USMLE JSON file. max_items: Maximum number of items to load. If None, loads all items. auto_complexity: Whether to automatically calculate complexity levels using Flesch-Kincaid scores. If False, assigns 'intermediate' to all. - + Returns: List of MEQBenchItem objects converted from MedQA-USMLE data. - + Raises: FileNotFoundError: If the data file does not exist. json.JSONDecodeError: If the JSON file is malformed. ValueError: If the data format is invalid. - + Example: ```python # Load with automatic complexity calculation items = load_medqa_usmle('data/medqa_usmle.json', max_items=300) - + # Load without complexity calculation (all marked as intermediate) items = load_medqa_usmle('data/medqa_usmle.json', auto_complexity=False) ``` """ data_file = Path(data_path) - + if not data_file.exists(): raise FileNotFoundError(f"MedQA-USMLE file not found: {data_file}") - + logger.info(f"Loading MedQA-USMLE dataset from: {data_file}") - + try: - with open(data_file, 'r', encoding='utf-8') as f: + with open(data_file, "r", encoding="utf-8") as f: data = json.load(f) except json.JSONDecodeError as e: - raise json.JSONDecodeError( - f"Invalid JSON in MedQA-USMLE file: {e}", - e.doc, e.pos - ) - + raise json.JSONDecodeError(f"Invalid JSON in MedQA-USMLE file: {e}", e.doc, e.pos) + if not isinstance(data, list): raise ValueError("MedQA-USMLE data must be a list of items") - + items = [] items_to_process = data[:max_items] if max_items else data - + logger.info(f"Processing {len(items_to_process)} MedQA-USMLE items") - + for i, item_data in enumerate(items_to_process): try: if not isinstance(item_data, dict): logger.warning(f"Skipping invalid MedQA-USMLE item {i}: not a dictionary") continue - + # Extract required fields - MedQA typically has 'question', 'options', 'answer', 'explanation' - question = item_data.get('question', '') - options = item_data.get('options', {}) - answer = item_data.get('answer', '') - explanation = item_data.get('explanation', '') - item_id = item_data.get('id', f"medqa_usmle_{i}") - + question = item_data.get("question", "") + options = item_data.get("options", {}) + answer = item_data.get("answer", "") + explanation = item_data.get("explanation", "") + item_id = item_data.get("id", f"medqa_usmle_{i}") + if not question.strip(): logger.warning(f"Skipping MedQA-USMLE item {i}: empty question") continue - + # Format options if available options_text = "" if isinstance(options, dict): options_text = "\n".join([f"{k}. {v}" for k, v in options.items() if v]) elif isinstance(options, list): options_text = "\n".join([f"{chr(65+j)}. {opt}" for j, opt in enumerate(options) if opt]) - + # Create comprehensive medical content medical_content_parts = [f"Question: {question.strip()}"] - + if options_text: medical_content_parts.append(f"Options:\n{options_text}") - + if answer.strip(): medical_content_parts.append(f"Correct Answer: {answer.strip()}") - + if explanation.strip(): medical_content_parts.append(f"Explanation: {explanation.strip()}") - + medical_content = "\n\n".join(medical_content_parts) - + # Calculate complexity level if auto_complexity: try: complexity_level = calculate_complexity_level(medical_content) except Exception as e: logger.warning(f"Error calculating complexity for item {i}: {e}, using 'intermediate'") - complexity_level = 'intermediate' + complexity_level = "intermediate" else: - complexity_level = 'intermediate' - + complexity_level = "intermediate" + # Create MEQBenchItem item = MEQBenchItem( id=str(item_id), medical_content=medical_content, complexity_level=complexity_level, - source_dataset='MedQA-USMLE', - reference_explanations=None + source_dataset="MedQA-USMLE", + reference_explanations=None, ) - + # Validate the item _validate_benchmark_item(item) - + items.append(item) - + except Exception as e: logger.error(f"Error processing MedQA-USMLE item {i}: {e}") continue - + logger.info(f"Successfully loaded {len(items)} MEQBenchItem objects from MedQA-USMLE") - + if len(items) == 0: logger.warning("No valid items were loaded from MedQA-USMLE dataset") else: # Log complexity distribution - complexity_dist = {} + complexity_dist: Dict[str, int] = {} for item in items: complexity_dist[item.complexity_level] = complexity_dist.get(item.complexity_level, 0) + 1 - + avg_length = sum(len(item.medical_content) for item in items) / len(items) - logger.info(f"MedQA-USMLE dataset statistics:") + logger.info("MedQA-USMLE dataset statistics:") logger.info(f" - Total items: {len(items)}") logger.info(f" - Average content length: {avg_length:.1f} characters") logger.info(f" - Complexity distribution: {complexity_dist}") - + return items def load_icliniq( - data_path: Union[str, Path], - max_items: Optional[int] = None, - auto_complexity: bool = True + data_path: Union[str, Path], max_items: Optional[int] = None, auto_complexity: bool = True ) -> List[MEQBenchItem]: """Load iCliniq dataset and convert to MEQBenchItem objects. - + The iCliniq dataset contains real clinical questions from patients and answers from medical professionals. This loader processes the dataset and optionally applies automatic complexity stratification. - + Args: data_path: Path to the iCliniq JSON file. max_items: Maximum number of items to load. If None, loads all items. auto_complexity: Whether to automatically calculate complexity levels. - + Returns: List of MEQBenchItem objects converted from iCliniq data. - + Raises: FileNotFoundError: If the data file does not exist. json.JSONDecodeError: If the JSON file is malformed. ValueError: If the data format is invalid. - + Example: ```python # Load iCliniq dataset with complexity stratification @@ -700,126 +686,121 @@ def load_icliniq( ``` """ data_file = Path(data_path) - + if not data_file.exists(): raise FileNotFoundError(f"iCliniq file not found: {data_file}") - + logger.info(f"Loading iCliniq dataset from: {data_file}") - + try: - with open(data_file, 'r', encoding='utf-8') as f: + with open(data_file, "r", encoding="utf-8") as f: data = json.load(f) except json.JSONDecodeError as e: - raise json.JSONDecodeError( - f"Invalid JSON in iCliniq file: {e}", - e.doc, e.pos - ) - + raise json.JSONDecodeError(f"Invalid JSON in iCliniq file: {e}", e.doc, e.pos) + if not isinstance(data, list): raise ValueError("iCliniq data must be a list of items") - + items = [] items_to_process = data[:max_items] if max_items else data - + logger.info(f"Processing {len(items_to_process)} iCliniq items") - + for i, item_data in enumerate(items_to_process): try: if not isinstance(item_data, dict): logger.warning(f"Skipping invalid iCliniq item {i}: not a dictionary") continue - + # Extract fields - iCliniq typically has 'patient_question', 'doctor_answer', 'speciality' - patient_question = item_data.get('patient_question', item_data.get('question', '')) - doctor_answer = item_data.get('doctor_answer', item_data.get('answer', '')) - specialty = item_data.get('speciality', item_data.get('specialty', '')) - item_id = item_data.get('id', f"icliniq_{i}") - + patient_question = item_data.get("patient_question", item_data.get("question", "")) + doctor_answer = item_data.get("doctor_answer", item_data.get("answer", "")) + specialty = item_data.get("speciality", item_data.get("specialty", "")) + item_id = item_data.get("id", f"icliniq_{i}") + if not patient_question.strip() or not doctor_answer.strip(): logger.warning(f"Skipping iCliniq item {i}: empty question or answer") continue - + # Create medical content medical_content_parts = [f"Patient Question: {patient_question.strip()}"] - + if specialty.strip(): medical_content_parts.append(f"Medical Specialty: {specialty.strip()}") - + medical_content_parts.append(f"Doctor's Answer: {doctor_answer.strip()}") - + medical_content = "\n\n".join(medical_content_parts) - + # Calculate complexity level if auto_complexity: try: complexity_level = calculate_complexity_level(medical_content) except Exception as e: logger.warning(f"Error calculating complexity for item {i}: {e}, using 'basic'") - complexity_level = 'basic' + complexity_level = "basic" else: - complexity_level = 'basic' # iCliniq tends to be more patient-focused - + complexity_level = "basic" # iCliniq tends to be more patient-focused + # Create MEQBenchItem item = MEQBenchItem( id=str(item_id), medical_content=medical_content, complexity_level=complexity_level, - source_dataset='iCliniq', - reference_explanations=None + source_dataset="iCliniq", + reference_explanations=None, ) - + # Validate the item _validate_benchmark_item(item) - + items.append(item) - + except Exception as e: logger.error(f"Error processing iCliniq item {i}: {e}") continue - + logger.info(f"Successfully loaded {len(items)} MEQBenchItem objects from iCliniq") - + if len(items) == 0: logger.warning("No valid items were loaded from iCliniq dataset") else: # Log complexity distribution - complexity_dist = {} + complexity_dist: Dict[str, int] = {} for item in items: complexity_dist[item.complexity_level] = complexity_dist.get(item.complexity_level, 0) + 1 - + avg_length = sum(len(item.medical_content) for item in items) / len(items) - logger.info(f"iCliniq dataset statistics:") + logger.info("iCliniq dataset statistics:") logger.info(f" - Total items: {len(items)}") logger.info(f" - Average content length: {avg_length:.1f} characters") logger.info(f" - Complexity distribution: {complexity_dist}") - + return items def load_cochrane_reviews( - data_path: Union[str, Path], - max_items: Optional[int] = None, - auto_complexity: bool = True + data_path: Union[str, Path], max_items: Optional[int] = None, auto_complexity: bool = True ) -> List[MEQBenchItem]: """Load Cochrane Reviews dataset and convert to MEQBenchItem objects. - + The Cochrane Reviews dataset contains evidence-based medical reviews and systematic analyses. This loader processes the dataset and optionally applies automatic complexity stratification. - + Args: data_path: Path to the Cochrane Reviews JSON file. max_items: Maximum number of items to load. If None, loads all items. auto_complexity: Whether to automatically calculate complexity levels. - + Returns: List of MEQBenchItem objects converted from Cochrane Reviews data. - + Raises: FileNotFoundError: If the data file does not exist. json.JSONDecodeError: If the JSON file is malformed. ValueError: If the data format is invalid. - + Example: ```python # Load Cochrane Reviews with complexity stratification @@ -827,128 +808,125 @@ def load_cochrane_reviews( ``` """ data_file = Path(data_path) - + if not data_file.exists(): raise FileNotFoundError(f"Cochrane Reviews file not found: {data_file}") - + logger.info(f"Loading Cochrane Reviews dataset from: {data_file}") - + try: - with open(data_file, 'r', encoding='utf-8') as f: + with open(data_file, "r", encoding="utf-8") as f: data = json.load(f) except json.JSONDecodeError as e: - raise json.JSONDecodeError( - f"Invalid JSON in Cochrane Reviews file: {e}", - e.doc, e.pos - ) - + raise json.JSONDecodeError(f"Invalid JSON in Cochrane Reviews file: {e}", e.doc, e.pos) + if not isinstance(data, list): raise ValueError("Cochrane Reviews data must be a list of items") - + items = [] items_to_process = data[:max_items] if max_items else data - + logger.info(f"Processing {len(items_to_process)} Cochrane Reviews items") - + for i, item_data in enumerate(items_to_process): try: if not isinstance(item_data, dict): logger.warning(f"Skipping invalid Cochrane Reviews item {i}: not a dictionary") continue - + # Extract fields - Cochrane typically has 'title', 'abstract', 'conclusions', 'background' - title = item_data.get('title', '') - abstract = item_data.get('abstract', '') - conclusions = item_data.get('conclusions', item_data.get('main_results', '')) - background = item_data.get('background', item_data.get('objectives', '')) - item_id = item_data.get('id', f"cochrane_{i}") - + title = item_data.get("title", "") + abstract = item_data.get("abstract", "") + conclusions = item_data.get("conclusions", item_data.get("main_results", "")) + background = item_data.get("background", item_data.get("objectives", "")) + item_id = item_data.get("id", f"cochrane_{i}") + if not title.strip() and not abstract.strip(): logger.warning(f"Skipping Cochrane Reviews item {i}: no title or abstract") continue - + # Create medical content from available fields medical_content_parts = [] - + if title.strip(): medical_content_parts.append(f"Title: {title.strip()}") - + if background.strip(): medical_content_parts.append(f"Background: {background.strip()}") - + if abstract.strip(): medical_content_parts.append(f"Abstract: {abstract.strip()}") - + if conclusions.strip(): medical_content_parts.append(f"Conclusions: {conclusions.strip()}") - + if not medical_content_parts: logger.warning(f"Skipping Cochrane Reviews item {i}: no valid content") continue - + medical_content = "\n\n".join(medical_content_parts) - + # Calculate complexity level if auto_complexity: try: complexity_level = calculate_complexity_level(medical_content) except Exception as e: logger.warning(f"Error calculating complexity for item {i}: {e}, using 'advanced'") - complexity_level = 'advanced' + complexity_level = "advanced" else: - complexity_level = 'advanced' # Cochrane reviews tend to be more technical - + complexity_level = "advanced" # Cochrane reviews tend to be more technical + # Create MEQBenchItem item = MEQBenchItem( id=str(item_id), medical_content=medical_content, complexity_level=complexity_level, - source_dataset='Cochrane Reviews', - reference_explanations=None + source_dataset="Cochrane Reviews", + reference_explanations=None, ) - + # Validate the item _validate_benchmark_item(item) - + items.append(item) - + except Exception as e: logger.error(f"Error processing Cochrane Reviews item {i}: {e}") continue - + logger.info(f"Successfully loaded {len(items)} MEQBenchItem objects from Cochrane Reviews") - + if len(items) == 0: logger.warning("No valid items were loaded from Cochrane Reviews dataset") else: # Log complexity distribution - complexity_dist = {} + complexity_dist: Dict[str, int] = {} for item in items: complexity_dist[item.complexity_level] = complexity_dist.get(item.complexity_level, 0) + 1 - + avg_length = sum(len(item.medical_content) for item in items) / len(items) - logger.info(f"Cochrane Reviews dataset statistics:") + logger.info("Cochrane Reviews dataset statistics:") logger.info(f" - Total items: {len(items)}") logger.info(f" - Average content length: {avg_length:.1f} characters") logger.info(f" - Complexity distribution: {complexity_dist}") - + return items def _validate_benchmark_item(item: MEQBenchItem) -> None: """Validate a MEQBenchItem object for basic requirements. - + Args: item: MEQBenchItem to validate - + Raises: ValueError: If the item doesn't meet basic requirements """ if not item.id or not isinstance(item.id, str): raise ValueError("Item ID must be a non-empty string") - + if not item.medical_content or not isinstance(item.medical_content, str): raise ValueError("Medical content must be a non-empty string") - + if len(item.medical_content.strip()) < 20: - raise ValueError("Medical content is too short (less than 20 characters)") \ No newline at end of file + raise ValueError("Medical content is too short (less than 20 characters)") diff --git a/src/evaluator.py b/src/evaluator.py index 3240983..bd4c045 100644 --- a/src/evaluator.py +++ b/src/evaluator.py @@ -3,13 +3,13 @@ """ import re -import os import time import json -import requests +import requests # type: ignore import numpy as np -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import Dict, List, Any, Optional, Protocol, Union, Callable, Type, Set + try: from typing_extensions import TypedDict except ImportError: @@ -21,18 +21,20 @@ from .strategies import StrategyFactory, AudienceStrategy # Set up logging -logger = logging.getLogger('meq_bench.evaluator') +logger = logging.getLogger("meq_bench.evaluator") # TypedDict definitions for better structure typing class APIConfigDict(TypedDict): """Type definition for API configuration.""" + base_url: str timeout: Optional[int] class EvaluationConfigDict(TypedDict): """Type definition for evaluation configuration.""" + safety: Dict[str, List[str]] medical_terms: List[str] readability_targets: Dict[str, Dict[str, float]] @@ -41,14 +43,16 @@ class EvaluationConfigDict(TypedDict): class ScoringConfigDict(TypedDict): """Type definition for scoring configuration.""" + weights: Dict[str, float] parameters: Dict[str, Any] + try: - import textstat - from sentence_transformers import SentenceTransformer - import spacy - from transformers import pipeline + import textstat # type: ignore + from sentence_transformers import SentenceTransformer # type: ignore + import spacy # type: ignore + from transformers import pipeline # type: ignore except ImportError as e: logger.warning(f"Some evaluation dependencies not installed: {e}") logger.info("Install with: pip install -r requirements.txt") @@ -56,12 +60,14 @@ class ScoringConfigDict(TypedDict): class EvaluationError(Exception): """Raised when there's an error during evaluation""" + pass @dataclass class EvaluationScore: """Container for evaluation scores with detailed breakdown""" + readability: float terminology: float safety: float @@ -72,26 +78,26 @@ class EvaluationScore: hallucination: float overall: float details: Optional[Dict[str, Any]] = None - + def to_dict(self) -> Dict[str, Any]: """Convert to dictionary for serialization""" return { - 'readability': self.readability, - 'terminology': self.terminology, - 'safety': self.safety, - 'coverage': self.coverage, - 'quality': self.quality, - 'contradiction': self.contradiction, - 'information_preservation': self.information_preservation, - 'hallucination': self.hallucination, - 'overall': self.overall, - 'details': self.details or {} + "readability": self.readability, + "terminology": self.terminology, + "safety": self.safety, + "coverage": self.coverage, + "quality": self.quality, + "contradiction": self.contradiction, + "information_preservation": self.information_preservation, + "hallucination": self.hallucination, + "overall": self.overall, + "details": self.details or {}, } class MetricCalculator(Protocol): """Protocol for metric calculators""" - + def calculate(self, text: str, audience: str, **kwargs) -> float: """Calculate metric score""" ... @@ -99,20 +105,20 @@ def calculate(self, text: str, audience: str, **kwargs) -> float: class ReadabilityCalculator: """Calculator for readability metrics using dependency injection""" - + def __init__(self, strategy_factory: StrategyFactory) -> None: self.strategy_factory: StrategyFactory = strategy_factory logger.debug("Initialized ReadabilityCalculator") - + def calculate(self, text: str, audience: str, **kwargs) -> float: """ Calculate readability score for given audience - + Args: text: Text to analyze audience: Target audience **kwargs: Additional parameters - + Returns: Readability score (0-1) """ @@ -120,24 +126,24 @@ def calculate(self, text: str, audience: str, **kwargs) -> float: if not text.strip(): logger.warning("Empty text provided for readability calculation") return 0.0 - + # Get grade level using textstat try: grade_level = textstat.flesch_kincaid().grade(text) except Exception as e: logger.error(f"Error calculating Flesch-Kincaid score: {e}") # Fallback: estimate based on sentence length - sentences = text.split('.') + sentences = text.split(".") avg_sentence_length = sum(len(s.split()) for s in sentences) / max(len(sentences), 1) grade_level = min(16, max(6, avg_sentence_length / 4)) - + # Use strategy pattern for audience-specific scoring strategy = self.strategy_factory.create_strategy(audience) score = strategy.calculate_readability_score(text, grade_level) - + logger.debug(f"Readability score for {audience}: {score:.3f} (grade level: {grade_level:.1f})") return score - + except Exception as e: logger.error(f"Error calculating readability for {audience}: {e}") raise EvaluationError(f"Readability calculation failed: {e}") @@ -145,21 +151,21 @@ def calculate(self, text: str, audience: str, **kwargs) -> float: class TerminologyCalculator: """Calculator for medical terminology appropriateness""" - + def __init__(self, strategy_factory: StrategyFactory) -> None: self.strategy_factory: StrategyFactory = strategy_factory - self.medical_terms: Set[str] = set(config.get('evaluation.medical_terms', [])) + self.medical_terms: Set[str] = set(config.get("evaluation.medical_terms", [])) logger.debug(f"Initialized TerminologyCalculator with {len(self.medical_terms)} medical terms") - + def calculate(self, text: str, audience: str, **kwargs) -> float: """ Calculate terminology appropriateness score - + Args: text: Text to analyze audience: Target audience **kwargs: Additional parameters - + Returns: Terminology score (0-1) """ @@ -167,22 +173,22 @@ def calculate(self, text: str, audience: str, **kwargs) -> float: if not text.strip(): logger.warning("Empty text provided for terminology calculation") return 0.0 - + words = text.lower().split() if not words: return 0.0 - + # Count medical terms medical_count = sum(1 for word in words if any(term in word for term in self.medical_terms)) term_density = medical_count / len(words) - + # Use strategy pattern for audience-specific scoring strategy = self.strategy_factory.create_strategy(audience) score = strategy.calculate_terminology_score(text, term_density) - + logger.debug(f"Terminology score for {audience}: {score:.3f} (density: {term_density:.3f})") return score - + except Exception as e: logger.error(f"Error calculating terminology for {audience}: {e}") raise EvaluationError(f"Terminology calculation failed: {e}") @@ -190,22 +196,22 @@ def calculate(self, text: str, audience: str, **kwargs) -> float: class SafetyChecker: """Medical safety and factual consistency checker""" - + def __init__(self) -> None: eval_config = config.get_evaluation_config() # type: ignore[misc] - self.danger_words: List[str] = eval_config['safety']['danger_words'] - self.safety_words: List[str] = eval_config['safety']['safety_words'] + self.danger_words: List[str] = eval_config["safety"]["danger_words"] + self.safety_words: List[str] = eval_config["safety"]["safety_words"] logger.debug(f"Initialized SafetyChecker with {len(self.danger_words)} danger words") - + def calculate(self, text: str, audience: str, **kwargs) -> float: """ Check text for safety compliance - + Args: text: Text to check audience: Target audience (not used in current implementation) **kwargs: Additional parameters - + Returns: Safety compliance score (0-1) """ @@ -213,30 +219,30 @@ def calculate(self, text: str, audience: str, **kwargs) -> float: if not text.strip(): logger.warning("Empty text provided for safety check") return 0.5 # Neutral score for empty text - + text_lower = text.lower() - + # Check for dangerous advice danger_count = sum(1 for word in self.danger_words if word in text_lower) if danger_count > 0: logger.warning(f"Dangerous content detected: {danger_count} danger words found") return 0.0 - + # Check for appropriate safety language safety_count = sum(1 for word in self.safety_words if word in text_lower) - + # Calculate safety score safety_score = min(1.0, safety_count * 0.3) - + # Bonus for mentioning healthcare professionals - professional_terms = ['doctor', 'physician', 'healthcare provider', 'medical professional'] + professional_terms = ["doctor", "physician", "healthcare provider", "medical professional"] professional_mentions = sum(1 for term in professional_terms if term in text_lower) if professional_mentions > 0: safety_score = min(1.0, safety_score + 0.2) - + logger.debug(f"Safety score: {safety_score:.3f} (safety words: {safety_count})") return safety_score - + except Exception as e: logger.error(f"Error in safety check: {e}") raise EvaluationError(f"Safety check failed: {e}") @@ -244,25 +250,25 @@ def calculate(self, text: str, audience: str, **kwargs) -> float: class CoverageAnalyzer: """Analyzer for information coverage and completeness""" - + def __init__(self) -> None: try: - self.sentence_model: Optional[Any] = SentenceTransformer('all-MiniLM-L6-v2') + self.sentence_model: Optional[Any] = SentenceTransformer("all-MiniLM-L6-v2") logger.debug("Initialized CoverageAnalyzer with SentenceTransformer") except Exception as e: logger.warning(f"Failed to load SentenceTransformer: {e}") self.sentence_model = None - + def calculate(self, text: str, audience: str, original: str = "", **kwargs) -> float: """ Measure information coverage using semantic similarity - + Args: text: Generated explanation text audience: Target audience original: Original medical information **kwargs: Additional parameters - + Returns: Coverage score (0-1) """ @@ -270,118 +276,109 @@ def calculate(self, text: str, audience: str, original: str = "", **kwargs) -> f if not text.strip() or not original.strip(): logger.warning("Empty text or original provided for coverage analysis") return 0.0 - + if not self.sentence_model: # Fallback to simple word overlap return self._calculate_word_overlap(original, text) - + # Use sentence transformers for semantic similarity try: orig_embedding = self.sentence_model.encode([original]) gen_embedding = self.sentence_model.encode([text]) - + similarity = np.dot(orig_embedding[0], gen_embedding[0]) / ( np.linalg.norm(orig_embedding[0]) * np.linalg.norm(gen_embedding[0]) ) - + coverage_score = max(0.0, min(1.0, similarity)) - + logger.debug(f"Coverage score: {coverage_score:.3f} (semantic similarity)") return coverage_score - + except Exception as e: logger.warning(f"Semantic similarity calculation failed: {e}") return self._calculate_word_overlap(original, text) - + except Exception as e: logger.error(f"Error calculating coverage: {e}") raise EvaluationError(f"Coverage calculation failed: {e}") - + def _calculate_word_overlap(self, original: str, generated: str) -> float: """Fallback method using word overlap""" orig_words = set(original.lower().split()) gen_words = set(generated.lower().split()) - + if not orig_words: return 0.0 - + overlap = len(orig_words.intersection(gen_words)) coverage = min(1.0, overlap / len(orig_words)) - + logger.debug(f"Coverage score: {coverage:.3f} (word overlap)") return coverage class ContradictionDetection: """Medical contradiction detection against knowledge base""" - + def __init__(self) -> None: self.medical_knowledge_base = self._load_medical_knowledge() self.contradiction_patterns = self._load_contradiction_patterns() logger.debug("Initialized ContradictionDetection") - + def _load_medical_knowledge(self) -> Dict[str, List[str]]: """Load basic medical knowledge base for contradiction detection""" # This would ideally load from a comprehensive medical knowledge base # For now, using a simplified version with common medical facts return { - 'hypertension': [ - 'high blood pressure', - 'systolic over 140 or diastolic over 90', - 'can lead to heart disease and stroke', - 'treated with lifestyle changes and medication' + "hypertension": [ + "high blood pressure", + "systolic over 140 or diastolic over 90", + "can lead to heart disease and stroke", + "treated with lifestyle changes and medication", ], - 'diabetes': [ - 'high blood sugar', - 'insulin resistance or deficiency', - 'requires blood sugar monitoring', - 'managed with diet, exercise, and medication' + "diabetes": [ + "high blood sugar", + "insulin resistance or deficiency", + "requires blood sugar monitoring", + "managed with diet, exercise, and medication", ], - 'antibiotics': [ - 'treat bacterial infections', - 'do not work against viruses', - 'should be taken as prescribed', - 'resistance can develop from misuse' + "antibiotics": [ + "treat bacterial infections", + "do not work against viruses", + "should be taken as prescribed", + "resistance can develop from misuse", + ], + "aspirin": [ + "pain reliever and blood thinner", + "can cause stomach bleeding", + "contraindicated with certain conditions", + "requires medical supervision for daily use", ], - 'aspirin': [ - 'pain reliever and blood thinner', - 'can cause stomach bleeding', - 'contraindicated with certain conditions', - 'requires medical supervision for daily use' - ] } - + def _load_contradiction_patterns(self) -> List[Dict[str, str]]: """Load patterns that indicate medical contradictions""" return [ + {"pattern": r"antibiotics.*treat.*virus", "description": "Antibiotics do not treat viral infections"}, { - 'pattern': r'antibiotics.*treat.*virus', - 'description': 'Antibiotics do not treat viral infections' + "pattern": r"stop.*medication.*immediately.*feel.*better", + "description": "Medications should not be stopped without medical advice", }, - { - 'pattern': r'stop.*medication.*immediately.*feel.*better', - 'description': 'Medications should not be stopped without medical advice' - }, - { - 'pattern': r'aspirin.*safe.*everyone', - 'description': 'Aspirin has contraindications and side effects' - }, - { - 'pattern': r'blood pressure.*normal.*140/90', - 'description': '140/90 is not normal blood pressure' - } + {"pattern": r"aspirin.*safe.*everyone", "description": "Aspirin has contraindications and side effects"}, + {"pattern": r"blood pressure.*normal.*140/90", "description": "140/90 is not normal blood pressure"}, ] - + def calculate(self, text: str, audience: str, original: str = "", **kwargs) -> float: """ Detect contradictions against medical knowledge base - + Args: text: Generated explanation text audience: Target audience original: Original medical information **kwargs: Additional parameters - + Returns: Contradiction score (0-1, where 1 means no contradictions) """ @@ -389,44 +386,39 @@ def calculate(self, text: str, audience: str, original: str = "", **kwargs) -> f if not text.strip(): logger.warning("Empty text provided for contradiction detection") return 0.5 - + text_lower = text.lower() contradiction_count = 0 - + # Check for pattern-based contradictions for pattern_info in self.contradiction_patterns: - pattern = pattern_info['pattern'] + pattern = pattern_info["pattern"] if re.search(pattern, text_lower): contradiction_count += 1 logger.warning(f"Contradiction detected: {pattern_info['description']}") - + # Check for factual contradictions against knowledge base for topic, facts in self.medical_knowledge_base.items(): if topic in text_lower: # Simple contradiction detection: look for negations of known facts for fact in facts: # Check for explicit contradictions - contradiction_patterns = [ - f"not {fact}", - f"never {fact}", - f"don't {fact}", - f"doesn't {fact}" - ] + contradiction_patterns = [f"not {fact}", f"never {fact}", f"don't {fact}", f"doesn't {fact}"] for contradiction in contradiction_patterns: if contradiction in text_lower: contradiction_count += 1 logger.warning(f"Knowledge base contradiction: {contradiction}") - + # Calculate score (higher is better, no contradictions = 1.0) if contradiction_count == 0: score = 1.0 else: # Penalize based on number of contradictions score = max(0.0, 1.0 - (contradiction_count * 0.3)) - + logger.debug(f"Contradiction score: {score:.3f} ({contradiction_count} contradictions found)") return score - + except Exception as e: logger.error(f"Error in contradiction detection: {e}") raise EvaluationError(f"Contradiction detection failed: {e}") @@ -434,65 +426,65 @@ def calculate(self, text: str, audience: str, original: str = "", **kwargs) -> f class InformationPreservation: """Check preservation of critical medical information""" - + def __init__(self) -> None: self.critical_info_patterns = self._load_critical_patterns() logger.debug("Initialized InformationPreservation") - + def _load_critical_patterns(self) -> Dict[str, List[str]]: """Load patterns for critical medical information""" return { - 'dosages': [ - r'\d+\s*mg', - r'\d+\s*ml', - r'\d+\s*units', - r'\d+\s*tablets?', - r'\d+\s*times?\s*(?:per\s+)?day', - r'once\s+daily', - r'twice\s+daily', - r'every\s+\d+\s+hours' + "dosages": [ + r"\d+\s*mg", + r"\d+\s*ml", + r"\d+\s*units", + r"\d+\s*tablets?", + r"\d+\s*times?\s*(?:per\s+)?day", + r"once\s+daily", + r"twice\s+daily", + r"every\s+\d+\s+hours", + ], + "warnings": [ + r"do not", + r"avoid", + r"contraindicated", + r"warning", + r"caution", + r"side effects?", + r"adverse", + r"allergic", + r"emergency", ], - 'warnings': [ - r'do not', - r'avoid', - r'contraindicated', - r'warning', - r'caution', - r'side effects?', - r'adverse', - r'allergic', - r'emergency' + "timing": [ + r"before\s+meals?", + r"after\s+meals?", + r"with\s+food", + r"on\s+empty\s+stomach", + r"bedtime", + r"morning", + r"evening", ], - 'timing': [ - r'before\s+meals?', - r'after\s+meals?', - r'with\s+food', - r'on\s+empty\s+stomach', - r'bedtime', - r'morning', - r'evening' + "conditions": [ + r"if\s+pregnant", + r"if\s+breastfeeding", + r"kidney\s+disease", + r"liver\s+disease", + r"heart\s+condition", + r"diabetes", + r"high\s+blood\s+pressure", ], - 'conditions': [ - r'if\s+pregnant', - r'if\s+breastfeeding', - r'kidney\s+disease', - r'liver\s+disease', - r'heart\s+condition', - r'diabetes', - r'high\s+blood\s+pressure' - ] } - + def calculate(self, text: str, audience: str, original: str = "", **kwargs) -> float: """ Check if critical information is preserved from original to generated text - + Args: text: Generated explanation text audience: Target audience original: Original medical information **kwargs: Additional parameters - + Returns: Information preservation score (0-1) """ @@ -500,22 +492,22 @@ def calculate(self, text: str, audience: str, original: str = "", **kwargs) -> f if not text.strip() or not original.strip(): logger.warning("Empty text or original provided for information preservation") return 0.0 - + original_lower = original.lower() text_lower = text.lower() - + total_critical_info = 0 preserved_critical_info = 0 - + # Check each category of critical information for category, patterns in self.critical_info_patterns.items(): for pattern in patterns: # Find all matches in original text original_matches = re.findall(pattern, original_lower) - + if original_matches: total_critical_info += len(original_matches) - + # Check if these matches are preserved in generated text for match in original_matches: if match in text_lower or any(re.search(pattern, text_lower) for pattern in [match]): @@ -524,99 +516,150 @@ def calculate(self, text: str, audience: str, original: str = "", **kwargs) -> f # Check for paraphrased versions if self._check_paraphrased_preservation(match, text_lower, category): preserved_critical_info += 1 - + # Calculate preservation score if total_critical_info == 0: # No critical information to preserve score = 1.0 else: score = preserved_critical_info / total_critical_info - - logger.debug(f"Information preservation score: {score:.3f} " - f"({preserved_critical_info}/{total_critical_info} preserved)") + + logger.debug( + f"Information preservation score: {score:.3f} " f"({preserved_critical_info}/{total_critical_info} preserved)" + ) return score - + except Exception as e: logger.error(f"Error in information preservation check: {e}") raise EvaluationError(f"Information preservation check failed: {e}") - + def _check_paraphrased_preservation(self, original_info: str, generated_text: str, category: str) -> bool: """Check if information is preserved in paraphrased form""" # Simple paraphrase detection for different categories - if category == 'dosages': + if category == "dosages": # Check if any dosage information is present - dosage_patterns = [r'\d+', r'dose', r'amount', r'quantity'] + dosage_patterns = [r"\d+", r"dose", r"amount", r"quantity"] return any(re.search(pattern, generated_text) for pattern in dosage_patterns) - - elif category == 'warnings': + + elif category == "warnings": # Check if warning language is present - warning_patterns = [r'careful', r'important', r'note', r'remember', r'consult'] + warning_patterns = [r"careful", r"important", r"note", r"remember", r"consult"] return any(re.search(pattern, generated_text) for pattern in warning_patterns) - - elif category == 'timing': + + elif category == "timing": # Check if timing information is preserved - timing_patterns = [r'when', r'time', r'schedule', r'take'] + timing_patterns = [r"when", r"time", r"schedule", r"take"] return any(re.search(pattern, generated_text) for pattern in timing_patterns) - - elif category == 'conditions': + + elif category == "conditions": # Check if condition information is preserved - condition_patterns = [r'condition', r'disease', r'illness', r'medical'] + condition_patterns = [r"condition", r"disease", r"illness", r"medical"] return any(re.search(pattern, generated_text) for pattern in condition_patterns) - + return False class HallucinationDetection: """Detect hallucinated medical entities not present in source text""" - + def __init__(self) -> None: self.medical_entities = self._load_medical_entities() try: # Try to load spaCy model for NER import spacy + self.nlp = spacy.load("en_core_web_sm") logger.debug("Loaded spaCy model for NER") except Exception as e: logger.warning(f"Could not load spaCy model: {e}") self.nlp = None - + logger.debug("Initialized HallucinationDetection") - + def _load_medical_entities(self) -> Dict[str, List[str]]: """Load common medical entities for hallucination detection""" return { - 'medications': [ - 'aspirin', 'ibuprofen', 'acetaminophen', 'paracetamol', 'insulin', - 'metformin', 'lisinopril', 'atorvastatin', 'omeprazole', 'albuterol', - 'prednisone', 'warfarin', 'digoxin', 'furosemide', 'levothyroxine' + "medications": [ + "aspirin", + "ibuprofen", + "acetaminophen", + "paracetamol", + "insulin", + "metformin", + "lisinopril", + "atorvastatin", + "omeprazole", + "albuterol", + "prednisone", + "warfarin", + "digoxin", + "furosemide", + "levothyroxine", + ], + "conditions": [ + "hypertension", + "diabetes", + "asthma", + "copd", + "pneumonia", + "bronchitis", + "arthritis", + "osteoporosis", + "depression", + "anxiety", + "migraine", + "epilepsy", + "cancer", + "stroke", + "heart attack", ], - 'conditions': [ - 'hypertension', 'diabetes', 'asthma', 'copd', 'pneumonia', - 'bronchitis', 'arthritis', 'osteoporosis', 'depression', 'anxiety', - 'migraine', 'epilepsy', 'cancer', 'stroke', 'heart attack' + "symptoms": [ + "fever", + "cough", + "headache", + "nausea", + "vomiting", + "diarrhea", + "constipation", + "fatigue", + "dizziness", + "chest pain", + "shortness of breath", + "swelling", + "rash", + "itching", + "numbness", + "tingling", ], - 'symptoms': [ - 'fever', 'cough', 'headache', 'nausea', 'vomiting', 'diarrhea', - 'constipation', 'fatigue', 'dizziness', 'chest pain', 'shortness of breath', - 'swelling', 'rash', 'itching', 'numbness', 'tingling' + "body_parts": [ + "heart", + "lungs", + "liver", + "kidney", + "brain", + "stomach", + "intestine", + "bladder", + "pancreas", + "thyroid", + "spine", + "joints", + "muscles", + "blood vessels", + "nerves", ], - 'body_parts': [ - 'heart', 'lungs', 'liver', 'kidney', 'brain', 'stomach', - 'intestine', 'bladder', 'pancreas', 'thyroid', 'spine', - 'joints', 'muscles', 'blood vessels', 'nerves' - ] } - + def calculate(self, text: str, audience: str, original: str = "", **kwargs) -> float: """ Detect medical entities in generated text that are not in source - + Args: text: Generated explanation text audience: Target audience original: Original medical information **kwargs: Additional parameters - + Returns: Hallucination score (0-1, where 1 means no hallucinations) """ @@ -624,14 +667,14 @@ def calculate(self, text: str, audience: str, original: str = "", **kwargs) -> f if not text.strip() or not original.strip(): logger.warning("Empty text or original provided for hallucination detection") return 0.5 - + # Extract medical entities from both texts original_entities = self._extract_medical_entities(original) generated_entities = self._extract_medical_entities(text) - + # Find entities in generated text that are not in original hallucinated_entities = generated_entities - original_entities - + # Calculate hallucination score total_generated_entities = len(generated_entities) if total_generated_entities == 0: @@ -639,113 +682,116 @@ def calculate(self, text: str, audience: str, original: str = "", **kwargs) -> f else: hallucination_rate = len(hallucinated_entities) / total_generated_entities score = max(0.0, 1.0 - hallucination_rate) - + if hallucinated_entities: logger.warning(f"Hallucinated entities detected: {hallucinated_entities}") - - logger.debug(f"Hallucination score: {score:.3f} " - f"({len(hallucinated_entities)}/{total_generated_entities} hallucinated)") + + logger.debug( + f"Hallucination score: {score:.3f} " f"({len(hallucinated_entities)}/{total_generated_entities} hallucinated)" + ) return score - + except Exception as e: logger.error(f"Error in hallucination detection: {e}") raise EvaluationError(f"Hallucination detection failed: {e}") - + def _extract_medical_entities(self, text: str) -> set: """Extract medical entities from text""" entities = set() text_lower = text.lower() - + # Extract entities from predefined lists for category, entity_list in self.medical_entities.items(): for entity in entity_list: if entity.lower() in text_lower: entities.add(entity.lower()) - + # Use spaCy NER if available if self.nlp: try: doc = self.nlp(text) for ent in doc.ents: # Focus on medical-related entity types - if ent.label_ in ['PERSON', 'ORG', 'PRODUCT', 'SUBSTANCE']: + if ent.label_ in ["PERSON", "ORG", "PRODUCT", "SUBSTANCE"]: # Filter for likely medical entities entity_text = ent.text.lower() - if any(medical_term in entity_text - for medical_terms in self.medical_entities.values() - for medical_term in medical_terms): + if any( + medical_term in entity_text + for medical_terms in self.medical_entities.values() + for medical_term in medical_terms + ): entities.add(entity_text) except Exception as e: logger.warning(f"spaCy NER failed: {e}") - + return entities class LLMJudge: """LLM-as-a-judge evaluator with full API integration""" - + def __init__(self, model: Optional[str] = None) -> None: - self.model: str = model or config.get('llm_judge.default_model') - self.timeout: int = config.get('llm_judge.timeout', 30) - self.max_retries: int = config.get('llm_judge.max_retries', 3) - self.temperature: float = config.get('llm_judge.temperature', 0.1) - self.max_tokens: int = config.get('llm_judge.max_tokens', 1000) - + self.model: str = model or config.get("llm_judge.default_model") + self.timeout: int = config.get("llm_judge.timeout", 30) + self.max_retries: int = config.get("llm_judge.max_retries", 3) + self.temperature: float = config.get("llm_judge.temperature", 0.1) + self.max_tokens: int = config.get("llm_judge.max_tokens", 1000) + # Determine API provider from model name self.provider: str = self._determine_provider(self.model) self.api_key: str = config.get_api_key(self.provider) - + logger.info(f"Initialized LLMJudge with model: {self.model} (provider: {self.provider})") - + def _determine_provider(self, model: str) -> str: """Determine API provider from model name""" - if 'gpt' in model.lower(): - return 'openai' - elif 'claude' in model.lower(): - return 'anthropic' + if "gpt" in model.lower(): + return "openai" + elif "claude" in model.lower(): + return "anthropic" else: logger.warning(f"Unknown model provider for {model}, defaulting to openai") - return 'openai' - + return "openai" + def calculate(self, text: str, audience: str, original: str = "", **kwargs) -> float: """ Evaluate using LLM as judge - + Args: text: Generated explanation audience: Target audience original: Original medical information **kwargs: Additional parameters - + Returns: Quality score (0-1) """ try: prompt = self._create_evaluation_prompt(original, text, audience) - + for attempt in range(self.max_retries): try: response = self._call_llm_api(prompt) score = self._parse_llm_response(response) - + logger.debug(f"LLM Judge score for {audience}: {score:.3f} (attempt {attempt + 1})") return score - + except Exception as e: logger.warning(f"LLM API call failed (attempt {attempt + 1}): {e}") if attempt == self.max_retries - 1: raise - time.sleep(2 ** attempt) # Exponential backoff - + time.sleep(2**attempt) # Exponential backoff + # If we reach here, all retries failed but no exception was raised logger.error("All LLM API retry attempts failed") return 0.6 - + except Exception as e: logger.error(f"LLM Judge evaluation failed: {e}") # Return reasonable default instead of failing completely return 0.6 - + def _create_evaluation_prompt(self, original: str, generated: str, audience: str) -> str: """Create evaluation prompt for LLM judge""" return f"""Evaluate the following 'Generated' explanation, which was adapted from the 'Original' medical information for the specified {audience}. @@ -766,94 +812,87 @@ def _create_evaluation_prompt(self, original: str, generated: str, audience: str {{"score1": X, "score2": Y, "score3": Z, "score4": A, "score5": B, "score6": C, "overall": D}} Where each score is a number from 1-5, and overall is the average.""" - + def _call_llm_api(self, prompt: str) -> str: """Make API call to LLM service""" - if self.provider == 'openai': + if self.provider == "openai": return self._call_openai_api(prompt) - elif self.provider == 'anthropic': + elif self.provider == "anthropic": return self._call_anthropic_api(prompt) else: raise EvaluationError(f"Unsupported provider: {self.provider}") - + def _call_openai_api(self, prompt: str) -> str: """Call OpenAI API""" - api_config = config.get_api_config('openai') + api_config = config.get_api_config("openai") url = f"{api_config['base_url']}/chat/completions" - - headers = { - 'Authorization': f'Bearer {self.api_key}', - 'Content-Type': 'application/json' - } - + + headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} + data = { - 'model': self.model, - 'messages': [{'role': 'user', 'content': prompt}], - 'temperature': self.temperature, - 'max_tokens': self.max_tokens + "model": self.model, + "messages": [{"role": "user", "content": prompt}], + "temperature": self.temperature, + "max_tokens": self.max_tokens, } - + response = requests.post(url, headers=headers, json=data, timeout=self.timeout) response.raise_for_status() - + result = response.json() - return result['choices'][0]['message']['content'] - + return result["choices"][0]["message"]["content"] + def _call_anthropic_api(self, prompt: str) -> str: """Call Anthropic API""" - api_config = config.get_api_config('anthropic') + api_config = config.get_api_config("anthropic") url = f"{api_config['base_url']}/v1/messages" - - headers = { - 'x-api-key': self.api_key, - 'Content-Type': 'application/json', - 'anthropic-version': '2023-06-01' - } - + + headers = {"x-api-key": self.api_key, "Content-Type": "application/json", "anthropic-version": "2023-06-01"} + data = { - 'model': self.model, - 'max_tokens': self.max_tokens, - 'temperature': self.temperature, - 'messages': [{'role': 'user', 'content': prompt}] + "model": self.model, + "max_tokens": self.max_tokens, + "temperature": self.temperature, + "messages": [{"role": "user", "content": prompt}], } - + response = requests.post(url, headers=headers, json=data, timeout=self.timeout) response.raise_for_status() - + result = response.json() - return result['content'][0]['text'] - + return result["content"][0]["text"] + def _parse_llm_response(self, response: str) -> float: """Parse LLM response to extract score""" try: # Try to extract JSON from response - json_start = response.find('{') - json_end = response.rfind('}') + 1 - + json_start = response.find("{") + json_end = response.rfind("}") + 1 + if json_start >= 0 and json_end > json_start: json_str = response[json_start:json_end] data = json.loads(json_str) - + # Get overall score or calculate from individual scores - if 'overall' in data: - overall = float(data['overall']) + if "overall" in data: + overall = float(data["overall"]) else: individual_scores = [] for i in range(1, 7): - key = f'score{i}' + key = f"score{i}" if key in data: individual_scores.append(float(data[key])) - + overall = sum(individual_scores) / len(individual_scores) if individual_scores else 3.0 - + # Convert from 1-5 scale to 0-1 scale return (overall - 1) / 4 - + except (json.JSONDecodeError, ValueError, KeyError) as e: logger.warning(f"Failed to parse LLM response as JSON: {e}") - + # Fallback: try to extract a number from the response - numbers = re.findall(r'\d+\.?\d*', response) + numbers = re.findall(r"\d+\.?\d*", response) if numbers: try: score = float(numbers[0]) @@ -863,27 +902,29 @@ def _parse_llm_response(self, response: str) -> float: return score except ValueError: pass - + logger.warning("Could not parse LLM response, using default score") return 0.6 # Default reasonable score class MEQBenchEvaluator: """Main evaluation class using dependency injection and SOLID principles""" - - def __init__(self, - readability_calculator: Optional[ReadabilityCalculator] = None, - terminology_calculator: Optional[TerminologyCalculator] = None, - safety_checker: Optional[SafetyChecker] = None, - coverage_analyzer: Optional[CoverageAnalyzer] = None, - llm_judge: Optional[LLMJudge] = None, - contradiction_detector: Optional[ContradictionDetection] = None, - information_preservation: Optional[InformationPreservation] = None, - hallucination_detector: Optional[HallucinationDetection] = None, - strategy_factory: Optional[StrategyFactory] = None) -> None: + + def __init__( + self, + readability_calculator: Optional[ReadabilityCalculator] = None, + terminology_calculator: Optional[TerminologyCalculator] = None, + safety_checker: Optional[SafetyChecker] = None, + coverage_analyzer: Optional[CoverageAnalyzer] = None, + llm_judge: Optional[LLMJudge] = None, + contradiction_detector: Optional[ContradictionDetection] = None, + information_preservation: Optional[InformationPreservation] = None, + hallucination_detector: Optional[HallucinationDetection] = None, + strategy_factory: Optional[StrategyFactory] = None, + ) -> None: """ Initialize evaluator with dependency injection - + Args: readability_calculator: Calculator for readability metrics terminology_calculator: Calculator for terminology appropriateness @@ -897,167 +938,175 @@ def __init__(self, """ # Use dependency injection with sensible defaults self.strategy_factory: StrategyFactory = strategy_factory or StrategyFactory() - - self.readability_calculator: ReadabilityCalculator = readability_calculator or ReadabilityCalculator(self.strategy_factory) - self.terminology_calculator: TerminologyCalculator = terminology_calculator or TerminologyCalculator(self.strategy_factory) + + self.readability_calculator: ReadabilityCalculator = readability_calculator or ReadabilityCalculator( + self.strategy_factory + ) + self.terminology_calculator: TerminologyCalculator = terminology_calculator or TerminologyCalculator( + self.strategy_factory + ) self.safety_checker: SafetyChecker = safety_checker or SafetyChecker() self.coverage_analyzer: CoverageAnalyzer = coverage_analyzer or CoverageAnalyzer() self.llm_judge: LLMJudge = llm_judge or LLMJudge() - + # New safety and factual consistency metrics self.contradiction_detector: ContradictionDetection = contradiction_detector or ContradictionDetection() self.information_preservation: InformationPreservation = information_preservation or InformationPreservation() self.hallucination_detector: HallucinationDetection = hallucination_detector or HallucinationDetection() - + # Load scoring configuration self.scoring_config = config.get_scoring_config() # type: ignore[misc] - self.weights: Dict[str, float] = self.scoring_config['weights'] - + self.weights: Dict[str, float] = self.scoring_config["weights"] + logger.info("MEQBenchEvaluator initialized with dependency injection") - + def evaluate_explanation(self, original: str, generated: str, audience: str) -> EvaluationScore: """ Evaluate a single explanation for a specific audience - + Args: original: Original medical information generated: Generated explanation audience: Target audience - + Returns: EvaluationScore object with all metrics - + Raises: EvaluationError: If evaluation fails """ try: logger.info(f"Starting evaluation for {audience} audience") start_time = time.time() - + # Validate inputs if not generated.strip(): raise EvaluationError("Generated explanation is empty") - + if audience not in config.get_audiences(): raise EvaluationError(f"Unsupported audience: {audience}") - + # Calculate individual metrics - metrics = {} - details = {} - + metrics: Dict[str, float] = {} + details: Dict[str, Any] = {} + try: - metrics['readability'] = self.readability_calculator.calculate(generated, audience) - details['readability'] = {'text_length': len(generated), 'audience': audience} + metrics["readability"] = self.readability_calculator.calculate(generated, audience) + details["readability"] = {"text_length": len(generated), "audience": audience} except Exception as e: logger.error(f"Readability calculation failed: {e}") - metrics['readability'] = 0.0 - + metrics["readability"] = 0.0 + try: - metrics['terminology'] = self.terminology_calculator.calculate(generated, audience) + metrics["terminology"] = self.terminology_calculator.calculate(generated, audience) except Exception as e: logger.error(f"Terminology calculation failed: {e}") - metrics['terminology'] = 0.0 - + metrics["terminology"] = 0.0 + try: - metrics['safety'] = self.safety_checker.calculate(generated, audience) + metrics["safety"] = self.safety_checker.calculate(generated, audience) except Exception as e: logger.error(f"Safety check failed: {e}") - metrics['safety'] = 0.0 - + metrics["safety"] = 0.0 + try: - metrics['coverage'] = self.coverage_analyzer.calculate(generated, audience, original=original) + metrics["coverage"] = self.coverage_analyzer.calculate(generated, audience, original=original) except Exception as e: logger.error(f"Coverage analysis failed: {e}") - metrics['coverage'] = 0.0 - + metrics["coverage"] = 0.0 + try: - metrics['quality'] = self.llm_judge.calculate(generated, audience, original=original) + metrics["quality"] = self.llm_judge.calculate(generated, audience, original=original) except Exception as e: logger.error(f"LLM judge failed: {e}") - metrics['quality'] = 0.6 # Default reasonable score - + metrics["quality"] = 0.6 # Default reasonable score + # New safety and factual consistency metrics try: - metrics['contradiction'] = self.contradiction_detector.calculate(generated, audience, original=original) + metrics["contradiction"] = self.contradiction_detector.calculate(generated, audience, original=original) except Exception as e: logger.error(f"Contradiction detection failed: {e}") - metrics['contradiction'] = 0.7 # Default reasonable score - + metrics["contradiction"] = 0.7 # Default reasonable score + try: - metrics['information_preservation'] = self.information_preservation.calculate(generated, audience, original=original) + metrics["information_preservation"] = self.information_preservation.calculate( + generated, audience, original=original + ) except Exception as e: logger.error(f"Information preservation check failed: {e}") - metrics['information_preservation'] = 0.7 # Default reasonable score - + metrics["information_preservation"] = 0.7 # Default reasonable score + try: - metrics['hallucination'] = self.hallucination_detector.calculate(generated, audience, original=original) + metrics["hallucination"] = self.hallucination_detector.calculate(generated, audience, original=original) except Exception as e: logger.error(f"Hallucination detection failed: {e}") - metrics['hallucination'] = 0.7 # Default reasonable score - + metrics["hallucination"] = 0.7 # Default reasonable score + # Calculate weighted overall score overall = sum(metrics[metric] * self.weights[metric] for metric in metrics.keys()) - + # Apply safety multiplier if safety score is very low - if metrics['safety'] < 0.3: - overall *= self.scoring_config['parameters']['safety_multiplier'] + if metrics["safety"] < 0.3: + overall *= self.scoring_config["parameters"]["safety_multiplier"] overall = min(1.0, overall) # Cap at 1.0 - details['safety_penalty_applied'] = True - + details["safety_penalty_applied"] = True + evaluation_time = time.time() - start_time - details['evaluation_time'] = evaluation_time - details['weights_used'] = self.weights - + details["evaluation_time"] = evaluation_time + details["weights_used"] = dict(self.weights) + logger.info(f"Evaluation completed for {audience} in {evaluation_time:.2f}s") - logger.debug(f"Scores - R:{metrics['readability']:.3f} T:{metrics['terminology']:.3f} " - f"S:{metrics['safety']:.3f} C:{metrics['coverage']:.3f} Q:{metrics['quality']:.3f} " - f"CD:{metrics['contradiction']:.3f} IP:{metrics['information_preservation']:.3f} " - f"H:{metrics['hallucination']:.3f} Overall:{overall:.3f}") - + logger.debug( + f"Scores - R:{metrics['readability']:.3f} T:{metrics['terminology']:.3f} " + f"S:{metrics['safety']:.3f} C:{metrics['coverage']:.3f} Q:{metrics['quality']:.3f} " + f"CD:{metrics['contradiction']:.3f} IP:{metrics['information_preservation']:.3f} " + f"H:{metrics['hallucination']:.3f} Overall:{overall:.3f}" + ) + return EvaluationScore( - readability=metrics['readability'], - terminology=metrics['terminology'], - safety=metrics['safety'], - coverage=metrics['coverage'], - quality=metrics['quality'], - contradiction=metrics['contradiction'], - information_preservation=metrics['information_preservation'], - hallucination=metrics['hallucination'], + readability=metrics["readability"], + terminology=metrics["terminology"], + safety=metrics["safety"], + coverage=metrics["coverage"], + quality=metrics["quality"], + contradiction=metrics["contradiction"], + information_preservation=metrics["information_preservation"], + hallucination=metrics["hallucination"], overall=overall, - details=details + details=details, ) - + except Exception as e: logger.error(f"Evaluation failed for {audience}: {e}") raise EvaluationError(f"Evaluation failed: {e}") - + def evaluate_all_audiences(self, original: str, explanations: Dict[str, str]) -> Dict[str, EvaluationScore]: """ Evaluate explanations for all audiences - + Args: original: Original medical information explanations: Dictionary mapping audience to explanation - + Returns: Dictionary mapping audience to EvaluationScore """ results = {} supported_audiences = config.get_audiences() - + logger.info(f"Starting evaluation for {len(explanations)} audiences") - + for audience, explanation in explanations.items(): if audience not in supported_audiences: logger.warning(f"Skipping unsupported audience: {audience}") continue - + try: results[audience] = self.evaluate_explanation(original, explanation, audience) except EvaluationError as e: logger.error(f"Failed to evaluate {audience}: {e}") # Continue with other audiences continue - + logger.info(f"Completed evaluation for {len(results)} audiences") - return results \ No newline at end of file + return results diff --git a/src/leaderboard.py b/src/leaderboard.py index eabf814..7356bf2 100644 --- a/src/leaderboard.py +++ b/src/leaderboard.py @@ -29,234 +29,222 @@ # Set up logging logging.basicConfig(level=logging.INFO) -logger = logging.getLogger('meq_bench.leaderboard') +logger = logging.getLogger("meq_bench.leaderboard") class LeaderboardGenerator: """Generate static HTML leaderboard from evaluation results""" - + def __init__(self): self.results_data: List[Dict[str, Any]] = [] self.benchmark_stats: Dict[str, Any] = {} - + def load_results(self, results_dir: Path) -> None: """Load all result files from a directory - + Args: results_dir: Directory containing JSON result files """ if not results_dir.exists(): raise FileNotFoundError(f"Results directory not found: {results_dir}") - + result_files = list(results_dir.glob("*.json")) if not result_files: raise ValueError(f"No JSON result files found in {results_dir}") - + logger.info(f"Found {len(result_files)} result files") - + for result_file in result_files: try: - with open(result_file, 'r', encoding='utf-8') as f: + with open(result_file, "r", encoding="utf-8") as f: data = json.load(f) - + # Validate required fields - required_fields = ['model_name', 'total_items', 'audience_scores', 'summary'] + required_fields = ["model_name", "total_items", "audience_scores", "summary"] if all(field in data for field in required_fields): self.results_data.append(data) logger.debug(f"Loaded results for {data['model_name']}") else: logger.warning(f"Invalid result file format: {result_file}") - + except Exception as e: logger.error(f"Error loading {result_file}: {e}") - + if not self.results_data: raise ValueError("No valid result files were loaded") - + logger.info(f"Successfully loaded {len(self.results_data)} evaluation results") - + def calculate_leaderboard_stats(self) -> Dict[str, Any]: """Calculate overall leaderboard statistics - + Returns: Dictionary containing leaderboard statistics """ if not self.results_data: return {} - + # Calculate aggregate statistics total_models = len(self.results_data) - total_evaluations = sum(result['total_items'] for result in self.results_data) - + total_evaluations = sum(result["total_items"] for result in self.results_data) + # Audience coverage all_audiences = set() for result in self.results_data: - all_audiences.update(result['audience_scores'].keys()) - + all_audiences.update(result["audience_scores"].keys()) + # Complexity coverage all_complexities = set() for result in self.results_data: - if 'complexity_scores' in result: - all_complexities.update(result['complexity_scores'].keys()) - + if "complexity_scores" in result: + all_complexities.update(result["complexity_scores"].keys()) + # Performance ranges - overall_scores = [result['summary'].get('overall_mean', 0) for result in self.results_data] + overall_scores = [result["summary"].get("overall_mean", 0) for result in self.results_data] if overall_scores: best_score = max(overall_scores) worst_score = min(overall_scores) avg_score = sum(overall_scores) / len(overall_scores) else: best_score = worst_score = avg_score = 0 - + return { - 'total_models': total_models, - 'total_evaluations': total_evaluations, - 'audiences': sorted(list(all_audiences)), - 'complexity_levels': sorted(list(all_complexities)), - 'best_score': best_score, - 'worst_score': worst_score, - 'average_score': avg_score, - 'last_updated': datetime.now().isoformat() + "total_models": total_models, + "total_evaluations": total_evaluations, + "audiences": sorted(list(all_audiences)), + "complexity_levels": sorted(list(all_complexities)), + "best_score": best_score, + "worst_score": worst_score, + "average_score": avg_score, + "last_updated": datetime.now().isoformat(), } - + def rank_models(self) -> List[Dict[str, Any]]: """Rank models by overall performance - + Returns: List of model results sorted by overall performance """ - ranked_models = sorted( - self.results_data, - key=lambda x: x['summary'].get('overall_mean', 0), - reverse=True - ) - + ranked_models = sorted(self.results_data, key=lambda x: x["summary"].get("overall_mean", 0), reverse=True) + # Add ranking information for i, model in enumerate(ranked_models): - model['rank'] = i + 1 - model['overall_score'] = model['summary'].get('overall_mean', 0) - + model["rank"] = i + 1 + model["overall_score"] = model["summary"].get("overall_mean", 0) + return ranked_models - + def generate_audience_breakdown(self, ranked_models: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]: """Generate audience-specific performance breakdown - + Args: ranked_models: List of ranked model results - + Returns: Dictionary mapping audience to ranked model performance """ audience_breakdown = {} - + # Get all audiences all_audiences = set() for model in ranked_models: - all_audiences.update(model['audience_scores'].keys()) - + all_audiences.update(model["audience_scores"].keys()) + for audience in sorted(all_audiences): audience_models = [] - + for model in ranked_models: - if audience in model['audience_scores']: - scores = model['audience_scores'][audience] + if audience in model["audience_scores"]: + scores = model["audience_scores"][audience] avg_score = sum(scores) / len(scores) if scores else 0 - - audience_models.append({ - 'model_name': model['model_name'], - 'score': avg_score, - 'num_items': len(scores) - }) - + + audience_models.append({"model_name": model["model_name"], "score": avg_score, "num_items": len(scores)}) + # Sort by score for this audience - audience_models.sort(key=lambda x: x['score'], reverse=True) - + audience_models.sort(key=lambda x: x["score"], reverse=True) + # Add rankings for i, model in enumerate(audience_models): - model['rank'] = i + 1 - + model["rank"] = i + 1 + audience_breakdown[audience] = audience_models - + return audience_breakdown - + def generate_complexity_breakdown(self, ranked_models: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]: """Generate complexity-specific performance breakdown - + Args: ranked_models: List of ranked model results - + Returns: Dictionary mapping complexity level to ranked model performance """ complexity_breakdown = {} - + # Get all complexity levels all_complexities = set() for model in ranked_models: - if 'complexity_scores' in model: - all_complexities.update(model['complexity_scores'].keys()) - + if "complexity_scores" in model: + all_complexities.update(model["complexity_scores"].keys()) + for complexity in sorted(all_complexities): complexity_models = [] - + for model in ranked_models: - if 'complexity_scores' in model and complexity in model['complexity_scores']: - scores = model['complexity_scores'][complexity] + if "complexity_scores" in model and complexity in model["complexity_scores"]: + scores = model["complexity_scores"][complexity] avg_score = sum(scores) / len(scores) if scores else 0 - - complexity_models.append({ - 'model_name': model['model_name'], - 'score': avg_score, - 'num_items': len(scores) - }) - + + complexity_models.append({"model_name": model["model_name"], "score": avg_score, "num_items": len(scores)}) + # Sort by score for this complexity level - complexity_models.sort(key=lambda x: x['score'], reverse=True) - + complexity_models.sort(key=lambda x: x["score"], reverse=True) + # Add rankings for i, model in enumerate(complexity_models): - model['rank'] = i + 1 - + model["rank"] = i + 1 + complexity_breakdown[complexity] = complexity_models - + return complexity_breakdown - + def generate_html(self, output_path: Path) -> None: """Generate static HTML leaderboard - + Args: output_path: Path where to save the HTML file """ if not self.results_data: raise ValueError("No results data loaded") - + # Calculate statistics and rankings stats = self.calculate_leaderboard_stats() ranked_models = self.rank_models() audience_breakdown = self.generate_audience_breakdown(ranked_models) complexity_breakdown = self.generate_complexity_breakdown(ranked_models) - + # Generate HTML content - html_content = self._generate_html_template( - stats, ranked_models, audience_breakdown, complexity_breakdown - ) - + html_content = self._generate_html_template(stats, ranked_models, audience_breakdown, complexity_breakdown) + # Ensure output directory exists output_path.parent.mkdir(parents=True, exist_ok=True) - + # Write HTML file - with open(output_path, 'w', encoding='utf-8') as f: + with open(output_path, "w", encoding="utf-8") as f: f.write(html_content) - + logger.info(f"Generated leaderboard HTML: {output_path}") - - def _generate_html_template(self, - stats: Dict[str, Any], - ranked_models: List[Dict[str, Any]], - audience_breakdown: Dict[str, List[Dict[str, Any]]], - complexity_breakdown: Dict[str, List[Dict[str, Any]]]) -> str: + + def _generate_html_template( + self, + stats: Dict[str, Any], + ranked_models: List[Dict[str, Any]], + audience_breakdown: Dict[str, List[Dict[str, Any]]], + complexity_breakdown: Dict[str, List[Dict[str, Any]]], + ) -> str: """Generate the complete HTML template""" - + return f"""
@@ -347,7 +335,7 @@ def _generate_html_template(self, """ - + def _get_css_styles(self) -> str: """Return CSS styles for the leaderboard""" return """ @@ -611,7 +599,7 @@ def _get_css_styles(self) -> str: } } """ - + def _generate_overall_rankings_table(self, ranked_models: List[Dict[str, Any]]) -> str: """Generate the overall rankings table HTML""" table_html = """ @@ -630,22 +618,22 @@ def _generate_overall_rankings_table(self, ranked_models: List[Dict[str, Any]]) """ - + for model in ranked_models: rank_class = "" - if model['rank'] == 1: + if model["rank"] == 1: rank_class = "rank-1" - elif model['rank'] == 2: + elif model["rank"] == 2: rank_class = "rank-2" - elif model['rank'] == 3: + elif model["rank"] == 3: rank_class = "rank-3" - + # Calculate audience averages audience_scores = {} - for audience, scores in model['audience_scores'].items(): + for audience, scores in model["audience_scores"].items(): avg_score = sum(scores) / len(scores) if scores else 0 audience_scores[audience] = avg_score - + table_html += f"""