From 8aacdca41eb1e690d97c3576ac49f45c647408a4 Mon Sep 17 00:00:00 2001 From: Asutosh Samal Date: Mon, 9 Feb 2026 13:19:11 +0530 Subject: [PATCH] add llm pool & judge panel config --- .../core/models/__init__.py | 4 + .../core/models/system.py | 350 +++++++++++++++ tests/unit/core/models/test_system.py | 419 ++++++++++++++++++ .../core/models/test_system_additional.py | 242 ---------- 4 files changed, 773 insertions(+), 242 deletions(-) create mode 100644 tests/unit/core/models/test_system.py delete mode 100644 tests/unit/core/models/test_system_additional.py diff --git a/src/lightspeed_evaluation/core/models/__init__.py b/src/lightspeed_evaluation/core/models/__init__.py index b6ffd47..8e4e61f 100644 --- a/src/lightspeed_evaluation/core/models/__init__.py +++ b/src/lightspeed_evaluation/core/models/__init__.py @@ -18,7 +18,9 @@ APIConfig, CoreConfig, EmbeddingConfig, + JudgePanelConfig, LLMConfig, + LLMPoolConfig, LoggingConfig, OutputConfig, SystemConfig, @@ -35,7 +37,9 @@ "EvaluationScope", # System config models "CoreConfig", + "JudgePanelConfig", "LLMConfig", + "LLMPoolConfig", "EmbeddingConfig", "APIConfig", "OutputConfig", diff --git a/src/lightspeed_evaluation/core/models/system.py b/src/lightspeed_evaluation/core/models/system.py index 5d11ca2..fed3f63 100644 --- a/src/lightspeed_evaluation/core/models/system.py +++ b/src/lightspeed_evaluation/core/models/system.py @@ -5,6 +5,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from lightspeed_evaluation.core.system.exceptions import ConfigurationError from lightspeed_evaluation.core.constants import ( DEFAULT_API_BASE, DEFAULT_API_CACHE_DIR, @@ -318,6 +319,289 @@ class CoreConfig(BaseModel): ) +class LLMParametersConfig(BaseModel): + """Dynamic parameters passed to LLM API calls. + + These parameters are passed directly to the LLM provider. + All fields are optional - unset fields inherit from parent level. + Uses extra="allow" to pass through any provider-specific parameters. + """ + + model_config = ConfigDict(extra="allow") + + temperature: Optional[float] = Field( + default=None, + ge=0.0, + le=2.0, + description="Sampling temperature", + ) + max_completion_tokens: Optional[int] = Field( + default=None, + ge=1, + description="Maximum tokens in response", + ) + + def to_dict(self, exclude_none: bool = True) -> dict[str, Any]: + """Convert parameters to dict for passing to LLM. + + Args: + exclude_none: If True, exclude None values from output + + Returns: + Dict of parameters ready for LLM API call + """ + params = self.model_dump() + if exclude_none: + return {k: v for k, v in params.items() if v is not None} + return params + + +class LLMDefaultsConfig(BaseModel): + """Global default settings for all LLMs in the pool. + + These are shared defaults that apply to all LLMs unless overridden + at the provider or model level. + """ + + model_config = ConfigDict(extra="forbid") + + cache_enabled: bool = Field( + default=True, + description="Is caching of LLM queries enabled?", + ) + cache_dir: str = Field( + default=DEFAULT_LLM_CACHE_DIR, + min_length=1, + description="Base cache directory", + ) + + timeout: int = Field( + default=DEFAULT_API_TIMEOUT, + ge=1, + description="Request timeout in seconds", + ) + + num_retries: int = Field( + default=DEFAULT_LLM_RETRIES, + ge=0, + description="Retry attempts for failed requests", + ) + + # Default dynamic parameters + parameters: LLMParametersConfig = Field( + default_factory=lambda: LLMParametersConfig( + temperature=DEFAULT_LLM_TEMPERATURE, + max_completion_tokens=DEFAULT_LLM_MAX_TOKENS, + ), + description="Default dynamic parameters for LLM calls", + ) + + +class LLMProviderConfig(BaseModel): + """Configuration for a single LLM provider/model in the pool. + + Contains model-specific settings. Cache and retry settings are managed + at the pool defaults level, not per-model. + + The dict key is the unique model ID used for referencing. + """ + + model_config = ConfigDict(extra="forbid") + + # Required: Provider type + provider: str = Field( + min_length=1, + description="Provider type (e.g., openai, watsonx, gemini, hosted_vllm)", + ) + + # Model identity (optional - defaults to dict key) + model: Optional[str] = Field( + default=None, + min_length=1, + description="Actual model name. If not set, uses the dict key as model name.", + ) + + # SSL settings (optional - inherit from defaults or use system defaults) + ssl_verify: Optional[bool] = Field( + default=None, + description="Verify SSL certificates. Inherits from defaults if not set.", + ) + ssl_cert_file: Optional[str] = Field( + default=None, + description="Path to custom CA certificate file", + ) + + # API endpoint/key configuration (optional - falls back to environment variable) + api_base: Optional[str] = Field( + default=None, + min_length=1, + description=( + "Base URL for the API endpoint. " + "If not set, falls back to provider-specific environment variable." + ), + ) + api_key_path: Optional[str] = Field( + default=None, + min_length=1, + description=( + "Path to text file containing the API key for this model. " + "If not set, falls back to provider-specific environment variable." + ), + ) + + # Dynamic parameters (passed to LLM API) + parameters: LLMParametersConfig = Field( + default_factory=LLMParametersConfig, + description="Dynamic parameters for this model (merged with defaults)", + ) + + # Timeout can be model-specific (some models are slower) + timeout: Optional[int] = Field( + default=None, + ge=1, + description="Override timeout for this model", + ) + + +class LLMPoolConfig(BaseModel): + """Pool of LLM configurations for reuse across the system. + + Provides a centralized place to define all LLM configurations, + which can be referenced by judge_panel, agents, or other components. + + Cache and retry settings are managed at the defaults level only. + Model entries contain model-specific settings (provider, parameters, SSL). + """ + + model_config = ConfigDict(extra="forbid") + + defaults: LLMDefaultsConfig = Field( + default_factory=LLMDefaultsConfig, + description="Global default settings for all LLMs (cache, retry, parameters)", + ) + models: dict[str, LLMProviderConfig] = Field( + default_factory=dict, + description="Model configurations. Key is unique model ID for referencing.", + ) + + def get_model_ids(self) -> list[str]: + """Get all available model IDs.""" + return list(self.models.keys()) + + def resolve_llm_config( + self, model_id: str, cache_suffix: Optional[str] = None + ) -> LLMConfig: + """Resolve a model ID to a fully configured LLMConfig. + + Resolution order: defaults -> model entry (for model-specific fields) + + Args: + model_id: Model identifier (key in models dict) + cache_suffix: Optional suffix for cache directory (e.g., "judge_0") + + Returns: + Fully resolved LLMConfig + + Raises: + ValueError: If model_id not found + """ + if model_id not in self.models: + raise ValueError( + f"Model '{model_id}' not found in llm_pool.models. " + f"Available: {list(self.models.keys())}" + ) + entry = self.models[model_id] + + # Merge parameters: defaults -> model entry + merged_params: dict[str, Any] = {} + merged_params.update(self.defaults.parameters.to_dict(exclude_none=True)) + merged_params.update(entry.parameters.to_dict(exclude_none=True)) + + # Build cache_dir from defaults with model-specific suffix + suffix = cache_suffix if cache_suffix else model_id + cache_dir = os.path.join(self.defaults.cache_dir, suffix) + + return LLMConfig( + provider=entry.provider, + model=entry.model or model_id, + temperature=merged_params.get("temperature", DEFAULT_LLM_TEMPERATURE), + max_tokens=merged_params.get( + "max_completion_tokens", DEFAULT_LLM_MAX_TOKENS + ), + timeout=( + entry.timeout if entry.timeout is not None else self.defaults.timeout + ), + num_retries=self.defaults.num_retries, + ssl_verify=( + entry.ssl_verify if entry.ssl_verify is not None else DEFAULT_SSL_VERIFY + ), + ssl_cert_file=entry.ssl_cert_file, + cache_enabled=self.defaults.cache_enabled, + cache_dir=cache_dir, + # Note: api_base and api_key_path are not propagated yet - requires LLMConfig extension + ) + + +class JudgePanelConfig(BaseModel): + """Judge panel configuration for multi-LLM evaluation. + + References models from LLM pool by model ID (the key in llm_pool.models). + Each judge ID must correspond to a key in the llm_pool.models dictionary. + """ + + model_config = ConfigDict(extra="forbid") + + judges: list[str] = Field( + ..., + min_length=1, + description="List of model IDs (keys from llm_pool.models). At least one required.", + ) + enabled_metrics: Optional[list[str]] = Field( + default=None, + description=( + "Metrics that should use the judge panel. " + "If None, all metrics use the panel. " + "If empty list, no metrics use the panel." + ), + ) + aggregation_strategy: str = Field( + default="average", + description=( + "Strategy for aggregating scores from multiple judges. " + "Options: 'max', 'average', 'majority_vote'. " + "Note: Currently unused - will be implemented later." + ), + ) + + @field_validator("enabled_metrics") + @classmethod + def validate_enabled_metrics(cls, v: Optional[list[str]]) -> Optional[list[str]]: + """Validate enabled_metrics format (framework:metric_name).""" + if v is not None: + for metric in v: + if not metric or ":" not in metric: + raise ValueError( + f'Metric "{metric}" must be in format "framework:metric_name"' + ) + parts = metric.split(":", 1) + if len(parts) != 2 or not parts[0].strip() or not parts[1].strip(): + raise ValueError( + f'Metric "{metric}" must be in format "framework:metric_name"' + ) + return v + + @field_validator("aggregation_strategy") + @classmethod + def validate_aggregation_strategy(cls, v: str) -> str: + """Validate aggregation_strategy is a supported value.""" + allowed = ["max", "average", "majority_vote"] + if v not in allowed: + raise ValueError( + f"Unsupported aggregation_strategy '{v}'. Allowed: {allowed}" + ) + return v + + class SystemConfig(BaseModel): """System configuration using individual config models.""" @@ -328,6 +612,25 @@ class SystemConfig(BaseModel): default_factory=CoreConfig, description="Core eval configuration" ) llm: LLMConfig = Field(default_factory=LLMConfig, description="LLM configuration") + + # LLM Pool - shared pool of LLM configurations + llm_pool: Optional[LLMPoolConfig] = Field( + default=None, + description=( + "Pool of LLM configurations. Define models once, " + "reference by ID in judge_panel or other components." + ), + ) + + # Judge Panel - references models from llm_pool + judge_panel: Optional[JudgePanelConfig] = Field( + default=None, + description=( + "Optional judge panel configuration. " + "References models from 'llm_pool' by ID. " + "If not provided, the single 'llm' configuration is used." + ), + ) embedding: EmbeddingConfig = Field( default_factory=EmbeddingConfig, description="Embeddings configuration" ) @@ -349,3 +652,50 @@ class SystemConfig(BaseModel): default_conversation_metrics_metadata: dict[str, dict[str, Any]] = Field( default_factory=dict, description="Default conversation metrics metadata" ) + + def get_judge_configs(self) -> list[LLMConfig]: + """Get resolved LLMConfig for all judges. + + Returns: + List of LLMConfig objects for each judge. + If judge_panel is configured, resolves from llm_pool. + Otherwise, returns single llm config. + """ + if not self.judge_panel: + return [self.llm] + + if not self.llm_pool: + raise ConfigurationError( + "judge_panel is configured but 'llm_pool' is not defined. " + "Please define the llm_pool section with models." + ) + + configs = [] + for idx, judge_id in enumerate(self.judge_panel.judges): + cache_suffix = f"judge_{idx}" + config = self.llm_pool.resolve_llm_config( + judge_id, cache_suffix=cache_suffix + ) + configs.append(config) + return configs + + def get_llm_config( + self, model_id: str, cache_suffix: Optional[str] = None + ) -> LLMConfig: + """Get resolved LLMConfig for a specific model from the pool. + + Args: + model_id: Model identifier (key in llm_pool.models) + cache_suffix: Optional suffix for cache directory + + Returns: + Fully resolved LLMConfig + + Raises: + ConfigurationError: If llm_pool not configured or model not found + """ + if not self.llm_pool: + raise ConfigurationError( + f"Cannot resolve model '{model_id}' - 'llm_pool' is not configured." + ) + return self.llm_pool.resolve_llm_config(model_id, cache_suffix=cache_suffix) diff --git a/tests/unit/core/models/test_system.py b/tests/unit/core/models/test_system.py new file mode 100644 index 0000000..be9a6a4 --- /dev/null +++ b/tests/unit/core/models/test_system.py @@ -0,0 +1,419 @@ +"""Tests for system configuration models.""" + +import os +import tempfile +import pytest +from pydantic import ValidationError +from pytest_mock import MockerFixture + +from lightspeed_evaluation.core.models import ( + JudgePanelConfig, + LLMConfig, + LLMPoolConfig, + SystemConfig, + EmbeddingConfig, + APIConfig, + OutputConfig, + VisualizationConfig, +) +from lightspeed_evaluation.core.models.system import ( + LLMDefaultsConfig, + LLMParametersConfig, + LLMProviderConfig, + LoggingConfig, +) +from lightspeed_evaluation.core.system.exceptions import ConfigurationError + + +class TestLLMConfig: + """Tests for LLMConfig model.""" + + def test_defaults_and_validation(self) -> None: + """Test default values and field validations.""" + # Test defaults + config = LLMConfig() + assert config.ssl_verify is True + assert config.ssl_cert_file is None + + # Test validation bounds + with pytest.raises(ValidationError): + LLMConfig(temperature=-0.1) + with pytest.raises(ValidationError): + LLMConfig(temperature=2.1) + with pytest.raises(ValidationError): + LLMConfig(max_tokens=0) + with pytest.raises(ValidationError): + LLMConfig(timeout=0) + with pytest.raises(ValidationError): + LLMConfig(num_retries=-1) + + def test_ssl_cert_file_handling(self, mocker: MockerFixture) -> None: + """Test ssl_cert_file validation and path expansion.""" + # Create temp cert file + with tempfile.NamedTemporaryFile(mode="w", suffix=".crt", delete=False) as f: + cert_path = f.name + f.write("-----BEGIN CERTIFICATE-----\ntest\n-----END CERTIFICATE-----\n") + + try: + # Valid path - converts to absolute + config = LLMConfig(ssl_cert_file=cert_path) + assert config.ssl_cert_file == os.path.abspath(cert_path) + + # Environment variable expansion + test_dir = os.path.dirname(cert_path) + test_filename = os.path.basename(cert_path) + mocker.patch.dict(os.environ, {"TEST_CERT_DIR": test_dir}) + config = LLMConfig(ssl_cert_file=f"$TEST_CERT_DIR/{test_filename}") + assert config.ssl_cert_file == os.path.abspath(cert_path) + finally: + os.unlink(cert_path) + + # Non-existent file fails + with pytest.raises(ValidationError, match="(?i)not found"): + LLMConfig(ssl_cert_file="/tmp/nonexistent_cert_12345.crt") + + # Directory fails + with pytest.raises(ValidationError): + LLMConfig(ssl_cert_file=tempfile.gettempdir()) + + +class TestBasicConfigModels: + """Tests for EmbeddingConfig, APIConfig, OutputConfig, VisualizationConfig, LoggingConfig.""" + + def test_embedding_config(self) -> None: + """Test EmbeddingConfig defaults and custom values.""" + default = EmbeddingConfig() + assert default.provider is not None + assert default.cache_enabled is True + + custom = EmbeddingConfig(provider="openai", model="text-embedding-3-small") + assert custom.provider == "openai" + assert custom.model == "text-embedding-3-small" + + def test_api_config(self) -> None: + """Test APIConfig defaults and validation.""" + default = APIConfig() + assert isinstance(default.enabled, bool) + assert default.timeout > 0 + + custom = APIConfig(enabled=True, api_base="https://custom.api.com", timeout=300) + assert custom.api_base == "https://custom.api.com" + + with pytest.raises(ValidationError): + APIConfig(timeout=0) + + def test_output_config(self) -> None: + """Test OutputConfig defaults and custom values.""" + default = OutputConfig() + assert "csv" in default.enabled_outputs + assert len(default.csv_columns) > 0 + + custom = OutputConfig(enabled_outputs=["json"], csv_columns=["result"]) + assert custom.enabled_outputs == ["json"] + + def test_visualization_config(self) -> None: + """Test VisualizationConfig defaults and validation.""" + default = VisualizationConfig() + assert default.dpi > 0 + assert len(default.figsize) == 2 + + custom = VisualizationConfig(dpi=150, figsize=[12, 8]) + assert custom.dpi == 150 + + with pytest.raises(ValidationError): + VisualizationConfig(dpi=0) + + def test_logging_config(self) -> None: + """Test LoggingConfig defaults and custom values.""" + default = LoggingConfig() + assert default.source_level is not None + + custom = LoggingConfig( + source_level="DEBUG", + package_overrides={"httpx": "CRITICAL"}, + show_timestamps=True, + ) + assert custom.source_level == "DEBUG" + assert custom.package_overrides["httpx"] == "CRITICAL" + + +class TestLLMParametersConfig: + """Tests for LLMParametersConfig model.""" + + def test_defaults_and_validation(self) -> None: + """Test defaults and field validations.""" + params = LLMParametersConfig() + assert params.temperature is None + assert params.max_completion_tokens is None + + # Validation + with pytest.raises(ValidationError): + LLMParametersConfig(temperature=2.5) + with pytest.raises(ValidationError): + LLMParametersConfig(temperature=-0.1) + with pytest.raises(ValidationError): + LLMParametersConfig(max_completion_tokens=0) + + def test_extra_params_and_to_dict(self) -> None: + """Test extra parameters allowed and to_dict method.""" + params = LLMParametersConfig.model_validate( + {"temperature": 0.5, "top_p": 0.9, "frequency_penalty": 0.5} + ) + assert params.temperature == 0.5 + # Access extra fields via model_dump + dump = params.model_dump() + assert dump["top_p"] == 0.9 + assert dump["frequency_penalty"] == 0.5 + + # to_dict excludes None by default + result = params.to_dict() + assert "temperature" in result + assert "max_completion_tokens" not in result + + # to_dict can include None + result = params.to_dict(exclude_none=False) + assert result["max_completion_tokens"] is None + + +class TestLLMProviderConfig: + """Tests for LLMProviderConfig model.""" + + def test_minimal_and_full_config(self) -> None: + """Test minimal required fields and full config.""" + # Minimal - only provider is required + entry = LLMProviderConfig(provider="openai") + assert entry.provider == "openai" + assert entry.model is None + assert entry.ssl_verify is None + assert entry.api_base is None + assert entry.api_key_path is None + assert entry.timeout is None + + # Full config with all fields + entry = LLMProviderConfig( + provider="hosted_vllm", + model="gpt-oss-20b", + ssl_verify=True, + api_base="https://vllm.example.com/v1", + api_key_path="/secrets/key.txt", + parameters=LLMParametersConfig(temperature=0.5), + timeout=600, + ) + assert entry.model == "gpt-oss-20b" + assert entry.api_base == "https://vllm.example.com/v1" + assert entry.api_key_path == "/secrets/key.txt" + assert entry.parameters.temperature == 0.5 + + def test_provider_required(self) -> None: + """Test that provider field is required.""" + with pytest.raises(ValidationError): + LLMProviderConfig.model_validate({}) + + +class TestLLMPoolConfig: + """Tests for LLMPoolConfig model.""" + + def test_pool_basics(self) -> None: + """Test pool creation and model ID retrieval.""" + pool = LLMPoolConfig( + models={ + "gpt-4o-mini": LLMProviderConfig(provider="openai"), + "gpt-4o": LLMProviderConfig(provider="openai"), + } + ) + assert pool.defaults is not None + assert len(pool.get_model_ids()) == 2 + + def test_resolve_llm_config(self) -> None: + """Test resolving LLMConfig with defaults and overrides.""" + pool = LLMPoolConfig( + defaults=LLMDefaultsConfig( + timeout=600, + cache_dir=".caches/llm", + parameters=LLMParametersConfig( + temperature=0.1, max_completion_tokens=512 + ), + ), + models={ + "gpt-4o-mini": LLMProviderConfig(provider="openai"), + "gpt-4o": LLMProviderConfig( + provider="openai", + parameters=LLMParametersConfig( + temperature=0.5, max_completion_tokens=2048 + ), + timeout=300, + ), + }, + ) + + # Uses defaults + resolved = pool.resolve_llm_config("gpt-4o-mini") + assert isinstance(resolved, LLMConfig) + assert resolved.provider == "openai" + assert resolved.model == "gpt-4o-mini" # Defaults to key + assert resolved.timeout == 600 + assert resolved.temperature == 0.1 + assert resolved.cache_dir == ".caches/llm/gpt-4o-mini" + + # With cache suffix + resolved = pool.resolve_llm_config("gpt-4o-mini", cache_suffix="judge_0") + assert resolved.cache_dir == ".caches/llm/judge_0" + + # Overrides take precedence + resolved = pool.resolve_llm_config("gpt-4o") + assert resolved.temperature == 0.5 + assert resolved.max_tokens == 2048 + assert resolved.timeout == 300 + + # Unknown model raises error + with pytest.raises(ValueError, match="Model 'unknown' not found"): + pool.resolve_llm_config("unknown") + + def test_custom_model_id_and_ssl(self) -> None: + """Test custom model IDs and SSL settings.""" + pool = LLMPoolConfig( + models={ + "gpt-4o-eval": LLMProviderConfig( + provider="openai", + model="gpt-4o", # Actual model differs from key + parameters=LLMParametersConfig(temperature=0.0), + ), + "gpt-oss-prod": LLMProviderConfig( + provider="hosted_vllm", model="gpt-oss-20b", ssl_verify=True + ), + "gpt-oss-staging": LLMProviderConfig( + provider="hosted_vllm", model="gpt-oss-20b", ssl_verify=False + ), + } + ) + + # Custom model ID + eval_config = pool.resolve_llm_config("gpt-4o-eval") + assert eval_config.model == "gpt-4o" + assert eval_config.temperature == 0.0 + + # SSL settings + assert pool.resolve_llm_config("gpt-oss-prod").ssl_verify is True + assert pool.resolve_llm_config("gpt-oss-staging").ssl_verify is False + + +class TestJudgePanelConfig: + """Tests for JudgePanelConfig model.""" + + def test_valid_configurations(self) -> None: + """Test valid judge panel configurations.""" + # Single judge + panel = JudgePanelConfig(judges=["gpt-4o-mini"]) + assert len(panel.judges) == 1 + assert panel.enabled_metrics is None + assert panel.aggregation_strategy == "average" + + # Multiple judges with metrics + panel = JudgePanelConfig( + judges=["gpt-4o-mini", "gpt-4o"], + enabled_metrics=["ragas:faithfulness", "custom:correctness"], + aggregation_strategy="max", + ) + assert len(panel.judges) == 2 + assert panel.enabled_metrics is not None + assert len(panel.enabled_metrics) == 2 + + # All aggregation strategies + for strategy in ["max", "average", "majority_vote"]: + panel = JudgePanelConfig( + judges=["gpt-4o-mini"], aggregation_strategy=strategy + ) + assert panel.aggregation_strategy == strategy + + def test_invalid_configurations(self) -> None: + """Test invalid configurations are rejected.""" + # Empty judges + with pytest.raises(ValidationError, match="(?i)at least 1 item"): + JudgePanelConfig(judges=[]) + + # Dict format rejected (must be string IDs) + with pytest.raises(ValidationError): + JudgePanelConfig.model_validate({"judges": [{"provider": "openai"}]}) + + # Invalid metric format - no colon + with pytest.raises(ValidationError, match="framework:metric_name"): + JudgePanelConfig(judges=["gpt-4o-mini"], enabled_metrics=["invalid"]) + + # Invalid metric format - empty parts + with pytest.raises(ValidationError, match="framework:metric_name"): + JudgePanelConfig(judges=["gpt-4o-mini"], enabled_metrics=[":metric"]) + with pytest.raises(ValidationError, match="framework:metric_name"): + JudgePanelConfig(judges=["gpt-4o-mini"], enabled_metrics=["framework:"]) + + # Invalid aggregation strategy + with pytest.raises(ValidationError, match="(?i)aggregation_strategy"): + JudgePanelConfig(judges=["gpt-4o-mini"], aggregation_strategy="invalid") + + +class TestSystemConfigWithLLMPoolAndJudgePanel: + """Tests for SystemConfig with llm_pool and judge_panel.""" + + def test_new_fields_are_optional(self) -> None: + """Test SystemConfig works without new fields (llm_pool, judge_panel).""" + config = SystemConfig() + assert config.llm is not None # Existing field still works + assert config.llm_pool is None + assert config.judge_panel is None + + def test_with_pool_and_panel(self) -> None: + """Test SystemConfig with llm_pool and judge_panel.""" + pool = LLMPoolConfig( + defaults=LLMDefaultsConfig( + parameters=LLMParametersConfig( + temperature=0.0, max_completion_tokens=512 + ) + ), + models={ + "gpt-4o-mini": LLMProviderConfig(provider="openai"), + "gpt-4o": LLMProviderConfig( + provider="openai", + parameters=LLMParametersConfig(max_completion_tokens=1024), + ), + }, + ) + panel = JudgePanelConfig(judges=["gpt-4o-mini", "gpt-4o"]) + + config = SystemConfig(llm_pool=pool, judge_panel=panel) + + # Pool and panel configured + assert config.llm_pool is not None + assert config.judge_panel is not None + + # get_judge_configs returns resolved configs + judge_configs = config.get_judge_configs() + assert len(judge_configs) == 2 + assert all(isinstance(c, LLMConfig) for c in judge_configs) + assert judge_configs[0].model == "gpt-4o-mini" + assert judge_configs[0].cache_dir.endswith("judge_0") + assert judge_configs[1].max_tokens == 1024 + assert judge_configs[1].cache_dir.endswith("judge_1") + + # get_llm_config works + llm_config = config.get_llm_config("gpt-4o-mini") + assert llm_config.provider == "openai" + + def test_error_branches(self) -> None: + """Test error handling in get_judge_configs and get_llm_config.""" + # get_llm_config without pool raises ConfigurationError + config = SystemConfig() + with pytest.raises(ConfigurationError, match="llm_pool.*not configured"): + config.get_llm_config("gpt-4o-mini") + + # get_judge_configs with panel but no pool raises ConfigurationError + config = SystemConfig(judge_panel=JudgePanelConfig(judges=["gpt-4o-mini"])) + with pytest.raises(ConfigurationError, match="llm_pool.*not defined"): + config.get_judge_configs() + + # get_judge_configs with invalid judge ID raises ValueError + pool = LLMPoolConfig( + models={"gpt-4o-mini": LLMProviderConfig(provider="openai")} + ) + panel = JudgePanelConfig(judges=["gpt-4o-mini", "nonexistent"]) + config = SystemConfig(llm_pool=pool, judge_panel=panel) + with pytest.raises(ValueError, match="Model 'nonexistent' not found"): + config.get_judge_configs() diff --git a/tests/unit/core/models/test_system_additional.py b/tests/unit/core/models/test_system_additional.py deleted file mode 100644 index 1a860a2..0000000 --- a/tests/unit/core/models/test_system_additional.py +++ /dev/null @@ -1,242 +0,0 @@ -"""Additional tests for system configuration models.""" - -import os -import tempfile -import pytest -from pydantic import ValidationError -from pytest_mock import MockerFixture - -from lightspeed_evaluation.core.models import ( - LLMConfig, - EmbeddingConfig, - APIConfig, - OutputConfig, - VisualizationConfig, - LoggingConfig, -) - - -class TestLLMConfig: - """Additional tests for LLMConfig.""" - - def test_temperature_validation_min(self) -> None: - """Test temperature minimum validation.""" - with pytest.raises(ValidationError): - LLMConfig(temperature=-0.1) - - def test_temperature_validation_max(self) -> None: - """Test temperature maximum validation.""" - with pytest.raises(ValidationError): - LLMConfig(temperature=2.1) - - def test_max_tokens_validation(self) -> None: - """Test max_tokens minimum validation.""" - with pytest.raises(ValidationError): - LLMConfig(max_tokens=0) - - def test_timeout_validation(self) -> None: - """Test timeout minimum validation.""" - with pytest.raises(ValidationError): - LLMConfig(timeout=0) - - def test_num_retries_validation(self) -> None: - """Test num_retries minimum validation.""" - with pytest.raises(ValidationError): - LLMConfig(num_retries=-1) - - def test_ssl_verify_default(self) -> None: - """Test ssl_verify has correct default value.""" - config = LLMConfig() - assert config.ssl_verify is True - - def test_ssl_verify_false(self) -> None: - """Test ssl_verify can be set to False.""" - config = LLMConfig(ssl_verify=False) - assert config.ssl_verify is False - - def test_ssl_cert_file_default(self) -> None: - """Test ssl_cert_file defaults to None.""" - config = LLMConfig() - assert config.ssl_cert_file is None - - def test_ssl_cert_file_valid_path(self) -> None: - """Test ssl_cert_file with valid certificate file.""" - with tempfile.NamedTemporaryFile(mode="w", suffix=".crt", delete=False) as f: - cert_path = f.name - f.write("-----BEGIN CERTIFICATE-----\ntest\n-----END CERTIFICATE-----\n") - - try: - config = LLMConfig(ssl_cert_file=cert_path) - assert config.ssl_cert_file is not None - assert config.ssl_cert_file == os.path.abspath(cert_path) - assert os.path.isabs(config.ssl_cert_file) - finally: - os.unlink(cert_path) - - def test_ssl_cert_file_expands_env_variables(self, mocker: MockerFixture) -> None: - """Test ssl_cert_file expands environment variables.""" - with tempfile.NamedTemporaryFile(mode="w", suffix=".crt", delete=False) as f: - cert_path = f.name - - try: - test_dir = os.path.dirname(cert_path) - test_filename = os.path.basename(cert_path) - mocker.patch.dict(os.environ, {"TEST_CERT_DIR": test_dir}) - - env_path = f"$TEST_CERT_DIR/{test_filename}" - config = LLMConfig(ssl_cert_file=env_path) - assert config.ssl_cert_file == os.path.abspath(cert_path) - finally: - os.unlink(cert_path) - - def test_ssl_cert_file_nonexistent_raises_error(self) -> None: - """Test ssl_cert_file validation fails for non-existent file.""" - with pytest.raises(ValidationError) as exc_info: - LLMConfig(ssl_cert_file="/tmp/nonexistent_cert_12345.crt") - - assert "not found" in str(exc_info.value).lower() - - def test_ssl_cert_file_directory_raises_error(self) -> None: - """Test ssl_cert_file validation fails for directory paths.""" - temp_dir = tempfile.gettempdir() - with pytest.raises(ValidationError): - LLMConfig(ssl_cert_file=temp_dir) - - -class TestEmbeddingConfig: - """Tests for EmbeddingConfig.""" - - def test_default_values(self) -> None: - """Test default embedding configuration.""" - config = EmbeddingConfig() - - assert config.provider is not None - assert config.model is not None - assert config.cache_enabled is True - - def test_custom_embedding_model(self) -> None: - """Test custom embedding model configuration.""" - config = EmbeddingConfig( - provider="openai", - model="text-embedding-3-small", - ) - - assert config.provider == "openai" - assert config.model == "text-embedding-3-small" - - -class TestAPIConfig: - """Tests for APIConfig.""" - - def test_default_api_config(self) -> None: - """Test default API configuration.""" - config = APIConfig() - - assert isinstance(config.enabled, bool) - assert isinstance(config.cache_enabled, bool) - assert config.timeout > 0 - - def test_custom_api_config(self) -> None: - """Test custom API configuration.""" - config = APIConfig( - enabled=True, - api_base="https://custom.api.com", - timeout=300, - ) - - assert config.enabled is True - assert config.api_base == "https://custom.api.com" - assert config.timeout == 300 - - def test_timeout_validation(self) -> None: - """Test API timeout validation.""" - with pytest.raises(ValidationError): - APIConfig(timeout=0) - - -class TestOutputConfig: - """Tests for OutputConfig.""" - - def test_default_output_config(self) -> None: - """Test default output configuration.""" - config = OutputConfig() - - assert "csv" in config.enabled_outputs - assert len(config.csv_columns) > 0 - - def test_custom_output_config(self) -> None: - """Test custom output configuration.""" - config = OutputConfig( - enabled_outputs=["json"], - csv_columns=["conversation_group_id", "result"], - ) - - assert config.enabled_outputs == ["json"] - assert len(config.csv_columns) == 2 - - def test_minimal_csv_columns(self) -> None: - """Test with minimal CSV columns.""" - config = OutputConfig(csv_columns=["result"]) - assert len(config.csv_columns) >= 1 - - -class TestVisualizationConfig: - """Tests for VisualizationConfig.""" - - def test_default_visualization_config(self) -> None: - """Test default visualization configuration.""" - config = VisualizationConfig() - - assert isinstance(config.enabled_graphs, list) - assert config.dpi > 0 - assert len(config.figsize) == 2 - - def test_custom_visualization_config(self) -> None: - """Test custom visualization configuration.""" - config = VisualizationConfig( - enabled_graphs=["pass_rates", "score_distribution"], - dpi=150, - figsize=(12, 8), # pyright: ignore[reportArgumentType] - ) - - assert "pass_rates" in config.enabled_graphs - assert "score_distribution" in config.enabled_graphs - assert config.dpi == 150 - assert config.figsize == [12, 8] # Pydantic converts tuple to list - - def test_dpi_validation(self) -> None: - """Test DPI validation.""" - with pytest.raises(ValidationError): - VisualizationConfig(dpi=0) - - -class TestLoggingConfig: - """Tests for LoggingConfig.""" - - def test_default_logging_config(self) -> None: - """Test default logging configuration.""" - config = LoggingConfig() - - assert config.source_level is not None - assert config.package_level is not None - assert isinstance(config.package_overrides, dict) - - def test_custom_logging_config(self) -> None: - """Test custom logging configuration.""" - config = LoggingConfig( - source_level="DEBUG", - package_level="ERROR", - package_overrides={"httpx": "CRITICAL"}, - ) - - assert config.source_level == "DEBUG" - assert config.package_level == "ERROR" - assert config.package_overrides["httpx"] == "CRITICAL" - - def test_show_timestamps_toggle(self) -> None: - """Test show_timestamps configuration.""" - config1 = LoggingConfig(show_timestamps=True) - config2 = LoggingConfig(show_timestamps=False) - - assert config1.show_timestamps is True - assert config2.show_timestamps is False