Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tasks to replicate Math-shepherd #1052

Merged
merged 52 commits into from
Dec 4, 2024
Merged

Add tasks to replicate Math-shepherd #1052

merged 52 commits into from
Dec 4, 2024

Conversation

plaguss
Copy link
Contributor

@plaguss plaguss commented Nov 6, 2024

Description

This task Integrates the tasks to replicate:
Math-Shepherd: Verify and Reinforce LLMs Step-by-step without Human Annotations

Examples:

It integrates structured outputs as I found while testing that it highly improves the success rate of the generations.

Example pipeline:

from datasets import load_dataset

from distilabel.steps.tasks.math_shepherd.generator import MathShepherdGenerator
from distilabel.steps.tasks.math_shepherd.completer import MathShepherdCompleter
from distilabel.steps.tasks.math_shepherd.utils import FormatPRM
from distilabel.models import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import CombineOutputs, ExpandColumns

ds_name = "openai/gsm8k"

ds = load_dataset(ds_name, "main", split="test").rename_column("question", "instruction").select(range(3))


with Pipeline(name="Math-Shepherd") as pipe:
    model_id_70B = "meta-llama/Meta-Llama-3.1-70B-Instruct"
    model_id_8B = "meta-llama/Meta-Llama-3.1-8B-Instruct"

    llm_70B = InferenceEndpointsLLM(
        model_id=model_id_8B,
        tokenizer_id=model_id_8B,
        generation_kwargs={"max_new_tokens": 1024, "temperature": 0.6},
    )
    llm_8B = InferenceEndpointsLLM(
        model_id=model_id_8B,
        tokenizer_id=model_id_8B,
        generation_kwargs={"max_new_tokens": 2048, "temperature": 0.6},
    )

    generator_golden = MathShepherdGenerator(
        name="golden_generator",
        llm=llm_70B,
    )
    generator = MathShepherdGenerator(
        name="generator",
        llm=llm_8B,
        M=5  # Generate 5 sample solutions
    )
    completer = MathShepherdCompleter(
        name="completer",
        llm=llm_8B,
        N=4  # Each solution will be tested with 4 completions during labelling
    )

    combine = CombineOutputs()
    expand = ExpandColumns(
        name="expand_columns",
        columns=["solutions"],
        split_statistics=True,
    )
    formatter = FormatPRM(name="format_prm", format="trl")
    [generator_golden, generator] >> combine >> completer >> expand >> formatter


if __name__ == "__main__":
    distiset = pipe.run(use_cache=False, dataset=ds)
    distiset.push_to_hub("plaguss/test_math_shepherd_prm")

A sample dataset can be seen at plaguss/test_math_shepherd_prm

@plaguss plaguss added the enhancement New feature or request label Nov 6, 2024
@plaguss plaguss added this to the 1.5.0 milestone Nov 6, 2024
@plaguss plaguss self-assigned this Nov 6, 2024
Copy link

github-actions bot commented Nov 6, 2024

Documentation for this PR has been built. You can view it at: https://distilabel.argilla.io/pr-1052/

Copy link

codspeed-hq bot commented Nov 6, 2024

CodSpeed Performance Report

Merging #1052 will not alter performance

Comparing math-shepherd (750f8d6) with develop (f8e41cd)

Summary

✅ 1 untouched benchmarks

@plaguss plaguss marked this pull request as ready for review November 12, 2024 12:02
@plaguss plaguss requested a review from gabrielmbmb November 12, 2024 12:03
…ate to work with the statistics from the llm
@plaguss plaguss merged commit 6bb61d1 into develop Dec 4, 2024
8 checks passed
@plaguss plaguss deleted the math-shepherd branch December 4, 2024 10:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants