Skip to content

Commit

Permalink
Validating LiteLLMModel.config structure (#383)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza authored Sep 12, 2024
1 parent 30cb026 commit df15471
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions paperqa/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

import numpy as np
import tiktoken
from litellm import Router, aembedding, token_counter
from pydantic import BaseModel, ConfigDict, Field, model_validator
from litellm import DeploymentTypedDict, Router, aembedding, token_counter
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, model_validator

from paperqa.prompts import default_system_prompt
from paperqa.types import Embeddable, LLMResult
Expand Down Expand Up @@ -350,6 +350,10 @@ async def _run_completion(
},
]

_DeploymentTypedDictValidator = TypeAdapter(
list[DeploymentTypedDict], config=ConfigDict(arbitrary_types_allowed=True)
)


class LiteLLMModel(LLMModel):
"""A wrapper around the litellm library.
Expand All @@ -360,7 +364,6 @@ class LiteLLMModel(LLMModel):
`router_kwargs`: kwargs for the Router class
This way users can specify routing strategies, retries, etc.
"""

config: dict = Field(default_factory=dict)
Expand All @@ -387,11 +390,12 @@ def maybe_set_config_attribute(cls, data: dict[str, Any]) -> dict[str, Any]:
"router_kwargs": {"num_retries": 3, "retry_after": 5},
}
# we only support one "model name" for now, here we validate
_DeploymentTypedDictValidator.validate_python(data["config"]["model_list"])
if (
"config" in data
and len({m["model_name"] for m in data["config"]["model_list"]}) > 1
):
raise ValueError("Only one model name per router is supported for now")
raise ValueError("Only one model name per router is supported for now.")
return data

def __getstate__(self):
Expand Down

0 comments on commit df15471

Please sign in to comment.