From 4dbbe7dda55c0bb9b5ab55dd8144376986caeaa7 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Thu, 23 May 2024 16:10:39 +0200 Subject: [PATCH 1/4] Update modeling_base.py --- trl/models/modeling_base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/trl/models/modeling_base.py b/trl/models/modeling_base.py index f7894ddedb..a3a80e461a 100644 --- a/trl/models/modeling_base.py +++ b/trl/models/modeling_base.py @@ -100,6 +100,9 @@ def __init__( if hasattr(pretrained_model, "gradient_checkpointing_enable"): self.gradient_checkpointing_enable = pretrained_model.gradient_checkpointing_enable + if hasattr(pretrained_model, "enable_input_require_grads"): + self.enable_input_require_grads = pretrained_model.enable_input_require_grads + self.supports_rm_adapter = supports_rm_adapter self.rm_adapter_name = rm_adapter_name self.policy_adapter_name = "default" From 1a9110b0ae547eed3e13c1ffa7d6b861af2c58f5 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Thu, 23 May 2024 16:18:15 +0200 Subject: [PATCH 2/4] Update ppo_config.py --- trl/trainer/ppo_config.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/trl/trainer/ppo_config.py b/trl/trainer/ppo_config.py index b70610d950..34998ba256 100644 --- a/trl/trainer/ppo_config.py +++ b/trl/trainer/ppo_config.py @@ -124,7 +124,9 @@ class PPOConfig: """Score clipping""" whiten_rewards: bool = False """Whiten the rewards before compute advantages""" - + gradient_checkpointing: bool = False + """Enable gradient checkpointing""" + # computed hyperparameters at runtime; we use `tyro.conf.Suppress` to hide them from the help text is_encoder_decoder: Optional[tyro.conf.Suppress[bool]] = None """TO BE FILLED In RUNTIME: Whether the model is an encoder-decoder model""" From cc967eccbf7816190b150609914170a97f888327 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Thu, 23 May 2024 16:20:03 +0200 Subject: [PATCH 3/4] Update ppo_trainer.py --- trl/trainer/ppo_trainer.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 0a60dcaffb..e961383463 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -319,6 +319,18 @@ def __init__( self.accelerator.state, "deepspeed_plugin" ) + if config.gradient_checkpointing: + self.model.gradient_checkpointing_enable() + + if hasattr(self.model, "enable_input_require_grads"): + self.model.enable_input_require_grads() + else: + # For backward compatibility with older versions of transformers + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + self.model.pretrained_model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + ( self.model, self.optimizer, From a106baa40843cb35da4a9b3ed1c3d2ee7ee13bc6 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 23 May 2024 16:21:34 +0200 Subject: [PATCH 4/4] style --- trl/trainer/ppo_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/ppo_config.py b/trl/trainer/ppo_config.py index 34998ba256..786d9d0d38 100644 --- a/trl/trainer/ppo_config.py +++ b/trl/trainer/ppo_config.py @@ -126,7 +126,7 @@ class PPOConfig: """Whiten the rewards before compute advantages""" gradient_checkpointing: bool = False """Enable gradient checkpointing""" - + # computed hyperparameters at runtime; we use `tyro.conf.Suppress` to hide them from the help text is_encoder_decoder: Optional[tyro.conf.Suppress[bool]] = None """TO BE FILLED In RUNTIME: Whether the model is an encoder-decoder model"""