Skip to content

Commit

Permalink
Fixing Pydantic validation in Python<3.12 (#385)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza committed Sep 12, 2024
1 parent b045db3 commit 0a37ca8
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 10 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
if: github.event_name == 'pull_request' # pre-commit-ci/lite-action only runs here
strategy:
matrix:
python-version: [3.12] # Our min and max supported Python versions
python-version: [3.11, 3.12] # Our min and max supported Python versions
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
Expand All @@ -39,7 +39,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.12] # Our min and max supported Python versions
python-version: [3.11, 3.12] # Our min and max supported Python versions
steps:
- uses: actions/checkout@v4
- name: Set up uv
Expand Down
24 changes: 16 additions & 8 deletions paperqa/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections.abc import AsyncIterable, Awaitable, Callable, Iterable, Sequence
from enum import StrEnum
from inspect import signature
from sys import version_info
from typing import Any

import numpy as np
Expand Down Expand Up @@ -350,9 +351,12 @@ async def _run_completion(
},
]

_DeploymentTypedDictValidator = TypeAdapter(
list[DeploymentTypedDict], config=ConfigDict(arbitrary_types_allowed=True)
)

IS_PYTHON_BELOW_312 = version_info < (3, 12)
if not IS_PYTHON_BELOW_312:
_DeploymentTypedDictValidator = TypeAdapter(
list[DeploymentTypedDict], config=ConfigDict(arbitrary_types_allowed=True)
)


class LiteLLMModel(LLMModel):
Expand Down Expand Up @@ -390,11 +394,15 @@ def maybe_set_config_attribute(cls, data: dict[str, Any]) -> dict[str, Any]:
"router_kwargs": {"num_retries": 3, "retry_after": 5},
}
# we only support one "model name" for now, here we validate
_DeploymentTypedDictValidator.validate_python(data["config"]["model_list"])
if (
"config" in data
and len({m["model_name"] for m in data["config"]["model_list"]}) > 1
):
model_list = data["config"]["model_list"]
if IS_PYTHON_BELOW_312:
if not isinstance(model_list, list):
# Work around https://github.com/BerriAI/litellm/issues/5664
raise TypeError(f"model_list must be a list, not a {type(model_list)}.")
else:
# pylint: disable-next=possibly-used-before-assignment
_DeploymentTypedDictValidator.validate_python(model_list)
if "config" in data and len({m["model_name"] for m in model_list}) > 1:
raise ValueError("Only one model name per router is supported for now.")
return data

Expand Down

0 comments on commit 0a37ca8

Please sign in to comment.