Skip to content

Commit

Permalink
feat: RND-105: Support instructor in LLM calls (#162)
Browse files Browse the repository at this point in the history
Co-authored-by: nik <[email protected]>
  • Loading branch information
niklub and nik authored Jul 26, 2024
1 parent 273f35a commit 41f0270
Show file tree
Hide file tree
Showing 34 changed files with 8,504 additions and 4,954 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ jobs:
fail-fast: false
matrix:
os: [ ubuntu-latest, windows-latest ]
python-version: ['3.9', '3.10', '3.11']
# python-version: ['3.9', '3.10', '3.11']
python-version: ['3.11']

steps:
- uses: actions/checkout@v4
Expand Down
3 changes: 3 additions & 0 deletions adala/memories/vectordb.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def remember(self, observation: str, data: Any):
self.remember_many([observation], [data])

def remember_many(self, observations: List[str], data: List[Dict]):
# filter None values from each item in `data`
data = [{k: v for k, v in d.items() if v is not None} for d in data]

self._collection.add(
ids=[self.create_unique_id(o) for o in observations],
documents=observations,
Expand Down
284 changes: 62 additions & 222 deletions adala/runtimes/_litellm.py
Original file line number Diff line number Diff line change
@@ -1,111 +1,24 @@
import asyncio
import logging
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

import litellm
from adala.utils.internal_data import InternalDataFrame
from adala.utils.logs import print_error
from adala.utils.matching import match_options
from adala.utils.parse import parse_template, partial_str_format
from adala.utils.parse import parse_template, partial_str_format, parse_template_to_pydantic_class
from adala.utils.llm import (
parallel_async_get_llm_response, get_llm_response, ConstrainedLLMResponse,
UnconstrainedLLMResponse, ErrorLLMResponse
)
from openai import NotFoundError
from pydantic import ConfigDict, field_validator
from rich import print
from tenacity import retry, stop_after_attempt, wait_random_exponential

from .base import AsyncRuntime, Runtime

logger = logging.getLogger(__name__)


@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
async def async_create_completion(
model: str,
user_prompt: str,
timeout: int,
system_prompt: Optional[str] = None,
api_key: Optional[str] = None,
instruction_first: bool = True,
max_tokens: int = 1000,
temperature: float = 0.0,
) -> Dict[str, Any]:
"""
Async version of create_completion function with error handling and session timeout.
Args:
model: model name. Refer to litellm supported models for how to pass
this: https://litellm.vercel.app/docs/providers
user_prompt: User prompt.
system_prompt: System prompt.
api_key: API key, optional. If provided, will be used to authenticate
with the provider of your specified model.
instruction_first: Whether to put instructions first.
max_tokens: Maximum tokens to generate.
temperature: Temperature for sampling.
Returns:
Dict[str, Any]: OpenAI response or error message.
"""
messages = [{'role': 'user', 'content': user_prompt}]
if system_prompt:
if instruction_first:
messages.insert(0, {'role': 'system', 'content': system_prompt})
else:
messages[0]['content'] += system_prompt

try:
completion = await litellm.acompletion(
api_key=api_key,
model=model,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
timeout=timeout,
)
completion_text = completion.choices[0].message.content
return {
'text': completion_text,
'_adala_error': False,
'_adala_message': None,
'_adala_details': None,
}
except Exception as e:
# Handle other exceptions
return {
'text': None,
'_adala_error': True,
'_adala_message': type(e).__name__,
'_adala_details': str(e),
}


async def async_concurrent_create_completion(
prompts: List[Dict],
instruction_first: bool,
model: str,
max_tokens: int,
temperature: float,
timeout: int,
api_key: Optional[str] = None,
):
tasks = [
asyncio.ensure_future(
async_create_completion(
user_prompt=prompt['user'],
system_prompt=prompt['system'],
model=model,
api_key=api_key,
max_tokens=max_tokens,
temperature=temperature,
timeout=timeout,
instruction_first=instruction_first,
)
)
for prompt in prompts
]
responses = await asyncio.gather(*tasks)
return responses


class LiteLLMChatRuntime(Runtime):
"""
Runtime that uses [LiteLLM API](https://litellm.vercel.app/docs) and chat
Expand Down Expand Up @@ -144,25 +57,22 @@ def init_runtime(self) -> 'Runtime':
)
return self

def execute(self, messages: List):
"""
Execute LiteLLM request given list of messages in OpenAI API format
"""
def get_llm_response(self, messages: List[Dict[str, str]]) -> str:
# TODO: sunset this method in favor of record_to_record
if self.verbose:
print(f'LiteLLM request: {messages}')

completion = litellm.completion(
print(f'**Prompt content**:\n{messages}')
response: Union[ErrorLLMResponse, UnconstrainedLLMResponse] = get_llm_response(
messages=messages,
model=self.model,
api_key=self.api_key,
messages=messages,
max_tokens=self.max_tokens,
temperature=self.temperature,
temperature=self.temperature
)
completion_text = completion.choices[0].message.content

if isinstance(response, ErrorLLMResponse):
raise ValueError(f'{response.adala_message}\n{response.adala_details}')
if self.verbose:
print(f'LiteLLM response: {completion_text}')
return completion_text
print(f'**Response**:\n{response.text}')
return response.text

def record_to_record(
self,
Expand Down Expand Up @@ -191,43 +101,29 @@ def record_to_record(
"""

extra_fields = extra_fields or {}
field_schema = field_schema or {}

options = {}
for field, schema in field_schema.items():
if schema.get('type') == 'array':
options[field] = schema.get('items', {}).get('enum', [])
response_model = parse_template_to_pydantic_class(
output_template,
provided_field_schema=field_schema
)

output_fields = parse_template(
partial_str_format(output_template, **extra_fields),
include_texts=True,
response: Union[ConstrainedLLMResponse, ErrorLLMResponse] = get_llm_response(
user_prompt=input_template.format(**record, **extra_fields),
system_prompt=instructions_template,
instruction_first=instructions_first,
model=self.model,
api_key=self.api_key,
max_tokens=self.max_tokens,
temperature=self.temperature,
response_model=response_model
)
system_prompt = instructions_template
user_prompt = input_template.format(**record, **extra_fields)
messages = [{'role': 'system', 'content': system_prompt}]

outputs = {}
for output_field in output_fields:
if output_field['type'] == 'text':
if user_prompt is not None:
user_prompt += f"\n{output_field['text']}"
else:
user_prompt = output_field['text']
elif output_field['type'] == 'var':
name = output_field['text']
messages.append({'role': 'user', 'content': user_prompt})
completion_text = self.execute(messages)
if name in options:
completion_text = match_options(
completion_text, options[name]
)
outputs[name] = completion_text
messages.append(
{'role': 'assistant', 'content': completion_text}
)
user_prompt = None

return outputs
if isinstance(response, ErrorLLMResponse):
if self.verbose:
print_error(response.adala_message, response.adala_details)
return response.model_dump(by_alias=True)

return response.data


class AsyncLiteLLMChatRuntime(AsyncRuntime):
Expand Down Expand Up @@ -280,20 +176,6 @@ def init_runtime(self) -> 'Runtime':
)
return self

def _prepare_prompt(
self,
row,
input_template: str,
instructions_template: str,
suffix: str,
extra_fields: dict,
) -> Dict[str, str]:
"""Prepare input prompt for OpenAI API from the row of the dataframe"""
return {
'system': instructions_template,
'user': input_template.format(**row, **extra_fields) + suffix,
}

async def batch_to_batch(
self,
batch: InternalDataFrame,
Expand All @@ -306,79 +188,37 @@ async def batch_to_batch(
) -> InternalDataFrame:
"""Execute batch of requests with async calls to OpenAI API"""

extra_fields = extra_fields or {}
field_schema = field_schema or {}
response_model = parse_template_to_pydantic_class(
output_template,
provided_field_schema=field_schema
)

options = {}
for field, schema in field_schema.items():
if schema.get('type') == 'array':
options[field] = schema.get('items', {}).get('enum', [])
extra_fields = extra_fields or {}
user_prompts = batch.apply(lambda row: input_template.format(**row, **extra_fields), axis=1).tolist()

output_fields = parse_template(
partial_str_format(output_template, **extra_fields),
include_texts=True,
responses: List[Union[ConstrainedLLMResponse, ErrorLLMResponse]] = await parallel_async_get_llm_response(
user_prompts=user_prompts,
system_prompt=instructions_template,
instruction_first=instructions_first,
max_tokens=self.max_tokens,
temperature=self.temperature,
model=self.model,
api_key=self.api_key,
timeout=self.timeout,
response_model=response_model
)

if len(output_fields) > 2:
raise NotImplementedError('Only one output field is supported')

suffix = ''
outputs = []
for output_field in output_fields:
if output_field['type'] == 'text':
suffix += output_field['text']

elif output_field['type'] == 'var':
name = output_field['text']
# prepare prompts
prompts = batch.apply(
lambda row: self._prepare_prompt(
row,
input_template,
instructions_template,
suffix,
extra_fields,
),
axis=1,
).tolist()

# TODO refactor to remove async_concurrent_create_completion and async_create_completion
responses = await async_concurrent_create_completion(
prompts=prompts,
instruction_first=instructions_first,
max_tokens=self.max_tokens,
temperature=self.temperature,
model=self.model,
api_key=self.api_key,
timeout=self.timeout,
)

# parse responses, optionally match it with options
for prompt, response in zip(prompts, responses):
completion_text = response.pop('text')
if self.verbose:
if response['error'] is not None:
print_error(
f'Prompt: {prompt}\nLiteLLM API error: {response}'
)
else:
print(
f'Prompt: {prompt}\nLiteLLM API response: {completion_text}'
)
if name in options and completion_text is not None:
completion_text = match_options(
completion_text, options[name]
)
# still technically possible to have a name collision here
# with the error, message, details fields `name in options`
# is only `True` for categorical variables, but is never
# `True` for freeform text generation
response[name] = completion_text
outputs.append(response)

# TODO: note that this doesn't work for multiple output fields e.g.
# `Output {output1} and Output {output2}`
output_df = InternalDataFrame(outputs)
# conver list of LLMResponse objects to the dataframe records
df_data = []
for response in responses:
if isinstance(response, ErrorLLMResponse):
if self.verbose:
print_error(response.adala_message, response.adala_details)
df_data.append(response.model_dump(by_alias=True))
else:
df_data.append(response.data)

output_df = InternalDataFrame(df_data)
return output_df.set_index(batch.index)

async def record_to_record(
Expand Down
4 changes: 2 additions & 2 deletions adala/skills/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def improve(
Summarize your analysis about incorrect predictions and suggest changes to the prompt.""",
}
]
reasoning = runtime.execute(messages)
reasoning = runtime.get_llm_response(messages)

messages += [
{"role": "assistant", "content": reasoning},
Expand Down Expand Up @@ -353,7 +353,7 @@ def improve(
# display dialogue:
for message in messages:
print(f'"{{{message["role"]}}}":\n{message["content"]}')
new_prompt = runtime.execute(messages)
new_prompt = runtime.get_llm_response(messages)
self.instructions = new_prompt


Expand Down
Loading

0 comments on commit 41f0270

Please sign in to comment.