Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for jinja based template rendering of the dataset #438

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/advanced-data-preprocessing.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ This library currently supports the following [preexisting data handlers](https:
Formats a dataset by appending an EOS token to a specified field.
- `apply_custom_data_formatting_template`:
Applies a custom template (e.g., Alpaca style) to format dataset elements.
- `apply_custom_data_formatting_jinja_template`:
Applies a custom jinja template (e.g., Alpaca style) to format dataset elements.
- `apply_tokenizer_chat_template`:
Uses a tokenizer's chat template to preprocess dataset elements, good for single/multi turn chat templates.

Expand Down
3 changes: 3 additions & 0 deletions tests/artifacts/predefined_data_configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML = os.path.join(
PREDEFINED_DATA_CONFIGS, "apply_custom_template.yaml"
)
DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML = os.path.join(
PREDEFINED_DATA_CONFIGS, "apply_custom_jinja_template.yaml"
)
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML = os.path.join(
PREDEFINED_DATA_CONFIGS, "pretokenized_json_data.yaml"
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
dataprocessor:
type: default
datasets:
- name: apply_custom_data_jinja_template
data_paths:
- "FILE_PATH"
data_handlers:
- name: apply_custom_data_formatting_jinja_template
arguments:
remove_columns: all
batched: false
fn_kwargs:
dataset_text_field: "dataset_text_field"
template: "dataset_template"
47 changes: 47 additions & 0 deletions tests/data/test_data_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# https://spdx.dev/learn/handling-license-info/

# Third Party
from jinja2.exceptions import TemplateSyntaxError
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove this import as its not used

from transformers import AutoTokenizer
import datasets
import pytest
Expand All @@ -25,6 +26,7 @@

# Local
from tuning.data.data_handlers import (
apply_custom_data_formatting_jinja_template,
apply_custom_data_formatting_template,
combine_sequence,
)
Expand Down Expand Up @@ -57,6 +59,32 @@ def test_apply_custom_formatting_template():
assert formatted_dataset["train"][0][formatted_dataset_field] == expected_response


def test_apply_custom_formatting_jinja_template():
json_dataset = datasets.load_dataset(
"json", data_files=TWITTER_COMPLAINTS_DATA_JSONL
)
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
formatted_dataset_field = "formatted_data_field"
formatted_dataset = json_dataset.map(
apply_custom_data_formatting_jinja_template,
fn_kwargs={
"tokenizer": tokenizer,
"dataset_text_field": formatted_dataset_field,
"template": template,
},
)
# First response from the data file that is read.
expected_response = (
"### Input: @HMRCcustomers No this is my first job"
+ " \n\n ### Response: no complaint"
+ tokenizer.eos_token
)

assert formatted_dataset_field in formatted_dataset["train"][0]
assert formatted_dataset["train"][0][formatted_dataset_field] == expected_response


def test_apply_custom_formatting_template_gives_error_with_wrong_keys():
"""Tests that the formatting function will throw error if wrong keys are passed to template"""
json_dataset = datasets.load_dataset(
Expand All @@ -76,6 +104,25 @@ def test_apply_custom_formatting_template_gives_error_with_wrong_keys():
)


def test_apply_custom_formatting_jinja_template_gives_error_with_wrong_keys():
"""Tests that the jinja formatting function will throw error if wrong keys are passed to template"""
json_dataset = datasets.load_dataset(
"json", data_files=TWITTER_COMPLAINTS_DATA_JSONL
)
template = "### Input: {{not found}} \n\n ### Response: {{text_label}}"
formatted_dataset_field = "formatted_data_field"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
with pytest.raises(KeyError):
json_dataset.map(
apply_custom_data_formatting_jinja_template,
fn_kwargs={
"tokenizer": tokenizer,
"dataset_text_field": formatted_dataset_field,
"template": template,
},
)


@pytest.mark.parametrize(
"input_element,output_element,expected_res",
[
Expand Down
15 changes: 13 additions & 2 deletions tests/data/test_data_preprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

# First Party
from tests.artifacts.predefined_data_configs import (
DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML,
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML,
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML,
Expand Down Expand Up @@ -693,6 +694,10 @@ def test_process_data_args_throws_error_where_needed(data_args, packing):
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL),
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_PARQUET),
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_ARROW),
(DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON),
(DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL),
(DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_PARQUET),
(DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_ARROW),
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSON),
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSONL),
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_PARQUET),
Expand Down Expand Up @@ -731,7 +736,10 @@ def test_process_dataconfig_file(data_config_path, data_path):

