Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🔁 🦈 Support iterative GRPO #2700

Merged
merged 15 commits into from
Feb 4, 2025
Merged
32 changes: 32 additions & 0 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,3 +423,35 @@ def test_training_torch_compile(self):
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

def test_training_with_sync_ref_model(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
sync_ref_model=True,
ref_model_sync_steps=2, # reduce sync steps to ensure a sync happens
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
4 changes: 4 additions & 0 deletions trl/trainer/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ def _generate_completions(


class SyncRefModelCallback(TrainerCallback):
"""
Callback to synchronize the model with a reference model.
"""

def __init__(
self,
ref_model: Union[PreTrainedModel, torch.nn.Module],
Expand Down
35 changes: 35 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,19 @@ class GRPOConfig(TrainingArguments):
Number of updates steps to accumulate the gradients for, before performing a backward/update pass.
beta (`float`, *optional*, defaults to `0.04`):
KL coefficient.
sync_ref_model (`bool`, *optional*, defaults to `False`):
Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using
the `ref_model_mixup_alpha` parameter. This synchronization originites from the
[TR-DPO](https://huggingface.co/papers/2404.09656) paper.
ref_model_mixup_alpha (`float`, *optional*, defaults to `0.9`):
α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix
between the current policy and the previous reference policy during updates. The reference policy is
updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you
must set `sync_ref_model=True`.
ref_model_sync_steps (`int`, *optional*, defaults to `64`):
τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how
frequently the current policy is synchronized with the reference policy. To use this parameter, you must
set `sync_ref_model=True`.
"""

# Parameters that control the model and reference model
Expand Down Expand Up @@ -184,3 +197,25 @@ class GRPOConfig(TrainingArguments):
default=0.04,
metadata={"help": "KL coefficient."},
)
sync_ref_model: bool = field(
default=False,
metadata={
"help": "Whether to synchronize the reference model with the active model every `ref_model_sync_steps` "
"steps, using the `ref_model_mixup_alpha` parameter."
},
)
ref_model_mixup_alpha: float = field(
default=0.9,
metadata={
"help": "α parameter from the TR-DPO paper, which controls the mix between the current policy and the "
"previous reference policy during updates. The reference policy is updated according to the equation: "
"`π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you must set `sync_ref_model=True`."
},
)
ref_model_sync_steps: int = field(
default=64,
metadata={
"help": "τ parameter from the TR-DPO paper, which determines how frequently the current policy is "
"synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`."
},
)
4 changes: 4 additions & 0 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from ..import_utils import is_vllm_available
from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
from .callbacks import SyncRefModelCallback
from .grpo_config import GRPOConfig
from .utils import generate_model_card, get_comet_experiment_url, pad

Expand Down Expand Up @@ -357,6 +358,9 @@ def data_collator(features): # No data collation is needed in GRPO
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)

if args.sync_ref_model:
self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))

for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, PreTrainedModel):
self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
Expand Down