Skip to content

Commit

Permalink
Custom Dataset refactoring + docs (#715)
Browse files Browse the repository at this point in the history
EDIT: removed the specific new functions in hf_datasets.py and kept most
of the doc changes and will not go for a registration based API

Fixes #311

This PR describes the status quo of how new datasets should be
registered today, in that there's the implicit assumption that people
are installing torchtitan from source and updating hf_datasets.py to
support new datasets. As an example I passed in the wikipedia dataset

The main "nice" thing about this PR is that `class HuggingFaceDataset`
is now agnostic to the c4 dataset which makes it easier for new people
to add datasets without reading the rest of the file

There's another direction this PR could have went in which was to allow
custom dataset registration, the benefit is people can support new
datasets without installing titan from source but registration apis can
feel kinda "bureaucratic" and presumably people would need to register
the dataset somewhere, probably `train.py`?

Not totally sure which is more in line with the repo's goals so opening
this PR to discuss

```python
def register_dataset(
    name: str,
    loader: Callable[[str, Dict[str, Any]], Any],
    processor: Callable[[Dict[str, Any]], str],
    path: Optional[str] = None,
) -> None:

    DATASET_LOADERS[name] = loader
    DATASET_TEXT_PROCESSORS[name] = processor

def wikipedia_loader(dataset_path: str, **kwargs):
    return load_dataset(
        dataset_path,
        name="20220301.en",
        split="train", 
        streaming=True,
        trust_remote_code=True,
    )

def wikipedia_processor(sample: Dict[str, Any]) -> str:
    return f"{sample['title']}\n\n{sample['text']}"

register_dataset(
    name="wikipedia",
    loader=wikipedia_loader,
    processor=wikipedia_processor,
    path="wikipedia"
)
```
  • Loading branch information
msaroufim authored Dec 5, 2024
1 parent afa8229 commit 7281e0b
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 83 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ You may want to see how the model is defined or how parallelism techniques are a
4. `torch.compile` support
5. [Float8](https://discuss.pytorch.org/t/distributed-w-torchtitan-enabling-float8-all-gather-in-fsdp2/209323) support ([how-to](docs/float8.md))
6. DDP and HSDP
7. Checkpointable data-loading, with the C4 dataset pre-configured (144M entries)
7. Checkpointable data-loading, with the C4 dataset pre-configured (144M entries) and support for [custom datasets](docs/datasets.md)
8. Learning rate scheduler, meta-init, (optional) fused RMSNorm kernel
9. Loss, GPU memory, throughput (tokens/sec), and MFU displayed and logged via [Tensorboard or Weights & Biases](/docs/metrics.md)
10. Debugging tools including CPU/GPU profiling, [memory profiling](docs/memory_profiler.md), [Flight Recorder](#debugging), etc.
Expand Down
74 changes: 74 additions & 0 deletions docs/datasets.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Custom Datasets in TorchTitan

TorchTitan is designed to work seamlessly with most HuggingFace datasets. While we provide the C4 dataset for numerics and convergence testing, you can easily add support for your own datasets. Here's how to do it using Wikipedia as an example.

## Quick Start
Locate the dataset configuration file:
```
torchtitan/datasets/hf_datasets/hf_datasets.py
```

## Adding Your Dataset
You'll need to add three components:
1. A dataset loader function
2. A sample processor function
3. A dataset configuration entry

### 1. Define Dataset Loader
Create a function that specifies how to load your dataset:

```python
def load_wikipedia_dataset(dataset_path: str, **kwargs):
"""Load Wikipedia dataset with specific configuration."""
logger.info("Loading Wikipedia dataset...")
return load_dataset(
dataset_path,
name="20220301.en",
split="train",
streaming=True,
trust_remote_code=True,
)
```

### 2. Define Sample Processor
Create a function that processes individual samples from your dataset:

```python
def process_wikipedia_text(sample: Dict[str, Any]) -> str:
"""Process Wikipedia dataset sample text."""
return f"{sample['title']}\n\n{sample['text']}"
```

### 3. Register Your Dataset
Add your dataset configuration to the DATASETS dictionary:

```python
DATASETS = {
# ... existing datasets ...
"wikipedia": DatasetConfig(
path="wikipedia", # default HuggingFace dataset path
loader=load_wikipedia_dataset,
text_processor=process_wikipedia_text,
),
}
```

### 4. Configure Your Training
In your training configuration file (`.toml`), set your dataset:

```toml
dataset = "wikipedia"
```

That's it! Your custom dataset is now ready to use with TorchTitan.

## Key Points
- The DatasetConfig contains all necessary components for a dataset:
- `path`: The default path to the dataset (can be overridden during training)
- `loader`: Function to load the dataset
- `text_processor`: Function to process individual samples
- The loader function should return a HuggingFace dataset object
- The processor function should return a string that combines the relevant fields from your dataset
- Use `streaming=True` for large datasets to manage memory efficiently

Now you can start training with your custom dataset!
153 changes: 72 additions & 81 deletions torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
# LICENSE file in the root directory of this source tree.

import pickle
from typing import Any, Dict, List, Optional
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional

import torch
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import IterableDataset

from torchdata.stateful_dataloader import StatefulDataLoader

from torchtitan.datasets.tokenizer import Tokenizer
Expand All @@ -19,49 +19,56 @@
from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node

# map from dataset name to a local directory, or
# a dataset repository on the HF hub
_supported_datasets = {
"c4_test": "test/assets/c4_test",
"c4": "allenai/c4",

def _load_c4_dataset(dataset_path: str):
"""Load C4 dataset with default configuration."""
return load_dataset(dataset_path, name="en", split="train", streaming=True)


def _process_c4_text(sample: Dict[str, Any]) -> str:
"""Process C4 dataset sample text."""
return sample["text"]


@dataclass
class DatasetConfig:
path: str
loader: Callable
text_processor: Callable


# Add your dataset here here - more information at docs/datasets.md
DATASETS = {
"c4": DatasetConfig(
path="allenai/c4",
loader=_load_c4_dataset,
text_processor=_process_c4_text,
),
"c4_test": DatasetConfig(
path="test/assets/c4_test",
loader=lambda path: load_dataset(path, split="train"),
text_processor=_process_c4_text,
),
}


class HuggingFaceDataset(IterableDataset, Stateful):
"""PyTorch Representation of the HuggingFace Dataset.
Args:
dataset_name (str): name of the dataset to load
dataset_path (Optional[str]):
Path to the dataset in the file system. If provided, data will be loaded
from this path instead of downloaded.
tokenizer (Tokenizer):
Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method.
seq_len (int): max sequence length
world_size (int): number of data parallel processes participating in training
rank (int): rank of the current data parallel process
infinite (bool): whether to loop infinitely over the dataset
We currently support the c4 dataset, and a subset of it for testing purposes:
c4_test (2K training entries)
c4 (177M training entries - this dataset is streamed due to the size)
>> c4 (EN) <<:
c4 cleaned, English version
Data input format (c4):
{
'url': 'https://klyq.com/beginners-bbq-class-taking-place-in-missoula/',
'text': 'Beginners BBQ Class Taking Place in Missoula!\nDo you want to get better at ...',
'timestamp': '2019-04-25T12:57:54Z'
}
Example use (c4):
>>> ds = HuggingFaceDataset(dataset_name="c4", dataset_path=None, tokenizer=tokenizer)
>>> for batch in Dataloader(ds, batch_size=8):
print(f"Batch size: {len(batch)}")
Batch size: 8
"""
def _validate_dataset(
dataset_name: str, dataset_path: str = None
) -> tuple[str, Callable, Callable]:
"""Validate dataset name and path."""
if dataset_name not in DATASETS:
raise ValueError(
f"Dataset {dataset_name} is not supported. "
f"Supported datasets are: {list(DATASETS.keys())}"
)

config = DATASETS[dataset_name]
path = dataset_path or config.path
logger.info(f"Preparing {dataset_name} dataset from {path}")
return path, config.loader, config.text_processor


class HuggingFaceDataset(IterableDataset, Stateful):
def __init__(
self,
dataset_name: str,
Expand All @@ -72,47 +79,41 @@ def __init__(
rank: int = 0,
infinite: bool = False,
) -> None:
# allow user to pass in a (local or HF hub) path to use unsupported datasets
if dataset_name not in _supported_datasets:
if dataset_path:
logger.warning(
f"Dataset {dataset_name} is not tested or verfied. "
f"Recommended datasets are: {list(_supported_datasets.keys())}"
)
else:
raise ValueError(
f"Dataset {dataset_name} is not supported. "
f"Supported datasets are: {list(_supported_datasets.keys())}"
)

if not dataset_path:
dataset_path = _supported_datasets[dataset_name]
logger.info(f"Preparing {dataset_name} dataset from {dataset_path}")

if dataset_name == "c4":
# c4 is huge, and requires both streaming and language selection
# (we default to en)
ds = load_dataset(dataset_path, name="en", split="train", streaming=True)
else:
ds = load_dataset(dataset_path, split="train")

# TODO: support shuffling
# Force lowercase for consistent comparison
dataset_name = dataset_name.lower()

path, dataset_loader, text_processor = _validate_dataset(
dataset_name, dataset_path
)
ds = dataset_loader(path)

self.dataset_name = dataset_name
self._data = split_dataset_by_node(ds, rank, world_size)
self._tokenizer = tokenizer
self.seq_len = seq_len
self.infinite = infinite
self._text_processor = text_processor

# variables for checkpointing
# Variables for checkpointing
self._sample_idx = 0
self._all_tokens: List[int] = []

def _get_data_iter(self):
if self._sample_idx == 0:
return iter(self._data)

if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
return iter([])

return iter(self._data.skip(self._sample_idx))

def __iter__(self):
max_buffer_token_len = 1 + self.seq_len

while True:
for sample in self._get_data_iter():
sample_text = sample["text"]
# Use the dataset-specific text processor
sample_text = self._text_processor(sample)
sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True)
self._all_tokens.extend(sample_tokens)
self._sample_idx += 1
Expand All @@ -133,16 +134,6 @@ def __iter__(self):
self._sample_idx = 0
logger.warning(f"Dataset {self.dataset_name} is being re-looped")

def _get_data_iter(self):
if self._sample_idx == 0:
return iter(self._data)

# As skipping to the end throws an error in case of map-style dataset, return an empty iterator
if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
return iter([])

return iter(self._data.skip(self._sample_idx))

def load_state_dict(self, state_dict):
self._sample_idx = state_dict["sample_idx"]
self._all_tokens = state_dict["token_buffer"]
Expand Down Expand Up @@ -184,12 +175,12 @@ def build_hf_data_loader(
tokenizer: Tokenizer,
batch_size: int,
seq_len: int,
world_size,
rank,
world_size: int,
rank: int,
infinite: bool = True,
):
"""Build a data loader for HuggingFace datasets."""
hf_ds = HuggingFaceDataset(
dataset_name, dataset_path, tokenizer, seq_len, world_size, rank, infinite
)

return DPAwareDataLoader(rank, hf_ds, batch_size=batch_size)
1 change: 0 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ def main(job_config: JobConfig):
# build tokenizer
tokenizer_type = model_name_to_tokenizer[model_name]
tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path)

# build dataloader
data_loader = build_hf_data_loader(
job_config.training.dataset,
Expand Down

0 comments on commit 7281e0b

Please sign in to comment.