From 1f92be8dcf2809767abe2550cd7975e83e565ef6 Mon Sep 17 00:00:00 2001 From: Hakan Erol <47988814+hakan458@users.noreply.github.com> Date: Fri, 16 Aug 2024 10:35:32 -0700 Subject: [PATCH] fix: DIA-1334: Catch and report pydantic ValidationError (#184) --- adala/runtimes/_litellm.py | 19 +++++++++++++++++++ adala/utils/exceptions.py | 6 ++++++ 2 files changed, 25 insertions(+) create mode 100644 adala/utils/exceptions.py diff --git a/adala/runtimes/_litellm.py b/adala/runtimes/_litellm.py index 14738457..d88989af 100644 --- a/adala/runtimes/_litellm.py +++ b/adala/runtimes/_litellm.py @@ -7,6 +7,7 @@ import instructor from instructor.exceptions import InstructorRetryException import traceback +from adala.utils.exceptions import ConstrainedGenerationError from adala.utils.internal_data import InternalDataFrame from adala.utils.logs import print_error from adala.utils.parse import ( @@ -17,6 +18,8 @@ from openai import NotFoundError from pydantic import ConfigDict, field_validator, BaseModel from rich import print +from tenacity import AsyncRetrying, Retrying, retry_if_not_exception_type, stop_after_attempt +from pydantic_core._pydantic_core import ValidationError from .base import AsyncRuntime, Runtime @@ -163,6 +166,10 @@ def record_to_record( instructions_first, ) + retries = Retrying( + retry=retry_if_not_exception_type((ValidationError)), stop=stop_after_attempt(3) + ) + try: # returns a pydantic model named Output response = instructor_client.chat.completions.create( @@ -172,6 +179,7 @@ def record_to_record( max_tokens=self.max_tokens, temperature=self.temperature, seed=self.seed, + max_retries=retries, # extra inference params passed to this runtime **self.model_extra, ) @@ -185,6 +193,9 @@ def record_to_record( logger.debug(tb) return dct except Exception as e: + # Catch case where the model does not return a properly formatted output + if type(e).__name__ == 'ValidationError' and 'Invalid JSON' in str(e): + e = ConstrainedGenerationError() # the only other instructor error that would be thrown is IncompleteOutputException due to max_tokens reached dct = _format_error_dict(e) print_error(f"Inference error {dct['_adala_message']}") @@ -281,6 +292,10 @@ async def batch_to_batch( lambda row: input_template.format(**row, **extra_fields), axis=1 ).tolist() + retries = AsyncRetrying( + retry=retry_if_not_exception_type((ValidationError)), stop=stop_after_attempt(3) + ) + tasks = [ asyncio.ensure_future( async_instructor_client.chat.completions.create( @@ -294,6 +309,7 @@ async def batch_to_batch( max_tokens=self.max_tokens, temperature=self.temperature, seed=self.seed, + max_retries=retries, # extra inference params passed to this runtime **self.model_extra, ) @@ -319,6 +335,9 @@ async def batch_to_batch( df_data.append(dct) elif isinstance(response, Exception): e = response + # Catch case where the model does not return a properly formatted output + if type(e).__name__ == 'ValidationError' and 'Invalid JSON' in str(e): + e = ConstrainedGenerationError() # the only other instructor error that would be thrown is IncompleteOutputException due to max_tokens reached dct = _format_error_dict(e) print_error(f"Inference error {dct['_adala_message']}") diff --git a/adala/utils/exceptions.py b/adala/utils/exceptions.py new file mode 100644 index 00000000..ff9d50df --- /dev/null +++ b/adala/utils/exceptions.py @@ -0,0 +1,6 @@ + +class ConstrainedGenerationError(Exception): + def __init__(self): + self.message = "The selected provider model could not generate a properly-formatted response" + + super(ConstrainedGenerationError, self).__init__(self.message)