Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support iterative GRPO #2684

Open
howardzhou opened this issue Jan 29, 2025 · 8 comments
Open

Support iterative GRPO #2684

howardzhou opened this issue Jan 29, 2025 · 8 comments
Assignees
Labels
✨ enhancement New feature or request 🏋 GRPO Related to GRPO ⚡ PEFT Related to PEFT

Comments

@howardzhou
Copy link

Feature request

Hi,

The GRPO paper also mentioned an interactive version of GRPO which allows for periodic update of reference model (see Algorithm 1). This has shown good performance for cold-start models, see Figure 6. For DeepSeek-R1-Zero, since it's not SFT'ed, they very likely used this iterative version of GRPO.

Currently, in GRPO trainer, the reference model cannot be updated: https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L200-L209

        # Reference model
        if is_deepspeed_zero3_enabled():
            self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
        elif peft_config is None:
            # If PEFT configuration is not provided, create a reference model based on the initial model.
            self.ref_model = create_reference_model(model)
        else:
            # If PEFT is used, the reference model is not needed since the adapter can be disabled
            # to revert to the initial model.
            self.ref_model = None

Would be nice to support periodically updating of reference model in Trainer, e.g. after each epoch or certain steps.

Motivation

Might be important for reproducing DeepSeek-R1-Zero and DeepSeek-R1

Your contribution

ideas:

  • may possibly change in compute_loss function, but not seems to be an elegant way.
  • may need to overwrite the train function in base trainer and adding that option
@github-actions github-actions bot added ✨ enhancement New feature or request 🏋 GRPO Related to GRPO ⚡ PEFT Related to PEFT labels Jan 29, 2025
@qgallouedec
Copy link
Member

What about using a callback? That the trainer would internally add.

@shirinyamani
Copy link

@qgallouedec Do you mean callbacks: Optional[list[TrainerCallback]] ?

@qgallouedec
Copy link
Member

Not really. Like an arg in the GRPOConfig, let's say sync_ref_steps. And having in the init of GRPO.__init__ something like

if args.sync_ref_steps is not None:
    sync_ref_callback = SyncRefCallback(args.sync_ref_steps)
    self.add_callback(sync_ref_callback)

@qgallouedec
Copy link
Member

Note that we already have such callback in trl, I think it makes sense to reuse it

@howardzhou
Copy link
Author

nice! I guess you mean SyncRefModelCallback

class SyncRefModelCallback(TrainerCallback):

@shirinyamani
Copy link

with the current callbacks existing in trl, it would be sth like;

if args.sync_ref_model:
   self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))

@qgallouedec
Copy link
Member

Yep

@shirinyamani
Copy link

I'll take care of it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
✨ enhancement New feature or request 🏋 GRPO Related to GRPO ⚡ PEFT Related to PEFT
Projects
None yet
Development

No branches or pull requests

3 participants