Skip to content

Commit

Permalink
Merge pull request #181 from ServiceNow/now-reasoner-branch-reward
Browse files Browse the repository at this point in the history
Now reasoner branch reward
  • Loading branch information
ehsk authored Jan 30, 2025
2 parents d085b9c + 2ee467c commit 38d1539
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 9 deletions.
5 changes: 5 additions & 0 deletions conf/rl_gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ task_template: |-
overflow_reward: 0
max_prompt_length: 1024

rewards:
unparsable: 0
wrong_answer: 0
correct_answer: 1

vllm_config:
vllm_kwargs:
--download-dir: /mnt/llmd/base_models/
Expand Down
3 changes: 3 additions & 0 deletions examples/rl_gsm8k/deepseek_math_eval/process_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def process_gsm8k_test(item):

def process_math_test(item):
question = item["problem"]
if "subject" in item and "type" not in item:
item["type"] = item["subject"]

try:
answer = extract_math_answer(question, item["solution"], task="cot")
except Exception:
Expand Down
22 changes: 13 additions & 9 deletions examples/rl_gsm8k/orchestrate_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import wandb
from tapeagents.agent import Agent
from tapeagents.core import LLMCall, LLMOutputParsingFailureAction, StepMetadata, TrainingText
from tapeagents.core import LLMCall, StepMetadata, TrainingText
from tapeagents.finetune.data import MASKED_TOKEN_ID
from tapeagents.finetune.logging_ import flatten_dict_config, init_wandb
from tapeagents.llms import TrainableLLM
Expand All @@ -38,8 +38,10 @@
def load_datasets(cfg: DictConfig) -> Tuple[list, list]:
match cfg.dataset_name:
case "math":
train_dataset_long_name = test_dataset_long_name = "hendrycks/competition_math"
train_dataset_long_name = "hendrycks/competition_math"
test_dataset_long_name = "HuggingFaceH4/MATH-500"
process_fn = process_math_test
test_builder_config = "default"
builder_config = "main"
case "gsm8k":
train_dataset_long_name = test_dataset_long_name = "openai/gsm8k"
Expand All @@ -53,8 +55,9 @@ def load_datasets(cfg: DictConfig) -> Tuple[list, list]:
case _:
raise ValueError(f"Unknown dataset: {cfg.dataset_name}")

test_builder_config = test_builder_config or builder_config
train_dataset = load_dataset(train_dataset_long_name, builder_config, split="train", trust_remote_code=True)
test_dataset = load_dataset(test_dataset_long_name, builder_config, split="test", trust_remote_code=True)
test_dataset = load_dataset(test_dataset_long_name, test_builder_config, split="test", trust_remote_code=True)
train_samples = [
process_fn(s) for s in tqdm(train_dataset, desc="Processing train samples") if process_fn(s) is not None
]
Expand Down Expand Up @@ -145,10 +148,11 @@ def extract_tape_training_samples(
case _:
raise ValueError(f"Unknown dataset: {cfg.dataset_name}")

if any([isinstance(step, LLMOutputParsingFailureAction) for step in new_tape.steps]):
# LLM produced a step that was unparsable. Negative reward.
no_error, reward, success = 0, -1, 0
if "\\boxed" not in new_tape.steps[-1].reasoning:
# LLM did not respect the formatting
no_error, success, reward = 0, 0, cfg.rewards.unparsable
else:
# LLM did respect the formatting
no_error = 1
prediction = extract_fn(new_tape.steps[0].task, new_tape.steps[-1].reasoning, "cot") # type: ignore
answer = new_tape.steps[0].metadata.other["value"]
Expand All @@ -159,10 +163,10 @@ def extract_tape_training_samples(
}
):
# Correct answer
reward, success = 1, 1
reward, success = cfg.rewards.correct_answer, 1
else:
# Incorrect answer or no answer
reward, success = 0, 0
reward, success = cfg.rewards.wrong_answer, 0

training_samples: list[TrainingText] = []
# For each LLM interaction in the tape:
Expand Down Expand Up @@ -194,7 +198,7 @@ def extract_tape_training_samples(

# check if the last produced token is the end of sequence token
overflow = False if input_ids[-1] == agent.llm.tokenizer.eos_token_id else True
trace.reward = cfg.overflow_reward if overflow else reward
trace.reward = cfg.rewards.unparsable if overflow else reward
overflows.append(overflow)
trace.logprobs = [lp.logprob for lp in llm_call.logprobs if lp.generated]
trace.group_id = new_tape.metadata.parent_id
Expand Down

0 comments on commit 38d1539

Please sign in to comment.