# Modify dataset_text_field and template according to dataset
formatted_dataset_field = "formatted_data_field"
if datasets_name == "apply_custom_data_template":
if datasets_name in (
"apply_custom_data_template",
"apply_custom_data_jinja_template",
):
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}"
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = {
"dataset_text_field": formatted_dataset_field,
Expand All @@ -753,7 +761,10 @@ def test_process_dataconfig_file(data_config_path, data_path):
assert set(train_set.column_names) == column_names
elif datasets_name == "pretokenized_dataset":
assert set(["input_ids", "labels"]).issubset(set(train_set.column_names))
elif datasets_name == "apply_custom_data_template":
elif datasets_name in (
"apply_custom_data_template",
"apply_custom_data_jinja_template",
):
assert formatted_dataset_field in set(train_set.column_names)


Expand Down
36 changes: 36 additions & 0 deletions tuning/data/data_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@
import re

# Third Party
from jinja2 import Environment, StrictUndefined
from transformers import AutoTokenizer

# Local
from tuning.utils.config_utils import transform_placeholders


### Utils for custom masking / manipulating input / output strs, etc
def combine_sequence(input_element: str, output_element: str, eos_token: str = ""):
Expand Down Expand Up @@ -137,6 +141,37 @@ def replace_text(match_obj):
}


def apply_custom_data_formatting_jinja_template(
element: Dict[str, str],
tokenizer: AutoTokenizer,
dataset_text_field: str,
template: str,
**kwargs,
):
"""Function to format datasets with jinja templates.
Expects to be run as a HF Map API function.
Args:
element: the HF Dataset element loaded from a JSON or DatasetDict object.
dataset_text_field: formatted_dataset_field.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add tokenizer to the args here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I know this is not on you but can you please fix the doc string for line 104 as well.

template: Template to format data with. Features of Dataset
should be referred to by {{key}}.
Returns:
Formatted HF Dataset
"""

template += tokenizer.eos_token
template = transform_placeholders(template)
env = Environment(undefined=StrictUndefined)
jinja_template = env.from_string(template)

try:
rendered_text = jinja_template.render(element=element, **element)
except Exception as e:
raise KeyError(f"Dataset does not contain field in template. {e}") from e

return {dataset_text_field: rendered_text}


def apply_tokenizer_chat_template(
element: Dict[str, str],
tokenizer: AutoTokenizer,
Expand All @@ -157,5 +192,6 @@ def apply_tokenizer_chat_template(
"tokenize_and_apply_input_masking": tokenize_and_apply_input_masking,
"apply_dataset_formatting": apply_dataset_formatting,
"apply_custom_data_formatting_template": apply_custom_data_formatting_template,
"apply_custom_data_formatting_jinja_template": apply_custom_data_formatting_jinja_template,
"apply_tokenizer_chat_template": apply_tokenizer_chat_template,
}
32 changes: 32 additions & 0 deletions tuning/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import json
import os
import pickle
import re

# Third Party
from peft import LoraConfig, PromptTuningConfig
Expand Down Expand Up @@ -135,3 +136,34 @@ def txt_to_obj(txt):
except UnicodeDecodeError:
# Otherwise the bytes are a pickled python dictionary
return pickle.loads(message_bytes)


def transform_placeholders(template: str) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we please rename this function to be more descriptive?

sanitise jinja placeholders?

"""
Function to detect all placeholders of the form {{...}}.
- If the inside has a space (e.g. {{Tweet text}}),
rewrite to {{ element['Tweet text'] }}.
- If it doesn't have a space (e.g. {{text_label}}), leave it as is.
- If it is already using dictionary-style access ({{ element['xyz'] }}), do nothing.

Args:
template: str
Return: template: str
"""

pattern = r"\{\{([^}]+)\}\}"
matches = re.findall(pattern, template)

for match in matches:
original_placeholder = f"{{{{{match}}}}}"
trimmed = match.strip()

if trimmed.startswith("element["):
continue

# If there's a space in the placeholder name, rewrite it to dictionary-style
if " " in trimmed:
new_placeholder = f"{{{{ element['{trimmed}'] }}}}"
template = template.replace(original_placeholder, new_placeholder)

return template