Skip to content

Commit

Permalink
Revert "feat: RND-119: Add token usage in inference output (#183)" (#200
Browse files Browse the repository at this point in the history
)

Co-authored-by: nik <[email protected]>
  • Loading branch information
niklub and nik authored Sep 4, 2024
1 parent 9bbd9c2 commit a0e7f6e
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 170 deletions.
176 changes: 49 additions & 127 deletions adala/runtimes/_litellm.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
import asyncio
import logging
from typing import Any, Dict, List, Optional, Type
from typing import Any, Dict, List, Optional, Union, Type

import litellm
from litellm.exceptions import (
AuthenticationError,
ContentPolicyViolationError,
BadRequestError,
NotFoundError,
)
from litellm.types.utils import Usage
from litellm.exceptions import AuthenticationError
import instructor
from instructor.exceptions import InstructorRetryException, IncompleteOutputException
from instructor.exceptions import InstructorRetryException
import traceback
from adala.utils.exceptions import ConstrainedGenerationError
from adala.utils.internal_data import InternalDataFrame
Expand All @@ -20,14 +14,14 @@
parse_template,
partial_str_format,
)
from openai import NotFoundError
from pydantic import ConfigDict, field_validator, BaseModel
from rich import print
from tenacity import (
AsyncRetrying,
Retrying,
retry_if_not_exception_type,
stop_after_attempt,
wait_random_exponential,
)
from pydantic_core._pydantic_core import ValidationError

Expand All @@ -39,25 +33,6 @@
logger = logging.getLogger(__name__)


# basically only retrying on timeout, incomplete output, or rate limit
# https://docs.litellm.ai/docs/exception_mapping#custom-mapping-list
# NOTE: token usage is only correctly calculated if we only use instructor retries, not litellm retries
# https://github.com/jxnl/instructor/pull/763
retry_policy = dict(
retry=retry_if_not_exception_type(
(
ValidationError,
ContentPolicyViolationError,
AuthenticationError,
BadRequestError,
)
),
# should stop earlier on ValidationError and later on other errors, but couldn't figure out how to do that cleanly
stop=stop_after_attempt(3),
wait=wait_random_exponential(multiplier=1, max=60),
)


