Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/lightspeed_evaluation/core/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
APIConfig,
CoreConfig,
EmbeddingConfig,
JudgePanelConfig,
LLMConfig,
LLMPoolConfig,
LoggingConfig,
OutputConfig,
SystemConfig,
Expand All @@ -35,7 +37,9 @@
"EvaluationScope",
# System config models
"CoreConfig",
"JudgePanelConfig",
"LLMConfig",
"LLMPoolConfig",
"EmbeddingConfig",
"APIConfig",
"OutputConfig",
Expand Down
350 changes: 350 additions & 0 deletions src/lightspeed_evaluation/core/models/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator

from lightspeed_evaluation.core.system.exceptions import ConfigurationError
from lightspeed_evaluation.core.constants import (
DEFAULT_API_BASE,
DEFAULT_API_CACHE_DIR,
Expand Down Expand Up @@ -318,6 +319,289 @@ class CoreConfig(BaseModel):
)


class LLMParametersConfig(BaseModel):
"""Dynamic parameters passed to LLM API calls.

These parameters are passed directly to the LLM provider.
All fields are optional - unset fields inherit from parent level.
Uses extra="allow" to pass through any provider-specific parameters.
"""

model_config = ConfigDict(extra="allow")

temperature: Optional[float] = Field(
default=None,
ge=0.0,
le=2.0,
description="Sampling temperature",
)
max_completion_tokens: Optional[int] = Field(
default=None,
ge=1,
description="Maximum tokens in response",
)

def to_dict(self, exclude_none: bool = True) -> dict[str, Any]:
"""Convert parameters to dict for passing to LLM.

Args:
exclude_none: If True, exclude None values from output

Returns:
Dict of parameters ready for LLM API call
"""
params = self.model_dump()
if exclude_none:
return {k: v for k, v in params.items() if v is not None}
return params


class LLMDefaultsConfig(BaseModel):
"""Global default settings for all LLMs in the pool.

These are shared defaults that apply to all LLMs unless overridden
at the provider or model level.
"""

model_config = ConfigDict(extra="forbid")

cache_enabled: bool = Field(
default=True,
description="Is caching of LLM queries enabled?",
)
cache_dir: str = Field(
default=DEFAULT_LLM_CACHE_DIR,
min_length=1,
description="Base cache directory",
)

timeout: int = Field(
default=DEFAULT_API_TIMEOUT,
ge=1,
description="Request timeout in seconds",
)

num_retries: int = Field(
default=DEFAULT_LLM_RETRIES,
ge=0,
description="Retry attempts for failed requests",
)

# Default dynamic parameters
parameters: LLMParametersConfig = Field(
default_factory=lambda: LLMParametersConfig(
temperature=DEFAULT_LLM_TEMPERATURE,
max_completion_tokens=DEFAULT_LLM_MAX_TOKENS,
),
description="Default dynamic parameters for LLM calls",
)


class LLMProviderConfig(BaseModel):
"""Configuration for a single LLM provider/model in the pool.

Contains model-specific settings. Cache and retry settings are managed
at the pool defaults level, not per-model.

The dict key is the unique model ID used for referencing.
"""

model_config = ConfigDict(extra="forbid")

# Required: Provider type
provider: str = Field(
min_length=1,
description="Provider type (e.g., openai, watsonx, gemini, hosted_vllm)",
)

# Model identity (optional - defaults to dict key)
model: Optional[str] = Field(
default=None,
min_length=1,
description="Actual model name. If not set, uses the dict key as model name.",
)

# SSL settings (optional - inherit from defaults or use system defaults)
ssl_verify: Optional[bool] = Field(
default=None,
description="Verify SSL certificates. Inherits from defaults if not set.",
)
ssl_cert_file: Optional[str] = Field(
default=None,
description="Path to custom CA certificate file",
)

# API endpoint/key configuration (optional - falls back to environment variable)
api_base: Optional[str] = Field(
default=None,
min_length=1,
description=(
"Base URL for the API endpoint. "
"If not set, falls back to provider-specific environment variable."
),
)
api_key_path: Optional[str] = Field(
default=None,
min_length=1,
description=(
"Path to text file containing the API key for this model. "
"If not set, falls back to provider-specific environment variable."
),
)

# Dynamic parameters (passed to LLM API)
parameters: LLMParametersConfig = Field(
default_factory=LLMParametersConfig,
description="Dynamic parameters for this model (merged with defaults)",
)

# Timeout can be model-specific (some models are slower)
timeout: Optional[int] = Field(
default=None,
ge=1,
description="Override timeout for this model",
)


class LLMPoolConfig(BaseModel):
"""Pool of LLM configurations for reuse across the system.

Provides a centralized place to define all LLM configurations,
which can be referenced by judge_panel, agents, or other components.

Cache and retry settings are managed at the defaults level only.
Model entries contain model-specific settings (provider, parameters, SSL).
"""

model_config = ConfigDict(extra="forbid")

defaults: LLMDefaultsConfig = Field(
default_factory=LLMDefaultsConfig,
description="Global default settings for all LLMs (cache, retry, parameters)",
)
models: dict[str, LLMProviderConfig] = Field(
default_factory=dict,
description="Model configurations. Key is unique model ID for referencing.",
)

def get_model_ids(self) -> list[str]:
"""Get all available model IDs."""
return list(self.models.keys())

def resolve_llm_config(
self, model_id: str, cache_suffix: Optional[str] = None
) -> LLMConfig:
"""Resolve a model ID to a fully configured LLMConfig.

Resolution order: defaults -> model entry (for model-specific fields)

Args:
model_id: Model identifier (key in models dict)
cache_suffix: Optional suffix for cache directory (e.g., "judge_0")

Returns:
Fully resolved LLMConfig

Raises:
ValueError: If model_id not found
"""
if model_id not in self.models:
raise ValueError(
f"Model '{model_id}' not found in llm_pool.models. "
f"Available: {list(self.models.keys())}"
)
entry = self.models[model_id]

