-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support iterative GRPO #2684
Comments
What about using a callback? That the trainer would internally add. |
@qgallouedec Do you mean |
Not really. Like an arg in the if args.sync_ref_steps is not None:
sync_ref_callback = SyncRefCallback(args.sync_ref_steps)
self.add_callback(sync_ref_callback) |
Note that we already have such callback in trl, I think it makes sense to reuse it |
nice! I guess you mean SyncRefModelCallback Line 96 in 801582e
|
with the current callbacks existing in trl, it would be sth like; if args.sync_ref_model:
self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator)) |
Yep |
I'll take care of it! |
Feature request
Hi,
The GRPO paper also mentioned an interactive version of GRPO which allows for periodic update of reference model (see Algorithm 1). This has shown good performance for cold-start models, see Figure 6. For DeepSeek-R1-Zero, since it's not SFT'ed, they very likely used this iterative version of GRPO.
Currently, in GRPO trainer, the reference model cannot be updated: https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L200-L209
Would be nice to support periodically updating of reference model in Trainer, e.g. after each epoch or certain steps.
Motivation
Might be important for reproducing DeepSeek-R1-Zero and DeepSeek-R1
Your contribution
ideas:
compute_loss
function, but not seems to be an elegant way.The text was updated successfully, but these errors were encountered: