Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

visualize rm prediction #1636

Merged
merged 5 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions examples/scripts/reward_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,17 @@
python examples/scripts/reward_modeling.py \
--model_name_or_path=facebook/opt-350m \
--output_dir="reward_modeling_anthropic_hh" \
--per_device_train_batch_size=64 \
--per_device_train_batch_size=16 \
--num_train_epochs=1 \
--gradient_accumulation_steps=16 \
--gradient_accumulation_steps=2 \
--gradient_checkpointing=True \
--learning_rate=1.41e-5 \
--report_to="wandb" \
--remove_unused_columns=False \
--optim="adamw_torch" \
--logging_steps=10 \
--evaluation_strategy="steps" \
--eval_steps=500 \
--max_length=512 \
"""
import warnings
Expand All @@ -42,8 +43,8 @@

if __name__ == "__main__":
parser = HfArgumentParser((RewardConfig, ModelConfig))
reward_config, model_config = parser.parse_args_into_dataclasses()
reward_config.gradient_checkpointing_kwargs = dict(use_reentrant=False)
config, model_config = parser.parse_args_into_dataclasses()
config.gradient_checkpointing_kwargs = dict(use_reentrant=False)

################
# Model & Tokenizer
Expand Down Expand Up @@ -103,8 +104,7 @@ def preprocess_function(examples):
num_proc=4,
)
raw_datasets = raw_datasets.filter(
lambda x: len(x["input_ids_chosen"]) <= reward_config.max_length
and len(x["input_ids_rejected"]) <= reward_config.max_length
lambda x: len(x["input_ids_chosen"]) <= config.max_length and len(x["input_ids_rejected"]) <= config.max_length
)
train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["test"]
Expand All @@ -115,10 +115,14 @@ def preprocess_function(examples):
trainer = RewardTrainer(
model=model,
tokenizer=tokenizer,
args=reward_config,
args=config,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=get_peft_config(model_config),
)
trainer.train()
trainer.save_model(reward_config.output_dir)
trainer.save_model(config.output_dir)
trainer.push_to_hub()
metrics = trainer.evaluate()
trainer.log_metrics("eval", metrics)
print(metrics)
41 changes: 40 additions & 1 deletion trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
# limitations under the License.
import inspect
import warnings
from collections import defaultdict
from dataclasses import FrozenInstanceError, replace
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import pandas as pd
import torch
import torch.nn as nn
from accelerate.utils import gather_object
from datasets import Dataset
from transformers import DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer, TrainingArguments
from transformers.trainer_callback import TrainerCallback
Expand All @@ -26,7 +29,7 @@

from ..import_utils import is_peft_available
from .reward_config import RewardConfig
from .utils import RewardDataCollatorWithPadding, compute_accuracy
from .utils import RewardDataCollatorWithPadding, compute_accuracy, print_rich_table


if is_peft_available():
Expand Down Expand Up @@ -279,3 +282,39 @@ def prediction_step(
labels = self._prepare_inputs(labels)

return loss, logits, labels

def evaluate(self, *args, **kwargs):
num_print_samples = kwargs.pop("num_print_samples", 4)
self.visualize_samples(num_print_samples)
return super().evaluate(*args, **kwargs)

def visualize_samples(self, num_print_samples: int):
"""
Visualize the reward model logits prediction

Args:
num_print_samples (`int`, defaults to `4`):
The number of samples to print. Set to `-1` to print all samples.
"""
eval_dataloader = self.get_eval_dataloader()
table = defaultdict(list)
for _, inputs in enumerate(eval_dataloader):
_, logits, _ = self.prediction_step(self.model, inputs, prediction_loss_only=False)
chosen_text = self.tokenizer.batch_decode(inputs["input_ids_chosen"], skip_special_tokens=True)
rejected_text = self.tokenizer.batch_decode(inputs["input_ids_rejected"], skip_special_tokens=True)
table["chosen_text"].extend(gather_object(chosen_text))
table["rejected_text"].extend(gather_object(rejected_text))
table["logits"].extend(
gather_object([[round(inner_item, 4) for inner_item in item] for item in logits.tolist()])
)
if num_print_samples >= 0 and len(table["chosen_text"]) >= num_print_samples:
break
df = pd.DataFrame(table)
print_rich_table(pd.DataFrame(table))
if self.accelerator.process_index == 0:
print_rich_table(df[:num_print_samples])
if "wandb" in self.args.report_to:
import wandb

if wandb.run is not None:
wandb.log({"completions": wandb.Table(dataframe=df)})
18 changes: 17 additions & 1 deletion trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,21 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
import torch
from accelerate import PartialState
from rich.console import Console, Group
from rich.live import Live
from rich.panel import Panel
from rich.progress import Progress
from rich.table import Table
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import IterableDataset
from transformers import BitsAndBytesConfig, DataCollatorForLanguageModeling, PreTrainedTokenizerBase
from transformers import (
BitsAndBytesConfig,
DataCollatorForLanguageModeling,
PreTrainedTokenizerBase,
)
from transformers.trainer import TrainerCallback
from transformers.trainer_utils import has_length

Expand Down Expand Up @@ -815,3 +821,13 @@ def on_train_end(self, args, state, control, **kwargs):
self.rich_console = None
self.training_status = None
self.current_step = None


def print_rich_table(df: pd.DataFrame) -> Table:
console = Console()
table = Table(show_lines=True)
for column in df.columns:
table.add_column(column)
for _, row in df.iterrows():
table.add_row(*row.astype(str).tolist())
console.print(table)
Loading