def get_messages(
user_prompt: str,
system_prompt: Optional[str] = None,
Expand All @@ -84,35 +59,6 @@ def _format_error_dict(e: Exception) -> dict:
return error_dct


def _log_llm_exception(e) -> dict:
dct = _format_error_dict(e)
base_error = f"Inference error {dct['_adala_message']}"
tb = traceback.format_exc()
logger.error(f"{base_error}\nTraceback:\n{tb}")
return dct


def _update_with_usage(data: Dict, usage: Usage, model: str) -> None:
data["_prompt_tokens"] = usage.prompt_tokens
# will not exist if there is no completion
data["_completion_tokens"] = usage.get("completion_tokens", 0)
# can't use litellm.completion_cost bc it only takes the most recent completion, and .usage is summed over retries
# TODO make sure this is calculated correctly after we turn on caching
# litellm will register the cost of an azure model on first successful completion. If there hasn't been a successful completion, the model will not be registered
try:
prompt_cost, completion_cost = litellm.cost_per_token(
model, usage.prompt_tokens, usage.get("completion_tokens", 0)
)
data["_prompt_cost_usd"] = prompt_cost
data["_completion_cost_usd"] = completion_cost
data["_total_cost_usd"] = prompt_cost + completion_cost
except NotFoundError:
logger.error(f"Failed to get cost for model {model}")
data["_prompt_cost_usd"] = None
data["_completion_cost_usd"] = None
data["_total_cost_usd"] = None


class LiteLLMChatRuntime(Runtime):
"""
Runtime that uses [LiteLLM API](https://litellm.vercel.app/docs) and chat
Expand Down Expand Up @@ -227,57 +173,45 @@ def record_to_record(
instructions_first,
)

retries = Retrying(**retry_policy)
retries = Retrying(
retry=retry_if_not_exception_type((ValidationError)),
stop=stop_after_attempt(3),
)

try:
# returns a pydantic model named Output
response, completion = (
instructor_client.chat.completions.create_with_completion(
messages=messages,
response_model=response_model,
model=self.model,
max_tokens=self.max_tokens,
temperature=self.temperature,
seed=self.seed,
max_retries=retries,
# extra inference params passed to this runtime
**self.model_extra,
)
response = instructor_client.chat.completions.create(
messages=messages,
response_model=response_model,
model=self.model,
max_tokens=self.max_tokens,
temperature=self.temperature,
seed=self.seed,
max_retries=retries,
# extra inference params passed to this runtime
**self.model_extra,
)
usage = completion.usage
dct = response.dict()
except IncompleteOutputException as e:
usage = e.total_usage
dct = _log_llm_exception(e)
except InstructorRetryException as e:
usage = e.total_usage
# get root cause error from retries
n_attempts = e.n_attempts
e = e.__cause__.last_attempt.exception()
dct = _log_llm_exception(e)
dct = _format_error_dict(e)
print_error(f"Inference error {dct['_adala_message']} after {n_attempts=}")
tb = traceback.format_exc()
logger.debug(tb)
return dct
except Exception as e:
# usage = e.total_usage
# not available here, so have to approximate by hand, assuming the same error occurred each time
n_attempts = retries.stop.max_attempt_number
prompt_tokens = n_attempts * litellm.token_counter(
model=self.model, messages=messages[:-1]
) # response is appended as the last message
# TODO a pydantic validation error may be appended as the last message, don't know how to get the raw response in this case
completion_tokens = 0
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=(prompt_tokens + completion_tokens),
)

# Catch case where the model does not return a properly formatted output
if type(e).__name__ == "ValidationError" and "Invalid JSON" in str(e):
e = ConstrainedGenerationError()
# there are no other known errors to catch
dct = _log_llm_exception(e)
# the only other instructor error that would be thrown is IncompleteOutputException due to max_tokens reached
dct = _format_error_dict(e)
print_error(f"Inference error {dct['_adala_message']}")
tb = traceback.format_exc()
logger.debug(tb)
return dct

_update_with_usage(dct, usage, model=self.model)
return dct
return response.dict()


class AsyncLiteLLMChatRuntime(AsyncRuntime):
Expand Down Expand Up @@ -370,11 +304,14 @@ async def batch_to_batch(
axis=1,
).tolist()

retries = AsyncRetrying(**retry_policy)
retries = AsyncRetrying(
retry=retry_if_not_exception_type((ValidationError)),
stop=stop_after_attempt(3),
)

tasks = [
asyncio.ensure_future(
async_instructor_client.chat.completions.create_with_completion(
async_instructor_client.chat.completions.create(
messages=get_messages(
user_prompt,
instructions_template,
Expand All @@ -397,46 +334,31 @@ async def batch_to_batch(
# convert list of LLMResponse objects to the dataframe records
df_data = []
for response in responses:
if isinstance(response, IncompleteOutputException):
e = response
usage = e.total_usage
dct = _log_llm_exception(e)
elif isinstance(response, InstructorRetryException):
if isinstance(response, InstructorRetryException):
e = response
usage = e.total_usage
# get root cause error from retries
n_attempts = e.n_attempts
e = e.__cause__.last_attempt.exception()
dct = _log_llm_exception(e)
dct = _format_error_dict(e)
print_error(
f"Inference error {dct['_adala_message']} after {n_attempts=}"
)
tb = traceback.format_exc()
logger.debug(tb)
df_data.append(dct)
elif isinstance(response, Exception):
e = response
# usage = e.total_usage
# not available here, so have to approximate by hand, assuming the same error occurred each time
n_attempts = retries.stop.max_attempt_number
messages = [] # TODO how to get these?
prompt_tokens = n_attempts * litellm.token_counter(
model=self.model, messages=messages[:-1]
) # response is appended as the last message
# TODO a pydantic validation error may be appended as the last message, don't know how to get the raw response in this case
completion_tokens = 0
usage = Usage(
prompt_tokens,
completion_tokens,
total_tokens=(prompt_tokens + completion_tokens),
)

# Catch case where the model does not return a properly formatted output
if type(e).__name__ == "ValidationError" and "Invalid JSON" in str(e):
e = ConstrainedGenerationError()
# the only other instructor error that would be thrown is IncompleteOutputException due to max_tokens reached
dct = _log_llm_exception(e)
dct = _format_error_dict(e)
print_error(f"Inference error {dct['_adala_message']}")
tb = traceback.format_exc()
logger.debug(tb)
df_data.append(dct)
else:
resp, completion = response
usage = completion.usage
dct = resp.dict()

_update_with_usage(dct, usage, model=self.model)
df_data.append(dct)
df_data.append(response.dict())

output_df = InternalDataFrame(df_data)
return output_df.set_index(batch.index)
Expand Down
10 changes: 1 addition & 9 deletions server/handlers/result_handlers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, List, Dict
from typing import Optional, Any, List, Dict
import json
from abc import abstractmethod
from pydantic import BaseModel, Field, computed_field, ConfigDict, model_validator
Expand Down Expand Up @@ -80,14 +80,6 @@ class LSEBatchItem(BaseModel):
message: Optional[str] = Field(None, alias="_adala_message")
details: Optional[str] = Field(None, alias="_adala_details")

prompt_tokens: int = Field(alias="_prompt_tokens")
completion_tokens: int = Field(alias="_completion_tokens")

# these can fail to calculate
prompt_cost_usd: Optional[float] = Field(alias="_prompt_cost_usd")
completion_cost_usd: Optional[float] = Field(alias="_completion_cost_usd")
total_cost_usd: Optional[float] = Field(alias="_total_cost_usd")

@model_validator(mode="after")
def check_error_consistency(self):
has_error = self.error
Expand Down
42 changes: 8 additions & 34 deletions tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

@pytest.mark.vcr
def test_llm_sync():

runtime = LiteLLMChatRuntime()

# test plaintext success
Expand All @@ -33,15 +34,7 @@ class Output(BaseModel):
)

# note age coerced to string
expected_result = {
"name": "Carla",
"age": "25",
"_prompt_tokens": 86,
"_completion_tokens": 10,
"_prompt_cost_usd": 1.29e-05,
"_completion_cost_usd": 6e-06,
"_total_cost_usd": 1.89e-05,
}
expected_result = {"name": "Carla", "age": "25"}
assert result == expected_result

# test structured failure
Expand All @@ -59,17 +52,13 @@ class Output(BaseModel):
"_adala_error": True,
"_adala_message": "AuthenticationError",
"_adala_details": "litellm.AuthenticationError: AuthenticationError: OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: fake_api_key. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}",
"_prompt_tokens": 9,
"_completion_tokens": 0,
"_prompt_cost_usd": 1.35e-06,
"_completion_cost_usd": 0.0,
"_total_cost_usd": 1.35e-06,
}
assert result == expected_result


@pytest.mark.vcr
def test_llm_async():

# test success

runtime = AsyncLiteLLMChatRuntime()
Expand All @@ -90,20 +79,9 @@ class Output(BaseModel):
)

# note age coerced to string
expected_result = pd.DataFrame.from_records(
[
{
"name": "Carla",
"age": "25",
"_prompt_tokens": 86,
"_completion_tokens": 10,
"_prompt_cost_usd": 1.29e-05,
"_completion_cost_usd": 6e-06,
"_total_cost_usd": 1.89e-05,
}
]
)
pd.testing.assert_frame_equal(result, expected_result)
expected_result = pd.DataFrame.from_records([{"name": "Carla", "age": "25"}])
# need 2 all() for row and column axis
assert (result == expected_result).all().all()

# test failure

Expand All @@ -125,14 +103,10 @@ class Output(BaseModel):
"_adala_error": True,
"_adala_message": "AuthenticationError",
"_adala_details": "litellm.AuthenticationError: AuthenticationError: OpenAIException - Incorrect API key provided: fake_api_key. You can find your API key at https://platform.openai.com/account/api-keys.",
"_prompt_tokens": 9,
"_completion_tokens": 0,
"_prompt_cost_usd": 1.35e-06,
"_completion_cost_usd": 0.0,
"_total_cost_usd": 1.35e-06,
}
]
)
pd.testing.assert_frame_equal(result, expected_result)
# need 2 all() for row and column axis
assert (result == expected_result).all().all()

# TODO test batch with successes and failures, figure out how to inject a particular error into LiteLLM

0 comments on commit a0e7f6e

Please sign in to comment.