Skip to content

Commit

Permalink
feat: DIA-1270: allow Azure inference (#169)
Browse files Browse the repository at this point in the history
Co-authored-by: pakelley <[email protected]>
Co-authored-by: nik <[email protected]>
  • Loading branch information
3 people authored Aug 1, 2024
1 parent 34ce544 commit 3b6d651
Show file tree
Hide file tree
Showing 42 changed files with 5,179 additions and 5,114 deletions.
8 changes: 3 additions & 5 deletions adala/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
from pydantic import (
BaseModel,
Field,
SkipValidation,
field_validator,
model_validator,
SerializeAsAny,
)
from abc import ABC, abstractmethod
from typing import Any, Optional, List, Dict, Union, Tuple
from abc import ABC
from typing import Optional, Dict, Union, Tuple
from rich import print
import yaml

Expand Down Expand Up @@ -215,8 +214,7 @@ def get_teacher_runtime(self, runtime: Optional[str] = None) -> Runtime:
return runtime

def run(
self, input: InternalDataFrame = None, runtime: Optional[str] = None,
**kwargs
self, input: InternalDataFrame = None, runtime: Optional[str] = None, **kwargs
) -> InternalDataFrame:
"""
Runs the agent on the specified dataset.
Expand Down
14 changes: 9 additions & 5 deletions adala/environments/static_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,15 @@ def get_feedback(
[gt_pred_match.rename("match"), gt], axis=1
)
pred_feedback[pred_column] = match_concat.apply(
lambda row: "Prediction is correct."
if row["match"]
else f'Prediction is incorrect. Correct answer: "{row[gt_column]}"'
if not pd.isna(row["match"])
else np.nan,
lambda row: (
"Prediction is correct."
if row["match"]
else (
f'Prediction is incorrect. Correct answer: "{row[gt_column]}"'
if not pd.isna(row["match"])
else np.nan
)
),
axis=1,
)

Expand Down
1 change: 0 additions & 1 deletion adala/memories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@


class Memory(BaseModel, ABC):

"""
Base class for memories.
"""
Expand Down
177 changes: 73 additions & 104 deletions adala/runtimes/_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,18 @@
import litellm
from adala.utils.internal_data import InternalDataFrame
from adala.utils.logs import print_error
from adala.utils.matching import match_options
from adala.utils.parse import parse_template, partial_str_format, parse_template_to_pydantic_class
from adala.utils.parse import (
parse_template,
partial_str_format,
parse_template_to_pydantic_class,
)
from adala.utils.llm import (
parallel_async_get_llm_response, get_llm_response, ConstrainedLLMResponse,
UnconstrainedLLMResponse, ErrorLLMResponse
parallel_async_get_llm_response,
get_llm_response,
ConstrainedLLMResponse,
UnconstrainedLLMResponse,
ErrorLLMResponse,
LiteLLMInferenceSettings,
)
from openai import NotFoundError
from pydantic import ConfigDict, field_validator
Expand All @@ -19,34 +26,18 @@
logger = logging.getLogger(__name__)


class LiteLLMChatRuntime(Runtime):
class LiteLLMChatRuntime(LiteLLMInferenceSettings, Runtime):
"""
Runtime that uses [LiteLLM API](https://litellm.vercel.app/docs) and chat
completion models to perform the skill.
Attributes:
model: Model name, refer to LiteLLM's supported provider docs for
how to pass this for your model: https://litellm.vercel.app/docs/providers
api_key: API key, optional. If provided, will be used to authenticate
with the provider of your specified model.
base_url: Points to the endpoint where your model is hosted
max_tokens: Maximum number of tokens to generate. Defaults to 1000.
splitter: Splitter to use for splitting messages. Defaults to None.
temperature: Temperature for sampling, between 0 and 1.
inference_settings (LiteLLMInferenceSettings): Common inference settings for LiteLLM.
"""

model_config = ConfigDict(
arbitrary_types_allowed=True
) # for @computed_field

model: str
api_key: Optional[str]
base_url: Optional[str] = None
max_tokens: Optional[int] = 1000
splitter: Optional[str] = None
temperature: Optional[float] = 0.0
model_config = ConfigDict(arbitrary_types_allowed=True) # for @computed_field

def init_runtime(self) -> 'Runtime':
def init_runtime(self) -> "Runtime":
# check model availability
try:
if self.api_key:
Expand All @@ -60,18 +51,17 @@ def init_runtime(self) -> 'Runtime':
def get_llm_response(self, messages: List[Dict[str, str]]) -> str:
# TODO: sunset this method in favor of record_to_record
if self.verbose:
print(f'**Prompt content**:\n{messages}')
print(f"**Prompt content**:\n{messages}")
response: Union[ErrorLLMResponse, UnconstrainedLLMResponse] = get_llm_response(
messages=messages,
model=self.model,
api_key=self.api_key,
max_tokens=self.max_tokens,
temperature=self.temperature
inference_settings=LiteLLMInferenceSettings(
**self.dict(include=LiteLLMInferenceSettings.model_fields.keys())
),
)
if isinstance(response, ErrorLLMResponse):
raise ValueError(f'{response.adala_message}\n{response.adala_details}')
raise ValueError(f"{response.adala_message}\n{response.adala_details}")
if self.verbose:
print(f'**Response**:\n{response.text}')
print(f"**Response**:\n{response.text}")
return response.text

def record_to_record(
Expand Down Expand Up @@ -103,19 +93,17 @@ def record_to_record(
extra_fields = extra_fields or {}

response_model = parse_template_to_pydantic_class(
output_template,
provided_field_schema=field_schema
output_template, provided_field_schema=field_schema
)

response: Union[ConstrainedLLMResponse, ErrorLLMResponse] = get_llm_response(
user_prompt=input_template.format(**record, **extra_fields),
system_prompt=instructions_template,
instruction_first=instructions_first,
model=self.model,
api_key=self.api_key,
max_tokens=self.max_tokens,
temperature=self.temperature,
response_model=response_model
response_model=response_model,
inference_settings=LiteLLMInferenceSettings(
**self.dict(include=LiteLLMInferenceSettings.model_fields.keys())
),
)

if isinstance(response, ErrorLLMResponse):
Expand All @@ -126,46 +114,28 @@ def record_to_record(
return response.data


class AsyncLiteLLMChatRuntime(AsyncRuntime):
class AsyncLiteLLMChatRuntime(LiteLLMInferenceSettings, AsyncRuntime):
"""
Runtime that uses [OpenAI API](https://openai.com/) and chat completion
models to perform the skill. It uses async calls to OpenAI API.
Attributes:
model: OpenAI model name.
api_key: API key, optional. If provided, will be used to authenticate
with the provider of your specified model.
base_url: Points to the endpoint where your model is hosted
max_tokens: Maximum number of tokens to generate. Defaults to 1000.
temperature: Temperature for sampling, between 0 and 1. Higher values
means the model will take more risks. Try 0.9 for more
creative applications, and 0 (argmax sampling) for ones
with a well-defined answer.
Defaults to 0.0.
inference_settings (LiteLLMInferenceSettings): Common inference settings for LiteLLM.
"""

model_config = ConfigDict(
arbitrary_types_allowed=True
) # for @computed_field

model: str
api_key: Optional[str] = None
base_url: Optional[str] = None
max_tokens: Optional[int] = 1000
temperature: Optional[float] = 0.0
splitter: Optional[str] = None
timeout: Optional[int] = 10
model_config = ConfigDict(arbitrary_types_allowed=True) # for @computed_field

@field_validator("concurrency", mode="before")
def check_concurrency(cls, value) -> int:
value = value or -1
if value < 1:
raise NotImplementedError(
"You must explicitly specify the number of concurrent clients for AsyncOpenAIChatRuntime. "
"Set `AsyncOpenAIChatRuntime(concurrency=10, ...)` or any other positive integer. ")
"Set `AsyncOpenAIChatRuntime(concurrency=10, ...)` or any other positive integer. "
)
return value

def init_runtime(self) -> 'Runtime':
def init_runtime(self) -> "Runtime":
# check model availability
try:
if self.api_key:
Expand All @@ -189,26 +159,27 @@ async def batch_to_batch(
"""Execute batch of requests with async calls to OpenAI API"""

response_model = parse_template_to_pydantic_class(
output_template,
provided_field_schema=field_schema
output_template, provided_field_schema=field_schema
)

extra_fields = extra_fields or {}
user_prompts = batch.apply(lambda row: input_template.format(**row, **extra_fields), axis=1).tolist()

responses: List[Union[ConstrainedLLMResponse, ErrorLLMResponse]] = await parallel_async_get_llm_response(
user_prompts=user_prompts,
system_prompt=instructions_template,
instruction_first=instructions_first,
max_tokens=self.max_tokens,
temperature=self.temperature,
model=self.model,
api_key=self.api_key,
timeout=self.timeout,
response_model=response_model
user_prompts = batch.apply(
lambda row: input_template.format(**row, **extra_fields), axis=1
).tolist()

responses: List[Union[ConstrainedLLMResponse, ErrorLLMResponse]] = (
await parallel_async_get_llm_response(
user_prompts=user_prompts,
system_prompt=instructions_template,
instruction_first=instructions_first,
response_model=response_model,
inference_settings=LiteLLMInferenceSettings(
**self.dict(include=LiteLLMInferenceSettings.model_fields.keys())
),
)
)

# conver list of LLMResponse objects to the dataframe records
# convert list of LLMResponse objects to the dataframe records
df_data = []
for response in responses:
if isinstance(response, ErrorLLMResponse):
Expand All @@ -231,7 +202,7 @@ async def record_to_record(
field_schema: Optional[Dict] = None,
instructions_first: bool = True,
) -> Dict[str, str]:
raise NotImplementedError('record_to_record is not implemented')
raise NotImplementedError("record_to_record is not implemented")


class LiteLLMVisionRuntime(LiteLLMChatRuntime):
Expand Down Expand Up @@ -284,58 +255,56 @@ def record_to_record(

if len(output_fields) > 1:
raise NotImplementedError(
f'{self.__class__.__name__} does not support multiple output fields. '
f'Found: {output_fields}'
f"{self.__class__.__name__} does not support multiple output fields. "
f"Found: {output_fields}"
)
output_field = output_fields[0]
output_field_name = output_field['text']
output_field_name = output_field["text"]

input_fields = parse_template(input_template)

# split input template into text and image parts
input_text = ''
input_text = ""
content = [
{
'type': 'text',
'text': instructions_template,
"type": "text",
"text": instructions_template,
}
]
for field in input_fields:
if field['type'] == 'text':
input_text += field['text']
elif field['type'] == 'var':
if field['text'] not in field_schema:
input_text += record[field['text']]
elif field_schema[field['text']]['type'] == 'string':
if field_schema[field['text']].get('format') == 'uri':
if field["type"] == "text":
input_text += field["text"]
elif field["type"] == "var":
if field["text"] not in field_schema:
input_text += record[field["text"]]
elif field_schema[field["text"]]["type"] == "string":
if field_schema[field["text"]].get("format") == "uri":
if input_text:
content.append(
{'type': 'text', 'text': input_text}
)
input_text = ''
content.append({"type": "text", "text": input_text})
input_text = ""
content.append(
{
'type': 'image_url',
'image_url': record[field['text']],
"type": "image_url",
"image_url": record[field["text"]],
}
)
else:
input_text += record[field['text']]
input_text += record[field["text"]]
else:
raise ValueError(
f'Unsupported field type: {field_schema[field["text"]]["type"]}'
)
if input_text:
content.append({'type': 'text', 'text': input_text})
content.append({"type": "text", "text": input_text})

if self.verbose:
print(f'**Prompt content**:\n{content}')
print(f"**Prompt content**:\n{content}")

completion = litellm.completion(
model=self.model,
api_key=self.api_key,
messages=[{'role': 'user', 'content': content}],
max_tokens=self.max_tokens,
messages=[{"role": "user", "content": content}],
inference_settings=LiteLLMInferenceSettings(
**self.dict(include=LiteLLMInferenceSettings.model_fields.keys())
),
)

completion_text = completion.choices[0].message.content
Expand Down
Loading

0 comments on commit 3b6d651

Please sign in to comment.