Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

πŸ‘¨β€πŸ‘¨β€πŸ‘§β€πŸ‘§ GRPO #2565

Draft
wants to merge 51 commits into
base: main
Choose a base branch
from
Draft

Conversation

qgallouedec
Copy link
Member

@qgallouedec qgallouedec commented Jan 13, 2025

What does this PR do?

from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification

# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B")
ref_model = AutoModelForCausalLM.from_pretrained("trl-lib/Qwen2-0.5B-ORPO")  # different, so that kl is not 0
reward_model = AutoModelForSequenceClassification.from_pretrained("Qwen/Qwen2-0.5B", num_labels=1)
tokenizer.padding_side = "left"
train_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")

training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO", logging_steps=10, gradient_accumulation_steps=8, per_device_train_batch_size=2)
trainer = GRPOTrainer(
    model=model, reward_model=reward_model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset
)

trainer.train()

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec
Copy link
Member Author

qgallouedec commented Jan 15, 2025

From the paper:

$$\mathcal{J}_{\text{GRPO}}(\theta) =\frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|}\sum_{t=1}^{|o_i|}\left[\min \left(\frac{\pi_\theta(o_{i,t} | q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} | q, o_{i,< t})} \hat{A}_{i,t}, \text{clip}\left(\frac{\pi_\theta(o_{i,t} | q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} | q, o_{i,< t})}, 1 - \epsilon, 1 + \epsilon\right) \hat{A}_{i,t}\right) - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right]\right].$$

where:

  • $G$ is the number of generations per prompt
  • $o_i$ is the $i$-th generation of the prompt and $|o_i|$ is the number of tokens in $o_i$
  • $q$ is the prompt
  • $\pi_\theta$ is the policy model
  • $\pi_{\theta_{\text{old}}}$ is the policy model before the update
  • $\pi_{\text{ref}}$ is the reference policy
  • $\hat{A}_{i,t}$ is the advantage estimate for the $t$-th token in the $i$-th generation (see under)
  • $\epsilon$ and $\beta$ are hyperparameters

In Section 4.2 we can read:

The policy model only has a single update following each exploration stage.

It implies that $\pi_{\theta_{\text{old}}} = \pi_\theta$. Consequently,

$$\mathcal{J}_{\text{GRPO}}(\theta) = \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \left[\hat{A}_{i,t}- \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right]\right]$$ $$\mathcal{J}_{\text{GRPO}}(\theta) = \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \hat{A}_{i,t}- \beta \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|}\mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right]$$

from 4.1.2, $\hat{A}_{i,t}$ is defined as

$$\hat{A}_{i,t} = \tilde{r}_i = \frac{r_i - \text{mean}(\mathbf{r})}{\text{std}(\mathbf{r})},$$

where $\mathbf{r} = {r_1, r_2, \ldots, r_{G}}$. It implies that $\hat{A}_{i,t}$ doesn't depend on $t$. Let's use the notation $\tilde{r}_i$ instead for clarity. So

$$\mathcal{J}_{\text{GRPO}}(\theta) = \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \tilde{r}_i- \beta \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|}\mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right]$$ $$\mathcal{J}_{\text{GRPO}}(\theta) = \frac{1}{G} \sum_{i=1}^G \tilde{r}_i- \beta \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|}\mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right]$$

we know that $\left\{ \tilde{r}_1, \tilde{r}_2, \ldots, \tilde{r}_G \right\}$ is normalized, so the mean is 0. It comes

$$\mathcal{J}_{\text{GRPO}}(\theta) = - \beta \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|}\mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right]$$

As a result, the GRPO objective just minimizes the KL divergence between the policy model and the reference policy.

@qgallouedec
Copy link
Member Author

@ZhihongShao if you can help by any chance

@qgallouedec
Copy link
Member Author

qgallouedec commented Jan 16, 2025

I have the answer to the question above.

The math is correct; if you look at the loss, it is indeed equal to the KL value.

However, in terms of differentiation, we cannot remove

$$\frac{\pi_\theta(o_{i,t} | q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} | q, o_{i,< t})}$$

justifying that it equals 1.

Let’s denote the stop-gradient operator as $\left[\cdot\right]_\cancel{\nabla}$.

We must retain the term

$$\frac{\pi_\theta(o_{i,t} | q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} | q, o_{i,< t})}$$

in the equation, rewriting it as

$$\frac{\pi_\theta(o_{i,t} | q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} | q, o_{i,< t})} = \frac{\pi_\theta(o_{i,t} | q, o_{i,< t})}{\left[\pi_\theta(o_{i,t} | q, o_{i,< t})\right]_\cancel{\nabla}}.$$

Finally, the objective is written as

$$\mathcal{J}_{\text{GRPO}}(\theta) = \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \left[ \frac{\pi_\theta(o_{i,t} | q, o_{i,< t})}{\left[\pi_\theta(o_{i,t} | q, o_{i,< t})\right]_\cancel{\nabla}} \hat{A}_{i,t} - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right].$$

In the end, the value remains equal to the KL divergence as initially stated. However, when implemented in this way, the gradient can propagate through the equation, allowing the policy to update effectively.

Thanks @edbeeching for helping me with this!!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants