Skip to content

Commit

Permalink
feat: Add support for jinja based template rendering of the dataset (#…
Browse files Browse the repository at this point in the history
…438)

Signed-off-by: Abhishek <[email protected]>
  • Loading branch information
Abhishek-TAMU authored Feb 5, 2025
1 parent a6e4f9d commit 5c03aa8
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 2 deletions.
2 changes: 2 additions & 0 deletions docs/advanced-data-preprocessing.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ This library currently supports the following [preexisting data handlers](https:
Formats a dataset by appending an EOS token to a specified field.
- `apply_custom_data_formatting_template`:
Applies a custom template (e.g., Alpaca style) to format dataset elements.
- `apply_custom_data_formatting_jinja_template`:
Applies a custom jinja template (e.g., Alpaca style) to format dataset elements.
- `apply_tokenizer_chat_template`:
Uses a tokenizer's chat template to preprocess dataset elements, good for single/multi turn chat templates.

Expand Down
3 changes: 3 additions & 0 deletions tests/artifacts/predefined_data_configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML = os.path.join(
PREDEFINED_DATA_CONFIGS, "apply_custom_template.yaml"
)
DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML = os.path.join(
PREDEFINED_DATA_CONFIGS, "apply_custom_jinja_template.yaml"
)
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML = os.path.join(
PREDEFINED_DATA_CONFIGS, "pretokenized_json_data.yaml"
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
dataprocessor:
type: default
datasets:
- name: apply_custom_data_jinja_template
data_paths:
- "FILE_PATH"
data_handlers:
- name: apply_custom_data_formatting_jinja_template
arguments:
remove_columns: all
batched: false
fn_kwargs:
dataset_text_field: "dataset_text_field"
template: "dataset_template"
46 changes: 46 additions & 0 deletions tests/data/test_data_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

# Local
from tuning.data.data_handlers import (
apply_custom_data_formatting_jinja_template,
apply_custom_data_formatting_template,
combine_sequence,
)
Expand Down Expand Up @@ -57,6 +58,32 @@ def test_apply_custom_formatting_template():
assert formatted_dataset["train"][0][formatted_dataset_field] == expected_response


def test_apply_custom_formatting_jinja_template():
json_dataset = datasets.load_dataset(
"json", data_files=TWITTER_COMPLAINTS_DATA_JSONL
)
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
formatted_dataset_field = "formatted_data_field"
formatted_dataset = json_dataset.map(
apply_custom_data_formatting_jinja_template,
fn_kwargs={
"tokenizer": tokenizer,
"dataset_text_field": formatted_dataset_field,
"template": template,
},
)
# First response from the data file that is read.
expected_response = (
"### Input: @HMRCcustomers No this is my first job"
+ " \n\n ### Response: no complaint"
+ tokenizer.eos_token
)

assert formatted_dataset_field in formatted_dataset["train"][0]
assert formatted_dataset["train"][0][formatted_dataset_field] == expected_response


def test_apply_custom_formatting_template_gives_error_with_wrong_keys():
"""Tests that the formatting function will throw error if wrong keys are passed to template"""
json_dataset = datasets.load_dataset(
Expand All @@ -76,6 +103,25 @@ def test_apply_custom_formatting_template_gives_error_with_wrong_keys():
)


def test_apply_custom_formatting_jinja_template_gives_error_with_wrong_keys():
"""Tests that the jinja formatting function will throw error if wrong keys are passed to template"""
json_dataset = datasets.load_dataset(
"json", data_files=TWITTER_COMPLAINTS_DATA_JSONL
)
template = "### Input: {{not found}} \n\n ### Response: {{text_label}}"
formatted_dataset_field = "formatted_data_field"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
with pytest.raises(KeyError):
json_dataset.map(
apply_custom_data_formatting_jinja_template,
fn_kwargs={
"tokenizer": tokenizer,
"dataset_text_field": formatted_dataset_field,
"template": template,
},
)


@pytest.mark.parametrize(
"input_element,output_element,expected_res",
[
Expand Down
15 changes: 13 additions & 2 deletions tests/data/test_data_preprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

# First Party
from tests.artifacts.predefined_data_configs import (
DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML,
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML,
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML,
Expand Down Expand Up @@ -693,6 +694,10 @@ def test_process_data_args_throws_error_where_needed(data_args, packing):
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL),
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_PARQUET),
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_ARROW),
(DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON),
(DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL),
(DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_PARQUET),
(DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_ARROW),
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSON),
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSONL),
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_PARQUET),
Expand Down Expand Up @@ -731,7 +736,10 @@ def test_process_dataconfig_file(data_config_path, data_path):

