-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add support for jinja based template rendering of the dataset #438
Changes from 1 commit
902af4f
d65f759
1a4ef2e
0e9ad3f
8d3e77f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
dataprocessor: | ||
type: default | ||
datasets: | ||
- name: apply_custom_data_jinja_template | ||
data_paths: | ||
- "FILE_PATH" | ||
data_handlers: | ||
- name: apply_custom_data_formatting_jinja_template | ||
arguments: | ||
remove_columns: all | ||
batched: false | ||
fn_kwargs: | ||
dataset_text_field: "dataset_text_field" | ||
template: "dataset_template" |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
# https://spdx.dev/learn/handling-license-info/ | ||
|
||
# Third Party | ||
from jinja2.exceptions import TemplateSyntaxError | ||
from transformers import AutoTokenizer | ||
import datasets | ||
import pytest | ||
|
@@ -25,6 +26,7 @@ | |
|
||
# Local | ||
from tuning.data.data_handlers import ( | ||
apply_custom_data_formatting_jinja_template, | ||
apply_custom_data_formatting_template, | ||
combine_sequence, | ||
) | ||
|
@@ -57,6 +59,32 @@ def test_apply_custom_formatting_template(): | |
assert formatted_dataset["train"][0][formatted_dataset_field] == expected_response | ||
|
||
|
||
def test_apply_custom_formatting_jinja_template(): | ||
json_dataset = datasets.load_dataset( | ||
"json", data_files=TWITTER_COMPLAINTS_DATA_JSONL | ||
) | ||
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" | ||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | ||
formatted_dataset_field = "formatted_data_field" | ||
formatted_dataset = json_dataset.map( | ||
apply_custom_data_formatting_jinja_template, | ||
fn_kwargs={ | ||
"tokenizer": tokenizer, | ||
"dataset_text_field": formatted_dataset_field, | ||
"template": template, | ||
}, | ||
) | ||
# First response from the data file that is read. | ||
expected_response = ( | ||
"### Input: @HMRCcustomers No this is my first job" | ||
+ " \n\n ### Response: no complaint" | ||
+ tokenizer.eos_token | ||
) | ||
|
||
assert formatted_dataset_field in formatted_dataset["train"][0] | ||
assert formatted_dataset["train"][0][formatted_dataset_field] == expected_response | ||
|
||
|
||
def test_apply_custom_formatting_template_gives_error_with_wrong_keys(): | ||
"""Tests that the formatting function will throw error if wrong keys are passed to template""" | ||
json_dataset = datasets.load_dataset( | ||
|
@@ -76,6 +104,25 @@ def test_apply_custom_formatting_template_gives_error_with_wrong_keys(): | |
) | ||
|
||
|
||
def test_apply_custom_formatting_jinja_template_gives_error_with_wrong_keys(): | ||
"""Tests that the jinja formatting function will throw error if wrong keys are passed to template""" | ||
json_dataset = datasets.load_dataset( | ||
"json", data_files=TWITTER_COMPLAINTS_DATA_JSONL | ||
) | ||
template = "### Input: {{not found}} \n\n ### Response: {{text_label}}" | ||
formatted_dataset_field = "formatted_data_field" | ||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | ||
with pytest.raises((KeyError, TemplateSyntaxError)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we catch this error inside our code and give users a simple text error? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
For |
||
json_dataset.map( | ||
apply_custom_data_formatting_jinja_template, | ||
fn_kwargs={ | ||
"tokenizer": tokenizer, | ||
"dataset_text_field": formatted_dataset_field, | ||
"template": template, | ||
}, | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_element,output_element,expected_res", | ||
[ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ | |
import re | ||
|
||
# Third Party | ||
from jinja2 import Environment, StrictUndefined | ||
from transformers import AutoTokenizer | ||
|
||
|
||
|
@@ -137,6 +138,65 @@ def replace_text(match_obj): | |
} | ||
|
||
|
||
def apply_custom_data_formatting_jinja_template( | ||
element: Dict[str, str], | ||
tokenizer: AutoTokenizer, | ||
dataset_text_field: str, | ||
template: str, | ||
**kwargs, | ||
): | ||
"""Function to format datasets with jinja templates. | ||
Expects to be run as a HF Map API function. | ||
Args: | ||
element: the HF Dataset element loaded from a JSON or DatasetDict object. | ||
dataset_text_field: formatted_dataset_field. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also I know this is not on you but can you please fix the doc string for line 104 as well. |
||
template: Template to format data with. Features of Dataset | ||
should be referred to by {{key}}. | ||
Returns: | ||
Formatted HF Dataset | ||
""" | ||
|
||
template = transform_placeholders(template) | ||
env = Environment(undefined=StrictUndefined) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the suggestion @kmehant. @ashokponkumar @dushyantbehl @willmj If it sounds good to go ahead with this, I created a draft PR to include usage of |
||
jinja_template = env.from_string(template) | ||
|
||
try: | ||
rendered_text = jinja_template.render(element=element, **element) | ||
except Exception as e: | ||
raise KeyError(f"Dataset does not contain field in template. {e}") from e | ||
|
||
rendered_text += tokenizer.eos_token | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ashokponkumar Wanted to just confirm the removal of 2- @dushyantbehl Can I ask how Jinja templating could be used with pre-tokenized dataset (Having There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
if you feel can you take this up with this patch? to add the kwarg for
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As per offline discussion, addition of kwarg Though handler documentation is added in this PR. |
||
|
||
return {dataset_text_field: rendered_text} | ||
|
||
|
||
def transform_placeholders(template: str) -> str: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @dushyantbehl @ashokponkumar Are we handling nested dataset use case also, as I see every other handler expects dataset element There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we were only handling non nested datasets apart from chat templates...can we test things out with this patch if our code works for nested datasets then can we have a change of the argument type here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also if you can move to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As per offline discussion, handling of nested dataset would be checked and done for all handlers as part of this issue.
Done |
||
""" | ||
Function to detect all placeholders of the form {{...}}. | ||
- If the inside has a space (e.g. {{Tweet text}}), | ||
rewrite to {{ element['Tweet text'] }}. | ||
- If it doesn't have a space (e.g. {{text_label}}), leave it as is. | ||
- If it is already using dictionary-style access ({{ element['xyz'] }}), do nothing. | ||
""" | ||
|
||
pattern = r"\{\{([^}]+)\}\}" | ||
matches = re.findall(pattern, template) | ||
|
||
for match in matches: | ||
original_placeholder = f"{{{{{match}}}}}" | ||
trimmed = match.strip() | ||
|
||
if trimmed.startswith("element["): | ||
continue | ||
|
||
# If there's a space in the placeholder name, rewrite it to dictionary-style | ||
if " " in trimmed: | ||
new_placeholder = f"{{{{ element['{trimmed}'] }}}}" | ||
template = template.replace(original_placeholder, new_placeholder) | ||
|
||
return template | ||
|
||
|
||
def apply_tokenizer_chat_template( | ||
element: Dict[str, str], | ||
tokenizer: AutoTokenizer, | ||
|
@@ -157,5 +217,6 @@ def apply_tokenizer_chat_template( | |
"tokenize_and_apply_input_masking": tokenize_and_apply_input_masking, | ||
"apply_dataset_formatting": apply_dataset_formatting, | ||
"apply_custom_data_formatting_template": apply_custom_data_formatting_template, | ||
"apply_custom_data_formatting_jinja_template": apply_custom_data_formatting_jinja_template, | ||
"apply_tokenizer_chat_template": apply_tokenizer_chat_template, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you remove this import as its not used