Skip to content

Commit

Permalink
some doc
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Jan 18, 2025
1 parent 73542d4 commit 2b060dd
Showing 1 changed file with 82 additions and 0 deletions.
82 changes: 82 additions & 0 deletions docs/source/grpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,88 @@ The abstract from the paper is the following:

This post-training method was contributed by [Quentin Gallouédec](https://huggingface.co/qgallouedec).

## Quick start

This example demonstrates how to train a model using the GRPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B) as the base model and the [RM-Gemma-2B model](https://huggingface.co/weqweasdas/RM-Gemma-2B) as the reward model. We use the prompts from the [TLDR dataset](https://huggingface.co/datasets/trl-lib/tldr). You can view the data in the dataset here:

<iframe
src="https://huggingface.co/datasets/trl-lib/tldr/embed/viewer/default/train?row=0"
frameborder="0"
width="100%"
height="560px"
></iframe>
Below is the script to train the model. We use PEFT to reduce the memory requirements.

```python
# train_grpo.py
from datasets import load_dataset
from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer

# Load the dataset
dataset = load_dataset("trl-lib/tldr", split="train")

training_args = GRPOConfig(
output_dir="Qwen2-0.5B-GRPO",
learning_rate=1e-5,
logging_steps=10,
gradient_accumulation_steps=16,
max_completion_length=128,
)
trainer = GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_model="weqweasdas/RM-Gemma-2B",
args=training_args,
train_dataset=dataset,
peft_config=LoraConfig(task_type="CAUSAL_LM"),
)

trainer.train()
```

Execute the script using the following command:

```bash
accelerate launch train_grpo.py
```

Distributed across 8 GPUs, the training takes approximately 1 day.

![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/grpo_curves.png)

## Looking deeper into the GRPO method

GRPO is an online learning algorithm, meaning it improves iteratively by using the data generated by the trained model itself during training. The intuition behind GRPO objective is to maximize the advantage of the generated completions, while ensuring that the model remains close to the reference policy. To understand how GRPO works, it can be broken down into four main steps: **Generating completions**, **Computing the advantage**, **Calculating the KL divergence**, and **Computing the loss**.

![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/grpo_visual.png)

### Generating completions

At each training step, we sample a batch of prompts and generate a set of $G$ completions for each prompt (denoted as $o_i$).

### Computing the advantage

For each of the $G$ sequences, we compute the reward using a reward model. To reduce gradient variance, we use the advantage, which is the difference between the reward and a baseline. Ideally, the baseline would be the value function, the expected reward under the policy, but since we don’t have access to it, we approximate it with the mean reward across the generated completions. The resulting value is the advantage:

$$\hat{A}_{i,t} = \frac{r_i - \text{mean}(\mathbf{r})}{\text{std}(\mathbf{r})}
$$

### Computing the KL divergence

...

### Computing the loss

...

## Logged metrics

The GRPO Trainer logs the following metrics:

- `reward`: The mean reward of the completions.
- ...

## GRPOTrainer

[[autodoc]] GRPOTrainer
Expand Down

0 comments on commit 2b060dd

Please sign in to comment.