Skip to content

Commit

Permalink
Cleaner integration.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ledzy authored and xtchen96 committed Jul 17, 2024
1 parent 3f28a70 commit 313c9a1
Show file tree
Hide file tree
Showing 8 changed files with 24 additions and 64 deletions.
7 changes: 2 additions & 5 deletions src/llamafactory/hparams/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,8 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
):
if finetuning_args.badam_mode == "ratio":
raise ValueError("Ratio-wise BAdam does not yet support distributed training, use layer-wise BAdam: --badam_mode layer")
if (finetuning_args.badam_mode == "layer"
and training_args.deepspeed_plugin is not None
and training_args.deepspeed_plugin.zero_stage < 3
):
raise ValueError(f"Layer-wise BAdam only supports DeepSpeed ZeRO 3 stage, got stage {training_args.deepspeed_plugin.zero_stage}")
if finetuning_args.badam_mode == "layer" and (not is_deepspeed_zero3_enabled()):
raise ValueError(f"Layer-wise BAdam only supports DeepSpeed ZeRO 3 stage.")

if (finetuning_args.use_galore) and training_args.deepspeed is not None:
raise ValueError("GaLore are incompatible with DeepSpeed yet.")
Expand Down
12 changes: 3 additions & 9 deletions src/llamafactory/train/dpo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,9 @@ def __init__(
self.save_model(os.path.join(self.args.output_dir, "pissa_init"))

if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor

self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)

if (self.args.deepspeed_plugin is not None
and self.args.deepspeed_plugin.zero_stage == 3
):
from badam.utils import BAdamZeRO3Callback
self.callback_handler.add_callback(BAdamZeRO3Callback)
from badam import clip_grad_norm_old_version, BAdamCallback
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.callback_handler.add_callback(BAdamCallback)

def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
Expand Down
12 changes: 3 additions & 9 deletions src/llamafactory/train/kto/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,9 @@ def __init__(
self.ref_model.eval()

if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor

self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)

if (self.args.deepspeed_plugin is not None
and self.args.deepspeed_plugin.zero_stage == 3
):
from badam.utils import BAdamZeRO3Callback
self.callback_handler.add_callback(BAdamZeRO3Callback)
from badam import clip_grad_norm_old_version, BAdamCallback
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.callback_handler.add_callback(BAdamCallback)

def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
Expand Down
12 changes: 3 additions & 9 deletions src/llamafactory/train/ppo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,15 +166,9 @@ def __init__(
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)

if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor

self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)

if (self.args.deepspeed_plugin is not None
and self.args.deepspeed_plugin.zero_stage == 3
):
from badam.utils import BAdamZeRO3Callback
self.callback_handler.add_callback(BAdamZeRO3Callback)
from badam import clip_grad_norm_old_version, BAdamCallback
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.callback_handler.add_callback(BAdamCallback)

def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
r"""
Expand Down
12 changes: 3 additions & 9 deletions src/llamafactory/train/pt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,9 @@ def __init__(
self.save_model(os.path.join(self.args.output_dir, "pissa_init"))

if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor

self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)

if (self.args.deepspeed_plugin is not None
and self.args.deepspeed_plugin.zero_stage == 3
):
from badam.utils import BAdamZeRO3Callback
self.callback_handler.add_callback(BAdamZeRO3Callback)
from badam import clip_grad_norm_old_version, BAdamCallback
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.callback_handler.add_callback(BAdamCallback)

def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
Expand Down
12 changes: 3 additions & 9 deletions src/llamafactory/train/rm/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,9 @@ def __init__(
self.processor = processor
self.can_return_loss = True # override property to return eval_loss
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor

self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)

if (self.args.deepspeed_plugin is not None
and self.args.deepspeed_plugin.zero_stage == 3
):
from badam.utils import BAdamZeRO3Callback
self.callback_handler.add_callback(BAdamZeRO3Callback)
from badam import clip_grad_norm_old_version, BAdamCallback
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.callback_handler.add_callback(BAdamCallback)

def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
Expand Down
11 changes: 3 additions & 8 deletions src/llamafactory/train/sft/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,9 @@ def __init__(
self.save_model(os.path.join(self.args.output_dir, "pissa_init"))

if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)

if (self.args.deepspeed_plugin is not None
and self.args.deepspeed_plugin.zero_stage == 3
):
from badam.utils import BAdamZeRO3Callback
self.callback_handler.add_callback(BAdamZeRO3Callback)
from badam import clip_grad_norm_old_version, BAdamCallback
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.callback_handler.add_callback(BAdamCallback)

def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
Expand Down
10 changes: 4 additions & 6 deletions src/llamafactory/train/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,11 +372,8 @@ def _create_badam_optimizer(
dict(params=decay_params, weight_decay=training_args.weight_decay),
]

ds_zero3_enabled = False
if hasattr(training_args, "deepspeed_plugin") and training_args.deepspeed_plugin is not None:
assert training_args.deepspeed_plugin.zero_stage == 3, f"BAdam only supports deepspeed ZeRO-3 stage, got {training_args.deepspeed_plugin.zero_stage}"
assert finetuning_args.badam_mode == "layer", "BAdam only supports layer-wise update in ZeRO-3 stage"
ds_zero3_enabled = True
from transformers.integrations import is_deepspeed_zero3_enabled
ds_zero3_enabled = is_deepspeed_zero3_enabled()

if finetuning_args.badam_mode == "layer":
from badam import BlockOptimizer
Expand All @@ -401,6 +398,7 @@ def _create_badam_optimizer(
elif finetuning_args.badam_mode == "ratio":
from badam import BlockOptimizerRatio

assert not ds_zero3_enabled, "BAdam with ratio-based update does not support Deepspeed ZeRO-3 yet, use layer-wise update instead: --badam_mode layer."
assert finetuning_args.badam_update_ratio > 1e-6
optimizer = BlockOptimizerRatio(
param_groups=param_groups,
Expand All @@ -412,7 +410,7 @@ def _create_badam_optimizer(
**optim_kwargs,
)
logger.info(
f"Using BAdam optimizer with ratio-wise update, update ratio is {finetuning_args.badam_update_ratio}, "
f"Using BAdam optimizer with ratio-based update, update ratio is {finetuning_args.badam_update_ratio}, "
f"mask mode is {finetuning_args.badam_mask_mode}"
)

Expand Down

0 comments on commit 313c9a1

Please sign in to comment.