Skip to content

Commit

Permalink
fix: DIA-1334: Catch and report pydantic ValidationError (#184)
Browse files Browse the repository at this point in the history
  • Loading branch information
hakan458 authored Aug 16, 2024
1 parent 2c8214d commit 1f92be8
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
19 changes: 19 additions & 0 deletions adala/runtimes/_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import instructor
from instructor.exceptions import InstructorRetryException
import traceback
from adala.utils.exceptions import ConstrainedGenerationError
from adala.utils.internal_data import InternalDataFrame
from adala.utils.logs import print_error
from adala.utils.parse import (
Expand All @@ -17,6 +18,8 @@
from openai import NotFoundError
from pydantic import ConfigDict, field_validator, BaseModel
from rich import print
from tenacity import AsyncRetrying, Retrying, retry_if_not_exception_type, stop_after_attempt
from pydantic_core._pydantic_core import ValidationError

from .base import AsyncRuntime, Runtime

Expand Down Expand Up @@ -163,6 +166,10 @@ def record_to_record(
instructions_first,
)

retries = Retrying(
retry=retry_if_not_exception_type((ValidationError)), stop=stop_after_attempt(3)
)

try:
# returns a pydantic model named Output
response = instructor_client.chat.completions.create(
Expand All @@ -172,6 +179,7 @@ def record_to_record(
max_tokens=self.max_tokens,
temperature=self.temperature,
seed=self.seed,
max_retries=retries,
# extra inference params passed to this runtime
**self.model_extra,
)
Expand All @@ -185,6 +193,9 @@ def record_to_record(
logger.debug(tb)
return dct
except Exception as e:
# Catch case where the model does not return a properly formatted output
if type(e).__name__ == 'ValidationError' and 'Invalid JSON' in str(e):
e = ConstrainedGenerationError()
# the only other instructor error that would be thrown is IncompleteOutputException due to max_tokens reached
dct = _format_error_dict(e)
print_error(f"Inference error {dct['_adala_message']}")
Expand Down Expand Up @@ -281,6 +292,10 @@ async def batch_to_batch(
lambda row: input_template.format(**row, **extra_fields), axis=1
).tolist()

retries = AsyncRetrying(
retry=retry_if_not_exception_type((ValidationError)), stop=stop_after_attempt(3)
)

tasks = [
asyncio.ensure_future(
async_instructor_client.chat.completions.create(
Expand All @@ -294,6 +309,7 @@ async def batch_to_batch(
max_tokens=self.max_tokens,
temperature=self.temperature,
seed=self.seed,
max_retries=retries,
# extra inference params passed to this runtime
**self.model_extra,
)
Expand All @@ -319,6 +335,9 @@ async def batch_to_batch(
df_data.append(dct)
elif isinstance(response, Exception):
e = response
# Catch case where the model does not return a properly formatted output
if type(e).__name__ == 'ValidationError' and 'Invalid JSON' in str(e):
e = ConstrainedGenerationError()
# the only other instructor error that would be thrown is IncompleteOutputException due to max_tokens reached
dct = _format_error_dict(e)
print_error(f"Inference error {dct['_adala_message']}")
Expand Down
6 changes: 6 additions & 0 deletions adala/utils/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

class ConstrainedGenerationError(Exception):
def __init__(self):
self.message = "The selected provider model could not generate a properly-formatted response"

super(ConstrainedGenerationError, self).__init__(self.message)

0 comments on commit 1f92be8

Please sign in to comment.