Skip to content

Commit

Permalink
fix: DIA-1534: prompts containing {} throw an error (#234)
Browse files Browse the repository at this point in the history
  • Loading branch information
matt-bernstein authored Oct 24, 2024
1 parent d6d39e5 commit b0f1e1d
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 244 deletions.
4 changes: 2 additions & 2 deletions adala/runtimes/_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def record_to_record(
)

messages = get_messages(
input_template.format(**record, **extra_fields),
partial_str_format(input_template, **record, **extra_fields),
instructions_template,
instructions_first,
)
Expand Down Expand Up @@ -404,7 +404,7 @@ async def batch_to_batch(
extra_fields = extra_fields or {}
user_prompts = batch.apply(
# TODO: remove "extra_fields" to avoid name collisions
lambda row: input_template.format(**row, **extra_fields),
lambda row: partial_str_format(input_template, **row, **extra_fields),
axis=1,
).tolist()

Expand Down
7 changes: 4 additions & 3 deletions adala/skills/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,8 +363,8 @@ def improve(
)
examples.append(
f"### Example #{i}\n\n"
f"{self.input_template.format(**row)}\n\n"
f"{self.output_template.format(**row)}\n\n"
f"{partial_str_format(self.input_template, **row)}\n\n"
f"{partial_str_format(self.output_template, **row)}\n\n"
f'User feedback: {row[f"{train_skill_output}__fb"]}\n\n'
)

Expand Down Expand Up @@ -625,7 +625,8 @@ def _iter_over_chunks(self, input: InternalDataFrame, chunk_size: Optional[int]
agg_chunk = chunk\
.reset_index()\
.apply(
lambda row: self.input_template.format(
lambda row: partial_str_format(
self.input_template,
**row, **extra_fields, i=int(row.name) + 1
),
axis=1,
Expand Down
7 changes: 4 additions & 3 deletions adala/skills/collection/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Optional
from adala.skills._base import TransformSkill
from adala.utils.internal_data import InternalDataFrame
from adala.utils.parse import partial_str_format
from adala.runtimes.base import Runtime
from adala.memories import Memory
from adala.memories.vectordb import VectorDBMemory
Expand Down Expand Up @@ -66,13 +67,13 @@ def apply(
If instructions are given, the output field contains the generated output.
"""
input_strings = input.apply(
lambda r: self.input_template.format(**r), axis=1
lambda r: partial_str_format(self.input_template, **r), axis=1
).tolist()
rag_input_data = self.memory.retrieve_many(
input_strings, num_results=self.num_results
)
rag_input_strings = [
"\n\n".join(self.rag_input_template.format(**i) for i in rag_items)
"\n\n".join(partial_str_format(self.rag_input_template, **i) for i in rag_items)
for rag_items in rag_input_data
]
output_fields = self.get_output_fields()
Expand Down Expand Up @@ -122,7 +123,7 @@ def improve(
indices = feedback.match.index
inputs = predictions.loc[indices]
input_strings = inputs.apply(
lambda r: self.input_template.format(**r), axis=1
lambda r: partial_str_format(self.input_template, **r), axis=1
).tolist()
fb = feedback.feedback.loc[indices].rename(columns=lambda c: f"{c}__fb")
inputs = inputs.join(fb)
Expand Down
8 changes: 8 additions & 0 deletions adala/utils/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ def get_value(self, key, args, kwds):
return "{" + key + "}"
else:
Formatter.get_value(key, args, kwds)

def format_field(self, value, format_spec):
try:
return super().format_field(value, format_spec)
except ValueError:
# HACK: the value was an unfilled variable or not a variable at all, so the format spec should be considered part of the variable name
if value.startswith("{") and value.endswith("}"):
return value[:-1] + ":" + format_spec + "}"


PartialStringFormat = PartialStringFormatter()
Expand Down
Loading

0 comments on commit b0f1e1d

Please sign in to comment.