# Merge parameters: defaults -> model entry
merged_params: dict[str, Any] = {}
merged_params.update(self.defaults.parameters.to_dict(exclude_none=True))
merged_params.update(entry.parameters.to_dict(exclude_none=True))

# Build cache_dir from defaults with model-specific suffix
suffix = cache_suffix if cache_suffix else model_id
cache_dir = os.path.join(self.defaults.cache_dir, suffix)

return LLMConfig(
provider=entry.provider,
model=entry.model or model_id,
temperature=merged_params.get("temperature", DEFAULT_LLM_TEMPERATURE),
max_tokens=merged_params.get(
"max_completion_tokens", DEFAULT_LLM_MAX_TOKENS
),
timeout=(
entry.timeout if entry.timeout is not None else self.defaults.timeout
),
num_retries=self.defaults.num_retries,
ssl_verify=(
entry.ssl_verify if entry.ssl_verify is not None else DEFAULT_SSL_VERIFY
),
ssl_cert_file=entry.ssl_cert_file,
cache_enabled=self.defaults.cache_enabled,
cache_dir=cache_dir,
# Note: api_base and api_key_path are not propagated yet - requires LLMConfig extension
)


class JudgePanelConfig(BaseModel):
"""Judge panel configuration for multi-LLM evaluation.

References models from LLM pool by model ID (the key in llm_pool.models).
Each judge ID must correspond to a key in the llm_pool.models dictionary.
"""

model_config = ConfigDict(extra="forbid")

judges: list[str] = Field(
...,
min_length=1,
description="List of model IDs (keys from llm_pool.models). At least one required.",
)
enabled_metrics: Optional[list[str]] = Field(
default=None,
description=(
"Metrics that should use the judge panel. "
"If None, all metrics use the panel. "
"If empty list, no metrics use the panel."
),
)
aggregation_strategy: str = Field(
default="average",
description=(
"Strategy for aggregating scores from multiple judges. "
"Options: 'max', 'average', 'majority_vote'. "
"Note: Currently unused - will be implemented later."
),
)

@field_validator("enabled_metrics")
@classmethod
def validate_enabled_metrics(cls, v: Optional[list[str]]) -> Optional[list[str]]:
"""Validate enabled_metrics format (framework:metric_name)."""
if v is not None:
for metric in v:
if not metric or ":" not in metric:
raise ValueError(
f'Metric "{metric}" must be in format "framework:metric_name"'
)
parts = metric.split(":", 1)
if len(parts) != 2 or not parts[0].strip() or not parts[1].strip():
raise ValueError(
f'Metric "{metric}" must be in format "framework:metric_name"'
)
return v

@field_validator("aggregation_strategy")
@classmethod
def validate_aggregation_strategy(cls, v: str) -> str:
"""Validate aggregation_strategy is a supported value."""
allowed = ["max", "average", "majority_vote"]
if v not in allowed:
raise ValueError(
f"Unsupported aggregation_strategy '{v}'. Allowed: {allowed}"
)
return v


class SystemConfig(BaseModel):
"""System configuration using individual config models."""

Expand All @@ -328,6 +612,25 @@ class SystemConfig(BaseModel):
default_factory=CoreConfig, description="Core eval configuration"
)
llm: LLMConfig = Field(default_factory=LLMConfig, description="LLM configuration")

# LLM Pool - shared pool of LLM configurations
llm_pool: Optional[LLMPoolConfig] = Field(
default=None,
description=(
"Pool of LLM configurations. Define models once, "
"reference by ID in judge_panel or other components."
),
)

# Judge Panel - references models from llm_pool
judge_panel: Optional[JudgePanelConfig] = Field(
default=None,
description=(
"Optional judge panel configuration. "
"References models from 'llm_pool' by ID. "
"If not provided, the single 'llm' configuration is used."
),
)
embedding: EmbeddingConfig = Field(
default_factory=EmbeddingConfig, description="Embeddings configuration"
)
Expand All @@ -349,3 +652,50 @@ class SystemConfig(BaseModel):
default_conversation_metrics_metadata: dict[str, dict[str, Any]] = Field(
default_factory=dict, description="Default conversation metrics metadata"
)

def get_judge_configs(self) -> list[LLMConfig]:
"""Get resolved LLMConfig for all judges.

Returns:
List of LLMConfig objects for each judge.
If judge_panel is configured, resolves from llm_pool.
Otherwise, returns single llm config.
"""
if not self.judge_panel:
return [self.llm]

if not self.llm_pool:
raise ConfigurationError(
"judge_panel is configured but 'llm_pool' is not defined. "
"Please define the llm_pool section with models."
)

configs = []
for idx, judge_id in enumerate(self.judge_panel.judges):
cache_suffix = f"judge_{idx}"
config = self.llm_pool.resolve_llm_config(
judge_id, cache_suffix=cache_suffix
)
configs.append(config)
return configs

def get_llm_config(
self, model_id: str, cache_suffix: Optional[str] = None
) -> LLMConfig:
"""Get resolved LLMConfig for a specific model from the pool.

Args:
model_id: Model identifier (key in llm_pool.models)
cache_suffix: Optional suffix for cache directory

Returns:
Fully resolved LLMConfig

Raises:
ConfigurationError: If llm_pool not configured or model not found
"""
if not self.llm_pool:
raise ConfigurationError(
f"Cannot resolve model '{model_id}' - 'llm_pool' is not configured."
)
return self.llm_pool.resolve_llm_config(model_id, cache_suffix=cache_suffix)
Loading
Loading