diff --git a/tests/data/test_data_handlers.py b/tests/data/test_data_handlers.py index c7e443433..bfe366ef8 100644 --- a/tests/data/test_data_handlers.py +++ b/tests/data/test_data_handlers.py @@ -16,7 +16,6 @@ # https://spdx.dev/learn/handling-license-info/ # Third Party -from jinja2.exceptions import TemplateSyntaxError from transformers import AutoTokenizer import datasets import pytest diff --git a/tuning/data/data_handlers.py b/tuning/data/data_handlers.py index 640ebcbd6..d993dee31 100644 --- a/tuning/data/data_handlers.py +++ b/tuning/data/data_handlers.py @@ -23,7 +23,7 @@ from transformers import AutoTokenizer # Local -from tuning.utils.config_utils import transform_placeholders +from tuning.utils.config_utils import process_jinja_placeholders ### Utils for custom masking / manipulating input / output strs, etc @@ -112,6 +112,8 @@ def apply_custom_data_formatting_template( Expects to be run as a HF Map API function. Args: element: the HF Dataset element loaded from a JSON or DatasetDict object. + tokenizer: Tokenizer to be used for the EOS token, which will be appended + when formatting the data into a single sequence. Defaults to empty. template: Template to format data with. Features of Dataset should be referred to by {{key}} formatted_dataset_field: Dataset_text_field @@ -152,6 +154,8 @@ def apply_custom_data_formatting_jinja_template( Expects to be run as a HF Map API function. Args: element: the HF Dataset element loaded from a JSON or DatasetDict object. + tokenizer: Tokenizer to be used for the EOS token, which will be appended + when formatting the data into a single sequence. Defaults to empty. dataset_text_field: formatted_dataset_field. template: Template to format data with. Features of Dataset should be referred to by {{key}}. @@ -160,7 +164,7 @@ def apply_custom_data_formatting_jinja_template( """ template += tokenizer.eos_token - template = transform_placeholders(template) + template = process_jinja_placeholders(template) env = Environment(undefined=StrictUndefined) jinja_template = env.from_string(template) diff --git a/tuning/utils/config_utils.py b/tuning/utils/config_utils.py index 8448c437b..061d6017b 100644 --- a/tuning/utils/config_utils.py +++ b/tuning/utils/config_utils.py @@ -138,7 +138,7 @@ def txt_to_obj(txt): return pickle.loads(message_bytes) -def transform_placeholders(template: str) -> str: +def process_jinja_placeholders(template: str) -> str: """ Function to detect all placeholders of the form {{...}}. - If the inside has a space (e.g. {{Tweet text}}),