Skip to content

Commit

Permalink
unify data loading from HF and from disk
Browse files Browse the repository at this point in the history
ghstack-source-id: 932e7cce828a15c788b34f07c264e119068777fe
Pull Request resolved: #287
  • Loading branch information
tianyu-l committed Apr 30, 2024
1 parent d442743 commit b24ed42
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 59 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,5 @@ dist/*
data
out
wandb
*.json

torchtitan/datasets/**/*.model
Binary file not shown.
20 changes: 0 additions & 20 deletions torchtitan/datasets/c4_mini/dataset_info.json

This file was deleted.

13 changes: 0 additions & 13 deletions torchtitan/datasets/c4_mini/state.json

This file was deleted.

38 changes: 13 additions & 25 deletions torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
from torchtitan.datasets.tokenizer import Tokenizer
from torchtitan.logging_utils import logger

from datasets import load_dataset, load_from_disk
from datasets import load_dataset
from datasets.distributed import split_dataset_by_node

# map from dataset name to a local directory, or
# a dataset repository on the HF hub
_supported_datasets = {
"c4_mini": "torchtitan/datasets/c4_mini",
"c4": "allenai/c4",
Expand Down Expand Up @@ -66,7 +68,7 @@ def __init__(
rank: int = 0,
infinite: bool = False,
) -> None:
# allow user to pass in a local path to use unsupported datasets
# allow user to pass in a (local or HF hub) path to use unsupported datasets
if dataset_name not in _supported_datasets:
if dataset_path:
logger.warning(
Expand All @@ -79,32 +81,18 @@ def __init__(
f"Supported datasets are: {list(_supported_datasets.keys())}."
)

# special case to auto-load c4_mini (and any future datasets) from local dir
if dataset_name == "c4_mini":
dataset_path = f"torchtitan/datasets/{dataset_name}"
if not dataset_path:
dataset_path = _supported_datasets[dataset_name]
logger.info(f"Preparing {dataset_name} dataset from {dataset_path}")

# TODO: This is a temporary solution for small datasets.
# For large datasets we need to use a more scalable approach,
# and support shuffling and checkpointing.
if dataset_path:
logger.info(f"Loading {dataset_name} dataset locally from {dataset_path}")
ds = load_from_disk(dataset_path)
if dataset_name == "c4":
# c4 is huge, and requires both streaming and language selection
# (we default to en)
ds = load_dataset(dataset_path, name="en", split="train", streaming=True)
else:
logger.info(f"Preparing {dataset_name} dataset from HuggingFace")
# Setting `streaming=True` works for large dataset, but is slightly
# slower and unstable.
if dataset_name == "c4":
# c4 is huge, and requires both streaming and language selection
# (we default to en).
ds = load_dataset(
_supported_datasets[dataset_name],
"en",
split="train",
streaming=True,
)
else:
ds = load_dataset(_supported_datasets[dataset_name], split="train")
ds = load_dataset(dataset_path, split="train")

# TODO: support shuffling and checkpointing
self.dataset_name = dataset_name
self._data = split_dataset_by_node(ds, rank, world_size)
self._tokenizer = tokenizer
Expand Down

0 comments on commit b24ed42

Please sign in to comment.