From 6f40a17ef1507dea1b73df4645bfbb04b582a54f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Sun, 2 Jun 2024 18:41:01 +0200 Subject: [PATCH] Fix typo in DPOTrainer's warnings --- trl/trainer/dpo_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index fb79ef0e39..0ff1dfce04 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -167,7 +167,7 @@ def __init__( ): if model_init_kwargs is not None: warnings.warn( - "You passed `model_init_kwargs` to the SFTTrainer, the value you passed will override the one in the `SFTConfig`." + "You passed `model_init_kwargs` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`." ) args.model_init_kwargs = model_init_kwargs @@ -187,7 +187,7 @@ def __init__( if ref_model_init_kwargs is not None: warnings.warn( - "You passed `ref_model_kwargs` to the SFTTrainer, the value you passed will override the one in the `SFTConfig`." + "You passed `ref_model_init_kwargs` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`." ) args.ref_model_init_kwargs = ref_model_init_kwargs