Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eurus prime data #183

Merged
merged 18 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions conf/rl_eurus.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@ llm:
# CoT are much longer, but the model only has 4096 tokens context
max_tokens: 3072

# EURUS already apply this template: {task}\n\nPresent the answer in LaTex format: \\boxed{Your answer}
task_template: |-
{task}
{task}\n\nPresent the answer in LaTex format: \\boxed{{Your answer}}
# https://github.com/PRIME-RL/PRIME/blob/49a58a8e4afd464f559f8d9f80418052f29cf3e4/eval/system_prompt.md?plain=1
# but note that sometimes they do not include the newline at the beginning
# https://github.com/PRIME-RL/PRIME/blob/49a58a8e4afd464f559f8d9f80418052f29cf3e4/data_preprocessing/sft_prompt.py#L1
Expand Down
13 changes: 8 additions & 5 deletions examples/rl_gsm8k/deepseek_math_eval/process_utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
# https://github.com/deepseek-ai/DeepSeek-Math/blob/b8b0f8ce093d80bf8e9a641e44142f06d092c305/evaluation/data_processing/process_utils.py
import regex

from examples.rl_gsm8k.deepseek_math_eval.answer_extraction import extract_math_answer, strip_string
from examples.rl_gsm8k.deepseek_math_eval.answer_extraction import (
extract_math_answer, strip_string)
from examples.rl_gsm8k.deepseek_math_eval.eval_utils import parse_ground_truth


def process_eurus_test(item):
if "ability" not in item:
# math 500 test set
answer = [item["expected_answer"]]
answer = [item["answer"]]
return {
"dataset": "math500",
# Same prompt as https://github.com/PRIME-RL/PRIME/blob/49a58a8e4afd464f559f8d9f80418052f29cf3e4/README.md?plain=1#L93
"task": item["problem"] + "\n\nPresent the answer in LaTex format: \\boxed{Your answer}",
"task": item["problem"],
"answer": answer
}
else:
Expand All @@ -25,9 +26,11 @@ def process_eurus_test(item):
answer = answer.replace("\n", "")
answer = "\\boxed{" + answer + "}"
answer = extract_math_answer(item["prompt"][1]["content"], answer, task="cot")
task = item["prompt"][1]["content"]
task = task.replace("\n\nPresent the answer in LaTex format: \\boxed{Your answer}", "")
return {
"dataset": item["data_source"],
"task": item["prompt"][1]["content"],
"task": task,
"answer": answer
}

Expand All @@ -47,7 +50,7 @@ def process_math_test(item):
answer = extract_math_answer(question, item["solution"], task="cot")
except Exception:
return
sample = {"dataset": "math-cot", "level": item["level"], "type": item["type"], "task": question, "answer": answer}
sample = {"dataset": "math-cot", "level": item["level"], "type": item.get("type", ""), "task": question, "answer": answer}
return sample


Expand Down
25 changes: 15 additions & 10 deletions examples/rl_gsm8k/orchestrate_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@


