From 40c3530dff078f6ed7ecb76e3ff102ba28726c8c Mon Sep 17 00:00:00 2001 From: Weiguo Zhu Date: Thu, 27 Feb 2025 21:01:30 +0800 Subject: [PATCH] [RL] Fix PPO and add GRPO (#9925) * fix ppo and grpo v1 * update grpo * delete notes and modify argument (#10) * [RL] Fix PPO and add GRPO (#11) * delete notes and modify argument * delete ppo_config.json * modify format * lint * fix model config set * fix grpo (#12) * [New Features] support json file data (#13) * delete notes and modify argument * delete ppo_config.json * modify format * support json data * modify argument * fix * fix ci * fix * fix datapath (#14) * delete notes and modify argument * delete ppo_config.json * modify format * support json data * modify argument * fix data --------- Co-authored-by: greycooker <94276438+greycooker@users.noreply.github.com> Co-authored-by: gongel --- docs/llm/alignment/ppo/README.md | 1 + llm/alignment/ppo/README.md | 87 + llm/alignment/ppo/client.py | 37 + llm/alignment/ppo/comm_utils.py | 534 ++++- llm/alignment/ppo/data/__init__.py | 1 + llm/alignment/ppo/data/alpaca.py | 2 +- llm/alignment/ppo/data/base.py | 13 +- llm/alignment/ppo/data/jsondata.py | 44 + llm/alignment/ppo/data/prompt_only.py | 30 +- llm/alignment/ppo/data/safe_rlhf.py | 2 +- llm/alignment/ppo/data/supervised.py | 6 +- llm/alignment/ppo/infer_utils.py | 128 +- llm/alignment/ppo/models/infer_model_utils.py | 30 + llm/alignment/ppo/models/pp_model_utils.py | 66 +- llm/alignment/ppo/models/ppo_model.py | 50 +- llm/alignment/ppo/models/ppo_model_utils.py | 574 ++++- llm/alignment/ppo/models/score_model.py | 119 +- llm/alignment/ppo/models/score_model_utils.py | 83 +- llm/alignment/ppo/ppo_trainer.py | 2057 +++++++++++++---- llm/alignment/ppo/reward_server.py | 308 +++ llm/alignment/ppo/run_ppo.py | 727 +++--- llm/alignment/ppo/trainer_utils.py | 744 +++++- llm/alignment/rm/reward_trainer.py | 2 +- llm/alignment/rm/run_reward.py | 19 +- llm/config/llama/grpo_argument.json | 82 + llm/config/llama/ppo_argument.json | 79 +- llm/config/qwen/grpo_argument.json | 82 + llm/docs/rlhf.md | 2 +- llm/predict/predictor.py | 56 +- .../transformers/llama/modeling.py | 8 +- .../transformers/qwen2/modeling.py | 8 +- paddlenlp/transformers/auto/modeling.py | 3 +- paddlenlp/trl/llm_utils.py | 49 +- tests/fixtures/llm/ppo.yaml | 79 + tests/fixtures/llm/ppo_data/dev.jsonl | 5 + tests/fixtures/llm/ppo_data/ptx.jsonl | 5 + tests/fixtures/llm/ppo_data/train.jsonl | 9 + tests/llm/test_ppo.py | 50 + 38 files changed, 5017 insertions(+), 1164 deletions(-) create mode 120000 docs/llm/alignment/ppo/README.md create mode 100644 llm/alignment/ppo/README.md create mode 100644 llm/alignment/ppo/client.py create mode 100644 llm/alignment/ppo/data/jsondata.py create mode 100644 llm/alignment/ppo/reward_server.py create mode 100644 llm/config/llama/grpo_argument.json create mode 100644 llm/config/qwen/grpo_argument.json create mode 100644 tests/fixtures/llm/ppo.yaml create mode 100644 tests/fixtures/llm/ppo_data/dev.jsonl create mode 100644 tests/fixtures/llm/ppo_data/ptx.jsonl create mode 100644 tests/fixtures/llm/ppo_data/train.jsonl create mode 100644 tests/llm/test_ppo.py diff --git a/docs/llm/alignment/ppo/README.md b/docs/llm/alignment/ppo/README.md new file mode 120000 index 000000000000..f2231cb0ee86 --- /dev/null +++ b/docs/llm/alignment/ppo/README.md @@ -0,0 +1 @@ +../../../../llm/alignment/ppo/README.md \ No newline at end of file diff --git a/llm/alignment/ppo/README.md b/llm/alignment/ppo/README.md new file mode 100644 index 000000000000..fc5a6ad4a589 --- /dev/null +++ b/llm/alignment/ppo/README.md @@ -0,0 +1,87 @@ +# PPO + +PPO(Proximal Policy Optimization,近端策略优化)是一种强化学习算法,旨在通过优化策略来最大化累积奖励。PPO 算法结合了 Policy Gradient 和‌TRPO 的优点,通过使用随机梯度上升优化一个“替代”目标函数,实现小批量更新,而不是每个数据样本只进行一次梯度更新。 + +以下是详细的使用文档和示例: + +## 环境依赖 + +* 训练环境:在 python3.9的环境下安装, 可以使用如下脚本安装 +```bash +bash -x scripts/install_train_env.sh gpu +``` + +## 数据协议 + +数据格式以`data/rlhf_train_data_test.jsonl`为例。 + +### 字段说明 + +- src (list(str)): 用户对话内容,可能会包含 markup 内容,如 []; +- tgt (list(str)): 除了最后一轮的系统多轮回复内容,以对话轮次排列,可能会包含 markup 内容,如 [];注意:len(tgt)==len(src)-1 + +### 数据示例 + +```json +{ + "src": [ + "需要你帮我写几个有创意的广告语来打开市场。", + "目标用户是年轻人,追求时尚、个性和自我。" + ], + "tgt": [ + "当然!我很乐意帮助你创作几个有创意的广告语来推广你的新洗发露。请告诉我一些关于你的产品的特点,目标受众以及你希望传达的核心信息,我会根据这些信息为你提供几个创意的广告语。" + ] +} +``` + +## 训练 + +```shell +bash scripts/ppo.sh +``` + +其中参数释义如下: + +- `train_task_config`: 训练数据 config, 请以`config/task_ppo.json`为例 +- `eval_task_config`: 评估数据 config, 请以`config/task_ppo.json`为例 +- `ptx_task_config`: SFT 辅助数据, 请以`config/task_sft.json`为例,默认为"" +- `actor_model_name_or_path`: PPO 中 actor-model 和 reference-model 模型本地的模型路径 +- `reward_model_name_or_path`: PPO 中 reward-model 和 critic-model 模型本地的模型路径 +- `use_fusemt`: 是否通过 FustMT 加速生成,默认为 True +- `use_flash_attention`: 是否启用 FlashAttention-2,默认为 False +- `output_dir`: 模型参数保存目录 +- `max_seq_len`: 输入数据的最大长度,默认为 4096 +- `max_dec_len`: 最大生成长度 +- `min_dec_len`: 最小生成长度 +- `top_p`: 生成解码超参数 +- `temperature`: 生成解码超参数 +- `repetition_penalty`: 生成解码超参数 +- `num_return_sequences`: 生成解码超参数 +- `min_learning_rate`: Actor 模型的最小学习率 +- `critic_learning_rate`: Critic 模型的最小学习率 +- `recompute`: Actor 模型是否使用重计算策略,开启后可节省训练显存 +- `critic_recompute`: Critic 模型是否使用重计算策略,开启后可节省训练显存 +- `recompute_granularity` Actor 模型的重计算的粒度,可选项为`core_attn`和`full`. `core_attn`速度快但是显存占用,`full`速度慢但是显存占用低 +- `critic_recompute_granularity` Critic 模型重计算的粒度,可选项为`core_attn`和`full`. `core_attn`速度快但是显存占用,`full`速度慢但是显存占用低 +- `warmup_ratio`: Actor 模型用于从 0 到 `learning_rate` 的线性 warmup 的总训练步骤的比例 +- `critic_warmup_ratio`: Critic 模型用于从 0 到 `critic_learning_rate` 的线性 warmup 的总训练步骤的比例 +- `lr_scheduler_type`: Actor 模型要使用的学习率调度策略。 (`str`, 可选, 默认为 `"linear"`) +- `critic_lr_scheduler_type`: Critic 模型要使用的学习率调度策略。 (`str`, 可选, 默认为 `"linear"`) +- `weight_decay`: Actor 模型除了所有 bias 和 LayerNorm 权重之外,应用于所有层的权重衰减数值。(`float`,可选,默认为 0.0) +- `critic_weight_decay`: Critic 模型除了所有 bias 和 LayerNorm 权重之外,应用于所有层的权重衰减数值。(`float`,可选,默认为 0.0) +- `max_prompt_len`: 生成样本时的最大生成长度, max_length 调大会增加生成时间,并且增加显存占用。注意: +max_dec_len + max_prompt_len 应当小于 max_seq_len。 +- `per_device_prompt_batch_size`: PPO 生成样本时的批处理大小,同 micro batch size,即满足 global_batch_size = dp(data parallel)* sharding * micro batch size。batch_size 调大会增加生成时间,并且增加显存占用 +- `per_device_train_batch_size`: 训练 batch 大小, 当前为了优化性能设为1,请避免更改 +- `per_device_eval_batch_size`: 评估 batch 大小。 +- `max_steps`: 总的训练步数 +- `eval_steps`: 模型评估的间隔步数 +- `max_evaluate_steps`: 模型单次评估的最大步数 +- `logging_steps`: 训练日志打印的间隔步数 +- `save_steps`: 模型参数保存的间隔步数 +- `weight_decay`: 权重衰减数值 +- `do_train`: 是否进行训练任务 +- `do_eval`: 是否进行评估任务 +- `fp16`: 使用 float16 精度进行模型训练和推理。 +- `bf16`: 使用 bfloat16 精度进行模型训练和推理。 +- `fp16_opt_level`: float16 精度训练模式,`O2`表示纯 float16 训练 diff --git a/llm/alignment/ppo/client.py b/llm/alignment/ppo/client.py new file mode 100644 index 000000000000..fbb1a285aa3b --- /dev/null +++ b/llm/alignment/ppo/client.py @@ -0,0 +1,37 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import requests + +CHAT_URL = "http://127.0.0.1:8731" + +data = { + "src": [ + "Natalia sold clips to 48 of her friends in April, ", + "Weng earns $12 an hour for babysitting. Yesterday", + ], + "tgt": [ + "Natalia sold 48/2 = 24 clips in May. #### 72", + "She earned 0.2 x 50 = $10. #### 10", + ], + "response": [ + "Natalia sold 48+24 = 72 clips altogether in April and May. #### 72", + "2", + ], +} +res = requests.post(CHAT_URL, json=data) +result = json.loads(res.text) +print("result:", result, result["score"]) \ No newline at end of file diff --git a/llm/alignment/ppo/comm_utils.py b/llm/alignment/ppo/comm_utils.py index de077c65db31..3514045d68ff 100644 --- a/llm/alignment/ppo/comm_utils.py +++ b/llm/alignment/ppo/comm_utils.py @@ -12,57 +12,246 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import sys +from enum import Enum, auto import paddle import paddle.distributed as dist +from paddle import nn -from paddlenlp.trainer.plugins.unified_checkpoint import flatten_list +from paddlenlp.trainer import strtobool from paddlenlp.trainer.trainer import Trainer, logger -from paddlenlp.trainer.utils.helper import nested_broadcast_tensor_with_empty from paddlenlp.utils.distributed import distributed_gather +from paddlenlp.utils.nested import flatten_list, nested_broadcast_tensor_with_empty global_dev_id = 0 if paddle.get_device() == "cpu" else int(paddle.get_device().split(":")[1]) +class ActorStages(Enum): + """ + Enum class, the stages of the actor training process. + """ + + MODEL_ENABLE_DISABLE = auto() + RL_STEP = auto() + PTX_STEP = auto() + + +class CriticStages(Enum): + """ + Enum class, the stages of the critic training process. + """ + + MODEL_ENABLE_DISABLE = auto() + CRITIC_TRAINING_STEP = auto() + + +class RolloutStages(Enum): + """ + Enum class, the stages of the rollout process. + """ + + ACTOR_MODEL_ENABLE_DISABLE = auto() + GENERATE = auto() + ROLLOUT_LOGPROB = auto() + REWARD_MODEL_ENABLE_DISABLE = auto() + ROLLOUT_REWARD_VALUE = auto() + + +def get_timer_label(stage: Enum) -> str: + """ + 获取计时器标签。 + + Args: + stage (Enum): RolloutStages/CriticStages/RolloutStages. + + Returns: + str: 打印Timer时的前缀。格式为 "[prefix] stage number.description"。 + - prefix: 阶段前缀,如"actor-step"、"critic-step"等。 + - stage number: 从1开始编号。 + - description: 阶段描述,小写形式。 + """ + step_prefix = { + ActorStages.MODEL_ENABLE_DISABLE: "actor-step", + ActorStages.RL_STEP: "actor-step", + ActorStages.PTX_STEP: "actor-step", + CriticStages.MODEL_ENABLE_DISABLE: "critic-step", + CriticStages.CRITIC_TRAINING_STEP: "critic-step", + RolloutStages.ACTOR_MODEL_ENABLE_DISABLE: "rollout", + RolloutStages.GENERATE: "rollout", + RolloutStages.ROLLOUT_LOGPROB: "rollout", + RolloutStages.REWARD_MODEL_ENABLE_DISABLE: "rollout", + RolloutStages.ROLLOUT_REWARD_VALUE: "rollout", + } + # stage + prefix = step_prefix.get(stage, "unknown") + # index + stage_number = list(stage.__class__).index(stage) + 1 + # description + description = stage.name.lower() # .replace('_', ' ') + # all + return f"[{prefix}] {stage_number}.{description}" + + +@paddle.no_grad() +def _move_param(src, device=None, blocking=True): + """ + 将参数从源设备移动到目标设备,并返回目标设备上的参数。如果目标设备未指定,则使用当前设备。 + + Args: + src (Tensor): 需要移动的参数张量。 + device (Optional[Union[str, paddle.Device]], optional): 目标设备,默认为None,表示使用当前设备。可以是字符串或paddle.Device对象。默认为None。 + blocking (bool, optional): 是否阻塞等待操作完成,默认为True。 + + Returns: + Tensor: 在目标设备上的参数张量。 + """ + if isinstance(device, str): + device = paddle.device._convert_to_place(device) + dst = src._copy_to(device, blocking) + dst_tensor = dst.value().get_tensor() + src_tensor = src.value().get_tensor() + src_tensor._clear() + src_tensor._share_data_with(dst_tensor) + + def offload_tensor_to_cpu(tensors): - if isinstance(tensors, dict): - for _, v in tensors.items(): - offload_tensor_to_cpu(v) - elif isinstance(tensors, paddle.Tensor): - if tensors.place.is_gpu_place(): - cpu_tensor = tensors._copy_to(paddle.CUDAPinnedPlace(), False) - tensors.value().get_tensor()._share_data_with(cpu_tensor.value().get_tensor()) + """ + 将给定的张量迁移到CPU上。如果使用了CUDA管理内存,则该函数无效。 + + Args: + tensors (tuple, list): tuple或list,包含两个元素,第一个元素是模型或优化器,第二个元素是字符串,表示是否为模型或优化器。 + + Returns: + None, 无返回值,直接修改原有张量。 + + Raises: + None, 没有引发任何异常。 + """ + if strtobool(os.getenv("FLAGS_use_cuda_managed_memory", "False")): + logger.warning("FLAGS_use_cuda_managed_memory has been set to True, " "offloading strategy is ineffective.") + return + + pin_device = paddle.CUDAPinnedPlace() + + def clear_main_grad(model): + for param in model.parameters(): + if hasattr(param, "main_grad") and param.main_grad is not None: + param.main_grad._clear_data() + param.main_grad = None + + # optimizer + if "optimizer" in tensors[1]: + optimizer = tensors[0] + # offload moment1 + for key, value in optimizer._accumulators[optimizer._moment1_acc_str].items(): + if value._is_initialized() and not isinstance(value.place, paddle.CUDAPinnedPlace): + optimizer._accumulators[optimizer._moment1_acc_str][key] = value.pin_memory() + + # offload moment2 + for key, value in optimizer._accumulators[optimizer._moment2_acc_str].items(): + if value._is_initialized() and not isinstance(value.place, paddle.CUDAPinnedPlace): + optimizer._accumulators[optimizer._moment2_acc_str][key] = value.pin_memory() + + # offload master_weight + for key, value in optimizer._master_weights.items(): + if value._is_initialized() and not isinstance(value.place, paddle.CUDAPinnedPlace): + optimizer._master_weights[key] = value.pin_memory() + # model + elif "model" in tensors[1]: + model = tensors[0] + clear_main_grad(model) + for name, src in model.named_parameters(): + if src._is_initialized() and not isinstance(src.place, paddle.CUDAPinnedPlace): + _move_param(src, pin_device) + + elif "tensor" in tensors[1]: + src = tensors[0] + if src._is_initialized() and not isinstance(src.place, paddle.CUDAPinnedPlace): + _move_param(src, pin_device) else: - logger.warning(f"Can't parse for type {type(tensors)}") - return tensors + logger.debug(f"Can't parse for type {tensors[1]}") def reload_tensor_to_gpu(tensors): - if isinstance(tensors, dict): - for _, v in tensors.items(): - reload_tensor_to_gpu(v) - elif isinstance(tensors, paddle.Tensor): - if tensors._is_initialized() and not tensors.place.is_gpu_place(): - gpu_tensor = tensors._copy_to(paddle.CUDAPlace(global_dev_id), False) - tensors.value().get_tensor()._share_data_with(gpu_tensor.value().get_tensor()) + """ + 将给定的张量从CPU转移到GPU中,并返回新的张量。如果没有设置环境变量FLAGS_use_cuda_managed_memory为True,则此函数无效。 + + Args: + tensors (List[Tuple[Any, str]]): 包含两个元素的列表,第一个元素是需要转移到GPU的张量,第二个元素是字符串,用于指示张量类型("optimizer"或"model")。 + + Returns: + List[Tuple[Any, str]]: 与输入相同的列表,但所有张量已经被转移到GPU中。 + + Raises: + None. + """ + if strtobool(os.getenv("FLAGS_use_cuda_managed_memory", "False")): + logger.warning("FLAGS_use_cuda_managed_memory has been set to True, " "offloading strategy is ineffective.") + return + + # optimizer + if "optimizer" in tensors[1]: + optimizer = tensors[0] + # offload moment1 + for key, value in optimizer._accumulators[optimizer._moment1_acc_str].items(): + if value._is_initialized() and not isinstance(value.place, paddle.CUDAPlace): + optimizer._accumulators[optimizer._moment1_acc_str][key] = value.cuda() + + # offload moment2 + for key, value in optimizer._accumulators[optimizer._moment2_acc_str].items(): + if value._is_initialized() and not isinstance(value.place, paddle.CUDAPlace): + optimizer._accumulators[optimizer._moment2_acc_str][key] = value.cuda() + + # offload master_weight + for key, value in optimizer._master_weights.items(): + if value._is_initialized() and not isinstance(value.place, paddle.CUDAPlace): + optimizer._master_weights[key] = value.cuda() + # model + elif "model" in tensors[1]: + model = tensors[0] + device = paddle.device.get_device() + for name, src in model.named_parameters(): + if src._is_initialized() and not isinstance(src.place, paddle.CUDAPlace): + _move_param(src, device) else: - logger.warning(f"Can't parse for type {type(tensors)}") - return tensors + logger.debug(f"Can't parse for type {tensors[1]}") def cleanup_tensor_space(tensors): + """ + 释放张量所占的空间,包括内存和磁盘空间。如果输入是字典类型,则递归处理其中的值;如果是paddle.Tensor类型,则清除数据;否则返回原始对象。 + + Args: + tensors (Union[dict, paddle.Tensor]): 需要释放空间的张量或字典,其中字典的值为张量。 + + Returns: + Union[dict, paddle.Tensor]: 如果输入是字典,则返回一个新的字典,其中值已经被释放空间;如果输入是paddle.Tensor,则返回一个清除了数据的paddle.Tensor。否则返回原始对象。 + """ if isinstance(tensors, dict): for _, v in tensors.items(): cleanup_tensor_space(v) elif isinstance(tensors, paddle.Tensor): tensors._clear_data() else: - logger.warning(f"Can't parse for type {type(tensors)}") + logger.debug(f"Can't parse for type {type(tensors)}") return tensors def data_group_split(tensors, group): + """ + 将数据按照给定的分组进行切分,如果没有给定分组则直接返回原始数据。 + 支持列表、元组、字典和paddle.Tensor类型的数据。 + + Args: + tensors (Union[List[Any], Tuple[Any], Dict[str, Any], paddle.Tensor]): 待切分的数据,可以是任意类型。 + group (Optional[distributed.Group]): 指定要切分的分组,如果为None则直接返回原始数据。默认为None。 + + Returns: + Union[List[Any], Tuple[Any], Dict[str, Any], paddle.Tensor]: 切分后的数据,与输入数据类型一致。 + 如果输入数据为字典,则返回的新字典中的值也会被切分。 + """ if group is None: return tensors if isinstance(tensors, (list, tuple)): @@ -75,11 +264,26 @@ def data_group_split(tensors, group): elif isinstance(tensors, paddle.Tensor): return tensors.split(group.nranks)[group.rank] else: - logger.warning(f"Can't parse for type {type(tensors)}") + logger.debug(f"Can't parse for type {type(tensors)}") return tensors def data_group_merge(tensors, group): + """ + 将数据组合成一个新的列表或字典,如果不是None则在指定的分组中进行all_gather_nd操作。 + + Args: + tensors (Union[List[Any], Tuple[Any], Dict[str, Any], paddle.Tensor]): 需要组合的数据,可以是列表、元组、字典或张量。 + 如果是张量,则会在指定的分组中进行all_gather_nd操作,并返回一个张量。 + group (Optional[int]): 指定的分组,如果为None,则直接返回原始数据。默认为None。 + + Returns: + Union[List[Any], Tuple[Any], Dict[str, Any], paddle.Tensor]: 返回一个新的列表或字典,或者一个张量,取决于传入的数据类型。 + 如果是张量,则是在指定的分组中进行all_gather_nd操作后的结果。 + + Raises: + None + """ if group is None: return tensors @@ -95,11 +299,26 @@ def data_group_merge(tensors, group): all_gather_nd(tensor_list, tensors, group=group, padded=True) return paddle.concat(tensor_list) else: - logger.warning(f"Can't parse for type {type(tensors)}") + logger.debug(f"Can't parse for type {type(tensors)}") return tensors def group_rank_guard(group, rank=0): + """ + 用于控制某个进程组中的某个进程是否参与函数调用,并在所有进程完成后进行通信。 + 如果该进程组中的某个进程不是指定的rank,则不会调用该函数。 + + Args: + group (distributed.ProcessGroup): 进程组对象。 + rank (int, optional, default=0): 需要参与函数调用的进程的rank,默认为0。 + rank为-1时表示所有进程都参与。 + + Returns: + function: 返回一个装饰器,该装饰器接受一个函数作为参数,返回一个包装后的函数。 + 被装饰的函数将在指定的rank的进程中被调用,其他进程不会被调用。 + 在所有进程完成后,将进行通信,并广播结果到所有进程。 + """ + def decorator(func): def wrapper_func(*args, **kwargs): if group.rank == rank: @@ -117,6 +336,23 @@ def wrapper_func(*args, **kwargs): def repad_rl_batches(batches, input_lengths): + """ + 对输入的批次进行重新填充,使得每个批次的长度都是最大长度。 + 如果批次中包含了位置ID,则在未被访问到的部分填充为1。 + + Args: + batches (dict): 包含输入数据和其他信息的字典,格式为{"input_ids": Tensor, "attention_mask": Tensor, ...}。 + 其中Tensor的形状应该是(batch_size, sequence_length)。 + input_lengths (Tensor): 一个长度为batch_size的张量,表示每个批次的实际长度。 + 形状为(batch_size,)。 + + Returns: + dict: 返回一个更新后的字典,包含了重新填充后的输入数据和其他信息。 + 如果原始批次中没有包含位置ID,那么这个字段将不会出现在返回值中。 + + Raises: + None + """ if batches.get("position_ids", None) is not None: v = batches["position_ids"] for x in range(v.shape[0]): @@ -129,6 +365,62 @@ def repad_rl_batches(batches, input_lengths): return batches +def remove_input_padding(input_ids, pad_id): + """ + 从输入ID中移除填充,返回一个列表,每个元素是一个不包含pad_id的paddle.Tensor。 + + Args: + input_ids (List[paddle.Tensor]): 包含输入ID的列表,每个元素是一个1维的paddle.Tensor,dtype为int64。 + pad_id (int): 需要移除的填充ID。 + + Returns: + List[paddle.Tensor]: 包含不包含pad_id的输入ID的列表,每个元素是一个1维的paddle.Tensor,dtype为int64。 + """ + result = [] + for ids in input_ids: + ids_list = ids.tolist() + filtered_ids = [id for id in ids_list if id != pad_id] + result.append(paddle.to_tensor(filtered_ids, dtype="int64")) + return result + + +def concat_input_response_and_padding(input_ids_wo_padding, response, pad_id): + """ + 将输入和响应进行拼接,并添加适当的填充。 + + Args: + input_ids_wo_padding (List[Tensor]): 不包含填充的输入ID列表,形状为(batch_size,seq_len)。 + response (Tensor): 响应矩阵,形状为(num_return_index,batch_size,seq_len)。 + pad_id (int): 用于填充的ID。 + + Returns: + Tensor: 返回一个形状为(num_return_index,batch_size,max_seq_len)的Tensor,其中max_seq_len是所有输入和响应的最大长度。 + 每个元素都是由input_ids_wo_padding和response的对应元素拼接而成的。如果拼接后的长度小于max_seq_len,则会在末尾追加pad_id。 + """ + concat_results = [] + max_seq_len = 0 + for num_return_index in range(response.shape[0]): + batch_concat_input_response = [] + for batch_index in range(response.shape[1]): + one_input = input_ids_wo_padding[batch_index] + one_response = response[num_return_index][batch_index] + one_concat_input_response = paddle.concat((one_input, one_response)) + max_seq_len = max(max_seq_len, one_concat_input_response.shape[0]) + batch_concat_input_response.append(one_concat_input_response) + concat_results.append(batch_concat_input_response) + + padding_results = [] + for num_return_index in range(response.shape[0]): + batch_padding_result = [] + for batch_index in range(response.shape[1]): + difference = max_seq_len - concat_results[num_return_index][batch_index].shape[0] + one_padding_result = concat_results[num_return_index][batch_index].tolist() + difference * [pad_id] + batch_padding_result.append(paddle.to_tensor(one_padding_result, dtype="int64")) + padding_results.append(batch_padding_result) + + return paddle.to_tensor(padding_results, dtype="int64") + + # https://stackoverflow.com/questions/12594148/skipping-execution-of-with-block class SkipWithBlock(Exception): pass @@ -136,18 +428,58 @@ class SkipWithBlock(Exception): class SkipContextManager: def __init__(self, skip): + """ + Initializes the class with the given skip value. + + Args: + skip (int): The number of rows to skip in the input data. + + Returns: + None. + """ self.skip = skip def __enter__(self): + """ + 在进入上下文管理器时调用,返回自身。 + 如果需要执行一些初始化操作,可以重写此方法。 + + Returns: + TraceContextManager: 当前实例对象自身。 + """ if self.skip: sys.settrace(lambda *args, **keys: None) frame = sys._getframe(1) frame.f_trace = self.trace def trace(self, frame, event, arg): - raise SkipWithBlock() + """ + 跟踪函数执行,并在遇到指定的代码块时抛出SkipWithBlock异常。 + 当前实现只支持单个代码块,不支持多个。 + + Args: + frame (types.FrameType): 当前执行的frame对象。 + event (str): 事件类型,包括'call', 'return', 'exception_raised', 'yield'. + arg (Any): 可选参数,用于传递给event_handler函数。 + + Raises: + SkipWithBlock: 当遇到指定的代码块时抛出此异常,表示需要跳过后续的测试执行。 + """ + raise SkipWithBlock def __exit__(self, type, value, traceback): + """ + 如果退出时没有异常,则返回True。如果退出时是SkipWithBlock的子类,则返回True以抑制该异常。否则返回False。 + 如果没有异常,则返回True。如果退出时是SkipWithBlock的子类,则返回True以抑制该异常。否则返回False。 + + Args: + type (Optional[Type[BaseException]]): 可选,异常类型,如果为None,则表示没有异常。默认为None。 + value (Optional[BaseException]): 可选,异常对象,如果type不为None,则必须提供value参数。默认为None。 + traceback (Optional[traceback]): 可选,追踪信息,如果type不为None,则必须提供traceback参数。默认为None。 + + Returns: + bool: 如果没有异常或者异常是SkipWithBlock的子类,则返回True;否则返回False。 + """ if type is None: return # No exception if issubclass(type, SkipWithBlock): @@ -166,7 +498,8 @@ def all_gather_nd(tensor_list, tensor, group=None, padded=False): Returns: (Tensor): output list of tensors that can be of different sizes """ - if len(tensor.shape) == 0: + tensor_dim = tensor.dim() + if tensor_dim == 0: tensor = tensor.reshape([1]) dist.all_gather(tensor_list, tensor, group=group) return tensor_list @@ -176,23 +509,20 @@ def all_gather_nd(tensor_list, tensor, group=None, padded=False): all_sizes = [paddle.zeros_like(local_size) for _ in range(world_size)] dist.all_gather(all_sizes, local_size, group=group) - # max_length = max(size[0] for size in all_sizes) - - # length_diff = max_length.item() - local_size[0].item() - # if length_diff: - # pad_size = (length_diff, *tensor.size()[1:]) - # padding = paddle.zeros(pad_size, place=tensor.place(), dtype=tensor.dtype) - # tensor = padle.concat((tensor, padding)) - max_length = max(size[-1] for size in all_sizes) length_diff = max_length.item() - local_size[-1].item() if length_diff: - pad_size = (*tensor.shape[:-1], length_diff) - padding = paddle.zeros(pad_size, dtype=tensor.dtype) - tensor = paddle.concat([tensor, padding], axis=-1) + if tensor_dim == 2: + pad_size = (*tensor.shape[:-1], length_diff) + padding = paddle.zeros(pad_size, dtype=tensor.dtype) + tensor = paddle.concat([tensor, padding], axis=-1) + elif tensor_dim == 4: + # Note(gongenlei): support attention mask + tensor = nn.Pad2D([0, length_diff, 0, length_diff], mode="constant", value=0.0)(tensor) all_tensors_padded = [] + tensor = tensor.contiguous() dist.all_gather(all_tensors_padded, tensor, group=group) # all_tensors = [] if padded: @@ -200,11 +530,43 @@ def all_gather_nd(tensor_list, tensor, group=None, padded=False): return all_tensors_padded for tensor_, size in zip(all_tensors_padded, all_sizes): - tensor_list.append(tensor_[..., : size[-1]]) + if tensor_dim == 2: + tensor_list.append(tensor_[..., : size[-1]]) + elif tensor_dim == 4: + tensor_list.append(tensor_[..., : size[-1], : size[-1]]) return tensor_list def export_evaluate_model(self: Trainer, train_model, eval_model, **kwargs): + """ + 导出评估模型。 + + Args: + self (Trainer, required): + Trainer 对象的引用。 + + train_model (nn.Layer, required): + Train 模型,需要在训练过程中使用。 + + eval_model (Optional[nn.Layer], optional): + 评估模型,如果没有提供,则返回 None。默认为 None。 + + with_offload (bool, optional): + 是否将训练模型的张量转换到 CPU 上,默认为 False。 + + kwargs (Dict, optional): + 可选参数字典,包括: + - with_offload (bool, optional): + 是否将训练模型的张量转换到 CPU 上,默认为 False。 + + Returns: + Optional[None]: + 如果 eval_model 不存在,则返回 None;否则返回 None。 + + Raises: + ValueError: + 当 eval_model 的 tensor_parallel_degree 与 train_model 的 tensor_parallel_degree 不相同时,会引发此错误。 + """ if eval_model is None: return None @@ -263,17 +625,21 @@ def export_evaluate_model(self: Trainer, train_model, eval_model, **kwargs): v.get_tensor()._share_data_with(t.get_tensor()) if with_offload: - offload_tensor_to_cpu(train_state_dict[key]) + offload_tensor_to_cpu((train_state_dict[key], "tensor")) else: # single to single # tp+pp -> single raise ValueError("Not support yet.") - def create_send_recv_table(train_keys, eval_keys): + def create_send_recv_table(train_keys, eval_keys, is_value_trainer): recv_table = [] send_table = [] if pp_group.rank == 0: for key in eval_keys: + if (not eval_model.config.weight_sharing) and is_value_trainer: + if "output_linear.out_linear" in key: + logger.debug(f"Skip: {key}") + continue recv_table.append((key, global_rank)) for key in train_keys: @@ -303,9 +669,13 @@ def create_send_recv_table(train_keys, eval_keys): # tp情况 # tp+pp->tp - self.timers and self.timers("export-merge-pp").start() + # self.timers and self.timers("export-merge-pp").start() if eval_tp_size > 1 and train_pp_size > 1: - table = create_send_recv_table(train_state_dict.keys(), eval_state_dict.keys()) + table = create_send_recv_table( + train_state_dict.keys(), + eval_state_dict.keys(), + self.trainer_type == "value", + ) for key, src_rank, dst_rank in table: # Init tensor for model is cleaned @@ -325,29 +695,32 @@ def create_send_recv_table(train_keys, eval_keys): # Offload train model if need if global_rank == src_rank and with_offload: - offload_tensor_to_cpu(train_state_dict[key]) + offload_tensor_to_cpu((train_state_dict[key], "tensor")) - self.timers and self.timers("export-merge-pp").stop() - self.timers and self.timers("export-broadcast-pp").start() + # self.timers and self.timers("export-merge-pp").stop() + # self.timers and self.timers("export-broadcast-pp").start() if pp_group.nranks > 1: paddle.distributed.parallel.sync_params_buffers( - eval_model, comm_group=pp_group, src_rank=pp_group.ranks[0], fuse_params=False + eval_model, + comm_group=pp_group, + src_rank=pp_group.ranks[0], + fuse_params=False, ) - self.timers and self.timers("export-broadcast-pp").stop() + # self.timers and self.timers("export-broadcast-pp").stop() else: # 其他 DP rank 的state dict, 适配 offload 和初始化 - self.timers and self.timers("export-offload-and-init").start() + # self.timers and self.timers("export-offload-and-init").start() if with_offload: for key in list(train_state_dict.keys()): - offload_tensor_to_cpu(train_state_dict[key]) + offload_tensor_to_cpu((train_state_dict[key], "tensor")) for k, v in eval_state_dict.items(): if not v._is_initialized(): t = paddle._C_ops.full_like(v, 0, v.dtype, paddle.CUDAPlace(global_dev_id)) v.get_tensor()._share_data_with(t.get_tensor()) - self.timers and self.timers("export-offload-and-init").stop() + # self.timers and self.timers("export-offload-and-init").stop() paddle.distributed.barrier() - self.timers and self.timers("export-broadcast-sd-dp").start() + # self.timers and self.timers("export-broadcast-sd-dp").start() if eval_tp_size == 1: for _, tensor in eval_state_dict.items(): paddle.distributed.broadcast(tensor, src=0, group=None, sync_op=True) @@ -355,17 +728,19 @@ def create_send_recv_table(train_keys, eval_keys): if sd_group.nranks > 1: if dp_group.rank <= 0: paddle.distributed.parallel.sync_params_buffers( - eval_model, comm_group=sd_group, src_rank=sd_group.ranks[0], fuse_params=False + eval_model, + comm_group=sd_group, + src_rank=sd_group.ranks[0], + fuse_params=False, ) if dp_group.nranks > 1: paddle.distributed.parallel.sync_params_buffers( - eval_model, comm_group=dp_group, src_rank=dp_group.ranks[0], fuse_params=False + eval_model, + comm_group=dp_group, + src_rank=dp_group.ranks[0], + fuse_params=False, ) - self.timers and self.timers("export-broadcast-sd-dp").stop() - # paddle.save(eval_state_dict, f"./tmp/eval_{sd_group.rank}_tp_{eval_tp_rank}_pp_{pp_group.rank}.pdparams") - # paddle.save(train_state_dict, f"./tmp/train_{sd_group.rank}_tp_{tp_group.rank}_pp_{pp_group.rank}.pdparams") - # paddle.distributed.barrier() - # exit(-1) + # self.timers and self.timers("export-broadcast-sd-dp").stop() old_dp_workers = self.args.world_size // (max(sd_group.nranks, 1) * max(dp_group.nranks, 1)) group_nums = self.args.logical_process_index // old_dp_workers * eval_tp_size + eval_tp_rank @@ -377,6 +752,17 @@ def create_send_recv_table(train_keys, eval_keys): def create_data_trans_group(global_rank, group_nums): + """ + 创建一个数据传输组,该组将根据给定的全局排名和组数分割。 + 该函数使用了paddle.distributed.all_gather_object进行通信,并返回一个新的分布式组对象。 + + Args: + global_rank (int): 当前全局排名。 + group_nums (List[int]): 需要分割的组数列表。 + + Returns: + paddle.distributed.Group: 返回一个新的分布式组对象,包含所有参与分割的全局排名。如果当前全局排名在任何一个组中,则返回该组。如果当前全局排名不在任何一个组中,则返回None。 + """ all_split_table = [] paddle.distributed.all_gather_object(all_split_table, [(global_rank, group_nums)]) all_split_table = flatten_list(all_split_table) @@ -400,4 +786,38 @@ def create_data_trans_group(global_rank, group_nums): return group +def new_timer_log(self, names, normalizer=1.0, reset=True): + """Log a group of timers.""" + + def format_dict(data): + """Format the timer log.""" + result = {} + order = [] + for key, value in data.items(): + category, detail = key.split(" ", maxsplit=1) + if category not in result: + result[category] = [] + order.append(category) + result[category].append(f"{detail}: {round(value, 2)}") + + output = "" + for category in order: + if category in result: + output += f"\n{category}" + for value in result[category]: + output += f"\n {value}" + return output + + assert normalizer > 0.0 + string = "time (ms)" + names = sorted(names) + time_dict = {} + for name in names: + time_dict[name] = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer + if len(time_dict) == 0: + return "skipped" + string += format_dict(time_dict) + return string + + Trainer.export_evaluate_model = export_evaluate_model diff --git a/llm/alignment/ppo/data/__init__.py b/llm/alignment/ppo/data/__init__.py index 227b120a4b68..5da71f0cebd9 100644 --- a/llm/alignment/ppo/data/__init__.py +++ b/llm/alignment/ppo/data/__init__.py @@ -16,6 +16,7 @@ from .alpaca import * from .base import * +from .jsondata import * from .preference import * from .prompt_only import * from .safe_rlhf import * diff --git a/llm/alignment/ppo/data/alpaca.py b/llm/alignment/ppo/data/alpaca.py index 2ce894b3dfbe..945df4ef4b8a 100644 --- a/llm/alignment/ppo/data/alpaca.py +++ b/llm/alignment/ppo/data/alpaca.py @@ -27,7 +27,7 @@ class AlpacaDataset(RawDataset): NAME: str = "alpaca" ALIASES: tuple[str, ...] = ("stanford-alpaca",) - def __init__(self, path: str | None = None) -> None: + def __init__(self, path: str | None = None, *args, **kwargs) -> None: self.data = load_dataset(path or "tatsu-lab/alpaca", split="train") def __getitem__(self, index: int) -> RawSample: diff --git a/llm/alignment/ppo/data/base.py b/llm/alignment/ppo/data/base.py index ab3fd5b19d2c..47c8b38456ca 100644 --- a/llm/alignment/ppo/data/base.py +++ b/llm/alignment/ppo/data/base.py @@ -51,9 +51,9 @@ ] IGNORE_INDEX: int = -100 -PROMPT_BEGIN: str = "BEGINNING OF CONVERSATION: " -PROMPT_USER: str = "USER: {input} " -PROMPT_ASSISTANT: str = "ASSISTANT:" # should not have a space at the end +PROMPT_BEGIN: str = "" +PROMPT_USER: str = "{input}" +PROMPT_ASSISTANT: str = "" # should not have a space at the end PROMPT_INPUT: str = PROMPT_BEGIN + PROMPT_USER + PROMPT_ASSISTANT @@ -326,6 +326,7 @@ def __init__( # pylint: disable=too-many-branches dataset_names_and_attributes: dict[str, float | dict[str, Any]] | Iterable[tuple[str, float | dict[str, Any]]], tokenizer: PretrainedTokenizerBase, lazy_tokenization: bool = True, + use_rm_server: bool = False, seed: int = 42, ) -> None: if not isinstance(dataset_names_and_attributes, dict): @@ -348,6 +349,8 @@ def __init__( # pylint: disable=too-many-branches raise TypeError( f"Dataset `{name}` attributes should be a float or a dict, " f"got {type(attributes).__name__}.", ) + kwargs["use_rm_server"] = use_rm_server + proportion = kwargs.pop("proportion", 1.0) if isinstance(proportion, Fraction): if not (proportion < 0 and proportion.denominator == 1): @@ -368,6 +371,7 @@ def __init__( # pylint: disable=too-many-branches self.tokenizer = tokenizer self.seed = seed + self.use_rm_server = use_rm_server merged_rawdata = self._merge_raw_datasets(seed=seed) self.rawdata = [merged_rawdata[i] for i in range(len(merged_rawdata))] @@ -510,9 +514,10 @@ def split_train_test( class CollatorBase(metaclass=abc.ABCMeta): pad_token_id: int # The id of the padding token for the tokenizer. - def __init__(self, pad_token_id: int) -> None: + def __init__(self, pad_token_id: int, use_rm_server: bool) -> None: """Initialize a collator.""" self.pad_token_id = pad_token_id + self.use_rm_server = use_rm_server @abc.abstractmethod def __call__(self, samples: list[dict[str, paddle.Tensor]]) -> dict[str, paddle.Tensor]: diff --git a/llm/alignment/ppo/data/jsondata.py b/llm/alignment/ppo/data/jsondata.py new file mode 100644 index 000000000000..a20fe8de348a --- /dev/null +++ b/llm/alignment/ppo/data/jsondata.py @@ -0,0 +1,44 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datasets import load_dataset + +from .base import RawDataset, RawSample + +__all__ = ["JsonDataset"] + + +class JsonDataset(RawDataset): + NAME: str = "Jsonfile" + + def __init__(self, path: str | None = None, *args, **kwargs) -> None: + self.data = load_dataset("json", data_files=path, split="train") + self.use_rm_server = kwargs.pop("use_rm_server", False) + assert "src" in self.data.column_names, "'src' should be included in jsonfile" + if self.use_rm_server: + assert "tgt" in self.data.column_names, "'tgt' should be included in jsonfile when using rm server" + + def __getitem__(self, index: int) -> RawSample: + data = self.data[index] + if self.use_rm_server: + rawdata = RawSample( + input=data["src"], + answer=data["tgt"], + ) + else: + rawdata = RawSample(input=data["src"]) + return rawdata + + def __len__(self) -> int: + return len(self.data) # dataset size diff --git a/llm/alignment/ppo/data/prompt_only.py b/llm/alignment/ppo/data/prompt_only.py index 964a55a70574..fc14741abfe6 100644 --- a/llm/alignment/ppo/data/prompt_only.py +++ b/llm/alignment/ppo/data/prompt_only.py @@ -35,23 +35,28 @@ class PromptOnlySample(TypedDict, total=True): input_ids: paddle.Tensor # size = (L,) + label_ids: paddle.Tensor # size = (L,) class PromptOnlyBatch(TypedDict, total=True): input_ids: paddle.Tensor # size = (B, L) attention_mask: paddle.Tensor # size = (B, L) + label_ids: paddle.Tensor # size = (B, L) class PromptOnlyDataset(TokenizedDataset): def preprocess(self, raw_sample: RawSample) -> PromptOnlySample: + input_dict = {} prompt = format_prompt(input=raw_sample["input"], eos_token=self.tokenizer.eos_token) - input_ids = self.tokenize(prompt) - return { - "input_ids": input_ids, # size = (L,) - } + input_dict["input_ids"] = self.tokenize(prompt) + if self.use_rm_server: + answer = format_prompt(input=raw_sample["answer"], eos_token=self.tokenizer.eos_token) + input_dict["label_ids"] = self.tokenize(answer) + + return input_dict def get_collator(self) -> Callable[[list[dict[str, paddle.Tensor]]], dict[str, paddle.Tensor]]: - return PromptOnlyCollator(self.tokenizer.pad_token_id) + return PromptOnlyCollator(self.tokenizer.pad_token_id, self.use_rm_server) def _merge_raw_datasets(self, seed: int | None = None) -> Dataset[RawSample]: """Merge multiple raw datasets into one dataset and remove duplicates.""" @@ -67,12 +72,15 @@ def to_hashable(raw_sample: RawSample) -> Hashable: class PromptOnlyCollator(CollatorBase): def __call__(self, samples: list[PromptOnlySample]) -> PromptOnlyBatch: + input_dict = {} + input_ids = [sample["input_ids"] for sample in samples] attention_mask = [np.ones(input_id.shape, dtype=bool) for input_id in input_ids] + input_dict["input_ids"] = left_padding(input_ids, padding_value=self.pad_token_id) + input_dict["attention_mask"] = left_padding(attention_mask, padding_value=0) + + if self.use_rm_server: + label_ids = [sample["label_ids"] for sample in samples] + input_dict["label_ids"] = left_padding(label_ids, padding_value=self.pad_token_id) - input_ids = left_padding(input_ids, padding_value=self.pad_token_id) - attention_mask = left_padding(attention_mask, padding_value=0) - return { - "input_ids": input_ids, # size = (B, L) - "attention_mask": attention_mask, # size = (B, L) - } + return input_dict diff --git a/llm/alignment/ppo/data/safe_rlhf.py b/llm/alignment/ppo/data/safe_rlhf.py index 427c2c7a69a1..c5c08eee472a 100644 --- a/llm/alignment/ppo/data/safe_rlhf.py +++ b/llm/alignment/ppo/data/safe_rlhf.py @@ -34,7 +34,7 @@ class SafeRLHFDataset(RawDataset): SPLIT: ClassVar[str] PATH: ClassVar[str] - def __init__(self, path: str | None = None) -> None: + def __init__(self, path: str | None = None, *args, **kwargs) -> None: self.data = load_dataset(path or self.PATH, split=self.SPLIT) def __getitem__(self, index: int) -> RawSample: diff --git a/llm/alignment/ppo/data/supervised.py b/llm/alignment/ppo/data/supervised.py index 26aa97a14377..195762c315e5 100644 --- a/llm/alignment/ppo/data/supervised.py +++ b/llm/alignment/ppo/data/supervised.py @@ -99,13 +99,13 @@ def preprocess(self, raw_sample: RawSample) -> SupervisedSample: def get_collator( self, shift: bool = False ) -> Callable[[list[dict[str, paddle.Tensor]]], dict[str, paddle.Tensor]]: - return SupervisedCollator(self.tokenizer.pad_token_id, shift) + return SupervisedCollator(self.tokenizer.pad_token_id, shift, use_rm_server=False) class SupervisedCollator(CollatorBase): - def __init__(self, pad_token_id: int, shift: bool = False) -> None: + def __init__(self, pad_token_id: int, shift: bool = False, use_rm_server: bool = False) -> None: """Initialize a collator.""" - super().__init__(pad_token_id) + super().__init__(pad_token_id, use_rm_server=use_rm_server) self._shift = shift def __call__(self, samples: list[SupervisedSample]) -> SupervisedBatch: diff --git a/llm/alignment/ppo/infer_utils.py b/llm/alignment/ppo/infer_utils.py index 62cb5a9b9fb0..2f655ced2a4d 100644 --- a/llm/alignment/ppo/infer_utils.py +++ b/llm/alignment/ppo/infer_utils.py @@ -21,14 +21,20 @@ import paddle import paddle.distributed as dist -from comm_utils import cleanup_tensor_space, offload_tensor_to_cpu, reload_tensor_to_gpu +from comm_utils import offload_tensor_to_cpu, reload_tensor_to_gpu from paddle.utils import try_import +from predict.predictor import ( + DygraphInferencePredictor, + PdArgumentParser, + PredictorArgument, +) from trainer_utils import guard_set_args import paddlenlp from paddlenlp.trainer.trainer import Trainer, logger from paddlenlp.transformers import PretrainedModel, PretrainedTokenizer from paddlenlp.transformers.model_utils import dtype_guard +from paddlenlp.trl.llm_utils import init_dist_env class Predictor: @@ -46,17 +52,12 @@ def __init__(self, config, model: PretrainedModel = None, tokenizer: PretrainedT # 2. inputs_processer creates caches and other inputs can be shared among # multi time prediction. define caches and extra inputs creation method # instead of using predictor.__init__ - from predictor import InferencePredictorMixin - self._buffer_maker = types.MethodType(InferencePredictorMixin.__init__, self) - self._inputs_processer = types.MethodType(InferencePredictorMixin._preprocess, self) + self._buffer_maker = types.MethodType(DygraphInferencePredictor.__init__, self) + self._inputs_processer = types.MethodType(DygraphInferencePredictor._preprocess, self) @staticmethod def create_predictor(trainer): - from predictor import PdArgumentParser, PredictorArgument - - from paddlenlp.trl.llm_utils import get_model_max_position_embeddings - # create infer model # NOTE: infer model use static name param_attr to create and cannot be # created multiple times. @@ -68,11 +69,12 @@ def create_infer_model(model, dtype, set_state=False): eos_token_id=trainer.tokenizer.eos_token_id, pad_token_id=trainer.tokenizer.pad_token_id ) config = copy.deepcopy(model.config) - hcg = dist.fleet.get_hybrid_communicate_group() # may differ with training - config.tensor_parallel_degree = hcg.get_model_parallel_world_size() - config.tensor_parallel_rank = hcg.get_model_parallel_rank() - config.quant_type = None + config.tensor_parallel_rank, config.tensor_parallel_degree = init_dist_env() + config.quant_type = [] + config.cachekv_int8_type = None + config.append_attn = False config.single_card_ptq = True + infer_model_cls = getattr(paddlenlp.experimental.transformers, model.__class__.__name__ + "InferenceModel") # ori_init_weights = infer_model_cls.init_weights # infer_model_cls.init_weights = lambda self: None @@ -109,12 +111,12 @@ def _create_param(self, *args, **kwargs): parser = PdArgumentParser((PredictorArgument,)) predictor_args = parser.parse_dict( { - "src_length": get_model_max_position_embeddings( # can be changed dynamically by predictor.input_length - trainer.model.config if eval_model is None else eval_model.config - ), - "max_length": trainer.args.max_length, + "model_name_or_path": trainer.args.actor_model_name_or_path, + "src_length": trainer.args.max_src_len, + "max_length": trainer.args.max_dec_len, + "total_max_length": trainer.args.max_src_len + trainer.args.max_dec_len, "dtype": trainer.amp_dtype, - "batch_size": trainer.args.per_device_train_batch_size, + "batch_size": trainer.args.per_device_prompt_batch_size * trainer.args.num_return_sequences, # infer model do not support top_k, and differ with non-infer model # generation which gets default top_K=50 using generation_config.top_k "top_p": trainer.args.top_p, @@ -122,6 +124,7 @@ def _create_param(self, *args, **kwargs): "repetition_penalty": trainer.args.repetition_penalty, } )[0] + policy_predictor = Predictor(predictor_args, model=infer_model, tokenizer=trainer.tokenizer) return policy_predictor @@ -138,10 +141,13 @@ def _create_caches(self): self.config.src_length = getattr(self, "input_length", self.config.src_length) if not hasattr(self, "_buffer_attrs"): pre_attrs = set(self.__dict__.keys()) - self.cache_k_shapes, self.cache_v_shapes = self.model.get_cache_kvs_shape( + self.cache_kvs_shapes = self.model.get_cache_kvs_shape( self.model_config, self.config.batch_size, self.config.total_max_length ) - self._buffer_maker(self.config, self.tokenizer) + + self.config.model_config = copy.deepcopy(self.model_config) + + self._buffer_maker(self.config, self.tokenizer, self.model) if not hasattr(self, "_buffer_attrs"): self._buffer_attrs = set(self.__dict__.keys()) - pre_attrs @@ -170,37 +176,15 @@ def enable(self, model, offload_model=True): @paddle.no_grad() def set_state_dict(self, model, offload_model=True): - offload_place = paddle.CUDAPinnedPlace() - state_dict = {} - for k, v in model.state_dict().items(): - state_dict[k] = v - - if getattr(self, "_weights_mapping", None) is None: - self._weights_mapping = self.model.get_weights_mapping() - - # non_share_params = [] - for k, v in self._weights_mapping.items(): - param, (convert_fun, args) = k, v - args = [state_dict[name] for name in args] - value = convert_fun(*args) - if offload_model: - for arg in args: - # shared params no need to offload - if value is not arg: - cpu_arg = arg._copy_to(offload_place, blocking=False) - cpu_arg._share_buffer_to(arg) - if not isinstance(value, paddle.Tensor): - param.set_value(value) - # elif isinstance(value.place, paddle.CUDAPlace): - elif value.place.is_gpu_place(): - # NOTE: _share_buffer_to seems do not work - # value._share_buffer_to(param) - # value._share_underline_tensor_to(param) - param.get_tensor()._share_data_with(value.get_tensor()) - else: - param.copy_(value, True) - - paddle.device.cuda.synchronize() + self.model.set_state_dict(model.state_dict()) + if offload_model: + offload_place = paddle.CUDAPinnedPlace() + state_dict = model.state_dict() + for k, v in state_dict.items(): + cpu_arg = v._copy_to(offload_place, blocking=False) + cpu_arg._share_buffer_to(v) + # v.value().get_tensor()._share_data_with(cpu_arg.value().get_tensor()) + paddle.device.synchronize() def _preprocess(self, source): # make cache when infer happens to get actual shape to save memory @@ -257,7 +241,7 @@ def infer_guard(trainer, offload_model=True): try: try_import("paddlenlp_ops") - except: + except ImportError: logger.warning("paddlenlp_ops does not exist, please install paddlenlp_ops for generation speedup.") yield return @@ -265,21 +249,30 @@ def infer_guard(trainer, offload_model=True): global policy_predictor if policy_predictor is None: policy_predictor = Predictor.create_predictor(trainer) - if not policy_predictor.is_available: - policy_predictor.enable(model, offload_model=offload_model) + with dtype_guard(trainer.amp_dtype): + if not policy_predictor.is_available: + policy_predictor.enable(model, offload_model=offload_model) # TODO(guosheng): patch for dist.all_recude to use tp group, fix it later - ori_all_reduce = dist.all_reduce - ori_broadcast = dist.broadcast - hcg = dist.fleet.get_hybrid_communicate_group() - dist.all_reduce = lambda x: ori_all_reduce(x, group=hcg.get_model_parallel_group()) - dist.broadcast = lambda x, rank: ori_broadcast( - x, src=hcg.get_model_parallel_group_src_rank(), group=hcg.get_model_parallel_group() - ) - yield - dist.all_reduce = ori_all_reduce - dist.broadcast = ori_broadcast + is_distributed = True + try: + hcg = dist.fleet.get_hybrid_communicate_group() + except Exception: + is_distributed = False + + if is_distributed: + ori_all_reduce = dist.all_reduce + ori_broadcast = dist.broadcast + dist.all_reduce = lambda x: ori_all_reduce(x, group=hcg.get_model_parallel_group()) + dist.broadcast = lambda x, rank: ori_broadcast( + x, src=hcg.get_model_parallel_group_src_rank(), group=hcg.get_model_parallel_group() + ) + yield + dist.all_reduce = ori_all_reduce + dist.broadcast = ori_broadcast + else: + yield policy_predictor.disable(model, onload_model=offload_model) @@ -296,20 +289,23 @@ def __init__(self, trainer: Trainer): def enable(self): trainer = self.trainer if trainer.model is not self.model: + reload_tensor_to_gpu((trainer.model, "train_model")) + reload_tensor_to_gpu((self.model, "freeze_model")) trainer.export_evaluate_model( trainer.model, self.model, with_offload="train_model" in trainer.args.offload_level, ) else: - reload_tensor_to_gpu(self.model.state_dict()) + reload_tensor_to_gpu((self.model, "train_model")) def disable(self): trainer = self.trainer if trainer.model is not self.model: - cleanup_tensor_space(self.model.state_dict()) + offload_tensor_to_cpu((trainer.model, "train_model")) + offload_tensor_to_cpu((self.model, "freeze_model")) else: - offload_tensor_to_cpu(self.model.state_dict()) + offload_tensor_to_cpu((self.model, "train_model")) def __getattr__(self, name): try: diff --git a/llm/alignment/ppo/models/infer_model_utils.py b/llm/alignment/ppo/models/infer_model_utils.py index 3d63fe52aa9b..e74329c66cf5 100644 --- a/llm/alignment/ppo/models/infer_model_utils.py +++ b/llm/alignment/ppo/models/infer_model_utils.py @@ -19,6 +19,17 @@ def patch_paddlenlp_ops(eos_token_id, pad_token_id): + """ + 修补 PaddleNLP Ops,用于处理 EOS 标记和填充。 + + Args: + eos_token_id (int): EOS 标记的 ID,用于推断模型中的 padding。 + 当非推断模型使用该标记进行 padding 时,需要将其更改为 pad_token_id。 + pad_token_id (int, optional): 填充标记的 ID,默认为 None。当非推断模型使用该标记进行 padding 时,需要将其更改为 eos_token_id。 + + Returns: + None. 直接在 PaddleNLP Ops 上修改函数实现。 + """ import paddlenlp_ops paddlenlp_ops.save_with_output = lambda *args, **kwargs: None @@ -69,6 +80,16 @@ def _update_model_kwargs(self, *args, **kwargs): def register_model(model_cls_name): + """ + 注册模型类名,并将其映射到对应的函数上。 + + Args: + model_cls_name (str): 模型类名,用于在映射表中进行存储。 + + Returns: + function: 返回一个装饰器,该装饰器接收一个函数作为参数,并将函数映射到传入的模型类名上。 + """ + def mark_cls_name(func): # Do not register here although we can, otherwise infer model would import # before paddlenlp_ops. @@ -79,6 +100,15 @@ def mark_cls_name(func): def patch_infer_model(): + """ + 修补 InferModel 类的 get_weights_mapping 方法,使其能够正确获取权重映射。 + + Args: + 无参数,不需要传入任何参数。 + + Returns: + None, 该函数没有返回值。 + """ import paddlenlp.experimental.transformers as infer_transformers for model_cls_name, get_weights_mapping in _model_weights_mapping_dict.items(): diff --git a/llm/alignment/ppo/models/pp_model_utils.py b/llm/alignment/ppo/models/pp_model_utils.py index 1444cdbdd2e2..8b33e7b0591b 100644 --- a/llm/alignment/ppo/models/pp_model_utils.py +++ b/llm/alignment/ppo/models/pp_model_utils.py @@ -19,6 +19,21 @@ def fwd_step_patch(func, output, self, *args, **kwargs): + """ + 前向步骤补丁函数,用于处理模型在训练过程中的梯度计算和损失记录。 + 如果当前模型是最后一个阶段并且正在进行训练,则会将输出的梯度记录到self._step_losses列表中。 + 否则,不会对输出进行任何操作。 + + Args: + func (Callable): 被调用的函数,应该是forward函数或者其他需要执行的函数。 + output (Tensor): 模型的输出,应该是一个张量。 + self (Any): 模型实例,可以是nn.Module类型或其他自定义模型类型。 + args (Tuple[Any], optional): 传递给func的可选参数,默认为None。 + kwargs (Dict[str, Any], optional): 传递给func的可选关键字参数,默认为None。 + + Returns: + None, 无返回值,直接修改了self._step_losses属性。 + """ # training patch if self.training and self.is_pipeline_last_stage(): if getattr(self, "_step_losses", None): @@ -28,6 +43,20 @@ def fwd_step_patch(func, output, self, *args, **kwargs): def make_wrapper(func, pre_patch=None, post_patch=None): + """ + 创建一个包装函数,可以在调用原始函数前后执行额外的操作。 + + Args: + func (function): 需要被包装的函数。 + pre_patch (Optional[function], optional): 在调用原始函数前执行的函数,默认为None。 + 函数签名应该是 `pre_patch(func, None, *args, **kwargs)`。 + post_patch (Optional[function], optional): 在调用原始函数后执行的函数,默认为None。 + 函数签名应该是 `post_patch(func, output, *args, **kwargs)`,其中output是原始函数的返回值。 + + Returns: + function: 包装后的函数,具有与原始函数相同的功能,但会在调用前后执行额外的操作。 + """ + def wrapper(*args, **kwargs): if pre_patch is not None: pre_patch(func, None, *args, **kwargs) @@ -39,7 +68,12 @@ def wrapper(*args, **kwargs): return wrapper -funcs = [(paddle.distributed.fleet.model.PipelineParallel._forward_step, fwd_step_patch)] +funcs = [ + ( + paddle.distributed.fleet.model.PipelineParallel._forward_step, + fwd_step_patch, + ) +] for func in funcs: fun, patch = func @@ -64,11 +98,28 @@ def pad_batches_inputs(inputs, padding_value=0, max_len=None, pad_len=None): # if x is None or x.shape[-1] == max_len: if not isinstance(x, paddle.Tensor) or x.shape[-1] == max_len: continue - inputs[i] = paddle.concat([x, paddle.full([x.shape[0], pad_len[i]], padding_value, dtype=x.dtype)], -1) + inputs[i] = paddle.concat( + [ + x, + paddle.full([x.shape[0], pad_len[i]], padding_value, dtype=x.dtype), + ], + -1, + ) return inputs def get_expected_keys(inputs, keys): + """ + 获取预期的键值对,如果输入中存在则返回该键值对,否则返回None。 + 如果键值对只有一个,则将其转换为单个元素。 + + Args: + inputs (dict): 包含多个键值对的字典,用于查找预期的键值对。 + keys (list[str]): 需要查找的键列表。 + + Returns: + Union[tuple, Any]: 如果键值对只有一个,则返回单个元素;否则返回包含所有键值对的元组。如果任何键不存在,则返回None。 + """ ret = tuple([inputs.get(k, None) for k in keys if k in inputs]) if len(ret) == 1: ret = ret[0] @@ -76,6 +127,17 @@ def get_expected_keys(inputs, keys): def fwd_args_to_dict(fun): + """ + 将函数的参数转换为字典,用于支持更多的参数格式在预测流程步骤中。 + 假设没有参数是inspect.Parameter.VAR_KEYWORD。 + + Args: + fun (Callable[[Any, Dict[str, Any]], Any]): 需要转换的函数,其第一个参数是非管道模型类实例,后续参数可以是任意格式的非管道模型前向传输参数,返回值是任意类型。 + + Returns: + Callable[[Any, *Any, **Any], Any]: 返回一个新的函数,接收与原函数相同的参数,但是将所有非self参数转换为字典形式,并作为第二个参数传入原函数。 + """ + def _impl(self, *args, **kwargs): try: return fun(self, *args, **kwargs) diff --git a/llm/alignment/ppo/models/ppo_model.py b/llm/alignment/ppo/models/ppo_model.py index 720009161022..c4fa3de262b8 100644 --- a/llm/alignment/ppo/models/ppo_model.py +++ b/llm/alignment/ppo/models/ppo_model.py @@ -22,6 +22,13 @@ # TODO(guosheng): create Mixin and make model classes using metaclass. class LlamaPolicyModel(LlamaForCausalLM): def __init__(self, config: PretrainedConfig, **kwargs): + """ + Initializes a RLHFPPOMixedLossWrapper instance. + + Args: + config (PretrainedConfig): The model configuration used for initialization. + kwargs (Dict[str, Any], optional): Additional keyword arguments passed along. Defaults to {}. + """ super().__init__(config) self.loss_fn = RLHFPPOMixedLoss(config, **kwargs) @@ -41,6 +48,16 @@ def forward( output_hidden_states=None, return_dict=None, ): + """ + Returns a tuple containing: + 1. the loss, calculated as the sum of the cross entropy for each token and the KL divergence between the + policy distribution and the uniform distribution. If `advantages` are provided, the loss will be + augmented with the additional term -E[log P(a|x)] where x is the input and a is the action. + 2. the model's output as a tuple of: + - the last layer's output of shape `(batch_size, sequence_length, config.vocab_size)` + - the cache used in inference for next chunk. + - the decoder's attention weights for each layer. + """ outputs = super().forward( input_ids=input_ids, position_ids=position_ids, @@ -56,7 +73,10 @@ def forward( logits = outputs[0] loss = None if labels is not None or advantages is not None: - loss = self.loss_fn(logits, (labels, input_ids, log_probs, advantages, sequence_mask)) + loss = self.loss_fn( + logits, + (labels, input_ids, log_probs, advantages, sequence_mask), + ) if not return_dict: return (loss,) + outputs if loss is not None else outputs @@ -71,6 +91,16 @@ def forward( class LlamaValueModel(LlamaModelForScore): def __init__(self, config, **kwargs): + """ + Initializes the RLHFValueLossWrapper instance. + + Args: + config (DictConfig): Config dict for the model. + **kwargs (Any, optional): Keyword arguments to be passed to the parent class. Defaults to None. + + Returns: + None. + """ super().__init__(config, **kwargs) self.loss_fn = RLHFValueLoss(config, **kwargs) @@ -89,6 +119,24 @@ def forward( output_hidden_states=None, return_dict=None, ): + """ + Returns: + Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + if `return_dict` is False, a tuple of tensors is returned, containing: + - the loss, if it is not None; + - the reward values; + - the rewards; + - the past key values; + - the hidden states; + - the attentions. + if `return_dict` is True, a [`ValueOutput`] is returned, containing: + - the loss, if it is not None; + - the reward values; + - the rewards; + - the past key values; + - the hidden states; + - the attentions. + """ outputs = super().forward( input_ids=input_ids, position_ids=position_ids, diff --git a/llm/alignment/ppo/models/ppo_model_utils.py b/llm/alignment/ppo/models/ppo_model_utils.py index da8972cc5c6d..19c7b083031a 100644 --- a/llm/alignment/ppo/models/ppo_model_utils.py +++ b/llm/alignment/ppo/models/ppo_model_utils.py @@ -21,10 +21,13 @@ from typing import Optional, Tuple import paddle -import paddle.nn as nn +import paddle.distributed as dist import paddle.nn.functional as F +from paddle import nn +from paddle.distributed import fleet +from paddle.distributed.fleet.layers.mpu import mp_ops +from paddle.distributed.fleet.meta_parallel import ParallelCrossEntropy -# use LlamaPretrainingCriterion as common PretrainingCriterion from paddlenlp.transformers import LlamaPretrainingCriterion as PretrainingCriterion from paddlenlp.transformers.model_outputs import ModelOutput @@ -71,7 +74,7 @@ def loss_fwd(self, predict, labels): return loss_cls -def create_loss(loss_cls, config, extra_args, merge_labels=None): +def create_loss(loss_cls, config, extra_args, info_buffer, merge_labels=None): """ loss_cls(paddle.nn.Layer): loss class config(PratrainedConfig): model config, to be consistent with loss defined @@ -90,16 +93,32 @@ def create_loss(loss_cls, config, extra_args, merge_labels=None): # forward(self, predict, label1, label2, ...) loss_arg_names = list(inspect.signature(loss_cls.__init__).parameters.keys())[2:] if isinstance(extra_args, dict): - loss_kwargs = dict([(name, extra_args[name]) for name in loss_arg_names if name in extra_args]) + loss_kwargs = {name: extra_args[name] for name in loss_arg_names if name in extra_args} else: # create from TrainingArguments - loss_kwargs = dict([(name, getattr(extra_args, name)) for name in loss_arg_names if hasattr(extra_args, name)]) + loss_kwargs = {name: getattr(extra_args, name) for name in loss_arg_names if hasattr(extra_args, name)} + if "info_buffer" in loss_arg_names: + loss_kwargs["info_buffer"] = info_buffer loss = loss_cls(config, **loss_kwargs) return loss @paddle.no_grad() def make_position_ids(attention_mask, source=None): + """ + 根据attention_mask生成位置id,如果source不为空则将源端padding部分设置为0。 + 当attention_mask的形状是[B, L, H, W]时,表示causal mask,返回的position_ids是[B, H, W]; + 当attention_mask的形状是[B, L]时,表示padding mask,返回的position_ids是[B, L]。 + + Args: + attention_mask (Tensor, numpy.ndarray): 形状为[B, L, H, W]或者[B, L]的Tensor/numpy数组,其中L是序列长度,H是头数,W是宽度(可选)。 + 每个元素为0表示该位置未被mask,非0表示该位置被mask。 + source (Tensor, numpy.ndarray, optional): 形状为[B, S]的Tensor/numpy数组,其中S是源端序列长度(可选)。默认值为None。 + + Returns: + Tensor: 形状为[B, H, W]或者[B, L]的Tensor,其中H是头数,W是宽度(可选)。每个元素为对应位置的位置id。 + 如果source不为空,则在源端padding部分设置为0。 + """ if len(attention_mask.shape) == 4: # causal mask position_ids_p1 = attention_mask.cast(paddle.int64).sum(-1) position_ids = position_ids_p1 - 1 @@ -128,10 +147,34 @@ def make_position_ids(attention_mask, source=None): @paddle.no_grad() -def make_attention_mask(input_ids, pad_id, unk_id=None, past_key_values_length=0, causal_mask=True): +def make_attention_mask( + input_ids, + pad_id, + eos_id=None, + unk_id=None, + past_key_values_length=0, + causal_mask=True, +): + """ + 根据输入的`input_ids`,生成一个注意力掩码。如果`pad_id`不是`unk_id`和`eos_id`中的任何一个,则该位置将被忽略。 + 如果`causal_mask`为`False`,则返回全部为`True`的注意力掩码。否则,返回一个三角形掩码,其中每个元素都小于或等于相应位置的元素。 + + Args: + input_ids (Tensor): 输入序列的ID,形状为(batch_size, seq_len)。 + pad_id (int): 用于padding的ID。 + eos_id (int, optional): 用于表示结束的ID,默认为None。如果设置了,则会从注意力掩码中删除对应位置。 + unk_id (int, optional): 用于表示未知的ID,默认为None。如果设置了,则会从注意力掩码中删除对应位置。 + past_key_values_length (int, optional): 预先存在的键值对的长度,默认为0。 + causal_mask (bool, optional): 是否使用因果掩码,默认为True。 + + Returns: + Tensor: 注意力掩码,形状为(batch_size, 1, seq_len, seq_len + past_len)。 + """ attention_mask = input_ids != pad_id if unk_id is not None and pad_id != unk_id: attention_mask = paddle.logical_and(attention_mask, input_ids != unk_id) + if eos_id is not None and pad_id != eos_id: + attention_mask = paddle.logical_and(attention_mask, input_ids != eos_id) if not causal_mask: return attention_mask @@ -139,7 +182,13 @@ def make_attention_mask(input_ids, pad_id, unk_id=None, past_key_values_length=0 mask = paddle.tril(paddle.ones((target_length, target_length), dtype="bool")) if past_key_values_length > 0: # [tgt_len, tgt_len + past_len] - mask = paddle.concat([paddle.ones([target_length, past_key_values_length], dtype="bool"), mask], axis=-1) + mask = paddle.concat( + [ + paddle.ones([target_length, past_key_values_length], dtype="bool"), + mask, + ], + axis=-1, + ) # [bs, 1, tgt_len, tgt_len + past_len] causal_mask = mask[None, None, :, :].expand([batch_size, 1, target_length, target_length + past_key_values_length]) @@ -157,14 +206,45 @@ def gather_log_probabilities(logits: paddle.Tensor, labels: paddle.Tensor) -> pa class RLHFPPOLoss(nn.Layer): def __init__(self, config, clip_range_ratio=0.2): + """ + Initialize the `ClipRewardRange` object. + + Args: + config (dict): A dictionary containing environment configuration parameters. + See :class:`~rllib.agents.Agent` for more information. + clip_range_ratio (float, optional): The ratio of the range to which the reward is clipped. + Defaults to 0.2. + + Raises: + None. + + Returns: + None. + """ super().__init__() self.clip_range_ratio = clip_range_ratio self.config = config def actor_loss_fn( - self, log_probs: paddle.Tensor, old_log_probs: paddle.Tensor, advantages: paddle.Tensor, mask: paddle.Tensor + self, + log_probs: paddle.Tensor, + old_log_probs: paddle.Tensor, + advantages: paddle.Tensor, + mask: paddle.Tensor, ) -> paddle.Tensor: + """ + 计算演员的策略损失函数。该函数接受以下参数: + Args: + log_probs (paddle.Tensor): 当前状态下每个演员的对数产生概率,形状为[B, A],其中B是批量大小,A是演员数量。 + old_log_probs (paddle.Tensor): 上一时间步骤的每个演员的对数产生概率,形状与log_probs相同。 + advantages (paddle.Tensor): 每个演员在当前状态下获得的价值函数估计值,形状为[B, A]。 + mask (paddle.Tensor): 用于过滤已完成或无效的轨迹,形状为[B, A],其中B是批量大小,A是演员数量。 + 如果轨迹已经完成(即reward不为None),则mask为1;否则为0。 + 返回值 (paddle.Tensor): + PG_loss (paddle.Tensor): 演员的策略损失,形状为[1]。 + """ # policy gradient loss + ratio = paddle.exp(log_probs - old_log_probs) pg_loss1 = -advantages * ratio pg_loss2 = -advantages * paddle.clip( @@ -174,21 +254,24 @@ def actor_loss_fn( ) return paddle.sum(paddle.maximum(pg_loss1, pg_loss2) * mask) / mask.sum() - def forward(self, logits, input_ids, old_log_probs, reward_advantages, sequence_mask): - log_probs = gather_log_probabilities(logits[:, :-1], input_ids[:, 1:]) - if log_probs.shape[1] == old_log_probs.shape[1]: - # labels (old_log_probs, reward_advantages, sequence_mask) has - # src+tgt-1 length, valid length is determined by sequence_mask - pass - elif log_probs.shape[1] < old_log_probs.shape[1]: - # labels (old_log_probs, reward_advantages, sequence_mask) has - # src+tgt length and the last one is a padding to be consistent - # with input_ids - assert log_probs.shape[1] == old_log_probs.shape[1] - 1 - log_probs = paddle.concat([log_probs, paddle.zeros([log_probs.shape[0], 1], dtype=log_probs.dtype)], -1) - else: - # labels (old_log_probs, reward_advantages, sequence_mask) has tgt length - log_probs = log_probs[:, -old_log_probs.shape[1] :] + def forward(self, log_probs, old_log_probs, reward_advantages, sequence_mask): + """ + Calculate the loss of the actor network. + + Args: + logits (Tensor, shape [batch_size, seq_len, vocab_size]): The output logits of the model. + input_ids (Tensor, shape [batch_size, seq_len]): The input ids of the batch. + old_log_probs (Tensor, shape [batch_size, seq_len]): The previous log probabilities of the batch. + reward_advantages (Tensor, shape [batch_size, seq_len]): The rewards or advantages of the batch. + sequence_mask (Tensor, shape [batch_size, seq_len]): A mask indicating which elements are valid. + Valid elements are those where sequence_mask is True. + + Returns: + Tensor, shape [1], the loss of the actor network. + + Raises: + None. + """ actor_loss = self.actor_loss_fn( log_probs, old_log_probs, @@ -202,13 +285,57 @@ def forward(self, logits, input_ids, old_log_probs, reward_advantages, sequence_ class RLHFPPOMixedLoss(nn.Layer): """provide two losses, one for PPO loss, the other for SFT loss.""" - def __init__(self, config, ptx_coeff=16, clip_range_ratio=0.2): + def __init__( + self, config, ptx_coeff=16, clip_range_ratio=0.2, kl_loss_coeff=0.001, clip_range_score=10, info_buffer=None + ): + """ + Args: + config (Config): configuration object containing hyperparameters and options for the agent. + ptx_coeff (int, optional): coefficient to use in the PTX loss calculation. Defaults to 16. + clip_range_ratio (float, optional): ratio of clipped range to unclipped range. Defaults to 0.2. + """ super(RLHFPPOMixedLoss, self).__init__() + self.config = config self.ptx_coeff = ptx_coeff + # if self.config.use_fused_head_and_loss_fn: + # self.ppo_criterion = FusedPPOLoss(config, clip_range_ratio) + # else: + # self.ppo_criterion = RLHFPPOLoss(config, clip_range_ratio) self.ppo_criterion = RLHFPPOLoss(config, clip_range_ratio) self.sft_criterion = PretrainingCriterion(config) + self.kl_loss_coeff = kl_loss_coeff + self.clip_range_score = clip_range_score + self.info_buffer = info_buffer - def forward(self, logits, labels, input_ids, old_log_probs, reward_advantages, sequence_mask): + def forward( + self, + logits, + labels, + input_ids, + old_log_probs, + reward_advantages, + sequence_mask, + ref_log_probs=None, + ): + """ + 计算损失函数,包含两部分:soft target loss和PPO loss。 + 如果labels不为None,则计算soft target loss;否则计算PPO loss。 + + Args: + logits (paddle.Tensor or List[paddle.Tensor]): 输入的预测结果,可以是单个tensor或list中的多个tensor。 + 如果是单个tensor,表示对应的输出logits;如果是list,表示每个时间步的logits。 + labels (paddle.Tensor, optional): 真实标签,shape与logits相同。默认为None。 + input_ids (paddle.Tensor, optional): 输入序列的id,shape为(batch_size, max_len)。默认为None。 + old_log_probs (paddle.Tensor, optional): 上一个时间步的log probabilities,shape为(batch_size, max_len)。默认为None。 + reward_advantages (paddle.Tensor, optional): 回报优势,shape为(batch_size, max_len)。默认为None。 + sequence_mask (paddle.Tensor, optional): 序列掩码,shape为(batch_size, max_len)。默认为None。 + + Returns: + paddle.Tensor: 返回损失函数,如果labels不为None,则为soft target loss;否则为PPO loss。 + """ + + if not self.config.use_fused_head_and_loss_fn: + logits = logits if isinstance(logits, paddle.Tensor) else logits[0] logits = logits if isinstance(logits, paddle.Tensor) else logits[0] loss = None # sft, pt loss @@ -216,7 +343,47 @@ def forward(self, logits, labels, input_ids, old_log_probs, reward_advantages, s loss = self.ptx_coeff * self.sft_criterion(logits, labels) # ppo loss if reward_advantages is not None: - loss = self.ppo_criterion(logits, input_ids, old_log_probs, reward_advantages, sequence_mask) + if self.config.tensor_parallel_degree > 1 and self.config.tensor_parallel_output: + log_probs = ( + -ParallelCrossEntropy()(logits[:, :-1].astype("float32"), input_ids[:, 1:]) + .squeeze(axis=-1) + .astype(logits.dtype) + ) + else: + log_probs = gather_log_probabilities(logits[:, :-1], input_ids[:, 1:]) + if log_probs.shape[1] == old_log_probs.shape[1]: + # labels (old_log_probs, reward_advantages, sequence_mask) has + # src+tgt-1 length, valid length is determined by sequence_mask + pass + elif log_probs.shape[1] < old_log_probs.shape[1]: + # labels (old_log_probs, reward_advantages, sequence_mask) has + # src+tgt length and the last one is a padding to be consistent + # with input_ids + assert log_probs.shape[1] == old_log_probs.shape[1] - 1 + log_probs = paddle.concat( + [ + log_probs, + paddle.zeros([log_probs.shape[0], 1], dtype=log_probs.dtype), + ], + -1, + ) + else: + # labels (old_log_probs, reward_advantages, sequence_mask) has tgt length + log_probs = log_probs[:, -old_log_probs.shape[1] :] + + # TODO:support fused head and loss fn + loss = self.ppo_criterion(log_probs, old_log_probs, reward_advantages, sequence_mask) + + if ref_log_probs is not None: + kl_divergence_estimate = paddle.clip( + paddle.exp(ref_log_probs - log_probs) - (ref_log_probs - log_probs) - 1, + min=-self.clip_range_score, + max=self.clip_range_score, + ) + kl_loss = paddle.sum(kl_divergence_estimate * sequence_mask) / sequence_mask.sum() + self.info_buffer["kl_loss"] = kl_loss.detach() + self.info_buffer["pure_policy_loss"] = loss.detach() + loss += self.kl_loss_coeff * kl_loss return loss @@ -224,6 +391,20 @@ def forward(self, logits, labels, input_ids, old_log_probs, reward_advantages, s @merge_fwd_labels class RLHFValueLoss(nn.Layer): def __init__(self, config, clip_range_value=5.0): + """ + Initializes the `ClipRewardRange` object. + + Args: + config (dict): The configuration dictionary for the environment. + See :ref:`rllib-spaces` for more information. + clip_range_value (Optional[float]): The value to which the rewards will be clipped. Defaults to 5.0. + + Raises: + None. + + Returns: + None. + """ super().__init__() self.clip_range_value = clip_range_value self.config = config @@ -239,13 +420,32 @@ def critic_loss_fn( # TODO(guosheng): use paddle.clip when its min/max can support more than # 0D Tensor values_clipped = paddle.minimum( - paddle.maximum(values, old_values - self.clip_range_value), old_values + self.clip_range_value + paddle.maximum(values, old_values - self.clip_range_value), + old_values + self.clip_range_value, ) vf_loss1 = paddle.square(values - returns) vf_loss2 = paddle.square(values_clipped - returns) return 0.5 * paddle.sum(paddle.maximum(vf_loss1, vf_loss2) * mask) / mask.sum() def forward(self, reward_values, old_reward_values, reward_returns, sequence_mask): + """ + 计算奖励值的损失函数。 + 如果输入的奖励值和旧奖励值的长度相同,则使用给定的序列掩码来确定有效长度。 + 如果输入的奖励值的长度比旧奖励值少一个,则将最后一个元素视为与输入IDs一致的填充,并删除它。 + 否则,奖励值只有tgt长度。 + + Args: + reward_values (paddle.Tensor, list of paddle.Tensor or None, optional): 奖励值,可以是单个张量或列表中的多个张量。默认为None。 + old_reward_values (paddle.Tensor, optional): 旧奖励值。 + reward_returns (paddle.Tensor, optional): 奖励返回值。 + sequence_mask (paddle.Tensor, optional): 序列掩码。 + + Returns: + paddle.Tensor, float32: 奖励值的损失函数。 + + Raises: + ValueError: 当奖励值和旧奖励值的长度不匹配时引发。 + """ reward_values = reward_values if isinstance(reward_values, paddle.Tensor) else reward_values[0] reward_values = reward_values.squeeze(axis=-1)[:, :-1] if reward_values.shape[1] == old_reward_values.shape[1]: @@ -258,7 +458,11 @@ def forward(self, reward_values, old_reward_values, reward_returns, sequence_mas # with input_ids assert reward_values.shape[1] == old_reward_values.shape[1] - 1 reward_values = paddle.concat( - [reward_values, paddle.zeros([reward_values.shape[0], 1], dtype=reward_values.dtype)], -1 + [ + reward_values, + paddle.zeros([reward_values.shape[0], 1], dtype=reward_values.dtype), + ], + -1, ) else: # labels (old_reward_values, reward_returns, sequence_mask) has @@ -272,3 +476,317 @@ def forward(self, reward_values, old_reward_values, reward_returns, sequence_mas ) return reward_critic_loss + + +class ActorFusedLoss(paddle.autograd.PyLayer): + """Fused Actor Loss""" + + @staticmethod + def forward( + ctx, + hidden_states: paddle.Tensor, + lm_head_weight: paddle.Tensor, + lm_head_bias: paddle.Tensor, + labels: paddle.Tensor, + mask: paddle.Tensor, + transpose_y: bool, + num_embeddings: int, + tensor_parallel_degree: int, + tensor_parallel_output: bool, + fused_linear: bool, + loop_chunk_size: int, + ignore_index: int, + old_log_probs: paddle.Tensor, + advantages: paddle.Tensor, + clip_range_ratio: float, + ): + """ + forward function of ActorFusedLoss + + Args: + ctx (paddle.autograd.PyLayerContext): context. + hidden_states (paddle.Tensor): hidden_states, [batch_size, seq_len-1, hidden_size]. + lm_head_weight (paddle.Tensor): lm_head_weight, [hidden_size, vocab_size / tensor_parallel_degree]. + lm_head_bias (paddle.Tensor, optional): lm_head_bias, [vocab_size / tensor_parallel_degree]. + labels (paddle.Tensor): labels, [batch_size, seq_len-1]. + mask (paddle.Tensor): mask, [batch_size, seq_len-1]. + transpose_y (bool): whether to transpose lm_head_weight. + num_embeddings (int): vocab_size. + tensor_parallel_degree (int): tensor_parallel_degree. + tensor_parallel_output (bool): tensor_parallel_output, set True in ppo_main.py. + fused_linear (bool): Flag for using fused linear, always False. + loop_chunk_size (int): chunk_size. + ignore_index (int): not used now. + old_log_probs (paddle.Tensor): old_log_probs, [batch_size, seq_len-1]. + advantages (paddle.Tensor): advantages, [batch_size, seq_len-1]. + clip_range_ratio (float): The clipping range for ratio. + + Returns: + paddle.Tensor: loss + + """ + if fused_linear: + # print("Cannot support fused_linear while using use_fused_head_and_loss_fn now!") + fused_linear = False + if tensor_parallel_degree > 1: + assert tensor_parallel_output, ( + "When tensor_parallel_degree > 1 and use_fused_head_and_loss_fn, " + "tensor_parallel_output needs to be set to True." + ) + dtype = hidden_states.dtype + # Parallel Configuration + if tensor_parallel_degree > 1 and tensor_parallel_output: + hcg = fleet.get_hybrid_communicate_group() + model_parallel_group = hcg.get_model_parallel_group() + tensor_parallel_degree = hcg.get_model_parallel_world_size() + + # reshape + original_shape = hidden_states.shape + hidden_states_stop_grad = hidden_states.stop_gradient # original stop_gradient + hidden_states = hidden_states.reshape([-1, original_shape[-1]]) + labels = labels.reshape([-1]) + old_log_probs = old_log_probs.reshape([-1]) + advantages = advantages.reshape([-1]) + loss_mask = mask.reshape([-1]).astype("float32") # .astype(dtype) + + n_tokens = hidden_states.shape[0] + n_classes = lm_head_weight.shape[0] if transpose_y else lm_head_weight.shape[1] + + # convert dtype of weights and biases of lm_head + lm_head_weight_cast = lm_head_weight.astype(dtype) + if lm_head_bias is not None: + lm_head_bias_cast = lm_head_bias.astype(dtype) + + # use indices to distinguish the devices. + if tensor_parallel_degree > 1 and tensor_parallel_output: + rank = hcg.get_model_parallel_rank() + per_part_size = num_embeddings // tensor_parallel_degree + indices = paddle.arange( + rank * per_part_size, + rank * per_part_size + n_classes, + dtype=labels.dtype, + ).unsqueeze(0) + else: + indices = paddle.arange(num_embeddings, dtype=labels.dtype).unsqueeze(0) + + # initialize total_loss and divisor + total_loss = paddle.zeros([1], dtype=dtype) + divisor = loss_mask.sum() + + # initialize grads + if not lm_head_weight.stop_gradient: + grad_lm_head_weight = paddle.zeros_like(lm_head_weight) + else: + grad_lm_head_weight = None + if lm_head_weight is not None and not lm_head_weight.stop_gradient: + grad_lm_head_bias = paddle.zeros_like(lm_head_bias) + else: + grad_lm_head_bias = None + if not hidden_states_stop_grad: + grad_hidden_states = paddle.zeros_like(hidden_states) + else: + grad_hidden_states = None + + for i in range(0, n_tokens, loop_chunk_size): + token_start_idx = i + token_end_idx = min(i + loop_chunk_size, n_tokens) + hidden_states_chunk = hidden_states[token_start_idx:token_end_idx] + labels_chunk = labels[token_start_idx:token_end_idx] + old_log_probs_chunk = old_log_probs[token_start_idx:token_end_idx] + advantages_chunk = advantages[token_start_idx:token_end_idx] + mask_chunk = loss_mask[token_start_idx:token_end_idx] + + # Calculate the current logits_chunk, not fused linear + logits_chunk_cast = paddle.matmul(hidden_states_chunk, lm_head_weight_cast, transpose_y=transpose_y) + if lm_head_bias is not None: + logits_chunk_cast += lm_head_bias_cast + # logits_chunk_cast = paddle.nn.functional.linear(hidden_states_chunk, lm_head_weight_cast, lm_head_bias) + + logits_chunk = logits_chunk_cast.astype("float32") + labels_one_hot = labels_chunk.unsqueeze(1) == indices + # rewritten as cross entropy + if tensor_parallel_degree > 1 and tensor_parallel_output: + token_loss_chunk, softmax_output_chunk = mp_ops._c_softmax_with_cross_entropy( + logits_chunk, + labels_chunk, + group=model_parallel_group, + return_softmax=True, + ) + else: + token_loss_chunk = F.cross_entropy(logits_chunk, labels_chunk, reduction="none") + softmax_output_chunk = F.softmax(logits_chunk, axis=-1) + + log_probs_chunk = -token_loss_chunk.squeeze(axis=-1) + # calculate gradient, note sign + grad_logits_chunk = labels_one_hot.astype("float32") - softmax_output_chunk + grad_logits_chunk = grad_logits_chunk.astype(dtype) + + # ratio + ratio_chunk = paddle.exp(log_probs_chunk - old_log_probs_chunk) + clipped_ratio_chunk = paddle.clip(ratio_chunk, min=1.0 - clip_range_ratio, max=1.0 + clip_range_ratio) + + # final loss + pg_loss1_chunk = -advantages_chunk * ratio_chunk + pg_loss2_chunk = -advantages_chunk * clipped_ratio_chunk + pg_loss_chunk = paddle.maximum(pg_loss1_chunk, pg_loss2_chunk) + + # mask + pg_loss_chunk = pg_loss_chunk * mask_chunk + masked_loss_sum = paddle.sum(pg_loss_chunk) + # add + total_loss += masked_loss_sum + + # grads + # direction + I1_chunk = (pg_loss1_chunk >= pg_loss2_chunk).astype(dtype) + I2_chunk = 1.0 - I1_chunk + + # clip + clip_mask_chunk = ( + (ratio_chunk >= 1.0 - clip_range_ratio) & (ratio_chunk <= 1.0 + clip_range_ratio) + ).astype(dtype) + + # ∂loss1/∂log_probs, ∂loss2/∂log_probs + d_ratio_d_log_probs_chunk = ratio_chunk + d_pg_loss1_d_log_probs_chunk = -advantages_chunk * d_ratio_d_log_probs_chunk + d_pg_loss2_d_log_probs_chunk = -advantages_chunk * clip_mask_chunk * d_ratio_d_log_probs_chunk + + # ∂loss/∂log_probs + d_loss_d_log_probs_chunk = ( + I1_chunk * d_pg_loss1_d_log_probs_chunk + I2_chunk * d_pg_loss2_d_log_probs_chunk + ) + d_loss_d_log_probs_chunk = d_loss_d_log_probs_chunk * mask_chunk / divisor + + # ∂log_probs/∂logits, just take the previous one. + d_log_probs_d_logits_chunk = grad_logits_chunk + # ∂loss/∂logits + d_loss_d_logits_chunk = d_loss_d_log_probs_chunk.unsqueeze(-1) * d_log_probs_d_logits_chunk + + # grads + if grad_hidden_states is not None: + grad_hidden_states[token_start_idx:token_end_idx] = paddle.matmul( + d_loss_d_logits_chunk, lm_head_weight_cast, transpose_y=not transpose_y + ) + if grad_lm_head_weight is not None: + if transpose_y: + grad_lm_head_weight += paddle.matmul(d_loss_d_logits_chunk, hidden_states_chunk, transpose_x=True) + else: + grad_lm_head_weight += paddle.matmul(hidden_states_chunk, d_loss_d_logits_chunk, transpose_x=True) + if grad_lm_head_bias is not None: + grad_lm_head_bias += d_loss_d_logits_chunk.astype("float32").sum(axis=0).astype(dtype) + + final_loss = total_loss / divisor + ctx.hidden_states_has_grad = grad_hidden_states is not None + ctx.lm_head_weight_has_grad = grad_lm_head_weight is not None + ctx.lm_head_bias_has_grad = grad_lm_head_bias is not None + + grad_args = [] + if ctx.hidden_states_has_grad: + if tensor_parallel_degree > 1: + dist.all_reduce(grad_hidden_states, op=dist.ReduceOp.SUM, group=model_parallel_group) + grad_args.append(grad_hidden_states.reshape(original_shape)) + if ctx.lm_head_weight_has_grad: + grad_args.append(grad_lm_head_weight) + if ctx.lm_head_bias_has_grad: + grad_args.append(grad_lm_head_bias) + + ctx.save_for_backward(*grad_args) + return final_loss + + @staticmethod + def backward(ctx, grad_output): + """ + backward function of ActorFusedLoss + + Args: + ctx: Context. + grad_output(paddle.Tensor): Gradient. + Returns: + tuple: + - Gradient tensors for hidden_states, lm_head_weight, and lm_head_bias, + None values are used for inputs not requiring gradients. + """ + grad_args = ctx.saved_tensor() + idx = 0 + if ctx.hidden_states_has_grad: + grad_hidden_states = grad_args[idx] * grad_output.astype(grad_args[idx].dtype) + idx += 1 + else: + grad_hidden_states = None + + if ctx.lm_head_weight_has_grad: + grad_lm_head_weight = grad_args[idx] * grad_output.astype(grad_args[idx].dtype) + idx += 1 + else: + grad_lm_head_weight = None + + if ctx.lm_head_bias_has_grad: + grad_lm_head_bias = grad_args[idx] * grad_output.astype(grad_args[idx].dtype) + idx += 1 + else: + grad_lm_head_bias = None + return grad_hidden_states, grad_lm_head_weight, grad_lm_head_bias, None, None + + +class FusedPPOLoss(nn.Layer): + """Fused PPOLoss""" + + def __init__(self, config, clip_range_ratio=0.2): + """Initialize FusedPPOLoss class.""" + super().__init__() + self.clip_range_ratio = clip_range_ratio + self.config = config + + def forward( + self, + hidden_states: paddle.Tensor, + lm_head_weight: paddle.Tensor, + lm_head_bias: paddle.Tensor, + input_ids: paddle.Tensor, + old_log_probs: paddle.Tensor, + reward_advantages: paddle.Tensor, + sequence_mask: paddle.Tensor, + transpose_y: bool, + ): + """ + forward function of FusedPPOLoss + + Args: + hidden_states (paddle.Tensor): hidden_states, [batch_size, seq_len, hidden_size]. + lm_head_weight (paddle.Tensor): lm_head_weight, [hidden_size, vocab_size / tensor_parallel_degree]. + lm_head_bias (paddle.Tensor, optional): lm_head_bias, [vocab_size / tensor_parallel_degree]. + input_ids (paddle.Tensor): input_ids, [batch_size, seq_len]. + old_log_probs (paddle.Tensor): old_log_probs, [batch_size, seq_len-1]. + reward_advantages (paddle.Tensor): advantages, [batch_size, seq_len-1]. + sequence_mask (paddle.Tensor): mask, [batch_size, seq_len-1]. + transpose_y (bool): whether to transpose lm_head_weight. + + Returns: + paddle.Tensor: loss + + """ + logits_next = hidden_states[:, :-1, :] + labels_next = input_ids[:, 1:] + + if old_log_probs.shape[1] != labels_next.shape[1]: + # labels(old_log_probs,reward_advantages,sequence_mask)的长度为 src + tgt - 1,实际长度由 sequence_mask 确定 + raise ValueError("old_log_probs and reward_advantages should have the same length") + + actor_loss = ActorFusedLoss.apply( + hidden_states=logits_next, + lm_head_weight=lm_head_weight, + lm_head_bias=lm_head_bias, + labels=labels_next, + mask=sequence_mask, + transpose_y=transpose_y, + num_embeddings=self.config.vocab_size, + tensor_parallel_degree=self.config.tensor_parallel_degree, + tensor_parallel_output=self.config.tensor_parallel_output, + fused_linear=False, + loop_chunk_size=1024, # 128, + ignore_index=0, + old_log_probs=old_log_probs, + advantages=reward_advantages, + clip_range_ratio=self.clip_range_ratio, + ) + return actor_loss diff --git a/llm/alignment/ppo/models/score_model.py b/llm/alignment/ppo/models/score_model.py index aa0f50977945..e158146a427e 100644 --- a/llm/alignment/ppo/models/score_model.py +++ b/llm/alignment/ppo/models/score_model.py @@ -39,6 +39,17 @@ class LlamaModelForScore(ScoreModelMixin, LlamaPretrainedModel): _keys_to_ignore_on_load_missing = ["lm_head.weight"] def __init__(self, config: PretrainedConfig, **kwargs: Any) -> None: + """ + Initializes a `LlamaForSequenceClassification` model. + + Args: + config (PretrainedConfig): Model configuration class with all the parameters of the model. + kwargs (Any, optional): Additional keyword arguments passed along to the `__init__` of the parent class. + This is necessary because of how `transformers.AutoModelWithHead` is designed. Defaults to `None`. + + Raises: + TypeError: If the config is not an instance of `PretrainedConfig`. + """ super().__init__(config) self.llama = LlamaModel(config) @@ -46,15 +57,46 @@ def __init__(self, config: PretrainedConfig, **kwargs: Any) -> None: self.init_score_head(config, hidden_size=config.hidden_size, **kwargs) def get_input_embeddings(self) -> nn.Embedding: + """ + 返回输入嵌入的nn.Embedding对象,该对象用于将输入序列转换为嵌入向量。 + 如果模型没有使用嵌入,则返回None。 + + Returns: + Optional[nn.Embedding]: 输入嵌入的nn.Embedding对象,或者None(如果没有使用嵌入)。 + """ return self.llama.embed_tokens def set_input_embeddings(self, value: nn.Embedding) -> None: + """ + Set the input embeddings to be used for the model. + + Args: + value (nn.Embedding): The embedding layer to use. + + Returns: + NoneType: No return value is returned. Instead, the input embeddings are updated in-place. + """ self.llama.embed_tokens = value def get_decoder(self) -> PretrainedModel: + """ + 获取解码器模型。 + + Returns: + PretrainedModel (Pytorch): 返回解码器模型,类型为Pytorch的PretrainedModel。 + """ return self.llama def set_decoder(self, decoder: PretrainedModel) -> None: + """ + 设置解码器,用于进行文本生成。 + + Args: + decoder (PretrainedModel): 预训练的模型对象,需要是一个有效的解码器。 + + Returns: + None; 无返回值。 + """ self.llama = decoder def forward( # pylint: disable=too-many-arguments @@ -69,6 +111,36 @@ def forward( # pylint: disable=too-many-arguments output_hidden_states: bool | None = None, return_dict: bool | None = None, ) -> tuple[paddle.Tensor, paddle.Tensor] | ScoreModelOutput: + """ + 句子的前向传播过程。 + + Args: + input_ids (paddle.Tensor): + 输入序列的ID,形状为(batch_size, sequence_length)。 + attention_mask (paddle.Tensor): + 用于区分padding和非padding元素的mask,形状为(batch_size, sequence_length),值为0或1。 + position_ids (paddle.Tensor, optional): + input_ids对应的位置ID,形状为(batch_size, sequence_length),默认为None。 + past_key_values (list[paddle.Tensor], optional): + 包含所有预处理器的键和值,默认为None。 + inputs_embeds (paddle.Tensor, optional): + 输入序列的嵌入,形状为(batch_size, sequence_length, embedding_dimension),默认为None。 + use_cache (bool, optional): + 是否使用缓存,默认为None。 + output_attentions (bool, optional): + 是否返回注意力张量,默认为None。 + output_hidden_states (bool, optional): + 是否返回隐藏状态,默认为None。 + return_dict (bool, optional): + 是否返回字典格式的结果,默认为None。 + + Returns: + tuple[paddle.Tensor, paddle.Tensor] or ScoreModelOutput: + 如果`return_dict`为True,则返回一个ScoreModelOutput类型的元组,其中包含两个元素:得分和附加信息;否则,返回一个tuple,其中包含得分和附加信息。 + Raises: + AssertionError: + 当`attention_mask`不为None时引发。 + """ assert attention_mask is not None output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -97,6 +169,17 @@ def forward( # pylint: disable=too-many-arguments @classmethod def _get_name_mappings(cls, config: LlamaConfig) -> list[StateDictNameMapping]: + """ + 获取模型的名称映射列表,包括模型参数和额外的映射。 + 如果配置中没有"LlamaModel",则将基础模型前缀添加到每个映射中。 + + Args: + config (LlamaConfig): 配置对象,其中包含模型参数。 + + Returns: + list[StateDictNameMapping]: 一个包含模型参数和额外映射的名称映射列表。 + 每个元素是一个三元组(原始名称,转换后的名称,转换类型),其中转换类型可以为None、"transpose"或者"add_prefix"。 + """ mappings: list[StateDictNameMapping] = [] model_mappings = [ ["embed_tokens.weight"], @@ -104,13 +187,37 @@ def _get_name_mappings(cls, config: LlamaConfig) -> list[StateDictNameMapping]: ] for layer_index in range(config.num_hidden_layers): layer_mappings = [ - [f"layers.{layer_index}.self_attn.q_proj.weight", None, "transpose"], - [f"layers.{layer_index}.self_attn.k_proj.weight", None, "transpose"], - [f"layers.{layer_index}.self_attn.v_proj.weight", None, "transpose"], - [f"layers.{layer_index}.self_attn.o_proj.weight", None, "transpose"], + [ + f"layers.{layer_index}.self_attn.q_proj.weight", + None, + "transpose", + ], + [ + f"layers.{layer_index}.self_attn.k_proj.weight", + None, + "transpose", + ], + [ + f"layers.{layer_index}.self_attn.v_proj.weight", + None, + "transpose", + ], + [ + f"layers.{layer_index}.self_attn.o_proj.weight", + None, + "transpose", + ], [f"layers.{layer_index}.self_attn.rotary_emb.inv_freq"], - [f"layers.{layer_index}.mlp.gate_proj.weight", None, "transpose"], - [f"layers.{layer_index}.mlp.down_proj.weight", None, "transpose"], + [ + f"layers.{layer_index}.mlp.gate_proj.weight", + None, + "transpose", + ], + [ + f"layers.{layer_index}.mlp.down_proj.weight", + None, + "transpose", + ], [f"layers.{layer_index}.mlp.up_proj.weight", None, "transpose"], [f"layers.{layer_index}.input_layernorm.weight"], [f"layers.{layer_index}.post_attention_layernorm.weight"], diff --git a/llm/alignment/ppo/models/score_model_utils.py b/llm/alignment/ppo/models/score_model_utils.py index 5515d56fbc20..a3dd5ccd1c43 100644 --- a/llm/alignment/ppo/models/score_model_utils.py +++ b/llm/alignment/ppo/models/score_model_utils.py @@ -17,7 +17,6 @@ from __future__ import annotations import importlib -import io import json from abc import abstractmethod from collections import OrderedDict @@ -50,8 +49,28 @@ class AutoModelForScore(_BaseAutoModelClass): @classmethod def _get_model_class_from_config(cls, pretrained_model_name_or_path, config_file_path, config=None): + """ + Get the model class from config file path. If no config is provided, it will try to load the config from + the given config_file_path. The config should be a dict containing at least one of the following keys: + - "architectures": A list of strings indicating the architecture names. The last element in the list will + be used as the model class. + - "init_class": A string indicating the model class name. + + Args: + pretrained_model_name_or_path (str): The pretrained model name or path. This argument is only used to + infer the model class when no config is provided. + config_file_path (str): The path to the config file. + config (Optional, dict, optional): The config dictionary. Defaults to None. + + Raises: + AttributeError: Unable to parse 'architectures' or 'init_class' from config_file_path. Also unable to + infer model class from 'pretrained_model_name_or_path'. + + Returns: + type: The model class. + """ if config is None: - with io.open(config_file_path, encoding="utf-8") as f: + with open(config_file_path, encoding="utf-8") as f: config = json.load(f) # Get class name corresponds to this configuration @@ -72,7 +91,8 @@ def _get_model_class_from_config(cls, pretrained_model_name_or_path, config_file break if model_name is None: raise AttributeError( - f"Unable to parse 'architectures' or 'init_class' from {config_file_path}. Also unable to infer model class from 'pretrained_model_name_or_path'" + f"Unable to parse 'architectures' or 'init_class' from {config_file_path}." + "Also unable to infer model class from 'pretrained_model_name_or_path'" ) init_class = cls._name_mapping[model_name + "_Import_Class"] # module_name = cls._name_mapping[init_class] @@ -84,6 +104,30 @@ def _get_model_class_from_config(cls, pretrained_model_name_or_path, config_file @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + """ + Instantiate a PreTrainedModel from a pre-trained model file. + + Args: + pretrained_model_name_or_path (str): + This can be either: + - a string with the name of a pre-trained model to load selected in the list of: + . `bert-base-uncased` + . `bert-large-uncased` + . `bert-base-cased` + . `bert-large-cased` + . `bert-base-multilingual-uncased` + . `bert-base-multilingual-cased` + . `bert-base-chinese` + - a path to a pre-trained model configuration file (e.g., `.json`) or a key to find such + file in a PyTorch state dictionary (e.g., returned by torch.save()). + *model_args (tuple): + Additional positional arguments that will be passed to the underlying model's `__init__` method. + **kwargs (dict): + Additional keyword arguments that will be passed to the underlying model's `__init__` method. + + Returns: + PreTrainedModel: A model with weights loaded from the specified pre-trained model file. + """ return cls._from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) @@ -107,7 +151,7 @@ class ScoreModelMixin: """Base class for score models.""" score_head: nn.Linear - normalizer: Normalizer + # normalizer: Normalizer do_normalize: bool = False normalize_function: NormalizeFunction = "affine" _initialized: bool = False @@ -142,7 +186,11 @@ def init_score_head(self, config: PretrainedConfig, hidden_size: int, **kwargs: "normalizer_type", getattr(config, "normalizer_type", None), ) - if config.normalizer_type not in {"RunningMeanStd", "ExponentialMovingAverage", None}: + if config.normalizer_type not in { + "RunningMeanStd", + "ExponentialMovingAverage", + None, + }: raise ValueError( f"Invalid norm type: {config.normalizer_type}." "Expected one of 'RunningMeadStd', 'ExponentialMovingAverage', or None.", @@ -174,6 +222,8 @@ def get_score( ) -> ScoreModelOutput: """Forward pass of the score model.""" scores = self.score_head(hidden_state) # size = (B, L, D) + if scores.dtype != hidden_state.dtype: # EB rm cast to float32 + scores = scores.cast(hidden_state.dtype) if position_ids is not None: first_pos = paddle.arange(hidden_state.shape[0]).unsqueeze(-1) @@ -205,7 +255,6 @@ def get_score( end_score = paddle.stack(end_score, axis=0) # size = (B, D) if self.training and self.do_normalize: - if dist.is_initialized(): gathered_end_score_list = [] try: @@ -235,6 +284,19 @@ def get_score( ) def set_normalize(self, mode: bool = True) -> None: + """ + 设置是否对输入数据进行归一化处理,默认为True。 + 如果mode为True,则对输入数据进行归一化处理;如果mode为False,则不对输入数据进行归一化处理。 + + Args: + mode (bool, optional): 是否对输入数据进行归一化处理,默认为True. Defaults to True. + + Returns: + None: 无返回值,直接修改了实例的do_normalize属性和config中的do_normalize属性。 + + Raises: + None: 没有异常抛出。 + """ if self.do_normalize == mode: return @@ -333,7 +395,7 @@ def instantiate( shape: tuple[int, ...], device: str | None = None, **kwargs: Any, - ) -> Normalizer: + ): """Get a normalizer.""" if normalizer_type == "RunningMeanStd": return RunningMeanStd( @@ -395,6 +457,13 @@ def __init__( device: str | None = None, momentum: float = 0.9, ) -> None: + """ + Args: + normalize_function (NormalizeFunction): Function to normalize the input tensor. + shape (tuple[int, ...]): Shape of the output tensor. + device (str, optional): Device where the tensor will be allocated. Defaults to None. + momentum (float, optional): Momentum for the moving average. Defaults to 0.9. + """ super().__init__(normalize_function, shape=shape, device=device) self.momentum = momentum diff --git a/llm/alignment/ppo/ppo_trainer.py b/llm/alignment/ppo/ppo_trainer.py index bdec462411e0..ec1253d89e85 100644 --- a/llm/alignment/ppo/ppo_trainer.py +++ b/llm/alignment/ppo/ppo_trainer.py @@ -11,23 +11,32 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import contextlib import copy import itertools +import json import math import os import sys import time +import types +import uuid +from collections import defaultdict from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import numpy as np import paddle import paddle.distributed as dist -import paddle.nn as nn -from comm_utils import ( # noqa - cleanup_tensor_space, +import requests +from comm_utils import ( + ActorStages, + CriticStages, + RolloutStages, create_data_trans_group, data_group_merge, data_group_split, + get_timer_label, + new_timer_log, offload_tensor_to_cpu, reload_tensor_to_gpu, ) @@ -40,7 +49,9 @@ make_attention_mask, make_position_ids, ) +from paddle import nn from paddle.distributed import fleet +from paddle.distributed.fleet.meta_parallel import ParallelCrossEntropy from paddle.io import DataLoader, Dataset, DistributedBatchSampler from paddle.utils import map_structure from rich.console import Console @@ -51,6 +62,7 @@ batch_retokenize, guard_set_args, is_same_tokenizer, + process_row, ) from paddlenlp.data import DataCollator @@ -65,8 +77,16 @@ logger, speed_metrics, ) -from paddlenlp.transformers import PretrainedModel, PretrainedTokenizer -from paddlenlp.utils import empty_device_cache +from paddlenlp.trainer.trainer_utils import TrainOutput +from paddlenlp.trainer.utils import distributed_concat +from paddlenlp.transformers import ( + CosineAnnealingWithWarmupDecay, + LinearAnnealingWithWarmupDecay, + PretrainedModel, + PretrainedTokenizer, +) +from paddlenlp.transformers.model_utils import _add_variant +from paddlenlp.utils.env import PADDLE_WEIGHTS_NAME class StepTrainer(Trainer): @@ -96,7 +116,7 @@ def __init__( compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, callbacks: Optional[List[TrainerCallback]] = None, optimizers: Tuple[paddle.optimizer.Optimizer, paddle.optimizer.lr.LRScheduler] = (None, None), - preprocess_logits_for_metrics: Callable[[paddle.Tensor, paddle.Tensor], paddle.Tensor] = None, + preprocess_logits_for_metrics: Optional[Callable[[paddle.Tensor, paddle.Tensor], paddle.Tensor]] = None, ): super().__init__( model, @@ -113,6 +133,7 @@ def __init__( ) # criterion is only used for non-PipelineParallel models. criterion is # included in model for PipelineParallel. + self.info_buffer = {} if getattr(self, "loss_cls", None) and self.criterion is None: self.criterion = self.create_criterion() @@ -122,6 +143,8 @@ def __init__( self.shard_ema = getattr(args, "shard_ema", False) self.offload_ema = getattr(args, "offload_ema", True) self.ema_beta = getattr(args, "ema_beta", 0.992) + # if self.timers: + # self.timers.log = types.MethodType(new_timer_log, self.timers) def create_criterion(self): """ @@ -129,7 +152,7 @@ def create_criterion(self): whose label arguments are merged into one argument, this is useful to PipelineParallel and trainer.criterion which limit loss format. """ - criterion = create_loss(self.loss_cls, self.model.config, self.args, merge_labels=True) + criterion = create_loss(self.loss_cls, self.model.config, self.args, self.info_buffer, merge_labels=True) return criterion def loss_identifier(self, inputs: Dict) -> str: @@ -192,7 +215,7 @@ def get_model(self, train=False): self._eval_model = model return model - def get_train_step_vars(self, vars: Dict = None) -> Dict: + def get_train_step_vars(self, vars: Optional[Dict] = None) -> Dict: """ NOTE: This is transparent to users. When using multiple instances of StepTrainer collaborate to do one training @@ -213,7 +236,8 @@ def get_train_step_vars(self, vars: Dict = None) -> Dict: # should be called after model is wrapped since the model field should # use model_wrapped. - assert self.model is not self.model_wrapped + if paddle.distributed.get_world_size() > 1: + assert self.model is not self.model_wrapped self.train_step_vars = { # meaningless vars can pass from outter, dummy value is enough "epoch": 0, # meaningless for step training @@ -229,6 +253,13 @@ def get_train_step_vars(self, vars: Dict = None) -> Dict: @property def loss_names(self): + """ + 返回所有损失项的名称列表,只在第一次调用时计算。 + 如果没有损失项,则返回空列表。 + + Returns: + List[str]: 损失项的名称列表,每个名称以"_loss"结尾。 + """ if not hasattr(self, "_loss_names"): self._loss_names = [var_name for var_name in self.get_train_step_vars() if var_name.endswith("_loss")] assert len(self._loss_names) > 0 @@ -266,12 +297,16 @@ def full_training_step(self, **inputs) -> paddle.Tensor: loss_var = train_step_vars[loss_name] train_step_vars["tr_loss"] = loss_var + # train_step_vars["timer_name"] = self.__class__.__name__ new_train_step_vars = super().full_training_step(inputs, **train_step_vars) # minimally update train_step_vars = self.get_train_step_vars( - {"step_control": new_train_step_vars["step_control"], loss_name: new_train_step_vars["tr_loss"]} + { + "step_control": new_train_step_vars["step_control"], + loss_name: new_train_step_vars["tr_loss"], + } ) if loss_name != "tr_loss": train_step_vars.pop("tr_loss") @@ -281,7 +316,11 @@ def full_training_step(self, **inputs) -> paddle.Tensor: if self.use_ema and self.is_accumulation_step: # TODO(guosheng): assume rollout next thus make ema weights on gpu, # but may not, maybe need a way to specify it. - self.ema_update(beta=self.ema_beta, offload_ema=self.offload_ema, offload_model=not self.offload_ema) + self.ema_update( + beta=self.ema_beta, + offload_ema=self.offload_ema, + offload_model=not self.offload_ema, + ) return train_step_vars[loss_name] @@ -298,7 +337,7 @@ def _prepare_inputs(self, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> Dict[ # criterion created by create_loss has `label_names` and `label_default_values` label_names = self.criterion.__class__.label_names # some data fields are used both in model and loss - shared_fields = set(["input_ids", "attention_mask"]) + shared_fields = {"input_ids", "attention_mask"} labels = [] for name in label_names: if name not in inputs: @@ -373,7 +412,10 @@ def get_step_loss(self, loss_prefix: str = "", loss_accumulator: Dict = {}) -> D # model.accumulate_steps in training_pipeline_step to use # trainer.args.accu_steps. The dtype is fp32(to be check), # thus no need to broadcast. - mix_loss = paddle.empty(shape=[self.args.gradient_accumulation_steps], dtype=paddle.float32) + mix_loss = paddle.empty( + shape=[self.args.gradient_accumulation_steps], + dtype=paddle.float32, + ) paddle.distributed.broadcast(mix_loss, src=model.pp_group.ranks[-1], group=model.pp_group) for loss_name in self.loss_names: # We assume a static order of multi-losses and mark the loss @@ -397,6 +439,13 @@ def is_accumulation_step(self): return self.get_train_step_vars()["step_control"] == 0 def get_sharding_master_weight_structured_names(self, model, optimizer): + """ + 获取分片主机权重的结构化名称列表。 + 参数: + model (torch.nn.Module) - 模型对象,包含需要进行权重分片的参数。 + optimizer (torch.optim.Optimizer) - 优化器对象,包含需要进行权重分片的参数。 + 返回值(list[str])- 一个包含所有参数的结构化名称列表,这些参数在当前分片主机上被训练。 + """ rank_param_names = [p.name for p in optimizer._rank2params[optimizer._sharding_rank]] structured_names = [] # for pipeline model, use `model.state_dict()` would auto map param name @@ -407,6 +456,18 @@ def get_sharding_master_weight_structured_names(self, model, optimizer): return structured_names def get_master_weight_state_dict(self, model, optimizer): + """ + 获取模型的权重状态字典,如果使用了AMP且支持pipeline并且存在master weights,则返回master weights。 + 否则返回model.state_dict()。 + + Args: + model (nn.Module): 待获取权重状态字典的模型。 + optimizer (Optimizer): 与模型关联的优化器,可选参数,默认为None。 + + Returns: + Union[Dict[str, Tensor], Dict[str, Any]]: 返回一个包含模型权重状态的字典,字典中的键是参数名称,值是对应的Tensor或Any类型的值。 + 如果使用了AMP且支持pipeline并且存在master weights,则返回的字典只包含master weights。 + """ if self.amp_dtype in ["float16", "bfloat16"] and hasattr(optimizer, "_master_weights"): master_weights = dict(optimizer._master_weights) result = {} @@ -494,6 +555,13 @@ def ema_apply(self): value._share_buffer_to(v) def ema_restore(self): + """ + 将EMA的权重值还原到模型中,并且将其移动到GPU上。 + 如果在初始化时设置了offload_ema=True,则会将EMA的权重值移动到GPU上。 + + Returns: + None, 无返回值,直接修改模型的权重值。 + """ for k, v in self.bak_state_dict.items(): value = v.cuda() value._share_buffer_to(v) @@ -505,47 +573,98 @@ def ema_restore(self): class ema(paddle.no_grad.__mro__[1]): def __init__(self, trainer: StepTrainer): + """ + Args: + trainer (StepTrainer): Trainer object to be used for training. + """ self.trainer = trainer def __enter__(self): + """ + 在进入上下文管理器时,如果使用了EMA,则初始化它。 + 如果模型和优化器已经创建并包装,则调用ema_init。 + 如果使用了EMA,则应用它。 + + Returns: + None, 无返回值。 + """ trainer = self.trainer if trainer.use_ema and not hasattr(trainer, "ema_state_dict"): # call ema_init here since it should be called after model and # optimizer are created and wrapped trainer.ema_init( - offload_ema=trainer.offload_ema, offload_model=not trainer.offload_ema, shard_ema=trainer.shard_ema + offload_ema=trainer.offload_ema, + offload_model=not trainer.offload_ema, + shard_ema=trainer.shard_ema, ) if self.trainer.use_ema: self.trainer.ema_apply() def __exit__(self, *args): + """ + 如果使用了EMA,则恢复EMA状态。 + 参数: + args (tuple) - 可选,不填或为空元组,默认值为None。 + 返回值: + None - 无返回值。 + """ if self.trainer.use_ema: self.trainer.ema_restore() -class enable(paddle.no_grad.__mro__[1]): +class Enable(paddle.no_grad.__mro__[1]): """offload""" - def __init__(self, *args): + def __init__(self, args): + """ + 初始化函数,用于设置类属性objs为传入的参数args。 + Args: + args (Any): 需要传入的参数,将作为类属性objs。 + """ self.objs = args def __enter__(self): + """ + 在进入上下文管理器时,将所有的对象都启用。 + 如果对象没有 enable 方法,则使用 reload_tensor_to_gpu 来重新加载到 GPU。 + + Returns: + None, 无返回值。 + """ for obj in self.objs: - if hasattr(obj, "enable"): - obj.enable() + if hasattr(obj[0], "enable"): + obj[0].enable() else: - reload_tensor_to_gpu(obj.state_dict()) + if obj[1] != "": + reload_tensor_to_gpu(obj) # offload_tensor_to_cpu/reload_tensor_to_gpu use non-blocking copy # maybe overlap with compute later if len(self.objs) > 0: paddle.device.synchronize() def __exit__(self, *args): + """ + 当with语句结束时,调用该方法。 + 关闭所有的对象,并将其中的张量转换为CPU内存。 + + Args: + args (tuple, optional): 可选参数,默认为None。 + + - 第一个元素是错误类型的对象(如果有)。 + - 第二个元素是错误信息(如果有)。 + - 第三个元素是错误的traceback(如果有)。 + + 这些参数与Python标准库中的__exit__方法相同。 + + Returns: + None: 无返回值。 + """ for obj in self.objs: - if hasattr(obj, "disable"): - obj.disable() + if hasattr(obj[0], "disable"): + obj[0].disable() else: - offload_tensor_to_cpu(obj.state_dict()) + if obj[1] != "": + offload_tensor_to_cpu(obj) # offload_tensor_to_cpu/reload_tensor_to_gpu use non-blocking copy # maybe overlap with compute later if len(self.objs) > 0: @@ -554,8 +673,15 @@ def __exit__(self, *args): class PolicyTrainer(StepTrainer): loss_cls = RLHFPPOMixedLoss + trainer_type = "policy" def loss_identifier(self, inputs: Dict) -> str: + """ + 根据输入的字典,判断是否使用ptx损失函数和演员损失函数。如果有标签(labels),则返回"ptx_loss";否则返回"actor_loss"。 + 参数: + inputs (Dict): 包含两个键值对,分别为"inputs"和"labels",其中"inputs"是模型的输入,"labels"是可选的,表示是否使用ptx损失函数。默认值为None。 + 返回值 (str): 返回一个字符串,分别为"ptx_loss"或"actor_loss",表示是否使用ptx损失函数和演员损失函数。 + """ labels = inputs.get("labels", None) if labels is not None: # use ptx loss_name = "ptx_loss" @@ -566,31 +692,74 @@ def loss_identifier(self, inputs: Dict) -> str: class ValueTrainer(StepTrainer): loss_cls = RLHFValueLoss + trainer_type = "value" # define loss name for logging loss_identifier = lambda self, inputs: "reward_critic_loss" class PPOMetric: def set_metric_meta(self, use_ptx=True): + """ + 设置指标的元信息,包括指标名称和运算方式。 + 如果不使用PTX(即不需要计算策略网络的损失),则会从指标名称中移除对应项。 + + Args: + use_ptx (bool, optional): 是否使用PTX(默认为True). Defaults to True. + + Returns: + None: 无返回值,直接修改了类属性。 + """ self.metric_names = [ - "train/" + name - for name in [ - "actor_loss", - "ptx_loss", - "reward_critic_loss", - "reward", - "kl_divergence", - "mean_generated_length", - "max_generated_length", - ] + "train_" + name + for name in ( + [ + "policy_loss", + "ptx_loss", + "value_loss", + "reward", + "norm_reward", + "kl_reward", + "norm_reward_with_kl", + "values", + "returns", + "kl_divergence", + "mean_generated_length", + "max_generated_length", + "min_generated_length", + ] + if self.args.rl_algorithm == "ppo" + else [ + "policy_loss", + "ptx_loss", + "pure_policy_loss", + "kl_loss", + "reward", + "kl_divergence", + "mean_generated_length", + "max_generated_length", + "min_generated_length", + ] + ) ] - self.metric_ops = ["mean", "mean", "mean", "mean", "mean", "mean", "max"] + self.metric_ops = ( + ["mean"] * 10 + ["max", "min"] if self.args.rl_algorithm == "ppo" else ["mean"] * 7 + ["max", "min"] + ) if not use_ptx: self.metric_names.pop(1) self.metric_ops.pop(1) - def __init__(self, freq, use_stack=True, use_ptx=True): + def __init__(self, freq, args, use_stack=True, use_ptx=True): + """ + Args: + freq (int): frequency of metrics collection. + use_stack (bool, optional): whether to stack the metrics into a single tensor. Defaults to True. + use_ptx (bool, optional): whether to use ptx or not. Defaults to True. + + Raises: + ValueError: when freq is less than 1. + """ + self.args = args self.set_metric_meta(use_ptx=use_ptx) self.freq = freq self.counter = 0 @@ -619,17 +788,21 @@ def update(self, metrics: Dict[str, paddle.Tensor]) -> Union[None, Dict[str, flo else: for i, name in enumerate(self.metric_names): self.metrics[i][self.counter] = metrics[name] - if self.counter + 1 == self.freq: - from paddlenlp.trainer.utils import distributed_concat - metrics = distributed_concat(self.metrics) + self.counter += 1 + if self.counter == self.freq: + metrics = distributed_concat(self.metrics) if paddle.distributed.get_world_size() > 1 else self.metrics + out_metrics = {} if self.use_stack: mean_metric = metrics.mean(0) max_metric = metrics.max(0) + min_metric = metrics.min(0) for i, (name, op) in enumerate(zip(self.metric_names, self.metric_ops)): if op == "max": out_metrics[name] = max_metric[i].item() if self.use_stack else metrics[i].max().item() + elif op == "min": + out_metrics[name] = min_metric[i].item() if self.use_stack else metrics[i].min().item() else: out_metrics[name] = mean_metric[i].item() if self.use_stack else metrics[i].mean().item() @@ -644,6 +817,17 @@ def update(self, metrics: Dict[str, paddle.Tensor]) -> Union[None, Dict[str, flo def data_dispatch(fun): + """ + 用于将函数转换为一个可以处理数据的函数,该函数会根据策略训练器中的数据分组参数进行数据切分和合并。 + 如果策略训练器没有设置数据分组参数,则不进行任何操作。 + + Args: + fun (Callable[[Any, Any], Any]): 需要被转换的函数,接受两个参数:第一个是当前对象,第二个是需要处理的数据。返回值为任意类型。 + + Returns: + Callable[[Any, Any], Any]: 返回一个新的函数,接受两个参数:第一个是当前对象,第二个是需要处理的数据。返回值为任意类型。 + """ + def _impl(self, data): gp = getattr(self.policy_trainer, "_data_trans_group", None) data = data_group_split(data, group=gp) @@ -668,8 +852,32 @@ def __init__( compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, callbacks: Optional[List[TrainerCallback]] = None, optimizers: Tuple[paddle.optimizer.Optimizer, paddle.optimizer.lr.LRScheduler] = (None, None), - preprocess_logits_for_metrics: Callable[[paddle.Tensor, paddle.Tensor], paddle.Tensor] = None, + preprocess_logits_for_metrics: Optional[Callable[[paddle.Tensor, paddle.Tensor], paddle.Tensor]] = None, ): + """ + Args: + model (Union[PretrainedModel, nn.Layer], optional): The model to be trained. If not provided, it will be + initialized based on the values of ``args``. Defaults to None. + criterion (nn.Layer, optional): The loss function used for training. Defaults to None. + args (TrainingArguments, optional): Training arguments. If not provided, it will be initialized with + default values. Defaults to None. + data_collator (Optional[DataCollator], optional): The function to batch data samples together into + mini-batches. If not provided, a simple batching function that drops remaining samples will be used. + Defaults to None. + train_dataset (Optional[Dataset], optional): The dataset to be used for training. Defaults to None. + ptx_dataset (Optional[Dataset], optional): The dataset to be used for ptx. Defaults to None. + eval_dataset (Union[Dataset, Dict[str, Dataset]], optional): The dataset to be used for evaluation. + Defaults to None. + tokenizer (Optional[PretrainedTokenizer], optional): The tokenizer used for encoding. Defaults to None. + compute_metrics (Optional[Callable[[EvalPrediction], Dict]], optional): The function to compute metrics + during evaluation. Defaults to None. + callbacks (Optional[List[TrainerCallback]], optional): A list of callbacks to customize the training + process. Defaults to None. + optimizers (Tuple[paddle.optimizer.Optimizer, paddle.optimizer.lr.LRScheduler], optional): The tuple of + optimizer and learning rate scheduler. Defaults to (None, None). + preprocess_logits_for_metrics (Callable[[paddle.Tensor, paddle.Tensor], paddle.Tensor], optional): The + function to preprocess logits before computing metrics. Defaults to None. + """ with guard_set_args( args, { @@ -699,18 +907,32 @@ def __init__( self.ptx_dataset = ptx_dataset self.eval_dataset = eval_dataset - (policy_model, reference_model, reward_model, value_model, policy_model_eval, value_model_eval) = model + ( + policy_model, + reference_model, + reward_model, + value_model, + policy_model_eval, + value_model_eval, + ) = model self._model_config = policy_model.config # use this to change flash attention dynamicly self._policy_model_eval = policy_model_eval - self._value_model_eval = value_model_eval + if args.rl_algorithm == "ppo": + self._value_model_eval = value_model_eval # policy_tokenizer and value_tokenizer should be same - (policy_tokenizer, reference_tokenizer, reward_tokenizer, value_tokenizer) = tokenizer + ( + policy_tokenizer, + reference_tokenizer, + reward_tokenizer, + value_tokenizer, + ) = tokenizer policy_training_args = copy.deepcopy(args) self.use_ptx = self.ptx_dataset is not None if self.use_ptx: policy_training_args.gradient_accumulation_steps *= 2 + lr_scheduler = self.get_scheduler(policy_training_args) self.policy_trainer = PolicyTrainer( policy_model, criterion, @@ -721,37 +943,51 @@ def __init__( policy_tokenizer, compute_metrics, callbacks, - optimizers, - preprocess_logits_for_metrics, - ) - value_training_args = copy.deepcopy(args) - for attr_name in [ - "critic_learning_rate", - "critic_weight_decay", - "critic_lr_scheduler_type", - "critic_warmup_ratio", - "critic_recompute", - ]: - if getattr(value_training_args, attr_name, None) is not None: - setattr(value_training_args, attr_name[len("critic_") :], getattr(value_training_args, attr_name)) - self.value_trainer = ValueTrainer( - value_model, - criterion, - value_training_args, - data_collator, - train_dataset, - eval_dataset, - value_tokenizer, - compute_metrics, - callbacks, - optimizers, + [None, lr_scheduler], preprocess_logits_for_metrics, ) + if args.rl_algorithm == "ppo": + value_training_args = copy.deepcopy(args) + for attr_name in [ + "critic_learning_rate", + "critic_weight_decay", + "critic_lr_scheduler_type", + "critic_warmup_ratio", + "critic_recompute", + ]: + if getattr(value_training_args, attr_name, None) is not None: + setattr( + value_training_args, + attr_name[len("critic_") :], + getattr(value_training_args, attr_name), + ) + lr_scheduler = self.get_scheduler(value_training_args) + self.value_trainer = ValueTrainer( + value_model, + criterion, + value_training_args, + data_collator, + train_dataset, + eval_dataset, + value_tokenizer, + compute_metrics, + callbacks, + [None, lr_scheduler], + preprocess_logits_for_metrics, + ) self.policy_trainer.set_eval_model(policy_model_eval) - self.value_trainer.set_eval_model(value_model_eval) + if args.rl_algorithm == "ppo": + self.value_trainer.set_eval_model(value_model_eval) # disable inner trainers' callback/state/control self.policy_trainer.add_callback(MuteDefaultFlowCallback) - self.value_trainer.add_callback(MuteDefaultFlowCallback) + if args.rl_algorithm == "ppo": + self.value_trainer.add_callback(MuteDefaultFlowCallback) + if not self.args.disable_tqdm: + from paddlenlp.trainer import ProgressCallback + + self.policy_trainer.pop_callback(ProgressCallback) + if args.rl_algorithm == "ppo": + self.value_trainer.pop_callback(ProgressCallback) # use trainer for reference_model/reward_model to enable sharding stage-3 # and PipelineParallel. maybe we should allow models to use different dist @@ -765,12 +1001,11 @@ def __init__( { "recompute": False, # "fp16_opt_level": "O1", - "pipeline_parallel_degree": args.pipeline_parallel_degree - if isinstance(reference_model, PipelineLayer) - else 1, # workaround for pipeline parallel model check + "pipeline_parallel_degree": ( + args.pipeline_parallel_degree if isinstance(reference_model, PipelineLayer) else 1 + ), # workaround for pipeline parallel model check }, ): - self.reference_trainer = StepTrainer( reference_model, criterion, @@ -784,29 +1019,34 @@ def __init__( optimizers, preprocess_logits_for_metrics, ) - self.reward_trainer = StepTrainer( - reward_model, - criterion, - copy.deepcopy(args), - data_collator, - train_dataset, - eval_dataset, - reward_tokenizer, - compute_metrics, - callbacks, - optimizers, - preprocess_logits_for_metrics, - ) + if isinstance(reward_model, PretrainedModel): + self.reward_trainer = StepTrainer( + reward_model, + criterion, + copy.deepcopy(args), + data_collator, + train_dataset, + eval_dataset, + reward_tokenizer, + compute_metrics, + callbacks, + optimizers, + preprocess_logits_for_metrics, + ) + else: + self.reward_server = reward_model # TODO(guosheng): sharding stage3 should create master weight optionally # instead of creation and clear. from paddlenlp.trainer.trainer_utils import ShardingOption if args.pipeline_parallel_degree > 1 or ShardingOption.FULL_SHARD in args.sharding: self.reference_trainer.init_train_model_opt(100, None, clear_master_weight=True) # dummy max_steps - self.reward_trainer.init_train_model_opt(100, None, clear_master_weight=True) # dummy max_steps + if isinstance(reward_model, PretrainedModel): + self.reward_trainer.init_train_model_opt(100, None, clear_master_weight=True) # dummy max_steps self.reference_model.eval() - self.reward_model.eval() + if isinstance(reward_model, PretrainedModel): + self.reward_model.eval() self.reward_tokenizer = reward_tokenizer self.tokenizer = policy_tokenizer @@ -814,17 +1054,18 @@ def __init__( self.reward_tokenizer = self.tokenizer self.generation_config = GenerationConfig( - max_new_tokens=self.args.max_length, + max_new_tokens=self.args.max_dec_len, num_return_sequences=self.args.num_return_sequences, temperature=self.args.temperature, top_p=self.args.top_p, top_k=0, # to disable top_k sampling, default is 50 repetition_penalty=self.args.repetition_penalty, + min_length=self.args.min_dec_len, do_sample=True, # allow generation output to contain input trunc_input=False, bos_token_id=self.tokenizer.bos_token_id, - eos_token_id=self.tokenizer.eos_token_id, + eos_token_id=self.tokenizer.cls_token_id, pad_token_id=self.tokenizer.pad_token_id, ) # Those value can be changed @@ -834,27 +1075,68 @@ def __init__( self.gamma = 1.0 self.gae_lambda = 0.95 + # for reward norm + self.reward_mean = 0.0 + self.reward_var = 1.0 + self.sample_batch_num = 0 + # dummy class and object for model to be compaible with methods of # Trainer, such as evaluation_loop self.DummyPPOModel = type( - "DummyPPOModel", (object,), {"eval": lambda _: self.set_eval(), "train": lambda _: self.set_train()} + "DummyPPOModel", + (object,), + { + "eval": lambda _: self.set_eval(), + "train": lambda _: self.set_train(), + }, ) self.model = self.model_wrapped = self.DummyPPOModel() + if self.timers: + self.timers.log = types.MethodType(new_timer_log, self.timers) @property def reference_model(self): + """ + 获取参考模型,如果没有则返回None。 + 该方法只能在初始化后使用,否则会引发异常。 + + Returns: + torch.nn.Module, optional - 参考模型,如果没有则返回None。 + + Raises: + Exception - 当调用此方法前未初始化reference_trainer时,将引发异常。 + """ return self.reference_trainer.get_model(train=False) @property def reward_model(self): - return self.reward_trainer.get_model(train=False) + """ + 获取奖励模型,如果没有则创建一个。 + 返回值:tf.keras.models.Model,奖励模型。 + """ + if hasattr(self, "reward_trainer"): + return self.reward_trainer.get_model(train=False) + else: + return self.reward_server @property def actor_model(self): + """ + 获取当前的actor模型,如果在训练中则返回训练后的模型,否则返回eval时使用的模型。 + + Returns: + torch.nn.Module, torch.jit.ScriptModule: Actor模型,可以是torch.nn.Module或者torch.jit.ScriptModule类型。 + """ return self.policy_trainer.get_model(train=self.training) @property def reward_critic_model(self): + """ + 获取 critic model,仅在使用 value-based 策略时有效。 + + Returns: + tf.keras.Model, optional: critic model,如果没有设置则返回 None。 + """ return self.value_trainer.get_model(train=self.training) def set_train(self, mode: bool = True) -> None: @@ -862,16 +1144,56 @@ def set_train(self, mode: bool = True) -> None: if mode: self.training = True self.actor_model.train() - self.reward_critic_model.train() + if self.args.rl_algorithm == "ppo": + self.reward_critic_model.train() else: self.training = False self.actor_model.eval() - self.reward_critic_model.eval() + if self.args.rl_algorithm == "ppo": + self.reward_critic_model.eval() def set_eval(self) -> None: """Set model to evaluation mode.""" self.set_train(mode=False) + def get_scheduler(self, args): + """ + 获取学习率调度器,如果没有设置最小学习率则返回None。 + 支持两种类型的学习率调度器:"cosine"和"linear"。 + + Args: + args (argparse.Namespace): 命令行参数,包含了学习率相关的参数。 + + Returns: + torch.optim.lr_scheduler._LRScheduler or None, optional: 学习率调度器或者None,默认为None。 + """ + if args.decay_steps is None: + args.decay_steps = args.max_steps + if args.warmup_steps > 0: + warmup_steps = args.warmup_steps + else: + warmup_steps = args.warmup_ratio * args.max_steps + lr_scheduler = None + if args.min_learning_rate is not None: + if args.lr_scheduler_type == "cosine": + lr_scheduler = CosineAnnealingWithWarmupDecay( + max_lr=args.learning_rate, + min_lr=args.min_learning_rate, + warmup_step=warmup_steps, + decay_step=args.decay_steps, + last_epoch=0, + ) + elif args.lr_scheduler_type == "linear": + lr_scheduler = LinearAnnealingWithWarmupDecay( + max_lr=args.learning_rate, + min_lr=args.min_learning_rate, + warmup_step=warmup_steps, + decay_step=args.decay_steps, + last_epoch=0, + ) + return lr_scheduler + + @paddle.no_grad() def prediction_step( self, model: nn.Layer, @@ -879,60 +1201,106 @@ def prediction_step( prediction_loss_only: bool, ignore_keys: Optional[List[str]] = None, ) -> Tuple[Optional[paddle.Tensor], Optional[paddle.Tensor], Optional[paddle.Tensor]]: + """ + 预测步骤,用于生成下一个输入序列。 + + Args: + model (nn.Layer): 模型实例,需要是 `paddle.nn.Layer` 的子类。 + inputs (Dict[str, Union[paddle.Tensor, Any]]): 包含输入数据的字典,其中包含以下键: + - "input_ids" (paddle.Tensor, optional): 输入序列的编号 ID,默认为None。 + - "attention_mask" (paddle.Tensor, optional): 输入序列的注意力掩码,默认为None。 + - "position_ids" (paddle.Tensor, optional): 输入序列的位置ID,默认为None。 + prediction_loss_only (bool): 仅返回预测损失,不返回其他任何值。 + ignore_keys (Optional[List[str]], optional): 忽略的键列表,默认为None。 + + Returns: + Tuple[Optional[paddle.Tensor], Optional[paddle.Tensor], Optional[paddle.Tensor]]: + 三元组,包含以下元素: + - Optional[paddle.Tensor]: 如果 `prediction_loss_only` 为False,则为预测得分,否则为None。 + - Optional[paddle.Tensor]: 当前未定义,始终为None。 + - Optional[paddle.Tensor]: 当前未定义,始终为None。 + + Raises: + ValueError: 如果 `ignore_keys` 不是可选参数或者不是一个列表。 + """ inputs = self._prepare_inputs(inputs) + with self.enable(self.actor_model, self.reference_model, self.policy_trainer): + with infer_guard(self.policy_trainer): + position_ids = inputs.get("position_ids", make_position_ids(inputs["attention_mask"])) + prompt_only_batch = { + "input_ids": inputs["input_ids"], + "attention_mask": inputs["attention_mask"], + "position_ids": position_ids, + **({"label_ids": inputs["label_ids"]} if self.args.use_rm_server else {}), + } + generated_seq = self.generate(prompt_only_batch, do_eval=True)[0]["input_ids"] + + if self._model_config.sequence_parallel: + # pad to max_sequence_length + seq = self.tokenizer.pad( + {"input_ids": [s for s in generated_seq]}, + padding="max_length", + max_length=self._model_config.max_sequence_length, + return_attention_mask=False, + )["input_ids"] + else: + seq = generated_seq - with paddle.no_grad(): - with self.autocast_smart_context_manager(): - seq = self.actor_model.generate( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - position_ids=inputs["position_ids"] - if "position_ids" in inputs - else make_position_ids(inputs["attention_mask"]), - generation_config=self.generation_config, - synced_gpus=ShardingOption.FULL_SHARD in self.policy_trainer.args.sharding, - )[0] + if not self.args.use_rm_server: if self.reward_tokenizer is not self.tokenizer: reward_tokenize_output = batch_retokenize( input_ids=seq, src_tokenizer=self.tokenizer, dest_tokenizer=self.reward_tokenizer, - skip_special_tokens=True, - device=self.args.device, ) reward_input_ids = reward_tokenize_output["input_ids"] + reward_attention_mask = reward_tokenize_output["attention_mask"] + reward_position_ids = reward_tokenize_output["position_ids"] else: reward_input_ids = seq - reward_attention_mask = make_attention_mask( - seq, - pad_id=self.reward_tokenizer.pad_token_id, - unk_id=self.reward_tokenizer.unk_token_id, - causal_mask=False, - ) - reward_position_ids = make_position_ids(reward_attention_mask) + reward_attention_mask = make_attention_mask( + seq, + pad_id=self.reward_tokenizer.pad_token_id, + eos_id=self.reward_tokenizer.eos_token_id, + unk_id=self.reward_tokenizer.unk_token_id, + causal_mask=True, + ) + reward_position_ids = make_position_ids(reward_attention_mask) - # unify PP with others since PP always return tuple + # .end_scores reward_score = self.reward_model( reward_input_ids, attention_mask=reward_attention_mask, position_ids=reward_position_ids, # return_dict=True, - )[ - 1 - ] # .end_scores - reward_score = reward_score.squeeze(axis=-1).cast(paddle.float32) - + )[1] + else: + prompt_len = inputs["input_ids"].shape[-1] + if "label_ids" not in inputs: + raise ValueError("Rule-based reward needs labels.") + src = self.tokenizer.batch_decode(inputs["input_ids"], skip_special_tokens=True) + tgt = self.tokenizer.batch_decode(inputs["label_ids"], skip_special_tokens=True) + response = self.tokenizer.batch_decode(generated_seq[:, prompt_len:], skip_special_tokens=True) + reward_score = self.request_reward_server(src, tgt, response) + + reward_score = reward_score.squeeze(axis=-1).cast(paddle.float32) # keep the first batch of eval output sequence to print and check prompt = self.tokenizer.batch_decode(inputs["input_ids"], skip_special_tokens=True) - generated = self.tokenizer.batch_decode(seq, skip_special_tokens=True) + generated = self.tokenizer.batch_decode(generated_seq, skip_special_tokens=True) # no padding + reward_score_list = reward_score.tolist() for i, text in enumerate(generated): - self._eval_out_file.write(text + "\n") + item = { + "Prompt": text[: len(prompt[i]) - 1], + "Generated": text[len(prompt[i]) :], + "Reward": reward_score_list[i], + } + self._eval_out_file.write(json.dumps(item, ensure_ascii=False) + "\n") + if getattr(self, "_eval_seq", None) is None: generated = [text[len(prompt[i]) :] for i, text in enumerate(generated)] # prompts.extend(prompt) # generateds.extend(generated) - self._eval_seq = (prompt, generated, reward_score.tolist()) - + self._eval_seq = (prompt, generated, reward_score_list) return reward_score.mean(), None, None def evaluation_loop( @@ -944,27 +1312,50 @@ def evaluation_loop( metric_key_prefix: str = "eval", max_eval_iters: Optional[int] = -1, ) -> EvalLoopOutput: + """ + 循环访问数据集,并对模型进行评估。 + + Args: + dataloader (DataLoader, optional): 用于评估的数据加载器。默认为None。 + description (str, optional): 描述评估过程的字符串。默认为''. + prediction_loss_only (Optional[bool], optional): 是否只计算预测损失。默认为None。 + ignore_keys (Optional[List[str]], optional): 要忽略的键列表。默认为None。 + metric_key_prefix (str, optional): 指标键前缀。默认为'eval'. + max_eval_iters (Optional[int], optional): 最大评估次数。默认为-1,表示无限制。 + + Returns: + EvalLoopOutput: 包含评估结果和指标的类实例。 + + Raises: + ValueError: 如果`prediction_loss_only`不是布尔值,则引发ValueError异常。 + """ # to save eval generated sequence eval_out_file = os.path.join( - self.args.output_dir, f"eval_out-step{self.state.global_step}-rank{self.args.local_rank}.txt" + self.args.output_dir, + f"eval_out-step{self.state.global_step}-rank{self.args.local_rank}.jsonl", ) - self._eval_out_file = open(eval_out_file, "w") + self._eval_out_file = open(eval_out_file, "w", encoding="utf-8") # TODO(guosheng): use _inner_eval_model (if trainer has one) instead of # original trainer model to eval, especially when using sharded EMA # NOTE: use here rather than in prediction_step since actor_model would # be set to eval out of prediction_step - with guard_set_args( - self.policy_trainer, # disable _inner_eval_model - { - "_eval_model": None, # otherwise would use cached _eval_model - "_inner_eval_model": None, # otherwise would use _inner_eval_model to create _eval_model - }, - ): - output = super().evaluation_loop( - dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix, max_eval_iters - ) - output.metrics[f"{metric_key_prefix}/reward"] = output.metrics.pop(f"{metric_key_prefix}_loss") + # with guard_set_args( + # self.policy_trainer, # disable _inner_eval_model + # { + # "_eval_model": None, # otherwise would use cached _eval_model + # "_inner_eval_model": None, # otherwise would use _inner_eval_model to create _eval_model + # }, + # ): + output = super().evaluation_loop( + dataloader, + description, + prediction_loss_only, + ignore_keys, + metric_key_prefix, + max_eval_iters, + ) + output.metrics[f"{metric_key_prefix}_reward"] = output.metrics.pop(f"{metric_key_prefix}_loss") columns = ["Prompt", "Generated", "Reward"] rows = list(zip(*self._eval_seq)) @@ -983,38 +1374,178 @@ def evaluation_loop( return output def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: + """ + 获取用于评估模型的数据加载器。如果未提供`eval_dataset`,则使用`self.eval_dataset`。 + 该函数会设置一个名为"data_collator"的参数,并将其传递给`super().get_eval_dataloader()`。 + + Args: + eval_dataset (Optional[Dataset], optional): 用于评估的数据集. Defaults to None. + + Returns: + DataLoader: 包含用于评估的数据的DataLoader实例。 + """ with guard_set_args(self, {"data_collator": self.eval_dataset.get_collator()}): return super().get_eval_dataloader(eval_dataset) def _save_checkpoint(self, model, metrics=None): + """ + 保存模型和指标到两个不同的 checkpoint,一个是 policy 模型,另一个是 value 模型。 + 这里使用了 `guard_set_args` 来防止在调用 `_save_checkpoint` 时修改了原始参数。 + + Args: + model (nn.Module): 需要保存的模型。 + metrics (Optional[Dict], optional): 可选的指标字典,默认为 None。 + key 是指标名称,value 是对应的指标值。 + + Returns: + None. + """ # maybe change args.output_dir of policy_trainer/value_trainer directly - with guard_set_args(self.policy_trainer.args, {"output_dir": os.path.join(self.args.output_dir, "policy")}): + self.runtime_timer.start("checkpoint saving time") + with guard_set_args( + self.policy_trainer.args, + {"output_dir": os.path.join(self.args.output_dir, "policy")}, + ): + if self.policy_trainer.args.unified_checkpoint: + if "train_model" in self.policy_trainer.args.offload_level: + reload_tensor_to_gpu((self.policy_trainer.model, "train_model")) + if ( + "optimizer" in self.policy_trainer.args.offload_level + and not self.policy_trainer.args.ignore_save_lr_and_optim + ): + reload_tensor_to_gpu((self.policy_trainer.optimizer, "optimizer")) self.policy_trainer._save_checkpoint(model, metrics) - with guard_set_args(self.value_trainer.args, {"output_dir": os.path.join(self.args.output_dir, "value")}): - self.value_trainer._save_checkpoint(model, metrics) + if self.args.rl_algorithm == "ppo": + with guard_set_args( + self.value_trainer.args, + {"output_dir": os.path.join(self.args.output_dir, "value")}, + ): + if self.value_trainer.args.unified_checkpoint: + if "train_model" in self.value_trainer.args.offload_level: + reload_tensor_to_gpu((self.value_trainer.model, "train_model")) + if ( + "optimizer" in self.value_trainer.args.offload_level + and not self.value_trainer.args.ignore_save_lr_and_optim + ): + reload_tensor_to_gpu((self.value_trainer.optimizer, "optimizer")) + self.value_trainer._save_checkpoint(model, metrics) + + # Determine the new best metric / best model checkpoint + if metrics is not None and self.args.metric_for_best_model is not None: + metric_to_check = self.args.metric_for_best_model + if not metric_to_check.startswith("eval_"): + metric_to_check = f"eval_{metric_to_check}" + metric_value = metrics[metric_to_check] + + operator = np.greater if self.args.greater_is_better else np.less + if ( + self.state.best_metric is None + or self.state.best_model_checkpoint is None + or operator(metric_value, self.state.best_metric) + ): + self.state.best_metric = metric_value + metrics = { + "policy": self.policy_trainer.state.best_model_checkpoint, + **( + {"value": self.value_trainer.state.best_model_checkpoint} + if self.args.rl_algorithm == "ppo" + else {} + ), + } + self.state.best_model_checkpoint = json.dumps(metrics) + + def save_model( + self, + output_dir: Optional[str] = None, + merge_tensor_parallel: Optional[bool] = False, + ): + """ + 保存模型。 + + Args: + output_dir (Optional[str], optional): 输出目录,默认为None,使用命令行参数--output-dir。 Defaults to None. + merge_tensor_parallel (Optional[bool], optional): 是否合并tensor parallel,默认为False。 Defaults to False. + + Raises: + ValueError: 如果output_dir不在当前工作目录下,则会引发ValueError异常。 + """ + if output_dir is None: + output_dir = self.args.output_dir + + if "train_model" in self.args.offload_level: + reload_tensor_to_gpu((self.policy_trainer.model, "model")) + if self.args.rl_algorithm == "ppo": + reload_tensor_to_gpu((self.value_trainer.model, "model")) + self.policy_trainer.save_model(os.path.join(output_dir, "policy"), merge_tensor_parallel) + if self.args.rl_algorithm == "ppo": + self.value_trainer.save_model(os.path.join(output_dir, "value"), merge_tensor_parallel) def init_train_model_opt( - self: Trainer, max_steps: int, resume_from_checkpoint: bool = False, clear_master_weight: bool = False + self: Trainer, + max_steps: int, + resume_from_checkpoint: bool = False, + clear_master_weight: bool = False, ) -> PretrainedModel: + """ + 初始化训练模型和优化器。 + 如果`resume_from_checkpoint`为字符串,则将其作为路径,并在该路径下恢复模型和优化器状态;否则,将其视为布尔值,表示是否从最后一个保存的检查点中恢复。 + 如果`clear_master_weight`为True,则清除主要权重。 + + Args: + max_steps (int): 最大训练步数。 + resume_from_checkpoint (bool, optional): 是否从检查点中恢复模型和优化器状态(默认为False)。 + 如果为字符串,则将其作为路径,并在该路径下恢复模型和优化器状态。 + clear_master_weight (bool, optional): 是否清除主要权重(默认为False)。 + + Returns: + Tuple[PretrainedModel, PretrainedModel]: 返回两个元组,分别包含策略模型和价值函数模型。 + """ # resume should be triggered here # maybe change args.output_dir of policy_trainer/value_trainer directly - with guard_set_args(self.policy_trainer.args, {"output_dir": os.path.join(self.args.output_dir, "policy")}): + with guard_set_args( + self.policy_trainer.args, + {"output_dir": os.path.join(self.args.output_dir, "policy")}, + ): policy_model = self.policy_trainer.init_train_model_opt( max_steps, - os.path.join(resume_from_checkpoint, "policy") - if isinstance(resume_from_checkpoint, str) - else resume_from_checkpoint, - ) - with guard_set_args(self.value_trainer.args, {"output_dir": os.path.join(self.args.output_dir, "value")}): - value_model = self.value_trainer.init_train_model_opt( - max_steps, - os.path.join(resume_from_checkpoint, "value") - if isinstance(resume_from_checkpoint, str) - else resume_from_checkpoint, + ( + os.path.join(resume_from_checkpoint, "policy") + if isinstance(resume_from_checkpoint, str) + else resume_from_checkpoint + ), ) + if self.args.rl_algorithm == "ppo": + with guard_set_args( + self.value_trainer.args, + {"output_dir": os.path.join(self.args.output_dir, "value")}, + ): + value_model = self.value_trainer.init_train_model_opt( + max_steps, + ( + os.path.join(resume_from_checkpoint, "value") + if isinstance(resume_from_checkpoint, str) + else resume_from_checkpoint + ), + ) + else: + value_model = None return policy_model, value_model def get_epoch_iterator(self): + """ + 获取一个迭代器,该迭代器将生成一个批次的数据。每个批次包含两部分:一个是提示仅批次(prompt only batch),另一个是PTX批次(PTX batch)。 + 如果使用了PTX,则PTX批次会在每个RL批次之后进行轮换。 + + Args: + 无参数。 + + Returns: + EpochIterator (class): 返回一个类,该类包含一个__iter__方法和一个__len__方法。__iter__方法可以生成一个批次的数据,__len__方法返回总共有多少个批次。 + + Raises: + 无异常抛出。 + """ + def gen_epoch_data(): for prompt_only_batch, ptx_batch in zip( self.prompt_only_dataloader, @@ -1023,17 +1554,19 @@ def gen_epoch_data(): # generate batches self.set_eval() - with ema(self.policy_trainer), ema(self.value_trainer): - rl_batches = self.split_rl_micro_batches(prompt_only_batch) + with ( + ema(self.policy_trainer), + ema(self.value_trainer) if self.args.rl_algorithm == "ppo" else contextlib.nullcontext(), + ): + with guard_set_args(self._model_config, {"use_fused_head_and_loss_fn": False}): + rl_batches = self.split_rl_micro_batches(prompt_only_batch) - self.timers and self.timers("ptx-batch").start() if self.use_ptx: ptx_batches = self.split_ptx_micro_batches(ptx_batch) else: ptx_batches = [None for _ in range(len(rl_batches))] - self.timers and self.timers("ptx-batch").stop() - empty_device_cache() + paddle.device.cuda.empty_cache() self.set_train() for _ in range(self.args.update_iters): @@ -1055,6 +1588,24 @@ def __len__(self): return EpochIterator() def init_train_num(self: Trainer, train_dataloader: DataLoader): + """ + 初始化训练数据的批次大小,以及相关参数。 + + Args: + self (Trainer): Trainer实例。 + train_dataloader (DataLoader): 用于训练的DataLoader对象。 + + Returns: + tuple (int, Optional[int], int, int, int, int, int): + 返回一个元组,包含: + 1. total_train_batch_size (int) - 总训练批次大小。 + 2. len_dataloader (Optional[int]) - 如果不是可迭代的数据集,则为DataLoader长度;否则为None。 + 3. max_steps (int) - 最大训练步数。 + 4. num_train_epochs (int) - 训练的最大轮数。 + 5. num_update_steps_per_epoch (int) - 每个epoch中更新模型的次数。 + 6. num_examples (int) - 训练数据的样本数量。 + 7. num_train_samples (int) - 训练数据的样本总数。 + """ args = self.args total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.dataset_world_size @@ -1098,15 +1649,32 @@ def init_train_num(self: Trainer, train_dataloader: DataLoader): ) def is_step_end(self): + """ + 判断是否到达了步数结尾,当累加步数等于args.gradient_accumulation_steps时返回True。 + 返回值:bool,如果到达了步数结尾则返回True,否则返回False。 + """ # reach accumulation_steps, value trainer has the same step_control and # gradient_accumulation_steps as PPO trainer. # if (step_control + 1) % args.gradient_accumulation_steps == 0 - return self.value_trainer.is_accumulation_step + if self.args.rl_algorithm == "ppo": + return self.value_trainer.is_accumulation_step + return self.policy_trainer.is_accumulation_step def get_step_loss(self, loss_prefix: str = "") -> Dict: + """ + 获取当前步骤的损失,包括策略训练和价值函数训练的损失。 + 如果提供了loss_prefix参数,则将损失名称加上该前缀。 + + Args: + loss_prefix (str, optional): 损失名称的前缀字符串,默认为"". + + Returns: + Dict[str, float]: 返回一个字典,包含两个损失项:rl_loss(策略训练的损失)和value_loss(价值函数训练的损失)。 + """ rl_loss = self.policy_trainer.get_step_loss(loss_prefix) - value_loss = self.value_trainer.get_step_loss(loss_prefix) - rl_loss.update(value_loss) + if self.args.rl_algorithm == "ppo": + value_loss = self.value_trainer.get_step_loss(loss_prefix) + rl_loss.update(value_loss) return rl_loss def train( @@ -1114,29 +1682,65 @@ def train( resume_from_checkpoint: Optional[Union[str, bool]] = None, ignore_keys_for_eval: Optional[List[str]] = None, ) -> None: + """ + Main training entry point. + + Args: + resume_from_checkpoint (Optional[Union[str, bool]], optional): + Checkpoint path from which training should be resumed. If a + path is given, training will restart from this checkpoint. If + set to ``True``, the last checkpoint in ``output_dir`` will be + loaded. If ``False`` or ``None`` (default), training will + start from scratch. Defaults to ``None``. + + ignore_keys_for_eval (Optional[List[str]], optional): + List of keys to ignore when computing the metrics during + evaluation. Defaults to ``None``. + + Returns: + None: + Training process is finished, no return value. + """ # ##### The following code try to keep same as the Trainer.train ##### args = self.args self.is_in_train = True # ##### trainging data and related num setting ##### # TODO(guosheng): remove the binding method get_collator of dataset - with guard_set_args( - args, {"per_device_train_batch_size": self.args.per_device_prompt_batch_size} - ), guard_set_args( - self, {"train_dataset": self.train_dataset, "data_collator": self.train_dataset.get_collator()} + with ( + guard_set_args( + args, + {"per_device_train_batch_size": self.args.per_device_prompt_batch_size}, + ), + guard_set_args( + self, + { + "train_dataset": self.train_dataset, + "data_collator": self.train_dataset.get_collator(), + }, + ), ): train_dataloader = self.prompt_only_dataloader = self.get_train_dataloader() if self.use_ptx: - with guard_set_args( - args, - { - "per_device_train_batch_size": 1 - if getattr(self.ptx_dataset, "is_intokens", False) - else self.args.per_device_prompt_batch_size * self.args.num_return_sequences - }, - ), guard_set_args( - self, {"train_dataset": self.ptx_dataset, "data_collator": self.ptx_dataset.get_collator()} + with ( + guard_set_args( + args, + { + "per_device_train_batch_size": ( + 1 + if getattr(self.ptx_dataset, "is_intokens", False) + else self.args.per_device_prompt_batch_size * self.args.num_return_sequences + ) + }, + ), + guard_set_args( + self, + { + "train_dataset": self.ptx_dataset, + "data_collator": self.ptx_dataset.get_collator(), + }, + ), ): self.ptx_dataloader = self.get_train_dataloader() else: @@ -1153,19 +1757,28 @@ def train( # ##### model and optimizer related setting ##### policy_model, value_model = self.init_train_model_opt(max_steps, resume_from_checkpoint) - empty_device_cache() + paddle.device.cuda.empty_cache() # ##### traing statistic logging ##### # Number of trainable parameters only account for policy_model self.init_train_log( - num_examples, num_train_epochs, total_train_batch_size, max_steps, num_train_samples, policy_model + num_examples, + num_train_epochs, + total_train_batch_size, + max_steps, + num_train_samples, + policy_model, ) # ##### set training state and resume ##### # consumed_samples used to set train_dataloader.batch_sampler may not be # correct. Thus, data cannot be resumed perfectly when not breaking at epoch end. - epochs_trained, steps_trained_in_current_epoch, steps_trained_progress_bar = self.init_train_state( - resume_from_checkpoint, train_dataloader, max_steps, num_train_epochs, num_update_steps_per_epoch + (epochs_trained, steps_trained_in_current_epoch, steps_trained_progress_bar,) = self.init_train_state( + resume_from_checkpoint, + train_dataloader, + max_steps, + num_train_epochs, + num_update_steps_per_epoch, ) epoch_iterator = self.get_epoch_iterator() @@ -1183,7 +1796,7 @@ def train( self.control = self.callback_handler.on_train_begin(args, self.state, self.control) self._globalstep_last_logged = self.state.global_step - metric = PPOMetric(freq=self.args.logging_steps, use_ptx=self.use_ptx) + metric = PPOMetric(freq=self.args.logging_steps, args=self.args, use_ptx=self.use_ptx) start_time = time.time() self._globalstep_last_start_time = start_time @@ -1203,42 +1816,53 @@ def train( # self.callback_handler.on_load_data_end(args, self.state, self.control, inputs=inputs) rl_batch, ptx_batch = inputs # TODO(guosheng): make rl_step/ptx_step run with autocast_smart_context_manager - logger.info("Doing rl step...") - self.timers and self.timers("rl_step").start() + # logger.info("Doing rl step...") + self.timers and self.timers(get_timer_label(ActorStages.MODEL_ENABLE_DISABLE)).start() with self.enable(self.actor_model, self.policy_trainer.optimizer): - # with self.enable(self.value_trainer.optimizer): - with self.enable(): # put value optimizer guard in rl_step - rl_info = self.rl_step(rl_batch) - empty_device_cache() - self.timers and self.timers("rl_step").stop() - + self.timers and self.timers(get_timer_label(ActorStages.RL_STEP)).start() + rl_info = self.rl_step(rl_batch) + self.timers and self.timers(get_timer_label(ActorStages.RL_STEP)).stop() if self.use_ptx: logger.info("Doing ptx step...") - self.timers and self.timers("ptx_step").start() + self.timers and self.timers(get_timer_label(ActorStages.PTX_STEP)).start() with guard_set_args( self._model_config, { # "set_attn_func": True, - # "use_flash_attention": True + "use_flash_attention": True }, ): ptx_info = self.ptx_step(ptx_batch) rl_info.update(ptx_info) - self.timers and self.timers("ptx_step").stop() - empty_device_cache() - - self.state.global_step += 1 - self.state.epoch = epoch + (step + 1) / steps_in_epoch + self.timers and self.timers(get_timer_label(ActorStages.PTX_STEP)).stop() + if self.timers: + self.timers(get_timer_label(ActorStages.MODEL_ENABLE_DISABLE)).stop() + self.timers(get_timer_label(ActorStages.MODEL_ENABLE_DISABLE)).elapsed_ -= self.timers( + get_timer_label(ActorStages.RL_STEP) + ).elapsed_ + if self.use_ptx: + self.timers(get_timer_label(ActorStages.MODEL_ENABLE_DISABLE)).elapsed_ -= self.timers( + get_timer_label(ActorStages.PTX_STEP) + ).elapsed_ + + paddle.device.cuda.empty_cache() + if self.args.rl_algorithm == "ppo": + rl_critic_info = self.rl_critic_step(rl_batch) + rl_info.update(rl_critic_info) if self.is_step_end(): - rl_info.update(self.get_step_loss(loss_prefix="train/")) + self.state.global_step += 1 + self.state.epoch = epoch + (step + 1) / steps_in_epoch + rl_info.update(self.get_step_loss(loss_prefix="train_")) rl_info = metric.update(rl_info) # on_step_end self.control = self.callback_handler.on_step_end(args, self.state, self.control) else: # on_sub_step_end self.control = self.callback_handler.on_substep_end(args, self.state, self.control) - self._maybe_log_save_evaluate(rl_info, None, epoch, ignore_keys_for_eval, inputs=inputs) self._print_timer() + self._maybe_log_save_evaluate(rl_info, None, epoch, ignore_keys_for_eval, inputs=inputs) + paddle.device.cuda.empty_cache() + if self.control.should_epoch_stop or self.control.should_training_stop: break @@ -1253,26 +1877,109 @@ def train( self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) # argument model is not used in _maybe_log_save_evaluate, thus use None self._maybe_log_save_evaluate(rl_info, None, epoch, ignore_keys_for_eval, inputs=inputs) - self._print_timer() if self.control.should_training_stop: break # TODO(guosheng): add epilogue of training + logger.info("\nTraining completed. \n") + if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: + if args.local_rank != -1: + dist.barrier() - def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval, **kwargs): - if self.control.should_log: + best_model_checkpoint = json.loads(self.state.best_model_checkpoint) + + logger.info(f"Loading best model from {best_model_checkpoint['value']}(score: {self.state.best_metric}).") + self.load_best_ckpt(best_model_checkpoint["value"], self.value_trainer) + + logger.info(f"Loading best model from {best_model_checkpoint['policy']}(score: {self.state.best_metric}).") + self.load_best_ckpt(best_model_checkpoint["policy"], self.policy_trainer) + + metrics = speed_metrics( + "train", + start_time, + num_samples=num_train_samples, + num_steps=self.state.max_steps, + ) + + self.is_in_train = False + self.log(metrics) + self.control = self.callback_handler.on_train_end(args, self.state, self.control) + tr_loss = 0.0 + for history in self.state.log_history: + if "train_policy_loss" in history: + tr_loss += history["train_policy_loss"] + tr_loss = tr_loss / self.state.global_step + return TrainOutput(self.state.global_step, tr_loss, metrics) + + def load_best_ckpt(self, model_path, trainer, **kwargs): + """ + Load the best checkpoint from the given path into the specified trainer. + + Args: + args (TrainingArguments): The arguments object containing the configuration settings. + model_path (str): The path to the directory where the best checkpoint is located. + trainer (Trainer): The trainer instance that will receive the loaded weights. + kwargs (Any, optional): Additional keyword arguments passed to the `load_unified_checkpoint` function. + """ + from paddlenlp.trainer.utils.helper import broadcast_dataset_rank0_model + if trainer.args.unified_checkpoint: + trainer.unified_checkpoint_handler.load_unified_checkpoint( + trainer.model, + model_path, + ) + if trainer.args.sharding_parallel_degree > 1 or trainer.args.data_parallel_degree > 1: + broadcast_dataset_rank0_model(trainer.model) + else: + weight_name = PADDLE_WEIGHTS_NAME + best_model_path = os.path.join( + model_path, + _add_variant(weight_name, trainer.args.weight_name_suffix), + ) + if os.path.exists(best_model_path): + # We load the model state dict on the CPU to avoid an OOM error. + state_dict = paddle.load(best_model_path, return_numpy=True) + # If the model is on the GPU, it still works! + trainer._set_state_dict_in_model(state_dict) + else: + logger.warning( + f"Could not locate the best model at {best_model_path}, if you are running a distributed training " + "on multiple nodes, you should activate `--save_on_each_node`." + ) + + def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval, **kwargs): + """ + 记录、保存和评估,如果需要。 + 如果控制变量指示应该记录,则记录损失,并将模型保存到磁盘上。 + 如果控制变量指示应该评估,则评估模型并将结果保存到磁盘上。 + + Args: + tr_loss (Optional[Dict[str, float]]): 字典形式的训练损失,包含键'train_policy_loss'和'train_ptx_loss'。 + 如果为None,则不记录任何内容。默认为None。 + model (Model): 用于评估的模型。 + epoch (int): 当前迭代次数。 + ignore_keys_for_eval (List[str]): 在评估时要忽略的键列表。默认为空列表。 + kwargs (Any, optional): 其他可选参数,将被传递给`log()`和`save()`方法。默认为空字典。 + + Returns: + None. + + Raises: + None. + """ + if self.control.should_log and tr_loss is not None: logs: Dict[str, float] = {} # use_ptx would double the gradient_accumulation_steps which causes - # actor_loss and ptx_loss reduced by half. Moreover, ptx_loss should + # policy_loss and ptx_loss reduced by half. Moreover, ptx_loss should # be divided by ptx_coeff for logging. - if "train/ptx_loss" in tr_loss: - tr_loss["train/actor_loss"] = tr_loss["train/actor_loss"] * 2 - tr_loss["train/ptx_loss"] = tr_loss["train/ptx_loss"] * 2 / self.ptx_coeff + if "train_ptx_loss" in tr_loss: + tr_loss["train_policy_loss"] = tr_loss["train_policy_loss"] * 2 + tr_loss["train_ptx_loss"] = tr_loss["train_ptx_loss"] * 2 / self.ptx_coeff logs.update(tr_loss) logs["global_step"] = int(self.state.global_step) - logs["train/actor_lr"] = float("{0:.3e}".format(self.policy_trainer._get_learning_rate())) - logs["train/reward_critic_lr"] = float("{0:.3e}".format(self.value_trainer._get_learning_rate())) + logs["train_actor_lr"] = float(f"{self.policy_trainer._get_learning_rate():.3e}") + if self.args.rl_algorithm == "ppo": + logs["train_reward_critic_lr"] = float(f"{self.value_trainer._get_learning_rate():.3e}") total_train_batch_size = ( self.args.train_batch_size * self.args.gradient_accumulation_steps * self.args.dataset_world_size @@ -1304,6 +2011,18 @@ def add_kl_divergence_regularization( reward_score: paddle.Tensor, # size = (B,) sequence_mask: paddle.Tensor, # size = (B, L) ) -> paddle.Tensor: + """ + 计算KL散度迭代增益,并将其添加到回报中。 + 参数: + prompt (paddle.Tensor, shape=(B, S)): 输入序列的prompt,未使用。 + log_probs (paddle.Tensor, shape=(B, L)): 当前预测的log概率分布。 + ref_log_probs (paddle.Tensor, shape=(B, L)): 基线预测的log概率分布。 + reward_score (paddle.Tensor, shape=(B,)): 基于prompt和输出序列的基本奖励得分。 + sequence_mask (paddle.Tensor, shape=(B, L)): 序列的mask,用于确定序列的长度。 + 返回值(paddle.Tensor, shape=(B, L)}: + 包含KL散度迭代增益的向量。 + """ + kl_divergence_estimate = -self.kl_coeff * (log_probs - ref_log_probs) # size = (B, L) rewards = kl_divergence_estimate # size = (B, L) reward_clip = paddle.clip( # size = (B,) @@ -1313,15 +2032,15 @@ def add_kl_divergence_regularization( ) # TODO(guosheng): use scatter_add/put_along_axis index = paddle.cumsum(sequence_mask.cast(paddle.int64), axis=-1).argmax(-1, keepdim=True) - rewards = paddle.put_along_axis(rewards, index, reward_clip.unsqueeze(axis=-1), axis=-1, reduce="add") - # batch_size = log_probs.shape[0] - # for i in range(batch_size): - # # print("="*20, sequence_mask[i]) - # end_index = sequence_mask[i].nonzero()[-1] - # # rewards[i, end_index] += reward_clip[i] - # rewards[i, end_index] = rewards[i, end_index] + reward_clip[i] - return rewards + rewards = paddle.put_along_axis( + rewards, + index, + reward_clip.unsqueeze(axis=-1), + axis=-1, + reduce="add", + ) + return rewards, kl_divergence_estimate def get_advantages_and_returns( self, @@ -1354,37 +2073,35 @@ def get_advantages_and_returns( last_gae_lambda = delta + self.gamma * self.gae_lambda * last_gae_lambda advantages_reversed.append(last_gae_lambda) advantages = paddle.stack(advantages_reversed[::-1], axis=1) - returns = advantages + values[:, start:] + returns = advantages + values[:, start:].contiguous() + if not use_tgt_len_return: advantages = paddle.concat( - [paddle.zeros([advantages.shape[0], start], dtype=advantages.dtype), advantages], -1 + [ + paddle.zeros([advantages.shape[0], start], dtype=advantages.dtype), + advantages, + ], + -1, ) - returns = paddle.concat([paddle.zeros([returns.shape[0], start], dtype=returns.dtype), returns], -1) + returns = paddle.concat( + [ + paddle.zeros([returns.shape[0], start], dtype=returns.dtype), + returns, + ], + -1, + ) + return advantages.detach(), returns def rl_step(self, rl_batch: Dict[str, paddle.Tensor]) -> Dict[str, Any]: # inputs shared by policy and value trainer - input_ids = rl_batch["input_ids"] # length: src+tgt + input_ids = rl_batch["input_ids"].contiguous() # length: src+tgt attention_mask = rl_batch["attention_mask"] # length: src+tgt position_ids = rl_batch["position_ids"] # length: src+tgt sequence_mask = rl_batch["sequence_mask"] # length: src+tgt(-1) # inputs used by policy trainer old_log_probs = rl_batch["log_probs"] # length: src+tgt(-1) reward_advantages = rl_batch["reward_advantages"] # length: src+tgt(-1) - # inputs used by value trainer - old_reward_values = rl_batch["reward_values"] # length: src+tgt(-1) - reward_returns = rl_batch["reward_returns"] # length: src+tgt(-1) - - value_trainer_inputs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "position_ids": position_ids, - "old_reward_values": old_reward_values, - "reward_returns": reward_returns, - "sequence_mask": sequence_mask, - } - with self.enable(self.reward_critic_model, self.value_trainer.optimizer): - reward_critic_loss = self.value_trainer.full_training_step(**value_trainer_inputs) policy_trainer_inputs = { "input_ids": input_ids, @@ -1394,28 +2111,103 @@ def rl_step(self, rl_batch: Dict[str, paddle.Tensor]) -> Dict[str, Any]: "reward_advantages": reward_advantages, "sequence_mask": sequence_mask, } + + if self.args.rl_algorithm == "grpo": + policy_trainer_inputs.update({"ref_log_probs": rl_batch["ref_log_probs"]}) + actor_loss = self.policy_trainer.full_training_step(**policy_trainer_inputs) # metric with paddle.no_grad(): - rewards = rl_batch["rewards"] - rewards = rewards.mean() + rewards = rl_batch["rewards"].mean() + ori_rewards = rl_batch["ori_rewards"].mean() + mask_cast = sequence_mask.cast(paddle.float32) + if self.args.rl_algorithm == "ppo": + kl_rewards = (rl_batch["kl_rewards"] * mask_cast).sum() / mask_cast.sum() + rewards_with_kl = (rl_batch["rewards_with_kl"] * mask_cast).sum() / mask_cast.sum() + values = (rl_batch["reward_values"] * mask_cast).sum() / mask_cast.sum() + returns = (rl_batch["reward_returns"] * mask_cast).sum() / mask_cast.sum() ref_log_probs = rl_batch["ref_log_probs"] - kl_divergence = ((old_log_probs - ref_log_probs) * sequence_mask).sum(axis=-1).mean() - mean_generated_length = sequence_mask.cast(paddle.float32).sum(axis=-1).mean() - max_generated_length = sequence_mask.cast(paddle.float32).sum(axis=-1).max() + # kl_divergence = ((old_log_probs - ref_log_probs) * sequence_mask).sum(axis=-1).mean() + kl_divergence = ((old_log_probs - ref_log_probs) * mask_cast).sum() / mask_cast.sum() + mean_generated_length = mask_cast.sum(axis=-1).mean() + max_generated_length = mask_cast.sum(axis=-1).max() + min_generated_length = mask_cast.sum(axis=-1).min() return { # when using PipelienParallel, the loss returned is 0 when not reach # accumulated step and the loss returned at accumulated step is a # mixed loss. - "train/actor_loss": actor_loss, - "train/reward_critic_loss": reward_critic_loss, - "train/reward": rewards, - "train/kl_divergence": kl_divergence, - "train/mean_generated_length": mean_generated_length, - "train/max_generated_length": max_generated_length, + "train_policy_loss": actor_loss, + **( + { + "train_pure_policy_loss": self.policy_trainer.info_buffer.get("pure_policy_loss"), + "train_kl_loss": self.policy_trainer.info_buffer.get("kl_loss"), + } + if self.args.rl_algorithm == "grpo" + else {} + ), + "train_reward": ori_rewards, # use original reward to log + **( + { + "train_norm_reward": rewards, + "train_kl_reward": kl_rewards, + "train_norm_reward_with_kl": rewards_with_kl, + "train_values": values, + "train_returns": returns, + } + if self.args.rl_algorithm == "ppo" + else {} + ), + "train_kl_divergence": kl_divergence, + "train_mean_generated_length": mean_generated_length, + "train_max_generated_length": max_generated_length, + "train_min_generated_length": min_generated_length, + } + + def rl_critic_step(self, rl_batch: Dict[str, paddle.Tensor]) -> Dict[str, Any]: + """ + 更新评价函数(奖励函数)的参数。 + 该函数需要接收一个字典类型的参数,包括以下键值对: + - input_ids (paddle.Tensor): 输入序列的ID,形状为(src+tgt, batch)。 + - attention_mask (paddle.Tensor): 输入序列的注意力掩码,形状为(src+tgt, batch)。 + - position_ids (paddle.Tensor): 输入序列的位置ID,形状为(src+tgt, batch)。 + - old_reward_values (paddle.Tensor): 上一时间步的奖励值,形状为(src+tgt-1, batch)。 + - reward_returns (paddle.Tensor): 回报返回值,形状为(src+tgt-1, batch)。 + - sequence_mask (paddle.Tensor): 序列掩码,形状为(src+tgt-1, batch)。 + 返回值(Dict[str, Any]): + - train_value_loss (float): 评价函数(奖励函数)的训练损失。 + """ + # inputs shared by policy and value trainer + input_ids = rl_batch["input_ids"].contiguous() # length: src+tgt + attention_mask = rl_batch["attention_mask"] # length: src+tgt + position_ids = rl_batch["position_ids"] # length: src+tgt + sequence_mask = rl_batch["sequence_mask"] # length: src+tgt(-1) + # inputs used by value trainer + old_reward_values = rl_batch["reward_values"] # length: src+tgt(-1) + reward_returns = rl_batch["reward_returns"] # length: src+tgt(-1) + + value_trainer_inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + "old_reward_values": old_reward_values, + "reward_returns": reward_returns, + "sequence_mask": sequence_mask, } + self.timers and self.timers(get_timer_label(CriticStages.MODEL_ENABLE_DISABLE)).start() + with self.enable(self.reward_critic_model, self.value_trainer.optimizer): + self.timers and self.timers(get_timer_label(CriticStages.CRITIC_TRAINING_STEP)).start() + reward_critic_loss = self.value_trainer.full_training_step(**value_trainer_inputs) + self.timers and self.timers(get_timer_label(CriticStages.CRITIC_TRAINING_STEP)).stop() + + if self.timers: + self.timers and self.timers(get_timer_label(CriticStages.MODEL_ENABLE_DISABLE)).stop() + self.timers(get_timer_label(CriticStages.MODEL_ENABLE_DISABLE)).elapsed_ -= self.timers( + get_timer_label(CriticStages.CRITIC_TRAINING_STEP) + ).elapsed_ + + return {"train_value_loss": reward_critic_loss} def ptx_step(self, ptx_batch: Dict[str, paddle.Tensor]) -> Dict[str, Any]: """Perform a single update step with PTX loss.""" @@ -1424,29 +2216,49 @@ def ptx_step(self, ptx_batch: Dict[str, paddle.Tensor]) -> Dict[str, Any]: # "position_ids", make_position_ids(ptx_batch["attention_mask"])) ptx_loss = self.policy_trainer.full_training_step(**ptx_batch) return { - "train/ptx_loss": ptx_loss, + "train_ptx_loss": ptx_loss, } def enable(self, *args): + """ + 启用指定的对象或方法。 + 如果指定的对象是模型,则会将其设置为训练状态;如果是优化器,则会将其设置为训练状态。 + 如果指定的方法是"train_model",则会将所有需要训练的模型设置为训练状态。 + 如果指定的方法是"freeze_model",则会将所有不需要训练的模型设置为非训练状态。 + 如果指定的方法是"optimizer",则会将所有需要训练的优化器设置为训练状态。 + 如果指定的方法是"",则会返回一个包含所有需要训练的对象和方法的元组列表。 + + Args: + args (Tuple[Any], optional): 可选参数,默认值为空元组,表示需要启用所有需要训练的对象和方法。支持多个参数,每个参数只能是一个模型、优化器或方法。 + + Returns: + Union[Tuple[Tuple[Any, str]], Enable]: 如果传入了参数,则返回一个包含所有需要训练的对象和方法的元组列表;否则返回一个Enable实例,用于启用所有需要训练的对象和方法。 + """ # note: must keep the same model since actor_model, reward_model etc. # are property enable_map = { # maybe use `model: (pattern, enable_method, disable_method)`` self.actor_model: "train_model", - self.reward_critic_model: "train_model", self.reference_model: "freeze_model", - self.reward_model: "freeze_model", + **({self.reward_model: "freeze_model"} if not self.args.use_rm_server else {}), self.policy_trainer.optimizer: "optimizer", - self.value_trainer.optimizer: "optimizer", } + if self.args.rl_algorithm == "ppo": + enable_map.update( + { + self.reward_critic_model: "train_model", + self.value_trainer.optimizer: "optimizer", + } + ) # if use an extra eval model to do eval/generation, switch on actor_model # and reward_critic_model; otherwise no need to switch if getattr(self.policy_trainer, "_inner_eval_model", None) is not None: - enable_map.pop(self.actor_model) - if getattr(self.value_trainer, "_inner_eval_model", None) is not None: - enable_map.pop(self.reward_critic_model) - objs = [arg for arg in args if enable_map.get(arg, "") in self.args.offload_level] - return enable(*objs) + enable_map.update({self.policy_trainer._inner_eval_model: "freeze_model"}) + if self.args.rl_algorithm == "ppo" and getattr(self.value_trainer, "_inner_eval_model", None) is not None: + enable_map.update({self.value_trainer._inner_eval_model: "freeze_model"}) + # NOTE(GONGENLEI): new offload + objs = [(arg, enable_map.get(arg, "")) for arg in args if enable_map.get(arg, "") in self.args.offload_level] + return Enable(objs) def split_ptx_micro_batches( self, @@ -1458,24 +2270,12 @@ def split_ptx_micro_batches( micro_batch_size = self.args.per_device_train_batch_size for i in range(0, total_batch_size, micro_batch_size): micro_batch = map_structure( - # pylint: disable-next=cell-var-from-loop - lambda tensor: tensor[i : i + micro_batch_size], # noqa: B023 + lambda tensor: tensor[i : i + micro_batch_size], ptx_batch, ) micro_batches.append(micro_batch) return micro_batches - # @staticmethod - # def data_dispatch(fun): - # def _impl(self, data): - # gp = getattr(self.policy_trainer, "_data_trans_group", None) - # data = data_group_split(data, group=gp) - # data = fun(self, data) - # data = data_group_merge(data, group=gp) - # return data - - # return _impl - @paddle.no_grad() @data_dispatch # 3.10 static methods are now callable as regular functions. def split_rl_micro_batches( @@ -1484,96 +2284,264 @@ def split_rl_micro_batches( ) -> List[Dict]: """Split a batch of RL samples into micro-batches.""" total_batch_size = prompt_only_batch["input_ids"].shape[0] - micro_batch_size = self.args.per_device_train_batch_size + # micro_batch_size = self.args.per_device_train_batch_size + per_device_rollout_batch_size = self.args.per_device_rollout_batch_size + per_device_train_batch_size = self.args.per_device_train_batch_size micro_batches = [] # TODO(guosheng): clean get_epoch_iterator: # 1. scope guard for offload, we would split post_rollout into multiple # sub-methods to offload in-time # 2. decorate split_rl_micro_batches to automatically split/merge data + + self.timers and self.timers(get_timer_label(RolloutStages.ACTOR_MODEL_ENABLE_DISABLE)).start() with self.enable(self.actor_model, self.reference_model): # generate for multi batches and then disable FuseMT model + cleanup_batches = [] + indices = [] + if self.args.use_rm_server: + label_ids_batches = [] + self.timers and self.timers(get_timer_label(RolloutStages.GENERATE)).start() with infer_guard(self.policy_trainer): - # dist.barrier() - # print("="*20, "begin generate") - for i in range(0, total_batch_size, micro_batch_size): + for i in range(0, total_batch_size, per_device_rollout_batch_size): micro_batch = {} micro_batch = map_structure( - lambda tensor: tensor[i : i + micro_batch_size], + lambda tensor: tensor[i : i + per_device_rollout_batch_size], prompt_only_batch, ) - micro_batches.extend(self.generate(micro_batch)) - # dist.barrier() - # paddle.device.cuda.synchronize() + generated_batches = self.generate(micro_batch) + + for batch in generated_batches: + cleanup_batches.extend( + [ + process_row( + row, + remove_value=self.tokenizer.pad_token_id, + remove_side="right", + ) + for row in batch["input_ids"] + ] + ) + if self.args.use_rm_server: + label_ids_batches.extend( + [ + process_row( + row, + remove_value=self.tokenizer.pad_token_id, + remove_side="right", + ) + for row in batch["label_ids"] + ] + ) + indices.append(batch["index"]) + indices = np.concatenate(indices) + self.timers and self.timers(get_timer_label(RolloutStages.GENERATE)).stop() # get log_probs for multi batches and then disable actor/refer rmodel - for micro_batch in micro_batches: + origin_padding_side = self.tokenizer.padding_side + self.tokenizer.padding_side = "right" + self.timers and self.timers(get_timer_label(RolloutStages.ROLLOUT_LOGPROB)).start() + for i in range(0, len(cleanup_batches), per_device_train_batch_size): # position_ids is necessary for non-right padding # If using right padding source + left padding target, make padding positions # in source be 0, since reward model use position_ids plus with padding size # (number of 0s) in source to calculate end offsets. - micro_batch["position_ids"] = make_position_ids(micro_batch["attention_mask"]) + + padding_strategy = "longest" + padding_max_len = None + + if self._model_config.sequence_parallel: + padding_strategy = "max_length" + padding_max_len = self._model_config.max_sequence_length + + truncate_max_len = self._model_config.max_position_embeddings + + cur_batch = [] + for batch in cleanup_batches[i : i + per_device_train_batch_size]: + if len(batch) > truncate_max_len: + cur_batch.append( + self.tokenizer.truncate_sequences( + batch, + num_tokens_to_remove=len(batch) - truncate_max_len, + truncation_strategy="longest_first", + )[0] + ) + else: + cur_batch.append(batch) + + input_ids = self.tokenizer.pad( + {"input_ids": cur_batch}, + padding=padding_strategy, + max_length=padding_max_len, + return_attention_mask=False, + )["input_ids"] + + sequence_mask = make_attention_mask( + input_ids, + pad_id=self.tokenizer.pad_token_id, + eos_id=None, + unk_id=self.tokenizer.unk_token_id, + causal_mask=False, + ).cast(self._model_config.dtype) + attention_mask = make_attention_mask( + input_ids, + pad_id=self.tokenizer.pad_token_id, + eos_id=None, + unk_id=self.tokenizer.unk_token_id, + causal_mask=True, + ).cast(self._model_config.dtype) + position_ids = make_position_ids(attention_mask) + prompt = prompt_only_batch["input_ids"][i : i + per_device_train_batch_size] + + micro_batch = { + "prompt": prompt, + "input_ids": input_ids, + "sequence_mask": sequence_mask, + "attention_mask": attention_mask, + "position_ids": position_ids, + "index": indices[i : i + per_device_train_batch_size], + **( + {"label_ids": label_ids_batches[i : i + per_device_train_batch_size]} + if self.args.use_rm_server + else {} + ), + } micro_batch.update(self.rollout_logprob(**micro_batch)) - # print("="*20, "micro_batch", micro_batch) + micro_batches.append(micro_batch) + self.timers and self.timers(get_timer_label(RolloutStages.ROLLOUT_LOGPROB)).stop() + self.tokenizer.padding_side = origin_padding_side + if self.timers: + self.timers(get_timer_label(RolloutStages.ACTOR_MODEL_ENABLE_DISABLE)).stop() + self.timers(get_timer_label(RolloutStages.ACTOR_MODEL_ENABLE_DISABLE)).elapsed_ -= self.timers( + get_timer_label(RolloutStages.GENERATE) + ).elapsed_ + self.timers(get_timer_label(RolloutStages.ACTOR_MODEL_ENABLE_DISABLE)).elapsed_ -= self.timers( + get_timer_label(RolloutStages.ROLLOUT_LOGPROB) + ).elapsed_ # get reward/value for multi batches and then disable reward/value model - with self.enable(self.reward_critic_model, self.reward_model): + self.timers and self.timers(get_timer_label(RolloutStages.REWARD_MODEL_ENABLE_DISABLE)).start() + with self.enable( + self.reward_critic_model if self.args.rl_algorithm == "ppo" else None, + self.reward_model if not self.args.use_rm_server else None, + ): + self.timers and self.timers(get_timer_label(RolloutStages.ROLLOUT_REWARD_VALUE)).start() for micro_batch in micro_batches: micro_batch.update(self.rollout_reward_value(**micro_batch)) + self.timers and self.timers(get_timer_label(RolloutStages.ROLLOUT_REWARD_VALUE)).stop() + if self.timers: + self.timers and self.timers(get_timer_label(RolloutStages.REWARD_MODEL_ENABLE_DISABLE)).stop() + self.timers(get_timer_label(RolloutStages.REWARD_MODEL_ENABLE_DISABLE)).elapsed_ -= self.timers( + get_timer_label(RolloutStages.ROLLOUT_REWARD_VALUE) + ).elapsed_ + + micro_batches = self.normalize_batch_data(micro_batches, use_tgt_len_value=self.args.use_tgt_len_value) - # - micro_batches = [self.normalize_data(micro_batch, use_tgt_len_value=False) for micro_batch in micro_batches] # size of micro_batches (num of training batch) would be: # per_device_prompt_batch_size * num_return_sequences // per_device_train_batch_size # micro_batches = [self.post_rollout(**micro_batch) for micro_batch in micro_batches] return micro_batches @paddle.no_grad() - def generate(self, prompt_only_batch: Dict) -> List[Dict[str, Any]]: + def generate(self, prompt_only_batch: Dict, do_eval=False) -> List[Dict[str, Any]]: """Rollout a batch of experiences.""" input_ids = prompt_only_batch["input_ids"] attention_mask = prompt_only_batch["attention_mask"] + if do_eval: + train_num_return_sequences = self.args.num_return_sequences + self.args.num_return_sequences = 1 + + position_ids = ( + prompt_only_batch["position_ids"] + if "position_ids" in prompt_only_batch + else make_position_ids(attention_mask) + ) + + if self.args.num_return_sequences > 1: + input_ids = input_ids.repeat_interleave(self.args.num_return_sequences, axis=0) + raw_dtype = attention_mask.dtype + attention_mask = ( + attention_mask.cast("int32").repeat_interleave(self.args.num_return_sequences, axis=0).cast(raw_dtype) + ) + position_ids = position_ids.repeat_interleave(self.args.num_return_sequences, axis=0) - self.timers and self.timers("actor-model-generate").start() sequences = self.actor_model.generate( input_ids=input_ids, attention_mask=attention_mask, - position_ids=prompt_only_batch["position_ids"] - if "position_ids" in prompt_only_batch - else make_position_ids(attention_mask), + position_ids=position_ids, generation_config=self.generation_config, synced_gpus=ShardingOption.FULL_SHARD in self.policy_trainer.args.sharding, )[0] - self.timers and self.timers("actor-model-generate").stop() - sequences = sequences.reshape([input_ids.shape[0], self.args.num_return_sequences, -1]).transpose([1, 0, 2]) + if self.args.use_rm_server: + label_ids = prompt_only_batch["label_ids"] + if self.args.num_return_sequences > 1: + label_ids = label_ids.repeat_interleave(self.args.num_return_sequences, axis=0) + sequences = sequences.reshape( + [input_ids.shape[0] // self.args.num_return_sequences, self.args.num_return_sequences, -1] + ) + if do_eval: + self.args.num_return_sequences = train_num_return_sequences + sequences = sequences.transpose([1, 0, 2]) # prompt, sequence, attention_mask return [ { "prompt": input_ids, - "input_ids": seq, # "sequence": + "input_ids": seq, + **({"label_ids": label_ids[idx * len(seq) : (idx + 1) * len(seq)]} if self.args.use_rm_server else {}), + "index": np.array([str(uuid.uuid4())] * len(seq), dtype=object), "attention_mask": make_attention_mask( seq, pad_id=self.tokenizer.pad_token_id, + eos_id=None, + unk_id=self.tokenizer.unk_token_id, + causal_mask=True, + ).cast(self._model_config.dtype), + "sequence_mask": make_attention_mask( + seq, + pad_id=self.tokenizer.pad_token_id, + eos_id=None, unk_id=self.tokenizer.unk_token_id, causal_mask=False, - ), - # "sequence_mask": make_attention_mask( - # seq, - # pad_id=self.tokenizer.pad_token_id, - # unk_id=self.tokenizer.unk_token_id, - # causal_mask=False, - # ), + ).cast(self._model_config.dtype), } - for seq in sequences + for idx, seq in enumerate(sequences) ] @paddle.no_grad() def rollout_logprob( - self, input_ids: paddle.Tensor, attention_mask: paddle.Tensor, position_ids: paddle.Tensor = None, **kwargs + self, + input_ids: paddle.Tensor, + attention_mask: paddle.Tensor, + position_ids: paddle.Tensor = None, + **kwargs, ) -> Dict[str, paddle.Tensor]: + """ + 计算rollout过程中每个token的log probability。 + + Args: + input_ids (paddle.Tensor, shape [batch_size, sequence_length]): + 输入序列,其中每个元素都是一个int,表示各自token的ID。 + attention_mask (paddle.Tensor, shape [batch_size, sequence_length]): + 输入序列的attention mask,其中每个元素为0或1,用于指示哪些tokens应该被模型考虑。 + position_ids (paddle.Tensor, optional, shape [batch_size, sequence_length], defaults to None): + 输入序列中每个token的位置ID,默认为None。 + kwargs (Dict[str, Any], optional, defaults to {}): + 可选参数,目前未使用。 + + Returns: + Dict[str, paddle.Tensor]: + - log_probs (paddle.Tensor, shape [batch_size, sequence_length - 1]): + 每个token在rollout过程中的log probability。 + - ref_log_probs (paddle.Tensor, shape [batch_size, sequence_length - 1]): + 每个token在rollout过程中的reference log probability。 + + Raises: + None. + """ # pipe model outputs a logits tensor with LMHead, while non-pipe model # outputs a tuple with logits tensor as the only one element. + logits = self.actor_model( input_ids, attention_mask=attention_mask, @@ -1581,223 +2549,356 @@ def rollout_logprob( # return_dict=True, ) # .logits if not isinstance(logits, paddle.Tensor): - logits = logits[0] + logits = logits[0] # [2, 355, 12544] ref_logits = self.reference_model( input_ids, attention_mask=attention_mask, position_ids=position_ids, # return_dict=True, ) # .logits + if not isinstance(ref_logits, paddle.Tensor): - ref_logits = ref_logits[0] - log_probs = gather_log_probabilities(logits[:, :-1], input_ids[:, 1:]) - ref_log_probs = gather_log_probabilities(ref_logits[:, :-1], input_ids[:, 1:]) - return {"log_probs": log_probs, "ref_log_probs": ref_log_probs} + ref_logits = ref_logits[0] # [2, 355, 12544] - @paddle.no_grad() - def rollout_reward_value( - self, input_ids: paddle.Tensor, attention_mask: paddle.Tensor, position_ids: paddle.Tensor = None, **kwargs - ) -> Dict[str, paddle.Tensor]: - if self.reward_tokenizer is not self.tokenizer: - # right padding - reward_tokenize_output = batch_retokenize( - input_ids, - src_tokenizer=self.tokenizer, - dest_tokenizer=self.reward_tokenizer, - skip_special_tokens=True, - ) - reward_input_ids = reward_tokenize_output["input_ids"] - reward_attention_mask = make_attention_mask( - reward_input_ids, - pad_id=self.reward_tokenizer.pad_token_id, - unk_id=self.reward_tokenizer.unk_token_id, - causal_mask=False, + if self.actor_model.config.tensor_parallel_degree > 1 and self.actor_model.config.tensor_parallel_output: + log_probs = ( + -ParallelCrossEntropy()(logits[:, :-1].astype("float32"), input_ids[:, 1:]) + .squeeze(axis=-1) + .astype(logits.dtype) ) - reward_position_ids = make_position_ids(reward_attention_mask) else: - # for text in self.tokenizer.batch_decode(input_ids, skip_special_tokens=False): - # print(text) - reward_input_ids = input_ids - reward_attention_mask = attention_mask - reward_position_ids = position_ids - reward_score = self.reward_model( - reward_input_ids, - attention_mask=reward_attention_mask, - position_ids=reward_position_ids, - # return_dict=True, - )[ - 1 - ] # .end_scores + log_probs = gather_log_probabilities(logits[:, :-1], input_ids[:, 1:]) - reward_value = self.reward_critic_model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - # return_dict=True, - )[ - 0 - ] # .scores - reward_score = reward_score.squeeze(axis=-1) - reward_value = reward_value.squeeze(axis=-1) + if ( + self.reference_model.config.tensor_parallel_degree > 1 + and self.reference_model.config.tensor_parallel_output + ): + ref_log_probs = ( + -ParallelCrossEntropy()(ref_logits[:, :-1].astype("float32"), input_ids[:, 1:]) + .squeeze(axis=-1) + .astype(ref_logits.dtype) + ) + else: + ref_log_probs = gather_log_probabilities(ref_logits[:, :-1], input_ids[:, 1:]) - reward_value = reward_value[:, :-1] - return {"rewards": reward_score, "reward_values": reward_value} + return {"log_probs": log_probs, "ref_log_probs": ref_log_probs} @paddle.no_grad() - def post_rollout( + def rollout_reward_value( self, - prompt: paddle.Tensor, - sequence: paddle.Tensor, + input_ids: paddle.Tensor, attention_mask: paddle.Tensor, - ) -> Dict[str, Any]: - if self.reward_tokenizer is not self.tokenizer: - # right padding - reward_tokenize_output = batch_retokenize( - sequence, - src_tokenizer=self.tokenizer, - dest_tokenizer=self.reward_tokenizer, - skip_special_tokens=True, - ) - reward_seq = reward_tokenize_output["input_ids"] - reward_attention_mask = reward_tokenize_output["attention_mask"] + position_ids: paddle.Tensor = None, + **kwargs, + ) -> Dict[str, paddle.Tensor]: + """ + 根据输入的序列,计算每个时间步骤的奖励值和奖励得分。如果模型使用了不同的tokenizer,则先将输入序列转换为目标tokenizer的格式。 + + Args: + input_ids (paddle.Tensor): shape=[batch_size, seq_len], 输入序列的ID,取值范围是[0, vocabulary_size - 1]。 + attention_mask (paddle.Tensor): shape=[batch_size, seq_len], 输入序列的注意力掩码,取值范围是{0, 1}。 + position_ids (Optional, paddle.Tensor, optional): shape=[batch_size, seq_len], 输入序列的位置ID,默认为None。 + kwargs (Dict, optional): 其他可选参数,包括: + reward_tokenizer (Tokenizer, optional): 奖励tokenizer,默认为None,表示使用与模型相同的tokenizer。 + + Returns: + Dict[str, paddle.Tensor]: 返回一个字典,包含两个键值对: + rewards (paddle.Tensor): shape=[batch_size, seq_len], 每个时间步骤的奖励得分,取值范围是[-inf, inf]。 + reward_values (paddle.Tensor): shape=[batch_size, seq_len-1], 每个时间步骤的奖励值,取值范围是[0, inf]。 + """ + if not self.args.use_rm_server: + if self.reward_tokenizer is not self.tokenizer: + # right padding + reward_tokenize_output = batch_retokenize( + input_ids, + src_tokenizer=self.tokenizer, + dest_tokenizer=self.reward_tokenizer, + ) + reward_input_ids = reward_tokenize_output["input_ids"] + reward_attention_mask = reward_tokenize_output["attention_mask"] + reward_position_ids = reward_tokenize_output["position_ids"] + else: + reward_input_ids = input_ids + reward_attention_mask = attention_mask + reward_position_ids = position_ids + + # .end_scores + reward_score = self.reward_model( + reward_input_ids, + attention_mask=reward_attention_mask, + position_ids=reward_position_ids, + # return_dict=True, + )[1] else: - # actor_model_in_use gen - # for text in self.tokenizer.batch_decode(sequence, skip_special_tokens=True): - # print(text) - reward_seq = sequence - reward_attention_mask = attention_mask - # position_ids is necessary for non-right padding - # If using right padding source + left padding target, make padding positions - # in source be 0, since reward model use position_ids plus with padding size - # (number of 0s) in source to calculate end offsets. - position_ids = make_position_ids(attention_mask) + prompt_len = kwargs["prompt"].shape[-1] + if "label_ids" not in kwargs: + raise ValueError("Rule-based reward needs labels.") + src = self.tokenizer.batch_decode(input_ids[:, :prompt_len], skip_special_tokens=True) + tgt = self.tokenizer.batch_decode(kwargs["label_ids"], skip_special_tokens=True) + response = self.tokenizer.batch_decode(input_ids[:, prompt_len:], skip_special_tokens=True) + reward_score = self.request_reward_server(src, tgt, response) - # pipe model outputs a logits tensor with LMHead, while non-pipe model - # outputs a tuple with logits tensor as the only one element. - self.timers and self.timers("actor-model-logit").start() - logits = self.actor_model( - sequence, - attention_mask=attention_mask, - position_ids=position_ids, - # return_dict=True, - ) # .logits - self.timers and self.timers("actor-model-logit").stop() - if not isinstance(logits, paddle.Tensor): - logits = logits[0] - self.timers and self.timers("reference-model-logit").start() - ref_logits = self.reference_model( - sequence, - attention_mask=attention_mask, - position_ids=position_ids, - # return_dict=True, - ) # .logits - self.timers and self.timers("reference-model-logit").stop() - if not isinstance(ref_logits, paddle.Tensor): - ref_logits = ref_logits[0] + reward_score = reward_score.squeeze(axis=-1) - self.timers and self.timers("reward-model-score").start() - reward_score = self.reward_model( - reward_seq, - attention_mask=reward_attention_mask, - position_ids=position_ids, - # return_dict=True, - )[ - 1 - ] # .end_scores + if self.args.rl_algorithm == "grpo": + return {"rewards": reward_score} + # .scores reward_value = self.reward_critic_model( - sequence, + input_ids, attention_mask=attention_mask, position_ids=position_ids, # return_dict=True, - )[ - 0 - ] # .scores - reward_score = reward_score.squeeze(axis=-1) + )[0] reward_value = reward_value.squeeze(axis=-1) - - self.timers and self.timers("reward-model-score").stop() reward_value = reward_value[:, :-1] - log_probs = gather_log_probabilities(logits[:, :-1], sequence[:, 1:]) - ref_log_probs = gather_log_probabilities(ref_logits[:, :-1], sequence[:, 1:]) - rollout_data = { - "prompt": prompt, - "input_ids": sequence, - "position_ids": position_ids, - "attention_mask": attention_mask, - "rewards": reward_score, - "reward_values": reward_value, - "log_probs": log_probs, - "ref_log_probs": ref_log_probs, - } - rollout_data = self.normalize_data(rollout_data, use_tgt_len_value=False) - return rollout_data + + return {"rewards": reward_score, "reward_values": reward_value} + + def request_reward_server(self, src, tgt, response): + data = {"src": src, "tgt": tgt, "response": response} + + def post(): + try: + res = requests.post(self.reward_server, json=data) + result = json.loads(res.text) + reward_score = paddle.to_tensor(result["score"], dtype=self._model_config.dtype) + except: + logger.warning("Request reward server failed and rewards_score will be set zero.") + reward_score = paddle.zeros(len(response), dtype=self._model_config.dtype) + return reward_score + + try: + hcg = fleet.get_hybrid_communicate_group() + tp_group = hcg.get_model_parallel_group() + nranks = tp_group.nranks + tp_rank = hcg.get_model_parallel_rank() + except: + nranks = 1 + tp_rank = 0 + + if nranks == 1: + reward_score = post() + else: + if tp_rank == 0: + reward_score = post() + else: + reward_score = paddle.empty(shape=[len(response)], dtype=self._model_config.dtype) + paddle.distributed.barrier(tp_group) + paddle.distributed.broadcast(reward_score, src=tp_group.ranks[0], group=tp_group) + + return reward_score.unsqueeze(-1) @paddle.no_grad() - def normalize_data( + def normalize_batch_data( self, - rl_batch: Dict[str, paddle.Tensor], + rl_batches: List[Dict[str, paddle.Tensor]], use_tgt_len_value: bool = False, ) -> Dict[str, Any]: """ data dispatch comm among devices needs padding, while the lengths of all data fields are different and related, and it's hard to pad. """ - prompt = rl_batch["prompt"] # length: src - attention_mask = rl_batch["attention_mask"] # length: src + tgt - if len(attention_mask.shape) == 4: - # use padding mask instead of causal mask - attention_mask = rl_batch["sequence_mask"] # length: src + tgt - old_log_probs = rl_batch["log_probs"] # length: src + tgt -1 - ref_log_probs = rl_batch["ref_log_probs"] # length: src + tgt -1 - rewards = rl_batch["rewards"] # length: 1 - old_reward_values = rl_batch["reward_values"] # length: src + tgt -1 - - # Beaver uses label data with target length, while we do not slice from - # inputs and use label data with target length: - # 1. Sometimes we cannot use label data with target length, mostly because - # it is hard to pad acorss batches. Think in some cases one batch might - # have the longest prompt+target length but the shortest target lengh, which - # might cause mismatch between inputs with prompt+target length and labels - # with target length. Padding acorss batches is needed in PP and data comm. - # 2. Additionally, when using flash_attn with casual mask and right padding - # we cannot use label data with target length. - start = prompt.shape[-1] - 1 - # sequence_mask is for label masking, make source be masked out - # clone to avoid to change attention_mask - sequence_mask = attention_mask[:, 1:].clone() # length: src + tgt -1 - sequence_mask[:, :start] = False - if use_tgt_len_value: - ref_log_probs = ref_log_probs[:, start:] - old_log_probs = old_log_probs[:, start:] - old_reward_values = old_reward_values[:, start:] - sequence_mask = sequence_mask[:, start:] - old_rewards = self.add_kl_divergence_regularization( - None, # prompt, - old_log_probs, - ref_log_probs, - rewards, - sequence_mask, - ) # length: tgt if use_tgt_len_value src + tgt -1 - reward_advantages, reward_returns = self.get_advantages_and_returns( - old_reward_values, - old_rewards, - sequence_mask, - start=0 if use_tgt_len_value else start, - use_tgt_len_return=use_tgt_len_value, - ) # length: tgt if use_tgt_len_value src + tgt -1 - - rl_batch.update( - { - "log_probs": old_log_probs, - "reward_values": old_reward_values, - "reward_advantages": reward_advantages, - "reward_returns": reward_returns, - "sequence_mask": sequence_mask, - "ref_log_probs": ref_log_probs, - "rewards": rewards, - } - ) - # pop out to reduce data dispatch comm overhead - rl_batch.pop("prompt") - return rl_batch + for rl_batch in rl_batches: + rl_batch["ori_rewards"] = rl_batch["rewards"].clone() + + use_reward_normalization = self.args.normalize_reward + use_advantage_normalization = self.args.normalize_advantage + + if use_reward_normalization: + batch_rewards_list = [rl_batch["rewards"] for rl_batch in rl_batches] + batch_rewards = paddle.concat(batch_rewards_list, axis=0) + batch_rewards = batch_rewards.cast(paddle.float32) + + try: + hcg = fleet.get_hybrid_communicate_group() + sd_group = hcg.get_sharding_parallel_group() + dp_group = hcg.get_data_parallel_group() + + if sd_group.nranks > 1: + all_gather_batch_rewards = [] + dist.all_gather(all_gather_batch_rewards, batch_rewards, group=sd_group) + batch_rewards = paddle.flatten(paddle.stack(all_gather_batch_rewards)) + if dp_group.nranks > 1: + all_gather_batch_rewards = [] + dist.all_gather(all_gather_batch_rewards, batch_rewards, group=dp_group) + batch_rewards = paddle.flatten(paddle.stack(all_gather_batch_rewards)) + except AttributeError: + pass + + batch_rewards_mean = batch_rewards.mean() + # batch_rewards_std = batch_rewards.std() + batch_rewards_var = batch_rewards.var() + + current_batch_num = batch_rewards.shape[0] + delta = batch_rewards_mean - self.reward_mean + total_batch_num = self.sample_batch_num + current_batch_num + + new_mean = self.reward_mean + delta * current_batch_num / total_batch_num + m_a = self.reward_var * self.sample_batch_num + m_b = batch_rewards_var * current_batch_num + m2 = m_a + m_b + paddle.square(delta) * (self.sample_batch_num * current_batch_num / total_batch_num) + new_var = m2 / total_batch_num + + self.reward_mean = new_mean + self.reward_var = new_var + self.sample_batch_num = total_batch_num + + for rl_batch in rl_batches: + reward_mean = self.reward_mean.cast(paddle.bfloat16) + reward_std = self.reward_var.sqrt().cast(paddle.bfloat16) + rl_batch["rewards"] = (rl_batch["rewards"] - reward_mean) / (reward_std + 1e-8) + + for rl_batch in rl_batches: + prompt = rl_batch["prompt"] # length: src + attention_mask = rl_batch["attention_mask"] # length: src + tgt + if len(attention_mask.shape) == 4: + # use padding mask instead of causal mask + attention_mask = rl_batch["sequence_mask"] # length: src + tgt + old_log_probs = rl_batch["log_probs"] # length: src + tgt -1 + ref_log_probs = rl_batch["ref_log_probs"] # length: src + tgt -1 + rewards = rl_batch["rewards"] # length: 1 + if self.args.rl_algorithm == "ppo": + old_reward_values = rl_batch["reward_values"] # length: src + tgt -1 + + start = prompt.shape[-1] - 1 + # sequence_mask is for label masking, make source be masked out + # clone to avoid to change attention_mask + sequence_mask = attention_mask[:, 1:].clone() # length: src + tgt -1 + sequence_mask[:, :start] = False + if use_tgt_len_value: + ref_log_probs = ref_log_probs[:, start:].contiguous() + old_log_probs = old_log_probs[:, start:].contiguous() + if self.args.rl_algorithm == "ppo": + old_reward_values = old_reward_values[:, start:].contiguous() + sequence_mask = sequence_mask[:, start:].contiguous() + if self.args.rl_algorithm == "grpo": + eos_mask = (rl_batch["input_ids"] != self.tokenizer.pad_token_id)[:, 1:].to(old_log_probs.dtype) + if use_tgt_len_value: + eos_mask = eos_mask[:, start:].contiguous() + reward_advantages = compute_grpo_advantages( + rewards, rl_batch["index"], eos_mask, old_log_probs.shape[-1] + ) + elif self.args.rl_algorithm == "ppo": + rewards_with_kl, kl_rewards = self.add_kl_divergence_regularization( + None, # prompt, + old_log_probs, + ref_log_probs, + rewards, + sequence_mask, + ) # length: tgt if use_tgt_len_value src + tgt -1 + reward_advantages, reward_returns = self.get_advantages_and_returns( + old_reward_values, + rewards_with_kl, + sequence_mask, + start=0 if use_tgt_len_value else start, + use_tgt_len_return=use_tgt_len_value, + ) # length: tgt if use_tgt_len_value src + tgt -1 + else: + raise ValueError(f"Unknown rl_algorithm: {self.args.rl_algorithm}") + + rl_batch.update( + { + "log_probs": old_log_probs, + "reward_advantages": reward_advantages, + "sequence_mask": sequence_mask, + "ref_log_probs": ref_log_probs, + "rewards": rewards, + } + ) + if self.args.rl_algorithm == "ppo": + rl_batch.update( + { + "reward_values": old_reward_values, + "reward_returns": reward_returns, + "kl_rewards": kl_rewards, + "rewards_with_kl": rewards_with_kl, + } + ) + + # pop out to reduce data dispatch comm overhead + rl_batch.pop("prompt") + + if use_advantage_normalization: + all_advantages_list = [] + for rl_batch in rl_batches: + sequence_mask = rl_batch["sequence_mask"].cast(paddle.int64) # length: src + tgt + advantages = rl_batch["reward_advantages"] + all_advantages_list.append(advantages[sequence_mask != 0]) + all_advantages = paddle.concat(all_advantages_list, axis=0) + all_advantages = all_advantages.cast(paddle.float32) + + try: + hcg = fleet.get_hybrid_communicate_group() + sd_group = hcg.get_sharding_parallel_group() + dp_group = hcg.get_data_parallel_group() + + if sd_group.nranks > 1: + object_list = [] + dist.all_gather_object(object_list, all_advantages.tolist(), group=sd_group) + flattened_data = [item for sublist in object_list for item in sublist] + all_advantages = paddle.to_tensor(flattened_data, dtype="float32") + if dp_group.nranks > 1: + object_list = [] + dist.all_gather_object(object_list, all_advantages.tolist(), group=dp_group) + flattened_data = [item for sublist in object_list for item in sublist] + all_advantages = paddle.to_tensor(flattened_data, dtype="float32") + except AttributeError: + pass + all_advantages_mean = all_advantages.mean() + all_advantages_std = all_advantages.std() + for rl_batch in rl_batches: + all_advantages_mean = all_advantages_mean.cast(paddle.bfloat16) + all_advantages_std = all_advantages_std.cast(paddle.bfloat16) + rl_batch["reward_advantages"] = (rl_batch["reward_advantages"] - all_advantages_mean) / ( + all_advantages_std + 1e-8 + ) + rl_batch["reward_advantages"] = rl_batch["reward_advantages"] * rl_batch["sequence_mask"] + + return rl_batches + + +@paddle.no_grad() +def compute_grpo_advantages( + rewards: paddle.Tensor, + index: np.ndarray, + sequence_mask: paddle.Tensor, + response_length: int, + epsilon: float = 1e-6, +): + """ + 计算每个prompt的GRPO优势。 + + Args: + rewards (paddle.Tensor, shape=[batch_size]): 回报,单位为float。 + index (np.ndarray, shape=[batch_size]): 每个样本对应的prompt索引,类型为int。 + sequence_mask (paddle.Tensor, shape=[batch_size, response_length]): 序列掩码,用于标记每个时间步是否有效,类型为bool。 + response_length (int): 每个样本的响应长度。 + epsilon (float, optional, default=1e-6): 避免除以0的值,默认为1e-6。 + + Returns: + rewards (paddle.Tensor, shape=[batch_size, response_length]): GRPO优势,单位为float。 + + Raises: + ValueError (ValueError): 如果没有在给定的prompt索引中有分数。 + """ + id2score = defaultdict(list) + id2mean = {} + id2std = {} + batch_size = rewards.shape[0] + + for i in range(batch_size): + id2score[index[i]].append(rewards[i]) + for idx in id2score: + if len(id2score[idx]) == 1: + id2mean[idx] = paddle.to_tensor(0.0, dtype=rewards.dtype) + id2std[idx] = paddle.to_tensor(1.0, dtype=rewards.dtype) + elif len(id2score[idx]) > 1: + id2mean[idx] = paddle.mean(paddle.stack(id2score[idx])) + id2std[idx] = paddle.std(paddle.stack(id2score[idx])) + else: + raise ValueError(f"No score in prompt index: {idx}") + for i in range(batch_size): + rewards[i] = (rewards[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon) + rewards = rewards.unsqueeze(-1).tile([1, response_length]) * sequence_mask + return rewards diff --git a/llm/alignment/ppo/reward_server.py b/llm/alignment/ppo/reward_server.py new file mode 100644 index 000000000000..29495f600a4d --- /dev/null +++ b/llm/alignment/ppo/reward_server.py @@ -0,0 +1,308 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Launch Reward HTTP Server.""" + +import argparse +import json +import logging +import re +import threading +import traceback +from typing import Dict, List, Optional, Tuple + +import uvicorn +from fastapi import FastAPI +from pydantic import BaseModel + + +class Request(BaseModel): + """The request for RM server.""" + + src: List[str] + tgt: List[str] + response: List[str] + + +class Response(BaseModel): + """The response for RM server.""" + + error_code: int = 0 + error_msg: str = "Success" + score: List[float] = None + + +def extract_solution(solution_str: str) -> Tuple[Optional[str], str]: + """Extracts the final answer from the model's response string. + + Args: + solution_str: Raw response string from the language model + + Returns: + Tuple containing (extracted_answer, processed_string) + """ + # Split response to isolate assistant output + if "Assistant:" in solution_str: + processed_str = solution_str.split("Assistant:", 1)[1] + elif "<|im_start|>assistant" in solution_str: + processed_str = solution_str.split("<|im_start|>assistant", 1)[1] + else: + print("[Error] Failed to locate model response header") + return None, solution_str + + # Extract final answer using XML-style tags + answer_pattern = r"(.*?)" + matches = list(re.finditer(answer_pattern, processed_str, re.DOTALL)) + + if not matches: + print("[Error] No valid answer tags found") + return None, processed_str + + final_answer = matches[-1].group(1).strip() + return final_answer, processed_str + + +def parse_solution_text_format(solution_text: str) -> Dict[str, str]: + """Parses ground truth solution text into status dictionary. + + Args: + solution_text: Formatted solution text from dataset + + Returns: + Dictionary mapping character names to their roles (knight/knave) + """ + status_dict = {} + print("\n[Ground Truth Parsing]") + + for line in solution_text.split("\n"): + line = line.strip() + if not line: + continue + + match = re.search(r"\b([A-Za-z]+)\b.*?\b(knight|knave)\b", line, re.IGNORECASE) + if match: + name, role = match.groups() + status_dict[name] = role.lower() + print(f" Found: {name} → {role}") + else: + print(f" [Warning] Unparseable line: '{line}'") + + return status_dict + + +def parse_model_answer(answer_text: str, expected_names: list) -> Optional[Dict[str, str]]: + """Parses model's answer text into status dictionary. + + Args: + answer_text: Text extracted from model's tags + expected_names: List of character names requiring identification + + Returns: + Dictionary mapping character names to predicted roles, or None if incomplete + """ + status_dict = {} + print("\n[Model Answer Parsing]") + print(f" Expected characters: {expected_names}") + + knight_count = answer_text.lower().count("knight") + knave_count = answer_text.lower().count("knave") + + print(f" Number of predicted roles: {knight_count + knave_count}") + if knight_count + knave_count != len(expected_names): + print(f" [Error] Number of characters mismatch: {knight_count + knave_count} != {len(expected_names)}") + return None + + for name in expected_names: + pattern = re.compile(rf"\b{re.escape(name)}\b\s+is\s+a\s+\b(knight|knave)\b", re.IGNORECASE) + match = pattern.search(answer_text) + + if match: + role = match.group(1).lower() + status_dict[name] = role + print(f" Found: {name} → {role}") + else: + print(f" [Error] Missing identification for {name}") + return None + + return status_dict + + +def validate_response_structure(processed_str: str) -> bool: + """Performs comprehensive validation of response structure. + + Args: + processed_str: Processed response string from the model + + Returns: + Boolean indicating whether all formatting requirements are met + """ + print("\n[Structure Validation]") + validation_passed = True + + # Check required tags + tags = { + "think_start": ("", 1), + "think_end": ("", 1), + "answer_start": ("", 1), + "answer_end": ("", 1), + } + + positions = {} + for tag_name, (tag_str, expected_count) in tags.items(): + count = processed_str.count(tag_str) + positions[tag_name] = pos = processed_str.find(tag_str) + + print(f" {tag_str}: count={count}, position={pos}") + + if count != expected_count: + print(f" [Error] {tag_str} appears {count} times (expected {expected_count})") + validation_passed = False + + # Verify tag order + if ( + positions["think_start"] > positions["think_end"] + or positions["think_end"] > positions["answer_start"] + or positions["answer_start"] > positions["answer_end"] + ): + print(" [Error] Incorrect tag order: Expected ......") + validation_passed = False + else: + print(" Tag sequence validation passed") + + return validation_passed + + +def compute_score( + solution_str: str, ground_truth: str, query=None, format_reward: int = 1, answer_reward: float = 1.0 +): + """Computes comprehensive score for model response. + + Args: + solution_str: Raw model response string + ground_truth: Dictionary containing ground truth data + format_reward: Points awarded/deducted for format correctness + answer_reward: Points awarded/deducted for answer correctness + + Returns: + Total score (sum of format and answer rewards) + """ + print("\n" + "=" * 80) + print(" Processing New Sample ".center(80, "=")) + + if "\n<|im_start|>assistant\n" not in solution_str: + solution_str = "\n<|im_start|>assistant\n" + solution_str + + # Parse ground truth data + solution_text = ground_truth + gt_status = parse_solution_text_format(solution_text) + expected_names = list(gt_status.keys()) + print(f"[Ground Truth] Final identities: {gt_status}") + + # Extract model answer + answer_text, processed_str = extract_solution(solution_str) + print(f"\n[Model Response]\n{processed_str}") + + # Validate response structure + format_correct = validate_response_structure(processed_str) + format_score = format_reward if format_correct else -abs(format_reward) + print(f"\n Format validation: {'PASS' if format_correct else 'FAIL'}") + print(f" Format score: {format_score}") + + # Validate answer content + answer_score = 0 + if format_correct and answer_text: + pred_status = parse_model_answer(answer_text, expected_names) + if pred_status: + print("\n[Content Validation]") + print(f" Expected: {gt_status}") + print(f" Predicted: {pred_status}") + + if pred_status == gt_status: + answer_score = 2 + print(" Content validation: FULL MATCH") + else: + answer_score = -1.5 + print(" Content validation: MISMATCH") + else: + answer_score = -2 + print("Fail to parse answer") + else: + answer_score = -2 + print("\n[Content Validation] Skipped due to format errors or missing answer") + + total_score = format_score + answer_score + print("\n" + "-" * 80) + print(" Final Score ".center(80, "-")) + print(f" Format: {format_score}") + print(f" Answer: {answer_score}") + print(f" Total: {total_score}") + print("=" * 80 + "\n") + + return float(total_score) + + +def setup_args(): + """Setup inerance server arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument("--port", type=int, default=8731) + parser.add_argument("--log_file", type=str, default="rm_server.log") + args = parser.parse_args() + return args + + +def server(args): + """Launch RM server.""" + app = FastAPI() + lock = threading.Lock() + + logging.basicConfig( + level=logging.INFO, + filename=args.log_file, + filemode="w", + format="%(asctime)s - %(message)s", + ) + + @app.post("/") + async def _server(request: Request) -> Response: + lock.acquire() + logging.info(f"Request: {request}") + try: + all_result = [] + if len(request.tgt) != len(request.response) or len(request.tgt) != len(request.src): + raise ValueError("The length of response, tgt, and src should be equal.") + for i in range(len(request.response)): + reward = compute_score(request.response[i], request.tgt[i], request.src[i]) + all_result.append(reward) + output = { + "error_code": 0, + "error_msg": "Success", + "score": all_result, + } + except Exception as err: + logging.error(f"Server error: when process {request}\n{traceback.format_stack()}") + output = { + "error_code": 500, + "error_msg": f"{err}", + "score": [0] * len(request.tgt), + } + logging.info(f"Response: {json.dumps(output, indent=2, ensure_ascii=False)}") + lock.release() + return output + + uvicorn.run(app, host="0.0.0.0", port=args.port) + + +if __name__ == "__main__": + args = setup_args() + server(args) diff --git a/llm/alignment/ppo/run_ppo.py b/llm/alignment/ppo/run_ppo.py index 6d9c36b31496..609985b008a8 100644 --- a/llm/alignment/ppo/run_ppo.py +++ b/llm/alignment/ppo/run_ppo.py @@ -12,220 +12,102 @@ # See the License for the specific language governing permissions and # limitations under the License. + import copy import os import sys import types -from dataclasses import dataclass, field from functools import partial -from typing import Any, Dict, Tuple import paddle -from data import PromptOnlyDataset, SupervisedDataset, parse_dataset -from models import AutoModelForScore +from comm_utils import offload_tensor_to_cpu +from data import PromptOnlyDataset, SupervisedDataset from models.score_model import LlamaModelForScore # noqa -from ppo_trainer import PPOTrainer, cleanup_tensor_space, offload_tensor_to_cpu - -from paddlenlp.trainer import PdArgumentParser, TrainingArguments, get_last_checkpoint -from paddlenlp.transformers import ( - AutoConfig, - AutoModelForCausalLM, - AutoTokenizer, - LlamaTokenizer, -) -from paddlenlp.utils.log import logger - - -@dataclass -class TrainingArguments(TrainingArguments): - kl_coeff: float = field( - default=0.02, - metadata={"help": "The coefficient for the KL divergence between the reference and actor policy."}, - ) - clip_range_ratio: float = field( - default=0.2, - metadata={ - "help": "The clipping range for ratio between the old and new policy. This is the epsilon parameter in the PPO algorithm." - }, - ) - clip_range_score: float = field( - default=50.0, - metadata={ - "help": "The clipping range for the output of the score model. The reward is clipped into [-clip_range_score, clip_range_score]." - }, - ) - clip_range_value: float = field( - default=5.0, - metadata={ - "help": "The clipping range for the value function. The value is clipped into [value_estimate - clip_range_value, value_estimate + clip_range_value] during training." - }, - ) - ptx_coeff: float = field( - default=0.0, - metadata={"help": "The coefficient for the ptx loss."}, - ) - update_iters: int = field( - default=1, - metadata={"help": "The number of repeated updates on a generated batch."}, - ) - critic_learning_rate: float = field( - default=None, - metadata={"help": "Initial learning rate (after the potential warmup period) for the critic model training."}, - ) - critic_weight_decay: float = field( - default=None, - metadata={"help": "Weight decay to for the critic model training."}, - ) - critic_lr_scheduler_type: str = field( - default=None, - metadata={"help": "The scheduler type for critic model."}, - ) - critic_warmup_ratio: float = field( - default=None, - metadata={"help": "Ratio of warm steps over total training steps for the critic lr scheduler."}, - ) - critic_recompute: bool = field( - default=None, - metadata={"help": "Enable gradient checkpointing for critic model."}, - ) - normalize_reward: bool = field( - default=None, - metadata={"help": "Whether to normalize the reward during RL training."}, - ) - temperature: float = field( - default=1.0, - metadata={"help": "The value used to module the next token probabilities."}, - ) - # top_k: int = field( - # default=1, - # metadata={"help": "top_k"}, - # ) - top_p: float = field( - default=0.8, - metadata={ - "help": "If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to`top_p` or higher are kept for generation." - }, - ) - num_return_sequences: int = field( - default=1, - metadata={"help": "The number of independently computed returned sequences for each element in the batch."}, - ) - repetition_penalty: float = field( - default=1.0, - metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}, - ) - per_device_prompt_batch_size: int = field( - default=16, - metadata={"help": "Batch size (per device) for the training dataloader."}, - ) - eval_mode: str = field( - default=None, - metadata={ - "help": "eval mode for actor model and reward_critic_model, optional for: None, single, tensor_parallel." - }, - ) - - offload_level: str = field( - default="", - metadata={"help": "Offload model, optional for: eval, reward, optimizer, train_model"}, - ) - use_fusemt: bool = field( - default=True, - metadata={"help": "use inference model to speedup in rollout generation"}, - ) +from ppo_trainer import PPOTrainer +from trainer_utils import DataArgument, ModelArgument, TrainingArguments - # save_generation_output: bool = field( - # default=False, - # metadata={"help": "Whether to save generated text to file when eval"}, - # ) - - -@dataclass -class ModelArgument: - actor_model_name_or_path: str = field( - default=None, metadata={"help": "Build-in pretrained model name or the path to local model."} - ) - reward_model_name_or_path: str = field( - default=None, metadata={"help": "Build-in pretrained model name or the path to local model."} - ) - reward_critic_model_name_or_path: str = field( - default=None, metadata={"help": "Build-in pretrained model name or the path to local model."} - ) - use_flash_attention: bool = field(default=False, metadata={"help": "Whether to use flash attention"}) - - # # LoRA related parameters - # lora: bool = field(default=False, metadata={"help": "Whether to use LoRA technique"}) - # lora_path: str = field(default=None, metadata={"help": "Initialize lora state dict."}) - # lora_rank: int = field(default=8, metadata={"help": "Lora attention dimension"}) - - # # prefix tuning related parameters - # prefix_tuning: bool = field(default=False, metadata={"help": "Whether to use Prefix technique"}) - # num_prefix_tokens: int = field(default=128, metadata={"help": "Number of prefix tokens"}) - - -@dataclass -class DataArgument: - train_datasets: str = field(default=None, metadata={"help": "Dataset name(s) registered in the raw dataset."}) - eval_datasets: str = field(default=None, metadata={"help": "Dataset name(s) registered in the raw dataset."}) - eval_split_ratio: float = field(default=None, metadata={"help": "Ratio of eval data to train data"}) - ptx_datasets: str = field(default=None, metadata={"help": "Dataset name(s) registered in the raw dataset."}) - max_length: int = field( - default=2048, - metadata={ - "help": "The maximum length that model input tokens can have. When intokens is set to True, it's also the maximum length for InTokens data stream" - }, - ) +from paddlenlp.trainer import PdArgumentParser, RuntimeTimer, get_last_checkpoint +from paddlenlp.trainer.trainer_utils import ShardingOption +from paddlenlp.transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from paddlenlp.trl import llm_utils +from paddlenlp.utils.log import logger - @property - def parsed_train_datasets(self) -> Tuple[str, Dict[str, Any]]: - """Parse dataset path and its proportion and optionally additional arguments from `train_datasets`.""" - return [parse_dataset(string) for string in self.train_datasets.split(",")] - @property - def parsed_eval_datasets(self) -> Tuple[str, Dict[str, Any]]: - """Parse dataset path and its proportion and optionally additional arguments from `eval_datasets`.""" - if self.eval_datasets is None: - return None - return [parse_dataset(string) for string in self.eval_datasets.split(",")] +def main(): + """ + 主函数,用于运行训练。 - @property - def parsed_ptx_datasets(self) -> Tuple[str, Dict[str, Any]]: - """Parse dataset path and its proportion and optionally additional arguments from `ptx_datasets`.""" - if self.ptx_datasets is None: - return None - return [parse_dataset(string) for string in self.ptx_datasets.split(",")] + Args: + 无参数。 + Returns: + None: 该函数没有返回值。 -def main(): + Raises: + 无异常抛出。 + """ # Arguments parser = PdArgumentParser((ModelArgument, DataArgument, TrainingArguments)) - if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): - model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + # 参数解析是不是改错了 + if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"): + model_args, data_args, training_args = parser.parse_json_file_and_cmd_lines() else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + training_args.max_src_len = data_args.max_prompt_len + training_args.actor_model_name_or_path = model_args.actor_model_name_or_path + + if training_args.sequence_parallel: + if training_args.tensor_parallel_degree <= 1: + training_args.sequence_parallel = False + logger.info("Tensor_parallel_degree = 1. Set sequence_parallel to False.") + + if training_args.tensor_parallel_degree <= 1: + training_args.tensor_parallel_output = False + logger.info("Tensor_parallel_degree = 1. Set tensor_parallel_output to False.") + + if training_args.sharding_parallel_degree > 1: + if ( + ShardingOption.SHARD_GRAD_OP in training_args.sharding + or ShardingOption.FULL_SHARD in training_args.sharding + ): + if training_args.release_grads is True: + training_args.release_grads = False + + if training_args.unified_checkpoint and "async_save" in training_args.unified_checkpoint_config: + training_args.unified_checkpoint_config.remove("async_save") + logger.warning( + "PPO training currently does not support asynchronous saving! " + "Remove `async_save` from unified_checkpoint_config." + ) + + training_args.offload_level = training_args.offload_level.split() training_args.print_config(model_args, "Model") training_args.print_config(data_args, "Data") + runtime_timer = RuntimeTimer("Training") + if training_args.eval_mode is not None and len(training_args.eval_mode) == 0: training_args.eval_mode = None - if training_args.eval_mode is None and training_args.offload_level is not None: - training_args.offload_level = training_args.offload_level.replace("eval", "") + # if training_args.eval_mode is None and training_args.offload_level is not None: + # training_args.offload_level = training_args.offload_level.replace("eval", "") # Setup GPU & distributed training paddle.set_device(training_args.device) logger.warning( - f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, " - + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16 or training_args.bf16}" + f"Process rank: {training_args.local_rank}, device: {training_args.device}, " + f"world_size: {training_args.world_size}, " + f"distributed training: {bool(training_args.local_rank != -1)}, " + f"16-bits training: {training_args.fp16 or training_args.bf16}" ) # Detecting last checkpoint. last_checkpoint = None if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: last_checkpoint = get_last_checkpoint(training_args.output_dir) - if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 1: - raise ValueError( - f"Output directory ({training_args.output_dir}) already exists and is not empty. " - "Use --overwrite_output_dir to overcome." - ) + # if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 1: + # raise ValueError( + # f"Output directory ({training_args.output_dir}) already exists and is not empty. " + # "Use --overwrite_output_dir to overcome." + # ) if last_checkpoint is not None and training_args.resume_from_checkpoint is None: logger.info( f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " @@ -242,9 +124,30 @@ def main(): raise ValueError("Please specific dtype: --fp16 or --bf16") else: dtype = "float32" + training_args.max_length = data_args.max_length - model_class_lm, model_class_score = AutoModelForCausalLM, AutoModelForScore + if training_args.decay_steps is None: + training_args.decay_steps = training_args.max_steps + + if training_args.use_rm_server: + if model_args.reward_server is None: + raise ValueError("Please specify reward_server when use_rm_server is true.") + logger.info(f"Use reward server: {model_args.reward_server} for training.") + if training_args.rl_algorithm == "ppo" and model_args.reward_critic_model_name_or_path is None: + raise ValueError("Please specify reward_critic_model_name_or_path when use_rm_server is true.") + else: + if model_args.reward_model_name_or_path is None: + raise ValueError("Please specify reward_model_name_or_path when use_rm_server is false.") + + if training_args.rl_algorithm != "ppo" and training_args.use_fused_head_and_loss_fn: + logger.warning( + f"Fused_head_and_loss_fn currently does not support {training_args.rl_algorithm}. " + "Reset `use_fused_head_and_loss_fn` to False." + ) + training_args.use_fused_head_and_loss_fn = False + + model_class_lm, model_class_score = AutoModelForCausalLM, LlamaModelForScore if training_args.pipeline_parallel_degree > 1: from models.model_pp import LlamaPolicyPipe, LlamaValuePipe @@ -259,252 +162,340 @@ def main(): # (StepTrainer.create_criterion) to set hyper-parameters extra_args = {} + common_config = dict( + use_flash_attention=model_args.use_flash_attention, + sequence_parallel=training_args.sequence_parallel, + fused_rotary=False, + max_sequence_length=data_args.max_length, + ) + + runtime_timer.start("Actor model loading time") + # actor model - model_config = AutoConfig.from_pretrained( + actor_model_config = AutoConfig.from_pretrained( model_args.actor_model_name_or_path, - tensor_parallel_output=False, + tensor_parallel_output=training_args.tensor_parallel_output, tensor_parallel_degree=training_args.tensor_parallel_degree, tensor_parallel_rank=training_args.tensor_parallel_rank, + recompute_granularity=model_args.recompute_granularity, dtype=dtype, + recompute=training_args.recompute, + recompute_use_reentrant=training_args.recompute_use_reentrant, + **common_config, ) - if hasattr(model_config, "use_flash_attention"): - model_config.use_flash_attention = model_args.use_flash_attention - # model_config.num_hidden_layers = 2 + actor_model_config.use_fused_head_and_loss_fn = training_args.use_fused_head_and_loss_fn + actor_model_config.set_attn_func = True + actor_model_config.max_position_embeddings = data_args.max_length + actor_model_config.use_sparse_head_and_loss_fn = False + actor_model_config.fused_linear = model_args.fused_linear + print(f"Loading Actor model with config:\n\t{actor_model_config}\n") + + if not training_args.autotuner_benchmark: + actor_model = model_class_lm.from_pretrained( + model_args.actor_model_name_or_path, + config=actor_model_config, + **extra_args, + # ptx_coeff=training_args.ptx_coeff, + # clip_range_ratio=training_args.clip_range_ratio, + ) + else: + actor_model = model_class_lm.from_config( + actor_model_config, + **extra_args, + # ptx_coeff=training_args.ptx_coeff, + # clip_range_ratio=training_args.clip_range_ratio, + ) + + logger.info(f"{runtime_timer.log()}") - actor_model = model_class_lm.from_pretrained( - model_args.actor_model_name_or_path, - config=model_config, - **extra_args, - # ptx_coeff=training_args.ptx_coeff, - # clip_range_ratio=training_args.clip_range_ratio, - ) if training_args.eval_mode is not None: config = copy.deepcopy(actor_model.config) + config.use_fused_head_and_loss_fn = False if training_args.eval_mode == "single": config.tensor_parallel_degree = -1 config.tensor_parallel_rank = 0 + runtime_timer.start("Actor eval model loading time") actor_eval_model = AutoModelForCausalLM.from_config(config) + logger.info(f"{runtime_timer.log()}") # TODO(guosheng): AutoModel (in `_get_model_class_from_config`) pop out # architecture which is necessary for infer predictor currently config.architectures = actor_model.config.architectures # actor_eval_model = AutoModelForCausalLM.from_pretrained(model_args.actor_model_name_or_path, config=config) + # cleanup_tensor_space(actor_eval_model.state_dict()) else: actor_eval_model = None + runtime_timer.start("Actor reference model loading time") # todo reference model if training_args.eval_mode is not None: - config = copy.deepcopy(model_config) + config = copy.deepcopy(actor_model_config) + config.use_fused_head_and_loss_fn = False if training_args.eval_mode == "single": config.tensor_parallel_degree = -1 config.tensor_parallel_rank = 0 - actor_reference_model = AutoModelForCausalLM.from_pretrained( - model_args.actor_model_name_or_path, - config=config, - ) + if not training_args.autotuner_benchmark: + actor_reference_model = AutoModelForCausalLM.from_pretrained( + model_args.actor_model_name_or_path, + config=config, + ) + else: + actor_reference_model = AutoModelForCausalLM.from_config( + config, + dtype=dtype, + ) else: - actor_reference_model = model_class_lm.from_pretrained( - model_args.actor_model_name_or_path, - config=model_config, + actor_reference_model = model_class_lm.from_config( + actor_model_config, + dtype=dtype, ) + if not training_args.autotuner_benchmark: + actor_reference_model.set_state_dict(actor_model.state_dict()) + logger.info(f"{runtime_timer.log()}") actor_tokenizer = AutoTokenizer.from_pretrained( - model_args.actor_model_name_or_path, model_max_length=data_args.max_length, padding_side="left" + model_args.actor_model_name_or_path, + model_max_length=data_args.max_length, + padding_side="left", + tokenizer_alpha=model_args.actor_tokenizer_alpha, ) + llm_utils.init_chat_template(actor_tokenizer, model_args.actor_model_name_or_path, model_args.chat_template) - # reward model - model_config = AutoConfig.from_pretrained( - model_args.reward_model_name_or_path, - tensor_parallel_output=False, - tensor_parallel_degree=training_args.tensor_parallel_degree, - tensor_parallel_rank=training_args.tensor_parallel_rank, - dtype=dtype, - ) - if hasattr(model_config, "use_flash_attention"): - model_config.use_flash_attention = model_args.use_flash_attention - # model_config.num_hidden_layers = 2 - # todo - if training_args.eval_mode is not None: - config = copy.deepcopy(model_config) - if training_args.eval_mode == "single": - config.tensor_parallel_degree = -1 - config.tensor_parallel_rank = 0 - reward_model = AutoModelForScore.from_pretrained( + training_args.autotuner_benchmark = True + if not training_args.use_rm_server and model_args.reward_model_name_or_path is not None: + runtime_timer.start("Reward model loading time") + # reward model + reward_model_config = AutoConfig.from_pretrained( model_args.reward_model_name_or_path, - config=config, - score_type="reward", - do_normalize=training_args.normalize_reward, + tensor_parallel_output=False, + tensor_parallel_degree=training_args.tensor_parallel_degree, + tensor_parallel_rank=training_args.tensor_parallel_rank, + dtype=dtype, + recompute=training_args.critic_recompute, + recompute_granularity=model_args.critic_recompute_granularity, + recompute_use_reentrant=training_args.recompute_use_reentrant, + **common_config, ) - else: - reward_model = model_class_score.from_pretrained( + reward_model_config.num_hidden_layers = 2 + reward_model_config.max_position_embeddings = data_args.max_length + reward_model_config.use_sparse_head_and_loss_fn = False + reward_model_config.fused_linear = model_args.fused_linear + print(f"Loading Reward model with config:\n\t{reward_model_config}\n") + + if training_args.eval_mode is not None: + config = copy.deepcopy(reward_model_config) + if training_args.eval_mode == "single": + config.tensor_parallel_degree = -1 + config.tensor_parallel_rank = 0 + if not training_args.autotuner_benchmark: + reward_model = LlamaModelForScore.from_pretrained( + model_args.reward_model_name_or_path, + config=config, + score_type="reward", + do_normalize=False, + ) + else: + reward_model = LlamaModelForScore.from_config( + config, + score_type="reward", + do_normalize=False, + ) + else: + if not training_args.autotuner_benchmark: + reward_model = model_class_score.from_pretrained( + model_args.reward_model_name_or_path, + config=reward_model_config, + score_type="reward", + do_normalize=False, + ) + else: + reward_model = model_class_score.from_config( + reward_model_config, + score_type="reward", + do_normalize=False, + ) + + logger.info(f"{runtime_timer.log()}") + reward_tokenizer = AutoTokenizer.from_pretrained( model_args.reward_model_name_or_path, - config=model_config, - score_type="reward", - do_normalize=training_args.normalize_reward, + model_max_length=data_args.max_length, + padding_side="right", + tokenizer_alpha=model_args.reward_tokenizer_alpha, ) - reward_tokenizer = AutoTokenizer.from_pretrained( - model_args.reward_model_name_or_path, model_max_length=data_args.max_length, padding_side="right" - ) - # critic model - if model_args.reward_critic_model_name_or_path is None: - model_args.reward_critic_model_name_or_path = model_args.reward_model_name_or_path - reward_critic_model = model_class_score.from_pretrained( - model_args.reward_critic_model_name_or_path, - config=model_config, - score_type="critic", - do_normalize=False, - clip_range_value=training_args.clip_range_value, - ) - reward_critic_tokenizer = AutoTokenizer.from_pretrained( - model_args.reward_critic_model_name_or_path, model_max_length=data_args.max_length, padding_side="left" - ) - if training_args.eval_mode is not None: - config = copy.deepcopy(reward_critic_model.config) - if training_args.eval_mode == "single": - config.tensor_parallel_degree = -1 - config.tensor_parallel_rank = 0 - reward_critic_eval_model = AutoModelForScore.from_config(config) - # reward_critic_eval_model = AutoModelForScore.from_pretrained( - # model_args.reward_critic_model_name_or_path,config=model_config - # ) + llm_utils.init_chat_template(reward_tokenizer, model_args.reward_model_name_or_path, model_args.chat_template) else: - reward_critic_eval_model = None - - # # actor model - # model_config = AutoConfig.from_pretrained( - # model_args.actor_model_name_or_path, - # tensor_parallel_output=False, - # tensor_parallel_degree=training_args.tensor_parallel_degree, - # tensor_parallel_rank=training_args.tensor_parallel_rank, - # dtype=dtype, - # ) - # model_config.num_hidden_layers = 2 - # if hasattr(model_config, "use_flash_attention"): - # model_config.use_flash_attention = model_args.use_flash_attention - # actor_model = AutoModelForCausalLM.from_pretrained( - # model_args.actor_model_name_or_path, - # config=model_config, - # ) - # - # if training_args.eval_mode is not None: - # config = copy.deepcopy(actor_model.config) - # if training_args.eval_mode == "single": - # config.tensor_parallel_degree = -1 - # config.tensor_parallel_rank = 0 - # actor_eval_model = AutoModelForCausalLM.from_config(config) - # else: - # actor_eval_model = None - # - # # reference model - # actor_reference_model = AutoModelForCausalLM.from_pretrained( - # model_args.actor_model_name_or_path, - # config=model_config, - # ) - # actor_tokenizer = AutoTokenizer.from_pretrained( - # model_args.actor_model_name_or_path, model_max_length=data_args.max_length, padding_side="left" - # ) - # - # # reward model - # model_config = AutoConfig.from_pretrained( - # model_args.reward_model_name_or_path, - # tensor_parallel_output=False, - # tensor_parallel_degree=training_args.tensor_parallel_degree, - # tensor_parallel_rank=training_args.tensor_parallel_rank, - # dtype=dtype, - # ) - # model_config.num_hidden_layers = 2 - # if hasattr(model_config, "use_flash_attention"): - # model_config.use_flash_attention = model_args.use_flash_attention - # reward_model = AutoModelForScore.from_pretrained( - # model_args.reward_model_name_or_path, - # config=model_config, - # score_type="reward", - # do_normalize=training_args.normalize_reward, - # ) - # reward_tokenizer = AutoTokenizer.from_pretrained( - # model_args.reward_model_name_or_path, model_max_length=data_args.max_length, padding_side="right" - # ) - # - # # critic model - # if model_args.reward_critic_model_name_or_path is None: - # model_args.reward_critic_model_name_or_path = model_args.reward_model_name_or_path - # reward_critic_model = AutoModelForScore.from_pretrained( - # model_args.reward_critic_model_name_or_path, config=model_config, score_type="critic", do_normalize=False - # ) - # reward_critic_tokenizer = AutoTokenizer.from_pretrained( - # model_args.reward_critic_model_name_or_path, model_max_length=data_args.max_length, padding_side="left" - # ) - # - # if training_args.eval_mode is not None: - # config = copy.deepcopy(reward_critic_model.config) - # if training_args.eval_mode == "single": - # config.tensor_parallel_degree = -1 - # config.tensor_parallel_rank = 0 - # reward_critic_eval_model = AutoModelForScore.from_config(config) - # else: - # reward_critic_eval_model = None - - for tokenizer in [actor_tokenizer, reward_tokenizer, reward_critic_tokenizer]: - if isinstance(tokenizer, LlamaTokenizer) and tokenizer.pad_token_id is None: + reward_tokenizer = actor_tokenizer + reward_model = model_args.reward_server + if training_args.rl_algorithm == "ppo": + # critic model + runtime_timer.start("Reward critic model loading time") + if model_args.reward_critic_model_name_or_path is None: + model_args.reward_critic_model_name_or_path = model_args.reward_model_name_or_path + reward_critic_model = model_class_score.from_config( + reward_model_config, + dtype=dtype, + score_type="critic", + do_normalize=False, + clip_range_value=training_args.clip_range_value, + ) + if not training_args.autotuner_benchmark: + reward_critic_model.set_state_dict(reward_model.state_dict()) + else: + if not training_args.autotuner_benchmark: + reward_critic_model = model_class_score.from_pretrained( + model_args.reward_critic_model_name_or_path, + config=reward_model_config, + score_type="critic", + do_normalize=False, + clip_range_value=training_args.clip_range_value, + ) + else: + reward_critic_model = model_class_score.from_config( + reward_model_config, + score_type="critic", + do_normalize=False, + clip_range_value=training_args.clip_range_value, + ) + logger.info(f"{runtime_timer.log()}") + reward_critic_tokenizer = AutoTokenizer.from_pretrained( + model_args.reward_critic_model_name_or_path, + model_max_length=data_args.max_length, + padding_side="left", + tokenizer_alpha=model_args.reward_critic_tokenizer_alpha, + ) + llm_utils.init_chat_template( + reward_critic_tokenizer, model_args.reward_critic_model_name_or_path, model_args.chat_template + ) + if training_args.eval_mode is not None: + config = copy.deepcopy(reward_critic_model.config) + if training_args.eval_mode == "single": + config.tensor_parallel_degree = -1 + config.tensor_parallel_rank = 0 + runtime_timer.start("Reward critic eval model loading time") + reward_critic_eval_model = LlamaModelForScore.from_config(config) + logger.info(f"{runtime_timer.log()}") + # reward_critic_eval_model = AutoModelForScore.from_pretrained( + # model_args.reward_critic_model_name_or_path,config=model_config + # ) + # cleanup_tensor_space(reward_critic_eval_model.state_dict()) + else: + reward_critic_eval_model = None + + for tokenizer in [ + actor_tokenizer, + reward_tokenizer, + reward_critic_tokenizer if training_args.rl_algorithm == "ppo" else None, + ]: + if tokenizer and tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id - train_ds = PromptOnlyDataset(data_args.parsed_train_datasets, tokenizer=actor_tokenizer) - if data_args.eval_datasets is None and data_args.eval_split_ratio: - train_ds, dev_ds = train_ds.split_train_test(split_ratio=data_args.eval_split_ratio) - elif data_args.eval_datasets is not None: - dev_ds = PromptOnlyDataset(data_args.parsed_eval_datasets, tokenizer=actor_tokenizer) - else: - dev_ds = None + if training_args.should_load_dataset: + train_ds = PromptOnlyDataset( + data_args.parsed_train_datasets, tokenizer=actor_tokenizer, use_rm_server=training_args.use_rm_server + ) + if data_args.eval_datasets is None and data_args.eval_split_ratio: + train_ds, dev_ds = train_ds.split_train_test(split_ratio=data_args.eval_split_ratio) + elif data_args.eval_datasets is not None: + dev_ds = PromptOnlyDataset( + data_args.parsed_eval_datasets, tokenizer=actor_tokenizer, use_rm_server=training_args.use_rm_server + ) + else: + dev_ds = None - ptx_ds = ( - SupervisedDataset(data_args.parsed_ptx_datasets, tokenizer=actor_tokenizer) - if data_args.ptx_datasets is not None - else None - ) - if ptx_ds is not None: - # PretrainingCriterion requires shifted inputs and labels - ptx_ds.get_collator = types.MethodType(partial(ptx_ds.get_collator.__func__, shift=True), ptx_ds) - - # offload - # cleanup actor_eval_model, reward_critic_eval_model - # offload actor_reference_model reward_model - - if training_args.offload_level is not None: - if "eval" in training_args.offload_level: - cleanup_tensor_space(actor_eval_model.state_dict()) - cleanup_tensor_space(reward_critic_eval_model.state_dict()) - if "reward" in training_args.offload_level: - # if pp mode, should lazy offload - offload_tensor_to_cpu(actor_reference_model.state_dict()) - offload_tensor_to_cpu(reward_model.state_dict()) + ptx_ds = ( + SupervisedDataset(data_args.parsed_ptx_datasets, tokenizer=actor_tokenizer) + if data_args.ptx_datasets is not None + else None + ) + if ptx_ds is not None: + # PretrainingCriterion requires shifted inputs and labels + ptx_ds.get_collator = types.MethodType(partial(ptx_ds.get_collator.__func__, shift=True), ptx_ds) + + if "freeze_model" in training_args.offload_level: + offload_tensor_to_cpu((actor_reference_model, "freeze_model")) + if training_args.rl_algorithm == "ppo": + offload_tensor_to_cpu((reward_model, "freeze_model")) + if actor_eval_model is not None: + offload_tensor_to_cpu((actor_eval_model, "freeze_model")) + if training_args.rl_algorithm == "ppo" and reward_critic_eval_model is not None: + offload_tensor_to_cpu((reward_critic_eval_model, "freeze_model")) + # NOTE(gongenlei): release memory_reserved_size to equal to memory_allocated_size + paddle.device.cuda.empty_cache() trainer = PPOTrainer( # (policy_model, reference_model, reward_model, value_model) # policy_model, sft_model, reward_model, value_model # (policy_model, reference_model, reward_model, value_model, # (policy_model, reference_model, reward_model, value_model, policy_eval_model, value_eval_model - # (actor_model, actor_reference_model, reward_model, reward_critic_model, actor_eval_model, reward_critic_eval_model + # (actor_model, actor_reference_model, reward_model, reward_critic_model, actor_eval_model, + # reward_critic_eval_model model=( actor_model, actor_reference_model, reward_model, - reward_critic_model, + reward_critic_model if training_args.rl_algorithm == "ppo" else None, actor_eval_model, - reward_critic_eval_model, + reward_critic_eval_model if training_args.rl_algorithm == "ppo" else None, ), args=training_args, - train_dataset=train_ds, - eval_dataset=dev_ds, + train_dataset=(train_ds if training_args.do_train and training_args.should_load_dataset else None), + eval_dataset=(dev_ds if training_args.do_eval and training_args.should_load_dataset else None), ptx_dataset=ptx_ds, - tokenizer=(actor_tokenizer, actor_tokenizer, reward_tokenizer, reward_critic_tokenizer), + tokenizer=( + actor_tokenizer, + actor_tokenizer, + reward_tokenizer, + reward_critic_tokenizer if training_args.rl_algorithm == "ppo" else None, + ), data_collator=train_ds.get_collator(), ) + + # TODO(gongenlei) resume_from_checkpoint is not ready checkpoint = None if training_args.resume_from_checkpoint is not None: checkpoint = training_args.resume_from_checkpoint elif last_checkpoint is not None: checkpoint = last_checkpoint - trainer.train(resume_from_checkpoint=checkpoint) + + # The early-stopping callback. + if training_args.early_stopping: + from paddlenlp.trainer import EarlyStoppingCallback + + early_stopping_info = ( + f"Early stopping is enabled, " + f"patience={training_args.early_stopping_patience}, " + f"threshold={training_args.early_stopping_threshold}, " + f"metric={training_args.metric_for_best_model}, " + f"greater_is_better={training_args.greater_is_better}" + ) + logger.info(early_stopping_info) + trainer.add_callback( + EarlyStoppingCallback( + early_stopping_patience=training_args.early_stopping_patience, + early_stopping_threshold=training_args.early_stopping_threshold, + ) + ) + + # if training_args.hidden_dropout_prob or training_args.attention_probs_dropout_prob: + # trainer.add_callback(LayerwiseDropoutCallback()) + + if training_args.do_train: + train_result = trainer.train(resume_from_checkpoint=checkpoint) + if not training_args.autotuner_benchmark: + runtime_timer.start("Model saving time") + trainer.save_model(merge_tensor_parallel=training_args.tensor_parallel_degree > 1) + if paddle.distributed.get_world_size() > 1: + paddle.distributed.barrier() + logger.info(f"{runtime_timer.log()}") + trainer.log_metrics("train", train_result.metrics) + trainer.save_metrics("train", train_result.metrics) + trainer.save_state() + + if training_args.do_eval: + eval_result = trainer.evaluate() + trainer.log_metrics("eval", eval_result) + # NOTE(gongenlei): set combined=False to avoid overwriting errors on AFS + trainer.save_metrics("eval", eval_result, combined=False) if __name__ == "__main__": diff --git a/llm/alignment/ppo/trainer_utils.py b/llm/alignment/ppo/trainer_utils.py index e10a339851fe..fd7636fa351b 100644 --- a/llm/alignment/ppo/trainer_utils.py +++ b/llm/alignment/ppo/trainer_utils.py @@ -18,15 +18,19 @@ import os import time from contextlib import contextmanager -from typing import Dict +from dataclasses import dataclass, field +from typing import Any, Dict, Optional, Tuple import numpy as np import paddle import tqdm +from data import parse_dataset +from models.ppo_model_utils import make_attention_mask, make_position_ids from paddle.distributed import fleet from paddle.io import DataLoader from paddlenlp.generation.utils import GenerationMixin +from paddlenlp.trainer import IntervalStrategy from paddlenlp.trainer.trainer import ( TRAINER_STATE_NAME, HybridParallelOptimizer, @@ -48,16 +52,383 @@ from paddlenlp.transformers import BatchEncoding, PretrainedModel, PretrainedTokenizer from paddlenlp.transformers.configuration_utils import PretrainedConfig from paddlenlp.transformers.model_outputs import ModelOutput -from paddlenlp.transformers.tokenizer_utils_base import ( - PaddingStrategy, - TruncationStrategy, -) +from paddlenlp.transformers.tokenizer_utils_base import PaddingStrategy + + +@dataclass +class TrainingArguments(TrainingArguments): + kl_coeff: float = field( + default=0.02, + metadata={"help": "The coefficient for the KL divergence between the reference and actor policy."}, + ) + kl_loss_coeff: float = field( + default=0.001, + metadata={"help": "The coefficient for the KL loss for GRPO."}, + ) + clip_range_ratio: float = field( + default=0.2, + metadata={ + "help": "The clipping range for ratio between the old and new policy. " + "This is the epsilon parameter in the PPO algorithm." + }, + ) + clip_range_score: float = field( + default=50.0, + metadata={ + "help": "The clipping range for the output of the score model. " + "The reward is clipped into [-clip_range_score, clip_range_score]." + }, + ) + clip_range_value: float = field( + default=5.0, + metadata={ + "help": "The clipping range for the value function. The value is clipped into [value_estimate - " + "clip_range_value, value_estimate + clip_range_value] during training." + }, + ) + ptx_coeff: float = field( + default=0.0, + metadata={"help": "The coefficient for the ptx loss."}, + ) + update_iters: int = field( + default=1, + metadata={"help": "The number of repeated updates on a generated batch."}, + ) + critic_learning_rate: float = field( + default=None, + metadata={"help": "Initial learning rate (after the potential warmup period) for the critic model training."}, + ) + critic_weight_decay: float = field( + default=None, + metadata={"help": "Weight decay to for the critic model training."}, + ) + critic_lr_scheduler_type: str = field( + default=None, + metadata={"help": "The scheduler type for critic model."}, + ) + critic_warmup_ratio: float = field( + default=None, + metadata={"help": "Ratio of warm steps over total training steps for the critic lr scheduler."}, + ) + critic_recompute: bool = field( + default=None, + metadata={"help": "Enable gradient checkpointing for critic model."}, + ) + normalize_reward: bool = field( + default=None, + metadata={"help": "Whether to normalize the reward during RL training."}, + ) + normalize_advantage: bool = field( + default=None, + metadata={"help": "Whether to normalize the advantage during RL training."}, + ) + temperature: float = field( + default=1.0, + metadata={"help": "The value used to module the next token probabilities."}, + ) + top_p: float = field( + default=1.0, + metadata={ + "help": "If set to float < 1, only the smallest set of most probable tokens " + "with probabilities that add up to`top_p` or higher are kept for generation." + }, + ) + num_return_sequences: int = field( + default=1, + metadata={"help": "The number of independently computed returned sequences for each element in the batch."}, + ) + repetition_penalty: float = field( + default=1.0, + metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}, + ) + per_device_prompt_batch_size: int = field( + default=16, + metadata={"help": "Batch size (per device) for the training dataloader."}, + ) + eval_mode: str = field( + default=None, + metadata={ + "help": "eval mode for actor model and reward_critic_model, optional for: None, single, tensor_parallel." + }, + ) + + offload_level: str = field( + default="", + metadata={"help": "Offload model, optional for: eval, reward, optimizer, train_model"}, + ) + + max_dec_len: int = field(default=512, metadata={"help": "Maximum output length."}) + + min_dec_len: int = field(default=1, metadata={"help": "Minimum output length."}) + + max_src_len: int = field(default=3072, metadata={"help": "Max length of src."}) + + eos_token: str = field( + default="", + metadata={"help": "Use it as an eos_token if set it to non empty."}, + ) + + use_fusemt: bool = field( + default=True, + metadata={"help": "use fused inference model to speedup in rollout generation"}, + ) + + recompute_use_reentrant: bool = field( + default=True, + metadata={"help": "use recompute_use_reentrant to recompute"}, + ) + + critic_min_learning_rate: float = field( + default=None, + metadata={"help": "Minimum learning rate deacyed to for critic model."}, + ) + + critic_decay_steps: int = field( + default=None, + metadata={ + "help": "The steps use to control the learing rate for critic model. If the step > decay_steps, " + "will use the min_learning_rate." + }, + ) + + min_learning_rate: float = field( + default=None, + metadata={"help": "Minimum learning rate deacyed to."}, + ) + + decay_steps: int = field( + default=None, + metadata={ + "help": "The steps use to control the learing rate. If the step > decay_steps, " + "will use the min_learning_rate." + }, + ) + unified_checkpoint: bool = field( + default=True, + metadata={ + "help": "Enable fused linear grad add strategy, which will reduce elementwise " + "add for grad accumulation in the backward of nn.Linear ." + }, + ) + unified_checkpoint_config: Optional[str] = field( + default="", + metadata={ + "help": ( + "Configs to unify hybrid parallel checkpoint.\n" + "Following options are supports:\n" + "- skip_save_model_weight: do not save model weights when the masters weight exist\n" + "- master_weight_compatible: 1. if the master weights exist, only load when needed\n" + " 2. if master weights does not exist, convert model weights" + " to master weights when needed\n" + "- async_save: enable asynchronous saving checkpoints to disk\n" + "- enable_all_options: enable all optimization configurations\n" + ) + }, + ) + autotuner_benchmark: bool = field( + default=False, + metadata={"help": "Whether to run benchmark by autotuner. True for from_scratch."}, + ) + early_stopping: bool = field( + default=False, + metadata={"help": "Whether apply early stopping strategy."}, + ) + early_stopping_patience: int = field( + default=4, + metadata={ + "help": "Stop training when the specified metric" "worsens for early_stopping_patience evaluation calls" + }, + ) + early_stopping_threshold: float = field( + default=0.0, + metadata={"help": "how much the specified metric must improve to satisfy early stopping conditions."}, + ) + use_fused_head_and_loss_fn: bool = field( + default=False, + metadata={"help": "use fused_head_and_loss_fn."}, + ) + tensor_parallel_output: bool = field( + default=True, + metadata={"help": "use tensor_parallel_output."}, + ) + per_device_rollout_batch_size: int = field( + default=-1, + metadata={"help": "Batch size per GPU core/CPU for rollout."}, + ) + # save_generation_output: bool = field( + # default=False, + # metadata={"help": "Whether to save generated text to file when eval"}, + # ) + dropout_warmup_steps: int = field( + default=0, + metadata={"help": "dropout warmup steps"}, + ) + hidden_dropout_prob: float = field( + default=0.0, + metadata={"help": "dropout probability for hidden layers"}, + ) + attention_probs_dropout_prob: float = field( + default=0.0, + metadata={"help": "dropout probability for attention layers"}, + ) + rl_algorithm: str = field( + default="ppo", + metadata={"help": "RL algorithm (supports PPO and GRPO)."}, + ) + use_tgt_len_value: bool = field( + default=False, + metadata={"help": "Whether to use tgt for KL."}, + ) + use_rm_server: bool = field(default=False, metadata={"help": "Use reward server instead of reward model."}) + + def __post_init__(self): + """ + 在初始化后执行的函数,用于设置一些默认值和验证参数。 + 如果 autotuner_benchmark 为 True,则将相关参数设置为默认值,并禁止其他任何操作。 + + Args: + None. + + Returns: + None. + + Raises: + None. + """ + super().__post_init__() + if self.autotuner_benchmark: + self.num_train_epochs = 1 + self.max_steps = 5 + self.do_train = True + self.do_export = False + self.do_predict = False + self.do_eval = False + self.overwrite_output_dir = True + self.load_best_model_at_end = False + self.report_to = [] + self.save_strategy = IntervalStrategy.NO + self.evaluation_strategy = IntervalStrategy.NO + self.per_device_prompt_batch_size = self.per_device_train_batch_size + self.min_dec_len = self.max_dec_len + # self.skip_profile_timer = False + + if not self.disable_tqdm: + self.logging_steps = 1 + self.logging_strategy = IntervalStrategy.STEPS + if self.per_device_rollout_batch_size < 0: + self.per_device_rollout_batch_size = self.per_device_train_batch_size + assert self.rl_algorithm in ["ppo", "grpo"], 'self.rl_algorithm should be one of ["ppo", "grpo"]' + if self.rl_algorithm == "grpo": + self.normalize_reward = False + self.normalize_advantage = False + + +@dataclass +class ModelArgument: + actor_model_name_or_path: str = field( + default=None, + metadata={"help": "Build-in pretrained model name or the path to local model."}, + ) + reward_model_name_or_path: str = field( + default=None, + metadata={"help": "Build-in pretrained model name or the path to local model."}, + ) + reward_server: str = field( + default=None, + metadata={"help": "Reward server address."}, + ) + reward_critic_model_name_or_path: str = field( + default=None, + metadata={"help": "Build-in pretrained model name or the path to local model."}, + ) + actor_tokenizer_alpha: float = field(default=None, metadata={"help": "Tokenizer will tokenize randomly"}) + reward_tokenizer_alpha: float = field(default=None, metadata={"help": "Tokenizer will tokenize randomly"}) + reward_critic_tokenizer_alpha: float = field(default=None, metadata={"help": "Tokenizer will tokenize randomly"}) + use_flash_attention: bool = field(default=False, metadata={"help": "Whether to use flash attention"}) + use_attn_mask_start_row_indices: bool = field(default=False, metadata={"help": "Should in data args"}) + stage: str = field(default="PPO", metadata={"help": "The type of training."}) + fused_linear: bool = field(default=True, metadata={"help": "Whether to use fused_gemm_epilogue"}) + recompute_granularity: str = field( + default="full", + metadata={ + "help": "The granularity of recompute in policy model, " + "can be selected as `full` or `full_attn` or `core_attn`. " + }, + ) + critic_recompute_granularity: str = field( + default="full", + metadata={ + "help": "The granularity of recompute in critic model, " + "can be selected as `full` or `full_attn` or `core_attn`. " + }, + ) + chat_template: str = field( + default="none", + metadata={ + "help": "the path of `chat_template.json` file to handle multi-rounds conversation. " + "If is None(do not set --chat_template argument), it will use the default `chat_template.json`;" + "If is equal with `model_name_or_path`, it will use the default loading; " + "If is directory, it will find the `chat_template.json` under the directory; If is file, it will load it." + "If is none string, it will not use chat_template.json." + }, + ) + + +@dataclass +class DataArgument: + train_datasets: str = field(default=None, metadata={"help": "Dataset name(s) registered in the raw dataset."}) + eval_datasets: str = field(default=None, metadata={"help": "Dataset name(s) registered in the raw dataset."}) + eval_split_ratio: float = field(default=None, metadata={"help": "Ratio of eval data to train data"}) + ptx_datasets: str = field(default=None, metadata={"help": "Dataset name(s) registered in the raw dataset."}) + max_length: int = field( + default=2048, + metadata={ + "help": "The maximum length that model input tokens can have. When intokens is set to True, it's also the maximum length for InTokens data stream" + }, + ) + max_prompt_len: int = field(default=4096, metadata={"help": "Maximum prompt length."}) + + @property + def parsed_train_datasets(self) -> Tuple[str, Dict[str, Any]]: + """Parse dataset path and its proportion and optionally additional arguments from `train_datasets`.""" + return [parse_dataset(string) for string in self.train_datasets.split(",")] + + @property + def parsed_eval_datasets(self) -> Tuple[str, Dict[str, Any]]: + """Parse dataset path and its proportion and optionally additional arguments from `eval_datasets`.""" + if self.eval_datasets is None: + return None + return [parse_dataset(string) for string in self.eval_datasets.split(",")] + + @property + def parsed_ptx_datasets(self) -> Tuple[str, Dict[str, Any]]: + """Parse dataset path and its proportion and optionally additional arguments from `ptx_datasets`.""" + if self.ptx_datasets is None: + return None + return [parse_dataset(string) for string in self.ptx_datasets.split(",")] # ########## patches for Trianer ########## def init_train_model_opt( - self: Trainer, max_steps: int, resume_from_checkpoint: bool = False, clear_master_weight: bool = False + self: Trainer, + max_steps: int, + resume_from_checkpoint: bool = False, + clear_master_weight: bool = False, ) -> PretrainedModel: + """ + 初始化训练模型和优化器,并返回已包装的模型。 + + Args: + self (Trainer): Trainer实例对象。 + max_steps (int): 最大训练步数。 + resume_from_checkpoint (bool, optional, default=False): 是否从保存点中恢复训练,默认为False。 + Defaults to False. + clear_master_weight (bool, optional, default=False): 在使用Trainer的分布式硬件加速时,清除主参数权重,默认为False。 + Defaults to False. + + Returns: + PretrainedModel: 已经包装好的模型。 + """ # Copy of model/optimizer init and resuming related code in `Trainer.train`. # NOTE: this `_load_from_checkpoint` is indeed to load model states in the # following elif-else branches, though they are apart away in `Trainer.train`. @@ -77,7 +448,8 @@ def init_train_model_opt( model = self._wrap_model_and_load_sharded_checkpoint(resume_from_checkpoint) elif self.args.should_save_sharding_stage1_model: # In the non-sharded mode, should invoke _load_from_checkpoint before _wrap_model. - # In this mode, the rank0 load all params and the _wrap_model implicitly broadcast params from rank0 to the other ranks. + # In this mode, the rank0 load all params and the _wrap_model implicitly broadcast + # params from rank0 to the other ranks. model = self._wrap_model(self.model_wrapped) if self.sharding_io is not None: assert delay_optimizer_creation is False, "delay_optimizer_creation should be False" @@ -119,6 +491,23 @@ def init_train_state( num_train_epochs: int, num_update_steps_per_epoch: int, ): + """ + 初始化训练状态。 + + Args: + self (Trainer): Trainer实例,用于记录训练状态。 + resume_from_checkpoint (bool, optional): 是否从检查点继续训练,默认为None。 + train_dataloader (DataLoader, optional): 训练数据加载器,默认为None。 + max_steps (int, optional): 最大训练步数,默认为-1。 + num_train_epochs (int, optional): 训练的最大轮数,默认为3。 + num_update_steps_per_epoch (int, optional): 每个轮次更新模型的步数,默认为1。 + + Returns: + Tuple[int, int, Optional[tqdm]]: + - epochs_trained (int): 已经训练了多少个epoch。 + - steps_trained_in_current_epoch (int): 如果不忽略数据跳过,则为当前epoch中已经训练了多少个批次;否则为0。 + - steps_trained_progress_bar (Optional[tqdm]): 如果不忽略数据跳过,则为一个tqdm进度条,用于显示正在跳过第一个批次;否则为None。 + """ args = self.args self.state = TrainerState() @@ -171,7 +560,11 @@ def init_train_state( self.state.is_local_process_zero = self.is_local_process_zero() self.state.is_world_process_zero = self.is_world_process_zero() - return epochs_trained, steps_trained_in_current_epoch, steps_trained_progress_bar + return ( + epochs_trained, + steps_trained_in_current_epoch, + steps_trained_progress_bar, + ) def init_train_log( @@ -183,6 +576,21 @@ def init_train_log( num_train_samples: int, model: PretrainedModel, ): + """ + 初始化训练日志。 + + Args: + self (Trainer): Trainer实例,包含了训练所需的参数和信息。 + num_examples (int): 训练集中样本的总数。 + num_train_epochs (int): 训练的 epoch 数量。 + total_train_batch_size (int): 单个设备上的训练 batch 大小之和。 + max_steps (int): 最大训练步数。 + num_train_samples (int): 训练集中样本的总数。 + model (PretrainedModel): 被训练的模型。 + + Returns: + None, 该函数不返回任何值。 + """ args = self.args logger.info("***** Running training *****") @@ -196,7 +604,7 @@ def init_train_log( # per_device_trainable_numel = sum(p.numel().item() for p in model.parameters() if not p.stop_gradient) # TODO: Temporary fix since Tensor.numel() not supported in distributed mode per_device_trainable_numel = sum(np.prod(p.shape) for p in model.parameters() if not p.stop_gradient) - logger.info(f" Number of trainable parameters = {per_device_trainable_numel:,} (per device)") + logger.debug(f" Number of trainable parameters = {per_device_trainable_numel:,} (per device)") if self.args.use_hybrid_parallel: # todo fix for pipeline_parallel_degree parts_num = max(self.args.tensor_parallel_degree, 1) * max(self.args.pipeline_parallel_degree, 1) @@ -210,7 +618,7 @@ def init_train_log( trainable_numel = int(trainable_numel_tensor.item()) // self.args.dataset_world_size # the numel is roughly, because the tensor parallel still hold own bias or layer_norm weight without splited # so, the trainable numel is a little bigger than real. - logger.info(f" Number of trainable parameters = {trainable_numel:,} (all devices, roughly)") + logger.debug(f" Number of trainable parameters = {trainable_numel:,} (all devices, roughly)") def full_training_step(self: Trainer, inputs: Dict[str, paddle.Tensor], **kwargs): @@ -236,6 +644,7 @@ def full_training_step(self: Trainer, inputs: Dict[str, paddle.Tensor], **kwargs steps_trained_progress_bar = kwargs.get("steps_trained_progress_bar", None) # for eval output ignore to gather ignore_keys_for_eval = kwargs.get("ignore_keys_for_eval", None) + # timer_name = kwargs.get("timer_name", "") tr_loss = kwargs.get("tr_loss", 0.0) model = kwargs.get("model", self.model_wrapped) # needed in _maybe_log_save_evaluate @@ -281,7 +690,7 @@ def full_training_step(self: Trainer, inputs: Dict[str, paddle.Tensor], **kwargs if step_control % args.gradient_accumulation_steps == 0: self.control = self.callback_handler.on_step_begin(args, self.state, self.control) - self.timers and self.timers("forward-backward").start() + # self.timers and self.timers(f"{timer_name}: forward-backward").start() dp_enabled = self.args.data_parallel_degree > 1 if self.args.use_hybrid_parallel else args.local_rank != -1 forbidden_no_sync = False @@ -322,13 +731,14 @@ def full_training_step(self: Trainer, inputs: Dict[str, paddle.Tensor], **kwargs if self.args.pipeline_parallel_degree <= 1 and self._enable_delay_scale_loss(): tr_loss /= self.args.gradient_accumulation_steps - self.timers and self.timers("forward-backward").stop() + # self.timers and self.timers(f"{timer_name}: forward-backward").stop() + # Maunally collect gradients # Case 1: Use recompute and dp # Case 2: Hack dp with master_grad # Case 3: Pipeline or sharding overlap # local_rank != -1 don't means dp in networks. - self.timers and self.timers("all-reduce").start() + # self.timers and self.timers(f"{timer_name}: all-reduce").start() # Case 1: Use recompute and dp / sharding stage1, # manualy collect gradient for dp. @@ -340,11 +750,11 @@ def full_training_step(self: Trainer, inputs: Dict[str, paddle.Tensor], **kwargs fused_allreduce_gradients(list(model.parameters()), None) # Pipeline parallel mode, handle gradient reduce here to overlap - enable_dp_comm_overlap = False - enable_release_grads = False - if args.pipeline_parallel_degree > 1: - enable_dp_comm_overlap = "enable_dp_comm_overlap" in args.pipeline_parallel_config - enable_release_grads = "enable_release_grads" in args.pipeline_parallel_config + pipeline_parallel_config = ( + set(args.pipeline_parallel_config.split(" ")) if args.pipeline_parallel_degree > 1 else set() + ) + enable_dp_comm_overlap = "enable_dp_comm_overlap" in pipeline_parallel_config + enable_release_grads = "enable_release_grads" in pipeline_parallel_config # Case 3: Pipeline parallel mode, overlap with dp if isinstance(self.optimizer, HybridParallelOptimizer) and not self.do_grad_scaling: @@ -358,8 +768,8 @@ def full_training_step(self: Trainer, inputs: Dict[str, paddle.Tensor], **kwargs if self.optimizer._dp_enable or getattr(self.optimizer, "_sep_enable", False): fused_allreduce_gradients(list(parameters_list), self.optimizer._hcg) - self.timers and self.timers("all-reduce").stop() - self.timers and self.timers("optimizer-step").start() + # self.timers and self.timers(f"{timer_name}: all-reduce").stop() + # self.timers and self.timers(f"{timer_name}: optimizer-step").start() if self.args.gradient_accumulation_steps > 1 and self._enable_delay_scale_loss(): for p in model._layers.parameters(): @@ -372,7 +782,10 @@ def full_training_step(self: Trainer, inputs: Dict[str, paddle.Tensor], **kwargs # Optimizer step self.callback_handler.on_optimizer_begin( - args, self.state, self.control, scaler=self.scaler if self.do_grad_scaling else None + args, + self.state, + self.control, + scaler=self.scaler if self.do_grad_scaling else None, ) optimizer_was_run = True if self.do_grad_scaling: @@ -396,7 +809,7 @@ def full_training_step(self: Trainer, inputs: Dict[str, paddle.Tensor], **kwargs else: self.optimizer.step() - self.timers and self.timers("optimizer-step").stop() + # self.timers and self.timers(f"{timer_name}: optimizer-step").stop() if optimizer_was_run: self.lr_scheduler.step() @@ -410,7 +823,10 @@ def full_training_step(self: Trainer, inputs: Dict[str, paddle.Tensor], **kwargs self.optimizer.clear_grad(set_to_zero=False) self.callback_handler.on_optimizer_end( - args, self.state, self.control, scaler=self.scaler if self.do_grad_scaling else None + args, + self.state, + self.control, + scaler=self.scaler if self.do_grad_scaling else None, ) self.state.global_step += 1 @@ -452,7 +868,28 @@ class MuteDefaultFlowCallback(TrainerCallback): Use this when having multi trainer. """ - def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + def on_step_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + """ + 在一个步骤结束时调用,可以用来更新控制流程。 + + Args: + args (TrainingArguments): 训练参数对象。 + state (TrainerState): 训练器状态对象。 + control (TrainerControl): 训练控制对象,包含了训练过程中的控制信息,如是否保存模型、是否进行评估和是否记录日志等。 + kwargs (dict, optional): 其他关键字参数,默认为None,没有使用。 + + Returns: + TrainerControl: 返回一个TrainerControl对象,包含了训练过程中的控制信息,如是否保存模型、是否进行评估和是否记录日志等。 + + Raises: + None + """ control.should_save = False control.should_evaluate = False control.should_log = False @@ -461,6 +898,24 @@ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: Tra @contextmanager def guard_set_args(args, arg_name_values): + """ + 在一个上下文中,设置给定的参数名称和值,并在上下文结束后将其还原。 + + Args: + args (object): 需要修改参数的对象,通常是命令行解析器的实例。 + arg_name_values (dict[str, Any]): 包含参数名称和新值的字典,该函数会在上下文中修改这些参数。 + key (str): 参数名称。 + value (Any): 参数的新值。 + + Yields: + None: 无返回值,只是用于上下文管理。 + + Returns: + None: 无返回值,只是用于上下文管理。 + + Raises: + None: 不会引发任何异常。 + """ for k, v in arg_name_values.items(): old_value = getattr(args, k, None) setattr(args, k, v) @@ -479,6 +934,12 @@ class PipeEvalModel(GenerationMixin): """ def __init__(self, trainer: Trainer): + """ + Args: + trainer (Trainer): Trainer object. + The trainer should have a attribute named `_inner_eval_model` which is the model used for evaluation. + If it does not exist, then the model in `trainer.model_wrapped` will be used. + """ eval_model = getattr(trainer, "_inner_eval_model", None) self.model: fleet.model.PipelineParallel = trainer.model_wrapped if eval_model is None else eval_model self.config: PretrainedConfig = trainer.model.config @@ -489,21 +950,62 @@ def __init__(self, trainer: Trainer): @property def pp_group(self): + """ + 获取当前模型的属性分组,返回值为str类型。 + 如果模型没有设置属性分组,则返回None。 + + Returns: + str, optional: 当前模型的属性分组,默认为None。 + """ return self.model.pp_group def eval(self): + """ + 将模型置于评估模式,禁用梯度计算和 dropout。 + 返回:None + """ self.model.eval() def train(self): + """ + 将模型设置为训练模式。 + 在调用任何前向传播函数之前,必须先调用此函数。 + + Returns: + None, 无返回值。 + """ self.model.train() def __getattr__(self, name): + """ + 如果在当前类中没有找到对应的属性,则尝试从模型中获取。 + 如果在模型中也没有找到对应的属性,则会引发AttributeError异常。 + + Args: + name (str): 要查询的属性名称。 + + Returns: + Any: 返回属性值,如果在当前类和模型中都没有找到该属性,则会引发AttributeError异常。 + + Raises: + AttributeError: 如果在当前类和模型中都没有找到对应的属性。 + """ try: return super().__getattr__(name) except AttributeError: return getattr(self.model, name) def _broadcast_outputs(self, outputs): + """ + 将输出广播到所有进程中,如果不是最后一个阶段则返回元组,否则返回ModelOutput或者paddle.Tensor。 + 如果不是最后一个阶段,会对输入的每个张量创建一个与其形状、类型相同但内容为空的新张量,并广播这些张量。 + + Args: + outputs (Union[paddle.Tensor, Tuple[paddle.Tensor], ModelOutput]): 模型的输出,可以是单个张量或张量元组,也可以是ModelOutput。 + + Returns: + Union[paddle.Tensor, Tuple[paddle.Tensor], ModelOutput]: 如果不是最后一个阶段,返回元组;否则返回ModelOutput或者paddle.Tensor。 + """ # outputs is PipelineParallel.eval_batch which is a list of batches. out = [] outputs = (outputs,) if isinstance(outputs, paddle.Tensor) else outputs @@ -512,16 +1014,19 @@ def _broadcast_outputs(self, outputs): tensor = tensors if isinstance(tensors, paddle.Tensor) else tensors[0] head_out_meta = ( (self.model._layers.head_out_meta,) - if isinstance(self.model._layers.head_out_meta, paddle.static.InputSpec) + if isinstance( + self.model._layers.head_out_meta, + paddle.static.InputSpec, + ) else self.model._layers.head_out_meta ) tensors = tuple( paddle.empty( shape=[ - tensor.shape[i] if (meta.shape[i] is None or meta.shape[i] < 0) else meta.shape[i] + (tensor.shape[i] if (meta.shape[i] is None or meta.shape[i] < 0) else meta.shape[i]) for i in range(len(meta.shape)) ], - dtype=tensor.dtype if meta.dtype is None else meta.dtype, + dtype=(tensor.dtype if meta.dtype is None else meta.dtype), ) for meta in head_out_meta ) @@ -531,18 +1036,32 @@ def _broadcast_outputs(self, outputs): tensors = ( (tensors,) if isinstance(tensors, paddle.Tensor) - else tensors.to_tuple() - if isinstance(tensors, ModelOutput) - else tensors + else (tensors.to_tuple() if isinstance(tensors, ModelOutput) else tensors) ) # use map_structure seems hung for tensor in tensors: - paddle.distributed.broadcast(tensor, src=self.model.pp_group.ranks[-1], group=self.model.pp_group) + paddle.distributed.broadcast( + tensor, + src=self.model.pp_group.ranks[-1], + group=self.model.pp_group, + ) out.append(tensors[0] if len(tensors) == 1 else tensors) return out[0] if len(out) == 1 else out def __call__(self, *args, **kwargs): + """ + Call the method to generate output from given input. + + Args: + *args (tuple, optional): Input arguments to the method. Defaults to (). + **kwargs (dict, optional): Keyword arguments to the method. Defaults to {}. + + Returns: + Union[List[Any], Tuple[Any]]: Output generated from the input. If the method is + called multiple times, each call returns one output. The type of the output + depends on the implementation of the method. + """ model = self.model assert self.model.training is False if self._is_gen: @@ -563,9 +1082,11 @@ def __call__(self, *args, **kwargs): # next_tokens though logits are broadcasted since pp ranks' seeds differs. # Currently, just slice the last token to reduce comm overhead. outputs = [ - micro_batch_output[:, -1, :].unsqueeze(1) - if isinstance(micro_batch_output, paddle.Tensor) - else micro_batch_output[0][:, -1, :].unsqueeze(1) + ( + micro_batch_output[:, -1, :].unsqueeze(1).contiguous() + if isinstance(micro_batch_output, paddle.Tensor) + else micro_batch_output[0][:, -1, :].unsqueeze(1).contiguous() + ) for micro_batch_output in outputs ] outputs = self._broadcast_outputs(outputs) @@ -581,20 +1102,44 @@ def __call__(self, *args, **kwargs): return outputs def generate(self, *args, **kwargs): + """ + 重写父类的方法,在生成文本时使用缓存。 + 首先将self._is_gen设置为True,然后修改DecoderLayerPipe以使用缓存。 + 接下来,调用super().generate(*args, **kwargs)进行文本生成。 + 最后,清除所有层中的缓存(包括子层),并将self._has_cache设置为False。 + + Args: + args (Tuple[Any], optional): 可变参数列表,默认为空元组。 + kwargs (Dict[str, Any], optional): 关键字参数字典,默认为空字典。 + + Returns: + Tuple[Any]: 返回一个元组,其中包含了生成的文本和相应的概率分布。 + + Raises: + 无。 + """ self._is_gen = True # patch DecoderLayerPipe to use cache, DecoderLayerPipe is subclass of # DecoderLayer, and would call super().forward ori_decoder_layer_forward = self.model._layers._non_pipe_decoder_layer_class.forward def decoder_layer_forward(layer_self, *args, **kwargs): - kwargs.update({"use_cache": True, "past_key_value": getattr(layer_self, "_cache", None)}) + kwargs.update( + { + "use_cache": True, + "cache": getattr(layer_self, "_cache", None), + } + ) outputs = ori_decoder_layer_forward(layer_self, *args, **kwargs) output = outputs[0] layer_self._cache = outputs[1] self._has_cache = True return output - with guard_set_args(self.model._layers._non_pipe_decoder_layer_class, {"forward": decoder_layer_forward}): + with guard_set_args( + self.model._layers._non_pipe_decoder_layer_class, + {"forward": decoder_layer_forward}, + ): outputs = super().generate(*args, **kwargs) self._is_gen = False # clear cache of decoder layers, sublayers is incursive thus suitable @@ -606,6 +1151,27 @@ def decoder_layer_forward(layer_self, *args, **kwargs): return outputs def prepare_inputs_for_generation(self, *args, **kwargs): + """ + Prepare the input for generation. This method is used by + :meth:`~transformers.Pipeline.__call__` to generate text from prompts. + + Args: + *args (tuple, optional): Arguments passed to :meth:`~transformers.Pipeline.__call__`. + **kwargs (dict, optional): Keyword arguments passed to :meth:`~transformers.Pipeline.__call__`. + + Returns: + dict: A dictionary containing the prepared inputs for generation. The keys are: + + - "prompt" (:obj:`str`, `optional`, defaults to :obj:`None`): + Text to be decoded. If not provided, the pipeline will try to use the cached prompts. + - "cache" (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to use the cached past key values. If not provided, it will be set to :obj:`True` when + the pipeline has cache. + - Other keyword arguments are passed to :meth:`~transformers.Pipeline.__call__`. + + Raises: + ValueError: If both ``prompt`` and ``cache`` are not provided. + """ arg_bind = inspect.signature(self.model._layers._non_pipe_model_class.prepare_inputs_for_generation).bind( *((self,) + args), **kwargs ) @@ -617,13 +1183,13 @@ def prepare_inputs_for_generation(self, *args, **kwargs): else: arg_dict[last_arg_name] = last_arg_value arg_dict.pop("self") - past_key_values = arg_dict.get("past_key_values", None) - # prepare_inputs_for_generation use past_key_values to discrimate prefill + cache = arg_dict.get("cache", None) + # prepare_inputs_for_generation use cache to discrimate prefill # or decode and slice inputs accordingly. if getattr(self, "_has_cache", False): - arg_dict.update({"past_key_values": True}) + arg_dict.update({"cache": True}) model_inputs = self.model._layers._non_pipe_model_class.prepare_inputs_for_generation(self, **arg_dict) - model_inputs.update({"past_key_values": past_key_values}) + model_inputs.update({"cache": cache}) return model_inputs @@ -637,26 +1203,100 @@ def is_same_tokenizer( ) +def retokenize(src_tokenizer, dest_tokenizer, token_ids, skip_special_tokens): + """Retokenize a sequence of token ids from one tokenizer to another.""" + tokens = src_tokenizer.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) + part_tokens = [] + result_ids = [] + for token in tokens: + if token in src_tokenizer.all_special_tokens: + if part_tokens: + decoded_text = src_tokenizer.decode( + src_tokenizer.convert_tokens_to_ids(part_tokens), + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=False, + ) + tmp_tokens = dest_tokenizer.tokenize(decoded_text) + result_ids.extend(dest_tokenizer.convert_tokens_to_ids(tmp_tokens)) + part_tokens = [] # 清空 + # 转换当前特殊 token + special_token = dest_tokenizer.convert_tokens_to_ids(token) + result_ids.append(special_token) + else: + part_tokens.append(token) + # 如果有,处理最后一段(一般不应该走到, 应该以special token结尾) + if part_tokens: + decoded_text = src_tokenizer.decode( + src_tokenizer.convert_tokens_to_ids(part_tokens), + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=False, + ) + tmp_tokens = dest_tokenizer.tokenize(decoded_text) + result_ids.extend(dest_tokenizer.convert_tokens_to_ids(tmp_tokens)) + return result_ids + + def batch_retokenize( input_ids: paddle.Tensor, src_tokenizer: PretrainedTokenizer, dest_tokenizer: PretrainedTokenizer, *, padding: bool | str | PaddingStrategy = PaddingStrategy.LONGEST, - truncation: bool | str | TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, - skip_special_tokens: bool = True, + skip_special_tokens: bool = False, ) -> BatchEncoding: """Re-tokenize a batch of input ids from one tokenizer to another.""" - output = dest_tokenizer( - [ - text + dest_tokenizer.eos_token - for text in src_tokenizer.batch_decode( - input_ids, - skip_special_tokens=skip_special_tokens, - ) - ], + all_ids = [] + for token_ids in input_ids: + tmp_ids = retokenize(src_tokenizer, dest_tokenizer, token_ids, skip_special_tokens) + all_ids.append(tmp_ids) + output = {} + + output["input_ids"] = dest_tokenizer.pad( + {"input_ids": all_ids}, padding=padding, - truncation=truncation, + return_attention_mask=False, return_tensors="pd", - ) + )["input_ids"] + output["attention_mask"] = make_attention_mask( + output["input_ids"], + pad_id=dest_tokenizer.pad_token_id, + eos_id=dest_tokenizer.eos_token_id, + unk_id=dest_tokenizer.unk_token_id, + causal_mask=True, + ).cast(paddle.bfloat16) + output["position_ids"] = make_position_ids(output["attention_mask"]) return output + + +def process_row(row, remove_value=0, remove_side="both"): + """ + 从张量中去除前导/尾随的特定值。 + + Args: + row (paddle.Tensor): 待处理的张量,一维。 + remove_value (int, optional): 要去除的值,默认为0。 + remove_side (str, optional): 去除的位置,可选"left"(只去除前导)、"right"(只去除尾随)、"both"(去除前导和尾随),默认为"both"。 + + Returns: + paddle.Tensor: 处理后的张量,一维。 + + """ + non_zero_indices = paddle.nonzero(row != remove_value).flatten() + if non_zero_indices.shape[0] == 0: + # 行全为0,警告,不处理 + logger.warning("Row is all zeros, no trimming will be performed.") + return row + start_index = non_zero_indices[0] + end_index = non_zero_indices[-1] + # 切取中间的非零部分 + if remove_side == "left": + trimmed_row = row[start_index:] + elif remove_side == "right": + trimmed_row = row[: end_index + 1] + elif remove_side == "both": + trimmed_row = row[start_index : end_index + 1] + else: + logger.warning("unknown remove_side, using both remove_side.") + trimmed_row = row[start_index : end_index + 1] + + return trimmed_row diff --git a/llm/alignment/rm/reward_trainer.py b/llm/alignment/rm/reward_trainer.py index a542d55942a2..ac0a2a71a7ac 100644 --- a/llm/alignment/rm/reward_trainer.py +++ b/llm/alignment/rm/reward_trainer.py @@ -37,7 +37,7 @@ speed_metrics = trainer.speed_metrics -def patch_speed_metrics(split, start_time, num_samples=None, num_steps=None, seq_length=None): +def patch_speed_metrics(split, start_time, num_samples=None, num_steps=None, seq_length=None, model_flops=None): # split: interval, train, eval, test result = speed_metrics(split, start_time, num_samples, num_steps, seq_length) if split not in ["train", "interval"]: diff --git a/llm/alignment/rm/run_reward.py b/llm/alignment/rm/run_reward.py index 32237592d864..d429af643c26 100644 --- a/llm/alignment/rm/run_reward.py +++ b/llm/alignment/rm/run_reward.py @@ -19,11 +19,6 @@ import paddle -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal - from data import PreferenceDataset, parse_dataset from models import AutoModelForScore from reward_trainer import RewardTrainer @@ -35,11 +30,12 @@ @dataclass class TrainingArguments(TrainingArguments): - loss_type: Literal["token-wise", "sequence-wise"] = field( + loss_type: str = field( default="sequence-wise", metadata={ - "help": "Calculate ranking loss with all token-wise reward outputs in the sequence or the " - "sequence-wise reward output only (the reward of the last token in each sequence)." + "help": "Calculate ranking loss using either 'token-wise' (all token-wise reward outputs in the sequence) " + "or 'sequence-wise' (reward of the last token in each sequence). " + "Allowed values: ['token-wise', 'sequence-wise']." }, ) # regularization @@ -57,8 +53,11 @@ class ModelArgument: normalize_score_during_training: bool = field( default=False, metadata={"help": "Whether to normalize score during training."} ) - normalizer_type: Literal["RunningMeanStd", "ExponentialMovingAverage"] = field( - default=None, metadata={"help": "The type of the reward normalizer."} + normalizer_type: str = field( + default="ExponentialMovingAverage", + metadata={ + "help": "The type of the reward normalizer. Allowed values: ['RunningMeanStd', 'ExponentialMovingAverage']." + }, ) normalizer_momentum: float = field( default=None, diff --git a/llm/config/llama/grpo_argument.json b/llm/config/llama/grpo_argument.json new file mode 100644 index 000000000000..fa3dbe8ead2f --- /dev/null +++ b/llm/config/llama/grpo_argument.json @@ -0,0 +1,82 @@ +{ + "train_datasets": "PKU-SafeRLHF/train", + "eval_datasets": "PKU-SafeRLHF/test", + "ptx_datasets": "alpaca", + "actor_model_name_or_path": "PKU-Alignment/alpaca-7b-reproduced", + "reward_model_name_or_path": "PKU-Alignment/beaver-7b-v1.0-reward", + "output_dir": "checkpoints/llama-grpo", + "logging_dir": "log", + "max_length": 2048, + "use_fusemt": 1, + "use_flash_attention": 1, + "max_dec_len": 1024, + "min_dec_len": 1, + "top_p": 0.8, + "temperature": 1.0, + "num_return_sequences": 1, + "repetition_penalty": 1.0, + "num_train_epochs": 1, + "max_steps": 17, + "update_iters": 1, + "per_device_prompt_batch_size": 2, + "per_device_train_batch_size": 2, + "gradient_accumulation_steps": 1, + "learning_rate": 2e-6, + "min_learning_rate": 2e-7, + "weight_decay": 0.01, + "lr_scheduler_type": "cosine", + "warmup_ratio": 0.03, + "recompute": 1, + "recompute_granularity": "full", + "recompute_use_reentrant": 1, + "critic_learning_rate": 2e-6, + "critic_min_learning_rate": 2e-7, + "critic_weight_decay": 0.01, + "critic_lr_scheduler_type": "cosine", + "critic_warmup_ratio": 0.03, + "critic_recompute": 1, + "critic_recompute_granularity": "full", + "normalize_reward": 1, + "normalize_advantage": 1, + "kl_coeff": 0.02, + "clip_range_ratio": 0.2, + "clip_range_score": 10.0, + "clip_range_value": 5.0, + "ptx_coeff": 16.0, + "logging_steps": 1, + "logging_dir": "vdl_log", + "evaluation_strategy": "no", + "per_device_eval_batch_size": 16, + "eval_steps": 10000, + "save_strategy": "steps", + "save_steps": 400, + "save_total_limit": 5, + "bf16": 1, + "fp16": 0, + "fp16_opt_level": "O2", + "do_train": 1, + "do_eval": 0, + "disable_tqdm": 1, + "sharding_parallel_degree": 1, + "sharding": "stage1", + "tensor_parallel_degree": 8, + "tensor_parallel_output": 0, + "pipeline_parallel_degree": 1, + "pipeline_parallel_config": "disable_p2p_cache_shape", + "sequence_parallel": 0, + "max_grad_norm": 1.0, + "adam_beta1": 0.9, + "adam_beta2": 0.95, + "dataloader_drop_last": 0, + "eval_mode": "", + "offload_level": "freeze_model optimizer train_model", + "release_grads": 1, + "seed": 23, + "use_fused_head_and_loss_fn": 0, + "fused_linear":1, + "autotuner_benchmark": 0, + "skip_profile_timer": 1, + "use_rm_server": true, + "reward_server": "http://10.174.146.80:8048", + "rl_algorithm": "grpo" +} diff --git a/llm/config/llama/ppo_argument.json b/llm/config/llama/ppo_argument.json index 442f78433562..f9058a7d3f11 100644 --- a/llm/config/llama/ppo_argument.json +++ b/llm/config/llama/ppo_argument.json @@ -4,54 +4,75 @@ "ptx_datasets": "alpaca", "actor_model_name_or_path": "PKU-Alignment/alpaca-7b-reproduced", "reward_model_name_or_path": "PKU-Alignment/beaver-7b-v1.0-reward", - "output_dir": "checkpoints/llm_ppo", - "max_length": 512, + "output_dir": "checkpoints/llama-ppo", + "max_length": 2048, + "use_fusemt": 1, + "use_flash_attention": 1, + "max_dec_len": 1024, + "min_dec_len": 1, "top_p": 0.8, "temperature": 1.0, - "num_return_sequences":1, + "num_return_sequences": 1, "repetition_penalty": 1.0, "num_train_epochs": 1, + "max_steps": 17, "update_iters": 1, - "per_device_prompt_batch_size": 16, - "per_device_train_batch_size": 16, + "per_device_prompt_batch_size": 2, + "per_device_train_batch_size": 2, "gradient_accumulation_steps": 1, - "learning_rate": 1e-5, + "learning_rate": 2e-6, + "min_learning_rate": 2e-7, "weight_decay": 0.01, "lr_scheduler_type": "cosine", "warmup_ratio": 0.03, - "recompute": true, - "critic_learning_rate": 5e-6, - "critic_weight_decay": 0.0, - "critic_lr_scheduler_type": "constant", + "recompute": 1, + "recompute_granularity": "full", + "recompute_use_reentrant": 1, + "critic_learning_rate": 2e-6, + "critic_min_learning_rate": 2e-7, + "critic_weight_decay": 0.01, + "critic_lr_scheduler_type": "cosine", "critic_warmup_ratio": 0.03, - "critic_recompute": true, - "normalize_reward": false, + "critic_recompute": 1, + "critic_recompute_granularity": "full", + "normalize_reward": 1, + "normalize_advantage": 1, "kl_coeff": 0.02, "clip_range_ratio": 0.2, - "clip_range_score": 50.0, + "clip_range_score": 10.0, "clip_range_value": 5.0, "ptx_coeff": 16.0, - "per_device_eval_batch_size": 16, "logging_steps": 1, - "evaluation_strategy": "steps", - "eval_steps": 100, - "save_strategy": "epoch", - "save_steps": 100000, - "bf16": true, + "logging_dir": "vdl_log", + "evaluation_strategy": "no", + "per_device_eval_batch_size": 16, + "eval_steps": 10000, + "save_strategy": "steps", + "save_steps": 400, + "save_total_limit": 5, + "bf16": 1, + "fp16": 0, "fp16_opt_level": "O2", - "do_train": true, - "do_eval": true, - "disable_tqdm": true, - "save_total_limit": 1, - "sharding_parallel_degree": 4, + "do_train": 1, + "do_eval": 0, + "disable_tqdm": 1, + "sharding_parallel_degree": 1, "sharding": "stage1", - "tensor_parallel_degree": 2, + "tensor_parallel_degree": 8, + "tensor_parallel_output": 0, "pipeline_parallel_degree": 1, "pipeline_parallel_config": "disable_p2p_cache_shape", - "max_grad_norm": 1.0, + "sequence_parallel": 0, + "max_grad_norm": 1.0, "adam_beta1": 0.9, "adam_beta2": 0.95, - "dataloader_drop_last": false, + "dataloader_drop_last": 0, "eval_mode": "", - "offload_level": "freeze_model" -} + "offload_level": "freeze_model optimizer train_model", + "release_grads": 1, + "seed": 23, + "use_fused_head_and_loss_fn": 0, + "fused_linear":1, + "autotuner_benchmark": 0, + "skip_profile_timer": 1 +} \ No newline at end of file diff --git a/llm/config/qwen/grpo_argument.json b/llm/config/qwen/grpo_argument.json new file mode 100644 index 000000000000..e56d186a7ac3 --- /dev/null +++ b/llm/config/qwen/grpo_argument.json @@ -0,0 +1,82 @@ +{ + "train_datasets": "PKU-SafeRLHF/train", + "eval_datasets": "PKU-SafeRLHF/test", + "ptx_datasets": "alpaca", + "actor_model_name_or_path": "/path/to/actor/model", + "reward_model_name_or_path": "path/to/reward/model", + "output_dir": "checkpoints/qwen-grpo", + "logging_dir": "log", + "max_length": 2048, + "use_fusemt": 1, + "use_flash_attention": 1, + "max_dec_len": 1024, + "min_dec_len": 1, + "top_p": 0.8, + "temperature": 1.0, + "num_return_sequences": 1, + "repetition_penalty": 1.0, + "num_train_epochs": 1, + "max_steps": 17, + "update_iters": 1, + "per_device_prompt_batch_size": 2, + "per_device_train_batch_size": 2, + "gradient_accumulation_steps": 1, + "learning_rate": 2e-6, + "min_learning_rate": 2e-7, + "weight_decay": 0.01, + "lr_scheduler_type": "cosine", + "warmup_ratio": 0.03, + "recompute": 1, + "recompute_granularity": "full", + "recompute_use_reentrant": 1, + "critic_learning_rate": 2e-6, + "critic_min_learning_rate": 2e-7, + "critic_weight_decay": 0.01, + "critic_lr_scheduler_type": "cosine", + "critic_warmup_ratio": 0.03, + "critic_recompute": 1, + "critic_recompute_granularity": "full", + "normalize_reward": 1, + "normalize_advantage": 1, + "kl_coeff": 0.02, + "clip_range_ratio": 0.2, + "clip_range_score": 10.0, + "clip_range_value": 5.0, + "ptx_coeff": 16.0, + "logging_steps": 1, + "logging_dir": "vdl_log", + "evaluation_strategy": "no", + "per_device_eval_batch_size": 16, + "eval_steps": 10000, + "save_strategy": "steps", + "save_steps": 400, + "save_total_limit": 5, + "bf16": 1, + "fp16": 0, + "fp16_opt_level": "O2", + "do_train": 1, + "do_eval": 0, + "disable_tqdm": 1, + "sharding_parallel_degree": 1, + "sharding": "stage1", + "tensor_parallel_degree": 8, + "tensor_parallel_output": 0, + "pipeline_parallel_degree": 1, + "pipeline_parallel_config": "disable_p2p_cache_shape", + "sequence_parallel": 0, + "max_grad_norm": 1.0, + "adam_beta1": 0.9, + "adam_beta2": 0.95, + "dataloader_drop_last": 0, + "eval_mode": "", + "offload_level": "freeze_model optimizer train_model", + "release_grads": 1, + "seed": 23, + "use_fused_head_and_loss_fn": 0, + "fused_linear":1, + "autotuner_benchmark": 0, + "skip_profile_timer": 1, + "use_rm_server": true, + "reward_server": "http://10.174.146.80:8048", + "rl_algorithm": "grpo" +} diff --git a/llm/docs/rlhf.md b/llm/docs/rlhf.md index 567b0de03cbd..2dc9f70880b2 100644 --- a/llm/docs/rlhf.md +++ b/llm/docs/rlhf.md @@ -47,7 +47,7 @@ - PaddlePaddle >= 2.6.0 - PaddleNLP 最新版本 -如需使用生成加速功能,需要安装 [paddlenlp_ops](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/csrc) ,请使用 `git clone https://github.com/PaddlePaddle/PaddleNLP.git` 克隆 PaddleNLP 代码库并且将 PaddleNLP/llm 目录的路径加入 PYTHONPATH(后续将进行完善)。安装 paddlenlp_ops 后训练时将直接开启生成加速(开启流水线并行时不支持生成加速),否则使用原生动态图进行生成。 +如需使用生成加速功能,需要安装 [paddlenlp_ops](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/csrc) ,请使用 `git clone https://github.com/PaddlePaddle/PaddleNLP.git` 克隆 PaddleNLP 代码库并且将 PaddleNLP/llm 与PaddleNLP/llm/alignment/ppo 目录的路径加入 PYTHONPATH(后续将进行完善)。安装 paddlenlp_ops 后训练时将直接开启生成加速(开启流水线并行时不支持生成加速),否则使用原生动态图进行生成。 ### 数据准备 diff --git a/llm/predict/predictor.py b/llm/predict/predictor.py index ab715172a2d6..4edd05911bae 100644 --- a/llm/predict/predictor.py +++ b/llm/predict/predictor.py @@ -49,6 +49,7 @@ Llama3Tokenizer, LlamaTokenizer, PretrainedConfig, + PretrainedModel, PretrainedTokenizer, ) from paddlenlp.trl import llm_utils @@ -204,8 +205,14 @@ def batchfy_text(texts, batch_size): class BasePredictor: - def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer = None): - self.model_config = AutoConfig.from_pretrained(config.model_name_or_path) + def __init__( + self, config: PredictorArgument, tokenizer: PretrainedTokenizer = None, model: PretrainedModel = None + ): + if model is not None and hasattr(model, "config"): + self.model_config = model.config + else: + self.model_config = AutoConfig.from_pretrained(config.model_name_or_path) + self.config: PredictorArgument = config if tokenizer is None: tokenizer = AutoTokenizer.from_pretrained(config.model_name_or_path, padding_side="left") @@ -278,9 +285,11 @@ def predict(self, input_texts: str | list[str], return_tokens=False): class DygraphPredictor(BasePredictor): - def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer = None, **kwargs): - super().__init__(config, tokenizer) - self.model = kwargs.get("model", None) + def __init__( + self, config: PredictorArgument, tokenizer: PretrainedTokenizer = None, model: PretrainedModel = None, **kwargs + ): + super().__init__(config, tokenizer, model) + self.model = model if config.lora_path is not None: lora_config = LoRAConfig.from_pretrained(config.lora_path) dtype = lora_config.dtype @@ -357,8 +366,10 @@ def stream_predict(self, inputs: dict[str, paddle.Tensor]): class StaticGraphPredictor(BasePredictor): - def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer = None, **kwargs): - super().__init__(config, tokenizer) + def __init__( + self, config: PredictorArgument, tokenizer: PretrainedTokenizer = None, model: PretrainedModel = None, **kwargs + ): + super().__init__(config, tokenizer, model) inference_config = paddle.inference.Config(self.config.model_name_or_path, self.config.model_prefix) @@ -409,9 +420,8 @@ def _infer(self, inputs: dict[str, np.ndarray]): class InferencePredictorMixin(BasePredictor): - def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer): - BasePredictor.__init__(self, config, tokenizer) - + def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer, model: PretrainedModel = None): + BasePredictor.__init__(self, config, tokenizer, model) self.architectures = self.model_config.architectures[0].lower() self.dtype = config.dtype or self.model_config.dtype @@ -660,12 +670,13 @@ def __init__( self, config: PredictorArgument, tokenizer: PretrainedTokenizer = None, + model: PretrainedModel = None, **kwargs, ): self.cache_kvs_shape = kwargs.get("cache_kvs_shape", None) if self.cache_kvs_shape is None: raise ValueError("cache_kvs_shape should be provided for StaticGraphInferencePredictor") - InferencePredictorMixin.__init__(self, config, tokenizer) + InferencePredictorMixin.__init__(self, config, tokenizer, model) self.predictor = self._create_predictor(config) @@ -737,13 +748,13 @@ def __init__( self, config: PredictorArgument, tokenizer: PretrainedTokenizer = None, + model: PretrainedModel = None, **kwargs, ): - model = kwargs.get("model", None) if model is None: raise ValueError("model should be provided for DygraphInferencePredictor") self.cache_kvs_shape = model.get_cache_kvs_shape(model.config, config.batch_size, config.total_max_length) - InferencePredictorMixin.__init__(self, config, tokenizer) + InferencePredictorMixin.__init__(self, config, tokenizer, model) self.model = model @paddle.no_grad() @@ -765,8 +776,13 @@ def _infer(self, inputs: dict[str, paddle.Tensor]): class BlockInferencePredictorMixin(BasePredictor): - def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer): - BasePredictor.__init__(self, config, tokenizer) + def __init__( + self, + config: PredictorArgument, + tokenizer: PretrainedTokenizer = None, + model: PretrainedModel = None, + ): + BasePredictor.__init__(self, config, tokenizer, model) self.num_layers = len(self.cache_k_shapes) self.num_key_value_heads = self.cache_k_shapes[0][-3] @@ -1027,14 +1043,15 @@ def _preprocess(self, input_text: list[str]): class DygraphBlockInferencePredictor(BlockInferencePredictorMixin): - def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer = None, **kwargs): - model = kwargs.get("model", None) + def __init__( + self, config: PredictorArgument, tokenizer: PretrainedTokenizer = None, model: PretrainedModel = None, **kwargs + ): self.return_full_hidden_states = config.return_full_hidden_states self.full_hidden_states = None if model is None: raise ValueError("model should be provided for DygraphBlockInferencePredictor") self.cache_k_shapes, self.cache_v_shapes = model.get_cache_kvs_shape(model.config, config.batch_size) - BlockInferencePredictorMixin.__init__(self, config, tokenizer) + BlockInferencePredictorMixin.__init__(self, config, tokenizer, model) cachekv_dtype = self.dtype if config.cachekv_int8_type is None else "uint8" @@ -1148,6 +1165,7 @@ def __init__( self, config: PredictorArgument, tokenizer: PretrainedTokenizer = None, + model: PretrainedModel = None, **kwargs, ): self.cache_k_shapes = kwargs.get("cache_k_shapes", None) @@ -1324,6 +1342,7 @@ def create_predictor( config: PretrainedConfig, model_args: ModelArgument, tokenizer: PretrainedTokenizer = None, + model: PretrainedModel = None, **kwargs, ): """ @@ -1338,7 +1357,6 @@ def create_predictor( Returns: Predictor: The predictor. """ - model = kwargs.pop("model", None) cache_kvs_shape = None # used for not block_attn/append_attn cache_k_shapes = None # used for block_attn/append_attn cache_v_shapes = None # used for block_attn/append_attn diff --git a/paddlenlp/experimental/transformers/llama/modeling.py b/paddlenlp/experimental/transformers/llama/modeling.py index b0ccf2dcd314..82ffebec9bb7 100644 --- a/paddlenlp/experimental/transformers/llama/modeling.py +++ b/paddlenlp/experimental/transformers/llama/modeling.py @@ -411,7 +411,7 @@ def __init__(self, config: LlamaConfig): elif config.quant_type == "weight_only_int4": self.use_weight_only = True self.quant_algo = "weight_only_int4" - elif "a8w8" in config.quant_type: + elif config.quant_type and "a8w8" in config.quant_type: self.quant_model_path = config.model_name_or_path self.shift = config.quantization_config.shift self.smooth = config.quantization_config.smooth @@ -672,6 +672,8 @@ def __init__(self, config: LlamaConfig): self.gradient_checkpointing = False + self._weights_initialized = False + def set_transformer_block(self, transformer_config): if self.use_weight_only: self.transformer_block = FusedMultiTransformerWeightOnly(transformer_config) @@ -989,7 +991,9 @@ def set_quant_scale(self): @paddle.no_grad() def set_state_dict(self, state_dict, is_eagle=False): self.set_quant_scale() - self.transformer_block.init_weight() + if not self._weights_initialized: + self.transformer_block.init_weight() + self._weights_initialized = True split_fn = split_param_func() self.embed_tokens.weight.set_value( paddle.to_tensor(state_dict["llama.embed_tokens.weight"]).cast(self.embed_tokens.weight.dtype) diff --git a/paddlenlp/experimental/transformers/qwen2/modeling.py b/paddlenlp/experimental/transformers/qwen2/modeling.py index 8a14ca199e20..d2a9a4b6eb2b 100644 --- a/paddlenlp/experimental/transformers/qwen2/modeling.py +++ b/paddlenlp/experimental/transformers/qwen2/modeling.py @@ -363,6 +363,8 @@ def __init__(self, config: Qwen2Config, base_model_prefix: str): self.cache_kvs = None self.head_dim_shape_tensor = paddle.ones((self.hidden_size // self.num_attention_heads), dtype="int8") + self._weights_initialized = False + def set_transformer_block(self, transformer_config): if self.use_weight_only: self.transformer_block = FusedMultiTransformerWeightOnly(transformer_config) @@ -554,7 +556,9 @@ def set_quant_scale(self): @paddle.no_grad() def set_state_dict(self, state_dict): self.set_quant_scale() - self.transformer_block.init_weight() + if not self._weights_initialized: + self.transformer_block.init_weight() + self._weights_initialized = True split_fn = split_param_func() self.embed_tokens.weight.set_value( paddle.to_tensor(state_dict[f"{self.base_model_prefix}.embed_tokens.weight"]).cast( @@ -567,7 +571,7 @@ def set_state_dict(self, state_dict): for idx in range(self.num_layers): model_prefix = self.base_model_prefix + f".layers.{idx}" - logger.info(f"set state for layer {idx}") + # logger.info(f"set state for layer {idx}") ln_scale = paddle.to_tensor(state_dict[f"{model_prefix}.input_layernorm.weight"]).cast( self.transformer_block.ln_scales[idx].dtype diff --git a/paddlenlp/transformers/auto/modeling.py b/paddlenlp/transformers/auto/modeling.py index 38e773f56bb4..88ee24e37640 100644 --- a/paddlenlp/transformers/auto/modeling.py +++ b/paddlenlp/transformers/auto/modeling.py @@ -17,6 +17,7 @@ import json import os from collections import OrderedDict +from copy import deepcopy from ...utils.download import resolve_file_path from ...utils.log import logger @@ -224,7 +225,7 @@ def _get_model_class_from_config(cls, pretrained_model_name_or_path, config_file # Get class name corresponds to this configuration if is_standard_config(config): - architectures = config["architectures"] + architectures = deepcopy(config["architectures"]) init_class = architectures.pop() if len(architectures) > 0 else None else: init_class = config.pop("init_class", None) diff --git a/paddlenlp/trl/llm_utils.py b/paddlenlp/trl/llm_utils.py index c19496909295..2bb00499e82c 100644 --- a/paddlenlp/trl/llm_utils.py +++ b/paddlenlp/trl/llm_utils.py @@ -22,7 +22,6 @@ import numpy as np import paddle import paddle.distributed as dist -import paddle.distributed.fleet.base.topology as tp import paddle.incubate.multiprocessing as mp from paddle.distributed import fleet from sklearn.metrics import accuracy_score @@ -744,24 +743,46 @@ def get_rotary_position_embedding(position_ids, head_dim, rope_theta=10000.0, ro def init_dist_env(): - tensor_parallel_degree = paddle.distributed.get_world_size() - tensor_parallel_rank = paddle.distributed.get_rank() + """ + Initialize the distributed environment and obtain tensor parallel degree and rank. + + Returns: + tuple: A tuple containing tensor parallel rank and degree. + """ + world_size = paddle.distributed.get_world_size() # Get the total number of distributed nodes + + if world_size > 1: + is_fleet_init = True + try: + # Try to get the hybrid communicate group to check if Fleet has been initialized + hcg = fleet.get_hybrid_communicate_group() + except AttributeError: + is_fleet_init = False # Fleet has not been initialized - if tensor_parallel_degree > 1: - # refer to: https://github.com/PaddlePaddle/Paddle/blob/4abea956ee852ce52791a1e08fa92ed4d3be150d/python/paddle/distributed/fleet/fleet.py#L298C23-L298C45 - hcg = tp._HYBRID_PARALLEL_GROUP - if hcg is None: + if is_fleet_init: + # If Fleet is already initialized, get tensor parallel degree and rank + tensor_parallel_degree = hcg.get_model_parallel_world_size() + tensor_parallel_rank = hcg.get_model_parallel_rank() + else: + # If Fleet is not initialized, set up the distributed strategy and initialize Fleet strategy = fleet.DistributedStrategy() strategy.hybrid_configs = { - "dp_degree": 1, - "mp_degree": tensor_parallel_degree, - "pp_degree": 1, - "sharding_degree": 1, + "dp_degree": 1, # Data parallelism degree + "mp_degree": world_size, # Model parallelism degree (to be determined or set) + "pp_degree": 1, # Pipeline parallelism degree + "sharding_degree": 1, # Sharding parallelism degree } - fleet.init(is_collective=True, strategy=strategy) - hcg = fleet.get_hybrid_communicate_group() + fleet.init(is_collective=True, strategy=strategy) # Initialize Fleet + hcg = fleet.get_hybrid_communicate_group() # Get the hybrid communicate group after initialization + + # Get tensor parallel degree and rank after Fleet initialization + tensor_parallel_degree = hcg.get_model_parallel_world_size() + tensor_parallel_rank = hcg.get_model_parallel_rank() + else: + # If not in a distributed environment, set tensor parallel degree and rank to 1 and 0 respectively + tensor_parallel_degree = 1 + tensor_parallel_rank = 0 - tensor_parallel_rank = hcg.get_model_parallel_rank() return tensor_parallel_rank, tensor_parallel_degree diff --git a/tests/fixtures/llm/ppo.yaml b/tests/fixtures/llm/ppo.yaml new file mode 100644 index 000000000000..adfa38c54689 --- /dev/null +++ b/tests/fixtures/llm/ppo.yaml @@ -0,0 +1,79 @@ +ppo: + base: + train_datasets: "Jsonfile::./tests/fixtures/llm/ppo_data/train.jsonl" + eval_datasets: "Jsonfile::./tests/fixtures/llm/ppo_data/dev.jsonl" + ptx_datasets: "Jsonfile::./tests/fixtures/llm/ppo_data/ptx.jsonl" + max_length: 2048 + use_fusemt: 1 + use_flash_attention: 1 + max_dec_len: 1024 + min_dec_len: 1 + top_p: 0.8 + temperature: 1.0 + num_return_sequences: 1 + repetition_penalty: 1.0 + num_train_epochs: 1 + max_steps: 5 + update_iters: 1 + per_device_prompt_batch_size: 1 + per_device_train_batch_size: 1 + gradient_accumulation_steps: 1 + learning_rate: 2e-6 + min_learning_rate: 2e-7 + weight_decay: 0.01 + lr_scheduler_type: "cosine" + warmup_ratio: 0.03 + recompute: 1 + recompute_granularity: "full" + recompute_use_reentrant: 1 + critic_learning_rate: 2e-6 + critic_min_learning_rate: 2e-7 + critic_weight_decay: 0.01 + critic_lr_scheduler_type: "cosine" + critic_warmup_ratio: 0.03 + critic_recompute: 1 + critic_recompute_granularity: "full" + normalize_reward: 1 + normalize_advantage: 1 + kl_coeff: 0.02 + clip_range_ratio: 0.2 + clip_range_score: 10.0 + clip_range_value: 5.0 + ptx_coeff: 16.0 + logging_steps: 1 + evaluation_strategy: "no" + per_device_eval_batch_size: 16 + eval_steps: 10000 + save_strategy: "steps" + save_steps: 400 + save_total_limit: 5 + bf16: 1 + fp16: 0 + fp16_opt_level: O2 + do_train: 1 + do_eval: 0 + disable_tqdm: 1 + sharding_parallel_degree: 1 + sharding: stage1 + tensor_parallel_degree: 8 + tensor_parallel_output: 0 + pipeline_parallel_degree: 1 + pipeline_parallel_config: "disable_p2p_cache_shape" + sequence_parallel: 0 + max_grad_norm: 1.0 + adam_beta1: 0.9 + adam_beta2: 0.95 + dataloader_drop_last: 0 + eval_mode: "" + offload_level: "freeze_model optimizer train_model" + release_grads: 1 + seed: 23 + use_fused_head_and_loss_fn: 0 + autotuner_benchmark: 0 + skip_profile_timer: 1 + fused_linear: 1 + + default: + llama: + actor_model_name_or_path: __internal_testing__/tiny-random-llama + reward_model_name_or_path: __internal_testing__/tiny-random-llama diff --git a/tests/fixtures/llm/ppo_data/dev.jsonl b/tests/fixtures/llm/ppo_data/dev.jsonl new file mode 100644 index 000000000000..bfe814674314 --- /dev/null +++ b/tests/fixtures/llm/ppo_data/dev.jsonl @@ -0,0 +1,5 @@ +{"src": "<|im_start|>system\nYou are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Now the user asks you to solve a logical reasoning problem. After thinking, when you finally reach a conclusion, clearly state the identity of each character within tags. i.e., (1) Zoey is a knight\n(2) ... .\n<|im_end|>\n<|im_start|>user\nA very special island is inhabited only by knights and knaves. Knights always tell the truth, and knaves always lie. You meet 5 inhabitants: Aria, Jackson, Ethan, Owen, and Henry. Aria commented, \"Owen is a knight or Aria is a knight\". Jackson commented, \"If Ethan is a knight then Aria is a knight\". \"Aria is a knave if and only if Owen is a knave,\" Ethan declared. Owen was heard saying, \"If Henry is a knight then Henry is a knave\". \"Jackson is a knave if and only if Owen is a knave,\" Henry mentioned. So who is a knight and who is a knave?\n<|im_end|>\n<|im_start|>assistant\n", "tgt": "(1) Aria is a knave\n(2) Jackson is a knave\n(3) Ethan is a knight\n(4) Owen is a knave\n(5) Henry is a knight"} +{"src": "<|im_start|>system\nYou are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Now the user asks you to solve a logical reasoning problem. After thinking, when you finally reach a conclusion, clearly state the identity of each character within tags. i.e., (1) Zoey is a knight\n(2) ... .\n<|im_end|>\n<|im_start|>user\nA very special island is inhabited only by knights and knaves. Knights always tell the truth, and knaves always lie. You meet 5 inhabitants: Liam, Abigail, Oliver, Charlotte, and Joseph. According to Liam, \"Charlotte is a knight and Joseph is a knave\". Abigail said, \"Charlotte is a knave or Liam is a knight.\" According to Oliver, \"Charlotte is a knight\". Charlotte told you that Liam is a knave. In Joseph's words: \"Liam is not a knight\". So who is a knight and who is a knave?\n<|im_end|>\n<|im_start|>assistant\n", "tgt": "(1) Liam is a knave\n(2) Abigail is a knave\n(3) Oliver is a knight\n(4) Charlotte is a knight\n(5) Joseph is a knight"} +{"src": "<|im_start|>system\nYou are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Now the user asks you to solve a logical reasoning problem. After thinking, when you finally reach a conclusion, clearly state the identity of each character within tags. i.e., (1) Zoey is a knight\n(2) ... .\n<|im_end|>\n<|im_start|>user\nA very special island is inhabited only by knights and knaves. Knights always tell the truth, and knaves always lie. You meet 5 inhabitants: Ella, Victoria, Riley, Avery, and Sofia. Ella asserted: \"Ella is a knight or Victoria is a knave\". Victoria asserted: \"If Victoria is a knight then Riley is a knight\". Riley asserted: \"Ella is not a knight\". \"Sofia is a knight or Ella is a knave,\" Avery declared. Sofia expressed that If Victoria is a knave then Riley is a knave. So who is a knight and who is a knave?\n<|im_end|>\n<|im_start|>assistant\n", "tgt": "(1) Ella is a knave\n(2) Victoria is a knight\n(3) Riley is a knight\n(4) Avery is a knight\n(5) Sofia is a knight"} +{"src": "<|im_start|>system\nYou are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Now the user asks you to solve a logical reasoning problem. After thinking, when you finally reach a conclusion, clearly state the identity of each character within tags. i.e., (1) Zoey is a knight\n(2) ... .\n<|im_end|>\n<|im_start|>user\nA very special island is inhabited only by knights and knaves. Knights always tell the truth, and knaves always lie. You meet 5 inhabitants: Joseph, Michael, Elizabeth, Lucas, and Aria. \"Lucas is a knave and Elizabeth is a knave\" - Joseph. Michael said, \"Elizabeth is a knave and Elizabeth is a knight.\" In Elizabeth's words: \"Michael is not a knave\". Lucas remarked, \"Michael is not a knight\". Aria asserted: \"Aria is a knight and Joseph is a knight\". So who is a knight and who is a knave?\n<|im_end|>\n<|im_start|>assistant\n", "tgt": "(1) Joseph is a knave\n(2) Michael is a knave\n(3) Elizabeth is a knave\n(4) Lucas is a knight\n(5) Aria is a knave"} +{"src": "<|im_start|>system\nYou are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Now the user asks you to solve a logical reasoning problem. After thinking, when you finally reach a conclusion, clearly state the identity of each character within tags. i.e., (1) Zoey is a knight\n(2) ... .\n<|im_end|>\n<|im_start|>user\nA very special island is inhabited only by knights and knaves. Knights always tell the truth, and knaves always lie. You meet 5 inhabitants: Alexander, James, Ava, Logan, and Grace. Alexander told you that If Grace is a knight then Logan is a knight. James noted, \"Ava is a knave or Logan is a knight\". According to Ava, \"James is a knave if and only if Alexander is a knight\". \"Alexander is a knave if and only if James is a knight,\" Logan declared. \"Ava is a knave and Ava is a knight,\" Grace declared. So who is a knight and who is a knave?\n<|im_end|>\n<|im_start|>assistant\n", "tgt": "(1) Alexander is a knight\n(2) James is a knight\n(3) Ava is a knave\n(4) Logan is a knave\n(5) Grace is a knave"} \ No newline at end of file diff --git a/tests/fixtures/llm/ppo_data/ptx.jsonl b/tests/fixtures/llm/ppo_data/ptx.jsonl new file mode 100644 index 000000000000..bfe814674314 --- /dev/null +++ b/tests/fixtures/llm/ppo_data/ptx.jsonl @@ -0,0 +1,5 @@ +{"src": "<|im_start|>system\nYou are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Now the user asks you to solve a logical reasoning problem. After thinking, when you finally reach a conclusion, clearly state the identity of each character within tags. i.e., (1) Zoey is a knight\n(2) ... .\n<|im_end|>\n<|im_start|>user\nA very special island is inhabited only by knights and knaves. Knights always tell the truth, and knaves always lie. You meet 5 inhabitants: Aria, Jackson, Ethan, Owen, and Henry. Aria commented, \"Owen is a knight or Aria is a knight\". Jackson commented, \"If Ethan is a knight then Aria is a knight\". \"Aria is a knave if and only if Owen is a knave,\" Ethan declared. Owen was heard saying, \"If Henry is a knight then Henry is a knave\". \"Jackson is a knave if and only if Owen is a knave,\" Henry mentioned. So who is a knight and who is a knave?\n<|im_end|>\n<|im_start|>assistant\n", "tgt": "(1) Aria is a knave\n(2) Jackson is a knave\n(3) Ethan is a knight\n(4) Owen is a knave\n(5) Henry is a knight"} +{"src": "<|im_start|>system\nYou are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Now the user asks you to solve a logical reasoning problem. After thinking, when you finally reach a conclusion, clearly state the identity of each character within tags. i.e., (1) Zoey is a knight\n(2) ... .\n<|im_end|>\n<|im_start|>user\nA very special island is inhabited only by knights and knaves. Knights always tell the truth, and knaves always lie. You meet 5 inhabitants: Liam, Abigail, Oliver, Charlotte, and Joseph. According to Liam, \"Charlotte is a knight and Joseph is a knave\". Abigail said, \"Charlotte is a knave or Liam is a knight.\" According to Oliver, \"Charlotte is a knight\". Charlotte told you that Liam is a knave. In Joseph's words: \"Liam is not a knight\". So who is a knight and who is a knave?\n<|im_end|>\n<|im_start|>assistant\n", "tgt": "(1) Liam is a knave\n(2) Abigail is a knave\n(3) Oliver is a knight\n(4) Charlotte is a knight\n(5) Joseph is a knight"} +{"src": "<|im_start|>system\nYou are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Now the user asks you to solve a logical reasoning problem. After thinking, when you finally reach a conclusion, clearly state the identity of each character within tags. i.e., (1) Zoey is a knight\n(2) ... .\n<|im_end|>\n<|im_start|>user\nA very special island is inhabited only by knights and knaves. Knights always tell the truth, and knaves always lie. You meet 5 inhabitants: Ella, Victoria, Riley, Avery, and Sofia. Ella asserted: \"Ella is a knight or Victoria is a knave\". Victoria asserted: \"If Victoria is a knight then Riley is a knight\". Riley asserted: \"Ella is not a knight\". \"Sofia is a knight or Ella is a knave,\" Avery declared. Sofia expressed that If Victoria is a knave then Riley is a knave. So who is a knight and who is a knave?\n<|im_end|>\n<|im_start|>assistant\n", "tgt": "(1) Ella is a knave\n(2) Victoria is a knight\n(3) Riley is a knight\n(4) Avery is a knight\n(5) Sofia is a knight"} +{"src": "<|im_start|>system\nYou are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Now the user asks you to solve a logical reasoning problem. After thinking, when you finally reach a conclusion, clearly state the identity of each character within tags. i.e., (1) Zoey is a knight\n(2) ... .\n<|im_end|>\n<|im_start|>user\nA very special island is inhabited only by knights and knaves. Knights always tell the truth, and knaves always lie. You meet 5 inhabitants: Joseph, Michael, Elizabeth, Lucas, and Aria. \"Lucas is a knave and Elizabeth is a knave\" - Joseph. Michael said, \"Elizabeth is a knave and Elizabeth is a knight.\" In Elizabeth's words: \"Michael is not a knave\". Lucas remarked, \"Michael is not a knight\". Aria asserted: \"Aria is a knight and Joseph is a knight\". So who is a knight and who is a knave?\n<|im_end|>\n<|im_start|>assistant\n", "tgt": "(1) Joseph is a knave\n(2) Michael is a knave\n(3) Elizabeth is a knave\n(4) Lucas is a knight\n(5) Aria is a knave"} +{"src": "<|im_start|>system\nYou are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Now the user asks you to solve a logical reasoning problem. After thinking, when you finally reach a conclusion, clearly state the identity of each character within tags. i.e., (1) Zoey is a knight\n(2) ... .\n<|im_end|>\n<|im_start|>user\nA very special island is inhabited only by knights and knaves. Knights always tell the truth, and knaves always lie. You meet 5 inhabitants: Alexander, James, Ava, Logan, and Grace. Alexander told you that If Grace is a knight then Logan is a knight. James noted, \"Ava is a knave or Logan is a knight\". According to Ava, \"James is a knave if and only if Alexander is a knight\". \"Alexander is a knave if and only if James is a knight,\" Logan declared. \"Ava is a knave and Ava is a knight,\" Grace declared. So who is a knight and who is a knave?\n<|im_end|>\n<|im_start|>assistant\n", "tgt": "(1) Alexander is a knight\n(2) James is a knight\n(3) Ava is a knave\n(4) Logan is a knave\n(5) Grace is a knave"} \ No newline at end of file diff --git a/tests/fixtures/llm/ppo_data/train.jsonl b/tests/fixtures/llm/ppo_data/train.jsonl new file mode 100644 index 000000000000..54c155fe9090 --- /dev/null +++ b/tests/fixtures/llm/ppo_data/train.jsonl @@ -0,0 +1,9 @@ +{"src": "<|im_start|>system\nYou are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Now the user asks you to solve a logical reasoning problem. After thinking, when you finally reach a conclusion, clearly state the identity of each character within tags. i.e., (1) Zoey is a knight\n(2) ... .\n<|im_end|>\n<|im_start|>user\nA very special island is inhabited only by knights and knaves. Knights always tell the truth, and knaves always lie. You meet 5 inhabitants: Aurora, Ethan, Noah, Aria, and Abigail. \"Abigail is not a knave\" - Aurora. Ethan remarked, \"Abigail is a knave or Aria is a knave\". According to Noah, \"Aria is a knave\". Aria told you that Aurora is a knave if and only if Noah is a knight. Abigail said that Noah is a knight. So who is a knight and who is a knave?\n<|im_end|>\n<|im_start|>assistant\n", "tgt": "(1) Aurora is a knight\n(2) Ethan is a knight\n(3) Noah is a knight\n(4) Aria is a knave\n(5) Abigail is a knight"} +{"src": "<|im_start|>system\nYou are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Now the user asks you to solve a logical reasoning problem. After thinking, when you finally reach a conclusion, clearly state the identity of each character within tags. i.e., (1) Zoey is a knight\n(2) ... .\n<|im_end|>\n<|im_start|>user\nA very special island is inhabited only by knights and knaves. Knights always tell the truth, and knaves always lie. You meet 5 inhabitants: Abigail, Mason, Elizabeth, Emily, and Sophia. \"Sophia is a knave or Elizabeth is a knight,\" Abigail declared. \"Sophia is a knight,\" Mason mentioned. \"Abigail is not a knight,\" Elizabeth claimed. In a statement by Emily: \"Abigail is a knight or Elizabeth is a knight\". As Sophia put it, \"If Emily is a knight then Emily is a knave\". So who is a knight and who is a knave?\n<|im_end|>\n<|im_start|>assistant\n", "tgt": "(1) Abigail is a knight\n(2) Mason is a knave\n(3) Elizabeth is a knave\n(4) Emily is a knight\n(5) Sophia is a knave"} +{"src": "<|im_start|>system\nYou are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Now the user asks you to solve a logical reasoning problem. After thinking, when you finally reach a conclusion, clearly state the identity of each character within tags. i.e., (1) Zoey is a knight\n(2) ... .\n<|im_end|>\n<|im_start|>user\nA very special island is inhabited only by knights and knaves. Knights always tell the truth, and knaves always lie. You meet 5 inhabitants: Penelope, Lucas, Amelia, Emily, and Zoey. Penelope asserted: \"Lucas is a knight or Amelia is a knight\". In a statement by Lucas: \"If Lucas is a knight then Zoey is a knave\". Amelia remarked, \"Emily is a knave\". Emily noted, \"Penelope is a knave and Zoey is a knave\". Zoey told you that Emily is a knight and Emily is a knave. So who is a knight and who is a knave?\n<|im_end|>\n<|im_start|>assistant\n", "tgt": "(1) Penelope is a knight\n(2) Lucas is a knight\n(3) Amelia is a knight\n(4) Emily is a knave\n(5) Zoey is a knave"} +{"src": "<|im_start|>system\nYou are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Now the user asks you to solve a logical reasoning problem. After thinking, when you finally reach a conclusion, clearly state the identity of each character within tags. i.e., (1) Zoey is a knight\n(2) ... .\n<|im_end|>\n<|im_start|>user\nA very special island is inhabited only by knights and knaves. Knights always tell the truth, and knaves always lie. You meet 5 inhabitants: William, Jacob, Aiden, Aria, and Logan. In a statement by William: \"If Aiden is a knight then Aria is a knave\". Jacob was heard saying, \"Logan is a knave and Jacob is a knight\". Aiden said, \"Aria is a knave or William is a knight.\" Aria told you that William is a knave. Logan stated, \"Jacob is a knight\". So who is a knight and who is a knave?\n<|im_end|>\n<|im_start|>assistant\n", "tgt": "(1) William is a knight\n(2) Jacob is a knave\n(3) Aiden is a knight\n(4) Aria is a knave\n(5) Logan is a knave"} +{"src": "<|im_start|>system\nYou are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Now the user asks you to solve a logical reasoning problem. After thinking, when you finally reach a conclusion, clearly state the identity of each character within tags. i.e., (1) Zoey is a knight\n(2) ... .\n<|im_end|>\n<|im_start|>user\nA very special island is inhabited only by knights and knaves. Knights always tell the truth, and knaves always lie. You meet 5 inhabitants: Lucas, Alexander, Henry, Aiden, and Sebastian. Lucas commented, \"Henry is a knight\". \"Aiden is a knight,\" Alexander mentioned. As Henry put it, \"Sebastian is a knight and Henry is a knight\". \"Lucas is not a knight,\" Aiden declared. Sebastian asserted: \"Lucas is a knave or Aiden is a knight\". So who is a knight and who is a knave?\n<|im_end|>\n<|im_start|>assistant\n", "tgt": "(1) Lucas is a knave\n(2) Alexander is a knight\n(3) Henry is a knave\n(4) Aiden is a knight\n(5) Sebastian is a knight"} +{"src": "<|im_start|>system\nYou are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Now the user asks you to solve a logical reasoning problem. After thinking, when you finally reach a conclusion, clearly state the identity of each character within tags. i.e., (1) Zoey is a knight\n(2) ... .\n<|im_end|>\n<|im_start|>user\nA very special island is inhabited only by knights and knaves. Knights always tell the truth, and knaves always lie. You meet 5 inhabitants: Amelia, Zoey, Owen, Samuel, and Aria. As Amelia put it, \"Zoey is a knave and Owen is a knave\". Zoey expressed that Amelia is a knight and Aria is a knight. Owen was heard saying, \"Samuel is a knight and Amelia is a knave\". \"If Aria is a knight then Zoey is a knight,\" Samuel declared. Aria remarked, \"Zoey is a knight if and only if Owen is a knave\". So who is a knight and who is a knave?\n<|im_end|>\n<|im_start|>assistant\n", "tgt": "(1) Amelia is a knight\n(2) Zoey is a knave\n(3) Owen is a knave\n(4) Samuel is a knight\n(5) Aria is a knave"} +{"src": "<|im_start|>system\nYou are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Now the user asks you to solve a logical reasoning problem. After thinking, when you finally reach a conclusion, clearly state the identity of each character within tags. i.e., (1) Zoey is a knight\n(2) ... .\n<|im_end|>\n<|im_start|>user\nA very special island is inhabited only by knights and knaves. Knights always tell the truth, and knaves always lie. You meet 5 inhabitants: Daniel, Mason, Aria, Liam, and Michael. Daniel said, \"Michael is a knave.\" Mason asserted: \"Liam is a knave\". \"If Mason is a knave then Michael is a knave\" - Aria. Liam said, \"Mason is a knave if and only if Mason is a knight.\" \"Aria is not a knight\" - Michael. So who is a knight and who is a knave?\n<|im_end|>\n<|im_start|>assistant\n", "tgt": "(1) Daniel is a knight\n(2) Mason is a knight\n(3) Aria is a knight\n(4) Liam is a knave\n(5) Michael is a knave"} +{"src": "<|im_start|>system\nYou are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Now the user asks you to solve a logical reasoning problem. After thinking, when you finally reach a conclusion, clearly state the identity of each character within tags. i.e., (1) Zoey is a knight\n(2) ... .\n<|im_end|>\n<|im_start|>user\nA very special island is inhabited only by knights and knaves. Knights always tell the truth, and knaves always lie. You meet 5 inhabitants: Ava, Abigail, Harper, Penelope, and Ethan. Ava told you that Ethan is a knight and Abigail is a knave. Abigail was heard saying, \"Abigail is a knight or Penelope is a knave\". According to Harper, \"If Harper is a knight then Abigail is a knight\". \"Abigail is not a knave,\" Penelope claimed. Ethan told you that If Ava is a knight then Abigail is a knave. So who is a knight and who is a knave?\n<|im_end|>\n<|im_start|>assistant\n", "tgt": "(1) Ava is a knave\n(2) Abigail is a knight\n(3) Harper is a knight\n(4) Penelope is a knight\n(5) Ethan is a knight"} +{"src": "<|im_start|>system\nYou are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Now the user asks you to solve a logical reasoning problem. After thinking, when you finally reach a conclusion, clearly state the identity of each character within tags. i.e., (1) Zoey is a knight\n(2) ... .\n<|im_end|>\n<|im_start|>user\nA very special island is inhabited only by knights and knaves. Knights always tell the truth, and knaves always lie. You meet 5 inhabitants: David, Evelyn, James, Grace, and Mason. David noted, \"Evelyn is a knave or Evelyn is a knight\". Evelyn expressed that Grace is a knave and Mason is a knight. James expressed that David is not a knight. In Grace's words: \"If David is a knave then James is a knave\". Mason commented, \"Grace is a knight\". So who is a knight and who is a knave?\n<|im_end|>\n<|im_start|>assistant\n", "tgt": "(1) David is a knight\n(2) Evelyn is a knave\n(3) James is a knave\n(4) Grace is a knight\n(5) Mason is a knight"} \ No newline at end of file diff --git a/tests/llm/test_ppo.py b/tests/llm/test_ppo.py new file mode 100644 index 000000000000..eb4eda10ed75 --- /dev/null +++ b/tests/llm/test_ppo.py @@ -0,0 +1,50 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import sys +import unittest + +from parameterized import parameterized_class + +from tests.testing_utils import argv_context_guard, load_test_config + +from .testing_utils import LLMTest + + +@parameterized_class( + ["model_dir"], + [["llama"]], +) +class FinetuneTest(LLMTest, unittest.TestCase): + config_path: str = "./tests/fixtures/llm/ppo.yaml" + model_dir: str = None + + def setUp(self) -> None: + LLMTest.setUp(self) + sys.path.insert(0, "./llm/alignment/ppo") + sys.path.insert(0, self.model_dir) + + def tearDown(self) -> None: + LLMTest.tearDown(self) + + def test_finetune(self): + ppo_config = load_test_config(self.config_path, "ppo", self.model_dir) + + ppo_config["output_dir"] = self.output_dir + with argv_context_guard(ppo_config): + from alignment.ppo.run_ppo import main + + main()