Skip to content

Commit

Permalink
feat: DIA-1322: Use field_schema in skills as primary output format (#…
Browse files Browse the repository at this point in the history
…185)

Co-authored-by: nik <[email protected]>
Co-authored-by: Matt Bernstein <[email protected]>
  • Loading branch information
3 people authored Aug 19, 2024
1 parent 1f92be8 commit 4626c38
Show file tree
Hide file tree
Showing 13 changed files with 933 additions and 670 deletions.
23 changes: 10 additions & 13 deletions adala/runtimes/_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from adala.utils.parse import (
parse_template,
partial_str_format,
parse_template_to_pydantic_class,
)
from openai import NotFoundError
from pydantic import ConfigDict, field_validator, BaseModel
Expand Down Expand Up @@ -131,11 +130,11 @@ def record_to_record(
record: Dict[str, str],
input_template: str,
instructions_template: str,
output_template: str,
response_model: Type[BaseModel],
output_template: Optional[str] = None, # TODO: deprecated in favor of response_model, can be removed
extra_fields: Optional[Dict[str, str]] = None,
field_schema: Optional[Dict] = None,
instructions_first: bool = False,
response_model: Optional[Type[BaseModel]] = None,
) -> Dict[str, str]:
"""
Execute OpenAI request given record and templates for input, instructions and output.
Expand All @@ -144,11 +143,11 @@ def record_to_record(
record: Record to be used for input, instructions and output templates.
input_template: Template for input message.
instructions_template: Template for instructions message.
output_template: Template for output message.
output_template: Template for output message (deprecated, not used).
extra_fields: Extra fields to be used in templates.
field_schema: Field schema to be used for parsing templates.
instructions_first: If True, instructions will be sent before input.
response_model: Pydantic model for response. If set, `output_template` and `field_schema` are ignored.
response_model: Pydantic model for response.
Returns:
Dict[str, str]: Output record.
Expand All @@ -157,9 +156,8 @@ def record_to_record(
extra_fields = extra_fields or {}

if not response_model:
response_model = parse_template_to_pydantic_class(
output_template, provided_field_schema=field_schema
)
raise ValueError('You must explicitly specify the `response_model` in runtime.')

messages = get_messages(
input_template.format(**record, **extra_fields),
instructions_template,
Expand Down Expand Up @@ -274,21 +272,20 @@ async def batch_to_batch(
batch: InternalDataFrame,
input_template: str,
instructions_template: str,
output_template: str,
response_model: Type[BaseModel],
output_template: Optional[str] = None, # TODO: deprecated in favor of response_model, can be removed
extra_fields: Optional[Dict[str, str]] = None,
field_schema: Optional[Dict] = None,
instructions_first: bool = True,
response_model: Optional[Type[BaseModel]] = None,
) -> InternalDataFrame:
"""Execute batch of requests with async calls to OpenAI API"""

if not response_model:
response_model = parse_template_to_pydantic_class(
output_template, provided_field_schema=field_schema
)
raise ValueError('You must explicitly specify the `response_model` in runtime.')

extra_fields = extra_fields or {}
user_prompts = batch.apply(
# TODO: remove "extra_fields" to avoid name collisions
lambda row: input_template.format(**row, **extra_fields), axis=1
).tolist()

Expand Down
35 changes: 16 additions & 19 deletions adala/runtimes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ def record_to_record(
record: Dict[str, str],
input_template: str,
instructions_template: str,
output_template: str,
response_model: Type[BaseModel],
extra_fields: Optional[Dict[str, Any]] = None,
field_schema: Optional[Dict] = None,
instructions_first: bool = True,
response_model: Optional[Type[BaseModel]] = None,
output_template: Optional[str] = None, # TODO: deprecated in favor of response_model, can be removed
) -> Dict[str, str]:
"""
Processes a record.
Expand All @@ -60,14 +60,12 @@ def record_to_record(
record (Dict[str, str]): The record to process.
input_template (str): The input template.
instructions_template (str): The instructions template.
output_template (str): The output template.
response_model (Type[BaseModel]): The response model to use for processing records.
extra_fields (Optional[Dict[str, str]]): Extra fields to use in the templates. Defaults to None.
field_schema (Optional[Dict]): Field JSON schema to use in the templates. Defaults to all fields are strings,
i.e. analogous to {"field_n": {"type": "string"}}.
instructions_first (bool): Whether to put instructions first. Defaults to True.
response_model (Optional[Type[BaseModel]]): The response model to use for processing records. Defaults to None.
If set, the response will be generated according to this model and `output_template` and `field_schema` fields will be ignored.
Note, explicitly providing ResponseModel will be the default behavior for all runtimes in the future.
output_template (str): The output template. Deprecated.
Returns:
Dict[str, str]: The processed record.
Expand All @@ -78,11 +76,11 @@ def batch_to_batch(
batch: InternalDataFrame,
input_template: str,
instructions_template: str,
output_template: str,
response_model: Type[BaseModel],
extra_fields: Optional[Dict[str, str]] = None,
field_schema: Optional[Dict] = None,
instructions_first: bool = True,
response_model: Optional[Type[BaseModel]] = None,
output_template: Optional[str] = None, # TODO: deprecated in favor of response_model, can be removed
) -> InternalDataFrame:
"""
Processes a record.
Expand All @@ -96,11 +94,12 @@ def batch_to_batch(
batch (InternalDataFrame): The batch to process.
input_template (str): The input template.
instructions_template (str): The instructions' template.
output_template (str): The output template.
response_model (Type[BaseModel]): The response model to use for processing records.
extra_fields (Optional[Dict[str, str]]): Extra fields to use in the templates. Defaults to None.
field_schema (Optional[Dict]): Field JSON schema to use in the templates. Defaults to all fields are strings,
i.e. analogous to {"field_n": {"type": "string"}}.
instructions_first (bool): Whether to put instructions first. Defaults to True.
output_template (str): The output template. Deprecated.
Returns:
InternalDataFrame: The processed batch.
"""
Expand Down Expand Up @@ -149,12 +148,12 @@ def record_to_batch(
record: Dict[str, str],
input_template: str,
instructions_template: str,
output_template: str,
response_model: Type[BaseModel],
output_batch_size: int = 1,
extra_fields: Optional[Dict[str, str]] = None,
field_schema: Optional[Dict] = None,
instructions_first: bool = True,
response_model: Optional[Type[BaseModel]] = None,
output_template: Optional[str] = None, # TODO: deprecated in favor of response_model, can be removed
) -> InternalDataFrame:
"""
Processes a record and return a batch.
Expand All @@ -163,15 +162,13 @@ def record_to_batch(
record (Dict[str, str]): The record to process.
input_template (str): The input template.
instructions_template (str): The instructions template.
output_template (str): The output template.
response_model (Optional[Type[BaseModel]]): The response model to use for processing records. Defaults to None.
output_batch_size (int): The batch size for the output. Defaults to 1.
extra_fields (Optional[Dict[str, str]]): Extra fields to use in the templates. Defaults to None.
field_schema (Optional[Dict]): Field JSON schema to use in the templates. Defaults to all fields are strings,
i.e. analogous to {"field_n": {"type": "string"}}.
instructions_first (bool): Whether to put instructions first. Defaults to True.
response_model (Optional[Type[BaseModel]]): The response model to use for processing records. Defaults to None.
If set, the response will be generated according to this model and `output_template` and `field_schema` fields will be ignored.
Note, explicitly providing ResponseModel will be the default behavior for all runtimes in the future.
output_template (str): The output template. Deprecated.
Returns:
InternalDataFrame: The processed batch.
Expand Down Expand Up @@ -230,11 +227,11 @@ async def batch_to_batch(
batch: InternalDataFrame,
input_template: str,
instructions_template: str,
output_template: str,
response_model: Type[BaseModel],
extra_fields: Optional[Dict[str, str]] = None,
field_schema: Optional[Dict] = None,
instructions_first: bool = True,
response_model: Optional[Type[BaseModel]] = None,
output_template: Optional[str] = None, # TODO: deprecated in favor of response_model, can be removed
) -> InternalDataFrame:
"""
Processes a record.
Expand All @@ -243,12 +240,12 @@ async def batch_to_batch(
batch (InternalDataFrame): The batch to process.
input_template (str): The input template.
instructions_template (str): The instructions template.
output_template (str): The output template.
response_model (Optional[Type[BaseModel]]): The response model to use for processing records.
extra_fields (Optional[Dict[str, str]]): Extra fields to use in the templates. Defaults to None.
field_schema (Optional[Dict]): Field JSON schema to use in the templates. Defaults to all fields are strings,
i.e. analogous to {"field_n": {"type": "string"}}.
instructions_first (bool): Whether to put instructions first. Defaults to True.
response_model (Optional[Type[BaseModel]]): The response model to use for processing records. Defaults to None.
output_template (str): The output template. Deprecated.
Returns:
InternalDataFrame: The processed batch.
Expand Down
54 changes: 53 additions & 1 deletion adala/skills/_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from pydantic import BaseModel, Field, field_validator
import logging
import string
from pydantic import BaseModel, Field, field_validator, model_validator
from typing import List, Optional, Any, Dict, Tuple, Union, ClassVar, Type
from abc import ABC, abstractmethod
from adala.utils.internal_data import (
Expand All @@ -7,11 +9,14 @@
InternalSeries,
)
from adala.utils.parse import parse_template, partial_str_format
from adala.utils.pydantic_generator import field_schema_to_pydantic_class
from adala.utils.logs import print_dataframe, print_text
from adala.utils.registry import BaseModelInRegistry
from adala.runtimes.base import Runtime, AsyncRuntime
from tqdm import tqdm

logger = logging.getLogger(__name__)


class Skill(BaseModelInRegistry):
"""
Expand Down Expand Up @@ -139,6 +144,11 @@ def get_output_fields(self):
Returns:
List[str]: A list of output fields.
"""
if self.response_model:
return list(self.response_model.__fields__.keys())
if self.field_schema:
return list(self.field_schema.keys())

extra_fields = self._get_extra_fields()
# TODO: input fields are not considered - shall we disallow input fields in output template?
output_fields = parse_template(
Expand All @@ -147,6 +157,42 @@ def get_output_fields(self):
)
return [f["text"] for f in output_fields]

@model_validator(mode='after')
def validate_response_model(self):
if self.response_model:
# if response_model, we use it right away
return self

if not self.field_schema:
# if field_schema is not provided, extract it from `output_template`
logger.info(f"Parsing output_template to generate the response model: {self.output_template}")
self.field_schema = {}
chunks = parse_template(self.output_template)

previous_text = ""
for chunk in chunks:
if chunk["type"] == "text":
previous_text = chunk["text"]
if chunk["type"] == "var":
field_name = chunk["text"]
# by default, all fields are strings
field_type = "string"

# if description is not provided, use the text before the field,
# otherwise use the field name with underscores replaced by spaces
field_description = previous_text or field_name.replace("_", " ")
field_description = field_description.strip(string.punctuation).strip()
previous_text = ""

# create default JSON schema entry for the field
self.field_schema[field_name] = {
"type": field_type,
"description": field_description,
}

self.response_model = field_schema_to_pydantic_class(self.field_schema, self.name, self.description)
return self

@abstractmethod
def apply(self, input, runtime):
"""
Expand Down Expand Up @@ -257,6 +303,12 @@ def improve(
# if fb marked as NaN, skip
if not row[f"{train_skill_output}__fb"]:
continue

# TODO: self.output_template can be missed or incompatible with the field_schema
# we need to redefine how we create examples for learn()
if not self.output_template:
raise ValueError("`output_template` is required for improve() method and must contain "
"the output fields from `field_schema`")
examples.append(
f"### Example #{i}\n\n"
f"{self.input_template.format(**row)}\n\n"
Expand Down
Loading

0 comments on commit 4626c38

Please sign in to comment.