Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import base64
import warnings
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Iterator
from collections.abc import AsyncIterator, Callable, Iterator
from contextlib import asynccontextmanager, contextmanager
from dataclasses import dataclass, field, replace
from datetime import datetime
Expand Down Expand Up @@ -47,7 +47,7 @@
)
from ..output import OutputMode
from ..profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec
from ..providers import infer_provider
from ..providers import Provider, infer_provider
from ..settings import ModelSettings, merge_model_settings
from ..tools import ToolDefinition
from ..usage import RequestUsage
Expand Down Expand Up @@ -677,8 +677,10 @@ def override_allow_model_requests(allow_model_requests: bool) -> Iterator[None]:
ALLOW_MODEL_REQUESTS = old_value # pyright: ignore[reportConstantRedefinition]


def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
"""Infer the model from the name."""
def infer_model( # noqa: C901
model: Model | KnownModelName | str, provider_factory: Callable[[str], Provider[Any]] = infer_provider
) -> Model:
"""Infer the model from the name. May optionally pass a callable that setup a custom provider for the model."""
if isinstance(model, Model):
return model
elif model == 'test':
Expand Down Expand Up @@ -713,7 +715,7 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
)
provider_name = 'google-vertex'

provider = infer_provider(provider_name)
provider = provider_factory(provider_name)

model_kind = provider_name
if model_kind.startswith('gateway/'):
Expand Down