Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Using Sandboxed Environment in handler rendering Jinja template #456

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions tests/data/test_data_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,23 @@ def test_apply_custom_formatting_template_gives_error_with_wrong_keys():
)


def test_apply_custom_formatting_jinja_template_gives_error_with_wrong_keys():
@pytest.mark.parametrize(
"template",
[
"### Input: {{ not found }} \n\n ### Response: {{ text_label }}",
"### Input: }} Tweet text {{ \n\n ### Response: {{ text_label }}",
"### Input: {{ Tweet text }} \n\n ### Response: {{ ''.__class__ }}",
"### Input: {{ Tweet text }} \n\n ### Response: {{ undefined_variable.split() }}",
],
)
def test_apply_custom_formatting_jinja_template_gives_error_with_wrong_keys(template):
"""Tests that the jinja formatting function will throw error if wrong keys are passed to template"""
json_dataset = datasets.load_dataset(
"json", data_files=TWITTER_COMPLAINTS_DATA_JSONL
)
template = "### Input: {{not found}} \n\n ### Response: {{text_label}}"
formatted_dataset_field = "formatted_data_field"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
with pytest.raises(KeyError):
with pytest.raises((KeyError, ValueError)):
json_dataset.map(
apply_custom_data_formatting_jinja_template,
fn_kwargs={
Expand Down
25 changes: 21 additions & 4 deletions tuning/data/data_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import re

# Third Party
from jinja2 import Environment, StrictUndefined
from jinja2 import StrictUndefined, TemplateSyntaxError, UndefinedError
from jinja2.sandbox import SandboxedEnvironment, SecurityError
from transformers import AutoTokenizer

# Local
Expand Down Expand Up @@ -165,13 +166,29 @@ def apply_custom_data_formatting_jinja_template(

template += tokenizer.eos_token
template = process_jinja_placeholders(template)
env = Environment(undefined=StrictUndefined)
jinja_template = env.from_string(template)
env = SandboxedEnvironment(undefined=StrictUndefined)

try:
jinja_template = env.from_string(template)
except TemplateSyntaxError as e:
raise ValueError(
f"Invalid template syntax in provided Jinja template. {e.message}"
) from e

try:
rendered_text = jinja_template.render(element=element, **element)
except UndefinedError as e:
raise KeyError(
f"The dataset does not contain the key used in the provided Jinja template. {e.message}"
) from e
except SecurityError as e:
raise ValueError(
f"Unsafe operation detected in the provided Jinja template. {e.message}"
) from e
except Exception as e:
raise KeyError(f"Dataset does not contain field in template. {e}") from e
raise ValueError(
f"Error occurred while rendering the provided Jinja template. {e.message}"
) from e

return {dataset_text_field: rendered_text}

Expand Down