Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add data loading option to load from local file system #117

Merged
merged 4 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions torchtrain/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,14 @@ def init_args_from_command_line(
parser.add_argument(
"--training.dataset", type=str, default="alpaca", help="dataset to use"
)
parser.add_argument(
"--training.dataset_path",
type=str,
help=(
"Path to the dataset in the file system. If provided, data will be"
"loaded from this path instead of downloaded.",
),
)
parser.add_argument(
"--training.batch_size", type=int, default=8, help="batch size"
)
Expand Down
57 changes: 34 additions & 23 deletions torchtrain/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

from typing import List
from typing import List, Optional

import torch
from torch.utils.data import DataLoader, IterableDataset
Expand All @@ -10,7 +10,7 @@
from torchtrain.logging_utils import rank0_log
from torchtrain.utils import Color

from datasets import load_dataset
from datasets import load_dataset, load_from_disk
from datasets.distributed import split_dataset_by_node

_supported_datasets = {
Expand All @@ -20,30 +20,20 @@


class HuggingFaceDataset(IterableDataset):
"""PyTorch Representation of a Dataset from Hugging Face.

We currently support two datasets:
minipile (1M training entries)
alpaca (52K training entries)

>> MiniPile <<:
MiniPile dataset is detailed in the following paper:
https://arxiv.org/abs/2304.08442
"""PyTorch Representation of the HuggingFace Dataset.

Args:
dataset_name (str): name of the dataset to load
dataset_path (Optional[str]): Path to the dataset in the file system. If provided, data will be loaded from this path instead of downloaded.
tokenizer (TokenizerIf): Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method.
seq_len (int): max sequence length
world_size (int): number of data parallel processes participating in training
rank (int): rank of the current data parallel process
infinite (bool): whether to loop infinitely over the dataset

Data input format (minipile):
{
"text": "Open-end spinning devices with such rotor bearing arrangements are known in
various different embodiments, and have been extensively described,
for example in German Patent Publications"
}
We currently support two datasets:
alpaca (52K training entries)
minipile (1M training entries)

>> Alpaca <<:
Data input format (alpaca):
Expand All @@ -55,8 +45,17 @@ class HuggingFaceDataset(IterableDataset):
Oranges\nClass 2: Bananas, Strawberries\nClass 3: Pineapples", # noqa: B950
}

>> MiniPile <<:
MiniPile dataset is detailed in the paper: https://arxiv.org/abs/2304.08442
Data input format (minipile):
{
"text": "Open-end spinning devices with such rotor bearing arrangements are known in
various different embodiments, and have been extensively described,
for example in German Patent Publications"
}

Example:
>>> alpaca_ds = HuggingFaceDataset(tokenizer=tokenizer)
>>> alpaca_ds = HuggingFaceDataset(dataset_name="alpaca", dataset_path=None, tokenizer=tokenizer)
>>> for batch in Dataloader(alpaca_ds, batch_size=8):
print(f"Batch size: {len(batch)}")
Batch size: 8
Expand All @@ -65,21 +64,32 @@ class HuggingFaceDataset(IterableDataset):
def __init__(
self,
dataset_name: str,
dataset_path: Optional[str],
tokenizer: TokenizerIf,
seq_len: int = 2048,
world_size: int = 1,
rank: int = 0,
infinite: bool = False,
) -> None:
# TODO: This is a temporary solution for small datasets like Alpaca.
# For larger datasets we need to use a more scalable approach.
# Setting `streaming=True` works for large dataset, but the speed is slow.
if dataset_name not in _supported_datasets:
raise ValueError(
f"Dataset {dataset_name} is not supported. Supported datasets are: {_supported_datasets.keys()}"
)

ds = load_dataset(_supported_datasets[dataset_name], split="train")
# TODO: This is a temporary solution for small datasets like Alpaca.
# For larger datasets we need to use a more scalable approach.
if dataset_path:
rank0_log(
f"{Color.green}Loading '{dataset_name}' dataset locally from {dataset_path}...{Color.reset}"
)
ds = load_from_disk(dataset_path)
else:
rank0_log(
f"{Color.green}Downloading '{dataset_name}' dataset from HuggingFace...{Color.reset}"
)
# Setting `streaming=True` works for large dataset, but the speed is slow.
ds = load_dataset(_supported_datasets[dataset_name], split="train")

self.dataset_name = dataset_name
self._data = split_dataset_by_node(ds, rank, world_size)
self._tokenizer = tokenizer
Expand Down Expand Up @@ -115,6 +125,7 @@ def __iter__(self):

def build_hf_data_loader(
dataset_name: str,
dataset_path: Optional[str],
tokenizer: TokenizerIf,
batch_size: int,
seq_len: int,
Expand All @@ -123,7 +134,7 @@ def build_hf_data_loader(
infinite: bool = True,
):
hf_ds = HuggingFaceDataset(
dataset_name, tokenizer, seq_len, world_size, rank, infinite
dataset_name, dataset_path, tokenizer, seq_len, world_size, rank, infinite
)

return DataLoader(hf_ds, batch_size=batch_size)
1 change: 1 addition & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def main(job_config: JobConfig):
dp_degree, dp_rank = 1, 0
data_loader = build_dataloader_fn(
job_config.training.dataset,
job_config.training.dataset_path,
tokenizer,
job_config.training.batch_size,
job_config.training.seq_len,
Expand Down
Loading