Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add c4 dataset (177M, streaming), update multi-node support for latest job configs #124

Merged
merged 6 commits into from
Mar 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions multinode_trainer.slurm
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

#SBATCH --job-name=torchtrain_multi_node

#SBATCH --ntasks=2
#SBATCH --ntasks=4

#SBATCH --nodes=2
#SBATCH --nodes=4

#SBATCH --gpus-per-task=8

Expand Down Expand Up @@ -48,9 +48,11 @@ export NCCL_SOCKET_IFNAME="eth0,en,eth,em,bond"
export NCCL_BUFFSIZE=2097152
#export TORCH_DIST_INIT_BARRIER=1
export FI_EFA_SET_CUDA_SYNC_MEMOPS=0
#export USE_LIBUV=1
CONFIG_FILE=${CONFIG_FILE:-"./train_configs/llama_13b.toml"}
lessw2020 marked this conversation as resolved.
Show resolved Hide resolved

dcgmi profile --pause
# adjust sbatch --ntasks and sbatch --nodes above and --nnodes below
# to your specific node count, and update target launch file.
srun torchrun --nnodes 2 --nproc_per_node 8 --rdzv_id 101 --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:29500" ./train.py --steps 10
srun torchrun --nnodes 4 --nproc_per_node 8 --rdzv_id 101 --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:29500" ./train.py --job.config_file ${CONFIG_FILE}
dcgmi profile --resume
1 change: 1 addition & 0 deletions torchtrain/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@
dataloader_fn = {
"alpaca": build_hf_data_loader,
"minipile": build_hf_data_loader,
"c4": build_hf_data_loader,
}
34 changes: 30 additions & 4 deletions torchtrain/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
_supported_datasets = {
"alpaca": "tatsu-lab/alpaca",
"minipile": "JeanKaddour/minipile",
"c4": "allenai/c4",
}


Expand All @@ -31,9 +32,10 @@ class HuggingFaceDataset(IterableDataset):
rank (int): rank of the current data parallel process
infinite (bool): whether to loop infinitely over the dataset

We currently support two datasets:
We currently support three datasets:
alpaca (52K training entries)
minipile (1M training entries)
c4 (177M training entries - this dataset is streamed due to the size)

>> Alpaca <<:
Data input format (alpaca):
Expand All @@ -54,11 +56,22 @@ class HuggingFaceDataset(IterableDataset):
for example in German Patent Publications"
}

Example:
>> c4 (EN) <<:
c4 cleaned, English version
Data input format (c4):
{
'url': 'https://klyq.com/beginners-bbq-class-taking-place-in-missoula/',
'text': 'Beginners BBQ Class Taking Place in Missoula!\nDo you want to get better at making delicious BBQ? You will have the opportunity, put this on your calendar now. Thursday, September 22nd join World Class BBQ Champion, Tony Balay from Lonestar Smoke Rangers. He will be teaching a beginner level class for everyone who wants to get better with their culinary skills.\nHe will teach you everything you need to know to compete in a KCBS BBQ competition, including techniques, recipes, timelines, meat selection and trimming, plus smoker and fire information.\nThe cost to be in the class is $35 per person, and for spectators it is free. Included in the cost will be either a t-shirt or apron and you will be tasting samples of each meat that is prepared.',
'timestamp': '2019-04-25T12:57:54Z'
}

Example use (alpaca):
>>> alpaca_ds = HuggingFaceDataset(dataset_name="alpaca", dataset_path=None, tokenizer=tokenizer)
>>> for batch in Dataloader(alpaca_ds, batch_size=8):
print(f"Batch size: {len(batch)}")
Batch size: 8


"""

def __init__(
Expand All @@ -85,10 +98,19 @@ def __init__(
ds = load_from_disk(dataset_path)
else:
rank0_log(
f"{Color.green}Downloading '{dataset_name}' dataset from HuggingFace...{Color.reset}"
f"{Color.green}Preparing '{dataset_name}' dataset from HuggingFace...{Color.reset}"
)
# Setting `streaming=True` works for large dataset, but the speed is slow.
ds = load_dataset(_supported_datasets[dataset_name], split="train")
# c4 is huge, and requires both streaming and language selection (we default to en)
if dataset_name == "c4":
ds = load_dataset(
_supported_datasets[dataset_name],
"en",
split="train",
streaming=True,
)
else:
ds = load_dataset(_supported_datasets[dataset_name], split="train")

self.dataset_name = dataset_name
self._data = split_dataset_by_node(ds, rank, world_size)
Expand All @@ -114,6 +136,10 @@ def __iter__(self):
label = x[1:]
yield input, label
if not self.infinite:
rank0_log(
f"{Color.red}WARNING:{Color.reset} dataset {Color.yellow}'{self.dataset_name}'{Color.reset} has "
f"run out of data.{Color.reset}"
)
break
else:
# we are re-looping on the same dataset, warn user
Expand Down
2 changes: 1 addition & 1 deletion train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@ compile = false
checkpoint_interval = 3600
checkpoint_interval_type = "steps"
checkpoint_folder = ""
dataset = "alpaca" # supported datasets = minipile (1M), alpaca (52K)
dataset = "alpaca" # supported datasets = alpaca (52K), minipile (1M), c4 (177M)
Loading