Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Allow hf dataset id to be passed via training_data_path #431

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,13 @@ pip install fms-hf-tuning[aim]
For more details on how to enable and use the trackers, Please see, [the experiment tracking section below](#experiment-tracking).

## Data Support
Users can pass training data in a single file using the `--training_data_path` argument along with other arguments required for various [use cases](#use-cases-supported-with-training_data_path-argument) (see details below) and the file can be in any of the [supported formats](#supported-data-formats). Alternatively, you can use our powerful [data preprocessing backend](./docs/advanced-data-preprocessing.md) to preprocess datasets on the fly.
Users can pass training data as either a single file or a Hugging Face dataset ID using the `--training_data_path` argument along with other arguments required for various [use cases](#use-cases-supported-with-training_data_path-argument) (see details below). If user choose to pass a file, it can be in any of the [supported formats](#supported-data-formats). Alternatively, you can use our powerful [data preprocessing backend](./docs/advanced-data-preprocessing.md) to preprocess datasets on the fly.


Below, we mention the list of supported data usecases via `--training_data_path` argument. For details of our advanced data preprocessing see more details in [Advanced Data Preprocessing](./docs/advanced-data-preprocessing.md).

## Supported Data Formats
We support the following data formats via `--training_data_path` argument
We support the following file formats via `--training_data_path` argument

Data Format | Tested Support
------------|---------------
Expand All @@ -77,6 +77,8 @@ JSONL | ✅
PARQUET | ✅
ARROW | ✅

As iterated above, we also support passing a HF dataset ID directly via `--training_data_path` argument.

## Use cases supported with `training_data_path` argument

### 1. Data formats with a single sequence and a specified response_template to use for masking on completion.
Expand Down
34 changes: 32 additions & 2 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import tempfile

# Third Party
from datasets.exceptions import DatasetGenerationError
from datasets.exceptions import DatasetGenerationError, DatasetNotFoundError
from transformers.trainer_callback import TrainerCallback
import pytest
import torch
Expand Down Expand Up @@ -326,7 +326,7 @@ def test_run_train_fails_training_data_path_not_exist():
"""Check fails when data path not found."""
updated_data_path_args = copy.deepcopy(DATA_ARGS)
updated_data_path_args.training_data_path = "fake/path"
with pytest.raises(ValueError):
with pytest.raises(DatasetNotFoundError):
sft_trainer.train(MODEL_ARGS, updated_data_path_args, TRAIN_ARGS, None)


Expand Down Expand Up @@ -998,6 +998,36 @@ def test_run_chat_style_ft_using_dataconfig(datafiles, dataconfigfile):
assert 'Provide two rhyming words for the word "love"' in output_inference


@pytest.mark.parametrize(
"data_args",
[
(
# sample hugging face dataset id
configs.DataArguments(
training_data_path="lhoestq/demo1",
data_formatter_template="### Text:{{review}} \n\n### Stars: {{star}}",
response_template="\n### Stars:",
)
)
],
)
def test_run_e2e_with_hf_dataset_id(data_args):
"""
Check if we can run an e2e test with a hf dataset id as training_data_path.
"""
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir

sft_trainer.train(MODEL_ARGS, data_args, train_args)

# validate ft tuning configs
_validate_training(tempdir)

# validate inference
_test_run_inference(checkpoint_path=_get_checkpoint_path(tempdir))


############################# Helper functions #############################
def _test_run_causallm_ft(training_args, model_args, data_args, tempdir):
train_args = copy.deepcopy(training_args)
Expand Down
2 changes: 1 addition & 1 deletion tuning/data/data_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def replace_text(match_obj):
if index_object not in element:
raise KeyError("Requested template string is not a valid key in dict")

return element[index_object]
return str(element[index_object])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dataset which I have tested in my unit test is lhoestq/demo1 which is used with the formatting template "### Text:{{review}} \n\n### Stars: {{star}}"
Now if you see the dataset the field star is an int which would cause an error if this element was not a string.

so this change supports any dataset formatting with dataset row types other than string as well.


return {
dataset_text_field: re.sub(r"{{([\s0-9a-zA-Z_\-\.]+)}}", replace_text, template)
Expand Down
70 changes: 42 additions & 28 deletions tuning/data/data_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,42 +130,56 @@ def _load_dataset(data_path=None, builder=None, data_files=None, data_dir=None):
f"Failed to generate the dataset from the provided {context}."
) from e

if datafile:
loader = get_loader_for_filepath(file_path=datafile)
if loader in (None, ""):
raise ValueError(f"data path is invalid [{datafile}]")
return _load_dataset(builder=loader, data_files=[datafile])
def _try_load_dataset(dataset_path, dataset_builder):
"""
Helper function to call load dataset on case by case basis to ensure we handle
directories and files (with or without builders) and hf datasets.

data_paths = datasetconfig.data_paths
builder = datasetconfig.builder
all_datasets = []
Args:
dataset_path: Path of directory/file, pattern, or hf dataset id.
dataset_builder: Optional builder to use if provided.
Returns: dataset
"""
if not dataset_path:
raise ValueError("Invalid dataset path")

for data_path in data_paths:
# CASE 1: User passes directory
if os.path.isdir(data_path): # Checks if path exists and isdirectory
if os.path.isdir(dataset_path): # Checks if path exists and it is a dir
# Directory case
if builder:
if dataset_builder:
# Load using a builder with a data_dir
dataset = _load_dataset(builder=builder, data_dir=data_path)
else:
# Load directly from the directory
dataset = _load_dataset(data_path=data_path)
else:
# Non-directory (file, pattern, HF dataset name)
# If no builder provided, attempt to infer one
effective_builder = (
builder if builder else get_loader_for_filepath(data_path)
return _load_dataset(builder=dataset_builder, data_dir=dataset_path)

# If no builder then load directly from the directory
return _load_dataset(data_path=dataset_path)

# Non-directory (file, pattern, HF dataset name)
# If no builder provided, attempt to infer one
effective_builder = (
dataset_builder
if dataset_builder
else get_loader_for_filepath(dataset_path)
)

if effective_builder:
# CASE 2: Files passed with builder. Load using the builder and specific files
return _load_dataset(
builder=effective_builder, data_files=[dataset_path]
)

if effective_builder:
# CASE 2: Files passed with builder. Load using the builder and specific files
dataset = _load_dataset(
builder=effective_builder, data_files=[data_path]
)
else:
# CASE 3: User passes files/folder/pattern/HF_dataset which has no builder
dataset = _load_dataset(data_path=data_path)
# CASE 3: User passes files/folder/pattern/HF_dataset which has no builder
# Still no builder, try if this is a dataset id
return _load_dataset(data_path=dataset_path)

if datafile:
return _try_load_dataset(datafile, None)

data_paths = datasetconfig.data_paths
builder = datasetconfig.builder
all_datasets = []

for data_path in data_paths:
dataset = _try_load_dataset(data_path, builder)
all_datasets.append(dataset)

# Logs warning if datasets have different columns
Expand Down
Loading