diff --git a/src/fast_agent/core/direct_factory.py b/src/fast_agent/core/direct_factory.py index dc0197c95..6bef11887 100644 --- a/src/fast_agent/core/direct_factory.py +++ b/src/fast_agent/core/direct_factory.py @@ -19,7 +19,7 @@ from fast_agent.core import Core from fast_agent.core.exceptions import AgentConfigError, ModelConfigError from fast_agent.core.logging.logger import get_logger -from fast_agent.core.model_resolution import resolve_model_spec +from fast_agent.core.model_resolution import HARDCODED_DEFAULT_MODEL, resolve_model_spec from fast_agent.core.validation import get_dependencies_groups from fast_agent.event_progress import ProgressAction from fast_agent.interfaces import ( @@ -82,8 +82,6 @@ async def __call__( ) -> AgentDict: ... -HARDCODED_DEFAULT_MODEL = "gpt-5-mini.low" - def get_model_factory( context, diff --git a/src/fast_agent/core/model_resolution.py b/src/fast_agent/core/model_resolution.py index 3f3552828..cfac551db 100644 --- a/src/fast_agent/core/model_resolution.py +++ b/src/fast_agent/core/model_resolution.py @@ -5,6 +5,8 @@ import os from typing import Any +HARDCODED_DEFAULT_MODEL = "gpt-5-mini.low" + def resolve_model_spec( context: Any, diff --git a/src/fast_agent/mcp/mcp_aggregator.py b/src/fast_agent/mcp/mcp_aggregator.py index fadcd3d43..5ccc2364e 100644 --- a/src/fast_agent/mcp/mcp_aggregator.py +++ b/src/fast_agent/mcp/mcp_aggregator.py @@ -29,7 +29,7 @@ from fast_agent.context_dependent import ContextDependent from fast_agent.core.exceptions import ServerSessionTerminatedError from fast_agent.core.logging.logger import get_logger -from fast_agent.core.model_resolution import resolve_model_spec +from fast_agent.core.model_resolution import HARDCODED_DEFAULT_MODEL, resolve_model_spec from fast_agent.event_progress import ProgressAction from fast_agent.mcp.common import SEP, create_namespaced_name, is_namespaced_name from fast_agent.mcp.gen_client import gen_client @@ -357,7 +357,7 @@ def session_factory(read_stream, write_stream, read_timeout, **kwargs): self.context, model=self.config.model, cli_model=cli_model_override, - fallback_to_hardcoded=False, + hardcoded_default=HARDCODED_DEFAULT_MODEL, ) if model_source: logger.info( diff --git a/src/fast_agent/mcp/sampling.py b/src/fast_agent/mcp/sampling.py index 107d65ccd..2211d947e 100644 --- a/src/fast_agent/mcp/sampling.py +++ b/src/fast_agent/mcp/sampling.py @@ -17,6 +17,7 @@ from fast_agent.agents.agent_types import AgentConfig from fast_agent.agents.llm_agent import LlmAgent from fast_agent.core.logging.logger import get_logger +from fast_agent.core.model_resolution import HARDCODED_DEFAULT_MODEL, resolve_model_spec from fast_agent.interfaces import FastAgentLLMProtocol from fast_agent.llm.sampling_converter import SamplingConverter from fast_agent.mcp.helpers.server_config_helpers import get_server_config @@ -106,6 +107,7 @@ async def sample( model: str | None = None api_key: str | None = None + app_context: Any | None = None try: # Extract model from server config using type-safe helper server_config = get_server_config(context) @@ -144,11 +146,20 @@ async def sample( # Fall back to system default model if model is None: try: - if app_context and app_context.config and app_context.config.default_model: - model = app_context.config.default_model - logger.debug(f"Using system default model for sampling: {model}") + cli_model_override = None + if app_context and app_context.config: + cli_model_override = getattr( + app_context.config, "cli_model_override", None + ) + model, model_source = resolve_model_spec( + app_context, + cli_model=cli_model_override, + hardcoded_default=HARDCODED_DEFAULT_MODEL, + ) + if model: + logger.debug(f"Using {model_source} model for sampling: {model}") except Exception as e: - logger.debug(f"Could not get system default model: {e}") + logger.debug(f"Could not resolve default model for sampling: {e}") if model is None: raise ValueError(