# Modify dataset_text_field and template according to dataset
formatted_dataset_field = "formatted_data_field"
if datasets_name == "apply_custom_data_template":
if datasets_name in (
"apply_custom_data_template",
"apply_custom_data_jinja_template",
):
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}"
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = {
"dataset_text_field": formatted_dataset_field,
Expand All @@ -753,7 +761,10 @@ def test_process_dataconfig_file(data_config_path, data_path):
assert set(train_set.column_names) == column_names
elif datasets_name == "pretokenized_dataset":
assert set(["input_ids", "labels"]).issubset(set(train_set.column_names))
elif datasets_name == "apply_custom_data_template":
elif datasets_name in (
"apply_custom_data_template",
"apply_custom_data_jinja_template",
):
assert formatted_dataset_field in set(train_set.column_names)


Expand Down
40 changes: 40 additions & 0 deletions tuning/data/data_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@
import re

# Third Party
from jinja2 import Environment, StrictUndefined
from transformers import AutoTokenizer

# Local
from tuning.utils.config_utils import process_jinja_placeholders


### Utils for custom masking / manipulating input / output strs, etc
def combine_sequence(input_element: str, output_element: str, eos_token: str = ""):
Expand Down Expand Up @@ -108,6 +112,8 @@ def apply_custom_data_formatting_template(
Expects to be run as a HF Map API function.
Args:
element: the HF Dataset element loaded from a JSON or DatasetDict object.
tokenizer: Tokenizer to be used for the EOS token, which will be appended
when formatting the data into a single sequence. Defaults to empty.
template: Template to format data with. Features of Dataset
should be referred to by {{key}}
formatted_dataset_field: Dataset_text_field
Expand Down Expand Up @@ -137,6 +143,39 @@ def replace_text(match_obj):
}


def apply_custom_data_formatting_jinja_template(
element: Dict[str, str],
tokenizer: AutoTokenizer,
dataset_text_field: str,
template: str,
**kwargs,
):
"""Function to format datasets with jinja templates.
Expects to be run as a HF Map API function.
Args:
element: the HF Dataset element loaded from a JSON or DatasetDict object.
tokenizer: Tokenizer to be used for the EOS token, which will be appended
when formatting the data into a single sequence. Defaults to empty.
dataset_text_field: formatted_dataset_field.
template: Template to format data with. Features of Dataset
should be referred to by {{key}}.
Returns:
Formatted HF Dataset
"""

template += tokenizer.eos_token
template = process_jinja_placeholders(template)
env = Environment(undefined=StrictUndefined)
jinja_template = env.from_string(template)

try:
rendered_text = jinja_template.render(element=element, **element)
except Exception as e:
raise KeyError(f"Dataset does not contain field in template. {e}") from e

return {dataset_text_field: rendered_text}


def apply_tokenizer_chat_template(
element: Dict[str, str],
tokenizer: AutoTokenizer,
Expand All @@ -157,5 +196,6 @@ def apply_tokenizer_chat_template(
"tokenize_and_apply_input_masking": tokenize_and_apply_input_masking,
"apply_dataset_formatting": apply_dataset_formatting,
"apply_custom_data_formatting_template": apply_custom_data_formatting_template,
"apply_custom_data_formatting_jinja_template": apply_custom_data_formatting_jinja_template,
"apply_tokenizer_chat_template": apply_tokenizer_chat_template,
}
32 changes: 32 additions & 0 deletions tuning/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import json
import os
import pickle
import re

# Third Party
from peft import LoraConfig, PromptTuningConfig
Expand Down Expand Up @@ -135,3 +136,34 @@ def txt_to_obj(txt):
except UnicodeDecodeError:
# Otherwise the bytes are a pickled python dictionary
return pickle.loads(message_bytes)


def process_jinja_placeholders(template: str) -> str:
"""
Function to detect all placeholders of the form {{...}}.
- If the inside has a space (e.g. {{Tweet text}}),
rewrite to {{ element['Tweet text'] }}.
- If it doesn't have a space (e.g. {{text_label}}), leave it as is.
- If it is already using dictionary-style access ({{ element['xyz'] }}), do nothing.
Args:
template: str
Return: template: str
"""

pattern = r"\{\{([^}]+)\}\}"
matches = re.findall(pattern, template)

for match in matches:
original_placeholder = f"{{{{{match}}}}}"
trimmed = match.strip()

if trimmed.startswith("element["):
continue

# If there's a space in the placeholder name, rewrite it to dictionary-style
if " " in trimmed:
new_placeholder = f"{{{{ element['{trimmed}'] }}}}"
template = template.replace(original_placeholder, new_placeholder)

return template

0 comments on commit 5c03aa8

Please sign in to comment.