def load_datasets(cfg: DictConfig) -> Tuple[list, list]:
"""
ehsk marked this conversation as resolved.
Show resolved Hide resolved
Deprecated, use load_datasets from rl/load_datasets.py
"""
match cfg.dataset_name:
case "math":
train_dataset_long_name = "hendrycks/competition_math"
Expand All @@ -50,7 +53,7 @@ def load_datasets(cfg: DictConfig) -> Tuple[list, list]:
builder_config = "main"
case "eurus":
train_dataset_long_name = "PRIME-RL/Eurus-2-RL-Data"
test_dataset_long_name = "alexpiche/math_test_cleaned"
test_dataset_long_name = "HuggingFaceH4/MATH-500"
process_fn = process_eurus_test
test_builder_config = None
builder_config = "default"
Expand Down Expand Up @@ -117,7 +120,7 @@ def convert_problems_to_tapes(problems: list, cfg: DictConfig) -> list[RLMathTap


def extract_tape_training_samples(
new_tape: RLMathTape, agent: CoTMathAgent, cfg: DictConfig
new_tape: RLMathTape, agent: CoTMathAgent, cfg: DictConfig, dataset_name: str = ""
) -> Tuple[List[TrainingText], Dict[str, int]]:
"""
Process a single tape to extract training samples and statistics.
Expand All @@ -127,8 +130,7 @@ def extract_tape_training_samples(
agent: CoTMathAgent
tapes_dir: Directory to save processed tapes
cfg: Configuration
llm_calls: List of LLM calls
strict: check that every token matches between the vLLM and the HF tokenizer otherwise just compare their lengths
dataset_name: Name of the dataset (optional). If not provided, defaults to cfg.dataset_name

Returns:
Tuple containing:
Expand All @@ -137,18 +139,21 @@ def extract_tape_training_samples(
"""
tape_prompt_tokens = 0
tape_output_tokens = 0
match cfg.dataset_name:
case "math":
dataset_name = dataset_name or cfg.dataset_name
assert dataset_name, "Dataset name must be provided"

match dataset_name:
case name if name.startswith("math") or name.startswith("eurus"):
eval_fn = eval_math
extract_fn = extract_math_answer
case "gsm8k":
case "gsm8k_test":
eval_fn = eval_last_single_answer
extract_fn = extract_last_single_answer
case "eurus":
case _:
# Default to math dataset
logger.debug(f"MATH dataset will be used for evaluation and extracting answer: {dataset_name}")
eval_fn = eval_math
extract_fn = extract_math_answer
case _:
raise ValueError(f"Unknown dataset: {cfg.dataset_name}")

if "\\boxed" not in new_tape.steps[-1].reasoning:
# LLM did not respect the formatting
Expand Down
12 changes: 10 additions & 2 deletions tapeagents/finetune/rl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from dataclasses import dataclass, field
from functools import partial
from typing import Callable
from typing import Literal

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -60,6 +60,10 @@ class RLConfig(StepConfig):
default=10,
metadata={"help": "Clamp the log ratio ref new value"},
)
aggregate_loss: Literal["mean", "sum"] = field(
default="mean",
metadata={"help": "How to aggregate the loss within a batch (when batch size is 1, there is no difference)"},
)


def make_rl_data_callback(args, current_dir, rl_config, model):
Expand Down Expand Up @@ -147,7 +151,10 @@ def rl_step(model: PreTrainedModel, batch: dict, config: RLConfig) -> tuple[torc
raise ValueError(f"Unknown algorithm {config.algo}")

num_nans = torch.isnan(loss).sum()
loss = -masked_sum(loss, masks_)
if config.aggregate_loss == "mean":
loss = -masked_mean(loss, masks_, axis=-1).mean()
else:
loss = -masked_sum(loss, masks_)
assert torch.isfinite(loss).all(), f"Loss is not finite: {loss}"

# normalize the loss by the micro batch size
Expand All @@ -160,6 +167,7 @@ def rl_step(model: PreTrainedModel, batch: dict, config: RLConfig) -> tuple[torc
"reward": masked_mean(rewards, masks_).item(),
"max_reward": rewards[masks_].max().item(),
"min_reward": rewards[masks_].min().item(),
"mean_entropy": -masked_sum(old_logprobs * torch.log(old_logprobs + 1e-9), masks_, axis=-1).mean().item(),
"mean_old_logprobs": masked_mean(old_logprobs, masks_).item(),
"mean_new_logprobs": masked_mean(new_log_probs, masks_).item(),
"mean_new_logprobs_positive_log_p_weights": masked_mean(
Expand Down
4 changes: 2 additions & 2 deletions tapeagents/finetune/rl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@ def get_avg_rl_stats(rl_stats):
return avg_rl_stats


def masked_sum(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor:
def masked_sum(values: torch.Tensor, mask: torch.Tensor, axis: Optional[int] = None) -> torch.Tensor:
"""Compute sum of tensor with a masked values."""
if axis is not None:
return (values * mask).nan_to_num(0).sum(axis=axis) # type: ignore
else:
return (values * mask).nan_to_num(0).sum()


def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor:
def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[int] = None) -> torch.Tensor:
"""Compute mean of tensor with a masked values."""
if axis is not None:
return (values * mask).nan_to_num(0).sum(axis=axis) / mask.sum(axis=axis) # type: ignore
Expand Down