diff --git a/tests/data/test_data_handlers.py b/tests/data/test_data_handlers.py index bfe366ef8..4a3736e6d 100644 --- a/tests/data/test_data_handlers.py +++ b/tests/data/test_data_handlers.py @@ -103,15 +103,23 @@ def test_apply_custom_formatting_template_gives_error_with_wrong_keys(): ) -def test_apply_custom_formatting_jinja_template_gives_error_with_wrong_keys(): +@pytest.mark.parametrize( + "template", + [ + "### Input: {{ not found }} \n\n ### Response: {{ text_label }}", + "### Input: }} Tweet text {{ \n\n ### Response: {{ text_label }}", + "### Input: {{ Tweet text }} \n\n ### Response: {{ ''.__class__ }}", + "### Input: {{ Tweet text }} \n\n ### Response: {{ undefined_variable.split() }}", + ], +) +def test_apply_custom_formatting_jinja_template_gives_error_with_wrong_keys(template): """Tests that the jinja formatting function will throw error if wrong keys are passed to template""" json_dataset = datasets.load_dataset( "json", data_files=TWITTER_COMPLAINTS_DATA_JSONL ) - template = "### Input: {{not found}} \n\n ### Response: {{text_label}}" formatted_dataset_field = "formatted_data_field" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - with pytest.raises(KeyError): + with pytest.raises((KeyError, ValueError)): json_dataset.map( apply_custom_data_formatting_jinja_template, fn_kwargs={ diff --git a/tuning/data/data_handlers.py b/tuning/data/data_handlers.py index d993dee31..cdfb263b2 100644 --- a/tuning/data/data_handlers.py +++ b/tuning/data/data_handlers.py @@ -19,7 +19,8 @@ import re # Third Party -from jinja2 import Environment, StrictUndefined +from jinja2 import StrictUndefined, TemplateSyntaxError, UndefinedError +from jinja2.sandbox import SandboxedEnvironment, SecurityError from transformers import AutoTokenizer # Local @@ -165,13 +166,29 @@ def apply_custom_data_formatting_jinja_template( template += tokenizer.eos_token template = process_jinja_placeholders(template) - env = Environment(undefined=StrictUndefined) - jinja_template = env.from_string(template) + env = SandboxedEnvironment(undefined=StrictUndefined) + + try: + jinja_template = env.from_string(template) + except TemplateSyntaxError as e: + raise ValueError( + f"Invalid template syntax in provided Jinja template. {e.message}" + ) from e try: rendered_text = jinja_template.render(element=element, **element) + except UndefinedError as e: + raise KeyError( + f"The dataset does not contain the key used in the provided Jinja template. {e.message}" + ) from e + except SecurityError as e: + raise ValueError( + f"Unsafe operation detected in the provided Jinja template. {e.message}" + ) from e except Exception as e: - raise KeyError(f"Dataset does not contain field in template. {e}") from e + raise ValueError( + f"Error occurred while rendering the provided Jinja template. {e.message}" + ) from e return {dataset_text_field: rendered_text}