-
Notifications
You must be signed in to change notification settings - Fork 147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Tasks to replicate APIGen
#925
Merged
Changes from 42 commits
Commits
Show all changes
74 commits
Select commit
Hold shift + click to select a range
edc1e06
Add apigen task module
plaguss 4b372fb
Add tests for apigen
plaguss 3a0da42
Fix default name for dataset info when requesting the number of examples
plaguss 01b43ab
checkpoint
plaguss 5058f65
Add tests for apigen generator
plaguss 8ee19e9
Create jinja template, split methods and add docstrings
plaguss d95375b
Update string format
plaguss 02d6803
Simplify function setting and move it to load method
plaguss 3371a37
Add tests for semantic checker
plaguss 19b9576
Add prompt template for semantic checker
plaguss 9f90191
Redirect import for semantic checker
plaguss ef7f263
Fix docstrins for output columns
plaguss 76a6da0
Add semantic checker task from apigen
plaguss 050a744
Add notes for execution checker
plaguss e4be16d
Merge with develop and fix conflicts
plaguss b8c356c
Remove extra jump of line
plaguss 94ef973
Add first version of data sampler, step helper for apigen
plaguss 5cffc3f
Add tests for data sampler
plaguss 952a640
Add integration test to check the sampler can be mixed with another g…
plaguss f5994d8
Draft tests for new execution checker
plaguss 5c8974a
Move helper functions
plaguss dab8a8b
Draft for execution checker functionality
plaguss d17cbde
Add first version of execution checker and tests
plaguss a2ae5f2
Add tests for utils module of apigen
plaguss 18be0b8
Remove unnecessary step for transformation and rename files for clarity
plaguss 71c0729
Fix import
plaguss cc25c8f
Change function results name to show the original results from the ex…
plaguss c5fae66
Remove print when the url for a reference doesn't contain https://arxiv
plaguss 127e377
first working version
plaguss a2279cd
Merge branch 'develop' of https://github.com/argilla-io/distilabel in…
plaguss bdbd8b3
Fix tests including previous columns
plaguss 7648535
Go back to previous name for dummy llm
plaguss 421125c
Change dummy llm names on tests
plaguss 051de38
Read the answers from the model parsed instead of dumped string
plaguss 614a817
Add option to include the tools if available for few shot
plaguss fa3cbf4
Allow extra checks for the parameter types and tests for those
plaguss 0cd8bc6
Add docs for the execution checker
plaguss f27186d
Add new icon for execution
plaguss 70260f4
Fix return type for outputs column
plaguss c49092a
Fix docstrings
plaguss 93fb319
Redirect imports to top level
plaguss 3e64236
Update docstrings to render on components gallery
plaguss 1560e0c
Improve docstrings for fields in the data sampler
plaguss 0fdde41
Remove unnecesary data from docstrings and remove TODO
plaguss 8d76bbe
Add missing data variable in example
plaguss cf74fae
Update src/distilabel/steps/tasks/apigen/execution_checker.py
plaguss ae3e4e2
Refactor to return formatted json string instead of dict to simplify …
plaguss 0ea95a7
Draft tutorial to replicate paper
plaguss d7c6a64
Allow number to be a dict with values and probabilities
plaguss 21e0757
Update pipeline run call
plaguss 82aa352
Add functionality to load functions from a folder with .py files
plaguss e70a258
Fix comment for arg
plaguss cbc288c
Add example implementation
plaguss 71a3517
Add dependency for vllm
plaguss 8dceb11
Fix dependency name
plaguss f363292
Add setuptools-scm in the script with the dependencies to install it …
plaguss a43b8e9
Another attempt with system
plaguss 7325cef
Add tests to take into account casting methods
plaguss 2f7418a
Avoid casting and update prompt to ensure argument order is respected
plaguss 4ac735c
Inform error type on generator
plaguss 60c1cd9
Add extra checks and safeguards for failed answer generation
plaguss bf1baed
Ensure the error is of the expected type
plaguss 740d3fe
Fix unstructured generation
plaguss 841d985
Remove json fences and fix semantic checker
plaguss dcded6a
Control case of functions without arguments
plaguss 2a76812
Add additional checks to run the execution checker
plaguss 58b92be
Remove additional dependency
plaguss f2eb160
Merge branch 'develop' of https://github.com/argilla-io/distilabel in…
plaguss 8a9743f
Try fixing CI error with dependencies
plaguss c26cca4
Install dependency for the system
plaguss 9c756b2
Undo fix attempt
plaguss 55ccc1e
Try fixing llvmlite dependency issue
plaguss c5ccf5a
Remove additional dependency as it breaks other tests
plaguss 776b36c
Merge with develop and fix conflict
plaguss File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# Copyright 2023-present, Argilla, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import TYPE_CHECKING, Any, List | ||
|
||
from distilabel.llms.base import LLM, AsyncLLM | ||
from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin | ||
|
||
if TYPE_CHECKING: | ||
from distilabel.llms.typing import GenerateOutput | ||
from distilabel.steps.tasks.typing import FormattedInput | ||
|
||
|
||
class DummyAsyncLLM(AsyncLLM): | ||
structured_output: Any = None | ||
|
||
def load(self) -> None: | ||
pass | ||
|
||
@property | ||
def model_name(self) -> str: | ||
return "test" | ||
|
||
async def agenerate( # type: ignore | ||
self, input: "FormattedInput", num_generations: int = 1 | ||
) -> "GenerateOutput": | ||
return ["output" for _ in range(num_generations)] | ||
|
||
|
||
class DummySyncLLM(LLM): | ||
structured_output: Any = None | ||
|
||
def load(self) -> None: | ||
super().load() | ||
|
||
@property | ||
def model_name(self) -> str: | ||
return "test" | ||
|
||
def generate( # type: ignore | ||
self, inputs: "FormattedInput", num_generations: int = 1 | ||
) -> "GenerateOutput": | ||
return [["output" for _ in range(num_generations)] for _ in range(len(inputs))] | ||
|
||
|
||
class DummyMagpieLLM(LLM, MagpieChatTemplateMixin): | ||
def load(self) -> None: | ||
pass | ||
|
||
@property | ||
def model_name(self) -> str: | ||
return "test" | ||
|
||
def generate( | ||
self, inputs: List["FormattedInput"], num_generations: int = 1, **kwargs: Any | ||
) -> List["GenerateOutput"]: | ||
return [ | ||
["Hello Magpie" for _ in range(num_generations)] for _ in range(len(inputs)) | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
# Copyright 2023-present, Argilla, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import random | ||
from itertools import islice | ||
from typing import TYPE_CHECKING, Any, Dict, List | ||
|
||
from pydantic import Field | ||
from typing_extensions import override | ||
|
||
from distilabel.steps.base import GeneratorStep | ||
|
||
if TYPE_CHECKING: | ||
from distilabel.steps.base import GeneratorStepOutput | ||
|
||
|
||
class DataSampler(GeneratorStep): | ||
"""Step to sample from a dataset. | ||
|
||
`GeneratorStep` that samples from a dataset and yields it in batches. | ||
|
||
Output columns: | ||
- dynamic (based on the keys found on the first dictionary of the list): The columns | ||
of the dataset. | ||
|
||
Categories: | ||
- load | ||
|
||
Attributes: | ||
data: The list of dictionaries to sample from. | ||
size: The number of samples per example. | ||
samples: The number of examples. | ||
plaguss marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Examples: | ||
Sample data from a list of dictionaries: | ||
|
||
```python | ||
from distilabel.steps import DataSampler | ||
|
||
sampler = DataSampler( | ||
data=[{"sample": f"sample {i}"} for i in range(30)], | ||
samples=10, | ||
size=2, | ||
batch_size=4 | ||
) | ||
sampler.load() | ||
|
||
result = next(sampler.process()) | ||
# >>> result | ||
# ([{'sample': ['sample 7', 'sample 0']}, {'sample': ['sample 2', 'sample 21']}, {'sample': ['sample 17', 'sample 12']}, {'sample': ['sample 2', 'sample 14']}], False) | ||
``` | ||
|
||
Pipeline with a loader and a sampler combined in a single stream: | ||
|
||
```python | ||
from datasets import load_dataset | ||
|
||
from distilabel.steps import LoadDataFromDicts, DataSampler | ||
from distilabel.steps.tasks.apigen.utils import PrepareExamples | ||
from distilabel.pipeline import Pipeline | ||
|
||
ds = ( | ||
load_dataset("Salesforce/xlam-function-calling-60k", split="train") | ||
.shuffle(seed=42) | ||
.select(range(500)) | ||
.to_list() | ||
) | ||
|
||
with Pipeline(name="APIGenPipeline") as pipeline: | ||
loader_seeds = LoadDataFromDicts(data=data) | ||
plaguss marked this conversation as resolved.
Show resolved
Hide resolved
|
||
sampler = DataSampler( | ||
data=ds, | ||
size=2, | ||
samples=len(data), | ||
batch_size=8, | ||
) | ||
prep_examples = PrepareExamples() | ||
|
||
sampler >> prep_examples | ||
( | ||
[loader_seeds, prep_examples] | ||
>> combine_steps | ||
) | ||
# Now we have a single stream of data with the loader and the sampler data | ||
``` | ||
""" | ||
|
||
data: List[Dict[str, Any]] = Field(default_factory=list, exclude=True) | ||
size: int = 2 # Number of samples per example | ||
samples: int = 100 # Number of examples | ||
|
||
@override | ||
def process(self, offset: int = 0) -> "GeneratorStepOutput": # type: ignore | ||
"""Yields batches from a list of dictionaries. | ||
|
||
Args: | ||
offset: The offset to start the generation from. Defaults to `0`. | ||
|
||
Yields: | ||
A list of Python dictionaries as read from the inputs (propagated in batches) | ||
and a flag indicating whether the yield batch is the last one. | ||
""" | ||
|
||
total_samples = 0 | ||
|
||
while total_samples < self.samples: | ||
batch = [] | ||
bs = min(self.batch_size, self.samples - total_samples) | ||
for _ in range(self.batch_size): | ||
choices = random.choices(self.data, k=self.size) | ||
choices = self._transform_data(choices) | ||
batch.extend(choices) | ||
total_samples += bs | ||
batch = list(islice(batch, bs)) | ||
yield (batch, True if total_samples >= self.samples else False) | ||
batch = [] | ||
|
||
@staticmethod | ||
def _transform_data(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | ||
if not data: | ||
return [] | ||
|
||
result = {key: [] for key in data[0].keys()} | ||
|
||
for item in data: | ||
for key, value in item.items(): | ||
result[key].append(value) | ||
|
||
return [result] | ||
|
||
@property | ||
def outputs(self) -> List[str]: | ||
return list(self.data[0].keys()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Copyright 2023-present, Argilla, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not at all, I added this one because I thought it could simplify testing examples with a dummy LLM without having to access to the tests. WDYT? No problem removing it