Skip to content

Commit

Permalink
run black formatter everywhere (#249)
Browse files Browse the repository at this point in the history
  • Loading branch information
matt-bernstein authored Nov 7, 2024
1 parent 591c0ad commit c6da54b
Show file tree
Hide file tree
Showing 13 changed files with 135 additions and 82 deletions.
2 changes: 1 addition & 1 deletion adala/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ async def arefine_skill(
predictions = await self.skills.aapply(inputs, runtime=runtime)
else:
predictions = inputs

response = await skill.aimprove(
predictions=predictions,
teacher_runtime=teacher_runtime,
Expand Down
6 changes: 3 additions & 3 deletions adala/environments/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,16 @@ async def initialize(self):
self.kafka_input_topic,
bootstrap_servers=self.kafka_bootstrap_servers,
value_deserializer=lambda v: json.loads(v.decode("utf-8")),
enable_auto_commit=False, # True by default which causes messages to be missed when using getmany()
enable_auto_commit=False, # True by default which causes messages to be missed when using getmany()
auto_offset_reset="earliest",
group_id=self.kafka_input_topic, # ensuring unique group_id to not mix up offsets between topics
group_id=self.kafka_input_topic, # ensuring unique group_id to not mix up offsets between topics
)
await self.consumer.start()

self.producer = AIOKafkaProducer(
bootstrap_servers=self.kafka_bootstrap_servers,
value_serializer=lambda v: json.dumps(v).encode("utf-8"),
acks='all' # waits for all replicas to respond that they have written the message
acks="all", # waits for all replicas to respond that they have written the message
)
await self.producer.start()

Expand Down
8 changes: 5 additions & 3 deletions adala/skills/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ def _iter_over_chunks(
input = InternalDataFrame([input])

extra_fields = self._get_extra_fields()

# if chunk_size is specified, split the input into chunks and process each chunk separately
if self.chunk_size is not None:
chunks = (
Expand All @@ -640,10 +640,12 @@ def _iter_over_chunks(
)
else:
chunks = [input]

# define the row preprocessing function
def row_preprocessing(row):
return partial_str_format(self.input_template, **row, **extra_fields, i=int(row.name) + 1)
return partial_str_format(
self.input_template, **row, **extra_fields, i=int(row.name) + 1
)

total = input.shape[0] // self.chunk_size if self.chunk_size is not None else 1
for chunk in tqdm(chunks, desc="Processing chunks", total=total):
Expand Down
32 changes: 20 additions & 12 deletions adala/skills/collection/entity_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
logger = logging.getLogger(__name__)


def validate_output_format_for_ner_tag(df: InternalDataFrame, input_field_name: str, output_field_name: str):
'''
def validate_output_format_for_ner_tag(
df: InternalDataFrame, input_field_name: str, output_field_name: str
):
"""
The output format for Labels is:
{
"start": start_idx,
Expand All @@ -23,30 +25,30 @@ def validate_output_format_for_ner_tag(df: InternalDataFrame, input_field_name:
"labels": [label1, label2, ...]
}
Sometimes the model cannot populate "text" correctly, but this can be fixed deterministically.
'''
"""
for i, row in df.iterrows():
if row.get("_adala_error"):
logger.warning(f"Error in row {i}: {row['_adala_message']}")
continue
text = row[input_field_name]
entities = row[output_field_name]
for entity in entities:
corrected_text = text[entity["start"]:entity["end"]]
corrected_text = text[entity["start"] : entity["end"]]
if entity.get("text") is None:
entity["text"] = corrected_text
elif entity["text"] != corrected_text:
# this seems to happen rarely if at all in testing, but could lead to invalid predictions
logger.warning(f"text and indices disagree for a predicted entity")
return df


def extract_indices(
df,
input_field_name,
output_field_name,
quote_string_field_name='quote_string',
labels_field_name='label'
):
df,
input_field_name,
output_field_name,
quote_string_field_name="quote_string",
labels_field_name="label",
):
"""
Give the input dataframe with "text" column and "entities" column of the format
```
Expand Down Expand Up @@ -354,7 +356,13 @@ def extract_indices(self, df):
"""
input_field_name = self._get_input_field_name()
output_field_name = self._get_output_field_name()
df = extract_indices(df, input_field_name, output_field_name, self._quote_string_field_name, self._labels_field_name)
df = extract_indices(
df,
input_field_name,
output_field_name,
self._quote_string_field_name,
self._labels_field_name,
)
return df

def apply(
Expand Down
36 changes: 22 additions & 14 deletions adala/skills/collection/label_studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@

from label_studio_sdk.label_interface import LabelInterface
from label_studio_sdk.label_interface.control_tags import ControlTag
from label_studio_sdk._extensions.label_studio_tools.core.utils.json_schema import json_schema_to_pydantic
from label_studio_sdk._extensions.label_studio_tools.core.utils.json_schema import (
json_schema_to_pydantic,
)

from .entity_extraction import extract_indices, validate_output_format_for_ner_tag

Expand All @@ -23,7 +25,9 @@ class LabelStudioSkill(TransformSkill):
input_template: str = "Annotate the input data according to the provided schema."
# TODO: remove output_template, fix calling @model_validator(mode='after') in the base class
output_template: str = "Output: {field_name}"
response_model: Type[BaseModel] = BaseModel # why validate_response_model is called in the base class?
response_model: Type[BaseModel] = (
BaseModel # why validate_response_model is called in the base class?
)
# ------------------------------
label_config: str = "<View></View>"

Expand All @@ -33,21 +37,21 @@ def ner_tags(self) -> Iterator[ControlTag]:
# check if the input config has NER tag (<Labels> + <Text>), and return its `from_name` and `to_name`
interface = LabelInterface(self.label_config)
for tag in interface.controls:
#TODO: don't need to check object tag because at this point, unusable control tags should have been stripped out of the label config, but confirm this - maybe move this logic to LSE
if tag.tag == 'Labels':
# TODO: don't need to check object tag because at this point, unusable control tags should have been stripped out of the label config, but confirm this - maybe move this logic to LSE
if tag.tag == "Labels":
yield tag
@model_validator(mode='after')

@model_validator(mode="after")
def validate_response_model(self):

interface = LabelInterface(self.label_config)
logger.debug(f'Read labeling config {self.label_config}')
logger.debug(f"Read labeling config {self.label_config}")

self.field_schema = interface.to_json_schema()
logger.debug(f'Converted labeling config to json schema: {self.field_schema}')
logger.debug(f"Converted labeling config to json schema: {self.field_schema}")

return self

def _create_response_model_from_field_schema(self):
pass

Expand All @@ -56,7 +60,7 @@ def apply(
input: InternalDataFrame,
runtime: Runtime,
) -> InternalDataFrame:

with json_schema_to_pydantic(self.field_schema) as ResponseModel:
return runtime.batch_to_batch(
input,
Expand All @@ -81,10 +85,14 @@ async def aapply(
response_model=ResponseModel,
)
for ner_tag in self.ner_tags():
input_field_name = ner_tag.objects[0].value.lstrip('$')
input_field_name = ner_tag.objects[0].value.lstrip("$")
output_field_name = ner_tag.name
quote_string_field_name = 'text'
quote_string_field_name = "text"
df = pd.concat([input, output], axis=1)
output = validate_output_format_for_ner_tag(df, input_field_name, output_field_name)
output = extract_indices(output, input_field_name, output_field_name, quote_string_field_name)
output = validate_output_format_for_ner_tag(
df, input_field_name, output_field_name
)
output = extract_indices(
output, input_field_name, output_field_name, quote_string_field_name
)
return output
43 changes: 29 additions & 14 deletions adala/skills/collection/prompt_improvement.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import json
import logging
from pydantic import BaseModel, field_validator, Field, ConfigDict, model_validator, AfterValidator
from pydantic import (
BaseModel,
field_validator,
Field,
ConfigDict,
model_validator,
AfterValidator,
)
from adala.skills import Skill
from typing import Any, Dict, List, Optional, Union
from typing_extensions import Annotated
Expand All @@ -15,7 +22,9 @@
def validate_used_variables(value: str) -> str:
templates = parse_template(value, include_texts=False)
if not templates:
raise ValueError("At least one input variable must be used in the prompt, formatted with curly braces like this: {input_variable}")
raise ValueError(
"At least one input variable must be used in the prompt, formatted with curly braces like this: {input_variable}"
)
return value


Expand Down Expand Up @@ -52,34 +61,38 @@ class PromptImprovementSkill(AnalysisSkill):

name: str = "prompt_improvement"
instructions: str = "Improve current prompt"
input_template: str = "" # Used to provide a few shot examples of input-output pairs
input_template: str = (
"" # Used to provide a few shot examples of input-output pairs
)
input_prefix: str = "" # Used to provide additional context for the input
input_separator: str = "\n"

response_model = PromptImprovementSkillResponseModel

@model_validator(mode="after")
def validate_prompts(self):

def get_json_template(fields):
json_body = ", ".join([f'"{field}": "{{{field}}}"' for field in fields])
return "{" + json_body + "}"

if isinstance(self.skill_to_improve, LabelStudioSkill):
model_json_schema = self.skill_to_improve.field_schema
else:
model_json_schema = self.skill_to_improve.response_model.model_json_schema()

# TODO: can remove this when only LabelStudioSkill is supported
label_config = getattr(self.skill_to_improve, 'label_config', '<View>Not available</View>')
label_config = getattr(
self.skill_to_improve, "label_config", "<View>Not available</View>"
)

input_variables = self.input_variables
output_variables = list(model_json_schema['properties'].keys())
output_variables = list(model_json_schema["properties"].keys())
input_json_template = get_json_template(input_variables)
output_json_template = get_json_template(output_variables)
self.input_template = f'{input_json_template} --> {output_json_template}'
self.input_prefix = f'''
self.input_template = f"{input_json_template} --> {output_json_template}"

self.input_prefix = f"""
## Current prompt:
```
{self.skill_to_improve.input_template}
Expand All @@ -102,10 +115,12 @@ def get_json_template(fields):
## Input-Output Examples:
'''
"""

# TODO: deprecated, leave self.output_template for compatibility
self.output_template = output_json_template

logger.debug(f'Instructions: {self.instructions}\nInput template: {self.input_template}\nInput prefix: {self.input_prefix}')

logger.debug(
f"Instructions: {self.instructions}\nInput template: {self.input_template}\nInput prefix: {self.input_prefix}"
)
return self
4 changes: 3 additions & 1 deletion adala/skills/collection/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def apply(
input_strings, num_results=self.num_results
)
rag_input_strings = [
"\n\n".join(partial_str_format(self.rag_input_template, **i) for i in rag_items)
"\n\n".join(
partial_str_format(self.rag_input_template, **i) for i in rag_items
)
for rag_items in rag_input_data
]
output_fields = self.get_output_fields()
Expand Down
46 changes: 28 additions & 18 deletions adala/utils/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def get_value(self, key, args, kwds):
return "{" + key + "}"
else:
Formatter.get_value(key, args, kwds)

def format_field(self, value, format_spec):
try:
return super().format_field(value, format_spec)
Expand All @@ -25,14 +25,16 @@ def format_field(self, value, format_spec):
if value.startswith("{") and value.endswith("}"):
return value[:-1] + ":" + format_spec + "}"

def _vformat(self, format_string, args, kwargs, used_args, recursion_depth,
auto_arg_index=0):
def _vformat(
self, format_string, args, kwargs, used_args, recursion_depth, auto_arg_index=0
):
# copied verbatim from parent class except for the # HACK
if recursion_depth < 0:
raise ValueError('Max string recursion exceeded')
raise ValueError("Max string recursion exceeded")
result = []
for literal_text, field_name, format_spec, conversion in \
self.parse(format_string):
for literal_text, field_name, format_spec, conversion in self.parse(
format_string
):

# output the literal text
if literal_text:
Expand All @@ -44,18 +46,22 @@ def _vformat(self, format_string, args, kwargs, used_args, recursion_depth,
# the formatting

# handle arg indexing when empty field_names are given.
if field_name == '':
if field_name == "":
if auto_arg_index is False:
raise ValueError('cannot switch from manual field '
'specification to automatic field '
'numbering')
raise ValueError(
"cannot switch from manual field "
"specification to automatic field "
"numbering"
)
field_name = str(auto_arg_index)
auto_arg_index += 1
elif field_name.isdigit():
if auto_arg_index:
raise ValueError('cannot switch from manual field '
'specification to automatic field '
'numbering')
raise ValueError(
"cannot switch from manual field "
"specification to automatic field "
"numbering"
)
# disable auto arg incrementing, if it gets
# used later on, then an exception will be raised
auto_arg_index = False
Expand All @@ -70,19 +76,23 @@ def _vformat(self, format_string, args, kwargs, used_args, recursion_depth,

# expand the format spec, if needed
format_spec, auto_arg_index = self._vformat(
format_spec, args, kwargs,
used_args, recursion_depth-1,
auto_arg_index=auto_arg_index)
format_spec,
args,
kwargs,
used_args,
recursion_depth - 1,
auto_arg_index=auto_arg_index,
)

# format the object and append to the result
# HACK: if the format_spec is invalid, assume this field_name was not meant to be a variable, and don't substitute anything
formatted_field = self.format_field(obj, format_spec)
if formatted_field is None:
result.append('{' + ':'.join([field_name, format_spec]) + '}')
result.append("{" + ":".join([field_name, format_spec]) + "}")
else:
result.append(formatted_field)

return ''.join(result), auto_arg_index
return "".join(result), auto_arg_index


PartialStringFormat = PartialStringFormatter()
Expand Down
Loading

0 comments on commit c6da54b

Please sign in to comment.