Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][Feature] DPO #434

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open

[WIP][Feature] DPO #434

wants to merge 19 commits into from

Conversation

amulil
Copy link
Contributor

@amulil amulil commented Feb 25, 2024

@pppppM 佬,按你说的,初步想法是在 dataset 目录下实现 DPODataset,在 model 目录下实现 DPO,其他 hook 暂时和 sft 一致的,不用修改,但是有一个疑问,DPO 里有 model 和 ref_model 两个 model,deepspeed 相关的部分用修改嘛?

self.use_varlen_attn = use_varlen_attn

# TODO: Add ref model and ref model config
self.ref_llm = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ref_llm, 也支持 api model

@amulil
Copy link
Contributor Author

amulil commented Mar 6, 2024

更新了 dpo 的实现,使用 sft 的数据,可以跑通流程,但是存在两个问题:
NPROC_PER_NODE=8 xtuner train internlm2_chat_1_8b_qlora_dpo_ultra_e3 --deepspeed deepspeed_zero2

  1. loss 为 nan
截屏2024-03-06 15 32 06
  1. deepcopy 的方式不支持量化加载,只有 lora 和不量化加载,流程可以跑通

@xiaohangguo @pppppM 佬们,看下这两个问题是为啥呀

@pppppM
Copy link
Collaborator

pppppM commented Mar 7, 2024

ref_model 要不直接用 llm 的 config 重新 build ?

loss 为 nan 可能要 @xiaohangguo 帮忙看下公式细节

@xiaohangguo
Copy link
Contributor

ref_model 要不直接用 llm 的 config 重新 build ?

loss 为 nan 可能要 @xiaohangguo 帮忙看下公式细节

好,今晚我切到这个分支复现一下,debug看看

@amulil
Copy link
Contributor Author

amulil commented Mar 7, 2024

ref_model 要不直接用 llm 的 config 重新 build ?

loss 为 nan 可能要 @xiaohangguo 帮忙看下公式细节

可以 我试试改成 用 llm 的 config 重新 build

@xiaohangguo
Copy link
Contributor

写了个Mock 数据pytest来验证算法,目前测试结果,loss计算应该是没有问题。

import torch
import torch.nn.functional as F
from unittest import TestCase, main
# from utils import print


class MockModelOutput:
    def __init__(self, logits):
        self.logits = logits


