From 92b9322db78306e614928cb42f8a80aab4b485ae Mon Sep 17 00:00:00 2001 From: Jonery Date: Wed, 19 Jun 2024 12:29:40 +0800 Subject: [PATCH] Cleaner integration. --- src/llamafactory/hparams/parser.py | 7 ++----- src/llamafactory/train/dpo/trainer.py | 12 +++--------- src/llamafactory/train/kto/trainer.py | 12 +++--------- src/llamafactory/train/ppo/trainer.py | 12 +++--------- src/llamafactory/train/pt/trainer.py | 12 +++--------- src/llamafactory/train/rm/trainer.py | 12 +++--------- src/llamafactory/train/sft/trainer.py | 11 +++-------- src/llamafactory/train/trainer_utils.py | 10 ++++------ 8 files changed, 24 insertions(+), 64 deletions(-) diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 8adab01c0c..f2ccd5e655 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -218,11 +218,8 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: ): if finetuning_args.badam_mode == "ratio": raise ValueError("Ratio-wise BAdam does not yet support distributed training, use layer-wise BAdam: --badam_mode layer") - if (finetuning_args.badam_mode == "layer" - and training_args.deepspeed_plugin is not None - and training_args.deepspeed_plugin.zero_stage < 3 - ): - raise ValueError(f"Layer-wise BAdam only supports DeepSpeed ZeRO 3 stage, got stage {training_args.deepspeed_plugin.zero_stage}") + if finetuning_args.badam_mode == "layer" and (not is_deepspeed_zero3_enabled()): + raise ValueError(f"Layer-wise BAdam only supports DeepSpeed ZeRO 3 stage.") if (finetuning_args.use_galore) and training_args.deepspeed is not None: raise ValueError("GaLore are incompatible with DeepSpeed yet.") diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index 284bf41af9..a3e0e9610b 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -96,15 +96,9 @@ def __init__( self.save_model(os.path.join(self.args.output_dir, "pissa_init")) if finetuning_args.use_badam: - from badam import clip_grad_norm_for_sparse_tensor - - self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) - - if (self.args.deepspeed_plugin is not None - and self.args.deepspeed_plugin.zero_stage == 3 - ): - from badam.utils import BAdamZeRO3Callback - self.callback_handler.add_callback(BAdamZeRO3Callback) + from badam import clip_grad_norm_old_version, BAdamCallback + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) + self.callback_handler.add_callback(BAdamCallback) def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index d8b609e0d1..0d50987fb1 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -91,15 +91,9 @@ def __init__( self.ref_model.eval() if finetuning_args.use_badam: - from badam import clip_grad_norm_for_sparse_tensor - - self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) - - if (self.args.deepspeed_plugin is not None - and self.args.deepspeed_plugin.zero_stage == 3 - ): - from badam.utils import BAdamZeRO3Callback - self.callback_handler.add_callback(BAdamZeRO3Callback) + from badam import clip_grad_norm_old_version, BAdamCallback + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) + self.callback_handler.add_callback(BAdamCallback) def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index 261ef757e5..2d5d7ffc4f 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -166,15 +166,9 @@ def __init__( self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True) if finetuning_args.use_badam: - from badam import clip_grad_norm_for_sparse_tensor - - self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) - - if (self.args.deepspeed_plugin is not None - and self.args.deepspeed_plugin.zero_stage == 3 - ): - from badam.utils import BAdamZeRO3Callback - self.callback_handler.add_callback(BAdamZeRO3Callback) + from badam import clip_grad_norm_old_version, BAdamCallback + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) + self.callback_handler.add_callback(BAdamCallback) def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None: r""" diff --git a/src/llamafactory/train/pt/trainer.py b/src/llamafactory/train/pt/trainer.py index 1e5e9f6a3f..d3516b4185 100644 --- a/src/llamafactory/train/pt/trainer.py +++ b/src/llamafactory/train/pt/trainer.py @@ -48,15 +48,9 @@ def __init__( self.save_model(os.path.join(self.args.output_dir, "pissa_init")) if finetuning_args.use_badam: - from badam import clip_grad_norm_for_sparse_tensor - - self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) - - if (self.args.deepspeed_plugin is not None - and self.args.deepspeed_plugin.zero_stage == 3 - ): - from badam.utils import BAdamZeRO3Callback - self.callback_handler.add_callback(BAdamZeRO3Callback) + from badam import clip_grad_norm_old_version, BAdamCallback + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) + self.callback_handler.add_callback(BAdamCallback) def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/rm/trainer.py b/src/llamafactory/train/rm/trainer.py index 5d0e626398..433251cf22 100644 --- a/src/llamafactory/train/rm/trainer.py +++ b/src/llamafactory/train/rm/trainer.py @@ -72,15 +72,9 @@ def __init__( self.processor = processor self.can_return_loss = True # override property to return eval_loss if finetuning_args.use_badam: - from badam import clip_grad_norm_for_sparse_tensor - - self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) - - if (self.args.deepspeed_plugin is not None - and self.args.deepspeed_plugin.zero_stage == 3 - ): - from badam.utils import BAdamZeRO3Callback - self.callback_handler.add_callback(BAdamZeRO3Callback) + from badam import clip_grad_norm_old_version, BAdamCallback + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) + self.callback_handler.add_callback(BAdamCallback) def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 9446d245a8..45799b9651 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -56,14 +56,9 @@ def __init__( self.save_model(os.path.join(self.args.output_dir, "pissa_init")) if finetuning_args.use_badam: - from badam import clip_grad_norm_for_sparse_tensor - self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) - - if (self.args.deepspeed_plugin is not None - and self.args.deepspeed_plugin.zero_stage == 3 - ): - from badam.utils import BAdamZeRO3Callback - self.callback_handler.add_callback(BAdamZeRO3Callback) + from badam import clip_grad_norm_old_version, BAdamCallback + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) + self.callback_handler.add_callback(BAdamCallback) def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 5d1a486361..0206dcb691 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -372,11 +372,8 @@ def _create_badam_optimizer( dict(params=decay_params, weight_decay=training_args.weight_decay), ] - ds_zero3_enabled = False - if hasattr(training_args, "deepspeed_plugin") and training_args.deepspeed_plugin is not None: - assert training_args.deepspeed_plugin.zero_stage == 3, f"BAdam only supports deepspeed ZeRO-3 stage, got {training_args.deepspeed_plugin.zero_stage}" - assert finetuning_args.badam_mode == "layer", "BAdam only supports layer-wise update in ZeRO-3 stage" - ds_zero3_enabled = True + from transformers.integrations import is_deepspeed_zero3_enabled + ds_zero3_enabled = is_deepspeed_zero3_enabled() if finetuning_args.badam_mode == "layer": from badam import BlockOptimizer @@ -401,6 +398,7 @@ def _create_badam_optimizer( elif finetuning_args.badam_mode == "ratio": from badam import BlockOptimizerRatio + assert not ds_zero3_enabled, "BAdam with ratio-based update does not support Deepspeed ZeRO-3 yet, use layer-wise update instead: --badam_mode layer." assert finetuning_args.badam_update_ratio > 1e-6 optimizer = BlockOptimizerRatio( param_groups=param_groups, @@ -412,7 +410,7 @@ def _create_badam_optimizer( **optim_kwargs, ) logger.info( - f"Using BAdam optimizer with ratio-wise update, update ratio is {finetuning_args.badam_update_ratio}, " + f"Using BAdam optimizer with ratio-based update, update ratio is {finetuning_args.badam_update_ratio}, " f"mask mode is {finetuning_args.badam_mask_mode}" )