diff --git a/adala/runtimes/_litellm.py b/adala/runtimes/_litellm.py index f244a8bf..9da01e8d 100644 --- a/adala/runtimes/_litellm.py +++ b/adala/runtimes/_litellm.py @@ -1,17 +1,11 @@ import asyncio import logging -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional, Union, Type import litellm -from litellm.exceptions import ( - AuthenticationError, - ContentPolicyViolationError, - BadRequestError, - NotFoundError, -) -from litellm.types.utils import Usage +from litellm.exceptions import AuthenticationError import instructor -from instructor.exceptions import InstructorRetryException, IncompleteOutputException +from instructor.exceptions import InstructorRetryException import traceback from adala.utils.exceptions import ConstrainedGenerationError from adala.utils.internal_data import InternalDataFrame @@ -20,6 +14,7 @@ parse_template, partial_str_format, ) +from openai import NotFoundError from pydantic import ConfigDict, field_validator, BaseModel from rich import print from tenacity import ( @@ -27,7 +22,6 @@ Retrying, retry_if_not_exception_type, stop_after_attempt, - wait_random_exponential, ) from pydantic_core._pydantic_core import ValidationError @@ -39,25 +33,6 @@ logger = logging.getLogger(__name__) -# basically only retrying on timeout, incomplete output, or rate limit -# https://docs.litellm.ai/docs/exception_mapping#custom-mapping-list -# NOTE: token usage is only correctly calculated if we only use instructor retries, not litellm retries -# https://github.com/jxnl/instructor/pull/763 -retry_policy = dict( - retry=retry_if_not_exception_type( - ( - ValidationError, - ContentPolicyViolationError, - AuthenticationError, - BadRequestError, - ) - ), - # should stop earlier on ValidationError and later on other errors, but couldn't figure out how to do that cleanly - stop=stop_after_attempt(3), - wait=wait_random_exponential(multiplier=1, max=60), -) - - def get_messages( user_prompt: str, system_prompt: Optional[str] = None, @@ -84,35 +59,6 @@ def _format_error_dict(e: Exception) -> dict: return error_dct -def _log_llm_exception(e) -> dict: - dct = _format_error_dict(e) - base_error = f"Inference error {dct['_adala_message']}" - tb = traceback.format_exc() - logger.error(f"{base_error}\nTraceback:\n{tb}") - return dct - - -def _update_with_usage(data: Dict, usage: Usage, model: str) -> None: - data["_prompt_tokens"] = usage.prompt_tokens - # will not exist if there is no completion - data["_completion_tokens"] = usage.get("completion_tokens", 0) - # can't use litellm.completion_cost bc it only takes the most recent completion, and .usage is summed over retries - # TODO make sure this is calculated correctly after we turn on caching - # litellm will register the cost of an azure model on first successful completion. If there hasn't been a successful completion, the model will not be registered - try: - prompt_cost, completion_cost = litellm.cost_per_token( - model, usage.prompt_tokens, usage.get("completion_tokens", 0) - ) - data["_prompt_cost_usd"] = prompt_cost - data["_completion_cost_usd"] = completion_cost - data["_total_cost_usd"] = prompt_cost + completion_cost - except NotFoundError: - logger.error(f"Failed to get cost for model {model}") - data["_prompt_cost_usd"] = None - data["_completion_cost_usd"] = None - data["_total_cost_usd"] = None - - class LiteLLMChatRuntime(Runtime): """ Runtime that uses [LiteLLM API](https://litellm.vercel.app/docs) and chat @@ -227,57 +173,45 @@ def record_to_record( instructions_first, ) - retries = Retrying(**retry_policy) + retries = Retrying( + retry=retry_if_not_exception_type((ValidationError)), + stop=stop_after_attempt(3), + ) try: # returns a pydantic model named Output - response, completion = ( - instructor_client.chat.completions.create_with_completion( - messages=messages, - response_model=response_model, - model=self.model, - max_tokens=self.max_tokens, - temperature=self.temperature, - seed=self.seed, - max_retries=retries, - # extra inference params passed to this runtime - **self.model_extra, - ) + response = instructor_client.chat.completions.create( + messages=messages, + response_model=response_model, + model=self.model, + max_tokens=self.max_tokens, + temperature=self.temperature, + seed=self.seed, + max_retries=retries, + # extra inference params passed to this runtime + **self.model_extra, ) - usage = completion.usage - dct = response.dict() - except IncompleteOutputException as e: - usage = e.total_usage - dct = _log_llm_exception(e) except InstructorRetryException as e: - usage = e.total_usage # get root cause error from retries n_attempts = e.n_attempts e = e.__cause__.last_attempt.exception() - dct = _log_llm_exception(e) + dct = _format_error_dict(e) + print_error(f"Inference error {dct['_adala_message']} after {n_attempts=}") + tb = traceback.format_exc() + logger.debug(tb) + return dct except Exception as e: - # usage = e.total_usage - # not available here, so have to approximate by hand, assuming the same error occurred each time - n_attempts = retries.stop.max_attempt_number - prompt_tokens = n_attempts * litellm.token_counter( - model=self.model, messages=messages[:-1] - ) # response is appended as the last message - # TODO a pydantic validation error may be appended as the last message, don't know how to get the raw response in this case - completion_tokens = 0 - usage = Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=(prompt_tokens + completion_tokens), - ) - # Catch case where the model does not return a properly formatted output if type(e).__name__ == "ValidationError" and "Invalid JSON" in str(e): e = ConstrainedGenerationError() - # there are no other known errors to catch - dct = _log_llm_exception(e) + # the only other instructor error that would be thrown is IncompleteOutputException due to max_tokens reached + dct = _format_error_dict(e) + print_error(f"Inference error {dct['_adala_message']}") + tb = traceback.format_exc() + logger.debug(tb) + return dct - _update_with_usage(dct, usage, model=self.model) - return dct + return response.dict() class AsyncLiteLLMChatRuntime(AsyncRuntime): @@ -370,11 +304,14 @@ async def batch_to_batch( axis=1, ).tolist() - retries = AsyncRetrying(**retry_policy) + retries = AsyncRetrying( + retry=retry_if_not_exception_type((ValidationError)), + stop=stop_after_attempt(3), + ) tasks = [ asyncio.ensure_future( - async_instructor_client.chat.completions.create_with_completion( + async_instructor_client.chat.completions.create( messages=get_messages( user_prompt, instructions_template, @@ -397,46 +334,31 @@ async def batch_to_batch( # convert list of LLMResponse objects to the dataframe records df_data = [] for response in responses: - if isinstance(response, IncompleteOutputException): - e = response - usage = e.total_usage - dct = _log_llm_exception(e) - elif isinstance(response, InstructorRetryException): + if isinstance(response, InstructorRetryException): e = response - usage = e.total_usage # get root cause error from retries n_attempts = e.n_attempts e = e.__cause__.last_attempt.exception() - dct = _log_llm_exception(e) + dct = _format_error_dict(e) + print_error( + f"Inference error {dct['_adala_message']} after {n_attempts=}" + ) + tb = traceback.format_exc() + logger.debug(tb) + df_data.append(dct) elif isinstance(response, Exception): e = response - # usage = e.total_usage - # not available here, so have to approximate by hand, assuming the same error occurred each time - n_attempts = retries.stop.max_attempt_number - messages = [] # TODO how to get these? - prompt_tokens = n_attempts * litellm.token_counter( - model=self.model, messages=messages[:-1] - ) # response is appended as the last message - # TODO a pydantic validation error may be appended as the last message, don't know how to get the raw response in this case - completion_tokens = 0 - usage = Usage( - prompt_tokens, - completion_tokens, - total_tokens=(prompt_tokens + completion_tokens), - ) - # Catch case where the model does not return a properly formatted output if type(e).__name__ == "ValidationError" and "Invalid JSON" in str(e): e = ConstrainedGenerationError() # the only other instructor error that would be thrown is IncompleteOutputException due to max_tokens reached - dct = _log_llm_exception(e) + dct = _format_error_dict(e) + print_error(f"Inference error {dct['_adala_message']}") + tb = traceback.format_exc() + logger.debug(tb) + df_data.append(dct) else: - resp, completion = response - usage = completion.usage - dct = resp.dict() - - _update_with_usage(dct, usage, model=self.model) - df_data.append(dct) + df_data.append(response.dict()) output_df = InternalDataFrame(df_data) return output_df.set_index(batch.index) diff --git a/server/handlers/result_handlers.py b/server/handlers/result_handlers.py index 102d00b0..784a0a89 100644 --- a/server/handlers/result_handlers.py +++ b/server/handlers/result_handlers.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Dict +from typing import Optional, Any, List, Dict import json from abc import abstractmethod from pydantic import BaseModel, Field, computed_field, ConfigDict, model_validator @@ -80,14 +80,6 @@ class LSEBatchItem(BaseModel): message: Optional[str] = Field(None, alias="_adala_message") details: Optional[str] = Field(None, alias="_adala_details") - prompt_tokens: int = Field(alias="_prompt_tokens") - completion_tokens: int = Field(alias="_completion_tokens") - - # these can fail to calculate - prompt_cost_usd: Optional[float] = Field(alias="_prompt_cost_usd") - completion_cost_usd: Optional[float] = Field(alias="_completion_cost_usd") - total_cost_usd: Optional[float] = Field(alias="_total_cost_usd") - @model_validator(mode="after") def check_error_consistency(self): has_error = self.error diff --git a/tests/test_llm.py b/tests/test_llm.py index c6479bec..aa7dd481 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -7,6 +7,7 @@ @pytest.mark.vcr def test_llm_sync(): + runtime = LiteLLMChatRuntime() # test plaintext success @@ -33,15 +34,7 @@ class Output(BaseModel): ) # note age coerced to string - expected_result = { - "name": "Carla", - "age": "25", - "_prompt_tokens": 86, - "_completion_tokens": 10, - "_prompt_cost_usd": 1.29e-05, - "_completion_cost_usd": 6e-06, - "_total_cost_usd": 1.89e-05, - } + expected_result = {"name": "Carla", "age": "25"} assert result == expected_result # test structured failure @@ -59,17 +52,13 @@ class Output(BaseModel): "_adala_error": True, "_adala_message": "AuthenticationError", "_adala_details": "litellm.AuthenticationError: AuthenticationError: OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: fake_api_key. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}", - "_prompt_tokens": 9, - "_completion_tokens": 0, - "_prompt_cost_usd": 1.35e-06, - "_completion_cost_usd": 0.0, - "_total_cost_usd": 1.35e-06, } assert result == expected_result @pytest.mark.vcr def test_llm_async(): + # test success runtime = AsyncLiteLLMChatRuntime() @@ -90,20 +79,9 @@ class Output(BaseModel): ) # note age coerced to string - expected_result = pd.DataFrame.from_records( - [ - { - "name": "Carla", - "age": "25", - "_prompt_tokens": 86, - "_completion_tokens": 10, - "_prompt_cost_usd": 1.29e-05, - "_completion_cost_usd": 6e-06, - "_total_cost_usd": 1.89e-05, - } - ] - ) - pd.testing.assert_frame_equal(result, expected_result) + expected_result = pd.DataFrame.from_records([{"name": "Carla", "age": "25"}]) + # need 2 all() for row and column axis + assert (result == expected_result).all().all() # test failure @@ -125,14 +103,10 @@ class Output(BaseModel): "_adala_error": True, "_adala_message": "AuthenticationError", "_adala_details": "litellm.AuthenticationError: AuthenticationError: OpenAIException - Incorrect API key provided: fake_api_key. You can find your API key at https://platform.openai.com/account/api-keys.", - "_prompt_tokens": 9, - "_completion_tokens": 0, - "_prompt_cost_usd": 1.35e-06, - "_completion_cost_usd": 0.0, - "_total_cost_usd": 1.35e-06, } ] ) - pd.testing.assert_frame_equal(result, expected_result) + # need 2 all() for row and column axis + assert (result == expected_result).all().all() # TODO test batch with successes and failures, figure out how to inject a particular error into LiteLLM