diff --git a/python/scenario/config/model_config_resolver.py b/python/scenario/config/model_config_resolver.py new file mode 100644 index 00000000..6f4dd48a --- /dev/null +++ b/python/scenario/config/model_config_resolver.py @@ -0,0 +1,149 @@ +""" +Model configuration resolution utilities. + +This module provides utilities for resolving model configuration parameters +by merging explicit arguments with global default configurations. +""" + +from dataclasses import dataclass +from typing import Optional + +from scenario.config import ModelConfig, ScenarioConfig + + +@dataclass +class ResolvedModelConfig: + """ + Resolved model configuration with all parameters merged from explicit args and defaults. + + Attributes: + model: The resolved model identifier + api_base: The resolved API base URL (if any) + api_key: The resolved API key (if any) + temperature: The resolved temperature value (if any) + max_tokens: The resolved max tokens value (if any) + extra_params: Additional parameters for the model provider + """ + + model: str + api_base: Optional[str] + api_key: Optional[str] + temperature: Optional[float] + max_tokens: Optional[int] + extra_params: dict + + +def resolve_model_config( + *, + model: Optional[str] = None, + api_base: Optional[str] = None, + api_key: Optional[str] = None, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + **extra_kwargs, +) -> ResolvedModelConfig: + """ + Resolve model configuration by merging explicit params with global config. + + This function takes explicit parameter values and merges them with the global + ScenarioConfig defaults. Explicit values always take precedence over config + defaults. Uses None-checks (not falsy checks) to properly handle values like + temperature=0.0. + + Args: + model: Explicit model identifier (e.g., "openai/gpt-4") + api_base: Explicit API base URL + api_key: Explicit API key + temperature: Explicit temperature value (None means "not specified") + max_tokens: Explicit max tokens value (None means "not specified") + **extra_kwargs: Additional parameters to pass to the model provider + (e.g., timeout, headers, client, etc.) + + Returns: + ResolvedModelConfig with all values resolved from explicit params and global config. + + Raises: + ValueError: If no model is configured either explicitly or in global config + + Example: + ```python + # With global config + ScenarioConfig.default_config = ScenarioConfig( + default_model=ModelConfig( + model="openai/gpt-4", + temperature=0.7, + timeout=30 + ) + ) + + # Resolve with overrides + config = resolve_model_config( + model=None, # Use config default + temperature=0.0, # Override (important: 0.0 is valid!) + timeout=60 # Override extra param + ) + # Result: ResolvedModelConfig( + # model="openai/gpt-4", + # api_base=None, + # api_key=None, + # temperature=0.0, + # max_tokens=None, + # extra_params={"timeout": 60} + # ) + ``` + + Note: + - Uses `is not None` checks to distinguish explicit values from defaults + - This allows falsy values like temperature=0.0 to override config defaults + - Extra params from config are merged with lower priority than explicit params + - Any kwargs beyond the standard fields are treated as extra params + """ + # Start with explicit params + resolved_model = model + resolved_api_base = api_base + resolved_api_key = api_key + resolved_temp = temperature + resolved_max_tokens = max_tokens + resolved_extra = extra_kwargs.copy() + + # Merge with global config if present + if ScenarioConfig.default_config is not None: + default_model = ScenarioConfig.default_config.default_model + + if isinstance(default_model, str): + # Simple string config: just set the model + resolved_model = resolved_model or default_model + elif isinstance(default_model, ModelConfig): + # Full ModelConfig: merge all fields + resolved_model = resolved_model or default_model.model + resolved_api_base = resolved_api_base or default_model.api_base + resolved_api_key = resolved_api_key or default_model.api_key + + # Use None-check not falsy-check to support temperature=0.0 + if resolved_temp is None and default_model.temperature is not None: + resolved_temp = default_model.temperature + if resolved_max_tokens is None and default_model.max_tokens is not None: + resolved_max_tokens = default_model.max_tokens + + # Extract extra params from config + config_dict = default_model.model_dump(exclude_none=True) + # Remove standard fields that we handle explicitly + for key in ["model", "api_base", "api_key", "temperature", "max_tokens"]: + config_dict.pop(key, None) + + # Merge: config extras < explicit extra_kwargs + resolved_extra = {**config_dict, **extra_kwargs} + + if not resolved_model: + raise ValueError( + "Model must be configured either explicitly or in ScenarioConfig.default_config" + ) + + return ResolvedModelConfig( + model=resolved_model, + api_base=resolved_api_base, + api_key=resolved_api_key, + temperature=resolved_temp, + max_tokens=resolved_max_tokens, + extra_params=resolved_extra, + ) diff --git a/python/scenario/judge_agent.py b/python/scenario/judge_agent.py index 4935b58d..edadc25f 100644 --- a/python/scenario/judge_agent.py +++ b/python/scenario/judge_agent.py @@ -19,6 +19,7 @@ from scenario.cache import scenario_cache from scenario.agent_adapter import AgentAdapter from scenario.config import ModelConfig, ScenarioConfig +from scenario.config.model_config_resolver import resolve_model_config from ._error_messages import agent_not_configured_error_message from .types import AgentInput, AgentReturnTypes, AgentRole, ScenarioResult @@ -101,7 +102,7 @@ class JudgeAgent(AgentAdapter): model: str api_base: Optional[str] api_key: Optional[str] - temperature: float + temperature: Optional[float] max_tokens: Optional[int] criteria: List[str] system_prompt: Optional[str] @@ -114,7 +115,7 @@ def __init__( model: Optional[str] = None, api_base: Optional[str] = None, api_key: Optional[str] = None, - temperature: float = 0.0, + temperature: Optional[float] = None, max_tokens: Optional[int] = None, system_prompt: Optional[str] = None, **extra_params, @@ -168,51 +169,24 @@ def __init__( experimental and may not be supported in future versions. """ self.criteria = criteria or [] - self.api_base = api_base - self.api_key = api_key - self.temperature = temperature - self.max_tokens = max_tokens self.system_prompt = system_prompt - if model: - self.model = model - - if ScenarioConfig.default_config is not None and isinstance( - ScenarioConfig.default_config.default_model, str - ): - self.model = model or ScenarioConfig.default_config.default_model - self._extra_params = extra_params - elif ScenarioConfig.default_config is not None and isinstance( - ScenarioConfig.default_config.default_model, ModelConfig - ): - self.model = model or ScenarioConfig.default_config.default_model.model - self.api_base = ( - api_base or ScenarioConfig.default_config.default_model.api_base + try: + config = resolve_model_config( + model=model, + api_base=api_base, + api_key=api_key, + temperature=temperature, + max_tokens=max_tokens, + **extra_params, ) - self.api_key = ( - api_key or ScenarioConfig.default_config.default_model.api_key - ) - self.temperature = ( - temperature or ScenarioConfig.default_config.default_model.temperature - ) - self.max_tokens = ( - max_tokens or ScenarioConfig.default_config.default_model.max_tokens - ) - # Extract extra params from ModelConfig - config_dict = ScenarioConfig.default_config.default_model.model_dump( - exclude_none=True - ) - config_dict.pop("model", None) - config_dict.pop("api_base", None) - config_dict.pop("api_key", None) - config_dict.pop("temperature", None) - config_dict.pop("max_tokens", None) - # Merge: config extras < agent extra_params - self._extra_params = {**config_dict, **extra_params} - else: - self._extra_params = extra_params - - if not hasattr(self, "model"): + self.model = config.model + self.api_base = config.api_base + self.api_key = config.api_key + self.temperature = config.temperature + self.max_tokens = config.max_tokens + self._extra_params = config.extra_params + except ValueError: raise Exception(agent_not_configured_error_message("JudgeAgent")) @scenario_cache() diff --git a/python/scenario/user_simulator_agent.py b/python/scenario/user_simulator_agent.py index 5364b59e..ef1b7a95 100644 --- a/python/scenario/user_simulator_agent.py +++ b/python/scenario/user_simulator_agent.py @@ -18,6 +18,7 @@ from scenario.agent_adapter import AgentAdapter from scenario._utils.utils import reverse_roles from scenario.config import ModelConfig, ScenarioConfig +from scenario.config.model_config_resolver import resolve_model_config from ._error_messages import agent_not_configured_error_message from .types import AgentInput, AgentReturnTypes, AgentRole @@ -84,7 +85,7 @@ class UserSimulatorAgent(AgentAdapter): model: str api_base: Optional[str] api_key: Optional[str] - temperature: float + temperature: Optional[float] max_tokens: Optional[int] system_prompt: Optional[str] _extra_params: dict @@ -95,7 +96,7 @@ def __init__( model: Optional[str] = None, api_base: Optional[str] = None, api_key: Optional[str] = None, - temperature: float = 0.0, + temperature: Optional[float] = None, max_tokens: Optional[int] = None, system_prompt: Optional[str] = None, **extra_params, @@ -141,51 +142,24 @@ def __init__( (e.g., headers, timeout, client) for specialized configurations. These are experimental and may not be supported in future versions. """ - self.api_base = api_base - self.api_key = api_key - self.temperature = temperature - self.max_tokens = max_tokens self.system_prompt = system_prompt - if model: - self.model = model - - if ScenarioConfig.default_config is not None and isinstance( - ScenarioConfig.default_config.default_model, str - ): - self.model = model or ScenarioConfig.default_config.default_model - self._extra_params = extra_params - elif ScenarioConfig.default_config is not None and isinstance( - ScenarioConfig.default_config.default_model, ModelConfig - ): - self.model = model or ScenarioConfig.default_config.default_model.model - self.api_base = ( - api_base or ScenarioConfig.default_config.default_model.api_base + try: + config = resolve_model_config( + model=model, + api_base=api_base, + api_key=api_key, + temperature=temperature, + max_tokens=max_tokens, + **extra_params, ) - self.api_key = ( - api_key or ScenarioConfig.default_config.default_model.api_key - ) - self.temperature = ( - temperature or ScenarioConfig.default_config.default_model.temperature - ) - self.max_tokens = ( - max_tokens or ScenarioConfig.default_config.default_model.max_tokens - ) - # Extract extra params from ModelConfig - config_dict = ScenarioConfig.default_config.default_model.model_dump( - exclude_none=True - ) - config_dict.pop("model", None) - config_dict.pop("api_base", None) - config_dict.pop("api_key", None) - config_dict.pop("temperature", None) - config_dict.pop("max_tokens", None) - # Merge: config extras < agent extra_params - self._extra_params = {**config_dict, **extra_params} - else: - self._extra_params = extra_params - - if not hasattr(self, "model"): + self.model = config.model + self.api_base = config.api_base + self.api_key = config.api_key + self.temperature = config.temperature + self.max_tokens = config.max_tokens + self._extra_params = config.extra_params + except ValueError: raise Exception(agent_not_configured_error_message("UserSimulatorAgent")) @scenario_cache() diff --git a/python/tests/test_model_config_resolver.py b/python/tests/test_model_config_resolver.py new file mode 100644 index 00000000..bf2d1082 --- /dev/null +++ b/python/tests/test_model_config_resolver.py @@ -0,0 +1,306 @@ +"""Tests for model configuration resolution logic.""" + +import pytest +from scenario.config import ModelConfig, ScenarioConfig +from scenario.config.model_config_resolver import ( + resolve_model_config, +) + + +class TestResolveModelConfigNoGlobalConfig: + """Test resolution when ScenarioConfig.default_config is None.""" + + def setup_method(self): + """Ensure no global config before each test.""" + ScenarioConfig.default_config = None + + def teardown_method(self): + """Clean up after each test.""" + ScenarioConfig.default_config = None + + def test_explicit_model_only(self): + """Resolver returns explicit model when no global config.""" + config = resolve_model_config( + model="openai/gpt-4", + ) + + assert config.model == "openai/gpt-4" + assert config.api_base is None + assert config.api_key is None + assert config.temperature is None + assert config.max_tokens is None + assert config.extra_params == {} + + def test_no_model_raises_error(self): + """Resolver raises ValueError when no model configured.""" + with pytest.raises(ValueError, match="Model must be configured"): + resolve_model_config() + + def test_all_explicit_params(self): + """Resolver returns all explicit params when provided.""" + config = resolve_model_config( + model="openai/gpt-4", + api_base="https://custom.com", + api_key="sk-test", + temperature=0.5, + max_tokens=1000, + timeout=60, + headers={"X-Test": "value"}, + ) + + assert config.model == "openai/gpt-4" + assert config.api_base == "https://custom.com" + assert config.api_key == "sk-test" + assert config.temperature == 0.5 + assert config.max_tokens == 1000 + assert config.extra_params == {"timeout": 60, "headers": {"X-Test": "value"}} + + +class TestResolveModelConfigWithStringConfig: + """Test resolution when default_model is a string.""" + + def setup_method(self): + """Set up string config before each test.""" + ScenarioConfig.default_config = ScenarioConfig(default_model="openai/gpt-4") + + def teardown_method(self): + """Clean up after each test.""" + ScenarioConfig.default_config = None + + def test_uses_string_config_model(self): + """Resolver uses string config when no explicit model.""" + config = resolve_model_config() + + assert config.model == "openai/gpt-4" + + def test_explicit_model_overrides_string_config(self): + """Explicit model takes precedence over string config.""" + config = resolve_model_config( + model="anthropic/claude-3", + ) + + assert config.model == "anthropic/claude-3" + + def test_string_config_with_extra_params(self): + """Extra params pass through with string config.""" + config = resolve_model_config( + timeout=30, + ) + + assert config.extra_params == {"timeout": 30} + + +class TestResolveModelConfigWithModelConfig: + """Test resolution when default_model is a ModelConfig object.""" + + def setup_method(self): + """Set up ModelConfig before each test.""" + ScenarioConfig.default_config = ScenarioConfig( + default_model=ModelConfig( + model="openai/gpt-4", + api_base="https://config.com", + api_key="sk-config", + temperature=0.7, + max_tokens=2000, + ) + ) + + def teardown_method(self): + """Clean up after each test.""" + ScenarioConfig.default_config = None + + def test_uses_all_config_defaults(self): + """Resolver uses all ModelConfig defaults when nothing explicit.""" + config = resolve_model_config() + + assert config.model == "openai/gpt-4" + assert config.api_base == "https://config.com" + assert config.api_key == "sk-config" + assert config.temperature == 0.7 + assert config.max_tokens == 2000 + assert config.extra_params == {} + + def test_explicit_params_override_config(self): + """Explicit params override ModelConfig defaults.""" + config = resolve_model_config( + model="anthropic/claude-3", + api_base="https://override.com", + api_key="sk-override", + temperature=0.1, + max_tokens=500, + ) + + assert config.model == "anthropic/claude-3" + assert config.api_base == "https://override.com" + assert config.api_key == "sk-override" + assert config.temperature == 0.1 + assert config.max_tokens == 500 + + def test_partial_override_uses_config_for_rest(self): + """Partial overrides use config defaults for unspecified values.""" + config = resolve_model_config( + temperature=0.3, # Only override temperature + ) + + assert config.model == "openai/gpt-4" # From config + assert config.api_base == "https://config.com" # From config + assert config.api_key == "sk-config" # From config + assert config.temperature == 0.3 # Overridden + assert config.max_tokens == 2000 # From config + + +class TestResolveModelConfigFalsyValues: + """Test that falsy values (0, 0.0, '') are handled correctly.""" + + def setup_method(self): + """Set up ModelConfig with non-zero defaults.""" + ScenarioConfig.default_config = ScenarioConfig( + default_model=ModelConfig( + model="openai/gpt-4", + temperature=0.7, + max_tokens=2000, + ) + ) + + def teardown_method(self): + """Clean up after each test.""" + ScenarioConfig.default_config = None + + def test_zero_temperature_overrides_config(self): + """Explicit temperature=0.0 should override config, not be treated as falsy.""" + config = resolve_model_config( + temperature=0.0, # Critical: 0.0 is valid! + ) + + assert config.temperature == 0.0 # Should be 0.0, NOT 0.7 from config + + def test_zero_max_tokens_overrides_config(self): + """Explicit max_tokens=0 should override config.""" + config = resolve_model_config( + max_tokens=0, # Edge case: 0 tokens + ) + + assert config.max_tokens == 0 # Should be 0, NOT 2000 from config + + def test_none_temperature_uses_config(self): + """temperature=None (not specified) should use config default.""" + config = resolve_model_config() + + assert config.temperature == 0.7 # Should use config default + + def test_empty_string_api_base_overrides_config(self): + """Empty string api_base should override config (edge case).""" + ScenarioConfig.default_config = ScenarioConfig( + default_model=ModelConfig( + model="openai/gpt-4", + api_base="https://config.com", + ) + ) + + config = resolve_model_config( + api_base="", # Empty string (falsy) + ) + + # Empty string is falsy, so `or` will use config + # This is actually desired behavior for strings + assert config.api_base == "https://config.com" + + +class TestResolveModelConfigExtraParams: + """Test extra_params merging behavior.""" + + def setup_method(self): + """Set up ModelConfig with extra params.""" + ScenarioConfig.default_config = ScenarioConfig( + default_model=ModelConfig( + model="openai/gpt-4", + timeout=30, # type: ignore # Extra param + headers={"X-Config": "config-value"}, # type: ignore # Extra param + max_retries=3, # type: ignore # Extra param + ) + ) + + def teardown_method(self): + """Clean up after each test.""" + ScenarioConfig.default_config = None + + def test_config_extra_params_pass_through(self): + """Config extra params are included when no explicit params.""" + config = resolve_model_config() + + assert config.extra_params["timeout"] == 30 + assert config.extra_params["headers"] == {"X-Config": "config-value"} + assert config.extra_params["max_retries"] == 3 + + def test_explicit_extra_params_override_config(self): + """Explicit extra_params override config extra params.""" + config = resolve_model_config( + timeout=60, # Override + new_param="value", # New param + ) + + assert config.extra_params["timeout"] == 60 # Overridden + assert config.extra_params["headers"] == { + "X-Config": "config-value" + } # From config + assert config.extra_params["max_retries"] == 3 # From config + assert config.extra_params["new_param"] == "value" # New param + + def test_extra_params_only_from_explicit_when_no_config(self): + """Only explicit extra_params when no ModelConfig.""" + ScenarioConfig.default_config = ScenarioConfig( + default_model="openai/gpt-4" # String config, no extra params + ) + + config = resolve_model_config( + custom="param", + ) + + assert config.extra_params == {"custom": "param"} + + +class TestResolveModelConfigEdgeCases: + """Test edge cases and error conditions.""" + + def teardown_method(self): + """Clean up after each test.""" + ScenarioConfig.default_config = None + + def test_explicit_temperature_overrides_config_default(self): + """Explicit temperature should override config's default 0.0.""" + ScenarioConfig.default_config = ScenarioConfig( + default_model=ModelConfig( + model="openai/gpt-4", + # temperature defaults to 0.0 in ModelConfig + ) + ) + + config = resolve_model_config( + temperature=0.5, + ) + + assert config.temperature == 0.5 + + def test_none_temperature_uses_config_default(self): + """temperature=None should use config's default temperature.""" + ScenarioConfig.default_config = ScenarioConfig( + default_model=ModelConfig( + model="openai/gpt-4", + # temperature defaults to 0.0 in ModelConfig + ) + ) + + config = resolve_model_config() + + assert config.temperature == 0.0 # ModelConfig default + + def test_extra_params_not_mutated(self): + """Resolver should not mutate the input extra_kwargs dict.""" + # Since we're using **kwargs now, this is less of a concern + # but we can still verify the behavior + resolve_model_config( + model="openai/gpt-4", + param="value", + ) + + # No assertion needed - just verify it doesn't raise