diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index df7ae9b54e..f771670fe4 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -9,7 +9,7 @@ import base64 import warnings from abc import ABC, abstractmethod -from collections.abc import AsyncIterator, Iterator +from collections.abc import AsyncIterator, Callable, Iterator from contextlib import asynccontextmanager, contextmanager from dataclasses import dataclass, field, replace from datetime import datetime @@ -47,7 +47,7 @@ ) from ..output import OutputMode from ..profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec -from ..providers import infer_provider +from ..providers import Provider, infer_provider from ..settings import ModelSettings, merge_model_settings from ..tools import ToolDefinition from ..usage import RequestUsage @@ -677,8 +677,10 @@ def override_allow_model_requests(allow_model_requests: bool) -> Iterator[None]: ALLOW_MODEL_REQUESTS = old_value # pyright: ignore[reportConstantRedefinition] -def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901 - """Infer the model from the name.""" +def infer_model( # noqa: C901 + model: Model | KnownModelName | str, provider_factory: Callable[[str], Provider[Any]] = infer_provider +) -> Model: + """Infer the model from the name. May optionally pass a callable that setup a custom provider for the model.""" if isinstance(model, Model): return model elif model == 'test': @@ -713,7 +715,7 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901 ) provider_name = 'google-vertex' - provider = infer_provider(provider_name) + provider = provider_factory(provider_name) model_kind = provider_name if model_kind.startswith('gateway/'): diff --git a/tests/models/test_model.py b/tests/models/test_model.py index b824fcbfe6..7290ed29b0 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -251,6 +251,17 @@ def test_infer_model( assert m2 is m +def test_infer_model_with_provider(): + from pydantic_ai.providers import openai + + provider_class = openai.OpenAIProvider(api_key='1234', base_url='http://test') + m = infer_model('openai:gpt-5', lambda x: provider_class) + + assert isinstance(m, OpenAIChatModel) + assert m._provider is provider_class # type: ignore + assert m._provider.base_url == 'http://test' # type: ignore + + def test_infer_str_unknown(): with pytest.raises(UserError, match='Unknown model: foobar'): infer_model('foobar')