From 173944bb64108a72543c4f33a02bc280c91ed6dc Mon Sep 17 00:00:00 2001 From: Dacheng Li Date: Sat, 20 May 2023 19:45:44 +0400 Subject: [PATCH 1/5] update doc + fix dist data process in T5 --- README.md | 30 ++++++++++++++++++++++++++++++ fastchat/train/train_flant5.py | 34 ++++++++++++++++++++++------------ 2 files changed, 52 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index d143eb0ae..1784f4fb3 100644 --- a/README.md +++ b/README.md @@ -292,6 +292,36 @@ torchrun --nproc_per_node=4 --master_port=20001 fastchat/train/train_mem.py \ If you meet out-of-memory during model saving, see solutions [here](https://github.com/pytorch/pytorch/issues/98823). +### Fine-tuning FastChat-T5 with Local GPUs +You can use the following command to train FastChat-T5 with 4 x A100 (40GB). +```bash +torchrun --nproc_per_node=4 --master_port=9778 fastchat/train/train_flant5.py \ + --model_name_or_path google/flan-t5-xl \ + --data_path data_path \ + --bf16 True \ + --output_dir ./checkpoints_flant5_3b \ + --num_train_epochs 3 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 4 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 300 \ + --save_total_limit 1 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --fsdp "full_shard auto_wrap" \ + --fsdp_transformer_layer_cls_to_wrap T5Block \ + --tf32 True \ + --model_max_length 2048 \ + --preprocessed_path ./preprocessed_data/processed.json \ + --gradient_checkpointing True +``` +After training, please use our post-processing [function](https://github.com/lm-sys/FastChat/blob/main/fastchat/utils.py#L164) to update the saved model weight. Additional discussions can be found [here](https://github.com/lm-sys/FastChat/issues/643). + ### Fine-tuning on Any Cloud with SkyPilot [SkyPilot](https://github.com/skypilot-org/skypilot) is a framework built by UC Berkeley for easily and cost effectively running ML workloads on any cloud (AWS, GCP, Azure, Lambda, etc.). To use SkyPilot, install it with the following command and setup the cloud credentials locally following the instructions [here](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html). diff --git a/fastchat/train/train_flant5.py b/fastchat/train/train_flant5.py index 5200a5152..9829c204c 100755 --- a/fastchat/train/train_flant5.py +++ b/fastchat/train/train_flant5.py @@ -24,6 +24,7 @@ from typing import Dict, Optional, Sequence import torch +import torch.distributed as dist import transformers from torch.utils.data import Dataset @@ -275,15 +276,23 @@ def __init__( ): super(SupervisedDataset, self).__init__() - # save to file + # Make sure only the first process is processing the dataset + if dist.get_rank() != 0: + dist.barrier() + self.preprocessed_path = preprocessed_path - if not os.path.exists("./preprocessed_data/"): - os.mkdir("preprocessed_data/") + if not os.path.exists("preprocessed_data"): + os.mkdir("preprocessed_data") if os.path.exists(self.preprocessed_path): - print("loading from preprocessed data") - data_dict = json.load(open(self.preprocessed_path, "r")) + print(f"loading from preprocessed data at {self.preprocessed_path}") + with open(self.preprocessed_path, "r") as f: + data_dict = json.load(f) + print(len(data_dict["input_ids"])) + if dist.get_rank() == 0: + dist.barrier() else: - logging.warning("Loading data...") + assert dist.get_rank() == 0, "Only the first process should process" + logging.warning("Loading raw data...") list_data_dict = json.load(open(data_path, "r")) logging.warning("Formatting inputs...") @@ -294,11 +303,12 @@ def __init__( data_dict = preprocess(sources, tokenizer) json_data_dict = json.dumps(data_dict) - # open file for writing, "w" - f = open(self.preprocessed_path, "w") - - # write json object to file - f.write(json_data_dict) + # open file for writing, "w" + with open(self.preprocessed_path,"w") as f: + f.write(json_data_dict) + + # Release barrier + dist.barrier() if num_data != -1: data_dict["input_ids"] = data_dict["input_ids"][:num_data] @@ -406,7 +416,7 @@ def train(): smart_tokenizer_and_embedding_resize( special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), - other_tokens=["<", "{", "\n", "}", "`", " ", "\\", "^", "\t"], + other_tokens= ["<", "{", "\n", "}", "`", " ", "\\", "^", "\t"], tokenizer=tokenizer, model=model, ) From 625d0a7c92c6e62056680fffadd05d4b96b5ddc7 Mon Sep 17 00:00:00 2001 From: Dacheng Li Date: Sat, 20 May 2023 19:56:57 +0400 Subject: [PATCH 2/5] data process t5 --- fastchat/train/train_flant5.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/fastchat/train/train_flant5.py b/fastchat/train/train_flant5.py index 9829c204c..3fe266858 100755 --- a/fastchat/train/train_flant5.py +++ b/fastchat/train/train_flant5.py @@ -284,10 +284,9 @@ def __init__( if not os.path.exists("preprocessed_data"): os.mkdir("preprocessed_data") if os.path.exists(self.preprocessed_path): - print(f"loading from preprocessed data at {self.preprocessed_path}") + logging.warning(f"loading from preprocessed data at {self.preprocessed_path}") with open(self.preprocessed_path, "r") as f: data_dict = json.load(f) - print(len(data_dict["input_ids"])) if dist.get_rank() == 0: dist.barrier() else: @@ -373,7 +372,6 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: labels=labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id), ) - torch.set_printoptions(profile="full") return ret @@ -416,7 +414,7 @@ def train(): smart_tokenizer_and_embedding_resize( special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), - other_tokens= ["<", "{", "\n", "}", "`", " ", "\\", "^", "\t"], + other_tokens = ["<", "{", "\n", "}", "`", " ", "\\", "^", "\t"], tokenizer=tokenizer, model=model, ) From cacbc64cb44189a46ede009deeffa08f5458b3ff Mon Sep 17 00:00:00 2001 From: Dacheng Li Date: Sun, 21 May 2023 22:04:07 +0400 Subject: [PATCH 3/5] dist data processing --- fastchat/train/train_flant5.py | 34 +++++++++++++--------------------- 1 file changed, 13 insertions(+), 21 deletions(-) diff --git a/fastchat/train/train_flant5.py b/fastchat/train/train_flant5.py index 3fe266858..657402517 100755 --- a/fastchat/train/train_flant5.py +++ b/fastchat/train/train_flant5.py @@ -24,7 +24,6 @@ from typing import Dict, Optional, Sequence import torch -import torch.distributed as dist import transformers from torch.utils.data import Dataset @@ -276,22 +275,15 @@ def __init__( ): super(SupervisedDataset, self).__init__() - # Make sure only the first process is processing the dataset - if dist.get_rank() != 0: - dist.barrier() - + # save to file self.preprocessed_path = preprocessed_path - if not os.path.exists("preprocessed_data"): - os.mkdir("preprocessed_data") if os.path.exists(self.preprocessed_path): - logging.warning(f"loading from preprocessed data at {self.preprocessed_path}") - with open(self.preprocessed_path, "r") as f: - data_dict = json.load(f) - if dist.get_rank() == 0: - dist.barrier() + print("loading from preprocessed data") + data_dict = json.load(open(self.preprocessed_path, "r")) else: - assert dist.get_rank() == 0, "Only the first process should process" - logging.warning("Loading raw data...") + if not os.path.exists("./preprocessed_data/"): + os.mkdir("preprocessed_data/") + logging.warning("Loading data...") list_data_dict = json.load(open(data_path, "r")) logging.warning("Formatting inputs...") @@ -302,12 +294,11 @@ def __init__( data_dict = preprocess(sources, tokenizer) json_data_dict = json.dumps(data_dict) - # open file for writing, "w" - with open(self.preprocessed_path,"w") as f: - f.write(json_data_dict) - - # Release barrier - dist.barrier() + # open file for writing, "w" + f = open(self.preprocessed_path, "w") + + # write json object to file + f.write(json_data_dict) if num_data != -1: data_dict["input_ids"] = data_dict["input_ids"][:num_data] @@ -372,6 +363,7 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: labels=labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id), ) + torch.set_printoptions(profile="full") return ret @@ -414,7 +406,7 @@ def train(): smart_tokenizer_and_embedding_resize( special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), - other_tokens = ["<", "{", "\n", "}", "`", " ", "\\", "^", "\t"], + other_tokens=["<", "{", "\n", "}", "`", " ", "\\", "^", "\t"], tokenizer=tokenizer, model=model, ) From 387528dd9903b2c5ddc114b6577e9cd38e3343a6 Mon Sep 17 00:00:00 2001 From: Dacheng Li Date: Sun, 21 May 2023 22:08:20 +0400 Subject: [PATCH 4/5] Update readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 1784f4fb3..bf0ab0284 100644 --- a/README.md +++ b/README.md @@ -297,7 +297,7 @@ You can use the following command to train FastChat-T5 with 4 x A100 (40GB). ```bash torchrun --nproc_per_node=4 --master_port=9778 fastchat/train/train_flant5.py \ --model_name_or_path google/flan-t5-xl \ - --data_path data_path \ + --data_path playground/data/dummy.json \ --bf16 True \ --output_dir ./checkpoints_flant5_3b \ --num_train_epochs 3 \ From ec43174f894e2f37b073a1e4a5b4ac2b160b36b2 Mon Sep 17 00:00:00 2001 From: Dacheng Li Date: Sun, 21 May 2023 22:22:47 +0400 Subject: [PATCH 5/5] tested load --- fastchat/train/train_flant5.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/fastchat/train/train_flant5.py b/fastchat/train/train_flant5.py index 657402517..2dc4b4860 100755 --- a/fastchat/train/train_flant5.py +++ b/fastchat/train/train_flant5.py @@ -24,6 +24,7 @@ from typing import Dict, Optional, Sequence import torch +import torch.distributed as dist import transformers from torch.utils.data import Dataset @@ -76,7 +77,6 @@ def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: st state_dict = trainer.model.state_dict() if trainer.args.should_save: cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} - # potential bug for T5 model del state_dict trainer._save(output_dir, state_dict=cpu_state_dict) # noqa @@ -276,13 +276,20 @@ def __init__( super(SupervisedDataset, self).__init__() # save to file + # Make sure only the first process is processing the dataset + if dist.get_rank() != 0: + dist.barrier() self.preprocessed_path = preprocessed_path if os.path.exists(self.preprocessed_path): - print("loading from preprocessed data") - data_dict = json.load(open(self.preprocessed_path, "r")) + logging.warning("loading from preprocessed data") + with open(self.preprocessed_path, "r") as f: + data_dict = json.load(f) + if dist.get_rank() == 0: + dist.barrier() else: - if not os.path.exists("./preprocessed_data/"): - os.mkdir("preprocessed_data/") + if not os.path.exists("preprocessed_data"): + os.mkdir("preprocessed_data") + assert dist.get_rank() == 0, "Only the first process should process" logging.warning("Loading data...") list_data_dict = json.load(open(data_path, "r")) @@ -294,11 +301,12 @@ def __init__( data_dict = preprocess(sources, tokenizer) json_data_dict = json.dumps(data_dict) - # open file for writing, "w" - f = open(self.preprocessed_path, "w") - - # write json object to file - f.write(json_data_dict) + # Remember to close file to avoid concurrent r/w + with open(self.preprocessed_path,"w") as f: + f.write(json_data_dict) + + # Release barrier + dist.barrier() if num_data != -1: data_dict["input_ids"] = data_dict["input_ids"][:num_data]