diff --git a/docs/assets/tutorials-assets/overview-apigen.jpg b/docs/assets/tutorials-assets/overview-apigen.jpg new file mode 100644 index 000000000..61deefac9 Binary files /dev/null and b/docs/assets/tutorials-assets/overview-apigen.jpg differ diff --git a/docs/sections/pipeline_samples/index.md b/docs/sections/pipeline_samples/index.md index ed2f51686..0d9605c58 100644 --- a/docs/sections/pipeline_samples/index.md +++ b/docs/sections/pipeline_samples/index.md @@ -83,6 +83,15 @@ hide: toc [:octicons-arrow-right-24: Paper](papers/ultrafeedback.md) +- __APIGen__ + + --- + + Learn how to create verifiable high-quality datases for function-calling applications. + + [:octicons-arrow-right-24: Paper](papers/apigen.md) + + ## Examples diff --git a/docs/sections/pipeline_samples/papers/apigen.md b/docs/sections/pipeline_samples/papers/apigen.md new file mode 100644 index 000000000..8cb034c18 --- /dev/null +++ b/docs/sections/pipeline_samples/papers/apigen.md @@ -0,0 +1,239 @@ +--- +hide: toc +--- + +# Create Function-Calling datasets with APIGen + +This example will introduce [APIGen: Automated Pipeline for Generating Verifiable and Diverse Function-Calling Datasets](https://arxiv.org/abs/2406.18518), a data generation pipeline designed to synthesize verifiable high-quality datasets for function-calling applications. + +## Replication + +The following figure showcases the APIGen framework: + +![APIGen framework](../../../assets/tutorials-assets/overview-apigen.jpg) + +Now, let's walk through the key steps illustrated in the figure: + +- [`DataSampler`](https://distilabel.argilla.io/dev/components-gallery/step/datasampler/): With the help of this step and the original [Salesforce/xlam-function-calling-60k](https://huggingface.co/datasets/Salesforce/xlam-function-calling-60k) we are getting the Seed QA Data Sampler for the prompt template. + +- [`APIGenGenerator`](https://distilabel.argilla.io/dev/components-gallery/task/apigengenerator/): This step does the job of the *Query-Answer Generator*, including the format checker from *Stage 1: Format Checker* thanks to the structured output generation. + +- [`APIGenExecutionChecker`](https://distilabel.argilla.io/dev/components-gallery/task/apigenexecutionchecker/): This step is in charge of the *Stage 2: Execution Checker*. + +- [`APIGenSemanticChecker`](https://distilabel.argilla.io/dev/components-gallery/task/apigensemanticchecker/): Step in charge of running *Stage 3: Semantic Checker*, can use the same or a different LLM, we are using the same as in [`APIGenGenerator`](https://distilabel.argilla.io/dev/components-gallery/task/apigengenerator/) step. + +The current implementation hasn't utilized the *Diverse Prompt Library*. To incorporate it, one could either adjust the prompt template within the [`APIGenGenerator`](https://distilabel.argilla.io/dev/components-gallery/task/apigengenerator/) or develop a new sampler specifically for this purpose. As for the *API Sampler*, while no specific data is shared here, we've created illustrative examples to demonstrate the pipeline's functionality. These examples represent a mix of data that could be used to replicate the sampler's output. + +## Data preparation + +The original paper tells about the data they used and give some hints, but nothing was shared. In this example, we will write a bunch of examples by hand to showcase how this pipeline can be built. + +Assume we have the following function names, and corresponding descriptions of their behaviour: + +```python +data = [ + { + "func_name": "final_velocity", + "func_desc": "Calculates the final velocity of an object given its initial velocity, acceleration, and time.", + }, + { + "func_name": "permutation_count", + "func_desc": "Calculates the number of permutations of k elements from a set of n elements.", + }, + { + "func_name": "getdivision", + "func_desc": "Divides two numbers by making an API call to a division service.", + }, + { + "func_name": "binary_addition", + "func_desc": "Adds two binary numbers and returns the result as a binary string.", + }, + { + "func_name": "swapi_planet_resource", + "func_desc": "get a specific planets resource", + }, + { + "func_name": "disney_character", + "func_desc": "Find a specific character using this endpoint", + } +] +``` + +The original paper refers to both python functions and APIs, but we will make use of python functions exclusively for simplicity. In order to execute and check this functions/APIs, we need access to the code, which we have moved to a python file: [lib_apigen.py](../../../../examples/lib_apigen.py). All this functions are executable, but we also need access to their *tool* representation. For this, we will make use of transformers' *get_json_schema* function[^1]. + +[^1]: Read this nice blog post for more information on tools and the reasoning behind `get_json_schema`: [Tool Use, Unified](https://huggingface.co/blog/unified-tool-use). + +We have all the machinery prepared in our libpath, except from the *tool* definition. With the help of our helper function `load_module_from_path` we will load this python module, collect all the tools, and add them to each row in our `data` variable. + +```python +from distilabel.steps.tasks.apigen.utils import load_module_from_path + +libpath_module = load_module_from_path(libpath) +tools = getattr(libpath_module, "get_tools")() # call get_tools() + +for row in data: + # The tools should have a mix where both the correct and irrelevant tools are present. + row.update({"tools": [tools[row["func_name"]]]}) +``` + +Now we have all the necessary data for our prompt. Additionally, we will make use of the original dataset as few-shot examples to enhance the model: + +```python +ds_og = ( + load_dataset("Salesforce/xlam-function-calling-60k", split="train") + .shuffle(seed=42) + .select(range(500)) + .to_list() +) +``` + +We have just loaded a subset and transformed it to a list of dictionaries, as we will use it in the [`DataSampler`](https://distilabel.argilla.io/dev/components-gallery/steps/datasampler/) `GeneratorStep`, grabbing random examples from the original dataset. + +## Building the Pipeline + +Now that we've walked through each component, it's time to see how it all comes together, here's the Pipeline code: + +```python +with Pipeline(name="apigen-example") as pipeline: + loader_seeds = LoadDataFromDicts(data=data) # (1) + + sampler = DataSampler( # (2) + data=ds_og, + size=2, + samples=len(data), + batch_size=8, + ) + + prep_examples = PrepareExamples() # This step will add the 'examples' column + + combine_steps = CombineOutputs() # (3) + + model_id = "meta-llama/Meta-Llama-3.1-70B-Instruct" + llm=InferenceEndpointsLLM( # (4) + model_id=model_id, + tokenizer_id=model_id, + generation_kwargs={ + "temperature": 0.7, + "max_new_tokens": 2048, + }, + ) + apigen = APIGenGenerator( # (5) + llm=llm, + use_default_structured_output=True, + ) + + execution_checker = APIGenExecutionChecker(libpath=str(libpath)) # (6) + semantic_checker = APIGenSemanticChecker(llm=llm) # (7) + + sampler >> prep_examples + ( + [loader_seeds, prep_examples] + >> combine_steps + >> apigen + >> execution_checker + >> semantic_checker + ) +``` + +1. Load the data seeds we are going to use to generate our function calling dataset. + +2. The `DataSampler` together with `PrepareExamples` will be used to help us create the few-shot +examples from the original dataset to be fed in our prompt. + +3. Combine both columns to obtain a single stream of data + +4. Will reuse the same LLM for the generation and the semantic checks. + +5. Creates the `query` and `answers` that will be used together with the `tools` to fine-tune a new model. Will generate the structured outputs to ensure we have valid JSON formatted answers. + +6. Adds columns `keep_row_after_execution_check` and `execution_result`. + +7. Adds columns `keep_row_after_semantic_check` and `thought`. + +## Script and final dataset + +To see all the pieces in place, take a look at the full pipeline, as well as an example row that would be generated from this pipeline. + +??? Run + + ```python + python examples/pipeline_apigen.py + ``` + +```python title="pipeline_apigen.py" +--8<-- "examples/pipeline_apigen.py" +``` + +Example row: + +```json +{ + "func_name": "final_velocity", + "func_desc": "Calculates the final velocity of an object given its initial velocity, acceleration, and time.", + "tools": [ + { + "function": { + "description": "Calculates the final velocity of an object given its initial velocity, acceleration, and time.", + "name": "final_velocity", + "parameters": { + "properties": { + "acceleration": { + "description": "The acceleration of the object.", + "type": "number" + }, + "initial_velocity": { + "description": "The initial velocity of the object.", + "type": "number" + }, + "time": { + "description": "The time elapsed.", + "type": "number" + } + }, + "required": [ + "initial_velocity", + "acceleration", + "time" + ], + "type": "object" + } + }, + "type": "function" + } + ], + "examples": "## Query:\nRetrieve the first 15 comments for post ID '12345' from the Tokapi mobile API.\n## Answers:\n[{\"name\": \"v1_post_post_id_comments\", \"arguments\": {\"post_id\": \"12345\", \"count\": 15}}]\n\n## Query:\nRetrieve the detailed recipe for the cake with ID 'cake101'.\n## Answers:\n[{\"name\": \"detailed_cake_recipe_by_id\", \"arguments\": {\"is_id\": \"cake101\"}}]\n\n## Query:\nWhat are the frequently asked questions and their answers for Coca-Cola Company? Also, what are the suggested tickers based on Coca-Cola Company?\n## Answers:\n[{\"name\": \"symbols_faq\", \"arguments\": {\"ticker_slug\": \"KO\"}}, {\"name\": \"symbols_suggested\", \"arguments\": {\"ticker_slug\": \"KO\"}}]", + "query": "What would be the final velocity of an object that starts at rest and accelerates at 9.8 m/s^2 for 10 seconds.", + "answers": "[{\"arguments\": {\"acceleration\": \"9.8\", \"initial_velocity\": \"0\", \"time\": \"10\"}, \"name\": \"final_velocity\"}]", + "distilabel_metadata": { + "raw_input_a_p_i_gen_generator_0": [ + { + "content": "You are a data labeler. Your responsibility is to generate a set of diverse queries and corresponding answers for the given functions in JSON format.\n\nConstruct queries and answers that exemplify how to use these functions in a practical scenario. Include in each query specific, plausible values for each parameter. For instance, if the function requires a date, use a typical and reasonable date.\n\nEnsure the query:\n- Is clear and concise\n- Demonstrates typical use cases\n- Includes all necessary parameters in a meaningful way. For numerical parameters, it could be either numbers or words\n- Across a variety level of difficulties, ranging from beginner and advanced use cases\n- The corresponding result's parameter types and ranges match with the function's descriptions\n\nEnsure the answer:\n- Is a list of function calls in JSON format\n- The length of the answer list should be equal to the number of requests in the query\n- Can solve all the requests in the query effectively", + "role": "system" + }, + { + "content": "Here are examples of queries and the corresponding answers for similar functions:\n## Query:\nRetrieve the first 15 comments for post ID '12345' from the Tokapi mobile API.\n## Answers:\n[{\"name\": \"v1_post_post_id_comments\", \"arguments\": {\"post_id\": \"12345\", \"count\": 15}}]\n\n## Query:\nRetrieve the detailed recipe for the cake with ID 'cake101'.\n## Answers:\n[{\"name\": \"detailed_cake_recipe_by_id\", \"arguments\": {\"is_id\": \"cake101\"}}]\n\n## Query:\nWhat are the frequently asked questions and their answers for Coca-Cola Company? Also, what are the suggested tickers based on Coca-Cola Company?\n## Answers:\n[{\"name\": \"symbols_faq\", \"arguments\": {\"ticker_slug\": \"KO\"}}, {\"name\": \"symbols_suggested\", \"arguments\": {\"ticker_slug\": \"KO\"}}]\n\nNote that the query could be interpreted as a combination of several independent requests.\n\nBased on these examples, generate 1 diverse query and answer pairs for the function `final_velocity`.\nThe detailed function description is the following:\nCalculates the final velocity of an object given its initial velocity, acceleration, and time.\n\nThese are the available tools to help you:\n[{'type': 'function', 'function': {'name': 'final_velocity', 'description': 'Calculates the final velocity of an object given its initial velocity, acceleration, and time.', 'parameters': {'type': 'object', 'properties': {'initial_velocity': {'type': 'number', 'description': 'The initial velocity of the object.'}, 'acceleration': {'type': 'number', 'description': 'The acceleration of the object.'}, 'time': {'type': 'number', 'description': 'The time elapsed.'}}, 'required': ['initial_velocity', 'acceleration', 'time']}}}]\n\nThe output MUST strictly adhere to the following JSON format, and NO other text MUST be included:\n```json\n[\n {\n \"query\": \"The generated query.\",\n \"answers\": [\n {\n \"name\": \"api_name\",\n \"arguments\": {\n \"arg_name\": \"value\"\n ... (more arguments as required)\n }\n },\n ... (more API calls as required)\n ]\n }\n]\n```\n\nNow please generate 1 diverse query and answer pairs following the above format.", + "role": "user" + } + ], + "raw_input_a_p_i_gen_semantic_checker_0": [ + { + "content": "As a data quality evaluator, you must assess the alignment between a user query, corresponding function calls, and their execution results.\nThese function calls and results are generated by other models, and your task is to ensure these results accurately reflect the user\u2019s intentions.\n\nDo not pass if:\n1. The function call does not align with the query\u2019s objective, or the input arguments appear incorrect.\n2. The function call and arguments are not properly chosen from the available functions.\n3. The number of function calls does not correspond to the user\u2019s intentions.\n4. The execution results are irrelevant and do not match the function\u2019s purpose.\n5. The execution results contain errors or reflect that the function calls were not executed successfully.", + "role": "system" + }, + { + "content": "Given Information:\n- All Available Functions:\nCalculates the final velocity of an object given its initial velocity, acceleration, and time.\n- User Query: What would be the final velocity of an object that starts at rest and accelerates at 9.8 m/s^2 for 10 seconds.\n- Generated Function Calls: [{\"arguments\": {\"acceleration\": \"9.8\", \"initial_velocity\": \"0\", \"time\": \"10\"}, \"name\": \"final_velocity\"}]\n- Execution Results: ['9.8']\n\nNote: The query may have multiple intentions. Functions may be placeholders, and execution results may be truncated due to length, which is acceptable and should not cause a failure.\n\nThe main decision factor is wheather the function calls accurately reflect the query's intentions and the function descriptions.\nProvide your reasoning in the thought section and decide if the data passes (answer yes or no).\nIf not passing, concisely explain your reasons in the thought section; otherwise, leave this section blank.\n\nYour response MUST strictly adhere to the following JSON format, and NO other text MUST be included.\n```\n{\n \"thought\": \"Concisely describe your reasoning here\",\n \"passes\": \"yes\" or \"no\"\n}\n```\n", + "role": "user" + } + ], + "raw_output_a_p_i_gen_generator_0": "{\"pairs\": [\n {\n \"answers\": [\n {\n \"arguments\": {\n \"acceleration\": \"9.8\",\n \"initial_velocity\": \"0\",\n \"time\": \"10\"\n },\n \"name\": \"final_velocity\"\n }\n ],\n \"query\": \"What would be the final velocity of an object that starts at rest and accelerates at 9.8 m/s^2 for 10 seconds.\"\n }\n]}", + "raw_output_a_p_i_gen_semantic_checker_0": "{\n \"thought\": \"\",\n \"passes\": \"yes\"\n}" + }, + "model_name": "meta-llama/Meta-Llama-3.1-70B-Instruct", + "keep_row_after_execution_check": true, + "execution_result": [ + "9.8" + ], + "thought": "", + "keep_row_after_semantic_check": true +} +``` diff --git a/examples/lib_apigen.py b/examples/lib_apigen.py new file mode 100644 index 000000000..d49f414e6 --- /dev/null +++ b/examples/lib_apigen.py @@ -0,0 +1,146 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional + + +def final_velocity(initial_velocity: float, acceleration: float, time: float) -> int: + """Calculates the final velocity of an object given its initial velocity, acceleration, and time. + + Args: + initial_velocity: The initial velocity of the object. + acceleration: The acceleration of the object. + time: The time elapsed. + + Returns: + The final velocity + """ + # Tool: + # {"name": "final_velocity", "description": "Calculates the final velocity of an object given its initial velocity, acceleration, and time.", "parameters": {"initial_velocity": {"description": "The initial velocity of the object.", "type": "float"}, "acceleration": {"description": "The acceleration of the object.", "type": "float"}, "time": {"description": "The time elapsed.", "type": "float"}}} + # Answer: + # {"name": "final_velocity", "arguments": {"initial_velocity": 5, "acceleration": 1.5, "time": 40}} + return initial_velocity + acceleration * time + + +def permutation_count(n: int, k: int) -> int: + """Calculates the number of permutations of k elements from a set of n elements. + + Args: + n: The total number of elements in the set. + k: The number of elements to choose for the permutation. + + Returns: + The number of permutations. + """ + # Tool: + # {"name": "permutation_count", "description": "Calculates the number of permutations of k elements from a set of n elements.", "parameters": {"n": {"description": "The total number of elements in the set.", "type": "int"}, "k": {"description": "The number of elements to choose for the permutation.", "type": "int"}}} + # Answer: + # {"name": "permutation_count", "arguments": {"n": 10, "k": 3}} + import math + + return math.factorial(n) / math.factorial(n - k) + + +def getdivision(dividend: int, divisor: int) -> float: + """Divides two numbers by making an API call to a division service. + + Args: + dividend: The dividend in the division operation. + divisor: The divisor in the division operation. + + Returns: + Division of the 2 numbers. + """ + # Tool: + # {"name": "getdivision", "description": "Divides two numbers by making an API call to a division service.", "parameters": {"divisor": {"description": "The divisor in the division operation.", "type": "int", "default": ""}, "dividend": {"description": "The dividend in the division operation.", "type": "int", "default": ""}}} + # Answer: + # {"name": "getdivision", "arguments": {"divisor": 25, "dividend": 100}} + return dividend / divisor + + +def binary_addition(a: str, b: str) -> str: + """Adds two binary numbers and returns the result as a binary string. + + Args: + a: The first binary number. + b: The second binary number. + + Raises: + ValueError: On invalid binary number. + + Returns: + Binary string of the sum of the two numbers. + """ + # Tool: + # {"name": "binary_addition", "description": "Adds two binary numbers and returns the result as a binary string.", "parameters": {"a": {"description": "The first binary number.", "type": "str"}, "b": {"description": "The second binary number.", "type": "str"}}} + # Answer: + # {"name": "binary_addition", "arguments": {"a": "1010", "b": "1101"}} + if not set(a).issubset("01") or not set(b).issubset("01"): + raise ValueError("Invalid binary number") + + return bin(int(a, 2) + int(b, 2))[2:] + + +def _make_request(url: str, params: Optional[Dict[str, Any]] = None): + import requests + + req = requests.get(url, params=params) + return req.json() + + +def swapi_planet_resource(id: str) -> Dict[str, Any]: + """get a specific planets resource + + Args: + id: identifier of the planet + + Returns: + Information about the planet. + """ + # url = "https://swapi.dev/api/planets/1" + return _make_request(r"https://swapi.dev/api/planets/", params={"id": id}) + + +def disney_character(name: str) -> Dict[str, Any]: + """Find a specific character using this endpoint + + Args: + name: Name of the character to look for. + + Returns: + Infrmation about the character. + """ + # Example: + # url = "https://api.disneyapi.dev/character" + # params = {"name": "mulan"} + return _make_request(r"https://api.disneyapi.dev/character", params={"name": name}) + + +def get_lib(): + return { + "swapi_planet_resource": swapi_planet_resource, + "disney_character": disney_character, + "final_velocity": final_velocity, + "permutation_count": permutation_count, + "getdivision": getdivision, + "binary_addition": binary_addition, + } + + +def get_tools() -> Dict[str, Dict[str, Any]]: + """Returns the tool representation of the functions in the library.""" + # TODO: Improve the `get_json_schema`, it fails on a lot of examples. + from transformers.utils import get_json_schema + + return {name: get_json_schema(func) for name, func in get_lib().items()} diff --git a/examples/pipeline_apigen.py b/examples/pipeline_apigen.py new file mode 100644 index 000000000..e63e16e39 --- /dev/null +++ b/examples/pipeline_apigen.py @@ -0,0 +1,116 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +from datasets import load_dataset + +from distilabel.llms import InferenceEndpointsLLM +from distilabel.pipeline import Pipeline +from distilabel.steps import CombineOutputs, DataSampler, LoadDataFromDicts +from distilabel.steps.tasks import ( + APIGenExecutionChecker, + APIGenGenerator, + APIGenSemanticChecker, +) +from distilabel.steps.tasks.apigen.utils import PrepareExamples, load_module_from_path + +libpath = Path(__file__).parent / "lib_apigen.py" + +data = [ + { + "func_name": "final_velocity", + "func_desc": "Calculates the final velocity of an object given its initial velocity, acceleration, and time.", + }, + { + "func_name": "permutation_count", + "func_desc": "Calculates the number of permutations of k elements from a set of n elements.", + }, + { + "func_name": "getdivision", + "func_desc": "Divides two numbers by making an API call to a division service.", + }, + { + "func_name": "binary_addition", + "func_desc": "Adds two binary numbers and returns the result as a binary string.", + }, + { + "func_name": "swapi_planet_resource", + "func_desc": "get a specific planets resource", + }, + { + "func_name": "disney_character", + "func_desc": "Find a specific character using this endpoint", + }, +] + +libpath_module = load_module_from_path(libpath) +tools = libpath_module.get_tools() # call get_tools() + +# TODO: Add in the tools between 0 and 2 extra tools to make the task more challenging. +for row in data: + # The tools should have a mix where both the correct and irrelevant tools are present. + row.update({"tools": [tools[row["func_name"]]]}) + + +ds_og = ( + load_dataset("Salesforce/xlam-function-calling-60k", split="train") + .shuffle(seed=42) + .select(range(500)) + .to_list() +) + + +with Pipeline(name="APIGenPipeline") as pipeline: + loader_seeds = LoadDataFromDicts(data=data) + sampler = DataSampler( + data=ds_og, + size=2, + samples=len(data), + batch_size=8, + ) + + prep_examples = PrepareExamples() + + model_id = "meta-llama/Meta-Llama-3.1-70B-Instruct" + llm = InferenceEndpointsLLM( + model_id=model_id, + tokenizer_id=model_id, + generation_kwargs={ + "temperature": 0.7, + "max_new_tokens": 2048, + }, + ) + apigen = APIGenGenerator( + llm=llm, + use_default_structured_output=True, + ) + combine_steps = CombineOutputs() + + execution_checker = APIGenExecutionChecker(libpath=str(libpath)) + semantic_checker = APIGenSemanticChecker(llm=llm) + + sampler >> prep_examples + ( + [loader_seeds, prep_examples] + >> combine_steps + >> apigen + >> execution_checker + >> semantic_checker + ) + + +if __name__ == "__main__": + distiset = pipeline.run() + print(distiset["default"]["train"][0]) diff --git a/mkdocs.yml b/mkdocs.yml index 2ef623467..3c26a6cc1 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -208,6 +208,7 @@ nav: - Instruction Backtranslation: "sections/pipeline_samples/papers/instruction_backtranslation.md" - Prometheus 2: "sections/pipeline_samples/papers/prometheus.md" - UltraFeedback: "sections/pipeline_samples/papers/ultrafeedback.md" + - APIGen: "sections/pipeline_samples/papers/apigen.md" - Examples: - Benchmarking with distilabel: "sections/pipeline_samples/examples/benchmarking_with_distilabel.md" - Structured generation with outlines: "sections/pipeline_samples/examples/llama_cpp_with_outlines.md" diff --git a/src/distilabel/distiset.py b/src/distilabel/distiset.py index b6b31f7a5..8e52c667d 100644 --- a/src/distilabel/distiset.py +++ b/src/distilabel/distiset.py @@ -695,8 +695,9 @@ def _grab_citations(dag: "DAG") -> List[str]: for ref in references.values(): try: bibtex_refs.append(get_bibtex(ref)) - except ValueError as e: - print(f"Error: {e}") + except ValueError: + # No need to inform in this case, it's noise + pass except AttributeError as e: print( f"Couldn't obtain the bibtex format for the ref: '{ref}', error: {e}" diff --git a/src/distilabel/llms/_dummy.py b/src/distilabel/llms/_dummy.py new file mode 100644 index 000000000..740f98cd4 --- /dev/null +++ b/src/distilabel/llms/_dummy.py @@ -0,0 +1,70 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, List + +from distilabel.llms.base import LLM, AsyncLLM +from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin + +if TYPE_CHECKING: + from distilabel.llms.typing import GenerateOutput + from distilabel.steps.tasks.typing import FormattedInput + + +class DummyAsyncLLM(AsyncLLM): + structured_output: Any = None + + def load(self) -> None: + pass + + @property + def model_name(self) -> str: + return "test" + + async def agenerate( # type: ignore + self, input: "FormattedInput", num_generations: int = 1 + ) -> "GenerateOutput": + return ["output" for _ in range(num_generations)] + + +class DummySyncLLM(LLM): + structured_output: Any = None + + def load(self) -> None: + super().load() + + @property + def model_name(self) -> str: + return "test" + + def generate( # type: ignore + self, inputs: "FormattedInput", num_generations: int = 1 + ) -> "GenerateOutput": + return [["output" for _ in range(num_generations)] for _ in range(len(inputs))] + + +class DummyMagpieLLM(LLM, MagpieChatTemplateMixin): + def load(self) -> None: + pass + + @property + def model_name(self) -> str: + return "test" + + def generate( + self, inputs: List["FormattedInput"], num_generations: int = 1, **kwargs: Any + ) -> List["GenerateOutput"]: + return [ + ["Hello Magpie" for _ in range(num_generations)] for _ in range(len(inputs)) + ] diff --git a/src/distilabel/steps/__init__.py b/src/distilabel/steps/__init__.py index 79c10a268..cc1be59f9 100644 --- a/src/distilabel/steps/__init__.py +++ b/src/distilabel/steps/__init__.py @@ -45,6 +45,7 @@ FormatTextGenerationSFT, ) from distilabel.steps.generators.data import LoadDataFromDicts +from distilabel.steps.generators.data_sampler import DataSampler from distilabel.steps.generators.huggingface import ( LoadDataFromDisk, LoadDataFromFileSystem, @@ -83,6 +84,7 @@ "FormatChatGenerationSFT", "FormatTextGenerationSFT", "LoadDataFromDicts", + "DataSampler", "LoadDataFromDisk", "LoadDataFromFileSystem", "LoadDataFromHub", diff --git a/src/distilabel/steps/generators/data_sampler.py b/src/distilabel/steps/generators/data_sampler.py new file mode 100644 index 000000000..6b2e55bf0 --- /dev/null +++ b/src/distilabel/steps/generators/data_sampler.py @@ -0,0 +1,179 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +from itertools import islice +from typing import TYPE_CHECKING, Any, Dict, List + +from pydantic import Field +from typing_extensions import override + +from distilabel.steps.base import GeneratorStep + +if TYPE_CHECKING: + from distilabel.steps.base import GeneratorStepOutput + + +class DataSampler(GeneratorStep): + """Step to sample from a dataset. + + `GeneratorStep` that samples from a dataset and yields it in batches. + This step is useful when you have a pipeline that can benefit from using examples + in the prompts for example as few-shot learning, that can be changing on each row. + For example, you can pass a list of dictionaries with N examples and generate M samples + from it (assuming you have another step loading data, this M should have the same size + as the data being loaded in that step). The size S argument is the number of samples per + row generated, so each example would contain S examples to be used as examples. + + Attributes: + data: The list of dictionaries to sample from. + size: Number of samples per example. For example in a few-shot learning scenario, + the number of few-shot examples that will be generated per example. Defaults to 2. + samples: Number of examples that will be generated by the step in total. + If used with another loader step, this should be the same as the number + of samples in the loader step. Defaults to 100. + + Output columns: + - dynamic (based on the keys found on the first dictionary of the list): The columns + of the dataset. + + Categories: + - load + + Examples: + Sample data from a list of dictionaries: + + ```python + from distilabel.steps import DataSampler + + sampler = DataSampler( + data=[{"sample": f"sample {i}"} for i in range(30)], + samples=10, + size=2, + batch_size=4 + ) + sampler.load() + + result = next(sampler.process()) + # >>> result + # ([{'sample': ['sample 7', 'sample 0']}, {'sample': ['sample 2', 'sample 21']}, {'sample': ['sample 17', 'sample 12']}, {'sample': ['sample 2', 'sample 14']}], False) + ``` + + Pipeline with a loader and a sampler combined in a single stream: + + ```python + from datasets import load_dataset + + from distilabel.steps import LoadDataFromDicts, DataSampler + from distilabel.steps.tasks.apigen.utils import PrepareExamples + from distilabel.pipeline import Pipeline + + ds = ( + load_dataset("Salesforce/xlam-function-calling-60k", split="train") + .shuffle(seed=42) + .select(range(500)) + .to_list() + ) + data = [ + { + "func_name": "final_velocity", + "func_desc": "Calculates the final velocity of an object given its initial velocity, acceleration, and time.", + }, + { + "func_name": "permutation_count", + "func_desc": "Calculates the number of permutations of k elements from a set of n elements.", + }, + { + "func_name": "getdivision", + "func_desc": "Divides two numbers by making an API call to a division service.", + }, + ] + with Pipeline(name="APIGenPipeline") as pipeline: + loader_seeds = LoadDataFromDicts(data=data) + sampler = DataSampler( + data=ds, + size=2, + samples=len(data), + batch_size=8, + ) + prep_examples = PrepareExamples() + + sampler >> prep_examples + ( + [loader_seeds, prep_examples] + >> combine_steps + ) + # Now we have a single stream of data with the loader and the sampler data + ``` + """ + + data: List[Dict[str, Any]] = Field(default_factory=list, exclude=True) + size: int = Field( + default=2, + description=( + "Number of samples per example. For example in a few-shot learning scenario, the number " + "of few-shot examples that will be generated per example." + ), + ) + samples: int = Field( + default=100, + description=( + "Number of examples that will be generated by the step in total. " + "If used with another loader step, this should be the same as the number of " + "samples in the loader step." + ), + ) + + @override + def process(self, offset: int = 0) -> "GeneratorStepOutput": # type: ignore + """Yields batches from a list of dictionaries. + + Args: + offset: The offset to start the generation from. Defaults to `0`. + + Yields: + A list of Python dictionaries as read from the inputs (propagated in batches) + and a flag indicating whether the yield batch is the last one. + """ + + total_samples = 0 + + while total_samples < self.samples: + batch = [] + bs = min(self.batch_size, self.samples - total_samples) + for _ in range(self.batch_size): + choices = random.choices(self.data, k=self.size) + choices = self._transform_data(choices) + batch.extend(choices) + total_samples += bs + batch = list(islice(batch, bs)) + yield (batch, True if total_samples >= self.samples else False) + batch = [] + + @staticmethod + def _transform_data(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + if not data: + return [] + + result = {key: [] for key in data[0].keys()} + + for item in data: + for key, value in item.items(): + result[key].append(value) + + return [result] + + @property + def outputs(self) -> List[str]: + return list(self.data[0].keys()) diff --git a/src/distilabel/steps/generators/huggingface.py b/src/distilabel/steps/generators/huggingface.py index f6e782a75..721b3d408 100644 --- a/src/distilabel/steps/generators/huggingface.py +++ b/src/distilabel/steps/generators/huggingface.py @@ -219,11 +219,11 @@ def _get_dataset_num_examples(self) -> int: Returns: The number of examples in the dataset. """ - return ( - self._dataset_info[self.config if self.config else "default"] - .splits[self.split] - .num_examples - ) + default_config = self.config + if not default_config: + default_config = list(self._dataset_info.keys())[0] + + return self._dataset_info[default_config].splits[self.split].num_examples def _get_dataset_columns(self) -> List[str]: """Get the columns of the dataset, based on the `config` runtime parameter provided. diff --git a/src/distilabel/steps/tasks/__init__.py b/src/distilabel/steps/tasks/__init__.py index c7b7d7239..065567a57 100644 --- a/src/distilabel/steps/tasks/__init__.py +++ b/src/distilabel/steps/tasks/__init__.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from distilabel.steps.tasks.apigen.execution_checker import APIGenExecutionChecker +from distilabel.steps.tasks.apigen.generator import APIGenGenerator +from distilabel.steps.tasks.apigen.semantic_checker import APIGenSemanticChecker from distilabel.steps.tasks.argilla_labeller import ArgillaLabeller from distilabel.steps.tasks.base import GeneratorTask, Task from distilabel.steps.tasks.complexity_scorer import ComplexityScorer @@ -54,6 +57,9 @@ "GeneratorTask", "Task", "ArgillaLabeller", + "APIGenExecutionChecker", + "APIGenGenerator", + "APIGenSemanticChecker", "ComplexityScorer", "EvolInstruct", "EvolComplexity", diff --git a/src/distilabel/steps/tasks/apigen/__init__.py b/src/distilabel/steps/tasks/apigen/__init__.py new file mode 100644 index 000000000..20ce00bda --- /dev/null +++ b/src/distilabel/steps/tasks/apigen/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/src/distilabel/steps/tasks/apigen/execution_checker.py b/src/distilabel/steps/tasks/apigen/execution_checker.py new file mode 100644 index 000000000..7d30dd1f7 --- /dev/null +++ b/src/distilabel/steps/tasks/apigen/execution_checker.py @@ -0,0 +1,268 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# - Try to import the function from a given module +# - If function, try to import it and run it +# - If fails, track the error message, and return it + +import inspect +import json +from pathlib import Path +from typing import TYPE_CHECKING, Callable, Union + +from pydantic import Field, PrivateAttr +from typing_extensions import override + +from distilabel.steps.base import Step, StepInput +from distilabel.steps.tasks.apigen.utils import ( + execute_from_response, + load_module_from_path, +) + +if TYPE_CHECKING: + from types import ModuleType + + from distilabel.steps.typing import StepColumns, StepOutput + + +class APIGenExecutionChecker(Step): + """Executes the generated function calls. + + This step checks if a given answer from a model as generated by `APIGenGenerator` + can be executed against the given library (given by `libpath`, which is a string + pointing to a python .py file with functions). + + Attributes: + libpath: The path to the library where we will retrieve the functions. + It can also point to a folder with the functions. In this case, the folder + layout should be a folder with .py files, each containing a single function, + the name of the function being the same as the filename. + check_is_dangerous: Bool to exclude some potentially dangerous functions, it contains + some heuristics found while testing. This functions can run subprocesses, deal with + the OS, or have other potentially dangerous operations. Defaults to True. + + Input columns: + - answers (`str`): List with arguments to be passed to the function, + dumped as a string from a list of dictionaries. Should be loaded using + `json.loads`. + + Output columns: + - keep_row_after_execution_check (`bool`): Whether the function should be kept or not. + - execution_result (`str`): The result from executing the function. + + Categories: + - filtering + - execution + + References: + - [APIGen: Automated Pipeline for Generating Verifiable and Diverse Function-Calling Datasets](https://arxiv.org/abs/2406.18518) + - [Salesforce/xlam-function-calling-60k](https://huggingface.co/datasets/Salesforce/xlam-function-calling-60k) + + Examples: + Execute a function from a given library with the answer from an LLM: + + ```python + from distilabel.steps.tasks import APIGenExecutionChecker + + # For the libpath you can use as an example the file at the tests folder: + # ../distilabel/tests/unit/steps/tasks/apigen/_sample_module.py + task = APIGenExecutionChecker( + libpath="../distilabel/tests/unit/steps/tasks/apigen/_sample_module.py", + ) + task.load() + + res = next( + task.process( + [ + { + "answers": [ + { + "arguments": { + "initial_velocity": 0.2, + "acceleration": 0.1, + "time": 0.5, + }, + "name": "final_velocity", + } + ], + } + ] + ) + ) + res + #[{'answers': [{'arguments': {'initial_velocity': 0.2, 'acceleration': 0.1, 'time': 0.5}, 'name': 'final_velocity'}], 'keep_row_after_execution_check': True, 'execution_result': ['0.25']}] + ``` + """ + + libpath: str = Field( + default=..., + description=( + "The path to the library where we will retrieve the functions, " + "or a folder with python files named the same as the functions they contain.", + ), + ) + check_is_dangerous: bool = Field( + default=True, + description=( + "Bool to exclude some potentially dangerous functions, it contains " + "some heuristics found while testing. This functions can run subprocesses, " + "deal with the OS, or have other potentially dangerous operations.", + ), + ) + + _toolbox: Union["ModuleType", None] = PrivateAttr(None) + + def load(self) -> None: + """Loads the library where the functions will be extracted from.""" + super().load() + if Path(self.libpath).suffix == ".py": + self._toolbox = load_module_from_path(self.libpath) + + def unload(self) -> None: + self._toolbox = None + + @property + def inputs(self) -> "StepColumns": + """The inputs for the task are those found in the original dataset.""" + return ["answers"] + + @property + def outputs(self) -> "StepColumns": + """The outputs are the columns required by `APIGenGenerator` task.""" + return ["keep_row_after_execution_check", "execution_result"] + + def _get_function(self, function_name: str) -> Callable: + """Retrieves the function from the toolbox. + + Args: + function_name: The name of the function to retrieve. + + Returns: + Callable: The function to be executed. + """ + if self._toolbox: + return getattr(self._toolbox, function_name, None) + try: + toolbox = load_module_from_path( + str(Path(self.libpath) / f"{function_name}.py") + ) + return getattr(toolbox, function_name, None) + except FileNotFoundError: + return None + except Exception as e: + self._logger.warning(f"Error loading function '{function_name}': {e}") + return None + + def _is_dangerous(self, function: Callable) -> bool: + """Checks if a function is dangerous to remove it. + Contains a list of heuristics to avoid executing possibly dangerous functions. + """ + source_code = inspect.getsource(function) + # We don't want to execute functions that use subprocess + if ( + ("subprocess." in source_code) + or ("os.system(" in source_code) + or ("input(" in source_code) + # Avoiding threading + or ("threading.Thread(" in source_code) + or ("exec(" in source_code) + # Avoiding argparse (not sure why) + or ("argparse.ArgumentParser(" in source_code) + # Avoiding logging changing the levels to not mess with the logs + or (".setLevel(" in source_code) + # Don't run a test battery + or ("unittest.main(" in source_code) + # Avoid exiting the program + or ("sys.exit(" in source_code) + or ("exit(" in source_code) + or ("raise SystemExit(" in source_code) + or ("multiprocessing.Pool(" in source_code) + ): + return True + return False + + @override + def process(self, inputs: StepInput) -> "StepOutput": + """Checks the answer to see if it can be executed. + Captures the possible errors and returns them. + + If a single example is provided, it is copied to avoid raising an error. + + Args: + inputs: A list of dictionaries with the input data. + + Yields: + A list of dictionaries with the output data. + """ + for input in inputs: + output = [] + if input["answers"]: + answers = json.loads(input["answers"]) + else: + input.update( + **{ + "keep_row_after_execution_check": False, + "execution_result": ["No answers were provided."], + } + ) + continue + for answer in answers: + if answer is None: + output.append( + { + "keep": False, + "execution_result": "Nothing was generated for this answer.", + } + ) + continue + + function_name = answer.get("name", None) + arguments = answer.get("arguments", None) + + self._logger.debug( + f"Executing function '{function_name}' with arguments: {arguments}" + ) + function = self._get_function(function_name) + + if self.check_is_dangerous: + if function and self._is_dangerous(function): + function = None + + if function is None: + output.append( + { + "keep": False, + "execution_result": f"Function '{function_name}' not found.", + } + ) + else: + execution = execute_from_response(function, arguments) + output.append( + { + "keep": execution["keep"], + "execution_result": execution["execution_result"], + } + ) + # We only consider a good response if all the answers were executed successfully, + # but keep the reasons for further review if needed. + input.update( + **{ + "keep_row_after_execution_check": all( + o["keep"] is True for o in output + ), + "execution_result": [o["execution_result"] for o in output], + } + ) + + yield inputs diff --git a/src/distilabel/steps/tasks/apigen/generator.py b/src/distilabel/steps/tasks/apigen/generator.py new file mode 100644 index 000000000..c1c691e37 --- /dev/null +++ b/src/distilabel/steps/tasks/apigen/generator.py @@ -0,0 +1,448 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.resources as importlib_resources +import json +import random +from typing import TYPE_CHECKING, Any, Callable, Dict, Final, List, Union + +import orjson +from jinja2 import Template +from pydantic import PrivateAttr +from typing_extensions import override + +from distilabel.steps.tasks.apigen.utils import remove_fences +from distilabel.steps.tasks.base import Task + +if TYPE_CHECKING: + from distilabel.steps.tasks.typing import ChatType + from distilabel.steps.typing import StepColumns + + +SYSTEM_PROMPT_API_GEN: Final[str] = """\ +You are a data labeler. Your responsibility is to generate a set of diverse queries and corresponding answers for the given functions in JSON format. + +Construct queries and answers that exemplify how to use these functions in a practical scenario. Include in each query specific, plausible values for each parameter. For instance, if the function requires a date, use a typical and reasonable date. + +Ensure the query: +- Is clear and concise +- Demonstrates typical use cases +- Includes all necessary parameters in a meaningful way. For numerical parameters, it could be either numbers or words +- Across a variety level of difficulties, ranging from beginner and advanced use cases +- The corresponding result's parameter types and ranges match with the function's descriptions + +Ensure the answer: +- Is a list of function calls in JSON format +- The length of the answer list should be equal to the number of requests in the query +- Can solve all the requests in the query effectively""" + + +class APIGenGenerator(Task): + """Generate queries and answers for the given functions in JSON format. + + The `APIGenGenerator` is inspired by the APIGen pipeline, which was designed to generate + verifiable and diverse function-calling datasets. The task generates a set of diverse queries + and corresponding answers for the given functions in JSON format. + + Attributes: + system_prompt: The system prompt to guide the user in the generation of queries and answers. + use_tools: Whether to use the tools available in the prompt to generate the queries and answers. + In case the tools are given in the input, they will be added to the prompt. + number: The number of queries to generate. It can be a list, where each number will be + chosen randomly, or a dictionary with the number of queries and the probability of each. + I.e: `number=1`, `number=[1, 2, 3]`, `number={1: 0.5, 2: 0.3, 3: 0.2}` are all valid inputs. + It corresponds to the number of parallel queries to generate. + use_default_structured_output: Whether to use the default structured output or not. + + Input columns: + - examples (`str`): Examples used as few shots to guide the model. + - func_name (`str`): Name for the function to generate. + - func_desc (`str`): Description of what the function should do. + - tools (`str`): JSON formatted string containing the tool representation of the function. + + Output columns: + - query (`str`): The list of queries. + - answers (`str`): JSON formatted string with the list of answers, containing the info as + a dictionary to be passed to the functions. + + Categories: + - text-generation + + References: + - [APIGen: Automated Pipeline for Generating Verifiable and Diverse Function-Calling Datasets](https://arxiv.org/abs/2406.18518) + - [Salesforce/xlam-function-calling-60k](https://huggingface.co/datasets/Salesforce/xlam-function-calling-60k) + + Examples: + Generate without structured output (original implementation): + + ```python + from distilabel.steps.tasks import ApiGenGenerator + from distilabel.llms import InferenceEndpointsLLM + + llm=InferenceEndpointsLLM( + model_id="meta-llama/Meta-Llama-3.1-70B-Instruct", + generation_kwargs={ + "temperature": 0.7, + "max_new_tokens": 1024, + }, + ) + apigen = ApiGenGenerator( + use_default_structured_output=False, + llm=llm + ) + apigen.load() + + res = next( + apigen.process( + [ + { + "examples": 'QUERY:\nWhat is the binary sum of 10010 and 11101?\nANSWER:\n[{"name": "binary_addition", "arguments": {"a": "10010", "b": "11101"}}]', + "func_name": "getrandommovie", + "func_desc": "Returns a list of random movies from a database by calling an external API." + } + ] + ) + ) + res + # [{'examples': 'QUERY:\nWhat is the binary sum of 10010 and 11101?\nANSWER:\n[{"name": "binary_addition", "arguments": {"a": "10010", "b": "11101"}}]', + # 'number': 1, + # 'func_name': 'getrandommovie', + # 'func_desc': 'Returns a list of random movies from a database by calling an external API.', + # 'queries': ['I want to watch a movie tonight, can you recommend a random one from your database?', + # 'Give me 5 random movie suggestions from your database to plan my weekend.'], + # 'answers': [[{'name': 'getrandommovie', 'arguments': {}}], + # [{'name': 'getrandommovie', 'arguments': {}}, + # {'name': 'getrandommovie', 'arguments': {}}, + # {'name': 'getrandommovie', 'arguments': {}}, + # {'name': 'getrandommovie', 'arguments': {}}, + # {'name': 'getrandommovie', 'arguments': {}}]], + # 'raw_input_api_gen_generator_0': [{'role': 'system', + # 'content': "You are a data labeler. Your responsibility is to generate a set of diverse queries and corresponding answers for the given functions in JSON format.\n\nConstruct queries and answers that exemplify how to use these functions in a practical scenario. Include in each query specific, plausible values for each parameter. For instance, if the function requires a date, use a typical and reasonable date.\n\nEnsure the query:\n- Is clear and concise\n- Demonstrates typical use cases\n- Includes all necessary parameters in a meaningful way. For numerical parameters, it could be either numbers or words\n- Across a variety level of difficulties, ranging from beginner and advanced use cases\n- The corresponding result's parameter types and ranges match with the function's descriptions\n\nEnsure the answer:\n- Is a list of function calls in JSON format\n- The length of the answer list should be equal to the number of requests in the query\n- Can solve all the requests in the query effectively"}, + # {'role': 'user', + # 'content': 'Here are examples of queries and the corresponding answers for similar functions:\nQUERY:\nWhat is the binary sum of 10010 and 11101?\nANSWER:\n[{"name": "binary_addition", "arguments": {"a": "10010", "b": "11101"}}]\n\nNote that the query could be interpreted as a combination of several independent requests.\nBased on these examples, generate 2 diverse query and answer pairs for the function `getrandommovie`\nThe detailed function description is the following:\nReturns a list of random movies from a database by calling an external API.\n\nThe output MUST strictly adhere to the following JSON format, and NO other text MUST be included:\n```json\n[\n {\n "query": "The generated query.",\n "answers": [\n {\n "name": "api_name",\n "arguments": {\n "arg_name": "value"\n ... (more arguments as required)\n }\n },\n ... (more API calls as required)\n ]\n }\n]\n```\n\nNow please generate 2 diverse query and answer pairs following the above format.'}]}, + # 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}] + ``` + + Generate with structured output: + + ```python + from distilabel.steps.tasks import ApiGenGenerator + from distilabel.llms import InferenceEndpointsLLM + + llm=InferenceEndpointsLLM( + model_id="meta-llama/Meta-Llama-3.1-70B-Instruct", + tokenizer="meta-llama/Meta-Llama-3.1-70B-Instruct", + generation_kwargs={ + "temperature": 0.7, + "max_new_tokens": 1024, + }, + ) + apigen = ApiGenGenerator( + use_default_structured_output=True, + llm=llm + ) + apigen.load() + + res_struct = next( + apigen.process( + [ + { + "examples": 'QUERY:\nWhat is the binary sum of 10010 and 11101?\nANSWER:\n[{"name": "binary_addition", "arguments": {"a": "10010", "b": "11101"}}]', + "func_name": "getrandommovie", + "func_desc": "Returns a list of random movies from a database by calling an external API." + } + ] + ) + ) + res_struct + # [{'examples': 'QUERY:\nWhat is the binary sum of 10010 and 11101?\nANSWER:\n[{"name": "binary_addition", "arguments": {"a": "10010", "b": "11101"}}]', + # 'number': 1, + # 'func_name': 'getrandommovie', + # 'func_desc': 'Returns a list of random movies from a database by calling an external API.', + # 'queries': ["I'm bored and want to watch a movie. Can you suggest some movies?", + # "My family and I are planning a movie night. We can't decide on what to watch. Can you suggest some random movie titles?"], + # 'answers': [[{'arguments': {}, 'name': 'getrandommovie'}], + # [{'arguments': {}, 'name': 'getrandommovie'}]], + # 'raw_input_api_gen_generator_0': [{'role': 'system', + # 'content': "You are a data labeler. Your responsibility is to generate a set of diverse queries and corresponding answers for the given functions in JSON format.\n\nConstruct queries and answers that exemplify how to use these functions in a practical scenario. Include in each query specific, plausible values for each parameter. For instance, if the function requires a date, use a typical and reasonable date.\n\nEnsure the query:\n- Is clear and concise\n- Demonstrates typical use cases\n- Includes all necessary parameters in a meaningful way. For numerical parameters, it could be either numbers or words\n- Across a variety level of difficulties, ranging from beginner and advanced use cases\n- The corresponding result's parameter types and ranges match with the function's descriptions\n\nEnsure the answer:\n- Is a list of function calls in JSON format\n- The length of the answer list should be equal to the number of requests in the query\n- Can solve all the requests in the query effectively"}, + # {'role': 'user', + # 'content': 'Here are examples of queries and the corresponding answers for similar functions:\nQUERY:\nWhat is the binary sum of 10010 and 11101?\nANSWER:\n[{"name": "binary_addition", "arguments": {"a": "10010", "b": "11101"}}]\n\nNote that the query could be interpreted as a combination of several independent requests.\nBased on these examples, generate 2 diverse query and answer pairs for the function `getrandommovie`\nThe detailed function description is the following:\nReturns a list of random movies from a database by calling an external API.\n\nNow please generate 2 diverse query and answer pairs following the above format.'}]}, + # 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}] + ``` + """ + + system_prompt: str = SYSTEM_PROMPT_API_GEN + use_default_structured_output: bool = False + number: Union[int, List[int], Dict[int, float]] = 1 + use_tools: bool = True + + _number: Union[int, None] = PrivateAttr(None) + _fn_parallel_queries: Union[Callable[[], str], None] = PrivateAttr(None) + _format_inst: Union[str, None] = PrivateAttr(None) + + def load(self) -> None: + """Loads the template for the generator prompt.""" + super().load() + _path = str( + importlib_resources.files("distilabel") + / "steps" + / "tasks" + / "templates" + / "apigen" + / "generator.jinja2" + ) + self._template = Template(open(_path).read()) + self._format_inst = self._set_format_inst() + + def _parallel_queries(self, number: int) -> Callable[[int], str]: + """Prepares the function to update the parallel queries guide in the prompt. + + Raises: + ValueError: if `is_parallel` is not a boolean or a list of floats. + + Returns: + The function to generate the parallel queries guide. + """ + if number > 1: + return ( + "It can contain multiple parallel queries in natural language for the given functions. " + "They could use either the same function with different arguments or different functions.\n" + ) + return "" + + def _get_number(self) -> int: + """Generates the number of queries to generate in a single call. + The number must be set to `_number` to avoid changing the original value + when calling `_default_error`. + """ + if isinstance(self.number, list): + self._number = random.choice(self.number) + elif isinstance(self.number, dict): + self._number = random.choices( + list(self.number.keys()), list(self.number.values()) + )[0] + else: + self._number = self.number + return self._number + + def _set_format_inst(self) -> str: + """Prepares the function to generate the formatted instructions for the prompt. + + If the default structured output is used, returns an empty string because nothing + else is needed, otherwise, returns the original addition to the prompt to guide the model + to generate a formatted JSON. + """ + return ( + "\nThe output MUST strictly adhere to the following JSON format, and NO other text MUST be included:\n" + "```\n" + "[\n" + " {\n" + ' "query": "The generated query.",\n' + ' "answers": [\n' + " {\n" + ' "name": "api_name",\n' + ' "arguments": {\n' + ' "arg_name": "value"\n' + " ... (more arguments as required)\n" + " }\n" + " },\n" + " ... (more API calls as required)\n" + " ]\n" + " }\n" + "]\n" + "```\n" + ) + + def _get_func_desc(self, input: Dict[str, Any]) -> str: + """If available and required, will use the info from the tools in the + prompt for extra information. Otherwise will use jut the function description. + """ + if not self.use_tools: + return input["func_desc"] + extra = "" # Extra information from the tools (if available will be added) + if "tools" in input: + extra = f"\n\nThis is the available tool to guide you (respect the order of the parameters):\n{input['tools']}" + return input["func_desc"] + extra + + @property + def inputs(self) -> "StepColumns": + """The inputs for the task.""" + return { + "examples": True, + "func_name": True, + "func_desc": True, + "tools": False, + } + + def format_input(self, input: Dict[str, Any]) -> "ChatType": + """The input is formatted as a `ChatType`.""" + number = self._get_number() + parallel_queries = self._parallel_queries(number) + return [ + {"role": "system", "content": self.system_prompt}, + { + "role": "user", + "content": self._template.render( + examples=input["examples"], + parallel_queries=parallel_queries, + number=number, + func_name=input["func_name"], + func_desc=self._get_func_desc(input), + format_inst=self._format_inst, + ), + }, + ] + + @property + def outputs(self) -> "StepColumns": + """The output for the task are the queries and corresponding answers.""" + return ["query", "answers", "model_name"] + + def format_output( + self, output: Union[str, None], input: Dict[str, Any] + ) -> Dict[str, Any]: + """The output is formatted as a list with the score of each instruction. + + Args: + output: the raw output of the LLM. + input: the input to the task. Used for obtaining the number of responses. + + Returns: + A dict with the queries and answers pairs. + The answers are an array of answers corresponding to the query. + Each answer is represented as an object with the following properties: + - name (string): The name of the tool used to generate the answer. + - arguments (object): An object representing the arguments passed to the tool to generate the answer. + Each argument is represented as a key-value pair, where the key is the parameter name and the + value is the corresponding value. + """ + if output is None: + return self._default_error(input) + + if not self.use_default_structured_output: + output = remove_fences(output) + + try: + pairs = orjson.loads(output) + except orjson.JSONDecodeError: + return self._default_error(input) + + pairs = pairs["pairs"] if self.use_default_structured_output else pairs + + return self._format_output(pairs, input) + + def _format_output( + self, pairs: Dict[str, Any], input: Dict[str, Any] + ) -> Dict[str, Any]: + """Parses the response, returning a dictionary with queries and answers. + + Args: + pairs: The parsed dictionary from the LLM's output. + input: The input from the `LLM`. + + Returns: + Formatted output, where the `queries` are a list of strings, and the `answers` + are a list of objects. + """ + try: + input.update( + **{ + "query": pairs[0]["query"], + "answers": json.dumps(pairs[0]["answers"]), + } + ) + return input + except Exception as e: + self._logger.error(f"Error formatting output: {e}, pairs: '{pairs}'") + return self._default_error(input) + + def _default_error(self, input: Dict[str, Any]) -> Dict[str, Any]: + """Returns a default error output, to fill the responses in case of failure.""" + input.update( + **{ + "query": None, + "answers": json.dumps([None] * self._number), + } + ) + return input + + @override + def get_structured_output(self) -> Dict[str, Any]: + """Creates the json schema to be passed to the LLM, to enforce generating + a dictionary with the output which can be directly parsed as a python dictionary. + + The schema corresponds to the following: + + ```python + from typing import Dict, List + from pydantic import BaseModel + + + class Answer(BaseModel): + name: str + arguments: Dict[str, str] + + class QueryAnswer(BaseModel): + query: str + answers: List[Answer] + + class QueryAnswerPairs(BaseModel): + pairs: List[QueryAnswer] + + json.dumps(QueryAnswerPairs.model_json_schema(), indent=4) + ``` + + Returns: + JSON Schema of the response to enforce. + """ + return { + "$defs": { + "Answer": { + "properties": { + "name": {"title": "Name", "type": "string"}, + "arguments": { + "additionalProperties": {"type": "string"}, + "title": "Arguments", + "type": "object", + }, + }, + "required": ["name", "arguments"], + "title": "Answer", + "type": "object", + }, + "QueryAnswer": { + "properties": { + "query": {"title": "Query", "type": "string"}, + "answers": { + "items": {"$ref": "#/$defs/Answer"}, + "title": "Answers", + "type": "array", + }, + }, + "required": ["query", "answers"], + "title": "QueryAnswer", + "type": "object", + }, + }, + "properties": { + "pairs": { + "items": {"$ref": "#/$defs/QueryAnswer"}, + "title": "Pairs", + "type": "array", + } + }, + "required": ["pairs"], + "title": "QueryAnswerPairs", + "type": "object", + } diff --git a/src/distilabel/steps/tasks/apigen/semantic_checker.py b/src/distilabel/steps/tasks/apigen/semantic_checker.py new file mode 100644 index 000000000..5ec7cdc57 --- /dev/null +++ b/src/distilabel/steps/tasks/apigen/semantic_checker.py @@ -0,0 +1,308 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.resources as importlib_resources +from typing import TYPE_CHECKING, Any, Dict, Final, Union + +import orjson +from jinja2 import Template +from pydantic import PrivateAttr +from typing_extensions import override + +from distilabel.steps.tasks.apigen.utils import remove_fences +from distilabel.steps.tasks.base import Task + +if TYPE_CHECKING: + from distilabel.steps.tasks.typing import ChatType + from distilabel.steps.typing import StepColumns + + +SYSTEM_PROMPT_SEMANTIC_CHECKER: Final[str] = """\ +As a data quality evaluator, you must assess the alignment between a user query, corresponding function calls, and their execution results. +These function calls and results are generated by other models, and your task is to ensure these results accurately reflect the user’s intentions. + +Do not pass if: +1. The function call does not align with the query’s objective, or the input arguments appear incorrect. +2. The function call and arguments are not properly chosen from the available functions. +3. The number of function calls does not correspond to the user’s intentions. +4. The execution results are irrelevant and do not match the function’s purpose. +5. The execution results contain errors or reflect that the function calls were not executed successfully. +""".rstrip() + + +class APIGenSemanticChecker(Task): + r"""Generate queries and answers for the given functions in JSON format. + + The `APIGenGenerator` is inspired by the APIGen pipeline, which was designed to generate + verifiable and diverse function-calling datasets. The task generates a set of diverse queries + and corresponding answers for the given functions in JSON format. + + Attributes: + system_prompt: System prompt for the task. Has a default one. + exclude_failed_execution: Whether to exclude failed executions (won't run on those + rows that have a False in `keep_row_after_execution_check` column, which + comes from running `APIGenExecutionChecker`). Defaults to True. + + Input columns: + - func_desc (`str`): Description of what the function should do. + - query (`str`): Instruction from the user. + - answers (`str`): JSON encoded list with arguments to be passed to the function/API. + Should be loaded using `json.loads`. + - execution_result (`str`): Result of the function/API executed. + + Output columns: + - thought (`str`): Reasoning for the output on whether to keep this output or not. + - keep_row_after_semantic_check (`bool`): True or False, can be used to filter + afterwards. + + Categories: + - filtering + - text-generation + + References: + - [APIGen: Automated Pipeline for Generating Verifiable and Diverse Function-Calling Datasets](https://arxiv.org/abs/2406.18518) + - [Salesforce/xlam-function-calling-60k](https://huggingface.co/datasets/Salesforce/xlam-function-calling-60k) + + Examples: + + Semantic checker for generated function calls (original implementation): + + ```python + from distilabel.steps.tasks import APIGenSemanticChecker + from distilabel.llms import InferenceEndpointsLLM + + llm=InferenceEndpointsLLM( + model_id="meta-llama/Meta-Llama-3.1-70B-Instruct", + generation_kwargs={ + "temperature": 0.7, + "max_new_tokens": 1024, + }, + ) + semantic_checker = APIGenSemanticChecker( + use_default_structured_output=False, + llm=llm + ) + semantic_checker.load() + + res = next( + semantic_checker.process( + [ + { + "func_desc": "Fetch information about a specific cat breed from the Cat Breeds API.", + "query": "What information can be obtained about the Maine Coon cat breed?", + "answers": json.dumps([{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]), + "execution_result": "The Maine Coon is a big and hairy breed of cat", + } + ] + ) + ) + res + # [{'func_desc': 'Fetch information about a specific cat breed from the Cat Breeds API.', + # 'query': 'What information can be obtained about the Maine Coon cat breed?', + # 'answers': [{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}], + # 'execution_result': 'The Maine Coon is a big and hairy breed of cat', + # 'thought': '', + # 'keep_row_after_semantic_check': True, + # 'raw_input_a_p_i_gen_semantic_checker_0': [{'role': 'system', + # 'content': 'As a data quality evaluator, you must assess the alignment between a user query, corresponding function calls, and their execution results.\nThese function calls and results are generated by other models, and your task is to ensure these results accurately reflect the user’s intentions.\n\nDo not pass if:\n1. The function call does not align with the query’s objective, or the input arguments appear incorrect.\n2. The function call and arguments are not properly chosen from the available functions.\n3. The number of function calls does not correspond to the user’s intentions.\n4. The execution results are irrelevant and do not match the function’s purpose.\n5. The execution results contain errors or reflect that the function calls were not executed successfully.\n'}, + # {'role': 'user', + # 'content': 'Given Information:\n- All Available Functions:\nFetch information about a specific cat breed from the Cat Breeds API.\n- User Query: What information can be obtained about the Maine Coon cat breed?\n- Generated Function Calls: [{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]\n- Execution Results: The Maine Coon is a big and hairy breed of cat\n\nNote: The query may have multiple intentions. Functions may be placeholders, and execution results may be truncated due to length, which is acceptable and should not cause a failure.\n\nThe main decision factor is wheather the function calls accurately reflect the query\'s intentions and the function descriptions.\nProvide your reasoning in the thought section and decide if the data passes (answer yes or no).\nIf not passing, concisely explain your reasons in the thought section; otherwise, leave this section blank.\n\nYour response MUST strictly adhere to the following JSON format, and NO other text MUST be included.\n```\n{\n "thought": "Concisely describe your reasoning here",\n "pass": "yes" or "no"\n}\n```\n'}]}, + # 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}] + ``` + + Semantic checker for generated function calls (structured output): + + ```python + from distilabel.steps.tasks import APIGenSemanticChecker + from distilabel.llms import InferenceEndpointsLLM + + llm=InferenceEndpointsLLM( + model_id="meta-llama/Meta-Llama-3.1-70B-Instruct", + generation_kwargs={ + "temperature": 0.7, + "max_new_tokens": 1024, + }, + ) + semantic_checker = APIGenSemanticChecker( + use_default_structured_output=True, + llm=llm + ) + semantic_checker.load() + + res = next( + semantic_checker.process( + [ + { + "func_desc": "Fetch information about a specific cat breed from the Cat Breeds API.", + "query": "What information can be obtained about the Maine Coon cat breed?", + "answers": json.dumps([{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]), + "execution_result": "The Maine Coon is a big and hairy breed of cat", + } + ] + ) + ) + res + # [{'func_desc': 'Fetch information about a specific cat breed from the Cat Breeds API.', + # 'query': 'What information can be obtained about the Maine Coon cat breed?', + # 'answers': [{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}], + # 'execution_result': 'The Maine Coon is a big and hairy breed of cat', + # 'keep_row_after_semantic_check': True, + # 'thought': '', + # 'raw_input_a_p_i_gen_semantic_checker_0': [{'role': 'system', + # 'content': 'As a data quality evaluator, you must assess the alignment between a user query, corresponding function calls, and their execution results.\nThese function calls and results are generated by other models, and your task is to ensure these results accurately reflect the user’s intentions.\n\nDo not pass if:\n1. The function call does not align with the query’s objective, or the input arguments appear incorrect.\n2. The function call and arguments are not properly chosen from the available functions.\n3. The number of function calls does not correspond to the user’s intentions.\n4. The execution results are irrelevant and do not match the function’s purpose.\n5. The execution results contain errors or reflect that the function calls were not executed successfully.\n'}, + # {'role': 'user', + # 'content': 'Given Information:\n- All Available Functions:\nFetch information about a specific cat breed from the Cat Breeds API.\n- User Query: What information can be obtained about the Maine Coon cat breed?\n- Generated Function Calls: [{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]\n- Execution Results: The Maine Coon is a big and hairy breed of cat\n\nNote: The query may have multiple intentions. Functions may be placeholders, and execution results may be truncated due to length, which is acceptable and should not cause a failure.\n\nThe main decision factor is wheather the function calls accurately reflect the query\'s intentions and the function descriptions.\nProvide your reasoning in the thought section and decide if the data passes (answer yes or no).\nIf not passing, concisely explain your reasons in the thought section; otherwise, leave this section blank.\n'}]}, + # 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}] + ``` + """ + + system_prompt: str = SYSTEM_PROMPT_SEMANTIC_CHECKER + use_default_structured_output: bool = False + + _format_inst: Union[str, None] = PrivateAttr(None) + + def load(self) -> None: + """Loads the template for the generator prompt.""" + super().load() + _path = str( + importlib_resources.files("distilabel") + / "steps" + / "tasks" + / "templates" + / "apigen" + / "semantic_checker.jinja2" + ) + + self._template = Template(open(_path).read()) + self._format_inst = self._set_format_inst() + + def _set_format_inst(self) -> str: + """Prepares the function to generate the formatted instructions for the prompt. + + If the default structured output is used, returns an empty string because nothing + else is needed, otherwise, returns the original addition to the prompt to guide the model + to generate a formatted JSON. + """ + return ( + "\nYour response MUST strictly adhere to the following JSON format, and NO other text MUST be included.\n" + "```\n" + "{\n" + ' "thought": "Concisely describe your reasoning here",\n' + ' "passes": "yes" or "no"\n' + "}\n" + "```\n" + ) + + @property + def inputs(self) -> "StepColumns": + """The inputs for the task.""" + return { + "func_desc": True, + "query": True, + "answers": True, + "execution_result": True, + "keep_row_after_execution_check": True, + } + + def format_input(self, input: Dict[str, Any]) -> "ChatType": + """The input is formatted as a `ChatType`.""" + return [ + {"role": "system", "content": self.system_prompt}, + { + "role": "user", + "content": self._template.render( + func_desc=input["func_desc"], + query=input["query"] or "", + func_call=input["answers"] or "", + execution_result=input["execution_result"], + format_inst=self._format_inst, + ), + }, + ] + + @property + def outputs(self) -> "StepColumns": + """The output for the task are the queries and corresponding answers.""" + return ["keep_row_after_semantic_check", "thought"] + + def format_output( + self, output: Union[str, None], input: Dict[str, Any] + ) -> Dict[str, Any]: + """The output is formatted as a list with the score of each instruction. + + Args: + output: the raw output of the LLM. + input: the input to the task. Used for obtaining the number of responses. + + Returns: + A dict with the queries and answers pairs. + The answers are an array of answers corresponding to the query. + Each answer is represented as an object with the following properties: + - name (string): The name of the tool used to generate the answer. + - arguments (object): An object representing the arguments passed to the tool to generate the answer. + Each argument is represented as a key-value pair, where the key is the parameter name and the + value is the corresponding value. + """ + if output is None: + return self._default_error(input) + + output = remove_fences(output) + + try: + result = orjson.loads(output) + # Update the column name and change to bool + result["keep_row_after_semantic_check"] = ( + result.pop("passes").lower() == "yes" + ) + input.update(**result) + return input + except orjson.JSONDecodeError: + return self._default_error(input) + + def _default_error(self, input: Dict[str, Any]) -> Dict[str, Any]: + """Default error message for the task.""" + input.update({"thought": None, "keep_row_after_semantic_check": None}) + return input + + @override + def get_structured_output(self) -> Dict[str, Any]: + """Creates the json schema to be passed to the LLM, to enforce generating + a dictionary with the output which can be directly parsed as a python dictionary. + + The schema corresponds to the following: + + ```python + from typing import Literal + from pydantic import BaseModel + import json + + class Checker(BaseModel): + thought: str + passes: Literal["yes", "no"] + + json.dumps(Checker.model_json_schema(), indent=4) + ``` + + Returns: + JSON Schema of the response to enforce. + """ + return { + "properties": { + "thought": {"title": "Thought", "type": "string"}, + "passes": {"enum": ["yes", "no"], "title": "Passes", "type": "string"}, + }, + "required": ["thought", "passes"], + "title": "Checker", + "type": "object", + } diff --git a/src/distilabel/steps/tasks/apigen/utils.py b/src/distilabel/steps/tasks/apigen/utils.py new file mode 100644 index 000000000..85ff0b764 --- /dev/null +++ b/src/distilabel/steps/tasks/apigen/utils.py @@ -0,0 +1,194 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.util +import re +import signal +from typing import TYPE_CHECKING, Any, Callable, Dict, TypedDict, Union + +from distilabel.steps.base import Step, StepInput + +if TYPE_CHECKING: + from types import ModuleType + + from distilabel.steps.typing import StepColumns, StepOutput + + +class PrepareExamples(Step): + r"""Helper step to create examples from `query` and `answers` pairs used as Few Shots in APIGen. + + Attributes: + template (str): The template to format the examples. + + Input columns: + - query (`str`): The query to generate examples from. + - answers (`str`): The answers to the query. + + Output columns: + - examples (`str`): The formatted examples. + + Categories: + - format + + Examples: + Generate examples for APIGen: + + ```python + from distilabel.steps.tasks.apigen.utils import PrepareExamples + + prepare_examples = PrepareExamples() + result = next(prepare_examples.process( + [ + { + "query": ['I need the area of circles with radius 2.5, 5, and 7.5 inches, please.', 'Can you provide the current locations of buses and trolleys on route 12?'], + "answers": ['[{"name": "circle_area", "arguments": {"radius": 2.5}}, {"name": "circle_area", "arguments": {"radius": 5}}, {"name": "circle_area", "arguments": {"radius": 7.5}}]', '[{"name": "bus_trolley_locations", "arguments": {"route": "12"}}]'] + } + ] + ) + # result + # [{'examples': '## Query:\nI need the area of circles with radius 2.5, 5, and 7.5 inches, please.\n## Answers:\n[{"name": "circle_area", "arguments": {"radius": 2.5}}, {"name": "circle_area", "arguments": {"radius": 5}}, {"name": "circle_area", "arguments": {"radius": 7.5}}]\n\n## Query:\nCan you provide the current locations of buses and trolleys on route 12?\n## Answers:\n[{"name": "bus_trolley_locations", "arguments": {"route": "12"}}]'}, {'examples': '## Query:\nI need the area of circles with radius 2.5, 5, and 7.5 inches, please.\n## Answers:\n[{"name": "circle_area", "arguments": {"radius": 2.5}}, {"name": "circle_area", "arguments": {"radius": 5}}, {"name": "circle_area", "arguments": {"radius": 7.5}}]\n\n## Query:\nCan you provide the current locations of buses and trolleys on route 12?\n## Answers:\n[{"name": "bus_trolley_locations", "arguments": {"route": "12"}}]'}] + ``` + """ + + template: str = "## Query:\n{query}\n## Answers:\n{answers}" + + @property + def inputs(self) -> "StepColumns": + return ["query", "answers"] + + @property + def outputs(self) -> "StepColumns": + return ["examples"] + + def process(self, inputs: StepInput) -> "StepOutput": + """The process prepares the data for the `APIGenGenerator` task. + + If a single example is provided, it is copied to avoid raising an error. + + Args: + inputs: A list of dictionaries with the input data. + + Yields: + A list of dictionaries with the output data. + """ + outputs = [] + for input in inputs: + example_list = [] + for query, answers in zip(input["query"], input["answers"]): + example_list.append(self.template.format(query=query, answers=answers)) + outputs.append({"examples": "\n\n".join(example_list)}) + + yield outputs + + +def load_module_from_path(path: str) -> "ModuleType": + """Loads a python module from a given path. + + Args: + path: Path pointing to the module. + + Returns: + ModuleType + + Example: + ```python + path = "/path/to/module.py" + module = load_module_from_path(path) + # And you can load functions from the module like this: + function = getattr(module, "function_name") + function(*args, **kwargs) + ``` + """ + spec = importlib.util.spec_from_file_location("module.name", path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +class FunctionResult(TypedDict): + keep: bool + execution_result: str + + +def execute_from_response( + function: Callable, call_answer: Union[Dict[str, Any], None] +) -> FunctionResult: + """Executes a function with the given arguments as generated by `APIGenGenerator`. + + Given that we cannot cast all the arguments arbitrarily, we try to evaluate them, + which ensures the strings can be converted to the correct type if possible (say + a list of lists of ints will be passed as such instead of its string representation). + + Args: + function: A callable object. + call_answer: The arguments to call the function, as generated by the model. + + Returns: + A container with the result of the execution and if the row should be kept. + """ + if not function: + return FunctionResult(keep=False, execution_result="Function not found") + + if call_answer: + for key, value in call_answer.items(): + if isinstance(value, str): + try: + call_answer[key] = eval(value) + except Exception: + # Leave as is and expect the function to handle it + pass + + try: + if call_answer: + result = run_function_with_timeout(function, 5, *call_answer.values()) + else: + # There can be functions that do not require arguments + result = run_function_with_timeout(function, 5) + return FunctionResult(keep=True, execution_result=str(result)) + except Exception as e: + return FunctionResult(keep=False, execution_result=str(e)) + + +def remove_json_fences(text: str) -> str: + pattern = r"^```json\n([\s\S]*)\n```$" + match = re.match(pattern, text, re.MULTILINE) + if match: + return match.group(1) + return text + + +def remove_fences(text: str) -> str: + pattern = r"^```\n([\s\S]*)\n```$" + match = re.match(pattern, text, re.MULTILINE) + if match: + return match.group(1) + return text + + +def timeout_handler(signum, frame): + raise TimeoutError("Function execution timed out") + + +def run_function_with_timeout(function: Callable, timeout: int = 5, *args: Any) -> Any: + """Run a function with a timeout, to limit the total time waiting for a result.""" + signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(timeout) + + try: + result = function(*args) + finally: + # Cancel the alarm + signal.alarm(0) + + return result diff --git a/src/distilabel/steps/tasks/templates/apigen/generator.jinja2 b/src/distilabel/steps/tasks/templates/apigen/generator.jinja2 new file mode 100644 index 000000000..cc92c725c --- /dev/null +++ b/src/distilabel/steps/tasks/templates/apigen/generator.jinja2 @@ -0,0 +1,10 @@ +Here are examples of queries and the corresponding answers for similar functions: +{{ examples }} + +Note that the query could be interpreted as a combination of several independent requests. +{{ parallel_queries }} +Based on these examples, generate {{ number }} diverse query and answer pairs for the function `{{ func_name }}`. +The detailed function description is the following: +{{ func_desc }} +{{ format_inst }} +Now please generate {{ number }} diverse query and answer pairs following the above format. \ No newline at end of file diff --git a/src/distilabel/steps/tasks/templates/apigen/semantic_checker.jinja2 b/src/distilabel/steps/tasks/templates/apigen/semantic_checker.jinja2 new file mode 100644 index 000000000..8d94357e7 --- /dev/null +++ b/src/distilabel/steps/tasks/templates/apigen/semantic_checker.jinja2 @@ -0,0 +1,13 @@ +Given Information: +- All Available Functions: +{{ func_desc }} +- User Query: {{ query }} +- Generated Function Calls: {{ func_call }} +- Execution Results: {{ execution_result }} + +Note: The query may have multiple intentions. Functions may be placeholders, and execution results may be truncated due to length, which is acceptable and should not cause a failure. + +The main decision factor is wheather the function calls accurately reflect the query's intentions and the function descriptions. +Provide your reasoning in the thought section and decide if the data passes (answer yes or no). +If not passing, concisely explain your reasons in the thought section; otherwise, leave this section blank. +{{ format_inst }} \ No newline at end of file diff --git a/src/distilabel/utils/mkdocs/components_gallery.py b/src/distilabel/utils/mkdocs/components_gallery.py index 9d5d9b59e..621f4b61d 100644 --- a/src/distilabel/utils/mkdocs/components_gallery.py +++ b/src/distilabel/utils/mkdocs/components_gallery.py @@ -90,6 +90,7 @@ "filtering": ":material-filter:", "format": ":material-format-list-bulleted:", "load": ":material-file-download:", + "execution": ":octicons-code-16:", "save": ":material-content-save:", } @@ -108,6 +109,7 @@ "filtering": "Filtering steps are used to filter the data based on some criteria.", "format": "Format steps are used to format the data.", "load": "Load steps are used to load the data.", + "execution": "Executes python functions.", "save": "Save steps are used to save the data.", } diff --git a/tests/integration/test_generator_and_sampler.py b/tests/integration/test_generator_and_sampler.py new file mode 100644 index 000000000..1bb0a457b --- /dev/null +++ b/tests/integration/test_generator_and_sampler.py @@ -0,0 +1,55 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from distilabel.llms._dummy import DummyAsyncLLM +from distilabel.pipeline import Pipeline +from distilabel.steps import CombineOutputs, LoadDataFromDicts +from distilabel.steps.generators.data_sampler import DataSampler +from distilabel.steps.tasks import TextGeneration + + +def get_pipeline(): + with Pipeline() as pipe: + size_dataset_1 = 10 + loader_1 = LoadDataFromDicts( + data=[{"instruction": f"instruction {i}"} for i in range(size_dataset_1)] + ) + sampler = DataSampler( + data=[{"sample": f"sample {i}"} for i in range(30)], + size=2, + samples=size_dataset_1, + batch_size=8, + ) + text_generation = TextGeneration(llm=DummyAsyncLLM(), input_batch_size=8) + + combine = CombineOutputs() + [loader_1, sampler] >> combine >> text_generation + return pipe + + +def test_sampler(): + pipe = get_pipeline() + distiset = pipe.run(use_cache=False) + assert len(distiset["default"]["train"]) == 10 + row = distiset["default"]["train"][0] + assert isinstance(row["sample"], list) + assert len(row["sample"]) == 2 + assert isinstance(row["instruction"], str) + + +if __name__ == "__main__": + pipe = get_pipeline() + distiset = pipe.run(use_cache=False) + print(distiset) + print(distiset["default"]["train"][0]) diff --git a/tests/unit/steps/generators/test_data_sampler.py b/tests/unit/steps/generators/test_data_sampler.py new file mode 100644 index 000000000..32882e037 --- /dev/null +++ b/tests/unit/steps/generators/test_data_sampler.py @@ -0,0 +1,45 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import pytest + +from distilabel.steps.generators.data_sampler import DataSampler + + +@pytest.mark.parametrize( + "samples, size, batch_size, expected", + [ + (10, 2, 4, [4, 4, 2]), + (7, 5, 6, [6, 1]), + (20, 5, 20, [20]), + (20, 50, 8, [8, 8, 4]), + ], +) +def test_generator_and_sampler( + samples: int, size: int, batch_size: int, expected: List[int] +): + sampler = DataSampler( + data=[{"sample": f"sample {i}"} for i in range(30)], + size=size, + samples=samples, + batch_size=batch_size, + ) + sampler.load() + results = [item[0] for item in sampler.process()] + assert len(results) == len(expected) + assert len(results[0]) == batch_size + for i, result in enumerate(results): + assert len(result) == expected[i] diff --git a/tests/unit/steps/tasks/apigen/__init__.py b/tests/unit/steps/tasks/apigen/__init__.py new file mode 100644 index 000000000..20ce00bda --- /dev/null +++ b/tests/unit/steps/tasks/apigen/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/unit/steps/tasks/apigen/_sample_lib/final_velocity.py b/tests/unit/steps/tasks/apigen/_sample_lib/final_velocity.py new file mode 100644 index 000000000..abcc66214 --- /dev/null +++ b/tests/unit/steps/tasks/apigen/_sample_lib/final_velocity.py @@ -0,0 +1,27 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def final_velocity(initial_velocity: float, acceleration: float, time: float) -> int: + """Calculates the final velocity of an object given its initial velocity, acceleration, and time. + + Args: + initial_velocity: The initial velocity of the object. + acceleration: The acceleration of the object. + time: The time elapsed. + + Returns: + The final velocity + """ + return initial_velocity + acceleration * time diff --git a/tests/unit/steps/tasks/apigen/_sample_lib/get_value.py b/tests/unit/steps/tasks/apigen/_sample_lib/get_value.py new file mode 100644 index 000000000..db3bd1bcc --- /dev/null +++ b/tests/unit/steps/tasks/apigen/_sample_lib/get_value.py @@ -0,0 +1,33 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple + + +def get_value(matrix: List[List[int]], indices: Tuple[int, int]) -> Optional[int]: + """Gets the value at the specified index in the matrix. + + Args: + matrix: A list of lists representing the matrix. + indices: A tuple containing the row and column indices. + """ + row_index, col_index = indices + if ( + row_index < 0 + or row_index >= len(matrix) + or col_index < 0 + or col_index >= len(matrix[row_index]) + ): + return None + return matrix[row_index][col_index] diff --git a/tests/unit/steps/tasks/apigen/_sample_module.py b/tests/unit/steps/tasks/apigen/_sample_module.py new file mode 100644 index 000000000..6e9e08502 --- /dev/null +++ b/tests/unit/steps/tasks/apigen/_sample_module.py @@ -0,0 +1,47 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple + + +def final_velocity(initial_velocity: float, acceleration: float, time: float) -> int: + """Calculates the final velocity of an object given its initial velocity, acceleration, and time. + + Args: + initial_velocity: The initial velocity of the object. + acceleration: The acceleration of the object. + time: The time elapsed. + + Returns: + The final velocity + """ + return initial_velocity + acceleration * time + + +def get_value(matrix: List[List[int]], indices: Tuple[int, int]) -> Optional[int]: + """Gets the value at the specified index in the matrix. + + Args: + matrix: A list of lists representing the matrix. + indices: A tuple containing the row and column indices. + """ + row_index, col_index = indices + if ( + row_index < 0 + or row_index >= len(matrix) + or col_index < 0 + or col_index >= len(matrix[row_index]) + ): + return None + return matrix[row_index][col_index] diff --git a/tests/unit/steps/tasks/apigen/test_execution_checker.py b/tests/unit/steps/tasks/apigen/test_execution_checker.py new file mode 100644 index 000000000..d70e42271 --- /dev/null +++ b/tests/unit/steps/tasks/apigen/test_execution_checker.py @@ -0,0 +1,140 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from pathlib import Path +from typing import Any, Dict + +import pytest + +from distilabel.steps.tasks.apigen.execution_checker import APIGenExecutionChecker + +SAMPLE_LIB = Path(__file__).parent / "_sample_module.py" +SAMPLE_LIB_FOLDER = Path(__file__).parent / "_sample_lib" + + +class TestAPIGenExecutionChecker: + @pytest.mark.parametrize("lib", (SAMPLE_LIB, SAMPLE_LIB_FOLDER)) + @pytest.mark.parametrize( + "answers, expected", + [ + ( + { + "query": "Whats the velocity of X?", + "answers": json.dumps( + [ + { + "arguments": { + "initial_velocity": 0.2, + "acceleration": "0.1", + "time": 5, + }, + "name": "final_velocity", + } + ] + ), + }, + [ + { + "query": "Whats the velocity of X?", + "answers": json.dumps( + [ + { + "arguments": { + "initial_velocity": 0.2, + "acceleration": "0.1", + "time": 5, + }, + "name": "final_velocity", + } + ] + ), + "keep_row_after_execution_check": True, + "execution_result": ["0.7"], + } + ], + ), + ( + { + "query": "Other query", + "answers": json.dumps( + [ + { + "arguments": { + "initial_velocity": 0.2, + "acceleration": 0.1, + "time": 0.5, + }, + "name": "unknown_function", + } + ] + ), + }, + [ + { + "query": "Other query", + "answers": json.dumps( + [ + { + "arguments": { + "initial_velocity": 0.2, + "acceleration": 0.1, + "time": 0.5, + }, + "name": "unknown_function", + } + ] + ), + "keep_row_after_execution_check": False, + "execution_result": ["Function 'unknown_function' not found."], + } + ], + ), + ( + { + "query": "Other query", + "answers": '[{"arguments": {"matrix": "[[1, 2, 3], [4, 5, 6], [7, 8, 9]]", "indices": "[1, 2]"}, "name": "get_value"}]', + }, + [ + { + "query": "Other query", + "answers": '[{"arguments": {"matrix": "[[1, 2, 3], [4, 5, 6], [7, 8, 9]]", "indices": "[1, 2]"}, "name": "get_value"}]', + "keep_row_after_execution_check": True, + "execution_result": ["6"], + } + ], + ), + ( + { + "query": "Other query", + "answers": None, + }, + [ + { + "query": "Other query", + "answers": None, + "keep_row_after_execution_check": False, + "execution_result": ["No answers were provided."], + } + ], + ), + ], + ) + def test_process( + self, lib: str, answers: Dict[str, str], expected: Dict[str, Any] + ) -> None: + task = APIGenExecutionChecker(libpath=str(lib)) + task.load() + result = next(task.process([answers])) + assert result == expected diff --git a/tests/unit/steps/tasks/apigen/test_generator.py b/tests/unit/steps/tasks/apigen/test_generator.py new file mode 100644 index 000000000..a290666a6 --- /dev/null +++ b/tests/unit/steps/tasks/apigen/test_generator.py @@ -0,0 +1,172 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +from typing import TYPE_CHECKING, List, Union + +import pytest + +from distilabel.steps.tasks.apigen.generator import APIGenGenerator +from tests.unit.conftest import DummyLLM + +if TYPE_CHECKING: + from distilabel.llms.typing import GenerateOutput + from distilabel.steps.tasks.typing import FormattedInput + +import json + + +class DummyAPIGenLLM(DummyLLM): + use_structured_output: bool = False + number: int = 1 + + def generate( + self, inputs: List["FormattedInput"], num_generations: int = 1 + ) -> "GenerateOutput": + query_answers = [ + { + "query": "What information can be obtained about the Maine Coon cat breed?", + "answers": [ + { + "name": "get_breed_information", + "arguments": {"breed": "Maine Coon"}, + } + ] + * self.number, + } + ] + if self.use_structured_output: + query_answers = {"pairs": query_answers} + return [ + [json.dumps(query_answers) for _ in range(num_generations)] + for _ in range(len(inputs)) + ] + + +# Example of 3 rows from Salesforce/xlam-function-calling-60k +SAMPLE_DATA = [ + { + "answers": '[{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]', + "query": "What information can be obtained about the Maine Coon cat breed?", + "id": 3493, + "tools": '[{"name": "get_breed_information", "description": "Fetch information about a specific cat breed from the Cat Breeds API.", "parameters": {"breed": {"description": "The name of the cat breed to fetch information for.", "type": "str", "default": "aegean"}}}, {"name": "country_region_cities", "description": "Fetches a list of cities within a specified region of a given country from the GeoDB API.", "parameters": {"countryid": {"description": "An ISO-3166 country code or WikiData ID.", "type": "str", "default": "US"}, "regioncode": {"description": "An ISO-3166 or FIPS region code.", "type": "str", "default": "CA"}, "limit": {"description": "The maximum number of results to retrieve. Defaults to None.", "type": "int, optional", "default": ""}, "hateoasmode": {"description": "Include HATEOAS-style links in results. Defaults to None.", "type": "bool, optional", "default": ""}, "asciimode": {"description": "Display results using ASCII characters. Defaults to None.", "type": "bool, optional", "default": ""}, "nameprefixdefaultlangresults": {"description": "Match on names in the default language if a non-default language is requested when prefix-matching. Defaults to None.", "type": "bool, optional", "default": ""}, "timezoneids": {"description": "Only include cities in these time zones. Comma-separated values. Defaults to None.", "type": "str, optional", "default": ""}, "nameprefix": {"description": "Only include cities whose names start with this prefix. If languagecode is set, the prefix will be matched on the name as it appears in that language. Defaults to None.", "type": "str, optional", "default": ""}, "types": {"description": "Only include cities of these types (comma-separated): CITY, ADM2. Defaults to None.", "type": "str, optional", "default": ""}, "minpopulation": {"description": "Only include cities with at least this population. Defaults to None.", "type": "int, optional", "default": ""}, "languagecode": {"description": "Display results in this language. Defaults to None.", "type": "str, optional", "default": ""}, "offset": {"description": "The zero-based offset into the results. Defaults to None.", "type": "int, optional", "default": ""}, "maxpopulation": {"description": "Only include cities with no more than this population. Defaults to None.", "type": "int, optional", "default": ""}, "includedeleted": {"description": "Whether to include any cities marked deleted. Options are: ALL, SINCE_YESTERDAY, SINCE_LAST_WEEK, NONE. Defaults to None.", "type": "str, optional", "default": ""}, "sort": {"description": "How to sort the results. Format: \\u00b1SORT_FIELD,\\u00b1SORT_FIELD where SORT_FIELD = elevation, name, population. Defaults to None.", "type": "str, optional", "default": ""}}}, {"name": "company_details", "description": "Fetch details of a company from Indeed\'s API.", "parameters": {"company_id": {"description": "The unique identifier of the company to fetch details for.", "type": "str", "default": "Microsoft"}, "locality": {"description": "The locality or country code for Indeed\'s subdomain. Default is \'us\' if not provided.", "type": "str, optional", "default": ""}}}]', + }, + { + "answers": '[{"name": "mailcheck", "arguments": {"domain": "protonmail.com"}}, {"name": "mailcheck", "arguments": {"domain": "mail.com"}}, {"name": "get_products_in_category", "arguments": {"skip": 20, "limit": 25, "category": "furniture"}}]', + "query": "Check if the email domains 'protonmail.com' and 'mail.com' are valid and not temporary. Get the products from category 'furniture' in my store, skipping the first 20 items and limiting to 25 items.", + "id": 57546, + "tools": '[{"name": "mailcheck", "description": "Checks if an email domain is valid or a disposable/temporary address.", "parameters": {"domain": {"description": "The email or domain to check for validity. It is recommended to enter just the domain for user privacy.", "type": "str", "default": "mailinator.com"}}}, {"name": "get_products_in_category", "description": "Fetches a list of products from a specified category in a store with pagination.", "parameters": {"skip": {"description": "The number of items to skip before starting to collect the result set.", "type": "int", "default": ""}, "limit": {"description": "The number of items to return in the result set.", "type": "int", "default": ""}, "category": {"description": "The category from which to fetch products.", "type": "str", "default": ""}}}, {"name": "product_by_id", "description": "Fetches detailed information about a specific product from the AliExpress API using the provided product ID.", "parameters": {"product_id": {"description": "The unique identifier for the product on AliExpress.", "type": "int", "default": "32841070485"}}}]', + }, + { + "answers": '[{"name": "navigations_get_node_content", "arguments": {"is_id": 8899, "cat_id": 8899, "language": "en"}}, {"name": "navigations_get_node_content", "arguments": {"is_id": 7766, "cat_id": 7766, "language": "en"}}, {"name": "navigations_get_node_content", "arguments": {"is_id": 5544, "cat_id": 5544, "language": "fr"}}, {"name": "navigations_get_node_content", "arguments": {"is_id": 3322, "cat_id": 3322, "language": "fr"}}]', + "query": "What are the node contents for category IDs 8899 and 7766 in English and for category IDs 5544 and 3322 in French?", + "id": 8815, + "tools": '[{"name": "navigations_get_node_content", "description": "Fetches the content of a node in a navigation hierarchy.", "parameters": {"is_id": {"description": "The \'id\' field value returned from the /navigations/get-root endpoint.", "type": "int", "default": "26066300130"}, "cat_id": {"description": "The \'cat_id\' field value returned from the /navigations/get-tabs endpoint.", "type": "int", "default": "2026"}, "language": {"description": "The 2-letter language code (default is \'en\').", "type": "str, optional", "default": "en"}, "currency": {"description": "The 3-letter currency code (default is \'USD\').", "type": "str, optional", "default": "USD"}, "country": {"description": "The 2-letter country code (default is \'US\').", "type": "str, optional", "default": "US"}}}, {"name": "products_get_reviews", "description": "Fetches brief reviews of a product from the Shein API.", "parameters": {"goods_spu": {"description": "The value of \'productRelationID\' returned in the /products/list or /products/search endpoints. Defaults to \'m22022854841\'.", "type": "str, optional", "default": "m22022854841"}, "cat_id": {"description": "The value of \'cat_id\' returned in the /products/list or /products/search endpoints. Defaults to \'1727\'.", "type": "str, optional", "default": "1727"}, "sku": {"description": "The value of \'goods_sn\' returned in the /products/list or /products/search endpoints. Defaults to \'rm2202285484176751\'.", "type": "str, optional", "default": "rm2202285484176751"}, "currency": {"description": "The 3-letter currency code. Defaults to \'USD\'.", "type": "str, optional", "default": "USD"}, "goods_id": {"description": "The value of \'goods_id\' field returned in the /products/list or /products/search endpoints. Defaults to \'10196865\'.", "type": "str, optional", "default": "10196865"}, "language": {"description": "The 2-letter language code. Defaults to \'en\'.", "type": "str, optional", "default": "en"}, "country": {"description": "The 2-letter country code. Defaults to \'US\'.", "type": "str, optional", "default": "US"}}}]', + }, +] + + +class TestApiGenGenerator: + @pytest.mark.parametrize("number", [1, 2, [3]]) + @pytest.mark.parametrize("use_default_structured_output", [True, False]) + @pytest.mark.parametrize("use_tools", [True, False]) + def test_format_input( + self, + number: Union[int, List[int]], + use_default_structured_output: bool, + use_tools: bool, + ) -> None: + random.seed(42) + task = APIGenGenerator( + llm=DummyLLM(), + number=number, + use_tools=use_tools, + use_default_structured_output=use_default_structured_output, + ) + task.load() + formatted = task.format_input( + input={ + "examples": '## Query:\nWhat information can be obtained about the Maine Coon cat breed?\n## Answer:\n[{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]', + "func_name": "get_breed_information", + "func_desc": "Fetch information about a specific cat breed from the Cat Breeds API.", + "tools": '[{"name": "navigations_get_node_content", "description": "Fetches the content of a node in a navigation hierarchy.", "parameters": {"is_id": {"description": "The \'id\' field value returned from the /navigations/get-root endpoint.", "type": "int", "default": "26066300130"}, "cat_id": {"description": "The \'cat_id\' field value returned from the /navigations/get-tabs endpoint.", "type": "int", "default": "2026"}, "language": {"description": "The 2-letter language code (default is \'en\').", "type": "str, optional", "default": "en"}, "currency": {"description": "The 3-letter currency code (default is \'USD\').", "type": "str, optional", "default": "USD"}, "country": {"description": "The 2-letter country code (default is \'US\').", "type": "str, optional", "default": "US"}}}, {"name": "products_get_reviews", "description": "Fetches brief reviews of a product from the Shein API.", "parameters": {"goods_spu": {"description": "The value of \'productRelationID\' returned in the /products/list or /products/search endpoints. Defaults to \'m22022854841\'.", "type": "str, optional", "default": "m22022854841"}, "cat_id": {"description": "The value of \'cat_id\' returned in the /products/list or /products/search endpoints. Defaults to \'1727\'.", "type": "str, optional", "default": "1727"}, "sku": {"description": "The value of \'goods_sn\' returned in the /products/list or /products/search endpoints. Defaults to \'rm2202285484176751\'.", "type": "str, optional", "default": "rm2202285484176751"}, "currency": {"description": "The 3-letter currency code. Defaults to \'USD\'.", "type": "str, optional", "default": "USD"}, "goods_id": {"description": "The value of \'goods_id\' field returned in the /products/list or /products/search endpoints. Defaults to \'10196865\'.", "type": "str, optional", "default": "10196865"}, "language": {"description": "The 2-letter language code. Defaults to \'en\'.", "type": "str, optional", "default": "en"}, "country": {"description": "The 2-letter country code. Defaults to \'US\'.", "type": "str, optional", "default": "US"}}}]', + } + ) + + assert isinstance(formatted, list) + # Check only the user prompt, the system one should be fixed + formatted_prompt = formatted[1]["content"] + + if isinstance(number, list): + # Fix the number for the tests for simplicity + number = 3 + assert f"Now please generate {number} diverse" in formatted_prompt + + assert ( + "The output MUST strictly adhere to the following JSON format, and NO other text MUST be included:" + in formatted_prompt + ) + + tools_entry = "This is the available tool to guide you (respect the order of the parameters):" + if use_tools: + assert tools_entry in formatted_prompt + else: + assert tools_entry not in formatted_prompt + + is_parallel_check = "It can contain multiple parallel queries in natural language for the given functions. They could use either the same function with different arguments or different functions." + if number > 1: + assert is_parallel_check in formatted_prompt + else: + assert is_parallel_check not in formatted_prompt + + @pytest.mark.parametrize("number", [1, 2]) + @pytest.mark.parametrize("use_default_structured_output", [True, False]) + @pytest.mark.parametrize("use_tools", [True, False]) + def test_process( + self, + number: Union[int, List[int]], + use_default_structured_output: bool, + use_tools: bool, + ) -> None: + # Is parallel is not relevant in this case, it's only relevant for the format_input + # as it will be multiple questions in the prompt + random.seed(42) + task = APIGenGenerator( + llm=DummyAPIGenLLM( + use_structured_output=use_default_structured_output, number=number + ), + number=number, + use_tools=use_tools, + use_default_structured_output=use_default_structured_output, + ) + task.load() + result = next( + task.process( + [ + { + "examples": '## Query:\nWhat information can be obtained about the Maine Coon cat breed?\n## Answer:\n[{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]', + "func_name": "get_breed_information", + "func_desc": "Fetch information about a specific cat breed from the Cat Breeds API.", + "tools": '[{"name": "navigations_get_node_content", "description": "Fetches the content of a node in a navigation hierarchy.", "parameters": {"is_id": {"description": "The \'id\' field value returned from the /navigations/get-root endpoint.", "type": "int", "default": "26066300130"}, "cat_id": {"description": "The \'cat_id\' field value returned from the /navigations/get-tabs endpoint.", "type": "int", "default": "2026"}, "language": {"description": "The 2-letter language code (default is \'en\').", "type": "str, optional", "default": "en"}, "currency": {"description": "The 3-letter currency code (default is \'USD\').", "type": "str, optional", "default": "USD"}, "country": {"description": "The 2-letter country code (default is \'US\').", "type": "str, optional", "default": "US"}}}, {"name": "products_get_reviews", "description": "Fetches brief reviews of a product from the Shein API.", "parameters": {"goods_spu": {"description": "The value of \'productRelationID\' returned in the /products/list or /products/search endpoints. Defaults to \'m22022854841\'.", "type": "str, optional", "default": "m22022854841"}, "cat_id": {"description": "The value of \'cat_id\' returned in the /products/list or /products/search endpoints. Defaults to \'1727\'.", "type": "str, optional", "default": "1727"}, "sku": {"description": "The value of \'goods_sn\' returned in the /products/list or /products/search endpoints. Defaults to \'rm2202285484176751\'.", "type": "str, optional", "default": "rm2202285484176751"}, "currency": {"description": "The 3-letter currency code. Defaults to \'USD\'.", "type": "str, optional", "default": "USD"}, "goods_id": {"description": "The value of \'goods_id\' field returned in the /products/list or /products/search endpoints. Defaults to \'10196865\'.", "type": "str, optional", "default": "10196865"}, "language": {"description": "The 2-letter language code. Defaults to \'en\'.", "type": "str, optional", "default": "en"}, "country": {"description": "The 2-letter country code. Defaults to \'US\'.", "type": "str, optional", "default": "US"}}}]', + } + ] + ) + )[0] + assert "query" in result + assert "answers" in result + query = result["query"] + assert isinstance(query, str) + answers = json.loads(result["answers"]) + assert isinstance(answers, list) + assert len(answers) == number diff --git a/tests/unit/steps/tasks/apigen/test_semantic_checker.py b/tests/unit/steps/tasks/apigen/test_semantic_checker.py new file mode 100644 index 000000000..e73b71c3a --- /dev/null +++ b/tests/unit/steps/tasks/apigen/test_semantic_checker.py @@ -0,0 +1,113 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict + +import pytest + +from distilabel.steps.tasks.apigen.semantic_checker import APIGenSemanticChecker +from tests.unit.conftest import DummyLLM + +SAMPLE_DATA = [ + # The info can for the function description can be obtained from the tool itself + { + "func_desc": "Fetch information about a specific cat breed from the Cat Breeds API.", + "query": "What information can be obtained about the Maine Coon cat breed?", + "answers": '[{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]', + "execution_result": "Hopefully some info about the Maine Coon", + }, + { + "func_desc": "Checks if an email domain is valid or a disposable/temporary address.", + "query": "Check if the email domains 'protonmail.com' and 'mail.com' are valid and not temporary. Get the products from category 'furniture' in my store, skipping the first 20 items and limiting to 25 items.", + "answers": '[{"name": "mailcheck", "arguments": {"domain": "protonmail.com"}}, {"name": "mailcheck", "arguments": {"domain": "mail.com"}}, {"name": "get_products_in_category", "arguments": {"skip": 20, "limit": 25, "category": "furniture"}}]', + "execution_result": "Response for the emails", + }, + { + "func_desc": "Fetches the content of a node in a navigation hierarchy.", + "query": "What are the node contents for category IDs 8899 and 7766 in English and for category IDs 5544 and 3322 in French?", + "answers": '[{"name": "navigations_get_node_content", "arguments": {"is_id": 8899, "cat_id": 8899, "language": "en"}}, {"name": "navigations_get_node_content", "arguments": {"is_id": 7766, "cat_id": 7766, "language": "en"}}, {"name": "navigations_get_node_content", "arguments": {"is_id": 5544, "cat_id": 5544, "language": "fr"}}, {"name": "navigations_get_node_content", "arguments": {"is_id": 3322, "cat_id": 3322, "language": "fr"}}]', + "execution_result": "Response for the node contents", + }, +] + + +class TestAPIGenSemanticChecker: + @pytest.mark.parametrize("use_default_structured_output", [True, False]) + def test_format_input(self, use_default_structured_output: bool) -> None: + task = APIGenSemanticChecker( + llm=DummyLLM(), + use_default_structured_output=use_default_structured_output, + ) + task.load() + result = task.format_input(SAMPLE_DATA[0]) + assert isinstance(result, list) + formatted_prompt = result[1]["content"] + + default_structured_output_check = "Your response MUST strictly adhere to the following JSON format, and NO other text MUST be included" + assert default_structured_output_check in formatted_prompt + assert ( + '- Generated Function Calls: [{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]' + in formatted_prompt + ) + assert ( + "- All Available Functions:\nFetch information about a specific cat breed from the Cat Breeds API." + in formatted_prompt + ) + assert ( + "- Execution Results: Hopefully some info about the Maine Coon" + in formatted_prompt + ) + + @pytest.mark.parametrize( + "result, expected", + [ + ( + '{"thought": "thought", "keep_row_after_semantic_check": "no", "passes": "no"}', + { + "thought": "thought", + "keep_row_after_semantic_check": False, + "answers": '[{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]', + "execution_result": "Hopefully some info about the Maine Coon", + "func_desc": "Fetch information about a specific cat breed from the Cat Breeds API.", + "query": "What information can be obtained about the Maine Coon cat breed?", + }, + ), + ( + None, + { + "thought": None, + "keep_row_after_semantic_check": None, + "answers": '[{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]', + "execution_result": "Hopefully some info about the Maine Coon", + "func_desc": "Fetch information about a specific cat breed from the Cat Breeds API.", + "query": "What information can be obtained about the Maine Coon cat breed?", + }, + ), + ( + "wrong", + { + "thought": None, + "keep_row_after_semantic_check": None, + "answers": '[{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]', + "execution_result": "Hopefully some info about the Maine Coon", + "func_desc": "Fetch information about a specific cat breed from the Cat Breeds API.", + "query": "What information can be obtained about the Maine Coon cat breed?", + }, + ), + ], + ) + def test_format_output(self, result: str, expected: Dict[str, Any]) -> None: + task = APIGenSemanticChecker(llm=DummyLLM()) + task.load() + assert task.format_output(result, SAMPLE_DATA[0]) == expected diff --git a/tests/unit/steps/tasks/apigen/test_utils.py b/tests/unit/steps/tasks/apigen/test_utils.py new file mode 100644 index 000000000..00707f17a --- /dev/null +++ b/tests/unit/steps/tasks/apigen/test_utils.py @@ -0,0 +1,77 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from typing import Any, Dict + +import pytest + +from distilabel.steps.tasks.apigen.utils import ( + execute_from_response, + load_module_from_path, +) + + +@pytest.mark.parametrize( + "function_name, answer, expected_result", + [ + ( + "final_velocity", + {"initial_velocity": 10, "acceleration": 5, "time": 2}, + {"execution_result": "20", "keep": True}, + ), + # In this case, internally we should cast the arguments + ( + "final_velocity", + {"initial_velocity": "10", "acceleration": "5", "time": "2"}, + {"execution_result": "20", "keep": True}, + ), + # Different names for the arguments but correctly positioned + ( + "final_velocity", + {"v0": "10", "a": "5", "t": "2"}, + {"execution_result": "20", "keep": True}, + ), + # Fail casting one of the values + ( + "final_velocity", + {"initial_velocity": "10", "acceleration": "5", "time": "1m/s"}, + { + "execution_result": "unsupported operand type(s) for +: 'int' and 'str'", + "keep": False, + }, + ), + ( + "final_velocity", + {"initial_velocity": 10, "acceleration": 5}, + { + "execution_result": "final_velocity() missing 1 required positional argument: 'time'", + "keep": False, + }, + ), + ( + "unknwown_function", + {"initial_velocity": 10, "acceleration": 5, "time": 2}, + {"execution_result": "Function not found", "keep": False}, + ), + ], +) +def test_execute_from_response( + function_name: str, answer: Dict[str, Any], expected_result: Dict[str, Any] +): + libpath = Path(__file__).parent / "_sample_module.py" + libpath = load_module_from_path(libpath) + final_velocity = getattr(libpath, function_name, None) + result = execute_from_response(final_velocity, answer) + assert result == expected_result