Skip to content

Commit

Permalink
Litellm refactor (#172)
Browse files Browse the repository at this point in the history
  • Loading branch information
matt-bernstein authored Aug 2, 2024
1 parent 3b6d651 commit 0d536db
Show file tree
Hide file tree
Showing 12 changed files with 715 additions and 941 deletions.
1 change: 1 addition & 0 deletions adala/runtimes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .base import Runtime, AsyncRuntime
from ._openai import OpenAIChatRuntime, OpenAIVisionRuntime, AsyncOpenAIChatRuntime
from ._litellm import LiteLLMChatRuntime, AsyncLiteLLMChatRuntime
250 changes: 181 additions & 69 deletions adala/runtimes/_litellm.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,113 @@
import asyncio
import logging
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional

import litellm
from litellm.exceptions import AuthenticationError
import instructor
import traceback
from adala.utils.internal_data import InternalDataFrame
from adala.utils.logs import print_error
from adala.utils.parse import (
parse_template,
partial_str_format,
parse_template_to_pydantic_class,
)
from adala.utils.llm import (
parallel_async_get_llm_response,
get_llm_response,
ConstrainedLLMResponse,
UnconstrainedLLMResponse,
ErrorLLMResponse,
LiteLLMInferenceSettings,
)
from openai import NotFoundError
from pydantic import ConfigDict, field_validator
from rich import print

from .base import AsyncRuntime, Runtime

instructor_client = instructor.from_litellm(litellm.completion)
async_instructor_client = instructor.from_litellm(litellm.acompletion)

logger = logging.getLogger(__name__)


class LiteLLMChatRuntime(LiteLLMInferenceSettings, Runtime):
def get_messages(
user_prompt: str,
system_prompt: Optional[str] = None,
instruction_first: bool = True,
):
messages = [{"role": "user", "content": user_prompt}]
if system_prompt:
if instruction_first:
messages.insert(0, {"role": "system", "content": system_prompt})
else:
messages[0]["content"] += system_prompt
return messages


class LiteLLMChatRuntime(Runtime):
"""
Runtime that uses [LiteLLM API](https://litellm.vercel.app/docs) and chat
completion models to perform the skill.
The default model provider is [OpenAI](https://openai.com/), using the OPENAI_API_KEY environment variable. Other providers [can be chosen](https://litellm.vercel.app/docs/set_keys) through environment variables or passed parameters.
Attributes:
inference_settings (LiteLLMInferenceSettings): Common inference settings for LiteLLM.
model: model name. Refer to litellm supported models for how to pass
this: https://litellm.vercel.app/docs/providers
max_tokens: Maximum tokens to generate.
temperature: Temperature for sampling.
seed: Integer seed to reduce nondeterminism in generation.
Extra parameters passed to this class will be used for inference. See `litellm.types.completion.CompletionRequest` for a full list. Some common ones are:
api_key: API key, optional. If provided, will be used to authenticate
with the provider of your specified model.
base_url (Optional[str]): Base URL, optional. If provided, will be used to talk to an OpenAI-compatible API provider besides OpenAI.
api_version (Optional[str]): API version, optional except for Azure.
timeout: Timeout in seconds.
"""

model_config = ConfigDict(arbitrary_types_allowed=True) # for @computed_field
model: str = "gpt-4o-mini"
max_tokens: int = 1000
temperature: float = 0.0
seed: Optional[int] = 47

model_config = ConfigDict(extra="allow")

def init_runtime(self) -> "Runtime":
# check model availability
# extension of litellm.check_valid_key for non-openai deployments
try:
if self.api_key:
litellm.check_valid_key(model=self.model, api_key=self.api_key)
except NotFoundError:
messages = [{"role": "user", "content": "Hey, how's it going?"}]
litellm.completion(
messages=messages,
model=self.model,
max_tokens=self.max_tokens,
temperature=self.temperature,
seed=self.seed,
# extra inference params passed to this runtime
**self.model_extra,
)
except AuthenticationError:
raise ValueError(
f'Requested model "{self.model}" is not available with your api_key.'
f'Requested model "{self.model}" is not available with your api_key and settings.'
)
except Exception as e:
raise ValueError(
f'Failed to check availability of requested model "{self.model}": {e}'
)
return self

def get_llm_response(self, messages: List[Dict[str, str]]) -> str:
# TODO: sunset this method in favor of record_to_record
if self.verbose:
print(f"**Prompt content**:\n{messages}")
response: Union[ErrorLLMResponse, UnconstrainedLLMResponse] = get_llm_response(
completion = litellm.completion(
messages=messages,
inference_settings=LiteLLMInferenceSettings(
**self.dict(include=LiteLLMInferenceSettings.model_fields.keys())
),
model=self.model,
max_tokens=self.max_tokens,
temperature=self.temperature,
seed=self.seed,
# extra inference params passed to this runtime
**self.model_extra,
)
if isinstance(response, ErrorLLMResponse):
raise ValueError(f"{response.adala_message}\n{response.adala_details}")
completion_text = completion.choices[0].message.content
if self.verbose:
print(f"**Response**:\n{response.text}")
return response.text
print(f"**Response**:\n{completion_text}")
return completion_text

def record_to_record(
self,
Expand Down Expand Up @@ -95,35 +140,93 @@ def record_to_record(
response_model = parse_template_to_pydantic_class(
output_template, provided_field_schema=field_schema
)

response: Union[ConstrainedLLMResponse, ErrorLLMResponse] = get_llm_response(
user_prompt=input_template.format(**record, **extra_fields),
system_prompt=instructions_template,
instruction_first=instructions_first,
response_model=response_model,
inference_settings=LiteLLMInferenceSettings(
**self.dict(include=LiteLLMInferenceSettings.model_fields.keys())
),
messages = get_messages(
input_template.format(**record, **extra_fields),
instructions_template,
instructions_first,
)

if isinstance(response, ErrorLLMResponse):
try:
# returns a pydantic model named Output
response = instructor_client.chat.completions.create(
messages=messages,
response_model=response_model,
model=self.model,
max_tokens=self.max_tokens,
temperature=self.temperature,
seed=self.seed,
# extra inference params passed to this runtime
**self.model_extra,
)
except Exception as e:
error_message = type(e).__name__
# error_details = str(e)
error_details = traceback.format_exc()
if self.verbose:
print_error(response.adala_message, response.adala_details)
return response.model_dump(by_alias=True)
print_error(error_message, error_details)
# TODO change this format
error_dct = {
"_adala_error": True,
"_adala_message": error_message,
"_adala_details": error_details,
}
return error_dct

return response.data
return response.dict()


class AsyncLiteLLMChatRuntime(LiteLLMInferenceSettings, AsyncRuntime):
class AsyncLiteLLMChatRuntime(AsyncRuntime):
"""
Runtime that uses [OpenAI API](https://openai.com/) and chat completion
models to perform the skill. It uses async calls to OpenAI API.
The default model provider is [OpenAI](https://openai.com/), using the OPENAI_API_KEY environment variable. Other providers [can be chosen](https://litellm.vercel.app/docs/set_keys) through environment variables or passed parameters.
Attributes:
inference_settings (LiteLLMInferenceSettings): Common inference settings for LiteLLM.
model: model name. Refer to litellm supported models for how to pass
this: https://litellm.vercel.app/docs/providers
max_tokens: Maximum tokens to generate.
temperature: Temperature for sampling.
seed: Integer seed to reduce nondeterminism in generation.
Extra parameters passed to this class will be used for inference. See `litellm.types.completion.CompletionRequest` for a full list. Some common ones are:
api_key: API key, optional. If provided, will be used to authenticate
with the provider of your specified model.
base_url (Optional[str]): Base URL, optional. If provided, will be used to talk to an OpenAI-compatible API provider besides OpenAI.
api_version (Optional[str]): API version, optional except for Azure.
timeout: Timeout in seconds.
"""

model_config = ConfigDict(arbitrary_types_allowed=True) # for @computed_field
model: str = "gpt-4o-mini"
max_tokens: int = 1000
temperature: float = 0.0
seed: Optional[int] = 47

model_config = ConfigDict(extra="allow")

def init_runtime(self) -> "Runtime":
# check model availability
# extension of litellm.check_valid_key for non-openai deployments
try:
messages = [{"role": "user", "content": "Hey, how's it going?"}]
litellm.completion(
messages=messages,
model=self.model,
max_tokens=self.max_tokens,
temperature=self.temperature,
seed=self.seed,
# extra inference params passed to this runtime
**self.model_extra,
)
except AuthenticationError:
raise ValueError(
f'Requested model "{self.model}" is not available with your api_key and settings.'
)
except Exception as e:
raise ValueError(
f'Failed to check availability of requested model "{self.model}": {e}'
)
return self

@field_validator("concurrency", mode="before")
def check_concurrency(cls, value) -> int:
Expand All @@ -135,17 +238,6 @@ def check_concurrency(cls, value) -> int:
)
return value

def init_runtime(self) -> "Runtime":
# check model availability
try:
if self.api_key:
litellm.check_valid_key(model=self.model, api_key=self.api_key)
except NotFoundError:
raise ValueError(
f'Requested model "{self.model}" is not available in your OpenAI account.'
)
return self

async def batch_to_batch(
self,
batch: InternalDataFrame,
Expand All @@ -167,27 +259,45 @@ async def batch_to_batch(
lambda row: input_template.format(**row, **extra_fields), axis=1
).tolist()

responses: List[Union[ConstrainedLLMResponse, ErrorLLMResponse]] = (
await parallel_async_get_llm_response(
user_prompts=user_prompts,
system_prompt=instructions_template,
instruction_first=instructions_first,
response_model=response_model,
inference_settings=LiteLLMInferenceSettings(
**self.dict(include=LiteLLMInferenceSettings.model_fields.keys())
),
tasks = [
asyncio.ensure_future(
async_instructor_client.chat.completions.create(
messages=get_messages(
user_prompt,
instructions_template,
instructions_first,
),
response_model=response_model,
model=self.model,
max_tokens=self.max_tokens,
temperature=self.temperature,
seed=self.seed,
# extra inference params passed to this runtime
**self.model_extra,
)
)
)
for user_prompt in user_prompts
]
responses = await asyncio.gather(*tasks, return_exceptions=True)

# convert list of LLMResponse objects to the dataframe records
df_data = []
for response in responses:
if isinstance(response, ErrorLLMResponse):
if isinstance(response, Exception):
error_message = type(response).__name__
# error_details = str(response)
error_details = traceback.format_exc()
if self.verbose:
print_error(response.adala_message, response.adala_details)
df_data.append(response.model_dump(by_alias=True))
print_error(error_message, error_details)
# TODO change this format
error_dct = {
"_adala_error": True,
"_adala_message": error_message,
"_adala_details": error_details,
}
df_data.append(error_dct)
else:
df_data.append(response.data)
df_data.append(response.dict())

output_df = InternalDataFrame(df_data)
return output_df.set_index(batch.index)
Expand Down Expand Up @@ -302,9 +412,11 @@ def record_to_record(

completion = litellm.completion(
messages=[{"role": "user", "content": content}],
inference_settings=LiteLLMInferenceSettings(
**self.dict(include=LiteLLMInferenceSettings.model_fields.keys())
),
max_tokens=self.max_tokens,
temperature=self.temperature,
seed=self.seed,
# extra inference params passed to this runtime
**self.model_extra,
)

completion_text = completion.choices[0].message.content
Expand Down
50 changes: 4 additions & 46 deletions adala/runtimes/_openai.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,7 @@
import os

from pydantic import Field

from ._litellm import AsyncLiteLLMChatRuntime, LiteLLMChatRuntime, LiteLLMVisionRuntime


class OpenAIChatRuntime(LiteLLMChatRuntime):
"""
Runtime that uses [OpenAI API](https://openai.com/) and chat completion
models to perform the skill.
Attributes:
inference_settings (LiteLLMInferenceSettings): Common inference settings for LiteLLM.
"""

# TODO does it make any sense for this to be optional?
api_key: str = Field(default=os.getenv("OPENAI_API_KEY"))


class AsyncOpenAIChatRuntime(AsyncLiteLLMChatRuntime):
"""
Runtime that uses [OpenAI API](https://openai.com/) and chat completion
models to perform the skill. It uses async calls to OpenAI API.
Attributes:
inference_settings (LiteLLMInferenceSettings): Common inference settings for LiteLLM.
"""

api_key: str = Field(default=os.getenv("OPENAI_API_KEY"))


class OpenAIVisionRuntime(LiteLLMVisionRuntime):
"""
Runtime that uses [OpenAI API](https://openai.com/) and vision models to
perform the skill.
Only compatible with OpenAI API version 1.0.0 or higher.
"""

api_key: str = Field(default=os.getenv("OPENAI_API_KEY"))
# NOTE this check used to exist in OpenAIVisionRuntime.record_to_record,
# but doesn't seem to have a definition
# def init_runtime(self) -> 'Runtime':
# if not check_if_new_openai_version():
# raise NotImplementedError(
# f'{self.__class__.__name__} requires OpenAI API version 1.0.0 or higher.'
# )
# super().init_runtime()
# litellm already reads the OPENAI_API_KEY env var, which was the reason for this class
OpenAIChatRuntime = LiteLLMChatRuntime
AsyncOpenAIChatRuntime = AsyncLiteLLMChatRuntime
OpenAIVisionRuntime = LiteLLMVisionRuntime
Loading

0 comments on commit 0d536db

Please sign in to comment.