From 910b45993911ab1ddf35f804897ac29c2c169729 Mon Sep 17 00:00:00 2001 From: Alexandre Piche Date: Mon, 27 Jan 2025 19:21:06 +0000 Subject: [PATCH 1/9] reward --- conf/rl_gsm8k.yaml | 5 +++++ examples/rl_gsm8k/orchestrate_rl.py | 14 ++++++++------ 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/conf/rl_gsm8k.yaml b/conf/rl_gsm8k.yaml index 2ec6df29..cde7f052 100644 --- a/conf/rl_gsm8k.yaml +++ b/conf/rl_gsm8k.yaml @@ -41,6 +41,11 @@ task_template: |- overflow_reward: 0 max_prompt_length: 1024 +rewards: + unparsable: -1 + wrong_answer: -0.5 + correct_answer: 1 + vllm_config: vllm_kwargs: --download-dir: /mnt/llmd/base_models/ diff --git a/examples/rl_gsm8k/orchestrate_rl.py b/examples/rl_gsm8k/orchestrate_rl.py index d3c05846..6eaf08af 100644 --- a/examples/rl_gsm8k/orchestrate_rl.py +++ b/examples/rl_gsm8k/orchestrate_rl.py @@ -21,7 +21,7 @@ import wandb from tapeagents.agent import Agent -from tapeagents.core import LLMCall, LLMOutputParsingFailureAction, StepMetadata, TrainingText +from tapeagents.core import LLMCall, StepMetadata, TrainingText from tapeagents.finetune.data import MASKED_TOKEN_ID from tapeagents.finetune.logging_ import flatten_dict_config, init_wandb from tapeagents.llms import TrainableLLM @@ -145,10 +145,11 @@ def extract_tape_training_samples( case _: raise ValueError(f"Unknown dataset: {cfg.dataset_name}") - if any([isinstance(step, LLMOutputParsingFailureAction) for step in new_tape.steps]): - # LLM produced a step that was unparsable. Negative reward. - no_error, reward, success = 0, -1, 0 + if "\\boxed" not in new_tape.steps[-1].reasoning: + # LLM did not respect the formatting + no_error, success, reward = 0, 0, cfg.rewards.unparsable else: + # LLM did respect the formatting no_error = 1 prediction = extract_fn(new_tape.steps[0].task, new_tape.steps[-1].reasoning, "cot") # type: ignore answer = new_tape.steps[0].metadata.other["value"] @@ -159,10 +160,11 @@ def extract_tape_training_samples( } ): # Correct answer - reward, success = 1, 1 + reward, success = cfg.rewards.right_answer, 1 + reward, success = cfg.rewards.correct_answer, 1 else: # Incorrect answer or no answer - reward, success = 0, 0 + reward, success = cfg.rewards.wrong_answer, 0 training_samples: list[TrainingText] = [] # For each LLM interaction in the tape: From 210ce2e2e30d1d7a91702e197d59b7a85cdd43ec Mon Sep 17 00:00:00 2001 From: Alexandre Piche Date: Mon, 27 Jan 2025 19:23:25 +0000 Subject: [PATCH 2/9] typo --- examples/rl_gsm8k/orchestrate_rl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/rl_gsm8k/orchestrate_rl.py b/examples/rl_gsm8k/orchestrate_rl.py index 6eaf08af..c017fba1 100644 --- a/examples/rl_gsm8k/orchestrate_rl.py +++ b/examples/rl_gsm8k/orchestrate_rl.py @@ -160,7 +160,6 @@ def extract_tape_training_samples( } ): # Correct answer - reward, success = cfg.rewards.right_answer, 1 reward, success = cfg.rewards.correct_answer, 1 else: # Incorrect answer or no answer From e086dc4b4d78506e672ea57bf3fd2b6cd730f2f7 Mon Sep 17 00:00:00 2001 From: ehsk Date: Mon, 27 Jan 2025 19:31:02 +0000 Subject: [PATCH 3/9] builder_config for loading MATH changed --- examples/rl_gsm8k/orchestrate_rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/rl_gsm8k/orchestrate_rl.py b/examples/rl_gsm8k/orchestrate_rl.py index c017fba1..2b3e9500 100644 --- a/examples/rl_gsm8k/orchestrate_rl.py +++ b/examples/rl_gsm8k/orchestrate_rl.py @@ -40,7 +40,7 @@ def load_datasets(cfg: DictConfig) -> Tuple[list, list]: case "math": train_dataset_long_name = test_dataset_long_name = "hendrycks/competition_math" process_fn = process_math_test - builder_config = "main" + builder_config = "default" case "gsm8k": train_dataset_long_name = test_dataset_long_name = "openai/gsm8k" process_fn = process_gsm8k_test From a46c84c6bbdbc95db9d88431073bdf07355e2199 Mon Sep 17 00:00:00 2001 From: ehsk Date: Mon, 27 Jan 2025 19:36:45 +0000 Subject: [PATCH 4/9] last change rolled back --- examples/rl_gsm8k/orchestrate_rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/rl_gsm8k/orchestrate_rl.py b/examples/rl_gsm8k/orchestrate_rl.py index 2b3e9500..c017fba1 100644 --- a/examples/rl_gsm8k/orchestrate_rl.py +++ b/examples/rl_gsm8k/orchestrate_rl.py @@ -40,7 +40,7 @@ def load_datasets(cfg: DictConfig) -> Tuple[list, list]: case "math": train_dataset_long_name = test_dataset_long_name = "hendrycks/competition_math" process_fn = process_math_test - builder_config = "default" + builder_config = "main" case "gsm8k": train_dataset_long_name = test_dataset_long_name = "openai/gsm8k" process_fn = process_gsm8k_test From 07de544c00b6d3dc58af94c6f0fa0f4b941fb950 Mon Sep 17 00:00:00 2001 From: ehsk Date: Mon, 27 Jan 2025 19:59:25 +0000 Subject: [PATCH 5/9] minor issue resolved --- examples/rl_gsm8k/orchestrate_rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/rl_gsm8k/orchestrate_rl.py b/examples/rl_gsm8k/orchestrate_rl.py index c017fba1..18259b7b 100644 --- a/examples/rl_gsm8k/orchestrate_rl.py +++ b/examples/rl_gsm8k/orchestrate_rl.py @@ -195,7 +195,7 @@ def extract_tape_training_samples( # check if the last produced token is the end of sequence token overflow = False if input_ids[-1] == agent.llm.tokenizer.eos_token_id else True - trace.reward = cfg.overflow_reward if overflow else reward + trace.reward = cfg.rewards.unparsable if overflow else reward overflows.append(overflow) trace.logprobs = [lp.logprob for lp in llm_call.logprobs if lp.generated] trace.group_id = new_tape.metadata.parent_id From ca51e08560911dcb442ad0c532cae89346927083 Mon Sep 17 00:00:00 2001 From: Alexandre Piche Date: Wed, 29 Jan 2025 17:03:16 +0000 Subject: [PATCH 6/9] backward compatiable reward values --- conf/rl_gsm8k.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/conf/rl_gsm8k.yaml b/conf/rl_gsm8k.yaml index cde7f052..a5c00c3d 100644 --- a/conf/rl_gsm8k.yaml +++ b/conf/rl_gsm8k.yaml @@ -42,8 +42,8 @@ overflow_reward: 0 max_prompt_length: 1024 rewards: - unparsable: -1 - wrong_answer: -0.5 + unparsable: 0 + wrong_answer: 0 correct_answer: 1 vllm_config: From e2993862d88e0b7fb893933e486dfe6178a034b8 Mon Sep 17 00:00:00 2001 From: ehsk Date: Wed, 29 Jan 2025 21:55:35 +0000 Subject: [PATCH 7/9] test_dataset for math set to MATH-500 --- examples/rl_gsm8k/orchestrate_rl.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/rl_gsm8k/orchestrate_rl.py b/examples/rl_gsm8k/orchestrate_rl.py index 18259b7b..239d49c5 100644 --- a/examples/rl_gsm8k/orchestrate_rl.py +++ b/examples/rl_gsm8k/orchestrate_rl.py @@ -38,7 +38,8 @@ def load_datasets(cfg: DictConfig) -> Tuple[list, list]: match cfg.dataset_name: case "math": - train_dataset_long_name = test_dataset_long_name = "hendrycks/competition_math" + train_dataset_long_name = "hendrycks/competition_math" + test_dataset_long_name = "HuggingFaceH4/MATH-500" process_fn = process_math_test builder_config = "main" case "gsm8k": From 568090c884bb1062c6cb757f9dd11071a12a139b Mon Sep 17 00:00:00 2001 From: ehsk Date: Wed, 29 Jan 2025 22:08:28 +0000 Subject: [PATCH 8/9] builder_config added for test dataset --- examples/rl_gsm8k/orchestrate_rl.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/rl_gsm8k/orchestrate_rl.py b/examples/rl_gsm8k/orchestrate_rl.py index 239d49c5..0240e672 100644 --- a/examples/rl_gsm8k/orchestrate_rl.py +++ b/examples/rl_gsm8k/orchestrate_rl.py @@ -41,6 +41,7 @@ def load_datasets(cfg: DictConfig) -> Tuple[list, list]: train_dataset_long_name = "hendrycks/competition_math" test_dataset_long_name = "HuggingFaceH4/MATH-500" process_fn = process_math_test + test_builder_config = "default" builder_config = "main" case "gsm8k": train_dataset_long_name = test_dataset_long_name = "openai/gsm8k" @@ -54,8 +55,9 @@ def load_datasets(cfg: DictConfig) -> Tuple[list, list]: case _: raise ValueError(f"Unknown dataset: {cfg.dataset_name}") + test_builder_config = test_builder_config or builder_config train_dataset = load_dataset(train_dataset_long_name, builder_config, split="train", trust_remote_code=True) - test_dataset = load_dataset(test_dataset_long_name, builder_config, split="test", trust_remote_code=True) + test_dataset = load_dataset(test_dataset_long_name, test_builder_config, split="test", trust_remote_code=True) train_samples = [ process_fn(s) for s in tqdm(train_dataset, desc="Processing train samples") if process_fn(s) is not None ] From 2ee467c0610df9b790d559e6b709df5b28868579 Mon Sep 17 00:00:00 2001 From: ehsk Date: Wed, 29 Jan 2025 22:27:22 +0000 Subject: [PATCH 9/9] MATH-500 issue resolved --- examples/rl_gsm8k/deepseek_math_eval/process_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/rl_gsm8k/deepseek_math_eval/process_utils.py b/examples/rl_gsm8k/deepseek_math_eval/process_utils.py index cf3b2044..534b9387 100644 --- a/examples/rl_gsm8k/deepseek_math_eval/process_utils.py +++ b/examples/rl_gsm8k/deepseek_math_eval/process_utils.py @@ -40,6 +40,9 @@ def process_gsm8k_test(item): def process_math_test(item): question = item["problem"] + if "subject" in item and "type" not in item: + item["type"] = item["subject"] + try: answer = extract_math_answer(question, item["solution"], task="cot") except Exception: