From b8c9d9c7bc999d06f2a48cba5b688de6d8e8beab Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 15 Nov 2024 15:49:43 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9A=96=EF=B8=8F=20Add=20`use=5Fsoft=5Fjudge`?= =?UTF-8?q?=20option=20to=20`WinRateCallback`=20(#2347)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add `use_soft_judge` option to WinRateCallback * formatting * Update trl/trainer/callbacks.py Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * renamed soft_win_rate to avg_win_prob * Update trl/trainer/callbacks.py Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * fix tests * keep orignal * formatting * Update tests/test_callbacks.py Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * Update trl/trainer/callbacks.py Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * Update tests/test_callbacks.py Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * Update tests/test_callbacks.py * fix test --------- Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- tests/test_callbacks.py | 49 ++++++++++++++++++++++++++++++++++++++-- trl/trainer/callbacks.py | 32 ++++++++++++++++++++++---- 2 files changed, 75 insertions(+), 6 deletions(-) diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 6b9c89c1e9..874d8b22f0 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -30,11 +30,13 @@ class HalfPairwiseJudge(BasePairwiseJudge): - """Naive pairwise judge that always returns [1, 0]""" + """Naive pairwise judge that always returns [1, 0] for two prompts""" - def judge(self, prompts, completions, shuffle_order=True): + def judge(self, prompts, completions, shuffle_order=True, return_scores=False): # just check that the batch size is 2 assert len(prompts) == 2 + if return_scores: + return [0.3, 0.9] return [1, 0] @@ -132,6 +134,49 @@ def test_without_ref_model(self): winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h] self.assertListEqual(winrate_history, self.expected_winrates) + def test_soft_judge(self): + """Test that the soft judge functionality works correctly""" + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + eval_strategy="steps", + eval_steps=2, # evaluate every 2 steps + per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch + per_device_eval_batch_size=2, + report_to="none", + ) + trainer = TrainerWithRefModel( + model=self.model, + ref_model=self.ref_model, + args=training_args, + train_dataset=self.dataset["train"], + eval_dataset=self.dataset["test"], + processing_class=self.tokenizer, + ) + win_rate_callback = WinRateCallback( + judge=self.judge, trainer=trainer, generation_config=self.generation_config, use_soft_judge=True + ) + trainer.add_callback(win_rate_callback) + trainer.train() + + # Expected values based on judge returning [0.3, 0.9] for each pair + expected_soft_winrates = [ + {"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 0.0, "step": 0}, + {"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 0.5, "step": 2}, + {"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 1.0, "step": 4}, + {"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 1.5, "step": 6}, + {"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 2.0, "step": 8}, + {"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 2.5, "step": 10}, + {"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 3.0, "step": 12}, + ] + + winrate_history = [ + {k: h[k] for k in ["eval_avg_win_prob", "eval_win_rate", "epoch", "step"]} + for h in trainer.state.log_history + if "eval_avg_win_prob" in h + ] + self.assertListEqual(winrate_history, expected_soft_winrates) + @require_peft def test_lora(self): with tempfile.TemporaryDirectory() as tmp_dir: diff --git a/trl/trainer/callbacks.py b/trl/trainer/callbacks.py index 24e33a7566..5ec62afd49 100644 --- a/trl/trainer/callbacks.py +++ b/trl/trainer/callbacks.py @@ -230,6 +230,9 @@ class WinRateCallback(TrainerCallback): in the evaluation dataset. shuffle_order (`bool`, *optional*, defaults to `True`): Whether to shuffle the order of the completions before judging. + use_soft_judge (`bool`, *optional*, defaults to `False`): + Whether to use a soft judge that returns a win probability between 0 and 1 for the first completion vs the + second. """ def __init__( @@ -239,12 +242,14 @@ def __init__( generation_config: Optional[GenerationConfig] = None, num_prompts: Optional[int] = None, shuffle_order: bool = True, + use_soft_judge: bool = False, ): self.judge = judge self.trainer = trainer self.shuffle_order = shuffle_order self.generation_config = generation_config self.ref_completions = [] + self.use_soft_judge = use_soft_judge if self.trainer.eval_dataset is None: raise ValueError("Trainer must have an evaluation dataset to use the WinRateCallback.") @@ -281,7 +286,12 @@ def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: ) # Compute initial win rate as a reference point completions = list(zip(self.ref_completions, self.ref_completions)) - winner_indices = self.judge.judge(prompts, completions, self.shuffle_order) + if self.use_soft_judge: + ref_win_probs = self.judge.judge(prompts, completions, self.shuffle_order, return_scores=True) + winner_indices = [0 if score > 0.5 else 1 for score in ref_win_probs] + ref_win_probs = gather_object(ref_win_probs) + else: + winner_indices = self.judge.judge(prompts, completions, self.shuffle_order) prompts = gather_object(prompts) completions = gather_object(completions) winner_indices = gather_object(winner_indices) @@ -289,7 +299,11 @@ def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: # Logging if self.trainer.accelerator.is_main_process: win_rate = sum(winner_idx == 1 for winner_idx in winner_indices) / len(winner_indices) - self.trainer.log({"eval_win_rate": win_rate}) + if self.use_soft_judge: + avg_win_prob = 1.0 - sum(ref_win_probs) / len(ref_win_probs) + self.trainer.log({"eval_avg_win_prob": avg_win_prob, "eval_win_rate": win_rate}) + else: + self.trainer.log({"eval_win_rate": win_rate}) if "wandb" in args.report_to: import wandb @@ -323,7 +337,13 @@ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: Tra ) completions = list(zip(self.ref_completions, completions)) - winner_indices = self.judge.judge(prompts, completions, self.shuffle_order) + + if self.use_soft_judge: + ref_win_probs = self.judge.judge(prompts, completions, self.shuffle_order, return_scores=True) + winner_indices = [0 if score > 0.5 else 1 for score in ref_win_probs] + ref_win_probs = gather_object(ref_win_probs) + else: + winner_indices = self.judge.judge(prompts, completions, self.shuffle_order) prompts = gather_object(prompts) completions = gather_object(completions) winner_indices = gather_object(winner_indices) @@ -331,7 +351,11 @@ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: Tra # Logging if self.trainer.accelerator.is_main_process: win_rate = sum(winner_idx == 1 for winner_idx in winner_indices) / len(winner_indices) - self.trainer.log({"eval_win_rate": win_rate}) + if self.use_soft_judge: + avg_win_prob = 1.0 - sum(ref_win_probs) / len(ref_win_probs) + self.trainer.log({"eval_avg_win_prob": avg_win_prob, "eval_win_rate": win_rate}) + else: + self.trainer.log({"eval_win_rate": win_rate}) if "wandb" in args.report_to: import wandb