diff --git a/.gitignore b/.gitignore index 82f9275..ab001d9 100644 --- a/.gitignore +++ b/.gitignore @@ -160,3 +160,5 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +data/ \ No newline at end of file diff --git a/recipes/README.md b/recipes/README.md index 7591cd8..0e9b669 100644 --- a/recipes/README.md +++ b/recipes/README.md @@ -1,11 +1,13 @@ # Recipes -Here we include yaml configs to run the three test time compute variants detailed in the blog post: +Here we include YAML configs to run the three test time compute variants detailed in the blog post: + - Best of N: [`recipes/Llama-3.2-1B-Instruct/best_of_n.yaml`](Llama-3.2-1B-Instruct/best_of_n.yaml) - Beam Search: [`recipes/Llama-3.2-1B-Instruct/beam_search.yaml`](Llama-3.2-1B-Instruct/beam_search.yaml) - Diverse Verifier Beam Search (DVTS): [`recipes/Llama-3.2-1B-Instruct/dvts.yaml`](Llama-3.2-1B-Instruct/dvts.yaml) -Each approach can be launched by specifying the associated yaml file: +Each approach can be launched by specifying the associated YAML file: + ``` python scripts/test_time_compute.py # for example: @@ -13,18 +15,19 @@ python scripts/test_time_compute.py recipes/Llama-3.2-1B-Instruct/best_of_n.yaml ``` -The configs shown here are for the `Llama-3.2-1B-Instruct` model, you can override the size of the llama model evaluated by including it in the command line arguments: +The configs shown here are for the `Llama-3.2-1B-Instruct` model, you can override the choice of model by including it in the command line arguments: ```shell python scripts/test_time_compute.py recipes/Llama-3.2-1B-Instruct/best_of_n.yaml --model_path=Llama-3.2-3B-Instruct --hub_dataset_id=/Llama-3.2-3B-Instruct-bon-completions ``` > [!WARNING] -> __best of n__ and __DVTS__ can be run at `n=256` and then subsampled for get complarable solutions for running at `n=4,16,64` etc. The beam search variant **must** be run at the correct `n` in order to make a valid comparison. +> __best of n__ and __DVTS__ can be run at `n=256` and then subsampled for get comparable solutions for running at `n=4,16,64` etc. The beam search variant **must** be run at the correct `n` in order to make a valid comparison. + +## Reproducing results on the MATH-500 dataset -## Reproducing results on the MATH-500 dataset: -We provide slurm scripts to configure array jobs to parallelize the evaluation of the three methods: +We provide Slurm scripts to configure array jobs to parallelize the evaluation of the three methods: ```shell @@ -41,11 +44,13 @@ sbatch recipes/launch_array.slurm recipes/Llama-3.2-1B-Instruct/dvts.yaml --n=16 By default this will shard the dataset into 20 chunks in order to run the algorithm in parallel, the dataset will be pushed to the Hugging Face hub. The full dataset can then be recontructed with: + ```shell python scripts/merge_chunks.py --dataset_name=/Llama-3.2-1B-Instruct-bon-completions ``` -## Exacting the MATH-500 accuracy numbers: +## Exacting the MATH-500 accuracy numbers + To get the final numbers for the evalations, we use the [Qwen2.5-Math evaluation repo](https://github.com/QwenLM/Qwen2.5-Math), their codebase is well documented, so please refer to their instuctions. diff --git a/setup.py b/setup.py index 4962878..17d5f7a 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ "pebble", # for parallel processing "latex2sympy2==1.9.1", # for MATH answer parsing "word2number", # for MATH answer parsing - "transformers>=4.47.0", + "transformers>=4.47.0", "fastapi", ] diff --git a/src/sal/search/best_of_n.py b/src/sal/search/best_of_n.py index d1d08c6..dbd268d 100644 --- a/src/sal/search/best_of_n.py +++ b/src/sal/search/best_of_n.py @@ -29,17 +29,20 @@ def best_of_n(x, config: Config, llm: LLM, prm: PRM): {"role": "system", "content": config.system_prompt}, {"role": "user", "content": prompt}, ] - for prompt in x["problem"] * config.n + for prompt in x["problem"] ] tokenizer = llm.get_tokenizer() # TODO: set the augmented template from a file if config.custom_chat_template is not None: tokenizer.chat_template = config.custom_chat_template templated_convs = tokenizer.apply_chat_template( - convs, - tokenize=False, + convs, tokenize=False, add_generation_prompt=True ) + # Duplicate convs to generate config.n completions per prompt so we can do continous batching + # This makes [p1, p2, p3, p4] become [p1, p1, p2, p2, p3, p3, p4, p4] for e.g. config.n=2 + templated_convs = [c for conv in templated_convs for c in [conv] * config.n] + # Initialize empty lists for completions and completion tokens completions = [[] for _ in range(len(x["problem"]))] completion_tokens = [[] for _ in range(len(x["problem"]))] diff --git a/src/sal/utils/data.py b/src/sal/utils/data.py index d06f9d2..aa39a77 100644 --- a/src/sal/utils/data.py +++ b/src/sal/utils/data.py @@ -72,5 +72,9 @@ def save_dataset(dataset, config): if config.output_dir is None: config.output_dir = f"data/{config.model_path}" Path(config.output_dir).mkdir(parents=True, exist_ok=True) - dataset.to_json(f"{config.output_dir}/{config.approach}_completions.jsonl", lines=True) - logger.info(f"Saved completions to {config.output_dir}/{config.approach}_completions.jsonl") + dataset.to_json( + f"{config.output_dir}/{config.approach}_completions.jsonl", lines=True + ) + logger.info( + f"Saved completions to {config.output_dir}/{config.approach}_completions.jsonl" + )