Skip to content

Commit

Permalink
[RL] Fix PPO and add GRPO (#9925)
Browse files Browse the repository at this point in the history
* fix ppo and grpo v1

* update grpo

* delete notes and modify argument (#10)

* [RL] Fix PPO and add GRPO  (#11)

* delete notes and modify argument

* delete ppo_config.json

* modify format

* lint

* fix model config set

* fix grpo (#12)

* [New Features] support json file data (#13)

* delete notes and modify argument

* delete ppo_config.json

* modify format

* support json data

* modify argument

* fix

* fix ci

* fix

* fix datapath (#14)

* delete notes and modify argument

* delete ppo_config.json

* modify format

* support json data

* modify argument

* fix data

---------

Co-authored-by: greycooker <[email protected]>
Co-authored-by: gongel <[email protected]>
  • Loading branch information
3 people authored Feb 27, 2025
1 parent 85e1238 commit 40c3530
Show file tree
Hide file tree
Showing 38 changed files with 5,017 additions and 1,164 deletions.
1 change: 1 addition & 0 deletions docs/llm/alignment/ppo/README.md
87 changes: 87 additions & 0 deletions llm/alignment/ppo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# PPO

PPO(Proximal Policy Optimization,近端策略优化)是一种强化学习算法,旨在通过优化策略来最大化累积奖励。PPO 算法结合了 Policy Gradient 和‌TRPO 的优点,通过使用随机梯度上升优化一个“替代”目标函数,实现小批量更新,而不是每个数据样本只进行一次梯度更新。

以下是详细的使用文档和示例:

## 环境依赖

* 训练环境:在 python3.9的环境下安装, 可以使用如下脚本安装
```bash
bash -x scripts/install_train_env.sh gpu
```

## 数据协议

数据格式以`data/rlhf_train_data_test.jsonl`为例。

### 字段说明

- src (list(str)): 用户对话内容,可能会包含 markup 内容,如 [<search-res>]
- tgt (list(str)): 除了最后一轮的系统多轮回复内容,以对话轮次排列,可能会包含 markup 内容,如 [<search>];注意:len(tgt)==len(src)-1

### 数据示例

```json
{
"src": [
"需要你帮我写几个有创意的广告语来打开市场。",
"目标用户是年轻人,追求时尚、个性和自我。"
],
"tgt": [
"当然!我很乐意帮助你创作几个有创意的广告语来推广你的新洗发露。请告诉我一些关于你的产品的特点,目标受众以及你希望传达的核心信息,我会根据这些信息为你提供几个创意的广告语。"
]
}
```

## 训练

```shell
bash scripts/ppo.sh
```

其中参数释义如下:

- `train_task_config`: 训练数据 config, 请以`config/task_ppo.json`为例
- `eval_task_config`: 评估数据 config, 请以`config/task_ppo.json`为例
- `ptx_task_config`: SFT 辅助数据, 请以`config/task_sft.json`为例,默认为""
- `actor_model_name_or_path`: PPO 中 actor-model 和 reference-model 模型本地的模型路径
- `reward_model_name_or_path`: PPO 中 reward-model 和 critic-model 模型本地的模型路径
- `use_fusemt`: 是否通过 FustMT 加速生成,默认为 True
- `use_flash_attention`: 是否启用 FlashAttention-2,默认为 False
- `output_dir`: 模型参数保存目录
- `max_seq_len`: 输入数据的最大长度,默认为 4096
- `max_dec_len`: 最大生成长度
- `min_dec_len`: 最小生成长度
- `top_p`: 生成解码超参数
- `temperature`: 生成解码超参数
- `repetition_penalty`: 生成解码超参数
- `num_return_sequences`: 生成解码超参数
- `min_learning_rate`: Actor 模型的最小学习率
- `critic_learning_rate`: Critic 模型的最小学习率
- `recompute`: Actor 模型是否使用重计算策略,开启后可节省训练显存
- `critic_recompute`: Critic 模型是否使用重计算策略,开启后可节省训练显存
- `recompute_granularity` Actor 模型的重计算的粒度,可选项为`core_attn``full`. `core_attn`速度快但是显存占用,`full`速度慢但是显存占用低
- `critic_recompute_granularity` Critic 模型重计算的粒度,可选项为`core_attn``full`. `core_attn`速度快但是显存占用,`full`速度慢但是显存占用低
- `warmup_ratio`: Actor 模型用于从 0 到 `learning_rate` 的线性 warmup 的总训练步骤的比例
- `critic_warmup_ratio`: Critic 模型用于从 0 到 `critic_learning_rate` 的线性 warmup 的总训练步骤的比例
- `lr_scheduler_type`: Actor 模型要使用的学习率调度策略。 (`str`, 可选, 默认为 `"linear"`)
- `critic_lr_scheduler_type`: Critic 模型要使用的学习率调度策略。 (`str`, 可选, 默认为 `"linear"`)
- `weight_decay`: Actor 模型除了所有 bias 和 LayerNorm 权重之外,应用于所有层的权重衰减数值。(`float`,可选,默认为 0.0)
- `critic_weight_decay`: Critic 模型除了所有 bias 和 LayerNorm 权重之外,应用于所有层的权重衰减数值。(`float`,可选,默认为 0.0)
- `max_prompt_len`: 生成样本时的最大生成长度, max_length 调大会增加生成时间,并且增加显存占用。注意:
max_dec_len + max_prompt_len 应当小于 max_seq_len。
- `per_device_prompt_batch_size`: PPO 生成样本时的批处理大小,同 micro batch size,即满足 global_batch_size = dp(data parallel)* sharding * micro batch size。batch_size 调大会增加生成时间,并且增加显存占用
- `per_device_train_batch_size`: 训练 batch 大小, 当前为了优化性能设为1,请避免更改
- `per_device_eval_batch_size`: 评估 batch 大小。
- `max_steps`: 总的训练步数
- `eval_steps`: 模型评估的间隔步数
- `max_evaluate_steps`: 模型单次评估的最大步数
- `logging_steps`: 训练日志打印的间隔步数
- `save_steps`: 模型参数保存的间隔步数
- `weight_decay`: 权重衰减数值
- `do_train`: 是否进行训练任务
- `do_eval`: 是否进行评估任务
- `fp16`: 使用 float16 精度进行模型训练和推理。
- `bf16`: 使用 bfloat16 精度进行模型训练和推理。
- `fp16_opt_level`: float16 精度训练模式,`O2`表示纯 float16 训练
37 changes: 37 additions & 0 deletions llm/alignment/ppo/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json

import requests

CHAT_URL = "http://127.0.0.1:8731"

data = {
"src": [
"Natalia sold clips to 48 of her friends in April, ",
"Weng earns $12 an hour for babysitting. Yesterday",
],
"tgt": [
"Natalia sold 48/2 = 24 clips in May. #### 72",
"She earned 0.2 x 50 = $10. #### 10",
],
"response": [
"Natalia sold 48+24 = 72 clips altogether in April and May. #### 72",
"2",
],
}
res = requests.post(CHAT_URL, json=data)
result = json.loads(res.text)
print("result:", result, result["score"])
Loading

0 comments on commit 40c3530

Please sign in to comment.