diff --git a/paperqa/llms.py b/paperqa/llms.py index 5fb5ff21..6c9b7401 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -6,6 +6,7 @@ from collections.abc import AsyncIterable, Awaitable, Callable, Iterable, Sequence from enum import StrEnum from inspect import signature +from sys import version_info from typing import Any import numpy as np @@ -350,9 +351,12 @@ async def _run_completion( }, ] -_DeploymentTypedDictValidator = TypeAdapter( - list[DeploymentTypedDict], config=ConfigDict(arbitrary_types_allowed=True) -) + +IS_PYTHON_BELOW_312 = version_info < (3, 12) +if not IS_PYTHON_BELOW_312: + _DeploymentTypedDictValidator = TypeAdapter( + list[DeploymentTypedDict], config=ConfigDict(arbitrary_types_allowed=True) + ) class LiteLLMModel(LLMModel): @@ -390,11 +394,15 @@ def maybe_set_config_attribute(cls, data: dict[str, Any]) -> dict[str, Any]: "router_kwargs": {"num_retries": 3, "retry_after": 5}, } # we only support one "model name" for now, here we validate - _DeploymentTypedDictValidator.validate_python(data["config"]["model_list"]) - if ( - "config" in data - and len({m["model_name"] for m in data["config"]["model_list"]}) > 1 - ): + model_list = data["config"]["model_list"] + if IS_PYTHON_BELOW_312: + if not isinstance(model_list, list): + # Work around https://github.com/BerriAI/litellm/issues/5664 + raise TypeError(f"model_list must be a list, not a {type(model_list)}.") + else: + # pylint: disable-next=possibly-used-before-assignment + _DeploymentTypedDictValidator.validate_python(model_list) + if "config" in data and len({m["model_name"] for m in model_list}) > 1: raise ValueError("Only one model name per router is supported for now.") return data