Skip to content

Commit

Permalink
Run black
Browse files Browse the repository at this point in the history
  • Loading branch information
nik committed Aug 19, 2024
1 parent 29b240b commit 338522f
Show file tree
Hide file tree
Showing 17 changed files with 580 additions and 301 deletions.
36 changes: 26 additions & 10 deletions adala/runtimes/_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
from openai import NotFoundError
from pydantic import ConfigDict, field_validator, BaseModel
from rich import print
from tenacity import AsyncRetrying, Retrying, retry_if_not_exception_type, stop_after_attempt
from tenacity import (
AsyncRetrying,
Retrying,
retry_if_not_exception_type,
stop_after_attempt,
)
from pydantic_core._pydantic_core import ValidationError

from .base import AsyncRuntime, Runtime
Expand Down Expand Up @@ -131,7 +136,9 @@ def record_to_record(
input_template: str,
instructions_template: str,
response_model: Type[BaseModel],
output_template: Optional[str] = None, # TODO: deprecated in favor of response_model, can be removed
output_template: Optional[
str
] = None, # TODO: deprecated in favor of response_model, can be removed
extra_fields: Optional[Dict[str, str]] = None,
field_schema: Optional[Dict] = None,
instructions_first: bool = False,
Expand All @@ -156,7 +163,9 @@ def record_to_record(
extra_fields = extra_fields or {}

if not response_model:
raise ValueError('You must explicitly specify the `response_model` in runtime.')
raise ValueError(
"You must explicitly specify the `response_model` in runtime."
)

messages = get_messages(
input_template.format(**record, **extra_fields),
Expand All @@ -165,7 +174,8 @@ def record_to_record(
)

retries = Retrying(
retry=retry_if_not_exception_type((ValidationError)), stop=stop_after_attempt(3)
retry=retry_if_not_exception_type((ValidationError)),
stop=stop_after_attempt(3),
)

try:
Expand All @@ -192,7 +202,7 @@ def record_to_record(
return dct
except Exception as e:
# Catch case where the model does not return a properly formatted output
if type(e).__name__ == 'ValidationError' and 'Invalid JSON' in str(e):
if type(e).__name__ == "ValidationError" and "Invalid JSON" in str(e):
e = ConstrainedGenerationError()
# the only other instructor error that would be thrown is IncompleteOutputException due to max_tokens reached
dct = _format_error_dict(e)
Expand Down Expand Up @@ -273,24 +283,30 @@ async def batch_to_batch(
input_template: str,
instructions_template: str,
response_model: Type[BaseModel],
output_template: Optional[str] = None, # TODO: deprecated in favor of response_model, can be removed
output_template: Optional[
str
] = None, # TODO: deprecated in favor of response_model, can be removed
extra_fields: Optional[Dict[str, str]] = None,
field_schema: Optional[Dict] = None,
instructions_first: bool = True,
) -> InternalDataFrame:
"""Execute batch of requests with async calls to OpenAI API"""

if not response_model:
raise ValueError('You must explicitly specify the `response_model` in runtime.')
raise ValueError(
"You must explicitly specify the `response_model` in runtime."
)

extra_fields = extra_fields or {}
user_prompts = batch.apply(
# TODO: remove "extra_fields" to avoid name collisions
lambda row: input_template.format(**row, **extra_fields), axis=1
lambda row: input_template.format(**row, **extra_fields),
axis=1,
).tolist()

retries = AsyncRetrying(
retry=retry_if_not_exception_type((ValidationError)), stop=stop_after_attempt(3)
retry=retry_if_not_exception_type((ValidationError)),
stop=stop_after_attempt(3),
)

tasks = [
Expand Down Expand Up @@ -333,7 +349,7 @@ async def batch_to_batch(
elif isinstance(response, Exception):
e = response
# Catch case where the model does not return a properly formatted output
if type(e).__name__ == 'ValidationError' and 'Invalid JSON' in str(e):
if type(e).__name__ == "ValidationError" and "Invalid JSON" in str(e):
e = ConstrainedGenerationError()
# the only other instructor error that would be thrown is IncompleteOutputException due to max_tokens reached
dct = _format_error_dict(e)
Expand Down
18 changes: 13 additions & 5 deletions adala/runtimes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def record_to_record(
extra_fields: Optional[Dict[str, Any]] = None,
field_schema: Optional[Dict] = None,
instructions_first: bool = True,
output_template: Optional[str] = None, # TODO: deprecated in favor of response_model, can be removed
output_template: Optional[
str
] = None, # TODO: deprecated in favor of response_model, can be removed
) -> Dict[str, str]:
"""
Processes a record.
Expand Down Expand Up @@ -80,7 +82,9 @@ def batch_to_batch(
extra_fields: Optional[Dict[str, str]] = None,
field_schema: Optional[Dict] = None,
instructions_first: bool = True,
output_template: Optional[str] = None, # TODO: deprecated in favor of response_model, can be removed
output_template: Optional[
str
] = None, # TODO: deprecated in favor of response_model, can be removed
) -> InternalDataFrame:
"""
Processes a record.
Expand Down Expand Up @@ -153,7 +157,9 @@ def record_to_batch(
extra_fields: Optional[Dict[str, str]] = None,
field_schema: Optional[Dict] = None,
instructions_first: bool = True,
output_template: Optional[str] = None, # TODO: deprecated in favor of response_model, can be removed
output_template: Optional[
str
] = None, # TODO: deprecated in favor of response_model, can be removed
) -> InternalDataFrame:
"""
Processes a record and return a batch.
Expand Down Expand Up @@ -182,7 +188,7 @@ def record_to_batch(
extra_fields=extra_fields,
field_schema=field_schema,
instructions_first=instructions_first,
response_model=response_model
response_model=response_model,
)


Expand Down Expand Up @@ -231,7 +237,9 @@ async def batch_to_batch(
extra_fields: Optional[Dict[str, str]] = None,
field_schema: Optional[Dict] = None,
instructions_first: bool = True,
output_template: Optional[str] = None, # TODO: deprecated in favor of response_model, can be removed
output_template: Optional[
str
] = None, # TODO: deprecated in favor of response_model, can be removed
) -> InternalDataFrame:
"""
Processes a record.
Expand Down
42 changes: 28 additions & 14 deletions adala/skills/_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import logging
import string
from pydantic import BaseModel, Field, field_validator, model_validator, field_serializer
from pydantic import (
BaseModel,
Field,
field_validator,
model_validator,
field_serializer,
)
from typing import List, Optional, Any, Dict, Tuple, Union, ClassVar, Type
from abc import ABC, abstractmethod
from adala.utils.internal_data import (
Expand Down Expand Up @@ -51,7 +57,7 @@ class Skill(BaseModelInRegistry):
"Can use templating to refer to input fields.",
examples=["Label the input text with the following labels: {labels}"],
# TODO: instructions can be deprecated in favor of using `input_template` to specify the instructions
default=''
default="",
)
input_template: str = Field(
title="Input template",
Expand All @@ -65,7 +71,7 @@ class Skill(BaseModelInRegistry):
"Can use templating to refer to input parameters and perform data transformations",
examples=["Output: {output}", "{predictions}"],
# TODO: output_template can be deprecated in favor of using `response_model` to specify the output
default=''
default="",
)
description: Optional[str] = Field(
default="",
Expand Down Expand Up @@ -161,17 +167,21 @@ def _create_response_model_from_field_schema(self):
assert self.field_schema, "field_schema is required to create a response model"
if self.response_model:
return
self.response_model = field_schema_to_pydantic_class(self.field_schema, self.name, self.description)
self.response_model = field_schema_to_pydantic_class(
self.field_schema, self.name, self.description
)

@model_validator(mode='after')
@model_validator(mode="after")
def validate_response_model(self):
if self.response_model:
# if response_model, we use it right away
return self

if not self.field_schema:
# if field_schema is not provided, extract it from `output_template`
logger.info(f"Parsing output_template to generate the response model: {self.output_template}")
logger.info(
f"Parsing output_template to generate the response model: {self.output_template}"
)
self.field_schema = {}
chunks = parse_template(self.output_template)

Expand All @@ -187,7 +197,9 @@ def validate_response_model(self):
# if description is not provided, use the text before the field,
# otherwise use the field name with underscores replaced by spaces
field_description = previous_text or field_name.replace("_", " ")
field_description = field_description.strip(string.punctuation).strip()
field_description = field_description.strip(
string.punctuation
).strip()
previous_text = ""

# create default JSON schema entry for the field
Expand All @@ -201,14 +213,14 @@ def validate_response_model(self):

# When serializing the agent, ensure `response_model` is excluded.
# It will be restored from `field_schema` during deserialization.
@field_serializer('response_model')
@field_serializer("response_model")
def serialize_response_model(self, value):
return None

# remove `response_model` from the pickle serialization
def __getstate__(self):
state = super().__getstate__()
state['__dict__']['response_model'] = None
state["__dict__"]["response_model"] = None
return state

def __setstate__(self, state):
Expand Down Expand Up @@ -286,7 +298,7 @@ async def aapply(
field_schema=self.field_schema,
extra_fields=self._get_extra_fields(),
instructions_first=self.instructions_first,
response_model=self.response_model
response_model=self.response_model,
)

def improve(
Expand Down Expand Up @@ -330,8 +342,10 @@ def improve(
# TODO: self.output_template can be missed or incompatible with the field_schema
# we need to redefine how we create examples for learn()
if not self.output_template:
raise ValueError("`output_template` is required for improve() method and must contain "
"the output fields from `field_schema`")
raise ValueError(
"`output_template` is required for improve() method and must contain "
"the output fields from `field_schema`"
)
examples.append(
f"### Example #{i}\n\n"
f"{self.input_template.format(**row)}\n\n"
Expand Down Expand Up @@ -505,7 +519,7 @@ def apply(
field_schema=self.field_schema,
extra_fields=self._get_extra_fields(),
instructions_first=self.instructions_first,
response_model=self.response_model
response_model=self.response_model,
)

def improve(self, **kwargs):
Expand Down Expand Up @@ -574,7 +588,7 @@ def apply(
instructions_template=self.instructions,
extra_fields=extra_fields,
instructions_first=self.instructions_first,
response_model=self.response_model
response_model=self.response_model,
)
outputs.append(InternalSeries(output))
output = InternalDataFrame(outputs)
Expand Down
16 changes: 10 additions & 6 deletions adala/skills/collection/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@ def validate_schema(schema: Dict[str, Any]):
"enum": {
"type": "array",
"items": {"type": "string"},
"minItems": 1
"minItems": 1,
},
"description": {"type": "string"}
"description": {"type": "string"},
},
"required": ["type", "enum"],
"additionalProperties": False
"additionalProperties": False,
}
},
"minProperties": 1,
"maxProperties": 1,
"additionalProperties": False
"additionalProperties": False,
}

try:
Expand Down Expand Up @@ -70,7 +70,9 @@ class ClassificationSkill(TransformSkill):
def validate_response_model(self):

if self.response_model:
raise NotImplementedError("Classification skill does not support custom response model yet.")
raise NotImplementedError(
"Classification skill does not support custom response model yet."
)

if self.field_schema:
# in case field_schema is already provided, we don't need to parse output template and validate labels
Expand Down Expand Up @@ -106,5 +108,7 @@ def validate_response_model(self):
"enum": self.labels,
}

self.response_model = field_schema_to_pydantic_class(self.field_schema, self.name, self.description)
self.response_model = field_schema_to_pydantic_class(
self.field_schema, self.name, self.description
)
return self
Loading

0 comments on commit 338522f

Please sign in to comment.