class TestModel:
    def __init__(self, beta):
        self.beta = beta

    def llm(self, **kwargs):
        return MockModelOutput(logits=torch.randn(10, 5, 20))

    def ref_model(self, **kwargs):
        return MockModelOutput(logits=torch.randn(10, 5, 20))

    def compute_loss(self, data, data_samples=None):
        len_chosen = data["input_ids"].shape[0] // 2
        assert len_chosen != 0
        all_logits = self.llm(**data).logits
        all_ref_logits = self.ref_model(**data).logits

        print("all_logits:", all_logits)
        print("all_ref_logits:", all_ref_logits)

        labels = data["labels"]
        labels[labels == -100] = 0
        loss_mask = labels != 0

        print("labels:", labels)
        print("loss_mask:", loss_mask)

        per_token_logps = torch.gather(
            all_logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
        per_ref_token_logps = torch.gather(
            all_ref_logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)

        print("per_token_logps:", per_token_logps)
        print("per_ref_token_logps:", per_ref_token_logps)

        epsilon = 0
        all_logps = (per_token_logps * loss_mask).sum(-1) / \
            (loss_mask.sum(-1) + epsilon)
        all_ref_logps = (per_ref_token_logps * loss_mask).sum(-1) / \
            (loss_mask.sum(-1) + epsilon)
        print("loss_mask.sum(-1)", loss_mask.sum(-1))
        print("all_logps:", all_logps)
        print("all_ref_logps:", all_ref_logps)

        policy_chosen_logps = all_logps[:len_chosen]
        policy_rejected_logps = all_logps[len_chosen:]
        reference_chosen_logps = all_ref_logps[:len_chosen]
        reference_rejected_logps = all_ref_logps[len_chosen:]

        print("policy_chosen_logps:", policy_chosen_logps)
        print("policy_rejected_logps:", policy_rejected_logps)
        print("reference_chosen_logps:", reference_chosen_logps)
        print("reference_rejected_logps:", reference_rejected_logps)

        pi_logratios = policy_chosen_logps - policy_rejected_logps
        ref_logratios = reference_chosen_logps - reference_rejected_logps

        print("pi_logratios:", pi_logratios)
        print("ref_logratios:", ref_logratios)

        logits = pi_logratios - ref_logratios
        loss = -F.logsigmoid(self.beta * logits)

        print("logits:", logits)
        print("loss:", loss)

        chosen_rewards = self.beta * \
            (policy_chosen_logps - reference_chosen_logps)
        rejected_rewards = self.beta * \
            (policy_rejected_logps - reference_rejected_logps)

        print("chosen_rewards:", chosen_rewards)
        print("rejected_rewards:", rejected_rewards)

        loss_dict = {
            'loss': loss,
            'chosen_rewards': chosen_rewards,
            'rejected_rewards': rejected_rewards
        }
        return loss_dict


class LossComputationTest(TestCase):
    def test_compute_loss(self):
        model = TestModel(beta=0.1)
        data = {
            "input_ids": torch.randint(0, 20, (10, 5)),
            "labels": torch.randint(-100, 20, (10, 5))
        }

        # 确保所有标签值非负
        data["labels"] = torch.where(
            data["labels"] < 0, torch.tensor(0), data["labels"])

        loss_dict = model.compute_loss(data)
        loss, chosen_rewards, rejected_rewards = loss_dict['loss'], loss_dict[
            'chosen_rewards'], loss_dict['rejected_rewards']
        # print("Loss values:", loss)
        # print("chosen_rewards values:", chosen_rewards)
        # print("rejected_rewards values:", rejected_rewards)
        self.assertTrue(torch.all(loss >= 0))
        # self.assertTrue(torch.all(chosen_rewards <= 0))
        # self.assertTrue(torch.all(rejected_rewards >= 0))


if __name__ == "__main__":
    main()

下一步需要适配Class DPOdataset ,一条batch中格式保持(prompt chosen reject)

@xiaohangguo
Copy link
Contributor

xiaohangguo commented Mar 9, 2024

把item_fn 搞了一下,但感觉还是有问题,单个conversation,应该是可以的,不知道能否和原来的encode_fn 结合,对于整个数据集处理好,正常走packer。
@LZHgrla ZH哥,麻烦帮忙看下看行不行

@amulil
Copy link
Contributor Author

amulil commented Apr 2, 2024

NPROC_PER_NODE=8 xtuner train internlm2_chat_1_8b_full_dpo_ultra_e3 --deepspeed deepspeed_zero2
目前 full dpo loss 正常了:
截屏2024-04-02 23 17 57
接下来按照 trl 文档里的说明添加 qlora dpo:
https://moon-ci-docs.huggingface.co/docs/trl/pr_1193/en/dpo_trainer#downsides-to-merging-qlora-before-dpo-approach-2

@xiaohangguo
Copy link
Contributor

NPROC_PER_NODE=8 xtuner train internlm2_chat_1_8b_full_dpo_ultra_e3 --deepspeed deepspeed_zero2 目前 full dpo loss 正常了: 截屏2024-04-02 23 17 57 接下来按照 trl 文档里的说明添加 qlora dpo: https://moon-ci-docs.huggingface.co/docs/trl/pr_1193/en/dpo_trainer#downsides-to-merging-qlora-before-dpo-approach-2

太强了!

@KooSung
Copy link
Contributor

KooSung commented Apr 7, 2024

@amulil 请问现在有DPO训练的模型指标对比吗?我想参考这个实现RLHF-V
code: https://github.com/RLHF-V/RLHF-V, https://github.com/thunlp/Muffin

@amulil
Copy link
Contributor Author

amulil commented Apr 7, 2024

@amulil 请问现在有DPO训练的模型指标对比吗?我想参考这个实现RLHF-V code: https://github.com/RLHF-V/RLHF-V, https://github.com/thunlp/Muffin

@KooSung 目前暂时没有,后面会参考 https://github.com/huggingface/alignment-handbook/blob/main/recipes/zephyr-7b-beta/README.md 提到的 zephyr-7b-dpo-qlora 模型来看指标对比。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants