diff --git a/README.md b/README.md index e73da98fe..bc3e7070d 100644 --- a/README.md +++ b/README.md @@ -292,6 +292,36 @@ torchrun --nproc_per_node=4 --master_port=20001 fastchat/train/train_mem.py \ If you meet out-of-memory during model saving, see solutions [here](https://github.com/pytorch/pytorch/issues/98823). +### Fine-tuning FastChat-T5 with Local GPUs +You can use the following command to train FastChat-T5 with 4 x A100 (40GB). +```bash +torchrun --nproc_per_node=4 --master_port=9778 fastchat/train/train_flant5.py \ + --model_name_or_path google/flan-t5-xl \ + --data_path playground/data/dummy.json \ + --bf16 True \ + --output_dir ./checkpoints_flant5_3b \ + --num_train_epochs 3 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 4 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 300 \ + --save_total_limit 1 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --fsdp "full_shard auto_wrap" \ + --fsdp_transformer_layer_cls_to_wrap T5Block \ + --tf32 True \ + --model_max_length 2048 \ + --preprocessed_path ./preprocessed_data/processed.json \ + --gradient_checkpointing True +``` +After training, please use our post-processing [function](https://github.com/lm-sys/FastChat/blob/main/fastchat/utils.py#L164) to update the saved model weight. Additional discussions can be found [here](https://github.com/lm-sys/FastChat/issues/643). + ### Fine-tuning on Any Cloud with SkyPilot [SkyPilot](https://github.com/skypilot-org/skypilot) is a framework built by UC Berkeley for easily and cost effectively running ML workloads on any cloud (AWS, GCP, Azure, Lambda, etc.). To use SkyPilot, install it with the following command and setup the cloud credentials locally following the instructions [here](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html). diff --git a/fastchat/train/train_flant5.py b/fastchat/train/train_flant5.py index 5200a5152..2dc4b4860 100755 --- a/fastchat/train/train_flant5.py +++ b/fastchat/train/train_flant5.py @@ -24,6 +24,7 @@ from typing import Dict, Optional, Sequence import torch +import torch.distributed as dist import transformers from torch.utils.data import Dataset @@ -76,7 +77,6 @@ def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: st state_dict = trainer.model.state_dict() if trainer.args.should_save: cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} - # potential bug for T5 model del state_dict trainer._save(output_dir, state_dict=cpu_state_dict) # noqa @@ -276,13 +276,20 @@ def __init__( super(SupervisedDataset, self).__init__() # save to file + # Make sure only the first process is processing the dataset + if dist.get_rank() != 0: + dist.barrier() self.preprocessed_path = preprocessed_path - if not os.path.exists("./preprocessed_data/"): - os.mkdir("preprocessed_data/") if os.path.exists(self.preprocessed_path): - print("loading from preprocessed data") - data_dict = json.load(open(self.preprocessed_path, "r")) + logging.warning("loading from preprocessed data") + with open(self.preprocessed_path, "r") as f: + data_dict = json.load(f) + if dist.get_rank() == 0: + dist.barrier() else: + if not os.path.exists("preprocessed_data"): + os.mkdir("preprocessed_data") + assert dist.get_rank() == 0, "Only the first process should process" logging.warning("Loading data...") list_data_dict = json.load(open(data_path, "r")) @@ -294,11 +301,12 @@ def __init__( data_dict = preprocess(sources, tokenizer) json_data_dict = json.dumps(data_dict) - # open file for writing, "w" - f = open(self.preprocessed_path, "w") - - # write json object to file - f.write(json_data_dict) + # Remember to close file to avoid concurrent r/w + with open(self.preprocessed_path,"w") as f: + f.write(json_data_dict) + + # Release barrier + dist.barrier() if num_data != -1: data_dict["input_ids"] = data_dict["input_ids"][:num_data]