diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7d3c9b6..7a2823c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -23,7 +23,7 @@ jobs: python-version: '3.9' - name: Cache pip dependencies - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ~/.cache/pip key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }} @@ -68,7 +68,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Cache pip dependencies - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ~/.cache/pip key: ${{ runner.os }}-${{ matrix.python-version }}-pip-${{ hashFiles('**/requirements.txt') }} @@ -98,7 +98,7 @@ jobs: - name: Upload coverage to Codecov if: matrix.python-version == '3.9' - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 with: file: ./coverage.xml flags: unittests @@ -174,7 +174,7 @@ jobs: make html - name: Upload documentation artifact - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: documentation path: docs/_build/html/ @@ -241,7 +241,7 @@ jobs: twine check dist/* - name: Upload package artifacts - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: dist-packages path: dist/ @@ -259,7 +259,7 @@ jobs: fetch-depth: 0 - name: Download package artifacts - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: dist-packages path: dist/ diff --git a/examples/basic_usage.py b/examples/basic_usage.py index a2dbb9e..ad4ce50 100644 --- a/examples/basic_usage.py +++ b/examples/basic_usage.py @@ -5,34 +5,30 @@ import os import sys import logging -from typing import Optional import torch import warnings -# Suppress some common warnings for cleaner output -warnings.filterwarnings("ignore", category=FutureWarning) -warnings.filterwarnings("ignore", category=UserWarning) - -sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +# Add parent directory to path for imports +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) from src.benchmark import MEQBench from src.evaluator import MEQBenchEvaluator -from src.config import config + +# Suppress some common warnings for cleaner output +warnings.filterwarnings("ignore", category=FutureWarning) +warnings.filterwarnings("ignore", category=UserWarning) # Set up logging -logger = logging.getLogger('meq_bench.examples') +logger = logging.getLogger("meq_bench.examples") -def generate_with_huggingface( - prompt: str, - model_name: str = "mistralai/Mistral-7B-Instruct-v0.2" -) -> str: +def generate_with_huggingface(prompt: str, model_name: str = "mistralai/Mistral-7B-Instruct-v0.2") -> str: """Generate text using a Hugging Face model. - + This function loads a pretrained language model from the Hugging Face Hub and generates a response to the given prompt. It's designed to work with instruction-tuned models that can follow the MEQ-Bench prompt format. - + Args: prompt: The input prompt containing medical content and audience instructions. model_name: Name of the Hugging Face model to use. Popular options include: @@ -40,24 +36,24 @@ def generate_with_huggingface( - "microsoft/DialoGPT-medium" (smaller, faster) - "meta-llama/Llama-2-7b-chat-hf" (requires approval) - "google/flan-t5-large" (encoder-decoder architecture) - + Returns: Generated text response as a string. - + Note: This function requires the transformers library to be installed: pip install transformers torch - + For larger models, ensure you have sufficient GPU memory or use CPU inference (which will be slower). The function automatically detects available devices. - + Example: ```python # Generate explanation using Mistral-7B prompt = "Medical Information: Diabetes is high blood sugar..." response = generate_with_huggingface(prompt, "mistralai/Mistral-7B-Instruct-v0.2") - + # Use a smaller model for faster inference response = generate_with_huggingface(prompt, "microsoft/DialoGPT-medium") ``` @@ -66,48 +62,45 @@ def generate_with_huggingface( # Import transformers here to make it optional from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from transformers import logging as hf_logging - + # Reduce transformers logging verbosity hf_logging.set_verbosity_error() - + except ImportError as e: logger.error("transformers library not installed. Install with: pip install transformers torch") raise ImportError( "transformers library is required for Hugging Face model integration. " "Install with: pip install transformers torch" ) from e - + logger.info(f"Loading Hugging Face model: {model_name}") - + try: # Determine device (GPU if available, otherwise CPU) device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {device}") - + # Load tokenizer logger.info("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(model_name) - + # Set pad token if not present (needed for some models) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - + # Load model with appropriate settings logger.info("Loading model...") model_kwargs = { "torch_dtype": torch.float16 if device == "cuda" else torch.float32, "low_cpu_mem_usage": True, } - + # For CPU inference or limited GPU memory, use smaller precision if device == "cpu": model_kwargs["torch_dtype"] = torch.float32 - - model = AutoModelForCausalLM.from_pretrained( - model_name, - **model_kwargs - ).to(device) - + + model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs).to(device) + # Create text generation pipeline generator = pipeline( "text-generation", @@ -115,20 +108,20 @@ def generate_with_huggingface( tokenizer=tokenizer, device=0 if device == "cuda" else -1, # 0 for first GPU, -1 for CPU return_full_text=False, # Only return generated text, not input prompt - pad_token_id=tokenizer.eos_token_id + pad_token_id=tokenizer.eos_token_id, ) - + # Generation parameters - adjust these for different models/quality trade-offs generation_params = { "max_new_tokens": 800, # Maximum tokens to generate - "temperature": 0.7, # Controls randomness (0.1 = deterministic, 1.0 = creative) - "top_p": 0.9, # Nucleus sampling parameter - "do_sample": True, # Enable sampling for more diverse outputs + "temperature": 0.7, # Controls randomness (0.1 = deterministic, 1.0 = creative) + "top_p": 0.9, # Nucleus sampling parameter + "do_sample": True, # Enable sampling for more diverse outputs "num_return_sequences": 1, "pad_token_id": tokenizer.eos_token_id, "eos_token_id": tokenizer.eos_token_id, } - + # For instruction-tuned models like Mistral, format the prompt appropriately if "instruct" in model_name.lower() or "chat" in model_name.lower(): # Use instruction format for better results @@ -136,41 +129,41 @@ def generate_with_huggingface( else: # Use prompt as-is for base models formatted_prompt = prompt - + logger.info("Generating response...") - + # Generate response result = generator(formatted_prompt, **generation_params) - + # Extract generated text if isinstance(result, list) and len(result) > 0: - generated_text = result[0].get('generated_text', '') + generated_text = result[0].get("generated_text", "") else: generated_text = str(result) - + # Clean up the response generated_text = generated_text.strip() - + # Remove potential instruction formatting artifacts - if generated_text.startswith('[/INST]'): + if generated_text.startswith("[/INST]"): generated_text = generated_text[7:].strip() - + logger.info(f"Generated {len(generated_text)} characters") - + return generated_text - + except Exception as e: logger.error(f"Error during model generation: {e}") - + # Fallback to dummy response to keep the example working logger.warning("Falling back to dummy response due to model loading error") return dummy_model_function(prompt) - + finally: # Clean up GPU memory if used - if 'model' in locals(): + if "model" in locals(): del model - if 'tokenizer' in locals(): + if "tokenizer" in locals(): del tokenizer if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -182,87 +175,110 @@ def dummy_model_function(prompt: str) -> str: In practice, this would call your actual LLM or use the Hugging Face function above. """ return """ - For a Physician: The patient presents with essential hypertension, likely multifactorial etiology including genetic predisposition and lifestyle factors. Recommend ACE inhibitor initiation with monitoring of renal function and electrolytes. Consider cardiovascular risk stratification. + For a Physician: The patient presents with essential hypertension, likely multifactorial etiology including + genetic predisposition and lifestyle factors. Recommend ACE inhibitor initiation with monitoring of renal + function and electrolytes. Consider cardiovascular risk stratification. - For a Nurse: Monitor blood pressure readings twice daily, document trends. Educate patient on medication compliance, dietary sodium restriction, and importance of regular follow-up. Watch for signs of medication side effects. + For a Nurse: Monitor blood pressure readings twice daily, document trends. Educate patient on medication + compliance, dietary sodium restriction, and importance of regular follow-up. Watch for signs of medication + side effects. - For a Patient: Your blood pressure is higher than normal, which means your heart is working harder than it should. We'll start you on medication to help lower it. It's important to take your medicine every day and eat less salt. + For a Patient: Your blood pressure is higher than normal, which means your heart is working harder than it + should. We'll start you on medication to help lower it. It's important to take your medicine every day and + eat less salt. - For a Caregiver: Help ensure they take their blood pressure medication at the same time each day. Monitor for dizziness or fatigue. Encourage a low-salt diet and regular gentle exercise. Call the doctor if blood pressure readings are very high. + For a Caregiver: Help ensure they take their blood pressure medication at the same time each day. Monitor for + dizziness or fatigue. Encourage a low-salt diet and regular gentle exercise. Call the doctor if blood pressure + readings are very high. """ def main(): """Main example function demonstrating both dummy and Hugging Face models""" + # Check if running in CI environment (non-interactive) + is_ci = os.getenv('CI') or os.getenv('GITHUB_ACTIONS') or not os.isatty(0) + print("MEQ-Bench Basic Usage Example with Hugging Face Integration") + if is_ci: + print("Running in CI mode (non-interactive)") print("=" * 60) - + # Initialize benchmark bench = MEQBench() - + # Create sample dataset print("Creating sample dataset...") sample_items = bench.create_sample_dataset() - + # Add items to benchmark for item in sample_items: bench.add_benchmark_item(item) - + # Get benchmark statistics stats = bench.get_benchmark_stats() - print(f"\nBenchmark Statistics:") + print("\nBenchmark Statistics:") print(f"Total items: {stats['total_items']}") print(f"Complexity distribution: {stats['complexity_distribution']}") print(f"Source distribution: {stats['source_distribution']}") - + # Test single explanation generation print("\nGenerating explanations for sample content...") - medical_content = "Diabetes is a condition where blood sugar levels are too high. It requires careful management through diet, exercise, and sometimes medication." - - # Ask user which model to use + medical_content = "Diabetes is a condition where blood sugar levels are too high." + + # Ask user which model to use (or use default in CI) print("\nChoose model for explanation generation:") print("1. Dummy model (fast, for testing)") print("2. Hugging Face model (requires transformers library)") - - choice = input("Enter choice (1 or 2, default=1): ").strip() - + + if is_ci: + choice = "1" # Use dummy model in CI + print("Using dummy model (CI mode)") + else: + choice = input("Enter choice (1 or 2, default=1): ").strip() + if choice == "2": print("\nUsing Hugging Face model...") print("Note: This requires 'transformers' and 'torch' libraries:") print("pip install transformers torch") - + # Option to specify model name - model_choice = input("\nChoose model (or press Enter for default):\n" - "1. mistralai/Mistral-7B-Instruct-v0.2 (default, ~7B params)\n" - "2. microsoft/DialoGPT-medium (smaller, faster)\n" - "3. Custom model name\n" - "Choice: ").strip() - - if model_choice == "2": - model_name = "microsoft/DialoGPT-medium" - elif model_choice == "3": - model_name = input("Enter model name: ").strip() + if is_ci: + model_name = "mistralai/Mistral-7B-Instruct-v0.2" # Use default in CI + print(f"Using default model: {model_name} (CI mode)") else: - model_name = "mistralai/Mistral-7B-Instruct-v0.2" - + model_choice = input( + "\nChoose model (or press Enter for default):\n" + "1. mistralai/Mistral-7B-Instruct-v0.2 (default, ~7B params)\n" + "2. microsoft/DialoGPT-medium (smaller, faster)\n" + "3. Custom model name\n" + "Choice: " + ).strip() + + if model_choice == "2": + model_name = "microsoft/DialoGPT-medium" + elif model_choice == "3": + model_name = input("Enter model name: ").strip() + else: + model_name = "mistralai/Mistral-7B-Instruct-v0.2" + print(f"\nLoading model: {model_name}") print("This may take a few minutes on first run...") - + # Create model function with specified model def hf_model_function(prompt: str) -> str: return generate_with_huggingface(prompt, model_name) - + model_function = hf_model_function model_type = f"Hugging Face ({model_name})" - + else: print("\nUsing dummy model...") model_function = dummy_model_function model_type = "Dummy model" - + print(f"\nGenerating explanations with {model_type}...") explanations = bench.generate_explanations(medical_content, model_function) - + print("\nGenerated Explanations:") for audience, explanation in explanations.items(): print(f"\n{audience.upper()}:") @@ -272,13 +288,13 @@ def hf_model_function(prompt: str) -> str: print(explanation[:max_length] + "...") else: print(explanation) - + # Evaluate explanations print("\nEvaluating explanations...") try: evaluator = MEQBenchEvaluator() results = evaluator.evaluate_all_audiences(medical_content, explanations) - + print("\nEvaluation Results:") for audience, score in results.items(): print(f"\n{audience.upper()}:") @@ -291,31 +307,35 @@ def hf_model_function(prompt: str) -> str: except Exception as e: print(f"Evaluation failed: {e}") print("This might be due to missing dependencies or configuration.") - + # Option to run full benchmark evaluation - run_full = input("\nRun full benchmark evaluation? (y/N): ").strip().lower() - - if run_full == 'y': + if is_ci: + run_full = "n" # Skip full evaluation in CI + print("\nSkipping full benchmark evaluation (CI mode)") + else: + run_full = input("\nRun full benchmark evaluation? (y/N): ").strip().lower() + + if run_full == "y": print("\nRunning full benchmark evaluation...") print("Note: This may take several minutes with real models...") - + try: full_results = bench.evaluate_model(model_function, max_items=2) - + print("\nFull Benchmark Results Summary:") - summary = full_results['summary'] + summary = full_results["summary"] for key, value in summary.items(): if isinstance(value, (int, float)): print(f"{key}: {value:.3f}") - + # Save results output_path = f"sample_results_{model_type.replace(' ', '_').replace('(', '').replace(')', '')}.json" bench.save_results(full_results, output_path) print(f"\nResults saved to: {output_path}") - + except Exception as e: print(f"Full evaluation failed: {e}") - + print("\n" + "=" * 60) print("Example completed!") print("\nTo use Hugging Face models in your own code:") @@ -328,4 +348,4 @@ def hf_model_function(prompt: str) -> str: if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/__init__.py b/src/__init__.py index 4a6449b..927d332 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1,8 +1,8 @@ """MEQ-Bench: A Resource-Efficient Benchmark for Evaluating Medical LLM Explanation Quality. -MEQ-Bench is the first benchmark specifically designed to assess an LLM's ability to generate -audience-adaptive medical explanations for four key stakeholders: physicians, nurses, patients, -and caregivers. This package provides a comprehensive evaluation framework that combines +MEQ-Bench is the first benchmark specifically designed to assess an LLM's ability to generate +audience-adaptive medical explanations for four key stakeholders: physicians, nurses, patients, +and caregivers. This package provides a comprehensive evaluation framework that combines automated metrics with LLM-as-a-judge evaluation. Key Features: @@ -14,11 +14,11 @@ Typical usage example: ```python from meq_bench import MEQBench, MEQBenchEvaluator - + # Initialize benchmark bench = MEQBench() evaluator = MEQBenchEvaluator() - + # Generate and evaluate explanations explanations = bench.generate_explanations(medical_content, model_func) scores = evaluator.evaluate_all_audiences(medical_content, explanations) @@ -31,17 +31,11 @@ # Initialize configuration and logging from .config import config -config.setup_logging() - from .benchmark import MEQBench from .evaluator import MEQBenchEvaluator from .prompt_templates import AudienceAdaptivePrompt from .strategies import StrategyFactory -__all__ = [ - "MEQBench", - "MEQBenchEvaluator", - "AudienceAdaptivePrompt", - "StrategyFactory", - "config" -] \ No newline at end of file +config.setup_logging() + +__all__ = ["MEQBench", "MEQBenchEvaluator", "AudienceAdaptivePrompt", "StrategyFactory", "config"] diff --git a/src/benchmark.py b/src/benchmark.py index b6f0da4..49d9a1b 100644 --- a/src/benchmark.py +++ b/src/benchmark.py @@ -1,7 +1,7 @@ """Main MEQ-Bench benchmark implementation. This module contains the core benchmark classes and functionality for MEQ-Bench. -It provides tools for loading medical datasets, generating audience-adaptive +It provides tools for loading medical datasets, generating audience-adaptive explanations, and running comprehensive evaluations. Key classes: @@ -13,6 +13,7 @@ import os import logging from typing import Dict, List, Any, Optional, Union, Callable + try: from typing_extensions import TypedDict except ImportError: @@ -22,14 +23,15 @@ from .config import config from .prompt_templates import AudienceAdaptivePrompt -from .evaluator import MEQBenchEvaluator, EvaluationScore +from .evaluator import MEQBenchEvaluator -logger = logging.getLogger('meq_bench.benchmark') +logger = logging.getLogger("meq_bench.benchmark") # TypedDict definitions for structured data class EvaluationResultDict(TypedDict): """Type definition for evaluation results.""" + model_name: str total_items: int audience_scores: Dict[str, List[float]] @@ -40,6 +42,7 @@ class EvaluationResultDict(TypedDict): class ItemResultDict(TypedDict): """Type definition for individual item results.""" + item_id: str complexity_level: str source_dataset: str @@ -49,6 +52,7 @@ class ItemResultDict(TypedDict): class BenchmarkStatsDict(TypedDict): """Type definition for benchmark statistics.""" + total_items: int complexity_distribution: Dict[str, int] source_distribution: Dict[str, int] @@ -57,11 +61,11 @@ class BenchmarkStatsDict(TypedDict): @dataclass class MEQBenchItem: """Represents a single benchmark item for evaluation. - + A benchmark item contains medical content to be explained, along with metadata about its complexity level and source dataset. Optionally includes reference explanations for different audiences. - + Attributes: id: Unique identifier for the benchmark item. medical_content: The medical information to be adapted for different audiences. @@ -70,6 +74,7 @@ class MEQBenchItem: reference_explanations: Optional reference explanations for each audience, mapping audience names to explanation text. """ + id: str medical_content: str complexity_level: str # "basic", "intermediate", "advanced" @@ -79,40 +84,40 @@ class MEQBenchItem: class MEQBench: """Main benchmark class for MEQ-Bench evaluation. - + This class provides the core functionality for running MEQ-Bench evaluations, including loading benchmark data, generating audience-adaptive explanations, and evaluating model performance across different audiences and complexity levels. - + The class manages benchmark items, interfaces with evaluation components, and provides comprehensive evaluation results with detailed statistics. - + Attributes: data_path: Path to the benchmark data directory. evaluator: MEQBenchEvaluator instance for scoring explanations. prompt_template: AudienceAdaptivePrompt instance for generating prompts. benchmark_items: List of loaded MEQBenchItem objects. - + Example: ```python # Initialize benchmark bench = MEQBench(data_path="/path/to/data") - + # Generate explanations for a model explanations = bench.generate_explanations(medical_content, model_func) - + # Run full evaluation results = bench.evaluate_model(model_func, max_items=100) ``` """ - + def __init__(self, data_path: Optional[str] = None) -> None: """Initialize MEQ-Bench instance. - + Sets up the benchmark with the specified data directory and initializes the evaluator and prompt template components. Automatically loads benchmark data if the data directory exists. - + Args: data_path: Path to benchmark data directory. If None, uses default 'data' directory relative to the package root. @@ -126,16 +131,16 @@ def __init__(self, data_path: Optional[str] = None) -> None: logger.info("Some evaluation features may be limited due to missing dependencies or configuration") self.prompt_template: AudienceAdaptivePrompt = AudienceAdaptivePrompt() self.benchmark_items: List[MEQBenchItem] = [] - + # Load benchmark data if available self._load_benchmark_data() - + def _resolve_data_path(self, data_path: Optional[str] = None) -> Path: """Resolve data directory path with fallback options. - + Args: data_path: Optional custom data path - + Returns: Resolved Path object for data directory """ @@ -150,35 +155,38 @@ def _resolve_data_path(self, data_path: Optional[str] = None) -> Path: # Current working directory Path.cwd() / "data", # Environment variable if set - Path(os.environ.get('MEQ_BENCH_DATA_PATH', '')) if os.environ.get('MEQ_BENCH_DATA_PATH') else None, + Path(os.environ.get("MEQ_BENCH_DATA_PATH", "")) if os.environ.get("MEQ_BENCH_DATA_PATH") else None, # Config-based path - Path(config.get_data_path()) if hasattr(config, 'get_data_path') else None + Path(config.get_data_path()) if hasattr(config, "get_data_path") else None, ] - + # Find first existing path or use first option as default resolved_path = None for path in possible_paths: if path and path.exists(): resolved_path = path.resolve() break - + if not resolved_path: # Default to package relative path resolved_path = (Path(__file__).parent.parent / "data").resolve() - + # Ensure directory exists - resolved_path.mkdir(parents=True, exist_ok=True) - logger.info(f"Using data directory: {resolved_path}") - return resolved_path - + if resolved_path is not None: + resolved_path.mkdir(parents=True, exist_ok=True) + logger.info(f"Using data directory: {resolved_path}") + return resolved_path + else: + raise ValueError("Could not resolve data directory path") + def _load_benchmark_data(self) -> None: """Load benchmark data from JSON files with error handling. - + Loads benchmark items from benchmark_items.json in the data directory. Each item is converted to a MEQBenchItem object and added to the benchmark_items list. Includes comprehensive error handling for missing files, invalid JSON, and malformed data. - + The JSON file should contain a list of dictionaries with the following keys: - id: Unique identifier for the item - medical_content: Medical information to be explained @@ -188,99 +196,99 @@ def _load_benchmark_data(self) -> None: """ try: benchmark_file = self.data_path / "benchmark_items.json" - + if not benchmark_file.exists(): logger.warning(f"Benchmark data file not found: {benchmark_file}") logger.info("Use create_sample_dataset() to generate sample data") return - - with open(benchmark_file, 'r', encoding='utf-8') as f: + + with open(benchmark_file, "r", encoding="utf-8") as f: data = json.load(f) - + if not isinstance(data, list): raise ValueError("Benchmark data must be a list of items") - + for i, item_data in enumerate(data): try: # Validate required fields - required_fields = ['id', 'medical_content', 'complexity_level', 'source_dataset'] + required_fields = ["id", "medical_content", "complexity_level", "source_dataset"] for field in required_fields: if field not in item_data: raise KeyError(f"Missing required field '{field}' in item {i}") - + item = MEQBenchItem( - id=item_data['id'], - medical_content=item_data['medical_content'], - complexity_level=item_data['complexity_level'], - source_dataset=item_data['source_dataset'], - reference_explanations=item_data.get('reference_explanations') + id=item_data["id"], + medical_content=item_data["medical_content"], + complexity_level=item_data["complexity_level"], + source_dataset=item_data["source_dataset"], + reference_explanations=item_data.get("reference_explanations"), ) self.benchmark_items.append(item) except (KeyError, ValueError) as e: logger.error(f"Error loading benchmark item {i}: {e}") continue - + logger.info(f"Loaded {len(self.benchmark_items)} benchmark items from {benchmark_file}") - + except FileNotFoundError: logger.warning(f"Benchmark data directory not found: {self.data_path}") except json.JSONDecodeError as e: logger.error(f"Invalid JSON in benchmark data file: {e}") except Exception as e: logger.error(f"Unexpected error loading benchmark data: {e}") - + def add_benchmark_item(self, item: MEQBenchItem) -> None: """Add a new benchmark item to the evaluation set. - + Validates the item data and adds it to the benchmark items list. Includes checks for data integrity and duplicate IDs. - + Args: item: MEQBenchItem object to add to the benchmark. - + Raises: TypeError: If item is not an instance of MEQBenchItem. ValueError: If item data is invalid or ID already exists. """ if not isinstance(item, MEQBenchItem): raise TypeError("item must be an instance of MEQBenchItem") - + # Validate item data if not item.id or not isinstance(item.id, str): raise ValueError("item.id must be a non-empty string") - + if not item.medical_content or not isinstance(item.medical_content, str): raise ValueError("item.medical_content must be a non-empty string") - - if item.complexity_level not in ['basic', 'intermediate', 'advanced']: + + if item.complexity_level not in ["basic", "intermediate", "advanced"]: raise ValueError("item.complexity_level must be 'basic', 'intermediate', or 'advanced'") - + # Check for duplicate IDs if any(existing_item.id == item.id for existing_item in self.benchmark_items): raise ValueError(f"Item with ID '{item.id}' already exists") - + self.benchmark_items.append(item) logger.debug(f"Added benchmark item: {item.id}") - + def generate_explanations(self, medical_content: str, model_func: Callable[[str], str]) -> Dict[str, str]: """Generate audience-adaptive explanations using a model. - + Uses the configured prompt template to generate explanations tailored for different healthcare audiences (physicians, nurses, patients, caregivers). - + Args: medical_content: Medical information to be adapted for different audiences. model_func: Callable that takes a prompt string and returns the model's response as a string. - + Returns: Dictionary mapping audience names to their respective explanations. Keys are audience names (e.g., 'physician', 'nurse', 'patient', 'caregiver'). - + Raises: ValueError: If medical_content is empty or invalid TypeError: If model_func is not callable - + Example: ```python def my_model(prompt): @@ -288,9 +296,9 @@ def my_model(prompt): model="gpt-4", messages=[{"role": "user", "content": prompt}] ).choices[0].message.content - + explanations = bench.generate_explanations( - "Hypertension is high blood pressure", + "Hypertension is high blood pressure", my_model ) ``` @@ -298,67 +306,67 @@ def my_model(prompt): # Input validation if not medical_content or not isinstance(medical_content, str): raise ValueError("medical_content must be a non-empty string") - + if medical_content.strip() == "": raise ValueError("medical_content cannot be empty or contain only whitespace") - + if len(medical_content.strip()) < 10: raise ValueError("medical_content must be at least 10 characters long") - + if not callable(model_func): raise TypeError("model_func must be a callable function") - + # Additional content validation if len(medical_content) > 10000: # Reasonable upper limit logger.warning(f"Medical content is very long ({len(medical_content)} chars). Consider splitting.") - + # Sanitize content - remove excessive whitespace - sanitized_content = ' '.join(medical_content.split()) - + sanitized_content = " ".join(medical_content.split()) + try: prompt = self.prompt_template.format_prompt(sanitized_content) logger.debug(f"Generated prompt with {len(prompt)} characters") - + response = model_func(prompt) - + # Validate model response if not response or not isinstance(response, str): raise ValueError("Model function returned empty or invalid response") - + if response.strip() == "": raise ValueError("Model function returned empty response") - + explanations = self.prompt_template.parse_response(response) - + # Validate parsed explanations if not explanations: raise ValueError("Failed to parse any explanations from model response") - + # Log successful generation logger.info(f"Generated explanations for {len(explanations)} audiences") - + return explanations - + except Exception as e: logger.error(f"Error generating explanations: {e}") if isinstance(e, (ValueError, TypeError)): raise else: raise RuntimeError(f"Unexpected error during explanation generation: {e}") from e - + def evaluate_model(self, model_func: Callable[[str], str], max_items: Optional[int] = None) -> EvaluationResultDict: """Evaluate a model on the full benchmark. - + Runs comprehensive evaluation of a model's performance across all benchmark items and audiences. Generates explanations for each item and evaluates them using the configured evaluator. - + Args: model_func: Callable that takes a prompt string and returns the model's response as a string. max_items: Maximum number of benchmark items to evaluate. If None, evaluates all available items. Useful for testing with smaller subsets. - + Returns: Dictionary containing comprehensive evaluation results with the following keys: - model_name: Name of the evaluated model @@ -367,7 +375,7 @@ def evaluate_model(self, model_func: Callable[[str], str], max_items: Optional[i - complexity_scores: Scores grouped by complexity level - detailed_results: Per-item detailed evaluation results - summary: Summary statistics including means, standard deviations, etc. - + Example: ```python results = bench.evaluate_model(my_model_func, max_items=50) @@ -375,77 +383,67 @@ def evaluate_model(self, model_func: Callable[[str], str], max_items: Optional[i ``` """ results: EvaluationResultDict = { - 'model_name': getattr(model_func, '__name__', 'unknown'), - 'total_items': len(self.benchmark_items[:max_items]) if max_items else len(self.benchmark_items), - 'audience_scores': { - 'physician': [], - 'nurse': [], - 'patient': [], - 'caregiver': [] - }, - 'complexity_scores': { - 'basic': [], - 'intermediate': [], - 'advanced': [] - }, - 'detailed_results': [] + "model_name": getattr(model_func, "__name__", "unknown"), + "total_items": len(self.benchmark_items[:max_items]) if max_items else len(self.benchmark_items), + "audience_scores": {"physician": [], "nurse": [], "patient": [], "caregiver": []}, + "complexity_scores": {"basic": [], "intermediate": [], "advanced": []}, + "detailed_results": [], + "summary": {}, } - + items_to_evaluate = self.benchmark_items[:max_items] if max_items else self.benchmark_items - + for item in items_to_evaluate: # Generate explanations explanations = self.generate_explanations(item.medical_content, model_func) - + # Evaluate explanations - evaluation_results = self.evaluator.evaluate_all_audiences( - item.medical_content, explanations - ) - + evaluation_results = self.evaluator.evaluate_all_audiences(item.medical_content, explanations) + # Store detailed results item_result = { - 'item_id': item.id, - 'complexity_level': item.complexity_level, - 'source_dataset': item.source_dataset, - 'explanations': explanations, - 'scores': { + "item_id": item.id, + "complexity_level": item.complexity_level, + "source_dataset": item.source_dataset, + "explanations": explanations, + "scores": { audience: { - 'readability': score.readability, - 'terminology': score.terminology, - 'safety': score.safety, - 'coverage': score.coverage, - 'quality': score.quality, - 'overall': score.overall + "readability": score.readability, + "terminology": score.terminology, + "safety": score.safety, + "coverage": score.coverage, + "quality": score.quality, + "overall": score.overall, } for audience, score in evaluation_results.items() - } + }, } - results['detailed_results'].append(item_result) - + results["detailed_results"].append(item_result) + # Aggregate scores by audience for audience, score in evaluation_results.items(): - if audience in results['audience_scores']: - results['audience_scores'][audience].append(score.overall) - + if audience in results["audience_scores"]: + results["audience_scores"][audience].append(score.overall) + # Aggregate scores by complexity avg_overall = sum(score.overall for score in evaluation_results.values()) / len(evaluation_results) - results['complexity_scores'][item.complexity_level].append(avg_overall) - + results["complexity_scores"][item.complexity_level].append(avg_overall) + # Calculate summary statistics - results['summary'] = self._calculate_summary_stats(results) - + results["summary"] = self._calculate_summary_stats(dict(results)) + return results - + def _calculate_summary_stats(self, results: Dict[str, Any]) -> Dict[str, Any]: """Calculate summary statistics from evaluation results. - + Computes descriptive statistics for audience-level and complexity-level scores, including means, standard deviations, minimums, and maximums. - + Args: results: Dictionary containing evaluation results with audience_scores and complexity_scores keys. - + Returns: Dictionary containing summary statistics with keys like: - {audience}_mean: Mean score for each audience @@ -458,46 +456,48 @@ def _calculate_summary_stats(self, results: Dict[str, Any]) -> Dict[str, Any]: - overall_std: Overall standard deviation """ summary = {} - + # Audience-level statistics - for audience, scores in results['audience_scores'].items(): + for audience, scores in results["audience_scores"].items(): if scores: - summary[f'{audience}_mean'] = sum(scores) / len(scores) - summary[f'{audience}_std'] = (sum((x - summary[f'{audience}_mean']) ** 2 for x in scores) / len(scores)) ** 0.5 - summary[f'{audience}_min'] = min(scores) - summary[f'{audience}_max'] = max(scores) - + summary[f"{audience}_mean"] = sum(scores) / len(scores) + summary[f"{audience}_std"] = (sum((x - summary[f"{audience}_mean"]) ** 2 for x in scores) / len(scores)) ** 0.5 + summary[f"{audience}_min"] = min(scores) + summary[f"{audience}_max"] = max(scores) + # Complexity-level statistics - for complexity, scores in results['complexity_scores'].items(): + for complexity, scores in results["complexity_scores"].items(): if scores: - summary[f'{complexity}_mean'] = sum(scores) / len(scores) - summary[f'{complexity}_std'] = (sum((x - summary[f'{complexity}_mean']) ** 2 for x in scores) / len(scores)) ** 0.5 - + summary[f"{complexity}_mean"] = sum(scores) / len(scores) + summary[f"{complexity}_std"] = ( + sum((x - summary[f"{complexity}_mean"]) ** 2 for x in scores) / len(scores) + ) ** 0.5 + # Overall statistics all_scores = [] - for audience_scores in results['audience_scores'].values(): + for audience_scores in results["audience_scores"].values(): all_scores.extend(audience_scores) - + if all_scores: - summary['overall_mean'] = sum(all_scores) / len(all_scores) - summary['overall_std'] = (sum((x - summary['overall_mean']) ** 2 for x in all_scores) / len(all_scores)) ** 0.5 - + summary["overall_mean"] = sum(all_scores) / len(all_scores) + summary["overall_std"] = (sum((x - summary["overall_mean"]) ** 2 for x in all_scores) / len(all_scores)) ** 0.5 + return summary - + def save_results(self, results: Dict[str, Any], output_path: str) -> None: """Save evaluation results to JSON file. - + Serializes the evaluation results dictionary to a JSON file with proper formatting for readability. Includes proper path handling and error handling. - + Args: results: Dictionary containing evaluation results from evaluate_model(). output_path: File path where results should be saved. - + Raises: Exception: If file writing fails. - + Example: ```python results = bench.evaluate_model(model_func) @@ -505,37 +505,37 @@ def save_results(self, results: Dict[str, Any], output_path: str) -> None: ``` """ output_file = Path(output_path).resolve() - + # Ensure output directory exists output_file.parent.mkdir(parents=True, exist_ok=True) - + try: - with open(output_file, 'w', encoding='utf-8') as f: + with open(output_file, "w", encoding="utf-8") as f: json.dump(results, f, indent=2, ensure_ascii=False) logger.info(f"Results saved to: {output_file}") except Exception as e: logger.error(f"Failed to save results to {output_file}: {e}") raise - + def create_sample_dataset(self, output_path: Optional[str] = None) -> List[MEQBenchItem]: """Create a sample dataset for testing. - + Generates a small set of sample medical content items with different complexity levels for testing and demonstration purposes. - + Args: output_path: Optional path to save the sample dataset as JSON. If provided, saves the dataset to this file. - + Returns: List of MEQBenchItem objects containing sample medical content with basic, intermediate, and advanced complexity levels. - + Example: ```python # Create sample data sample_items = bench.create_sample_dataset() - + # Create and save sample data sample_items = bench.create_sample_dataset("data/sample_dataset.json") ``` @@ -543,63 +543,78 @@ def create_sample_dataset(self, output_path: Optional[str] = None) -> List[MEQBe sample_items = [ MEQBenchItem( id="sample_001", - medical_content="Hypertension, also known as high blood pressure, is a condition where the force of blood against artery walls is consistently too high. It can lead to heart disease, stroke, and kidney problems if left untreated. Treatment typically involves lifestyle changes and medication.", + medical_content=( + "Hypertension, also known as high blood pressure, is a condition where the force of blood " + "against artery walls is consistently too high. It can lead to heart disease, stroke, and " + "kidney problems if left untreated. Treatment typically involves lifestyle changes and medication." + ), complexity_level="basic", - source_dataset="sample" + source_dataset="sample", ), MEQBenchItem( id="sample_002", - medical_content="Myocardial infarction occurs when blood flow to a part of the heart muscle is blocked, usually by a blood clot in a coronary artery. This results in damage or death of heart muscle cells. Immediate treatment with medications to dissolve clots or procedures to restore blood flow is critical.", + medical_content=( + "Myocardial infarction occurs when blood flow to a part of the heart muscle is blocked, usually " + "by a blood clot in a coronary artery. This results in damage or death of heart muscle cells. " + "Immediate treatment with medications to dissolve clots or procedures to restore blood flow is critical." + ), complexity_level="intermediate", - source_dataset="sample" + source_dataset="sample", ), MEQBenchItem( id="sample_003", - medical_content="Diabetic ketoacidosis (DKA) is a serious complication of diabetes mellitus characterized by hyperglycemia, ketosis, and metabolic acidosis. It typically occurs in type 1 diabetes due to absolute insulin deficiency. Treatment involves IV fluids, insulin therapy, and electrolyte replacement while addressing underlying precipitating factors.", + medical_content=( + "Diabetic ketoacidosis (DKA) is a serious complication of diabetes mellitus characterized by " + "hyperglycemia, ketosis, and metabolic acidosis. It typically occurs in type 1 diabetes due to " + "absolute insulin deficiency. Treatment involves IV fluids, insulin therapy, and electrolyte " + "replacement while addressing underlying precipitating factors." + ), complexity_level="advanced", - source_dataset="sample" - ) + source_dataset="sample", + ), ] - + if output_path: # Save sample dataset with proper path handling output_file = Path(output_path).resolve() output_file.parent.mkdir(parents=True, exist_ok=True) - + data_to_save = [] for item in sample_items: - data_to_save.append({ - 'id': item.id, - 'medical_content': item.medical_content, - 'complexity_level': item.complexity_level, - 'source_dataset': item.source_dataset, - 'reference_explanations': item.reference_explanations - }) - + data_to_save.append( + { + "id": item.id, + "medical_content": item.medical_content, + "complexity_level": item.complexity_level, + "source_dataset": item.source_dataset, + "reference_explanations": item.reference_explanations, + } + ) + try: - with open(output_file, 'w', encoding='utf-8') as f: + with open(output_file, "w", encoding="utf-8") as f: json.dump(data_to_save, f, indent=2, ensure_ascii=False) logger.info(f"Sample dataset saved to: {output_file}") except Exception as e: logger.error(f"Failed to save sample dataset to {output_file}: {e}") raise - + return sample_items - + def get_benchmark_stats(self) -> Union[BenchmarkStatsDict, Dict[str, Union[int, str]]]: """Get statistics about the benchmark dataset. - + Provides summary statistics about the loaded benchmark items, including total counts and distributions by complexity level and source dataset. - + Returns: Dictionary containing benchmark statistics with keys: - total_items: Total number of benchmark items - complexity_distribution: Count of items by complexity level - source_distribution: Count of items by source dataset - message: Informational message if no items are loaded - + Example: ```python stats = bench.get_benchmark_stats() @@ -608,83 +623,80 @@ def get_benchmark_stats(self) -> Union[BenchmarkStatsDict, Dict[str, Union[int, ``` """ if not self.benchmark_items: - return {'total_items': 0, 'message': 'No benchmark items loaded'} - + return {"total_items": 0, "message": "No benchmark items loaded"} + stats: BenchmarkStatsDict = { - 'total_items': len(self.benchmark_items), - 'complexity_distribution': {}, - 'source_distribution': {} + "total_items": len(self.benchmark_items), + "complexity_distribution": {}, + "source_distribution": {}, } - + for item in self.benchmark_items: # Count by complexity - if item.complexity_level not in stats['complexity_distribution']: - stats['complexity_distribution'][item.complexity_level] = 0 - stats['complexity_distribution'][item.complexity_level] += 1 - + if item.complexity_level not in stats["complexity_distribution"]: + stats["complexity_distribution"][item.complexity_level] = 0 + stats["complexity_distribution"][item.complexity_level] += 1 + # Count by source - if item.source_dataset not in stats['source_distribution']: - stats['source_distribution'][item.source_dataset] = 0 - stats['source_distribution'][item.source_dataset] += 1 - + if item.source_dataset not in stats["source_distribution"]: + stats["source_distribution"][item.source_dataset] = 0 + stats["source_distribution"][item.source_dataset] += 1 + return stats - + def validate_benchmark(self) -> Dict[str, Any]: """Validate the benchmark dataset and return validation report. - + Returns: Dictionary containing validation results and any issues found """ - validation_report = { - 'valid': True, - 'total_items': len(self.benchmark_items), - 'issues': [], - 'warnings': [], - 'statistics': {} + validation_report: Dict[str, Any] = { + "valid": True, + "total_items": len(self.benchmark_items), + "issues": [], + "warnings": [], + "statistics": {}, } - + if not self.benchmark_items: - validation_report['valid'] = False - validation_report['issues'].append("No benchmark items loaded") + validation_report["valid"] = False + validation_report["issues"].append("No benchmark items loaded") return validation_report - + # Check for duplicate IDs ids = [item.id for item in self.benchmark_items] duplicate_ids = set([id for id in ids if ids.count(id) > 1]) if duplicate_ids: - validation_report['valid'] = False - validation_report['issues'].append(f"Duplicate item IDs found: {duplicate_ids}") - + validation_report["valid"] = False + validation_report["issues"].append(f"Duplicate item IDs found: {duplicate_ids}") + # Validate complexity level distribution - complexity_counts = {} + complexity_counts: Dict[str, int] = {} for item in self.benchmark_items: complexity = item.complexity_level complexity_counts[complexity] = complexity_counts.get(complexity, 0) + 1 - - validation_report['statistics']['complexity_distribution'] = complexity_counts - + + validation_report["statistics"]["complexity_distribution"] = complexity_counts + # Check for balanced distribution if len(complexity_counts) < 3: - validation_report['warnings'].append("Not all complexity levels represented") - + validation_report["warnings"].append("Not all complexity levels represented") + # Validate content length content_lengths = [len(item.medical_content) for item in self.benchmark_items] avg_length = sum(content_lengths) / len(content_lengths) - validation_report['statistics']['average_content_length'] = avg_length - + validation_report["statistics"]["average_content_length"] = avg_length + if avg_length < 50: - validation_report['warnings'].append("Average content length is quite short") + validation_report["warnings"].append("Average content length is quite short") elif avg_length > 2000: - validation_report['warnings'].append("Average content length is quite long") - + validation_report["warnings"].append("Average content length is quite long") + # Check for empty or very short content - short_content_items = [ - item.id for item in self.benchmark_items - if len(item.medical_content.strip()) < 20 - ] + short_content_items = [item.id for item in self.benchmark_items if len(item.medical_content.strip()) < 20] if short_content_items: - validation_report['valid'] = False - validation_report['issues'].append(f"Items with very short content: {short_content_items}") - + validation_report["valid"] = False + validation_report["issues"].append(f"Items with very short content: {short_content_items}") + logger.info(f"Benchmark validation completed. Valid: {validation_report['valid']}") - return validation_report \ No newline at end of file + return validation_report diff --git a/src/config.py b/src/config.py index 4fd5984..9205289 100644 --- a/src/config.py +++ b/src/config.py @@ -12,120 +12,185 @@ - Multiple configuration environments """ -import yaml +import yaml # type: ignore import os import logging from pathlib import Path -from typing import Dict, Any, List, Optional, Union, Type +from typing import Dict, Any, List, Optional class ConfigurationError(Exception): """Raised when there's an error in configuration. - + This exception is raised when: - Configuration files are missing or invalid - Required configuration sections are missing - YAML parsing fails - Required environment variables are not set """ + pass class Config: """Singleton configuration manager for MEQ-Bench. - + This class manages all configuration settings for the MEQ-Bench application. - It loads configuration from YAML files and provides methods to access + It loads configuration from YAML files and provides methods to access configuration values using dot notation. - + The class implements the singleton pattern to ensure that configuration is loaded once and shared across the entire application. - + Attributes: _instance: Singleton instance of the Config class. _config: Dictionary containing the loaded configuration data. - + Example: ```python from config import config - + # Get configuration values model_name = config.get('llm_judge.default_model') audiences = config.get_audiences() - + # Set up logging config.setup_logging() ``` """ - - _instance: Optional['Config'] = None + + _instance: Optional["Config"] = None _config: Optional[Dict[str, Any]] = None - - def __new__(cls) -> 'Config': + + def __new__(cls) -> "Config": if cls._instance is None: cls._instance = super(Config, cls).__new__(cls) return cls._instance - + def __init__(self) -> None: if self._config is None: self.load_config() - + def load_config(self, config_path: Optional[str] = None) -> None: """Load configuration from YAML file. - + Loads configuration settings from a YAML file and validates required sections. If no path is provided, looks for config.yaml in the project root directory. - + Args: config_path: Path to configuration file. If None, uses default config.yaml in the project root directory. - + Raises: ConfigurationError: If configuration file is not found or contains invalid YAML. """ if config_path is None: - # Look for config.yaml in the project root - current_dir = Path(__file__).parent.parent - config_path = str(current_dir / "config.yaml") - + # Search for project root by looking for indicators + config_path = self._find_project_config() + try: - with open(config_path, 'r') as f: + with open(config_path, "r") as f: self._config = yaml.safe_load(f) - + # Validate required sections self._validate_config() - + except FileNotFoundError: raise ConfigurationError(f"Configuration file not found: {config_path}") except yaml.YAMLError as e: raise ConfigurationError(f"Error parsing YAML configuration: {e}") - + + def _find_project_config(self) -> str: + """Find the project configuration file by searching upwards from this file. + + Searches upwards from the current file location until it finds a directory + containing project indicators like .git, pyproject.toml, or setup.py. + + Returns: + Path to config.yaml in the project root. + + Raises: + ConfigurationError: If project root cannot be found. + """ + current_path = Path(__file__).resolve() + + # Search upwards from the current file + for parent in [current_path.parent] + list(current_path.parents): + # Check for project root indicators + indicators = [".git", "pyproject.toml", "setup.py", "setup.cfg", "requirements.txt"] + if any((parent / indicator).exists() for indicator in indicators): + config_file = parent / "config.yaml" + if config_file.exists(): + return str(config_file) + # If config.yaml doesn't exist in project root, create a minimal one + return self._create_minimal_config(parent) + + # Fallback: look in the parent directory of this file + fallback_path = Path(__file__).parent.parent / "config.yaml" + if fallback_path.exists(): + return str(fallback_path) + + # Last resort: create minimal config in current directory + return self._create_minimal_config(Path.cwd()) + + def _create_minimal_config(self, project_root: Path) -> str: + """Create a minimal configuration file for testing/CI environments. + + Args: + project_root: Path to the project root directory. + + Returns: + Path to the created configuration file. + """ + config_path = project_root / "config.yaml" + + minimal_config = { + "app": {"name": "MEQ-Bench", "version": "1.0.0", "data_path": "data/", "output_path": "results/"}, + "audiences": ["physician", "nurse", "patient", "caregiver"], + "complexity_levels": ["basic", "intermediate", "advanced"], + "llm_judge": {"default_model": "gpt-3.5-turbo", "max_tokens": 800, "temperature": 0.3}, + "evaluation": {"metrics": ["readability", "terminology", "safety", "coverage", "quality"], "timeout": 30}, + "scoring": {"weights": {"readability": 0.2, "terminology": 0.2, "safety": 0.3, "coverage": 0.15, "quality": 0.15}}, + "logging": { + "version": 1, + "disable_existing_loggers": False, + "formatters": {"standard": {"format": "%(asctime)s [%(levelname)s] %(name)s: %(message)s"}}, + "handlers": {"default": {"level": "INFO", "formatter": "standard", "class": "logging.StreamHandler"}}, + "loggers": {"": {"handlers": ["default"], "level": "INFO", "propagate": False}}, + }, + } + + try: + with open(config_path, "w") as f: + yaml.dump(minimal_config, f, default_flow_style=False) + return str(config_path) + except Exception: + # If we can't write to disk, return a path that will trigger default values + return str(config_path) + def _validate_config(self) -> None: """Validate that required configuration sections exist""" - required_sections = [ - 'app', 'audiences', 'complexity_levels', 'llm_judge', - 'evaluation', 'scoring', 'logging' - ] - + required_sections = ["app", "audiences", "complexity_levels", "llm_judge", "evaluation", "scoring", "logging"] + for section in required_sections: if self._config is None or section not in self._config: raise ConfigurationError(f"Missing required configuration section: {section}") - + def get(self, key: str, default: Optional[Any] = None) -> Any: """Get configuration value using dot notation. - + Retrieves configuration values using a dot-separated key path. For example, 'llm_judge.default_model' would access config['llm_judge']['default_model']. - + Args: - key: Configuration key using dot notation (e.g., 'app.name', + key: Configuration key using dot notation (e.g., 'app.name', 'llm_judge.default_model'). default: Default value to return if key is not found. - + Returns: The configuration value at the specified key path, or the default value if the key is not found. - + Raises: ConfigurationError: If key is not found and no default value is provided. """ @@ -133,10 +198,10 @@ def get(self, key: str, default: Optional[Any] = None) -> Any: if default is not None: return default raise ConfigurationError(f"Configuration not loaded: {key}") - - keys = key.split('.') + + keys = key.split(".") value = self._config - + try: for k in keys: value = value[k] @@ -145,108 +210,100 @@ def get(self, key: str, default: Optional[Any] = None) -> Any: if default is not None: return default raise ConfigurationError(f"Configuration key not found: {key}") - + def get_audiences(self) -> List[str]: """Get list of target audiences""" - return self.get('audiences') - + return self.get("audiences") + def get_complexity_levels(self) -> List[str]: """Get list of complexity levels""" - return self.get('complexity_levels') - + return self.get("complexity_levels") + def get_llm_config(self) -> Dict[str, Any]: """Get LLM judge configuration""" - return self.get('llm_judge') - + return self.get("llm_judge") + def get_evaluation_config(self) -> Dict[str, Any]: """Get evaluation configuration""" - return self.get('evaluation') - + return self.get("evaluation") + def get_scoring_config(self) -> Dict[str, Any]: """Get scoring configuration""" - return self.get('scoring') - + return self.get("scoring") + def get_api_config(self, provider: str) -> Dict[str, Any]: """ Get API configuration for a specific provider - + Args: provider: API provider name (e.g., 'openai', 'anthropic') """ - return self.get(f'api.{provider}') - + return self.get(f"api.{provider}") + def get_data_path(self) -> str: """Get data directory path""" - return self.get('app.data_path', 'data/') - + return self.get("app.data_path", "data/") + def get_output_path(self) -> str: """Get output directory path""" - return self.get('app.output_path', 'results/') - + return self.get("app.output_path", "results/") + def setup_logging(self) -> None: """Set up logging based on configuration""" import logging.config - + # Create logs directory if it doesn't exist logs_dir = Path("logs") logs_dir.mkdir(exist_ok=True) - + # Get logging configuration - logging_config = self.get('logging') - + logging_config = self.get("logging") + try: logging.config.dictConfig(logging_config) except Exception as e: # Fallback to basic logging if config fails - logging.basicConfig( - level=logging.INFO, - format='%(asctime)s [%(levelname)s] %(name)s: %(message)s' - ) + logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s") logging.error(f"Failed to configure logging from config: {e}") - + def get_environment_variable(self, key: str, default: Optional[str] = None) -> Optional[str]: """ Get environment variable with optional default - + Args: key: Environment variable name default: Default value if not found - + Returns: Environment variable value or default """ return os.getenv(key, default) - + def get_api_key(self, provider: str) -> str: """ Get API key for provider from environment variables - + Args: provider: Provider name (e.g., 'openai', 'anthropic') - + Returns: API key - + Raises: ConfigurationError: If API key not found """ - env_var_map = { - 'openai': 'OPENAI_API_KEY', - 'anthropic': 'ANTHROPIC_API_KEY' - } - + env_var_map = {"openai": "OPENAI_API_KEY", "anthropic": "ANTHROPIC_API_KEY"} + env_var = env_var_map.get(provider.lower()) if not env_var: raise ConfigurationError(f"Unknown API provider: {provider}") - + api_key = self.get_environment_variable(env_var) if not api_key: - raise ConfigurationError( - f"API key not found. Please set {env_var} environment variable." - ) - + raise ConfigurationError(f"API key not found. Please set {env_var} environment variable.") + return api_key - + def reload(self) -> None: """Reload configuration from file""" self._config = None @@ -254,4 +311,4 @@ def reload(self) -> None: # Global config instance -config = Config() \ No newline at end of file +config = Config() diff --git a/src/data_loaders.py b/src/data_loaders.py index 1e17ae2..4427b64 100644 --- a/src/data_loaders.py +++ b/src/data_loaders.py @@ -24,12 +24,12 @@ Example: ```python from data_loaders import load_medqa_usmle, load_icliniq, load_cochrane_reviews - + # Load different datasets with automatic complexity stratification medqa_items = load_medqa_usmle('path/to/medqa.json', max_items=300) icliniq_items = load_icliniq('path/to/icliniq.json', max_items=400) cochrane_items = load_cochrane_reviews('path/to/cochrane.json', max_items=300) - + # Combine and save as benchmark dataset all_items = medqa_items + icliniq_items + cochrane_items save_benchmark_items(all_items, 'data/benchmark_items.json') @@ -40,54 +40,52 @@ import logging import re from pathlib import Path -from typing import List, Dict, Any, Optional, Union +from typing import List, Dict, Optional, Union try: - import textstat + import textstat # type: ignore except ImportError: textstat = None - + from .benchmark import MEQBenchItem -logger = logging.getLogger('meq_bench.data_loaders') +logger = logging.getLogger("meq_bench.data_loaders") def load_medquad( - data_path: Union[str, Path], - max_items: Optional[int] = None, - complexity_level: str = 'basic' + data_path: Union[str, Path], max_items: Optional[int] = None, complexity_level: str = "basic" ) -> List[MEQBenchItem]: """Load MedQuAD dataset and convert to MEQBenchItem objects. - + The MedQuAD (Medical Question Answering Dataset) contains consumer health questions and answers from various medical sources. This function loads the dataset and converts it to the MEQ-Bench format for evaluation. - + Args: data_path: Path to the MedQuAD JSON file. Can be a string or Path object. max_items: Maximum number of items to load. If None, loads all items. complexity_level: Complexity level to assign to all items. Defaults to 'basic' since MedQuAD primarily contains consumer health questions. - + Returns: List of MEQBenchItem objects converted from MedQuAD data. - + Raises: FileNotFoundError: If the data file does not exist. json.JSONDecodeError: If the JSON file is malformed. ValueError: If the data format is invalid. - + Example: ```python # Load all MedQuAD items items = load_medquad('data/medquad.json') - + # Load only first 100 items items = load_medquad('data/medquad.json', max_items=100) - + # Load with different complexity level items = load_medquad('data/medquad.json', complexity_level='intermediate') - + # Add to benchmark bench = MEQBench() for item in items: @@ -95,119 +93,114 @@ def load_medquad( ``` """ data_file = Path(data_path) - + if not data_file.exists(): raise FileNotFoundError(f"MedQuAD file not found: {data_file}") - + logger.info(f"Loading MedQuAD dataset from: {data_file}") - + try: - with open(data_file, 'r', encoding='utf-8') as f: + with open(data_file, "r", encoding="utf-8") as f: data = json.load(f) except json.JSONDecodeError as e: - raise json.JSONDecodeError( - f"Invalid JSON in MedQuAD file: {e}", - e.doc, e.pos - ) - + raise json.JSONDecodeError(f"Invalid JSON in MedQuAD file: {e}", e.doc, e.pos) + if not isinstance(data, list): raise ValueError("MedQuAD data must be a list of items") - + # Validate complexity level - if complexity_level not in ['basic', 'intermediate', 'advanced']: + if complexity_level not in ["basic", "intermediate", "advanced"]: logger.warning(f"Invalid complexity level '{complexity_level}', using 'basic'") - complexity_level = 'basic' - + complexity_level = "basic" + # Convert to MEQBenchItem objects items = [] items_to_process = data[:max_items] if max_items else data - + logger.info(f"Processing {len(items_to_process)} MedQuAD items") - + for i, item_data in enumerate(items_to_process): try: if not isinstance(item_data, dict): logger.warning(f"Skipping invalid MedQuAD item {i}: not a dictionary") continue - + # Extract required fields - MedQuAD typically has 'question' and 'answer' - question = item_data.get('question', '') - answer = item_data.get('answer', '') - item_id = item_data.get('id', f"medquad_{i}") - + question = item_data.get("question", "") + answer = item_data.get("answer", "") + item_id = item_data.get("id", f"medquad_{i}") + if not question.strip() or not answer.strip(): logger.warning(f"Skipping MedQuAD item {i}: empty question or answer") continue - + # Combine question and answer to create medical content medical_content = f"Question: {question.strip()}\\n\\nAnswer: {answer.strip()}" - + # Create MEQBenchItem item = MEQBenchItem( id=str(item_id), medical_content=medical_content, complexity_level=complexity_level, - source_dataset='MedQuAD', - reference_explanations=None # No reference explanations in MedQuAD + source_dataset="MedQuAD", + reference_explanations=None, # No reference explanations in MedQuAD ) - + # Basic validation _validate_benchmark_item(item) - + items.append(item) - + except Exception as e: logger.error(f"Error processing MedQuAD item {i}: {e}") continue - + logger.info(f"Successfully loaded {len(items)} MEQBenchItem objects from MedQuAD") - + if len(items) == 0: logger.warning("No valid items were loaded from MedQuAD dataset") else: # Log some statistics avg_length = sum(len(item.medical_content) for item in items) / len(items) - logger.info(f"MedQuAD dataset statistics:") + logger.info("MedQuAD dataset statistics:") logger.info(f" - Total items: {len(items)}") logger.info(f" - Average content length: {avg_length:.1f} characters") logger.info(f" - Complexity level: {complexity_level}") - + return items def load_healthsearchqa( - data_path: Union[str, Path], - max_items: Optional[int] = None, - complexity_level: str = 'intermediate' + data_path: Union[str, Path], max_items: Optional[int] = None, complexity_level: str = "intermediate" ) -> List[MEQBenchItem]: """Load HealthSearchQA dataset and convert to MEQBenchItem objects. - + The HealthSearchQA dataset contains health-related search queries and answers from various health websites and search engines. This loader converts the dataset into MEQBenchItem objects for use in the benchmark. - + Args: data_path: Path to the HealthSearchQA JSON file. max_items: Maximum number of items to load. If None, loads all items. complexity_level: Complexity level to assign to all items. Defaults to 'intermediate' since HealthSearchQA contains more varied complexity levels. - + Returns: List of MEQBenchItem objects converted from HealthSearchQA data. - + Raises: FileNotFoundError: If the data file does not exist. json.JSONDecodeError: If the JSON file is malformed. ValueError: If the data format is invalid. - + Example: ```python # Load HealthSearchQA items items = load_healthsearchqa('data/healthsearchqa.json') - + # Load with custom complexity level items = load_healthsearchqa('data/healthsearchqa.json', complexity_level='advanced') - + # Add to benchmark bench = MEQBench() for item in items: @@ -215,82 +208,79 @@ def load_healthsearchqa( ``` """ data_file = Path(data_path) - + if not data_file.exists(): raise FileNotFoundError(f"HealthSearchQA file not found: {data_file}") - + logger.info(f"Loading HealthSearchQA dataset from: {data_file}") - + try: - with open(data_file, 'r', encoding='utf-8') as f: + with open(data_file, "r", encoding="utf-8") as f: data = json.load(f) except json.JSONDecodeError as e: - raise json.JSONDecodeError( - f"Invalid JSON in HealthSearchQA file: {e}", - e.doc, e.pos - ) - + raise json.JSONDecodeError(f"Invalid JSON in HealthSearchQA file: {e}", e.doc, e.pos) + if not isinstance(data, list): raise ValueError("HealthSearchQA data must be a list of items") - + # Validate complexity level - if complexity_level not in ['basic', 'intermediate', 'advanced']: + if complexity_level not in ["basic", "intermediate", "advanced"]: logger.warning(f"Invalid complexity level '{complexity_level}', using 'intermediate'") - complexity_level = 'intermediate' - + complexity_level = "intermediate" + items = [] items_to_process = data[:max_items] if max_items else data - + logger.info(f"Processing {len(items_to_process)} HealthSearchQA items") - + for i, item_data in enumerate(items_to_process): try: if not isinstance(item_data, dict): logger.warning(f"Skipping invalid HealthSearchQA item {i}: not a dictionary") continue - + # HealthSearchQA might have different field names - query = item_data.get('query', item_data.get('question', '')) - answer = item_data.get('answer', item_data.get('response', '')) - item_id = item_data.get('id', f"healthsearch_{i}") - + query = item_data.get("query", item_data.get("question", "")) + answer = item_data.get("answer", item_data.get("response", "")) + item_id = item_data.get("id", f"healthsearch_{i}") + if not query.strip() or not answer.strip(): logger.warning(f"Skipping HealthSearchQA item {i}: empty query or answer") continue - + # Create medical content medical_content = f"Search Query: {query.strip()}\\n\\nAnswer: {answer.strip()}" - + # Create MEQBenchItem item = MEQBenchItem( id=str(item_id), medical_content=medical_content, complexity_level=complexity_level, - source_dataset='HealthSearchQA', - reference_explanations=None + source_dataset="HealthSearchQA", + reference_explanations=None, ) - + # Basic validation _validate_benchmark_item(item) - + items.append(item) - + except Exception as e: logger.error(f"Error processing HealthSearchQA item {i}: {e}") continue - + logger.info(f"Successfully loaded {len(items)} MEQBenchItem objects from HealthSearchQA") - + if len(items) == 0: logger.warning("No valid items were loaded from HealthSearchQA dataset") else: # Log some statistics avg_length = sum(len(item.medical_content) for item in items) / len(items) - logger.info(f"HealthSearchQA dataset statistics:") + logger.info("HealthSearchQA dataset statistics:") logger.info(f" - Total items: {len(items)}") logger.info(f" - Average content length: {avg_length:.1f} characters") logger.info(f" - Complexity level: {complexity_level}") - + return items @@ -298,52 +288,47 @@ def load_custom_dataset( data_path: Union[str, Path], field_mapping: Optional[Dict[str, str]] = None, max_items: Optional[int] = None, - complexity_level: str = 'basic' + complexity_level: str = "basic", ) -> List[MEQBenchItem]: """Load custom dataset and convert to MEQBenchItem objects. - + Args: data_path: Path to the JSON file containing the dataset. field_mapping: Dictionary mapping dataset fields to MEQBenchItem fields. Example: {'q': 'question', 'a': 'answer', 'topic': 'medical_content'} max_items: Maximum number of items to load. complexity_level: Complexity level to assign to all items. - + Returns: List of MEQBenchItem objects. """ # Default field mapping if field_mapping is None: - field_mapping = { - 'question': 'question', - 'answer': 'answer', - 'content': 'medical_content', - 'id': 'id' - } - + field_mapping = {"question": "question", "answer": "answer", "content": "medical_content", "id": "id"} + data_file = Path(data_path) if not data_file.exists(): raise FileNotFoundError(f"Custom dataset file not found: {data_file}") - + logger.info(f"Loading custom dataset from: {data_file}") - - with open(data_file, 'r', encoding='utf-8') as f: + + with open(data_file, "r", encoding="utf-8") as f: data = json.load(f) - + if not isinstance(data, list): raise ValueError("Custom dataset must be a list of items") - + items = [] items_to_process = data[:max_items] if max_items else data - + for i, item_data in enumerate(items_to_process): try: # Extract fields based on mapping - question = item_data.get(field_mapping.get('question', 'question'), '') - answer = item_data.get(field_mapping.get('answer', 'answer'), '') - content = item_data.get(field_mapping.get('content', 'content'), '') - item_id = item_data.get(field_mapping.get('id', 'id'), f"custom_{i}") - + question = item_data.get(field_mapping.get("question", "question"), "") + answer = item_data.get(field_mapping.get("answer", "answer"), "") + content = item_data.get(field_mapping.get("content", "content"), "") + item_id = item_data.get(field_mapping.get("id", "id"), f"custom_{i}") + # Create medical content if content: medical_content = content @@ -352,33 +337,29 @@ def load_custom_dataset( else: logger.warning(f"Skipping item {i}: no valid content found") continue - + item = MEQBenchItem( id=str(item_id), medical_content=medical_content, complexity_level=complexity_level, - source_dataset='Custom', - reference_explanations=None + source_dataset="Custom", + reference_explanations=None, ) - + _validate_benchmark_item(item) items.append(item) - + except Exception as e: logger.error(f"Error processing custom dataset item {i}: {e}") continue - + logger.info(f"Successfully loaded {len(items)} items from custom dataset") return items -def save_benchmark_items( - items: List[MEQBenchItem], - output_path: Union[str, Path], - pretty_print: bool = True -) -> None: +def save_benchmark_items(items: List[MEQBenchItem], output_path: Union[str, Path], pretty_print: bool = True) -> None: """Save MEQBenchItem objects to a JSON file. - + Args: items: List of MEQBenchItem objects to save. output_path: Path where to save the JSON file. @@ -386,44 +367,44 @@ def save_benchmark_items( """ output_file = Path(output_path) output_file.parent.mkdir(parents=True, exist_ok=True) - + # Convert items to dictionaries items_data = [] for item in items: item_dict = { - 'id': item.id, - 'medical_content': item.medical_content, - 'complexity_level': item.complexity_level, - 'source_dataset': item.source_dataset, - 'reference_explanations': item.reference_explanations + "id": item.id, + "medical_content": item.medical_content, + "complexity_level": item.complexity_level, + "source_dataset": item.source_dataset, + "reference_explanations": item.reference_explanations, } items_data.append(item_dict) - + # Save to JSON - with open(output_file, 'w', encoding='utf-8') as f: + with open(output_file, "w", encoding="utf-8") as f: if pretty_print: json.dump(items_data, f, indent=2, ensure_ascii=False) else: json.dump(items_data, f, ensure_ascii=False) - + logger.info(f"Saved {len(items)} benchmark items to: {output_file}") def calculate_complexity_level(text: str) -> str: """Calculate complexity level based on Flesch-Kincaid Grade Level. - + Uses textstat library to compute Flesch-Kincaid Grade Level and categorizes the text into basic, intermediate, or advanced complexity levels. - + Args: text: Text content to analyze for complexity. - + Returns: Complexity level string: 'basic', 'intermediate', or 'advanced' - + Raises: ValueError: If text is empty or invalid. - + Example: ```python complexity = calculate_complexity_level("This is simple text.") @@ -432,31 +413,31 @@ def calculate_complexity_level(text: str) -> str: """ if not text or not isinstance(text, str): raise ValueError("Text must be a non-empty string") - + text = text.strip() if not text: raise ValueError("Text cannot be empty or whitespace only") - + # Clean text for analysis - remove extra whitespace and normalize - cleaned_text = ' '.join(text.split()) - + cleaned_text = " ".join(text.split()) + # Fallback if textstat is not available if textstat is None: logger.warning("textstat library not available, using fallback complexity calculation") return _calculate_complexity_fallback(cleaned_text) - + try: # Calculate Flesch-Kincaid Grade Level fk_score = textstat.flesch_kincaid().grade(cleaned_text) - + # Categorize based on grade level if fk_score <= 8: - return 'basic' + return "basic" elif fk_score <= 12: - return 'intermediate' + return "intermediate" else: - return 'advanced' - + return "advanced" + except Exception as e: logger.warning(f"Error calculating Flesch-Kincaid score: {e}, using fallback") return _calculate_complexity_fallback(cleaned_text) @@ -464,235 +445,240 @@ def calculate_complexity_level(text: str) -> str: def _calculate_complexity_fallback(text: str) -> str: """Fallback complexity calculation when textstat is unavailable. - + Uses simple heuristics based on sentence length, word length, and medical terminology density as approximations. - + Args: text: Cleaned text to analyze. - + Returns: Complexity level: 'basic', 'intermediate', or 'advanced' """ # Count sentences (approximate) - sentences = len(re.split(r'[.!?]+', text)) + sentences = len(re.split(r"[.!?]+", text)) if sentences == 0: sentences = 1 - + # Count words words = len(text.split()) if words == 0: - return 'basic' - + return "basic" + # Calculate average words per sentence avg_words_per_sentence = words / sentences - + # Count syllables (rough approximation) syllable_count = 0 for word in text.split(): # Simple syllable counting heuristic - vowels = 'aeiouyAEIOUY' - word = re.sub(r'[^a-zA-Z]', '', word) + word = re.sub(r"[^a-zA-Z]", "", word) if word: - syllables = len(re.findall(r'[aeiouyAEIOUY]+', word)) + syllables = len(re.findall(r"[aeiouyAEIOUY]+", word)) if syllables == 0: syllables = 1 syllable_count += syllables - + avg_syllables_per_word = syllable_count / words if words > 0 else 1 - + # Check for medical terminology (indicator of higher complexity) medical_terms = [ - 'diagnosis', 'treatment', 'therapy', 'syndrome', 'pathology', - 'etiology', 'prognosis', 'medication', 'prescription', 'dosage', - 'contraindication', 'adverse', 'efficacy', 'pharmacology', - 'clinical', 'therapeutic', 'intervention' + "diagnosis", + "treatment", + "therapy", + "syndrome", + "pathology", + "etiology", + "prognosis", + "medication", + "prescription", + "dosage", + "contraindication", + "adverse", + "efficacy", + "pharmacology", + "clinical", + "therapeutic", + "intervention", ] - + medical_term_count = sum(1 for term in medical_terms if term.lower() in text.lower()) medical_density = medical_term_count / words * 100 # percentage - + # Simple scoring algorithm - complexity_score = 0 + complexity_score: float = 0 complexity_score += avg_words_per_sentence * 0.5 complexity_score += avg_syllables_per_word * 3 complexity_score += medical_density * 0.3 - + if complexity_score <= 8: - return 'basic' + return "basic" elif complexity_score <= 15: - return 'intermediate' + return "intermediate" else: - return 'advanced' + return "advanced" def load_medqa_usmle( - data_path: Union[str, Path], - max_items: Optional[int] = None, - auto_complexity: bool = True + data_path: Union[str, Path], max_items: Optional[int] = None, auto_complexity: bool = True ) -> List[MEQBenchItem]: """Load MedQA-USMLE dataset and convert to MEQBenchItem objects. - + The MedQA-USMLE dataset contains medical questions based on USMLE exam format with multiple choice questions and explanations. This loader processes the dataset and optionally applies automatic complexity stratification. - + Args: data_path: Path to the MedQA-USMLE JSON file. max_items: Maximum number of items to load. If None, loads all items. auto_complexity: Whether to automatically calculate complexity levels using Flesch-Kincaid scores. If False, assigns 'intermediate' to all. - + Returns: List of MEQBenchItem objects converted from MedQA-USMLE data. - + Raises: FileNotFoundError: If the data file does not exist. json.JSONDecodeError: If the JSON file is malformed. ValueError: If the data format is invalid. - + Example: ```python # Load with automatic complexity calculation items = load_medqa_usmle('data/medqa_usmle.json', max_items=300) - + # Load without complexity calculation (all marked as intermediate) items = load_medqa_usmle('data/medqa_usmle.json', auto_complexity=False) ``` """ data_file = Path(data_path) - + if not data_file.exists(): raise FileNotFoundError(f"MedQA-USMLE file not found: {data_file}") - + logger.info(f"Loading MedQA-USMLE dataset from: {data_file}") - + try: - with open(data_file, 'r', encoding='utf-8') as f: + with open(data_file, "r", encoding="utf-8") as f: data = json.load(f) except json.JSONDecodeError as e: - raise json.JSONDecodeError( - f"Invalid JSON in MedQA-USMLE file: {e}", - e.doc, e.pos - ) - + raise json.JSONDecodeError(f"Invalid JSON in MedQA-USMLE file: {e}", e.doc, e.pos) + if not isinstance(data, list): raise ValueError("MedQA-USMLE data must be a list of items") - + items = [] items_to_process = data[:max_items] if max_items else data - + logger.info(f"Processing {len(items_to_process)} MedQA-USMLE items") - + for i, item_data in enumerate(items_to_process): try: if not isinstance(item_data, dict): logger.warning(f"Skipping invalid MedQA-USMLE item {i}: not a dictionary") continue - + # Extract required fields - MedQA typically has 'question', 'options', 'answer', 'explanation' - question = item_data.get('question', '') - options = item_data.get('options', {}) - answer = item_data.get('answer', '') - explanation = item_data.get('explanation', '') - item_id = item_data.get('id', f"medqa_usmle_{i}") - + question = item_data.get("question", "") + options = item_data.get("options", {}) + answer = item_data.get("answer", "") + explanation = item_data.get("explanation", "") + item_id = item_data.get("id", f"medqa_usmle_{i}") + if not question.strip(): logger.warning(f"Skipping MedQA-USMLE item {i}: empty question") continue - + # Format options if available options_text = "" if isinstance(options, dict): options_text = "\n".join([f"{k}. {v}" for k, v in options.items() if v]) elif isinstance(options, list): options_text = "\n".join([f"{chr(65+j)}. {opt}" for j, opt in enumerate(options) if opt]) - + # Create comprehensive medical content medical_content_parts = [f"Question: {question.strip()}"] - + if options_text: medical_content_parts.append(f"Options:\n{options_text}") - + if answer.strip(): medical_content_parts.append(f"Correct Answer: {answer.strip()}") - + if explanation.strip(): medical_content_parts.append(f"Explanation: {explanation.strip()}") - + medical_content = "\n\n".join(medical_content_parts) - + # Calculate complexity level if auto_complexity: try: complexity_level = calculate_complexity_level(medical_content) except Exception as e: logger.warning(f"Error calculating complexity for item {i}: {e}, using 'intermediate'") - complexity_level = 'intermediate' + complexity_level = "intermediate" else: - complexity_level = 'intermediate' - + complexity_level = "intermediate" + # Create MEQBenchItem item = MEQBenchItem( id=str(item_id), medical_content=medical_content, complexity_level=complexity_level, - source_dataset='MedQA-USMLE', - reference_explanations=None + source_dataset="MedQA-USMLE", + reference_explanations=None, ) - + # Validate the item _validate_benchmark_item(item) - + items.append(item) - + except Exception as e: logger.error(f"Error processing MedQA-USMLE item {i}: {e}") continue - + logger.info(f"Successfully loaded {len(items)} MEQBenchItem objects from MedQA-USMLE") - + if len(items) == 0: logger.warning("No valid items were loaded from MedQA-USMLE dataset") else: # Log complexity distribution - complexity_dist = {} + complexity_dist: Dict[str, int] = {} for item in items: complexity_dist[item.complexity_level] = complexity_dist.get(item.complexity_level, 0) + 1 - + avg_length = sum(len(item.medical_content) for item in items) / len(items) - logger.info(f"MedQA-USMLE dataset statistics:") + logger.info("MedQA-USMLE dataset statistics:") logger.info(f" - Total items: {len(items)}") logger.info(f" - Average content length: {avg_length:.1f} characters") logger.info(f" - Complexity distribution: {complexity_dist}") - + return items def load_icliniq( - data_path: Union[str, Path], - max_items: Optional[int] = None, - auto_complexity: bool = True + data_path: Union[str, Path], max_items: Optional[int] = None, auto_complexity: bool = True ) -> List[MEQBenchItem]: """Load iCliniq dataset and convert to MEQBenchItem objects. - + The iCliniq dataset contains real clinical questions from patients and answers from medical professionals. This loader processes the dataset and optionally applies automatic complexity stratification. - + Args: data_path: Path to the iCliniq JSON file. max_items: Maximum number of items to load. If None, loads all items. auto_complexity: Whether to automatically calculate complexity levels. - + Returns: List of MEQBenchItem objects converted from iCliniq data. - + Raises: FileNotFoundError: If the data file does not exist. json.JSONDecodeError: If the JSON file is malformed. ValueError: If the data format is invalid. - + Example: ```python # Load iCliniq dataset with complexity stratification @@ -700,126 +686,121 @@ def load_icliniq( ``` """ data_file = Path(data_path) - + if not data_file.exists(): raise FileNotFoundError(f"iCliniq file not found: {data_file}") - + logger.info(f"Loading iCliniq dataset from: {data_file}") - + try: - with open(data_file, 'r', encoding='utf-8') as f: + with open(data_file, "r", encoding="utf-8") as f: data = json.load(f) except json.JSONDecodeError as e: - raise json.JSONDecodeError( - f"Invalid JSON in iCliniq file: {e}", - e.doc, e.pos - ) - + raise json.JSONDecodeError(f"Invalid JSON in iCliniq file: {e}", e.doc, e.pos) + if not isinstance(data, list): raise ValueError("iCliniq data must be a list of items") - + items = [] items_to_process = data[:max_items] if max_items else data - + logger.info(f"Processing {len(items_to_process)} iCliniq items") - + for i, item_data in enumerate(items_to_process): try: if not isinstance(item_data, dict): logger.warning(f"Skipping invalid iCliniq item {i}: not a dictionary") continue - + # Extract fields - iCliniq typically has 'patient_question', 'doctor_answer', 'speciality' - patient_question = item_data.get('patient_question', item_data.get('question', '')) - doctor_answer = item_data.get('doctor_answer', item_data.get('answer', '')) - specialty = item_data.get('speciality', item_data.get('specialty', '')) - item_id = item_data.get('id', f"icliniq_{i}") - + patient_question = item_data.get("patient_question", item_data.get("question", "")) + doctor_answer = item_data.get("doctor_answer", item_data.get("answer", "")) + specialty = item_data.get("speciality", item_data.get("specialty", "")) + item_id = item_data.get("id", f"icliniq_{i}") + if not patient_question.strip() or not doctor_answer.strip(): logger.warning(f"Skipping iCliniq item {i}: empty question or answer") continue - + # Create medical content medical_content_parts = [f"Patient Question: {patient_question.strip()}"] - + if specialty.strip(): medical_content_parts.append(f"Medical Specialty: {specialty.strip()}") - + medical_content_parts.append(f"Doctor's Answer: {doctor_answer.strip()}") - + medical_content = "\n\n".join(medical_content_parts) - + # Calculate complexity level if auto_complexity: try: complexity_level = calculate_complexity_level(medical_content) except Exception as e: logger.warning(f"Error calculating complexity for item {i}: {e}, using 'basic'") - complexity_level = 'basic' + complexity_level = "basic" else: - complexity_level = 'basic' # iCliniq tends to be more patient-focused - + complexity_level = "basic" # iCliniq tends to be more patient-focused + # Create MEQBenchItem item = MEQBenchItem( id=str(item_id), medical_content=medical_content, complexity_level=complexity_level, - source_dataset='iCliniq', - reference_explanations=None + source_dataset="iCliniq", + reference_explanations=None, ) - + # Validate the item _validate_benchmark_item(item) - + items.append(item) - + except Exception as e: logger.error(f"Error processing iCliniq item {i}: {e}") continue - + logger.info(f"Successfully loaded {len(items)} MEQBenchItem objects from iCliniq") - + if len(items) == 0: logger.warning("No valid items were loaded from iCliniq dataset") else: # Log complexity distribution - complexity_dist = {} + complexity_dist: Dict[str, int] = {} for item in items: complexity_dist[item.complexity_level] = complexity_dist.get(item.complexity_level, 0) + 1 - + avg_length = sum(len(item.medical_content) for item in items) / len(items) - logger.info(f"iCliniq dataset statistics:") + logger.info("iCliniq dataset statistics:") logger.info(f" - Total items: {len(items)}") logger.info(f" - Average content length: {avg_length:.1f} characters") logger.info(f" - Complexity distribution: {complexity_dist}") - + return items def load_cochrane_reviews( - data_path: Union[str, Path], - max_items: Optional[int] = None, - auto_complexity: bool = True + data_path: Union[str, Path], max_items: Optional[int] = None, auto_complexity: bool = True ) -> List[MEQBenchItem]: """Load Cochrane Reviews dataset and convert to MEQBenchItem objects. - + The Cochrane Reviews dataset contains evidence-based medical reviews and systematic analyses. This loader processes the dataset and optionally applies automatic complexity stratification. - + Args: data_path: Path to the Cochrane Reviews JSON file. max_items: Maximum number of items to load. If None, loads all items. auto_complexity: Whether to automatically calculate complexity levels. - + Returns: List of MEQBenchItem objects converted from Cochrane Reviews data. - + Raises: FileNotFoundError: If the data file does not exist. json.JSONDecodeError: If the JSON file is malformed. ValueError: If the data format is invalid. - + Example: ```python # Load Cochrane Reviews with complexity stratification @@ -827,128 +808,125 @@ def load_cochrane_reviews( ``` """ data_file = Path(data_path) - + if not data_file.exists(): raise FileNotFoundError(f"Cochrane Reviews file not found: {data_file}") - + logger.info(f"Loading Cochrane Reviews dataset from: {data_file}") - + try: - with open(data_file, 'r', encoding='utf-8') as f: + with open(data_file, "r", encoding="utf-8") as f: data = json.load(f) except json.JSONDecodeError as e: - raise json.JSONDecodeError( - f"Invalid JSON in Cochrane Reviews file: {e}", - e.doc, e.pos - ) - + raise json.JSONDecodeError(f"Invalid JSON in Cochrane Reviews file: {e}", e.doc, e.pos) + if not isinstance(data, list): raise ValueError("Cochrane Reviews data must be a list of items") - + items = [] items_to_process = data[:max_items] if max_items else data - + logger.info(f"Processing {len(items_to_process)} Cochrane Reviews items") - + for i, item_data in enumerate(items_to_process): try: if not isinstance(item_data, dict): logger.warning(f"Skipping invalid Cochrane Reviews item {i}: not a dictionary") continue - + # Extract fields - Cochrane typically has 'title', 'abstract', 'conclusions', 'background' - title = item_data.get('title', '') - abstract = item_data.get('abstract', '') - conclusions = item_data.get('conclusions', item_data.get('main_results', '')) - background = item_data.get('background', item_data.get('objectives', '')) - item_id = item_data.get('id', f"cochrane_{i}") - + title = item_data.get("title", "") + abstract = item_data.get("abstract", "") + conclusions = item_data.get("conclusions", item_data.get("main_results", "")) + background = item_data.get("background", item_data.get("objectives", "")) + item_id = item_data.get("id", f"cochrane_{i}") + if not title.strip() and not abstract.strip(): logger.warning(f"Skipping Cochrane Reviews item {i}: no title or abstract") continue - + # Create medical content from available fields medical_content_parts = [] - + if title.strip(): medical_content_parts.append(f"Title: {title.strip()}") - + if background.strip(): medical_content_parts.append(f"Background: {background.strip()}") - + if abstract.strip(): medical_content_parts.append(f"Abstract: {abstract.strip()}") - + if conclusions.strip(): medical_content_parts.append(f"Conclusions: {conclusions.strip()}") - + if not medical_content_parts: logger.warning(f"Skipping Cochrane Reviews item {i}: no valid content") continue - + medical_content = "\n\n".join(medical_content_parts) - + # Calculate complexity level if auto_complexity: try: complexity_level = calculate_complexity_level(medical_content) except Exception as e: logger.warning(f"Error calculating complexity for item {i}: {e}, using 'advanced'") - complexity_level = 'advanced' + complexity_level = "advanced" else: - complexity_level = 'advanced' # Cochrane reviews tend to be more technical - + complexity_level = "advanced" # Cochrane reviews tend to be more technical + # Create MEQBenchItem item = MEQBenchItem( id=str(item_id), medical_content=medical_content, complexity_level=complexity_level, - source_dataset='Cochrane Reviews', - reference_explanations=None + source_dataset="Cochrane Reviews", + reference_explanations=None, ) - + # Validate the item _validate_benchmark_item(item) - + items.append(item) - + except Exception as e: logger.error(f"Error processing Cochrane Reviews item {i}: {e}") continue - + logger.info(f"Successfully loaded {len(items)} MEQBenchItem objects from Cochrane Reviews") - + if len(items) == 0: logger.warning("No valid items were loaded from Cochrane Reviews dataset") else: # Log complexity distribution - complexity_dist = {} + complexity_dist: Dict[str, int] = {} for item in items: complexity_dist[item.complexity_level] = complexity_dist.get(item.complexity_level, 0) + 1 - + avg_length = sum(len(item.medical_content) for item in items) / len(items) - logger.info(f"Cochrane Reviews dataset statistics:") + logger.info("Cochrane Reviews dataset statistics:") logger.info(f" - Total items: {len(items)}") logger.info(f" - Average content length: {avg_length:.1f} characters") logger.info(f" - Complexity distribution: {complexity_dist}") - + return items def _validate_benchmark_item(item: MEQBenchItem) -> None: """Validate a MEQBenchItem object for basic requirements. - + Args: item: MEQBenchItem to validate - + Raises: ValueError: If the item doesn't meet basic requirements """ if not item.id or not isinstance(item.id, str): raise ValueError("Item ID must be a non-empty string") - + if not item.medical_content or not isinstance(item.medical_content, str): raise ValueError("Medical content must be a non-empty string") - + if len(item.medical_content.strip()) < 20: - raise ValueError("Medical content is too short (less than 20 characters)") \ No newline at end of file + raise ValueError("Medical content is too short (less than 20 characters)") diff --git a/src/evaluator.py b/src/evaluator.py index 3240983..bd4c045 100644 --- a/src/evaluator.py +++ b/src/evaluator.py @@ -3,13 +3,13 @@ """ import re -import os import time import json -import requests +import requests # type: ignore import numpy as np -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import Dict, List, Any, Optional, Protocol, Union, Callable, Type, Set + try: from typing_extensions import TypedDict except ImportError: @@ -21,18 +21,20 @@ from .strategies import StrategyFactory, AudienceStrategy # Set up logging -logger = logging.getLogger('meq_bench.evaluator') +logger = logging.getLogger("meq_bench.evaluator") # TypedDict definitions for better structure typing class APIConfigDict(TypedDict): """Type definition for API configuration.""" + base_url: str timeout: Optional[int] class EvaluationConfigDict(TypedDict): """Type definition for evaluation configuration.""" + safety: Dict[str, List[str]] medical_terms: List[str] readability_targets: Dict[str, Dict[str, float]] @@ -41,14 +43,16 @@ class EvaluationConfigDict(TypedDict): class ScoringConfigDict(TypedDict): """Type definition for scoring configuration.""" + weights: Dict[str, float] parameters: Dict[str, Any] + try: - import textstat - from sentence_transformers import SentenceTransformer - import spacy - from transformers import pipeline + import textstat # type: ignore + from sentence_transformers import SentenceTransformer # type: ignore + import spacy # type: ignore + from transformers import pipeline # type: ignore except ImportError as e: logger.warning(f"Some evaluation dependencies not installed: {e}") logger.info("Install with: pip install -r requirements.txt") @@ -56,12 +60,14 @@ class ScoringConfigDict(TypedDict): class EvaluationError(Exception): """Raised when there's an error during evaluation""" + pass @dataclass class EvaluationScore: """Container for evaluation scores with detailed breakdown""" + readability: float terminology: float safety: float @@ -72,26 +78,26 @@ class EvaluationScore: hallucination: float overall: float details: Optional[Dict[str, Any]] = None - + def to_dict(self) -> Dict[str, Any]: """Convert to dictionary for serialization""" return { - 'readability': self.readability, - 'terminology': self.terminology, - 'safety': self.safety, - 'coverage': self.coverage, - 'quality': self.quality, - 'contradiction': self.contradiction, - 'information_preservation': self.information_preservation, - 'hallucination': self.hallucination, - 'overall': self.overall, - 'details': self.details or {} + "readability": self.readability, + "terminology": self.terminology, + "safety": self.safety, + "coverage": self.coverage, + "quality": self.quality, + "contradiction": self.contradiction, + "information_preservation": self.information_preservation, + "hallucination": self.hallucination, + "overall": self.overall, + "details": self.details or {}, } class MetricCalculator(Protocol): """Protocol for metric calculators""" - + def calculate(self, text: str, audience: str, **kwargs) -> float: """Calculate metric score""" ... @@ -99,20 +105,20 @@ def calculate(self, text: str, audience: str, **kwargs) -> float: class ReadabilityCalculator: """Calculator for readability metrics using dependency injection""" - + def __init__(self, strategy_factory: StrategyFactory) -> None: self.strategy_factory: StrategyFactory = strategy_factory logger.debug("Initialized ReadabilityCalculator") - + def calculate(self, text: str, audience: str, **kwargs) -> float: """ Calculate readability score for given audience - + Args: text: Text to analyze audience: Target audience **kwargs: Additional parameters - + Returns: Readability score (0-1) """ @@ -120,24 +126,24 @@ def calculate(self, text: str, audience: str, **kwargs) -> float: if not text.strip(): logger.warning("Empty text provided for readability calculation") return 0.0 - + # Get grade level using textstat try: grade_level = textstat.flesch_kincaid().grade(text) except Exception as e: logger.error(f"Error calculating Flesch-Kincaid score: {e}") # Fallback: estimate based on sentence length - sentences = text.split('.') + sentences = text.split(".") avg_sentence_length = sum(len(s.split()) for s in sentences) / max(len(sentences), 1) grade_level = min(16, max(6, avg_sentence_length / 4)) - + # Use strategy pattern for audience-specific scoring strategy = self.strategy_factory.create_strategy(audience) score = strategy.calculate_readability_score(text, grade_level) - + logger.debug(f"Readability score for {audience}: {score:.3f} (grade level: {grade_level:.1f})") return score - + except Exception as e: logger.error(f"Error calculating readability for {audience}: {e}") raise EvaluationError(f"Readability calculation failed: {e}") @@ -145,21 +151,21 @@ def calculate(self, text: str, audience: str, **kwargs) -> float: class TerminologyCalculator: """Calculator for medical terminology appropriateness""" - + def __init__(self, strategy_factory: StrategyFactory) -> None: self.strategy_factory: StrategyFactory = strategy_factory - self.medical_terms: Set[str] = set(config.get('evaluation.medical_terms', [])) + self.medical_terms: Set[str] = set(config.get("evaluation.medical_terms", [])) logger.debug(f"Initialized TerminologyCalculator with {len(self.medical_terms)} medical terms") - + def calculate(self, text: str, audience: str, **kwargs) -> float: """ Calculate terminology appropriateness score - + Args: text: Text to analyze audience: Target audience **kwargs: Additional parameters - + Returns: Terminology score (0-1) """ @@ -167,22 +173,22 @@ def calculate(self, text: str, audience: str, **kwargs) -> float: if not text.strip(): logger.warning("Empty text provided for terminology calculation") return 0.0 - + words = text.lower().split() if not words: return 0.0 - + # Count medical terms medical_count = sum(1 for word in words if any(term in word for term in self.medical_terms)) term_density = medical_count / len(words) - + # Use strategy pattern for audience-specific scoring strategy = self.strategy_factory.create_strategy(audience) score = strategy.calculate_terminology_score(text, term_density) - + logger.debug(f"Terminology score for {audience}: {score:.3f} (density: {term_density:.3f})") return score - + except Exception as e: logger.error(f"Error calculating terminology for {audience}: {e}") raise EvaluationError(f"Terminology calculation failed: {e}") @@ -190,22 +196,22 @@ def calculate(self, text: str, audience: str, **kwargs) -> float: class SafetyChecker: """Medical safety and factual consistency checker""" - + def __init__(self) -> None: eval_config = config.get_evaluation_config() # type: ignore[misc] - self.danger_words: List[str] = eval_config['safety']['danger_words'] - self.safety_words: List[str] = eval_config['safety']['safety_words'] + self.danger_words: List[str] = eval_config["safety"]["danger_words"] + self.safety_words: List[str] = eval_config["safety"]["safety_words"] logger.debug(f"Initialized SafetyChecker with {len(self.danger_words)} danger words") - + def calculate(self, text: str, audience: str, **kwargs) -> float: """ Check text for safety compliance - + Args: text: Text to check audience: Target audience (not used in current implementation) **kwargs: Additional parameters - + Returns: Safety compliance score (0-1) """ @@ -213,30 +219,30 @@ def calculate(self, text: str, audience: str, **kwargs) -> float: if not text.strip(): logger.warning("Empty text provided for safety check") return 0.5 # Neutral score for empty text - + text_lower = text.lower() - + # Check for dangerous advice danger_count = sum(1 for word in self.danger_words if word in text_lower) if danger_count > 0: logger.warning(f"Dangerous content detected: {danger_count} danger words found") return 0.0 - + # Check for appropriate safety language safety_count = sum(1 for word in self.safety_words if word in text_lower) - + # Calculate safety score safety_score = min(1.0, safety_count * 0.3) - + # Bonus for mentioning healthcare professionals - professional_terms = ['doctor', 'physician', 'healthcare provider', 'medical professional'] + professional_terms = ["doctor", "physician", "healthcare provider", "medical professional"] professional_mentions = sum(1 for term in professional_terms if term in text_lower) if professional_mentions > 0: safety_score = min(1.0, safety_score + 0.2) - + logger.debug(f"Safety score: {safety_score:.3f} (safety words: {safety_count})") return safety_score - + except Exception as e: logger.error(f"Error in safety check: {e}") raise EvaluationError(f"Safety check failed: {e}") @@ -244,25 +250,25 @@ def calculate(self, text: str, audience: str, **kwargs) -> float: class CoverageAnalyzer: """Analyzer for information coverage and completeness""" - + def __init__(self) -> None: try: - self.sentence_model: Optional[Any] = SentenceTransformer('all-MiniLM-L6-v2') + self.sentence_model: Optional[Any] = SentenceTransformer("all-MiniLM-L6-v2") logger.debug("Initialized CoverageAnalyzer with SentenceTransformer") except Exception as e: logger.warning(f"Failed to load SentenceTransformer: {e}") self.sentence_model = None - + def calculate(self, text: str, audience: str, original: str = "", **kwargs) -> float: """ Measure information coverage using semantic similarity - + Args: text: Generated explanation text audience: Target audience original: Original medical information **kwargs: Additional parameters - + Returns: Coverage score (0-1) """ @@ -270,118 +276,109 @@ def calculate(self, text: str, audience: str, original: str = "", **kwargs) -> f if not text.strip() or not original.strip(): logger.warning("Empty text or original provided for coverage analysis") return 0.0 - + if not self.sentence_model: # Fallback to simple word overlap return self._calculate_word_overlap(original, text) - + # Use sentence transformers for semantic similarity try: orig_embedding = self.sentence_model.encode([original]) gen_embedding = self.sentence_model.encode([text]) - + similarity = np.dot(orig_embedding[0], gen_embedding[0]) / ( np.linalg.norm(orig_embedding[0]) * np.linalg.norm(gen_embedding[0]) ) - + coverage_score = max(0.0, min(1.0, similarity)) - + logger.debug(f"Coverage score: {coverage_score:.3f} (semantic similarity)") return coverage_score - + except Exception as e: logger.warning(f"Semantic similarity calculation failed: {e}") return self._calculate_word_overlap(original, text) - + except Exception as e: logger.error(f"Error calculating coverage: {e}") raise EvaluationError(f"Coverage calculation failed: {e}") - + def _calculate_word_overlap(self, original: str, generated: str) -> float: """Fallback method using word overlap""" orig_words = set(original.lower().split()) gen_words = set(generated.lower().split()) - + if not orig_words: return 0.0 - + overlap = len(orig_words.intersection(gen_words)) coverage = min(1.0, overlap / len(orig_words)) - + logger.debug(f"Coverage score: {coverage:.3f} (word overlap)") return coverage class ContradictionDetection: """Medical contradiction detection against knowledge base""" - + def __init__(self) -> None: self.medical_knowledge_base = self._load_medical_knowledge() self.contradiction_patterns = self._load_contradiction_patterns() logger.debug("Initialized ContradictionDetection") - + def _load_medical_knowledge(self) -> Dict[str, List[str]]: """Load basic medical knowledge base for contradiction detection""" # This would ideally load from a comprehensive medical knowledge base # For now, using a simplified version with common medical facts return { - 'hypertension': [ - 'high blood pressure', - 'systolic over 140 or diastolic over 90', - 'can lead to heart disease and stroke', - 'treated with lifestyle changes and medication' + "hypertension": [ + "high blood pressure", + "systolic over 140 or diastolic over 90", + "can lead to heart disease and stroke", + "treated with lifestyle changes and medication", ], - 'diabetes': [ - 'high blood sugar', - 'insulin resistance or deficiency', - 'requires blood sugar monitoring', - 'managed with diet, exercise, and medication' + "diabetes": [ + "high blood sugar", + "insulin resistance or deficiency", + "requires blood sugar monitoring", + "managed with diet, exercise, and medication", ], - 'antibiotics': [ - 'treat bacterial infections', - 'do not work against viruses', - 'should be taken as prescribed', - 'resistance can develop from misuse' + "antibiotics": [ + "treat bacterial infections", + "do not work against viruses", + "should be taken as prescribed", + "resistance can develop from misuse", + ], + "aspirin": [ + "pain reliever and blood thinner", + "can cause stomach bleeding", + "contraindicated with certain conditions", + "requires medical supervision for daily use", ], - 'aspirin': [ - 'pain reliever and blood thinner', - 'can cause stomach bleeding', - 'contraindicated with certain conditions', - 'requires medical supervision for daily use' - ] } - + def _load_contradiction_patterns(self) -> List[Dict[str, str]]: """Load patterns that indicate medical contradictions""" return [ + {"pattern": r"antibiotics.*treat.*virus", "description": "Antibiotics do not treat viral infections"}, { - 'pattern': r'antibiotics.*treat.*virus', - 'description': 'Antibiotics do not treat viral infections' + "pattern": r"stop.*medication.*immediately.*feel.*better", + "description": "Medications should not be stopped without medical advice", }, - { - 'pattern': r'stop.*medication.*immediately.*feel.*better', - 'description': 'Medications should not be stopped without medical advice' - }, - { - 'pattern': r'aspirin.*safe.*everyone', - 'description': 'Aspirin has contraindications and side effects' - }, - { - 'pattern': r'blood pressure.*normal.*140/90', - 'description': '140/90 is not normal blood pressure' - } + {"pattern": r"aspirin.*safe.*everyone", "description": "Aspirin has contraindications and side effects"}, + {"pattern": r"blood pressure.*normal.*140/90", "description": "140/90 is not normal blood pressure"}, ] - + def calculate(self, text: str, audience: str, original: str = "", **kwargs) -> float: """ Detect contradictions against medical knowledge base - + Args: text: Generated explanation text audience: Target audience original: Original medical information **kwargs: Additional parameters - + Returns: Contradiction score (0-1, where 1 means no contradictions) """ @@ -389,44 +386,39 @@ def calculate(self, text: str, audience: str, original: str = "", **kwargs) -> f if not text.strip(): logger.warning("Empty text provided for contradiction detection") return 0.5 - + text_lower = text.lower() contradiction_count = 0 - + # Check for pattern-based contradictions for pattern_info in self.contradiction_patterns: - pattern = pattern_info['pattern'] + pattern = pattern_info["pattern"] if re.search(pattern, text_lower): contradiction_count += 1 logger.warning(f"Contradiction detected: {pattern_info['description']}") - + # Check for factual contradictions against knowledge base for topic, facts in self.medical_knowledge_base.items(): if topic in text_lower: # Simple contradiction detection: look for negations of known facts for fact in facts: # Check for explicit contradictions - contradiction_patterns = [ - f"not {fact}", - f"never {fact}", - f"don't {fact}", - f"doesn't {fact}" - ] + contradiction_patterns = [f"not {fact}", f"never {fact}", f"don't {fact}", f"doesn't {fact}"] for contradiction in contradiction_patterns: if contradiction in text_lower: contradiction_count += 1 logger.warning(f"Knowledge base contradiction: {contradiction}") - + # Calculate score (higher is better, no contradictions = 1.0) if contradiction_count == 0: score = 1.0 else: # Penalize based on number of contradictions score = max(0.0, 1.0 - (contradiction_count * 0.3)) - + logger.debug(f"Contradiction score: {score:.3f} ({contradiction_count} contradictions found)") return score - + except Exception as e: logger.error(f"Error in contradiction detection: {e}") raise EvaluationError(f"Contradiction detection failed: {e}") @@ -434,65 +426,65 @@ def calculate(self, text: str, audience: str, original: str = "", **kwargs) -> f class InformationPreservation: """Check preservation of critical medical information""" - + def __init__(self) -> None: self.critical_info_patterns = self._load_critical_patterns() logger.debug("Initialized InformationPreservation") - + def _load_critical_patterns(self) -> Dict[str, List[str]]: """Load patterns for critical medical information""" return { - 'dosages': [ - r'\d+\s*mg', - r'\d+\s*ml', - r'\d+\s*units', - r'\d+\s*tablets?', - r'\d+\s*times?\s*(?:per\s+)?day', - r'once\s+daily', - r'twice\s+daily', - r'every\s+\d+\s+hours' + "dosages": [ + r"\d+\s*mg", + r"\d+\s*ml", + r"\d+\s*units", + r"\d+\s*tablets?", + r"\d+\s*times?\s*(?:per\s+)?day", + r"once\s+daily", + r"twice\s+daily", + r"every\s+\d+\s+hours", + ], + "warnings": [ + r"do not", + r"avoid", + r"contraindicated", + r"warning", + r"caution", + r"side effects?", + r"adverse", + r"allergic", + r"emergency", ], - 'warnings': [ - r'do not', - r'avoid', - r'contraindicated', - r'warning', - r'caution', - r'side effects?', - r'adverse', - r'allergic', - r'emergency' + "timing": [ + r"before\s+meals?", + r"after\s+meals?", + r"with\s+food", + r"on\s+empty\s+stomach", + r"bedtime", + r"morning", + r"evening", ], - 'timing': [ - r'before\s+meals?', - r'after\s+meals?', - r'with\s+food', - r'on\s+empty\s+stomach', - r'bedtime', - r'morning', - r'evening' + "conditions": [ + r"if\s+pregnant", + r"if\s+breastfeeding", + r"kidney\s+disease", + r"liver\s+disease", + r"heart\s+condition", + r"diabetes", + r"high\s+blood\s+pressure", ], - 'conditions': [ - r'if\s+pregnant', - r'if\s+breastfeeding', - r'kidney\s+disease', - r'liver\s+disease', - r'heart\s+condition', - r'diabetes', - r'high\s+blood\s+pressure' - ] } - + def calculate(self, text: str, audience: str, original: str = "", **kwargs) -> float: """ Check if critical information is preserved from original to generated text - + Args: text: Generated explanation text audience: Target audience original: Original medical information **kwargs: Additional parameters - + Returns: Information preservation score (0-1) """ @@ -500,22 +492,22 @@ def calculate(self, text: str, audience: str, original: str = "", **kwargs) -> f if not text.strip() or not original.strip(): logger.warning("Empty text or original provided for information preservation") return 0.0 - + original_lower = original.lower() text_lower = text.lower() - + total_critical_info = 0 preserved_critical_info = 0 - + # Check each category of critical information for category, patterns in self.critical_info_patterns.items(): for pattern in patterns: # Find all matches in original text original_matches = re.findall(pattern, original_lower) - + if original_matches: total_critical_info += len(original_matches) - + # Check if these matches are preserved in generated text for match in original_matches: if match in text_lower or any(re.search(pattern, text_lower) for pattern in [match]): @@ -524,99 +516,150 @@ def calculate(self, text: str, audience: str, original: str = "", **kwargs) -> f # Check for paraphrased versions if self._check_paraphrased_preservation(match, text_lower, category): preserved_critical_info += 1 - + # Calculate preservation score if total_critical_info == 0: # No critical information to preserve score = 1.0 else: score = preserved_critical_info / total_critical_info - - logger.debug(f"Information preservation score: {score:.3f} " - f"({preserved_critical_info}/{total_critical_info} preserved)") + + logger.debug( + f"Information preservation score: {score:.3f} " f"({preserved_critical_info}/{total_critical_info} preserved)" + ) return score - + except Exception as e: logger.error(f"Error in information preservation check: {e}") raise EvaluationError(f"Information preservation check failed: {e}") - + def _check_paraphrased_preservation(self, original_info: str, generated_text: str, category: str) -> bool: """Check if information is preserved in paraphrased form""" # Simple paraphrase detection for different categories - if category == 'dosages': + if category == "dosages": # Check if any dosage information is present - dosage_patterns = [r'\d+', r'dose', r'amount', r'quantity'] + dosage_patterns = [r"\d+", r"dose", r"amount", r"quantity"] return any(re.search(pattern, generated_text) for pattern in dosage_patterns) - - elif category == 'warnings': + + elif category == "warnings": # Check if warning language is present - warning_patterns = [r'careful', r'important', r'note', r'remember', r'consult'] + warning_patterns = [r"careful", r"important", r"note", r"remember", r"consult"] return any(re.search(pattern, generated_text) for pattern in warning_patterns) - - elif category == 'timing': + + elif category == "timing": # Check if timing information is preserved - timing_patterns = [r'when', r'time', r'schedule', r'take'] + timing_patterns = [r"when", r"time", r"schedule", r"take"] return any(re.search(pattern, generated_text) for pattern in timing_patterns) - - elif category == 'conditions': + + elif category == "conditions": # Check if condition information is preserved - condition_patterns = [r'condition', r'disease', r'illness', r'medical'] + condition_patterns = [r"condition", r"disease", r"illness", r"medical"] return any(re.search(pattern, generated_text) for pattern in condition_patterns) - + return False class HallucinationDetection: """Detect hallucinated medical entities not present in source text""" - + def __init__(self) -> None: self.medical_entities = self._load_medical_entities() try: # Try to load spaCy model for NER import spacy + self.nlp = spacy.load("en_core_web_sm") logger.debug("Loaded spaCy model for NER") except Exception as e: logger.warning(f"Could not load spaCy model: {e}") self.nlp = None - + logger.debug("Initialized HallucinationDetection") - + def _load_medical_entities(self) -> Dict[str, List[str]]: """Load common medical entities for hallucination detection""" return { - 'medications': [ - 'aspirin', 'ibuprofen', 'acetaminophen', 'paracetamol', 'insulin', - 'metformin', 'lisinopril', 'atorvastatin', 'omeprazole', 'albuterol', - 'prednisone', 'warfarin', 'digoxin', 'furosemide', 'levothyroxine' + "medications": [ + "aspirin", + "ibuprofen", + "acetaminophen", + "paracetamol", + "insulin", + "metformin", + "lisinopril", + "atorvastatin", + "omeprazole", + "albuterol", + "prednisone", + "warfarin", + "digoxin", + "furosemide", + "levothyroxine", + ], + "conditions": [ + "hypertension", + "diabetes", + "asthma", + "copd", + "pneumonia", + "bronchitis", + "arthritis", + "osteoporosis", + "depression", + "anxiety", + "migraine", + "epilepsy", + "cancer", + "stroke", + "heart attack", ], - 'conditions': [ - 'hypertension', 'diabetes', 'asthma', 'copd', 'pneumonia', - 'bronchitis', 'arthritis', 'osteoporosis', 'depression', 'anxiety', - 'migraine', 'epilepsy', 'cancer', 'stroke', 'heart attack' + "symptoms": [ + "fever", + "cough", + "headache", + "nausea", + "vomiting", + "diarrhea", + "constipation", + "fatigue", + "dizziness", + "chest pain", + "shortness of breath", + "swelling", + "rash", + "itching", + "numbness", + "tingling", ], - 'symptoms': [ - 'fever', 'cough', 'headache', 'nausea', 'vomiting', 'diarrhea', - 'constipation', 'fatigue', 'dizziness', 'chest pain', 'shortness of breath', - 'swelling', 'rash', 'itching', 'numbness', 'tingling' + "body_parts": [ + "heart", + "lungs", + "liver", + "kidney", + "brain", + "stomach", + "intestine", + "bladder", + "pancreas", + "thyroid", + "spine", + "joints", + "muscles", + "blood vessels", + "nerves", ], - 'body_parts': [ - 'heart', 'lungs', 'liver', 'kidney', 'brain', 'stomach', - 'intestine', 'bladder', 'pancreas', 'thyroid', 'spine', - 'joints', 'muscles', 'blood vessels', 'nerves' - ] } - + def calculate(self, text: str, audience: str, original: str = "", **kwargs) -> float: """ Detect medical entities in generated text that are not in source - + Args: text: Generated explanation text audience: Target audience original: Original medical information **kwargs: Additional parameters - + Returns: Hallucination score (0-1, where 1 means no hallucinations) """ @@ -624,14 +667,14 @@ def calculate(self, text: str, audience: str, original: str = "", **kwargs) -> f if not text.strip() or not original.strip(): logger.warning("Empty text or original provided for hallucination detection") return 0.5 - + # Extract medical entities from both texts original_entities = self._extract_medical_entities(original) generated_entities = self._extract_medical_entities(text) - + # Find entities in generated text that are not in original hallucinated_entities = generated_entities - original_entities - + # Calculate hallucination score total_generated_entities = len(generated_entities) if total_generated_entities == 0: @@ -639,113 +682,116 @@ def calculate(self, text: str, audience: str, original: str = "", **kwargs) -> f else: hallucination_rate = len(hallucinated_entities) / total_generated_entities score = max(0.0, 1.0 - hallucination_rate) - + if hallucinated_entities: logger.warning(f"Hallucinated entities detected: {hallucinated_entities}") - - logger.debug(f"Hallucination score: {score:.3f} " - f"({len(hallucinated_entities)}/{total_generated_entities} hallucinated)") + + logger.debug( + f"Hallucination score: {score:.3f} " f"({len(hallucinated_entities)}/{total_generated_entities} hallucinated)" + ) return score - + except Exception as e: logger.error(f"Error in hallucination detection: {e}") raise EvaluationError(f"Hallucination detection failed: {e}") - + def _extract_medical_entities(self, text: str) -> set: """Extract medical entities from text""" entities = set() text_lower = text.lower() - + # Extract entities from predefined lists for category, entity_list in self.medical_entities.items(): for entity in entity_list: if entity.lower() in text_lower: entities.add(entity.lower()) - + # Use spaCy NER if available if self.nlp: try: doc = self.nlp(text) for ent in doc.ents: # Focus on medical-related entity types - if ent.label_ in ['PERSON', 'ORG', 'PRODUCT', 'SUBSTANCE']: + if ent.label_ in ["PERSON", "ORG", "PRODUCT", "SUBSTANCE"]: # Filter for likely medical entities entity_text = ent.text.lower() - if any(medical_term in entity_text - for medical_terms in self.medical_entities.values() - for medical_term in medical_terms): + if any( + medical_term in entity_text + for medical_terms in self.medical_entities.values() + for medical_term in medical_terms + ): entities.add(entity_text) except Exception as e: logger.warning(f"spaCy NER failed: {e}") - + return entities class LLMJudge: """LLM-as-a-judge evaluator with full API integration""" - + def __init__(self, model: Optional[str] = None) -> None: - self.model: str = model or config.get('llm_judge.default_model') - self.timeout: int = config.get('llm_judge.timeout', 30) - self.max_retries: int = config.get('llm_judge.max_retries', 3) - self.temperature: float = config.get('llm_judge.temperature', 0.1) - self.max_tokens: int = config.get('llm_judge.max_tokens', 1000) - + self.model: str = model or config.get("llm_judge.default_model") + self.timeout: int = config.get("llm_judge.timeout", 30) + self.max_retries: int = config.get("llm_judge.max_retries", 3) + self.temperature: float = config.get("llm_judge.temperature", 0.1) + self.max_tokens: int = config.get("llm_judge.max_tokens", 1000) + # Determine API provider from model name self.provider: str = self._determine_provider(self.model) self.api_key: str = config.get_api_key(self.provider) - + logger.info(f"Initialized LLMJudge with model: {self.model} (provider: {self.provider})") - + def _determine_provider(self, model: str) -> str: """Determine API provider from model name""" - if 'gpt' in model.lower(): - return 'openai' - elif 'claude' in model.lower(): - return 'anthropic' + if "gpt" in model.lower(): + return "openai" + elif "claude" in model.lower(): + return "anthropic" else: logger.warning(f"Unknown model provider for {model}, defaulting to openai") - return 'openai' - + return "openai" + def calculate(self, text: str, audience: str, original: str = "", **kwargs) -> float: """ Evaluate using LLM as judge - + Args: text: Generated explanation audience: Target audience original: Original medical information **kwargs: Additional parameters - + Returns: Quality score (0-1) """ try: prompt = self._create_evaluation_prompt(original, text, audience) - + for attempt in range(self.max_retries): try: response = self._call_llm_api(prompt) score = self._parse_llm_response(response) - + logger.debug(f"LLM Judge score for {audience}: {score:.3f} (attempt {attempt + 1})") return score - + except Exception as e: logger.warning(f"LLM API call failed (attempt {attempt + 1}): {e}") if attempt == self.max_retries - 1: raise - time.sleep(2 ** attempt) # Exponential backoff - + time.sleep(2**attempt) # Exponential backoff + # If we reach here, all retries failed but no exception was raised logger.error("All LLM API retry attempts failed") return 0.6 - + except Exception as e: logger.error(f"LLM Judge evaluation failed: {e}") # Return reasonable default instead of failing completely return 0.6 - + def _create_evaluation_prompt(self, original: str, generated: str, audience: str) -> str: """Create evaluation prompt for LLM judge""" return f"""Evaluate the following 'Generated' explanation, which was adapted from the 'Original' medical information for the specified {audience}. @@ -766,94 +812,87 @@ def _create_evaluation_prompt(self, original: str, generated: str, audience: str {{"score1": X, "score2": Y, "score3": Z, "score4": A, "score5": B, "score6": C, "overall": D}} Where each score is a number from 1-5, and overall is the average.""" - + def _call_llm_api(self, prompt: str) -> str: """Make API call to LLM service""" - if self.provider == 'openai': + if self.provider == "openai": return self._call_openai_api(prompt) - elif self.provider == 'anthropic': + elif self.provider == "anthropic": return self._call_anthropic_api(prompt) else: raise EvaluationError(f"Unsupported provider: {self.provider}") - + def _call_openai_api(self, prompt: str) -> str: """Call OpenAI API""" - api_config = config.get_api_config('openai') + api_config = config.get_api_config("openai") url = f"{api_config['base_url']}/chat/completions" - - headers = { - 'Authorization': f'Bearer {self.api_key}', - 'Content-Type': 'application/json' - } - + + headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} + data = { - 'model': self.model, - 'messages': [{'role': 'user', 'content': prompt}], - 'temperature': self.temperature, - 'max_tokens': self.max_tokens + "model": self.model, + "messages": [{"role": "user", "content": prompt}], + "temperature": self.temperature, + "max_tokens": self.max_tokens, } - + response = requests.post(url, headers=headers, json=data, timeout=self.timeout) response.raise_for_status() - + result = response.json() - return result['choices'][0]['message']['content'] - + return result["choices"][0]["message"]["content"] + def _call_anthropic_api(self, prompt: str) -> str: """Call Anthropic API""" - api_config = config.get_api_config('anthropic') + api_config = config.get_api_config("anthropic") url = f"{api_config['base_url']}/v1/messages" - - headers = { - 'x-api-key': self.api_key, - 'Content-Type': 'application/json', - 'anthropic-version': '2023-06-01' - } - + + headers = {"x-api-key": self.api_key, "Content-Type": "application/json", "anthropic-version": "2023-06-01"} + data = { - 'model': self.model, - 'max_tokens': self.max_tokens, - 'temperature': self.temperature, - 'messages': [{'role': 'user', 'content': prompt}] + "model": self.model, + "max_tokens": self.max_tokens, + "temperature": self.temperature, + "messages": [{"role": "user", "content": prompt}], } - + response = requests.post(url, headers=headers, json=data, timeout=self.timeout) response.raise_for_status() - + result = response.json() - return result['content'][0]['text'] - + return result["content"][0]["text"] + def _parse_llm_response(self, response: str) -> float: """Parse LLM response to extract score""" try: # Try to extract JSON from response - json_start = response.find('{') - json_end = response.rfind('}') + 1 - + json_start = response.find("{") + json_end = response.rfind("}") + 1 + if json_start >= 0 and json_end > json_start: json_str = response[json_start:json_end] data = json.loads(json_str) - + # Get overall score or calculate from individual scores - if 'overall' in data: - overall = float(data['overall']) + if "overall" in data: + overall = float(data["overall"]) else: individual_scores = [] for i in range(1, 7): - key = f'score{i}' + key = f"score{i}" if key in data: individual_scores.append(float(data[key])) - + overall = sum(individual_scores) / len(individual_scores) if individual_scores else 3.0 - + # Convert from 1-5 scale to 0-1 scale return (overall - 1) / 4 - + except (json.JSONDecodeError, ValueError, KeyError) as e: logger.warning(f"Failed to parse LLM response as JSON: {e}") - + # Fallback: try to extract a number from the response - numbers = re.findall(r'\d+\.?\d*', response) + numbers = re.findall(r"\d+\.?\d*", response) if numbers: try: score = float(numbers[0]) @@ -863,27 +902,29 @@ def _parse_llm_response(self, response: str) -> float: return score except ValueError: pass - + logger.warning("Could not parse LLM response, using default score") return 0.6 # Default reasonable score class MEQBenchEvaluator: """Main evaluation class using dependency injection and SOLID principles""" - - def __init__(self, - readability_calculator: Optional[ReadabilityCalculator] = None, - terminology_calculator: Optional[TerminologyCalculator] = None, - safety_checker: Optional[SafetyChecker] = None, - coverage_analyzer: Optional[CoverageAnalyzer] = None, - llm_judge: Optional[LLMJudge] = None, - contradiction_detector: Optional[ContradictionDetection] = None, - information_preservation: Optional[InformationPreservation] = None, - hallucination_detector: Optional[HallucinationDetection] = None, - strategy_factory: Optional[StrategyFactory] = None) -> None: + + def __init__( + self, + readability_calculator: Optional[ReadabilityCalculator] = None, + terminology_calculator: Optional[TerminologyCalculator] = None, + safety_checker: Optional[SafetyChecker] = None, + coverage_analyzer: Optional[CoverageAnalyzer] = None, + llm_judge: Optional[LLMJudge] = None, + contradiction_detector: Optional[ContradictionDetection] = None, + information_preservation: Optional[InformationPreservation] = None, + hallucination_detector: Optional[HallucinationDetection] = None, + strategy_factory: Optional[StrategyFactory] = None, + ) -> None: """ Initialize evaluator with dependency injection - + Args: readability_calculator: Calculator for readability metrics terminology_calculator: Calculator for terminology appropriateness @@ -897,167 +938,175 @@ def __init__(self, """ # Use dependency injection with sensible defaults self.strategy_factory: StrategyFactory = strategy_factory or StrategyFactory() - - self.readability_calculator: ReadabilityCalculator = readability_calculator or ReadabilityCalculator(self.strategy_factory) - self.terminology_calculator: TerminologyCalculator = terminology_calculator or TerminologyCalculator(self.strategy_factory) + + self.readability_calculator: ReadabilityCalculator = readability_calculator or ReadabilityCalculator( + self.strategy_factory + ) + self.terminology_calculator: TerminologyCalculator = terminology_calculator or TerminologyCalculator( + self.strategy_factory + ) self.safety_checker: SafetyChecker = safety_checker or SafetyChecker() self.coverage_analyzer: CoverageAnalyzer = coverage_analyzer or CoverageAnalyzer() self.llm_judge: LLMJudge = llm_judge or LLMJudge() - + # New safety and factual consistency metrics self.contradiction_detector: ContradictionDetection = contradiction_detector or ContradictionDetection() self.information_preservation: InformationPreservation = information_preservation or InformationPreservation() self.hallucination_detector: HallucinationDetection = hallucination_detector or HallucinationDetection() - + # Load scoring configuration self.scoring_config = config.get_scoring_config() # type: ignore[misc] - self.weights: Dict[str, float] = self.scoring_config['weights'] - + self.weights: Dict[str, float] = self.scoring_config["weights"] + logger.info("MEQBenchEvaluator initialized with dependency injection") - + def evaluate_explanation(self, original: str, generated: str, audience: str) -> EvaluationScore: """ Evaluate a single explanation for a specific audience - + Args: original: Original medical information generated: Generated explanation audience: Target audience - + Returns: EvaluationScore object with all metrics - + Raises: EvaluationError: If evaluation fails """ try: logger.info(f"Starting evaluation for {audience} audience") start_time = time.time() - + # Validate inputs if not generated.strip(): raise EvaluationError("Generated explanation is empty") - + if audience not in config.get_audiences(): raise EvaluationError(f"Unsupported audience: {audience}") - + # Calculate individual metrics - metrics = {} - details = {} - + metrics: Dict[str, float] = {} + details: Dict[str, Any] = {} + try: - metrics['readability'] = self.readability_calculator.calculate(generated, audience) - details['readability'] = {'text_length': len(generated), 'audience': audience} + metrics["readability"] = self.readability_calculator.calculate(generated, audience) + details["readability"] = {"text_length": len(generated), "audience": audience} except Exception as e: logger.error(f"Readability calculation failed: {e}") - metrics['readability'] = 0.0 - + metrics["readability"] = 0.0 + try: - metrics['terminology'] = self.terminology_calculator.calculate(generated, audience) + metrics["terminology"] = self.terminology_calculator.calculate(generated, audience) except Exception as e: logger.error(f"Terminology calculation failed: {e}") - metrics['terminology'] = 0.0 - + metrics["terminology"] = 0.0 + try: - metrics['safety'] = self.safety_checker.calculate(generated, audience) + metrics["safety"] = self.safety_checker.calculate(generated, audience) except Exception as e: logger.error(f"Safety check failed: {e}") - metrics['safety'] = 0.0 - + metrics["safety"] = 0.0 + try: - metrics['coverage'] = self.coverage_analyzer.calculate(generated, audience, original=original) + metrics["coverage"] = self.coverage_analyzer.calculate(generated, audience, original=original) except Exception as e: logger.error(f"Coverage analysis failed: {e}") - metrics['coverage'] = 0.0 - + metrics["coverage"] = 0.0 + try: - metrics['quality'] = self.llm_judge.calculate(generated, audience, original=original) + metrics["quality"] = self.llm_judge.calculate(generated, audience, original=original) except Exception as e: logger.error(f"LLM judge failed: {e}") - metrics['quality'] = 0.6 # Default reasonable score - + metrics["quality"] = 0.6 # Default reasonable score + # New safety and factual consistency metrics try: - metrics['contradiction'] = self.contradiction_detector.calculate(generated, audience, original=original) + metrics["contradiction"] = self.contradiction_detector.calculate(generated, audience, original=original) except Exception as e: logger.error(f"Contradiction detection failed: {e}") - metrics['contradiction'] = 0.7 # Default reasonable score - + metrics["contradiction"] = 0.7 # Default reasonable score + try: - metrics['information_preservation'] = self.information_preservation.calculate(generated, audience, original=original) + metrics["information_preservation"] = self.information_preservation.calculate( + generated, audience, original=original + ) except Exception as e: logger.error(f"Information preservation check failed: {e}") - metrics['information_preservation'] = 0.7 # Default reasonable score - + metrics["information_preservation"] = 0.7 # Default reasonable score + try: - metrics['hallucination'] = self.hallucination_detector.calculate(generated, audience, original=original) + metrics["hallucination"] = self.hallucination_detector.calculate(generated, audience, original=original) except Exception as e: logger.error(f"Hallucination detection failed: {e}") - metrics['hallucination'] = 0.7 # Default reasonable score - + metrics["hallucination"] = 0.7 # Default reasonable score + # Calculate weighted overall score overall = sum(metrics[metric] * self.weights[metric] for metric in metrics.keys()) - + # Apply safety multiplier if safety score is very low - if metrics['safety'] < 0.3: - overall *= self.scoring_config['parameters']['safety_multiplier'] + if metrics["safety"] < 0.3: + overall *= self.scoring_config["parameters"]["safety_multiplier"] overall = min(1.0, overall) # Cap at 1.0 - details['safety_penalty_applied'] = True - + details["safety_penalty_applied"] = True + evaluation_time = time.time() - start_time - details['evaluation_time'] = evaluation_time - details['weights_used'] = self.weights - + details["evaluation_time"] = evaluation_time + details["weights_used"] = dict(self.weights) + logger.info(f"Evaluation completed for {audience} in {evaluation_time:.2f}s") - logger.debug(f"Scores - R:{metrics['readability']:.3f} T:{metrics['terminology']:.3f} " - f"S:{metrics['safety']:.3f} C:{metrics['coverage']:.3f} Q:{metrics['quality']:.3f} " - f"CD:{metrics['contradiction']:.3f} IP:{metrics['information_preservation']:.3f} " - f"H:{metrics['hallucination']:.3f} Overall:{overall:.3f}") - + logger.debug( + f"Scores - R:{metrics['readability']:.3f} T:{metrics['terminology']:.3f} " + f"S:{metrics['safety']:.3f} C:{metrics['coverage']:.3f} Q:{metrics['quality']:.3f} " + f"CD:{metrics['contradiction']:.3f} IP:{metrics['information_preservation']:.3f} " + f"H:{metrics['hallucination']:.3f} Overall:{overall:.3f}" + ) + return EvaluationScore( - readability=metrics['readability'], - terminology=metrics['terminology'], - safety=metrics['safety'], - coverage=metrics['coverage'], - quality=metrics['quality'], - contradiction=metrics['contradiction'], - information_preservation=metrics['information_preservation'], - hallucination=metrics['hallucination'], + readability=metrics["readability"], + terminology=metrics["terminology"], + safety=metrics["safety"], + coverage=metrics["coverage"], + quality=metrics["quality"], + contradiction=metrics["contradiction"], + information_preservation=metrics["information_preservation"], + hallucination=metrics["hallucination"], overall=overall, - details=details + details=details, ) - + except Exception as e: logger.error(f"Evaluation failed for {audience}: {e}") raise EvaluationError(f"Evaluation failed: {e}") - + def evaluate_all_audiences(self, original: str, explanations: Dict[str, str]) -> Dict[str, EvaluationScore]: """ Evaluate explanations for all audiences - + Args: original: Original medical information explanations: Dictionary mapping audience to explanation - + Returns: Dictionary mapping audience to EvaluationScore """ results = {} supported_audiences = config.get_audiences() - + logger.info(f"Starting evaluation for {len(explanations)} audiences") - + for audience, explanation in explanations.items(): if audience not in supported_audiences: logger.warning(f"Skipping unsupported audience: {audience}") continue - + try: results[audience] = self.evaluate_explanation(original, explanation, audience) except EvaluationError as e: logger.error(f"Failed to evaluate {audience}: {e}") # Continue with other audiences continue - + logger.info(f"Completed evaluation for {len(results)} audiences") - return results \ No newline at end of file + return results diff --git a/src/leaderboard.py b/src/leaderboard.py index eabf814..7356bf2 100644 --- a/src/leaderboard.py +++ b/src/leaderboard.py @@ -29,234 +29,222 @@ # Set up logging logging.basicConfig(level=logging.INFO) -logger = logging.getLogger('meq_bench.leaderboard') +logger = logging.getLogger("meq_bench.leaderboard") class LeaderboardGenerator: """Generate static HTML leaderboard from evaluation results""" - + def __init__(self): self.results_data: List[Dict[str, Any]] = [] self.benchmark_stats: Dict[str, Any] = {} - + def load_results(self, results_dir: Path) -> None: """Load all result files from a directory - + Args: results_dir: Directory containing JSON result files """ if not results_dir.exists(): raise FileNotFoundError(f"Results directory not found: {results_dir}") - + result_files = list(results_dir.glob("*.json")) if not result_files: raise ValueError(f"No JSON result files found in {results_dir}") - + logger.info(f"Found {len(result_files)} result files") - + for result_file in result_files: try: - with open(result_file, 'r', encoding='utf-8') as f: + with open(result_file, "r", encoding="utf-8") as f: data = json.load(f) - + # Validate required fields - required_fields = ['model_name', 'total_items', 'audience_scores', 'summary'] + required_fields = ["model_name", "total_items", "audience_scores", "summary"] if all(field in data for field in required_fields): self.results_data.append(data) logger.debug(f"Loaded results for {data['model_name']}") else: logger.warning(f"Invalid result file format: {result_file}") - + except Exception as e: logger.error(f"Error loading {result_file}: {e}") - + if not self.results_data: raise ValueError("No valid result files were loaded") - + logger.info(f"Successfully loaded {len(self.results_data)} evaluation results") - + def calculate_leaderboard_stats(self) -> Dict[str, Any]: """Calculate overall leaderboard statistics - + Returns: Dictionary containing leaderboard statistics """ if not self.results_data: return {} - + # Calculate aggregate statistics total_models = len(self.results_data) - total_evaluations = sum(result['total_items'] for result in self.results_data) - + total_evaluations = sum(result["total_items"] for result in self.results_data) + # Audience coverage all_audiences = set() for result in self.results_data: - all_audiences.update(result['audience_scores'].keys()) - + all_audiences.update(result["audience_scores"].keys()) + # Complexity coverage all_complexities = set() for result in self.results_data: - if 'complexity_scores' in result: - all_complexities.update(result['complexity_scores'].keys()) - + if "complexity_scores" in result: + all_complexities.update(result["complexity_scores"].keys()) + # Performance ranges - overall_scores = [result['summary'].get('overall_mean', 0) for result in self.results_data] + overall_scores = [result["summary"].get("overall_mean", 0) for result in self.results_data] if overall_scores: best_score = max(overall_scores) worst_score = min(overall_scores) avg_score = sum(overall_scores) / len(overall_scores) else: best_score = worst_score = avg_score = 0 - + return { - 'total_models': total_models, - 'total_evaluations': total_evaluations, - 'audiences': sorted(list(all_audiences)), - 'complexity_levels': sorted(list(all_complexities)), - 'best_score': best_score, - 'worst_score': worst_score, - 'average_score': avg_score, - 'last_updated': datetime.now().isoformat() + "total_models": total_models, + "total_evaluations": total_evaluations, + "audiences": sorted(list(all_audiences)), + "complexity_levels": sorted(list(all_complexities)), + "best_score": best_score, + "worst_score": worst_score, + "average_score": avg_score, + "last_updated": datetime.now().isoformat(), } - + def rank_models(self) -> List[Dict[str, Any]]: """Rank models by overall performance - + Returns: List of model results sorted by overall performance """ - ranked_models = sorted( - self.results_data, - key=lambda x: x['summary'].get('overall_mean', 0), - reverse=True - ) - + ranked_models = sorted(self.results_data, key=lambda x: x["summary"].get("overall_mean", 0), reverse=True) + # Add ranking information for i, model in enumerate(ranked_models): - model['rank'] = i + 1 - model['overall_score'] = model['summary'].get('overall_mean', 0) - + model["rank"] = i + 1 + model["overall_score"] = model["summary"].get("overall_mean", 0) + return ranked_models - + def generate_audience_breakdown(self, ranked_models: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]: """Generate audience-specific performance breakdown - + Args: ranked_models: List of ranked model results - + Returns: Dictionary mapping audience to ranked model performance """ audience_breakdown = {} - + # Get all audiences all_audiences = set() for model in ranked_models: - all_audiences.update(model['audience_scores'].keys()) - + all_audiences.update(model["audience_scores"].keys()) + for audience in sorted(all_audiences): audience_models = [] - + for model in ranked_models: - if audience in model['audience_scores']: - scores = model['audience_scores'][audience] + if audience in model["audience_scores"]: + scores = model["audience_scores"][audience] avg_score = sum(scores) / len(scores) if scores else 0 - - audience_models.append({ - 'model_name': model['model_name'], - 'score': avg_score, - 'num_items': len(scores) - }) - + + audience_models.append({"model_name": model["model_name"], "score": avg_score, "num_items": len(scores)}) + # Sort by score for this audience - audience_models.sort(key=lambda x: x['score'], reverse=True) - + audience_models.sort(key=lambda x: x["score"], reverse=True) + # Add rankings for i, model in enumerate(audience_models): - model['rank'] = i + 1 - + model["rank"] = i + 1 + audience_breakdown[audience] = audience_models - + return audience_breakdown - + def generate_complexity_breakdown(self, ranked_models: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]: """Generate complexity-specific performance breakdown - + Args: ranked_models: List of ranked model results - + Returns: Dictionary mapping complexity level to ranked model performance """ complexity_breakdown = {} - + # Get all complexity levels all_complexities = set() for model in ranked_models: - if 'complexity_scores' in model: - all_complexities.update(model['complexity_scores'].keys()) - + if "complexity_scores" in model: + all_complexities.update(model["complexity_scores"].keys()) + for complexity in sorted(all_complexities): complexity_models = [] - + for model in ranked_models: - if 'complexity_scores' in model and complexity in model['complexity_scores']: - scores = model['complexity_scores'][complexity] + if "complexity_scores" in model and complexity in model["complexity_scores"]: + scores = model["complexity_scores"][complexity] avg_score = sum(scores) / len(scores) if scores else 0 - - complexity_models.append({ - 'model_name': model['model_name'], - 'score': avg_score, - 'num_items': len(scores) - }) - + + complexity_models.append({"model_name": model["model_name"], "score": avg_score, "num_items": len(scores)}) + # Sort by score for this complexity level - complexity_models.sort(key=lambda x: x['score'], reverse=True) - + complexity_models.sort(key=lambda x: x["score"], reverse=True) + # Add rankings for i, model in enumerate(complexity_models): - model['rank'] = i + 1 - + model["rank"] = i + 1 + complexity_breakdown[complexity] = complexity_models - + return complexity_breakdown - + def generate_html(self, output_path: Path) -> None: """Generate static HTML leaderboard - + Args: output_path: Path where to save the HTML file """ if not self.results_data: raise ValueError("No results data loaded") - + # Calculate statistics and rankings stats = self.calculate_leaderboard_stats() ranked_models = self.rank_models() audience_breakdown = self.generate_audience_breakdown(ranked_models) complexity_breakdown = self.generate_complexity_breakdown(ranked_models) - + # Generate HTML content - html_content = self._generate_html_template( - stats, ranked_models, audience_breakdown, complexity_breakdown - ) - + html_content = self._generate_html_template(stats, ranked_models, audience_breakdown, complexity_breakdown) + # Ensure output directory exists output_path.parent.mkdir(parents=True, exist_ok=True) - + # Write HTML file - with open(output_path, 'w', encoding='utf-8') as f: + with open(output_path, "w", encoding="utf-8") as f: f.write(html_content) - + logger.info(f"Generated leaderboard HTML: {output_path}") - - def _generate_html_template(self, - stats: Dict[str, Any], - ranked_models: List[Dict[str, Any]], - audience_breakdown: Dict[str, List[Dict[str, Any]]], - complexity_breakdown: Dict[str, List[Dict[str, Any]]]) -> str: + + def _generate_html_template( + self, + stats: Dict[str, Any], + ranked_models: List[Dict[str, Any]], + audience_breakdown: Dict[str, List[Dict[str, Any]]], + complexity_breakdown: Dict[str, List[Dict[str, Any]]], + ) -> str: """Generate the complete HTML template""" - + return f""" @@ -347,7 +335,7 @@ def _generate_html_template(self, """ - + def _get_css_styles(self) -> str: """Return CSS styles for the leaderboard""" return """ @@ -611,7 +599,7 @@ def _get_css_styles(self) -> str: } } """ - + def _generate_overall_rankings_table(self, ranked_models: List[Dict[str, Any]]) -> str: """Generate the overall rankings table HTML""" table_html = """ @@ -630,22 +618,22 @@ def _generate_overall_rankings_table(self, ranked_models: List[Dict[str, Any]]) """ - + for model in ranked_models: rank_class = "" - if model['rank'] == 1: + if model["rank"] == 1: rank_class = "rank-1" - elif model['rank'] == 2: + elif model["rank"] == 2: rank_class = "rank-2" - elif model['rank'] == 3: + elif model["rank"] == 3: rank_class = "rank-3" - + # Calculate audience averages audience_scores = {} - for audience, scores in model['audience_scores'].items(): + for audience, scores in model["audience_scores"].items(): avg_score = sum(scores) / len(scores) if scores else 0 audience_scores[audience] = avg_score - + table_html += f""" #{model['rank']} @@ -658,18 +646,18 @@ def _generate_overall_rankings_table(self, ranked_models: List[Dict[str, Any]]) {audience_scores.get('caregiver', 0):.3f} """ - + table_html += """ """ - + return table_html - + def _generate_audience_breakdown_section(self, audience_breakdown: Dict[str, List[Dict[str, Any]]]) -> str: """Generate the audience breakdown section HTML""" html = "" - + for audience, models in audience_breakdown.items(): html += f"""
@@ -685,16 +673,16 @@ def _generate_audience_breakdown_section(self, audience_breakdown: Dict[str, Lis """ - + for model in models[:10]: # Show top 10 rank_class = "" - if model['rank'] == 1: + if model["rank"] == 1: rank_class = "rank-1" - elif model['rank'] == 2: + elif model["rank"] == 2: rank_class = "rank-2" - elif model['rank'] == 3: + elif model["rank"] == 3: rank_class = "rank-3" - + html += f""" #{model['rank']} @@ -703,19 +691,19 @@ def _generate_audience_breakdown_section(self, audience_breakdown: Dict[str, Lis {model['num_items']} """ - + html += """
""" - + return html - + def _generate_complexity_breakdown_section(self, complexity_breakdown: Dict[str, List[Dict[str, Any]]]) -> str: """Generate the complexity breakdown section HTML""" html = "" - + for complexity, models in complexity_breakdown.items(): html += f"""
@@ -731,16 +719,16 @@ def _generate_complexity_breakdown_section(self, complexity_breakdown: Dict[str, """ - + for model in models[:10]: # Show top 10 rank_class = "" - if model['rank'] == 1: + if model["rank"] == 1: rank_class = "rank-1" - elif model['rank'] == 2: + elif model["rank"] == 2: rank_class = "rank-2" - elif model['rank'] == 3: + elif model["rank"] == 3: rank_class = "rank-3" - + html += f""" #{model['rank']} @@ -749,34 +737,33 @@ def _generate_complexity_breakdown_section(self, complexity_breakdown: Dict[str, {model['num_items']} """ - + html += """
""" - + return html - - def _generate_javascript(self, - ranked_models: List[Dict[str, Any]], - audience_breakdown: Dict[str, List[Dict[str, Any]]], - stats: Dict[str, Any]) -> str: + + def _generate_javascript( + self, ranked_models: List[Dict[str, Any]], audience_breakdown: Dict[str, List[Dict[str, Any]]], stats: Dict[str, Any] + ) -> str: """Generate JavaScript for interactive features""" - + # Prepare data for charts - model_names = [model['model_name'][:20] for model in ranked_models[:8]] # Top 8 models - model_scores = [model['overall_score'] for model in ranked_models[:8]] - + model_names = [model["model_name"][:20] for model in ranked_models[:8]] # Top 8 models + model_scores = [model["overall_score"] for model in ranked_models[:8]] + audience_labels = list(audience_breakdown.keys()) audience_data = [] for audience in audience_labels: if audience_breakdown[audience]: - avg_score = sum(model['score'] for model in audience_breakdown[audience]) / len(audience_breakdown[audience]) + avg_score = sum(model["score"] for model in audience_breakdown[audience]) / len(audience_breakdown[audience]) audience_data.append(avg_score) else: audience_data.append(0) - + return f""" function showTab(tabName) {{ // Hide all tab contents @@ -872,36 +859,23 @@ def setup_argument_parser() -> argparse.ArgumentParser: # Generate with custom title python -m src.leaderboard --input results/ --output leaderboard.html --title "Custom MEQ-Bench Results" - """ - ) - - parser.add_argument( - '--input', '-i', - type=str, - required=True, - help='Directory containing JSON evaluation result files' + """, ) - - parser.add_argument( - '--output', '-o', - type=str, - default='docs/index.html', - help='Output path for the HTML leaderboard (default: docs/index.html)' - ) - + + parser.add_argument("--input", "-i", type=str, required=True, help="Directory containing JSON evaluation result files") + parser.add_argument( - '--title', + "--output", + "-o", type=str, - default='MEQ-Bench Leaderboard', - help='Custom title for the leaderboard page' - ) - - parser.add_argument( - '--verbose', '-v', - action='store_true', - help='Enable verbose logging' + default="docs/index.html", + help="Output path for the HTML leaderboard (default: docs/index.html)", ) - + + parser.add_argument("--title", type=str, default="MEQ-Bench Leaderboard", help="Custom title for the leaderboard page") + + parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose logging") + return parser @@ -909,30 +883,30 @@ def main(): """Main function for command-line usage""" parser = setup_argument_parser() args = parser.parse_args() - + # Set logging level if args.verbose: logging.getLogger().setLevel(logging.DEBUG) - + try: # Initialize leaderboard generator generator = LeaderboardGenerator() - + # Load results results_dir = Path(args.input) generator.load_results(results_dir) - + # Generate HTML leaderboard output_path = Path(args.output) generator.generate_html(output_path) - + logger.info(f"✅ Leaderboard generated successfully: {output_path}") logger.info(f"📊 Processed {len(generator.results_data)} model results") - + except Exception as e: logger.error(f"❌ Error generating leaderboard: {e}") sys.exit(1) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/prompt_templates.py b/src/prompt_templates.py index eddfe65..c4fc062 100644 --- a/src/prompt_templates.py +++ b/src/prompt_templates.py @@ -33,28 +33,28 @@ from .config import config -logger = logging.getLogger('meq_bench.prompts') +logger = logging.getLogger("meq_bench.prompts") class AudienceAdaptivePrompt: """Standardized prompt template for generating audience-adaptive medical explanations. - + This class provides a consistent template for instructing language models to generate medical explanations adapted for four distinct healthcare audiences: physicians, nurses, patients, and caregivers. Each explanation is tailored to the specific needs, knowledge level, and communication preferences of the target audience. - + The template ensures that explanations are: - Appropriately technical for the audience - Focused on relevant aspects for each role - Formatted consistently for easy parsing - + Attributes: base_template: The standardized prompt template string that instructs the model to generate audience-specific explanations. """ - + base_template: str = """Medical Information: {medical_content} Transform this explanation for four distinct audiences. Ensure each explanation is self-contained and clearly labeled: @@ -72,18 +72,18 @@ class AudienceAdaptivePrompt: @classmethod def format_prompt(cls, medical_content: str) -> str: """Format the prompt template with medical content. - + Takes medical information and inserts it into the standardized template to create a complete prompt for the language model. - + Args: medical_content: The medical information to be adapted for different audiences. This should be the original medical content that needs to be explained. - + Returns: Formatted prompt string ready to be sent to a language model. - + Example: ```python medical_info = "Hypertension is high blood pressure that can damage organs." @@ -91,28 +91,28 @@ def format_prompt(cls, medical_content: str) -> str: ``` """ return cls.base_template.format(medical_content=medical_content) - + @staticmethod def parse_response(response: str) -> Dict[str, str]: """Parse the model response to extract audience-specific explanations. - + Extracts individual explanations for each audience from the model's - response using robust regex patterns. The parsing logic looks for + response using robust regex patterns. The parsing logic looks for audience indicators in the text and separates the explanations accordingly. - + Uses multiple pattern variations to handle different response formats and includes fallback parsing for improved robustness. - + Args: response: Model response containing explanations for all audiences. Expected to contain sections labeled with audience names (physician, nurse, patient, caregiver). - + Returns: Dictionary mapping audience names to their respective explanations. Keys are audience names (e.g., 'physician', 'nurse', 'patient', 'caregiver'). Values are the explanation text for each audience. - + Example: ```python model_response = \"\"\"For a Physician: Technical explanation... @@ -122,185 +122,184 @@ def parse_response(response: str) -> Dict[str, str]: ``` """ explanations: Dict[str, str] = {} - + # Get supported audiences from configuration try: audiences = config.get_audiences() except Exception: # Fallback to default audiences if config fails audiences = ["physician", "nurse", "patient", "caregiver"] - + logger.debug(f"Parsing response for audiences: {audiences}") - + # Define robust regex patterns for each audience # These patterns are case-insensitive and handle various formatting variations patterns = {} - + for audience in audiences: # Create multiple pattern variations for robustness # Build audience alternation pattern for lookahead - audience_alternation = '|'.join([re.escape(aud) for aud in audiences]) - + audience_alternation = "|".join([re.escape(aud) for aud in audiences]) + audience_patterns = [ # Standard format: "For a Physician:" or "For the Physician:" rf"(?:for\s+(?:a|the)\s+{re.escape(audience)}\s*:)(.*?)(?=for\s+(?:a|the)\s+(?:{audience_alternation})\s*:|$)", - # Alternative format: "Physician:" or "PHYSICIAN:" rf"(?:^|\n)\s*{re.escape(audience)}\s*:(.*?)(?=(?:^|\n)\s*(?:{audience_alternation})\s*:|$)", - # Numbered format: "1. Physician:" or "- Physician:" rf"(?:^|\n)\s*(?:\d+\.|\-|\*)\s*{re.escape(audience)}\s*:(.*?)(?=(?:^|\n)\s*(?:\d+\.|\-|\*)\s*(?:{audience_alternation})\s*:|$)", - # Section header format: "## Physician" or "### For Physician" rf"(?:^|\n)\s*#{1,4}\s*(?:for\s+)?{re.escape(audience)}\s*#{0,4}\s*\n?(.*?)(?=(?:^|\n)\s*#{1,4}\s*(?:for\s+)?(?:{audience_alternation})\s*#{0,4}|$)", - - # Bold/emphasis format: "**Physician:**" or "*For Physician:*" - rf"(?:\*{{1,2}}|_{{1,2}})\s*(?:for\s+)?{re.escape(audience)}\s*:?\s*(?:\*{{1,2}}|_{{1,2}})(.*?)(?=(?:\*{{1,2}}|_{{1,2}})\s*(?:for\s+)?(?:{audience_alternation})\s*:?\s*(?:\*{{1,2}}|_{{1,2}})|$)" + # Bold/emphasis format: "**Physician:**" or "*For Physician:*" + rf"(?:\*{{1,2}}|_{{1,2}})\s*(?:for\s+)?{re.escape(audience)}\s*:?\s*(?:\*{{1,2}}|_{{1,2}})(.*?)(?=(?:\*{{1,2}}|_{{1,2}})\s*(?:for\s+)?(?:{audience_alternation})\s*:?\s*(?:\*{{1,2}}|_{{1,2}})|$)", ] - + patterns[audience] = audience_patterns - + # Try each pattern for each audience and use the first match found for audience in audiences: explanation = None - + for pattern in patterns[audience]: try: # Use DOTALL flag to match across newlines and IGNORECASE for case-insensitive matching matches = re.finditer(pattern, response, re.DOTALL | re.IGNORECASE | re.MULTILINE) - + for match in matches: content = match.group(1).strip() if content and len(content) > 10: # Ensure we have substantial content explanation = content logger.debug(f"Found {audience} explanation using pattern: {pattern[:50]}...") break - + if explanation: break - + except re.error as e: logger.warning(f"Regex error for {audience} with pattern {pattern}: {e}") continue - + if explanation: # Clean up the extracted explanation explanation = AudienceAdaptivePrompt._clean_explanation(explanation) explanations[audience] = explanation else: logger.warning(f"Could not extract explanation for {audience}") - + # Fallback: If we have fewer than expected explanations, try to extract missing ones if len(explanations) < len(audiences): logger.info(f"Only found {len(explanations)}/{len(audiences)} explanations, trying fallback parsing") fallback_explanations = AudienceAdaptivePrompt._fallback_parse(response, audiences) - + # Add any missing explanations from fallback for audience in audiences: if audience not in explanations and audience in fallback_explanations: explanations[audience] = fallback_explanations[audience] logger.info(f"Added {audience} explanation from fallback parsing") - + # Validate that we have explanations for all audiences missing_audiences = [aud for aud in audiences if aud not in explanations or not explanations[aud].strip()] if missing_audiences: logger.warning(f"Missing explanations for audiences: {missing_audiences}") - + logger.info(f"Successfully parsed explanations for {len(explanations)} audiences") return explanations - + @staticmethod def _clean_explanation(text: str) -> str: """ Clean and normalize extracted explanation text - + Args: text: Raw extracted text - + Returns: Cleaned explanation text """ # Remove leading/trailing whitespace text = text.strip() - + # Remove common prefixes that might be captured prefixes_to_remove = [ - r'^\s*:\s*', # Leading colon - r'^\s*\-\s*', # Leading dash - r'^\s*\*\s*', # Leading asterisk - r'^\s*\d+\.\s*', # Leading number + r"^\s*:\s*", # Leading colon + r"^\s*\-\s*", # Leading dash + r"^\s*\*\s*", # Leading asterisk + r"^\s*\d+\.\s*", # Leading number ] - + for prefix_pattern in prefixes_to_remove: - text = re.sub(prefix_pattern, '', text, flags=re.MULTILINE) - + text = re.sub(prefix_pattern, "", text, flags=re.MULTILINE) + # Normalize whitespace - convert multiple spaces/newlines to single - text = re.sub(r'\n\s*\n\s*\n', '\n\n', text) # Max 2 consecutive newlines - text = re.sub(r'[ \t]+', ' ', text) # Multiple spaces to single space - + text = re.sub(r"\n\s*\n\s*\n", "\n\n", text) # Max 2 consecutive newlines + text = re.sub(r"[ \t]+", " ", text) # Multiple spaces to single space + # Remove markdown formatting artifacts - text = re.sub(r'\*{1,2}([^*]+)\*{1,2}', r'\1', text) # Remove bold/italic - text = re.sub(r'_{1,2}([^_]+)_{1,2}', r'\1', text) # Remove underscore emphasis - text = re.sub(r'`([^`]+)`', r'\1', text) # Remove code formatting - + text = re.sub(r"\*{1,2}([^*]+)\*{1,2}", r"\1", text) # Remove bold/italic + text = re.sub(r"_{1,2}([^_]+)_{1,2}", r"\1", text) # Remove underscore emphasis + text = re.sub(r"`([^`]+)`", r"\1", text) # Remove code formatting + return text.strip() - + @staticmethod def _fallback_parse(response: str, audiences: List[str]) -> Dict[str, str]: """ Fallback parsing method using simple keyword search - + Args: response: Model response text audiences: List of audience names - + Returns: Dictionary with extracted explanations """ explanations: Dict[str, str] = {} - + # Split response into sections based on audience keywords - lines = response.split('\n') + lines = response.split("\n") current_audience: Optional[str] = None current_text: List[str] = [] - + for line in lines: line_lower = line.lower().strip() - + # Check if line contains an audience keyword matched_audience = None for audience in audiences: if audience.lower() in line_lower: # Additional check to ensure it's likely a section header - if any(indicator in line_lower for indicator in ['for', ':', '#']) or line_lower.strip() == audience.lower(): + if ( + any(indicator in line_lower for indicator in ["for", ":", "#"]) + or line_lower.strip() == audience.lower() + ): matched_audience = audience break - + if matched_audience: # Save previous section if current_audience and current_text: - content = '\n'.join(current_text).strip() + content = "\n".join(current_text).strip() if content: explanations[current_audience] = content - + # Start new section current_audience = matched_audience current_text = [] - + # Check if the explanation starts on the same line after colon - colon_pos = line.find(':') + colon_pos = line.find(":") if colon_pos != -1 and colon_pos < len(line) - 1: - remaining_text = line[colon_pos + 1:].strip() + remaining_text = line[colon_pos + 1 :].strip() if remaining_text: current_text.append(remaining_text) - + elif current_audience and line.strip(): current_text.append(line.strip()) - + # Add the last section if current_audience and current_text: - content = '\n'.join(current_text).strip() + content = "\n".join(current_text).strip() if content: explanations[current_audience] = content - - return explanations \ No newline at end of file + + return explanations diff --git a/src/strategies.py b/src/strategies.py index b6515db..5303f14 100644 --- a/src/strategies.py +++ b/src/strategies.py @@ -35,122 +35,122 @@ class AudienceStrategy(ABC): """Abstract base class for audience-specific scoring strategies. - + This abstract class defines the interface for audience-specific scoring strategies used in MEQ-Bench evaluation. Each concrete strategy implements scoring logic tailored to the expectations and needs of a specific healthcare audience. - + Attributes: audience: The target audience name (e.g., 'physician', 'nurse'). eval_config: Evaluation configuration loaded from the config system. - + Abstract Methods: calculate_readability_score: Calculate readability score for the audience. calculate_terminology_score: Calculate terminology appropriateness score. get_expected_explanation_length: Get expected explanation length range. """ - + def __init__(self, audience: str) -> None: """Initialize audience strategy. - + Args: audience: Target audience name (e.g., 'physician', 'nurse', 'patient', 'caregiver'). """ self.audience: str = audience self.eval_config: Dict[str, Any] = config.get_evaluation_config() - + @abstractmethod def calculate_readability_score(self, text: str, grade_level: float) -> float: """Calculate readability score for the audience. - + Args: text: Text to evaluate for readability. grade_level: Computed grade level (e.g., Flesch-Kincaid score). - + Returns: Readability score between 0.0 and 1.0, where 1.0 indicates optimal readability for the target audience. """ pass - + @abstractmethod def calculate_terminology_score(self, text: str, term_density: float) -> float: """Calculate terminology appropriateness score for the audience. - + Args: text: Text to evaluate for terminology usage. term_density: Ratio of medical terms to total words in the text. - + Returns: Terminology score between 0.0 and 1.0, where 1.0 indicates optimal terminology usage for the target audience. """ pass - + @abstractmethod def get_expected_explanation_length(self) -> Dict[str, int]: """Get expected explanation length range for the audience. - + Returns: Dictionary with 'min' and 'max' keys indicating the expected word count range for explanations targeting this audience. """ pass - + def get_readability_targets(self) -> Dict[str, float]: """Get readability targets for this audience. - + Returns: Dictionary containing readability targets including minimum and maximum grade levels appropriate for this audience. """ - return self.eval_config['readability_targets'][self.audience] - + return self.eval_config["readability_targets"][self.audience] + def get_terminology_targets(self) -> Dict[str, float]: """Get terminology density targets for this audience. - + Returns: Dictionary containing terminology density targets including target density and acceptable tolerance range. """ - return self.eval_config['terminology_density'][self.audience] + return self.eval_config["terminology_density"][self.audience] class PhysicianStrategy(AudienceStrategy): """Strategy for physician audience scoring. - + Physicians expect technical, evidence-based explanations with precise medical terminology. This strategy scores explanations based on graduate-level complexity (12-16 grade level) and high medical terminology density. - + Scoring characteristics: - Readability: Favors higher complexity (graduate level) - Terminology: Expects high medical term density - Length: Accommodates longer, detailed explanations """ - + def __init__(self) -> None: - super().__init__('physician') - + super().__init__("physician") + def calculate_readability_score(self, text: str, grade_level: float) -> float: """Calculate readability score for physician audience. - + Physicians expect graduate-level complexity (12-16 grade level). Higher complexity is generally better for physicians, but extremely high complexity (>16) may still be penalized. - + Args: text: Text to evaluate for readability. grade_level: Computed grade level (e.g., Flesch-Kincaid score). - + Returns: Readability score between 0.0 and 1.0, with 1.0 for optimal complexity. """ targets = self.get_readability_targets() - min_level = targets['min_grade_level'] - max_level = targets['max_grade_level'] - + min_level = targets["min_grade_level"] + max_level = targets["max_grade_level"] + if grade_level < min_level: # Too simple for physicians return max(0.0, grade_level / min_level) @@ -160,24 +160,24 @@ def calculate_readability_score(self, text: str, grade_level: float) -> float: else: # In the sweet spot return 1.0 - + def calculate_terminology_score(self, text: str, term_density: float) -> float: """Calculate terminology score for physician audience. - + Physicians expect high medical terminology density as they are comfortable with technical medical language and precise terminology. - + Args: text: Text to evaluate for terminology usage. term_density: Ratio of medical terms to total words in the text. - + Returns: Terminology score between 0.0 and 1.0, with 1.0 for optimal density. """ targets = self.get_terminology_targets() - target = targets['target'] - tolerance = targets['tolerance'] - + target = targets["target"] + tolerance = targets["tolerance"] + if abs(term_density - target) <= tolerance: return 1.0 elif term_density < target: @@ -187,136 +187,136 @@ def calculate_terminology_score(self, text: str, term_density: float) -> float: # Too much terminology (even for physicians) excess = term_density - target - tolerance return max(0.0, 1.0 - excess * 2.0) - + def get_expected_explanation_length(self) -> Dict[str, int]: """Get expected explanation length for physician audience. - + Physicians can handle longer, detailed explanations with comprehensive medical information and technical details. - + Returns: Dictionary with 'min' and 'max' keys for expected word count range. """ scoring_config = config.get_scoring_config() - max_length = scoring_config['parameters']['max_explanation_length']['physician'] - min_length = scoring_config['parameters']['min_explanation_length'] - - return {'min': min_length, 'max': max_length} + max_length = scoring_config["parameters"]["max_explanation_length"]["physician"] + min_length = scoring_config["parameters"]["min_explanation_length"] + + return {"min": min_length, "max": max_length} class NurseStrategy(AudienceStrategy): """Strategy for nurse audience scoring. - + Nurses expect moderate complexity explanations that balance technical accuracy with practical application. They need information that supports patient care and education responsibilities. - + Scoring characteristics: - Readability: Moderate complexity (10-14 grade level) - Terminology: Balanced medical terminology with practical language - Length: Practical, actionable explanations """ - + def __init__(self) -> None: - super().__init__('nurse') - + super().__init__("nurse") + def calculate_readability_score(self, text: str, grade_level: float) -> float: """Calculate readability score for nurse audience. - + Nurses expect moderate complexity (10-14 grade level) that balances technical accuracy with practical application in patient care settings. - + Args: text: Text to evaluate for readability. grade_level: Computed grade level (e.g., Flesch-Kincaid score). - + Returns: Readability score between 0.0 and 1.0, with 1.0 for optimal complexity. """ targets = self.get_readability_targets() - min_level = targets['min_grade_level'] - max_level = targets['max_grade_level'] - + min_level = targets["min_grade_level"] + max_level = targets["max_grade_level"] + if min_level <= grade_level <= max_level: return 1.0 elif grade_level < min_level: return max(0.0, grade_level / min_level) else: return max(0.0, 1.0 - (grade_level - max_level) / 6.0) - + def calculate_terminology_score(self, text: str, term_density: float) -> float: """Calculate terminology score for nurse audience. - + Nurses expect moderate medical terminology that includes technical terms but also incorporates practical language for patient care contexts. - + Args: text: Text to evaluate for terminology usage. term_density: Ratio of medical terms to total words in the text. - + Returns: Terminology score between 0.0 and 1.0, with 1.0 for optimal density. """ targets = self.get_terminology_targets() - target = targets['target'] - tolerance = targets['tolerance'] - + target = targets["target"] + tolerance = targets["tolerance"] + if abs(term_density - target) <= tolerance: return 1.0 else: deviation = abs(term_density - target) - tolerance return max(0.0, 1.0 - deviation * 3.0) - + def get_expected_explanation_length(self) -> Dict[str, int]: """Get expected explanation length for nurse audience. - + Nurses need practical, actionable explanations that provide clear guidance for patient care and education. - + Returns: Dictionary with 'min' and 'max' keys for expected word count range. """ scoring_config = config.get_scoring_config() - max_length = scoring_config['parameters']['max_explanation_length']['nurse'] - min_length = scoring_config['parameters']['min_explanation_length'] - - return {'min': min_length, 'max': max_length} + max_length = scoring_config["parameters"]["max_explanation_length"]["nurse"] + min_length = scoring_config["parameters"]["min_explanation_length"] + + return {"min": min_length, "max": max_length} class PatientStrategy(AudienceStrategy): """Strategy for patient audience scoring. - + Patients need simple, accessible explanations that avoid medical jargon and focus on understanding their condition and next steps. Explanations should be empathetic and reassuring. - + Scoring characteristics: - Readability: Simple language (6-10 grade level) - Terminology: Minimal medical terminology, jargon explained - Length: Concise, clear explanations """ - + def __init__(self) -> None: - super().__init__('patient') - + super().__init__("patient") + def calculate_readability_score(self, text: str, grade_level: float) -> float: """Calculate readability score for patient audience. - + Patients need simple, accessible language (6-10 grade level). Lower complexity is generally better for patients to ensure understanding and engagement. - + Args: text: Text to evaluate for readability. grade_level: Computed grade level (e.g., Flesch-Kincaid score). - + Returns: Readability score between 0.0 and 1.0, with 1.0 for optimal simplicity. """ targets = self.get_readability_targets() - min_level = targets['min_grade_level'] - max_level = targets['max_grade_level'] - + min_level = targets["min_grade_level"] + max_level = targets["max_grade_level"] + if min_level <= grade_level <= max_level: return 1.0 elif grade_level < min_level: @@ -325,162 +325,162 @@ def calculate_readability_score(self, text: str, grade_level: float) -> float: else: # Too complex for patients return max(0.0, 1.0 - (grade_level - max_level) / 4.0) - + def calculate_terminology_score(self, text: str, term_density: float) -> float: """Calculate terminology score for patient audience. - + Patients should have minimal medical terminology in their explanations. Medical jargon should be avoided or clearly explained in simple terms. - + Args: text: Text to evaluate for terminology usage. term_density: Ratio of medical terms to total words in the text. - + Returns: Terminology score between 0.0 and 1.0, with 1.0 for minimal jargon. """ targets = self.get_terminology_targets() - target = targets['target'] - tolerance = targets['tolerance'] - + target = targets["target"] + tolerance = targets["tolerance"] + if term_density <= target + tolerance: return 1.0 else: # Penalty for too much medical terminology excess = term_density - target - tolerance return max(0.0, 1.0 - excess * 10.0) - + def get_expected_explanation_length(self) -> Dict[str, int]: """Get expected explanation length for patient audience. - + Patients need concise, clear explanations that don't overwhelm with too much information at once. - + Returns: Dictionary with 'min' and 'max' keys for expected word count range. """ scoring_config = config.get_scoring_config() - max_length = scoring_config['parameters']['max_explanation_length']['patient'] - min_length = scoring_config['parameters']['min_explanation_length'] - - return {'min': min_length, 'max': max_length} + max_length = scoring_config["parameters"]["max_explanation_length"]["patient"] + min_length = scoring_config["parameters"]["min_explanation_length"] + + return {"min": min_length, "max": max_length} class CaregiverStrategy(AudienceStrategy): """Strategy for caregiver audience scoring. - + Caregivers need actionable, practical explanations that focus on observable symptoms, clear instructions, and when to seek help. They need simple language but with specific guidance. - + Scoring characteristics: - Readability: Clear, actionable language (6-10 grade level) - Terminology: Minimal medical terminology, focus on observable signs - Length: Practical, step-by-step guidance """ - + def __init__(self) -> None: - super().__init__('caregiver') - + super().__init__("caregiver") + def calculate_readability_score(self, text: str, grade_level: float) -> float: """Calculate readability score for caregiver audience. - + Caregivers need actionable, clear language (6-10 grade level) with a focus on practical instructions and observable guidance. - + Args: text: Text to evaluate for readability. grade_level: Computed grade level (e.g., Flesch-Kincaid score). - + Returns: Readability score between 0.0 and 1.0, with 1.0 for optimal clarity. """ targets = self.get_readability_targets() - min_level = targets['min_grade_level'] - max_level = targets['max_grade_level'] - + min_level = targets["min_grade_level"] + max_level = targets["max_grade_level"] + if min_level <= grade_level <= max_level: return 1.0 elif grade_level < min_level: return 0.9 # Simple is good for caregivers else: return max(0.0, 1.0 - (grade_level - max_level) / 4.0) - + def calculate_terminology_score(self, text: str, term_density: float) -> float: """Calculate terminology score for caregiver audience. - + Caregivers need minimal medical terminology with focus on observable symptoms and clear actions they can take or monitor. - + Args: text: Text to evaluate for terminology usage. term_density: Ratio of medical terms to total words in the text. - + Returns: Terminology score between 0.0 and 1.0, with 1.0 for optimal practical language. """ targets = self.get_terminology_targets() - target = targets['target'] - tolerance = targets['tolerance'] - + target = targets["target"] + tolerance = targets["tolerance"] + if term_density <= target + tolerance: return 1.0 else: excess = term_density - target - tolerance return max(0.0, 1.0 - excess * 8.0) - + def get_expected_explanation_length(self) -> Dict[str, int]: """Get expected explanation length for caregiver audience. - + Caregivers need practical, step-by-step guidance that provides clear instructions and actionable information. - + Returns: Dictionary with 'min' and 'max' keys for expected word count range. """ scoring_config = config.get_scoring_config() - max_length = scoring_config['parameters']['max_explanation_length']['caregiver'] - min_length = scoring_config['parameters']['min_explanation_length'] - - return {'min': min_length, 'max': max_length} + max_length = scoring_config["parameters"]["max_explanation_length"]["caregiver"] + min_length = scoring_config["parameters"]["min_explanation_length"] + + return {"min": min_length, "max": max_length} class StrategyFactory: """Factory for creating audience strategies. - + This factory class implements the Factory pattern to create appropriate audience strategy instances based on the target audience. It provides a centralized way to instantiate strategies and maintains a registry of supported audiences. - + Attributes: _strategies: Dictionary mapping audience names to strategy classes. """ - + _strategies: Dict[str, Type[AudienceStrategy]] = { - 'physician': PhysicianStrategy, - 'nurse': NurseStrategy, - 'patient': PatientStrategy, - 'caregiver': CaregiverStrategy + "physician": PhysicianStrategy, + "nurse": NurseStrategy, + "patient": PatientStrategy, + "caregiver": CaregiverStrategy, } - + @classmethod def create_strategy(cls, audience: str) -> AudienceStrategy: """Create strategy for given audience. - + Instantiates the appropriate strategy class for the specified audience. The strategy encapsulates audience-specific scoring logic and expectations. - + Args: audience: Target audience name (e.g., 'physician', 'nurse', 'patient', 'caregiver'). - + Returns: AudienceStrategy instance configured for the specified audience. - + Raises: ValueError: If the audience is not supported. Use get_supported_audiences() to see available options. - + Example: ```python strategy = StrategyFactory.create_strategy('patient') @@ -490,15 +490,15 @@ def create_strategy(cls, audience: str) -> AudienceStrategy: if audience not in cls._strategies: supported = list(cls._strategies.keys()) raise ValueError(f"Unsupported audience: {audience}. Supported: {supported}") - + strategy_class = cls._strategies[audience] - return strategy_class() # type: ignore[misc] - + return strategy_class(audience) # type: ignore[misc] + @classmethod def get_supported_audiences(cls) -> List[str]: """Get list of supported audiences. - + Returns: List of audience names that have corresponding strategy implementations. """ - return list(cls._strategies.keys()) \ No newline at end of file + return list(cls._strategies.keys()) diff --git a/tests/conftest.py b/tests/conftest.py index 958692c..4b400e8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,8 +13,8 @@ sys.path.insert(0, str(Path(__file__).parent.parent / "src")) # Mock environment variables for testing -os.environ['OPENAI_API_KEY'] = 'test-key' -os.environ['ANTHROPIC_API_KEY'] = 'test-key' +os.environ["OPENAI_API_KEY"] = "test-key" +os.environ["ANTHROPIC_API_KEY"] = "test-key" from src.benchmark import MEQBench, MEQBenchItem from src.evaluator import MEQBenchEvaluator @@ -37,7 +37,7 @@ def sample_benchmark_item(): id="test_001", medical_content="Diabetes is a metabolic disorder characterized by high blood sugar levels.", complexity_level="basic", - source_dataset="test" + source_dataset="test", ) @@ -48,7 +48,7 @@ def sample_explanations(): "physician": "Essential hypertension with systolic BP >140 mmHg or diastolic >90 mmHg, requiring antihypertensive therapy and cardiovascular risk stratification.", "nurse": "Patient has high blood pressure requiring medication monitoring, lifestyle education, and regular BP checks. Watch for medication side effects.", "patient": "You have high blood pressure, which means your heart is working harder than it should. We'll give you medicine to help lower it.", - "caregiver": "Their blood pressure is too high. Make sure they take their medicine daily and watch for dizziness or headaches." + "caregiver": "Their blood pressure is too high. Make sure they take their medicine daily and watch for dizziness or headaches.", } @@ -67,6 +67,7 @@ def evaluator_instance(): @pytest.fixture def dummy_model_function(): """Dummy model function for testing""" + def model_func(prompt): return """ For a Physician: Technical medical explanation with proper terminology. @@ -74,6 +75,7 @@ def model_func(prompt): For a Patient: Simple, clear explanation without medical jargon. For a Caregiver: Concrete instructions and warning signs to watch for. """ + return model_func @@ -86,4 +88,4 @@ def test_data_dir(): @pytest.fixture def temp_output_dir(tmp_path): """Temporary directory for test outputs""" - return tmp_path / "test_outputs" \ No newline at end of file + return tmp_path / "test_outputs" diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 98ff8a0..4c38f6b 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -11,217 +11,206 @@ class TestMEQBenchItem: """Test MEQBenchItem dataclass""" - + def test_creation(self): """Test basic item creation""" - item = MEQBenchItem( - id="test_001", - medical_content="Test content", - complexity_level="basic", - source_dataset="test" - ) - + item = MEQBenchItem(id="test_001", medical_content="Test content", complexity_level="basic", source_dataset="test") + assert item.id == "test_001" assert item.medical_content == "Test content" assert item.complexity_level == "basic" assert item.source_dataset == "test" assert item.reference_explanations is None - + def test_creation_with_references(self): """Test item creation with reference explanations""" - references = { - "physician": "Technical explanation", - "patient": "Simple explanation" - } - + references = {"physician": "Technical explanation", "patient": "Simple explanation"} + item = MEQBenchItem( id="test_002", medical_content="Test content", complexity_level="intermediate", source_dataset="test", - reference_explanations=references + reference_explanations=references, ) - + assert item.reference_explanations == references class TestMEQBench: """Test MEQBench class""" - + def test_initialization(self): """Test benchmark initialization""" bench = MEQBench() assert bench.benchmark_items == [] assert bench.evaluator is not None assert bench.prompt_template is not None - + def test_add_benchmark_item(self, sample_benchmark_item): """Test adding benchmark items""" bench = MEQBench() bench.add_benchmark_item(sample_benchmark_item) - + assert len(bench.benchmark_items) == 1 assert bench.benchmark_items[0] == sample_benchmark_item - + def test_generate_explanations(self, sample_medical_content, dummy_model_function): """Test explanation generation""" bench = MEQBench() explanations = bench.generate_explanations(sample_medical_content, dummy_model_function) - + assert isinstance(explanations, dict) # Should contain at least some audience explanations assert len(explanations) > 0 - + def test_create_sample_dataset(self, tmp_path): """Test sample dataset creation""" bench = MEQBench() output_path = tmp_path / "sample_dataset.json" - + sample_items = bench.create_sample_dataset(str(output_path)) - + # Check items were created assert len(sample_items) > 0 assert all(isinstance(item, MEQBenchItem) for item in sample_items) - + # Check file was saved assert output_path.exists() - + # Check file contents - with open(output_path, 'r') as f: + with open(output_path, "r") as f: data = json.load(f) - + assert len(data) == len(sample_items) - assert all('id' in item for item in data) - assert all('medical_content' in item for item in data) - + assert all("id" in item for item in data) + assert all("medical_content" in item for item in data) + def test_get_benchmark_stats_empty(self): """Test stats for empty benchmark""" bench = MEQBench() stats = bench.get_benchmark_stats() - - assert stats['total_items'] == 0 - assert 'message' in stats - + + assert stats["total_items"] == 0 + assert "message" in stats + def test_get_benchmark_stats_with_items(self, sample_benchmark_item): """Test stats with benchmark items""" bench = MEQBench() bench.add_benchmark_item(sample_benchmark_item) - + stats = bench.get_benchmark_stats() - - assert stats['total_items'] == 1 - assert 'complexity_distribution' in stats - assert 'source_distribution' in stats - assert stats['complexity_distribution']['basic'] == 1 - assert stats['source_distribution']['test'] == 1 - + + assert stats["total_items"] == 1 + assert "complexity_distribution" in stats + assert "source_distribution" in stats + assert stats["complexity_distribution"]["basic"] == 1 + assert stats["source_distribution"]["test"] == 1 + def test_evaluate_model_basic(self, sample_benchmark_item, dummy_model_function): """Test basic model evaluation""" bench = MEQBench() bench.add_benchmark_item(sample_benchmark_item) - + results = bench.evaluate_model(dummy_model_function, max_items=1) - - assert 'total_items' in results - assert 'audience_scores' in results - assert 'complexity_scores' in results - assert 'detailed_results' in results - assert 'summary' in results - - assert results['total_items'] == 1 - assert len(results['detailed_results']) == 1 - + + assert "total_items" in results + assert "audience_scores" in results + assert "complexity_scores" in results + assert "detailed_results" in results + assert "summary" in results + + assert results["total_items"] == 1 + assert len(results["detailed_results"]) == 1 + def test_save_results(self, tmp_path): """Test results saving""" bench = MEQBench() output_path = tmp_path / "test_results.json" - - test_results = { - 'total_items': 1, - 'test_data': 'test_value' - } - + + test_results = {"total_items": 1, "test_data": "test_value"} + bench.save_results(test_results, str(output_path)) - + assert output_path.exists() - - with open(output_path, 'r') as f: + + with open(output_path, "r") as f: loaded_results = json.load(f) - + assert loaded_results == test_results - + def test_add_duplicate_benchmark_item(self, sample_benchmark_item): """Test that adding a benchmark item with a duplicate ID raises a ValueError""" bench = MEQBench() - + # Add the first item bench.add_benchmark_item(sample_benchmark_item) assert len(bench.benchmark_items) == 1 - + # Try to add another item with the same ID duplicate_item = MEQBenchItem( id=sample_benchmark_item.id, # Same ID as the first item medical_content="Different content", complexity_level="intermediate", - source_dataset="different_source" + source_dataset="different_source", ) - + # Should raise ValueError due to duplicate ID with pytest.raises(ValueError, match=f"Item with ID '{sample_benchmark_item.id}' already exists"): bench.add_benchmark_item(duplicate_item) - + # Verify the original item count remains the same assert len(bench.benchmark_items) == 1 - + def test_generate_explanations_empty_content(self, dummy_model_function): """Test that generate_explanations raises a ValueError when medical_content is empty""" bench = MEQBench() - + # Test with completely empty string with pytest.raises(ValueError, match="medical_content cannot be empty or contain only whitespace"): bench.generate_explanations("", dummy_model_function) - + # Test with whitespace-only string with pytest.raises(ValueError, match="medical_content cannot be empty or contain only whitespace"): bench.generate_explanations(" \n\t ", dummy_model_function) - + # Test with very short content (less than 10 characters) with pytest.raises(ValueError, match="medical_content must be at least 10 characters long"): bench.generate_explanations("short", dummy_model_function) - + def test_evaluate_model_no_items(self, dummy_model_function): """Test that evaluate_model returns appropriate result when there are no benchmark items""" bench = MEQBench() - + # Ensure no items are loaded assert len(bench.benchmark_items) == 0 - + # Evaluate model with no items results = bench.evaluate_model(dummy_model_function) - + # Should return a valid results structure but with zero items assert isinstance(results, dict) - assert results['total_items'] == 0 - assert 'audience_scores' in results - assert 'complexity_scores' in results - assert 'detailed_results' in results - assert 'summary' in results - + assert results["total_items"] == 0 + assert "audience_scores" in results + assert "complexity_scores" in results + assert "detailed_results" in results + assert "summary" in results + # All audience scores should be empty lists - for audience in ['physician', 'nurse', 'patient', 'caregiver']: - assert results['audience_scores'][audience] == [] - + for audience in ["physician", "nurse", "patient", "caregiver"]: + assert results["audience_scores"][audience] == [] + # All complexity scores should be empty lists - for complexity in ['basic', 'intermediate', 'advanced']: - assert results['complexity_scores'][complexity] == [] - + for complexity in ["basic", "intermediate", "advanced"]: + assert results["complexity_scores"][complexity] == [] + # Detailed results should be empty - assert results['detailed_results'] == [] - + assert results["detailed_results"] == [] + # Summary should handle empty data gracefully - summary = results['summary'] + summary = results["summary"] assert isinstance(summary, dict) # Most summary stats should be absent or 0 for empty data - if 'overall_mean' in summary: + if "overall_mean" in summary: # If present, should be a reasonable default or empty value - assert summary['overall_mean'] is None or isinstance(summary['overall_mean'], (int, float)) \ No newline at end of file + assert summary["overall_mean"] is None or isinstance(summary["overall_mean"], (int, float)) diff --git a/tests/test_data_loaders.py b/tests/test_data_loaders.py index 48d27fd..cfbe130 100644 --- a/tests/test_data_loaders.py +++ b/tests/test_data_loaders.py @@ -10,130 +10,114 @@ from src.data_loaders import ( load_medqa_usmle, - load_icliniq, + load_icliniq, load_cochrane_reviews, save_benchmark_items, calculate_complexity_level, - _validate_benchmark_item + _validate_benchmark_item, ) from src.benchmark import MEQBenchItem class TestCalculateComplexityLevel: """Test complexity level calculation using Flesch-Kincaid scores""" - + def test_empty_text_raises_error(self): """Test that empty text raises ValueError""" with pytest.raises(ValueError, match="Text must be a non-empty string"): calculate_complexity_level("") - + with pytest.raises(ValueError, match="Text must be a non-empty string"): calculate_complexity_level(None) - + def test_whitespace_only_raises_error(self): """Test that whitespace-only text raises ValueError""" with pytest.raises(ValueError, match="Text cannot be empty or whitespace only"): calculate_complexity_level(" \n\t ") - - @patch('src.data_loaders.textstat', None) + + @patch("src.data_loaders.textstat", None) def test_fallback_complexity_calculation(self): """Test fallback complexity calculation when textstat is unavailable""" # Simple text should be basic simple_text = "This is simple. It has short words." complexity = calculate_complexity_level(simple_text) - assert complexity in ['basic', 'intermediate', 'advanced'] - + assert complexity in ["basic", "intermediate", "advanced"] + # Complex medical text should be advanced - complex_text = ("Pharmacokinetic interactions involving cytochrome P450 enzymes can significantly " - "alter therapeutic drug concentrations, potentially leading to adverse effects or " - "therapeutic failure in clinical practice.") + complex_text = ( + "Pharmacokinetic interactions involving cytochrome P450 enzymes can significantly " + "alter therapeutic drug concentrations, potentially leading to adverse effects or " + "therapeutic failure in clinical practice." + ) complexity = calculate_complexity_level(complex_text) - assert complexity in ['basic', 'intermediate', 'advanced'] - - @patch('src.data_loaders.textstat') + assert complexity in ["basic", "intermediate", "advanced"] + + @patch("src.data_loaders.textstat") def test_with_textstat_available(self, mock_textstat): """Test complexity calculation when textstat is available""" # Mock textstat to return specific grade levels mock_textstat.flesch_kincaid.return_value.grade.return_value = 6.0 - + text = "Simple medical text for testing." complexity = calculate_complexity_level(text) - assert complexity == 'basic' - + assert complexity == "basic" + # Test intermediate level mock_textstat.flesch_kincaid.return_value.grade.return_value = 10.0 complexity = calculate_complexity_level(text) - assert complexity == 'intermediate' - + assert complexity == "intermediate" + # Test advanced level mock_textstat.flesch_kincaid.return_value.grade.return_value = 15.0 complexity = calculate_complexity_level(text) - assert complexity == 'advanced' - - @patch('src.data_loaders.textstat') + assert complexity == "advanced" + + @patch("src.data_loaders.textstat") def test_textstat_error_fallback(self, mock_textstat): """Test fallback when textstat raises an exception""" mock_textstat.flesch_kincaid.return_value.grade.side_effect = Exception("Textstat error") - + text = "Test text for error handling." complexity = calculate_complexity_level(text) - assert complexity in ['basic', 'intermediate', 'advanced'] + assert complexity in ["basic", "intermediate", "advanced"] class TestValidateBenchmarkItem: """Test benchmark item validation""" - + def test_valid_item(self): """Test validation of a valid item""" item = MEQBenchItem( id="test_001", medical_content="This is valid medical content for testing purposes.", complexity_level="basic", - source_dataset="test" + source_dataset="test", ) # Should not raise any exception _validate_benchmark_item(item) - + def test_empty_id_raises_error(self): """Test that empty ID raises ValueError""" - item = MEQBenchItem( - id="", - medical_content="Valid content", - complexity_level="basic", - source_dataset="test" - ) + item = MEQBenchItem(id="", medical_content="Valid content", complexity_level="basic", source_dataset="test") with pytest.raises(ValueError, match="Item ID must be a non-empty string"): _validate_benchmark_item(item) - + def test_non_string_id_raises_error(self): """Test that non-string ID raises ValueError""" - item = MEQBenchItem( - id=123, - medical_content="Valid content", - complexity_level="basic", - source_dataset="test" - ) + item = MEQBenchItem(id=123, medical_content="Valid content", complexity_level="basic", source_dataset="test") with pytest.raises(ValueError, match="Item ID must be a non-empty string"): _validate_benchmark_item(item) - + def test_empty_content_raises_error(self): """Test that empty medical content raises ValueError""" - item = MEQBenchItem( - id="test_001", - medical_content="", - complexity_level="basic", - source_dataset="test" - ) + item = MEQBenchItem(id="test_001", medical_content="", complexity_level="basic", source_dataset="test") with pytest.raises(ValueError, match="Medical content must be a non-empty string"): _validate_benchmark_item(item) - + def test_short_content_raises_error(self): """Test that very short content raises ValueError""" item = MEQBenchItem( - id="test_001", - medical_content="Short", # Less than 20 characters - complexity_level="basic", - source_dataset="test" + id="test_001", medical_content="Short", complexity_level="basic", source_dataset="test" # Less than 20 characters ) with pytest.raises(ValueError, match="Medical content is too short"): _validate_benchmark_item(item) @@ -141,34 +125,34 @@ def test_short_content_raises_error(self): class TestLoadMedQAUSMLE: """Test MedQA-USMLE dataset loading""" - + def test_file_not_found_raises_error(self): """Test that non-existent file raises FileNotFoundError""" with pytest.raises(FileNotFoundError): load_medqa_usmle("/nonexistent/file.json") - + def test_invalid_json_raises_error(self): """Test that invalid JSON raises JSONDecodeError""" - with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: f.write("invalid json content") f.flush() - + with pytest.raises(json.JSONDecodeError): load_medqa_usmle(f.name) - + Path(f.name).unlink() # Clean up - + def test_non_list_data_raises_error(self): """Test that non-list JSON data raises ValueError""" - with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: json.dump({"not": "a list"}, f) f.flush() - + with pytest.raises(ValueError, match="MedQA-USMLE data must be a list"): load_medqa_usmle(f.name) - + Path(f.name).unlink() # Clean up - + def test_load_valid_data(self): """Test loading valid MedQA-USMLE data""" sample_data = [ @@ -176,84 +160,75 @@ def test_load_valid_data(self): "id": "medqa_001", "question": "What is the most common cause of hypertension?", "options": { - "A": "Primary hypertension", + "A": "Primary hypertension", "B": "Secondary hypertension", "C": "White coat hypertension", - "D": "Malignant hypertension" + "D": "Malignant hypertension", }, "answer": "A", - "explanation": "Primary hypertension accounts for 90-95% of cases." + "explanation": "Primary hypertension accounts for 90-95% of cases.", }, { "id": "medqa_002", "question": "Which medication is first-line for diabetes?", - "options": { - "A": "Insulin", - "B": "Metformin", - "C": "Sulfonylureas", - "D": "Thiazolidinediones" - }, + "options": {"A": "Insulin", "B": "Metformin", "C": "Sulfonylureas", "D": "Thiazolidinediones"}, "answer": "B", - "explanation": "Metformin is the first-line treatment for type 2 diabetes." - } + "explanation": "Metformin is the first-line treatment for type 2 diabetes.", + }, ] - - with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: json.dump(sample_data, f) f.flush() - + items = load_medqa_usmle(f.name, auto_complexity=False) - + assert len(items) == 2 assert all(isinstance(item, MEQBenchItem) for item in items) assert items[0].id == "medqa_001" assert items[0].source_dataset == "MedQA-USMLE" assert items[0].complexity_level == "intermediate" # Default when auto_complexity=False assert "What is the most common cause of hypertension?" in items[0].medical_content - + Path(f.name).unlink() # Clean up - + def test_max_items_limit(self): """Test that max_items parameter limits the number of loaded items""" - sample_data = [{"id": f"test_{i}", "question": f"Question {i}", "answer": "A"} - for i in range(5)] - - with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + sample_data = [{"id": f"test_{i}", "question": f"Question {i}", "answer": "A"} for i in range(5)] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: json.dump(sample_data, f) f.flush() - + items = load_medqa_usmle(f.name, max_items=3, auto_complexity=False) assert len(items) == 3 - + Path(f.name).unlink() # Clean up - - @patch('src.data_loaders.calculate_complexity_level') + + @patch("src.data_loaders.calculate_complexity_level") def test_auto_complexity_calculation(self, mock_calc_complexity): """Test automatic complexity level calculation""" - mock_calc_complexity.return_value = 'advanced' - - sample_data = [{ - "id": "test_001", - "question": "Complex medical question", - "answer": "A", - "explanation": "Detailed explanation" - }] - - with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + mock_calc_complexity.return_value = "advanced" + + sample_data = [ + {"id": "test_001", "question": "Complex medical question", "answer": "A", "explanation": "Detailed explanation"} + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: json.dump(sample_data, f) f.flush() - + items = load_medqa_usmle(f.name, auto_complexity=True) assert len(items) == 1 - assert items[0].complexity_level == 'advanced' + assert items[0].complexity_level == "advanced" mock_calc_complexity.assert_called_once() - + Path(f.name).unlink() # Clean up class TestLoadiCliniq: """Test iCliniq dataset loading""" - + def test_load_valid_icliniq_data(self): """Test loading valid iCliniq data""" sample_data = [ @@ -261,55 +236,57 @@ def test_load_valid_icliniq_data(self): "id": "icliniq_001", "patient_question": "I have been experiencing chest pain. Should I be worried?", "doctor_answer": "Chest pain can have various causes. Please consult a cardiologist for proper evaluation.", - "speciality": "Cardiology" + "speciality": "Cardiology", }, { - "id": "icliniq_002", + "id": "icliniq_002", "patient_question": "What are the side effects of aspirin?", "doctor_answer": "Common side effects include stomach irritation and increased bleeding risk.", - "speciality": "General Medicine" - } + "speciality": "General Medicine", + }, ] - - with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: json.dump(sample_data, f) f.flush() - + items = load_icliniq(f.name, auto_complexity=False) - + assert len(items) == 2 assert all(isinstance(item, MEQBenchItem) for item in items) assert items[0].source_dataset == "iCliniq" assert "Patient Question:" in items[0].medical_content assert "Doctor's Answer:" in items[0].medical_content assert "Cardiology" in items[0].medical_content - + Path(f.name).unlink() # Clean up - + def test_alternative_field_names(self): """Test loading iCliniq data with alternative field names""" - sample_data = [{ - "id": "icliniq_001", - "question": "Alternative question field", # Using 'question' instead of 'patient_question' - "answer": "Alternative answer field", # Using 'answer' instead of 'doctor_answer' - "specialty": "Dermatology" # Using 'specialty' instead of 'speciality' - }] - - with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + sample_data = [ + { + "id": "icliniq_001", + "question": "Alternative question field", # Using 'question' instead of 'patient_question' + "answer": "Alternative answer field", # Using 'answer' instead of 'doctor_answer' + "specialty": "Dermatology", # Using 'specialty' instead of 'speciality' + } + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: json.dump(sample_data, f) f.flush() - + items = load_icliniq(f.name, auto_complexity=False) assert len(items) == 1 assert "Alternative question field" in items[0].medical_content assert "Alternative answer field" in items[0].medical_content - + Path(f.name).unlink() # Clean up class TestLoadCochraneReviews: """Test Cochrane Reviews dataset loading""" - + def test_load_valid_cochrane_data(self): """Test loading valid Cochrane Reviews data""" sample_data = [ @@ -318,23 +295,23 @@ def test_load_valid_cochrane_data(self): "title": "Effectiveness of statins for cardiovascular disease prevention", "abstract": "This systematic review evaluates the effectiveness of statins in preventing cardiovascular events.", "conclusions": "Statins significantly reduce cardiovascular events in high-risk patients.", - "background": "Cardiovascular disease is a leading cause of mortality worldwide." + "background": "Cardiovascular disease is a leading cause of mortality worldwide.", }, { "id": "cochrane_002", "title": "Antibiotics for acute respiratory infections", "abstract": "Review of antibiotic effectiveness for respiratory tract infections.", "main_results": "Limited benefit of antibiotics for viral respiratory infections.", # Alternative field name - "objectives": "To assess antibiotic effectiveness in respiratory infections." # Alternative field name - } + "objectives": "To assess antibiotic effectiveness in respiratory infections.", # Alternative field name + }, ] - - with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: json.dump(sample_data, f) f.flush() - + items = load_cochrane_reviews(f.name, auto_complexity=False) - + assert len(items) == 2 assert all(isinstance(item, MEQBenchItem) for item in items) assert items[0].source_dataset == "Cochrane Reviews" @@ -342,48 +319,44 @@ def test_load_valid_cochrane_data(self): assert "Title:" in items[0].medical_content assert "Abstract:" in items[0].medical_content assert "statins" in items[0].medical_content.lower() - + Path(f.name).unlink() # Clean up - + def test_missing_title_and_abstract_skipped(self): """Test that items without title and abstract are skipped""" sample_data = [ - { - "id": "cochrane_001", - "title": "Valid title", - "abstract": "Valid abstract" - }, + {"id": "cochrane_001", "title": "Valid title", "abstract": "Valid abstract"}, { "id": "cochrane_002", # Missing both title and abstract - "conclusions": "Only conclusions available" - } + "conclusions": "Only conclusions available", + }, ] - - with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: json.dump(sample_data, f) f.flush() - + items = load_cochrane_reviews(f.name, auto_complexity=False) assert len(items) == 1 # Only the valid item should be loaded - + Path(f.name).unlink() # Clean up class TestSaveBenchmarkItems: """Test saving benchmark items to JSON""" - + def test_save_empty_list(self): """Test saving an empty list of items""" with tempfile.TemporaryDirectory() as temp_dir: output_path = Path(temp_dir) / "test_items.json" save_benchmark_items([], output_path) - + assert output_path.exists() - with open(output_path, 'r') as f: + with open(output_path, "r") as f: data = json.load(f) assert data == [] - + def test_save_valid_items(self): """Test saving valid benchmark items""" items = [ @@ -392,108 +365,93 @@ def test_save_valid_items(self): medical_content="Test content 1", complexity_level="basic", source_dataset="test", - reference_explanations={"physician": "Technical explanation"} + reference_explanations={"physician": "Technical explanation"}, ), - MEQBenchItem( - id="test_002", - medical_content="Test content 2", - complexity_level="advanced", - source_dataset="test" - ) + MEQBenchItem(id="test_002", medical_content="Test content 2", complexity_level="advanced", source_dataset="test"), ] - + with tempfile.TemporaryDirectory() as temp_dir: output_path = Path(temp_dir) / "test_items.json" save_benchmark_items(items, output_path, pretty_print=True) - + assert output_path.exists() - with open(output_path, 'r') as f: + with open(output_path, "r") as f: data = json.load(f) - + assert len(data) == 2 - assert data[0]['id'] == "test_001" - assert data[0]['medical_content'] == "Test content 1" - assert data[0]['complexity_level'] == "basic" - assert data[0]['source_dataset'] == "test" - assert data[0]['reference_explanations'] == {"physician": "Technical explanation"} - - assert data[1]['id'] == "test_002" - assert data[1]['reference_explanations'] is None - + assert data[0]["id"] == "test_001" + assert data[0]["medical_content"] == "Test content 1" + assert data[0]["complexity_level"] == "basic" + assert data[0]["source_dataset"] == "test" + assert data[0]["reference_explanations"] == {"physician": "Technical explanation"} + + assert data[1]["id"] == "test_002" + assert data[1]["reference_explanations"] is None + def test_save_creates_directory(self): """Test that saving creates parent directories if they don't exist""" with tempfile.TemporaryDirectory() as temp_dir: output_path = Path(temp_dir) / "subdir" / "test_items.json" - items = [MEQBenchItem( - id="test_001", - medical_content="Test content", - complexity_level="basic", - source_dataset="test" - )] - + items = [ + MEQBenchItem(id="test_001", medical_content="Test content", complexity_level="basic", source_dataset="test") + ] + save_benchmark_items(items, output_path) - + assert output_path.exists() assert output_path.parent.exists() class TestDataLoadersIntegration: """Integration tests for data loaders""" - + def test_load_multiple_datasets(self): """Test loading and combining multiple datasets""" # Create sample data for each dataset type - medqa_data = [{ - "id": "medqa_001", - "question": "MedQA question", - "answer": "A", - "explanation": "MedQA explanation" - }] - - icliniq_data = [{ - "id": "icliniq_001", - "patient_question": "iCliniq question", - "doctor_answer": "iCliniq answer", - "speciality": "General" - }] - - cochrane_data = [{ - "id": "cochrane_001", - "title": "Cochrane title", - "abstract": "Cochrane abstract" - }] - + medqa_data = [{"id": "medqa_001", "question": "MedQA question", "answer": "A", "explanation": "MedQA explanation"}] + + icliniq_data = [ + { + "id": "icliniq_001", + "patient_question": "iCliniq question", + "doctor_answer": "iCliniq answer", + "speciality": "General", + } + ] + + cochrane_data = [{"id": "cochrane_001", "title": "Cochrane title", "abstract": "Cochrane abstract"}] + with tempfile.TemporaryDirectory() as temp_dir: # Save sample data files medqa_file = Path(temp_dir) / "medqa.json" icliniq_file = Path(temp_dir) / "icliniq.json" cochrane_file = Path(temp_dir) / "cochrane.json" - - with open(medqa_file, 'w') as f: + + with open(medqa_file, "w") as f: json.dump(medqa_data, f) - with open(icliniq_file, 'w') as f: + with open(icliniq_file, "w") as f: json.dump(icliniq_data, f) - with open(cochrane_file, 'w') as f: + with open(cochrane_file, "w") as f: json.dump(cochrane_data, f) - + # Load all datasets medqa_items = load_medqa_usmle(medqa_file, auto_complexity=False) icliniq_items = load_icliniq(icliniq_file, auto_complexity=False) cochrane_items = load_cochrane_reviews(cochrane_file, auto_complexity=False) - + # Combine all items all_items = medqa_items + icliniq_items + cochrane_items - + assert len(all_items) == 3 assert all_items[0].source_dataset == "MedQA-USMLE" assert all_items[1].source_dataset == "iCliniq" assert all_items[2].source_dataset == "Cochrane Reviews" - + # Save combined dataset combined_file = Path(temp_dir) / "combined.json" save_benchmark_items(all_items, combined_file) - + # Verify saved file - with open(combined_file, 'r') as f: + with open(combined_file, "r") as f: saved_data = json.load(f) - assert len(saved_data) == 3 \ No newline at end of file + assert len(saved_data) == 3 diff --git a/tests/test_evaluator_metrics.py b/tests/test_evaluator_metrics.py index eaa3a53..36ce8c4 100644 --- a/tests/test_evaluator_metrics.py +++ b/tests/test_evaluator_metrics.py @@ -7,225 +7,227 @@ from src.evaluator import ( ContradictionDetection, - InformationPreservation, + InformationPreservation, HallucinationDetection, EvaluationScore, - EvaluationError + EvaluationError, ) class TestContradictionDetection: """Test ContradictionDetection metric""" - + @pytest.fixture def contradiction_detector(self): """Create ContradictionDetection instance for testing""" return ContradictionDetection() - + def test_initialization(self, contradiction_detector): """Test proper initialization of ContradictionDetection""" - assert hasattr(contradiction_detector, 'medical_knowledge_base') - assert hasattr(contradiction_detector, 'contradiction_patterns') + assert hasattr(contradiction_detector, "medical_knowledge_base") + assert hasattr(contradiction_detector, "contradiction_patterns") assert isinstance(contradiction_detector.medical_knowledge_base, dict) assert isinstance(contradiction_detector.contradiction_patterns, list) - + def test_no_contradictions_returns_high_score(self, contradiction_detector): """Test that text without contradictions returns high score""" clean_text = "Hypertension is treated with lifestyle changes and medication as prescribed by a doctor." score = contradiction_detector.calculate(clean_text, "patient") assert score == 1.0 - + def test_pattern_based_contradiction_detection(self, contradiction_detector): """Test detection of pattern-based contradictions""" contradictory_text = "You can treat viral infections with antibiotics effectively." score = contradiction_detector.calculate(contradictory_text, "patient") assert score < 1.0 # Should detect contradiction about antibiotics treating viruses - + def test_knowledge_base_contradiction_detection(self, contradiction_detector): """Test detection of contradictions against knowledge base""" contradictory_text = "High blood pressure is not related to heart disease or stroke." score = contradiction_detector.calculate(contradictory_text, "patient") assert score < 1.0 # Should detect contradiction about hypertension consequences - + def test_empty_text_returns_neutral_score(self, contradiction_detector): """Test that empty text returns neutral score""" score = contradiction_detector.calculate("", "patient") assert score == 0.5 - + score = contradiction_detector.calculate(" ", "patient") assert score == 0.5 - + def test_multiple_contradictions_lower_score(self, contradiction_detector): """Test that multiple contradictions result in lower scores""" single_contradiction = "Antibiotics treat viral infections." - multiple_contradictions = ("Antibiotics treat viral infections. " - "You should stop medication immediately when you feel better. " - "Aspirin is safe for everyone to use daily.") - + multiple_contradictions = ( + "Antibiotics treat viral infections. " + "You should stop medication immediately when you feel better. " + "Aspirin is safe for everyone to use daily." + ) + single_score = contradiction_detector.calculate(single_contradiction, "patient") multiple_score = contradiction_detector.calculate(multiple_contradictions, "patient") - + assert multiple_score < single_score - + def test_case_insensitive_detection(self, contradiction_detector): """Test that contradiction detection is case-insensitive""" upper_text = "ANTIBIOTICS TREAT VIRUS INFECTIONS" lower_text = "antibiotics treat virus infections" - + upper_score = contradiction_detector.calculate(upper_text, "patient") lower_score = contradiction_detector.calculate(lower_text, "patient") - + assert upper_score == lower_score assert upper_score < 1.0 class TestInformationPreservation: """Test InformationPreservation metric""" - + @pytest.fixture def info_preservation(self): """Create InformationPreservation instance for testing""" return InformationPreservation() - + def test_initialization(self, info_preservation): """Test proper initialization of InformationPreservation""" - assert hasattr(info_preservation, 'critical_info_patterns') + assert hasattr(info_preservation, "critical_info_patterns") assert isinstance(info_preservation.critical_info_patterns, dict) - assert 'dosages' in info_preservation.critical_info_patterns - assert 'warnings' in info_preservation.critical_info_patterns - + assert "dosages" in info_preservation.critical_info_patterns + assert "warnings" in info_preservation.critical_info_patterns + def test_perfect_preservation_returns_high_score(self, info_preservation): """Test that perfect information preservation returns high score""" original = "Take 10 mg twice daily with food. Do not drink alcohol while taking this medication." generated = "You should take 10 mg two times per day with food. Avoid alcohol while on this medication." - + score = info_preservation.calculate(generated, "patient", original=original) assert score > 0.5 # Should preserve most critical information - + def test_dosage_preservation(self, info_preservation): """Test preservation of dosage information""" original = "Take 20 mg once daily before breakfast." generated_good = "Take 20 mg one time each day before breakfast." generated_bad = "Take medication before breakfast." - + good_score = info_preservation.calculate(generated_good, "patient", original=original) bad_score = info_preservation.calculate(generated_bad, "patient", original=original) - + assert good_score > bad_score - + def test_warning_preservation(self, info_preservation): """Test preservation of warning information""" original = "Do not take with alcohol. Avoid driving. Side effects may include dizziness." generated_good = "Don't drink alcohol. Be careful driving. May cause dizziness." generated_bad = "Take as directed." - + good_score = info_preservation.calculate(generated_good, "patient", original=original) bad_score = info_preservation.calculate(generated_bad, "patient", original=original) - + assert good_score > bad_score - + def test_timing_preservation(self, info_preservation): """Test preservation of timing information""" original = "Take before meals with water on empty stomach." generated_good = "Take before eating with water when stomach is empty." generated_bad = "Take with water." - + good_score = info_preservation.calculate(generated_good, "patient", original=original) bad_score = info_preservation.calculate(generated_bad, "patient", original=original) - + assert good_score > bad_score - + def test_empty_texts_return_zero(self, info_preservation): """Test that empty texts return zero score""" score = info_preservation.calculate("", "patient", original="Some content") assert score == 0.0 - + score = info_preservation.calculate("Some content", "patient", original="") assert score == 0.0 - + def test_no_critical_info_returns_high_score(self, info_preservation): """Test that texts without critical info return high score""" original = "This is general medical information about health." generated = "This discusses general health topics." - + score = info_preservation.calculate(generated, "patient", original=original) assert score == 1.0 # No critical info to preserve - + def test_paraphrased_preservation_detection(self, info_preservation): """Test detection of paraphrased information preservation""" original = "Take 5 mg tablets twice daily" generated = "Take the dose two times each day" # Paraphrased but preserves key info - + score = info_preservation.calculate(generated, "patient", original=original) assert score > 0.0 # Should detect some preservation through paraphrasing class TestHallucinationDetection: """Test HallucinationDetection metric""" - + @pytest.fixture def hallucination_detector(self): """Create HallucinationDetection instance for testing""" return HallucinationDetection() - + def test_initialization(self, hallucination_detector): """Test proper initialization of HallucinationDetection""" - assert hasattr(hallucination_detector, 'medical_entities') + assert hasattr(hallucination_detector, "medical_entities") assert isinstance(hallucination_detector.medical_entities, dict) - assert 'medications' in hallucination_detector.medical_entities - assert 'conditions' in hallucination_detector.medical_entities - + assert "medications" in hallucination_detector.medical_entities + assert "conditions" in hallucination_detector.medical_entities + def test_no_hallucinations_returns_high_score(self, hallucination_detector): """Test that text without hallucinations returns high score""" original = "Patient has hypertension and takes aspirin daily." generated = "The patient has high blood pressure and takes aspirin every day." - + score = hallucination_detector.calculate(generated, "patient", original=original) assert score >= 0.8 # Should be high since no new entities are introduced - + def test_hallucinated_medication_detection(self, hallucination_detector): """Test detection of hallucinated medications""" original = "Patient has headache." generated = "Patient has headache and should take ibuprofen and metformin." # Metformin not related to headache - + score = hallucination_detector.calculate(generated, "patient", original=original) assert score < 1.0 # Should detect hallucinated medications - + def test_hallucinated_condition_detection(self, hallucination_detector): """Test detection of hallucinated medical conditions""" original = "Patient reports fatigue." generated = "Patient has diabetes and hypertension causing fatigue." # New conditions not in original - + score = hallucination_detector.calculate(generated, "patient", original=original) assert score < 1.0 # Should detect hallucinated conditions - + def test_empty_texts_return_neutral_score(self, hallucination_detector): """Test that empty texts return neutral score""" score = hallucination_detector.calculate("", "patient", original="Some content") assert score == 0.5 - + score = hallucination_detector.calculate("Some content", "patient", original="") assert score == 0.5 - + def test_no_entities_returns_high_score(self, hallucination_detector): """Test that text without medical entities returns high score""" original = "This is general health advice." generated = "This provides general health guidance." - + score = hallucination_detector.calculate(generated, "patient", original=original) assert score == 1.0 # No entities, so no hallucinations - + def test_entity_extraction_methods(self, hallucination_detector): """Test different entity extraction methods""" text = "Patient has diabetes and takes insulin and aspirin." entities = hallucination_detector._extract_medical_entities(text) - + # Should extract known medical entities - assert 'diabetes' in entities - assert 'insulin' in entities - assert 'aspirin' in entities - - @patch('src.evaluator.spacy') + assert "diabetes" in entities + assert "insulin" in entities + assert "aspirin" in entities + + @patch("src.evaluator.spacy") def test_spacy_ner_integration(self, mock_spacy, hallucination_detector): """Test spaCy NER integration when available""" # Mock spaCy NLP pipeline @@ -236,33 +238,33 @@ def test_spacy_ner_integration(self, mock_spacy, hallucination_detector): mock_ent.label_ = "PRODUCT" mock_doc.ents = [mock_ent] mock_nlp.return_value = mock_doc - + # Create detector with mocked spaCy detector = HallucinationDetection() detector.nlp = mock_nlp - + text = "Patient takes custom_medication" entities = detector._extract_medical_entities(text) - + # Should include spaCy-detected entities if they match medical terms mock_nlp.assert_called_once_with(text) - + def test_case_insensitive_entity_extraction(self, hallucination_detector): """Test that entity extraction is case-insensitive""" text_upper = "PATIENT HAS DIABETES AND TAKES INSULIN" text_lower = "patient has diabetes and takes insulin" - + entities_upper = hallucination_detector._extract_medical_entities(text_upper) entities_lower = hallucination_detector._extract_medical_entities(text_lower) - + assert entities_upper == entities_lower - assert 'diabetes' in entities_upper - assert 'insulin' in entities_upper + assert "diabetes" in entities_upper + assert "insulin" in entities_upper class TestEvaluatorIntegration: """Integration tests for new metrics with MEQBenchEvaluator""" - + def test_evaluation_score_with_new_metrics(self): """Test that EvaluationScore includes new metrics""" score = EvaluationScore( @@ -274,13 +276,13 @@ def test_evaluation_score_with_new_metrics(self): contradiction=0.9, information_preservation=0.8, hallucination=0.8, - overall=0.8 + overall=0.8, ) - + assert score.contradiction == 0.9 assert score.information_preservation == 0.8 assert score.hallucination == 0.8 - + def test_evaluation_score_to_dict_includes_new_metrics(self): """Test that to_dict method includes new metrics""" score = EvaluationScore( @@ -292,115 +294,115 @@ def test_evaluation_score_to_dict_includes_new_metrics(self): contradiction=0.9, information_preservation=0.8, hallucination=0.8, - overall=0.8 + overall=0.8, ) - + score_dict = score.to_dict() - assert 'contradiction' in score_dict - assert 'information_preservation' in score_dict - assert 'hallucination' in score_dict - assert score_dict['contradiction'] == 0.9 - assert score_dict['information_preservation'] == 0.8 - assert score_dict['hallucination'] == 0.8 - - @patch('src.evaluator.config') + assert "contradiction" in score_dict + assert "information_preservation" in score_dict + assert "hallucination" in score_dict + assert score_dict["contradiction"] == 0.9 + assert score_dict["information_preservation"] == 0.8 + assert score_dict["hallucination"] == 0.8 + + @patch("src.evaluator.config") def test_new_metrics_error_handling(self, mock_config): """Test error handling in new metrics""" # Mock config to avoid configuration issues - mock_config.get_evaluation_config.return_value = { - 'safety': {'danger_words': [], 'safety_words': []} - } + mock_config.get_evaluation_config.return_value = {"safety": {"danger_words": [], "safety_words": []}} mock_config.get_scoring_config.return_value = { - 'weights': {'readability': 0.2, 'terminology': 0.2, 'safety': 0.2, - 'coverage': 0.2, 'quality': 0.2}, - 'parameters': {'safety_multiplier': 0.5} + "weights": {"readability": 0.2, "terminology": 0.2, "safety": 0.2, "coverage": 0.2, "quality": 0.2}, + "parameters": {"safety_multiplier": 0.5}, } - mock_config.get_audiences.return_value = ['physician', 'nurse', 'patient', 'caregiver'] - + mock_config.get_audiences.return_value = ["physician", "nurse", "patient", "caregiver"] + # Create evaluator instance from src.evaluator import MEQBenchEvaluator + evaluator = MEQBenchEvaluator() - + # Test that new metrics are properly initialized - assert hasattr(evaluator, 'contradiction_detector') - assert hasattr(evaluator, 'information_preservation') - assert hasattr(evaluator, 'hallucination_detector') - + assert hasattr(evaluator, "contradiction_detector") + assert hasattr(evaluator, "information_preservation") + assert hasattr(evaluator, "hallucination_detector") + def test_metric_calculation_robustness(self): """Test that metrics handle edge cases robustly""" # Test all metrics with edge cases contradiction_detector = ContradictionDetection() info_preservation = InformationPreservation() hallucination_detector = HallucinationDetection() - + # Edge case: very short text short_text = "Hi" original_short = "Hello" - + # All metrics should handle short text without crashing try: contradiction_score = contradiction_detector.calculate(short_text, "patient") info_score = info_preservation.calculate(short_text, "patient", original=original_short) hallucination_score = hallucination_detector.calculate(short_text, "patient", original=original_short) - + # Scores should be valid floats between 0 and 1 assert 0 <= contradiction_score <= 1 - assert 0 <= info_score <= 1 + assert 0 <= info_score <= 1 assert 0 <= hallucination_score <= 1 - + except Exception as e: pytest.fail(f"Metrics should handle short text without crashing: {e}") - + # Edge case: text with special characters special_text = "Take 10mg @#$%^&*() twice daily!!!" original_special = "Medication: 10mg dosage instructions." - + try: contradiction_score = contradiction_detector.calculate(special_text, "patient") info_score = info_preservation.calculate(special_text, "patient", original=original_special) hallucination_score = hallucination_detector.calculate(special_text, "patient", original=original_special) - + assert 0 <= contradiction_score <= 1 assert 0 <= info_score <= 1 assert 0 <= hallucination_score <= 1 - + except Exception as e: pytest.fail(f"Metrics should handle special characters without crashing: {e}") class TestMetricPerformance: """Test performance characteristics of new metrics""" - + def test_metrics_complete_quickly(self): """Test that metrics complete within reasonable time""" import time - + # Create large text for performance testing - large_text = ("This is a medical explanation about hypertension and diabetes. " * 100 + - "Patient should take medication as prescribed. " * 50 + - "Avoid alcohol and maintain healthy diet. " * 50) - + large_text = ( + "This is a medical explanation about hypertension and diabetes. " * 100 + + "Patient should take medication as prescribed. " * 50 + + "Avoid alcohol and maintain healthy diet. " * 50 + ) + original_text = "Patient has hypertension and diabetes requiring medication management." - + # Initialize metrics contradiction_detector = ContradictionDetection() info_preservation = InformationPreservation() hallucination_detector = HallucinationDetection() - + # Time each metric start_time = time.time() contradiction_detector.calculate(large_text, "patient") contradiction_time = time.time() - start_time - + start_time = time.time() info_preservation.calculate(large_text, "patient", original=original_text) info_time = time.time() - start_time - + start_time = time.time() hallucination_detector.calculate(large_text, "patient", original=original_text) hallucination_time = time.time() - start_time - + # Each metric should complete within 5 seconds for large text assert contradiction_time < 5.0, f"ContradictionDetection took {contradiction_time:.2f}s" assert info_time < 5.0, f"InformationPreservation took {info_time:.2f}s" - assert hallucination_time < 5.0, f"HallucinationDetection took {hallucination_time:.2f}s" \ No newline at end of file + assert hallucination_time < 5.0, f"HallucinationDetection took {hallucination_time:.2f}s" diff --git a/tests/test_leaderboard.py b/tests/test_leaderboard.py index bbb7cac..df03b2d 100644 --- a/tests/test_leaderboard.py +++ b/tests/test_leaderboard.py @@ -13,7 +13,7 @@ class TestLeaderboardGenerator: """Test LeaderboardGenerator class""" - + @pytest.fixture def sample_results_data(self): """Sample evaluation results data for testing""" @@ -25,20 +25,16 @@ def sample_results_data(self): "physician": [0.9, 0.8, 0.85], "nurse": [0.8, 0.75, 0.8], "patient": [0.7, 0.65, 0.7], - "caregiver": [0.75, 0.7, 0.75] - }, - "complexity_scores": { - "basic": [0.8, 0.75], - "intermediate": [0.7, 0.8], - "advanced": [0.6, 0.65] + "caregiver": [0.75, 0.7, 0.75], }, + "complexity_scores": {"basic": [0.8, 0.75], "intermediate": [0.7, 0.8], "advanced": [0.6, 0.65]}, "summary": { "overall_mean": 0.75, "physician_mean": 0.85, "nurse_mean": 0.78, "patient_mean": 0.68, - "caregiver_mean": 0.73 - } + "caregiver_mean": 0.73, + }, }, { "model_name": "Claude-3", @@ -47,20 +43,16 @@ def sample_results_data(self): "physician": [0.85, 0.9, 0.88], "nurse": [0.82, 0.85, 0.83], "patient": [0.75, 0.8, 0.78], - "caregiver": [0.8, 0.82, 0.81] - }, - "complexity_scores": { - "basic": [0.85, 0.9], - "intermediate": [0.8, 0.85], - "advanced": [0.75, 0.8] + "caregiver": [0.8, 0.82, 0.81], }, + "complexity_scores": {"basic": [0.85, 0.9], "intermediate": [0.8, 0.85], "advanced": [0.75, 0.8]}, "summary": { "overall_mean": 0.82, "physician_mean": 0.88, "nurse_mean": 0.83, "patient_mean": 0.78, - "caregiver_mean": 0.81 - } + "caregiver_mean": 0.81, + }, }, { "model_name": "LLaMA-2", @@ -69,191 +61,187 @@ def sample_results_data(self): "physician": [0.7, 0.75, 0.72], "nurse": [0.68, 0.7, 0.69], "patient": [0.6, 0.65, 0.62], - "caregiver": [0.65, 0.68, 0.66] - }, - "complexity_scores": { - "basic": [0.7, 0.75], - "intermediate": [0.65, 0.7], - "advanced": [0.55, 0.6] + "caregiver": [0.65, 0.68, 0.66], }, + "complexity_scores": {"basic": [0.7, 0.75], "intermediate": [0.65, 0.7], "advanced": [0.55, 0.6]}, "summary": { "overall_mean": 0.67, "physician_mean": 0.72, "nurse_mean": 0.69, "patient_mean": 0.62, - "caregiver_mean": 0.66 - } - } + "caregiver_mean": 0.66, + }, + }, ] - + @pytest.fixture def leaderboard_generator(self): """Create LeaderboardGenerator instance for testing""" return LeaderboardGenerator() - + def test_initialization(self, leaderboard_generator): """Test proper initialization of LeaderboardGenerator""" - assert hasattr(leaderboard_generator, 'results_data') - assert hasattr(leaderboard_generator, 'benchmark_stats') + assert hasattr(leaderboard_generator, "results_data") + assert hasattr(leaderboard_generator, "benchmark_stats") assert leaderboard_generator.results_data == [] assert leaderboard_generator.benchmark_stats == {} - + def test_load_results_file_not_found(self, leaderboard_generator): """Test error handling when results directory doesn't exist""" with pytest.raises(FileNotFoundError): leaderboard_generator.load_results(Path("/nonexistent/directory")) - + def test_load_results_no_json_files(self, leaderboard_generator): """Test error handling when no JSON files found""" with tempfile.TemporaryDirectory() as temp_dir: # Create directory with no JSON files Path(temp_dir, "not_json.txt").write_text("not json") - + with pytest.raises(ValueError, match="No JSON result files found"): leaderboard_generator.load_results(Path(temp_dir)) - + def test_load_results_invalid_json(self, leaderboard_generator): """Test handling of invalid JSON files""" with tempfile.TemporaryDirectory() as temp_dir: # Create invalid JSON file invalid_file = Path(temp_dir) / "invalid.json" invalid_file.write_text("invalid json content") - + # Should log error but not crash leaderboard_generator.load_results(Path(temp_dir)) assert len(leaderboard_generator.results_data) == 0 - + def test_load_results_missing_required_fields(self, leaderboard_generator): """Test handling of JSON files missing required fields""" with tempfile.TemporaryDirectory() as temp_dir: # Create JSON file missing required fields incomplete_data = {"model_name": "TestModel"} # Missing other required fields incomplete_file = Path(temp_dir) / "incomplete.json" - with open(incomplete_file, 'w') as f: + with open(incomplete_file, "w") as f: json.dump(incomplete_data, f) - + # Should skip invalid files leaderboard_generator.load_results(Path(temp_dir)) assert len(leaderboard_generator.results_data) == 0 - + def test_load_results_valid_data(self, leaderboard_generator, sample_results_data): """Test loading valid results data""" with tempfile.TemporaryDirectory() as temp_dir: # Create valid JSON files for i, result in enumerate(sample_results_data): result_file = Path(temp_dir) / f"result_{i}.json" - with open(result_file, 'w') as f: + with open(result_file, "w") as f: json.dump(result, f) - + leaderboard_generator.load_results(Path(temp_dir)) - + assert len(leaderboard_generator.results_data) == 3 - assert leaderboard_generator.results_data[0]['model_name'] == "GPT-4" - assert leaderboard_generator.results_data[1]['model_name'] == "Claude-3" - assert leaderboard_generator.results_data[2]['model_name'] == "LLaMA-2" - + assert leaderboard_generator.results_data[0]["model_name"] == "GPT-4" + assert leaderboard_generator.results_data[1]["model_name"] == "Claude-3" + assert leaderboard_generator.results_data[2]["model_name"] == "LLaMA-2" + def test_calculate_leaderboard_stats_empty_data(self, leaderboard_generator): """Test stats calculation with empty data""" stats = leaderboard_generator.calculate_leaderboard_stats() assert stats == {} - + def test_calculate_leaderboard_stats_valid_data(self, leaderboard_generator, sample_results_data): """Test stats calculation with valid data""" leaderboard_generator.results_data = sample_results_data stats = leaderboard_generator.calculate_leaderboard_stats() - - assert stats['total_models'] == 3 - assert stats['total_evaluations'] == 300 # 3 models * 100 items each - assert set(stats['audiences']) == {'physician', 'nurse', 'patient', 'caregiver'} - assert set(stats['complexity_levels']) == {'basic', 'intermediate', 'advanced'} - assert stats['best_score'] == 0.82 # Claude-3's score - assert stats['worst_score'] == 0.67 # LLaMA-2's score - assert 0.6 < stats['average_score'] < 0.8 - assert 'last_updated' in stats - + + assert stats["total_models"] == 3 + assert stats["total_evaluations"] == 300 # 3 models * 100 items each + assert set(stats["audiences"]) == {"physician", "nurse", "patient", "caregiver"} + assert set(stats["complexity_levels"]) == {"basic", "intermediate", "advanced"} + assert stats["best_score"] == 0.82 # Claude-3's score + assert stats["worst_score"] == 0.67 # LLaMA-2's score + assert 0.6 < stats["average_score"] < 0.8 + assert "last_updated" in stats + def test_rank_models(self, leaderboard_generator, sample_results_data): """Test model ranking functionality""" leaderboard_generator.results_data = sample_results_data ranked_models = leaderboard_generator.rank_models() - + # Should be ranked by overall_mean in descending order assert len(ranked_models) == 3 - assert ranked_models[0]['model_name'] == "Claude-3" # Highest score (0.82) - assert ranked_models[0]['rank'] == 1 - assert ranked_models[0]['overall_score'] == 0.82 - - assert ranked_models[1]['model_name'] == "GPT-4" # Middle score (0.75) - assert ranked_models[1]['rank'] == 2 - assert ranked_models[1]['overall_score'] == 0.75 - - assert ranked_models[2]['model_name'] == "LLaMA-2" # Lowest score (0.67) - assert ranked_models[2]['rank'] == 3 - assert ranked_models[2]['overall_score'] == 0.67 - + assert ranked_models[0]["model_name"] == "Claude-3" # Highest score (0.82) + assert ranked_models[0]["rank"] == 1 + assert ranked_models[0]["overall_score"] == 0.82 + + assert ranked_models[1]["model_name"] == "GPT-4" # Middle score (0.75) + assert ranked_models[1]["rank"] == 2 + assert ranked_models[1]["overall_score"] == 0.75 + + assert ranked_models[2]["model_name"] == "LLaMA-2" # Lowest score (0.67) + assert ranked_models[2]["rank"] == 3 + assert ranked_models[2]["overall_score"] == 0.67 + def test_generate_audience_breakdown(self, leaderboard_generator, sample_results_data): """Test audience-specific performance breakdown""" leaderboard_generator.results_data = sample_results_data ranked_models = leaderboard_generator.rank_models() audience_breakdown = leaderboard_generator.generate_audience_breakdown(ranked_models) - + # Should have breakdown for each audience - assert set(audience_breakdown.keys()) == {'physician', 'nurse', 'patient', 'caregiver'} - + assert set(audience_breakdown.keys()) == {"physician", "nurse", "patient", "caregiver"} + # Check physician audience breakdown - physician_breakdown = audience_breakdown['physician'] + physician_breakdown = audience_breakdown["physician"] assert len(physician_breakdown) == 3 - + # Models should be ranked by their physician-specific scores - physician_scores = [model['score'] for model in physician_breakdown] + physician_scores = [model["score"] for model in physician_breakdown] assert physician_scores == sorted(physician_scores, reverse=True) - + # Check that rankings are assigned correctly for i, model in enumerate(physician_breakdown): - assert model['rank'] == i + 1 - assert 'model_name' in model - assert 'score' in model - assert 'num_items' in model - + assert model["rank"] == i + 1 + assert "model_name" in model + assert "score" in model + assert "num_items" in model + def test_generate_complexity_breakdown(self, leaderboard_generator, sample_results_data): """Test complexity-specific performance breakdown""" leaderboard_generator.results_data = sample_results_data ranked_models = leaderboard_generator.rank_models() complexity_breakdown = leaderboard_generator.generate_complexity_breakdown(ranked_models) - + # Should have breakdown for each complexity level - assert set(complexity_breakdown.keys()) == {'basic', 'intermediate', 'advanced'} - + assert set(complexity_breakdown.keys()) == {"basic", "intermediate", "advanced"} + # Check basic complexity breakdown - basic_breakdown = complexity_breakdown['basic'] + basic_breakdown = complexity_breakdown["basic"] assert len(basic_breakdown) == 3 - + # Models should be ranked by their basic-specific scores - basic_scores = [model['score'] for model in basic_breakdown] + basic_scores = [model["score"] for model in basic_breakdown] assert basic_scores == sorted(basic_scores, reverse=True) - + # Check that rankings are assigned correctly for i, model in enumerate(basic_breakdown): - assert model['rank'] == i + 1 - + assert model["rank"] == i + 1 + def test_generate_html_no_data_raises_error(self, leaderboard_generator): """Test that HTML generation raises error with no data""" with tempfile.TemporaryDirectory() as temp_dir: output_path = Path(temp_dir) / "test.html" - + with pytest.raises(ValueError, match="No results data loaded"): leaderboard_generator.generate_html(output_path) - + def test_generate_html_creates_file(self, leaderboard_generator, sample_results_data): """Test that HTML generation creates file""" leaderboard_generator.results_data = sample_results_data - + with tempfile.TemporaryDirectory() as temp_dir: output_path = Path(temp_dir) / "leaderboard.html" leaderboard_generator.generate_html(output_path) - + assert output_path.exists() html_content = output_path.read_text() - + # Check that HTML contains expected elements assert "" in html_content assert "MEQ-Bench Leaderboard" in html_content @@ -262,134 +250,134 @@ def test_generate_html_creates_file(self, leaderboard_generator, sample_results_ assert "LLaMA-2" in html_content assert "physician" in html_content assert "patient" in html_content - + def test_generate_html_creates_parent_directory(self, leaderboard_generator, sample_results_data): """Test that HTML generation creates parent directories""" leaderboard_generator.results_data = sample_results_data - + with tempfile.TemporaryDirectory() as temp_dir: output_path = Path(temp_dir) / "subdir" / "leaderboard.html" leaderboard_generator.generate_html(output_path) - + assert output_path.exists() assert output_path.parent.exists() - + def test_html_template_structure(self, leaderboard_generator, sample_results_data): """Test that generated HTML has proper structure""" leaderboard_generator.results_data = sample_results_data - + with tempfile.TemporaryDirectory() as temp_dir: output_path = Path(temp_dir) / "test.html" leaderboard_generator.generate_html(output_path) - + html_content = output_path.read_text() - + # Check HTML structure assert "" in html_content assert "" in html_content assert "" in html_content assert "</html>" in html_content - + # Check CSS inclusion assert "<style>" in html_content assert "font-family" in html_content - + # Check JavaScript inclusion assert "<script>" in html_content assert "function showTab" in html_content - + # Check Chart.js inclusion assert "chart.js" in html_content - + def test_overall_rankings_table_generation(self, leaderboard_generator, sample_results_data): """Test generation of overall rankings table""" leaderboard_generator.results_data = sample_results_data ranked_models = leaderboard_generator.rank_models() - + table_html = leaderboard_generator._generate_overall_rankings_table(ranked_models) - + # Check table structure assert "<table>" in table_html assert "<thead>" in table_html assert "<tbody>" in table_html assert "</table>" in table_html - + # Check table headers assert "Rank" in table_html assert "Model" in table_html assert "Overall Score" in table_html assert "Physician" in table_html assert "Patient" in table_html - + # Check model data assert "Claude-3" in table_html assert "#1" in table_html # First rank assert "rank-1" in table_html # CSS class for first place - + def test_audience_breakdown_section_generation(self, leaderboard_generator, sample_results_data): """Test generation of audience breakdown section""" leaderboard_generator.results_data = sample_results_data ranked_models = leaderboard_generator.rank_models() audience_breakdown = leaderboard_generator.generate_audience_breakdown(ranked_models) - + section_html = leaderboard_generator._generate_audience_breakdown_section(audience_breakdown) - + # Check section structure assert "audience-section" in section_html assert "Physician Audience Rankings" in section_html assert "Patient Audience Rankings" in section_html - + # Check that all models appear in sections assert "Claude-3" in section_html assert "GPT-4" in section_html assert "LLaMA-2" in section_html - + def test_complexity_breakdown_section_generation(self, leaderboard_generator, sample_results_data): """Test generation of complexity breakdown section""" leaderboard_generator.results_data = sample_results_data ranked_models = leaderboard_generator.rank_models() complexity_breakdown = leaderboard_generator.generate_complexity_breakdown(ranked_models) - + section_html = leaderboard_generator._generate_complexity_breakdown_section(complexity_breakdown) - + # Check section structure assert "complexity-section" in section_html assert "Basic Complexity Level Rankings" in section_html assert "Advanced Complexity Level Rankings" in section_html - + # Check that all models appear in sections assert "Claude-3" in section_html assert "GPT-4" in section_html assert "LLaMA-2" in section_html - + def test_javascript_generation(self, leaderboard_generator, sample_results_data): """Test JavaScript generation for interactive features""" leaderboard_generator.results_data = sample_results_data ranked_models = leaderboard_generator.rank_models() audience_breakdown = leaderboard_generator.generate_audience_breakdown(ranked_models) stats = leaderboard_generator.calculate_leaderboard_stats() - + js_content = leaderboard_generator._generate_javascript(ranked_models, audience_breakdown, stats) - + # Check JavaScript structure assert "function showTab" in js_content assert "function initCharts" in js_content assert "Chart(" in js_content - + # Check data inclusion assert "Claude-3" in js_content or "Claude" in js_content # Model names assert "physician" in js_content assert "patient" in js_content - + # Check chart types assert "'bar'" in js_content assert "'radar'" in js_content - + def test_css_styles_generation(self, leaderboard_generator): """Test CSS styles generation""" css_content = leaderboard_generator._get_css_styles() - + # Check CSS structure and important styles assert "body {" in css_content assert "font-family" in css_content @@ -398,7 +386,7 @@ def test_css_styles_generation(self, leaderboard_generator): assert ".rank-1" in css_content assert ".rank-2" in css_content assert ".rank-3" in css_content - + # Check responsive design assert "@media" in css_content assert "max-width" in css_content @@ -406,49 +394,49 @@ def test_css_styles_generation(self, leaderboard_generator): class TestLeaderboardIntegration: """Integration tests for leaderboard functionality""" - + def test_end_to_end_leaderboard_generation(self, sample_results_data): """Test complete end-to-end leaderboard generation""" with tempfile.TemporaryDirectory() as temp_dir: # Create sample result files results_dir = Path(temp_dir) / "results" results_dir.mkdir() - + for i, result in enumerate(sample_results_data): result_file = results_dir / f"model_{i}_results.json" - with open(result_file, 'w') as f: + with open(result_file, "w") as f: json.dump(result, f) - + # Generate leaderboard generator = LeaderboardGenerator() generator.load_results(results_dir) - + output_path = Path(temp_dir) / "leaderboard.html" generator.generate_html(output_path) - + # Verify the complete leaderboard assert output_path.exists() html_content = output_path.read_text() - + # Check that all major components are present assert "MEQ-Bench Leaderboard" in html_content assert "Overall Rankings" in html_content assert "By Audience" in html_content assert "By Complexity" in html_content assert "Analytics" in html_content - + # Check that all models are represented for result in sample_results_data: - assert result['model_name'] in html_content - + assert result["model_name"] in html_content + # Check that audience types are represented - for audience in ['physician', 'nurse', 'patient', 'caregiver']: + for audience in ["physician", "nurse", "patient", "caregiver"]: assert audience in html_content - + # Check that complexity levels are represented - for complexity in ['basic', 'intermediate', 'advanced']: + for complexity in ["basic", "intermediate", "advanced"]: assert complexity in html_content - + def test_leaderboard_with_mixed_data_quality(self): """Test leaderboard generation with mixed quality data""" mixed_results = [ @@ -458,52 +446,52 @@ def test_leaderboard_with_mixed_data_quality(self): "total_items": 100, "audience_scores": {"physician": [0.8, 0.9], "patient": [0.7, 0.8]}, "complexity_scores": {"basic": [0.8], "advanced": [0.7]}, - "summary": {"overall_mean": 0.8} + "summary": {"overall_mean": 0.8}, }, # Missing complexity scores { "model_name": "PartialModel", "total_items": 50, "audience_scores": {"physician": [0.7], "patient": [0.6]}, - "summary": {"overall_mean": 0.65} + "summary": {"overall_mean": 0.65}, }, # Minimal data { "model_name": "MinimalModel", "total_items": 25, "audience_scores": {"patient": [0.5]}, - "summary": {"overall_mean": 0.5} - } + "summary": {"overall_mean": 0.5}, + }, ] - + with tempfile.TemporaryDirectory() as temp_dir: results_dir = Path(temp_dir) / "results" results_dir.mkdir() - + for i, result in enumerate(mixed_results): result_file = results_dir / f"result_{i}.json" - with open(result_file, 'w') as f: + with open(result_file, "w") as f: json.dump(result, f) - + # Should handle mixed data gracefully generator = LeaderboardGenerator() generator.load_results(results_dir) - + output_path = Path(temp_dir) / "leaderboard.html" generator.generate_html(output_path) - + assert output_path.exists() html_content = output_path.read_text() - + # All models should still appear assert "CompleteModel" in html_content assert "PartialModel" in html_content assert "MinimalModel" in html_content - + def test_leaderboard_performance_with_large_dataset(self): """Test leaderboard performance with larger dataset""" import time - + # Create larger dataset large_results = [] for i in range(20): # 20 models @@ -514,43 +502,43 @@ def test_leaderboard_performance_with_large_dataset(self): "physician": [0.5 + (i * 0.02)] * 100, "nurse": [0.6 + (i * 0.015)] * 100, "patient": [0.4 + (i * 0.025)] * 100, - "caregiver": [0.55 + (i * 0.02)] * 100 + "caregiver": [0.55 + (i * 0.02)] * 100, }, "complexity_scores": { "basic": [0.6 + (i * 0.02)] * 100, "intermediate": [0.5 + (i * 0.02)] * 100, - "advanced": [0.4 + (i * 0.02)] * 100 + "advanced": [0.4 + (i * 0.02)] * 100, }, - "summary": {"overall_mean": 0.5 + (i * 0.02)} + "summary": {"overall_mean": 0.5 + (i * 0.02)}, } large_results.append(result) - + with tempfile.TemporaryDirectory() as temp_dir: results_dir = Path(temp_dir) / "results" results_dir.mkdir() - + for i, result in enumerate(large_results): result_file = results_dir / f"result_{i}.json" - with open(result_file, 'w') as f: + with open(result_file, "w") as f: json.dump(result, f) - + # Time the leaderboard generation start_time = time.time() - + generator = LeaderboardGenerator() generator.load_results(results_dir) - + output_path = Path(temp_dir) / "leaderboard.html" generator.generate_html(output_path) - + end_time = time.time() generation_time = end_time - start_time - + # Should complete within reasonable time (30 seconds for large dataset) assert generation_time < 30.0, f"Leaderboard generation took {generation_time:.2f}s" - + # Verify output quality assert output_path.exists() html_content = output_path.read_text() assert len(html_content) > 10000 # Should be substantial HTML content - assert "Model_19" in html_content # Highest ranked model should appear \ No newline at end of file + assert "Model_19" in html_content # Highest ranked model should appear diff --git a/tests/test_process_datasets.py b/tests/test_process_datasets.py index 2257543..6b9f016 100644 --- a/tests/test_process_datasets.py +++ b/tests/test_process_datasets.py @@ -19,7 +19,7 @@ balance_complexity_distribution, validate_dataset, print_dataset_statistics, - setup_argument_parser + setup_argument_parser, ) except ImportError: # If direct import fails, we'll test through subprocess calls @@ -30,54 +30,54 @@ class TestDatasetLimitsCalculation: """Test dataset limits calculation""" - + def test_equal_distribution_three_datasets(self): """Test equal distribution across three datasets""" limits = calculate_dataset_limits(1000, 3) - + assert len(limits) == 3 - assert 'medqa' in limits - assert 'icliniq' in limits - assert 'cochrane' in limits - + assert "medqa" in limits + assert "icliniq" in limits + assert "cochrane" in limits + # Should distribute as evenly as possible total = sum(limits.values()) assert total == 1000 - + # Each should get approximately 333-334 items for limit in limits.values(): assert 333 <= limit <= 334 - + def test_equal_distribution_two_datasets(self): """Test equal distribution across two datasets""" limits = calculate_dataset_limits(1000, 2) - + assert len(limits) == 2 total = sum(limits.values()) assert total == 1000 - + # Each should get 500 items for limit in limits.values(): assert limit == 500 - + def test_uneven_distribution_handling(self): """Test handling of uneven divisions""" limits = calculate_dataset_limits(1001, 3) - + total = sum(limits.values()) assert total == 1001 - + # Should handle remainder correctly limit_values = list(limits.values()) assert max(limit_values) - min(limit_values) <= 1 # Difference should be at most 1 - + def test_small_total_items(self): """Test handling of small total item counts""" limits = calculate_dataset_limits(5, 3) - + total = sum(limits.values()) assert total == 5 - + # Some datasets might get 1 item, others 2 for limit in limits.values(): assert limit >= 1 @@ -85,83 +85,87 @@ def test_small_total_items(self): class TestComplexityBalancing: """Test complexity distribution balancing""" - + @pytest.fixture def sample_items(self): """Create sample items with different complexity levels""" items = [] - + # Create 10 basic items for i in range(10): - items.append(MEQBenchItem( - id=f"basic_{i}", - medical_content=f"Basic medical content {i}", - complexity_level="basic", - source_dataset="test" - )) - + items.append( + MEQBenchItem( + id=f"basic_{i}", + medical_content=f"Basic medical content {i}", + complexity_level="basic", + source_dataset="test", + ) + ) + # Create 5 intermediate items for i in range(5): - items.append(MEQBenchItem( - id=f"intermediate_{i}", - medical_content=f"Intermediate medical content {i}", - complexity_level="intermediate", - source_dataset="test" - )) - + items.append( + MEQBenchItem( + id=f"intermediate_{i}", + medical_content=f"Intermediate medical content {i}", + complexity_level="intermediate", + source_dataset="test", + ) + ) + # Create 2 advanced items for i in range(2): - items.append(MEQBenchItem( - id=f"advanced_{i}", - medical_content=f"Advanced medical content {i}", - complexity_level="advanced", - source_dataset="test" - )) - + items.append( + MEQBenchItem( + id=f"advanced_{i}", + medical_content=f"Advanced medical content {i}", + complexity_level="advanced", + source_dataset="test", + ) + ) + return items - - @patch('process_datasets.random.sample') - @patch('process_datasets.random.seed') + + @patch("process_datasets.random.sample") + @patch("process_datasets.random.seed") def test_balance_complexity_distribution(self, mock_seed, mock_sample, sample_items): """Test complexity balancing functionality""" + # Mock random.sample to return predictable results def side_effect(population, k): return population[:k] # Return first k items - + mock_sample.side_effect = side_effect - + balanced_items = balance_complexity_distribution(sample_items) - + # Should have roughly equal distribution complexity_counts = {} for item in balanced_items: complexity_counts[item.complexity_level] = complexity_counts.get(item.complexity_level, 0) + 1 - + # Check that balancing was attempted assert len(balanced_items) <= len(sample_items) - assert 'basic' in complexity_counts - assert 'intermediate' in complexity_counts - assert 'advanced' in complexity_counts - + assert "basic" in complexity_counts + assert "intermediate" in complexity_counts + assert "advanced" in complexity_counts + def test_balance_empty_items(self): """Test balancing with empty item list""" balanced_items = balance_complexity_distribution([]) assert balanced_items == [] - + def test_balance_single_complexity_level(self): """Test balancing when only one complexity level exists""" items = [ MEQBenchItem( - id=f"basic_{i}", - medical_content=f"Basic content {i}", - complexity_level="basic", - source_dataset="test" + id=f"basic_{i}", medical_content=f"Basic content {i}", complexity_level="basic", source_dataset="test" ) for i in range(5) ] - + balanced_items = balance_complexity_distribution(items) - + # Should still return items even if balancing isn't possible assert len(balanced_items) > 0 assert all(item.complexity_level == "basic" for item in balanced_items) @@ -169,15 +173,15 @@ def test_balance_single_complexity_level(self): class TestDatasetValidation: """Test dataset validation functionality""" - + def test_validate_empty_dataset(self): """Test validation of empty dataset""" report = validate_dataset([]) - - assert report['valid'] is False - assert report['total_items'] == 0 - assert "Dataset is empty" in report['issues'] - + + assert report["valid"] is False + assert report["total_items"] == 0 + assert "Dataset is empty" in report["issues"] + def test_validate_valid_dataset(self): """Test validation of valid dataset""" items = [ @@ -185,24 +189,24 @@ def test_validate_valid_dataset(self): id="test_001", medical_content="This is valid medical content for testing purposes and is long enough.", complexity_level="basic", - source_dataset="test" + source_dataset="test", ), MEQBenchItem( id="test_002", medical_content="This is another valid medical content item for comprehensive testing.", complexity_level="intermediate", - source_dataset="test" - ) + source_dataset="test", + ), ] - + report = validate_dataset(items) - - assert report['valid'] is True - assert report['total_items'] == 2 - assert len(report['issues']) == 0 - assert 'complexity_distribution' in report['statistics'] - assert 'source_distribution' in report['statistics'] - + + assert report["valid"] is True + assert report["total_items"] == 2 + assert len(report["issues"]) == 0 + assert "complexity_distribution" in report["statistics"] + assert "source_distribution" in report["statistics"] + def test_validate_duplicate_ids(self): """Test detection of duplicate IDs""" items = [ @@ -210,37 +214,32 @@ def test_validate_duplicate_ids(self): id="duplicate_id", medical_content="First item with duplicate ID and sufficient content length.", complexity_level="basic", - source_dataset="test" + source_dataset="test", ), MEQBenchItem( id="duplicate_id", # Same ID medical_content="Second item with duplicate ID and sufficient content length.", complexity_level="intermediate", - source_dataset="test" - ) + source_dataset="test", + ), ] - + report = validate_dataset(items) - - assert report['valid'] is False - assert any("Duplicate item IDs" in issue for issue in report['issues']) - + + assert report["valid"] is False + assert any("Duplicate item IDs" in issue for issue in report["issues"]) + def test_validate_short_content(self): """Test detection of very short content""" items = [ - MEQBenchItem( - id="test_001", - medical_content="Short", # Too short - complexity_level="basic", - source_dataset="test" - ) + MEQBenchItem(id="test_001", medical_content="Short", complexity_level="basic", source_dataset="test") # Too short ] - + report = validate_dataset(items) - - assert report['valid'] is False - assert any("very short content" in issue for issue in report['issues']) - + + assert report["valid"] is False + assert any("very short content" in issue for issue in report["issues"]) + def test_validate_missing_complexity_levels(self): """Test warning for missing complexity levels""" items = [ @@ -248,15 +247,15 @@ def test_validate_missing_complexity_levels(self): id="test_001", medical_content="This is valid medical content with sufficient length for testing.", complexity_level="basic", # Only basic level - source_dataset="test" + source_dataset="test", ) ] - + report = validate_dataset(items) - - assert report['valid'] is True # Still valid, just warning - assert any("Not all complexity levels represented" in warning for warning in report['warnings']) - + + assert report["valid"] is True # Still valid, just warning + assert any("Not all complexity levels represented" in warning for warning in report["warnings"]) + def test_validate_content_length_statistics(self): """Test content length statistics calculation""" items = [ @@ -264,38 +263,39 @@ def test_validate_content_length_statistics(self): id="test_001", medical_content="Short but valid medical content for testing purposes.", # ~50 chars complexity_level="basic", - source_dataset="test" + source_dataset="test", ), MEQBenchItem( id="test_002", - medical_content="This is a much longer medical content item that contains significantly more text to test the average content length calculation functionality." * 3, # ~400+ chars + medical_content="This is a much longer medical content item that contains significantly more text to test the average content length calculation functionality." + * 3, # ~400+ chars complexity_level="intermediate", - source_dataset="test" - ) + source_dataset="test", + ), ] - + report = validate_dataset(items) - - assert 'content_length' in report['statistics'] - assert 'average' in report['statistics']['content_length'] - assert 'minimum' in report['statistics']['content_length'] - assert 'maximum' in report['statistics']['content_length'] - + + assert "content_length" in report["statistics"] + assert "average" in report["statistics"]["content_length"] + assert "minimum" in report["statistics"]["content_length"] + assert "maximum" in report["statistics"]["content_length"] + # Average should be reasonable - avg_length = report['statistics']['content_length']['average'] + avg_length = report["statistics"]["content_length"]["average"] assert 50 < avg_length < 500 class TestStatisticsPrinting: """Test statistics printing functionality""" - + def test_print_empty_dataset_statistics(self, capsys): """Test printing statistics for empty dataset""" print_dataset_statistics([]) - + captured = capsys.readouterr() assert "Dataset is empty" in captured.out - + def test_print_valid_dataset_statistics(self, capsys): """Test printing statistics for valid dataset""" items = [ @@ -303,43 +303,43 @@ def test_print_valid_dataset_statistics(self, capsys): id="test_001", medical_content="Valid medical content for testing statistics display functionality.", complexity_level="basic", - source_dataset="MedQA-USMLE" + source_dataset="MedQA-USMLE", ), MEQBenchItem( - id="test_002", + id="test_002", medical_content="Another valid medical content item for comprehensive statistics testing.", complexity_level="intermediate", - source_dataset="iCliniq" + source_dataset="iCliniq", ), MEQBenchItem( id="test_003", medical_content="Third valid medical content item for advanced complexity testing and statistics.", complexity_level="advanced", - source_dataset="Cochrane Reviews" - ) + source_dataset="Cochrane Reviews", + ), ] - + print_dataset_statistics(items) - + captured = capsys.readouterr() output = captured.out - + # Check that all expected sections are present assert "Dataset Statistics" in output assert "Complexity Distribution" in output assert "Source Distribution" in output assert "Content Length Statistics" in output - + # Check that complexity levels are shown assert "Basic" in output assert "Intermediate" in output assert "Advanced" in output - + # Check that sources are shown assert "MedQA-USMLE" in output assert "iCliniq" in output assert "Cochrane Reviews" in output - + # Check that statistics are shown assert "Total items: 3" in output assert "Average:" in output @@ -349,53 +349,56 @@ def test_print_valid_dataset_statistics(self, capsys): class TestArgumentParser: """Test command line argument parser""" - + def test_argument_parser_setup(self): """Test that argument parser is set up correctly""" parser = setup_argument_parser() - + # Test that parser exists and has expected arguments - args = parser.parse_args([ - '--medqa', 'medqa.json', - '--icliniq', 'icliniq.json', - '--cochrane', 'cochrane.json', - '--output', 'output.json' - ]) - - assert args.medqa == 'medqa.json' - assert args.icliniq == 'icliniq.json' - assert args.cochrane == 'cochrane.json' - assert args.output == 'output.json' + args = parser.parse_args( + ["--medqa", "medqa.json", "--icliniq", "icliniq.json", "--cochrane", "cochrane.json", "--output", "output.json"] + ) + + assert args.medqa == "medqa.json" + assert args.icliniq == "icliniq.json" + assert args.cochrane == "cochrane.json" + assert args.output == "output.json" assert args.max_items == 1000 # Default value - + def test_argument_parser_defaults(self): """Test default argument values""" parser = setup_argument_parser() - + # Test with minimal arguments - args = parser.parse_args(['--medqa', 'test.json']) - - assert args.output == 'data/benchmark_items.json' # Default output + args = parser.parse_args(["--medqa", "test.json"]) + + assert args.output == "data/benchmark_items.json" # Default output assert args.max_items == 1000 # Default max items assert args.auto_complexity is True # Default auto complexity assert args.balance_complexity is True # Default balance assert args.seed == 42 # Default seed - + def test_argument_parser_custom_values(self): """Test custom argument values""" parser = setup_argument_parser() - - args = parser.parse_args([ - '--medqa', 'medqa.json', - '--max-items', '500', - '--medqa-items', '200', - '--no-auto-complexity', - '--seed', '123', - '--validate', - '--stats', - '--verbose' - ]) - + + args = parser.parse_args( + [ + "--medqa", + "medqa.json", + "--max-items", + "500", + "--medqa-items", + "200", + "--no-auto-complexity", + "--seed", + "123", + "--validate", + "--stats", + "--verbose", + ] + ) + assert args.max_items == 500 assert args.medqa_items == 200 assert args.auto_complexity is False @@ -407,100 +410,109 @@ def test_argument_parser_custom_values(self): class TestScriptIntegration: """Integration tests for the complete script""" - + @pytest.mark.skipif(process_datasets is None, reason="process_datasets module not available") def test_script_with_sample_data(self): """Test the complete script with sample data""" with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) - + # Create sample dataset files - medqa_data = [{ - "id": "medqa_001", - "question": "What is hypertension?", - "answer": "A", - "explanation": "High blood pressure condition." - }] - - icliniq_data = [{ - "id": "icliniq_001", - "patient_question": "I have chest pain", - "doctor_answer": "Consult a cardiologist", - "speciality": "Cardiology" - }] - - cochrane_data = [{ - "id": "cochrane_001", - "title": "Statin effectiveness", - "abstract": "Systematic review of statins" - }] - + medqa_data = [ + { + "id": "medqa_001", + "question": "What is hypertension?", + "answer": "A", + "explanation": "High blood pressure condition.", + } + ] + + icliniq_data = [ + { + "id": "icliniq_001", + "patient_question": "I have chest pain", + "doctor_answer": "Consult a cardiologist", + "speciality": "Cardiology", + } + ] + + cochrane_data = [ + {"id": "cochrane_001", "title": "Statin effectiveness", "abstract": "Systematic review of statins"} + ] + # Save sample files medqa_file = temp_path / "medqa.json" icliniq_file = temp_path / "icliniq.json" cochrane_file = temp_path / "cochrane.json" output_file = temp_path / "output.json" - - with open(medqa_file, 'w') as f: + + with open(medqa_file, "w") as f: json.dump(medqa_data, f) - with open(icliniq_file, 'w') as f: + with open(icliniq_file, "w") as f: json.dump(icliniq_data, f) - with open(cochrane_file, 'w') as f: + with open(cochrane_file, "w") as f: json.dump(cochrane_data, f) - + # Run the script through subprocess to test CLI script_path = Path(__file__).parent.parent / "scripts" / "process_datasets.py" - - result = subprocess.run([ - sys.executable, str(script_path), - '--medqa', str(medqa_file), - '--icliniq', str(icliniq_file), - '--cochrane', str(cochrane_file), - '--output', str(output_file), - '--max-items', '10', - '--no-auto-complexity', - '--validate', - '--stats' - ], capture_output=True, text=True) - + + result = subprocess.run( + [ + sys.executable, + str(script_path), + "--medqa", + str(medqa_file), + "--icliniq", + str(icliniq_file), + "--cochrane", + str(cochrane_file), + "--output", + str(output_file), + "--max-items", + "10", + "--no-auto-complexity", + "--validate", + "--stats", + ], + capture_output=True, + text=True, + ) + # Check that script ran successfully assert result.returncode == 0, f"Script failed with error: {result.stderr}" - + # Check that output file was created assert output_file.exists() - + # Verify output content - with open(output_file, 'r') as f: + with open(output_file, "r") as f: output_data = json.load(f) - + assert len(output_data) == 3 # Should have loaded all 3 items - assert any(item['source_dataset'] == 'MedQA-USMLE' for item in output_data) - assert any(item['source_dataset'] == 'iCliniq' for item in output_data) - assert any(item['source_dataset'] == 'Cochrane Reviews' for item in output_data) - + assert any(item["source_dataset"] == "MedQA-USMLE" for item in output_data) + assert any(item["source_dataset"] == "iCliniq" for item in output_data) + assert any(item["source_dataset"] == "Cochrane Reviews" for item in output_data) + def test_script_error_handling(self): """Test script error handling with invalid inputs""" script_path = Path(__file__).parent.parent / "scripts" / "process_datasets.py" - + # Test with non-existent file - result = subprocess.run([ - sys.executable, str(script_path), - '--medqa', '/nonexistent/file.json', - '--output', 'output.json' - ], capture_output=True, text=True) - + result = subprocess.run( + [sys.executable, str(script_path), "--medqa", "/nonexistent/file.json", "--output", "output.json"], + capture_output=True, + text=True, + ) + # Should exit with error code assert result.returncode != 0 - + def test_script_help_output(self): """Test that script provides help output""" script_path = Path(__file__).parent.parent / "scripts" / "process_datasets.py" - - result = subprocess.run([ - sys.executable, str(script_path), - '--help' - ], capture_output=True, text=True) - + + result = subprocess.run([sys.executable, str(script_path), "--help"], capture_output=True, text=True) + assert result.returncode == 0 assert "Process medical datasets for MEQ-Bench" in result.stdout assert "--medqa" in result.stdout @@ -510,69 +522,78 @@ def test_script_help_output(self): class TestScriptPerformance: """Test script performance characteristics""" - + @pytest.mark.skipif(process_datasets is None, reason="process_datasets module not available") def test_script_performance_with_large_dataset(self): """Test script performance with larger dataset""" import time - + with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) - + # Create larger sample datasets large_medqa_data = [ { "id": f"medqa_{i:03d}", "question": f"Medical question {i}", "answer": "A", - "explanation": f"Medical explanation {i}" + "explanation": f"Medical explanation {i}", } for i in range(200) ] - + large_icliniq_data = [ { "id": f"icliniq_{i:03d}", "patient_question": f"Patient question {i}", "doctor_answer": f"Doctor answer {i}", - "speciality": "General" + "speciality": "General", } for i in range(200) ] - + # Save large files medqa_file = temp_path / "large_medqa.json" icliniq_file = temp_path / "large_icliniq.json" output_file = temp_path / "large_output.json" - - with open(medqa_file, 'w') as f: + + with open(medqa_file, "w") as f: json.dump(large_medqa_data, f) - with open(icliniq_file, 'w') as f: + with open(icliniq_file, "w") as f: json.dump(large_icliniq_data, f) - + # Time the script execution start_time = time.time() - + script_path = Path(__file__).parent.parent / "scripts" / "process_datasets.py" - - result = subprocess.run([ - sys.executable, str(script_path), - '--medqa', str(medqa_file), - '--icliniq', str(icliniq_file), - '--output', str(output_file), - '--max-items', '100', - '--no-auto-complexity' # Skip complexity calculation for speed - ], capture_output=True, text=True) - + + result = subprocess.run( + [ + sys.executable, + str(script_path), + "--medqa", + str(medqa_file), + "--icliniq", + str(icliniq_file), + "--output", + str(output_file), + "--max-items", + "100", + "--no-auto-complexity", # Skip complexity calculation for speed + ], + capture_output=True, + text=True, + ) + end_time = time.time() execution_time = end_time - start_time - + # Should complete within reasonable time (30 seconds for large dataset) assert execution_time < 30.0, f"Script took {execution_time:.2f}s" assert result.returncode == 0 assert output_file.exists() - + # Verify output quality - with open(output_file, 'r') as f: + with open(output_file, "r") as f: output_data = json.load(f) - assert len(output_data) == 100 # Should respect max_items limit \ No newline at end of file + assert len(output_data) == 100 # Should respect max_items limit diff --git a/tests/test_robust_parser.py b/tests/test_robust_parser.py index 54a2fe4..abdab61 100644 --- a/tests/test_robust_parser.py +++ b/tests/test_robust_parser.py @@ -8,7 +8,7 @@ class TestRobustParser: """Test the robust LLM response parser with various input formats""" - + def test_standard_format(self): """Test parsing with standard format: 'For a Physician:'""" response = """ @@ -20,20 +20,20 @@ def test_standard_format(self): For a Caregiver: This provides concrete tasks, symptoms to watch for, and clear guidance on when to seek help. """ - + result = AudienceAdaptivePrompt.parse_response(response) - + assert len(result) == 4 - assert 'physician' in result - assert 'nurse' in result - assert 'patient' in result - assert 'caregiver' in result - - assert 'technical explanation' in result['physician'] - assert 'practical care implications' in result['nurse'] - assert 'simple, jargon-free' in result['patient'] - assert 'concrete tasks' in result['caregiver'] - + assert "physician" in result + assert "nurse" in result + assert "patient" in result + assert "caregiver" in result + + assert "technical explanation" in result["physician"] + assert "practical care implications" in result["nurse"] + assert "simple, jargon-free" in result["patient"] + assert "concrete tasks" in result["caregiver"] + def test_colon_format(self): """Test parsing with simple colon format: 'Physician:'""" response = """ @@ -45,15 +45,15 @@ def test_colon_format(self): Caregiver: Clear instructions for family caregivers. """ - + result = AudienceAdaptivePrompt.parse_response(response) - + assert len(result) == 4 - assert 'Technical medical explanation' in result['physician'] - assert 'Practical nursing care' in result['nurse'] - assert 'Easy-to-understand' in result['patient'] - assert 'Clear instructions' in result['caregiver'] - + assert "Technical medical explanation" in result["physician"] + assert "Practical nursing care" in result["nurse"] + assert "Easy-to-understand" in result["patient"] + assert "Clear instructions" in result["caregiver"] + def test_numbered_format(self): """Test parsing with numbered format: '1. Physician:'""" response = """ @@ -65,15 +65,15 @@ def test_numbered_format(self): 4. Caregiver: Supportive care instructions. """ - + result = AudienceAdaptivePrompt.parse_response(response) - + assert len(result) == 4 - assert 'Advanced medical explanation' in result['physician'] - assert 'Care plan and monitoring' in result['nurse'] - assert 'Simple explanation' in result['patient'] - assert 'Supportive care instructions' in result['caregiver'] - + assert "Advanced medical explanation" in result["physician"] + assert "Care plan and monitoring" in result["nurse"] + assert "Simple explanation" in result["patient"] + assert "Supportive care instructions" in result["caregiver"] + def test_markdown_header_format(self): """Test parsing with markdown headers: '## Physician'""" response = """ @@ -89,15 +89,15 @@ def test_markdown_header_format(self): ## Caregiver Family support and care coordination. """ - + result = AudienceAdaptivePrompt.parse_response(response) - + assert len(result) == 4 - assert 'pathophysiology' in result['physician'] - assert 'assessment and intervention' in result['nurse'] - assert 'Patient-friendly' in result['patient'] - assert 'Family support' in result['caregiver'] - + assert "pathophysiology" in result["physician"] + assert "assessment and intervention" in result["nurse"] + assert "Patient-friendly" in result["patient"] + assert "Family support" in result["caregiver"] + def test_bold_format(self): """Test parsing with bold formatting: '**Physician:**'""" response = """ @@ -109,15 +109,15 @@ def test_bold_format(self): **Caregiver:** Practical guidance for family members. """ - + result = AudienceAdaptivePrompt.parse_response(response) - + assert len(result) == 4 - assert 'clinical focus' in result['physician'] - assert 'care priorities' in result['nurse'] - assert 'Accessible explanation' in result['patient'] - assert 'Practical guidance' in result['caregiver'] - + assert "clinical focus" in result["physician"] + assert "care priorities" in result["nurse"] + assert "Accessible explanation" in result["patient"] + assert "Practical guidance" in result["caregiver"] + def test_case_insensitive(self): """Test that parsing is case-insensitive""" response = """ @@ -129,15 +129,15 @@ def test_case_insensitive(self): for a Caregiver: Mixed case format explanation. """ - + result = AudienceAdaptivePrompt.parse_response(response) - + assert len(result) == 4 - assert 'Upper case format' in result['physician'] - assert 'Lower case format' in result['nurse'] - assert 'Mixed case format' in result['patient'] - assert 'Mixed case format' in result['caregiver'] - + assert "Upper case format" in result["physician"] + assert "Lower case format" in result["nurse"] + assert "Mixed case format" in result["patient"] + assert "Mixed case format" in result["caregiver"] + def test_whitespace_variations(self): """Test parsing with various whitespace variations""" response = """ @@ -153,15 +153,15 @@ def test_whitespace_variations(self): Explanation with blank line. """ - + result = AudienceAdaptivePrompt.parse_response(response) - + assert len(result) == 4 - assert 'extra spaces' in result['physician'] - assert 'new line' in result['nurse'] - assert 'tabs and spaces' in result['patient'] - assert 'blank line' in result['caregiver'] - + assert "extra spaces" in result["physician"] + assert "new line" in result["nurse"] + assert "tabs and spaces" in result["patient"] + assert "blank line" in result["caregiver"] + def test_multiline_explanations(self): """Test parsing with multi-line explanations""" response = """ @@ -182,17 +182,17 @@ def test_multiline_explanations(self): Caregiver instruction line one. Caregiver instruction line two. """ - + result = AudienceAdaptivePrompt.parse_response(response) - + assert len(result) == 4 - assert 'first line' in result['physician'] - assert 'second line' in result['physician'] - assert 'third line' in result['physician'] - assert 'care instructions' in result['nurse'] - assert 'Simple first line' in result['patient'] - assert 'instruction line one' in result['caregiver'] - + assert "first line" in result["physician"] + assert "second line" in result["physician"] + assert "third line" in result["physician"] + assert "care instructions" in result["nurse"] + assert "Simple first line" in result["patient"] + assert "instruction line one" in result["caregiver"] + def test_fallback_parsing(self): """Test fallback parsing when regex fails""" response = """ @@ -206,30 +206,30 @@ def test_fallback_parsing(self): caregiver details included """ - + result = AudienceAdaptivePrompt.parse_response(response) - + # Should still extract some content even with poor formatting assert isinstance(result, dict) - + def test_empty_response(self): """Test parsing with empty response""" response = "" - + result = AudienceAdaptivePrompt.parse_response(response) - + assert isinstance(result, dict) assert len(result) == 0 - + def test_malformed_response(self): """Test parsing with malformed response""" response = "This is just random text without any audience markers." - + result = AudienceAdaptivePrompt.parse_response(response) - + assert isinstance(result, dict) # Should return empty dict if no audiences found - + def test_partial_audiences(self): """Test parsing when only some audiences are present""" response = """ @@ -237,15 +237,15 @@ def test_partial_audiences(self): For a Patient: Simple explanation only for patients. """ - + result = AudienceAdaptivePrompt.parse_response(response) - + assert len(result) == 2 - assert 'physician' in result - assert 'patient' in result - assert 'nurse' not in result - assert 'caregiver' not in result - + assert "physician" in result + assert "patient" in result + assert "nurse" not in result + assert "caregiver" not in result + def test_text_cleaning(self): """Test that extracted text is properly cleaned""" response = """ @@ -257,17 +257,17 @@ def test_text_cleaning(self): For a Patient: - Leading dash should be removed """ - + result = AudienceAdaptivePrompt.parse_response(response) - + # Check that markdown formatting is removed - assert '**' not in result['physician'] - assert '*' not in result['physician'] - assert '`' not in result['physician'] - assert 'Bold text' in result['physician'] - assert 'italic text' in result['physician'] - assert 'code' in result['physician'] - + assert "**" not in result["physician"] + assert "*" not in result["physician"] + assert "`" not in result["physician"] + assert "Bold text" in result["physician"] + assert "italic text" in result["physician"] + assert "code" in result["physician"] + # Check that leading dash is removed - assert not result['patient'].startswith('-') - assert 'Leading dash should be removed' in result['patient'] \ No newline at end of file + assert not result["patient"].startswith("-") + assert "Leading dash should be removed" in result["patient"]