Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tasks to replicate Math-shepherd #1052

Merged
merged 52 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
9757b5a
Add draft for math-shepherd generator
plaguss Oct 25, 2024
83ab0ea
First draft of step by step generator
plaguss Oct 25, 2024
b7c2df5
Merge branch 'develop' of https://github.com/argilla-io/distilabel in…
plaguss Oct 25, 2024
ee10a6f
First working version of the math shepherd generator
plaguss Oct 29, 2024
5f5d823
Add helper function to parse the solutions
plaguss Nov 6, 2024
3fbb18b
Add passing tests
plaguss Nov 6, 2024
80a76b3
First version of completer working decently enough
plaguss Nov 6, 2024
6d63567
Initial version of the generator
plaguss Nov 6, 2024
9e3f1ce
Update prompt to be similar across generator and completer
plaguss Nov 11, 2024
1454218
Add an example of how to implement the math-shepherd recipe
plaguss Nov 11, 2024
6353ecf
Add example pipeline
plaguss Nov 11, 2024
04d8241
Include the implementation as a paper in the docs
plaguss Nov 11, 2024
3a24aa2
Add the label category for the completer
plaguss Nov 11, 2024
ed0d351
Add docs and redirect imports
plaguss Nov 11, 2024
619ee8f
Update ExpandColumns to allow decoding json encoded lists
plaguss Nov 12, 2024
2bee8e0
Add FormatPRM step to prepare the data for training
plaguss Nov 12, 2024
a7be8bb
Update example with FormatPRM
plaguss Nov 12, 2024
fb71c63
Add tutorial to reproduce Math-Shepherd
plaguss Nov 12, 2024
3ca7b7d
Redirect import
plaguss Nov 12, 2024
986f281
Update docs/sections/pipeline_samples/papers/math_shepherd.md
plaguss Nov 18, 2024
7a63129
Add comment per code review
plaguss Nov 18, 2024
02e38fc
Update src/distilabel/steps/tasks/math_shepherd/utils.py
plaguss Nov 18, 2024
d48598a
Merge branch 'develop' of https://github.com/argilla-io/distilabel in…
plaguss Nov 25, 2024
8403444
Update dummy llm to fix tests
plaguss Nov 25, 2024
9011ad1
Return list instead of json encoded list
plaguss Nov 25, 2024
3bb074d
Update function to deal with the new output generated by the LLMs
plaguss Nov 25, 2024
b8267e3
Update docs with the Expand updated
plaguss Nov 26, 2024
75638c2
Add a new argument to expand columns to account the split of distilab…
plaguss Nov 26, 2024
ab3b624
Update the code to return a list instead of json encoded list and upd…
plaguss Nov 26, 2024
d6e7ee3
Fix possible missing data in serialization
plaguss Nov 26, 2024
f272e71
Update argilla test
plaguss Nov 27, 2024
5ee6115
Extra control for unexpected values in distilabel_metadata statistics
plaguss Nov 27, 2024
9d661ce
Fix update of distilabel metadata
plaguss Nov 27, 2024
4145f80
Add metadata to the completer and take care of not removing previous …
plaguss Nov 27, 2024
5cff76d
Add structured generation to the math shepherd generator
plaguss Nov 28, 2024
122d597
Add tests for structured generation
plaguss Nov 28, 2024
fe4ad1a
Add safeguard for not found solutions
plaguss Nov 28, 2024
81b5edb
Add extra controls to prevent errors with the generator
plaguss Nov 28, 2024
0fc3bf7
Add control to generator outputs
plaguss Nov 28, 2024
fa4328f
Fix error with format_output
plaguss Nov 29, 2024
7a137af
Add TRL format
plaguss Nov 29, 2024
ee0f0f1
Add structured output to the completer
plaguss Nov 29, 2024
a02e9eb
Add extra control on the formatter
plaguss Nov 29, 2024
acd0d24
Let statistics in the completer with the same format as the outputs
plaguss Nov 29, 2024
b283934
Remove comment
plaguss Dec 1, 2024
72e26cc
Add extra control on the completers metadata
plaguss Dec 2, 2024
faa08a1
Fix types and default value on error
plaguss Dec 2, 2024
1883abe
Update the docs with the new pipeline version
plaguss Dec 2, 2024
4329214
Fix extra types and add new example
plaguss Dec 2, 2024
171975c
Fix refactor on variable names
plaguss Dec 4, 2024
92a44bd
Merge branch 'develop' of https://github.com/argilla-io/distilabel in…
plaguss Dec 4, 2024
750f8d6
Add noqa to pass CI tests
plaguss Dec 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/assets/tutorials-assets/math-sheperd.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 8 additions & 0 deletions docs/sections/pipeline_samples/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@ hide: toc

[:octicons-arrow-right-24: Paper](papers/clair.md)

