Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FastChat-T5 doc+fix data processing #1430

Merged
merged 5 commits into from
May 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,36 @@ torchrun --nproc_per_node=4 --master_port=20001 fastchat/train/train_mem.py \

If you meet out-of-memory during model saving, see solutions [here](https://github.com/pytorch/pytorch/issues/98823).

### Fine-tuning FastChat-T5 with Local GPUs
You can use the following command to train FastChat-T5 with 4 x A100 (40GB).
```bash
torchrun --nproc_per_node=4 --master_port=9778 fastchat/train/train_flant5.py \
--model_name_or_path google/flan-t5-xl \
--data_path playground/data/dummy.json \
--bf16 True \
--output_dir ./checkpoints_flant5_3b \
--num_train_epochs 3 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 4 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 300 \
--save_total_limit 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--fsdp "full_shard auto_wrap" \
--fsdp_transformer_layer_cls_to_wrap T5Block \
--tf32 True \
--model_max_length 2048 \
--preprocessed_path ./preprocessed_data/processed.json \
--gradient_checkpointing True
```
After training, please use our post-processing [function](https://github.com/lm-sys/FastChat/blob/main/fastchat/utils.py#L164) to update the saved model weight. Additional discussions can be found [here](https://github.com/lm-sys/FastChat/issues/643).

### Fine-tuning on Any Cloud with SkyPilot
[SkyPilot](https://github.com/skypilot-org/skypilot) is a framework built by UC Berkeley for easily and cost effectively running ML workloads on any cloud (AWS, GCP, Azure, Lambda, etc.).
To use SkyPilot, install it with the following command and setup the cloud credentials locally following the instructions [here](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html).
Expand Down
28 changes: 18 additions & 10 deletions fastchat/train/train_flant5.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import Dict, Optional, Sequence

import torch
import torch.distributed as dist

import transformers
from torch.utils.data import Dataset
Expand Down Expand Up @@ -76,7 +77,6 @@ def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: st
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
# potential bug for T5 model
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa

Expand Down Expand Up @@ -276,13 +276,20 @@ def __init__(
super(SupervisedDataset, self).__init__()

# save to file
# Make sure only the first process is processing the dataset
if dist.get_rank() != 0:
dist.barrier()
self.preprocessed_path = preprocessed_path
if not os.path.exists("./preprocessed_data/"):
os.mkdir("preprocessed_data/")
if os.path.exists(self.preprocessed_path):
print("loading from preprocessed data")
data_dict = json.load(open(self.preprocessed_path, "r"))
logging.warning("loading from preprocessed data")
with open(self.preprocessed_path, "r") as f:
data_dict = json.load(f)
if dist.get_rank() == 0:
dist.barrier()
else:
if not os.path.exists("preprocessed_data"):
os.mkdir("preprocessed_data")
assert dist.get_rank() == 0, "Only the first process should process"
logging.warning("Loading data...")
list_data_dict = json.load(open(data_path, "r"))

Expand All @@ -294,11 +301,12 @@ def __init__(
data_dict = preprocess(sources, tokenizer)
json_data_dict = json.dumps(data_dict)

# open file for writing, "w"
f = open(self.preprocessed_path, "w")

# write json object to file
f.write(json_data_dict)
# Remember to close file to avoid concurrent r/w
with open(self.preprocessed_path,"w") as f:
f.write(json_data_dict)

# Release barrier
dist.barrier()

if num_data != -1:
data_dict["input_ids"] = data_dict["input_ids"][:num_data]
Expand Down