From 902af4f179629699fee4e44a055ac2b6006bec92 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Wed, 15 Jan 2025 18:09:04 -0500 Subject: [PATCH 1/3] Add support for jinja template Signed-off-by: Abhishek --- .../predefined_data_configs/__init__.py | 3 + .../apply_custom_jinja_template.yaml | 14 +++++ tests/data/test_data_handlers.py | 47 ++++++++++++++ tests/data/test_data_preprocessing_utils.py | 15 ++++- tuning/data/data_handlers.py | 61 +++++++++++++++++++ 5 files changed, 138 insertions(+), 2 deletions(-) create mode 100644 tests/artifacts/predefined_data_configs/apply_custom_jinja_template.yaml diff --git a/tests/artifacts/predefined_data_configs/__init__.py b/tests/artifacts/predefined_data_configs/__init__.py index c199406c6..8ecc07b8f 100644 --- a/tests/artifacts/predefined_data_configs/__init__.py +++ b/tests/artifacts/predefined_data_configs/__init__.py @@ -22,6 +22,9 @@ DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML = os.path.join( PREDEFINED_DATA_CONFIGS, "apply_custom_template.yaml" ) +DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML = os.path.join( + PREDEFINED_DATA_CONFIGS, "apply_custom_jinja_template.yaml" +) DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML = os.path.join( PREDEFINED_DATA_CONFIGS, "pretokenized_json_data.yaml" ) diff --git a/tests/artifacts/predefined_data_configs/apply_custom_jinja_template.yaml b/tests/artifacts/predefined_data_configs/apply_custom_jinja_template.yaml new file mode 100644 index 000000000..474068fe8 --- /dev/null +++ b/tests/artifacts/predefined_data_configs/apply_custom_jinja_template.yaml @@ -0,0 +1,14 @@ +dataprocessor: + type: default +datasets: + - name: apply_custom_data_jinja_template + data_paths: + - "FILE_PATH" + data_handlers: + - name: apply_custom_data_formatting_jinja_template + arguments: + remove_columns: all + batched: false + fn_kwargs: + dataset_text_field: "dataset_text_field" + template: "dataset_template" \ No newline at end of file diff --git a/tests/data/test_data_handlers.py b/tests/data/test_data_handlers.py index d2a390fe9..e7228ef5f 100644 --- a/tests/data/test_data_handlers.py +++ b/tests/data/test_data_handlers.py @@ -16,6 +16,7 @@ # https://spdx.dev/learn/handling-license-info/ # Third Party +from jinja2.exceptions import TemplateSyntaxError from transformers import AutoTokenizer import datasets import pytest @@ -25,6 +26,7 @@ # Local from tuning.data.data_handlers import ( + apply_custom_data_formatting_jinja_template, apply_custom_data_formatting_template, combine_sequence, ) @@ -57,6 +59,32 @@ def test_apply_custom_formatting_template(): assert formatted_dataset["train"][0][formatted_dataset_field] == expected_response +def test_apply_custom_formatting_jinja_template(): + json_dataset = datasets.load_dataset( + "json", data_files=TWITTER_COMPLAINTS_DATA_JSONL + ) + template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + formatted_dataset_field = "formatted_data_field" + formatted_dataset = json_dataset.map( + apply_custom_data_formatting_jinja_template, + fn_kwargs={ + "tokenizer": tokenizer, + "dataset_text_field": formatted_dataset_field, + "template": template, + }, + ) + # First response from the data file that is read. + expected_response = ( + "### Input: @HMRCcustomers No this is my first job" + + " \n\n ### Response: no complaint" + + tokenizer.eos_token + ) + + assert formatted_dataset_field in formatted_dataset["train"][0] + assert formatted_dataset["train"][0][formatted_dataset_field] == expected_response + + def test_apply_custom_formatting_template_gives_error_with_wrong_keys(): """Tests that the formatting function will throw error if wrong keys are passed to template""" json_dataset = datasets.load_dataset( @@ -76,6 +104,25 @@ def test_apply_custom_formatting_template_gives_error_with_wrong_keys(): ) +def test_apply_custom_formatting_jinja_template_gives_error_with_wrong_keys(): + """Tests that the jinja formatting function will throw error if wrong keys are passed to template""" + json_dataset = datasets.load_dataset( + "json", data_files=TWITTER_COMPLAINTS_DATA_JSONL + ) + template = "### Input: {{not found}} \n\n ### Response: {{text_label}}" + formatted_dataset_field = "formatted_data_field" + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + with pytest.raises((KeyError, TemplateSyntaxError)): + json_dataset.map( + apply_custom_data_formatting_jinja_template, + fn_kwargs={ + "tokenizer": tokenizer, + "dataset_text_field": formatted_dataset_field, + "template": template, + }, + ) + + @pytest.mark.parametrize( "input_element,output_element,expected_res", [ diff --git a/tests/data/test_data_preprocessing_utils.py b/tests/data/test_data_preprocessing_utils.py index 8de5dfc36..95153d1a4 100644 --- a/tests/data/test_data_preprocessing_utils.py +++ b/tests/data/test_data_preprocessing_utils.py @@ -29,6 +29,7 @@ # First Party from tests.artifacts.predefined_data_configs import ( + DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML, DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, @@ -693,6 +694,10 @@ def test_process_data_args_throws_error_where_needed(data_args, packing): (DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL), (DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_PARQUET), (DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_ARROW), + (DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON), + (DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL), + (DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_PARQUET), + (DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_ARROW), (DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSON), (DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSONL), (DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_PARQUET), @@ -731,7 +736,10 @@ def test_process_dataconfig_file(data_config_path, data_path): # Modify dataset_text_field and template according to dataset formatted_dataset_field = "formatted_data_field" - if datasets_name == "apply_custom_data_template": + if datasets_name in ( + "apply_custom_data_template", + "apply_custom_data_jinja_template", + ): template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { "dataset_text_field": formatted_dataset_field, @@ -753,7 +761,10 @@ def test_process_dataconfig_file(data_config_path, data_path): assert set(train_set.column_names) == column_names elif datasets_name == "pretokenized_dataset": assert set(["input_ids", "labels"]).issubset(set(train_set.column_names)) - elif datasets_name == "apply_custom_data_template": + elif datasets_name in ( + "apply_custom_data_template", + "apply_custom_data_jinja_template", + ): assert formatted_dataset_field in set(train_set.column_names) diff --git a/tuning/data/data_handlers.py b/tuning/data/data_handlers.py index 5b80dc4bb..4654edc08 100644 --- a/tuning/data/data_handlers.py +++ b/tuning/data/data_handlers.py @@ -19,6 +19,7 @@ import re # Third Party +from jinja2 import Environment, StrictUndefined from transformers import AutoTokenizer @@ -137,6 +138,65 @@ def replace_text(match_obj): } +def apply_custom_data_formatting_jinja_template( + element: Dict[str, str], + tokenizer: AutoTokenizer, + dataset_text_field: str, + template: str, + **kwargs, +): + """Function to format datasets with jinja templates. + Expects to be run as a HF Map API function. + Args: + element: the HF Dataset element loaded from a JSON or DatasetDict object. + dataset_text_field: formatted_dataset_field. + template: Template to format data with. Features of Dataset + should be referred to by {{key}}. + Returns: + Formatted HF Dataset + """ + + template = transform_placeholders(template) + env = Environment(undefined=StrictUndefined) + jinja_template = env.from_string(template) + + try: + rendered_text = jinja_template.render(element=element, **element) + except Exception as e: + raise KeyError(f"Dataset does not contain field in template. {e}") from e + + rendered_text += tokenizer.eos_token + + return {dataset_text_field: rendered_text} + + +def transform_placeholders(template: str) -> str: + """ + Function to detect all placeholders of the form {{...}}. + - If the inside has a space (e.g. {{Tweet text}}), + rewrite to {{ element['Tweet text'] }}. + - If it doesn't have a space (e.g. {{text_label}}), leave it as is. + - If it is already using dictionary-style access ({{ element['xyz'] }}), do nothing. + """ + + pattern = r"\{\{([^}]+)\}\}" + matches = re.findall(pattern, template) + + for match in matches: + original_placeholder = f"{{{{{match}}}}}" + trimmed = match.strip() + + if trimmed.startswith("element["): + continue + + # If there's a space in the placeholder name, rewrite it to dictionary-style + if " " in trimmed: + new_placeholder = f"{{{{ element['{trimmed}'] }}}}" + template = template.replace(original_placeholder, new_placeholder) + + return template + + def apply_tokenizer_chat_template( element: Dict[str, str], tokenizer: AutoTokenizer, @@ -157,5 +217,6 @@ def apply_tokenizer_chat_template( "tokenize_and_apply_input_masking": tokenize_and_apply_input_masking, "apply_dataset_formatting": apply_dataset_formatting, "apply_custom_data_formatting_template": apply_custom_data_formatting_template, + "apply_custom_data_formatting_jinja_template": apply_custom_data_formatting_jinja_template, "apply_tokenizer_chat_template": apply_tokenizer_chat_template, } From 0e9ad3fc43a52bf3e3ae46abe2604fed6142f8f6 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Mon, 3 Feb 2025 12:28:13 -0500 Subject: [PATCH 2/3] Suggested PR changes Signed-off-by: Abhishek --- docs/advanced-data-preprocessing.md | 2 ++ tests/data/test_data_handlers.py | 2 +- tuning/data/data_handlers.py | 33 ++++------------------------- tuning/utils/config_utils.py | 32 ++++++++++++++++++++++++++++ 4 files changed, 39 insertions(+), 30 deletions(-) diff --git a/docs/advanced-data-preprocessing.md b/docs/advanced-data-preprocessing.md index dd22a99d9..2e53dc61e 100644 --- a/docs/advanced-data-preprocessing.md +++ b/docs/advanced-data-preprocessing.md @@ -210,6 +210,8 @@ This library currently supports the following [preexisting data handlers](https: Formats a dataset by appending an EOS token to a specified field. - `apply_custom_data_formatting_template`: Applies a custom template (e.g., Alpaca style) to format dataset elements. + - `apply_custom_data_formatting_jinja_template`: + Applies a custom jinja template (e.g., Alpaca style) to format dataset elements. - `apply_tokenizer_chat_template`: Uses a tokenizer's chat template to preprocess dataset elements, good for single/multi turn chat templates. diff --git a/tests/data/test_data_handlers.py b/tests/data/test_data_handlers.py index e7228ef5f..c7e443433 100644 --- a/tests/data/test_data_handlers.py +++ b/tests/data/test_data_handlers.py @@ -112,7 +112,7 @@ def test_apply_custom_formatting_jinja_template_gives_error_with_wrong_keys(): template = "### Input: {{not found}} \n\n ### Response: {{text_label}}" formatted_dataset_field = "formatted_data_field" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - with pytest.raises((KeyError, TemplateSyntaxError)): + with pytest.raises(KeyError): json_dataset.map( apply_custom_data_formatting_jinja_template, fn_kwargs={ diff --git a/tuning/data/data_handlers.py b/tuning/data/data_handlers.py index 4654edc08..640ebcbd6 100644 --- a/tuning/data/data_handlers.py +++ b/tuning/data/data_handlers.py @@ -22,6 +22,9 @@ from jinja2 import Environment, StrictUndefined from transformers import AutoTokenizer +# Local +from tuning.utils.config_utils import transform_placeholders + ### Utils for custom masking / manipulating input / output strs, etc def combine_sequence(input_element: str, output_element: str, eos_token: str = ""): @@ -156,6 +159,7 @@ def apply_custom_data_formatting_jinja_template( Formatted HF Dataset """ + template += tokenizer.eos_token template = transform_placeholders(template) env = Environment(undefined=StrictUndefined) jinja_template = env.from_string(template) @@ -165,38 +169,9 @@ def apply_custom_data_formatting_jinja_template( except Exception as e: raise KeyError(f"Dataset does not contain field in template. {e}") from e - rendered_text += tokenizer.eos_token - return {dataset_text_field: rendered_text} -def transform_placeholders(template: str) -> str: - """ - Function to detect all placeholders of the form {{...}}. - - If the inside has a space (e.g. {{Tweet text}}), - rewrite to {{ element['Tweet text'] }}. - - If it doesn't have a space (e.g. {{text_label}}), leave it as is. - - If it is already using dictionary-style access ({{ element['xyz'] }}), do nothing. - """ - - pattern = r"\{\{([^}]+)\}\}" - matches = re.findall(pattern, template) - - for match in matches: - original_placeholder = f"{{{{{match}}}}}" - trimmed = match.strip() - - if trimmed.startswith("element["): - continue - - # If there's a space in the placeholder name, rewrite it to dictionary-style - if " " in trimmed: - new_placeholder = f"{{{{ element['{trimmed}'] }}}}" - template = template.replace(original_placeholder, new_placeholder) - - return template - - def apply_tokenizer_chat_template( element: Dict[str, str], tokenizer: AutoTokenizer, diff --git a/tuning/utils/config_utils.py b/tuning/utils/config_utils.py index b5dede937..8448c437b 100644 --- a/tuning/utils/config_utils.py +++ b/tuning/utils/config_utils.py @@ -18,6 +18,7 @@ import json import os import pickle +import re # Third Party from peft import LoraConfig, PromptTuningConfig @@ -135,3 +136,34 @@ def txt_to_obj(txt): except UnicodeDecodeError: # Otherwise the bytes are a pickled python dictionary return pickle.loads(message_bytes) + + +def transform_placeholders(template: str) -> str: + """ + Function to detect all placeholders of the form {{...}}. + - If the inside has a space (e.g. {{Tweet text}}), + rewrite to {{ element['Tweet text'] }}. + - If it doesn't have a space (e.g. {{text_label}}), leave it as is. + - If it is already using dictionary-style access ({{ element['xyz'] }}), do nothing. + + Args: + template: str + Return: template: str + """ + + pattern = r"\{\{([^}]+)\}\}" + matches = re.findall(pattern, template) + + for match in matches: + original_placeholder = f"{{{{{match}}}}}" + trimmed = match.strip() + + if trimmed.startswith("element["): + continue + + # If there's a space in the placeholder name, rewrite it to dictionary-style + if " " in trimmed: + new_placeholder = f"{{{{ element['{trimmed}'] }}}}" + template = template.replace(original_placeholder, new_placeholder) + + return template From 8d3e77fdfe5c015c4f778382c108f325883e42eb Mon Sep 17 00:00:00 2001 From: Abhishek Date: Tue, 4 Feb 2025 16:39:11 -0500 Subject: [PATCH 3/3] PR Changes Signed-off-by: Abhishek --- tests/data/test_data_handlers.py | 1 - tuning/data/data_handlers.py | 8 ++++++-- tuning/utils/config_utils.py | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/data/test_data_handlers.py b/tests/data/test_data_handlers.py index c7e443433..bfe366ef8 100644 --- a/tests/data/test_data_handlers.py +++ b/tests/data/test_data_handlers.py @@ -16,7 +16,6 @@ # https://spdx.dev/learn/handling-license-info/ # Third Party -from jinja2.exceptions import TemplateSyntaxError from transformers import AutoTokenizer import datasets import pytest diff --git a/tuning/data/data_handlers.py b/tuning/data/data_handlers.py index 640ebcbd6..d993dee31 100644 --- a/tuning/data/data_handlers.py +++ b/tuning/data/data_handlers.py @@ -23,7 +23,7 @@ from transformers import AutoTokenizer # Local -from tuning.utils.config_utils import transform_placeholders +from tuning.utils.config_utils import process_jinja_placeholders ### Utils for custom masking / manipulating input / output strs, etc @@ -112,6 +112,8 @@ def apply_custom_data_formatting_template( Expects to be run as a HF Map API function. Args: element: the HF Dataset element loaded from a JSON or DatasetDict object. + tokenizer: Tokenizer to be used for the EOS token, which will be appended + when formatting the data into a single sequence. Defaults to empty. template: Template to format data with. Features of Dataset should be referred to by {{key}} formatted_dataset_field: Dataset_text_field @@ -152,6 +154,8 @@ def apply_custom_data_formatting_jinja_template( Expects to be run as a HF Map API function. Args: element: the HF Dataset element loaded from a JSON or DatasetDict object. + tokenizer: Tokenizer to be used for the EOS token, which will be appended + when formatting the data into a single sequence. Defaults to empty. dataset_text_field: formatted_dataset_field. template: Template to format data with. Features of Dataset should be referred to by {{key}}. @@ -160,7 +164,7 @@ def apply_custom_data_formatting_jinja_template( """ template += tokenizer.eos_token - template = transform_placeholders(template) + template = process_jinja_placeholders(template) env = Environment(undefined=StrictUndefined) jinja_template = env.from_string(template) diff --git a/tuning/utils/config_utils.py b/tuning/utils/config_utils.py index 8448c437b..061d6017b 100644 --- a/tuning/utils/config_utils.py +++ b/tuning/utils/config_utils.py @@ -138,7 +138,7 @@ def txt_to_obj(txt): return pickle.loads(message_bytes) -def transform_placeholders(template: str) -> str: +def process_jinja_placeholders(template: str) -> str: """ Function to detect all placeholders of the form {{...}}. - If the inside has a space (e.g. {{Tweet text}}),