- __Math Shepherd__

---

Learn about Math-Shepherd, a framework to generate datasets to train process reward models (PRMs) which assign reward scores to each step of math problem solutions.

[:octicons-arrow-right-24: Paper](papers/math_shepherd.md)

</div>

## Examples
Expand Down
119 changes: 119 additions & 0 deletions docs/sections/pipeline_samples/papers/math_shepherd.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
---
hide: toc
---

# Create datasets to train a Process Reward Model using Math-Shepherd

This example will introduce [Math-Shepherd: Verify and Reinforce LLMs Step-by-step without Human Annotations](https://arxiv.org/abs/2312.08935), an innovative math process reward model (PRM) which assigns reward scores to each step of math problem solutions. Specifically, we will present a recipe to create datasets to train such models.

## Replica

Unlike traditional models that only look at final answers (Output Reward Models or ORM), this system evaluates each step of a mathematical solution and assigns reward scores to individual solution steps. Let's see the Figure 2 from the paper, which makes a summary of the labelling approach presented in their work.

![Math-Shepherd framework](../../../assets/tutorials-assets/math-sheperd.png)

In the traditional ORM approach, the annotation was done depending on the final outcome, while the Process Reward Model (PRM) allows labelling the different steps that lead to a solution, making for a richer set of information.

### Steps involved

- [`MathShepherdGenerator`](https://distilabel.argilla.io/dev/components-gallery/task/mathshepherdgenerator/): This step is in charge of generating solutions for the instruction. Depending on the value set for the `M`, this step can be used to generate both the `golden_solution`, to be used as a reference for the labeller, or the set of `solutions` to be labelled. For the `solutions` column we want some diversity, to allow the model to reach both good and bad solutions, so we have a representative sample for the labeller, so it may be better to use a "weaker" model.

- [`MathShepherdCompleter`](https://distilabel.argilla.io/dev/components-gallery/task/mathshepherdcompleter/). This task does the job of the `completer` in the paper, generating completions as presented in Figure 2, section 3.3.2. It doesn't generate a column on it's own, but updates the steps generated in the `solutions` column from the [`MathShepherdGenerator`](https://distilabel.argilla.io/dev/components-gallery/task/mathshepherdgenerator/), using as reference to label the data, the `golden_solution`. So in order for this step to work, we need both of this columns in our dataset. Depending on the type of dataset, we may already have access to the `golden_solution`, even if it's with a different name, but it's not the same for the `solutions`.

- [`FormatPRM`](https://distilabel.argilla.io/dev/components-gallery/task/formatprm/). This step does the auxiliary job of preparing the data to follow the format defined in the paper of having two columns `input` and `label`. After running the [`MathShepherdCompleter`](https://distilabel.argilla.io/dev/components-gallery/task/mathshepherdcompleter/), we have raw data that can be formatted as the user want. Using [`ExpandColumns`](https://distilabel.argilla.io/latest/components-gallery/steps/expandcolumns/) and this step, one can directly obtain the same format presented in the dataset shared in the paper: [peiyi9979/Math-Shepherd](https://huggingface.co/datasets/peiyi9979/Math-Shepherd?row=0).

## Data preparation

For this example, just as the original paper, we are using the [openai/gsm8k](https://huggingface.co/datasets/openai/gsm8k) dataset. We only need a dataset with instructions to be solved (in this case it corresponds to the `question` column), and we can generate everything else using our predefined steps.

## Building the pipeline

The pipeline uses `openai/gsm8k` as reference, but the pipeline can be applied to different datasets, keep in mind the prompts can be modified with the current definition, by tweaking the `extra_rules` and `few_shots` in each task:

```python
from datasets import load_dataset

from distilabel.steps.tasks import MathShepherdCompleter, MathShepherdGenerator, FormatPRM
from distilabel.models import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import CombineOutputs, ExpandColumns

ds_name = "openai/gsm8k"

ds = load_dataset(ds_name, "main", split="test").rename_column("question", "instruction").select(range(3)) # (1)

with Pipeline(name="Math-Shepherd") as pipe:
model_id_70B = "meta-llama/Meta-Llama-3.1-70B-Instruct"
model_id_8B = "meta-llama/Meta-Llama-3.1-8B-Instruct"

llm_70B = InferenceEndpointsLLM(
model_id=model_id_70B,
tokenizer_id=model_id_70B,
generation_kwargs={"max_new_tokens": 1024, "temperature": 0.6},
)
llm_8B = InferenceEndpointsLLM(
model_id=model_id_8B,
tokenizer_id=model_id_8B,
generation_kwargs={"max_new_tokens": 2048, "temperature": 0.6},
) # (2)

generator_golden = MathShepherdGenerator(
name="golden_generator",
llm=llm_70B,
) # (3)
generator = MathShepherdGenerator(
name="generator",
llm=llm_8B,
M=5
) # (4)
completer = MathShepherdCompleter(
name="completer",
llm=llm_8B,
N=4
) # (5)

combine = CombineOutputs()

expand = ExpandColumns(
name="expand_columns",
columns=["solutions"],
encoded=True,
) # (6)
formatter = FormatPRM(name="format_prm") # (7)

[generator_golden, generator] >> combine >> completer >> expand >> formatter # (8)
```

1. Will use just 3 rows from the sample dataset, and rename the "question" to "instruction", to set the expected value for the [`MathShepherdGenerator`](https://distilabel.argilla.io/dev/components-gallery/task/mathshepherdgenerator/).

2. We will use 2 different LLMs, `meta-llama/Meta-Llama-3.1-70B-Instruct` (a stronger model for hte `golden_solution`) and `meta-llama/Meta-Llama-3.1-8B-Instruct` (a weaker one to generate candidate solutions, and the completions).
plaguss marked this conversation as resolved.
Show resolved Hide resolved

3. This [`MathShepherdGenerator`](https://distilabel.argilla.io/dev/components-gallery/task/mathshepherdgenerator/) task, that uses the *stronger* model, will generate the `golden_solution` for us, the step by step solution for the task.

4. Another [`MathShepherdGenerator`](https://distilabel.argilla.io/dev/components-gallery/task/mathshepherdgenerator/) task, but in this case using the *weaker* model will generate candidate `solutions` (`M=5` in total).

5. Now the [`MathShepherdCompleter`](https://distilabel.argilla.io/dev/components-gallery/task/mathshepherdcompleter/) task will generate `n=4` *completions* for each step of each candidate solution in the `solutions` column, and label them using the `golden_solution` as shown in Figure 2 in the paper. This step will add the label (it uses [+ and -] tags following the implementation in the paper, but these values can be modified) to the `solutions` column in place, instead of generating an additional column, but the intermediate completions won't be shown at the end.

6. The [`ExpandColumns`](https://distilabel.argilla.io/latest/components-gallery/steps/expandcolumns/) step expands the solution to match the instruction, so if we had set M=5, we would now have 5x instruction-pair solutions. One can omit both this and the following step and process the data for training as preferred.

7. And finally, the [`FormatPRM`](https://distilabel.argilla.io/dev/components-gallery/task/formatprm/) generates two columns: `input` and `label` which prepare the data for training as presented in the original Math-Shepherd dataset.

8. Both the `generator_golden` and `generator` can be run in parallel as there's no dependency between them, and after that we combine the results and pass them to the `completer`. Finally, we use the `expand` and `formatter` prepare the data in the expected format to train the Process Reward Model as defined in the original paper.

## Script and final dataset

To see all the pieces in place, take a look at the full pipeline:

??? Run

```python
python examples/pipe_math_shepherd.py
```

??? "Full pipeline"

```python title="pipe_math_shepherd.py"
--8<-- "examples/pipe_math_shepherd.py"
```

The resulting dataset can be seen at: [plaguss/test_math_shepherd_prm](https://huggingface.co/datasets/plaguss/test_math_shepherd_prm).
74 changes: 74 additions & 0 deletions examples/pipe_math_shepherd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from datasets import load_dataset

from distilabel.models import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import CombineOutputs, ExpandColumns
from distilabel.steps.tasks import (
FormatPRM,
MathShepherdCompleter,
MathShepherdGenerator,
)

ds_name = "openai/gsm8k"

ds = (
load_dataset(ds_name, "main", split="test")
.rename_column("question", "instruction")
.select(range(3))
)


with Pipeline(name="Math-Shepherd") as pipe:
model_id_70B = "meta-llama/Meta-Llama-3.1-70B-Instruct"
model_id_8B = "meta-llama/Meta-Llama-3.1-8B-Instruct"

llm_70B = InferenceEndpointsLLM(
model_id=model_id_8B,
tokenizer_id=model_id_8B,
generation_kwargs={"max_new_tokens": 1024, "temperature": 0.5},
)
llm_8B = InferenceEndpointsLLM(
model_id=model_id_8B,
tokenizer_id=model_id_8B,
generation_kwargs={"max_new_tokens": 2048, "temperature": 0.7},
)

generator_golden = MathShepherdGenerator(
name="golden_generator",
llm=llm_70B,
)
generator = MathShepherdGenerator(
name="generator",
llm=llm_8B,
M=5,
)
completer = MathShepherdCompleter(name="completer", llm=llm_8B, N=4)

combine = CombineOutputs()

expand = ExpandColumns(
name="expand_columns",
columns=["solutions"],
encoded=True,
)
formatter = FormatPRM(name="format_prm")
[generator_golden, generator] >> combine >> completer >> expand >> formatter


if __name__ == "__main__":
distiset = pipe.run(use_cache=False, dataset=ds)
distiset.push_to_hub("plaguss/test_math_shepherd_prm")
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ nav:
- UltraFeedback: "sections/pipeline_samples/papers/ultrafeedback.md"
- APIGen: "sections/pipeline_samples/papers/apigen.md"
- CLAIR: "sections/pipeline_samples/papers/clair.md"
- Math Shepherd: "sections/pipeline_samples/papers/math_shepherd.md"
- Examples:
- Benchmarking with distilabel: "sections/pipeline_samples/examples/benchmarking_with_distilabel.md"
- Structured generation with outlines: "sections/pipeline_samples/examples/llama_cpp_with_outlines.md"
Expand Down
54 changes: 53 additions & 1 deletion src/distilabel/steps/columns/expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from itertools import zip_longest
from typing import TYPE_CHECKING, Any, Dict, List, Union

from pydantic import field_validator
from pydantic import field_validator, model_validator
from typing_extensions import Self

from distilabel.steps.base import Step, StepInput

Expand All @@ -34,6 +36,10 @@ class ExpandColumns(Step):
columns: A dictionary that maps the column to be expanded to the new column name
or a list of columns to be expanded. If a list is provided, the new column name
will be the same as the column name.
encoded: A bool to inform Whether the columns are JSON encoded lists. If this value is
set to True, the columns will be decoded before expanding. Alternatively, to specify
columns that can be encoded, a list can be provided. In this case, the column names
informed must be a subset of the columns selected for expansion.

Input columns:
- dynamic (determined by `columns` attribute): The columns to be expanded into
Expand Down Expand Up @@ -68,9 +74,34 @@ class ExpandColumns(Step):
# >>> result
# [{'instruction': 'instruction 1', 'generation': 'generation 1'}, {'instruction': 'instruction 1', 'generation': 'generation 2'}]
```

Expand the selected columns which are JSON encoded into multiple rows:

```python
from distilabel.steps import ExpandColumns

expand_columns = ExpandColumns(
columns=["generation"],
encoded=True, # It can also be a list of columns that are encoded, i.e. ["generation"]
)
expand_columns.load()

result = next(
expand_columns.process(
[
{
"instruction": "instruction 1",
"generation": '["generation 1", "generation 2"]'}
],
)
)
# >>> result
# [{'instruction': 'instruction 1', 'generation': 'generation 1'}, {'instruction': 'instruction 1', 'generation': 'generation 2'}]
```
"""

columns: Union[Dict[str, str], List[str]]
encoded: Union[bool, List[str]] = False

@field_validator("columns")
@classmethod
Expand All @@ -88,6 +119,22 @@ def always_dict(cls, value: Union[Dict[str, str], List[str]]) -> Dict[str, str]:

return value

@model_validator(mode="after")
def is_subset(self) -> Self:
"""Ensure the "encoded" column names are a subset of the "columns" selected.

Returns:
The "encoded" attribute updated to work internally.
"""
if isinstance(self.encoded, list):
if not set(self.encoded).issubset(set(self.columns.keys())):
raise ValueError(
"The 'encoded' columns must be a subset of the 'columns' selected for expansion."
)
if isinstance(self.encoded, bool):
self.encoded = list(self.columns.keys()) if self.encoded else []
return self

@property
def inputs(self) -> "StepColumns":
"""The columns to be expanded."""
Expand All @@ -110,6 +157,11 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
Yields:
The expanded rows.
"""
if self.encoded:
for input in inputs:
for column in self.encoded:
input[column] = json.loads(input[column])

yield [row for input in inputs for row in self._expand_columns(input)]

def _expand_columns(self, input: Dict[str, Any]) -> List[Dict[str, Any]]:
Expand Down
6 changes: 6 additions & 0 deletions src/distilabel/steps/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@
)
from distilabel.steps.tasks.magpie.base import Magpie
from distilabel.steps.tasks.magpie.generator import MagpieGenerator
from distilabel.steps.tasks.math_shepherd.completer import MathShepherdCompleter
from distilabel.steps.tasks.math_shepherd.generator import MathShepherdGenerator
from distilabel.steps.tasks.math_shepherd.utils import FormatPRM
from distilabel.steps.tasks.pair_rm import PairRM
from distilabel.steps.tasks.prometheus_eval import PrometheusEval
from distilabel.steps.tasks.quality_scorer import QualityScorer
Expand Down Expand Up @@ -81,6 +84,9 @@
"InstructionBacktranslation",
"Magpie",
"MagpieGenerator",
"MathShepherdGenerator",
"MathShepherdCompleter",
"FormatPRM",
"PairRM",
"PrometheusEval",
"QualityScorer",
Expand Down
14 changes: 14 additions & 0 deletions src/distilabel/steps/tasks/math_shepherd/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Loading